Fix bugs, N+1 queries, and wire settings in persian-tutor
- Replace inline __import__("datetime").timedelta hack with proper import
- Remove unused import random in anki_export.py
- Add error handling for Claude CLI subprocess failures in ai.py
- Fix hardcoded absolute path in stt.py with relative Path resolution
- Fix N+1 DB queries in vocab.get_flashcard_batch and dashboard.get_category_breakdown
by adding db.get_all_word_progress() batch query
- Wire Ollama model and Whisper size settings to actually update config
via ai.set_ollama_model() and stt.set_whisper_size()
Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
@@ -6,9 +6,18 @@ import ollama
|
||||
|
||||
DEFAULT_OLLAMA_MODEL = "qwen2.5:7b"
|
||||
|
||||
_ollama_model = DEFAULT_OLLAMA_MODEL
|
||||
|
||||
def ask_ollama(prompt, system=None, model=DEFAULT_OLLAMA_MODEL):
|
||||
|
||||
def set_ollama_model(model):
|
||||
"""Change the Ollama model used for fast queries."""
|
||||
global _ollama_model
|
||||
_ollama_model = model
|
||||
|
||||
|
||||
def ask_ollama(prompt, system=None, model=None):
|
||||
"""Query Ollama with an optional system prompt."""
|
||||
model = model or _ollama_model
|
||||
messages = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
@@ -24,6 +33,8 @@ def ask_claude(prompt):
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Claude CLI failed (exit {result.returncode}): {result.stderr.strip()}")
|
||||
return result.stdout.strip()
|
||||
|
||||
|
||||
@@ -34,8 +45,9 @@ def ask(prompt, system=None, quality="fast"):
|
||||
return ask_ollama(prompt, system=system)
|
||||
|
||||
|
||||
def chat_ollama(messages, system=None, model=DEFAULT_OLLAMA_MODEL):
|
||||
def chat_ollama(messages, system=None, model=None):
|
||||
"""Multi-turn conversation with Ollama."""
|
||||
model = model or _ollama_model
|
||||
all_messages = []
|
||||
if system:
|
||||
all_messages.append({"role": "system", "content": system})
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Generate Anki .apkg decks from vocabulary data."""
|
||||
|
||||
import genanki
|
||||
import random
|
||||
|
||||
# Stable model/deck IDs (generated once, kept constant)
|
||||
_MODEL_ID = 1607392319
|
||||
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import ai
|
||||
import db
|
||||
from modules import vocab, dashboard, essay, tutor, idioms
|
||||
from modules.essay import GCSE_THEMES
|
||||
@@ -214,6 +215,15 @@ def do_anki_export(cats_selected):
|
||||
return path
|
||||
|
||||
|
||||
def update_ollama_model(model):
|
||||
ai.set_ollama_model(model)
|
||||
|
||||
|
||||
def update_whisper_size(size):
|
||||
from stt import set_whisper_size
|
||||
set_whisper_size(size)
|
||||
|
||||
|
||||
def reset_progress():
|
||||
conn = db.get_connection()
|
||||
conn.execute("DELETE FROM word_progress")
|
||||
@@ -491,6 +501,10 @@ with gr.Blocks(title="Persian Language Tutor") as app:
|
||||
|
||||
export_btn.click(fn=do_anki_export, inputs=[export_cats], outputs=[export_file])
|
||||
|
||||
# Wire model settings
|
||||
ollama_model.change(fn=update_ollama_model, inputs=[ollama_model])
|
||||
whisper_size.change(fn=update_whisper_size, inputs=[whisper_size])
|
||||
|
||||
gr.Markdown("### Reset")
|
||||
reset_btn = gr.Button("Reset All Progress", variant="stop")
|
||||
reset_status = gr.Markdown("")
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import fsrs
|
||||
@@ -148,6 +148,13 @@ def get_word_counts(total_vocab_size=0):
|
||||
}
|
||||
|
||||
|
||||
def get_all_word_progress():
|
||||
"""Return all word progress as a dict of word_id -> progress dict."""
|
||||
conn = get_connection()
|
||||
rows = conn.execute("SELECT * FROM word_progress").fetchall()
|
||||
return {row["word_id"]: dict(row) for row in rows}
|
||||
|
||||
|
||||
def record_quiz_session(category, total_questions, correct, duration_seconds):
|
||||
"""Log a completed flashcard session."""
|
||||
conn = get_connection()
|
||||
@@ -203,7 +210,7 @@ def get_stats():
|
||||
today = datetime.now(timezone.utc).date()
|
||||
for i, row in enumerate(days):
|
||||
day = datetime.fromisoformat(row["d"]).date() if isinstance(row["d"], str) else row["d"]
|
||||
expected = today - __import__("datetime").timedelta(days=i)
|
||||
expected = today - timedelta(days=i)
|
||||
if day == expected:
|
||||
streak += 1
|
||||
else:
|
||||
|
||||
@@ -19,17 +19,17 @@ def get_category_breakdown():
|
||||
"""Return progress per category as list of dicts."""
|
||||
vocab = load_vocab()
|
||||
categories = get_categories()
|
||||
all_progress = db.get_all_word_progress()
|
||||
|
||||
breakdown = []
|
||||
for cat in categories:
|
||||
cat_words = [e for e in vocab if e["category"] == cat]
|
||||
cat_ids = {e["id"] for e in cat_words}
|
||||
total = len(cat_words)
|
||||
|
||||
seen = 0
|
||||
mastered = 0
|
||||
for wid in cat_ids:
|
||||
progress = db.get_word_progress(wid)
|
||||
for e in cat_words:
|
||||
progress = all_progress.get(e["id"])
|
||||
if progress:
|
||||
seen += 1
|
||||
if progress["stability"] and progress["stability"] > 10:
|
||||
|
||||
@@ -84,8 +84,9 @@ def get_flashcard_batch(count=10, category=None):
|
||||
remaining = count - len(due_entries)
|
||||
if remaining > 0:
|
||||
seen_ids = {e["id"] for e in due_entries}
|
||||
all_progress = db.get_all_word_progress()
|
||||
# Prefer unseen words
|
||||
unseen = [e for e in pool if e["id"] not in seen_ids and not db.get_word_progress(e["id"])]
|
||||
unseen = [e for e in pool if e["id"] not in seen_ids and e["id"] not in all_progress]
|
||||
if len(unseen) >= remaining:
|
||||
fill = random.sample(unseen, remaining)
|
||||
else:
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
"""Persian speech-to-text wrapper using sttlib."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, "/home/ys/family-repo/Code/python/tool-speechtotext")
|
||||
# sttlib lives in sibling project tool-speechtotext
|
||||
_sttlib_path = str(Path(__file__).resolve().parent.parent / "tool-speechtotext")
|
||||
sys.path.insert(0, _sttlib_path)
|
||||
from sttlib import load_whisper_model, transcribe, is_hallucination
|
||||
|
||||
_model = None
|
||||
_whisper_size = "medium"
|
||||
|
||||
# Common Whisper hallucinations in Persian/silence
|
||||
PERSIAN_HALLUCINATIONS = [
|
||||
@@ -18,11 +22,19 @@ PERSIAN_HALLUCINATIONS = [
|
||||
]
|
||||
|
||||
|
||||
def get_model(size="medium"):
|
||||
def set_whisper_size(size):
|
||||
"""Change the Whisper model size. Reloads on next transcription."""
|
||||
global _whisper_size, _model
|
||||
if size != _whisper_size:
|
||||
_whisper_size = size
|
||||
_model = None
|
||||
|
||||
|
||||
def get_model():
|
||||
"""Load Whisper model (cached singleton)."""
|
||||
global _model
|
||||
if _model is None:
|
||||
_model = load_whisper_model(size)
|
||||
_model = load_whisper_model(_whisper_size)
|
||||
return _model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user