Compare commits

...

1 Commits

Author SHA1 Message Date
meizhong986 971858f13f feat: initial commit for brouhaha pipeline architecture experiment 2025-09-03 22:23:24 +01:00
9 changed files with 915 additions and 1 deletions

View File

@ -0,0 +1,98 @@
# Quality-aware ASR pipeline using brouhaha VAD/SNR/C50
This document captures a two-stage pipeline design proposed for WhisperJAV that leverages brouhaha VAD/SNR/C50 to route segments to the optimal processing path.
## Goals
- Maximize ASR accuracy and throughput by routing clean speech directly to ASR
- Improve challenging segments via targeted enhancement prior to ASR
- Remain modular, configurable, and data-driven
## High-level flow
1. Diarize audio and extract segments (existing pyannote pipeline)
2. For each segment:
- Run brouhaha inference to obtain per-frame (vad, snr, c50)
- Aggregate per-segment metrics (e.g., mean/median over vad>threshold frames)
- Classify segment as clean or challenging based on configurable rules
3. Route:
- Clean → Whisper ASR as-is
- Challenging → Enhancement stack → Whisper ASR
4. Merge ASR outputs and preserve provenance/metadata (path taken, metrics)
## Segment classification
- Inputs: list of frames with (vad, snr, c50)
- Steps:
- Keep frames with vad >= VAD_MIN
- Compute metrics: avg_snr, p50_snr, p10_snr, avg_c50, speech_coverage (ratio of speech frames)
- Decision policies (configurable):
- Clean if avg_snr >= SNR_CLEAN && avg_c50 >= C50_CLEAN && speech_coverage >= MIN_COVERAGE
- Challenging otherwise, with sub-classes (noisy, reverberant, low-speech-coverage)
- Optional: score = w_snr * avg_snr + w_c50 * avg_c50; compare with SCORE_THRESH
## Enhancement stack (challenging segments)
- Noisy: denoise → loudness norm → (optional) bandwidth EQ
- Reverberant: dereverb → denoise → norm
- Music/background: vocal separation → denoise → norm
- Low-energy: dynamic range compression/AGC → norm
## Libraries/algorithms (Python-first)
- Denoising:
- torchaudio.functional.spectral_gate
- RNNoise (via rnnoise-py)
- NVIDIA Maxine/RTX Voice (Windows, optional, external)
- SpeechBrain/asteroid (DPRNN/ConvTasNet/FullSubNet2)
- Vocal separation:
- demucs (facebookresearch/demucs)
- spleeter (deezer/spleeter)
- Dereverberation / clarity:
- Weighted prediction error (WPE) via nara_wpe
- Kaggle/pyroomacoustics-based dereverb filters
- DCCRN/DCCRNet models
- Loudness/dynamics:
- pyloudnorm for EBU R128 normalization
- librosa effects (preemphasis), simple compressors (custom)
- Utility:
- pyannote.audio for diarization/embedding
- brouhaha-vad for VAD/SNR/C50
## Data structures
- SegmentMetrics
- start: float
- end: float
- avg_snr: float
- avg_c50: float
- coverage: float
- label: Literal["clean", "noisy", "reverberant", "music", "low_energy"]
- SegmentDecision
- route: Literal["direct", "enhance"]
- reasons: list[str]
- params: dict[str, Any] (e.g., chosen enhancement chain)
## Config
- YAML/JSON with thresholds and weights
- Example:
- VAD_MIN: 0.5
- SNR_CLEAN: 7.5
- C50_CLEAN: 0.5
- MIN_COVERAGE: 0.4
- SCORE: { w_snr: 0.7, w_c50: 0.3, thresh: 6.0 }
- Chains:
- noisy: [denoise, norm]
- reverberant: [dereverb, denoise, norm]
- music: [separate, denoise, norm]
- low_energy: [agc, norm]
## Integration notes
- Keep current diarization/embedding flow unchanged
- Insert a segment routing phase prior to embedding ASR or directly prior to Whisper
- Enhancement should be stateless, batch-friendly, and GPU-optional
- Cache enhanced waveforms for reproducibility; tag outputs with metadata
## Testing
- Unit tests for routing decisions given synthetic metrics
- Golden audio samples with known issues to verify improvement
- A/B test: baseline vs two-step routing on WER/CER and speaker attribution
## Future
- Adaptive thresholds learned from data
- Confidence-driven re-segmentation for borderline cases
- On-device light-weight enhancement for real-time

View File

