Source code for encoding.features.speech_model

from __future__ import annotations

import torch
import numpy as np
from typing import Dict, Tuple, Optional
from transformers import AutoModel, AutoProcessor, AutoFeatureExtractor    
from tqdm import tqdm

def import_torchaudio_gracefully():
    try:
        import torchaudio
        return torchaudio
    except ImportError:
        raise ImportError('torchaudio is required for SpeechFeatureExtractor. Please install it with this command:\npip install torchaudio')

def auto_device(fn):
    def wrapper(self, *args, **kwargs):
        with torch.no_grad():
            return fn(self, *args, **kwargs)

    return wrapper


[docs] class SpeechFeatureExtractor: """ Unified feature extractor for HF speech models (Whisper encoder, HuBERT, Wav2Vec2). - extract_features(wav_path, layer=None) -> (features [n_chunks, D], times [n_chunks]) - extract_all_layers(wav_path) -> (layer_to_features {idx: [n_chunks, D]}, times [n_chunks]) Notes: * Pooling over encoder time can be 'last' or 'mean'. * For Whisper, we call the ENCODER ONLY (model.get_encoder()). * 'layer' indices are 0-based over encoder blocks (exclude embeddings). """
[docs] def __init__( self, model_name: str, chunk_size: float, # seconds between chunk starts (stride) context_size: float, # seconds of audio per window (window length) layer: str | int = "last", # default layer for single-layer extraction pool: str = "last", # 'last' or 'mean' device: Optional[str] = None, target_sample_rate: int = 16000, disable_tqdm: bool = False, ): import_torchaudio_gracefully() assert pool in {"last", "mean"}, "pool must be 'last' or 'mean'" self.model_name = model_name self.chunk_size = float(chunk_size) self.context_size = float(context_size) self.layer = layer self.pool = pool self.device = device or ( "cuda" if torch.cuda.is_available() else ("mps" if torch.backends.mps.is_available() else "cpu") ) self.target_sample_rate = int(target_sample_rate) self.disable_tqdm = disable_tqdm # Load base model self.model = AutoModel.from_pretrained(model_name).to(self.device) self.model.eval() # Detect model type & set up feature extractor + forward key self.model_type = getattr(self.model.config, "model_type", "").lower() if self.model_type == "whisper": # Whisper expects log-mel features self.feature_extractor = AutoFeatureExtractor.from_pretrained(model_name) self._forward_key = "input_features" self._encoder = self.model.get_encoder() # use encoder only else: # HuBERT/Wav2Vec2 expect raw PCM try: proc = AutoProcessor.from_pretrained(model_name) self.feature_extractor = getattr(proc, "feature_extractor", proc) except Exception: self.feature_extractor = AutoFeatureExtractor.from_pretrained( model_name ) self._forward_key = "input_values" self._encoder = self.model # whole model acts as encoder here
# Helpers def _prepare_inputs(self, waveform: torch.Tensor) -> Dict[str, torch.Tensor]: # waveform: 1D CPU torch tensor inputs = self.feature_extractor( waveform.cpu().numpy(), sampling_rate=self.target_sample_rate, return_tensors="pt", ) return {k: v.to(self.device) for k, v in inputs.items()} def _resolve_n_layers(self, hidden_states: Tuple[torch.Tensor, ...]) -> int: """ hidden_states usually length = n_layers + 1 (embeddings + each block). expose 0..n_layers-1 over the *blocks* (exclude embeddings). """ return len(hidden_states) - 1 def _get_layer_tensor( self, hidden_states: Tuple[torch.Tensor, ...], layer: str | int ) -> torch.Tensor: """ Return hidden state for a given layer index (0-based over blocks), or 'last' meaning last encoder block. Shift by +1 to skip embeddings. """ if layer == "last": return hidden_states[-1] idx = int(layer) return hidden_states[idx + 1] def _pool_time(self, x: torch.Tensor) -> torch.Tensor: # x: [1, T, D] if x.dim() == 2: x = x.unsqueeze(1) # [1, 1, D] return x[0, -1, :] if self.pool == "last" else x[0].mean(dim=0) def _load_and_resample(self, wav_path: str) -> torch.Tensor: torchaudio = import_torchaudio_gracefully() wav, sr = torchaudio.load(wav_path) if wav.shape[0] != 1: wav = wav.mean(0, keepdim=True) if sr != self.target_sample_rate: wav = torchaudio.functional.resample(wav, sr, self.target_sample_rate) return wav.squeeze(0) # [num_samples], CPU # The important part for the library. We need the same interface for all models, so that everything works # outside the box.
[docs] @auto_device def extract_features(self, wav_path: str, layer: str | int | None = None): """ Single-layer extraction stacked over chunks. Returns: features: [n_chunks, D] times: [n_chunks] """ layer = self.layer if layer is None else layer wav = self._load_and_resample(wav_path) chunk_samples = int(self.chunk_size * self.target_sample_rate) context_samples = int(self.context_size * self.target_sample_rate) total = wav.shape[0] if context_samples <= 0 or chunk_samples <= 0: raise ValueError("context_size and chunk_size must be > 0 seconds.") if total < context_samples: return np.empty((0, 0)), np.array([]) n_chunks = (total - context_samples) // chunk_samples + 1 features, times = [], [] with tqdm( total=int(n_chunks), desc="Extracting features", disable=self.disable_tqdm ) as pbar: for i in range(int(n_chunks)): end = context_samples + i * chunk_samples start = max(0, end - context_samples) if end > total: break window = wav[start:end] inputs = self._prepare_inputs(window) outputs = self._encoder( **{self._forward_key: inputs[self._forward_key]}, output_hidden_states=True, ) hs = outputs.hidden_states # tuple [1, T, D] layer_t = self._get_layer_tensor(hs, layer) if layer_t.shape[1] == 0: pbar.update(1) continue vec = self._pool_time(layer_t) # [D] features.append(vec.detach().cpu().numpy()) times.append(end / self.target_sample_rate) pbar.update(1) features = np.stack(features) if len(features) else np.empty((0, 0)) times = np.array(times) return features, times
[docs] @auto_device def extract_all_layers(self, wav_path: str): """ All-layers extraction stacked over chunks. Returns: layer_to_features: {layer_idx: [n_chunks, D]} times: [n_chunks] """ wav = self._load_and_resample(wav_path) chunk_samples = int(self.chunk_size * self.target_sample_rate) context_samples = int(self.context_size * self.target_sample_rate) total = wav.shape[0] if context_samples <= 0 or chunk_samples <= 0: raise ValueError("context_size and chunk_size must be > 0 seconds.") if total < context_samples: return {}, np.array([]) n_chunks = (total - context_samples) // chunk_samples + 1 layer_buffers: Dict[int, list[np.ndarray]] = {} times: list[float] = [] with tqdm( total=int(n_chunks), desc="Extracting all layers", disable=self.disable_tqdm ) as pbar: for i in range(int(n_chunks)): end = context_samples + i * chunk_samples start = max(0, end - context_samples) if end > total: break window = wav[start:end] inputs = self._prepare_inputs(window) outputs = self._encoder( **{self._forward_key: inputs[self._forward_key]}, output_hidden_states=True, ) hs = outputs.hidden_states # tuple [1, T, D] if hs[-1].shape[1] == 0: pbar.update(1) continue n_layers = self._resolve_n_layers(hs) if not layer_buffers: for li in range(n_layers): layer_buffers[li] = [] for li in range(n_layers): layer_t = hs[li + 1] # skip embeddings at index 0 vec = self._pool_time(layer_t) layer_buffers[li].append(vec.detach().cpu().numpy()) times.append(end / self.target_sample_rate) pbar.update(1) layer_to_features = { li: (np.stack(buf) if len(buf) else np.empty((0, 0))) for li, buf in layer_buffers.items() } return layer_to_features, np.array(times)