Source code for rtgym.agent.sensory.spatial_modulated.place_cell

import numpy as np
from .sm_base import SMBase
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt


[docs] class PlaceCell(SMBase): """Place cell sensory responses. Place cells are neurons that fire when an animal is in specific locations in an environment. Each place cell has a "place field" - a region where it fires maximally. The firing rate typically follows a Gaussian distribution centered on the place field, decreasing with distance from the center. Place cells are a key component of the brain's spatial navigation system, discovered by O'Keefe and Dostrovsky in 1971. They are primarily found in the hippocampus and are thought to form a cognitive map of the environment. This implementation models place cells with Gaussian or difference-of-Gaussian tuning curves. The difference-of-Gaussian option creates more realistic place fields with inhibitory surrounds. Args: arena (Arena): Arena environment object defining the navigation space. **kwargs: Additional keyword arguments including: n_cells (int): Number of place cells to create. sigma (float or list): Width of place fields in spatial units. If float, all cells use same sigma. If list, each cell gets its own sigma. Defaults to 8. ssigma (float): Standard deviation of random variation in place field sizes. Only used when sigma is float. Defaults to 0. dg_ratio (float): Ratio between center and surround Gaussian for difference-of-Gaussian fields. Values > 1 create inhibitory surrounds. Defaults to 1 (simple Gaussian). magnitude (float): Scaling factor for response amplitudes. Defaults to None. normalize (bool): Whether to normalize responses to [0,1]. Defaults to False. Attributes: sens_type (str): Sensory type identifier 'place_cell'. response_map (np.ndarray): Spatial response map of shape (n_cells, *arena_dimensions). Contains the firing rate map for each cell across the entire arena. """ sens_type = 'place_cell' def __init__(self, arena, **kwargs): super().__init__(arena, **kwargs) # parameters self.sigma = kwargs.get('sigma', 8) # Place field width self.ssigma = kwargs.get('ssigma', 0) # Random variation in field width self.dg_ratio = kwargs.get('dg_ratio', 1) # For difference-of-Gaussian fields self.magnitude = kwargs.get('magnitude', None) # Response scaling self.normalize = kwargs.get('normalize', False) # Normalize to [0,1] # check parameters and initialize responses self._check_params() self._init_params() self._init_response_map() def _check_params(self): """ Validate initialization parameters. Raises: AssertionError: If parameters are invalid (e.g. negative values). """ assert self.n_cells > 0, "n_cells <= 0" # sigma can be a list or an integer if isinstance(self.sigma, int): assert self.sigma > 0, "sigma <= 0" elif isinstance(self.sigma, list): for s in self.sigma: assert s > 0, "sigma <= 0" def _init_params(self): """ Convert parameters from real units to grid units using arena resolution. """ if isinstance(self.sigma, int): self.sigma = int(self.sigma/self.arena.spatial_resolution) self.ssimga = int(self.ssigma/self.arena.spatial_resolution) elif isinstance(self.sigma, list): self.sigma = [int(s/self.arena.spatial_resolution) for s in self.sigma] if self.ssigma != 0: print("Warning: ssigma is not used when sigma is a list") def _init_response_map(self): """ Generate the spatial response fields for all place cells. Creates Gaussian or difference-of-Gaussian response fields centered at random locations in the arena's free space. The response map contains the firing rate for each cell at every location in the arena. """ super()._init_response_map() arena_dims = self.arena.dimensions arena_free_space = self.arena.free_space # Randomly select cell centers from available free space place_centers = arena_free_space[self.rng.choice(arena_free_space.shape[0], size=self.n_cells, replace=False)] cells = np.zeros((self.n_cells, *arena_dims)) # Generate place field sizes, either constant or with random variation sigma_list = self.rng.normal(self.sigma, self.ssigma, self.n_cells) if self.ssigma != 0 else [self.sigma] * self.n_cells # Create response fields using Gaussian filters for i in range(self.n_cells): cells[i, place_centers[i][0], place_centers[i][1]] = 1 cells[i] = gaussian_filter(cells[i], sigma_list[i]) if self.dg_ratio > 1: # Create inhibitory surround if requested cells[i] -= gaussian_filter(cells[i], sigma_list[i]*self.dg_ratio) # Normalize responses if requested if self.normalize: if self.dg_ratio > 1: cells = cells * self.magnitude if self.magnitude is not None else cells cells = (cells - cells.mean(axis=(1, 2), keepdims=True)) else: # Normalize to 0-1 cells = (cells - cells.min()) / (cells.max() - cells.min()) cells = cells * self.magnitude if self.magnitude is not None else cells self.response_map = cells
[docs] def get_specs(self): """ Get statistical specifications of the place cell population. Returns: dict: Dictionary containing population statistics including: - cell_max_avg: Average maximum firing rate across cells - cell_min_avg: Average minimum firing rate across cells - cell_mean_avg: Average mean firing rate across cells """ specs = super().get_specs() specs['cell_max_avg'] = self.response_map.max(axis=(1, 2)).mean() specs['cell_min_avg'] = self.response_map.min(axis=(1, 2)).mean() specs['cell_mean_avg'] = self.response_map.mean(axis=(1, 2)).mean() return specs