Source code for encoding.features.factory

from typing import Dict, Any, Union, Optional, Tuple
import numpy as np
from datetime import datetime
from .base import BaseFeatureExtractor
from .language_model import LanguageModelFeatureExtractor
from .speech_model import SpeechFeatureExtractor
from .simple_features import WordRateFeatureExtractor
from .embeddings import StaticEmbeddingFeatureExtractor
from ..utils import ActivationCache, SpeechActivationCache


[docs] class FeatureExtractorFactory: """Factory class for creating feature extractors with caching support.""" _extractors = { "language_model": LanguageModelFeatureExtractor, "speech": SpeechFeatureExtractor, "wordrate": WordRateFeatureExtractor, "embeddings": StaticEmbeddingFeatureExtractor, }
[docs] @classmethod def create_extractor( cls, modality: str, model_name: str, config: Dict[str, Any], cache_dir: str = "cache", ) -> BaseFeatureExtractor: """Create a feature extractor based on modality and model name. Args: modality: The type of feature extractor ('language_model', 'speech', 'wordrate', 'embeddings') model_name: The specific model name (e.g., 'gpt2-small', 'word2vec', 'openai/whisper-tiny') config: Configuration dictionary for the extractor cache_dir: Directory for caching Returns: BaseFeatureExtractor: The appropriate feature extractor instance Raises: ValueError: If modality is not supported """ if modality not in cls._extractors: raise ValueError( f"Unsupported modality '{modality}'. " f"Supported modalities: {list(cls._extractors.keys())}" ) extractor_class = cls._extractors[modality] # Add model_name to config if not present if "model_name" not in config: config["model_name"] = model_name # TODO: Change later to use **config for all extractors. But for now, only speech will use **config # ideally, they should all use a config, and that config should be a class. if modality == "language_model": extractor = extractor_class(config) elif modality == "speech": extractor = extractor_class(**config) else: extractor = extractor_class(config) print(f"this is the config: {config}") # Add caching capability if modality in ["language_model", "speech"]: extractor.cache_dir = cache_dir if modality == "speech": extractor.speech_cache = SpeechActivationCache(cache_dir=cache_dir) else: extractor.activation_cache = ActivationCache(cache_dir=cache_dir) return extractor
[docs] @classmethod def extract_features_with_caching( cls, extractor: BaseFeatureExtractor, assembly: Any, story: str, idx: int, layer_idx: int = 9, lookback: int = 256, dataset_type: str = "narratives", ) -> Union[np.ndarray, Tuple[np.ndarray, np.ndarray]]: """Extract features with caching support. Args: extractor: The feature extractor instance assembly: The assembly containing data story: Story name idx: Story index layer_idx: Layer index for multi-layer extractors lookback: Number of tokens to look back (for language models) dataset_type: Type of dataset (e.g., 'narratives', 'lebel', etc.) Returns: Features array, or (features, times) tuple for speech """ modality = cls._get_modality_from_extractor(extractor) if modality == "language_model": return cls._extract_language_model_features( extractor, assembly, story, idx, layer_idx, lookback, dataset_type ) elif modality == "speech": return cls._extract_speech_features( extractor, assembly, story, idx, layer_idx, dataset_type ) elif modality == "wordrate": word_rates = assembly.get_word_rates()[idx] return extractor.extract_features(word_rates) elif modality == "embeddings": words = assembly.get_words()[idx] return extractor.extract_features(words) else: raise ValueError(f"Unknown modality: {modality}")
@classmethod def _get_modality_from_extractor(cls, extractor: BaseFeatureExtractor) -> str: """Get modality from extractor instance.""" if isinstance(extractor, LanguageModelFeatureExtractor): return "language_model" elif isinstance(extractor, SpeechFeatureExtractor): return "speech" elif isinstance(extractor, WordRateFeatureExtractor): return "wordrate" elif isinstance(extractor, StaticEmbeddingFeatureExtractor): return "embeddings" else: raise ValueError(f"Unknown extractor type: {type(extractor)}") @classmethod def _extract_language_model_features( cls, extractor: LanguageModelFeatureExtractor, assembly: Any, story: str, idx: int, layer_idx: int, lookback: int = 256, dataset_type: str = "narratives", ) -> np.ndarray: """Extract language model features with caching.""" texts = assembly.get_stimuli()[idx] # Try to load cached activations cache_key = extractor.activation_cache._get_cache_key( story=story, lookback=lookback, # You can make this configurable model_name=extractor.model_name, context_type=getattr(extractor, "context_type", "fullcontext"), last_token=getattr(extractor, "last_token", False), dataset_type=dataset_type, raw=True, ) print(f"this is the last token: {getattr(extractor, 'last_token', False)}") print(f"this is the lookback: {lookback}") print(f'this is the layer: {layer_idx}') lazy_cache = extractor.activation_cache.load_multi_layer_activations(cache_key) if lazy_cache is not None: return lazy_cache.get_layer(layer_idx) else: # Compute and cache features all_features = extractor.extract_all_layers(texts) # Create metadata for caching metadata = { "model_name": extractor.model_name, "story": story, "lookback": lookback, "context_type": getattr(extractor, "context_type", "fullcontext"), "hook_type": extractor.hook_type, "last_token": getattr(extractor, "last_token", False), "dataset_type": dataset_type, "available_layers": list(all_features.keys()), "created_at": datetime.now().isoformat(), } # Save to cache extractor.activation_cache.save_multi_layer_activations( cache_key, all_features, metadata ) return all_features[layer_idx] @classmethod def _extract_speech_features( cls, extractor: SpeechFeatureExtractor, assembly: Any, story: str, idx: int, layer_idx: int, dataset_type: str, ) -> Tuple[np.ndarray, np.ndarray]: """Extract speech features with caching.""" wav_path = assembly.get_audio_path()[idx] # Try to load from cache cache_key = extractor.speech_cache.get_cache_key( audio_id=wav_path, model_name=extractor.model_name, chunk_size=extractor.chunk_size, context_size=extractor.context_size, pool=extractor.pool, target_sample_rate=extractor.target_sample_rate, dataset_type=dataset_type, extra={"layer_mode": "all"}, ) lazy = extractor.speech_cache.load_multi_layer_activations(cache_key) if lazy is not None: # Validate cached data lazy.validate_params( expected={ "model_name": extractor.model_name, "chunk_size": extractor.chunk_size, "context_size": extractor.context_size, "pool": extractor.pool, "target_sample_rate": extractor.target_sample_rate, "dataset_type": dataset_type, } ) features = lazy.get_layer(layer_idx) times = lazy.get_times() else: # Compute and cache features layer_to_feats, times = extractor.extract_all_layers(wav_path) if len(layer_to_feats) == 0: raise RuntimeError( "extract_all_layers returned no layers (audio too short?)." ) # Save to cache metadata = { "modality": "speech", "audio_id": wav_path, "model_name": extractor.model_name, "chunk_size": extractor.chunk_size, "context_size": extractor.context_size, "pool": extractor.pool, "target_sample_rate": extractor.target_sample_rate, "dataset_type": dataset_type, "available_layers": sorted(layer_to_feats.keys()), } extractor.speech_cache.save_multi_layer_activations( cache_key, all_layer_activations=layer_to_feats, metadata=metadata, times=times, ) features = layer_to_feats[layer_idx] return features, times
[docs] @classmethod def get_supported_modalities(cls) -> list: """Get list of supported modalities.""" return list(cls._extractors.keys())
[docs] @classmethod def register_extractor(cls, modality: str, extractor_class: type): """Register a new feature extractor class. Args: modality: The modality name extractor_class: The extractor class to register """ cls._extractors[modality] = extractor_class