Source code for encoding.features.language_model

from typing import Any, Dict, List, Union, Optional
import numpy as np
import torch
from transformer_lens import HookedTransformer

from .base import BaseFeatureExtractor


[docs] class LanguageModelFeatureExtractor(BaseFeatureExtractor): """Feature extractor that uses HookedTransformer to extract embeddings from text. This extractor supports different language models and can extract features from either the last token or average across all tokens. It now supports multi-layer extraction with lazy loading. """
[docs] def __init__(self, config: Dict[str, Any]): """Initialize the language model feature extractor. Args: config (Dict[str, Any]): Configuration dictionary containing: - model_name (str): Name of the language model to use - layer_idx (int): Index of the layer to extract features from (for backward compatibility) - hook_type (str): Type of hook to use (default: "hook_resid_pre") - last_token (bool): Whether to use only the last token's features - device (str): Device to run the model on ('cuda' or 'cpu') - context_type (str): Type of context to use (fullcontext, nocontext, halfcontext) """ super().__init__(config) self.model_name = config["model_name"] self.layer_idx = config.get("layer_idx", -1) # For backward compatibility self.hook_type = config.get("hook_type", "hook_resid_pre") self.last_token = config.get("last_token", True) self.context_type = config.get("context_type", "fullcontext") if torch.backends.mps.is_available(): self.device = "mps" elif torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" # Initialize model self.model = HookedTransformer.from_pretrained( self.model_name, device=self.device ) self.model.eval()
[docs] def extract_features( self, stimuli: Union[str, List[str]], layer_idx: Optional[int] = None, **kwargs ) -> np.ndarray: """Extract features from the input stimuli using a for loop. Args: stimuli (Union[str, List[str]]): Input text or list of texts layer_idx (Optional[int]): Specific layer to extract from. If None, uses self.layer_idx **kwargs: Additional arguments for feature extraction Returns: np.ndarray: Extracted features """ if layer_idx is None: layer_idx = self.layer_idx if isinstance(stimuli, str): stimuli = [stimuli] # Process each stimulus individually all_features = [] print(f"Processing {len(stimuli)} texts one at a time...") for i, text in enumerate(stimuli): if i % 10 == 0: print(f"Processing text {i+1}/{len(stimuli)}") # Extract features for the current text features = self._extract_single_features(text, layer_idx) all_features.append(features) # Stack all features return np.vstack(all_features)
[docs] def extract_all_layers( self, stimuli: Union[str, List[str]], **kwargs ) -> Dict[int, np.ndarray]: """Extract features from all layers for the input stimuli. Args: stimuli (Union[str, List[str]]): Input text or list of texts **kwargs: Additional arguments for feature extraction Returns: Dict[int, np.ndarray]: Dictionary mapping layer indices to features """ if isinstance(stimuli, str): stimuli = [stimuli] # Process each stimulus individually all_layer_features = {} # use the logger: TODO: Taha print(f"Processing {len(stimuli)} texts for all layers...") for i, text in enumerate(stimuli): if i % 10 == 0: print(f"Processing text {i+1}/{len(stimuli)}") # Extract all layers for the current text layer_features = self._extract_single_text_all_layers(text) # Accumulate features across texts for layer_idx, features in layer_features.items(): if layer_idx not in all_layer_features: all_layer_features[layer_idx] = [] all_layer_features[layer_idx].append(features) # Stack features for each layer for layer_idx in all_layer_features: all_layer_features[layer_idx] = np.vstack(all_layer_features[layer_idx]) return all_layer_features
def _extract_single_features(self, text: str, layer_idx: int) -> np.ndarray: """Extract features from a single text for a specific layer. Args: text (str): Input text layer_idx (int): Layer index to extract from Returns: np.ndarray: Extracted features for the text """ # if the text is '' then return np.zeros(dimensions of the features) if text == "": return np.zeros((self.model.cfg.d_model)).reshape( -1, self.model.cfg.d_model ) with torch.no_grad(): # Process a single text _, cache = self.model.run_with_cache( text, prepend_bos=True, return_type=None # Return the raw outputs ) # Get features from the specified hook and layer hook_name = f"blocks.{layer_idx}.{self.hook_type}" features = cache[hook_name] # Handle last token or average across tokens if self.last_token: # Get the last token's features token_features = features[0, -1].unsqueeze(0) # Add batch dimension else: # Average across all tokens token_features = ( features[0].mean(dim=0).unsqueeze(0) ) # Add batch dimension # Convert to numpy array return token_features.cpu().numpy() def _extract_single_text_all_layers(self, text: str) -> Dict[int, np.ndarray]: """Extract features from all layers for a single text. Args: text (str): Input text Returns: Dict[int, np.ndarray]: Dictionary mapping layer indices to features """ # if the text is '' then return zeros for all layers if text == "": empty_features = np.zeros((self.model.cfg.d_model)).reshape( -1, self.model.cfg.d_model ) return {i: empty_features for i in range(self.model.cfg.n_layers)} with torch.no_grad(): # Process a single text _, cache = self.model.run_with_cache( text, prepend_bos=True, return_type=None # Return the raw outputs ) # Extract features from all layers all_layer_features = {} for layer_idx in range(self.model.cfg.n_layers): hook_name = f"blocks.{layer_idx}.{self.hook_type}" features = cache[hook_name] # Handle last token or average across tokens if self.last_token: # Get the last token's features token_features = features[0, -1].unsqueeze(0) # Add batch dimension else: # Average across all tokens token_features = ( features[0].mean(dim=0).unsqueeze(0) ) # Add batch dimension # Convert to numpy array and store all_layer_features[layer_idx] = token_features.cpu().numpy() return all_layer_features def _validate_config(self) -> None: """Validate the configuration parameters.""" required_params = ["model_name"] for param in required_params: if param not in self.config: raise ValueError(f"Missing required parameter: {param}") if "layer_idx" in self.config: if not isinstance(self.config["layer_idx"], int): raise ValueError("layer_idx must be an integer") if "device" in self.config: if self.config["device"] not in ["cuda", "cpu"]: raise ValueError("device must be either 'cuda' or 'cpu'") if "context_type" in self.config: valid_context_types = ["fullcontext", "nocontext", "halfcontext"] if self.config["context_type"] not in valid_context_types: raise ValueError(f"context_type must be one of {valid_context_types}")