Source code for rtgym.dataclass.trajectory

import os
import numpy as np
from typing import Optional
from .agent_state import AgentState


[docs] class Trajectory: """ Class to hold trial data for the agent. This class is used to store the coordinates, coordinates in float, and head directions of the agent. """ def __init__( self, coord: Optional[np.ndarray] = None, disp: Optional[np.ndarray] = None, hd: Optional[np.ndarray] = None ): """Initialize the Trajectory object with the given data. Args: coord: Coordinates of the agent as float values. disp: Displacements of the agent between coordinates. hd: Head directions of the agent. At least one of these must be provided. """ # Check if not all are None if coord is None and disp is None and dir is None: raise ValueError("At least one of coords_float, head_directions, or displacements must be provided") # Check if the input data are all numpy arrays assert isinstance(coord, np.ndarray) or coord is None, "coord must be a numpy array" assert isinstance(disp, np.ndarray) or disp is None, "disp must be a numpy array" assert isinstance(hd, np.ndarray) or hd is None, "hd must be a numpy array" # Check if the input data all have three dimensions, (n_batch, n_time, n_features) if coord is not None: assert coord.ndim == 3, "coord must have three dimensions (n_batch, n_time, n_features)" if disp is not None: assert disp.ndim == 3, "disp must have three dimensions (n_batch, n_time, n_features)" if hd is not None: assert hd.ndim == 3, "hd must have three dimensions (n_batch, n_time, n_features)" # Check if the first and second dimensions of the input data are the same for all not None data_list = [coord, disp, hd] not_none_data = [data for data in data_list if data is not None] for i in range(1, len(not_none_data)): assert not_none_data[i].shape[0] == not_none_data[0].shape[0], "All provided data must have the same batch dimension" assert not_none_data[i].shape[1] == not_none_data[0].shape[1], "All provided data must have the same time dimension" # Store the input data self._coord = coord self.disp = disp self.hd = hd
[docs] def copy(self): """ Returns a copy of the Trajectory object. """ return Trajectory( coord=self.coord.copy() if self.coord is not None else None, disp=self.disp.copy() if self.disp is not None else None, hd=self.hd.copy() if self.hd is not None else None )
def __getitem__(self, index): """ Returns the agent state at the specified time index. Args: time_index (int): The time index for which to get the agent state. Returns: AgentState: An AgentState object containing the agent state at the specified time index. """ if isinstance(index, int): return AgentState( coord=self.coord[:, index, :].copy() if self.coord is not None else None, hd=self.hd[:, index, :].copy() if self.hd is not None else None, disp=self.disp[:, index, :].copy() if self.disp is not None else None ) elif isinstance(index, slice): return Trajectory( coord=self.coord[:, index, :].copy() if self.coord is not None else None, hd=self.hd[:, index, :].copy() if self.hd is not None else None, disp=self.disp[:, index, :].copy() if self.disp is not None else None ) else: raise TypeError("Index must be an int or slice") def __len__(self): """ Returns the number of time steps in the trajectory. """ return self.coord.shape[1] if self.coord is not None else self.hd.shape[1] if self.hd is not None else self.disp.shape[1]
[docs] def slice(self, start, end): return self.t_range((start, end))
@property def float_coord(self): return self.coord @property def coord(self): return self._coord @coord.setter def coord(self, value): self._coord = value @property def int_coord(self): return self.coord.astype(int) @property def n_steps(self): """ Returns the number of time steps in the trajectory. """ return self.coord.shape[1] if self.coord is not None else self.hd.shape[1] if self.hd is not None else self.disp.shape[1]
[docs] def reshape(self, shape): """ Reshape the trajectory data to the specified shape. """ assert isinstance(shape, tuple) and len(shape) == 2, "Shape must be a tuple of length 2" self.coord = self.coord.reshape(shape[0], shape[1], self.coord.shape[2]) self.hd = self.hd.reshape(shape[0], shape[1], self.hd.shape[2]) self.disp = self.disp.reshape(shape[0], shape[1], self.disp.shape[2]) return self
[docs] def t_range(self, range_): """ Returns a new Trajectory object with data trimmed to the specified range. Args: range_ (list or tuple): A 2-element list or tuple specifying the start and end indices for trimming. Returns: Trajectory: A new Trajectory object with trimmed data. """ # Check if the range is valid assert len(range_) == 2, "range_ must be a tuple of length 2" assert range_[0] < range_[1], "range_[0] must be less than range_[1]" assert range_[1] <= self.coord.shape[1], "range_[1] must be less than the trial duration" assert range_[0] >= 0, "range_[0] must be greater than or equal to 0" # Trim the data return Trajectory( coord=self.coord[:, range_[0]:range_[1]] if self.coord is not None else None, hd=self.hd[:, range_[0]:range_[1]] if self.hd is not None else None, disp=self.disp[:, range_[0]:range_[1]] if self.disp is not None else None )
[docs] @staticmethod def load(path): """ Load a Trajectory object from a file. Args: path (str): The file path to load the Trajectory object from. Returns: Trajectory: A Trajectory object loaded from the specified file. """ assert os.path.splitext(path)[1] == '.npz', "File extension must be .npz" loaded_dict = np.load(path) return Trajectory( coord=loaded_dict['coord'] if 'coord' in loaded_dict else None, hd=loaded_dict['hd'] if 'hd' in loaded_dict else None, disp=loaded_dict['disp'] if 'disp' in loaded_dict else None )
[docs] @classmethod def from_dict(cls, state_dict): return cls( coord=state_dict['coord'], hd=state_dict['hd'], disp=state_dict['disp'] )
[docs] def state_dict(self): return { 'coord': self.coord, 'hd': self.hd, 'disp': self.disp }