Source code for encoding.features.embeddings

# encoding/features/static_token_extractor.py
from typing import Any, Dict, List, Union, Optional, Iterable
import os
import re
import numpy as np
import torch
from gensim.models import KeyedVectors

from .base import BaseFeatureExtractor

try:
    from tqdm.auto import tqdm
except Exception:

    def tqdm(x, **kwargs):
        return x


[docs] class StaticEmbeddingFeatureExtractor(BaseFeatureExtractor): """ Local-only static *token* embedding extractor (Word2Vec / GloVe). Input (extract_features): - List[str]: list of tokens/words (preferred), order preserved - str: a raw string (will be tokenized using `tokenizer_pattern`) Output: - np.ndarray with shape [N, D], one row per input token. Config (Dict[str, Any]): - vector_path (str, required): local vectors path. Supported: *.kv -> KeyedVectors.load (mmap capable) *.bin / *.bin.gz -> word2vec binary (binary=True) *.w2v.txt -> word2vec text WITH header (binary=False, no_header=False) *.txt / *.txt.gz -> GloVe text WITHOUT header (binary=False, no_header=True) - lowercase (bool): lowercase tokens before lookup (GoogleNews: False; GloVe/Wiki-Giga: True) [default: True] - oov_handling (str): one of: "copy_prev" -> OOV copies the previous valid embedding (DEFAULT) "zero" -> OOV becomes a zero vector (length preserved) "skip" -> OOV is dropped (length may shrink) "error" -> raise on first OOV - use_tqdm (bool): show progress bar for long inputs [default: True] - mmap (bool): memory-map .kv [default: True] - binary (Optional[bool]): force word2vec binary flag; auto-infer if None - no_header (Optional[bool]): force GloVe no-header; auto-infer if None - l2_normalize_tokens (bool): L2-normalize each token vector [default: False] - tokenizer_pattern (str): ONLY used if input is a single string. Default r"[A-Za-z0-9_']+" (keeps underscores) Note: This has also been tested with ENG1000. You just have to convert it to the .kv format first. We'll provide a scrip to do that! """
[docs] def __init__(self, config: Dict[str, Any]): super().__init__(config) # ---- Required path vector_path = config.get("vector_path", "") if not vector_path: raise ValueError("'vector_path' is required.") # expanduser + abspath for clearer logs self.vector_path: str = os.path.abspath(os.path.expanduser(vector_path)) if not os.path.exists(self.vector_path): raise FileNotFoundError(f"Vector file not found: {self.vector_path}") self.lowercase: bool = bool(config.get("lowercase", True)) self.oov_handling: str = config.get("oov_handling", "copy_prev") if self.oov_handling not in {"copy_prev", "zero", "skip", "error"}: raise ValueError( "oov_handling must be 'copy_prev', 'zero', 'skip', or 'error'" ) self.use_tqdm: bool = bool(config.get("use_tqdm", True)) self.mmap: bool = bool(config.get("mmap", True)) self.l2_normalize_tokens: bool = bool(config.get("l2_normalize_tokens", False)) self.tokenizer_pattern: str = config.get("tokenizer_pattern", r"[A-Za-z0-9_']+") self._force_binary: Optional[bool] = config.get("binary", None) self._force_no_header: Optional[bool] = config.get("no_header", None) if torch.backends.mps.is_available(): self.device = "mps" elif torch.cuda.is_available(): self.device = "cuda" else: self.device = "cpu" self._tok_re = re.compile(self.tokenizer_pattern) print(f"[StaticToken] Loading vectors: {self.vector_path}") self.kv = self._load_local_vectors(self.vector_path) self.dim = int(self.kv.vector_size) print( f"[StaticToken] Loaded ({self.dim}-D), vocab={len(self.kv.key_to_index):,}" )
[docs] def extract_features( self, stimuli: Union[str, List[str]], **kwargs, ) -> np.ndarray: """ Tokens -> [N, D], one row per input token. If `stimuli` is a string, it is tokenized. OOV handling per config (default: copy previous valid embedding). """ # Normalize input to a token list if isinstance(stimuli, str): text = stimuli.lower() if self.lowercase else stimuli tokens = self._tok_re.findall(text) elif isinstance(stimuli, list): tokens = [] for t in stimuli: if isinstance(t, str): tokens.append(t.lower() if self.lowercase else t) else: tokens.append(t) # will be handled below else: raise TypeError( "extract_features expects a List[str] of tokens or a single string." ) N = len(tokens) if N == 0: return np.zeros((0, self.dim), dtype=np.float32) iterator: Iterable[str] = ( tqdm(tokens, desc="Embedding tokens", total=N) if self.use_tqdm else tokens ) vecs: List[np.ndarray] = [] last_valid: Optional[np.ndarray] = None # for copy_prev for i, tok in enumerate(iterator): v: Optional[np.ndarray] = None if not isinstance(tok, str): if self.oov_handling == "error": raise ValueError(f"Non-string token at index {i}: {tok!r}") elif self.oov_handling == "skip": # Skip may shrink length continue elif self.oov_handling == "copy_prev": v = ( last_valid.copy() if last_valid is not None else np.zeros((self.dim,), dtype=np.float32) ) else: # "zero" v = np.zeros((self.dim,), dtype=np.float32) else: # String token: lookup if tok in self.kv.key_to_index: v = self.kv.get_vector(tok).astype(np.float32, copy=False) # only update last_valid when we have a real vector last_valid = v.copy() else: # OOV handling if self.oov_handling == "error": raise KeyError(f"OOV token at index {i}: {tok!r}") elif self.oov_handling == "skip": continue # WARNING: length may shrink elif self.oov_handling == "copy_prev": v = ( last_valid.copy() if last_valid is not None else np.zeros((self.dim,), dtype=np.float32) ) else: # "zero" v = np.zeros((self.dim,), dtype=np.float32) # Optional per-token L2 norm if self.l2_normalize_tokens: n = np.linalg.norm(v) if n > 0: v = v / n vecs.append(v) if not vecs: return np.zeros((0, self.dim), dtype=np.float32) return np.stack([np.asarray(v, dtype=np.float32) for v in vecs], axis=0)
def _load_local_vectors(self, path: str) -> KeyedVectors: ext = path.lower() if ext.endswith(".kv"): return KeyedVectors.load(path, mmap="r" if self.mmap else None) binary = ( self._infer_binary(ext) if self._force_binary is None else bool(self._force_binary) ) no_header = ( self._infer_no_header(ext) if self._force_no_header is None else bool(self._force_no_header) ) try: return KeyedVectors.load_word2vec_format( path, binary=binary, no_header=no_header ) except Exception as e: # If *.txt mis-detected, flip no_header once and retry if ext.endswith(".txt") or ext.endswith(".txt.gz"): try: return KeyedVectors.load_word2vec_format( path, binary=False, no_header=not no_header ) except Exception as e2: raise RuntimeError( f"Failed to load vectors from {path}.\n" f"Attempt1 (binary={binary}, no_header={no_header}) -> {e}\n" f"Attempt2 (binary=False, no_header={not no_header}) -> {e2}\n" "If this is raw GloVe, use no_header=True. If word2vec text, it must have a header." ) raise @staticmethod def _infer_binary(ext: str) -> bool: return ext.endswith(".bin") or ext.endswith(".bin.gz") @staticmethod def _infer_no_header(ext: str) -> bool: # Heuristics: # *.w2v.txt -> word2vec text WITH header => no_header=False # *.txt/.txt.gz -> assume GloVe text WITHOUT header => no_header=True # (binaries ignore no_header) if ext.endswith(".w2v.txt"): return False if ext.endswith(".txt") or ext.endswith(".txt.gz"): return True return False