Source code for rtgym.agent.sensory.spatial_modulated.grid_cell

import numpy as np
from .sm_base import SMBase
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
import pickle


[docs] class GridCell(SMBase): """ Grid cell responses for a spatially modulated grid-like field. """ sens_type = 'grid_cell' def __init__(self, arena, **kwargs): """Initialize the grid cell responses. Args: arena: The arena object defining the spatial environment. **kwargs: Additional keyword arguments including: n_cells (int): Number of cells. sigma2scale (float): Scale to sigma ratio. scale (int or float): Scale (periodicity) of the grid cells. magnitude (float): Maximum magnitude of the cell responses. normalize (bool): If True, normalize the responses. orientation (float): Orientation of the grid cells (degrees). """ super().__init__(arena, **kwargs) # Parameters self.scale = kwargs.get('scale', 40) self.sigma2scale = kwargs.get('sigma2scale', 3.26) self.magnitude = kwargs.get('magnitude', None) self.normalize = kwargs.get('normalize', False) self.orientation = kwargs.get('orientation', 0) self.orientation = int(self.orientation) # Validate and initialize parameters self._check_params() self._init_params() self._init_response_map() def _check_params(self): """Validate the parameters.""" assert self.n_cells > 0, "n_cells must be greater than 0." if isinstance(self.sigma2scale, (int, float)): assert self.sigma2scale > 1, "sigma2scale must be greater than 1." elif isinstance(self.sigma, list): assert all(s > 0 for s in self.sigma), "All sigma values must be greater than 0." def _init_params(self): """Initialize and scale parameters based on spatial resolution.""" if isinstance(self.scale, (int, float)): self.scale = self.scale / self.arena.spatial_resolution elif isinstance(self.sigma, list): self.scale = np.array(self.sigma) / self.arena.spatial_resolution else: raise ValueError("scale or sigma must be a float or a list.") self.sigma = self.scale / self.sigma2scale / 2 def _generate_grid_phase_shifts(self): """Generate uniformly distributed points within a triangular region. Returns: points (ndarray): Array of shape (n_cells, 2) with uniformly distributed points. """ # Generate random points u, v = self.rng.uniform(size=self.n_cells), self.rng.uniform(size=self.n_cells) # Reflect points to ensure uniformity in the triangular region mask = u + v > 1 u[mask], v[mask] = 1 - u[mask], 1 - v[mask] # Define vertices of the triangle s = self.scale A, B, C = (0, 0), (0, -s), (np.sqrt(3) / 2 * s, -s / 2) # Convert to Cartesian coordinates self.grid_shifts = (1 - u - v)[:, None] * np.array(A) + \ u[:, None] * np.array(B) + \ v[:, None] * np.array(C) def _generate_grid_centers(self, extended_arena_dims): """Generate grid centers for the extended arena dimensions with optional rotation. Args: extended_arena_dims (tuple): Dimensions of the extended arena. Returns: grid_centers (ndarray): Array of grid center coordinates. """ # Calculate grid dimensions and add padding for phase shift arc_length = extended_arena_dims[1] * 2 / np.sqrt(3) dim_0 = extended_arena_dims[0] + arc_length / 2 dim_1 = arc_length dim_0 += self.sigma * 4 dim_1 += self.sigma * 4 # Create ranges for grid centers x, y = np.arange(0, dim_0, self.scale), np.arange(0, dim_1, self.scale) XX, YY = np.meshgrid(x, y) grid_centers = np.column_stack((XX.ravel(), YY.ravel())) # Apply hexagonal transformation trans_mat = np.array([[1, 1/2], [0, np.sqrt(3)/2]]) grid_centers = grid_centers @ trans_mat.T grid_centers[:, 0] -= arc_length / 2 # Center the grid # Rotate the grid centers around the center of the dimensions if self.orientation != 0: # Convert rotation angle to radians angle_rad = np.radians(self.orientation) # Define rotation matrix rotation_matrix = np.array([ [np.cos(angle_rad), -np.sin(angle_rad)], [np.sin(angle_rad), np.cos(angle_rad)] ]) # Compute the center of the dimensions center = np.array(extended_arena_dims) / 2 # Translate grid centers to origin, apply rotation, and translate back grid_centers = (grid_centers - center) @ rotation_matrix.T + center return np.int32(grid_centers) def _generate_grid_map(self): """ Generate standard grid centers for all cells without shifts. """ diag = np.linalg.norm(self.arena.dimensions) max_shift = np.max(np.abs(self.grid_shifts)) extended_dims = (int(diag * 2 + max_shift * 2), int(diag * 2 + max_shift * 2)) # Extended arena size grid_centers = self._generate_grid_centers(extended_dims) # Filter grid centers within bounds is_within_bounds = ( (grid_centers[:, 0] >= 0) & (grid_centers[:, 0] < extended_dims[0]) & (grid_centers[:, 1] >= 0) & (grid_centers[:, 1] < extended_dims[1]) ) grid_centers = grid_centers[is_within_bounds] # Create sample map and apply Gaussian filter sample_map = np.zeros(extended_dims) sample_map[grid_centers[:, 0], grid_centers[:, 1]] = 1 self.sample_map = gaussian_filter(sample_map, self.sigma, mode='constant') def _compute_res(self): """ Compute the responses for the grid cells. """ # Compute a sample map self._generate_grid_map() # Generate responses for each cell dims = self.arena.dimensions x_start = (self.sample_map.shape[0] - dims[0]) // 2 y_start = (self.sample_map.shape[1] - dims[1]) // 2 cells = np.zeros((self.n_cells, *dims)) for cell_idx in range(self.n_cells): _x = int(x_start + self.grid_shifts[cell_idx, 0]) _y = int(y_start + self.grid_shifts[cell_idx, 1]) cell = self.sample_map[_x:_x + dims[0], _y:_y + dims[1]] if self.normalize: cell = (cell - cell.min()) / (cell.max() - cell.min()) cells[cell_idx] = cell # Scale cells to have desired mean magnitude instead of max magnitude if self.magnitude is not None: # Calculate current mean for each cell cell_means = cells.mean(axis=(1, 2)) # Create scaling factors to achieve target mean magnitude scaling_factors = self.magnitude / cell_means # Apply scaling factors to each cell (broadcasting across spatial dimensions) cells = cells * scaling_factors[:, np.newaxis, np.newaxis] self.response_map = cells def _init_response_map(self): """ Initialize spatially modulated responses for the grid cells. """ assert len(self.arena.dimensions) == 2, "Only 2D arenas are supported." super()._init_response_map() self._generate_grid_phase_shifts() self._compute_res()
[docs] def get_specs(self): """Retrieve statistics of the cell responses.""" specs = super().get_specs() specs.update({ 'cell_max_avg': self.response_map.max(axis=(1, 2)).mean(), 'cell_min_avg': self.response_map.min(axis=(1, 2)).mean(), 'cell_mean_avg': self.response_map.mean(axis=(1, 2)).mean(), }) return specs
[docs] def state_dict(self): """Get the essential attributes of the GridCell object. Returns: dict: Dictionary of the essential attributes. """ return { 'sens_type': self.sens_type, 'n_cells': self.n_cells, 'sigma': self.sigma, 'scale': self.scale, 'magnitude': self.magnitude, 'normalize': self.normalize, 'orientation': self.orientation, 'grid_shifts': self.grid_shifts, }
[docs] @classmethod def load_from_dict(cls, state_dict, arena): """Load the GridCell object from a dictionary and reconstruct it. Args: state_dict (dict): Dictionary containing the object's state. arena (Arena): Arena object to reinitialize the class. Returns: GridCell: Reconstructed GridCell object. """ obj = cls( arena, n_cells=state_dict['n_cells'], sigma=state_dict['sigma'], scale=state_dict['scale'], magnitude=state_dict['magnitude'], normalize=state_dict['normalize'], orientation=state_dict['orientation'] ) obj.grid_shifts = state_dict['grid_shifts'] obj._compute_res() return obj