import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import time # Import the time module
[docs]
class Masking:
"""
Base class for all masking utilities.
"""
def __init__(
self,
m_max: float = 0.5,
m_min: float = 0.0,
sigma_t: float = 0,
sigma_d: float = 0,
t_warmup: int = 0,
device: torch.device = torch.device('cpu')
):
self.m_max = m_max
self.m_min = m_min if m_min is not None else m_max
self.sigma_t = sigma_t # Temporal smoothing
self.sigma_d = sigma_d # Spatial smoothing
self.t_warmup = t_warmup # Warmup time steps
self.device = device
assert self.m_min <= self.m_max, "m_min must be <= m_max"
[docs]
def to_tensor(self, x):
if isinstance(x, torch.Tensor):
return x.float().to(self.device)
elif isinstance(x, np.ndarray):
return torch.from_numpy(x).float().to(self.device)
else:
raise ValueError(f"Unsupported input type: {type(x)}")
def __call__(self, x, mask_idx=None):
return self.mask(x, mask_idx)
[docs]
def to(self, device: torch.device):
self.device = device
return self
[docs]
def set_m_max(self, m_max: float):
self.m_max = m_max
return self
[docs]
def set_m_min(self, m_min: float):
self.m_min = m_min
return self
[docs]
def gaussian_kernel_1d(self, kernel_size: int, sigma: float):
"""Creates a 1D Gaussian kernel."""
x = torch.arange(kernel_size) - kernel_size // 2
gauss = torch.exp(-x.float()**2 / (2 * sigma**2))
gauss /= gauss.sum()
return gauss
[docs]
def apply_gaussian_blur(self, mask: torch.Tensor):
# Temporal blur
if self.sigma_t > 0:
k_size_t = max(int(6 * self.sigma_t) | 1, 3)
kernel_t = self.gaussian_kernel_1d(k_size_t, self.sigma_t).to(mask.device)
kernel_t = kernel_t.view(1, 1, -1, 1)
pad_t = k_size_t // 2
m = mask.unsqueeze(1)
m = F.pad(m, (0,0, pad_t,pad_t), mode='reflect')
mask = F.conv2d(m, kernel_t, padding=0).squeeze(1)
# Spatial blur
if self.sigma_d > 0:
k_size_d = max(int(6 * self.sigma_d) | 1, 3)
kernel_d = self.gaussian_kernel_1d(k_size_d, self.sigma_d).to(mask.device)
kernel_d = kernel_d.view(1, 1, 1, -1)
pad_d = k_size_d // 2
m = mask.unsqueeze(1)
m = F.pad(m, (pad_d,pad_d, 0,0), mode='reflect')
mask = F.conv2d(m, kernel_d, padding=0).squeeze(1)
return mask
[docs]
def new_mask(self, x):
"""
Generates a mask with different masking thresholds for each dimension.
The thresholds are determined based on per-batch and per-dimension masking ratios.
Args:
x (torch.Tensor): Input tensor of shape (batch_size, T, D)
Returns:
torch.Tensor: Mask tensor of the same shape as x
"""
# Generate a random mask
if self.m_max == 0:
return torch.ones_like(x, device=self.device)
mask = torch.rand(x.shape, device=self.device)
mask = self.apply_gaussian_blur(mask) # Shape: (batch_size, T, D)
# Sort the mask along the time dimension
sorted_mask, _ = mask.sort(dim=1) # Shape: (batch_size, T, D)
# Generate masking ratios for each dimension
masking_ratios = torch.empty(x.shape[0], 1, device=mask.device).uniform_(self.m_min, self.m_max) # Per-batch ratio
sorted_mask, _ = mask.view(x.shape[0], -1).sort(dim=1)
max_size = sorted_mask.size(1) - 1
q_indices = (masking_ratios * (max_size)).long().clamp(max=max_size)
thresholds = sorted_mask.gather(dim=1, index=q_indices).unsqueeze(1)
mask = mask > thresholds # Apply thresholds to create final mask
if self.t_warmup > 0:
mask[:, :self.t_warmup, :] = 1 # Unmask the first time step
return mask.float()
[docs]
def mask(self, x, mask_idx=None):
x = self.to_tensor(x) # Ensure tensor
mask = self.new_mask(x) # Generate mask
if mask_idx is None:
return x * mask # Apply mask
else:
mask[~mask_idx] = 1.0
return x * mask # Apply mask