Refactor tool-speechtotext: extract sttlib shared library and add tests
Extract duplicated code (Whisper loading, audio recording, transcription, VAD processing) into reusable sttlib/ package. Rewrite all 3 scripts as thin wrappers. Add 24 unit tests with mocked hardware. Fix GPU fallback bug in assistant.py and args.system assignment bug.
This commit is contained in:
@@ -1,5 +1,11 @@
|
|||||||
{
|
{
|
||||||
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||||
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
||||||
"python-envs.pythonProjects": []
|
"python-envs.pythonProjects": [],
|
||||||
|
"python.testing.pytestEnabled": true,
|
||||||
|
"python.testing.unittestEnabled": false,
|
||||||
|
"python.testing.pytestArgs": [
|
||||||
|
"tests",
|
||||||
|
"-v"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
@@ -12,11 +12,20 @@ Speech-to-text command line utilities leveraging local models (faster-whisper, O
|
|||||||
## Tools
|
## Tools
|
||||||
- `assistant.py` / `talk.sh` — transcribe speech, copy to clipboard, optionally send to Ollama
|
- `assistant.py` / `talk.sh` — transcribe speech, copy to clipboard, optionally send to Ollama
|
||||||
- `voice_to_terminal.py` / `terminal.sh` — voice-controlled terminal via Ollama tool calling
|
- `voice_to_terminal.py` / `terminal.sh` — voice-controlled terminal via Ollama tool calling
|
||||||
- `voice_to_xdotool.py` / `dotool.sh` — hands-free voice typing into any focused window (VAD + xdotool)
|
- `voice_to_xdotool.py` / `xdotool.sh` — hands-free voice typing into any focused window (VAD + xdotool)
|
||||||
|
|
||||||
|
## Shared Library
|
||||||
|
- `sttlib/` — shared package used by all scripts and importable by other projects
|
||||||
|
- `whisper_loader.py` — model loading with GPU→CPU fallback
|
||||||
|
- `audio.py` — press-enter recording, PCM conversion
|
||||||
|
- `transcription.py` — Whisper transcribe wrapper, hallucination filter
|
||||||
|
- `vad.py` — VADProcessor, audio callback, constants
|
||||||
|
- Other projects import via: `sys.path.insert(0, "/path/to/tool-speechtotext")`
|
||||||
|
|
||||||
## Testing
|
## Testing
|
||||||
- To test scripts: `mamba run -n whisper-ollama python <script.py> --model-size base`
|
- Run tests: `mamba run -n whisper-ollama python -m pytest tests/`
|
||||||
- Use `--model-size base` for faster iteration during development
|
- Use `--model-size base` for faster iteration during development
|
||||||
|
- Tests mock hardware (Whisper model, VAD, mic) — no GPU/mic needed to run them
|
||||||
- Audio device is available — live mic testing is possible
|
- Audio device is available — live mic testing is possible
|
||||||
- Test xdotool output by focusing a text editor window
|
- Test xdotool output by focusing a text editor window
|
||||||
|
|
||||||
@@ -24,12 +33,13 @@ Speech-to-text command line utilities leveraging local models (faster-whisper, O
|
|||||||
- Conda: faster-whisper, sounddevice, numpy, pyperclip, requests, ollama
|
- Conda: faster-whisper, sounddevice, numpy, pyperclip, requests, ollama
|
||||||
- Pip (in conda env): webrtcvad
|
- Pip (in conda env): webrtcvad
|
||||||
- System: libportaudio2, xdotool
|
- System: libportaudio2, xdotool
|
||||||
|
- Dev: pytest
|
||||||
|
|
||||||
## Conventions
|
## Conventions
|
||||||
- Shell wrappers go in .sh files using `mamba run -n whisper-ollama`
|
- Shell wrappers go in .sh files using `mamba run -n whisper-ollama`
|
||||||
- All scripts set `CT2_CUDA_ALLOW_FP16=1`
|
- Shared code lives in `sttlib/` — scripts are thin entry points that import from it
|
||||||
- Whisper model loading always has GPU (cuda/float16) -> CPU (cpu/int8) fallback
|
- Whisper model loading always has GPU (cuda/float16) -> CPU (cpu/int8) fallback
|
||||||
- Keep scripts self-contained (no shared module)
|
- `CT2_CUDA_ALLOW_FP16=1` is set by `sttlib.whisper_loader` at import time
|
||||||
- Don't print output for non-actionable events
|
- Don't print output for non-actionable events
|
||||||
|
|
||||||
## Preferences
|
## Preferences
|
||||||
|
|||||||
@@ -1,57 +1,17 @@
|
|||||||
import sounddevice as sd
|
import argparse
|
||||||
import numpy as np
|
|
||||||
import pyperclip
|
import pyperclip
|
||||||
import requests
|
import requests
|
||||||
import sys
|
from sttlib import load_whisper_model, record_until_enter, transcribe
|
||||||
import argparse
|
|
||||||
from faster_whisper import WhisperModel
|
|
||||||
|
|
||||||
import os
|
|
||||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
|
||||||
|
|
||||||
# --- Configuration ---
|
# --- Configuration ---
|
||||||
MODEL_SIZE = "medium" # Options: "base", "small", "medium", "large-v3"
|
OLLAMA_URL = "http://localhost:11434/api/generate"
|
||||||
OLLAMA_URL = "http://localhost:11434/api/generate" # Default is 11434
|
|
||||||
DEFAULT_OLLAMA_MODEL = "qwen3:latest"
|
DEFAULT_OLLAMA_MODEL = "qwen3:latest"
|
||||||
|
|
||||||
# Load Whisper on GPU
|
|
||||||
# float16 is faster and uses less VRAM on NVIDIA cards
|
|
||||||
print("Loading Whisper model...")
|
|
||||||
|
|
||||||
|
|
||||||
try:
|
|
||||||
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="float16")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading GPU: {e}")
|
|
||||||
print("Falling back to CPU (Check your CUDA/cuDNN installation)")
|
|
||||||
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="int16")
|
|
||||||
|
|
||||||
|
|
||||||
def record_audio():
|
|
||||||
fs = 16000
|
|
||||||
print("\n[READY] Press Enter to START recording...")
|
|
||||||
input()
|
|
||||||
print("[RECORDING] Press Enter to STOP...")
|
|
||||||
|
|
||||||
recording = []
|
|
||||||
|
|
||||||
def callback(indata, frames, time, status):
|
|
||||||
if status:
|
|
||||||
print(status, file=sys.stderr)
|
|
||||||
recording.append(indata.copy())
|
|
||||||
|
|
||||||
with sd.InputStream(samplerate=fs, channels=1, callback=callback):
|
|
||||||
input()
|
|
||||||
|
|
||||||
return np.concatenate(recording, axis=0)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
# 1. Setup Parser
|
|
||||||
print(f"System active. Model: {DEFAULT_OLLAMA_MODEL}")
|
print(f"System active. Model: {DEFAULT_OLLAMA_MODEL}")
|
||||||
parser = argparse.ArgumentParser(description="Whisper + Ollama CLI")
|
parser = argparse.ArgumentParser(description="Whisper + Ollama CLI")
|
||||||
|
|
||||||
# Known Arguments (Hardcoded logic)
|
|
||||||
parser.add_argument("--nollm", "-n", action='store_true',
|
parser.add_argument("--nollm", "-n", action='store_true',
|
||||||
help="turn off llm")
|
help="turn off llm")
|
||||||
parser.add_argument("--system", "-s", default=None,
|
parser.add_argument("--system", "-s", default=None,
|
||||||
@@ -65,30 +25,27 @@ def main():
|
|||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--temp", default='0.7', help="temperature")
|
"--temp", default='0.7', help="temperature")
|
||||||
|
|
||||||
# 2. Capture "Unknown" arguments
|
|
||||||
# args = known values, unknown = a list like ['--num_ctx', '4096', '--temp', '0.7']
|
|
||||||
args, unknown = parser.parse_known_args()
|
args, unknown = parser.parse_known_args()
|
||||||
|
|
||||||
# Convert unknown list to a dictionary for the Ollama 'options' field
|
# Convert unknown list to a dictionary for the Ollama 'options' field
|
||||||
# This logic pairs ['--key', 'value'] into {key: value}
|
|
||||||
extra_options = {}
|
extra_options = {}
|
||||||
for i in range(0, len(unknown), 2):
|
for i in range(0, len(unknown), 2):
|
||||||
key = unknown[i].lstrip('-') # remove the '--'
|
key = unknown[i].lstrip('-')
|
||||||
val = unknown[i+1]
|
val = unknown[i+1]
|
||||||
# Try to convert numbers to actual ints/floats
|
|
||||||
try:
|
try:
|
||||||
val = float(val) if '.' in val else int(val)
|
val = float(val) if '.' in val else int(val)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
extra_options[key] = val
|
extra_options[key] = val
|
||||||
|
|
||||||
|
model = load_whisper_model(args.model_size)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
audio_data = record_audio()
|
audio_data = record_until_enter()
|
||||||
|
|
||||||
print("[TRANSCRIBING]...")
|
print("[TRANSCRIBING]...")
|
||||||
segments, _ = model.transcribe(audio_data.flatten(), beam_size=5)
|
text = transcribe(model, audio_data.flatten())
|
||||||
text = "".join([segment.text for segment in segments]).strip()
|
|
||||||
|
|
||||||
if not text:
|
if not text:
|
||||||
print("No speech detected. Try again.")
|
print("No speech detected. Try again.")
|
||||||
@@ -97,8 +54,7 @@ def main():
|
|||||||
print(f"You said: {text}")
|
print(f"You said: {text}")
|
||||||
pyperclip.copy(text)
|
pyperclip.copy(text)
|
||||||
|
|
||||||
if (args.nollm == False):
|
if not args.nollm:
|
||||||
# Send to Ollama
|
|
||||||
print(f"[OLLAMA] Thinking...")
|
print(f"[OLLAMA] Thinking...")
|
||||||
payload = {
|
payload = {
|
||||||
"model": args.ollama_model,
|
"model": args.ollama_model,
|
||||||
@@ -108,9 +64,9 @@ def main():
|
|||||||
}
|
}
|
||||||
|
|
||||||
if args.system:
|
if args.system:
|
||||||
payload["system"] = args
|
payload["system"] = args.system
|
||||||
response = requests.post(OLLAMA_URL, json=payload)
|
|
||||||
|
|
||||||
|
response = requests.post(OLLAMA_URL, json=payload)
|
||||||
result = response.json().get("response", "")
|
result = response.json().get("response", "")
|
||||||
print(f"\nLLM Response:\n{result}\n")
|
print(f"\nLLM Response:\n{result}\n")
|
||||||
else:
|
else:
|
||||||
|
|||||||
7
python/tool-speechtotext/sttlib/__init__.py
Normal file
7
python/tool-speechtotext/sttlib/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
from sttlib.whisper_loader import load_whisper_model
|
||||||
|
from sttlib.audio import record_until_enter, pcm_bytes_to_float32
|
||||||
|
from sttlib.transcription import transcribe, is_hallucination, HALLUCINATION_PATTERNS
|
||||||
|
from sttlib.vad import (
|
||||||
|
VADProcessor, audio_callback, audio_queue,
|
||||||
|
SAMPLE_RATE, CHANNELS, FRAME_DURATION_MS, FRAME_SIZE, MIN_UTTERANCE_FRAMES,
|
||||||
|
)
|
||||||
28
python/tool-speechtotext/sttlib/audio.py
Normal file
28
python/tool-speechtotext/sttlib/audio.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
import sys
|
||||||
|
import numpy as np
|
||||||
|
import sounddevice as sd
|
||||||
|
|
||||||
|
|
||||||
|
def record_until_enter(sample_rate=16000):
|
||||||
|
"""Record audio until user presses Enter. Returns float32 numpy array."""
|
||||||
|
print("\n[READY] Press Enter to START recording...")
|
||||||
|
input()
|
||||||
|
print("[RECORDING] Press Enter to STOP...")
|
||||||
|
|
||||||
|
recording = []
|
||||||
|
|
||||||
|
def callback(indata, frames, time, status):
|
||||||
|
if status:
|
||||||
|
print(status, file=sys.stderr)
|
||||||
|
recording.append(indata.copy())
|
||||||
|
|
||||||
|
with sd.InputStream(samplerate=sample_rate, channels=1, callback=callback):
|
||||||
|
input()
|
||||||
|
|
||||||
|
return np.concatenate(recording, axis=0)
|
||||||
|
|
||||||
|
|
||||||
|
def pcm_bytes_to_float32(pcm_bytes):
|
||||||
|
"""Convert raw 16-bit PCM bytes to float32 array normalized to [-1, 1]."""
|
||||||
|
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
||||||
|
return audio_int16.astype(np.float32) / 32768.0
|
||||||
19
python/tool-speechtotext/sttlib/transcription.py
Normal file
19
python/tool-speechtotext/sttlib/transcription.py
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
HALLUCINATION_PATTERNS = [
|
||||||
|
"thank you", "thanks for watching", "subscribe",
|
||||||
|
"bye", "the end", "thank you for watching",
|
||||||
|
"please subscribe", "like and subscribe",
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def transcribe(model, audio_float32):
|
||||||
|
"""Transcribe audio using Whisper. Returns stripped text."""
|
||||||
|
segments, _ = model.transcribe(audio_float32, beam_size=5)
|
||||||
|
return "".join(segment.text for segment in segments).strip()
|
||||||
|
|
||||||
|
|
||||||
|
def is_hallucination(text):
|
||||||
|
"""Return True if text looks like a Whisper hallucination."""
|
||||||
|
lowered = text.lower().strip()
|
||||||
|
if len(lowered) < 3:
|
||||||
|
return True
|
||||||
|
return any(p in lowered for p in HALLUCINATION_PATTERNS)
|
||||||
58
python/tool-speechtotext/sttlib/vad.py
Normal file
58
python/tool-speechtotext/sttlib/vad.py
Normal file
@@ -0,0 +1,58 @@
|
|||||||
|
import sys
|
||||||
|
import queue
|
||||||
|
import collections
|
||||||
|
import webrtcvad
|
||||||
|
|
||||||
|
SAMPLE_RATE = 16000
|
||||||
|
CHANNELS = 1
|
||||||
|
FRAME_DURATION_MS = 30
|
||||||
|
FRAME_SIZE = int(SAMPLE_RATE * FRAME_DURATION_MS / 1000) # 480 samples
|
||||||
|
MIN_UTTERANCE_FRAMES = 10 # ~300ms minimum to filter coughs/clicks
|
||||||
|
|
||||||
|
audio_queue = queue.Queue()
|
||||||
|
|
||||||
|
|
||||||
|
def audio_callback(indata, frames, time_info, status):
|
||||||
|
"""sounddevice callback that pushes raw bytes to the audio queue."""
|
||||||
|
if status:
|
||||||
|
print(status, file=sys.stderr)
|
||||||
|
audio_queue.put(bytes(indata))
|
||||||
|
|
||||||
|
|
||||||
|
class VADProcessor:
|
||||||
|
def __init__(self, aggressiveness, silence_threshold):
|
||||||
|
self.vad = webrtcvad.Vad(aggressiveness)
|
||||||
|
self.silence_threshold = silence_threshold
|
||||||
|
self.reset()
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.triggered = False
|
||||||
|
self.utterance_frames = []
|
||||||
|
self.silence_duration = 0.0
|
||||||
|
self.pre_buffer = collections.deque(maxlen=10) # ~300ms pre-roll
|
||||||
|
|
||||||
|
def process_frame(self, frame_bytes):
|
||||||
|
"""Process one 30ms frame. Returns utterance bytes when complete, else None."""
|
||||||
|
is_speech = self.vad.is_speech(frame_bytes, SAMPLE_RATE)
|
||||||
|
|
||||||
|
if not self.triggered:
|
||||||
|
self.pre_buffer.append(frame_bytes)
|
||||||
|
if is_speech:
|
||||||
|
self.triggered = True
|
||||||
|
self.silence_duration = 0.0
|
||||||
|
self.utterance_frames = list(self.pre_buffer)
|
||||||
|
self.utterance_frames.append(frame_bytes)
|
||||||
|
else:
|
||||||
|
self.utterance_frames.append(frame_bytes)
|
||||||
|
if is_speech:
|
||||||
|
self.silence_duration = 0.0
|
||||||
|
else:
|
||||||
|
self.silence_duration += FRAME_DURATION_MS / 1000.0
|
||||||
|
if self.silence_duration >= self.silence_threshold:
|
||||||
|
if len(self.utterance_frames) < MIN_UTTERANCE_FRAMES:
|
||||||
|
self.reset()
|
||||||
|
return None
|
||||||
|
result = b"".join(self.utterance_frames)
|
||||||
|
self.reset()
|
||||||
|
return result
|
||||||
|
return None
|
||||||
15
python/tool-speechtotext/sttlib/whisper_loader.py
Normal file
15
python/tool-speechtotext/sttlib/whisper_loader.py
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
import os
|
||||||
|
from faster_whisper import WhisperModel
|
||||||
|
|
||||||
|
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
||||||
|
|
||||||
|
|
||||||
|
def load_whisper_model(model_size):
|
||||||
|
"""Load Whisper with GPU (cuda/float16) -> CPU (cpu/int8) fallback."""
|
||||||
|
print(f"Loading Whisper model ({model_size})...")
|
||||||
|
try:
|
||||||
|
return WhisperModel(model_size, device="cuda", compute_type="float16")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"GPU loading failed: {e}")
|
||||||
|
print("Falling back to CPU (int8)")
|
||||||
|
return WhisperModel(model_size, device="cpu", compute_type="int8")
|
||||||
0
python/tool-speechtotext/tests/__init__.py
Normal file
0
python/tool-speechtotext/tests/__init__.py
Normal file
38
python/tool-speechtotext/tests/test_audio.py
Normal file
38
python/tool-speechtotext/tests/test_audio.py
Normal file
@@ -0,0 +1,38 @@
|
|||||||
|
import struct
|
||||||
|
import numpy as np
|
||||||
|
from sttlib.audio import pcm_bytes_to_float32
|
||||||
|
|
||||||
|
|
||||||
|
def test_known_value():
|
||||||
|
# 16384 in int16 -> 0.5 in float32
|
||||||
|
pcm = struct.pack("<h", 16384)
|
||||||
|
result = pcm_bytes_to_float32(pcm)
|
||||||
|
assert abs(result[0] - 0.5) < 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
def test_silence():
|
||||||
|
pcm = b"\x00\x00" * 10
|
||||||
|
result = pcm_bytes_to_float32(pcm)
|
||||||
|
assert np.all(result == 0.0)
|
||||||
|
|
||||||
|
|
||||||
|
def test_full_scale():
|
||||||
|
# max int16 = 32767 -> ~1.0
|
||||||
|
pcm = struct.pack("<h", 32767)
|
||||||
|
result = pcm_bytes_to_float32(pcm)
|
||||||
|
assert abs(result[0] - (32767 / 32768.0)) < 1e-5
|
||||||
|
|
||||||
|
|
||||||
|
def test_negative():
|
||||||
|
# min int16 = -32768 -> -1.0
|
||||||
|
pcm = struct.pack("<h", -32768)
|
||||||
|
result = pcm_bytes_to_float32(pcm)
|
||||||
|
assert result[0] == -1.0
|
||||||
|
|
||||||
|
|
||||||
|
def test_round_trip_shape():
|
||||||
|
# 100 samples worth of bytes
|
||||||
|
pcm = b"\x00\x00" * 100
|
||||||
|
result = pcm_bytes_to_float32(pcm)
|
||||||
|
assert result.shape == (100,)
|
||||||
|
assert result.dtype == np.float32
|
||||||
78
python/tool-speechtotext/tests/test_transcription.py
Normal file
78
python/tool-speechtotext/tests/test_transcription.py
Normal file
@@ -0,0 +1,78 @@
|
|||||||
|
from unittest.mock import MagicMock
|
||||||
|
from sttlib.transcription import transcribe, is_hallucination
|
||||||
|
|
||||||
|
|
||||||
|
# --- is_hallucination tests ---
|
||||||
|
|
||||||
|
def test_known_hallucinations():
|
||||||
|
assert is_hallucination("Thank you")
|
||||||
|
assert is_hallucination("thanks for watching")
|
||||||
|
assert is_hallucination("Subscribe")
|
||||||
|
assert is_hallucination("the end")
|
||||||
|
|
||||||
|
|
||||||
|
def test_short_text():
|
||||||
|
assert is_hallucination("hi")
|
||||||
|
assert is_hallucination("")
|
||||||
|
assert is_hallucination("a")
|
||||||
|
|
||||||
|
|
||||||
|
def test_normal_text():
|
||||||
|
assert not is_hallucination("Hello how are you")
|
||||||
|
assert not is_hallucination("Please open the terminal")
|
||||||
|
|
||||||
|
|
||||||
|
def test_case_insensitivity():
|
||||||
|
assert is_hallucination("THANK YOU")
|
||||||
|
assert is_hallucination("Thank You For Watching")
|
||||||
|
|
||||||
|
|
||||||
|
def test_substring_match():
|
||||||
|
assert is_hallucination("I want to subscribe to your channel")
|
||||||
|
|
||||||
|
|
||||||
|
def test_exactly_three_chars():
|
||||||
|
assert not is_hallucination("hey")
|
||||||
|
|
||||||
|
|
||||||
|
# --- transcribe tests ---
|
||||||
|
|
||||||
|
def _make_segment(text):
|
||||||
|
seg = MagicMock()
|
||||||
|
seg.text = text
|
||||||
|
return seg
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcribe_joins_segments():
|
||||||
|
model = MagicMock()
|
||||||
|
model.transcribe.return_value = (
|
||||||
|
[_make_segment("Hello "), _make_segment("world")],
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
result = transcribe(model, MagicMock())
|
||||||
|
assert result == "Hello world"
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcribe_empty():
|
||||||
|
model = MagicMock()
|
||||||
|
model.transcribe.return_value = ([], None)
|
||||||
|
result = transcribe(model, MagicMock())
|
||||||
|
assert result == ""
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcribe_strips_whitespace():
|
||||||
|
model = MagicMock()
|
||||||
|
model.transcribe.return_value = (
|
||||||
|
[_make_segment(" hello ")],
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
result = transcribe(model, MagicMock())
|
||||||
|
assert result == "hello"
|
||||||
|
|
||||||
|
|
||||||
|
def test_transcribe_passes_beam_size():
|
||||||
|
model = MagicMock()
|
||||||
|
model.transcribe.return_value = ([], None)
|
||||||
|
audio = MagicMock()
|
||||||
|
transcribe(model, audio)
|
||||||
|
model.transcribe.assert_called_once_with(audio, beam_size=5)
|
||||||
151
python/tool-speechtotext/tests/test_vad.py
Normal file
151
python/tool-speechtotext/tests/test_vad.py
Normal file
@@ -0,0 +1,151 @@
|
|||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from sttlib.vad import VADProcessor, FRAME_DURATION_MS, MIN_UTTERANCE_FRAMES
|
||||||
|
|
||||||
|
|
||||||
|
def _make_vad_processor(aggressiveness=3, silence_threshold=0.8):
|
||||||
|
"""Create VADProcessor with a mocked webrtcvad.Vad."""
|
||||||
|
with patch("sttlib.vad.webrtcvad.Vad") as mock_vad_cls:
|
||||||
|
mock_vad = MagicMock()
|
||||||
|
mock_vad_cls.return_value = mock_vad
|
||||||
|
proc = VADProcessor(aggressiveness, silence_threshold)
|
||||||
|
return proc, mock_vad
|
||||||
|
|
||||||
|
|
||||||
|
def _frame(label="x"):
|
||||||
|
"""Return a fake 30ms frame (just needs to be distinct bytes)."""
|
||||||
|
return label.encode() * 960 # 480 samples * 2 bytes
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_speech_returns_none():
|
||||||
|
proc, mock_vad = _make_vad_processor()
|
||||||
|
mock_vad.is_speech.return_value = False
|
||||||
|
|
||||||
|
for _ in range(100):
|
||||||
|
assert proc.process_frame(_frame()) is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_speech_then_silence_triggers_utterance():
|
||||||
|
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
|
||||||
|
|
||||||
|
# Feed enough speech frames
|
||||||
|
speech_count = MIN_UTTERANCE_FRAMES + 5
|
||||||
|
mock_vad.is_speech.return_value = True
|
||||||
|
for _ in range(speech_count):
|
||||||
|
result = proc.process_frame(_frame("s"))
|
||||||
|
assert result is None # not done yet
|
||||||
|
|
||||||
|
# Feed silence frames until threshold (0.3s = 10 frames at 30ms)
|
||||||
|
mock_vad.is_speech.return_value = False
|
||||||
|
result = None
|
||||||
|
for _ in range(20):
|
||||||
|
result = proc.process_frame(_frame("q"))
|
||||||
|
if result is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
assert len(result) > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_short_utterance_filtered():
|
||||||
|
# Use very short silence threshold so silence frames don't push total
|
||||||
|
# past MIN_UTTERANCE_FRAMES. With threshold=0.09s (3 frames of silence):
|
||||||
|
# 0 pre-buffer + 1 speech + 3 silence = 4 total < MIN_UTTERANCE_FRAMES (10)
|
||||||
|
proc, mock_vad = _make_vad_processor(silence_threshold=0.09)
|
||||||
|
|
||||||
|
# Single speech frame triggers VAD
|
||||||
|
mock_vad.is_speech.return_value = True
|
||||||
|
proc.process_frame(_frame("s"))
|
||||||
|
|
||||||
|
# Immediately go silent — threshold reached in 3 frames
|
||||||
|
mock_vad.is_speech.return_value = False
|
||||||
|
result = None
|
||||||
|
for _ in range(20):
|
||||||
|
result = proc.process_frame(_frame("q"))
|
||||||
|
if result is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Should be filtered (too short — only 4 total frames)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_pre_buffer_included():
|
||||||
|
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
|
||||||
|
|
||||||
|
# Fill pre-buffer with non-speech frames
|
||||||
|
mock_vad.is_speech.return_value = False
|
||||||
|
pre_frame = _frame("p")
|
||||||
|
for _ in range(10):
|
||||||
|
proc.process_frame(pre_frame)
|
||||||
|
|
||||||
|
# Speech starts
|
||||||
|
mock_vad.is_speech.return_value = True
|
||||||
|
speech_frame = _frame("s")
|
||||||
|
for _ in range(MIN_UTTERANCE_FRAMES):
|
||||||
|
proc.process_frame(speech_frame)
|
||||||
|
|
||||||
|
# Silence to trigger
|
||||||
|
mock_vad.is_speech.return_value = False
|
||||||
|
result = None
|
||||||
|
for _ in range(20):
|
||||||
|
result = proc.process_frame(_frame("q"))
|
||||||
|
if result is not None:
|
||||||
|
break
|
||||||
|
|
||||||
|
assert result is not None
|
||||||
|
# Result should contain pre-buffer frames
|
||||||
|
assert pre_frame in result
|
||||||
|
|
||||||
|
|
||||||
|
def test_reset_after_utterance():
|
||||||
|
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
|
||||||
|
|
||||||
|
# First utterance
|
||||||
|
mock_vad.is_speech.return_value = True
|
||||||
|
for _ in range(MIN_UTTERANCE_FRAMES + 5):
|
||||||
|
proc.process_frame(_frame("s"))
|
||||||
|
|
||||||
|
mock_vad.is_speech.return_value = False
|
||||||
|
for _ in range(20):
|
||||||
|
result = proc.process_frame(_frame("q"))
|
||||||
|
if result is not None:
|
||||||
|
break
|
||||||
|
assert result is not None
|
||||||
|
|
||||||
|
# After reset, should be able to collect a second utterance
|
||||||
|
assert not proc.triggered
|
||||||
|
assert proc.utterance_frames == []
|
||||||
|
|
||||||
|
mock_vad.is_speech.return_value = True
|
||||||
|
for _ in range(MIN_UTTERANCE_FRAMES + 5):
|
||||||
|
proc.process_frame(_frame("s"))
|
||||||
|
|
||||||
|
mock_vad.is_speech.return_value = False
|
||||||
|
result2 = None
|
||||||
|
for _ in range(20):
|
||||||
|
result2 = proc.process_frame(_frame("q"))
|
||||||
|
if result2 is not None:
|
||||||
|
break
|
||||||
|
assert result2 is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_silence_threshold_boundary():
|
||||||
|
# Use 0.3s threshold: 0.3 / 0.03 = exactly 10 frames needed
|
||||||
|
threshold = 0.3
|
||||||
|
proc, mock_vad = _make_vad_processor(silence_threshold=threshold)
|
||||||
|
|
||||||
|
# Start with speech
|
||||||
|
mock_vad.is_speech.return_value = True
|
||||||
|
for _ in range(MIN_UTTERANCE_FRAMES + 5):
|
||||||
|
proc.process_frame(_frame("s"))
|
||||||
|
|
||||||
|
frames_needed = 10 # 0.3s / 0.03s per frame
|
||||||
|
mock_vad.is_speech.return_value = False
|
||||||
|
|
||||||
|
# Feed one less than needed — should NOT trigger
|
||||||
|
for i in range(frames_needed - 1):
|
||||||
|
result = proc.process_frame(_frame("q"))
|
||||||
|
assert result is None, f"Triggered too early at frame {i}"
|
||||||
|
|
||||||
|
# The 10th frame should trigger (silence_duration = 0.3 >= 0.3)
|
||||||
|
result = proc.process_frame(_frame("q"))
|
||||||
|
assert result is not None
|
||||||
37
python/tool-speechtotext/tests/test_whisper_loader.py
Normal file
37
python/tool-speechtotext/tests/test_whisper_loader.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
from unittest.mock import patch, MagicMock
|
||||||
|
from sttlib.whisper_loader import load_whisper_model
|
||||||
|
|
||||||
|
|
||||||
|
@patch("sttlib.whisper_loader.WhisperModel")
|
||||||
|
def test_gpu_success(mock_cls):
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_cls.return_value = mock_model
|
||||||
|
|
||||||
|
result = load_whisper_model("base")
|
||||||
|
|
||||||
|
assert result is mock_model
|
||||||
|
mock_cls.assert_called_once_with("base", device="cuda", compute_type="float16")
|
||||||
|
|
||||||
|
|
||||||
|
@patch("sttlib.whisper_loader.WhisperModel")
|
||||||
|
def test_gpu_fails_cpu_fallback(mock_cls):
|
||||||
|
mock_model = MagicMock()
|
||||||
|
mock_cls.side_effect = [RuntimeError("no CUDA"), mock_model]
|
||||||
|
|
||||||
|
result = load_whisper_model("base")
|
||||||
|
|
||||||
|
assert result is mock_model
|
||||||
|
assert mock_cls.call_count == 2
|
||||||
|
_, kwargs = mock_cls.call_args
|
||||||
|
assert kwargs == {"device": "cpu", "compute_type": "int8"}
|
||||||
|
|
||||||
|
|
||||||
|
@patch("sttlib.whisper_loader.WhisperModel")
|
||||||
|
def test_both_fail_propagates(mock_cls):
|
||||||
|
mock_cls.side_effect = RuntimeError("no device")
|
||||||
|
|
||||||
|
try:
|
||||||
|
load_whisper_model("base")
|
||||||
|
assert False, "Should have raised"
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
@@ -1,31 +1,16 @@
|
|||||||
import sounddevice as sd
|
|
||||||
import numpy as np
|
|
||||||
import pyperclip
|
|
||||||
import sys
|
|
||||||
import argparse
|
import argparse
|
||||||
import os
|
|
||||||
import subprocess
|
import subprocess
|
||||||
import ollama
|
|
||||||
import json
|
import json
|
||||||
from faster_whisper import WhisperModel
|
import ollama
|
||||||
|
from sttlib import load_whisper_model, record_until_enter, transcribe
|
||||||
|
|
||||||
# --- Configuration ---
|
# --- Configuration ---
|
||||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
|
||||||
MODEL_SIZE = "medium"
|
|
||||||
OLLAMA_MODEL = "qwen2.5-coder:7b"
|
OLLAMA_MODEL = "qwen2.5-coder:7b"
|
||||||
CONFIRM_COMMANDS = True # Set to False to run commands instantly
|
CONFIRM_COMMANDS = True # Set to False to run commands instantly
|
||||||
|
|
||||||
# Load Whisper on GPU
|
|
||||||
print("Loading Whisper model...")
|
|
||||||
try:
|
|
||||||
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="float16")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Error loading GPU: {e}, falling back to CPU")
|
|
||||||
model = WhisperModel(MODEL_SIZE, device="cpu", compute_type="int8")
|
|
||||||
|
|
||||||
# --- Terminal Tool ---
|
# --- Terminal Tool ---
|
||||||
|
|
||||||
|
|
||||||
def run_terminal_command(command: str):
|
def run_terminal_command(command: str):
|
||||||
"""
|
"""
|
||||||
Executes a bash command in the Linux terminal.
|
Executes a bash command in the Linux terminal.
|
||||||
@@ -33,8 +18,7 @@ def run_terminal_command(command: str):
|
|||||||
"""
|
"""
|
||||||
if CONFIRM_COMMANDS:
|
if CONFIRM_COMMANDS:
|
||||||
print(f"\n{'='*40}")
|
print(f"\n{'='*40}")
|
||||||
print(f"⚠️ AI SUGGESTED: \033[1;32m{command}\033[0m")
|
print(f"\u26a0\ufe0f AI SUGGESTED: \033[1;32m{command}\033[0m")
|
||||||
# Allow user to provide feedback if they say 'n'
|
|
||||||
choice = input(" Confirm? [Y/n] or provide feedback: ").strip()
|
choice = input(" Confirm? [Y/n] or provide feedback: ").strip()
|
||||||
|
|
||||||
if choice.lower() == 'n':
|
if choice.lower() == 'n':
|
||||||
@@ -57,22 +41,15 @@ def run_terminal_command(command: str):
|
|||||||
return f"Execution Error: {str(e)}"
|
return f"Execution Error: {str(e)}"
|
||||||
|
|
||||||
|
|
||||||
def record_audio():
|
|
||||||
fs, recording = 16000, []
|
|
||||||
print("\n[READY] Press Enter to START...")
|
|
||||||
input()
|
|
||||||
print("[RECORDING] Press Enter to STOP...")
|
|
||||||
def cb(indata, f, t, s): recording.append(indata.copy())
|
|
||||||
with sd.InputStream(samplerate=fs, channels=1, callback=cb):
|
|
||||||
input()
|
|
||||||
return np.concatenate(recording, axis=0)
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = argparse.ArgumentParser()
|
parser = argparse.ArgumentParser()
|
||||||
parser.add_argument("--model", default=OLLAMA_MODEL)
|
parser.add_argument("--model", default=OLLAMA_MODEL)
|
||||||
|
parser.add_argument("--model-size", default="medium",
|
||||||
|
help="Whisper model size")
|
||||||
args, _ = parser.parse_known_args()
|
args, _ = parser.parse_known_args()
|
||||||
|
|
||||||
|
whisper_model = load_whisper_model(args.model_size)
|
||||||
|
|
||||||
# Initial System Prompt
|
# Initial System Prompt
|
||||||
messages = [{
|
messages = [{
|
||||||
'role': 'system',
|
'role': 'system',
|
||||||
@@ -88,9 +65,8 @@ def main():
|
|||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
# 1. Voice Capture
|
# 1. Voice Capture
|
||||||
audio_data = record_audio()
|
audio_data = record_until_enter()
|
||||||
segments, _ = model.transcribe(audio_data.flatten(), beam_size=5)
|
user_text = transcribe(whisper_model, audio_data.flatten())
|
||||||
user_text = "".join([s.text for s in segments]).strip()
|
|
||||||
if not user_text:
|
if not user_text:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
|||||||
@@ -1,91 +1,14 @@
|
|||||||
import sounddevice as sd
|
|
||||||
import numpy as np
|
|
||||||
import webrtcvad
|
|
||||||
import subprocess
|
|
||||||
import sys
|
|
||||||
import os
|
|
||||||
import argparse
|
import argparse
|
||||||
|
import subprocess
|
||||||
import threading
|
import threading
|
||||||
import queue
|
import queue
|
||||||
import collections
|
|
||||||
import time
|
import time
|
||||||
from faster_whisper import WhisperModel
|
import sounddevice as sd
|
||||||
|
from sttlib import (
|
||||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
load_whisper_model, transcribe, is_hallucination, pcm_bytes_to_float32,
|
||||||
|
VADProcessor, audio_callback, audio_queue,
|
||||||
# --- Constants ---
|
SAMPLE_RATE, CHANNELS, FRAME_SIZE,
|
||||||
SAMPLE_RATE = 16000
|
)
|
||||||
CHANNELS = 1
|
|
||||||
FRAME_DURATION_MS = 30
|
|
||||||
FRAME_SIZE = int(SAMPLE_RATE * FRAME_DURATION_MS / 1000) # 480 samples
|
|
||||||
MIN_UTTERANCE_FRAMES = 10 # ~300ms minimum to filter coughs/clicks
|
|
||||||
|
|
||||||
HALLUCINATION_PATTERNS = [
|
|
||||||
"thank you", "thanks for watching", "subscribe",
|
|
||||||
"bye", "the end", "thank you for watching",
|
|
||||||
"please subscribe", "like and subscribe",
|
|
||||||
]
|
|
||||||
|
|
||||||
# --- Thread-safe audio queue ---
|
|
||||||
audio_queue = queue.Queue()
|
|
||||||
|
|
||||||
|
|
||||||
def audio_callback(indata, frames, time_info, status):
|
|
||||||
if status:
|
|
||||||
print(status, file=sys.stderr)
|
|
||||||
audio_queue.put(bytes(indata))
|
|
||||||
|
|
||||||
|
|
||||||
# --- Whisper model loading (reused pattern from assistant.py) ---
|
|
||||||
def load_whisper_model(model_size):
|
|
||||||
print(f"Loading Whisper model ({model_size})...")
|
|
||||||
try:
|
|
||||||
return WhisperModel(model_size, device="cuda", compute_type="float16")
|
|
||||||
except Exception as e:
|
|
||||||
print(f"GPU loading failed: {e}")
|
|
||||||
print("Falling back to CPU (int8)")
|
|
||||||
return WhisperModel(model_size, device="cpu", compute_type="int8")
|
|
||||||
|
|
||||||
|
|
||||||
# --- VAD State Machine ---
|
|
||||||
class VADProcessor:
|
|
||||||
def __init__(self, aggressiveness, silence_threshold):
|
|
||||||
self.vad = webrtcvad.Vad(aggressiveness)
|
|
||||||
self.silence_threshold = silence_threshold
|
|
||||||
self.reset()
|
|
||||||
|
|
||||||
def reset(self):
|
|
||||||
self.triggered = False
|
|
||||||
self.utterance_frames = []
|
|
||||||
self.silence_duration = 0.0
|
|
||||||
self.pre_buffer = collections.deque(maxlen=10) # ~300ms pre-roll
|
|
||||||
|
|
||||||
def process_frame(self, frame_bytes):
|
|
||||||
"""Process one 30ms frame. Returns utterance bytes when complete, else None."""
|
|
||||||
is_speech = self.vad.is_speech(frame_bytes, SAMPLE_RATE)
|
|
||||||
|
|
||||||
if not self.triggered:
|
|
||||||
self.pre_buffer.append(frame_bytes)
|
|
||||||
if is_speech:
|
|
||||||
self.triggered = True
|
|
||||||
self.silence_duration = 0.0
|
|
||||||
self.utterance_frames = list(self.pre_buffer)
|
|
||||||
self.utterance_frames.append(frame_bytes)
|
|
||||||
pass # silent until transcription confirms speech
|
|
||||||
else:
|
|
||||||
self.utterance_frames.append(frame_bytes)
|
|
||||||
if is_speech:
|
|
||||||
self.silence_duration = 0.0
|
|
||||||
else:
|
|
||||||
self.silence_duration += FRAME_DURATION_MS / 1000.0
|
|
||||||
if self.silence_duration >= self.silence_threshold:
|
|
||||||
if len(self.utterance_frames) < MIN_UTTERANCE_FRAMES:
|
|
||||||
self.reset()
|
|
||||||
return None
|
|
||||||
result = b"".join(self.utterance_frames)
|
|
||||||
self.reset()
|
|
||||||
return result
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
# --- Typer Interface (xdotool) ---
|
# --- Typer Interface (xdotool) ---
|
||||||
@@ -99,6 +22,7 @@ class Typer:
|
|||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
print("ERROR: xdotool not found. Install it:")
|
print("ERROR: xdotool not found. Install it:")
|
||||||
print(" sudo apt-get install xdotool")
|
print(" sudo apt-get install xdotool")
|
||||||
|
import sys
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
def type_text(self, text, submit_now=False):
|
def type_text(self, text, submit_now=False):
|
||||||
@@ -120,24 +44,6 @@ class Typer:
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
|
|
||||||
# --- Helpers ---
|
|
||||||
def pcm_bytes_to_float32(pcm_bytes):
|
|
||||||
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
|
||||||
return audio_int16.astype(np.float32) / 32768.0
|
|
||||||
|
|
||||||
|
|
||||||
def transcribe(model, audio_float32):
|
|
||||||
segments, _ = model.transcribe(audio_float32, beam_size=5)
|
|
||||||
return "".join(segment.text for segment in segments).strip()
|
|
||||||
|
|
||||||
|
|
||||||
def is_hallucination(text):
|
|
||||||
lowered = text.lower().strip()
|
|
||||||
if len(lowered) < 3:
|
|
||||||
return True
|
|
||||||
return any(p in lowered for p in HALLUCINATION_PATTERNS)
|
|
||||||
|
|
||||||
|
|
||||||
# --- CLI ---
|
# --- CLI ---
|
||||||
def parse_args():
|
def parse_args():
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
|
|||||||
Reference in New Issue
Block a user