Compare commits

..

6 Commits

Author SHA1 Message Date
dl92
01d5532823 Add expression-evaluator: DAGs & state machines tutorial project
Educational calculator teaching FSMs (explicit transition table tokenizer)
and DAGs (recursive descent parser with AST evaluation). Includes CLI with
REPL, graphviz visualization, and 61 tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 18:09:42 +00:00
dl92
3a8705ece8 Fix bugs, N+1 queries, and wire settings in persian-tutor
- Replace inline __import__("datetime").timedelta hack with proper import
- Remove unused import random in anki_export.py
- Add error handling for Claude CLI subprocess failures in ai.py
- Fix hardcoded absolute path in stt.py with relative Path resolution
- Fix N+1 DB queries in vocab.get_flashcard_batch and dashboard.get_category_breakdown
  by adding db.get_all_word_progress() batch query
- Wire Ollama model and Whisper size settings to actually update config
  via ai.set_ollama_model() and stt.set_whisper_size()

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 15:40:24 +00:00
local
8b5eb8797f working cva computation using quantlib 2026-02-08 13:40:35 +00:00
local
d484f9c236 working AMC algorithm tested against quantlib 2026-02-08 13:40:00 +00:00
local
2e8c2c11d0 Add persian-tutor: Gradio-based GCSE Persian language learning app
Vocabulary study with FSRS spaced repetition, AI tutoring (Ollama/Claude),
essay marking, idioms browser, Anki export, and dashboard. 918 vocabulary
entries across 39 categories. 41 tests passing.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 01:57:44 +00:00
local
104da381fb Refactor tool-speechtotext: extract sttlib shared library and add tests
Extract duplicated code (Whisper loading, audio recording, transcription,
VAD processing) into reusable sttlib/ package. Rewrite all 3 scripts as
thin wrappers. Add 24 unit tests with mocked hardware. Fix GPU fallback
bug in assistant.py and args.system assignment bug.
2026-02-08 00:40:31 +00:00
58 changed files with 14385 additions and 195 deletions

View File

@@ -0,0 +1,5 @@
{
"python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda",
"python-envs.pythonProjects": []
}

View File

@@ -0,0 +1,19 @@
Current Progress Summary:
The Baseline: We fixed a standard Longstaff-Schwartz (LSM) American Put pricer, correcting the cash-flow propagation logic and regression targets.
The Evolution: We moved to Bermudan Swaptions using the Hull-White One-Factor Model.
The "Gold Standard": We implemented a 100% exact simulation from scratch. Instead of Euler discretization, we used Bivariate Normal sampling to jointly simulate the short rate rt and the stochastic integral ∫rsds. This accounts for the stochastic discount factor (the convexity adjustment) without approximation error.
The Current Frontier: We were debating the Risk-Neutral Measure (Q) vs. the Terminal Forward Measure (QT).
We concluded that while QT simplifies European options, it makes Bermudan LSM "messy" because it introduces a time-dependent drift shift: DriftQT=DriftQaσ2(1ea(Tt)).
Pending Topics:
Mathematical Proof: The derivation of the "Drift Shift" via Girsanovs Theorem.
Exercise Boundary Impact: How the choice of measure (and the resulting drift) visually shifts the optimal exercise boundary in simulation.
Beyond One-Factor: Potential move toward Two-Factor models or non-flat initial term structures.

145
python/american-mc/main.py Normal file
View File

@@ -0,0 +1,145 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
# -----------------------------
# Black-Scholes European Put
# -----------------------------
def european_put_black_scholes(S0, K, r, sigma, T):
d1 = (np.log(S0/K) + (r + 0.5*sigma**2)*T) / (sigma*np.sqrt(T))
d2 = d1 - sigma*np.sqrt(T)
return K*np.exp(-r*T)*norm.cdf(-d2) - S0*norm.cdf(-d1)
# -----------------------------
# GBM Simulation
# -----------------------------
def simulate_gbm(S0, r, sigma, T, M, N, seed=None):
if seed is not None:
np.random.seed(seed)
dt = T / N
# Vectorized simulation for speed
S = np.zeros((M, N + 1))
S[:, 0] = S0
for t in range(1, N + 1):
z = np.random.standard_normal(M)
S[:, t] = S[:, t-1] * np.exp((r - 0.5 * sigma**2) * dt + sigma * np.sqrt(dt) * z)
return S
# -----------------------------
# Longstaff-Schwartz Monte Carlo (Fixed)
# -----------------------------
def lsm_american_put(S0, K, r, sigma, T, M=100_000, N=50, basis_deg=2, seed=42, return_boundary=False):
S = simulate_gbm(S0, r, sigma, T, M, N, seed)
dt = T / N
df = np.exp(-r * dt)
# Immediate exercise value (payoff) at each step
payoff = np.maximum(K - S, 0)
# V stores the value of the option at each path
# Initialize with the payoff at maturity
V = payoff[:, -1]
boundary = []
# Backward induction
for t in range(N - 1, 0, -1):
# Identify In-The-Money paths
itm = payoff[:, t] > 0
# Default: option value at t is just the discounted value from t+1
V = V * df
if np.any(itm):
X = S[itm, t]
Y = V[itm] # These are already discounted from t+1
# Regression: Estimate Continuation Value
coeffs = np.polyfit(X, Y, basis_deg)
continuation_value = np.polyval(coeffs, X)
# Exercise if payoff > estimated continuation value
exercise = payoff[itm, t] > continuation_value
# Update V for paths where we exercise
# Get the indices of the ITM paths that should exercise
itm_indices = np.where(itm)[0]
exercise_indices = itm_indices[exercise]
V[exercise_indices] = payoff[exercise_indices, t]
# Boundary: The highest stock price at which we still exercise
if len(exercise_indices) > 0:
boundary.append((t * dt, np.max(S[exercise_indices, t])))
else:
boundary.append((t * dt, np.nan))
else:
boundary.append((t * dt, np.nan))
# Final price is the average of discounted values at t=1
price = np.mean(V * df)
if return_boundary:
return price, boundary[::-1]
return price
# -----------------------------
# QuantLib American Put (v1.18)
# -----------------------------
def quantlib_american_put(S0, K, r, sigma, T):
import QuantLib as ql
today = ql.Date().todaysDate()
ql.Settings.instance().evaluationDate = today
maturity = today + int(T*365)
payoff = ql.PlainVanillaPayoff(ql.Option.Put, K)
exercise = ql.AmericanExercise(today, maturity)
spot = ql.QuoteHandle(ql.SimpleQuote(S0))
dividend_curve = ql.YieldTermStructureHandle(ql.FlatForward(today, 0.0, ql.Actual365Fixed()))
riskfree_curve = ql.YieldTermStructureHandle(ql.FlatForward(today, r, ql.Actual365Fixed()))
vol_curve = ql.BlackVolTermStructureHandle(ql.BlackConstantVol(today, ql.NullCalendar(), sigma, ql.Actual365Fixed()))
process = ql.BlackScholesMertonProcess(spot, dividend_curve, riskfree_curve, vol_curve)
engine = ql.FdBlackScholesVanillaEngine(process, 200, 400)
option = ql.VanillaOption(payoff, exercise)
option.setPricingEngine(engine)
return option.NPV()
# -----------------------------
# Main script
# -----------------------------
if __name__ == "__main__":
# Parameters
S0, K, r, sigma, T = 100, 100, 0.05, 0.2, 1.0
M, N = 200_000, 50
# European Put
eur_bs = european_put_black_scholes(S0, K, r, sigma, T)
print(f"European Put (BS): {eur_bs:.4f}")
# LSM American Put
lsm_price, boundary = lsm_american_put(S0, K, r, sigma, T, M, N, return_boundary=True)
print(f"LSM American Put (M={M}): {lsm_price:.4f}")
# QuantLib American Put
ql_price = quantlib_american_put(S0, K, r, sigma, T)
print(f"QuantLib American Put: {ql_price:.4f}")
print(f"Lower bound (immediate exercise): {max(K-S0,0):.4f}")
# Plot Exercise Boundary
times = [b[0] for b in boundary]
boundaries = [b[1] for b in boundary]
plt.figure(figsize=(8,5))
plt.plot(times, boundaries, color='orange', label="LSM Exercise Boundary")
plt.axhline(K, color='red', linestyle='--', label="Strike Price")
plt.xlabel("Time to Maturity")
plt.ylabel("Stock Price for Exercise")
plt.title("American Put LSM Exercise Boundary")
plt.gca().invert_xaxis()
plt.legend()
plt.grid(True)
plt.show()

133
python/american-mc/main2.py Normal file
View File

@@ -0,0 +1,133 @@
import numpy as np
import matplotlib.pyplot as plt
# ---------------------------------------------------------
# 1. Analytical Hull-White Bond Pricing
# ---------------------------------------------------------
def bond_price(r_t, t, T, a, sigma, r0):
"""Zero-coupon bond price P(t, T) in Hull-White model for flat curve r0."""
B = (1 - np.exp(-a * (T - t))) / a
A = np.exp((B - (T - t)) * (a**2 * r0 - sigma**2/2) / a**2 - (sigma**2 * B**2 / (4 * a)))
return A * np.exp(-B * r_t)
def get_swap_npv(r_t, t, T_end, strike, a, sigma, r0):
"""NPV of a Payer Swap: Receives Floating, Pays Fixed strike."""
payment_dates = np.arange(t + 1, T_end + 1, 1.0)
if len(payment_dates) == 0:
return np.zeros_like(r_t)
# NPV = 1 - P(t, T_end) - strike * Sum[P(t, T_i)]
p_end = bond_price(r_t, t, T_end, a, sigma, r0)
fixed_leg = sum(bond_price(r_t, t, pd, a, sigma, r0) for pd in payment_dates)
return np.maximum(1 - p_end - strike * fixed_leg, 0)
# ---------------------------------------------------------
# 2. Joint Exact Simulator (Short Rate & Stochastic Integral)
# ---------------------------------------------------------
def simulate_hw_exact_joint(r0, a, sigma, exercise_dates, M):
"""Samples (r_t, Integral[r]) jointly to get exact discount factors."""
M = int(M)
num_dates = len(exercise_dates)
r_matrix = np.zeros((M, num_dates + 1))
d_matrix = np.zeros((M, num_dates)) # d[i] = exp(-integral from t_i to t_{i+1})
r_matrix[:, 0] = r0
t_steps = np.insert(exercise_dates, 0, 0.0)
for i in range(len(t_steps) - 1):
t, T = t_steps[i], t_steps[i+1]
dt = T - t
# Drift adjustments for flat initial curve r0
alpha_t = r0 + (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * t))**2
alpha_T = r0 + (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * T))**2
# Means
mean_r = r_matrix[:, i] * np.exp(-a * dt) + alpha_T - alpha_t * np.exp(-a * dt)
# The expected value of the integral is derived from the bond price: E[exp(-I)] = P(t,T)
mean_I = -np.log(bond_price(r_matrix[:, i], t, T, a, sigma, r0))
# Covariance Matrix Components
var_r = (sigma**2 / (2 * a)) * (1 - np.exp(-2 * a * dt))
B = (1 - np.exp(-a * dt)) / a
var_I = (sigma**2 / a**2) * (dt - 2*B + (1 - np.exp(-2*a*dt))/(2*a))
cov_rI = (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * dt))**2
cov_matrix = [[var_r, cov_rI], [cov_rI, var_I]]
# Sample joint normal innovations
Z = np.random.multivariate_normal([0, 0], cov_matrix, M)
r_matrix[:, i+1] = mean_r + Z[:, 0]
# Important: The variance of the integral affects the mean of the exponent (convexity)
# mean_I here is already the log of the bond price (risk-neutral expectation)
d_matrix[:, i] = np.exp(-(mean_I + Z[:, 1] - 0.5 * var_I))
return r_matrix[:, 1:], d_matrix
# ---------------------------------------------------------
# 3. LSM Pricing Logic
# ---------------------------------------------------------
def price_bermudan_swaption(r0, a, sigma, strike, exercise_dates, T_end, M):
# 1. Generate exact paths
r_at_ex, discounts = simulate_hw_exact_joint(r0, a, sigma, exercise_dates, M)
# 2. Final exercise date payoff
T_last = exercise_dates[-1]
cash_flows = get_swap_npv(r_at_ex[:, -1], T_last, T_end, strike, a, sigma, r0)
# 3. Backward Induction
# exercise_dates[:-1] because we already handled the last date
for i in reversed(range(len(exercise_dates) - 1)):
t_current = exercise_dates[i]
# Pull cash flows back to current time using stochastic discount
# Note: If path was exercised later, this is the discounted value of that exercise.
cash_flows = cash_flows * discounts[:, i]
# Current intrinsic value
X = r_at_ex[:, i]
immediate_payoff = get_swap_npv(X, t_current, T_end, strike, a, sigma, r0)
# Only regress In-The-Money paths
itm = immediate_payoff > 0
if np.any(itm):
# Regression: Basis functions [1, r, r^2]
A = np.column_stack([np.ones_like(X[itm]), X[itm], X[itm]**2])
coeffs = np.linalg.lstsq(A, cash_flows[itm], rcond=None)[0]
continuation_value = A @ coeffs
# Exercise decision
exercise = immediate_payoff[itm] > continuation_value
itm_indices = np.where(itm)[0]
exercise_indices = itm_indices[exercise]
# Update cash flows for paths where we exercise early
cash_flows[exercise_indices] = immediate_payoff[exercise_indices]
# Final discount to t=0
# The first discount factor in 'discounts' is from t1 to t0
# But wait, our 'discounts' matrix is (M, num_dates).
# Let's just use the analytical P(0, t1) for the very last step to t=0.
final_price = np.mean(cash_flows * bond_price(r0, 0, exercise_dates[0], a, sigma, r0))
return final_price
# ---------------------------------------------------------
# Execution
# ---------------------------------------------------------
if __name__ == "__main__":
# Params
r0_val, a_val, sigma_val = 0.05, 0.1, 0.01
strike_val = 0.05
ex_dates = np.array([1.0, 2.0, 3.0, 4.0])
maturity = 5.0
num_paths = 100_000
price = price_bermudan_swaption(r0_val, a_val, sigma_val, strike_val, ex_dates, maturity, num_paths)
print(f"--- 100% Exact LSM Bermudan Swaption ---")
print(f"Parameters: a={a_val}, sigma={sigma_val}, strike={strike_val}")
print(f"Exercise Dates: {ex_dates}")
print(f"Calculated Price: {price:.6f}")

View File