@ -0,0 +1,26 @@
# Brouhaha Pipeline & QualityAware ASR
This document outlines the architectural additions:
- A new ASR engine `QualityAwareASR` that uses brouhaha (VAD/SNR/C50) to route segments to enhancement or direct ASR.
- A new pipeline `BrouhahaPipeline` mirroring the balanced pipeline but using `QualityAwareASR`.
## Contracts
- `QualityAwareASR.transcribe(path) -> Dict` returns {segments, text, language, [quality_meta]}.
- `QualityAwareASR.transcribe_to_srt(path, out_path) -> Path` writes SRT similar to WhisperProASR.
## Config
- Optional `params.quality_aware.routing_config_path` to point to a JSON like `whisperjav/config/quality_routing.template.json`.
- Optional `params.quality_aware.hf_token` for private models if needed.
## Integration
- New pipeline class `BrouhahaPipeline` under `whisperjav/pipelines/brouhaha_pipeline.py`.
- To enable as a CLI mode, wire it into `main.py` mode selection (not yet done in code here to avoid conflicts).
## Enhancement registry
- Implement your enhancement steps, then register them by name in `audio_enhancement.registry`.
- The router selects a chain via config and `run_chain` executes it.
## Next steps
- Decide on enhancement implementations: rnnoise/nara_wpe/demucs/pyloudnorm.
- Wire a new CLI mode (e.g., --mode brouhaha) and add unit tests for routing decisions.

View File

@ -0,0 +1,14 @@
{
"VAD_MIN": 0.5,
"SNR_CLEAN": 7.5,
"C50_CLEAN": 0.5,
"MIN_COVERAGE": 0.4,
"SCORE": { "w_snr": 0.7, "w_c50": 0.3, "thresh": 6.0 },
"CHAINS": {
"noisy": ["denoise", "norm"],
"reverberant": ["dereverb", "denoise", "norm"],
"music": ["separate", "denoise", "norm"],
"low_energy": ["agc", "norm"],
"unknown": ["denoise", "norm"]
}
}

View File

