Source code for rtgym.agent.sensory.movement_modulated.direction_cell

import numpy as np
import rtgym
import rtgym.dataclass
from .mm_base import MMBase
import matplotlib.pyplot as plt
from rtgym.dataclass import Trajectory, AgentState
from typing import Union

[docs] class DirectionCell(MMBase): """Direction cells with Gaussian tuning to movement direction. These cells respond to the direction of agent movement with Gaussian tuning curves. Each cell has a preferred direction and fires maximally when the agent moves in that direction. Args: arena (Arena): Arena environment object. **kwargs: Additional keyword arguments including: n_cells (int): Number of direction cells. magnitude (float): Maximum magnitude of cell responses. normalize (bool): Whether to normalize cell responses. sorted (bool): Whether to sort cells by preferred directions. Defaults to True. n_bins (int): Number of direction bins for discretization. Defaults to 360. sigma (float): Standard deviation of the Gaussian tuning kernel. Defaults to 2. msigma (float): Standard deviation of magnitude modulation. Attributes: sens_type (str): Sensory type identifier 'direction_cell'. """ sens_type = 'direction_cell' def __init__(self, arena, **kwargs): super().__init__(arena, **kwargs) # parameters self.sigma = kwargs.get('sigma', 2) self.n_bins = kwargs.get('n_bins', 360) self.sorted = kwargs.get('sorted', True) self.normalize = kwargs.get('normalize', False) # check parameters self._check_params() self._init_mm_responses() def _check_params(self): """ Check parameters """ assert self.n_cells > 0, "n_cells <= 0" assert self.n_bins > 0, "n_bins <= 0" def _init_mm_responses(self): """ Initialize the movement modulated responses with wrapped Gaussian distributions. """ # Define the range of directions self.dirs = np.linspace(-np.pi, np.pi, self.n_bins, endpoint=False) # Shape: (n_bins,) # Assign random preferred directions to each cell self.pref_dirs = self.rng.uniform(-np.pi, np.pi, self.n_cells) # Shape: (n_cells,) if self.sorted: self.pref_dirs = np.sort(self.pref_dirs) # Compute the angular distance between each cell's preferred direction and each bin's direction (n_cells, n_bins) angular_dist = np.abs(self.pref_dirs[:, np.newaxis] - self.dirs[np.newaxis, :]) angular_dist = np.mod(angular_dist + np.pi, 2 * np.pi) - np.pi # Compute the Gaussian response based on angular distance self.mm_responses = np.exp(-0.5 * (angular_dist / self.sigma) ** 2) # Shape: (n_cells, n_bins) if self.normalize: # zero-one normalization self.mm_responses = (self.mm_responses - np.min(self.mm_responses)) / (np.max(self.mm_responses) - np.min(self.mm_responses))
[docs] def get_response(self, agent_state: Union[AgentState, Trajectory]): disp = agent_state.disp # arctan2 already returns angles in [-π, π] mv_dir = np.arctan2(disp[:, :, 0], disp[:, :, 1]) # Squeeze if the last dimension is singleton if mv_dir.ndim == 3 and mv_dir.shape[-1] == 1: mv_dir = mv_dir.squeeze(-1) # Directly compute bin indices from angles indices = (((mv_dir + np.pi) / (2 * np.pi)) * self.n_bins).astype(np.int64) # Use np.take to get responses from precomputed mm_responses responses = np.take(self.mm_responses, indices, axis=1) return responses.transpose(1, 2, 0) * self.magnitude
[docs] def get_specs(self): specs = super().get_specs() specs['magnitude'] = self.magnitude specs['normalize'] = self.normalize return specs
[docs] def vis(self, N, **kwargs): fig, ax = plt.subplots() cax = ax.imshow(self.mm_responses[:N], aspect='auto', cmap='jet') fig.colorbar(cax, ax=ax) ax.set_xticks([0, self.n_bins//4, self.n_bins//2, 3*self.n_bins//4, self.n_bins-1]) ax.set_xticklabels(['-π', '-π/2', '0', 'π/2', 'π (wrap around)']) ax.set_xlabel('Direction (radian)') ax.set_ylabel('Cells') ax.set_title('Head Direction Cells') plt.show()