newton.actuators.ControllerNeuralLSTM#

class newton.actuators.ControllerNeuralLSTM(model_path)[source]#

Bases: Controller

LSTM-based neural network controller.

Uses a pre-trained LSTM network to compute joint effort from position error and velocity error. The network maintains hidden and cell state across timesteps to capture temporal patterns.

The network must be callable as:

effort, (hidden_new, cell_new) = network(input, (hidden, cell))

where input has shape (batch, 1, 2) with features [pos_error, vel_error], and hidden/cell have shape (num_layers, batch, hidden_size).

The network is expected to have a lstm attribute (torch.nn.LSTM) so that num_layers and hidden_size can be inferred automatically.

Scale factors (pos_scale, vel_scale, effort_scale) are read from checkpoint metadata, falling back to 1.0 when absent. Supported checkpoint formats: TorchScript (.pt saved with torch.jit.save) and state-dict bundles ({"model": state_dict, "metadata": {...}} saved with torch.save).

classmethod resolve_arguments(args)#
__init__(model_path)#

Initialize LSTM controller from a checkpoint file.

Supported checkpoint formats: TorchScript (.pt saved with torch.jit.save) and state-dict bundles ({"model": state_dict, "metadata": {...}} saved with torch.save).

Configuration is read from checkpoint metadata:

  • pos_scale (float): position-error scaling (default 1.0).

  • vel_scale (float): velocity-error scaling (default 1.0).

  • effort_scale (float): output effort scaling (default 1.0).

Parameters:

model_path (str) – Path to the checkpoint (.pt).

compute(positions, velocities, target_pos, target_vel, feedforward, pos_indices, vel_indices, target_pos_indices, target_vel_indices, forces, state, dt, device=None)#
finalize(device, num_actuators)#
is_graphable()#
is_stateful()#
state(num_actuators, device)#
update_state(current_state, next_state)#
SHARED_PARAMS: ClassVar[set[str]] = {'model_path'}#