Compare commits
1 Commits
main
...
brouhaha_a
| Author | SHA1 | Date |
|---|---|---|
|
|
971858f13f |
|
|
@ -0,0 +1,98 @@
|
|||
# Quality-aware ASR pipeline using brouhaha VAD/SNR/C50
|
||||
|
||||
This document captures a two-stage pipeline design proposed for WhisperJAV that leverages brouhaha VAD/SNR/C50 to route segments to the optimal processing path.
|
||||
|
||||
## Goals
|
||||
- Maximize ASR accuracy and throughput by routing clean speech directly to ASR
|
||||
- Improve challenging segments via targeted enhancement prior to ASR
|
||||
- Remain modular, configurable, and data-driven
|
||||
|
||||
## High-level flow
|
||||
1. Diarize audio and extract segments (existing pyannote pipeline)
|
||||
2. For each segment:
|
||||
- Run brouhaha inference to obtain per-frame (vad, snr, c50)
|
||||
- Aggregate per-segment metrics (e.g., mean/median over vad>threshold frames)
|
||||
- Classify segment as clean or challenging based on configurable rules
|
||||
3. Route:
|
||||
- Clean → Whisper ASR as-is
|
||||
- Challenging → Enhancement stack → Whisper ASR
|
||||
4. Merge ASR outputs and preserve provenance/metadata (path taken, metrics)
|
||||
|
||||
## Segment classification
|
||||
- Inputs: list of frames with (vad, snr, c50)
|
||||
- Steps:
|
||||
- Keep frames with vad >= VAD_MIN
|
||||
- Compute metrics: avg_snr, p50_snr, p10_snr, avg_c50, speech_coverage (ratio of speech frames)
|
||||
- Decision policies (configurable):
|
||||
- Clean if avg_snr >= SNR_CLEAN && avg_c50 >= C50_CLEAN && speech_coverage >= MIN_COVERAGE
|
||||
- Challenging otherwise, with sub-classes (noisy, reverberant, low-speech-coverage)
|
||||
- Optional: score = w_snr * avg_snr + w_c50 * avg_c50; compare with SCORE_THRESH
|
||||
|
||||
## Enhancement stack (challenging segments)
|
||||
- Noisy: denoise → loudness norm → (optional) bandwidth EQ
|
||||
- Reverberant: dereverb → denoise → norm
|
||||
- Music/background: vocal separation → denoise → norm
|
||||
- Low-energy: dynamic range compression/AGC → norm
|
||||
|
||||
## Libraries/algorithms (Python-first)
|
||||
- Denoising:
|
||||
- torchaudio.functional.spectral_gate
|
||||
- RNNoise (via rnnoise-py)
|
||||
- NVIDIA Maxine/RTX Voice (Windows, optional, external)
|
||||
- SpeechBrain/asteroid (DPRNN/ConvTasNet/FullSubNet2)
|
||||
- Vocal separation:
|
||||
- demucs (facebookresearch/demucs)
|
||||
- spleeter (deezer/spleeter)
|
||||
- Dereverberation / clarity:
|
||||
- Weighted prediction error (WPE) via nara_wpe
|
||||
- Kaggle/pyroomacoustics-based dereverb filters
|
||||
- DCCRN/DCCRNet models
|
||||
- Loudness/dynamics:
|
||||
- pyloudnorm for EBU R128 normalization
|
||||
- librosa effects (preemphasis), simple compressors (custom)
|
||||
- Utility:
|
||||
- pyannote.audio for diarization/embedding
|
||||
- brouhaha-vad for VAD/SNR/C50
|
||||
|
||||
## Data structures
|
||||
- SegmentMetrics
|
||||
- start: float
|
||||
- end: float
|
||||
- avg_snr: float
|
||||
- avg_c50: float
|
||||
- coverage: float
|
||||
- label: Literal["clean", "noisy", "reverberant", "music", "low_energy"]
|
||||
- SegmentDecision
|
||||
- route: Literal["direct", "enhance"]
|
||||
- reasons: list[str]
|
||||
- params: dict[str, Any] (e.g., chosen enhancement chain)
|
||||
|
||||
## Config
|
||||
- YAML/JSON with thresholds and weights
|
||||
- Example:
|
||||
- VAD_MIN: 0.5
|
||||
- SNR_CLEAN: 7.5
|
||||
- C50_CLEAN: 0.5
|
||||
- MIN_COVERAGE: 0.4
|
||||
- SCORE: { w_snr: 0.7, w_c50: 0.3, thresh: 6.0 }
|
||||
- Chains:
|
||||
- noisy: [denoise, norm]
|
||||
- reverberant: [dereverb, denoise, norm]
|
||||
- music: [separate, denoise, norm]
|
||||
- low_energy: [agc, norm]
|
||||
|
||||
## Integration notes
|
||||
- Keep current diarization/embedding flow unchanged
|
||||
- Insert a segment routing phase prior to embedding ASR or directly prior to Whisper
|
||||
- Enhancement should be stateless, batch-friendly, and GPU-optional
|
||||
- Cache enhanced waveforms for reproducibility; tag outputs with metadata
|
||||
|
||||
## Testing
|
||||
- Unit tests for routing decisions given synthetic metrics
|
||||
- Golden audio samples with known issues to verify improvement
|
||||
- A/B test: baseline vs two-step routing on WER/CER and speaker attribution
|
||||
|
||||
## Future
|
||||
- Adaptive thresholds learned from data
|
||||
- Confidence-driven re-segmentation for borderline cases
|
||||
- On-device light-weight enhancement for real-time
|
||||
|
|
@ -0,0 +1,26 @@
|
|||
# Brouhaha Pipeline & QualityAware ASR
|
||||
|
||||
This document outlines the architectural additions:
|
||||
|
||||
- A new ASR engine `QualityAwareASR` that uses brouhaha (VAD/SNR/C50) to route segments to enhancement or direct ASR.
|
||||
- A new pipeline `BrouhahaPipeline` mirroring the balanced pipeline but using `QualityAwareASR`.
|
||||
|
||||
## Contracts
|
||||
- `QualityAwareASR.transcribe(path) -> Dict` returns {segments, text, language, [quality_meta]}.
|
||||
- `QualityAwareASR.transcribe_to_srt(path, out_path) -> Path` writes SRT similar to WhisperProASR.
|
||||
|
||||
## Config
|
||||
- Optional `params.quality_aware.routing_config_path` to point to a JSON like `whisperjav/config/quality_routing.template.json`.
|
||||
- Optional `params.quality_aware.hf_token` for private models if needed.
|
||||
|
||||
## Integration
|
||||
- New pipeline class `BrouhahaPipeline` under `whisperjav/pipelines/brouhaha_pipeline.py`.
|
||||
- To enable as a CLI mode, wire it into `main.py` mode selection (not yet done in code here to avoid conflicts).
|
||||
|
||||
## Enhancement registry
|
||||
- Implement your enhancement steps, then register them by name in `audio_enhancement.registry`.
|
||||
- The router selects a chain via config and `run_chain` executes it.
|
||||
|
||||
## Next steps
|
||||
- Decide on enhancement implementations: rnnoise/nara_wpe/demucs/pyloudnorm.
|
||||
- Wire a new CLI mode (e.g., --mode brouhaha) and add unit tests for routing decisions.
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
{
|
||||
"VAD_MIN": 0.5,
|
||||
"SNR_CLEAN": 7.5,
|
||||
"C50_CLEAN": 0.5,
|
||||
"MIN_COVERAGE": 0.4,
|
||||
"SCORE": { "w_snr": 0.7, "w_c50": 0.3, "thresh": 6.0 },
|
||||
"CHAINS": {
|
||||
"noisy": ["denoise", "norm"],
|
||||
"reverberant": ["dereverb", "denoise", "norm"],
|
||||
"music": ["separate", "denoise", "norm"],
|
||||
"low_energy": ["agc", "norm"],
|
||||
"unknown": ["denoise", "norm"]
|
||||
}
|
||||
}
|
||||
|
|
@ -37,6 +37,7 @@ from whisperjav.modules.media_discovery import MediaDiscovery
|
|||
from whisperjav.pipelines.faster_pipeline import FasterPipeline
|
||||
from whisperjav.pipelines.fast_pipeline import FastPipeline
|
||||
from whisperjav.pipelines.balanced_pipeline import BalancedPipeline
|
||||
from whisperjav.pipelines.brouhaha_pipeline import BrouhahaPipeline
|
||||
from whisperjav.config.transcription_tuner import TranscriptionTuner
|
||||
from whisperjav.__version__ import __version__
|
||||
|
||||
|
|
@ -96,7 +97,7 @@ def parse_arguments():
|
|||
|
||||
# Core arguments
|
||||
parser.add_argument("input", nargs="*", help="Input media file(s), directory, or wildcard pattern.")
|
||||
parser.add_argument("--mode", choices=["balanced", "fast", "faster"], default="balanced",
|
||||
parser.add_argument("--mode", choices=["balanced", "fast", "faster", "brouhaha"], default="balanced",
|
||||
help="Processing mode (default: balanced)")
|
||||
parser.add_argument("--config", default=None, help="Path to a JSON configuration file")
|
||||
parser.add_argument("--subs-language", choices=["japanese", "english-direct"],
|
||||
|
|
@ -241,6 +242,8 @@ def process_files_sync(media_files: List[Dict], args: argparse.Namespace, resolv
|
|||
pipeline = FasterPipeline(**pipeline_args)
|
||||
elif args.mode == "fast":
|
||||
pipeline = FastPipeline(**pipeline_args)
|
||||
elif args.mode == "brouhaha":
|
||||
pipeline = BrouhahaPipeline(**pipeline_args)
|
||||
else: # balanced
|
||||
pipeline = BalancedPipeline(**pipeline_args)
|
||||
|
||||
|
|
|
|||
|
|
@ -0,0 +1,53 @@
|
|||
"""
|
||||
Audio enhancement orchestrator.
|
||||
|
||||
Provides a simple registry and runner for enhancement chains
|
||||
such as denoise, dereverb, separation, normalization, etc.
|
||||
|
||||
This module is intentionally minimal; implement steps with
|
||||
preferred libraries in your environment.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Callable, Dict, List, Tuple
|
||||
|
||||
EnhancementFn = Callable[[Any, int], Tuple[Any, int]]
|
||||
|
||||
|
||||
class EnhancementRegistry:
|
||||
def __init__(self) -> None:
|
||||
self._fns: Dict[str, EnhancementFn] = {}
|
||||
|
||||
def register(self, name: str, fn: EnhancementFn) -> None:
|
||||
self._fns[name] = fn
|
||||
|
||||
def get(self, name: str) -> EnhancementFn | None:
|
||||
return self._fns.get(name)
|
||||
|
||||
|
||||
registry = EnhancementRegistry()
|
||||
|
||||
|
||||
def run_chain(waveform: Any, sample_rate: int, steps: List[str]) -> Tuple[Any, int]:
|
||||
"""
|
||||
Run a sequence of enhancement steps registered in the registry.
|
||||
Each step is a function: (waveform, sample_rate) -> (waveform, sample_rate)
|
||||
"""
|
||||
out_wav, out_sr = waveform, sample_rate
|
||||
for name in steps:
|
||||
fn = registry.get(name)
|
||||
if fn is None:
|
||||
# skip unknown steps, allowing config-driven experimentation
|
||||
continue
|
||||
out_wav, out_sr = fn(out_wav, out_sr)
|
||||
return out_wav, out_sr
|
||||
|
||||
|
||||
# Example placeholder steps (no-ops) — replace with real implementations
|
||||
def _noop(wav: Any, sr: int) -> Tuple[Any, int]:
|
||||
return wav, sr
|
||||
|
||||
|
||||
# Pre-register common names to avoid KeyErrors during experimentation
|
||||
for _name in ["denoise", "dereverb", "separate", "norm", "agc"]:
|
||||
registry.register(_name, _noop)
|
||||
|
|
@ -0,0 +1,134 @@
|
|||
"""
|
||||
Quality-based segment routing using brouhaha VAD/SNR/C50.
|
||||
|
||||
This module computes per-segment quality metrics from brouhaha
|
||||
inference outputs and classifies segments as 'direct' (ASR as-is)
|
||||
or 'enhance' (run through enhancement chain before ASR).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentMetrics:
|
||||
start: float
|
||||
end: float
|
||||
avg_snr: float
|
||||
avg_c50: float
|
||||
coverage: float # ratio of frames with vad >= vad_min
|
||||
score: float
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentDecision:
|
||||
route: str # 'direct' | 'enhance'
|
||||
label: str # 'clean' | 'noisy' | 'reverberant' | 'music' | 'low_energy' | 'unknown'
|
||||
reasons: List[str]
|
||||
params: Dict[str, Any]
|
||||
|
||||
|
||||
def aggregate_metrics(
|
||||
frames: List[Tuple[float, Tuple[float, float, float]]],
|
||||
vad_min: float,
|
||||
start: float,
|
||||
end: float,
|
||||
score_weights: Tuple[float, float] = (0.7, 0.3),
|
||||
) -> SegmentMetrics:
|
||||
"""
|
||||
Aggregate per-frame (time, (vad, snr, c50)) into segment metrics.
|
||||
|
||||
frames: list of (time, (vad, snr, c50)) from brouhaha Inference.
|
||||
vad_min: threshold to consider a frame as speech.
|
||||
start, end: segment boundaries.
|
||||
score_weights: (w_snr, w_c50) for composite score.
|
||||
"""
|
||||
if not frames:
|
||||
return SegmentMetrics(start, end, avg_snr=0.0, avg_c50=0.0, coverage=0.0, score=0.0)
|
||||
|
||||
# Import numpy lazily to avoid hard dependency at import time
|
||||
import numpy as np # type: ignore
|
||||
|
||||
vad = np.array([v for _, (v, _, _) in frames], dtype=np.float32)
|
||||
snr = np.array([s for _, (_, s, _) in frames], dtype=np.float32)
|
||||
c50 = np.array([c for _, (_, _, c) in frames], dtype=np.float32)
|
||||
|
||||
speech_mask = vad >= vad_min
|
||||
coverage = float(np.mean(speech_mask)) if vad.size else 0.0
|
||||
|
||||
if speech_mask.any():
|
||||
avg_snr = float(np.mean(snr[speech_mask]))
|
||||
avg_c50 = float(np.mean(c50[speech_mask]))
|
||||
else:
|
||||
avg_snr = 0.0
|
||||
avg_c50 = 0.0
|
||||
|
||||
w_snr, w_c50 = score_weights
|
||||
score = w_snr * avg_snr + w_c50 * avg_c50
|
||||
|
||||
return SegmentMetrics(start, end, avg_snr=avg_snr, avg_c50=avg_c50, coverage=coverage, score=score)
|
||||
|
||||
|
||||
def classify_segment(
|
||||
m: SegmentMetrics,
|
||||
cfg: Dict[str, Any],
|
||||
) -> SegmentDecision:
|
||||
"""
|
||||
Classify a segment as 'direct' or 'enhance' with an issue label.
|
||||
|
||||
cfg expects keys:
|
||||
- VAD_MIN, SNR_CLEAN, C50_CLEAN, MIN_COVERAGE
|
||||
- SCORE: { w_snr, w_c50, thresh }
|
||||
- CHAINS: mapping from label -> list of step names
|
||||
"""
|
||||
reasons: List[str] = []
|
||||
label = "clean"
|
||||
|
||||
snr_clean = float(cfg.get("SNR_CLEAN", 7.5))
|
||||
c50_clean = float(cfg.get("C50_CLEAN", 0.5))
|
||||
min_cov = float(cfg.get("MIN_COVERAGE", 0.4))
|
||||
score_cfg = cfg.get("SCORE", {"w_snr": 0.7, "w_c50": 0.3, "thresh": 6.0})
|
||||
score_thresh = float(score_cfg.get("thresh", 6.0))
|
||||
|
||||
if m.coverage < min_cov:
|
||||
label = "low_energy"
|
||||
reasons.append(f"coverage {m.coverage:.2f} < {min_cov}")
|
||||
|
||||
if m.avg_snr < snr_clean:
|
||||
label = "noisy"
|
||||
reasons.append(f"avg_snr {m.avg_snr:.2f} < {snr_clean}")
|
||||
|
||||
if m.avg_c50 < c50_clean:
|
||||
# prefer 'reverberant' if not already more critical
|
||||
if label == "clean":
|
||||
label = "reverberant"
|
||||
reasons.append(f"avg_c50 {m.avg_c50:.2f} < {c50_clean}")
|
||||
|
||||
# Composite score gate
|
||||
if m.score < score_thresh and label == "clean":
|
||||
label = "unknown"
|
||||
reasons.append(f"score {m.score:.2f} < {score_thresh}")
|
||||
|
||||
if label == "clean":
|
||||
return SegmentDecision(route="direct", label=label, reasons=["meets clean criteria"], params={})
|
||||
|
||||
chain = (cfg.get("CHAINS", {}) or {}).get(label) or []
|
||||
return SegmentDecision(route="enhance", label=label, reasons=reasons, params={"chain": chain})
|
||||
|
||||
|
||||
def select_route_for_segment(
|
||||
frames: List[Tuple[float, Tuple[float, float, float]]],
|
||||
start: float,
|
||||
end: float,
|
||||
cfg: Dict[str, Any],
|
||||
) -> Tuple[SegmentMetrics, SegmentDecision]:
|
||||
"""
|
||||
Convenience wrapper to aggregate metrics then classify.
|
||||
"""
|
||||
vad_min = float(cfg.get("VAD_MIN", 0.5))
|
||||
score_cfg = cfg.get("SCORE", {"w_snr": 0.7, "w_c50": 0.3})
|
||||
weights = (float(score_cfg.get("w_snr", 0.7)), float(score_cfg.get("w_c50", 0.3)))
|
||||
metrics = aggregate_metrics(frames, vad_min=vad_min, start=start, end=end, score_weights=weights)
|
||||
decision = classify_segment(metrics, cfg)
|
||||
return metrics, decision
|
||||
|
|
@ -0,0 +1,199 @@
|
|||
#!/usr/bin/env python3
|
||||
"""
|
||||
QualityAware ASR engine leveraging brouhaha VAD/SNR/C50 for quality-aware routing
|
||||
and optional enhancement before Whisper transcription. Designed to mirror
|
||||
WhisperProASR's high-level contract (transcribe, transcribe_to_srt).
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
import logging
|
||||
import json
|
||||
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
import whisper
|
||||
import srt
|
||||
import datetime
|
||||
|
||||
from pyannote.audio import Model, Inference
|
||||
|
||||
from whisperjav.utils.logger import logger
|
||||
from whisperjav.modules.audio_quality_router import select_route_for_segment
|
||||
from whisperjav.modules.audio_enhancement import run_chain
|
||||
|
||||
|
||||
class QualityAwareASR:
|
||||
"""Whisper ASR with brouhaha-based quality routing and optional enhancement."""
|
||||
|
||||
def __init__(self, model_config: Dict, params: Dict, task: str):
|
||||
# Whisper model/device
|
||||
self.model_name = model_config.get("model_name", "large-v2")
|
||||
self.device = model_config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Params
|
||||
self._decoder_params = params.get("decoder", {})
|
||||
self._provider_params = params.get("provider", {})
|
||||
self._qa_params = params.get("quality_aware", {}) # optional section
|
||||
|
||||
# Whisper parameters consolidated
|
||||
self.whisper_params: Dict[str, Any] = {}
|
||||
self.whisper_params.update(self._decoder_params)
|
||||
self.whisper_params.update(self._provider_params)
|
||||
self.whisper_params["task"] = task
|
||||
|
||||
# Optional language override for consistency with existing flows
|
||||
if "language" not in self.whisper_params:
|
||||
self.whisper_params["language"] = "ja"
|
||||
|
||||
# Load quality routing config (JSON) if provided
|
||||
self.routing_cfg = self._load_routing_cfg(self._qa_params.get("routing_config_path"))
|
||||
|
||||
# HF token optional – if missing, we degrade gracefully
|
||||
self.hf_token = self._qa_params.get("hf_token")
|
||||
|
||||
# Initialize models
|
||||
self._init_models()
|
||||
|
||||
def _init_models(self) -> None:
|
||||
logger.debug(f"Loading Whisper model: {self.model_name} on {self.device}")
|
||||
self.whisper_model = whisper.load_model(self.model_name, device=self.device)
|
||||
|
||||
# Load brouhaha model if possible
|
||||
self.brouhaha_inference: Inference | None = None
|
||||
try:
|
||||
snr_model = Model.from_pretrained("pyannote/brouhaha", use_auth_token=self.hf_token)
|
||||
self.brouhaha_inference = Inference(snr_model)
|
||||
logger.debug("Loaded brouhaha model for VAD/SNR/C50")
|
||||
except Exception as e:
|
||||
logger.warning(f"brouhaha model not available; proceeding without quality routing: {e}")
|
||||
|
||||
def _load_routing_cfg(self, path: Union[str, Path, None]) -> Dict[str, Any]:
|
||||
if not path:
|
||||
# Reasonable defaults aligned with template
|
||||
return {
|
||||
"VAD_MIN": 0.5,
|
||||
"SNR_CLEAN": 7.5,
|
||||
"C50_CLEAN": 0.5,
|
||||
"MIN_COVERAGE": 0.4,
|
||||
"SCORE": {"w_snr": 0.7, "w_c50": 0.3, "thresh": 6.0},
|
||||
"CHAINS": {
|
||||
"noisy": ["denoise", "norm"],
|
||||
"reverberant": ["dereverb", "denoise", "norm"],
|
||||
"music": ["separate", "denoise", "norm"],
|
||||
"low_energy": ["agc", "norm"],
|
||||
"unknown": ["denoise", "norm"],
|
||||
},
|
||||
}
|
||||
try:
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
return json.load(f)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load routing config {path}: {e}; using defaults")
|
||||
return {}
|
||||
|
||||
def _prepare_whisper_params(self) -> Dict[str, Any]:
|
||||
params = self.whisper_params.copy()
|
||||
# Normalize temperature list to tuple for whisper
|
||||
if isinstance(params.get("temperature"), list):
|
||||
params["temperature"] = tuple(params["temperature"]) # type: ignore
|
||||
params.setdefault("verbose", None)
|
||||
return params
|
||||
|
||||
def _brouhaha_analyze(self, waveform: np.ndarray, sample_rate: int) -> List[Tuple[float, Tuple[float, float, float]]]:
|
||||
"""Run brouhaha inference over the whole segment; return frames as (time, (vad, snr, c50))."""
|
||||
if self.brouhaha_inference is None:
|
||||
return []
|
||||
# pyannote Inference expects dict with waveform, sample_rate
|
||||
result = self.brouhaha_inference({"waveform": waveform[np.newaxis, :], "sample_rate": sample_rate})
|
||||
# Normalize to list of (time, (vad, snr, c50))
|
||||
frames: List[Tuple[float, Tuple[float, float, float]]] = []
|
||||
for t, triple in result:
|
||||
vad, snr, c50 = triple
|
||||
frames.append((float(t), (float(vad), float(snr), float(c50))))
|
||||
return frames
|
||||
|
||||
def _route_and_enhance(self, waveform: np.ndarray, sample_rate: int) -> Tuple[np.ndarray, int, Dict[str, Any]]:
|
||||
"""Decide direct vs enhance and apply chain if needed. Returns possibly modified audio and a metadata dict."""
|
||||
meta: Dict[str, Any] = {"routed": False, "decision": None}
|
||||
frames = self._brouhaha_analyze(waveform, sample_rate)
|
||||
if not frames:
|
||||
# No brouhaha available; pass-through
|
||||
return waveform, sample_rate, meta
|
||||
|
||||
metrics, decision = select_route_for_segment(frames, 0.0, float(len(waveform) / sample_rate), self.routing_cfg)
|
||||
meta["routed"] = True
|
||||
meta["metrics"] = metrics.__dict__
|
||||
meta["decision"] = {"route": decision.route, "label": decision.label, "reasons": decision.reasons}
|
||||
|
||||
if decision.route == "enhance":
|
||||
steps = (decision.params or {}).get("chain", [])
|
||||
try:
|
||||
enhanced_wav, enhanced_sr = run_chain(waveform, sample_rate, steps)
|
||||
meta["enhancement_chain"] = steps
|
||||
return enhanced_wav, enhanced_sr, meta
|
||||
except Exception as e:
|
||||
logger.warning(f"Enhancement failed ({steps}); falling back to direct: {e}")
|
||||
return waveform, sample_rate, meta
|
||||
return waveform, sample_rate, meta
|
||||
|
||||
def transcribe(self, audio_path: Union[str, Path], **kwargs) -> Dict[str, Any]:
|
||||
"""Transcribe a file with quality-aware routing and optional enhancement."""
|
||||
audio_path = Path(audio_path)
|
||||
# Load audio to mono float32
|
||||
data, sr = sf.read(str(audio_path), dtype="float32")
|
||||
if data.ndim > 1:
|
||||
data = np.mean(data, axis=1)
|
||||
|
||||
routed_wav, routed_sr, meta = self._route_and_enhance(data, sr)
|
||||
|
||||
params = self._prepare_whisper_params()
|
||||
try:
|
||||
result = self.whisper_model.transcribe(routed_wav, **params)
|
||||
except Exception as e:
|
||||
logger.error(f"QualityAwareASR transcribe failed: {e}")
|
||||
# Minimal fallback
|
||||
result = self.whisper_model.transcribe(routed_wav, task=params.get("task", "transcribe"), language=params.get("language", "ja"), verbose=None)
|
||||
|
||||
# Pack minimal contract: segments, text, language, plus optional meta
|
||||
out = {
|
||||
"segments": result.get("segments", []),
|
||||
"text": result.get("text", ""),
|
||||
"language": params.get("language", "ja"),
|
||||
}
|
||||
if meta.get("routed"):
|
||||
out["quality_meta"] = meta
|
||||
return out
|
||||
|
||||
def transcribe_to_srt(self, audio_path: Union[str, Path], output_srt_path: Union[str, Path], **kwargs) -> Path:
|
||||
audio_path = Path(audio_path)
|
||||
output_srt_path = Path(output_srt_path)
|
||||
|
||||
result = self.transcribe(audio_path, **kwargs)
|
||||
segments = result.get("segments", []) or []
|
||||
|
||||
srt_subs: List[srt.Subtitle] = []
|
||||
for idx, seg in enumerate(segments, 1):
|
||||
text = (seg.get("text") or "").strip()
|
||||
if not text:
|
||||
continue
|
||||
start = float(seg.get("start", 0.0))
|
||||
end = float(seg.get("end", start))
|
||||
srt_subs.append(
|
||||
srt.Subtitle(
|
||||
index=idx,
|
||||
start=datetime.timedelta(seconds=start),
|
||||
end=datetime.timedelta(seconds=end),
|
||||
content=text,
|
||||
)
|
||||
)
|
||||
|
||||
output_srt_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_srt_path, "w", encoding="utf-8") as f:
|
||||
f.write(srt.compose(srt_subs))
|
||||
|
||||
logger.debug(f"Saved SRT to: {output_srt_path}")
|
||||
return output_srt_path
|
||||
|
|
@ -0,0 +1,111 @@
|
|||
#!/usr/bin/env python3
|
||||
"""Brouhaha pipeline: scene detection + quality-aware ASR."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Dict, List, Tuple
|
||||
from datetime import datetime
|
||||
import time
|
||||
import shutil
|
||||
|
||||
from whisperjav.pipelines.base_pipeline import BasePipeline
|
||||
from whisperjav.modules.audio_extraction import AudioExtractor
|
||||
from whisperjav.modules.quality_aware_asr import QualityAwareASR
|
||||
from whisperjav.modules.srt_postprocessing import SRTPostProcessor as StandardPostProcessor
|
||||
from whisperjav.modules.scene_detection import DynamicSceneDetector
|
||||
from whisperjav.modules.srt_stitching import SRTStitcher
|
||||
from whisperjav.utils.logger import logger
|
||||
|
||||
|
||||
class BrouhahaPipeline(BasePipeline):
|
||||
"""Scene detection with QualityAware ASR (brouhaha-based routing)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
output_dir: str,
|
||||
temp_dir: str,
|
||||
keep_temp_files: bool,
|
||||
subs_language: str,
|
||||
resolved_config: Dict,
|
||||
progress_display=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(output_dir=output_dir, temp_dir=temp_dir, keep_temp_files=keep_temp_files, **kwargs)
|
||||
|
||||
self.progress = progress_display
|
||||
self.subs_language = subs_language
|
||||
|
||||
model_cfg = resolved_config["model"]
|
||||
params = resolved_config["params"]
|
||||
features = resolved_config["features"]
|
||||
task = resolved_config["task"]
|
||||
|
||||
scene_opts = features.get("scene_detection", {})
|
||||
post_proc_opts = features.get("post_processing", {})
|
||||
|
||||
self.audio_extractor = AudioExtractor()
|
||||
self.scene_detector = DynamicSceneDetector(**scene_opts)
|
||||
self.asr = QualityAwareASR(model_config=model_cfg, params=params, task=task)
|
||||
self.stitcher = SRTStitcher()
|
||||
|
||||
lang_code = "en" if subs_language == "english-direct" else "ja"
|
||||
self.standard_postprocessor = StandardPostProcessor(language=lang_code, **post_proc_opts)
|
||||
|
||||
def process(self, media_info: Dict) -> Dict:
|
||||
start_time = time.time()
|
||||
input_file = media_info["path"]
|
||||
media_basename = media_info["basename"]
|
||||
|
||||
master_metadata = self.metadata_manager.create_master_metadata(
|
||||
input_file=input_file, mode=self.get_mode_name(), media_info=media_info
|
||||
)
|
||||
|
||||
# Step 1: Extract audio
|
||||
self.progress.set_current_step("Transforming audio", 1, 5) if self.progress else None
|
||||
audio_path = self.temp_dir / f"{media_basename}_extracted.wav"
|
||||
extracted_audio, duration = self.audio_extractor.extract(input_file, audio_path)
|
||||
master_metadata["input_info"]["processed_audio_file"] = str(extracted_audio)
|
||||
master_metadata["input_info"]["audio_duration_seconds"] = duration
|
||||
|
||||
# Step 2: Detect scenes
|
||||
self.progress.set_current_step("Detecting audio scenes", 2, 5) if self.progress else None
|
||||
scenes_dir = self.temp_dir / "scenes"
|
||||
scenes_dir.mkdir(exist_ok=True)
|
||||
scene_paths = self.scene_detector.detect_scenes(extracted_audio, scenes_dir, media_basename)
|
||||
|
||||
# Step 3: Transcribe scenes using QualityAware ASR
|
||||
self.progress.set_current_step("Transcribing scenes (quality-aware)", 3, 5) if self.progress else None
|
||||
scene_srts_dir = self.temp_dir / "scene_srts"
|
||||
scene_srts_dir.mkdir(exist_ok=True)
|
||||
scene_srt_info: List[Tuple[Path, float]] = []
|
||||
for idx, (scene_path, start_time_sec, _, _) in enumerate(scene_paths):
|
||||
scene_srt_path = scene_srts_dir / f"{scene_path.stem}.srt"
|
||||
try:
|
||||
self.asr.transcribe_to_srt(scene_path, scene_srt_path, task=self.asr.whisper_params.get("task"))
|
||||
if scene_srt_path.exists() and scene_srt_path.stat().st_size > 0:
|
||||
scene_srt_info.append((scene_srt_path, start_time_sec))
|
||||
except Exception as e:
|
||||
logger.error(f"Scene {idx} transcription failed: {e}")
|
||||
|
||||
# Step 4: Stitch
|
||||
self.progress.set_current_step("Combining scene transcriptions", 4, 5) if self.progress else None
|
||||
stitched_srt_path = self.temp_dir / f"{media_basename}_stitched.srt"
|
||||
num_subtitles = self.stitcher.stitch(scene_srt_info, stitched_srt_path)
|
||||
|
||||
# Step 5: Post-process
|
||||
self.progress.set_current_step("Post-processing subtitles", 5, 5) if self.progress else None
|
||||
lang_code = "en" if self.subs_language == "english-direct" else "ja"
|
||||
final_srt_path = self.output_dir / f"{media_basename}.{lang_code}.whisperjav.srt"
|
||||
processed_srt_path, stats = self.standard_postprocessor.process(stitched_srt_path, final_srt_path)
|
||||
if processed_srt_path != final_srt_path:
|
||||
shutil.copy2(processed_srt_path, final_srt_path)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
master_metadata["summary"]["total_processing_time_seconds"] = round(total_time, 2)
|
||||
master_metadata["output_files"]["final_srt"] = str(final_srt_path)
|
||||
master_metadata["output_files"]["stitched_srt"] = str(stitched_srt_path)
|
||||
|
||||
return master_metadata
|
||||
|
||||
def get_mode_name(self) -> str:
|
||||
return "brouhaha"
|
||||
|
|
@ -0,0 +1,276 @@
|
|||
import logging
|
||||
|
||||
# Configure logging level and format
|
||||
logging.basicConfig(
|
||||
level=logging.INFO, # options: DEBUG, INFO, WARNING, ERROR, CRITICAL
|
||||
format='%(asctime)s [%(levelname)s] %(message)s',
|
||||
handlers=[
|
||||
logging.StreamHandler() # prints to console
|
||||
]
|
||||
)
|
||||
|
||||
!pip install -qq pyannote.audio
|
||||
!pip install openai-whisper
|
||||
#using forked version of pyannote-whisper so that requirements.txt doesn't clash with brouhaha-vad
|
||||
#if running on colab, the session will need restarting after these libraries have been installed due to numpy downgrade
|
||||
!pip install https://github.com/alunkingusw/pyannote-whisper/archive/main.zip
|
||||
!pip install https://github.com/marianne-m/brouhaha-vad/archive/main.zip
|
||||
|
||||
|
||||
#LIBRARIES
|
||||
|
||||
#transcription
|
||||
import whisper
|
||||
|
||||
#audio handling
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
#diarisation
|
||||
from pyannote.audio import Pipeline
|
||||
from pyannote_whisper.utils import diarize_text
|
||||
from pyannote.audio import Model
|
||||
from pyannote.core import Segment, Annotation
|
||||
|
||||
#embedding
|
||||
from pyannote.audio import Inference
|
||||
from sklearn.metrics.pairwise import cosine_similarity
|
||||
from sklearn.preprocessing import normalize
|
||||
import numpy
|
||||
|
||||
#other
|
||||
import datetime
|
||||
from collections import defaultdict
|
||||
|
||||
#file handling
|
||||
from google.colab import drive, userdata
|
||||
|
||||
#set up files to load
|
||||
logging.info("Mounting Google Drive. Please follow the instructions to authenticate.")
|
||||
drive.mount('/content/drive')
|
||||
|
||||
# Get Hugging Face token stored in Colab Secrets
|
||||
HUGGING_FACE = userdata.get('HUGGING_FACE')
|
||||
|
||||
# Use the path to your audio file in Google Drive
|
||||
input_file = '/content/drive/audio_to_process.wav'
|
||||
|
||||
#organise the inputs for the transcription pipeline
|
||||
NUM_SPEAKERS = None
|
||||
language = 'English'
|
||||
model_size = 'medium'
|
||||
model_name = model_size
|
||||
#name according to available models outlined on https://github.com/openai/whisper?tab=readme-ov-file#available-models-and-languages
|
||||
if language == 'English' and model_size != 'large':
|
||||
model_name += '.en'
|
||||
|
||||
#inputs for the embedding process
|
||||
SNR_THRESHOLD = 5.0
|
||||
MIN_SEGMENT_DURATION = 5.0
|
||||
EMBEDDING_MATCH_THRESHOLD = 0.7
|
||||
|
||||
#MODELS
|
||||
#diarisation by Pyannote
|
||||
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=HUGGING_FACE)
|
||||
#transcription by Whisper
|
||||
model = whisper.load_model(model_name)
|
||||
|
||||
|
||||
# Move models to the GPU if available
|
||||
if torch.cuda.is_available():
|
||||
pipeline.to(torch.device("cuda"))
|
||||
model.to(torch.device("cuda"))
|
||||
logging.info("Pyannote and Whisper models moved to GPU")
|
||||
else:
|
||||
logging.info("GPU not available, Pyannote and Whisper models running on CPU")
|
||||
|
||||
# Perform the intensive stuff - transcription
|
||||
asr_result = model.transcribe(input_file)
|
||||
diarisation_result = pipeline(input_file, num_speakers=NUM_SPEAKERS)
|
||||
|
||||
# We could merge the results here if we didn't want to look at embeddings
|
||||
# final_result = diarize_text(asr_result, diarisation_result)
|
||||
|
||||
# print out the results or save to file
|
||||
# for seg, spk, sent in final_result:
|
||||
#line = f'{seg.start:.2f} {seg.end:.2f} {spk} "{sent}"'
|
||||
#print(line)
|
||||
|
||||
# ------ This next section deals with the embeddings and comparison used to label speakers ------
|
||||
|
||||
# Declare our functions used for various stages of embedding comparison
|
||||
def crop_waveform(waveform, sample_rate, segment):
|
||||
"""Return waveform cropped to the specified segment."""
|
||||
start_sample = int(segment.start * sample_rate)
|
||||
end_sample = int(segment.end * sample_rate)
|
||||
return waveform[:, start_sample:end_sample]
|
||||
|
||||
def get_reference_embeddings(ref_dict):
|
||||
"""Return dictionary of normalized reference embeddings."""
|
||||
embeddings = {}
|
||||
for i, (name, path) in enumerate(ref_dict.items()):
|
||||
try:
|
||||
emb = embedding_inference(path).reshape(1, -1)
|
||||
emb = normalize(emb)
|
||||
embeddings[i] = emb
|
||||
logging.info(f"Loaded reference embedding for '{name}' as ID {i}")
|
||||
except Exception as e:
|
||||
logging.warning(f"Error loading reference '{name}': {e}")
|
||||
return embeddings
|
||||
|
||||
def match_speaker_by_embedding(embedding, speaker_embeddings, speaker_names, threshold=EMBEDDING_MATCH_THRESHOLD):
|
||||
"""Return the best matching speaker name based on cosine similarity."""
|
||||
best_match = None
|
||||
highest_similarity = -1
|
||||
|
||||
# Compare with reference embeddings
|
||||
for ref_id, ref_embedding in speaker_embeddings.items():
|
||||
similarity = cosine_similarity(embedding, ref_embedding)[0][0]
|
||||
if similarity > highest_similarity:
|
||||
highest_similarity = similarity
|
||||
best_match = speaker_names[ref_id]
|
||||
logging.info(f"Highest similarity with '{best_match}': {highest_similarity}")
|
||||
if highest_similarity >= threshold:
|
||||
return best_match
|
||||
else:
|
||||
return None
|
||||
|
||||
def remove_speaker_from_diarisation(diarisation, speaker_to_remove):
|
||||
"""Return a copy of diarisation with the specified speaker removed."""
|
||||
new_diarisation = Annotation(uri=diarisation.uri)
|
||||
|
||||
for segment, track, speaker in diarisation.itertracks(yield_label=True):
|
||||
if speaker != speaker_to_remove:
|
||||
new_diarisation[segment, track] = speaker
|
||||
|
||||
return new_diarisation
|
||||
|
||||
def rename_speaker_in_diarisation(diarisation, old_label, new_label):
|
||||
"""Return a copy of diarisation with one speaker label renamed."""
|
||||
updated_diarisation = Annotation(uri=diarisation.uri)
|
||||
|
||||
for segment, track, speaker in diarisation.itertracks(yield_label=True):
|
||||
if speaker == old_label:
|
||||
updated_diarisation[segment, track] = new_label
|
||||
else:
|
||||
updated_diarisation[segment, track] = speaker
|
||||
|
||||
return updated_diarisation
|
||||
|
||||
|
||||
#load speaker samples and generate the embeddings to be tested below
|
||||
embedding_model = Model.from_pretrained("pyannote/embedding",
|
||||
use_auth_token=HUGGING_FACE)
|
||||
# Move model to GPU if available
|
||||
if torch.cuda.is_available():
|
||||
embedding_model.to(torch.device("cuda"))
|
||||
logging.info("Embedding model moved to GPU")
|
||||
else:
|
||||
logging.info("GPU not available, embedding model running on CPU")
|
||||
|
||||
embedding_inference = Inference(embedding_model, window="whole")
|
||||
|
||||
# Load full audio for cropping our speaker clips
|
||||
waveform, sample_rate = torchaudio.load(input_file) # mono only
|
||||
|
||||
# Define the dictionary of known speaker names and the sample audio
|
||||
speaker_samples = {
|
||||
"speaker_one": "/content/drive/speaker_one_sample.wav",
|
||||
"speaker_two": "/content/drive/speaker_two_sample.wav",
|
||||
"speaker_three": "/content/drive/speaker_three_sample.wav",
|
||||
"speaker_four": "/content/drive/speaker_four_sample.wav",
|
||||
"speaker_five": "/content/drive/speaker_five_sample.wav",
|
||||
"speaker_six": "/content/drive/speaker_six_sample.wav",
|
||||
"speaker_seven": "/content/drive/speaker_seven_sample.wav",
|
||||
}
|
||||
#turn those files into embeddings
|
||||
speaker_embeddings = get_reference_embeddings(speaker_samples)
|
||||
|
||||
# Use diarisation from previous code block to loop through identified speakers
|
||||
# diarisation_result
|
||||
|
||||
# Group segments by speaker so we can run comparisons
|
||||
segments_by_speaker = defaultdict(list)
|
||||
for turn, _, speaker in diarisation_result.itertracks(yield_label=True):
|
||||
segments_by_speaker[speaker].append(Segment(turn.start, turn.end))
|
||||
|
||||
# Load SNR model which will help identify the best audio clip to compare
|
||||
snr_model = Model.from_pretrained("pyannote/brouhaha", use_auth_token=HUGGING_FACE)
|
||||
|
||||
# Move model to GPU if available
|
||||
if torch.cuda.is_available():
|
||||
snr_model.to(torch.device("cuda"))
|
||||
logging.info("SNR model moved to GPU")
|
||||
else:
|
||||
logging.info("GPU not available, SNR model running on CPU")
|
||||
|
||||
# apply model
|
||||
snr_inference = Inference(snr_model)
|
||||
|
||||
# Step through each speaker label and its associated segments from diarisation
|
||||
for speaker, segments in segments_by_speaker.items():
|
||||
valid_embeddings = []
|
||||
|
||||
# Process each segment for that speaker
|
||||
for segment in segments:
|
||||
duration = segment.end - segment.start
|
||||
|
||||
# Ignore short segments if there are other longer ones
|
||||
if duration < MIN_SEGMENT_DURATION and len(segments) > 1:
|
||||
continue
|
||||
|
||||
# Crop audio to just this segment
|
||||
cropped = crop_waveform(waveform, sample_rate, segment)
|
||||
|
||||
# Apply SNR model to filter out low-quality audio
|
||||
snr_result = snr_inference({"waveform": cropped, "sample_rate": sample_rate})
|
||||
snr_values = [snr for frame, (vad, snr, c50) in snr_result if vad > 0.5]
|
||||
|
||||
if not snr_values:
|
||||
continue # skip if there's no speech
|
||||
|
||||
avg_snr = sum(snr_values) / len(snr_values)
|
||||
|
||||
if avg_snr < SNR_THRESHOLD: # optional threshold to skip noisy clips
|
||||
continue
|
||||
|
||||
# Compute embedding for this valid segment
|
||||
embedding = embedding_inference({
|
||||
"waveform": cropped,
|
||||
"sample_rate": sample_rate
|
||||
}).reshape(1, -1)
|
||||
|
||||
embedding = normalize(embedding)
|
||||
valid_embeddings.append(embedding)
|
||||
|
||||
# Average the embeddings for this speaker if any were collected
|
||||
if valid_embeddings:
|
||||
mean_embedding = numpy.mean(numpy.vstack(valid_embeddings), axis=0, keepdims=True)
|
||||
|
||||
match = match_speaker_by_embedding(mean_embedding, speaker_embeddings, list(speaker_samples.keys()))
|
||||
|
||||
if match:
|
||||
logging.info(f"Speaker '{speaker}' best matches reference speaker: {match}")
|
||||
diarisation_result = rename_speaker_in_diarisation(diarisation_result, speaker, match)
|
||||
else:
|
||||
logging.info(f"Speaker '{speaker}' could not be confidently matched and will be removed.")
|
||||
diarisation_result = remove_speaker_from_diarisation(diarisation_result, speaker)
|
||||
else:
|
||||
logging.info(f"No valid segments found for speaker '{speaker}', removing from diarisation.")
|
||||
diarisation_result = remove_speaker_from_diarisation(diarisation_result, speaker)
|
||||
|
||||
#finally, merge the diarisation results with the whisper output and export to terminal.
|
||||
final_result = diarize_text(asr_result, diarisation_result)
|
||||
|
||||
print(f"Results")
|
||||
# print out the results
|
||||
for seg, spk, sent in final_result:
|
||||
line = f'{seg.start:.2f} {seg.end:.2f} {spk} "{sent}"'
|
||||
print(line)
|
||||
|
||||
# or could save to file
|
||||
output_path = "/content/drive/final_result.txt"
|
||||
with open(output_path, "w") as f:
|
||||
for seg, spk, sent in final_result:
|
||||
line = f'{seg.start:.2f} {seg.end:.2f} {spk} "{sent}"\n'
|
||||
f.write(line)
|
||||
Loading…
Reference in New Issue