Refactor tool-speechtotext: extract sttlib shared library and add tests
Extract duplicated code (Whisper loading, audio recording, transcription, VAD processing) into reusable sttlib/ package. Rewrite all 3 scripts as thin wrappers. Add 24 unit tests with mocked hardware. Fix GPU fallback bug in assistant.py and args.system assignment bug.
This commit is contained in:
@@ -1,91 +1,14 @@
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import webrtcvad
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import subprocess
|
||||
import threading
|
||||
import queue
|
||||
import collections
|
||||
import time
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
||||
|
||||
# --- Constants ---
|
||||
SAMPLE_RATE = 16000
|
||||
CHANNELS = 1
|
||||
FRAME_DURATION_MS = 30
|
||||
FRAME_SIZE = int(SAMPLE_RATE * FRAME_DURATION_MS / 1000) # 480 samples
|
||||
MIN_UTTERANCE_FRAMES = 10 # ~300ms minimum to filter coughs/clicks
|
||||
|
||||
HALLUCINATION_PATTERNS = [
|
||||
"thank you", "thanks for watching", "subscribe",
|
||||
"bye", "the end", "thank you for watching",
|
||||
"please subscribe", "like and subscribe",
|
||||
]
|
||||
|
||||
# --- Thread-safe audio queue ---
|
||||
audio_queue = queue.Queue()
|
||||
|
||||
|
||||
def audio_callback(indata, frames, time_info, status):
|
||||
if status:
|
||||
print(status, file=sys.stderr)
|
||||
audio_queue.put(bytes(indata))
|
||||
|
||||
|
||||
# --- Whisper model loading (reused pattern from assistant.py) ---
|
||||
def load_whisper_model(model_size):
|
||||
print(f"Loading Whisper model ({model_size})...")
|
||||
try:
|
||||
return WhisperModel(model_size, device="cuda", compute_type="float16")
|
||||
except Exception as e:
|
||||
print(f"GPU loading failed: {e}")
|
||||
print("Falling back to CPU (int8)")
|
||||
return WhisperModel(model_size, device="cpu", compute_type="int8")
|
||||
|
||||
|
||||
# --- VAD State Machine ---
|
||||
class VADProcessor:
|
||||
def __init__(self, aggressiveness, silence_threshold):
|
||||
self.vad = webrtcvad.Vad(aggressiveness)
|
||||
self.silence_threshold = silence_threshold
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.triggered = False
|
||||
self.utterance_frames = []
|
||||
self.silence_duration = 0.0
|
||||
self.pre_buffer = collections.deque(maxlen=10) # ~300ms pre-roll
|
||||
|
||||
def process_frame(self, frame_bytes):
|
||||
"""Process one 30ms frame. Returns utterance bytes when complete, else None."""
|
||||
is_speech = self.vad.is_speech(frame_bytes, SAMPLE_RATE)
|
||||
|
||||
if not self.triggered:
|
||||
self.pre_buffer.append(frame_bytes)
|
||||
if is_speech:
|
||||
self.triggered = True
|
||||
self.silence_duration = 0.0
|
||||
self.utterance_frames = list(self.pre_buffer)
|
||||
self.utterance_frames.append(frame_bytes)
|
||||
pass # silent until transcription confirms speech
|
||||
else:
|
||||
self.utterance_frames.append(frame_bytes)
|
||||
if is_speech:
|
||||
self.silence_duration = 0.0
|
||||
else:
|
||||
self.silence_duration += FRAME_DURATION_MS / 1000.0
|
||||
if self.silence_duration >= self.silence_threshold:
|
||||
if len(self.utterance_frames) < MIN_UTTERANCE_FRAMES:
|
||||
self.reset()
|
||||
return None
|
||||
result = b"".join(self.utterance_frames)
|
||||
self.reset()
|
||||
return result
|
||||
return None
|
||||
import sounddevice as sd
|
||||
from sttlib import (
|
||||
load_whisper_model, transcribe, is_hallucination, pcm_bytes_to_float32,
|
||||
VADProcessor, audio_callback, audio_queue,
|
||||
SAMPLE_RATE, CHANNELS, FRAME_SIZE,
|
||||
)
|
||||
|
||||
|
||||
# --- Typer Interface (xdotool) ---
|
||||
@@ -99,6 +22,7 @@ class Typer:
|
||||
except FileNotFoundError:
|
||||
print("ERROR: xdotool not found. Install it:")
|
||||
print(" sudo apt-get install xdotool")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
def type_text(self, text, submit_now=False):
|
||||
@@ -120,24 +44,6 @@ class Typer:
|
||||
pass
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
def pcm_bytes_to_float32(pcm_bytes):
|
||||
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
||||
return audio_int16.astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def transcribe(model, audio_float32):
|
||||
segments, _ = model.transcribe(audio_float32, beam_size=5)
|
||||
return "".join(segment.text for segment in segments).strip()
|
||||
|
||||
|
||||
def is_hallucination(text):
|
||||
lowered = text.lower().strip()
|
||||
if len(lowered) < 3:
|
||||
return True
|
||||
return any(p in lowered for p in HALLUCINATION_PATTERNS)
|
||||
|
||||
|
||||
# --- CLI ---
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
||||
Reference in New Issue
Block a user