Refactor tool-speechtotext: extract sttlib shared library and add tests
Extract duplicated code (Whisper loading, audio recording, transcription, VAD processing) into reusable sttlib/ package. Rewrite all 3 scripts as thin wrappers. Add 24 unit tests with mocked hardware. Fix GPU fallback bug in assistant.py and args.system assignment bug.
This commit is contained in:
@@ -1,5 +1,11 @@
|
||||
{
|
||||
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
||||
"python-envs.pythonProjects": []
|
||||
"python-envs.pythonProjects": [],
|
||||
"python.testing.pytestEnabled": true,
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestArgs": [
|
||||
"tests",
|
||||
"-v"
|
||||
]
|
||||
}
|
||||
@@ -12,11 +12,20 @@ Speech-to-text command line utilities leveraging local models (faster-whisper, O
|
||||
## Tools
|
||||
- `assistant.py` / `talk.sh` — transcribe speech, copy to clipboard, optionally send to Ollama
|
||||
- `voice_to_terminal.py` / `terminal.sh` — voice-controlled terminal via Ollama tool calling
|
||||
- `voice_to_xdotool.py` / `dotool.sh` — hands-free voice typing into any focused window (VAD + xdotool)
|
||||
- `voice_to_xdotool.py` / `xdotool.sh` — hands-free voice typing into any focused window (VAD + xdotool)
|
||||
|
||||
## Shared Library
|
||||
- `sttlib/` — shared package used by all scripts and importable by other projects
|
||||
- `whisper_loader.py` — model loading with GPU→CPU fallback
|
||||
- `audio.py` — press-enter recording, PCM conversion
|
||||
- `transcription.py` — Whisper transcribe wrapper, hallucination filter
|
||||
- `vad.py` — VADProcessor, audio callback, constants
|
||||
- Other projects import via: `sys.path.insert(0, "/path/to/tool-speechtotext")`
|
||||
|
||||
## Testing
|
||||
- To test scripts: `mamba run -n whisper-ollama python <script.py> --model-size base`
|
||||
- Run tests: `mamba run -n whisper-ollama python -m pytest tests/`
|
||||
- Use `--model-size base` for faster iteration during development
|
||||
- Tests mock hardware (Whisper model, VAD, mic) — no GPU/mic needed to run them
|
||||
- Audio device is available — live mic testing is possible
|
||||
- Test xdotool output by focusing a text editor window
|
||||
|
||||
@@ -24,12 +33,13 @@ Speech-to-text command line utilities leveraging local models (faster-whisper, O
|
||||
- Conda: faster-whisper, sounddevice, numpy, pyperclip, requests, ollama
|
||||
- Pip (in conda env): webrtcvad
|
||||
- System: libportaudio2, xdotool
|
||||
- Dev: pytest
|
||||
|
||||
## Conventions
|
||||
- Shell wrappers go in .sh files using `mamba run -n whisper-ollama`
|
||||
- All scripts set `CT2_CUDA_ALLOW_FP16=1`
|
||||
- Shared code lives in `sttlib/` — scripts are thin entry points that import from it
|
||||
- Whisper model loading always has GPU (cuda/float16) -> CPU (cpu/int8) fallback
|
||||
- Keep scripts self-contained (no shared module)
|
||||
- `CT2_CUDA_ALLOW_FP16=1` is set by `sttlib.whisper_loader` at import time
|
||||
- Don't print output for non-actionable events
|
||||
|
||||
## Preferences
|
||||
|
||||
@@ -1,57 +1,17 @@
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import argparse
|
||||
import pyperclip
|
||||
import requests
|
||||
import sys
|
||||
import argparse
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
import os
|
||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
||||
from sttlib import load_whisper_model, record_until_enter, transcribe
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_SIZE = "medium" # Options: "base", "small", "medium", "large-v3"
|
||||
OLLAMA_URL = "http://localhost:11434/api/generate" # Default is 11434
|
||||
OLLAMA_URL = "http://localhost:11434/api/generate"
|
||||
DEFAULT_OLLAMA_MODEL = "qwen3:latest"
|
||||
|
||||
# Load Whisper on GPU
|
||||
# float16 is faster and uses less VRAM on NVIDIA cards
|
||||
print("Loading Whisper model...")
|
||||
|
||||
|
||||
try:
|
||||
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="float16")
|
||||
except Exception as e:
|
||||
print(f"Error loading GPU: {e}")
|
||||
print("Falling back to CPU (Check your CUDA/cuDNN installation)")
|
||||
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="int16")
|
||||
|
||||
|
||||
def record_audio():
|
||||
fs = 16000
|
||||
print("\n[READY] Press Enter to START recording...")
|
||||
input()
|
||||
print("[RECORDING] Press Enter to STOP...")
|
||||
|
||||
recording = []
|
||||
|
||||
def callback(indata, frames, time, status):
|
||||
if status:
|
||||
print(status, file=sys.stderr)
|
||||
recording.append(indata.copy())
|
||||
|
||||
with sd.InputStream(samplerate=fs, channels=1, callback=callback):
|
||||
input()
|
||||
|
||||
return np.concatenate(recording, axis=0)
|
||||
|
||||
|
||||
def main():
|
||||
# 1. Setup Parser
|
||||
print(f"System active. Model: {DEFAULT_OLLAMA_MODEL}")
|
||||
parser = argparse.ArgumentParser(description="Whisper + Ollama CLI")
|
||||
|
||||
# Known Arguments (Hardcoded logic)
|
||||
parser.add_argument("--nollm", "-n", action='store_true',
|
||||
help="turn off llm")
|
||||
parser.add_argument("--system", "-s", default=None,
|
||||
@@ -65,30 +25,27 @@ def main():
|
||||
parser.add_argument(
|
||||
"--temp", default='0.7', help="temperature")
|
||||
|
||||
# 2. Capture "Unknown" arguments
|
||||
# args = known values, unknown = a list like ['--num_ctx', '4096', '--temp', '0.7']
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
# Convert unknown list to a dictionary for the Ollama 'options' field
|
||||
# This logic pairs ['--key', 'value'] into {key: value}
|
||||
extra_options = {}
|
||||
for i in range(0, len(unknown), 2):
|
||||
key = unknown[i].lstrip('-') # remove the '--'
|
||||
key = unknown[i].lstrip('-')
|
||||
val = unknown[i+1]
|
||||
# Try to convert numbers to actual ints/floats
|
||||
try:
|
||||
val = float(val) if '.' in val else int(val)
|
||||
except ValueError:
|
||||
pass
|
||||
extra_options[key] = val
|
||||
|
||||
model = load_whisper_model(args.model_size)
|
||||
|
||||
while True:
|
||||
try:
|
||||
audio_data = record_audio()
|
||||
audio_data = record_until_enter()
|
||||
|
||||
print("[TRANSCRIBING]...")
|
||||
segments, _ = model.transcribe(audio_data.flatten(), beam_size=5)
|
||||
text = "".join([segment.text for segment in segments]).strip()
|
||||
text = transcribe(model, audio_data.flatten())
|
||||
|
||||
if not text:
|
||||
print("No speech detected. Try again.")
|
||||
@@ -97,8 +54,7 @@ def main():
|
||||
print(f"You said: {text}")
|
||||
pyperclip.copy(text)
|
||||
|
||||
if (args.nollm == False):
|
||||
# Send to Ollama
|
||||
if not args.nollm:
|
||||
print(f"[OLLAMA] Thinking...")
|
||||
payload = {
|
||||
"model": args.ollama_model,
|
||||
@@ -108,9 +64,9 @@ def main():
|
||||
}
|
||||
|
||||
if args.system:
|
||||
payload["system"] = args
|
||||
response = requests.post(OLLAMA_URL, json=payload)
|
||||
payload["system"] = args.system
|
||||
|
||||
response = requests.post(OLLAMA_URL, json=payload)
|
||||
result = response.json().get("response", "")
|
||||
print(f"\nLLM Response:\n{result}\n")
|
||||
else:
|
||||
|
||||
7
python/tool-speechtotext/sttlib/__init__.py
Normal file
7
python/tool-speechtotext/sttlib/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from sttlib.whisper_loader import load_whisper_model
|
||||
from sttlib.audio import record_until_enter, pcm_bytes_to_float32
|
||||
from sttlib.transcription import transcribe, is_hallucination, HALLUCINATION_PATTERNS
|
||||
from sttlib.vad import (
|
||||
VADProcessor, audio_callback, audio_queue,
|
||||
SAMPLE_RATE, CHANNELS, FRAME_DURATION_MS, FRAME_SIZE, MIN_UTTERANCE_FRAMES,
|
||||
)
|
||||
28
python/tool-speechtotext/sttlib/audio.py
Normal file
28
python/tool-speechtotext/sttlib/audio.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
|
||||
|
||||
def record_until_enter(sample_rate=16000):
|
||||
"""Record audio until user presses Enter. Returns float32 numpy array."""
|
||||
print("\n[READY] Press Enter to START recording...")
|
||||
input()
|
||||
print("[RECORDING] Press Enter to STOP...")
|
||||
|
||||
recording = []
|
||||
|
||||
def callback(indata, frames, time, status):
|
||||
if status:
|
||||
print(status, file=sys.stderr)
|
||||
recording.append(indata.copy())
|
||||
|
||||
with sd.InputStream(samplerate=sample_rate, channels=1, callback=callback):
|
||||
input()
|
||||
|
||||
return np.concatenate(recording, axis=0)
|
||||
|
||||
|
||||
def pcm_bytes_to_float32(pcm_bytes):
|
||||
"""Convert raw 16-bit PCM bytes to float32 array normalized to [-1, 1]."""
|
||||
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
||||
return audio_int16.astype(np.float32) / 32768.0
|
||||
19
python/tool-speechtotext/sttlib/transcription.py
Normal file
19
python/tool-speechtotext/sttlib/transcription.py
Normal file
@@ -0,0 +1,19 @@
|
||||
HALLUCINATION_PATTERNS = [
|
||||
"thank you", "thanks for watching", "subscribe",
|
||||
"bye", "the end", "thank you for watching",
|
||||
"please subscribe", "like and subscribe",
|
||||
]
|
||||
|
||||
|
||||
def transcribe(model, audio_float32):
|
||||
"""Transcribe audio using Whisper. Returns stripped text."""
|
||||
segments, _ = model.transcribe(audio_float32, beam_size=5)
|
||||
return "".join(segment.text for segment in segments).strip()
|
||||
|
||||
|
||||
def is_hallucination(text):
|
||||
"""Return True if text looks like a Whisper hallucination."""
|
||||
lowered = text.lower().strip()
|
||||
if len(lowered) < 3:
|
||||
return True
|
||||
return any(p in lowered for p in HALLUCINATION_PATTERNS)
|
||||
58
python/tool-speechtotext/sttlib/vad.py
Normal file
58
python/tool-speechtotext/sttlib/vad.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import sys
|
||||
import queue
|
||||
import collections
|
||||
import webrtcvad
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
CHANNELS = 1
|
||||
FRAME_DURATION_MS = 30
|
||||
FRAME_SIZE = int(SAMPLE_RATE * FRAME_DURATION_MS / 1000) # 480 samples
|
||||
MIN_UTTERANCE_FRAMES = 10 # ~300ms minimum to filter coughs/clicks
|
||||
|
||||
audio_queue = queue.Queue()
|
||||
|
||||
|
||||
def audio_callback(indata, frames, time_info, status):
|
||||
"""sounddevice callback that pushes raw bytes to the audio queue."""
|
||||
if status:
|
||||
print(status, file=sys.stderr)
|
||||
audio_queue.put(bytes(indata))
|
||||
|
||||
|
||||
class VADProcessor:
|
||||
def __init__(self, aggressiveness, silence_threshold):
|
||||
self.vad = webrtcvad.Vad(aggressiveness)
|
||||
self.silence_threshold = silence_threshold
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.triggered = False
|
||||
self.utterance_frames = []
|
||||
self.silence_duration = 0.0
|
||||
self.pre_buffer = collections.deque(maxlen=10) # ~300ms pre-roll
|
||||
|
||||
def process_frame(self, frame_bytes):
|
||||
"""Process one 30ms frame. Returns utterance bytes when complete, else None."""
|
||||
is_speech = self.vad.is_speech(frame_bytes, SAMPLE_RATE)
|
||||
|
||||
if not self.triggered:
|
||||
self.pre_buffer.append(frame_bytes)
|
||||
if is_speech:
|
||||
self.triggered = True
|
||||
self.silence_duration = 0.0
|
||||
self.utterance_frames = list(self.pre_buffer)
|
||||
self.utterance_frames.append(frame_bytes)
|
||||
else:
|
||||
self.utterance_frames.append(frame_bytes)
|
||||
if is_speech:
|
||||
self.silence_duration = 0.0
|
||||
else:
|
||||
self.silence_duration += FRAME_DURATION_MS / 1000.0
|
||||
if self.silence_duration >= self.silence_threshold:
|
||||
if len(self.utterance_frames) < MIN_UTTERANCE_FRAMES:
|
||||
self.reset()
|
||||
return None
|
||||
result = b"".join(self.utterance_frames)
|
||||
self.reset()
|
||||
return result
|
||||
return None
|
||||
15
python/tool-speechtotext/sttlib/whisper_loader.py
Normal file
15
python/tool-speechtotext/sttlib/whisper_loader.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
||||
|
||||
|
||||
def load_whisper_model(model_size):
|
||||
"""Load Whisper with GPU (cuda/float16) -> CPU (cpu/int8) fallback."""
|
||||
print(f"Loading Whisper model ({model_size})...")
|
||||
try:
|
||||
return WhisperModel(model_size, device="cuda", compute_type="float16")
|
||||
except Exception as e:
|
||||
print(f"GPU loading failed: {e}")
|
||||
print("Falling back to CPU (int8)")
|
||||
return WhisperModel(model_size, device="cpu", compute_type="int8")
|
||||
0
python/tool-speechtotext/tests/__init__.py
Normal file
0
python/tool-speechtotext/tests/__init__.py
Normal file
38
python/tool-speechtotext/tests/test_audio.py
Normal file
38
python/tool-speechtotext/tests/test_audio.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import struct
|
||||
import numpy as np
|
||||
from sttlib.audio import pcm_bytes_to_float32
|
||||
|
||||
|
||||
def test_known_value():
|
||||
# 16384 in int16 -> 0.5 in float32
|
||||
pcm = struct.pack("<h", 16384)
|
||||
result = pcm_bytes_to_float32(pcm)
|
||||
assert abs(result[0] - 0.5) < 1e-5
|
||||
|
||||
|
||||
def test_silence():
|
||||
pcm = b"\x00\x00" * 10
|
||||
result = pcm_bytes_to_float32(pcm)
|
||||
assert np.all(result == 0.0)
|
||||
|
||||
|
||||
def test_full_scale():
|
||||
# max int16 = 32767 -> ~1.0
|
||||
pcm = struct.pack("<h", 32767)
|
||||
result = pcm_bytes_to_float32(pcm)
|
||||
assert abs(result[0] - (32767 / 32768.0)) < 1e-5
|
||||
|
||||
|
||||
def test_negative():
|
||||
# min int16 = -32768 -> -1.0
|
||||
pcm = struct.pack("<h", -32768)
|
||||
result = pcm_bytes_to_float32(pcm)
|
||||
assert result[0] == -1.0
|
||||
|
||||
|
||||
def test_round_trip_shape():
|
||||
# 100 samples worth of bytes
|
||||
pcm = b"\x00\x00" * 100
|
||||
result = pcm_bytes_to_float32(pcm)
|
||||
assert result.shape == (100,)
|
||||
assert result.dtype == np.float32
|
||||
78
python/tool-speechtotext/tests/test_transcription.py
Normal file
78
python/tool-speechtotext/tests/test_transcription.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from unittest.mock import MagicMock
|
||||
from sttlib.transcription import transcribe, is_hallucination
|
||||
|
||||
|
||||
# --- is_hallucination tests ---
|
||||
|
||||
def test_known_hallucinations():
|
||||
assert is_hallucination("Thank you")
|
||||
assert is_hallucination("thanks for watching")
|
||||
assert is_hallucination("Subscribe")
|
||||
assert is_hallucination("the end")
|
||||
|
||||
|
||||
def test_short_text():
|
||||
assert is_hallucination("hi")
|
||||
assert is_hallucination("")
|
||||
assert is_hallucination("a")
|
||||
|
||||
|
||||
def test_normal_text():
|
||||
assert not is_hallucination("Hello how are you")
|
||||
assert not is_hallucination("Please open the terminal")
|
||||
|
||||
|
||||
def test_case_insensitivity():
|
||||
assert is_hallucination("THANK YOU")
|
||||
assert is_hallucination("Thank You For Watching")
|
||||
|
||||
|
||||
def test_substring_match():
|
||||
assert is_hallucination("I want to subscribe to your channel")
|
||||
|
||||
|
||||
def test_exactly_three_chars():
|
||||
assert not is_hallucination("hey")
|
||||
|
||||
|
||||
# --- transcribe tests ---
|
||||
|
||||
def _make_segment(text):
|
||||
seg = MagicMock()
|
||||
seg.text = text
|
||||
return seg
|
||||
|
||||
|
||||
def test_transcribe_joins_segments():
|
||||
model = MagicMock()
|
||||
model.transcribe.return_value = (
|
||||
[_make_segment("Hello "), _make_segment("world")],
|
||||
None,
|
||||
)
|
||||
result = transcribe(model, MagicMock())
|
||||
assert result == "Hello world"
|
||||
|
||||
|
||||
def test_transcribe_empty():
|
||||
model = MagicMock()
|
||||
model.transcribe.return_value = ([], None)
|
||||
result = transcribe(model, MagicMock())
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_transcribe_strips_whitespace():
|
||||
model = MagicMock()
|
||||
model.transcribe.return_value = (
|
||||
[_make_segment(" hello ")],
|
||||
None,
|
||||
)
|
||||
result = transcribe(model, MagicMock())
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
def test_transcribe_passes_beam_size():
|
||||
model = MagicMock()
|
||||
model.transcribe.return_value = ([], None)
|
||||
audio = MagicMock()
|
||||
transcribe(model, audio)
|
||||
model.transcribe.assert_called_once_with(audio, beam_size=5)
|
||||
151
python/tool-speechtotext/tests/test_vad.py
Normal file
151
python/tool-speechtotext/tests/test_vad.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sttlib.vad import VADProcessor, FRAME_DURATION_MS, MIN_UTTERANCE_FRAMES
|
||||
|
||||
|
||||
def _make_vad_processor(aggressiveness=3, silence_threshold=0.8):
|
||||
"""Create VADProcessor with a mocked webrtcvad.Vad."""
|
||||
with patch("sttlib.vad.webrtcvad.Vad") as mock_vad_cls:
|
||||
mock_vad = MagicMock()
|
||||
mock_vad_cls.return_value = mock_vad
|
||||
proc = VADProcessor(aggressiveness, silence_threshold)
|
||||
return proc, mock_vad
|
||||
|
||||
|
||||
def _frame(label="x"):
|
||||
"""Return a fake 30ms frame (just needs to be distinct bytes)."""
|
||||
return label.encode() * 960 # 480 samples * 2 bytes
|
||||
|
||||
|
||||
def test_no_speech_returns_none():
|
||||
proc, mock_vad = _make_vad_processor()
|
||||
mock_vad.is_speech.return_value = False
|
||||
|
||||
for _ in range(100):
|
||||
assert proc.process_frame(_frame()) is None
|
||||
|
||||
|
||||
def test_speech_then_silence_triggers_utterance():
|
||||
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
|
||||
|
||||
# Feed enough speech frames
|
||||
speech_count = MIN_UTTERANCE_FRAMES + 5
|
||||
mock_vad.is_speech.return_value = True
|
||||
for _ in range(speech_count):
|
||||
result = proc.process_frame(_frame("s"))
|
||||
assert result is None # not done yet
|
||||
|
||||
# Feed silence frames until threshold (0.3s = 10 frames at 30ms)
|
||||
mock_vad.is_speech.return_value = False
|
||||
result = None
|
||||
for _ in range(20):
|
||||
result = proc.process_frame(_frame("q"))
|
||||
if result is not None:
|
||||
break
|
||||
|
||||
assert result is not None
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_short_utterance_filtered():
|
||||
# Use very short silence threshold so silence frames don't push total
|
||||
# past MIN_UTTERANCE_FRAMES. With threshold=0.09s (3 frames of silence):
|
||||
# 0 pre-buffer + 1 speech + 3 silence = 4 total < MIN_UTTERANCE_FRAMES (10)
|
||||
proc, mock_vad = _make_vad_processor(silence_threshold=0.09)
|
||||
|
||||
# Single speech frame triggers VAD
|
||||
mock_vad.is_speech.return_value = True
|
||||
proc.process_frame(_frame("s"))
|
||||
|
||||
# Immediately go silent — threshold reached in 3 frames
|
||||
mock_vad.is_speech.return_value = False
|
||||
result = None
|
||||
for _ in range(20):
|
||||
result = proc.process_frame(_frame("q"))
|
||||
if result is not None:
|
||||
break
|
||||
|
||||
# Should be filtered (too short — only 4 total frames)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_pre_buffer_included():
|
||||
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
|
||||
|
||||
# Fill pre-buffer with non-speech frames
|
||||
mock_vad.is_speech.return_value = False
|
||||
pre_frame = _frame("p")
|
||||
for _ in range(10):
|
||||
proc.process_frame(pre_frame)
|
||||
|
||||
# Speech starts
|
||||
mock_vad.is_speech.return_value = True
|
||||
speech_frame = _frame("s")
|
||||
for _ in range(MIN_UTTERANCE_FRAMES):
|
||||
proc.process_frame(speech_frame)
|
||||
|
||||
# Silence to trigger
|
||||
mock_vad.is_speech.return_value = False
|
||||
result = None
|
||||
for _ in range(20):
|
||||
result = proc.process_frame(_frame("q"))
|
||||
if result is not None:
|
||||
break
|
||||
|
||||
assert result is not None
|
||||
# Result should contain pre-buffer frames
|
||||
assert pre_frame in result
|
||||
|
||||
|
||||
def test_reset_after_utterance():
|
||||
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
|
||||
|
||||
# First utterance
|
||||
mock_vad.is_speech.return_value = True
|
||||
for _ in range(MIN_UTTERANCE_FRAMES + 5):
|
||||
proc.process_frame(_frame("s"))
|
||||
|
||||
mock_vad.is_speech.return_value = False
|
||||
for _ in range(20):
|
||||
result = proc.process_frame(_frame("q"))
|
||||
if result is not None:
|
||||
break
|
||||
assert result is not None
|
||||
|
||||
# After reset, should be able to collect a second utterance
|
||||
assert not proc.triggered
|
||||
assert proc.utterance_frames == []
|
||||
|
||||
mock_vad.is_speech.return_value = True
|
||||
for _ in range(MIN_UTTERANCE_FRAMES + 5):
|
||||
proc.process_frame(_frame("s"))
|
||||
|
||||
mock_vad.is_speech.return_value = False
|
||||
result2 = None
|
||||
for _ in range(20):
|
||||
result2 = proc.process_frame(_frame("q"))
|
||||
if result2 is not None:
|
||||
break
|
||||
assert result2 is not None
|
||||
|
||||
|
||||
def test_silence_threshold_boundary():
|
||||
# Use 0.3s threshold: 0.3 / 0.03 = exactly 10 frames needed
|
||||
threshold = 0.3
|
||||
proc, mock_vad = _make_vad_processor(silence_threshold=threshold)
|
||||
|
||||
# Start with speech
|
||||
mock_vad.is_speech.return_value = True
|
||||
for _ in range(MIN_UTTERANCE_FRAMES + 5):
|
||||
proc.process_frame(_frame("s"))
|
||||
|
||||
frames_needed = 10 # 0.3s / 0.03s per frame
|
||||
mock_vad.is_speech.return_value = False
|
||||
|
||||
# Feed one less than needed — should NOT trigger
|
||||
for i in range(frames_needed - 1):
|
||||
result = proc.process_frame(_frame("q"))
|
||||
assert result is None, f"Triggered too early at frame {i}"
|
||||
|
||||
# The 10th frame should trigger (silence_duration = 0.3 >= 0.3)
|
||||
result = proc.process_frame(_frame("q"))
|
||||
assert result is not None
|
||||
37
python/tool-speechtotext/tests/test_whisper_loader.py
Normal file
37
python/tool-speechtotext/tests/test_whisper_loader.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sttlib.whisper_loader import load_whisper_model
|
||||
|
||||
|
||||
@patch("sttlib.whisper_loader.WhisperModel")
|
||||
def test_gpu_success(mock_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_cls.return_value = mock_model
|
||||
|
||||
result = load_whisper_model("base")
|
||||
|
||||
assert result is mock_model
|
||||
mock_cls.assert_called_once_with("base", device="cuda", compute_type="float16")
|
||||
|
||||
|
||||
@patch("sttlib.whisper_loader.WhisperModel")
|
||||
def test_gpu_fails_cpu_fallback(mock_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_cls.side_effect = [RuntimeError("no CUDA"), mock_model]
|
||||
|
||||
result = load_whisper_model("base")
|
||||
|
||||
assert result is mock_model
|
||||
assert mock_cls.call_count == 2
|
||||
_, kwargs = mock_cls.call_args
|
||||
assert kwargs == {"device": "cpu", "compute_type": "int8"}
|
||||
|
||||
|
||||
@patch("sttlib.whisper_loader.WhisperModel")
|
||||
def test_both_fail_propagates(mock_cls):
|
||||
mock_cls.side_effect = RuntimeError("no device")
|
||||
|
||||
try:
|
||||
load_whisper_model("base")
|
||||
assert False, "Should have raised"
|
||||
except RuntimeError:
|
||||
pass
|
||||
@@ -1,31 +1,16 @@
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import pyperclip
|
||||
import sys
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import ollama
|
||||
import json
|
||||
from faster_whisper import WhisperModel
|
||||
import ollama
|
||||
from sttlib import load_whisper_model, record_until_enter, transcribe
|
||||
|
||||
# --- Configuration ---
|
||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
||||
MODEL_SIZE = "medium"
|
||||
OLLAMA_MODEL = "qwen2.5-coder:7b"
|
||||
CONFIRM_COMMANDS = True # Set to False to run commands instantly
|
||||
|
||||
# Load Whisper on GPU
|
||||
print("Loading Whisper model...")
|
||||
try:
|
||||
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="float16")
|
||||
except Exception as e:
|
||||
print(f"Error loading GPU: {e}, falling back to CPU")
|
||||
model = WhisperModel(MODEL_SIZE, device="cpu", compute_type="int8")
|
||||
|
||||
# --- Terminal Tool ---
|
||||
|
||||
|
||||
def run_terminal_command(command: str):
|
||||
"""
|
||||
Executes a bash command in the Linux terminal.
|
||||
@@ -33,8 +18,7 @@ def run_terminal_command(command: str):
|
||||
"""
|
||||
if CONFIRM_COMMANDS:
|
||||
print(f"\n{'='*40}")
|
||||
print(f"⚠️ AI SUGGESTED: \033[1;32m{command}\033[0m")
|
||||
# Allow user to provide feedback if they say 'n'
|
||||
print(f"\u26a0\ufe0f AI SUGGESTED: \033[1;32m{command}\033[0m")
|
||||
choice = input(" Confirm? [Y/n] or provide feedback: ").strip()
|
||||
|
||||
if choice.lower() == 'n':
|
||||
@@ -57,22 +41,15 @@ def run_terminal_command(command: str):
|
||||
return f"Execution Error: {str(e)}"
|
||||
|
||||
|
||||
def record_audio():
|
||||
fs, recording = 16000, []
|
||||
print("\n[READY] Press Enter to START...")
|
||||
input()
|
||||
print("[RECORDING] Press Enter to STOP...")
|
||||
def cb(indata, f, t, s): recording.append(indata.copy())
|
||||
with sd.InputStream(samplerate=fs, channels=1, callback=cb):
|
||||
input()
|
||||
return np.concatenate(recording, axis=0)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default=OLLAMA_MODEL)
|
||||
parser.add_argument("--model-size", default="medium",
|
||||
help="Whisper model size")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
whisper_model = load_whisper_model(args.model_size)
|
||||
|
||||
# Initial System Prompt
|
||||
messages = [{
|
||||
'role': 'system',
|
||||
@@ -88,9 +65,8 @@ def main():
|
||||
while True:
|
||||
try:
|
||||
# 1. Voice Capture
|
||||
audio_data = record_audio()
|
||||
segments, _ = model.transcribe(audio_data.flatten(), beam_size=5)
|
||||
user_text = "".join([s.text for s in segments]).strip()
|
||||
audio_data = record_until_enter()
|
||||
user_text = transcribe(whisper_model, audio_data.flatten())
|
||||
if not user_text:
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,91 +1,14 @@
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import webrtcvad
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import subprocess
|
||||
import threading
|
||||
import queue
|
||||
import collections
|
||||
import time
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
||||
|
||||
# --- Constants ---
|
||||
SAMPLE_RATE = 16000
|
||||
CHANNELS = 1
|
||||
FRAME_DURATION_MS = 30
|
||||
FRAME_SIZE = int(SAMPLE_RATE * FRAME_DURATION_MS / 1000) # 480 samples
|
||||
MIN_UTTERANCE_FRAMES = 10 # ~300ms minimum to filter coughs/clicks
|
||||
|
||||
HALLUCINATION_PATTERNS = [
|
||||
"thank you", "thanks for watching", "subscribe",
|
||||
"bye", "the end", "thank you for watching",
|
||||
"please subscribe", "like and subscribe",
|
||||
]
|
||||
|
||||
# --- Thread-safe audio queue ---
|
||||
audio_queue = queue.Queue()
|
||||
|
||||
|
||||
def audio_callback(indata, frames, time_info, status):
|
||||
if status:
|
||||
print(status, file=sys.stderr)
|
||||
audio_queue.put(bytes(indata))
|
||||
|
||||
|
||||
# --- Whisper model loading (reused pattern from assistant.py) ---
|
||||
def load_whisper_model(model_size):
|
||||
print(f"Loading Whisper model ({model_size})...")
|
||||
try:
|
||||
return WhisperModel(model_size, device="cuda", compute_type="float16")
|
||||
except Exception as e:
|
||||
print(f"GPU loading failed: {e}")
|
||||
print("Falling back to CPU (int8)")
|
||||
return WhisperModel(model_size, device="cpu", compute_type="int8")
|
||||
|
||||
|
||||
# --- VAD State Machine ---
|
||||
class VADProcessor:
|
||||
def __init__(self, aggressiveness, silence_threshold):
|
||||
self.vad = webrtcvad.Vad(aggressiveness)
|
||||
self.silence_threshold = silence_threshold
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.triggered = False
|
||||
self.utterance_frames = []
|
||||
self.silence_duration = 0.0
|
||||
self.pre_buffer = collections.deque(maxlen=10) # ~300ms pre-roll
|
||||
|
||||
def process_frame(self, frame_bytes):
|
||||
"""Process one 30ms frame. Returns utterance bytes when complete, else None."""
|
||||
is_speech = self.vad.is_speech(frame_bytes, SAMPLE_RATE)
|
||||
|
||||
if not self.triggered:
|
||||
self.pre_buffer.append(frame_bytes)
|
||||
if is_speech:
|
||||
self.triggered = True
|
||||
self.silence_duration = 0.0
|
||||
self.utterance_frames = list(self.pre_buffer)
|
||||
self.utterance_frames.append(frame_bytes)
|
||||
pass # silent until transcription confirms speech
|
||||
else:
|
||||
self.utterance_frames.append(frame_bytes)
|
||||
if is_speech:
|
||||
self.silence_duration = 0.0
|
||||
else:
|
||||
self.silence_duration += FRAME_DURATION_MS / 1000.0
|
||||
if self.silence_duration >= self.silence_threshold:
|
||||
if len(self.utterance_frames) < MIN_UTTERANCE_FRAMES:
|
||||
self.reset()
|
||||
return None
|
||||
result = b"".join(self.utterance_frames)
|
||||
self.reset()
|
||||
return result
|
||||
return None
|
||||
import sounddevice as sd
|
||||
from sttlib import (
|
||||
load_whisper_model, transcribe, is_hallucination, pcm_bytes_to_float32,
|
||||
VADProcessor, audio_callback, audio_queue,
|
||||
SAMPLE_RATE, CHANNELS, FRAME_SIZE,
|
||||
)
|
||||
|
||||
|
||||
# --- Typer Interface (xdotool) ---
|
||||
@@ -99,6 +22,7 @@ class Typer:
|
||||
except FileNotFoundError:
|
||||
print("ERROR: xdotool not found. Install it:")
|
||||
print(" sudo apt-get install xdotool")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
def type_text(self, text, submit_now=False):
|
||||
@@ -120,24 +44,6 @@ class Typer:
|
||||
pass
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
def pcm_bytes_to_float32(pcm_bytes):
|
||||
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
||||
return audio_int16.astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def transcribe(model, audio_float32):
|
||||
segments, _ = model.transcribe(audio_float32, beam_size=5)
|
||||
return "".join(segment.text for segment in segments).strip()
|
||||
|
||||
|
||||
def is_hallucination(text):
|
||||
lowered = text.lower().strip()
|
||||
if len(lowered) < 3:
|
||||
return True
|
||||
return any(p in lowered for p in HALLUCINATION_PATTERNS)
|
||||
|
||||
|
||||
# --- CLI ---
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
||||
Reference in New Issue
Block a user