@@ -0,0 +1,5 @@
{
"python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda",
"python-envs.pythonProjects": []
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,83 @@
import numpy as np
import Quantlib as ql
from hullwhite import (build_calibrated_hw,simulate_hw_state,swap_npv_from_state)
from instruments import (make_vanilla_swap,make_swaption_helpers)
today = ql.Date(15, 1, 2025)
ql.Settings.instance().evaluationDate = today
calendar = ql.TARGET()
day_count = ql.Actual365Fixed()
def expected_exposure(
swap,
hw,
curve_handle,
today,
time_grid,
x_paths
):
ee = np.zeros(len(time_grid))
for j, t in enumerate(time_grid):
values = []
for i in range(len(x_paths)):
v = swap_npv_from_state(
swap, hw, curve_handle, t, x_paths[i, j], today
)
values.append(max(v, 0.0))
ee[j] = np.mean(values)
return ee
def main():
# Time grid
time_grid = np.linspace(0, 10, 121)
flat_rate = 0.02
curve = ql.FlatForward(today, flat_rate, day_count)
curve_handle = ql.YieldTermStructureHandle(curve)
#creat and calibrate HW
a_init = 0.05
sigma_init = 0.01
hw = ql.HullWhite(curve_handle, a_init, sigma_init)
index = ql.Euribor6M(curve_handle)
helpers=make_swaption_helpers(curve_handle,index,hw)
optimizer = ql.LevenbergMarquardt()
end_criteria = ql.EndCriteria(1000, 500, 1e-8, 1e-8, 1e-8)
hw.calibrate(helpers, optimizer, end_criteria)
a, sigma = hw.params()
print(f"Calibrated a = {a:.4f}")
print(f"Calibrated sigma = {sigma:.4f}")
# Simulate model
x_paths = simulate_hw_state(
hw, curve_handle, time_grid, n_paths=10000
)
# Swap
swap = make_receiver_swap(
today, curve_handle, maturity_years=10, fixed_rate=0.025
)
# Exposure
ee = expected_exposure(
swap, hw, curve_handle, today, time_grid, x_paths
)
return 0
if __name__ =='__main__':
main()

View File

@@ -0,0 +1,66 @@
import QuantLib as ql
def build_calibrated_hw(
curve_handle,
swaption_helpers,
a_init=0.05,
sigma_init=0.01
):
hw = ql.HullWhite(curve_handle, a_init, sigma_init)
optimizer = ql.LevenbergMarquardt()
end_criteria = ql.EndCriteria(1000, 500, 1e-8, 1e-8, 1e-8)
hw.calibrate(swaption_helpers, optimizer, end_criteria)
return hw
import numpy as np
def simulate_hw_state(
hw: ql.HullWhite,
curve_handle,
time_grid,
n_paths,
seed=42
):
a, sigma = hw.params()
dt = np.diff(time_grid)
n_steps = len(time_grid)
np.random.seed(seed)
x = np.zeros((n_paths, n_steps))
for i in range(1, n_steps):
decay = np.exp(-a * dt[i-1])
vol = sigma * np.sqrt((1 - np.exp(-2 * a * dt[i-1])) / (2 * a))
z = np.random.normal(size=n_paths)
x[:, i] = x[:, i-1] * decay + vol * z
return x
def swap_npv_from_state(
swap,
hw,
curve_handle,
t,
x_t,
today
):
ql.Settings.instance().evaluationDate = today + int(t * 365)
hw.setState(x_t)
engine = ql.DiscountingSwapEngine(
curve_handle,
False,
hw
)
swap.setPricingEngine(engine)
return swap.NPV()

View File

@@ -0,0 +1,165 @@
import QuantLib as ql
# ============================================================
# Global defaults (override via parameters if needed)
# ============================================================
CALENDAR = ql.TARGET()
BUSINESS_CONVENTION = ql.ModifiedFollowing
DATE_GEN = ql.DateGeneration.Forward
FIXED_DAYCOUNT = ql.Thirty360(ql.Thirty360.European)
FLOAT_DAYCOUNT = ql.Actual360()
# ============================================================
# Yield curve factory
# ============================================================
def make_flat_curve(
evaluation_date: ql.Date,
rate: float,
day_count: ql.DayCounter = ql.Actual365Fixed()
) -> ql.YieldTermStructureHandle:
"""
Flat yield curve factory.
"""
ql.Settings.instance().evaluationDate = evaluation_date
curve = ql.FlatForward(evaluation_date, rate, day_count)
return ql.YieldTermStructureHandle(curve)
# ============================================================
# Index factory
# ============================================================
def make_euribor_6m(
curve_handle: ql.YieldTermStructureHandle
) -> ql.IborIndex:
"""
Euribor 6M index factory.
"""
return ql.Euribor6M(curve_handle)
# ============================================================
# Schedule factory
# ============================================================
def make_schedule(
start: ql.Date,
maturity: ql.Date,
tenor: ql.Period,
calendar: ql.Calendar = CALENDAR
) -> ql.Schedule:
"""
Generic schedule factory.
"""
return ql.Schedule(
start,
maturity,
tenor,
calendar,
BUSINESS_CONVENTION,
BUSINESS_CONVENTION,
DATE_GEN,
False
)
# ============================================================
# Swap factory
# ============================================================
def make_vanilla_swap(
evaluation_date: ql.Date,
curve_handle: ql.YieldTermStructureHandle,
notional: float,
fixed_rate: float,
maturity_years: int,
pay_fixed: bool = False,
fixed_leg_freq: ql.Period = ql.Annual,
float_leg_freq: ql.Period = ql.Semiannual
) -> ql.VanillaSwap:
"""
Vanilla fixed-for-float IRS factory.
pay_fixed = True -> payer swap
pay_fixed = False -> receiver swap
"""
ql.Settings.instance().evaluationDate = evaluation_date
index = make_euribor_6m(curve_handle)
start = CALENDAR.advance(evaluation_date, 2, ql.Days)
maturity = CALENDAR.advance(start, maturity_years, ql.Years)
fixed_schedule = make_schedule(
start, maturity, ql.Period(fixed_leg_freq)
)
float_schedule = make_schedule(
start, maturity, ql.Period(float_leg_freq)
)
swap_type = (
ql.VanillaSwap.Payer if pay_fixed else ql.VanillaSwap.Receiver
)
swap = ql.VanillaSwap(
swap_type,
notional,
fixed_schedule,
fixed_rate,
FIXED_DAYCOUNT,
float_schedule,
index,
0.0,
FLOAT_DAYCOUNT
)
swap.setPricingEngine(
ql.DiscountingSwapEngine(curve_handle)
)
return swap
# ============================================================
# Swaption helper factory (for HW calibration)
# ============================================================
def make_swaption_helpers(
swaption_data: list,
curve_handle: ql.YieldTermStructureHandle,
index: ql.IborIndex,
model: ql.HullWhite
) -> list:
"""
Create ATM swaption helpers.
swaption_data = [(expiry, tenor, vol), ...]
"""
helpers = []
for expiry, tenor, vol in swaption_data:
helper = ql.SwaptionHelper(
expiry,
tenor,
ql.QuoteHandle(ql.SimpleQuote(vol)),
index,
index.tenor(),
index.dayCounter(),
index.dayCounter(),
curve_handle
)
helper.setPricingEngine(
ql.JamshidianSwaptionEngine(model)
)
helpers.append(helper)
return helpers

109
python/cvatesting/main.py Normal file
View File

@@ -0,0 +1,109 @@
import QuantLib as ql
import numpy as np
import matplotlib.pyplot as plt
# --- UNIT 1: Instrument Setup ---
def create_30y_swap(today, rate_quote):
ql.Settings.instance().evaluationDate = today
# We use a Relinkable handle to swap curves per path/time-step
yield_handle = ql.RelinkableYieldTermStructureHandle()
yield_handle.linkTo(ql.FlatForward(today, rate_quote, ql.Actual365Fixed()))
calendar = ql.TARGET()
settle_date = calendar.advance(today, 2, ql.Days)
maturity_date = calendar.advance(settle_date, 30, ql.Years)
index = ql.Euribor6M(yield_handle)
fixed_schedule = ql.Schedule(settle_date, maturity_date, ql.Period(ql.Annual),
calendar, ql.ModifiedFollowing, ql.ModifiedFollowing,
ql.DateGeneration.Forward, False)
float_schedule = ql.Schedule(settle_date, maturity_date, ql.Period(ql.Semiannual),
calendar, ql.ModifiedFollowing, ql.ModifiedFollowing,
ql.DateGeneration.Forward, False)
swap = ql.VanillaSwap(ql.VanillaSwap.Payer, 1e6, fixed_schedule, rate_quote,
ql.Thirty360(ql.Thirty360.BondBasis), float_schedule,
index, 0.0, ql.Actual360())
# Pre-calculate all fixing dates needed for the life of the swap
fixing_dates = [index.fixingDate(d) for d in float_schedule]
return swap, yield_handle, index, fixing_dates
# --- UNIT 2: The Simulation Loop ---
def run_simulation(n_paths=50):
today = ql.Date(27, 1, 2026)
swap, yield_handle, index, fixing_dates = create_30y_swap(today, 0.03)
# HW Parameters: a=mean reversion, sigma=vol
model = ql.HullWhite(yield_handle, 0.01, 0.01)
process = ql.HullWhiteProcess(yield_handle, 0.01, 0.01)
times = np.arange(0, 31, 1.0) # Annual buckets
grid = ql.TimeGrid(times)
rng = ql.GaussianRandomSequenceGenerator(ql.UniformRandomSequenceGenerator(
len(grid)-1, ql.UniformRandomGenerator()))
seq = ql.GaussianMultiPathGenerator(process, grid, rng, False)
npv_matrix = np.zeros((n_paths, len(times)))
for i in range(n_paths):
path = seq.next().value()[0]
# 1. Clear previous path's fixings to avoid data pollution
ql.IndexManager.instance().clearHistories()
for j, t in enumerate(times):
if t >= 30: continue
eval_date = ql.TARGET().advance(today, int(t), ql.Years)
ql.Settings.instance().evaluationDate = eval_date
# 2. Update the curve with simulated short rate rt
rt = path[j]
# Use pillars up to 35Y to avoid extrapolation crashes
tenors = [0, 1, 2, 5, 10, 20, 30, 35]
dates = [ql.TARGET().advance(eval_date, y, ql.Years) for y in tenors]
discounts = [model.discountBond(t, t + y, rt) for y in tenors]
sim_curve = ql.DiscountCurve(dates, discounts, ql.Actual365Fixed())
sim_curve.enableExtrapolation()
yield_handle.linkTo(sim_curve)
# 3. MANUAL FIXING INJECTION
# For every fixing date that has passed or is 'today'
for fd in fixing_dates:
if fd <= eval_date:
# Direct manual calculation to avoid index.fixing() error
# We calculate the forward rate for the 6M period starting at fd
start_date = fd
end_date = ql.TARGET().advance(start_date, 6, ql.Months)
# Manual Euribor rate formula: (P1/P2 - 1) / dt
p1 = sim_curve.discount(start_date)
p2 = sim_curve.discount(end_date)
dt = ql.Actual360().yearFraction(start_date, end_date)
fwd_rate = (p1 / p2 - 1.0) / dt
index.addFixing(fd, fwd_rate, True) # Force overwrite
# 4. Valuation
swap.setPricingEngine(ql.DiscountingSwapEngine(yield_handle))
npv_matrix[i, j] = max(swap.NPV(), 0)
ql.Settings.instance().evaluationDate = today
return times, npv_matrix
# --- UNIT 3: Visualization ---
times, npv_matrix = run_simulation(n_paths=100)
ee = np.mean(npv_matrix, axis=0)
plt.figure(figsize=(10, 5))
plt.plot(times, ee, lw=2, label="Expected Exposure (EE)")
plt.fill_between(times, ee, alpha=0.3)
plt.title("30Y Swap EE Profile - Hull White")
plt.xlabel("Years"); plt.ylabel("Exposure"); plt.grid(True); plt.legend()
plt.show()

View File

@@ -0,0 +1,42 @@
# Expression Evaluator
## Overview
Educational project teaching DAGs and state machines through a calculator.
Pure Python, no external dependencies.
## Running
```bash
python main.py "3 + 4 * 2" # single expression
python main.py # REPL mode
python main.py --show-tokens --show-ast --trace "expr" # show internals
python main.py --dot "3+4*2" | dot -Tpng -o ast.png # AST diagram
python main.py --dot-fsm | dot -Tpng -o fsm.png # FSM diagram
```
## Testing
```bash
python -m pytest tests/ -v
```
## Architecture
- `tokenizer.py` -- Explicit finite state machine (Mealy machine) tokenizer
- `parser.py` -- Recursive descent parser building an AST (DAG)
- `evaluator.py` -- Post-order tree walker (topological sort evaluation)
- `visualize.py` -- Graphviz dot generation for AST and FSM diagrams
- `main.py` -- CLI entry point with argparse, REPL mode
## Key Design Decisions
- State machine uses an explicit transition table (dict), not implicit if/else
- Unary minus resolved by examining previous token context
- Power operator (`^`) is right-associative (grammar uses right-recursion)
- AST nodes are dataclasses; evaluation uses structural pattern matching
- Graphviz output is raw dot strings (no graphviz Python package needed)
## Grammar
```
expression ::= term ((PLUS | MINUS) term)*
term ::= unary ((MULTIPLY | DIVIDE) unary)*
unary ::= UNARY_MINUS unary | power
power ::= atom (POWER power)?
atom ::= NUMBER | LPAREN expression RPAREN
```

View File

@@ -0,0 +1,87 @@
# Expression Evaluator -- DAGs & State Machines Tutorial
A calculator that teaches two fundamental CS patterns by building them from scratch:
1. **Finite State Machine** -- the tokenizer processes input character-by-character using an explicit transition table
2. **Directed Acyclic Graph (DAG)** -- the parser builds an expression tree, evaluated bottom-up in topological order
## What You'll Learn
| File | CS Concept | What it does |
|------|-----------|-------------|
| `tokenizer.py` | **State Machine** (Mealy machine) | Converts `"3 + 4 * 2"` into tokens using a transition table |
| `parser.py` | **DAG construction** | Builds an expression tree with operator precedence |
| `evaluator.py` | **Topological evaluation** | Walks the tree bottom-up (leaves before parents) |
| `visualize.py` | **Visualization** | Generates graphviz diagrams of both the FSM and AST |
## Quick Start
```bash
# Evaluate an expression
python main.py "3 + 4 * 2"
# => 11
# Interactive REPL
python main.py
# See how the state machine tokenizes
python main.py --show-tokens "(2 + 3) * -4"
# See the expression tree (DAG)
python main.py --show-ast "(2 + 3) * 4"
# *
# +-- +
# | +-- 2
# | `-- 3
# `-- 4
# Watch evaluation in topological order
python main.py --trace "(2 + 3) * 4"
# Step 1: 2 => 2
# Step 2: 3 => 3
# Step 3: 2 + 3 => 5
# Step 4: 4 => 4
# Step 5: 5 * 4 => 20
# Generate graphviz diagrams
python main.py --dot "(2 + 3) * 4" | dot -Tpng -o ast.png
python main.py --dot-fsm | dot -Tpng -o fsm.png
```
## Features
- Arithmetic: `+`, `-`, `*`, `/`, `^` (power)
- Parentheses: `(2 + 3) * 4`
- Unary minus: `-3`, `-(2 + 1)`, `2 * -3`
- Decimals: `3.14`, `.5`
- Standard precedence: parens > `^` > `*`/`/` > `+`/`-`
- Right-associative power: `2^3^4` = `2^(3^4)`
- Correct unary minus: `-3^2` = `-(3^2)` = `-9`
## Running Tests
```bash
python -m pytest tests/ -v
```
## How the State Machine Works
The tokenizer in `tokenizer.py` uses an **explicit transition table** -- a dictionary mapping `(current_state, character_class)` to `(next_state, action)`. This is the same pattern used in network protocol parsers, regex engines, and compiler lexers.
The three states are:
- `START` -- between tokens, dispatching based on the next character
- `INTEGER` -- accumulating digits (e.g., `"12"` so far)
- `DECIMAL` -- accumulating digits after a decimal point (e.g., `"12.3"`)
Use `--dot-fsm` to generate a visual diagram of the state machine.
## How the DAG Works
The parser in `parser.py` builds an **expression tree** (AST) where:
- **Leaf nodes** are numbers (no dependencies)
- **Interior nodes** are operators with edges to their operands
- **Edges** represent "depends on" relationships
Evaluation in `evaluator.py` walks this tree **bottom-up** -- children before parents. This is exactly a **topological sort** of the DAG: you can only compute a node after all its dependencies are resolved.
Use `--show-ast` to see the tree structure, or `--dot` to generate a graphviz diagram.

View File

@@ -0,0 +1,147 @@
"""
Part 3: DAG Evaluation -- Tree Walker
=======================================
Evaluating the AST bottom-up is equivalent to topological-sort
evaluation of a DAG. We must evaluate a node's children before
the node itself -- just like in any dependency graph.
For a tree, post-order traversal gives a topological ordering.
The recursive evaluate() function naturally does this:
1. Recursively evaluate all children (dependencies)
2. Combine the results (compute this node's value)
3. Return the result (make it available to the parent)
This is the same pattern as:
- make: build dependencies before the target
- pip/npm install: install dependencies before the package
- Spreadsheet recalculation: compute referenced cells first
"""
from parser import NumberNode, BinOpNode, UnaryOpNode, Node
from tokenizer import TokenType
# ---------- Errors ----------
class EvalError(Exception):
pass
# ---------- Evaluator ----------
OP_SYMBOLS = {
TokenType.PLUS: '+',
TokenType.MINUS: '-',
TokenType.MULTIPLY: '*',
TokenType.DIVIDE: '/',
TokenType.POWER: '^',
TokenType.UNARY_MINUS: 'neg',
}
def evaluate(node):
"""
Evaluate an AST by walking it bottom-up (post-order traversal).
This is a recursive function that mirrors the DAG structure:
each recursive call follows a DAG edge to a child node.
Children are evaluated before parents -- topological order.
"""
match node:
case NumberNode(value=v):
return v
case UnaryOpNode(op=TokenType.UNARY_MINUS, operand=child):
return -evaluate(child)
case BinOpNode(op=op, left=left, right=right):
left_val = evaluate(left)
right_val = evaluate(right)
match op:
case TokenType.PLUS:
return left_val + right_val
case TokenType.MINUS:
return left_val - right_val
case TokenType.MULTIPLY:
return left_val * right_val
case TokenType.DIVIDE:
if right_val == 0:
raise EvalError("division by zero")
return left_val / right_val
case TokenType.POWER:
return left_val ** right_val
raise EvalError(f"unknown node type: {type(node)}")
def evaluate_traced(node):
"""
Like evaluate(), but records each step for educational display.
Returns (result, list_of_trace_lines).
The trace shows the topological evaluation order -- how the DAG
is evaluated from leaves to root. Each step shows a node being
evaluated after all its dependencies are resolved.
"""
steps = []
counter = [0] # mutable counter for step numbering
def _walk(node, depth):
indent = " " * depth
counter[0] += 1
step = counter[0]
match node:
case NumberNode(value=v):
result = v
display = _format_number(v)
steps.append(f"{indent}Step {step}: {display} => {_format_number(result)}")
return result
case UnaryOpNode(op=TokenType.UNARY_MINUS, operand=child):
child_val = _walk(child, depth + 1)
result = -child_val
counter[0] += 1
step = counter[0]
steps.append(
f"{indent}Step {step}: neg({_format_number(child_val)}) "
f"=> {_format_number(result)}"
)
return result
case BinOpNode(op=op, left=left, right=right):
left_val = _walk(left, depth + 1)
right_val = _walk(right, depth + 1)
sym = OP_SYMBOLS[op]
match op:
case TokenType.PLUS:
result = left_val + right_val
case TokenType.MINUS:
result = left_val - right_val
case TokenType.MULTIPLY:
result = left_val * right_val
case TokenType.DIVIDE:
if right_val == 0:
raise EvalError("division by zero")
result = left_val / right_val
case TokenType.POWER:
result = left_val ** right_val
counter[0] += 1
step = counter[0]
steps.append(
f"{indent}Step {step}: {_format_number(left_val)} {sym} "
f"{_format_number(right_val)} => {_format_number(result)}"
)
return result
raise EvalError(f"unknown node type: {type(node)}")
result = _walk(node, 0)
return result, steps
def _format_number(v):
"""Display a number as integer when possible."""
if isinstance(v, float) and v == int(v):
return str(int(v))
return str(v)

View File

@@ -0,0 +1,163 @@
"""
Expression Evaluator -- Learn DAGs & State Machines
====================================================
CLI entry point and interactive REPL.
Usage:
python main.py "3 + 4 * 2" # evaluate
python main.py # REPL mode
python main.py --show-tokens --show-ast --trace "expr" # show internals
python main.py --dot "3 + 4 * 2" | dot -Tpng -o ast.png
python main.py --dot-fsm | dot -Tpng -o fsm.png
"""
import argparse
import sys
from tokenizer import tokenize, TokenError
from parser import Parser, ParseError
from evaluator import evaluate, evaluate_traced, EvalError
from visualize import ast_to_dot, fsm_to_dot, ast_to_text
def process_expression(expr, args):
"""Tokenize, parse, and evaluate a single expression."""
try:
tokens = tokenize(expr)
except TokenError as e:
_print_error(expr, e)
return
if args.show_tokens:
print("\nTokens:")
for tok in tokens:
print(f" {tok}")
try:
ast = Parser(tokens).parse()
except ParseError as e:
_print_error(expr, e)
return
if args.show_ast:
print("\nAST (text tree):")
print(ast_to_text(ast))
if args.dot:
print(ast_to_dot(ast))
return # dot output goes to stdout, skip numeric result
if args.trace:
try:
result, steps = evaluate_traced(ast)
except EvalError as e:
print(f"Eval error: {e}")
return
print("\nEvaluation trace (topological order):")
for step in steps:
print(step)
print(f"\nResult: {_format_result(result)}")
else:
try:
result = evaluate(ast)
except EvalError as e:
print(f"Eval error: {e}")
return
print(_format_result(result))
def repl(args):
"""Interactive read-eval-print loop."""
print("Expression Evaluator REPL")
print("Type an expression, or 'quit' to exit.")
flags = []
if args.show_tokens:
flags.append("--show-tokens")
if args.show_ast:
flags.append("--show-ast")
if args.trace:
flags.append("--trace")
if flags:
print(f"Active flags: {' '.join(flags)}")
print()
while True:
try:
line = input(">>> ").strip()
except (EOFError, KeyboardInterrupt):
print()
break
if line.lower() in ("quit", "exit", "q"):
break
if not line:
continue
process_expression(line, args)
print()
def _print_error(expr, error):
"""Print an error with a caret pointing to the position."""
print(f"Error: {error}")
if hasattr(error, 'position') and error.position is not None:
print(f" {expr}")
print(f" {' ' * error.position}^")
def _format_result(v):
"""Format a numeric result: show as int when possible."""
if isinstance(v, float) and v == int(v) and abs(v) < 1e15:
return str(int(v))
return str(v)
def main():
arg_parser = argparse.ArgumentParser(
description="Expression Evaluator -- learn DAGs and state machines",
epilog="Examples:\n"
" python main.py '3 + 4 * 2'\n"
" python main.py --show-tokens --trace '-(3 + 4) ^ 2'\n"
" python main.py --dot '(2+3)*4' | dot -Tpng -o ast.png\n"
" python main.py --dot-fsm | dot -Tpng -o fsm.png",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
arg_parser.add_argument(
"expression", nargs="?",
help="Expression to evaluate (omit for REPL mode)",
)
arg_parser.add_argument(
"--show-tokens", action="store_true",
help="Display tokenizer output",
)
arg_parser.add_argument(
"--show-ast", action="store_true",
help="Display AST as indented text tree",
)
arg_parser.add_argument(
"--trace", action="store_true",
help="Show step-by-step evaluation trace",
)
arg_parser.add_argument(
"--dot", action="store_true",
help="Output AST as graphviz dot (pipe to: dot -Tpng -o ast.png)",
)
arg_parser.add_argument(
"--dot-fsm", action="store_true",
help="Output tokenizer FSM as graphviz dot",
)
args = arg_parser.parse_args()
# Special mode: just print the FSM diagram and exit
if args.dot_fsm:
print(fsm_to_dot())
return
# REPL mode if no expression given
if args.expression is None:
repl(args)
else:
process_expression(args.expression, args)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,217 @@
"""
Part 2: DAG Construction -- Recursive Descent Parser
=====================================================
A parser converts a flat list of tokens into a tree structure (AST).
The AST is a DAG (Directed Acyclic Graph) where:
- Nodes are operations (BinOpNode) or values (NumberNode)
- Edges point from parent operations to their operands
- The graph is acyclic because an operation's inputs are always
"simpler" sub-expressions (no circular dependencies)
- It is a tree (a special case of DAG) because no node is shared
This is the same structure as:
- Spreadsheet dependency graphs (cell A1 depends on B1, B2...)
- Build systems (Makefile targets depend on other targets)
- Task scheduling (some tasks must finish before others start)
- Neural network computation graphs (forward pass is a DAG)
Key DAG concepts demonstrated:
- Nodes: operations and values
- Directed edges: from operation to its inputs (dependencies)
- Acyclic: no circular dependencies
- Topological ordering: natural evaluation order (leaves first)
Grammar (BNF) -- precedence is encoded by nesting depth:
expression ::= term ((PLUS | MINUS) term)* # lowest precedence
term ::= unary ((MULTIPLY | DIVIDE) unary)*
unary ::= UNARY_MINUS unary | power
power ::= atom (POWER power)? # right-associative
atom ::= NUMBER | LPAREN expression RPAREN # highest precedence
Call chain: expression -> term -> unary -> power -> atom
This means: +/- binds loosest, then *//, then unary -, then ^, then parens.
So -3^2 = -(3^2) = -9, matching standard math convention.
"""
from dataclasses import dataclass
from tokenizer import Token, TokenType
# ---------- AST node types ----------
# These are the nodes of our DAG. Each node is either a leaf (NumberNode)
# or an interior node with edges pointing to its children (operands).
@dataclass
class NumberNode:
"""Leaf node: a numeric literal. In DAG terms, a node with no outgoing edges."""
value: float
def __repr__(self):
if self.value == int(self.value):
return f"NumberNode({int(self.value)})"
return f"NumberNode({self.value})"
@dataclass
class BinOpNode:
"""
Interior node: a binary operation with two children.
DAG edges: this node -> left, this node -> right
The edges represent "depends on": to compute this node's value,
we must first compute left and right.
"""
op: TokenType
left: 'NumberNode | BinOpNode | UnaryOpNode'
right: 'NumberNode | BinOpNode | UnaryOpNode'
def __repr__(self):
return f"BinOpNode({self.op.name}, {self.left}, {self.right})"
@dataclass
class UnaryOpNode:
"""Interior node: a unary operation (negation) with one child."""
op: TokenType
operand: 'NumberNode | BinOpNode | UnaryOpNode'
def __repr__(self):
return f"UnaryOpNode({self.op.name}, {self.operand})"
# Union type for any AST node
Node = NumberNode | BinOpNode | UnaryOpNode
# ---------- Errors ----------
class ParseError(Exception):
def __init__(self, message, position=None):
self.position = position
pos_info = f" at position {position}" if position is not None else ""
super().__init__(f"Parse error{pos_info}: {message}")
# ---------- Recursive descent parser ----------
class Parser:
"""
Converts a list of tokens into an AST (expression tree / DAG).
Each grammar rule becomes a method. The call tree mirrors the shape
of the AST being built. When a deeper method returns a node, it
becomes a child of the node built by the caller -- this is how
the DAG edges form.
Precedence is encoded by nesting: lower-precedence operators are
parsed at higher (outer) levels, so they become closer to the root
of the tree and are evaluated last.
"""
def __init__(self, tokens):
self.tokens = tokens
self.pos = 0
def peek(self):
"""Look at the current token without consuming it."""
return self.tokens[self.pos]
def consume(self, expected=None):
"""Consume and return the current token, optionally asserting its type."""
token = self.tokens[self.pos]
if expected is not None and token.type != expected:
raise ParseError(
f"expected {expected.name}, got {token.type.name}",
token.position,
)
self.pos += 1
return token
def parse(self):
"""Entry point: parse the full expression and verify we consumed everything."""
if self.peek().type == TokenType.EOF:
raise ParseError("empty expression")
node = self.expression()
self.consume(TokenType.EOF)
return node
# --- Grammar rules ---
# Each method corresponds to one production in the grammar.
# The nesting encodes operator precedence.
def expression(self):
"""expression ::= term ((PLUS | MINUS) term)*"""
node = self.term()
while self.peek().type in (TokenType.PLUS, TokenType.MINUS):
op_token = self.consume()
right = self.term()
# Build a new BinOpNode -- this creates a DAG edge from
# the new node to both 'node' (left) and 'right'
node = BinOpNode(op_token.type, node, right)
return node
def term(self):
"""term ::= unary ((MULTIPLY | DIVIDE) unary)*"""
node = self.unary()
while self.peek().type in (TokenType.MULTIPLY, TokenType.DIVIDE):
op_token = self.consume()
right = self.unary()
node = BinOpNode(op_token.type, node, right)
return node
def unary(self):
"""
unary ::= UNARY_MINUS unary | power
Unary minus is parsed here, between term and power, so it binds
looser than ^ but tighter than * and /. This gives the standard
math behavior: -3^2 = -(3^2) = -9.
The recursion (unary calls itself) handles double negation: --3 = 3.
"""
if self.peek().type == TokenType.UNARY_MINUS:
op_token = self.consume()
operand = self.unary()
return UnaryOpNode(op_token.type, operand)
return self.power()
def power(self):
"""
power ::= atom (POWER power)?
Right-recursive for right-associativity: 2^3^4 = 2^(3^4) = 2^81.
Compare with term() which uses a while loop for LEFT-associativity.
"""
node = self.atom()
if self.peek().type == TokenType.POWER:
op_token = self.consume()
right = self.power() # recurse (not loop) for right-associativity
node = BinOpNode(op_token.type, node, right)
return node
def atom(self):
"""
atom ::= NUMBER | LPAREN expression RPAREN
The base case: either a literal number or a parenthesized
sub-expression. Parentheses work by recursing back to
expression(), which restarts precedence parsing from the top.
"""
token = self.peek()
if token.type == TokenType.NUMBER:
self.consume()
return NumberNode(float(token.value))
if token.type == TokenType.LPAREN:
self.consume()
node = self.expression()
self.consume(TokenType.RPAREN)
return node
raise ParseError(
f"expected number or '(', got {token.type.name}",
token.position,
)

View File

@@ -0,0 +1,120 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import pytest
from tokenizer import tokenize
from parser import Parser
from evaluator import evaluate, evaluate_traced, EvalError
def eval_expr(expr):
"""Helper: tokenize -> parse -> evaluate in one step."""
tokens = tokenize(expr)
ast = Parser(tokens).parse()
return evaluate(ast)
# ---------- Basic arithmetic ----------
def test_addition():
assert eval_expr("3 + 4") == 7.0
def test_subtraction():
assert eval_expr("10 - 3") == 7.0
def test_multiplication():
assert eval_expr("3 * 4") == 12.0
def test_division():
assert eval_expr("10 / 4") == 2.5
def test_power():
assert eval_expr("2 ^ 10") == 1024.0
# ---------- Precedence ----------
def test_standard_precedence():
assert eval_expr("3 + 4 * 2") == 11.0
def test_parentheses():
assert eval_expr("(3 + 4) * 2") == 14.0
def test_power_precedence():
assert eval_expr("2 * 3 ^ 2") == 18.0
def test_right_associative_power():
# 2^(2^3) = 2^8 = 256
assert eval_expr("2 ^ 2 ^ 3") == 256.0
# ---------- Unary minus ----------
def test_negation():
assert eval_expr("-5") == -5.0
def test_double_negation():
assert eval_expr("--5") == 5.0
def test_negation_with_power():
# -(3^2) = -9, not (-3)^2 = 9
assert eval_expr("-3 ^ 2") == -9.0
def test_negation_in_parens():
assert eval_expr("(-3) ^ 2") == 9.0
# ---------- Decimals ----------
def test_decimal_addition():
assert eval_expr("0.1 + 0.2") == pytest.approx(0.3)
def test_leading_dot():
assert eval_expr(".5 + .5") == 1.0
# ---------- Edge cases ----------
def test_nested_parens():
assert eval_expr("((((3))))") == 3.0
def test_complex_expression():
assert eval_expr("(2 + 3) * (7 - 2) / 5 ^ 1") == 5.0
def test_long_chain():
assert eval_expr("1 + 2 + 3 + 4 + 5") == 15.0
def test_mixed_operations():
assert eval_expr("2 + 3 * 4 - 6 / 2") == 11.0
# ---------- Division by zero ----------
def test_division_by_zero():
with pytest.raises(EvalError):
eval_expr("1 / 0")
def test_division_by_zero_in_expression():
with pytest.raises(EvalError):
eval_expr("5 + 3 / (2 - 2)")
# ---------- Traced evaluation ----------
def test_traced_returns_correct_result():
tokens = tokenize("3 + 4 * 2")
ast = Parser(tokens).parse()
result, steps = evaluate_traced(ast)
assert result == 11.0
assert len(steps) > 0
def test_traced_step_count():
"""A simple binary op has 3 evaluation events: left, right, combine."""
tokens = tokenize("3 + 4")
ast = Parser(tokens).parse()
result, steps = evaluate_traced(ast)
assert result == 7.0
# NumberNode(3), NumberNode(4), BinOp(+)
assert len(steps) == 3

View File

@@ -0,0 +1,136 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import pytest
from tokenizer import tokenize, TokenType
from parser import Parser, ParseError, NumberNode, BinOpNode, UnaryOpNode
def parse(expr):
"""Helper: tokenize and parse in one step."""
return Parser(tokenize(expr)).parse()
# ---------- Basic parsing ----------
def test_parse_number():
ast = parse("42")
assert isinstance(ast, NumberNode)
assert ast.value == 42.0
def test_parse_decimal():
ast = parse("3.14")
assert isinstance(ast, NumberNode)
assert ast.value == 3.14
def test_parse_addition():
ast = parse("3 + 4")
assert isinstance(ast, BinOpNode)
assert ast.op == TokenType.PLUS
assert isinstance(ast.left, NumberNode)
assert isinstance(ast.right, NumberNode)
# ---------- Precedence ----------
def test_multiply_before_add():
"""3 + 4 * 2 should parse as 3 + (4 * 2)."""
ast = parse("3 + 4 * 2")
assert ast.op == TokenType.PLUS
assert isinstance(ast.right, BinOpNode)
assert ast.right.op == TokenType.MULTIPLY
def test_power_before_multiply():
"""2 * 3 ^ 4 should parse as 2 * (3 ^ 4)."""
ast = parse("2 * 3 ^ 4")
assert ast.op == TokenType.MULTIPLY
assert isinstance(ast.right, BinOpNode)
assert ast.right.op == TokenType.POWER
def test_parentheses_override_precedence():
"""(3 + 4) * 2 should parse as (3 + 4) * 2."""
ast = parse("(3 + 4) * 2")
assert ast.op == TokenType.MULTIPLY
assert isinstance(ast.left, BinOpNode)
assert ast.left.op == TokenType.PLUS
# ---------- Associativity ----------
def test_left_associative_subtraction():
"""10 - 3 - 2 should parse as (10 - 3) - 2."""
ast = parse("10 - 3 - 2")
assert ast.op == TokenType.MINUS
assert isinstance(ast.left, BinOpNode)
assert ast.left.op == TokenType.MINUS
assert isinstance(ast.right, NumberNode)
def test_power_right_associative():
"""2 ^ 3 ^ 4 should parse as 2 ^ (3 ^ 4)."""
ast = parse("2 ^ 3 ^ 4")
assert ast.op == TokenType.POWER
assert isinstance(ast.left, NumberNode)
assert isinstance(ast.right, BinOpNode)
assert ast.right.op == TokenType.POWER
# ---------- Unary minus ----------
def test_unary_minus():
ast = parse("-3")
assert isinstance(ast, UnaryOpNode)
assert ast.operand.value == 3.0
def test_double_negation():
ast = parse("--3")
assert isinstance(ast, UnaryOpNode)
assert isinstance(ast.operand, UnaryOpNode)
assert ast.operand.operand.value == 3.0
def test_unary_minus_precedence():
"""-3^2 should parse as -(3^2), not (-3)^2."""
ast = parse("-3 ^ 2")
assert isinstance(ast, UnaryOpNode)
assert isinstance(ast.operand, BinOpNode)
assert ast.operand.op == TokenType.POWER
def test_unary_minus_in_expression():
"""2 * -3 should parse as 2 * (-(3))."""
ast = parse("2 * -3")
assert ast.op == TokenType.MULTIPLY
assert isinstance(ast.right, UnaryOpNode)
# ---------- Nested parentheses ----------
def test_nested_parens():
ast = parse("((3))")
assert isinstance(ast, NumberNode)
assert ast.value == 3.0
def test_complex_nesting():
"""((2 + 3) * (7 - 2))"""
ast = parse("((2 + 3) * (7 - 2))")
assert isinstance(ast, BinOpNode)
assert ast.op == TokenType.MULTIPLY
# ---------- Errors ----------
def test_missing_rparen():
with pytest.raises(ParseError):
parse("(3 + 4")
def test_empty_expression():
with pytest.raises(ParseError):
parse("")
def test_trailing_operator():
with pytest.raises(ParseError):
parse("3 +")
def test_empty_parens():
with pytest.raises(ParseError):
parse("()")

View File

@@ -0,0 +1,139 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import pytest
from tokenizer import tokenize, TokenType, Token, TokenError
# ---------- Basic tokens ----------
def test_single_integer():
tokens = tokenize("42")
assert tokens[0].type == TokenType.NUMBER
assert tokens[0].value == "42"
def test_decimal_number():
tokens = tokenize("3.14")
assert tokens[0].type == TokenType.NUMBER
assert tokens[0].value == "3.14"
def test_leading_dot():
tokens = tokenize(".5")
assert tokens[0].type == TokenType.NUMBER
assert tokens[0].value == ".5"
def test_all_operators():
"""Operators between numbers are all binary."""
tokens = tokenize("1 + 1 - 1 * 1 / 1 ^ 1")
ops = [t.type for t in tokens if t.type not in (TokenType.NUMBER, TokenType.EOF)]
assert ops == [
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
TokenType.DIVIDE, TokenType.POWER,
]
def test_operators_between_numbers():
tokens = tokenize("1 + 2 - 3 * 4 / 5 ^ 6")
ops = [t.type for t in tokens if t.type not in (TokenType.NUMBER, TokenType.EOF)]
assert ops == [
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
TokenType.DIVIDE, TokenType.POWER,
]
def test_parentheses():
tokens = tokenize("()")
assert tokens[0].type == TokenType.LPAREN
assert tokens[1].type == TokenType.RPAREN
# ---------- Unary minus ----------
def test_unary_minus_at_start():
tokens = tokenize("-3")
assert tokens[0].type == TokenType.UNARY_MINUS
assert tokens[1].type == TokenType.NUMBER
def test_unary_minus_after_lparen():
tokens = tokenize("(-3)")
assert tokens[1].type == TokenType.UNARY_MINUS
def test_unary_minus_after_operator():
tokens = tokenize("2 * -3")
assert tokens[2].type == TokenType.UNARY_MINUS
def test_binary_minus():
tokens = tokenize("5 - 3")
assert tokens[1].type == TokenType.MINUS
def test_double_unary_minus():
tokens = tokenize("--3")
assert tokens[0].type == TokenType.UNARY_MINUS
assert tokens[1].type == TokenType.UNARY_MINUS
assert tokens[2].type == TokenType.NUMBER
# ---------- Whitespace handling ----------
def test_no_spaces():
tokens = tokenize("3+4")
non_eof = [t for t in tokens if t.type != TokenType.EOF]
assert len(non_eof) == 3
def test_extra_spaces():
tokens = tokenize(" 3 + 4 ")
non_eof = [t for t in tokens if t.type != TokenType.EOF]
assert len(non_eof) == 3
# ---------- Position tracking ----------
def test_positions():
tokens = tokenize("3 + 4")
assert tokens[0].position == 0 # '3'
assert tokens[1].position == 2 # '+'
assert tokens[2].position == 4 # '4'
# ---------- Errors ----------
def test_invalid_character():
with pytest.raises(TokenError):
tokenize("3 & 4")
def test_double_dot():
with pytest.raises(TokenError):
tokenize("3.14.15")
# ---------- EOF token ----------
def test_eof_always_present():
tokens = tokenize("42")
assert tokens[-1].type == TokenType.EOF
def test_empty_input():
tokens = tokenize("")
assert len(tokens) == 1
assert tokens[0].type == TokenType.EOF
# ---------- Complex expressions ----------
def test_complex_expression():
tokens = tokenize("(3 + 4.5) * -2 ^ 3")
types = [t.type for t in tokens if t.type != TokenType.EOF]
assert types == [
TokenType.LPAREN, TokenType.NUMBER, TokenType.PLUS,
TokenType.NUMBER, TokenType.RPAREN, TokenType.MULTIPLY,
TokenType.UNARY_MINUS, TokenType.NUMBER, TokenType.POWER,
TokenType.NUMBER,
]
def test_adjacent_parens():
tokens = tokenize("(3)(4)")
types = [t.type for t in tokens if t.type != TokenType.EOF]
assert types == [
TokenType.LPAREN, TokenType.NUMBER, TokenType.RPAREN,
TokenType.LPAREN, TokenType.NUMBER, TokenType.RPAREN,
]

View File

@@ -0,0 +1,306 @@
"""
Part 1: State Machine Tokenizer
================================
A tokenizer (lexer) converts raw text into a stream of tokens.
This implementation uses an EXPLICIT finite state machine (FSM):
- States are named values (an enum), not implicit control flow
- A transition table maps (current_state, input_class) -> (next_state, action)
- The main loop reads one character at a time and consults the table
This is the same pattern used in:
- Network protocol parsers (HTTP, TCP state machines)
- Regular expression engines
- Compiler front-ends (lexers for C, Python, etc.)
- Game AI (enemy behavior states)
Key FSM concepts demonstrated:
- States: the "memory" of what we're currently building
- Transitions: rules for moving between states based on input
- Actions: side effects (emit a token, accumulate a character)
- Mealy machine: outputs depend on both state AND input
"""
from dataclasses import dataclass
from enum import Enum
# ---------- Token types ----------
class TokenType(Enum):
NUMBER = "NUMBER"
PLUS = "PLUS"
MINUS = "MINUS"
MULTIPLY = "MULTIPLY"
DIVIDE = "DIVIDE"
POWER = "POWER"
LPAREN = "LPAREN"
RPAREN = "RPAREN"
UNARY_MINUS = "UNARY_MINUS"
EOF = "EOF"
@dataclass
class Token:
type: TokenType
value: str # raw text: "42", "+", "(", etc.
position: int # character offset in original expression
def __repr__(self):
return f"Token({self.type.name}, {self.value!r}, pos={self.position})"
OPERATOR_MAP = {
'+': TokenType.PLUS,
'-': TokenType.MINUS,
'*': TokenType.MULTIPLY,
'/': TokenType.DIVIDE,
'^': TokenType.POWER,
}
# ---------- FSM state definitions ----------
class State(Enum):
"""
The tokenizer's finite set of states.
START -- idle / between tokens, deciding what comes next
INTEGER -- accumulating digits of an integer (e.g. "12" so far)
DECIMAL -- accumulating digits after a decimal point (e.g. "12.3" so far)
"""
START = "START"
INTEGER = "INTEGER"
DECIMAL = "DECIMAL"
class CharClass(Enum):
"""
Character classification -- groups raw characters into categories
so the transition table stays small and readable.
"""
DIGIT = "DIGIT"
DOT = "DOT"
OPERATOR = "OPERATOR"
LPAREN = "LPAREN"
RPAREN = "RPAREN"
SPACE = "SPACE"
EOF = "EOF"
UNKNOWN = "UNKNOWN"
class Action(Enum):
"""
What the FSM does on a transition. In a Mealy machine, the output
(action) depends on both the current state AND the input.
"""
ACCUMULATE = "ACCUMULATE"
EMIT_NUMBER = "EMIT_NUMBER"
EMIT_OPERATOR = "EMIT_OPERATOR"
EMIT_LPAREN = "EMIT_LPAREN"
EMIT_RPAREN = "EMIT_RPAREN"
EMIT_NUMBER_THEN_OP = "EMIT_NUMBER_THEN_OP"
EMIT_NUMBER_THEN_LPAREN = "EMIT_NUMBER_THEN_LPAREN"
EMIT_NUMBER_THEN_RPAREN = "EMIT_NUMBER_THEN_RPAREN"
EMIT_NUMBER_THEN_DONE = "EMIT_NUMBER_THEN_DONE"
SKIP = "SKIP"
DONE = "DONE"
ERROR = "ERROR"
@dataclass(frozen=True)
class Transition:
next_state: State
action: Action
# ---------- Transition table ----------
# This is the heart of the state machine. Every (state, char_class) pair
# maps to exactly one transition: a next state and an action to perform.
# Making this a data structure (not nested if/else) means we can:
# 1. Inspect it programmatically (e.g. to generate a diagram)
# 2. Verify completeness (every combination is covered)
# 3. Understand the FSM at a glance
TRANSITIONS = {
# --- START: between tokens, dispatch based on character class ---
(State.START, CharClass.DIGIT): Transition(State.INTEGER, Action.ACCUMULATE),
(State.START, CharClass.DOT): Transition(State.DECIMAL, Action.ACCUMULATE),
(State.START, CharClass.OPERATOR): Transition(State.START, Action.EMIT_OPERATOR),
(State.START, CharClass.LPAREN): Transition(State.START, Action.EMIT_LPAREN),
(State.START, CharClass.RPAREN): Transition(State.START, Action.EMIT_RPAREN),
(State.START, CharClass.SPACE): Transition(State.START, Action.SKIP),
(State.START, CharClass.EOF): Transition(State.START, Action.DONE),
# --- INTEGER: accumulating digits like "123" ---
(State.INTEGER, CharClass.DIGIT): Transition(State.INTEGER, Action.ACCUMULATE),
(State.INTEGER, CharClass.DOT): Transition(State.DECIMAL, Action.ACCUMULATE),
(State.INTEGER, CharClass.OPERATOR): Transition(State.START, Action.EMIT_NUMBER_THEN_OP),
(State.INTEGER, CharClass.LPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_LPAREN),
(State.INTEGER, CharClass.RPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_RPAREN),
(State.INTEGER, CharClass.SPACE): Transition(State.START, Action.EMIT_NUMBER),
(State.INTEGER, CharClass.EOF): Transition(State.START, Action.EMIT_NUMBER_THEN_DONE),
# --- DECIMAL: accumulating digits after "." like "123.45" ---
(State.DECIMAL, CharClass.DIGIT): Transition(State.DECIMAL, Action.ACCUMULATE),
(State.DECIMAL, CharClass.DOT): Transition(State.START, Action.ERROR),
(State.DECIMAL, CharClass.OPERATOR): Transition(State.START, Action.EMIT_NUMBER_THEN_OP),
(State.DECIMAL, CharClass.LPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_LPAREN),
(State.DECIMAL, CharClass.RPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_RPAREN),
(State.DECIMAL, CharClass.SPACE): Transition(State.START, Action.EMIT_NUMBER),
(State.DECIMAL, CharClass.EOF): Transition(State.START, Action.EMIT_NUMBER_THEN_DONE),
}
# ---------- Errors ----------
class TokenError(Exception):
def __init__(self, message, position):
self.position = position
super().__init__(f"Token error at position {position}: {message}")
# ---------- Character classification ----------
def classify(ch):
"""Map a single character to its CharClass."""
if ch.isdigit():
return CharClass.DIGIT
if ch == '.':
return CharClass.DOT
if ch in OPERATOR_MAP:
return CharClass.OPERATOR
if ch == '(':
return CharClass.LPAREN
if ch == ')':
return CharClass.RPAREN
if ch.isspace():
return CharClass.SPACE
return CharClass.UNKNOWN
# ---------- Main tokenize function ----------
def tokenize(expression):
"""
Process an expression string through the state machine, producing tokens.
The main loop:
1. Classify the current character
2. Look up (state, char_class) in the transition table
3. Execute the action (accumulate, emit, skip, etc.)
4. Move to the next state
5. Advance to the next character
After all tokens are emitted, a post-processing step resolves
unary minus: if a MINUS token appears at the start, after an operator,
or after LPAREN, it is re-classified as UNARY_MINUS.
"""
state = State.START
buffer = [] # characters accumulated for the current token
buffer_start = 0 # position where the current buffer started
tokens = []
pos = 0
# Append a sentinel so EOF is handled uniformly in the loop
chars = expression + '\0'
while pos <= len(expression):
ch = chars[pos]
char_class = CharClass.EOF if pos == len(expression) else classify(ch)
if char_class == CharClass.UNKNOWN:
raise TokenError(f"unexpected character {ch!r}", pos)
# Look up the transition
key = (state, char_class)
transition = TRANSITIONS.get(key)
if transition is None:
raise TokenError(f"no transition for state={state.name}, input={char_class.name}", pos)
action = transition.action
next_state = transition.next_state
# --- Execute the action ---
if action == Action.ACCUMULATE:
if not buffer:
buffer_start = pos
buffer.append(ch)
elif action == Action.EMIT_NUMBER:
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
buffer.clear()
elif action == Action.EMIT_OPERATOR:
tokens.append(Token(OPERATOR_MAP[ch], ch, pos))
elif action == Action.EMIT_LPAREN:
tokens.append(Token(TokenType.LPAREN, ch, pos))
elif action == Action.EMIT_RPAREN:
tokens.append(Token(TokenType.RPAREN, ch, pos))
elif action == Action.EMIT_NUMBER_THEN_OP:
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
buffer.clear()
tokens.append(Token(OPERATOR_MAP[ch], ch, pos))
elif action == Action.EMIT_NUMBER_THEN_LPAREN:
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
buffer.clear()
tokens.append(Token(TokenType.LPAREN, ch, pos))
elif action == Action.EMIT_NUMBER_THEN_RPAREN:
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
buffer.clear()
tokens.append(Token(TokenType.RPAREN, ch, pos))
elif action == Action.EMIT_NUMBER_THEN_DONE:
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
buffer.clear()
elif action == Action.SKIP:
pass
elif action == Action.DONE:
pass
elif action == Action.ERROR:
raise TokenError(f"unexpected {ch!r} in state {state.name}", pos)
state = next_state
pos += 1
# --- Post-processing: resolve unary minus ---
# A MINUS is unary if it appears:
# - at the very start of the token stream
# - immediately after an operator (+, -, *, /, ^) or LPAREN
# This context-sensitivity cannot be captured by the FSM alone --
# it requires looking at previously emitted tokens.
_resolve_unary_minus(tokens)
tokens.append(Token(TokenType.EOF, '', len(expression)))
return tokens
def _resolve_unary_minus(tokens):
"""
Convert binary MINUS tokens to UNARY_MINUS where appropriate.
Why this isn't in the FSM: the FSM processes characters one at a time
and only tracks what kind of token it's currently building (its state).
But whether '-' is unary or binary depends on the PREVIOUS TOKEN --
information the FSM doesn't track. This is a common real-world pattern:
the lexer handles most work, then a lightweight post-pass adds context.
"""
unary_predecessor = {
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
TokenType.DIVIDE, TokenType.POWER, TokenType.LPAREN,
TokenType.UNARY_MINUS,
}
for i, token in enumerate(tokens):
if token.type != TokenType.MINUS:
continue
if i == 0 or tokens[i - 1].type in unary_predecessor:
tokens[i] = Token(TokenType.UNARY_MINUS, token.value, token.position)

View File

@@ -0,0 +1,200 @@
"""
Part 4: Visualization -- Graphviz Dot Output
==============================================
Generate graphviz dot-format strings for:
1. The tokenizer's finite state machine (FSM)
2. Any expression's AST (DAG)
3. Text-based tree rendering for the terminal
No external dependencies -- outputs raw dot strings that can be piped
to the 'dot' command: python main.py --dot "3+4*2" | dot -Tpng -o ast.png
"""
from parser import NumberNode, BinOpNode, UnaryOpNode, Node
from tokenizer import TRANSITIONS, State, CharClass, Action, TokenType
# ---------- FSM diagram ----------
# Human-readable labels for character classes
_CHAR_LABELS = {
CharClass.DIGIT: "digit",
CharClass.DOT: "'.'",
CharClass.OPERATOR: "op",
CharClass.LPAREN: "'('",
CharClass.RPAREN: "')'",
CharClass.SPACE: "space",
CharClass.EOF: "EOF",
}
# Short labels for actions
_ACTION_LABELS = {
Action.ACCUMULATE: "accum",
Action.EMIT_NUMBER: "emit num",
Action.EMIT_OPERATOR: "emit op",
Action.EMIT_LPAREN: "emit '('",
Action.EMIT_RPAREN: "emit ')'",
Action.EMIT_NUMBER_THEN_OP: "emit num+op",
Action.EMIT_NUMBER_THEN_LPAREN: "emit num+'('",
Action.EMIT_NUMBER_THEN_RPAREN: "emit num+')'",
Action.EMIT_NUMBER_THEN_DONE: "emit num, done",
Action.SKIP: "skip",
Action.DONE: "done",
Action.ERROR: "ERROR",
}
def fsm_to_dot():
"""
Generate a graphviz dot diagram of the tokenizer's state machine.
Reads the TRANSITIONS table directly -- because the FSM is data (a dict),
we can programmatically inspect and visualize it. This is a key advantage
of explicit state machines over implicit if/else control flow.
"""
lines = [
'digraph FSM {',
' rankdir=LR;',
' node [shape=circle, fontname="Helvetica"];',
' edge [fontname="Helvetica", fontsize=10];',
'',
' // Start indicator',
' __start__ [shape=point, width=0.2];',
' __start__ -> START;',
'',
]
# Collect edges grouped by (src, dst) to merge labels
edge_labels = {}
for (state, char_class), transition in TRANSITIONS.items():
src = state.name
dst = transition.next_state.name
char_label = _CHAR_LABELS.get(char_class, char_class.name)
action_label = _ACTION_LABELS.get(transition.action, transition.action.name)
label = f"{char_label} / {action_label}"
edge_labels.setdefault((src, dst), []).append(label)
# Emit edges
for (src, dst), labels in sorted(edge_labels.items()):
combined = "\\n".join(labels)
lines.append(f' {src} -> {dst} [label="{combined}"];')
lines.append('}')
return '\n'.join(lines)
# ---------- AST diagram ----------
_OP_LABELS = {
TokenType.PLUS: '+',
TokenType.MINUS: '-',
TokenType.MULTIPLY: '*',
TokenType.DIVIDE: '/',
TokenType.POWER: '^',
TokenType.UNARY_MINUS: 'neg',
}
def ast_to_dot(node):
"""
Generate a graphviz dot diagram of an AST (expression tree / DAG).
Each node gets a unique ID. Edges go from parent to children,
showing the directed acyclic structure. Leaves are boxed,
operators are ellipses.
"""
lines = [
'digraph AST {',
' node [fontname="Helvetica"];',
' edge [fontname="Helvetica"];',
'',
]
counter = [0]
def _visit(node):
nid = f"n{counter[0]}"
counter[0] += 1
match node:
case NumberNode(value=v):
label = _format_number(v)
lines.append(f' {nid} [label="{label}", shape=box, style=rounded];')
return nid
case UnaryOpNode(op=op, operand=child):
label = _OP_LABELS.get(op, op.name)
lines.append(f' {nid} [label="{label}", shape=ellipse];')
child_id = _visit(child)
lines.append(f' {nid} -> {child_id};')
return nid
case BinOpNode(op=op, left=left, right=right):
label = _OP_LABELS.get(op, op.name)
lines.append(f' {nid} [label="{label}", shape=ellipse];')
left_id = _visit(left)
right_id = _visit(right)
lines.append(f' {nid} -> {left_id} [label="L"];')
lines.append(f' {nid} -> {right_id} [label="R"];')
return nid
_visit(node)
lines.append('}')
return '\n'.join(lines)
# ---------- Text-based tree ----------
def ast_to_text(node, prefix="", connector=""):
"""
Render the AST as an indented text tree for terminal display.
Example output for (2 + 3) * 4:
*
+-- +
| +-- 2
| +-- 3
+-- 4
"""
match node:
case NumberNode(value=v):
label = _format_number(v)
case UnaryOpNode(op=op):
label = _OP_LABELS.get(op, op.name)
case BinOpNode(op=op):
label = _OP_LABELS.get(op, op.name)
lines = [f"{prefix}{connector}{label}"]
children = _get_children(node)
for i, child in enumerate(children):
is_last_child = (i == len(children) - 1)
if connector:
# Extend the prefix: if we used "+-- " then next children
# see "| " (continuing) or " " (last child)
child_prefix = prefix + ("| " if connector == "+-- " else " ")
else:
child_prefix = prefix
child_connector = "+-- " if is_last_child else "+-- "
# Use a different lead for non-last: the vertical bar continues
child_connector = "`-- " if is_last_child else "+-- "
child_lines = ast_to_text(child, child_prefix, child_connector)
lines.append(child_lines)
return '\n'.join(lines)
def _get_children(node):
match node:
case NumberNode():
return []
case UnaryOpNode(operand=child):
return [child]
case BinOpNode(left=left, right=right):
return [left, right]
return []
def _format_number(v):
if isinstance(v, float) and v == int(v):
return str(int(v))
return str(v)

View File

@@ -0,0 +1,38 @@
# Persian Language Tutor
## Overview
Gradio-based Persian (Farsi) language learning app for English speakers, using GCSE Persian vocabulary (Pearson spec) as seed data.
## Tech Stack
- **Frontend**: Gradio (browser handles RTL natively)
- **Spaced repetition**: py-fsrs (same algorithm as Anki)
- **AI**: Ollama (fast, local) + Claude CLI (smart, subprocess)
- **STT**: faster-whisper via sttlib from tool-speechtotext
- **Anki export**: genanki for .apkg generation
- **Database**: SQLite (file-based, data/progress.db)
- **Environment**: `whisper-ollama` conda env
## Running
```bash
mamba run -n whisper-ollama python app.py
```
## Testing
```bash
mamba run -n whisper-ollama python -m pytest tests/
```
## Key Paths
- `data/vocabulary.json` — GCSE vocabulary data
- `data/progress.db` — SQLite database (auto-created)
- `app.py` — Gradio entry point
- `db.py` — Database layer with FSRS integration
- `ai.py` — Dual AI backend (Ollama + Claude)
- `stt.py` — Persian speech-to-text wrapper
- `modules/` — Feature modules (vocab, dashboard, essay, tutor, idioms)
## Architecture
- Single-process Gradio app with shared SQLite connection
- FSRS Card objects serialized as JSON in SQLite TEXT columns
- Timestamps stored as ISO-8601 strings
- sttlib imported via sys.path from tool-speechtotext project

View File

@@ -0,0 +1,57 @@
# Persian Language Tutor
A Gradio-based Persian (Farsi) language learning app for English speakers, built around GCSE Persian vocabulary (Pearson specification).
## Features
- **Vocabulary Study** — Search, browse, and study 918 GCSE Persian words across 39 categories
- **Flashcards with FSRS** — Spaced repetition scheduling (same algorithm as Anki)
- **Idioms & Expressions** — 25 Persian social conventions with cultural context
- **AI Tutor** — Conversational Persian lessons by GCSE theme (via Ollama)
- **Essay Marking** — Write Persian essays, get AI feedback and grading (via Claude)
- **Dashboard** — Track progress, streaks, and mastery
- **Anki Export** — Generate .apkg decks for offline study
- **Voice Input** — Speak Persian via microphone (Whisper STT) in the Tutor tab
## Prerequisites
- `whisper-ollama` conda environment with Python 3.10+
- Ollama running locally with `qwen2.5:7b` (or another model)
- Claude CLI installed (for essay marking / smart mode)
## Setup
```bash
/home/ys/miniforge3/envs/whisper-ollama/bin/pip install gradio genanki fsrs
```
## Running the app
```bash
cd /home/ys/family-repo/Code/python/persian-tutor
/home/ys/miniforge3/envs/whisper-ollama/bin/python app.py
```
Then open http://localhost:7860 in your browser.
## Running tests
```bash
cd /home/ys/family-repo/Code/python/persian-tutor
/home/ys/miniforge3/envs/whisper-ollama/bin/python -m pytest tests/ -v
```
41 tests covering db, vocab, ai, and anki_export modules.
## Expanding vocabulary
The vocabulary can be expanded by editing `data/vocabulary.json` directly or by updating `scripts/build_vocab.py` and re-running it:
```bash
/home/ys/miniforge3/envs/whisper-ollama/bin/python scripts/build_vocab.py
```
## TODO
- [ ] Voice-based vocabulary testing — answer flashcard prompts by speaking Persian
- [ ] Improved UI theme and layout polish

View File

@@ -0,0 +1,56 @@
"""Dual AI backend: Ollama (fast/local) and Claude CLI (smart)."""
import subprocess
import ollama
DEFAULT_OLLAMA_MODEL = "qwen2.5:7b"
_ollama_model = DEFAULT_OLLAMA_MODEL
def set_ollama_model(model):
"""Change the Ollama model used for fast queries."""
global _ollama_model
_ollama_model = model
def ask_ollama(prompt, system=None, model=None):
"""Query Ollama with an optional system prompt."""
model = model or _ollama_model
messages = []
if system:
messages.append({"role": "system", "content": system})
messages.append({"role": "user", "content": prompt})
response = ollama.chat(model=model, messages=messages)
return response.message.content
def ask_claude(prompt):
"""Query Claude via the CLI subprocess."""
result = subprocess.run(
["claude", "-p", prompt],
capture_output=True,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"Claude CLI failed (exit {result.returncode}): {result.stderr.strip()}")
return result.stdout.strip()
def ask(prompt, system=None, quality="fast"):
"""Unified interface. quality='fast' uses Ollama, 'smart' uses Claude."""
if quality == "smart":
return ask_claude(prompt)
return ask_ollama(prompt, system=system)
def chat_ollama(messages, system=None, model=None):
"""Multi-turn conversation with Ollama."""
model = model or _ollama_model
all_messages = []
if system:
all_messages.append({"role": "system", "content": system})
all_messages.extend(messages)
response = ollama.chat(model=model, messages=all_messages)
return response.message.content

View File

@@ -0,0 +1,75 @@
"""Generate Anki .apkg decks from vocabulary data."""
import genanki
# Stable model/deck IDs (generated once, kept constant)
_MODEL_ID = 1607392319
_DECK_ID = 2059400110
def _make_model():
"""Create an Anki note model with two card templates."""
return genanki.Model(
_MODEL_ID,
"GCSE Persian",
fields=[
{"name": "English"},
{"name": "Persian"},
{"name": "Finglish"},
{"name": "Category"},
],
templates=[
{
"name": "English → Persian",
"qfmt": '<div style="font-size:1.5em">{{English}}</div>'
'<br><small>{{Category}}</small>',
"afmt": '{{FrontSide}}<hr id="answer">'
'<div dir="rtl" style="font-size:2em">{{Persian}}</div>'
"<br><div>{{Finglish}}</div>",
},
{
"name": "Persian → English",
"qfmt": '<div dir="rtl" style="font-size:2em">{{Persian}}</div>'
'<br><small>{{Category}}</small>',
"afmt": '{{FrontSide}}<hr id="answer">'
'<div style="font-size:1.5em">{{English}}</div>'
"<br><div>{{Finglish}}</div>",
},
],
css=".card { font-family: arial; text-align: center; }",
)
def export_deck(vocab, categories=None, output_path="gcse-persian.apkg"):
"""Generate an Anki .apkg deck from vocabulary entries.
Args:
vocab: List of vocabulary entries (dicts with english, persian, finglish, category).
categories: Optional list of categories to include. None = all.
output_path: Where to save the .apkg file.
Returns:
Path to the generated .apkg file.
"""
model = _make_model()
deck = genanki.Deck(_DECK_ID, "GCSE Persian")
for entry in vocab:
if categories and entry.get("category") not in categories:
continue
note = genanki.Note(
model=model,
fields=[
entry.get("english", ""),
entry.get("persian", ""),
entry.get("finglish", ""),
entry.get("category", ""),
],
guid=genanki.guid_for(entry.get("id", entry["english"])),
)
deck.add_note(note)
package = genanki.Package(deck)
package.write_to_file(output_path)
return output_path

525
python/persian-tutor/app.py Normal file
View File

@@ -0,0 +1,525 @@
"""Persian Language Tutor — Gradio UI."""
import json
import os
import tempfile
import time
import gradio as gr
import ai
import db
from modules import vocab, dashboard, essay, tutor, idioms
from modules.essay import GCSE_THEMES
from modules.tutor import THEME_PROMPTS
from anki_export import export_deck
# ---------- Initialise ----------
db.init_db()
vocabulary = vocab.load_vocab()
categories = ["All"] + vocab.get_categories()
# ---------- Helper ----------
def _rtl(text, size="2em"):
return f'<div dir="rtl" style="font-size:{size}; text-align:center">{text}</div>'
# ================================================================
# TAB HANDLERS
# ================================================================
# ---------- Dashboard ----------
def refresh_dashboard():
overview_md = dashboard.format_overview_markdown()
cat_data = dashboard.get_category_breakdown()
quiz_data = dashboard.get_recent_quizzes()
return overview_md, cat_data, quiz_data
# ---------- Vocabulary Search ----------
def do_search(query, category):
results = vocab.search(query)
if category and category != "All":
results = [r for r in results if r["category"] == category]
if not results:
return "No results found."
lines = []
for r in results:
status = vocab.get_word_status(r["id"])
icon = {"new": "", "learning": "🟨", "mastered": "🟩"}.get(status, "")
lines.append(
f'{icon} **{r["english"]}** — '
f'<span dir="rtl">{r["persian"]}</span>'
f' ({r.get("finglish", "")})'
)
return "\n\n".join(lines)
def do_random_word(category, transliteration):
entry = vocab.get_random_word(category=category)
if not entry:
return "No words found."
return vocab.format_word_card(entry, show_transliteration=transliteration)
# ---------- Flashcards ----------
def start_flashcards(category, direction):
batch = vocab.get_flashcard_batch(count=10, category=category)
if not batch:
return "No words available.", [], 0, 0, "", gr.update(visible=False)
first = batch[0]
if direction == "English → Persian":
prompt = f'<div style="font-size:2em; text-align:center">{first["english"]}</div>'
else:
prompt = _rtl(first["persian"])
return (
prompt, # card_display
batch, # batch state
0, # current index
0, # score
"", # answer_box cleared
gr.update(visible=True), # answer_area visible
)
def submit_answer(user_answer, batch, index, score, direction, transliteration):
if not batch or index >= len(batch):
return "Session complete!", batch, index, score, "", gr.update(visible=False), ""
entry = batch[index]
dir_key = "en_to_fa" if direction == "English → Persian" else "fa_to_en"
is_correct, correct_answer, _ = vocab.check_answer(entry["id"], user_answer, direction=dir_key)
if is_correct:
score += 1
result = "✅ **Correct!**"
else:
result = f"❌ **Incorrect.** The answer is: "
if dir_key == "en_to_fa":
result += f'<span dir="rtl">{correct_answer}</span>'
else:
result += correct_answer
card_info = vocab.format_word_card(entry, show_transliteration=transliteration)
feedback = f"{result}\n\n{card_info}\n\n---\n*Rate your recall to continue:*"
return feedback, batch, index, score, "", gr.update(visible=True), ""
def rate_and_next(rating_str, batch, index, score, direction):
if not batch or index >= len(batch):
return "Session complete!", batch, index, score, gr.update(visible=False)
import fsrs as fsrs_mod
rating_map = {
"Again": fsrs_mod.Rating.Again,
"Hard": fsrs_mod.Rating.Hard,
"Good": fsrs_mod.Rating.Good,
"Easy": fsrs_mod.Rating.Easy,
}
rating = rating_map.get(rating_str, fsrs_mod.Rating.Good)
entry = batch[index]
db.update_word_progress(entry["id"], rating)
index += 1
if index >= len(batch):
summary = f"## Session Complete!\n\n**Score:** {score}/{len(batch)}\n\n"
summary += f"**Accuracy:** {score/len(batch)*100:.0f}%"
return summary, batch, index, score, gr.update(visible=False)
next_entry = batch[index]
if direction == "English → Persian":
prompt = f'<div style="font-size:2em; text-align:center">{next_entry["english"]}</div>'
else:
prompt = _rtl(next_entry["persian"])
return prompt, batch, index, score, gr.update(visible=True)
# ---------- Idioms ----------
def show_random_idiom(transliteration):
expr = idioms.get_random_expression()
return idioms.format_expression(expr, show_transliteration=transliteration), expr
def explain_idiom(expr_state):
if not expr_state:
return "Pick an idiom first."
return idioms.explain_expression(expr_state)
def browse_idioms(transliteration):
exprs = idioms.get_all_expressions()
lines = []
for e in exprs:
line = f'**<span dir="rtl">{e["persian"]}</span>** — {e["english"]}'
if transliteration != "off":
line += f' *({e["finglish"]})*'
lines.append(line)
return "\n\n".join(lines)
# ---------- Tutor ----------
def start_tutor_lesson(theme):
response, messages, system = tutor.start_lesson(theme)
chat_history = [{"role": "assistant", "content": response}]
return chat_history, messages, system, time.time()
def send_tutor_message(user_msg, chat_history, messages, system, audio_input):
# Use STT if audio provided and no text
if audio_input is not None and (not user_msg or not user_msg.strip()):
try:
from stt import transcribe_persian
user_msg = transcribe_persian(audio_input)
except Exception:
user_msg = ""
if not user_msg or not user_msg.strip():
return chat_history, messages, "", None
response, messages = tutor.process_response(user_msg, messages, system=system)
chat_history.append({"role": "user", "content": user_msg})
chat_history.append({"role": "assistant", "content": response})
return chat_history, messages, "", None
def save_tutor(theme, messages, start_time):
if messages and len(messages) > 1:
tutor.save_session(theme, messages, start_time)
return "Session saved!"
return "Nothing to save."
# ---------- Essay ----------
def submit_essay(text, theme):
if not text or not text.strip():
return "Please write an essay first."
return essay.mark_essay(text, theme)
def load_essay_history():
return essay.get_essay_history()
# ---------- Settings / Export ----------
def do_anki_export(cats_selected):
v = vocab.load_vocab()
cats = cats_selected if cats_selected else None
path = os.path.join(tempfile.gettempdir(), "gcse-persian.apkg")
export_deck(v, categories=cats, output_path=path)
return path
def update_ollama_model(model):
ai.set_ollama_model(model)
def update_whisper_size(size):
from stt import set_whisper_size
set_whisper_size(size)
def reset_progress():
conn = db.get_connection()
conn.execute("DELETE FROM word_progress")
conn.execute("DELETE FROM quiz_sessions")
conn.execute("DELETE FROM essays")
conn.execute("DELETE FROM tutor_sessions")
conn.commit()
return "Progress reset."
# ================================================================
# GRADIO UI
# ================================================================
with gr.Blocks(title="Persian Language Tutor") as app:
gr.Markdown("# 🇮🇷 Persian Language Tutor\n*GCSE Persian vocabulary with spaced repetition*")
# Shared state
transliteration_state = gr.State(value="Finglish")
with gr.Tabs():
# ==================== DASHBOARD ====================
with gr.Tab("📊 Dashboard"):
overview_md = gr.Markdown("Loading...")
with gr.Row():
cat_table = gr.Dataframe(
headers=["Category", "Total", "Seen", "Mastered", "Progress"],
label="Category Breakdown",
)
quiz_table = gr.Dataframe(
headers=["Date", "Category", "Score", "Duration"],
label="Recent Quizzes",
)
refresh_btn = gr.Button("Refresh", variant="secondary")
refresh_btn.click(
fn=refresh_dashboard,
outputs=[overview_md, cat_table, quiz_table],
)
# ==================== VOCABULARY ====================
with gr.Tab("📚 Vocabulary"):
with gr.Row():
search_box = gr.Textbox(
label="Search (English or Persian)",
placeholder="Type to search...",
)
vocab_cat = gr.Dropdown(
choices=categories, value="All", label="Category"
)
search_btn = gr.Button("Search", variant="primary")
random_btn = gr.Button("Random Word")
search_results = gr.Markdown("Search for a word above.")
search_btn.click(
fn=do_search,
inputs=[search_box, vocab_cat],
outputs=[search_results],
)
search_box.submit(
fn=do_search,
inputs=[search_box, vocab_cat],
outputs=[search_results],
)
random_btn.click(
fn=do_random_word,
inputs=[vocab_cat, transliteration_state],
outputs=[search_results],
)
# ==================== FLASHCARDS ====================
with gr.Tab("🃏 Flashcards"):
with gr.Row():
fc_category = gr.Dropdown(
choices=categories, value="All", label="Category"
)
fc_direction = gr.Radio(
["English → Persian", "Persian → English"],
value="English → Persian",
label="Direction",
)
start_fc_btn = gr.Button("Start Session", variant="primary")
card_display = gr.Markdown("Press 'Start Session' to begin.")
# Hidden states
fc_batch = gr.State([])
fc_index = gr.State(0)
fc_score = gr.State(0)
with gr.Group(visible=False) as answer_area:
answer_box = gr.Textbox(
label="Your answer",
placeholder="Type your answer...",
rtl=True,
)
submit_ans_btn = gr.Button("Submit Answer", variant="primary")
answer_feedback = gr.Markdown("")
with gr.Row():
btn_again = gr.Button("Again", variant="stop")
btn_hard = gr.Button("Hard", variant="secondary")
btn_good = gr.Button("Good", variant="primary")
btn_easy = gr.Button("Easy", variant="secondary")
start_fc_btn.click(
fn=start_flashcards,
inputs=[fc_category, fc_direction],
outputs=[card_display, fc_batch, fc_index, fc_score, answer_box, answer_area],
)
submit_ans_btn.click(
fn=submit_answer,
inputs=[answer_box, fc_batch, fc_index, fc_score, fc_direction, transliteration_state],
outputs=[card_display, fc_batch, fc_index, fc_score, answer_box, answer_area, answer_feedback],
)
answer_box.submit(
fn=submit_answer,
inputs=[answer_box, fc_batch, fc_index, fc_score, fc_direction, transliteration_state],
outputs=[card_display, fc_batch, fc_index, fc_score, answer_box, answer_area, answer_feedback],
)
for btn, label in [(btn_again, "Again"), (btn_hard, "Hard"), (btn_good, "Good"), (btn_easy, "Easy")]:
btn.click(
fn=rate_and_next,
inputs=[gr.State(label), fc_batch, fc_index, fc_score, fc_direction],
outputs=[card_display, fc_batch, fc_index, fc_score, answer_area],
)
# ==================== IDIOMS ====================
with gr.Tab("💬 Idioms & Expressions"):
idiom_display = gr.Markdown("Click 'Random Idiom' or browse below.")
idiom_state = gr.State(None)
with gr.Row():
random_idiom_btn = gr.Button("Random Idiom", variant="primary")
explain_idiom_btn = gr.Button("Explain Usage")
browse_idiom_btn = gr.Button("Browse All")
idiom_explanation = gr.Markdown("")
random_idiom_btn.click(
fn=show_random_idiom,
inputs=[transliteration_state],
outputs=[idiom_display, idiom_state],
)
explain_idiom_btn.click(
fn=explain_idiom,
inputs=[idiom_state],
outputs=[idiom_explanation],
)
browse_idiom_btn.click(
fn=browse_idioms,
inputs=[transliteration_state],
outputs=[idiom_display],
)
# ==================== TUTOR ====================
with gr.Tab("🎓 Tutor"):
tutor_theme = gr.Dropdown(
choices=list(THEME_PROMPTS.keys()),
value="Identity and culture",
label="Theme",
)
start_lesson_btn = gr.Button("New Lesson", variant="primary")
chatbot = gr.Chatbot(label="Conversation")
# Tutor states
tutor_messages = gr.State([])
tutor_system = gr.State("")
tutor_start_time = gr.State(0)
with gr.Row():
tutor_input = gr.Textbox(
label="Your message",
placeholder="Type in English or Persian...",
scale=3,
)
tutor_mic = gr.Audio(
sources=["microphone"],
type="numpy",
label="Speak",
scale=1,
)
send_btn = gr.Button("Send", variant="primary")
save_btn = gr.Button("Save Session", variant="secondary")
save_status = gr.Markdown("")
start_lesson_btn.click(
fn=start_tutor_lesson,
inputs=[tutor_theme],
outputs=[chatbot, tutor_messages, tutor_system, tutor_start_time],
)
send_btn.click(
fn=send_tutor_message,
inputs=[tutor_input, chatbot, tutor_messages, tutor_system, tutor_mic],
outputs=[chatbot, tutor_messages, tutor_input, tutor_mic],
)
tutor_input.submit(
fn=send_tutor_message,
inputs=[tutor_input, chatbot, tutor_messages, tutor_system, tutor_mic],
outputs=[chatbot, tutor_messages, tutor_input, tutor_mic],
)
save_btn.click(
fn=save_tutor,
inputs=[tutor_theme, tutor_messages, tutor_start_time],
outputs=[save_status],
)
# ==================== ESSAY ====================
with gr.Tab("✍️ Essay"):
essay_theme = gr.Dropdown(
choices=GCSE_THEMES,
value="Identity and culture",
label="Theme",
)
essay_input = gr.Textbox(
label="Write your essay in Persian",
lines=10,
rtl=True,
placeholder="اینجا بنویسید...",
)
submit_essay_btn = gr.Button("Submit for Marking", variant="primary")
essay_feedback = gr.Markdown("Write an essay and submit for AI marking.")
gr.Markdown("### Essay History")
essay_history_table = gr.Dataframe(
headers=["Date", "Theme", "Grade", "Preview"],
label="Past Essays",
)
refresh_essays_btn = gr.Button("Refresh History")
submit_essay_btn.click(
fn=submit_essay,
inputs=[essay_input, essay_theme],
outputs=[essay_feedback],
)
refresh_essays_btn.click(
fn=load_essay_history,
outputs=[essay_history_table],
)
# ==================== SETTINGS ====================
with gr.Tab("⚙️ Settings"):
gr.Markdown("## Settings")
transliteration_radio = gr.Radio(
["off", "Finglish", "Academic"],
value="Finglish",
label="Transliteration",
)
ollama_model = gr.Textbox(
label="Ollama Model",
value="qwen2.5:7b",
info="Model used for fast AI responses",
)
whisper_size = gr.Dropdown(
choices=["tiny", "base", "small", "medium", "large-v3"],
value="medium",
label="Whisper Model Size",
)
gr.Markdown("### Anki Export")
export_cats = gr.Dropdown(
choices=vocab.get_categories(),
multiselect=True,
label="Categories to export (empty = all)",
)
export_btn = gr.Button("Export to Anki (.apkg)", variant="primary")
export_file = gr.File(label="Download")
export_btn.click(fn=do_anki_export, inputs=[export_cats], outputs=[export_file])
# Wire model settings
ollama_model.change(fn=update_ollama_model, inputs=[ollama_model])
whisper_size.change(fn=update_whisper_size, inputs=[whisper_size])
gr.Markdown("### Reset")
reset_btn = gr.Button("Reset All Progress", variant="stop")
reset_status = gr.Markdown("")
reset_btn.click(fn=reset_progress, outputs=[reset_status])
# Wire transliteration state
transliteration_radio.change(
fn=lambda x: x,
inputs=[transliteration_radio],
outputs=[transliteration_state],
)
# Load dashboard on app start
app.load(fn=refresh_dashboard, outputs=[overview_md, cat_table, quiz_table])
if __name__ == "__main__":
app.launch(theme=gr.themes.Soft())

File diff suppressed because it is too large Load Diff

241
python/persian-tutor/db.py Normal file
View File

@@ -0,0 +1,241 @@
"""SQLite database layer with FSRS spaced repetition integration."""
import json
import sqlite3
from datetime import datetime, timedelta, timezone
from pathlib import Path
import fsrs
DB_PATH = Path(__file__).parent / "data" / "progress.db"
_conn = None
_scheduler = fsrs.Scheduler()
def get_connection():
"""Return the shared SQLite connection (singleton)."""
global _conn
if _conn is None:
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
_conn = sqlite3.connect(str(DB_PATH), check_same_thread=False)
_conn.row_factory = sqlite3.Row
_conn.execute("PRAGMA journal_mode=WAL")
return _conn
def init_db():
"""Create all tables if they don't exist. Called once at startup."""
conn = get_connection()
conn.executescript("""
CREATE TABLE IF NOT EXISTS word_progress (
word_id TEXT PRIMARY KEY,
fsrs_state TEXT,
due TIMESTAMP,
stability REAL,
difficulty REAL,
reps INTEGER DEFAULT 0,
lapses INTEGER DEFAULT 0,
last_review TIMESTAMP
);
CREATE TABLE IF NOT EXISTS quiz_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
category TEXT,
total_questions INTEGER,
correct INTEGER,
duration_seconds INTEGER
);
CREATE TABLE IF NOT EXISTS essays (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
essay_text TEXT,
grade TEXT,
feedback TEXT,
theme TEXT
);
CREATE TABLE IF NOT EXISTS tutor_sessions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
theme TEXT,
messages TEXT,
duration_seconds INTEGER
);
""")
conn.commit()
def get_word_progress(word_id):
"""Return learning state for one word, or None if never reviewed."""
conn = get_connection()
row = conn.execute(
"SELECT * FROM word_progress WHERE word_id = ?", (word_id,)
).fetchone()
return dict(row) if row else None
def update_word_progress(word_id, rating):
"""Run FSRS algorithm, update due date/stability/difficulty.
Args:
word_id: Vocabulary entry ID.
rating: fsrs.Rating value (Again=1, Hard=2, Good=3, Easy=4).
"""
conn = get_connection()
existing = get_word_progress(word_id)
if existing and existing["fsrs_state"]:
card = fsrs.Card.from_dict(json.loads(existing["fsrs_state"]))
else:
card = fsrs.Card()
card, review_log = _scheduler.review_card(card, rating)
now = datetime.now(timezone.utc).isoformat()
card_json = json.dumps(card.to_dict(), default=str)
conn.execute(
"""INSERT OR REPLACE INTO word_progress
(word_id, fsrs_state, due, stability, difficulty, reps, lapses, last_review)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
(
word_id,
card_json,
card.due.isoformat(),
card.stability,
card.difficulty,
(existing["reps"] + 1) if existing else 1,
existing["lapses"] if existing else 0,
now,
),
)
conn.commit()
return card
def get_due_words(limit=20):
"""Return word IDs where due <= now, ordered by due date."""
conn = get_connection()
now = datetime.now(timezone.utc).isoformat()
rows = conn.execute(
"SELECT word_id FROM word_progress WHERE due <= ? ORDER BY due LIMIT ?",
(now, limit),
).fetchall()
return [row["word_id"] for row in rows]
def get_word_counts(total_vocab_size=0):
"""Return dict with total/seen/mastered/due counts for dashboard."""
conn = get_connection()
now = datetime.now(timezone.utc).isoformat()
seen = conn.execute("SELECT COUNT(*) FROM word_progress").fetchone()[0]
mastered = conn.execute(
"SELECT COUNT(*) FROM word_progress WHERE stability > 10"
).fetchone()[0]
due = conn.execute(
"SELECT COUNT(*) FROM word_progress WHERE due <= ?", (now,)
).fetchone()[0]
return {
"total": total_vocab_size,
"seen": seen,
"mastered": mastered,
"due": due,
}
def get_all_word_progress():
"""Return all word progress as a dict of word_id -> progress dict."""
conn = get_connection()
rows = conn.execute("SELECT * FROM word_progress").fetchall()
return {row["word_id"]: dict(row) for row in rows}
def record_quiz_session(category, total_questions, correct, duration_seconds):
"""Log a completed flashcard session."""
conn = get_connection()
conn.execute(
"INSERT INTO quiz_sessions (category, total_questions, correct, duration_seconds) VALUES (?, ?, ?, ?)",
(category, total_questions, correct, duration_seconds),
)
conn.commit()
def save_essay(essay_text, grade, feedback, theme):
"""Save an essay + AI feedback."""
conn = get_connection()
conn.execute(
"INSERT INTO essays (essay_text, grade, feedback, theme) VALUES (?, ?, ?, ?)",
(essay_text, grade, feedback, theme),
)
conn.commit()
def save_tutor_session(theme, messages, duration_seconds):
"""Save a tutor conversation."""
conn = get_connection()
conn.execute(
"INSERT INTO tutor_sessions (theme, messages, duration_seconds) VALUES (?, ?, ?)",
(theme, json.dumps(messages, ensure_ascii=False), duration_seconds),
)
conn.commit()
def get_stats():
"""Aggregate data for the dashboard."""
conn = get_connection()
recent_quizzes = conn.execute(
"SELECT * FROM quiz_sessions ORDER BY timestamp DESC LIMIT 10"
).fetchall()
total_reviews = conn.execute(
"SELECT COALESCE(SUM(reps), 0) FROM word_progress"
).fetchone()[0]
total_quizzes = conn.execute(
"SELECT COUNT(*) FROM quiz_sessions"
).fetchone()[0]
# Streak: count consecutive days with activity
days = conn.execute(
"SELECT DISTINCT DATE(last_review) as d FROM word_progress WHERE last_review IS NOT NULL ORDER BY d DESC"
).fetchall()
streak = 0
today = datetime.now(timezone.utc).date()
for i, row in enumerate(days):
day = datetime.fromisoformat(row["d"]).date() if isinstance(row["d"], str) else row["d"]
expected = today - timedelta(days=i)
if day == expected:
streak += 1
else:
break
return {
"recent_quizzes": [dict(r) for r in recent_quizzes],
"total_reviews": total_reviews,
"total_quizzes": total_quizzes,
"streak": streak,
}
def get_recent_essays(limit=10):
"""Return recent essays for the essay history view."""
conn = get_connection()
rows = conn.execute(
"SELECT * FROM essays ORDER BY timestamp DESC LIMIT ?", (limit,)
).fetchall()
return [dict(r) for r in rows]
def close():
"""Close the database connection."""
global _conn
if _conn:
_conn.close()
_conn = None

View File

View File

@@ -0,0 +1,84 @@
"""Dashboard: progress stats, charts, and overview."""
import db
from modules.vocab import load_vocab, get_categories
def get_overview():
"""Return overview stats: total words, seen, mastered, due today."""
vocab = load_vocab()
counts = db.get_word_counts(total_vocab_size=len(vocab))
stats = db.get_stats()
counts["streak"] = stats["streak"]
counts["total_reviews"] = stats["total_reviews"]
counts["total_quizzes"] = stats["total_quizzes"]
return counts
def get_category_breakdown():
"""Return progress per category as list of dicts."""
vocab = load_vocab()
categories = get_categories()
all_progress = db.get_all_word_progress()
breakdown = []
for cat in categories:
cat_words = [e for e in vocab if e["category"] == cat]
total = len(cat_words)
seen = 0
mastered = 0
for e in cat_words:
progress = all_progress.get(e["id"])
if progress:
seen += 1
if progress["stability"] and progress["stability"] > 10:
mastered += 1
breakdown.append({
"Category": cat,
"Total": total,
"Seen": seen,
"Mastered": mastered,
"Progress": f"{seen}/{total}" if total > 0 else "0/0",
})
return breakdown
def get_recent_quizzes(limit=10):
"""Return recent quiz results as list of dicts for display."""
stats = db.get_stats()
quizzes = stats["recent_quizzes"][:limit]
result = []
for q in quizzes:
result.append({
"Date": q["timestamp"],
"Category": q["category"] or "All",
"Score": f"{q['correct']}/{q['total_questions']}",
"Duration": f"{q['duration_seconds'] or 0}s",
})
return result
def format_overview_markdown():
"""Format overview stats as a markdown string for display."""
o = get_overview()
pct = (o["seen"] / o["total"] * 100) if o["total"] > 0 else 0
bar_filled = int(pct / 5)
bar_empty = 20 - bar_filled
progress_bar = "" * bar_filled + "" * bar_empty
lines = [
"## Dashboard",
"",
f"**Words studied:** {o['seen']} / {o['total']} ({pct:.0f}%)",
f"`{progress_bar}`",
"",
f"**Due today:** {o['due']}",
f"**Mastered:** {o['mastered']}",
f"**Daily streak:** {o['streak']} day{'s' if o['streak'] != 1 else ''}",
f"**Total reviews:** {o['total_reviews']}",
f"**Quiz sessions:** {o['total_quizzes']}",
]
return "\n".join(lines)

View File

@@ -0,0 +1,78 @@
"""Essay writing and AI marking."""
import db
from ai import ask
MARKING_SYSTEM_PROMPT = """You are an expert Persian (Farsi) language teacher marking a GCSE-level essay.
You write in English but can read and correct Persian text.
Always provide constructive, encouraging feedback suitable for a language learner."""
MARKING_PROMPT_TEMPLATE = """Please mark this Persian essay written by a GCSE student.
Theme: {theme}
Student's essay:
{essay_text}
Please provide your response in this exact format:
**Grade:** [Give a grade from 1-9 matching GCSE grading, or a descriptive level like A2/B1]
**Summary:** [1-2 sentence overview of the essay quality]
**Corrections:**
[List specific errors with corrections. For each error, show the original text and the corrected version in Persian, with an English explanation]
**Improved version:**
[Rewrite the essay in corrected Persian]
**Tips for improvement:**
[3-5 specific, actionable tips for the student]"""
GCSE_THEMES = [
"Identity and culture",
"Local area and environment",
"School and work",
"Travel and tourism",
"International and global dimension",
]
def mark_essay(essay_text, theme="General"):
"""Send essay to AI for marking. Returns structured feedback."""
if not essay_text or not essay_text.strip():
return "Please write an essay first."
prompt = MARKING_PROMPT_TEMPLATE.format(
theme=theme,
essay_text=essay_text.strip(),
)
feedback = ask(prompt, system=MARKING_SYSTEM_PROMPT, quality="smart")
# Extract grade from feedback (best-effort)
grade = ""
for line in feedback.split("\n"):
if line.strip().startswith("**Grade:**"):
grade = line.replace("**Grade:**", "").strip()
break
# Save to database
db.save_essay(essay_text.strip(), grade, feedback, theme)
return feedback
def get_essay_history(limit=10):
"""Return recent essays for the history view."""
essays = db.get_recent_essays(limit)
result = []
for e in essays:
result.append({
"Date": e["timestamp"],
"Theme": e["theme"] or "General",
"Grade": e["grade"] or "-",
"Preview": (e["essay_text"] or "")[:50] + "...",
})
return result

View File

@@ -0,0 +1,200 @@
"""Persian idioms, expressions, and social conventions."""
from ai import ask
# Built-in collection of common Persian expressions and idioms
EXPRESSIONS = [
{
"persian": "سلام علیکم",
"finglish": "salâm aleykom",
"english": "Peace be upon you (formal greeting)",
"context": "Formal greeting, especially with elders",
},
{
"persian": "خسته نباشید",
"finglish": "khaste nabâshid",
"english": "May you not be tired",
"context": "Common greeting to someone who has been working. Used as 'hello' in shops, offices, etc.",
},
{
"persian": "دستت درد نکنه",
"finglish": "dastet dard nakone",
"english": "May your hand not hurt",
"context": "Thank you for your effort (after someone does something for you)",
},
{
"persian": "قابلی نداره",
"finglish": "ghâbeli nadâre",
"english": "It's not worthy (of you)",
"context": "You're welcome / Don't mention it — said when giving a gift or doing a favour",
},
{
"persian": "تعارف نکن",
"finglish": "ta'ârof nakon",
"english": "Don't do ta'arof",
"context": "Stop being politely modest — please accept! Part of Persian ta'arof culture.",
},
{
"persian": "نوش جان",
"finglish": "nush-e jân",
"english": "May it nourish your soul",
"context": "Said to someone eating — like 'bon appétit' or 'enjoy your meal'",
},
{
"persian": "چشمت روز بد نبینه",
"finglish": "cheshmet ruz-e bad nabine",
"english": "May your eyes never see a bad day",
"context": "A warm wish for someone's wellbeing",
},
{
"persian": "قدمت روی چشم",
"finglish": "ghadamet ru-ye cheshm",
"english": "Your step is on my eye",
"context": "Warm welcome — 'you're very welcome here'. Extremely hospitable expression.",
},
{
"persian": "ان‌شاءالله",
"finglish": "inshâ'allâh",
"english": "God willing",
"context": "Used when talking about future plans. Very common in daily speech.",
},
{
"persian": "ماشاءالله",
"finglish": "mâshâ'allâh",
"english": "What God has willed",
"context": "Expression of admiration or praise, also used to ward off the evil eye.",
},
{
"persian": "الهی شکر",
"finglish": "elâhi shokr",
"english": "Thank God",
"context": "Expression of gratitude, similar to 'thankfully'",
},
{
"persian": "به سلامتی",
"finglish": "be salâmati",
"english": "To your health / Cheers",
"context": "A toast or general well-wishing expression",
},
{
"persian": "عید مبارک",
"finglish": "eyd mobârak",
"english": "Happy holiday/celebration",
"context": "Used for any celebration, especially Nowruz",
},
{
"persian": "تسلیت می‌گم",
"finglish": "tasliyat migam",
"english": "I offer my condolences",
"context": "Expressing sympathy when someone has lost a loved one",
},
{
"persian": "خدا بیامرزه",
"finglish": "khodâ biâmorzesh",
"english": "May God forgive them (rest in peace)",
"context": "Said about someone who has passed away",
},
{
"persian": "زبونت رو گاز بگیر",
"finglish": "zaboonet ro gâz begir",
"english": "Bite your tongue",
"context": "Don't say such things! (similar to English 'touch wood')",
},
{
"persian": "دمت گرم",
"finglish": "damet garm",
"english": "May your breath be warm",
"context": "Well done! / Good for you! (informal, friendly praise)",
},
{
"persian": "چشم",
"finglish": "cheshm",
"english": "On my eye (I will do it)",
"context": "Respectful way of saying 'yes, I'll do it' — shows obedience/respect",
},
{
"persian": "بفرمایید",
"finglish": "befarmâyid",
"english": "Please (go ahead / help yourself / come in)",
"context": "Very versatile polite expression: offering food, inviting someone in, or giving way",
},
{
"persian": "ببخشید",
"finglish": "bebakhshid",
"english": "Excuse me / I'm sorry",
"context": "Used for both apologies and getting someone's attention",
},
{
"persian": "مخلصیم",
"finglish": "mokhlesim",
"english": "I'm your humble servant",
"context": "Polite/humble way of saying goodbye or responding to a compliment (ta'arof)",
},
{
"persian": "سرت سلامت باشه",
"finglish": "saret salâmat bâshe",
"english": "May your head be safe",
"context": "Expression of condolence — 'I'm sorry for your loss'",
},
{
"persian": "روی ما رو زمین ننداز",
"finglish": "ru-ye mâ ro zamin nandâz",
"english": "Don't throw our face on the ground",
"context": "Please don't refuse/embarrass us — said when insisting on a request",
},
{
"persian": "قربونت برم",
"finglish": "ghorboonet beram",
"english": "I'd sacrifice myself for you",
"context": "Term of endearment — very common among family and close friends",
},
{
"persian": "جون دل",
"finglish": "jun-e del",
"english": "Life of my heart",
"context": "Affectionate term used with loved ones",
},
]
def get_all_expressions():
"""Return all built-in expressions."""
return EXPRESSIONS
def get_random_expression():
"""Pick a random expression."""
import random
return random.choice(EXPRESSIONS)
def explain_expression(expression):
"""Use AI to generate a detailed explanation with usage examples."""
prompt = f"""Explain this Persian expression for an English-speaking student:
Persian: {expression['persian']}
Transliteration: {expression['finglish']}
Literal meaning: {expression['english']}
Context: {expression['context']}
Please provide:
1. A fuller explanation of when and how this is used
2. The cultural context (ta'arof, hospitality, etc.)
3. Two example dialogues showing it in use (in Persian with English translation)
4. Any variations or related expressions
Keep it concise and student-friendly."""
return ask(prompt, quality="fast")
def format_expression(expr, show_transliteration="off"):
"""Format an expression for display."""
parts = [
f'<div dir="rtl" style="font-size:1.8em; text-align:center">{expr["persian"]}</div>',
f'<div style="text-align:center; font-size:1.2em">{expr["english"]}</div>',
]
if show_transliteration != "off":
parts.append(f'<div style="text-align:center; color:#666; font-style:italic">{expr["finglish"]}</div>')
parts.append(f'<div style="text-align:center; color:#888; margin-top:0.5em">{expr["context"]}</div>')
return "\n".join(parts)

View File

@@ -0,0 +1,65 @@
"""Conversational Persian lessons by GCSE theme."""
import time
import db
from ai import chat_ollama
TUTOR_SYSTEM_PROMPT = """You are a friendly Persian (Farsi) language tutor teaching English-speaking GCSE students.
Rules:
- Use a mix of English and Persian. Start mostly in English, gradually introducing more Persian.
- When you write Persian, also provide the Finglish transliteration in parentheses.
- Keep responses concise (2-4 sentences per turn).
- Ask the student to practice: translate phrases, answer questions in Persian, or fill in blanks.
- Correct mistakes gently and explain why.
- Stay on the current theme/topic.
- Use Iranian Persian (Farsi), not Dari or Tajik.
- Adapt to the student's level based on their responses."""
THEME_PROMPTS = {
"Identity and culture": "Let's practice talking about family, personality, daily routines, and Persian celebrations like Nowruz!",
"Local area and environment": "Let's practice talking about your home, neighbourhood, shopping, and the environment!",
"School and work": "Let's practice talking about school subjects, school life, jobs, and future plans!",
"Travel and tourism": "Let's practice talking about transport, directions, holidays, hotels, and restaurants!",
"International and global dimension": "Let's practice talking about health, global issues, technology, and social media!",
"Free conversation": "Let's have a free conversation in Persian! I'll help you along the way.",
}
def start_lesson(theme):
"""Generate the opening message for a new lesson.
Returns:
(assistant_message, messages_list)
"""
intro = THEME_PROMPTS.get(theme, THEME_PROMPTS["Free conversation"])
system = TUTOR_SYSTEM_PROMPT + f"\n\nCurrent topic: {theme}. {intro}"
messages = [{"role": "user", "content": f"I'd like to practice Persian. Today's theme is: {theme}"}]
response = chat_ollama(messages, system=system)
messages.append({"role": "assistant", "content": response})
return response, messages, system
def process_response(user_input, messages, system=None):
"""Add user input to conversation, get AI response.
Returns:
(assistant_response, updated_messages)
"""
if not user_input or not user_input.strip():
return "", messages
messages.append({"role": "user", "content": user_input.strip()})
response = chat_ollama(messages, system=system)
messages.append({"role": "assistant", "content": response})
return response, messages
def save_session(theme, messages, start_time):
"""Save the current tutor session to the database."""
duration = int(time.time() - start_time)
db.save_tutor_session(theme, messages, duration)

View File

@@ -0,0 +1,153 @@
"""Vocabulary search, flashcard logic, and FSRS-driven review."""
import json
import random
from pathlib import Path
import fsrs
import db
VOCAB_PATH = Path(__file__).parent.parent / "data" / "vocabulary.json"
_vocab_data = None
def load_vocab():
"""Load vocabulary data from JSON (cached)."""
global _vocab_data
if _vocab_data is None:
with open(VOCAB_PATH, encoding="utf-8") as f:
_vocab_data = json.load(f)
return _vocab_data
def get_categories():
"""Return sorted list of unique categories."""
vocab = load_vocab()
return sorted({entry["category"] for entry in vocab})
def get_sections():
"""Return sorted list of unique sections."""
vocab = load_vocab()
return sorted({entry["section"] for entry in vocab})
def search(query, vocab_data=None):
"""Search vocabulary by English or Persian text. Returns matching entries."""
if not query or not query.strip():
return []
vocab = vocab_data or load_vocab()
query_lower = query.strip().lower()
results = []
for entry in vocab:
if (
query_lower in entry["english"].lower()
or query_lower in entry["persian"]
or (entry.get("finglish") and query_lower in entry["finglish"].lower())
):
results.append(entry)
return results
def get_random_word(vocab_data=None, category=None):
"""Pick a random vocabulary entry, optionally filtered by category."""
vocab = vocab_data or load_vocab()
if category and category != "All":
filtered = [e for e in vocab if e["category"] == category]
else:
filtered = vocab
if not filtered:
return None
return random.choice(filtered)
def get_flashcard_batch(count=10, category=None):
"""Get a batch of words for flashcard study.
Prioritizes due words (FSRS), then fills with new/random words.
"""
vocab = load_vocab()
if category and category != "All":
pool = [e for e in vocab if e["category"] == category]
else:
pool = vocab
# Get due words first
due_ids = db.get_due_words(limit=count)
due_entries = [e for e in pool if e["id"] in due_ids]
# Fill remaining with unseen or random words
remaining = count - len(due_entries)
if remaining > 0:
seen_ids = {e["id"] for e in due_entries}
all_progress = db.get_all_word_progress()
# Prefer unseen words
unseen = [e for e in pool if e["id"] not in seen_ids and e["id"] not in all_progress]
if len(unseen) >= remaining:
fill = random.sample(unseen, remaining)
else:
# Use all unseen + random from rest
fill = unseen
still_needed = remaining - len(fill)
rest = [e for e in pool if e["id"] not in seen_ids and e not in fill]
if rest:
fill.extend(random.sample(rest, min(still_needed, len(rest))))
due_entries.extend(fill)
random.shuffle(due_entries)
return due_entries
def check_answer(word_id, user_answer, direction="en_to_fa"):
"""Check if user's answer matches the target word.
Args:
word_id: Vocabulary entry ID.
user_answer: What the user typed.
direction: "en_to_fa" (user writes Persian) or "fa_to_en" (user writes English).
Returns:
(is_correct, correct_answer, entry)
"""
vocab = load_vocab()
entry = next((e for e in vocab if e["id"] == word_id), None)
if not entry:
return False, "", None
user_answer = user_answer.strip()
if direction == "en_to_fa":
correct = entry["persian"].strip()
is_correct = user_answer == correct
else:
correct = entry["english"].strip().lower()
is_correct = user_answer.lower() == correct
return is_correct, correct if not is_correct else user_answer, entry
def format_word_card(entry, show_transliteration="off"):
"""Format a vocabulary entry for display as RTL-safe markdown."""
parts = []
parts.append(f'<div dir="rtl" style="font-size:2em; text-align:center">{entry["persian"]}</div>')
parts.append(f'<div style="font-size:1.3em; text-align:center">{entry["english"]}</div>')
if show_transliteration != "off" and entry.get("finglish"):
parts.append(f'<div style="text-align:center; color:#666; font-style:italic">{entry["finglish"]}</div>')
parts.append(f'<div style="text-align:center; color:#999; font-size:0.9em">{entry.get("category", "")}</div>')
return "\n".join(parts)
def get_word_status(word_id):
"""Return status string for a word: new, learning, or mastered."""
progress = db.get_word_progress(word_id)
if not progress:
return "new"
if progress["stability"] and progress["stability"] > 10:
return "mastered"
return "learning"

View File

@@ -0,0 +1,3 @@
gradio>=4.0
genanki
fsrs

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,81 @@
#!/usr/bin/env python3
"""One-time script to generate/update vocabulary.json with AI-assisted transliterations.
Usage:
python scripts/generate_vocab.py
This reads an existing vocabulary.json, finds entries missing finglish
transliterations, and uses Ollama to generate them.
"""
import json
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
from ai import ask_ollama
VOCAB_PATH = Path(__file__).parent.parent / "data" / "vocabulary.json"
def generate_transliterations(vocab):
"""Fill in missing finglish transliterations using AI."""
missing = [e for e in vocab if not e.get("finglish")]
if not missing:
print("All entries already have finglish transliterations.")
return vocab
print(f"Generating transliterations for {len(missing)} entries...")
# Process in batches of 20
batch_size = 20
for i in range(0, len(missing), batch_size):
batch = missing[i : i + batch_size]
pairs = "\n".join(f"{e['persian']} = {e['english']}" for e in batch)
prompt = f"""For each Persian word below, provide the Finglish (romanized) transliteration.
Use these conventions: â for آ, kh for خ, sh for ش, zh for ژ, gh for ق/غ, ch for چ.
Reply with ONLY the transliterations, one per line, in the same order.
{pairs}"""
try:
response = ask_ollama(prompt, model="qwen2.5:7b")
lines = [l.strip() for l in response.strip().split("\n") if l.strip()]
for j, entry in enumerate(batch):
if j < len(lines):
# Clean up the response line
line = lines[j]
# Remove any numbering or equals signs
for sep in ["=", ":", "-", "."]:
if sep in line:
line = line.split(sep)[-1].strip()
entry["finglish"] = line
print(f" Processed {min(i + batch_size, len(missing))}/{len(missing)}")
except Exception as e:
print(f" Error processing batch: {e}")
return vocab
def main():
if not VOCAB_PATH.exists():
print(f"No vocabulary file found at {VOCAB_PATH}")
return
with open(VOCAB_PATH, encoding="utf-8") as f:
vocab = json.load(f)
print(f"Loaded {len(vocab)} entries")
vocab = generate_transliterations(vocab)
with open(VOCAB_PATH, "w", encoding="utf-8") as f:
json.dump(vocab, f, ensure_ascii=False, indent=2)
print(f"Saved {len(vocab)} entries to {VOCAB_PATH}")
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,77 @@
"""Persian speech-to-text wrapper using sttlib."""
import sys
from pathlib import Path
import numpy as np
# sttlib lives in sibling project tool-speechtotext
_sttlib_path = str(Path(__file__).resolve().parent.parent / "tool-speechtotext")
sys.path.insert(0, _sttlib_path)
from sttlib import load_whisper_model, transcribe, is_hallucination
_model = None
_whisper_size = "medium"
# Common Whisper hallucinations in Persian/silence
PERSIAN_HALLUCINATIONS = [
"ممنون", # "thank you" hallucination
"خداحافظ", # "goodbye" hallucination
"تماشا کنید", # "watch" hallucination
"لایک کنید", # "like" hallucination
]
def set_whisper_size(size):
"""Change the Whisper model size. Reloads on next transcription."""
global _whisper_size, _model
if size != _whisper_size:
_whisper_size = size
_model = None
def get_model():
"""Load Whisper model (cached singleton)."""
global _model
if _model is None:
_model = load_whisper_model(_whisper_size)
return _model
def transcribe_persian(audio_tuple):
"""Transcribe Persian audio from Gradio audio component.
Args:
audio_tuple: (sample_rate, numpy_array) from gr.Audio component.
Returns:
Transcribed text string, or empty string on failure/hallucination.
"""
if audio_tuple is None:
return ""
sr, audio = audio_tuple
model = get_model()
# Convert to float32 normalized [-1, 1]
if audio.dtype == np.int16:
audio_float = audio.astype(np.float32) / 32768.0
elif audio.dtype == np.float32:
audio_float = audio
else:
audio_float = audio.astype(np.float32) / np.iinfo(audio.dtype).max
# Mono conversion if stereo
if audio_float.ndim > 1:
audio_float = audio_float.mean(axis=1)
# Use sttlib transcribe
text = transcribe(model, audio_float)
# Filter hallucinations (English + Persian)
if is_hallucination(text):
return ""
if text.strip() in PERSIAN_HALLUCINATIONS:
return ""
return text

View File

View File

@@ -0,0 +1,89 @@
"""Tests for ai.py — dual AI backend."""
import sys
from pathlib import Path
from unittest.mock import patch, MagicMock
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
import ai
def test_ask_ollama_calls_ollama_chat():
"""ask_ollama should call ollama.chat with correct messages."""
mock_response = MagicMock()
mock_response.message.content = "test response"
with patch("ai.ollama.chat", return_value=mock_response) as mock_chat:
result = ai.ask_ollama("Hello", system="Be helpful")
assert result == "test response"
call_args = mock_chat.call_args
messages = call_args.kwargs.get("messages") or call_args[1].get("messages")
assert len(messages) == 2
assert messages[0]["role"] == "system"
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "Hello"
def test_ask_ollama_no_system():
"""ask_ollama without system prompt should only send user message."""
mock_response = MagicMock()
mock_response.message.content = "response"
with patch("ai.ollama.chat", return_value=mock_response) as mock_chat:
ai.ask_ollama("Hi")
call_args = mock_chat.call_args
messages = call_args.kwargs.get("messages") or call_args[1].get("messages")
assert len(messages) == 1
assert messages[0]["role"] == "user"
def test_ask_claude_calls_subprocess():
"""ask_claude should call claude CLI via subprocess."""
with patch("ai.subprocess.run") as mock_run:
mock_run.return_value = MagicMock(stdout="Claude says hi\n")
result = ai.ask_claude("Hello")
assert result == "Claude says hi"
mock_run.assert_called_once()
args = mock_run.call_args[0][0]
assert args[0] == "claude"
assert "-p" in args
def test_ask_fast_uses_ollama():
"""ask with quality='fast' should use Ollama."""
with patch("ai.ask_ollama", return_value="ollama response") as mock:
result = ai.ask("test", quality="fast")
assert result == "ollama response"
mock.assert_called_once()
def test_ask_smart_uses_claude():
"""ask with quality='smart' should use Claude."""
with patch("ai.ask_claude", return_value="claude response") as mock:
result = ai.ask("test", quality="smart")
assert result == "claude response"
mock.assert_called_once()
def test_chat_ollama():
"""chat_ollama should pass multi-turn messages."""
mock_response = MagicMock()
mock_response.message.content = "continuation"
with patch("ai.ollama.chat", return_value=mock_response) as mock_chat:
messages = [
{"role": "user", "content": "Hi"},
{"role": "assistant", "content": "Hello!"},
{"role": "user", "content": "How are you?"},
]
result = ai.chat_ollama(messages, system="Be helpful")
assert result == "continuation"
call_args = mock_chat.call_args
all_msgs = call_args.kwargs.get("messages") or call_args[1].get("messages")
# system + 3 conversation messages
assert len(all_msgs) == 4

View File

@@ -0,0 +1,86 @@
"""Tests for anki_export.py — Anki .apkg generation."""
import os
import sys
import tempfile
import zipfile
from pathlib import Path
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
from anki_export import export_deck
SAMPLE_VOCAB = [
{
"id": "verb_go",
"section": "High-frequency language",
"category": "Common verbs",
"english": "to go",
"persian": "رفتن",
"finglish": "raftan",
},
{
"id": "verb_eat",
"section": "High-frequency language",
"category": "Common verbs",
"english": "to eat",
"persian": "خوردن",
"finglish": "khordan",
},
{
"id": "colour_red",
"section": "High-frequency language",
"category": "Colours",
"english": "red",
"persian": "قرمز",
"finglish": "ghermez",
},
]
def test_export_deck_creates_file(tmp_path):
"""export_deck should create a valid .apkg file."""
output = str(tmp_path / "test.apkg")
result = export_deck(SAMPLE_VOCAB, output_path=output)
assert result == output
assert os.path.exists(output)
assert os.path.getsize(output) > 0
def test_export_deck_is_valid_zip(tmp_path):
"""An .apkg file is a zip archive containing an Anki SQLite database."""
output = str(tmp_path / "test.apkg")
export_deck(SAMPLE_VOCAB, output_path=output)
assert zipfile.is_zipfile(output)
def test_export_deck_with_category_filter(tmp_path):
"""export_deck with category filter should only include matching entries."""
output = str(tmp_path / "test.apkg")
export_deck(SAMPLE_VOCAB, categories=["Colours"], output_path=output)
# File should exist and be smaller than unfiltered
assert os.path.exists(output)
size_filtered = os.path.getsize(output)
output2 = str(tmp_path / "test_all.apkg")
export_deck(SAMPLE_VOCAB, output_path=output2)
size_all = os.path.getsize(output2)
# Filtered deck should be smaller (fewer cards)
assert size_filtered <= size_all
def test_export_deck_empty_vocab(tmp_path):
"""export_deck with empty vocabulary should still create a valid file."""
output = str(tmp_path / "test.apkg")
export_deck([], output_path=output)
assert os.path.exists(output)
def test_export_deck_no_category_match(tmp_path):
"""export_deck with non-matching category filter should create empty deck."""
output = str(tmp_path / "test.apkg")
export_deck(SAMPLE_VOCAB, categories=["Nonexistent"], output_path=output)
assert os.path.exists(output)

View File

@@ -0,0 +1,151 @@
"""Tests for db.py — SQLite database layer with FSRS integration."""
import os
import sys
import tempfile
from pathlib import Path
from unittest.mock import patch
import pytest
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent.parent))
import fsrs
@pytest.fixture(autouse=True)
def temp_db(tmp_path):
"""Use a temporary database for each test."""
import db as db_mod
db_mod._conn = None
db_mod.DB_PATH = tmp_path / "test.db"
db_mod.init_db()
yield db_mod
db_mod.close()
def test_init_db_creates_tables(temp_db):
"""init_db should create all required tables."""
conn = temp_db.get_connection()
tables = conn.execute(
"SELECT name FROM sqlite_master WHERE type='table'"
).fetchall()
table_names = {row["name"] for row in tables}
assert "word_progress" in table_names
assert "quiz_sessions" in table_names
assert "essays" in table_names
assert "tutor_sessions" in table_names
def test_get_word_progress_nonexistent(temp_db):
"""Should return None for a word that hasn't been reviewed."""
assert temp_db.get_word_progress("nonexistent") is None
def test_update_and_get_word_progress(temp_db):
"""update_word_progress should create and update progress."""
card = temp_db.update_word_progress("verb_go", fsrs.Rating.Good)
assert card is not None
assert card.stability is not None
progress = temp_db.get_word_progress("verb_go")
assert progress is not None
assert progress["word_id"] == "verb_go"
assert progress["reps"] == 1
assert progress["fsrs_state"] is not None
def test_update_word_progress_increments_reps(temp_db):
"""Reviewing the same word multiple times should increment reps."""
temp_db.update_word_progress("verb_go", fsrs.Rating.Good)
temp_db.update_word_progress("verb_go", fsrs.Rating.Easy)
progress = temp_db.get_word_progress("verb_go")
assert progress["reps"] == 2
def test_get_due_words(temp_db):
"""get_due_words should return words that are due for review."""
# A newly reviewed word with Rating.Again should be due soon
temp_db.update_word_progress("verb_go", fsrs.Rating.Again)
# An easy word should have a later due date
temp_db.update_word_progress("verb_eat", fsrs.Rating.Easy)
# Due words depend on timing; at minimum both should be in the system
all_progress = temp_db.get_connection().execute(
"SELECT word_id FROM word_progress"
).fetchall()
assert len(all_progress) == 2
def test_get_word_counts(temp_db):
"""get_word_counts should return correct counts."""
counts = temp_db.get_word_counts(total_vocab_size=100)
assert counts["total"] == 100
assert counts["seen"] == 0
assert counts["mastered"] == 0
assert counts["due"] == 0
temp_db.update_word_progress("verb_go", fsrs.Rating.Good)
counts = temp_db.get_word_counts(total_vocab_size=100)
assert counts["seen"] == 1
def test_record_quiz_session(temp_db):
"""record_quiz_session should insert a quiz record."""
temp_db.record_quiz_session("Common verbs", 10, 7, 120)
rows = temp_db.get_connection().execute(
"SELECT * FROM quiz_sessions"
).fetchall()
assert len(rows) == 1
assert rows[0]["correct"] == 7
assert rows[0]["total_questions"] == 10
def test_save_essay(temp_db):
"""save_essay should store the essay and feedback."""
temp_db.save_essay("متن آزمایشی", "B1", "Good effort!", "Identity and culture")
essays = temp_db.get_recent_essays()
assert len(essays) == 1
assert essays[0]["grade"] == "B1"
def test_save_tutor_session(temp_db):
"""save_tutor_session should store the conversation."""
messages = [
{"role": "user", "content": "سلام"},
{"role": "assistant", "content": "سلام! حالت چطوره؟"},
]
temp_db.save_tutor_session("Identity and culture", messages, 300)
rows = temp_db.get_connection().execute(
"SELECT * FROM tutor_sessions"
).fetchall()
assert len(rows) == 1
assert rows[0]["theme"] == "Identity and culture"
def test_get_stats(temp_db):
"""get_stats should return aggregated stats."""
stats = temp_db.get_stats()
assert stats["total_reviews"] == 0
assert stats["total_quizzes"] == 0
assert stats["streak"] == 0
assert isinstance(stats["recent_quizzes"], list)
def test_close_and_reopen(temp_db):
"""Closing and reopening should preserve data."""
temp_db.update_word_progress("verb_go", fsrs.Rating.Good)
db_path = temp_db.DB_PATH
temp_db.close()
# Reopen
temp_db._conn = None
temp_db.DB_PATH = db_path
temp_db.init_db()
progress = temp_db.get_word_progress("verb_go")
assert progress is not None
assert progress["reps"] == 1

View File

@@ -0,0 +1,204 @@
"""Tests for modules/vocab.py — vocabulary search and flashcard logic."""
import json
import sys
from pathlib import Path
from unittest.mock import patch
import pytest
sys.path.insert(0, str(Path(__file__).parent.parent))
SAMPLE_VOCAB = [
{
"id": "verb_go",
"section": "High-frequency language",
"category": "Common verbs",
"english": "to go",
"persian": "رفتن",
"finglish": "raftan",
},
{
"id": "verb_eat",
"section": "High-frequency language",
"category": "Common verbs",
"english": "to eat",
"persian": "خوردن",
"finglish": "khordan",
},
{
"id": "adj_big",
"section": "High-frequency language",
"category": "Common adjectives",
"english": "big",
"persian": "بزرگ",
"finglish": "bozorg",
},
{
"id": "colour_red",
"section": "High-frequency language",
"category": "Colours",
"english": "red",
"persian": "قرمز",
"finglish": "ghermez",
},
]
@pytest.fixture(autouse=True)
def mock_vocab_and_db(tmp_path):
"""Mock vocabulary loading and use temp DB."""
import db as db_mod
import modules.vocab as vocab_mod
# Temp DB
db_mod._conn = None
db_mod.DB_PATH = tmp_path / "test.db"
db_mod.init_db()
# Mock vocab
vocab_mod._vocab_data = SAMPLE_VOCAB
yield vocab_mod
db_mod.close()
vocab_mod._vocab_data = None
def test_load_vocab(mock_vocab_and_db):
"""load_vocab should return the vocabulary data."""
data = mock_vocab_and_db.load_vocab()
assert len(data) == 4
def test_get_categories(mock_vocab_and_db):
"""get_categories should return unique sorted categories."""
cats = mock_vocab_and_db.get_categories()
assert "Colours" in cats
assert "Common verbs" in cats
assert "Common adjectives" in cats
def test_search_english(mock_vocab_and_db):
"""Search should find entries by English text."""
results = mock_vocab_and_db.search("go")
assert len(results) == 1
assert results[0]["id"] == "verb_go"
def test_search_persian(mock_vocab_and_db):
"""Search should find entries by Persian text."""
results = mock_vocab_and_db.search("رفتن")
assert len(results) == 1
assert results[0]["id"] == "verb_go"
def test_search_finglish(mock_vocab_and_db):
"""Search should find entries by Finglish text."""
results = mock_vocab_and_db.search("raftan")
assert len(results) == 1
assert results[0]["id"] == "verb_go"
def test_search_empty(mock_vocab_and_db):
"""Empty search should return empty list."""
assert mock_vocab_and_db.search("") == []
assert mock_vocab_and_db.search(None) == []
def test_search_no_match(mock_vocab_and_db):
"""Search with no match should return empty list."""
assert mock_vocab_and_db.search("zzzzz") == []
def test_get_random_word(mock_vocab_and_db):
"""get_random_word should return a valid entry."""
word = mock_vocab_and_db.get_random_word()
assert word is not None
assert "id" in word
assert "english" in word
assert "persian" in word
def test_get_random_word_with_category(mock_vocab_and_db):
"""get_random_word with category filter should only return matching entries."""
word = mock_vocab_and_db.get_random_word(category="Colours")
assert word is not None
assert word["category"] == "Colours"
def test_get_random_word_nonexistent_category(mock_vocab_and_db):
"""get_random_word with bad category should return None."""
word = mock_vocab_and_db.get_random_word(category="Nonexistent")
assert word is None
def test_check_answer_correct_en_to_fa(mock_vocab_and_db):
"""Correct Persian answer should be marked correct."""
correct, answer, entry = mock_vocab_and_db.check_answer(
"verb_go", "رفتن", direction="en_to_fa"
)
assert correct is True
def test_check_answer_incorrect_en_to_fa(mock_vocab_and_db):
"""Incorrect Persian answer should be marked incorrect with correct answer."""
correct, answer, entry = mock_vocab_and_db.check_answer(
"verb_go", "خوردن", direction="en_to_fa"
)
assert correct is False
assert answer == "رفتن"
def test_check_answer_fa_to_en(mock_vocab_and_db):
"""Correct English answer (case-insensitive) should be marked correct."""
correct, answer, entry = mock_vocab_and_db.check_answer(
"verb_go", "To Go", direction="fa_to_en"
)
assert correct is True
def test_check_answer_nonexistent_word(mock_vocab_and_db):
"""Checking answer for nonexistent word should return False."""
correct, answer, entry = mock_vocab_and_db.check_answer(
"nonexistent", "test", direction="en_to_fa"
)
assert correct is False
assert entry is None
def test_format_word_card(mock_vocab_and_db):
"""format_word_card should produce RTL HTML with correct content."""
entry = SAMPLE_VOCAB[0]
html = mock_vocab_and_db.format_word_card(entry, show_transliteration="Finglish")
assert "رفتن" in html
assert "to go" in html
assert "raftan" in html
def test_format_word_card_no_transliteration(mock_vocab_and_db):
"""format_word_card with transliteration off should not show finglish."""
entry = SAMPLE_VOCAB[0]
html = mock_vocab_and_db.format_word_card(entry, show_transliteration="off")
assert "raftan" not in html
def test_get_flashcard_batch(mock_vocab_and_db):
"""get_flashcard_batch should return a batch of entries."""
batch = mock_vocab_and_db.get_flashcard_batch(count=2)
assert len(batch) == 2
assert all("id" in e for e in batch)
def test_get_word_status_new(mock_vocab_and_db):
"""Unreviewed word should have status 'new'."""
assert mock_vocab_and_db.get_word_status("verb_go") == "new"
def test_get_word_status_learning(mock_vocab_and_db):
"""Recently reviewed word should have status 'learning'."""
import db
import fsrs
db.update_word_progress("verb_go", fsrs.Rating.Good)
assert mock_vocab_and_db.get_word_status("verb_go") == "learning"

View File

@@ -1,5 +1,11 @@
{ {
"python-envs.defaultEnvManager": "ms-python.python:conda", "python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda", "python-envs.defaultPackageManager": "ms-python.python:conda",
"python-envs.pythonProjects": [] "python-envs.pythonProjects": [],
"python.testing.pytestEnabled": true,
"python.testing.unittestEnabled": false,
"python.testing.pytestArgs": [
"tests",
"-v"
]
} }

View File

@@ -12,11 +12,20 @@ Speech-to-text command line utilities leveraging local models (faster-whisper, O
## Tools ## Tools
- `assistant.py` / `talk.sh` — transcribe speech, copy to clipboard, optionally send to Ollama - `assistant.py` / `talk.sh` — transcribe speech, copy to clipboard, optionally send to Ollama
- `voice_to_terminal.py` / `terminal.sh` — voice-controlled terminal via Ollama tool calling - `voice_to_terminal.py` / `terminal.sh` — voice-controlled terminal via Ollama tool calling
- `voice_to_xdotool.py` / `dotool.sh` — hands-free voice typing into any focused window (VAD + xdotool) - `voice_to_xdotool.py` / `xdotool.sh` — hands-free voice typing into any focused window (VAD + xdotool)
## Shared Library
- `sttlib/` — shared package used by all scripts and importable by other projects
- `whisper_loader.py` — model loading with GPU→CPU fallback
- `audio.py` — press-enter recording, PCM conversion
- `transcription.py` — Whisper transcribe wrapper, hallucination filter
- `vad.py` — VADProcessor, audio callback, constants
- Other projects import via: `sys.path.insert(0, "/path/to/tool-speechtotext")`
## Testing ## Testing
- To test scripts: `mamba run -n whisper-ollama python <script.py> --model-size base` - Run tests: `mamba run -n whisper-ollama python -m pytest tests/`
- Use `--model-size base` for faster iteration during development - Use `--model-size base` for faster iteration during development
- Tests mock hardware (Whisper model, VAD, mic) — no GPU/mic needed to run them
- Audio device is available — live mic testing is possible - Audio device is available — live mic testing is possible
- Test xdotool output by focusing a text editor window - Test xdotool output by focusing a text editor window
@@ -24,12 +33,13 @@ Speech-to-text command line utilities leveraging local models (faster-whisper, O
- Conda: faster-whisper, sounddevice, numpy, pyperclip, requests, ollama - Conda: faster-whisper, sounddevice, numpy, pyperclip, requests, ollama
- Pip (in conda env): webrtcvad - Pip (in conda env): webrtcvad
- System: libportaudio2, xdotool - System: libportaudio2, xdotool
- Dev: pytest
## Conventions ## Conventions
- Shell wrappers go in .sh files using `mamba run -n whisper-ollama` - Shell wrappers go in .sh files using `mamba run -n whisper-ollama`
- All scripts set `CT2_CUDA_ALLOW_FP16=1` - Shared code lives in `sttlib/` — scripts are thin entry points that import from it
- Whisper model loading always has GPU (cuda/float16) -> CPU (cpu/int8) fallback - Whisper model loading always has GPU (cuda/float16) -> CPU (cpu/int8) fallback
- Keep scripts self-contained (no shared module) - `CT2_CUDA_ALLOW_FP16=1` is set by `sttlib.whisper_loader` at import time
- Don't print output for non-actionable events - Don't print output for non-actionable events
## Preferences ## Preferences

View File

@@ -1,57 +1,17 @@
import sounddevice as sd import argparse
import numpy as np
import pyperclip import pyperclip
import requests import requests
import sys from sttlib import load_whisper_model, record_until_enter, transcribe
import argparse
from faster_whisper import WhisperModel
import os
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
# --- Configuration --- # --- Configuration ---
MODEL_SIZE = "medium" # Options: "base", "small", "medium", "large-v3" OLLAMA_URL = "http://localhost:11434/api/generate"
OLLAMA_URL = "http://localhost:11434/api/generate" # Default is 11434
DEFAULT_OLLAMA_MODEL = "qwen3:latest" DEFAULT_OLLAMA_MODEL = "qwen3:latest"
# Load Whisper on GPU
# float16 is faster and uses less VRAM on NVIDIA cards
print("Loading Whisper model...")
try:
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="float16")
except Exception as e:
print(f"Error loading GPU: {e}")
print("Falling back to CPU (Check your CUDA/cuDNN installation)")
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="int16")
def record_audio():
fs = 16000
print("\n[READY] Press Enter to START recording...")
input()
print("[RECORDING] Press Enter to STOP...")
recording = []
def callback(indata, frames, time, status):
if status:
print(status, file=sys.stderr)
recording.append(indata.copy())
with sd.InputStream(samplerate=fs, channels=1, callback=callback):
input()
return np.concatenate(recording, axis=0)
def main(): def main():
# 1. Setup Parser
print(f"System active. Model: {DEFAULT_OLLAMA_MODEL}") print(f"System active. Model: {DEFAULT_OLLAMA_MODEL}")
parser = argparse.ArgumentParser(description="Whisper + Ollama CLI") parser = argparse.ArgumentParser(description="Whisper + Ollama CLI")
# Known Arguments (Hardcoded logic)
parser.add_argument("--nollm", "-n", action='store_true', parser.add_argument("--nollm", "-n", action='store_true',
help="turn off llm") help="turn off llm")
parser.add_argument("--system", "-s", default=None, parser.add_argument("--system", "-s", default=None,
@@ -65,30 +25,27 @@ def main():
parser.add_argument( parser.add_argument(
"--temp", default='0.7', help="temperature") "--temp", default='0.7', help="temperature")
# 2. Capture "Unknown" arguments
# args = known values, unknown = a list like ['--num_ctx', '4096', '--temp', '0.7']
args, unknown = parser.parse_known_args() args, unknown = parser.parse_known_args()
# Convert unknown list to a dictionary for the Ollama 'options' field # Convert unknown list to a dictionary for the Ollama 'options' field
# This logic pairs ['--key', 'value'] into {key: value}
extra_options = {} extra_options = {}
for i in range(0, len(unknown), 2): for i in range(0, len(unknown), 2):
key = unknown[i].lstrip('-') # remove the '--' key = unknown[i].lstrip('-')
val = unknown[i+1] val = unknown[i+1]
# Try to convert numbers to actual ints/floats
try: try:
val = float(val) if '.' in val else int(val) val = float(val) if '.' in val else int(val)
except ValueError: except ValueError:
pass pass
extra_options[key] = val extra_options[key] = val
model = load_whisper_model(args.model_size)
while True: while True:
try: try:
audio_data = record_audio() audio_data = record_until_enter()
print("[TRANSCRIBING]...") print("[TRANSCRIBING]...")
segments, _ = model.transcribe(audio_data.flatten(), beam_size=5) text = transcribe(model, audio_data.flatten())
text = "".join([segment.text for segment in segments]).strip()
if not text: if not text:
print("No speech detected. Try again.") print("No speech detected. Try again.")
@@ -97,8 +54,7 @@ def main():
print(f"You said: {text}") print(f"You said: {text}")
pyperclip.copy(text) pyperclip.copy(text)
if (args.nollm == False): if not args.nollm:
# Send to Ollama
print(f"[OLLAMA] Thinking...") print(f"[OLLAMA] Thinking...")
payload = { payload = {
"model": args.ollama_model, "model": args.ollama_model,
@@ -108,9 +64,9 @@ def main():
} }
if args.system: if args.system:
payload["system"] = args payload["system"] = args.system
response = requests.post(OLLAMA_URL, json=payload)
response = requests.post(OLLAMA_URL, json=payload)
result = response.json().get("response", "") result = response.json().get("response", "")
print(f"\nLLM Response:\n{result}\n") print(f"\nLLM Response:\n{result}\n")
else: else:

View File

@@ -0,0 +1,7 @@
from sttlib.whisper_loader import load_whisper_model
from sttlib.audio import record_until_enter, pcm_bytes_to_float32
from sttlib.transcription import transcribe, is_hallucination, HALLUCINATION_PATTERNS
from sttlib.vad import (
VADProcessor, audio_callback, audio_queue,
SAMPLE_RATE, CHANNELS, FRAME_DURATION_MS, FRAME_SIZE, MIN_UTTERANCE_FRAMES,
)

View File

@@ -0,0 +1,28 @@
import sys
import numpy as np
import sounddevice as sd
def record_until_enter(sample_rate=16000):
"""Record audio until user presses Enter. Returns float32 numpy array."""
print("\n[READY] Press Enter to START recording...")
input()
print("[RECORDING] Press Enter to STOP...")
recording = []
def callback(indata, frames, time, status):
if status:
print(status, file=sys.stderr)
recording.append(indata.copy())
with sd.InputStream(samplerate=sample_rate, channels=1, callback=callback):
input()
return np.concatenate(recording, axis=0)
def pcm_bytes_to_float32(pcm_bytes):
"""Convert raw 16-bit PCM bytes to float32 array normalized to [-1, 1]."""
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
return audio_int16.astype(np.float32) / 32768.0

View File

@@ -0,0 +1,19 @@
HALLUCINATION_PATTERNS = [
"thank you", "thanks for watching", "subscribe",
"bye", "the end", "thank you for watching",
"please subscribe", "like and subscribe",
]
def transcribe(model, audio_float32):
"""Transcribe audio using Whisper. Returns stripped text."""
segments, _ = model.transcribe(audio_float32, beam_size=5)
return "".join(segment.text for segment in segments).strip()
def is_hallucination(text):
"""Return True if text looks like a Whisper hallucination."""
lowered = text.lower().strip()
if len(lowered) < 3:
return True
return any(p in lowered for p in HALLUCINATION_PATTERNS)

View File

@@ -0,0 +1,58 @@
import sys
import queue
import collections
import webrtcvad
SAMPLE_RATE = 16000
CHANNELS = 1
FRAME_DURATION_MS = 30
FRAME_SIZE = int(SAMPLE_RATE * FRAME_DURATION_MS / 1000) # 480 samples
MIN_UTTERANCE_FRAMES = 10 # ~300ms minimum to filter coughs/clicks
audio_queue = queue.Queue()
def audio_callback(indata, frames, time_info, status):
"""sounddevice callback that pushes raw bytes to the audio queue."""
if status:
print(status, file=sys.stderr)
audio_queue.put(bytes(indata))
class VADProcessor:
def __init__(self, aggressiveness, silence_threshold):
self.vad = webrtcvad.Vad(aggressiveness)
self.silence_threshold = silence_threshold
self.reset()
def reset(self):
self.triggered = False
self.utterance_frames = []
self.silence_duration = 0.0
self.pre_buffer = collections.deque(maxlen=10) # ~300ms pre-roll
def process_frame(self, frame_bytes):
"""Process one 30ms frame. Returns utterance bytes when complete, else None."""
is_speech = self.vad.is_speech(frame_bytes, SAMPLE_RATE)
if not self.triggered:
self.pre_buffer.append(frame_bytes)
if is_speech:
self.triggered = True
self.silence_duration = 0.0
self.utterance_frames = list(self.pre_buffer)
self.utterance_frames.append(frame_bytes)
else:
self.utterance_frames.append(frame_bytes)
if is_speech:
self.silence_duration = 0.0
else:
self.silence_duration += FRAME_DURATION_MS / 1000.0
if self.silence_duration >= self.silence_threshold:
if len(self.utterance_frames) < MIN_UTTERANCE_FRAMES:
self.reset()
return None
result = b"".join(self.utterance_frames)
self.reset()
return result
return None

View File

@@ -0,0 +1,15 @@
import os
from faster_whisper import WhisperModel
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
def load_whisper_model(model_size):
"""Load Whisper with GPU (cuda/float16) -> CPU (cpu/int8) fallback."""
print(f"Loading Whisper model ({model_size})...")
try:
return WhisperModel(model_size, device="cuda", compute_type="float16")
except Exception as e:
print(f"GPU loading failed: {e}")
print("Falling back to CPU (int8)")
return WhisperModel(model_size, device="cpu", compute_type="int8")

View File

@@ -0,0 +1,38 @@
import struct
import numpy as np
from sttlib.audio import pcm_bytes_to_float32
def test_known_value():
# 16384 in int16 -> 0.5 in float32
pcm = struct.pack("<h", 16384)
result = pcm_bytes_to_float32(pcm)
assert abs(result[0] - 0.5) < 1e-5
def test_silence():
pcm = b"\x00\x00" * 10
result = pcm_bytes_to_float32(pcm)
assert np.all(result == 0.0)
def test_full_scale():
# max int16 = 32767 -> ~1.0
pcm = struct.pack("<h", 32767)
result = pcm_bytes_to_float32(pcm)
assert abs(result[0] - (32767 / 32768.0)) < 1e-5
def test_negative():
# min int16 = -32768 -> -1.0
pcm = struct.pack("<h", -32768)
result = pcm_bytes_to_float32(pcm)
assert result[0] == -1.0
def test_round_trip_shape():
# 100 samples worth of bytes
pcm = b"\x00\x00" * 100
result = pcm_bytes_to_float32(pcm)
assert result.shape == (100,)
assert result.dtype == np.float32

View File

@@ -0,0 +1,78 @@
from unittest.mock import MagicMock
from sttlib.transcription import transcribe, is_hallucination
# --- is_hallucination tests ---
def test_known_hallucinations():
assert is_hallucination("Thank you")
assert is_hallucination("thanks for watching")
assert is_hallucination("Subscribe")
assert is_hallucination("the end")
def test_short_text():
assert is_hallucination("hi")
assert is_hallucination("")
assert is_hallucination("a")
def test_normal_text():
assert not is_hallucination("Hello how are you")
assert not is_hallucination("Please open the terminal")
def test_case_insensitivity():
assert is_hallucination("THANK YOU")
assert is_hallucination("Thank You For Watching")
def test_substring_match():
assert is_hallucination("I want to subscribe to your channel")
def test_exactly_three_chars():
assert not is_hallucination("hey")
# --- transcribe tests ---
def _make_segment(text):
seg = MagicMock()
seg.text = text
return seg
def test_transcribe_joins_segments():
model = MagicMock()
model.transcribe.return_value = (
[_make_segment("Hello "), _make_segment("world")],
None,
)
result = transcribe(model, MagicMock())
assert result == "Hello world"
def test_transcribe_empty():
model = MagicMock()
model.transcribe.return_value = ([], None)
result = transcribe(model, MagicMock())
assert result == ""
def test_transcribe_strips_whitespace():
model = MagicMock()
model.transcribe.return_value = (
[_make_segment(" hello ")],
None,
)
result = transcribe(model, MagicMock())
assert result == "hello"
def test_transcribe_passes_beam_size():
model = MagicMock()
model.transcribe.return_value = ([], None)
audio = MagicMock()
transcribe(model, audio)
model.transcribe.assert_called_once_with(audio, beam_size=5)

View File

@@ -0,0 +1,151 @@
from unittest.mock import patch, MagicMock
from sttlib.vad import VADProcessor, FRAME_DURATION_MS, MIN_UTTERANCE_FRAMES
def _make_vad_processor(aggressiveness=3, silence_threshold=0.8):
"""Create VADProcessor with a mocked webrtcvad.Vad."""
with patch("sttlib.vad.webrtcvad.Vad") as mock_vad_cls:
mock_vad = MagicMock()
mock_vad_cls.return_value = mock_vad
proc = VADProcessor(aggressiveness, silence_threshold)
return proc, mock_vad
def _frame(label="x"):
"""Return a fake 30ms frame (just needs to be distinct bytes)."""
return label.encode() * 960 # 480 samples * 2 bytes
def test_no_speech_returns_none():
proc, mock_vad = _make_vad_processor()
mock_vad.is_speech.return_value = False
for _ in range(100):
assert proc.process_frame(_frame()) is None
def test_speech_then_silence_triggers_utterance():
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
# Feed enough speech frames
speech_count = MIN_UTTERANCE_FRAMES + 5
mock_vad.is_speech.return_value = True
for _ in range(speech_count):
result = proc.process_frame(_frame("s"))
assert result is None # not done yet
# Feed silence frames until threshold (0.3s = 10 frames at 30ms)
mock_vad.is_speech.return_value = False
result = None
for _ in range(20):
result = proc.process_frame(_frame("q"))
if result is not None:
break
assert result is not None
assert len(result) > 0
def test_short_utterance_filtered():
# Use very short silence threshold so silence frames don't push total
# past MIN_UTTERANCE_FRAMES. With threshold=0.09s (3 frames of silence):
# 0 pre-buffer + 1 speech + 3 silence = 4 total < MIN_UTTERANCE_FRAMES (10)
proc, mock_vad = _make_vad_processor(silence_threshold=0.09)
# Single speech frame triggers VAD
mock_vad.is_speech.return_value = True
proc.process_frame(_frame("s"))
# Immediately go silent — threshold reached in 3 frames
mock_vad.is_speech.return_value = False
result = None
for _ in range(20):
result = proc.process_frame(_frame("q"))
if result is not None:
break
# Should be filtered (too short — only 4 total frames)
assert result is None
def test_pre_buffer_included():
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
# Fill pre-buffer with non-speech frames
mock_vad.is_speech.return_value = False
pre_frame = _frame("p")
for _ in range(10):
proc.process_frame(pre_frame)
# Speech starts
mock_vad.is_speech.return_value = True
speech_frame = _frame("s")
for _ in range(MIN_UTTERANCE_FRAMES):
proc.process_frame(speech_frame)
# Silence to trigger
mock_vad.is_speech.return_value = False
result = None
for _ in range(20):
result = proc.process_frame(_frame("q"))
if result is not None:
break
assert result is not None
# Result should contain pre-buffer frames
assert pre_frame in result
def test_reset_after_utterance():
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
# First utterance
mock_vad.is_speech.return_value = True
for _ in range(MIN_UTTERANCE_FRAMES + 5):
proc.process_frame(_frame("s"))
mock_vad.is_speech.return_value = False
for _ in range(20):
result = proc.process_frame(_frame("q"))
if result is not None:
break
assert result is not None
# After reset, should be able to collect a second utterance
assert not proc.triggered
assert proc.utterance_frames == []
mock_vad.is_speech.return_value = True
for _ in range(MIN_UTTERANCE_FRAMES + 5):
proc.process_frame(_frame("s"))
mock_vad.is_speech.return_value = False
result2 = None
for _ in range(20):
result2 = proc.process_frame(_frame("q"))
if result2 is not None:
break
assert result2 is not None
def test_silence_threshold_boundary():
# Use 0.3s threshold: 0.3 / 0.03 = exactly 10 frames needed
threshold = 0.3
proc, mock_vad = _make_vad_processor(silence_threshold=threshold)
# Start with speech
mock_vad.is_speech.return_value = True
for _ in range(MIN_UTTERANCE_FRAMES + 5):
proc.process_frame(_frame("s"))
frames_needed = 10 # 0.3s / 0.03s per frame
mock_vad.is_speech.return_value = False
# Feed one less than needed — should NOT trigger
for i in range(frames_needed - 1):
result = proc.process_frame(_frame("q"))
assert result is None, f"Triggered too early at frame {i}"
# The 10th frame should trigger (silence_duration = 0.3 >= 0.3)
result = proc.process_frame(_frame("q"))
assert result is not None

View File

@@ -0,0 +1,37 @@
from unittest.mock import patch, MagicMock
from sttlib.whisper_loader import load_whisper_model
@patch("sttlib.whisper_loader.WhisperModel")
def test_gpu_success(mock_cls):
mock_model = MagicMock()
mock_cls.return_value = mock_model
result = load_whisper_model("base")
assert result is mock_model
mock_cls.assert_called_once_with("base", device="cuda", compute_type="float16")
@patch("sttlib.whisper_loader.WhisperModel")
def test_gpu_fails_cpu_fallback(mock_cls):
mock_model = MagicMock()
mock_cls.side_effect = [RuntimeError("no CUDA"), mock_model]
result = load_whisper_model("base")
assert result is mock_model
assert mock_cls.call_count == 2
_, kwargs = mock_cls.call_args
assert kwargs == {"device": "cpu", "compute_type": "int8"}
@patch("sttlib.whisper_loader.WhisperModel")
def test_both_fail_propagates(mock_cls):
mock_cls.side_effect = RuntimeError("no device")
try:
load_whisper_model("base")
assert False, "Should have raised"
except RuntimeError:
pass

View File

@@ -1,31 +1,16 @@
import sounddevice as sd
import numpy as np
import pyperclip
import sys
import argparse import argparse
import os
import subprocess import subprocess
import ollama
import json import json
from faster_whisper import WhisperModel import ollama
from sttlib import load_whisper_model, record_until_enter, transcribe
# --- Configuration --- # --- Configuration ---
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
MODEL_SIZE = "medium"
OLLAMA_MODEL = "qwen2.5-coder:7b" OLLAMA_MODEL = "qwen2.5-coder:7b"
CONFIRM_COMMANDS = True # Set to False to run commands instantly CONFIRM_COMMANDS = True # Set to False to run commands instantly
# Load Whisper on GPU
print("Loading Whisper model...")
try:
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="float16")
except Exception as e:
print(f"Error loading GPU: {e}, falling back to CPU")
model = WhisperModel(MODEL_SIZE, device="cpu", compute_type="int8")
# --- Terminal Tool --- # --- Terminal Tool ---
def run_terminal_command(command: str): def run_terminal_command(command: str):
""" """
Executes a bash command in the Linux terminal. Executes a bash command in the Linux terminal.
@@ -33,8 +18,7 @@ def run_terminal_command(command: str):
""" """
if CONFIRM_COMMANDS: if CONFIRM_COMMANDS:
print(f"\n{'='*40}") print(f"\n{'='*40}")
print(f"⚠️ AI SUGGESTED: \033[1;32m{command}\033[0m") print(f"\u26a0\ufe0f AI SUGGESTED: \033[1;32m{command}\033[0m")
# Allow user to provide feedback if they say 'n'
choice = input(" Confirm? [Y/n] or provide feedback: ").strip() choice = input(" Confirm? [Y/n] or provide feedback: ").strip()
if choice.lower() == 'n': if choice.lower() == 'n':
@@ -57,22 +41,15 @@ def run_terminal_command(command: str):
return f"Execution Error: {str(e)}" return f"Execution Error: {str(e)}"
def record_audio():
fs, recording = 16000, []
print("\n[READY] Press Enter to START...")
input()
print("[RECORDING] Press Enter to STOP...")
def cb(indata, f, t, s): recording.append(indata.copy())
with sd.InputStream(samplerate=fs, channels=1, callback=cb):
input()
return np.concatenate(recording, axis=0)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--model", default=OLLAMA_MODEL) parser.add_argument("--model", default=OLLAMA_MODEL)
parser.add_argument("--model-size", default="medium",
help="Whisper model size")
args, _ = parser.parse_known_args() args, _ = parser.parse_known_args()
whisper_model = load_whisper_model(args.model_size)
# Initial System Prompt # Initial System Prompt
messages = [{ messages = [{
'role': 'system', 'role': 'system',
@@ -88,9 +65,8 @@ def main():
while True: while True:
try: try:
# 1. Voice Capture # 1. Voice Capture
audio_data = record_audio() audio_data = record_until_enter()
segments, _ = model.transcribe(audio_data.flatten(), beam_size=5) user_text = transcribe(whisper_model, audio_data.flatten())
user_text = "".join([s.text for s in segments]).strip()
if not user_text: if not user_text:
continue continue

View File

@@ -1,91 +1,14 @@
import sounddevice as sd
import numpy as np
import webrtcvad
import subprocess
import sys
import os
import argparse import argparse
import subprocess
import threading import threading
import queue import queue
import collections
import time import time
from faster_whisper import WhisperModel import sounddevice as sd
from sttlib import (
os.environ["CT2_CUDA_ALLOW_FP16"] = "1" load_whisper_model, transcribe, is_hallucination, pcm_bytes_to_float32,
VADProcessor, audio_callback, audio_queue,
# --- Constants --- SAMPLE_RATE, CHANNELS, FRAME_SIZE,
SAMPLE_RATE = 16000 )
CHANNELS = 1
FRAME_DURATION_MS = 30
FRAME_SIZE = int(SAMPLE_RATE * FRAME_DURATION_MS / 1000) # 480 samples
MIN_UTTERANCE_FRAMES = 10 # ~300ms minimum to filter coughs/clicks
HALLUCINATION_PATTERNS = [
"thank you", "thanks for watching", "subscribe",
"bye", "the end", "thank you for watching",
"please subscribe", "like and subscribe",
]
# --- Thread-safe audio queue ---
audio_queue = queue.Queue()
def audio_callback(indata, frames, time_info, status):
if status:
print(status, file=sys.stderr)
audio_queue.put(bytes(indata))
# --- Whisper model loading (reused pattern from assistant.py) ---
def load_whisper_model(model_size):
print(f"Loading Whisper model ({model_size})...")
try:
return WhisperModel(model_size, device="cuda", compute_type="float16")
except Exception as e:
print(f"GPU loading failed: {e}")
print("Falling back to CPU (int8)")
return WhisperModel(model_size, device="cpu", compute_type="int8")
# --- VAD State Machine ---
class VADProcessor:
def __init__(self, aggressiveness, silence_threshold):
self.vad = webrtcvad.Vad(aggressiveness)
self.silence_threshold = silence_threshold
self.reset()
def reset(self):
self.triggered = False
self.utterance_frames = []
self.silence_duration = 0.0
self.pre_buffer = collections.deque(maxlen=10) # ~300ms pre-roll
def process_frame(self, frame_bytes):
"""Process one 30ms frame. Returns utterance bytes when complete, else None."""
is_speech = self.vad.is_speech(frame_bytes, SAMPLE_RATE)
if not self.triggered:
self.pre_buffer.append(frame_bytes)
if is_speech:
self.triggered = True
self.silence_duration = 0.0
self.utterance_frames = list(self.pre_buffer)
self.utterance_frames.append(frame_bytes)
pass # silent until transcription confirms speech
else:
self.utterance_frames.append(frame_bytes)
if is_speech:
self.silence_duration = 0.0
else:
self.silence_duration += FRAME_DURATION_MS / 1000.0
if self.silence_duration >= self.silence_threshold:
if len(self.utterance_frames) < MIN_UTTERANCE_FRAMES:
self.reset()
return None
result = b"".join(self.utterance_frames)
self.reset()
return result
return None
# --- Typer Interface (xdotool) --- # --- Typer Interface (xdotool) ---
@@ -99,6 +22,7 @@ class Typer:
except FileNotFoundError: except FileNotFoundError:
print("ERROR: xdotool not found. Install it:") print("ERROR: xdotool not found. Install it:")
print(" sudo apt-get install xdotool") print(" sudo apt-get install xdotool")
import sys
sys.exit(1) sys.exit(1)
def type_text(self, text, submit_now=False): def type_text(self, text, submit_now=False):
@@ -120,24 +44,6 @@ class Typer:
pass pass
# --- Helpers ---
def pcm_bytes_to_float32(pcm_bytes):
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
return audio_int16.astype(np.float32) / 32768.0
def transcribe(model, audio_float32):
segments, _ = model.transcribe(audio_float32, beam_size=5)
return "".join(segment.text for segment in segments).strip()
def is_hallucination(text):
lowered = text.lower().strip()
if len(lowered) < 3:
return True
return any(p in lowered for p in HALLUCINATION_PATTERNS)
# --- CLI --- # --- CLI ---
def parse_args(): def parse_args():
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(