encoding.trainer

Abstract trainer that accepts components as dependencies.

class encoding.trainer.AbstractTrainer(assembly: Any, feature_extractors: List[Any], downsampler: Any, model: Any, fir_delays: List[int], trimming_config: Dict, use_train_test_split: bool = False, layer_idx: int = 9, lookback: int = 256, dataset_type: str = 'unknown', logger_backend: str = 'wandb', wandb_project_name: str = 'abstract-trainer', results_dir: str = 'results', run_name: str | None = None, downsample_config: Dict | None = None, story_selection: List[str] | None = None)

A completely abstract trainer that accepts all components as dependencies.

This trainer doesn’t know about datasets, assemblies, or specific feature types. It just orchestrates the pipeline: extract → downsample → FIR → trim → train.

__init__(assembly: Any, feature_extractors: List[Any], downsampler: Any, model: Any, fir_delays: List[int], trimming_config: Dict, use_train_test_split: bool = False, layer_idx: int = 9, lookback: int = 256, dataset_type: str = 'unknown', logger_backend: str = 'wandb', wandb_project_name: str = 'abstract-trainer', results_dir: str = 'results', run_name: str | None = None, downsample_config: Dict | None = None, story_selection: List[str] | None = None)

Initialize with all components as dependencies.

Parameters:
  • assembly – Data assembly (has .stories, .get_brain_data(), etc.)

  • feature_extractors – List of feature extraction components

  • downsampler – Downsampling component

  • model – Model with fit_predict() method

  • fir_delays – List of FIR delays to apply

  • trimming_config – Dict specifying how to trim data

  • use_train_test_split – Whether to use train/test split vs concatenation

  • layer_idx – Layer index for feature extraction

  • lookback – Context lookback for feature extraction

  • dataset_type – Dataset type for caching

  • logger_backend – “wandb” or “tensorboard”

  • wandb_project_name – Project name for wandb

  • results_dir – Directory for results

  • run_name – Custom run name

  • downsample_config – Downsampling parameters

  • story_selection – Specific stories to process (None = all)

apply_fir_delays(features: Dict[str, numpy.ndarray]) Dict[str, numpy.ndarray]

Apply FIR delays to features.

extract_and_downsample_features() Dict[str, numpy.ndarray]

Extract and downsample features for all stories.

log_metrics(metrics: Dict)

Log metrics to configured backend.

save_model(weights, best_alphas, metrics, model_kwargs)

Save model results.

setup_logger(backend: str, project_name: str, results_dir: str, run_name: str | None)

Setup experiment logger.

structure_data(features: Dict[str, numpy.ndarray]) Dict[str, numpy.ndarray]

Structure data according to training paradigm.

train(**model_kwargs) Dict[str, Any]

Run the complete training pipeline.