newton.actuators.ControllerNeuralLSTM#
- class newton.actuators.ControllerNeuralLSTM(model_path)[source]#
Bases:
ControllerLSTM-based neural network controller.
Uses a pre-trained LSTM network to compute joint effort from position error and velocity error. The network maintains hidden and cell state across timesteps to capture temporal patterns.
The network must be callable as:
effort, (hidden_new, cell_new) = network(input, (hidden, cell))
where input has shape (batch, 1, 2) with features [pos_error, vel_error], and hidden/cell have shape (num_layers, batch, hidden_size).
The network is expected to have a
lstmattribute (torch.nn.LSTM) so thatnum_layersandhidden_sizecan be inferred automatically.Scale factors (
pos_scale,vel_scale,effort_scale) are read from checkpoint metadata, falling back to1.0when absent. Supported checkpoint formats: TorchScript (.ptsaved withtorch.jit.save) and state-dict bundles ({"model": state_dict, "metadata": {...}}saved withtorch.save).- classmethod resolve_arguments(args)#
- __init__(model_path)#
Initialize LSTM controller from a checkpoint file.
Supported checkpoint formats: TorchScript (
.ptsaved withtorch.jit.save) and state-dict bundles ({"model": state_dict, "metadata": {...}}saved withtorch.save).Configuration is read from checkpoint metadata:
pos_scale(float): position-error scaling (default1.0).vel_scale(float): velocity-error scaling (default1.0).effort_scale(float): output effort scaling (default1.0).
- Parameters:
model_path (str) – Path to the checkpoint (
.pt).
- compute(positions, velocities, target_pos, target_vel, feedforward, pos_indices, vel_indices, target_pos_indices, target_vel_indices, forces, state, dt, device=None)#
- finalize(device, num_actuators)#
- is_graphable()#
- is_stateful()#
- state(num_actuators, device)#
- update_state(current_state, next_state)#