r"""
Utility functions for decoding sensory responses to spatial trajectories.
This module contains algorithms for converting high-dimensional sensory responses
(e.g., from place cells, grid cells) back to spatial coordinates using various
optimization techniques including nearest neighbor search, spatial interpolation,
and approximate search methods.
Mathematical Background
-----------------------
Given sensory response vectors :math:`\mathbf{r} \in \mathbb{R}^d` and response maps
:math:`\mathbf{M} \in \mathbb{R}^{H \times W \times d}`, we seek to find coordinates
:math:`(y,x)` such that:
.. math::
(y,x) = \arg\min_{(i,j)} \|\mathbf{r} - \mathbf{M}[i,j,:]\|_2^2
where :math:`\|\cdot\|_2` denotes the Euclidean norm.
Authors: RatatouGym Development Team
"""
import numpy as np
import torch
from typing import Union, Optional, Tuple, Any
from sklearn.neighbors import KDTree
from sklearn.cluster import KMeans
import faiss
from rtgym.dataclass import AgentState, Trajectory
[docs]
def decode_response_euclidean(
response: np.ndarray,
res_maps: np.ndarray
) -> Tuple[np.ndarray, bool]:
r"""
Decode sensory responses using brute-force Euclidean distance computation.
This is the reference implementation that computes the exact nearest neighbor
by evaluating the Euclidean distance between each query response and all
template responses in the response maps.
Mathematical Formulation
------------------------
For response vector :math:`\mathbf{r} \in \mathbb{R}^d` and response map
:math:`\mathbf{M} \in \mathbb{R}^{H \times W \times d}`, compute:
.. math::
d^2(\mathbf{r}, \mathbf{M}[i,j]) = \sum_{k=1}^{d} (r_k - M[i,j,k])^2
Return:
.. math::
\arg\min_{(i,j)} d^2(\mathbf{r}, \mathbf{M}[i,j])
Parameters
----------
response : np.ndarray
Sensory response array of shape:
- ``(B, T, D)`` for trajectory decoding
- ``(B, D)`` for single state decoding
where ``B`` = batch size, ``T`` = time steps, ``D`` = feature dimensions
res_maps : np.ndarray
Response template maps of shape ``(D, H, W)`` where ``H`` = arena height,
``W`` = arena width
Returns
-------
tuple of (np.ndarray, bool)
- **decoded_coordinates** : Array of shape ``(B,T,2)`` or ``(B,2)`` containing spatial coordinates
- **is_trajectory** : Boolean indicating if input was trajectory (3D) or state (2D)
Complexity
----------
- **Time**: :math:`O(B \cdot T \cdot H \cdot W \cdot D)` for trajectory, :math:`O(B \cdot H \cdot W \cdot D)` for states
- **Space**: :math:`O(H \cdot W \cdot D)`
Note
----
This method provides exact results but can be computationally expensive
for large arenas or high-dimensional feature spaces. Consider using
approximate methods for better performance.
r"""
n_cells, H, W = res_maps.shape
# Reshape response maps: (D, H, W) → (H*W, D)
M = res_maps.reshape(n_cells, -1).T # shape: (H*W, D)
# Handle NaN values by replacing with infinity (won't be selected as minimum)
M = np.nan_to_num(M, nan=np.inf)
# Generate coordinate grid: each position (i,j) maps to linear index i*W + j
coords = np.stack(
np.meshgrid(np.arange(H), np.arange(W), indexing='ij'), axis=-1
).reshape(-1, 2) # shape: (H*W, 2)
# Determine input format and flatten appropriately
is_trajectory = len(response.shape) == 3
if is_trajectory:
B, T, _ = response.shape
r_flat = response.reshape(-1, n_cells) # shape: (B*T, D)
else:
B = response.shape[0]
r_flat = response # shape: (B, D)
# Compute squared Euclidean distances: ||r - M||²
# Broadcasting: (B*T, 1, D) - (1, H*W, D) → (B*T, H*W, D)
dists = np.sum((r_flat[:, None, :] - M[None, :, :]) ** 2, axis=-1)
# Find coordinates with minimum distance
min_indices = np.argmin(dists, axis=-1)
pred_coords = coords[min_indices]
# Reshape output based on input format
if is_trajectory:
pred_coords = pred_coords.reshape(B, T, 2)
return pred_coords, is_trajectory
[docs]
def decode_response_kdtree(
response: np.ndarray,
res_maps: np.ndarray
) -> Tuple[np.ndarray, bool]:
r"""
Decode sensory responses using K-d tree for accelerated nearest neighbor search.
K-d trees provide significant speedup over brute force methods, especially
for high-dimensional feature spaces. The tree is constructed once and then
queried efficiently for all response vectors.
Algorithm
---------
1. Build K-d tree on flattened response map features
2. Query tree for nearest neighbors of each response vector
3. Map indices back to spatial coordinates
Parameters
----------
response : np.ndarray
Sensory response array (see :func:`decode_response_euclidean`)
res_maps : np.ndarray
Response template maps (see :func:`decode_response_euclidean`)
Returns
-------
tuple of (np.ndarray, bool)
Same format as :func:`decode_response_euclidean`
Complexity
----------
- **Tree construction**: :math:`O(H \cdot W \cdot D \cdot \log(H \cdot W))`
- **Query**: :math:`O(B \cdot T \cdot \log(H \cdot W))` for trajectory
- **Space**: :math:`O(H \cdot W \cdot D)`
Advantages
----------
- Significantly faster than brute force for large arenas
- Exact nearest neighbor results
- Memory efficient tree structure
Limitations
-----------
- Performance degrades in very high dimensions (curse of dimensionality)
- Tree construction overhead for small datasets
r"""
n_cells, H, W = res_maps.shape
# Prepare response maps and coordinates (same as euclidean method)
M = res_maps.reshape(n_cells, -1).T # shape: (H*W, D)
M = np.nan_to_num(M, nan=np.inf)
coords = np.stack(
np.meshgrid(np.arange(H), np.arange(W), indexing='ij'), axis=-1
).reshape(-1, 2)
# Process input format
is_trajectory = len(response.shape) == 3
if is_trajectory:
B, T, _ = response.shape
r_flat = response.reshape(-1, n_cells)
else:
B = response.shape[0]
r_flat = response
# Build K-d tree on feature space
# This is the computational bottleneck but only done once
kdtree = KDTree(M)
# Query tree for nearest neighbors
# Returns (distances, indices) but we only need indices
_, min_indices = kdtree.query(r_flat, k=1)
min_indices = min_indices.flatten()
# Map indices to spatial coordinates
pred_coords = coords[min_indices]
if is_trajectory:
pred_coords = pred_coords.reshape(B, T, 2)
return pred_coords, is_trajectory
[docs]
def decode_response_torch(
response: Union[np.ndarray, torch.Tensor],
res_maps: Union[np.ndarray, torch.Tensor],
device: Optional[Union[str, torch.device]] = None,
chunk_size: int = 1024
) -> Tuple[np.ndarray, bool]:
r"""
Decode sensory responses using PyTorch for GPU-accelerated computation.
This implementation leverages PyTorch's optimized tensor operations and
optional GPU acceleration for faster distance computation. Memory usage
is controlled through chunked processing.
Mathematical Optimization
-------------------------
Instead of computing :math:`\|\mathbf{r} - \mathbf{M}\|^2` directly, we use the identity:
.. math::
\|\mathbf{r} - \mathbf{M}\|^2 = \|\mathbf{r}\|^2 - 2\langle\mathbf{r},\mathbf{M}\rangle + \|\mathbf{M}\|^2
Since :math:`\|\mathbf{r}\|^2` is constant for :math:`\arg\min`, we compute:
.. math::
\text{distance} \propto -2\langle\mathbf{r},\mathbf{M}\rangle + \|\mathbf{M}\|^2
Parameters
----------
response : np.ndarray or torch.Tensor
Sensory response array (numpy or torch tensor)
res_maps : np.ndarray or torch.Tensor
Response template maps (numpy or torch tensor)
device : str, torch.device, or None
Computation device (``'cpu'``, ``'cuda'``, or torch.device object).
If None, automatically selects CUDA if available
chunk_size : int, default=1024
Number of responses to process simultaneously.
Larger values use more memory but may be faster
Returns
-------
tuple of (np.ndarray, bool)
``(decoded_coordinates, is_trajectory)`` as numpy arrays
Performance Notes
-----------------
- GPU acceleration provides 10-100x speedup for large problems
- Chunked processing prevents out-of-memory errors
- Automatic mixed precision could be added for further optimization
Memory Usage
------------
Peak memory ≈ ``chunk_size × H × W × 4`` bytes (for float32)
r"""
# Device selection and setup
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else:
device = torch.device(device)
# Convert inputs to torch tensors if needed
if isinstance(res_maps, np.ndarray):
res_maps = torch.from_numpy(res_maps).to(device)
if isinstance(response, np.ndarray):
response = torch.from_numpy(response).to(device)
# Ensure tensors are on correct device
res_maps = res_maps.to(device)
response = response.to(device)
n_cells, H, W = res_maps.shape
# Create coordinate grid on device
y_coords, x_coords = torch.meshgrid(
torch.arange(H, device=device),
torch.arange(W, device=device),
indexing='ij'
)
coords = torch.stack([y_coords, x_coords], dim=-1).reshape(-1, 2)
# Reshape response maps: (D, H, W) → (H*W, D)
M = res_maps.reshape(n_cells, -1).t() # PyTorch transpose
M = torch.nan_to_num(M, nan=float('inf'))
# Process input format
is_trajectory = len(response.shape) == 3
if is_trajectory:
B, T, _ = response.shape
r_flat = response.reshape(-1, n_cells)
else:
B = response.shape[0]
r_flat = response
# Chunked computation to manage memory usage
min_indices = []
for i in range(0, r_flat.size(0), chunk_size):
chunk = r_flat[i:i+chunk_size]
# Optimized distance computation using matrix operations
# ||r - M||² = ||r||² - 2⟨r,M⟩ + ||M||²
r_norm = (chunk**2).sum(dim=1, keepdim=True) # ||r||²
m_norm = (M**2).sum(dim=1) # ||M||²
# Efficient matrix multiplication for dot products
dot_products = torch.mm(chunk, M.t()) # ⟨r,M⟩
# Compute distances (||r||² term cancels in argmin)
distances = r_norm - 2*dot_products + m_norm
# Find minimum distance indices for this chunk
chunk_min_indices = torch.argmin(distances, dim=1)
min_indices.append(chunk_min_indices)
# Concatenate results from all chunks
min_indices = torch.cat(min_indices)
pred_coords = coords[min_indices]
# Convert back to numpy and reshape
pred_coords_np = pred_coords.cpu().numpy()
if is_trajectory:
pred_coords_np = pred_coords_np.reshape(B, T, 2)
return pred_coords_np, is_trajectory
[docs]
def decode_response_faiss(
response: np.ndarray,
res_maps: np.ndarray,
n_clusters: int = 100
) -> Tuple[np.ndarray, bool]:
r"""
Decode sensory responses using Facebook AI Similarity Search (FAISS).
FAISS provides state-of-the-art performance for large-scale nearest neighbor
search through optimized indexing structures. This implementation uses
Inverted File Index (IVF) with exact search within clusters.
Algorithm
---------
1. Cluster response map features into ``n_clusters`` groups using k-means
2. Build inverted index mapping clusters to their members
3. For each query, search only relevant clusters for nearest neighbors
4. Return exact nearest neighbor within searched clusters
Parameters
----------
response : np.ndarray
Sensory response array
res_maps : np.ndarray
Response template maps
n_clusters : int, default=100
Number of clusters for IVF index:
- More clusters: faster search, potentially lower recall
- Fewer clusters: slower search, higher recall
- Recommended: :math:`\sqrt{\text{number of points}}` to :math:`\text{n_points}/39`
Returns
-------
tuple of (np.ndarray, bool)
``(decoded_coordinates, is_trajectory)``
Performance Characteristics
---------------------------
- **Index build time**: :math:`O(H \cdot W \cdot D \cdot \log(\text{n_clusters}))`
- **Query time**: :math:`O(B \cdot T \cdot \sqrt{H \cdot W})` approximately
- **Memory usage**: ~1.5x the size of original data
Notes
-----
- Requires FAISS library installation
- Optimized for very large datasets (>10k points)
- Can utilize GPU acceleration with appropriate FAISS build
- May return approximate results depending on ``nprobe`` parameter
r"""
n_cells, H, W = res_maps.shape
# Prepare data (same preprocessing as other methods)
M = res_maps.reshape(n_cells, -1).T # shape: (H*W, D)
M = np.nan_to_num(M, nan=0.0) # FAISS doesn't handle inf well
M = M.astype(np.float32) # FAISS requires float32
coords = np.stack(
np.meshgrid(np.arange(H), np.arange(W), indexing='ij'), axis=-1
).reshape(-1, 2)
# Process input format
is_trajectory = len(response.shape) == 3
if is_trajectory:
B, T, _ = response.shape
r_flat = response.reshape(-1, n_cells)
else:
B = response.shape[0]
r_flat = response
r_flat = r_flat.astype(np.float32)
# Build FAISS index
d = n_cells # Feature dimensionality
# Base quantizer for exact L2 distance computation
quantizer = faiss.IndexFlatL2(d)
# Adjust cluster count based on data constraints
# FAISS requires at least 39 training vectors per cluster
n_clusters = min(n_clusters, M.shape[0] // 39)
n_clusters = max(1, n_clusters)
# Create Inverted File Index with Flat (exact) quantizer
index = faiss.IndexIVFFlat(quantizer, d, n_clusters, faiss.METRIC_L2)
# Train index on the response map data
index.train(M)
# Add all response map vectors to index
index.add(M)
# Set search parameters
# nprobe: number of clusters to search (higher = more accurate, slower)
index.nprobe = min(10, n_clusters)
# Perform nearest neighbor search
_, indices = index.search(r_flat, 1) # k=1 for single nearest neighbor
# Map indices back to coordinates
min_indices = indices.flatten()
pred_coords = coords[min_indices]
if is_trajectory:
pred_coords = pred_coords.reshape(B, T, 2)
return pred_coords, is_trajectory
[docs]
def decode_response_interpolation(
response: np.ndarray,
res_maps: np.ndarray,
n_anchors: int = 1000,
random_state: int = 42
) -> Tuple[np.ndarray, bool]:
r"""
Decode sensory responses using spatial interpolation with anchor points.
This method exploits the spatial continuity assumption: nearby locations
should have similar sensory responses. Instead of exact nearest neighbor
search, it interpolates coordinates from multiple nearby anchor points.
Algorithm
---------
1. Select ``n_anchors`` representative points using k-means clustering
2. For each query response, find ``k`` nearest anchor points
3. Compute inverse-distance weighted interpolation of anchor coordinates
4. Return interpolated coordinates (may be non-integer)
Mathematical Formulation
------------------------
Given :math:`k` nearest anchors with coordinates :math:`\mathbf{c}_1,\ldots,\mathbf{c}_k`
and distances :math:`d_1,\ldots,d_k`:
.. math::
w_i &= \frac{1}{d_i + \varepsilon} \quad \text{where } \varepsilon \text{ prevents division by zero}
\tilde{w}_i &= \frac{w_i}{\sum_{j=1}^k w_j} \quad \text{(normalized weights)}
\hat{\mathbf{c}} &= \sum_{i=1}^k \tilde{w}_i \cdot \mathbf{c}_i \quad \text{(interpolated coordinate)}
Parameters
----------
response : np.ndarray
Sensory response array
res_maps : np.ndarray
Response template maps
n_anchors : int, default=1000
Number of anchor points to select:
- More anchors: better spatial resolution, slower computation
- Fewer anchors: faster computation, lower spatial precision
random_state : int, default=42
Random seed for reproducible anchor selection
Returns
-------
tuple of (np.ndarray, bool)
``(decoded_coordinates, is_trajectory)``
Coordinates are rounded to integers and clipped to valid arena bounds
Advantages
----------
- Smooth spatial interpolation reduces noise
- Faster than exact nearest neighbor for large datasets
- Naturally handles uncertainty through weighted averaging
Limitations
-----------
- May not find exact nearest neighbor
- Requires tuning of ``n_anchors`` parameter
- Assumes local spatial smoothness in response patterns
Complexity
----------
- **Anchor selection**: :math:`O(H \cdot W \cdot D \cdot \log(\text{n_anchors}))`
- **Query**: :math:`O(B \cdot T \cdot \log(\text{n_anchors}))`
r"""
n_cells, H, W = res_maps.shape
# Standard preprocessing
M = res_maps.reshape(n_cells, -1).T
M = np.nan_to_num(M, nan=0.0)
coords = np.stack(
np.meshgrid(np.arange(H), np.arange(W), indexing='ij'), axis=-1
).reshape(-1, 2)
# Process input format
is_trajectory = len(response.shape) == 3
if is_trajectory:
B, T, _ = response.shape
r_flat = response.reshape(-1, n_cells)
else:
B = response.shape[0]
r_flat = response
# Limit anchors to available data
n_anchors = min(n_anchors, M.shape[0])
# Select anchor points using k-means clustering
# This finds representative points in feature space
kmeans = KMeans(n_clusters=n_anchors, random_state=random_state, n_init=10)
kmeans.fit(M)
anchor_indices = []
# For each cluster, find the actual data point closest to centroid
for i in range(n_anchors):
cluster_mask = kmeans.labels_ == i
cluster_points = M[cluster_mask]
if len(cluster_points) > 0:
centroid = kmeans.cluster_centers_[i]
distances_to_centroid = np.sum((cluster_points - centroid) ** 2, axis=1)
closest_in_cluster = np.argmin(distances_to_centroid)
# Map back to original index
original_indices = np.where(cluster_mask)[0]
anchor_indices.append(original_indices[closest_in_cluster])
# Extract anchor features and coordinates
anchor_features = M[anchor_indices]
anchor_coords = coords[anchor_indices]
# Build efficient search structure for anchors
anchor_tree = KDTree(anchor_features)
# Find k nearest anchors for interpolation
k = min(5, len(anchor_indices)) # Use up to 5 neighbors
distances, neighbor_indices = anchor_tree.query(r_flat, k=k)
# Compute inverse distance weights
# Add small epsilon to prevent division by zero
epsilon = 1e-10
weights = 1.0 / (distances + epsilon)
# Normalize weights so they sum to 1
weight_sums = weights.sum(axis=1, keepdims=True)
weights = weights / weight_sums
# Perform weighted interpolation of coordinates
pred_coords = np.zeros((r_flat.shape[0], 2))
for i in range(k):
neighbor_coords = anchor_coords[neighbor_indices[:, i]]
pred_coords += weights[:, i:i+1] * neighbor_coords
# Round to integer coordinates and clip to valid bounds
pred_coords = np.round(pred_coords).astype(int)
pred_coords[:, 0] = np.clip(pred_coords[:, 0], 0, H-1)
pred_coords[:, 1] = np.clip(pred_coords[:, 1], 0, W-1)
if is_trajectory:
pred_coords = pred_coords.reshape(B, T, 2)
return pred_coords, is_trajectory
[docs]
def create_dataclass_result(
pred_coords: np.ndarray,
is_trajectory: bool
) -> Union[Trajectory, AgentState]:
r"""
Create appropriate dataclass result from decoded coordinates.
Parameters
----------
pred_coords : np.ndarray
Decoded coordinate array
is_trajectory : bool
Whether to create Trajectory or AgentState object
Returns
-------
Trajectory or AgentState
Dataclass object containing the coordinates
r"""
if is_trajectory:
return Trajectory(coord=pred_coords)
else:
return AgentState(coord=pred_coords)