Source code for rtgym.agent.sensory.spatial_modulated.weak_sm_cell

import numpy as np
from .sm_base import SMBase
from scipy.ndimage import gaussian_filter
import matplotlib.pyplot as plt
from scipy.stats import truncnorm


[docs] class WeakSMCell(SMBase): """Weakly spatially modulated cells with Gaussian responses. These cells have broader, weaker spatial tuning compared to place cells, representing cells that are spatially modulated but not strongly place-specific. Each cell has a Gaussian-shaped firing field at a random location. Args: arena (Arena): Arena environment object. **kwargs: Additional keyword arguments including: n_cells (int): Number of weakly spatially modulated cells. sigma (float): Sigma of the Gaussian filter in spatial units. magnitude (float): Maximum magnitude of cell responses. Attributes: sens_type (str): Sensory type identifier 'weak_sm_cell'. response_map (np.ndarray): Spatial response map of shape (n_cells, *arena_dimensions). """ sens_type = 'weak_sm_cell' def __init__(self, arena, **kwargs): super().__init__(arena, **kwargs) # parameters self.sigma = kwargs.get('sigma', 8) self.ssigma = kwargs.get('ssigma', 0) self.magnitude = kwargs.get('magnitude', None) self.normalize = kwargs.get('normalize', False) # check parameters and initialize responses self._check_params() self._init_params() self._init_response_map() def _check_params(self): """ Check parameters """ assert self.n_cells > 0, "n_cells <= 0" # sigma can be a list or an integer if isinstance(self.sigma, int): assert self.sigma > 0, "sigma <= 0" elif isinstance(self.sigma, list): for s in self.sigma: assert s > 0, "sigma <= 0" def _init_params(self): if isinstance(self.sigma, int): self.sigma = int(self.sigma/self.arena.spatial_resolution) elif isinstance(self.sigma, list): self.sigma = [int(s/self.arena.spatial_resolution) for s in self.sigma] def _init_response_map(self): """ Initialize response_map """ super()._init_response_map() # border padding is also included in the arena map self.response_map = self._generate_smcells() def _generate_smcells(self): """ Generate a spatially modulated non-grid/place cell response field """ arena_dims = self.arena.dimensions cells = self.rng.normal(0, 1, (self.n_cells, *arena_dims)) # filter each cell response field with a 2d gaussian filter if self.ssigma > 0: mean, std = self.sigma, self.ssigma a = (0 - mean) / std b = (100 - mean) / std sigma_distribution = truncnorm(a=a, b=b, loc=mean, scale=std) sigma_list = sigma_distribution.rvs(self.n_cells, random_state=self.rng) else: sigma_list = np.ones(self.n_cells) * self.sigma for i in range(self.n_cells): cell = gaussian_filter(cells[i], sigma_list[i], mode='constant') if self.normalize: cell = (cell - cell.min()) / (cell.max() - cell.min()) cells[i] = cell # Scale cells to have desired mean magnitude instead of max magnitude if self.magnitude is not None: # Calculate current mean for each cell cell_means = cells.mean(axis=(1, 2)) # Create scaling factors to achieve target mean magnitude scaling_factors = self.magnitude / cell_means # Apply scaling factors to each cell (broadcasting across spatial dimensions) cells = cells * scaling_factors[:, np.newaxis, np.newaxis] return cells # (n, *arena_dims)
[docs] def get_specs(self): specs = super().get_specs() specs['cell_max_avg'] = self.response_map.max(axis=(1, 2)).mean() specs['cell_min_avg'] = self.response_map.min(axis=(1, 2)).mean() specs['cell_mean_avg'] = self.response_map.mean(axis=(1, 2)).mean() return specs