Source code for encoding.features.FIR_expander

import numpy as np
from dataclasses import dataclass
from typing import Iterable, Optional


[docs] @dataclass class FIR: """ Finite Impulse Response (FIR) expander for creating delayed feature matrices. Usage options: - Static/class usage: FIR.make_delayed(stim, delays, circpad=False) - Instance usage: FIR(delays, circpad).expand(stim) """ delays: Optional[Iterable[int]] = None circpad: bool = False
[docs] def expand(self, stim: np.ndarray) -> np.ndarray: if self.delays is None: raise ValueError("delays must be provided for instance usage of FIR") return FIR.make_delayed(stim, self.delays, self.circpad)
[docs] @staticmethod def make_delayed( stim: np.ndarray, delays: Iterable[int], circpad: bool = False ) -> np.ndarray: nt, ndim = stim.shape dstims = [] for d in delays: dstim = np.zeros((nt, ndim)) if d < 0: dstim[:d, :] = stim[-d:, :] if circpad: dstim[d:, :] = stim[:-d, :] elif d > 0: dstim[d:, :] = stim[:-d, :] if circpad: dstim[:d, :] = stim[-d:, :] else: dstim = stim.copy() dstims.append(dstim) return np.hstack(dstims)
[docs] def n_delays(self) -> int: """Return the number of delays used.""" return len(self.delays) if self.delays is not None else 0
[docs] def output_dim(self, input_dim: int) -> int: """Return the output dimensionality after FIR expansion.""" return input_dim * self.n_delays()
[docs] def valid_length(self, nt: int) -> int: """ Number of valid time points (non-padded). With circpad=True, always nt. Without circpad, depends on max shift. """ if self.delays is None: raise ValueError("delays must be provided") if self.circpad: return nt max_shift = max(abs(d) for d in self.delays) return max(0, nt - max_shift)
[docs] def summary(self, input_dim: Optional[int] = None, nt: Optional[int] = None) -> str: """Return a readable summary of FIR configuration.""" msg = f"FIR(delays={list(self.delays)}, circpad={self.circpad})" if input_dim is not None: msg += f"\n- Output dim: {self.output_dim(input_dim)}" if nt is not None: msg += f"\n- Valid length: {self.valid_length(nt)}" return msg