Source code for rtgym.utils.visualization.plot_ratemap

import numpy as np
import matplotlib.pyplot as plt


[docs] def plot_ratemaps(ratemaps, n_rows, n_cols, cmap="jet", per_image_height=2, arena_map=None): """ Plots a grid of average rate maps with specified rows and columns, dynamically adjusting the figure size to maintain aspect ratios. Args: ratemaps: numpy array of shape (n, w, h) n_rows: number of rows in the subplot grid n_cols: number of columns in the subplot grid cmap: colormap to use for the images per_image_height: height of each individual subplot in inches (default: 2) """ # If ratemaps is a Tensor, convert it to a numpy array if hasattr(ratemaps, 'cpu'): ratemaps = ratemaps.cpu().numpy() elif hasattr(ratemaps, 'numpy'): ratemaps = ratemaps.numpy() n_maps, h, w = ratemaps.shape # This is a transposed shape total_plots = n_rows * n_cols if total_plots < n_maps: print(f"Warning: The grid ({n_rows}x{n_cols}) is smaller than the number of rate maps ({n_maps}). Some maps won't be displayed.") # Calculate aspect ratio of individual images image_aspect = w / h # width divided by height # Calculate figure size # per_image_height is in inches per_image_width = per_image_height * image_aspect total_width = n_cols * per_image_width total_height = n_rows * per_image_height fig, axes = plt.subplots(n_rows, n_cols, figsize=(total_width, total_height)) axes = axes.flatten() # Flatten in case of multiple rows and columns # Plot each rate map for i in range(total_plots): ax = axes[i] if i < n_maps: ratemap = ratemaps[i] if arena_map is not None: # The arena_map corresponds to the inv_arena_map in gym.arena, # where 1 is free, 0 is wall. # If arena_map is not None, set wall to nan and use black as bad color ratemap = np.where(arena_map == 0, np.nan, ratemap) cm = plt.cm.get_cmap(cmap) cm.set_bad(color='grey') ax.imshow(ratemap, cmap=cm) else: ax.axis('off') # Hide axes without data # Remove ticks ax.set_xticks([]) ax.set_yticks([]) # Remove spines for spine in ax.spines.values(): spine.set_visible(False) # Adjust layout to prevent overlap plt.tight_layout() return fig, axes