encoding.trainer
Abstract trainer that accepts components as dependencies.
- class encoding.trainer.AbstractTrainer(assembly: Any, feature_extractors: List[Any], downsampler: Any, model: Any, fir_delays: List[int], trimming_config: Dict, use_train_test_split: bool = False, layer_idx: int = 9, lookback: int = 256, dataset_type: str = 'unknown', logger_backend: str = 'wandb', wandb_project_name: str = 'abstract-trainer', results_dir: str = 'results', run_name: str | None = None, downsample_config: Dict | None = None, story_selection: List[str] | None = None)
A completely abstract trainer that accepts all components as dependencies.
This trainer doesn’t know about datasets, assemblies, or specific feature types. It just orchestrates the pipeline: extract → downsample → FIR → trim → train.
- __init__(assembly: Any, feature_extractors: List[Any], downsampler: Any, model: Any, fir_delays: List[int], trimming_config: Dict, use_train_test_split: bool = False, layer_idx: int = 9, lookback: int = 256, dataset_type: str = 'unknown', logger_backend: str = 'wandb', wandb_project_name: str = 'abstract-trainer', results_dir: str = 'results', run_name: str | None = None, downsample_config: Dict | None = None, story_selection: List[str] | None = None)
Initialize with all components as dependencies.
- Parameters:
assembly – Data assembly (has .stories, .get_brain_data(), etc.)
feature_extractors – List of feature extraction components
downsampler – Downsampling component
model – Model with fit_predict() method
fir_delays – List of FIR delays to apply
trimming_config – Dict specifying how to trim data
use_train_test_split – Whether to use train/test split vs concatenation
layer_idx – Layer index for feature extraction
lookback – Context lookback for feature extraction
dataset_type – Dataset type for caching
logger_backend – “wandb” or “tensorboard”
wandb_project_name – Project name for wandb
results_dir – Directory for results
run_name – Custom run name
downsample_config – Downsampling parameters
story_selection – Specific stories to process (None = all)
- apply_fir_delays(features: Dict[str, numpy.ndarray]) Dict[str, numpy.ndarray]
Apply FIR delays to features.
- extract_and_downsample_features() Dict[str, numpy.ndarray]
Extract and downsample features for all stories.
- log_metrics(metrics: Dict)
Log metrics to configured backend.
- save_model(weights, best_alphas, metrics, model_kwargs)
Save model results.
- setup_logger(backend: str, project_name: str, results_dir: str, run_name: str | None)
Setup experiment logger.
- structure_data(features: Dict[str, numpy.ndarray]) Dict[str, numpy.ndarray]
Structure data according to training paradigm.
- train(**model_kwargs) Dict[str, Any]
Run the complete training pipeline.