Source code for rtgym.rtgym

"""
RatatouGym is a python package that provides a simple interface to generate sensory responses
from a virtual agent in a virtual environment.

This file contains more documentation as it serves as the main API for the package.
"""

import os
import numpy as np
from rtgym.agent import Agent
from rtgym.arena import Arena
import torch
import matplotlib.pyplot as plt
from IPython.display import HTML
import matplotlib.animation as animation


[docs] class RatatouGym(): """ The RatatouGym class is a singleton that integrates the environment, agent, behavior, and sensory streams. When initialized, the class creates an arena and an uninitialized agent. The arena is the environment the agent moves in, and the agent is the subject navigating the arena. After initialization, both the arena and the agent must be set up. For the arena, you can either choose from predefined shapes or provide a custom 2D NumPy array as the map. For the agent, you must configure sensory and behavioral parameters. Sensory is defined by a dictionary of the sensory cell types you want to simulate. Behavior is defined by parameters such as movement speed, how frequently the agent changes direction, and how it avoids obstacles. Args: temporal_resolution (float): Time resolution in milliseconds. E.g. 50 will mean 50ms per timestep. spatial_resolution (float): Spatial resolution in units per pixel. E.g. 1 will mean 1cm per pixel. **kwargs: Additional keyword arguments. Example: >>> from rtgym import RatatouGym >>> gym = RatatouGym( ... temporal_resolution=50, ... spatial_resolution=1, ... ) """ def __init__( self, temporal_resolution, spatial_resolution, **kwargs ): self.temporal_resolution = temporal_resolution self.spatial_resolution = spatial_resolution self.arena = Arena(self) self.agent = Agent(self) self.arena.subscribe(self._on_arena_change) # ========================================================================= # Basic get methods without trial generation # ========================================================================= @property def t_res(self): """ Get temporal resolution of the gym. The temporal resolution controls how continuous time in the real world is discretized into timesteps. Returns: float: Temporal resolution in milliseconds. """ return self.temporal_resolution @property def s_res(self): """ Get spatial resolution of the gym. The spatial resolution controls how the arena is discretized into pixels. The units are in cm per pixel. Returns: float: Spatial resolution in cm per pixel. """ return self.spatial_resolution
[docs] def to_ts(self, t): """ Convert time (in seconds) to number of timesteps. This is a helper function. Note that this method will round the time to the nearest integer. Args: t (int, float, np.ndarray, or torch.Tensor): Time value(s) in seconds Returns: int, np.ndarray, or torch.Tensor: Corresponding number of timestep(s) Raises: TypeError: If input type is not supported. Example: >>> gym = RatatouGym(temporal_resolution=50, spatial_resolution=1) >>> gym.to_ts(1) >>> 20 # 1 second = 20 timesteps >>> gym.to_ts(np.array([1, 2, 3])) >>> array([20 40 60]) >>> gym.to_ts(torch.tensor([1, 2, 3])) >>> tensor([20, 40, 60], dtype=torch.int32) """ if isinstance(t, (int, float)): return int(np.round(t * 1e3 / self.t_res)) elif isinstance(t, np.ndarray): return np.round(t * 1e3 / self.t_res).astype(int) elif 'torch' in str(type(t)): return torch.round(t * 1e3 / self.t_res).int() else: raise TypeError("Unsupported type for t")
[docs] def to_sec(self, ts): """ Convert timesteps to seconds. This is a helper function to convert timesteps in milliseconds to seconds. Args: ts (int, float, np.ndarray, or torch.Tensor): Timestep value(s). Returns: int, float, np.ndarray, or torch.Tensor: Corresponding time(s) in seconds. Raises: TypeError: If input type is not supported. Example: >>> gym = RatatouGym(temporal_resolution=50, spatial_resolution=1) >>> gym.to_sec(20) >>> 1.0 # 20 timesteps = 1 second (50ms per timestep) >>> gym.to_sec(np.array([20, 40, 60])) >>> array([1., 2., 3.]) >>> gym.to_sec(torch.tensor([20, 40, 60])) >>> tensor([1., 2., 3.]) """ if isinstance(ts, (int, float)): return ts * self.t_res / 1e3 elif isinstance(ts, np.ndarray): return ts * self.t_res / 1e3 elif 'torch' in str(type(ts)): return ts * self.t_res / 1e3 else: raise TypeError("Unsupported type for ts")
[docs] def to_coord(self, pos): """ Convert position to coordinate indices. This method transforms continuous position values into discrete coordinate indices based on the spatial resolution of the arena. Args: pos (array-like): Position values in the continuous space of the arena. Returns: np.ndarray: Coordinate indices in the discretized arena grid. """ return np.array(pos/self.s_res)
@property def arena_map(self): """ Get the arena map. The arena map represents the spatial layout of the environment, encoding walls and free space for navigation. Returns: np.ndarray: Arena map of shape (H, W) where 0 represents free space and 1 represents walls. This binary representation is used internally for collision detection and path planning. """ return self.arena.arena_map @property def inv_arena_map(self): """ Get the inverted arena map for visualization. The inverted map is useful for plotting and visualization purposes where free space should be highlighted rather than walls. Returns: np.ndarray: Inverted arena map of shape (H, W) where 1 represents free space and 0 represents walls. This format is optimized for visual display and matplotlib visualization. """ return self.arena.inv_arena_map
[docs] def random_pos(self, n): """ Generate random valid positions in the arena. This method creates random positions that are guaranteed to be in free space (not inside walls), making them suitable for agent placement or target positioning. Args: n (int): Number of random positions to generate. Must be a positive integer. Returns: np.ndarray: Array of shape (n, 2) containing random positions. """ return self.arena.generate_random_pos(n)
# ========================================================================= # ========================================================================= # Set methods to initialize the arena and agent # ========================================================================= def _on_arena_change(self): """ Callback method triggered when arena changes. This internal method ensures proper synchronization between the arena and agent components when the arena configuration is modified. The method automatically notifies the agent component about arena changes, allowing it to update its internal state and sensory configurations accordingly. This ensures that the agent's spatial representations remain consistent with the current arena layout. Note: This is an internal callback method and should not be called directly by users. It is automatically invoked when arena properties change. """ self.agent._on_arena_change()
[docs] def load_arena(self, path): """ Load an arena configuration from a file. This method allows loading previously saved arena configurations, enabling reproducible experiments and sharing of environment setups. The loaded arena will replace the current arena configuration entirely, including its spatial layout, dimensions, and any associated metadata. The agent will be automatically notified of the arena change through the internal callback system. Args: path (str): Path to the arena file. Should be a valid file path pointing to a previously saved arena configuration. The file format depends on the arena implementation but typically uses .npz format for numpy arrays. Raises: FileNotFoundError: If the specified path does not exist. ValueError: If the file format is incompatible or corrupted. Example: >>> gym = RatatouGym(temporal_resolution=50.0, spatial_resolution=1.0) >>> gym.load_arena('saved_environments/maze_1.npz') """ self.arena.load(path)
[docs] def save_arena(self, path): """ Save the current arena configuration to a file. This method enables persistence of arena configurations for later use, sharing, or experimental reproducibility. The saved file will contain the complete arena state including the spatial layout, dimensions, and any associated metadata. The saved arena can be later loaded using the load_arena method. Args: path (str): Path where the arena file will be saved. Must have .npz extension as the arena is saved in numpy's compressed format for efficiency and compatibility. Raises: AssertionError: If the file extension is not .npz. PermissionError: If the specified path is not writable. OSError: If there are issues with file system operations. Example: >>> gym = RatatouGym(temporal_resolution=50.0, spatial_resolution=1.0) >>> gym.init_arena_map(shape='square', width=100, height=100) >>> gym.save_arena('my_environments/custom_maze.npz') """ # file extension is npz assert os.path.splitext(path)[1] == '.npz', "file extension must be .npz" self.arena.save(path)
[docs] def init_arena_map(self, **kwargs): """ Initialize the arena map with a predefined geometric shape. This method calls the corresponding function in the arena class to initialize the navigation space from a set of predefined shapes. The method supports various predefined shapes commonly used in spatial navigation research, such as squares, rectangles, circles, and more complex geometries. Each shape type accepts specific parameters that define its dimensions and characteristics. Please refer to the `rtgym.arena.arena_shapes` module for more details. Args: **kwargs: Shape-specific parameters that vary depending on the selected arena geometry. Each arean shape might have different parameters. Please refer to the `rtgym.arena.arena_shapes` module for more details. Example: >>> gym = RatatouGym(temporal_resolution=50.0, spatial_resolution=1.0) >>> gym.init_arena_map(shape='rectangle', dimensions=[100, 100]) """ self.arena.init_arena_map(**kwargs)
[docs] def set_arena_map(self, arena_map): """ Set arena map with a custom binary layout. This method allows complete control over the arena geometry by directly specifying the spatial layout as a binary array. The custom arena map enables creation of arbitrary environments beyond the predefined shapes, allowing for complex mazes, irregular boundaries, or specialized experimental setups. The binary format uses 0 for free space and 1 for walls, following standard conventions in spatial navigation research. Args: arena_map (np.ndarray): Custom arena map of shape (H, W) where 0 represents free space (navigable areas) and 1 represents walls (impassable barriers). The array should be 2-dimensional with consistent data types (typically bool or int). Raises: ValueError: If arena_map is not a 2D array or contains invalid values. TypeError: If arena_map is not a numpy array or compatible type. Example: >>> import numpy as np >>> custom_map = np.zeros((50, 50)) # 50x50 open field >>> custom_map[0, :] = 1 # top wall >>> custom_map[-1, :] = 1 # bottom wall >>> custom_map[:, 0] = 1 # left wall >>> custom_map[:, -1] = 1 # right wall >>> gym.set_arena_map(custom_map) """ self.arena.set_arena_map(arena_map=arena_map)
# ====================================================================================== # ====================================================================================== # Sensory methods # ======================================================================================
[docs] def set_sensory_manually(self, sens_type, sensory): import warnings warnings.warn("DeprecationWarning: The method 'set_sensory_manually' is deprecated. " "Potential implementation issues might cause unintended behavior. " "It is recommended to implement a separate sensory class instead.")
# self.agent.set_sensory_manually(sens_type, sensory) # ====================================================================================== # ====================================================================================== # Visualization methods # ====================================================================================== @staticmethod def _compute_plot_dimensions(arena_map, fig_height): aspect_ratio = arena_map.shape[1] / arena_map.shape[0] fig_width = fig_height * aspect_ratio return fig_width, fig_height
[docs] def vis_gif(self, traj, plot_w, plot_h, return_format='anim', stride=10, interval=50, max_frames=100): """Create an animated visualization of trajectory and heading direction. This method generates an animated plot showing the agent's path through the arena with real-time visualization of position and heading direction. The trajectory is displayed as a red line that progressively reveals the path taken, while a blue arrow indicates the agent's current heading direction at each timestep. The animation can be customized for performance by adjusting frame rate, sampling density, and maximum duration. For long trajectories, the stride and max_frames parameters help balance visualization quality with computational efficiency. Args: traj: Trajectory data object containing position coordinates (int_coord, float_coord) and heading direction (hd) information. Must have shape (batch, timesteps, dims). plot_w (float): Desired width of the plot in inches. Will be automatically computed based on arena aspect ratio if not specified correctly. plot_h (float): Desired height of the plot in inches. return_format (str): Output format for the animation. Options: - 'anim': Returns matplotlib animation object for further manipulation - 'html': Returns HTML representation for Jupyter notebook display stride (int): Step size for sampling trajectory points. Higher values create faster animations by skipping frames but may lose trajectory detail. Default is 10. interval (int): Time between animation frames in milliseconds. Lower values create faster playback. Default is 50ms (20 FPS). max_frames (int): Maximum number of frames to generate regardless of trajectory length. Helps prevent memory issues with very long trajectories and ensures reasonable animation generation times. Default is 100. Returns: matplotlib.animation.FuncAnimation: Animation object that can be saved, displayed, or converted to different formats. Raises: ValueError: If trajectory data is empty or malformed (no data along expected axis). Example: >>> # Create animation for a recorded trajectory >>> anim = gym.vis_gif(trajectory, plot_w=8, plot_h=6, ... stride=5, interval=100, max_frames=50) >>> anim.save('trajectory.gif', writer='pillow', fps=10) """ if traj.int_coord.shape[1] == 0 or traj.hd.shape[1] == 0: raise ValueError("traj.int_coord or traj.hd has no data along the expected axis.") plot_w, plot_h = self._compute_plot_dimensions(self.arena.arena_map, plot_h) fig, ax = plt.subplots(figsize=(plot_w, plot_h)) ax.imshow(self.inv_arena_map) ax.axis('off') line, = ax.plot([], [], 'r') # Initialize the quiver arrow arrow = ax.quiver(0, 0, 0.05, 0.05, angles='xy', scale_units='xy', scale=0.2, color='blue') # Precompute trajectory data traj_length = traj.int_coord.shape[1] # Calculate appropriate stride to ensure we cover the full trajectory total_frames = min(max_frames, int(np.ceil(traj_length / stride))) frames_stride = max(1, int(np.ceil(traj_length / total_frames))) # Precompute frame indices to ensure we show the full trajectory frame_indices = np.linspace(0, traj_length-1, total_frames, dtype=int) def init(): line.set_data([], []) arrow.set_offsets([0, 0]) arrow.set_UVC(0.05, 0.05) return line, arrow def animate(i): frame_idx = frame_indices[i] # Always show the complete trajectory up to the current point line.set_data(traj.float_coord[0, :frame_idx+1, 1], traj.float_coord[0, :frame_idx+1, 0]) if frame_idx > 0: # Get the tip position of the line tip_x, tip_y = traj.float_coord[0, frame_idx, 1], traj.float_coord[0, frame_idx, 0] # Calculate direction based on the angle angle = traj.hd[0, frame_idx, 0] direction_x = np.cos(angle) direction_y = np.sin(angle) # Update arrow position and direction arrow.set_offsets([tip_x, tip_y]) arrow.set_UVC(direction_x, direction_y) return line, arrow anim = animation.FuncAnimation(fig, animate, init_func=init, frames=total_frames, interval=interval, blit=True) plt.tight_layout() return HTML(anim.to_jshtml()) if return_format == 'html' else anim
[docs] def vis_traj(self, traj, cmap="viridis", linewidth=1, return_format=None, height=3, ax=None, vis_kwargs=None): """ Visualize the trajectory of the agent in the arena. This method provides flexible visualization of agent trajectories, supporting both static plotting and animated displays. The trajectory is overlaid on the arena map, showing the path taken by the agent through the environment. The visualization automatically scales to maintain proper aspect ratios and provides options for customization through various parameters. For static visualizations, the complete trajectory is displayed as a line plot over the arena. For animated visualizations, the trajectory is built up frame by frame, optionally showing the agent's heading direction as it moves. Args: traj (rtgym.dataclass.Trajectory): Trajectory object containing the agent's position data and other trajectory information to be visualized. cmap (str, optional): Colormap for the arena background. Defaults to "viridis". linewidth (float, optional): Width of the trajectory line. Defaults to 1. return_format (str, optional): Format of the returned visualization. Options are: - 'anim': Returns matplotlib animation object - 'html': Returns HTML representation of animation for Jupyter notebooks - None: Returns static plot as (fig, ax) tuple height (float, optional): Height of the figure in inches. Width is automatically calculated to maintain arena aspect ratio. Defaults to 3. ax (matplotlib.axes.Axes, optional): Existing axes to plot on. If provided, figure size parameters are ignored. vis_kwargs (dict, optional): Additional keyword arguments passed to animation function when return_format is 'anim' or 'html'. Returns: tuple or animation object: - If return_format is None: Returns (fig, ax) tuple for static plot - If return_format is 'anim': Returns matplotlib FuncAnimation object - If return_format is 'html': Returns HTML object for Jupyter display """ plot_w, plot_h = self._compute_plot_dimensions(self.arena.arena_map, height) if return_format == 'anim' or return_format == 'html': return self.vis_gif(traj, plot_w, plot_h, return_format, **vis_kwargs) else: fig, ax = plt.subplots(figsize=(plot_w, plot_h)) ax.imshow(self.arena.inv_arena_map, cmap=cmap) ax.plot(traj.float_coord[:, :, 1].T, traj.float_coord[:, :, 0].T, color='red', linewidth=linewidth) ax.set_axis_off() return fig, ax