Source code for rtgym.agent.sensory.movement_modulated.speed_cell

import numpy as np
from typing import Union
import matplotlib.pyplot as plt
from rtgym.dataclass import AgentState, Trajectory
from .displacement_abs import DisplacementAbs
from typing import Any


[docs] class SpeedCell(DisplacementAbs): """Speed cells that encode movement speed. Generates sensory responses based on the absolute speed (velocity magnitude) of agent movement. Each cell outputs a 1D value representing the agent's speed, regardless of direction. Inherits from DisplacementAbs and converts displacement to speed. Args: arena (Arena): Arena environment object. **kwargs: Additional keyword arguments including: n_cells (int): Number of speed cells. sigma_s (float): Temporal smoothing sigma in seconds. ssigma_s (float): Spatial smoothing sigma in seconds. normalize (bool): Whether to normalize cell responses. Attributes: sens_type (str): Sensory type identifier 'speed_cell'. """ sens_type = 'speed_cell' def __init__(self, arena, **kwargs): super().__init__(arena, **kwargs) # parameters self.sigma_ts = self._t_to_ts(kwargs.get('sigma_s', 0.0)) self.ssigma_ts = self._t_to_ts(kwargs.get('ssigma_s', 0.0)) # check parameters self._check_params() def _check_params(self): """ Check parameters """ assert self.n_cells > 0, "n_cells <= 0" @staticmethod def _displacement_to_velocity(displacement): """Convert displacement to speed. Computes the Euclidean norm of 2D displacement vectors to get speed (velocity magnitude). Args: displacement (np.ndarray): Displacement array of shape (n_batch, n_time, 2). Returns: np.ndarray: Speed values of shape (n_batch, n_time, 1). """ # displacement: (n_batch, n_time, 2), last dimension is x, y # speed: (n_batch, n_time, 1), the magnitude of velocity velocity = np.linalg.norm(displacement, axis=2) velocity = velocity[:, :, np.newaxis] return velocity
[docs] def get_response(self, agent_data: Union[AgentState, Trajectory, Any]): """Get speed responses for given agent data. Converts displacement to speed and duplicates across all cells. Each cell outputs the same speed value but can be modified by smoothing. Args: agent_data (AgentState or Trajectory): Agent data containing displacement. Returns: np.ndarray: Speed responses. Raises: ValueError: If agent_data type is not supported. """ if isinstance(agent_data, Trajectory): vel = self._displacement_to_velocity(agent_data.disp) / self.t_res dup_res = self._duplicate_res(vel) return self._smooth_res(dup_res) elif isinstance(agent_data, AgentState): return self._displacement_to_velocity(agent_data.disp) / self.t_res else: raise ValueError(f"Invalid agent_data type: {type(agent_data)}, must be rtgym.dataclass.Trajectory or rtgym.dataclass.AgentState")
[docs] def get_specs(self): """Get specifications of this sensory modality. Returns: dict: Dictionary containing speed cell specifications. """ specs = super().get_specs() return specs
[docs] def vis(self, N=5, max_vel=1.0, n_bins=10, cmap='jet', *args, **kwargs): """Visualize speed response curves. Plots speed vs response for the first N cells. Args: N (int): Number of cells to visualize. max_vel (float): Maximum speed for visualization range. n_bins (int): Number of speed bins to plot. cmap (str): Colormap for visualization. *args: Additional positional arguments. **kwargs: Additional keyword arguments. """ vels = np.linspace(0, max_vel, n_bins).reshape(1, -1, 1) dups_res = self._duplicate_res(vels) smooth_res = self._smooth_res(dups_res) fig, axs = plt.subplots(1, N, figsize=(4*N, 4)) for i in range(N): # plot, x axis is the speed, y axis is the response v = vels[0, :, 0] r = smooth_res[0, :, i] axs[i].plot(v, r) plt.show()