Source code for rtgym.dataclass.agent_state
import numpy as np
[docs]
class AgentState:
"""Agent state data container for a single timestep.
Holds the state data of the agent including coordinates, head directions,
and displacements. This represents a snapshot of the agent's state at a
single timestep and can hold multiple batches of data.
Args:
coord (np.ndarray, optional): Integer coordinates of the agent.
hd (np.ndarray, optional): Head direction of the agent in radians.
disp (np.ndarray, optional): Displacement vector of the agent.
Raises:
ValueError: If all input parameters are None.
Note:
For unknown values, use np.nan as placeholder. This data class helps
standardize the data format for all sensory modalities.
"""
def __init__(
self,
coord: np.ndarray = None,
hd: np.ndarray = None,
disp: np.ndarray = None
):
# Ensure that at least one input is not None
if coord is None and hd is None and disp is None:
raise ValueError("At least one of coord, hd, or disp must be provided.")
# Determine batch size based on the first non-None input
batch_size = next(
x.shape[0] for x in (coord, hd, disp) if x is not None
)
# Replace None inputs with np.nan arrays of appropriate shapes
coord = coord if coord is not None else np.full((batch_size, 2), np.nan)
hd = hd if hd is not None else np.full((batch_size, 1), np.nan)
disp = disp if disp is not None else np.full((batch_size, 2), np.nan)
# Check if the input data are all numpy arrays
assert isinstance(coord, np.ndarray), "coord must be a numpy array"
assert isinstance(hd, np.ndarray), "hd must be a numpy array"
assert isinstance(disp, np.ndarray), "disp must be a numpy array"
# Check if the input data all have two dimensions, (n_batch, n_features)
assert coord.ndim == 2, "coord must have two dimensions, (n_batch, 2)"
assert hd.ndim == 2, "hd must have two dimensions, (n_batch, 2)"
assert disp.ndim == 2, "disp must have two dimensions, (n_batch, 2)"
self.coord = coord
self.hd = hd
self.disp = disp
@property
def int_coord(self):
return self.coord.astype(int)
@property
def float_coord(self):
return self.coord