import rtgym
import numpy as np
import rtgym.dataclass
from .displacement_abs import DisplacementAbs
[docs]
class HeadDirectionCell(DisplacementAbs):
"""Head direction cells that respond to agent's heading direction.
These cells fire when the agent is facing a specific direction, with each cell
having a preferred direction. The responses are modeled using Gaussian tuning
curves around the preferred directions.
Args:
arena (Arena): Arena environment object.
**kwargs: Additional keyword arguments including:
n_cells (int): Number of head direction cells.
magnitude (float): Maximum magnitude of cell responses. Defaults to 1.
normalize (bool): Whether to normalize cell responses.
sorted (bool): Whether to sort cells by preferred directions. Defaults to True.
n_bins (int): Number of direction bins for discretization. Defaults to 360.
sigma (float): Standard deviation of the Gaussian tuning kernel. Defaults to 2.
Attributes:
sens_type (str): Sensory type identifier 'head_direction_cell'.
"""
sens_type = 'head_direction_cell'
def __init__(self, arena, **kwargs):
super().__init__(arena, **kwargs)
# parameters
self.sigma = kwargs.get('sigma', 2)
self.n_bins = kwargs.get('n_bins', 360)
self.sorted = kwargs.get('sorted', True)
self.normalize = kwargs.get('normalize', False)
self.magnitude = kwargs.get('magnitude', 1)
# check parameters
self._check_params()
self._init_mm_responses()
def _check_params(self):
""" Check parameters """
assert self.n_cells > 0, "n_cells <= 0"
def _init_mm_responses(self):
"""
Initialize the movement modulated responses with wrapped Gaussian distributions.
"""
# Define the range of directions
self.dirs = self.rng.uniform(-np.pi, np.pi, self.n_bins) # Shape: (n_bins,)
# Assign random preferred directions to each cell
self.hd_dirs = self.rng.uniform(-np.pi, np.pi, self.n_cells) # Shape: (n_cells,)
if self.sorted:
self.hd_dirs = np.sort(self.hd_dirs)
# Compute the angular distance between each cell's preferred direction and each bin
angular_dist = np.mod(self.hd_dirs[:, np.newaxis] - self.dirs[np.newaxis, :] + np.pi, 2 * np.pi) - np.pi
# Compute the Gaussian response based on angular distance
self.mm_responses = np.exp(-0.5 * (angular_dist / self.sigma) ** 2) # Shape: (n_cells, n_bins)
if self.normalize:
# zero-one normalization
self.mm_responses = (self.mm_responses - np.min(self.mm_responses)) / (np.max(self.mm_responses) - np.min(self.mm_responses))
self.mm_responses *= self.magnitude
[docs]
def get_response(self, agent_state: rtgym.dataclass.AgentState):
"""Get the movement modulated responses for the given trajectory.
Args:
agent_state: Agent state.
Returns:
Movement modulated responses of shape (n_batches, n_steps, n_cells).
"""
# Get the head direction from the trajectory
hd = agent_state.hds # Expected shape: (n_batches, n_steps, 1)
# Remove the last dimension if it's singleton
if hd.shape[-1] == 1:
hd = hd.squeeze(-1) # Now shape: (n_batches, n_steps)
# Compute the closest direction index for each head direction
# Resulting shape: (n_batches, n_steps)
closest_indices = np.argmin(np.abs(hd[:, :, np.newaxis] - self.dirs[np.newaxis, np.newaxis, :]), axis=-1)
# Use np.take to gather responses from mm_responses
# self.mm_responses shape: (n_cells, n_bins)
# We take along axis=1 (bins) for the closest_indices
# Resulting shape after np.take: (n_cells, n_batches, n_steps)
gathered_responses = np.take(self.mm_responses, closest_indices, axis=1)
# Transpose to get shape: (n_batches, n_steps, n_cells)
responses = gathered_responses.transpose(1, 2, 0)
return responses
[docs]
def get_specs(self):
specs = super().get_specs()
specs['magnitude'] = self.magnitude
specs['normalize'] = self.normalize
return specs
[docs]
def vis(self, N, **kwargs):
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
ax.imshow(self.mm_responses[:N], aspect='auto', cmap='jet')
ax.set_xticks([0, self.n_bins//4, self.n_bins//2, 3*self.n_bins//4, self.n_bins-1])
ax.set_xticklabels(['-π', '-π/2', '0', 'π/2', 'π (wrap around)'])
ax.set_xlabel('Direction (radian)')
ax.set_ylabel('Cells')
ax.set_title('Head Direction Cells')
plt.show()