@ -37,6 +37,7 @@ from whisperjav.modules.media_discovery import MediaDiscovery
from whisperjav.pipelines.faster_pipeline import FasterPipeline
from whisperjav.pipelines.fast_pipeline import FastPipeline
from whisperjav.pipelines.balanced_pipeline import BalancedPipeline
from whisperjav.pipelines.brouhaha_pipeline import BrouhahaPipeline
from whisperjav.config.transcription_tuner import TranscriptionTuner
from whisperjav.__version__ import __version__
@ -96,7 +97,7 @@ def parse_arguments():
# Core arguments
parser.add_argument("input", nargs="*", help="Input media file(s), directory, or wildcard pattern.")
parser.add_argument("--mode", choices=["balanced", "fast", "faster"], default="balanced",
parser.add_argument("--mode", choices=["balanced", "fast", "faster", "brouhaha"], default="balanced",
help="Processing mode (default: balanced)")
parser.add_argument("--config", default=None, help="Path to a JSON configuration file")
parser.add_argument("--subs-language", choices=["japanese", "english-direct"],
@ -241,6 +242,8 @@ def process_files_sync(media_files: List[Dict], args: argparse.Namespace, resolv
pipeline = FasterPipeline(**pipeline_args)
elif args.mode == "fast":
pipeline = FastPipeline(**pipeline_args)
elif args.mode == "brouhaha":
pipeline = BrouhahaPipeline(**pipeline_args)
else: # balanced
pipeline = BalancedPipeline(**pipeline_args)

View File

@ -0,0 +1,53 @@
"""
Audio enhancement orchestrator.
Provides a simple registry and runner for enhancement chains
such as denoise, dereverb, separation, normalization, etc.
This module is intentionally minimal; implement steps with
preferred libraries in your environment.
"""
from __future__ import annotations
from typing import Any, Callable, Dict, List, Tuple
EnhancementFn = Callable[[Any, int], Tuple[Any, int]]
class EnhancementRegistry:
def __init__(self) -> None:
self._fns: Dict[str, EnhancementFn] = {}
def register(self, name: str, fn: EnhancementFn) -> None:
self._fns[name] = fn
def get(self, name: str) -> EnhancementFn | None:
return self._fns.get(name)
registry = EnhancementRegistry()
def run_chain(waveform: Any, sample_rate: int, steps: List[str]) -> Tuple[Any, int]:
"""
Run a sequence of enhancement steps registered in the registry.
Each step is a function: (waveform, sample_rate) -> (waveform, sample_rate)
"""
out_wav, out_sr = waveform, sample_rate
for name in steps:
fn = registry.get(name)
if fn is None:
# skip unknown steps, allowing config-driven experimentation
continue
out_wav, out_sr = fn(out_wav, out_sr)
return out_wav, out_sr
# Example placeholder steps (no-ops) — replace with real implementations
def _noop(wav: Any, sr: int) -> Tuple[Any, int]:
return wav, sr
# Pre-register common names to avoid KeyErrors during experimentation
for _name in ["denoise", "dereverb", "separate", "norm", "agc"]:
registry.register(_name, _noop)

View File

@ -0,0 +1,134 @@
"""
Quality-based segment routing using brouhaha VAD/SNR/C50.
This module computes per-segment quality metrics from brouhaha
inference outputs and classifies segments as 'direct' (ASR as-is)
or 'enhance' (run through enhancement chain before ASR).
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Any, Dict, List, Tuple
@dataclass
class SegmentMetrics:
start: float
end: float
avg_snr: float
avg_c50: float
coverage: float # ratio of frames with vad >= vad_min
score: float
@dataclass
class SegmentDecision:
route: str # 'direct' | 'enhance'
label: str # 'clean' | 'noisy' | 'reverberant' | 'music' | 'low_energy' | 'unknown'
reasons: List[str]
params: Dict[str, Any]
def aggregate_metrics(
frames: List[Tuple[float, Tuple[float, float, float]]],
vad_min: float,
start: float,
end: float,
score_weights: Tuple[float, float] = (0.7, 0.3),
) -> SegmentMetrics:
"""
Aggregate per-frame (time, (vad, snr, c50)) into segment metrics.
frames: list of (time, (vad, snr, c50)) from brouhaha Inference.
vad_min: threshold to consider a frame as speech.
start, end: segment boundaries.
score_weights: (w_snr, w_c50) for composite score.
"""
if not frames:
return SegmentMetrics(start, end, avg_snr=0.0, avg_c50=0.0, coverage=0.0, score=0.0)
# Import numpy lazily to avoid hard dependency at import time
import numpy as np # type: ignore
vad = np.array([v for _, (v, _, _) in frames], dtype=np.float32)
snr = np.array([s for _, (_, s, _) in frames], dtype=np.float32)
c50 = np.array([c for _, (_, _, c) in frames], dtype=np.float32)
speech_mask = vad >= vad_min
coverage = float(np.mean(speech_mask)) if vad.size else 0.0
if speech_mask.any():
avg_snr = float(np.mean(snr[speech_mask]))
avg_c50 = float(np.mean(c50[speech_mask]))
else:
avg_snr = 0.0
avg_c50 = 0.0
w_snr, w_c50 = score_weights
score = w_snr * avg_snr + w_c50 * avg_c50
return SegmentMetrics(start, end, avg_snr=avg_snr, avg_c50=avg_c50, coverage=coverage, score=score)
def classify_segment(
m: SegmentMetrics,
cfg: Dict[str, Any],
) -> SegmentDecision:
"""
Classify a segment as 'direct' or 'enhance' with an issue label.
cfg expects keys:
- VAD_MIN, SNR_CLEAN, C50_CLEAN, MIN_COVERAGE
- SCORE: { w_snr, w_c50, thresh }
- CHAINS: mapping from label -> list of step names
"""
reasons: List[str] = []
label = "clean"
snr_clean = float(cfg.get("SNR_CLEAN", 7.5))
c50_clean = float(cfg.get("C50_CLEAN", 0.5))
min_cov = float(cfg.get("MIN_COVERAGE", 0.4))
score_cfg = cfg.get("SCORE", {"w_snr": 0.7, "w_c50": 0.3, "thresh": 6.0})
score_thresh = float(score_cfg.get("thresh", 6.0))
if m.coverage < min_cov:
label = "low_energy"
reasons.append(f"coverage {m.coverage:.2f} < {min_cov}")
if m.avg_snr < snr_clean:
label = "noisy"
reasons.append(f"avg_snr {m.avg_snr:.2f} < {snr_clean}")
if m.avg_c50 < c50_clean:
# prefer 'reverberant' if not already more critical
if label == "clean":
label = "reverberant"
reasons.append(f"avg_c50 {m.avg_c50:.2f} < {c50_clean}")
# Composite score gate
if m.score < score_thresh and label == "clean":
label = "unknown"
reasons.append(f"score {m.score:.2f} < {score_thresh}")
if label == "clean":
return SegmentDecision(route="direct", label=label, reasons=["meets clean criteria"], params={})
chain = (cfg.get("CHAINS", {}) or {}).get(label) or []
return SegmentDecision(route="enhance", label=label, reasons=reasons, params={"chain": chain})
def select_route_for_segment(
frames: List[Tuple[float, Tuple[float, float, float]]],
start: float,
end: float,
cfg: Dict[str, Any],
) -> Tuple[SegmentMetrics, SegmentDecision]:
"""
Convenience wrapper to aggregate metrics then classify.
"""
vad_min = float(cfg.get("VAD_MIN", 0.5))
score_cfg = cfg.get("SCORE", {"w_snr": 0.7, "w_c50": 0.3})
weights = (float(score_cfg.get("w_snr", 0.7)), float(score_cfg.get("w_c50", 0.3)))
metrics = aggregate_metrics(frames, vad_min=vad_min, start=start, end=end, score_weights=weights)
decision = classify_segment(metrics, cfg)
return metrics, decision

View File

@ -0,0 +1,199 @@
#!/usr/bin/env python3
"""
QualityAware ASR engine leveraging brouhaha VAD/SNR/C50 for quality-aware routing
and optional enhancement before Whisper transcription. Designed to mirror
WhisperProASR's high-level contract (transcribe, transcribe_to_srt).
"""
from __future__ import annotations
from pathlib import Path
from typing import Any, Dict, List, Tuple, Union
import logging
import json
import numpy as np
import soundfile as sf
import torch
import whisper
import srt
import datetime
from pyannote.audio import Model, Inference
from whisperjav.utils.logger import logger
from whisperjav.modules.audio_quality_router import select_route_for_segment
from whisperjav.modules.audio_enhancement import run_chain
class QualityAwareASR:
"""Whisper ASR with brouhaha-based quality routing and optional enhancement."""
def __init__(self, model_config: Dict, params: Dict, task: str):
# Whisper model/device
self.model_name = model_config.get("model_name", "large-v2")
self.device = model_config.get("device", "cuda" if torch.cuda.is_available() else "cpu")
# Params
self._decoder_params = params.get("decoder", {})
self._provider_params = params.get("provider", {})
self._qa_params = params.get("quality_aware", {}) # optional section
# Whisper parameters consolidated
self.whisper_params: Dict[str, Any] = {}
self.whisper_params.update(self._decoder_params)
self.whisper_params.update(self._provider_params)
self.whisper_params["task"] = task
# Optional language override for consistency with existing flows
if "language" not in self.whisper_params:
self.whisper_params["language"] = "ja"
# Load quality routing config (JSON) if provided
self.routing_cfg = self._load_routing_cfg(self._qa_params.get("routing_config_path"))
# HF token optional if missing, we degrade gracefully
self.hf_token = self._qa_params.get("hf_token")
# Initialize models
self._init_models()
def _init_models(self) -> None:
logger.debug(f"Loading Whisper model: {self.model_name} on {self.device}")
self.whisper_model = whisper.load_model(self.model_name, device=self.device)
# Load brouhaha model if possible
self.brouhaha_inference: Inference | None = None
try:
snr_model = Model.from_pretrained("pyannote/brouhaha", use_auth_token=self.hf_token)
self.brouhaha_inference = Inference(snr_model)
logger.debug("Loaded brouhaha model for VAD/SNR/C50")
except Exception as e:
logger.warning(f"brouhaha model not available; proceeding without quality routing: {e}")
def _load_routing_cfg(self, path: Union[str, Path, None]) -> Dict[str, Any]:
if not path:
# Reasonable defaults aligned with template
return {
"VAD_MIN": 0.5,
"SNR_CLEAN": 7.5,
"C50_CLEAN": 0.5,
"MIN_COVERAGE": 0.4,
"SCORE": {"w_snr": 0.7, "w_c50": 0.3, "thresh": 6.0},
"CHAINS": {
"noisy": ["denoise", "norm"],
"reverberant": ["dereverb", "denoise", "norm"],
"music": ["separate", "denoise", "norm"],
"low_energy": ["agc", "norm"],
"unknown": ["denoise", "norm"],
},
}
try:
with open(path, "r", encoding="utf-8") as f:
return json.load(f)
except Exception as e:
logger.warning(f"Failed to load routing config {path}: {e}; using defaults")
return {}
def _prepare_whisper_params(self) -> Dict[str, Any]:
params = self.whisper_params.copy()
# Normalize temperature list to tuple for whisper
if isinstance(params.get("temperature"), list):
params["temperature"] = tuple(params["temperature"]) # type: ignore
params.setdefault("verbose", None)
return params
def _brouhaha_analyze(self, waveform: np.ndarray, sample_rate: int) -> List[Tuple[float, Tuple[float, float, float]]]:
"""Run brouhaha inference over the whole segment; return frames as (time, (vad, snr, c50))."""
if self.brouhaha_inference is None:
return []
# pyannote Inference expects dict with waveform, sample_rate
result = self.brouhaha_inference({"waveform": waveform[np.newaxis, :], "sample_rate": sample_rate})
# Normalize to list of (time, (vad, snr, c50))
frames: List[Tuple[float, Tuple[float, float, float]]] = []
for t, triple in result:
vad, snr, c50 = triple
frames.append((float(t), (float(vad), float(snr), float(c50))))
return frames
def _route_and_enhance(self, waveform: np.ndarray, sample_rate: int) -> Tuple[np.ndarray, int, Dict[str, Any]]:
"""Decide direct vs enhance and apply chain if needed. Returns possibly modified audio and a metadata dict."""
meta: Dict[str, Any] = {"routed": False, "decision": None}
frames = self._brouhaha_analyze(waveform, sample_rate)
if not frames:
# No brouhaha available; pass-through
return waveform, sample_rate, meta
metrics, decision = select_route_for_segment(frames, 0.0, float(len(waveform) / sample_rate), self.routing_cfg)
meta["routed"] = True
meta["metrics"] = metrics.__dict__
meta["decision"] = {"route": decision.route, "label": decision.label, "reasons": decision.reasons}
if decision.route == "enhance":
steps = (decision.params or {}).get("chain", [])
try:
enhanced_wav, enhanced_sr = run_chain(waveform, sample_rate, steps)
meta["enhancement_chain"] = steps
return enhanced_wav, enhanced_sr, meta
except Exception as e:
logger.warning(f"Enhancement failed ({steps}); falling back to direct: {e}")
return waveform, sample_rate, meta
return waveform, sample_rate, meta
def transcribe(self, audio_path: Union[str, Path], **kwargs) -> Dict[str, Any]:
"""Transcribe a file with quality-aware routing and optional enhancement."""
audio_path = Path(audio_path)
# Load audio to mono float32
data, sr = sf.read(str(audio_path), dtype="float32")
if data.ndim > 1:
data = np.mean(data, axis=1)
routed_wav, routed_sr, meta = self._route_and_enhance(data, sr)
params = self._prepare_whisper_params()
try:
result = self.whisper_model.transcribe(routed_wav, **params)
except Exception as e:
logger.error(f"QualityAwareASR transcribe failed: {e}")
# Minimal fallback
result = self.whisper_model.transcribe(routed_wav, task=params.get("task", "transcribe"), language=params.get("language", "ja"), verbose=None)
# Pack minimal contract: segments, text, language, plus optional meta
out = {
"segments": result.get("segments", []),
"text": result.get("text", ""),
"language": params.get("language", "ja"),
}
if meta.get("routed"):
out["quality_meta"] = meta
return out
def transcribe_to_srt(self, audio_path: Union[str, Path], output_srt_path: Union[str, Path], **kwargs) -> Path:
audio_path = Path(audio_path)
output_srt_path = Path(output_srt_path)
result = self.transcribe(audio_path, **kwargs)
segments = result.get("segments", []) or []
srt_subs: List[srt.Subtitle] = []
for idx, seg in enumerate(segments, 1):
text = (seg.get("text") or "").strip()
if not text:
continue
start = float(seg.get("start", 0.0))
end = float(seg.get("end", start))
srt_subs.append(
srt.Subtitle(
index=idx,
start=datetime.timedelta(seconds=start),
end=datetime.timedelta(seconds=end),
content=text,
)
)
output_srt_path.parent.mkdir(parents=True, exist_ok=True)
with open(output_srt_path, "w", encoding="utf-8") as f:
f.write(srt.compose(srt_subs))
logger.debug(f"Saved SRT to: {output_srt_path}")
return output_srt_path

View File

@ -0,0 +1,111 @@
#!/usr/bin/env python3
"""Brouhaha pipeline: scene detection + quality-aware ASR."""
from __future__ import annotations
from pathlib import Path
from typing import Dict, List, Tuple
from datetime import datetime
import time
import shutil
from whisperjav.pipelines.base_pipeline import BasePipeline
from whisperjav.modules.audio_extraction import AudioExtractor
from whisperjav.modules.quality_aware_asr import QualityAwareASR
from whisperjav.modules.srt_postprocessing import SRTPostProcessor as StandardPostProcessor
from whisperjav.modules.scene_detection import DynamicSceneDetector
from whisperjav.modules.srt_stitching import SRTStitcher
from whisperjav.utils.logger import logger
class BrouhahaPipeline(BasePipeline):
"""Scene detection with QualityAware ASR (brouhaha-based routing)."""
def __init__(
self,
output_dir: str,
temp_dir: str,
keep_temp_files: bool,
subs_language: str,
resolved_config: Dict,
progress_display=None,
**kwargs,
):
super().__init__(output_dir=output_dir, temp_dir=temp_dir, keep_temp_files=keep_temp_files, **kwargs)
self.progress = progress_display
self.subs_language = subs_language
model_cfg = resolved_config["model"]
params = resolved_config["params"]
features = resolved_config["features"]
task = resolved_config["task"]
scene_opts = features.get("scene_detection", {})
post_proc_opts = features.get("post_processing", {})
self.audio_extractor = AudioExtractor()
self.scene_detector = DynamicSceneDetector(**scene_opts)
self.asr = QualityAwareASR(model_config=model_cfg, params=params, task=task)
self.stitcher = SRTStitcher()
lang_code = "en" if subs_language == "english-direct" else "ja"
self.standard_postprocessor = StandardPostProcessor(language=lang_code, **post_proc_opts)
def process(self, media_info: Dict) -> Dict:
start_time = time.time()
input_file = media_info["path"]
media_basename = media_info["basename"]
master_metadata = self.metadata_manager.create_master_metadata(
input_file=input_file, mode=self.get_mode_name(), media_info=media_info
)
# Step 1: Extract audio
self.progress.set_current_step("Transforming audio", 1, 5) if self.progress else None
audio_path = self.temp_dir / f"{media_basename}_extracted.wav"
extracted_audio, duration = self.audio_extractor.extract(input_file, audio_path)
master_metadata["input_info"]["processed_audio_file"] = str(extracted_audio)
master_metadata["input_info"]["audio_duration_seconds"] = duration
# Step 2: Detect scenes
self.progress.set_current_step("Detecting audio scenes", 2, 5) if self.progress else None
scenes_dir = self.temp_dir / "scenes"
scenes_dir.mkdir(exist_ok=True)
scene_paths = self.scene_detector.detect_scenes(extracted_audio, scenes_dir, media_basename)
# Step 3: Transcribe scenes using QualityAware ASR
self.progress.set_current_step("Transcribing scenes (quality-aware)", 3, 5) if self.progress else None
scene_srts_dir = self.temp_dir / "scene_srts"
scene_srts_dir.mkdir(exist_ok=True)
scene_srt_info: List[Tuple[Path, float]] = []
for idx, (scene_path, start_time_sec, _, _) in enumerate(scene_paths):
scene_srt_path = scene_srts_dir / f"{scene_path.stem}.srt"
try:
self.asr.transcribe_to_srt(scene_path, scene_srt_path, task=self.asr.whisper_params.get("task"))
if scene_srt_path.exists() and scene_srt_path.stat().st_size > 0:
scene_srt_info.append((scene_srt_path, start_time_sec))
except Exception as e:
logger.error(f"Scene {idx} transcription failed: {e}")
# Step 4: Stitch
self.progress.set_current_step("Combining scene transcriptions", 4, 5) if self.progress else None
stitched_srt_path = self.temp_dir / f"{media_basename}_stitched.srt"
num_subtitles = self.stitcher.stitch(scene_srt_info, stitched_srt_path)
# Step 5: Post-process
self.progress.set_current_step("Post-processing subtitles", 5, 5) if self.progress else None
lang_code = "en" if self.subs_language == "english-direct" else "ja"
final_srt_path = self.output_dir / f"{media_basename}.{lang_code}.whisperjav.srt"
processed_srt_path, stats = self.standard_postprocessor.process(stitched_srt_path, final_srt_path)
if processed_srt_path != final_srt_path:
shutil.copy2(processed_srt_path, final_srt_path)
total_time = time.time() - start_time
master_metadata["summary"]["total_processing_time_seconds"] = round(total_time, 2)
master_metadata["output_files"]["final_srt"] = str(final_srt_path)
master_metadata["output_files"]["stitched_srt"] = str(stitched_srt_path)
return master_metadata
def get_mode_name(self) -> str:
return "brouhaha"

View File

@ -0,0 +1,276 @@
import logging
# Configure logging level and format
logging.basicConfig(
level=logging.INFO, # options: DEBUG, INFO, WARNING, ERROR, CRITICAL
format='%(asctime)s [%(levelname)s] %(message)s',
handlers=[
logging.StreamHandler() # prints to console
]
)
!pip install -qq pyannote.audio
!pip install openai-whisper
#using forked version of pyannote-whisper so that requirements.txt doesn't clash with brouhaha-vad
#if running on colab, the session will need restarting after these libraries have been installed due to numpy downgrade
!pip install https://github.com/alunkingusw/pyannote-whisper/archive/main.zip
!pip install https://github.com/marianne-m/brouhaha-vad/archive/main.zip
#LIBRARIES
#transcription
import whisper
#audio handling
import torch
import torchaudio
#diarisation
from pyannote.audio import Pipeline
from pyannote_whisper.utils import diarize_text
from pyannote.audio import Model
from pyannote.core import Segment, Annotation
#embedding
from pyannote.audio import Inference
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.preprocessing import normalize
import numpy
#other
import datetime
from collections import defaultdict
#file handling
from google.colab import drive, userdata
#set up files to load
logging.info("Mounting Google Drive. Please follow the instructions to authenticate.")
drive.mount('/content/drive')
# Get Hugging Face token stored in Colab Secrets
HUGGING_FACE = userdata.get('HUGGING_FACE')
# Use the path to your audio file in Google Drive
input_file = '/content/drive/audio_to_process.wav'
#organise the inputs for the transcription pipeline
NUM_SPEAKERS = None
language = 'English'
model_size = 'medium'
model_name = model_size
#name according to available models outlined on https://github.com/openai/whisper?tab=readme-ov-file#available-models-and-languages
if language == 'English' and model_size != 'large':
model_name += '.en'
#inputs for the embedding process
SNR_THRESHOLD = 5.0
MIN_SEGMENT_DURATION = 5.0
EMBEDDING_MATCH_THRESHOLD = 0.7
#MODELS
#diarisation by Pyannote
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=HUGGING_FACE)
#transcription by Whisper
model = whisper.load_model(model_name)
# Move models to the GPU if available
if torch.cuda.is_available():
pipeline.to(torch.device("cuda"))
model.to(torch.device("cuda"))
logging.info("Pyannote and Whisper models moved to GPU")
else:
logging.info("GPU not available, Pyannote and Whisper models running on CPU")
# Perform the intensive stuff - transcription
asr_result = model.transcribe(input_file)
diarisation_result = pipeline(input_file, num_speakers=NUM_SPEAKERS)
# We could merge the results here if we didn't want to look at embeddings
# final_result = diarize_text(asr_result, diarisation_result)
# print out the results or save to file
# for seg, spk, sent in final_result:
#line = f'{seg.start:.2f} {seg.end:.2f} {spk} "{sent}"'
#print(line)
# ------ This next section deals with the embeddings and comparison used to label speakers ------
# Declare our functions used for various stages of embedding comparison
def crop_waveform(waveform, sample_rate, segment):
"""Return waveform cropped to the specified segment."""
start_sample = int(segment.start * sample_rate)
end_sample = int(segment.end * sample_rate)
return waveform[:, start_sample:end_sample]
def get_reference_embeddings(ref_dict):
"""Return dictionary of normalized reference embeddings."""
embeddings = {}
for i, (name, path) in enumerate(ref_dict.items()):
try:
emb = embedding_inference(path).reshape(1, -1)
emb = normalize(emb)
embeddings[i] = emb
logging.info(f"Loaded reference embedding for '{name}' as ID {i}")
except Exception as e:
logging.warning(f"Error loading reference '{name}': {e}")
return embeddings
def match_speaker_by_embedding(embedding, speaker_embeddings, speaker_names, threshold=EMBEDDING_MATCH_THRESHOLD):
"""Return the best matching speaker name based on cosine similarity."""
best_match = None
highest_similarity = -1
# Compare with reference embeddings
for ref_id, ref_embedding in speaker_embeddings.items():
similarity = cosine_similarity(embedding, ref_embedding)[0][0]
if similarity > highest_similarity:
highest_similarity = similarity
best_match = speaker_names[ref_id]
logging.info(f"Highest similarity with '{best_match}': {highest_similarity}")
if highest_similarity >= threshold:
return best_match
else:
return None
def remove_speaker_from_diarisation(diarisation, speaker_to_remove):
"""Return a copy of diarisation with the specified speaker removed."""
new_diarisation = Annotation(uri=diarisation.uri)
for segment, track, speaker in diarisation.itertracks(yield_label=True):
if speaker != speaker_to_remove:
new_diarisation[segment, track] = speaker
return new_diarisation
def rename_speaker_in_diarisation(diarisation, old_label, new_label):
"""Return a copy of diarisation with one speaker label renamed."""
updated_diarisation = Annotation(uri=diarisation.uri)
for segment, track, speaker in diarisation.itertracks(yield_label=True):
if speaker == old_label:
updated_diarisation[segment, track] = new_label
else:
updated_diarisation[segment, track] = speaker
return updated_diarisation
#load speaker samples and generate the embeddings to be tested below
embedding_model = Model.from_pretrained("pyannote/embedding",
use_auth_token=HUGGING_FACE)
# Move model to GPU if available
if torch.cuda.is_available():
embedding_model.to(torch.device("cuda"))
logging.info("Embedding model moved to GPU")
else:
logging.info("GPU not available, embedding model running on CPU")
embedding_inference = Inference(embedding_model, window="whole")
# Load full audio for cropping our speaker clips
waveform, sample_rate = torchaudio.load(input_file) # mono only
# Define the dictionary of known speaker names and the sample audio
speaker_samples = {
"speaker_one": "/content/drive/speaker_one_sample.wav",
"speaker_two": "/content/drive/speaker_two_sample.wav",
"speaker_three": "/content/drive/speaker_three_sample.wav",
"speaker_four": "/content/drive/speaker_four_sample.wav",
"speaker_five": "/content/drive/speaker_five_sample.wav",
"speaker_six": "/content/drive/speaker_six_sample.wav",
"speaker_seven": "/content/drive/speaker_seven_sample.wav",
}
#turn those files into embeddings
speaker_embeddings = get_reference_embeddings(speaker_samples)
# Use diarisation from previous code block to loop through identified speakers
# diarisation_result
# Group segments by speaker so we can run comparisons
segments_by_speaker = defaultdict(list)
for turn, _, speaker in diarisation_result.itertracks(yield_label=True):
segments_by_speaker[speaker].append(Segment(turn.start, turn.end))
# Load SNR model which will help identify the best audio clip to compare
snr_model = Model.from_pretrained("pyannote/brouhaha", use_auth_token=HUGGING_FACE)
# Move model to GPU if available
if torch.cuda.is_available():
snr_model.to(torch.device("cuda"))
logging.info("SNR model moved to GPU")
else:
logging.info("GPU not available, SNR model running on CPU")
# apply model
snr_inference = Inference(snr_model)
# Step through each speaker label and its associated segments from diarisation
for speaker, segments in segments_by_speaker.items():
valid_embeddings = []
# Process each segment for that speaker
for segment in segments:
duration = segment.end - segment.start
# Ignore short segments if there are other longer ones
if duration < MIN_SEGMENT_DURATION and len(segments) > 1:
continue
# Crop audio to just this segment
cropped = crop_waveform(waveform, sample_rate, segment)
# Apply SNR model to filter out low-quality audio
snr_result = snr_inference({"waveform": cropped, "sample_rate": sample_rate})
snr_values = [snr for frame, (vad, snr, c50) in snr_result if vad > 0.5]
if not snr_values:
continue # skip if there's no speech
avg_snr = sum(snr_values) / len(snr_values)
if avg_snr < SNR_THRESHOLD: # optional threshold to skip noisy clips
continue
# Compute embedding for this valid segment
embedding = embedding_inference({
"waveform": cropped,
"sample_rate": sample_rate
}).reshape(1, -1)
embedding = normalize(embedding)
valid_embeddings.append(embedding)
# Average the embeddings for this speaker if any were collected
if valid_embeddings:
mean_embedding = numpy.mean(numpy.vstack(valid_embeddings), axis=0, keepdims=True)
match = match_speaker_by_embedding(mean_embedding, speaker_embeddings, list(speaker_samples.keys()))
if match:
logging.info(f"Speaker '{speaker}' best matches reference speaker: {match}")
diarisation_result = rename_speaker_in_diarisation(diarisation_result, speaker, match)
else:
logging.info(f"Speaker '{speaker}' could not be confidently matched and will be removed.")
diarisation_result = remove_speaker_from_diarisation(diarisation_result, speaker)
else:
logging.info(f"No valid segments found for speaker '{speaker}', removing from diarisation.")
diarisation_result = remove_speaker_from_diarisation(diarisation_result, speaker)
#finally, merge the diarisation results with the whisper output and export to terminal.
final_result = diarize_text(asr_result, diarisation_result)
print(f"Results")
# print out the results
for seg, spk, sent in final_result:
line = f'{seg.start:.2f} {seg.end:.2f} {spk} "{sent}"'
print(line)
# or could save to file
output_path = "/content/drive/final_result.txt"
with open(output_path, "w") as f:
for seg, spk, sent in final_result:
line = f'{seg.start:.2f} {seg.end:.2f} {spk} "{sent}"\n'
f.write(line)