Compare commits
6 Commits
848681087e
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
01d5532823 | ||
|
|
3a8705ece8 | ||
|
|
8b5eb8797f | ||
|
|
d484f9c236 | ||
|
|
2e8c2c11d0 | ||
|
|
104da381fb |
5
python/american-mc/.vscode/settings.json
vendored
Normal file
5
python/american-mc/.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
||||
"python-envs.pythonProjects": []
|
||||
}
|
||||
19
python/american-mc/CONTEXTRESUME.md
Normal file
19
python/american-mc/CONTEXTRESUME.md
Normal file
@@ -0,0 +1,19 @@
|
||||
Current Progress Summary:
|
||||
|
||||
The Baseline: We fixed a standard Longstaff-Schwartz (LSM) American Put pricer, correcting the cash-flow propagation logic and regression targets.
|
||||
|
||||
The Evolution: We moved to Bermudan Swaptions using the Hull-White One-Factor Model.
|
||||
|
||||
The "Gold Standard": We implemented a 100% exact simulation from scratch. Instead of Euler discretization, we used Bivariate Normal sampling to jointly simulate the short rate rt and the stochastic integral ∫rsds. This accounts for the stochastic discount factor (the convexity adjustment) without approximation error.
|
||||
|
||||
The Current Frontier: We were debating the Risk-Neutral Measure (Q) vs. the Terminal Forward Measure (QT).
|
||||
|
||||
We concluded that while QT simplifies European options, it makes Bermudan LSM "messy" because it introduces a time-dependent drift shift: DriftQT=DriftQ−aσ2(1−e−a(T−t)).
|
||||
|
||||
Pending Topics:
|
||||
|
||||
Mathematical Proof: The derivation of the "Drift Shift" via Girsanov’s Theorem.
|
||||
|
||||
Exercise Boundary Impact: How the choice of measure (and the resulting drift) visually shifts the optimal exercise boundary in simulation.
|
||||
|
||||
Beyond One-Factor: Potential move toward Two-Factor models or non-flat initial term structures.
|
||||
145
python/american-mc/main.py
Normal file
145
python/american-mc/main.py
Normal file
@@ -0,0 +1,145 @@
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.stats import norm
|
||||
|
||||
# -----------------------------
|
||||
# Black-Scholes European Put
|
||||
# -----------------------------
|
||||
def european_put_black_scholes(S0, K, r, sigma, T):
|
||||
d1 = (np.log(S0/K) + (r + 0.5*sigma**2)*T) / (sigma*np.sqrt(T))
|
||||
d2 = d1 - sigma*np.sqrt(T)
|
||||
return K*np.exp(-r*T)*norm.cdf(-d2) - S0*norm.cdf(-d1)
|
||||
|
||||
# -----------------------------
|
||||
# GBM Simulation
|
||||
# -----------------------------
|
||||
def simulate_gbm(S0, r, sigma, T, M, N, seed=None):
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
dt = T / N
|
||||
# Vectorized simulation for speed
|
||||
S = np.zeros((M, N + 1))
|
||||
S[:, 0] = S0
|
||||
for t in range(1, N + 1):
|
||||
z = np.random.standard_normal(M)
|
||||
S[:, t] = S[:, t-1] * np.exp((r - 0.5 * sigma**2) * dt + sigma * np.sqrt(dt) * z)
|
||||
return S
|
||||
|
||||
# -----------------------------
|
||||
# Longstaff-Schwartz Monte Carlo (Fixed)
|
||||
# -----------------------------
|
||||
def lsm_american_put(S0, K, r, sigma, T, M=100_000, N=50, basis_deg=2, seed=42, return_boundary=False):
|
||||
S = simulate_gbm(S0, r, sigma, T, M, N, seed)
|
||||
dt = T / N
|
||||
df = np.exp(-r * dt)
|
||||
|
||||
# Immediate exercise value (payoff) at each step
|
||||
payoff = np.maximum(K - S, 0)
|
||||
|
||||
# V stores the value of the option at each path
|
||||
# Initialize with the payoff at maturity
|
||||
V = payoff[:, -1]
|
||||
|
||||
boundary = []
|
||||
|
||||
# Backward induction
|
||||
for t in range(N - 1, 0, -1):
|
||||
# Identify In-The-Money paths
|
||||
itm = payoff[:, t] > 0
|
||||
|
||||
# Default: option value at t is just the discounted value from t+1
|
||||
V = V * df
|
||||
|
||||
if np.any(itm):
|
||||
X = S[itm, t]
|
||||
Y = V[itm] # These are already discounted from t+1
|
||||
|
||||
# Regression: Estimate Continuation Value
|
||||
coeffs = np.polyfit(X, Y, basis_deg)
|
||||
continuation_value = np.polyval(coeffs, X)
|
||||
|
||||
# Exercise if payoff > estimated continuation value
|
||||
exercise = payoff[itm, t] > continuation_value
|
||||
|
||||
# Update V for paths where we exercise
|
||||
# Get the indices of the ITM paths that should exercise
|
||||
itm_indices = np.where(itm)[0]
|
||||
exercise_indices = itm_indices[exercise]
|
||||
V[exercise_indices] = payoff[exercise_indices, t]
|
||||
|
||||
# Boundary: The highest stock price at which we still exercise
|
||||
if len(exercise_indices) > 0:
|
||||
boundary.append((t * dt, np.max(S[exercise_indices, t])))
|
||||
else:
|
||||
boundary.append((t * dt, np.nan))
|
||||
else:
|
||||
boundary.append((t * dt, np.nan))
|
||||
|
||||
# Final price is the average of discounted values at t=1
|
||||
price = np.mean(V * df)
|
||||
|
||||
if return_boundary:
|
||||
return price, boundary[::-1]
|
||||
return price
|
||||
|
||||
# -----------------------------
|
||||
# QuantLib American Put (v1.18)
|
||||
# -----------------------------
|
||||
|
||||
def quantlib_american_put(S0, K, r, sigma, T):
|
||||
import QuantLib as ql
|
||||
|
||||
today = ql.Date().todaysDate()
|
||||
ql.Settings.instance().evaluationDate = today
|
||||
maturity = today + int(T*365)
|
||||
payoff = ql.PlainVanillaPayoff(ql.Option.Put, K)
|
||||
exercise = ql.AmericanExercise(today, maturity)
|
||||
|
||||
spot = ql.QuoteHandle(ql.SimpleQuote(S0))
|
||||
dividend_curve = ql.YieldTermStructureHandle(ql.FlatForward(today, 0.0, ql.Actual365Fixed()))
|
||||
riskfree_curve = ql.YieldTermStructureHandle(ql.FlatForward(today, r, ql.Actual365Fixed()))
|
||||
vol_curve = ql.BlackVolTermStructureHandle(ql.BlackConstantVol(today, ql.NullCalendar(), sigma, ql.Actual365Fixed()))
|
||||
|
||||
process = ql.BlackScholesMertonProcess(spot, dividend_curve, riskfree_curve, vol_curve)
|
||||
engine = ql.FdBlackScholesVanillaEngine(process, 200, 400)
|
||||
option = ql.VanillaOption(payoff, exercise)
|
||||
option.setPricingEngine(engine)
|
||||
return option.NPV()
|
||||
|
||||
# -----------------------------
|
||||
# Main script
|
||||
# -----------------------------
|
||||
if __name__ == "__main__":
|
||||
# Parameters
|
||||
S0, K, r, sigma, T = 100, 100, 0.05, 0.2, 1.0
|
||||
M, N = 200_000, 50
|
||||
|
||||
# European Put
|
||||
eur_bs = european_put_black_scholes(S0, K, r, sigma, T)
|
||||
print(f"European Put (BS): {eur_bs:.4f}")
|
||||
|
||||
# LSM American Put
|
||||
lsm_price, boundary = lsm_american_put(S0, K, r, sigma, T, M, N, return_boundary=True)
|
||||
print(f"LSM American Put (M={M}): {lsm_price:.4f}")
|
||||
|
||||
# QuantLib American Put
|
||||
ql_price = quantlib_american_put(S0, K, r, sigma, T)
|
||||
print(f"QuantLib American Put: {ql_price:.4f}")
|
||||
|
||||
print(f"Lower bound (immediate exercise): {max(K-S0,0):.4f}")
|
||||
|
||||
# Plot Exercise Boundary
|
||||
times = [b[0] for b in boundary]
|
||||
boundaries = [b[1] for b in boundary]
|
||||
|
||||
plt.figure(figsize=(8,5))
|
||||
plt.plot(times, boundaries, color='orange', label="LSM Exercise Boundary")
|
||||
plt.axhline(K, color='red', linestyle='--', label="Strike Price")
|
||||
plt.xlabel("Time to Maturity")
|
||||
plt.ylabel("Stock Price for Exercise")
|
||||
plt.title("American Put LSM Exercise Boundary")
|
||||
plt.gca().invert_xaxis()
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
133
python/american-mc/main2.py
Normal file
133
python/american-mc/main2.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 1. Analytical Hull-White Bond Pricing
|
||||
# ---------------------------------------------------------
|
||||
def bond_price(r_t, t, T, a, sigma, r0):
|
||||
"""Zero-coupon bond price P(t, T) in Hull-White model for flat curve r0."""
|
||||
B = (1 - np.exp(-a * (T - t))) / a
|
||||
A = np.exp((B - (T - t)) * (a**2 * r0 - sigma**2/2) / a**2 - (sigma**2 * B**2 / (4 * a)))
|
||||
return A * np.exp(-B * r_t)
|
||||
|
||||
def get_swap_npv(r_t, t, T_end, strike, a, sigma, r0):
|
||||
"""NPV of a Payer Swap: Receives Floating, Pays Fixed strike."""
|
||||
payment_dates = np.arange(t + 1, T_end + 1, 1.0)
|
||||
if len(payment_dates) == 0:
|
||||
return np.zeros_like(r_t)
|
||||
|
||||
# NPV = 1 - P(t, T_end) - strike * Sum[P(t, T_i)]
|
||||
p_end = bond_price(r_t, t, T_end, a, sigma, r0)
|
||||
fixed_leg = sum(bond_price(r_t, t, pd, a, sigma, r0) for pd in payment_dates)
|
||||
|
||||
return np.maximum(1 - p_end - strike * fixed_leg, 0)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 2. Joint Exact Simulator (Short Rate & Stochastic Integral)
|
||||
# ---------------------------------------------------------
|
||||
def simulate_hw_exact_joint(r0, a, sigma, exercise_dates, M):
|
||||
"""Samples (r_t, Integral[r]) jointly to get exact discount factors."""
|
||||
M = int(M)
|
||||
num_dates = len(exercise_dates)
|
||||
r_matrix = np.zeros((M, num_dates + 1))
|
||||
d_matrix = np.zeros((M, num_dates)) # d[i] = exp(-integral from t_i to t_{i+1})
|
||||
|
||||
r_matrix[:, 0] = r0
|
||||
t_steps = np.insert(exercise_dates, 0, 0.0)
|
||||
|
||||
for i in range(len(t_steps) - 1):
|
||||
t, T = t_steps[i], t_steps[i+1]
|
||||
dt = T - t
|
||||
|
||||
# Drift adjustments for flat initial curve r0
|
||||
alpha_t = r0 + (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * t))**2
|
||||
alpha_T = r0 + (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * T))**2
|
||||
|
||||
# Means
|
||||
mean_r = r_matrix[:, i] * np.exp(-a * dt) + alpha_T - alpha_t * np.exp(-a * dt)
|
||||
# The expected value of the integral is derived from the bond price: E[exp(-I)] = P(t,T)
|
||||
mean_I = -np.log(bond_price(r_matrix[:, i], t, T, a, sigma, r0))
|
||||
|
||||
# Covariance Matrix Components
|
||||
var_r = (sigma**2 / (2 * a)) * (1 - np.exp(-2 * a * dt))
|
||||
B = (1 - np.exp(-a * dt)) / a
|
||||
var_I = (sigma**2 / a**2) * (dt - 2*B + (1 - np.exp(-2*a*dt))/(2*a))
|
||||
cov_rI = (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * dt))**2
|
||||
|
||||
cov_matrix = [[var_r, cov_rI], [cov_rI, var_I]]
|
||||
|
||||
# Sample joint normal innovations
|
||||
Z = np.random.multivariate_normal([0, 0], cov_matrix, M)
|
||||
|
||||
r_matrix[:, i+1] = mean_r + Z[:, 0]
|
||||
# Important: The variance of the integral affects the mean of the exponent (convexity)
|
||||
# mean_I here is already the log of the bond price (risk-neutral expectation)
|
||||
d_matrix[:, i] = np.exp(-(mean_I + Z[:, 1] - 0.5 * var_I))
|
||||
|
||||
return r_matrix[:, 1:], d_matrix
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 3. LSM Pricing Logic
|
||||
# ---------------------------------------------------------
|
||||
def price_bermudan_swaption(r0, a, sigma, strike, exercise_dates, T_end, M):
|
||||
# 1. Generate exact paths
|
||||
r_at_ex, discounts = simulate_hw_exact_joint(r0, a, sigma, exercise_dates, M)
|
||||
|
||||
# 2. Final exercise date payoff
|
||||
T_last = exercise_dates[-1]
|
||||
cash_flows = get_swap_npv(r_at_ex[:, -1], T_last, T_end, strike, a, sigma, r0)
|
||||
|
||||
# 3. Backward Induction
|
||||
# exercise_dates[:-1] because we already handled the last date
|
||||
for i in reversed(range(len(exercise_dates) - 1)):
|
||||
t_current = exercise_dates[i]
|
||||
|
||||
# Pull cash flows back to current time using stochastic discount
|
||||
# Note: If path was exercised later, this is the discounted value of that exercise.
|
||||
cash_flows = cash_flows * discounts[:, i]
|
||||
|
||||
# Current intrinsic value
|
||||
X = r_at_ex[:, i]
|
||||
immediate_payoff = get_swap_npv(X, t_current, T_end, strike, a, sigma, r0)
|
||||
|
||||
# Only regress In-The-Money paths
|
||||
itm = immediate_payoff > 0
|
||||
if np.any(itm):
|
||||
# Regression: Basis functions [1, r, r^2]
|
||||
A = np.column_stack([np.ones_like(X[itm]), X[itm], X[itm]**2])
|
||||
coeffs = np.linalg.lstsq(A, cash_flows[itm], rcond=None)[0]
|
||||
continuation_value = A @ coeffs
|
||||
|
||||
# Exercise decision
|
||||
exercise = immediate_payoff[itm] > continuation_value
|
||||
|
||||
itm_indices = np.where(itm)[0]
|
||||
exercise_indices = itm_indices[exercise]
|
||||
|
||||
# Update cash flows for paths where we exercise early
|
||||
cash_flows[exercise_indices] = immediate_payoff[exercise_indices]
|
||||
|
||||
# Final discount to t=0
|
||||
# The first discount factor in 'discounts' is from t1 to t0
|
||||
# But wait, our 'discounts' matrix is (M, num_dates).
|
||||
# Let's just use the analytical P(0, t1) for the very last step to t=0.
|
||||
final_price = np.mean(cash_flows * bond_price(r0, 0, exercise_dates[0], a, sigma, r0))
|
||||
return final_price
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Execution
|
||||
# ---------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
# Params
|
||||
r0_val, a_val, sigma_val = 0.05, 0.1, 0.01
|
||||
strike_val = 0.05
|
||||
ex_dates = np.array([1.0, 2.0, 3.0, 4.0])
|
||||
maturity = 5.0
|
||||
num_paths = 100_000
|
||||
|
||||
price = price_bermudan_swaption(r0_val, a_val, sigma_val, strike_val, ex_dates, maturity, num_paths)
|
||||
|
||||
print(f"--- 100% Exact LSM Bermudan Swaption ---")
|
||||
print(f"Parameters: a={a_val}, sigma={sigma_val}, strike={strike_val}")
|
||||
print(f"Exercise Dates: {ex_dates}")
|
||||
print(f"Calculated Price: {price:.6f}")
|
||||
5
python/cvatesting/.vscode/settings.json
vendored
Normal file
5
python/cvatesting/.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
||||
"python-envs.pythonProjects": []
|
||||
}
|
||||
909
python/cvatesting/CVA_calculation_I.ipynb
Normal file
909
python/cvatesting/CVA_calculation_I.ipynb
Normal file
File diff suppressed because one or more lines are too long
83
python/cvatesting/driver.py
Normal file
83
python/cvatesting/driver.py
Normal file
@@ -0,0 +1,83 @@
|
||||
|
||||
import numpy as np
|
||||
import Quantlib as ql
|
||||
|
||||
from hullwhite import (build_calibrated_hw,simulate_hw_state,swap_npv_from_state)
|
||||
from instruments import (make_vanilla_swap,make_swaption_helpers)
|
||||
|
||||
today = ql.Date(15, 1, 2025)
|
||||
ql.Settings.instance().evaluationDate = today
|
||||
calendar = ql.TARGET()
|
||||
day_count = ql.Actual365Fixed()
|
||||
|
||||
|
||||
def expected_exposure(
|
||||
swap,
|
||||
hw,
|
||||
curve_handle,
|
||||
today,
|
||||
time_grid,
|
||||
x_paths
|
||||
):
|
||||
ee = np.zeros(len(time_grid))
|
||||
|
||||
for j, t in enumerate(time_grid):
|
||||
values = []
|
||||
|
||||
for i in range(len(x_paths)):
|
||||
v = swap_npv_from_state(
|
||||
swap, hw, curve_handle, t, x_paths[i, j], today
|
||||
)
|
||||
values.append(max(v, 0.0))
|
||||
|
||||
ee[j] = np.mean(values)
|
||||
|
||||
return ee
|
||||
|
||||
def main():
|
||||
# Time grid
|
||||
time_grid = np.linspace(0, 10, 121)
|
||||
|
||||
flat_rate = 0.02
|
||||
curve = ql.FlatForward(today, flat_rate, day_count)
|
||||
curve_handle = ql.YieldTermStructureHandle(curve)
|
||||
|
||||
#creat and calibrate HW
|
||||
a_init = 0.05
|
||||
sigma_init = 0.01
|
||||
hw = ql.HullWhite(curve_handle, a_init, sigma_init)
|
||||
|
||||
index = ql.Euribor6M(curve_handle)
|
||||
helpers=make_swaption_helpers(curve_handle,index,hw)
|
||||
optimizer = ql.LevenbergMarquardt()
|
||||
end_criteria = ql.EndCriteria(1000, 500, 1e-8, 1e-8, 1e-8)
|
||||
hw.calibrate(helpers, optimizer, end_criteria)
|
||||
|
||||
a, sigma = hw.params()
|
||||
print(f"Calibrated a = {a:.4f}")
|
||||
print(f"Calibrated sigma = {sigma:.4f}")
|
||||
|
||||
|
||||
|
||||
# Simulate model
|
||||
x_paths = simulate_hw_state(
|
||||
hw, curve_handle, time_grid, n_paths=10000
|
||||
)
|
||||
|
||||
# Swap
|
||||
swap = make_receiver_swap(
|
||||
today, curve_handle, maturity_years=10, fixed_rate=0.025
|
||||
)
|
||||
|
||||
# Exposure
|
||||
ee = expected_exposure(
|
||||
swap, hw, curve_handle, today, time_grid, x_paths
|
||||
)
|
||||
|
||||
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ =='__main__':
|
||||
main()
|
||||
66
python/cvatesting/hullwhite.py
Normal file
66
python/cvatesting/hullwhite.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import QuantLib as ql
|
||||
|
||||
|
||||
def build_calibrated_hw(
|
||||
curve_handle,
|
||||
swaption_helpers,
|
||||
a_init=0.05,
|
||||
sigma_init=0.01
|
||||
):
|
||||
hw = ql.HullWhite(curve_handle, a_init, sigma_init)
|
||||
|
||||
optimizer = ql.LevenbergMarquardt()
|
||||
end_criteria = ql.EndCriteria(1000, 500, 1e-8, 1e-8, 1e-8)
|
||||
|
||||
hw.calibrate(swaption_helpers, optimizer, end_criteria)
|
||||
|
||||
return hw
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def simulate_hw_state(
|
||||
hw: ql.HullWhite,
|
||||
curve_handle,
|
||||
time_grid,
|
||||
n_paths,
|
||||
seed=42
|
||||
):
|
||||
a, sigma = hw.params()
|
||||
|
||||
dt = np.diff(time_grid)
|
||||
n_steps = len(time_grid)
|
||||
|
||||
np.random.seed(seed)
|
||||
x = np.zeros((n_paths, n_steps))
|
||||
|
||||
for i in range(1, n_steps):
|
||||
decay = np.exp(-a * dt[i-1])
|
||||
vol = sigma * np.sqrt((1 - np.exp(-2 * a * dt[i-1])) / (2 * a))
|
||||
z = np.random.normal(size=n_paths)
|
||||
|
||||
x[:, i] = x[:, i-1] * decay + vol * z
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def swap_npv_from_state(
|
||||
swap,
|
||||
hw,
|
||||
curve_handle,
|
||||
t,
|
||||
x_t,
|
||||
today
|
||||
):
|
||||
ql.Settings.instance().evaluationDate = today + int(t * 365)
|
||||
|
||||
hw.setState(x_t)
|
||||
|
||||
engine = ql.DiscountingSwapEngine(
|
||||
curve_handle,
|
||||
False,
|
||||
hw
|
||||
)
|
||||
|
||||
swap.setPricingEngine(engine)
|
||||
return swap.NPV()
|
||||
165
python/cvatesting/instruments.py
Normal file
165
python/cvatesting/instruments.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import QuantLib as ql
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Global defaults (override via parameters if needed)
|
||||
# ============================================================
|
||||
|
||||
CALENDAR = ql.TARGET()
|
||||
BUSINESS_CONVENTION = ql.ModifiedFollowing
|
||||
DATE_GEN = ql.DateGeneration.Forward
|
||||
|
||||
FIXED_DAYCOUNT = ql.Thirty360(ql.Thirty360.European)
|
||||
FLOAT_DAYCOUNT = ql.Actual360()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Yield curve factory
|
||||
# ============================================================
|
||||
|
||||
def make_flat_curve(
|
||||
evaluation_date: ql.Date,
|
||||
rate: float,
|
||||
day_count: ql.DayCounter = ql.Actual365Fixed()
|
||||
) -> ql.YieldTermStructureHandle:
|
||||
"""
|
||||
Flat yield curve factory.
|
||||
"""
|
||||
ql.Settings.instance().evaluationDate = evaluation_date
|
||||
curve = ql.FlatForward(evaluation_date, rate, day_count)
|
||||
return ql.YieldTermStructureHandle(curve)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Index factory
|
||||
# ============================================================
|
||||
|
||||
def make_euribor_6m(
|
||||
curve_handle: ql.YieldTermStructureHandle
|
||||
) -> ql.IborIndex:
|
||||
"""
|
||||
Euribor 6M index factory.
|
||||
"""
|
||||
return ql.Euribor6M(curve_handle)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Schedule factory
|
||||
# ============================================================
|
||||
|
||||
def make_schedule(
|
||||
start: ql.Date,
|
||||
maturity: ql.Date,
|
||||
tenor: ql.Period,
|
||||
calendar: ql.Calendar = CALENDAR
|
||||
) -> ql.Schedule:
|
||||
"""
|
||||
Generic schedule factory.
|
||||
"""
|
||||
return ql.Schedule(
|
||||
start,
|
||||
maturity,
|
||||
tenor,
|
||||
calendar,
|
||||
BUSINESS_CONVENTION,
|
||||
BUSINESS_CONVENTION,
|
||||
DATE_GEN,
|
||||
False
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Swap factory
|
||||
# ============================================================
|
||||
|
||||
def make_vanilla_swap(
|
||||
evaluation_date: ql.Date,
|
||||
curve_handle: ql.YieldTermStructureHandle,
|
||||
notional: float,
|
||||
fixed_rate: float,
|
||||
maturity_years: int,
|
||||
pay_fixed: bool = False,
|
||||
fixed_leg_freq: ql.Period = ql.Annual,
|
||||
float_leg_freq: ql.Period = ql.Semiannual
|
||||
) -> ql.VanillaSwap:
|
||||
"""
|
||||
Vanilla fixed-for-float IRS factory.
|
||||
|
||||
pay_fixed = True -> payer swap
|
||||
pay_fixed = False -> receiver swap
|
||||
"""
|
||||
|
||||
ql.Settings.instance().evaluationDate = evaluation_date
|
||||
|
||||
index = make_euribor_6m(curve_handle)
|
||||
|
||||
start = CALENDAR.advance(evaluation_date, 2, ql.Days)
|
||||
maturity = CALENDAR.advance(start, maturity_years, ql.Years)
|
||||
|
||||
fixed_schedule = make_schedule(
|
||||
start, maturity, ql.Period(fixed_leg_freq)
|
||||
)
|
||||
float_schedule = make_schedule(
|
||||
start, maturity, ql.Period(float_leg_freq)
|
||||
)
|
||||
|
||||
swap_type = (
|
||||
ql.VanillaSwap.Payer if pay_fixed else ql.VanillaSwap.Receiver
|
||||
)
|
||||
|
||||
swap = ql.VanillaSwap(
|
||||
swap_type,
|
||||
notional,
|
||||
fixed_schedule,
|
||||
fixed_rate,
|
||||
FIXED_DAYCOUNT,
|
||||
float_schedule,
|
||||
index,
|
||||
0.0,
|
||||
FLOAT_DAYCOUNT
|
||||
)
|
||||
|
||||
swap.setPricingEngine(
|
||||
ql.DiscountingSwapEngine(curve_handle)
|
||||
)
|
||||
|
||||
return swap
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Swaption helper factory (for HW calibration)
|
||||
# ============================================================
|
||||
|
||||
def make_swaption_helpers(
|
||||
swaption_data: list,
|
||||
curve_handle: ql.YieldTermStructureHandle,
|
||||
index: ql.IborIndex,
|
||||
model: ql.HullWhite
|
||||
) -> list:
|
||||
"""
|
||||
Create ATM swaption helpers.
|
||||
|
||||
swaption_data = [(expiry, tenor, vol), ...]
|
||||
"""
|
||||
|
||||
helpers = []
|
||||
|
||||
for expiry, tenor, vol in swaption_data:
|
||||
helper = ql.SwaptionHelper(
|
||||
expiry,
|
||||
tenor,
|
||||
ql.QuoteHandle(ql.SimpleQuote(vol)),
|
||||
index,
|
||||
index.tenor(),
|
||||
index.dayCounter(),
|
||||
index.dayCounter(),
|
||||
curve_handle
|
||||
)
|
||||
|
||||
helper.setPricingEngine(
|
||||
ql.JamshidianSwaptionEngine(model)
|
||||
)
|
||||
|
||||
helpers.append(helper)
|
||||
|
||||
return helpers
|
||||
109
python/cvatesting/main.py
Normal file
109
python/cvatesting/main.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import QuantLib as ql
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# --- UNIT 1: Instrument Setup ---
|
||||
def create_30y_swap(today, rate_quote):
|
||||
ql.Settings.instance().evaluationDate = today
|
||||
|
||||
# We use a Relinkable handle to swap curves per path/time-step
|
||||
yield_handle = ql.RelinkableYieldTermStructureHandle()
|
||||
yield_handle.linkTo(ql.FlatForward(today, rate_quote, ql.Actual365Fixed()))
|
||||
|
||||
calendar = ql.TARGET()
|
||||
settle_date = calendar.advance(today, 2, ql.Days)
|
||||
maturity_date = calendar.advance(settle_date, 30, ql.Years)
|
||||
|
||||
index = ql.Euribor6M(yield_handle)
|
||||
|
||||
fixed_schedule = ql.Schedule(settle_date, maturity_date, ql.Period(ql.Annual),
|
||||
calendar, ql.ModifiedFollowing, ql.ModifiedFollowing,
|
||||
ql.DateGeneration.Forward, False)
|
||||
float_schedule = ql.Schedule(settle_date, maturity_date, ql.Period(ql.Semiannual),
|
||||
calendar, ql.ModifiedFollowing, ql.ModifiedFollowing,
|
||||
ql.DateGeneration.Forward, False)
|
||||
|
||||
swap = ql.VanillaSwap(ql.VanillaSwap.Payer, 1e6, fixed_schedule, rate_quote,
|
||||
ql.Thirty360(ql.Thirty360.BondBasis), float_schedule,
|
||||
index, 0.0, ql.Actual360())
|
||||
|
||||
# Pre-calculate all fixing dates needed for the life of the swap
|
||||
fixing_dates = [index.fixingDate(d) for d in float_schedule]
|
||||
|
||||
return swap, yield_handle, index, fixing_dates
|
||||
|
||||
# --- UNIT 2: The Simulation Loop ---
|
||||
def run_simulation(n_paths=50):
|
||||
today = ql.Date(27, 1, 2026)
|
||||
swap, yield_handle, index, fixing_dates = create_30y_swap(today, 0.03)
|
||||
|
||||
# HW Parameters: a=mean reversion, sigma=vol
|
||||
model = ql.HullWhite(yield_handle, 0.01, 0.01)
|
||||
process = ql.HullWhiteProcess(yield_handle, 0.01, 0.01)
|
||||
|
||||
times = np.arange(0, 31, 1.0) # Annual buckets
|
||||
grid = ql.TimeGrid(times)
|
||||
|
||||
rng = ql.GaussianRandomSequenceGenerator(ql.UniformRandomSequenceGenerator(
|
||||
len(grid)-1, ql.UniformRandomGenerator()))
|
||||
seq = ql.GaussianMultiPathGenerator(process, grid, rng, False)
|
||||
|
||||
npv_matrix = np.zeros((n_paths, len(times)))
|
||||
|
||||
for i in range(n_paths):
|
||||
path = seq.next().value()[0]
|
||||
# 1. Clear previous path's fixings to avoid data pollution
|
||||
ql.IndexManager.instance().clearHistories()
|
||||
|
||||
for j, t in enumerate(times):
|
||||
if t >= 30: continue
|
||||
|
||||
eval_date = ql.TARGET().advance(today, int(t), ql.Years)
|
||||
ql.Settings.instance().evaluationDate = eval_date
|
||||
|
||||
# 2. Update the curve with simulated short rate rt
|
||||
rt = path[j]
|
||||
# Use pillars up to 35Y to avoid extrapolation crashes
|
||||
tenors = [0, 1, 2, 5, 10, 20, 30, 35]
|
||||
dates = [ql.TARGET().advance(eval_date, y, ql.Years) for y in tenors]
|
||||
discounts = [model.discountBond(t, t + y, rt) for y in tenors]
|
||||
|
||||
sim_curve = ql.DiscountCurve(dates, discounts, ql.Actual365Fixed())
|
||||
sim_curve.enableExtrapolation()
|
||||
yield_handle.linkTo(sim_curve)
|
||||
|
||||
# 3. MANUAL FIXING INJECTION
|
||||
# For every fixing date that has passed or is 'today'
|
||||
for fd in fixing_dates:
|
||||
if fd <= eval_date:
|
||||
# Direct manual calculation to avoid index.fixing() error
|
||||
# We calculate the forward rate for the 6M period starting at fd
|
||||
start_date = fd
|
||||
end_date = ql.TARGET().advance(start_date, 6, ql.Months)
|
||||
|
||||
# Manual Euribor rate formula: (P1/P2 - 1) / dt
|
||||
p1 = sim_curve.discount(start_date)
|
||||
p2 = sim_curve.discount(end_date)
|
||||
dt = ql.Actual360().yearFraction(start_date, end_date)
|
||||
fwd_rate = (p1 / p2 - 1.0) / dt
|
||||
|
||||
index.addFixing(fd, fwd_rate, True) # Force overwrite
|
||||
|
||||
# 4. Valuation
|
||||
swap.setPricingEngine(ql.DiscountingSwapEngine(yield_handle))
|
||||
npv_matrix[i, j] = max(swap.NPV(), 0)
|
||||
|
||||
ql.Settings.instance().evaluationDate = today
|
||||
|
||||
return times, npv_matrix
|
||||
|
||||
# --- UNIT 3: Visualization ---
|
||||
times, npv_matrix = run_simulation(n_paths=100)
|
||||
ee = np.mean(npv_matrix, axis=0)
|
||||
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(times, ee, lw=2, label="Expected Exposure (EE)")
|
||||
plt.fill_between(times, ee, alpha=0.3)
|
||||
plt.title("30Y Swap EE Profile - Hull White")
|
||||
plt.xlabel("Years"); plt.ylabel("Exposure"); plt.grid(True); plt.legend()
|
||||
plt.show()
|
||||
42
python/expression-evaluator/CLAUDE.md
Normal file
42
python/expression-evaluator/CLAUDE.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Expression Evaluator
|
||||
|
||||
## Overview
|
||||
Educational project teaching DAGs and state machines through a calculator.
|
||||
Pure Python, no external dependencies.
|
||||
|
||||
## Running
|
||||
```bash
|
||||
python main.py "3 + 4 * 2" # single expression
|
||||
python main.py # REPL mode
|
||||
python main.py --show-tokens --show-ast --trace "expr" # show internals
|
||||
python main.py --dot "3+4*2" | dot -Tpng -o ast.png # AST diagram
|
||||
python main.py --dot-fsm | dot -Tpng -o fsm.png # FSM diagram
|
||||
```
|
||||
|
||||
## Testing
|
||||
```bash
|
||||
python -m pytest tests/ -v
|
||||
```
|
||||
|
||||
## Architecture
|
||||
- `tokenizer.py` -- Explicit finite state machine (Mealy machine) tokenizer
|
||||
- `parser.py` -- Recursive descent parser building an AST (DAG)
|
||||
- `evaluator.py` -- Post-order tree walker (topological sort evaluation)
|
||||
- `visualize.py` -- Graphviz dot generation for AST and FSM diagrams
|
||||
- `main.py` -- CLI entry point with argparse, REPL mode
|
||||
|
||||
## Key Design Decisions
|
||||
- State machine uses an explicit transition table (dict), not implicit if/else
|
||||
- Unary minus resolved by examining previous token context
|
||||
- Power operator (`^`) is right-associative (grammar uses right-recursion)
|
||||
- AST nodes are dataclasses; evaluation uses structural pattern matching
|
||||
- Graphviz output is raw dot strings (no graphviz Python package needed)
|
||||
|
||||
## Grammar
|
||||
```
|
||||
expression ::= term ((PLUS | MINUS) term)*
|
||||
term ::= unary ((MULTIPLY | DIVIDE) unary)*
|
||||
unary ::= UNARY_MINUS unary | power
|
||||
power ::= atom (POWER power)?
|
||||
atom ::= NUMBER | LPAREN expression RPAREN
|
||||
```
|
||||
87
python/expression-evaluator/README.md
Normal file
87
python/expression-evaluator/README.md
Normal file
@@ -0,0 +1,87 @@
|
||||
# Expression Evaluator -- DAGs & State Machines Tutorial
|
||||
|
||||
A calculator that teaches two fundamental CS patterns by building them from scratch:
|
||||
|
||||
1. **Finite State Machine** -- the tokenizer processes input character-by-character using an explicit transition table
|
||||
2. **Directed Acyclic Graph (DAG)** -- the parser builds an expression tree, evaluated bottom-up in topological order
|
||||
|
||||
## What You'll Learn
|
||||
|
||||
| File | CS Concept | What it does |
|
||||
|------|-----------|-------------|
|
||||
| `tokenizer.py` | **State Machine** (Mealy machine) | Converts `"3 + 4 * 2"` into tokens using a transition table |
|
||||
| `parser.py` | **DAG construction** | Builds an expression tree with operator precedence |
|
||||
| `evaluator.py` | **Topological evaluation** | Walks the tree bottom-up (leaves before parents) |
|
||||
| `visualize.py` | **Visualization** | Generates graphviz diagrams of both the FSM and AST |
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Evaluate an expression
|
||||
python main.py "3 + 4 * 2"
|
||||
# => 11
|
||||
|
||||
# Interactive REPL
|
||||
python main.py
|
||||
|
||||
# See how the state machine tokenizes
|
||||
python main.py --show-tokens "(2 + 3) * -4"
|
||||
|
||||
# See the expression tree (DAG)
|
||||
python main.py --show-ast "(2 + 3) * 4"
|
||||
# *
|
||||
# +-- +
|
||||
# | +-- 2
|
||||
# | `-- 3
|
||||
# `-- 4
|
||||
|
||||
# Watch evaluation in topological order
|
||||
python main.py --trace "(2 + 3) * 4"
|
||||
# Step 1: 2 => 2
|
||||
# Step 2: 3 => 3
|
||||
# Step 3: 2 + 3 => 5
|
||||
# Step 4: 4 => 4
|
||||
# Step 5: 5 * 4 => 20
|
||||
|
||||
# Generate graphviz diagrams
|
||||
python main.py --dot "(2 + 3) * 4" | dot -Tpng -o ast.png
|
||||
python main.py --dot-fsm | dot -Tpng -o fsm.png
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- Arithmetic: `+`, `-`, `*`, `/`, `^` (power)
|
||||
- Parentheses: `(2 + 3) * 4`
|
||||
- Unary minus: `-3`, `-(2 + 1)`, `2 * -3`
|
||||
- Decimals: `3.14`, `.5`
|
||||
- Standard precedence: parens > `^` > `*`/`/` > `+`/`-`
|
||||
- Right-associative power: `2^3^4` = `2^(3^4)`
|
||||
- Correct unary minus: `-3^2` = `-(3^2)` = `-9`
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
python -m pytest tests/ -v
|
||||
```
|
||||
|
||||
## How the State Machine Works
|
||||
|
||||
The tokenizer in `tokenizer.py` uses an **explicit transition table** -- a dictionary mapping `(current_state, character_class)` to `(next_state, action)`. This is the same pattern used in network protocol parsers, regex engines, and compiler lexers.
|
||||
|
||||
The three states are:
|
||||
- `START` -- between tokens, dispatching based on the next character
|
||||
- `INTEGER` -- accumulating digits (e.g., `"12"` so far)
|
||||
- `DECIMAL` -- accumulating digits after a decimal point (e.g., `"12.3"`)
|
||||
|
||||
Use `--dot-fsm` to generate a visual diagram of the state machine.
|
||||
|
||||
## How the DAG Works
|
||||
|
||||
The parser in `parser.py` builds an **expression tree** (AST) where:
|
||||
- **Leaf nodes** are numbers (no dependencies)
|
||||
- **Interior nodes** are operators with edges to their operands
|
||||
- **Edges** represent "depends on" relationships
|
||||
|
||||
Evaluation in `evaluator.py` walks this tree **bottom-up** -- children before parents. This is exactly a **topological sort** of the DAG: you can only compute a node after all its dependencies are resolved.
|
||||
|
||||
Use `--show-ast` to see the tree structure, or `--dot` to generate a graphviz diagram.
|
||||
147
python/expression-evaluator/evaluator.py
Normal file
147
python/expression-evaluator/evaluator.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Part 3: DAG Evaluation -- Tree Walker
|
||||
=======================================
|
||||
Evaluating the AST bottom-up is equivalent to topological-sort
|
||||
evaluation of a DAG. We must evaluate a node's children before
|
||||
the node itself -- just like in any dependency graph.
|
||||
|
||||
For a tree, post-order traversal gives a topological ordering.
|
||||
The recursive evaluate() function naturally does this:
|
||||
1. Recursively evaluate all children (dependencies)
|
||||
2. Combine the results (compute this node's value)
|
||||
3. Return the result (make it available to the parent)
|
||||
|
||||
This is the same pattern as:
|
||||
- make: build dependencies before the target
|
||||
- pip/npm install: install dependencies before the package
|
||||
- Spreadsheet recalculation: compute referenced cells first
|
||||
"""
|
||||
|
||||
from parser import NumberNode, BinOpNode, UnaryOpNode, Node
|
||||
from tokenizer import TokenType
|
||||
|
||||
|
||||
# ---------- Errors ----------
|
||||
|
||||
class EvalError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# ---------- Evaluator ----------
|
||||
|
||||
OP_SYMBOLS = {
|
||||
TokenType.PLUS: '+',
|
||||
TokenType.MINUS: '-',
|
||||
TokenType.MULTIPLY: '*',
|
||||
TokenType.DIVIDE: '/',
|
||||
TokenType.POWER: '^',
|
||||
TokenType.UNARY_MINUS: 'neg',
|
||||
}
|
||||
|
||||
|
||||
def evaluate(node):
|
||||
"""
|
||||
Evaluate an AST by walking it bottom-up (post-order traversal).
|
||||
|
||||
This is a recursive function that mirrors the DAG structure:
|
||||
each recursive call follows a DAG edge to a child node.
|
||||
Children are evaluated before parents -- topological order.
|
||||
"""
|
||||
match node:
|
||||
case NumberNode(value=v):
|
||||
return v
|
||||
|
||||
case UnaryOpNode(op=TokenType.UNARY_MINUS, operand=child):
|
||||
return -evaluate(child)
|
||||
|
||||
case BinOpNode(op=op, left=left, right=right):
|
||||
left_val = evaluate(left)
|
||||
right_val = evaluate(right)
|
||||
match op:
|
||||
case TokenType.PLUS:
|
||||
return left_val + right_val
|
||||
case TokenType.MINUS:
|
||||
return left_val - right_val
|
||||
case TokenType.MULTIPLY:
|
||||
return left_val * right_val
|
||||
case TokenType.DIVIDE:
|
||||
if right_val == 0:
|
||||
raise EvalError("division by zero")
|
||||
return left_val / right_val
|
||||
case TokenType.POWER:
|
||||
return left_val ** right_val
|
||||
|
||||
raise EvalError(f"unknown node type: {type(node)}")
|
||||
|
||||
|
||||
def evaluate_traced(node):
|
||||
"""
|
||||
Like evaluate(), but records each step for educational display.
|
||||
Returns (result, list_of_trace_lines).
|
||||
|
||||
The trace shows the topological evaluation order -- how the DAG
|
||||
is evaluated from leaves to root. Each step shows a node being
|
||||
evaluated after all its dependencies are resolved.
|
||||
"""
|
||||
steps = []
|
||||
counter = [0] # mutable counter for step numbering
|
||||
|
||||
def _walk(node, depth):
|
||||
indent = " " * depth
|
||||
counter[0] += 1
|
||||
step = counter[0]
|
||||
|
||||
match node:
|
||||
case NumberNode(value=v):
|
||||
result = v
|
||||
display = _format_number(v)
|
||||
steps.append(f"{indent}Step {step}: {display} => {_format_number(result)}")
|
||||
return result
|
||||
|
||||
case UnaryOpNode(op=TokenType.UNARY_MINUS, operand=child):
|
||||
child_val = _walk(child, depth + 1)
|
||||
result = -child_val
|
||||
counter[0] += 1
|
||||
step = counter[0]
|
||||
steps.append(
|
||||
f"{indent}Step {step}: neg({_format_number(child_val)}) "
|
||||
f"=> {_format_number(result)}"
|
||||
)
|
||||
return result
|
||||
|
||||
case BinOpNode(op=op, left=left, right=right):
|
||||
left_val = _walk(left, depth + 1)
|
||||
right_val = _walk(right, depth + 1)
|
||||
sym = OP_SYMBOLS[op]
|
||||
match op:
|
||||
case TokenType.PLUS:
|
||||
result = left_val + right_val
|
||||
case TokenType.MINUS:
|
||||
result = left_val - right_val
|
||||
case TokenType.MULTIPLY:
|
||||
result = left_val * right_val
|
||||
case TokenType.DIVIDE:
|
||||
if right_val == 0:
|
||||
raise EvalError("division by zero")
|
||||
result = left_val / right_val
|
||||
case TokenType.POWER:
|
||||
result = left_val ** right_val
|
||||
counter[0] += 1
|
||||
step = counter[0]
|
||||
steps.append(
|
||||
f"{indent}Step {step}: {_format_number(left_val)} {sym} "
|
||||
f"{_format_number(right_val)} => {_format_number(result)}"
|
||||
)
|
||||
return result
|
||||
|
||||
raise EvalError(f"unknown node type: {type(node)}")
|
||||
|
||||
result = _walk(node, 0)
|
||||
return result, steps
|
||||
|
||||
|
||||
def _format_number(v):
|
||||
"""Display a number as integer when possible."""
|
||||
if isinstance(v, float) and v == int(v):
|
||||
return str(int(v))
|
||||
return str(v)
|
||||
163
python/expression-evaluator/main.py
Normal file
163
python/expression-evaluator/main.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Expression Evaluator -- Learn DAGs & State Machines
|
||||
====================================================
|
||||
CLI entry point and interactive REPL.
|
||||
|
||||
Usage:
|
||||
python main.py "3 + 4 * 2" # evaluate
|
||||
python main.py # REPL mode
|
||||
python main.py --show-tokens --show-ast --trace "expr" # show internals
|
||||
python main.py --dot "3 + 4 * 2" | dot -Tpng -o ast.png
|
||||
python main.py --dot-fsm | dot -Tpng -o fsm.png
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tokenizer import tokenize, TokenError
|
||||
from parser import Parser, ParseError
|
||||
from evaluator import evaluate, evaluate_traced, EvalError
|
||||
from visualize import ast_to_dot, fsm_to_dot, ast_to_text
|
||||
|
||||
|
||||
def process_expression(expr, args):
|
||||
"""Tokenize, parse, and evaluate a single expression."""
|
||||
try:
|
||||
tokens = tokenize(expr)
|
||||
except TokenError as e:
|
||||
_print_error(expr, e)
|
||||
return
|
||||
|
||||
if args.show_tokens:
|
||||
print("\nTokens:")
|
||||
for tok in tokens:
|
||||
print(f" {tok}")
|
||||
|
||||
try:
|
||||
ast = Parser(tokens).parse()
|
||||
except ParseError as e:
|
||||
_print_error(expr, e)
|
||||
return
|
||||
|
||||
if args.show_ast:
|
||||
print("\nAST (text tree):")
|
||||
print(ast_to_text(ast))
|
||||
|
||||
if args.dot:
|
||||
print(ast_to_dot(ast))
|
||||
return # dot output goes to stdout, skip numeric result
|
||||
|
||||
if args.trace:
|
||||
try:
|
||||
result, steps = evaluate_traced(ast)
|
||||
except EvalError as e:
|
||||
print(f"Eval error: {e}")
|
||||
return
|
||||
print("\nEvaluation trace (topological order):")
|
||||
for step in steps:
|
||||
print(step)
|
||||
print(f"\nResult: {_format_result(result)}")
|
||||
else:
|
||||
try:
|
||||
result = evaluate(ast)
|
||||
except EvalError as e:
|
||||
print(f"Eval error: {e}")
|
||||
return
|
||||
print(_format_result(result))
|
||||
|
||||
|
||||
def repl(args):
|
||||
"""Interactive read-eval-print loop."""
|
||||
print("Expression Evaluator REPL")
|
||||
print("Type an expression, or 'quit' to exit.")
|
||||
flags = []
|
||||
if args.show_tokens:
|
||||
flags.append("--show-tokens")
|
||||
if args.show_ast:
|
||||
flags.append("--show-ast")
|
||||
if args.trace:
|
||||
flags.append("--trace")
|
||||
if flags:
|
||||
print(f"Active flags: {' '.join(flags)}")
|
||||
print()
|
||||
|
||||
while True:
|
||||
try:
|
||||
line = input(">>> ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
break
|
||||
if line.lower() in ("quit", "exit", "q"):
|
||||
break
|
||||
if not line:
|
||||
continue
|
||||
process_expression(line, args)
|
||||
print()
|
||||
|
||||
|
||||
def _print_error(expr, error):
|
||||
"""Print an error with a caret pointing to the position."""
|
||||
print(f"Error: {error}")
|
||||
if hasattr(error, 'position') and error.position is not None:
|
||||
print(f" {expr}")
|
||||
print(f" {' ' * error.position}^")
|
||||
|
||||
|
||||
def _format_result(v):
|
||||
"""Format a numeric result: show as int when possible."""
|
||||
if isinstance(v, float) and v == int(v) and abs(v) < 1e15:
|
||||
return str(int(v))
|
||||
return str(v)
|
||||
|
||||
|
||||
def main():
|
||||
arg_parser = argparse.ArgumentParser(
|
||||
description="Expression Evaluator -- learn DAGs and state machines",
|
||||
epilog="Examples:\n"
|
||||
" python main.py '3 + 4 * 2'\n"
|
||||
" python main.py --show-tokens --trace '-(3 + 4) ^ 2'\n"
|
||||
" python main.py --dot '(2+3)*4' | dot -Tpng -o ast.png\n"
|
||||
" python main.py --dot-fsm | dot -Tpng -o fsm.png",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"expression", nargs="?",
|
||||
help="Expression to evaluate (omit for REPL mode)",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--show-tokens", action="store_true",
|
||||
help="Display tokenizer output",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--show-ast", action="store_true",
|
||||
help="Display AST as indented text tree",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--trace", action="store_true",
|
||||
help="Show step-by-step evaluation trace",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--dot", action="store_true",
|
||||
help="Output AST as graphviz dot (pipe to: dot -Tpng -o ast.png)",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--dot-fsm", action="store_true",
|
||||
help="Output tokenizer FSM as graphviz dot",
|
||||
)
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
# Special mode: just print the FSM diagram and exit
|
||||
if args.dot_fsm:
|
||||
print(fsm_to_dot())
|
||||
return
|
||||
|
||||
# REPL mode if no expression given
|
||||
if args.expression is None:
|
||||
repl(args)
|
||||
else:
|
||||
process_expression(args.expression, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
217
python/expression-evaluator/parser.py
Normal file
217
python/expression-evaluator/parser.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Part 2: DAG Construction -- Recursive Descent Parser
|
||||
=====================================================
|
||||
A parser converts a flat list of tokens into a tree structure (AST).
|
||||
The AST is a DAG (Directed Acyclic Graph) where:
|
||||
|
||||
- Nodes are operations (BinOpNode) or values (NumberNode)
|
||||
- Edges point from parent operations to their operands
|
||||
- The graph is acyclic because an operation's inputs are always
|
||||
"simpler" sub-expressions (no circular dependencies)
|
||||
- It is a tree (a special case of DAG) because no node is shared
|
||||
|
||||
This is the same structure as:
|
||||
- Spreadsheet dependency graphs (cell A1 depends on B1, B2...)
|
||||
- Build systems (Makefile targets depend on other targets)
|
||||
- Task scheduling (some tasks must finish before others start)
|
||||
- Neural network computation graphs (forward pass is a DAG)
|
||||
|
||||
Key DAG concepts demonstrated:
|
||||
- Nodes: operations and values
|
||||
- Directed edges: from operation to its inputs (dependencies)
|
||||
- Acyclic: no circular dependencies
|
||||
- Topological ordering: natural evaluation order (leaves first)
|
||||
|
||||
Grammar (BNF) -- precedence is encoded by nesting depth:
|
||||
expression ::= term ((PLUS | MINUS) term)* # lowest precedence
|
||||
term ::= unary ((MULTIPLY | DIVIDE) unary)*
|
||||
unary ::= UNARY_MINUS unary | power
|
||||
power ::= atom (POWER power)? # right-associative
|
||||
atom ::= NUMBER | LPAREN expression RPAREN # highest precedence
|
||||
|
||||
Call chain: expression -> term -> unary -> power -> atom
|
||||
This means: +/- binds loosest, then *//, then unary -, then ^, then parens.
|
||||
So -3^2 = -(3^2) = -9, matching standard math convention.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tokenizer import Token, TokenType
|
||||
|
||||
|
||||
# ---------- AST node types ----------
|
||||
# These are the nodes of our DAG. Each node is either a leaf (NumberNode)
|
||||
# or an interior node with edges pointing to its children (operands).
|
||||
|
||||
@dataclass
|
||||
class NumberNode:
|
||||
"""Leaf node: a numeric literal. In DAG terms, a node with no outgoing edges."""
|
||||
value: float
|
||||
|
||||
def __repr__(self):
|
||||
if self.value == int(self.value):
|
||||
return f"NumberNode({int(self.value)})"
|
||||
return f"NumberNode({self.value})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinOpNode:
|
||||
"""
|
||||
Interior node: a binary operation with two children.
|
||||
|
||||
DAG edges: this node -> left, this node -> right
|
||||
The edges represent "depends on": to compute this node's value,
|
||||
we must first compute left and right.
|
||||
"""
|
||||
op: TokenType
|
||||
left: 'NumberNode | BinOpNode | UnaryOpNode'
|
||||
right: 'NumberNode | BinOpNode | UnaryOpNode'
|
||||
|
||||
def __repr__(self):
|
||||
return f"BinOpNode({self.op.name}, {self.left}, {self.right})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnaryOpNode:
|
||||
"""Interior node: a unary operation (negation) with one child."""
|
||||
op: TokenType
|
||||
operand: 'NumberNode | BinOpNode | UnaryOpNode'
|
||||
|
||||
def __repr__(self):
|
||||
return f"UnaryOpNode({self.op.name}, {self.operand})"
|
||||
|
||||
|
||||
# Union type for any AST node
|
||||
Node = NumberNode | BinOpNode | UnaryOpNode
|
||||
|
||||
|
||||
# ---------- Errors ----------
|
||||
|
||||
class ParseError(Exception):
|
||||
def __init__(self, message, position=None):
|
||||
self.position = position
|
||||
pos_info = f" at position {position}" if position is not None else ""
|
||||
super().__init__(f"Parse error{pos_info}: {message}")
|
||||
|
||||
|
||||
# ---------- Recursive descent parser ----------
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Converts a list of tokens into an AST (expression tree / DAG).
|
||||
|
||||
Each grammar rule becomes a method. The call tree mirrors the shape
|
||||
of the AST being built. When a deeper method returns a node, it
|
||||
becomes a child of the node built by the caller -- this is how
|
||||
the DAG edges form.
|
||||
|
||||
Precedence is encoded by nesting: lower-precedence operators are
|
||||
parsed at higher (outer) levels, so they become closer to the root
|
||||
of the tree and are evaluated last.
|
||||
"""
|
||||
|
||||
def __init__(self, tokens):
|
||||
self.tokens = tokens
|
||||
self.pos = 0
|
||||
|
||||
def peek(self):
|
||||
"""Look at the current token without consuming it."""
|
||||
return self.tokens[self.pos]
|
||||
|
||||
def consume(self, expected=None):
|
||||
"""Consume and return the current token, optionally asserting its type."""
|
||||
token = self.tokens[self.pos]
|
||||
if expected is not None and token.type != expected:
|
||||
raise ParseError(
|
||||
f"expected {expected.name}, got {token.type.name}",
|
||||
token.position,
|
||||
)
|
||||
self.pos += 1
|
||||
return token
|
||||
|
||||
def parse(self):
|
||||
"""Entry point: parse the full expression and verify we consumed everything."""
|
||||
if self.peek().type == TokenType.EOF:
|
||||
raise ParseError("empty expression")
|
||||
node = self.expression()
|
||||
self.consume(TokenType.EOF)
|
||||
return node
|
||||
|
||||
# --- Grammar rules ---
|
||||
# Each method corresponds to one production in the grammar.
|
||||
# The nesting encodes operator precedence.
|
||||
|
||||
def expression(self):
|
||||
"""expression ::= term ((PLUS | MINUS) term)*"""
|
||||
node = self.term()
|
||||
while self.peek().type in (TokenType.PLUS, TokenType.MINUS):
|
||||
op_token = self.consume()
|
||||
right = self.term()
|
||||
# Build a new BinOpNode -- this creates a DAG edge from
|
||||
# the new node to both 'node' (left) and 'right'
|
||||
node = BinOpNode(op_token.type, node, right)
|
||||
return node
|
||||
|
||||
def term(self):
|
||||
"""term ::= unary ((MULTIPLY | DIVIDE) unary)*"""
|
||||
node = self.unary()
|
||||
while self.peek().type in (TokenType.MULTIPLY, TokenType.DIVIDE):
|
||||
op_token = self.consume()
|
||||
right = self.unary()
|
||||
node = BinOpNode(op_token.type, node, right)
|
||||
return node
|
||||
|
||||
def unary(self):
|
||||
"""
|
||||
unary ::= UNARY_MINUS unary | power
|
||||
|
||||
Unary minus is parsed here, between term and power, so it binds
|
||||
looser than ^ but tighter than * and /. This gives the standard
|
||||
math behavior: -3^2 = -(3^2) = -9.
|
||||
|
||||
The recursion (unary calls itself) handles double negation: --3 = 3.
|
||||
"""
|
||||
if self.peek().type == TokenType.UNARY_MINUS:
|
||||
op_token = self.consume()
|
||||
operand = self.unary()
|
||||
return UnaryOpNode(op_token.type, operand)
|
||||
return self.power()
|
||||
|
||||
def power(self):
|
||||
"""
|
||||
power ::= atom (POWER power)?
|
||||
|
||||
Right-recursive for right-associativity: 2^3^4 = 2^(3^4) = 2^81.
|
||||
Compare with term() which uses a while loop for LEFT-associativity.
|
||||
"""
|
||||
node = self.atom()
|
||||
if self.peek().type == TokenType.POWER:
|
||||
op_token = self.consume()
|
||||
right = self.power() # recurse (not loop) for right-associativity
|
||||
node = BinOpNode(op_token.type, node, right)
|
||||
return node
|
||||
|
||||
def atom(self):
|
||||
"""
|
||||
atom ::= NUMBER | LPAREN expression RPAREN
|
||||
|
||||
The base case: either a literal number or a parenthesized
|
||||
sub-expression. Parentheses work by recursing back to
|
||||
expression(), which restarts precedence parsing from the top.
|
||||
"""
|
||||
token = self.peek()
|
||||
|
||||
if token.type == TokenType.NUMBER:
|
||||
self.consume()
|
||||
return NumberNode(float(token.value))
|
||||
|
||||
if token.type == TokenType.LPAREN:
|
||||
self.consume()
|
||||
node = self.expression()
|
||||
self.consume(TokenType.RPAREN)
|
||||
return node
|
||||
|
||||
raise ParseError(
|
||||
f"expected number or '(', got {token.type.name}",
|
||||
token.position,
|
||||
)
|
||||
0
python/expression-evaluator/tests/__init__.py
Normal file
0
python/expression-evaluator/tests/__init__.py
Normal file
120
python/expression-evaluator/tests/test_evaluator.py
Normal file
120
python/expression-evaluator/tests/test_evaluator.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import pytest
|
||||
from tokenizer import tokenize
|
||||
from parser import Parser
|
||||
from evaluator import evaluate, evaluate_traced, EvalError
|
||||
|
||||
|
||||
def eval_expr(expr):
|
||||
"""Helper: tokenize -> parse -> evaluate in one step."""
|
||||
tokens = tokenize(expr)
|
||||
ast = Parser(tokens).parse()
|
||||
return evaluate(ast)
|
||||
|
||||
|
||||
# ---------- Basic arithmetic ----------
|
||||
|
||||
def test_addition():
|
||||
assert eval_expr("3 + 4") == 7.0
|
||||
|
||||
def test_subtraction():
|
||||
assert eval_expr("10 - 3") == 7.0
|
||||
|
||||
def test_multiplication():
|
||||
assert eval_expr("3 * 4") == 12.0
|
||||
|
||||
def test_division():
|
||||
assert eval_expr("10 / 4") == 2.5
|
||||
|
||||
def test_power():
|
||||
assert eval_expr("2 ^ 10") == 1024.0
|
||||
|
||||
|
||||
# ---------- Precedence ----------
|
||||
|
||||
def test_standard_precedence():
|
||||
assert eval_expr("3 + 4 * 2") == 11.0
|
||||
|
||||
def test_parentheses():
|
||||
assert eval_expr("(3 + 4) * 2") == 14.0
|
||||
|
||||
def test_power_precedence():
|
||||
assert eval_expr("2 * 3 ^ 2") == 18.0
|
||||
|
||||
def test_right_associative_power():
|
||||
# 2^(2^3) = 2^8 = 256
|
||||
assert eval_expr("2 ^ 2 ^ 3") == 256.0
|
||||
|
||||
|
||||
# ---------- Unary minus ----------
|
||||
|
||||
def test_negation():
|
||||
assert eval_expr("-5") == -5.0
|
||||
|
||||
def test_double_negation():
|
||||
assert eval_expr("--5") == 5.0
|
||||
|
||||
def test_negation_with_power():
|
||||
# -(3^2) = -9, not (-3)^2 = 9
|
||||
assert eval_expr("-3 ^ 2") == -9.0
|
||||
|
||||
def test_negation_in_parens():
|
||||
assert eval_expr("(-3) ^ 2") == 9.0
|
||||
|
||||
|
||||
# ---------- Decimals ----------
|
||||
|
||||
def test_decimal_addition():
|
||||
assert eval_expr("0.1 + 0.2") == pytest.approx(0.3)
|
||||
|
||||
def test_leading_dot():
|
||||
assert eval_expr(".5 + .5") == 1.0
|
||||
|
||||
|
||||
# ---------- Edge cases ----------
|
||||
|
||||
def test_nested_parens():
|
||||
assert eval_expr("((((3))))") == 3.0
|
||||
|
||||
def test_complex_expression():
|
||||
assert eval_expr("(2 + 3) * (7 - 2) / 5 ^ 1") == 5.0
|
||||
|
||||
def test_long_chain():
|
||||
assert eval_expr("1 + 2 + 3 + 4 + 5") == 15.0
|
||||
|
||||
def test_mixed_operations():
|
||||
assert eval_expr("2 + 3 * 4 - 6 / 2") == 11.0
|
||||
|
||||
|
||||
# ---------- Division by zero ----------
|
||||
|
||||
def test_division_by_zero():
|
||||
with pytest.raises(EvalError):
|
||||
eval_expr("1 / 0")
|
||||
|
||||
def test_division_by_zero_in_expression():
|
||||
with pytest.raises(EvalError):
|
||||
eval_expr("5 + 3 / (2 - 2)")
|
||||
|
||||
|
||||
# ---------- Traced evaluation ----------
|
||||
|
||||
def test_traced_returns_correct_result():
|
||||
tokens = tokenize("3 + 4 * 2")
|
||||
ast = Parser(tokens).parse()
|
||||
result, steps = evaluate_traced(ast)
|
||||
assert result == 11.0
|
||||
assert len(steps) > 0
|
||||
|
||||
def test_traced_step_count():
|
||||
"""A simple binary op has 3 evaluation events: left, right, combine."""
|
||||
tokens = tokenize("3 + 4")
|
||||
ast = Parser(tokens).parse()
|
||||
result, steps = evaluate_traced(ast)
|
||||
assert result == 7.0
|
||||
# NumberNode(3), NumberNode(4), BinOp(+)
|
||||
assert len(steps) == 3
|
||||
136
python/expression-evaluator/tests/test_parser.py
Normal file
136
python/expression-evaluator/tests/test_parser.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import pytest
|
||||
from tokenizer import tokenize, TokenType
|
||||
from parser import Parser, ParseError, NumberNode, BinOpNode, UnaryOpNode
|
||||
|
||||
|
||||
def parse(expr):
|
||||
"""Helper: tokenize and parse in one step."""
|
||||
return Parser(tokenize(expr)).parse()
|
||||
|
||||
|
||||
# ---------- Basic parsing ----------
|
||||
|
||||
def test_parse_number():
|
||||
ast = parse("42")
|
||||
assert isinstance(ast, NumberNode)
|
||||
assert ast.value == 42.0
|
||||
|
||||
def test_parse_decimal():
|
||||
ast = parse("3.14")
|
||||
assert isinstance(ast, NumberNode)
|
||||
assert ast.value == 3.14
|
||||
|
||||
def test_parse_addition():
|
||||
ast = parse("3 + 4")
|
||||
assert isinstance(ast, BinOpNode)
|
||||
assert ast.op == TokenType.PLUS
|
||||
assert isinstance(ast.left, NumberNode)
|
||||
assert isinstance(ast.right, NumberNode)
|
||||
|
||||
|
||||
# ---------- Precedence ----------
|
||||
|
||||
def test_multiply_before_add():
|
||||
"""3 + 4 * 2 should parse as 3 + (4 * 2)."""
|
||||
ast = parse("3 + 4 * 2")
|
||||
assert ast.op == TokenType.PLUS
|
||||
assert isinstance(ast.right, BinOpNode)
|
||||
assert ast.right.op == TokenType.MULTIPLY
|
||||
|
||||
def test_power_before_multiply():
|
||||
"""2 * 3 ^ 4 should parse as 2 * (3 ^ 4)."""
|
||||
ast = parse("2 * 3 ^ 4")
|
||||
assert ast.op == TokenType.MULTIPLY
|
||||
assert isinstance(ast.right, BinOpNode)
|
||||
assert ast.right.op == TokenType.POWER
|
||||
|
||||
def test_parentheses_override_precedence():
|
||||
"""(3 + 4) * 2 should parse as (3 + 4) * 2."""
|
||||
ast = parse("(3 + 4) * 2")
|
||||
assert ast.op == TokenType.MULTIPLY
|
||||
assert isinstance(ast.left, BinOpNode)
|
||||
assert ast.left.op == TokenType.PLUS
|
||||
|
||||
|
||||
# ---------- Associativity ----------
|
||||
|
||||
def test_left_associative_subtraction():
|
||||
"""10 - 3 - 2 should parse as (10 - 3) - 2."""
|
||||
ast = parse("10 - 3 - 2")
|
||||
assert ast.op == TokenType.MINUS
|
||||
assert isinstance(ast.left, BinOpNode)
|
||||
assert ast.left.op == TokenType.MINUS
|
||||
assert isinstance(ast.right, NumberNode)
|
||||
|
||||
def test_power_right_associative():
|
||||
"""2 ^ 3 ^ 4 should parse as 2 ^ (3 ^ 4)."""
|
||||
ast = parse("2 ^ 3 ^ 4")
|
||||
assert ast.op == TokenType.POWER
|
||||
assert isinstance(ast.left, NumberNode)
|
||||
assert isinstance(ast.right, BinOpNode)
|
||||
assert ast.right.op == TokenType.POWER
|
||||
|
||||
|
||||
# ---------- Unary minus ----------
|
||||
|
||||
def test_unary_minus():
|
||||
ast = parse("-3")
|
||||
assert isinstance(ast, UnaryOpNode)
|
||||
assert ast.operand.value == 3.0
|
||||
|
||||
def test_double_negation():
|
||||
ast = parse("--3")
|
||||
assert isinstance(ast, UnaryOpNode)
|
||||
assert isinstance(ast.operand, UnaryOpNode)
|
||||
assert ast.operand.operand.value == 3.0
|
||||
|
||||
def test_unary_minus_precedence():
|
||||
"""-3^2 should parse as -(3^2), not (-3)^2."""
|
||||
ast = parse("-3 ^ 2")
|
||||
assert isinstance(ast, UnaryOpNode)
|
||||
assert isinstance(ast.operand, BinOpNode)
|
||||
assert ast.operand.op == TokenType.POWER
|
||||
|
||||
def test_unary_minus_in_expression():
|
||||
"""2 * -3 should parse as 2 * (-(3))."""
|
||||
ast = parse("2 * -3")
|
||||
assert ast.op == TokenType.MULTIPLY
|
||||
assert isinstance(ast.right, UnaryOpNode)
|
||||
|
||||
|
||||
# ---------- Nested parentheses ----------
|
||||
|
||||
def test_nested_parens():
|
||||
ast = parse("((3))")
|
||||
assert isinstance(ast, NumberNode)
|
||||
assert ast.value == 3.0
|
||||
|
||||
def test_complex_nesting():
|
||||
"""((2 + 3) * (7 - 2))"""
|
||||
ast = parse("((2 + 3) * (7 - 2))")
|
||||
assert isinstance(ast, BinOpNode)
|
||||
assert ast.op == TokenType.MULTIPLY
|
||||
|
||||
|
||||
# ---------- Errors ----------
|
||||
|
||||
def test_missing_rparen():
|
||||
with pytest.raises(ParseError):
|
||||
parse("(3 + 4")
|
||||
|
||||
def test_empty_expression():
|
||||
with pytest.raises(ParseError):
|
||||
parse("")
|
||||
|
||||
def test_trailing_operator():
|
||||
with pytest.raises(ParseError):
|
||||
parse("3 +")
|
||||
|
||||
def test_empty_parens():
|
||||
with pytest.raises(ParseError):
|
||||
parse("()")
|
||||
139
python/expression-evaluator/tests/test_tokenizer.py
Normal file
139
python/expression-evaluator/tests/test_tokenizer.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import pytest
|
||||
from tokenizer import tokenize, TokenType, Token, TokenError
|
||||
|
||||
|
||||
# ---------- Basic tokens ----------
|
||||
|
||||
def test_single_integer():
|
||||
tokens = tokenize("42")
|
||||
assert tokens[0].type == TokenType.NUMBER
|
||||
assert tokens[0].value == "42"
|
||||
|
||||
def test_decimal_number():
|
||||
tokens = tokenize("3.14")
|
||||
assert tokens[0].type == TokenType.NUMBER
|
||||
assert tokens[0].value == "3.14"
|
||||
|
||||
def test_leading_dot():
|
||||
tokens = tokenize(".5")
|
||||
assert tokens[0].type == TokenType.NUMBER
|
||||
assert tokens[0].value == ".5"
|
||||
|
||||
def test_all_operators():
|
||||
"""Operators between numbers are all binary."""
|
||||
tokens = tokenize("1 + 1 - 1 * 1 / 1 ^ 1")
|
||||
ops = [t.type for t in tokens if t.type not in (TokenType.NUMBER, TokenType.EOF)]
|
||||
assert ops == [
|
||||
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
|
||||
TokenType.DIVIDE, TokenType.POWER,
|
||||
]
|
||||
|
||||
def test_operators_between_numbers():
|
||||
tokens = tokenize("1 + 2 - 3 * 4 / 5 ^ 6")
|
||||
ops = [t.type for t in tokens if t.type not in (TokenType.NUMBER, TokenType.EOF)]
|
||||
assert ops == [
|
||||
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
|
||||
TokenType.DIVIDE, TokenType.POWER,
|
||||
]
|
||||
|
||||
def test_parentheses():
|
||||
tokens = tokenize("()")
|
||||
assert tokens[0].type == TokenType.LPAREN
|
||||
assert tokens[1].type == TokenType.RPAREN
|
||||
|
||||
|
||||
# ---------- Unary minus ----------
|
||||
|
||||
def test_unary_minus_at_start():
|
||||
tokens = tokenize("-3")
|
||||
assert tokens[0].type == TokenType.UNARY_MINUS
|
||||
assert tokens[1].type == TokenType.NUMBER
|
||||
|
||||
def test_unary_minus_after_lparen():
|
||||
tokens = tokenize("(-3)")
|
||||
assert tokens[1].type == TokenType.UNARY_MINUS
|
||||
|
||||
def test_unary_minus_after_operator():
|
||||
tokens = tokenize("2 * -3")
|
||||
assert tokens[2].type == TokenType.UNARY_MINUS
|
||||
|
||||
def test_binary_minus():
|
||||
tokens = tokenize("5 - 3")
|
||||
assert tokens[1].type == TokenType.MINUS
|
||||
|
||||
def test_double_unary_minus():
|
||||
tokens = tokenize("--3")
|
||||
assert tokens[0].type == TokenType.UNARY_MINUS
|
||||
assert tokens[1].type == TokenType.UNARY_MINUS
|
||||
assert tokens[2].type == TokenType.NUMBER
|
||||
|
||||
|
||||
# ---------- Whitespace handling ----------
|
||||
|
||||
def test_no_spaces():
|
||||
tokens = tokenize("3+4")
|
||||
non_eof = [t for t in tokens if t.type != TokenType.EOF]
|
||||
assert len(non_eof) == 3
|
||||
|
||||
def test_extra_spaces():
|
||||
tokens = tokenize(" 3 + 4 ")
|
||||
non_eof = [t for t in tokens if t.type != TokenType.EOF]
|
||||
assert len(non_eof) == 3
|
||||
|
||||
|
||||
# ---------- Position tracking ----------
|
||||
|
||||
def test_positions():
|
||||
tokens = tokenize("3 + 4")
|
||||
assert tokens[0].position == 0 # '3'
|
||||
assert tokens[1].position == 2 # '+'
|
||||
assert tokens[2].position == 4 # '4'
|
||||
|
||||
|
||||
# ---------- Errors ----------
|
||||
|
||||
def test_invalid_character():
|
||||
with pytest.raises(TokenError):
|
||||
tokenize("3 & 4")
|
||||
|
||||
def test_double_dot():
|
||||
with pytest.raises(TokenError):
|
||||
tokenize("3.14.15")
|
||||
|
||||
|
||||
# ---------- EOF token ----------
|
||||
|
||||
def test_eof_always_present():
|
||||
tokens = tokenize("42")
|
||||
assert tokens[-1].type == TokenType.EOF
|
||||
|
||||
def test_empty_input():
|
||||
tokens = tokenize("")
|
||||
assert len(tokens) == 1
|
||||
assert tokens[0].type == TokenType.EOF
|
||||
|
||||
|
||||
# ---------- Complex expressions ----------
|
||||
|
||||
def test_complex_expression():
|
||||
tokens = tokenize("(3 + 4.5) * -2 ^ 3")
|
||||
types = [t.type for t in tokens if t.type != TokenType.EOF]
|
||||
assert types == [
|
||||
TokenType.LPAREN, TokenType.NUMBER, TokenType.PLUS,
|
||||
TokenType.NUMBER, TokenType.RPAREN, TokenType.MULTIPLY,
|
||||
TokenType.UNARY_MINUS, TokenType.NUMBER, TokenType.POWER,
|
||||
TokenType.NUMBER,
|
||||
]
|
||||
|
||||
def test_adjacent_parens():
|
||||
tokens = tokenize("(3)(4)")
|
||||
types = [t.type for t in tokens if t.type != TokenType.EOF]
|
||||
assert types == [
|
||||
TokenType.LPAREN, TokenType.NUMBER, TokenType.RPAREN,
|
||||
TokenType.LPAREN, TokenType.NUMBER, TokenType.RPAREN,
|
||||
]
|
||||
306
python/expression-evaluator/tokenizer.py
Normal file
306
python/expression-evaluator/tokenizer.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Part 1: State Machine Tokenizer
|
||||
================================
|
||||
A tokenizer (lexer) converts raw text into a stream of tokens.
|
||||
This implementation uses an EXPLICIT finite state machine (FSM):
|
||||
|
||||
- States are named values (an enum), not implicit control flow
|
||||
- A transition table maps (current_state, input_class) -> (next_state, action)
|
||||
- The main loop reads one character at a time and consults the table
|
||||
|
||||
This is the same pattern used in:
|
||||
- Network protocol parsers (HTTP, TCP state machines)
|
||||
- Regular expression engines
|
||||
- Compiler front-ends (lexers for C, Python, etc.)
|
||||
- Game AI (enemy behavior states)
|
||||
|
||||
Key FSM concepts demonstrated:
|
||||
- States: the "memory" of what we're currently building
|
||||
- Transitions: rules for moving between states based on input
|
||||
- Actions: side effects (emit a token, accumulate a character)
|
||||
- Mealy machine: outputs depend on both state AND input
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# ---------- Token types ----------
|
||||
|
||||
class TokenType(Enum):
|
||||
NUMBER = "NUMBER"
|
||||
PLUS = "PLUS"
|
||||
MINUS = "MINUS"
|
||||
MULTIPLY = "MULTIPLY"
|
||||
DIVIDE = "DIVIDE"
|
||||
POWER = "POWER"
|
||||
LPAREN = "LPAREN"
|
||||
RPAREN = "RPAREN"
|
||||
UNARY_MINUS = "UNARY_MINUS"
|
||||
EOF = "EOF"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
type: TokenType
|
||||
value: str # raw text: "42", "+", "(", etc.
|
||||
position: int # character offset in original expression
|
||||
|
||||
def __repr__(self):
|
||||
return f"Token({self.type.name}, {self.value!r}, pos={self.position})"
|
||||
|
||||
|
||||
OPERATOR_MAP = {
|
||||
'+': TokenType.PLUS,
|
||||
'-': TokenType.MINUS,
|
||||
'*': TokenType.MULTIPLY,
|
||||
'/': TokenType.DIVIDE,
|
||||
'^': TokenType.POWER,
|
||||
}
|
||||
|
||||
|
||||
# ---------- FSM state definitions ----------
|
||||
|
||||
class State(Enum):
|
||||
"""
|
||||
The tokenizer's finite set of states.
|
||||
|
||||
START -- idle / between tokens, deciding what comes next
|
||||
INTEGER -- accumulating digits of an integer (e.g. "12" so far)
|
||||
DECIMAL -- accumulating digits after a decimal point (e.g. "12.3" so far)
|
||||
"""
|
||||
START = "START"
|
||||
INTEGER = "INTEGER"
|
||||
DECIMAL = "DECIMAL"
|
||||
|
||||
|
||||
class CharClass(Enum):
|
||||
"""
|
||||
Character classification -- groups raw characters into categories
|
||||
so the transition table stays small and readable.
|
||||
"""
|
||||
DIGIT = "DIGIT"
|
||||
DOT = "DOT"
|
||||
OPERATOR = "OPERATOR"
|
||||
LPAREN = "LPAREN"
|
||||
RPAREN = "RPAREN"
|
||||
SPACE = "SPACE"
|
||||
EOF = "EOF"
|
||||
UNKNOWN = "UNKNOWN"
|
||||
|
||||
|
||||
class Action(Enum):
|
||||
"""
|
||||
What the FSM does on a transition. In a Mealy machine, the output
|
||||
(action) depends on both the current state AND the input.
|
||||
"""
|
||||
ACCUMULATE = "ACCUMULATE"
|
||||
EMIT_NUMBER = "EMIT_NUMBER"
|
||||
EMIT_OPERATOR = "EMIT_OPERATOR"
|
||||
EMIT_LPAREN = "EMIT_LPAREN"
|
||||
EMIT_RPAREN = "EMIT_RPAREN"
|
||||
EMIT_NUMBER_THEN_OP = "EMIT_NUMBER_THEN_OP"
|
||||
EMIT_NUMBER_THEN_LPAREN = "EMIT_NUMBER_THEN_LPAREN"
|
||||
EMIT_NUMBER_THEN_RPAREN = "EMIT_NUMBER_THEN_RPAREN"
|
||||
EMIT_NUMBER_THEN_DONE = "EMIT_NUMBER_THEN_DONE"
|
||||
SKIP = "SKIP"
|
||||
DONE = "DONE"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Transition:
|
||||
next_state: State
|
||||
action: Action
|
||||
|
||||
|
||||
# ---------- Transition table ----------
|
||||
# This is the heart of the state machine. Every (state, char_class) pair
|
||||
# maps to exactly one transition: a next state and an action to perform.
|
||||
# Making this a data structure (not nested if/else) means we can:
|
||||
# 1. Inspect it programmatically (e.g. to generate a diagram)
|
||||
# 2. Verify completeness (every combination is covered)
|
||||
# 3. Understand the FSM at a glance
|
||||
|
||||
TRANSITIONS = {
|
||||
# --- START: between tokens, dispatch based on character class ---
|
||||
(State.START, CharClass.DIGIT): Transition(State.INTEGER, Action.ACCUMULATE),
|
||||
(State.START, CharClass.DOT): Transition(State.DECIMAL, Action.ACCUMULATE),
|
||||
(State.START, CharClass.OPERATOR): Transition(State.START, Action.EMIT_OPERATOR),
|
||||
(State.START, CharClass.LPAREN): Transition(State.START, Action.EMIT_LPAREN),
|
||||
(State.START, CharClass.RPAREN): Transition(State.START, Action.EMIT_RPAREN),
|
||||
(State.START, CharClass.SPACE): Transition(State.START, Action.SKIP),
|
||||
(State.START, CharClass.EOF): Transition(State.START, Action.DONE),
|
||||
|
||||
# --- INTEGER: accumulating digits like "123" ---
|
||||
(State.INTEGER, CharClass.DIGIT): Transition(State.INTEGER, Action.ACCUMULATE),
|
||||
(State.INTEGER, CharClass.DOT): Transition(State.DECIMAL, Action.ACCUMULATE),
|
||||
(State.INTEGER, CharClass.OPERATOR): Transition(State.START, Action.EMIT_NUMBER_THEN_OP),
|
||||
(State.INTEGER, CharClass.LPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_LPAREN),
|
||||
(State.INTEGER, CharClass.RPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_RPAREN),
|
||||
(State.INTEGER, CharClass.SPACE): Transition(State.START, Action.EMIT_NUMBER),
|
||||
(State.INTEGER, CharClass.EOF): Transition(State.START, Action.EMIT_NUMBER_THEN_DONE),
|
||||
|
||||
# --- DECIMAL: accumulating digits after "." like "123.45" ---
|
||||
(State.DECIMAL, CharClass.DIGIT): Transition(State.DECIMAL, Action.ACCUMULATE),
|
||||
(State.DECIMAL, CharClass.DOT): Transition(State.START, Action.ERROR),
|
||||
(State.DECIMAL, CharClass.OPERATOR): Transition(State.START, Action.EMIT_NUMBER_THEN_OP),
|
||||
(State.DECIMAL, CharClass.LPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_LPAREN),
|
||||
(State.DECIMAL, CharClass.RPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_RPAREN),
|
||||
(State.DECIMAL, CharClass.SPACE): Transition(State.START, Action.EMIT_NUMBER),
|
||||
(State.DECIMAL, CharClass.EOF): Transition(State.START, Action.EMIT_NUMBER_THEN_DONE),
|
||||
}
|
||||
|
||||
|
||||
# ---------- Errors ----------
|
||||
|
||||
class TokenError(Exception):
|
||||
def __init__(self, message, position):
|
||||
self.position = position
|
||||
super().__init__(f"Token error at position {position}: {message}")
|
||||
|
||||
|
||||
# ---------- Character classification ----------
|
||||
|
||||
def classify(ch):
|
||||
"""Map a single character to its CharClass."""
|
||||
if ch.isdigit():
|
||||
return CharClass.DIGIT
|
||||
if ch == '.':
|
||||
return CharClass.DOT
|
||||
if ch in OPERATOR_MAP:
|
||||
return CharClass.OPERATOR
|
||||
if ch == '(':
|
||||
return CharClass.LPAREN
|
||||
if ch == ')':
|
||||
return CharClass.RPAREN
|
||||
if ch.isspace():
|
||||
return CharClass.SPACE
|
||||
return CharClass.UNKNOWN
|
||||
|
||||
|
||||
# ---------- Main tokenize function ----------
|
||||
|
||||
def tokenize(expression):
|
||||
"""
|
||||
Process an expression string through the state machine, producing tokens.
|
||||
|
||||
The main loop:
|
||||
1. Classify the current character
|
||||
2. Look up (state, char_class) in the transition table
|
||||
3. Execute the action (accumulate, emit, skip, etc.)
|
||||
4. Move to the next state
|
||||
5. Advance to the next character
|
||||
|
||||
After all tokens are emitted, a post-processing step resolves
|
||||
unary minus: if a MINUS token appears at the start, after an operator,
|
||||
or after LPAREN, it is re-classified as UNARY_MINUS.
|
||||
"""
|
||||
state = State.START
|
||||
buffer = [] # characters accumulated for the current token
|
||||
buffer_start = 0 # position where the current buffer started
|
||||
tokens = []
|
||||
pos = 0
|
||||
|
||||
# Append a sentinel so EOF is handled uniformly in the loop
|
||||
chars = expression + '\0'
|
||||
|
||||
while pos <= len(expression):
|
||||
ch = chars[pos]
|
||||
char_class = CharClass.EOF if pos == len(expression) else classify(ch)
|
||||
|
||||
if char_class == CharClass.UNKNOWN:
|
||||
raise TokenError(f"unexpected character {ch!r}", pos)
|
||||
|
||||
# Look up the transition
|
||||
key = (state, char_class)
|
||||
transition = TRANSITIONS.get(key)
|
||||
if transition is None:
|
||||
raise TokenError(f"no transition for state={state.name}, input={char_class.name}", pos)
|
||||
|
||||
action = transition.action
|
||||
next_state = transition.next_state
|
||||
|
||||
# --- Execute the action ---
|
||||
|
||||
if action == Action.ACCUMULATE:
|
||||
if not buffer:
|
||||
buffer_start = pos
|
||||
buffer.append(ch)
|
||||
|
||||
elif action == Action.EMIT_NUMBER:
|
||||
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
|
||||
buffer.clear()
|
||||
|
||||
elif action == Action.EMIT_OPERATOR:
|
||||
tokens.append(Token(OPERATOR_MAP[ch], ch, pos))
|
||||
|
||||
elif action == Action.EMIT_LPAREN:
|
||||
tokens.append(Token(TokenType.LPAREN, ch, pos))
|
||||
|
||||
elif action == Action.EMIT_RPAREN:
|
||||
tokens.append(Token(TokenType.RPAREN, ch, pos))
|
||||
|
||||
elif action == Action.EMIT_NUMBER_THEN_OP:
|
||||
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
|
||||
buffer.clear()
|
||||
tokens.append(Token(OPERATOR_MAP[ch], ch, pos))
|
||||
|
||||
elif action == Action.EMIT_NUMBER_THEN_LPAREN:
|
||||
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
|
||||
buffer.clear()
|
||||
tokens.append(Token(TokenType.LPAREN, ch, pos))
|
||||
|
||||
elif action == Action.EMIT_NUMBER_THEN_RPAREN:
|
||||
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
|
||||
buffer.clear()
|
||||
tokens.append(Token(TokenType.RPAREN, ch, pos))
|
||||
|
||||
elif action == Action.EMIT_NUMBER_THEN_DONE:
|
||||
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
|
||||
buffer.clear()
|
||||
|
||||
elif action == Action.SKIP:
|
||||
pass
|
||||
|
||||
elif action == Action.DONE:
|
||||
pass
|
||||
|
||||
elif action == Action.ERROR:
|
||||
raise TokenError(f"unexpected {ch!r} in state {state.name}", pos)
|
||||
|
||||
state = next_state
|
||||
pos += 1
|
||||
|
||||
# --- Post-processing: resolve unary minus ---
|
||||
# A MINUS is unary if it appears:
|
||||
# - at the very start of the token stream
|
||||
# - immediately after an operator (+, -, *, /, ^) or LPAREN
|
||||
# This context-sensitivity cannot be captured by the FSM alone --
|
||||
# it requires looking at previously emitted tokens.
|
||||
_resolve_unary_minus(tokens)
|
||||
|
||||
tokens.append(Token(TokenType.EOF, '', len(expression)))
|
||||
return tokens
|
||||
|
||||
|
||||
def _resolve_unary_minus(tokens):
|
||||
"""
|
||||
Convert binary MINUS tokens to UNARY_MINUS where appropriate.
|
||||
|
||||
Why this isn't in the FSM: the FSM processes characters one at a time
|
||||
and only tracks what kind of token it's currently building (its state).
|
||||
But whether '-' is unary or binary depends on the PREVIOUS TOKEN --
|
||||
information the FSM doesn't track. This is a common real-world pattern:
|
||||
the lexer handles most work, then a lightweight post-pass adds context.
|
||||
"""
|
||||
unary_predecessor = {
|
||||
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
|
||||
TokenType.DIVIDE, TokenType.POWER, TokenType.LPAREN,
|
||||
TokenType.UNARY_MINUS,
|
||||
}
|
||||
for i, token in enumerate(tokens):
|
||||
if token.type != TokenType.MINUS:
|
||||
continue
|
||||
if i == 0 or tokens[i - 1].type in unary_predecessor:
|
||||
tokens[i] = Token(TokenType.UNARY_MINUS, token.value, token.position)
|
||||
200
python/expression-evaluator/visualize.py
Normal file
200
python/expression-evaluator/visualize.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Part 4: Visualization -- Graphviz Dot Output
|
||||
==============================================
|
||||
Generate graphviz dot-format strings for:
|
||||
1. The tokenizer's finite state machine (FSM)
|
||||
2. Any expression's AST (DAG)
|
||||
3. Text-based tree rendering for the terminal
|
||||
|
||||
No external dependencies -- outputs raw dot strings that can be piped
|
||||
to the 'dot' command: python main.py --dot "3+4*2" | dot -Tpng -o ast.png
|
||||
"""
|
||||
|
||||
from parser import NumberNode, BinOpNode, UnaryOpNode, Node
|
||||
from tokenizer import TRANSITIONS, State, CharClass, Action, TokenType
|
||||
|
||||
|
||||
# ---------- FSM diagram ----------
|
||||
|
||||
# Human-readable labels for character classes
|
||||
_CHAR_LABELS = {
|
||||
CharClass.DIGIT: "digit",
|
||||
CharClass.DOT: "'.'",
|
||||
CharClass.OPERATOR: "op",
|
||||
CharClass.LPAREN: "'('",
|
||||
CharClass.RPAREN: "')'",
|
||||
CharClass.SPACE: "space",
|
||||
CharClass.EOF: "EOF",
|
||||
}
|
||||
|
||||
# Short labels for actions
|
||||
_ACTION_LABELS = {
|
||||
Action.ACCUMULATE: "accum",
|
||||
Action.EMIT_NUMBER: "emit num",
|
||||
Action.EMIT_OPERATOR: "emit op",
|
||||
Action.EMIT_LPAREN: "emit '('",
|
||||
Action.EMIT_RPAREN: "emit ')'",
|
||||
Action.EMIT_NUMBER_THEN_OP: "emit num+op",
|
||||
Action.EMIT_NUMBER_THEN_LPAREN: "emit num+'('",
|
||||
Action.EMIT_NUMBER_THEN_RPAREN: "emit num+')'",
|
||||
Action.EMIT_NUMBER_THEN_DONE: "emit num, done",
|
||||
Action.SKIP: "skip",
|
||||
Action.DONE: "done",
|
||||
Action.ERROR: "ERROR",
|
||||
}
|
||||
|
||||
|
||||
def fsm_to_dot():
|
||||
"""
|
||||
Generate a graphviz dot diagram of the tokenizer's state machine.
|
||||
|
||||
Reads the TRANSITIONS table directly -- because the FSM is data (a dict),
|
||||
we can programmatically inspect and visualize it. This is a key advantage
|
||||
of explicit state machines over implicit if/else control flow.
|
||||
"""
|
||||
lines = [
|
||||
'digraph FSM {',
|
||||
' rankdir=LR;',
|
||||
' node [shape=circle, fontname="Helvetica"];',
|
||||
' edge [fontname="Helvetica", fontsize=10];',
|
||||
'',
|
||||
' // Start indicator',
|
||||
' __start__ [shape=point, width=0.2];',
|
||||
' __start__ -> START;',
|
||||
'',
|
||||
]
|
||||
|
||||
# Collect edges grouped by (src, dst) to merge labels
|
||||
edge_labels = {}
|
||||
for (state, char_class), transition in TRANSITIONS.items():
|
||||
src = state.name
|
||||
dst = transition.next_state.name
|
||||
char_label = _CHAR_LABELS.get(char_class, char_class.name)
|
||||
action_label = _ACTION_LABELS.get(transition.action, transition.action.name)
|
||||
label = f"{char_label} / {action_label}"
|
||||
edge_labels.setdefault((src, dst), []).append(label)
|
||||
|
||||
# Emit edges
|
||||
for (src, dst), labels in sorted(edge_labels.items()):
|
||||
combined = "\\n".join(labels)
|
||||
lines.append(f' {src} -> {dst} [label="{combined}"];')
|
||||
|
||||
lines.append('}')
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
# ---------- AST diagram ----------
|
||||
|
||||
_OP_LABELS = {
|
||||
TokenType.PLUS: '+',
|
||||
TokenType.MINUS: '-',
|
||||
TokenType.MULTIPLY: '*',
|
||||
TokenType.DIVIDE: '/',
|
||||
TokenType.POWER: '^',
|
||||
TokenType.UNARY_MINUS: 'neg',
|
||||
}
|
||||
|
||||
|
||||
def ast_to_dot(node):
|
||||
"""
|
||||
Generate a graphviz dot diagram of an AST (expression tree / DAG).
|
||||
|
||||
Each node gets a unique ID. Edges go from parent to children,
|
||||
showing the directed acyclic structure. Leaves are boxed,
|
||||
operators are ellipses.
|
||||
"""
|
||||
lines = [
|
||||
'digraph AST {',
|
||||
' node [fontname="Helvetica"];',
|
||||
' edge [fontname="Helvetica"];',
|
||||
'',
|
||||
]
|
||||
counter = [0]
|
||||
|
||||
def _visit(node):
|
||||
nid = f"n{counter[0]}"
|
||||
counter[0] += 1
|
||||
|
||||
match node:
|
||||
case NumberNode(value=v):
|
||||
label = _format_number(v)
|
||||
lines.append(f' {nid} [label="{label}", shape=box, style=rounded];')
|
||||
return nid
|
||||
|
||||
case UnaryOpNode(op=op, operand=child):
|
||||
label = _OP_LABELS.get(op, op.name)
|
||||
lines.append(f' {nid} [label="{label}", shape=ellipse];')
|
||||
child_id = _visit(child)
|
||||
lines.append(f' {nid} -> {child_id};')
|
||||
return nid
|
||||
|
||||
case BinOpNode(op=op, left=left, right=right):
|
||||
label = _OP_LABELS.get(op, op.name)
|
||||
lines.append(f' {nid} [label="{label}", shape=ellipse];')
|
||||
left_id = _visit(left)
|
||||
right_id = _visit(right)
|
||||
lines.append(f' {nid} -> {left_id} [label="L"];')
|
||||
lines.append(f' {nid} -> {right_id} [label="R"];')
|
||||
return nid
|
||||
|
||||
_visit(node)
|
||||
lines.append('}')
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
# ---------- Text-based tree ----------
|
||||
|
||||
def ast_to_text(node, prefix="", connector=""):
|
||||
"""
|
||||
Render the AST as an indented text tree for terminal display.
|
||||
|
||||
Example output for (2 + 3) * 4:
|
||||
*
|
||||
+-- +
|
||||
| +-- 2
|
||||
| +-- 3
|
||||
+-- 4
|
||||
"""
|
||||
match node:
|
||||
case NumberNode(value=v):
|
||||
label = _format_number(v)
|
||||
case UnaryOpNode(op=op):
|
||||
label = _OP_LABELS.get(op, op.name)
|
||||
case BinOpNode(op=op):
|
||||
label = _OP_LABELS.get(op, op.name)
|
||||
|
||||
lines = [f"{prefix}{connector}{label}"]
|
||||
|
||||
children = _get_children(node)
|
||||
for i, child in enumerate(children):
|
||||
is_last_child = (i == len(children) - 1)
|
||||
if connector:
|
||||
# Extend the prefix: if we used "+-- " then next children
|
||||
# see "| " (continuing) or " " (last child)
|
||||
child_prefix = prefix + ("| " if connector == "+-- " else " ")
|
||||
else:
|
||||
child_prefix = prefix
|
||||
child_connector = "+-- " if is_last_child else "+-- "
|
||||
# Use a different lead for non-last: the vertical bar continues
|
||||
child_connector = "`-- " if is_last_child else "+-- "
|
||||
child_lines = ast_to_text(child, child_prefix, child_connector)
|
||||
lines.append(child_lines)
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
def _get_children(node):
|
||||
match node:
|
||||
case NumberNode():
|
||||
return []
|
||||
case UnaryOpNode(operand=child):
|
||||
return [child]
|
||||
case BinOpNode(left=left, right=right):
|
||||
return [left, right]
|
||||
return []
|
||||
|
||||
|
||||
def _format_number(v):
|
||||
if isinstance(v, float) and v == int(v):
|
||||
return str(int(v))
|
||||
return str(v)
|
||||
38
python/persian-tutor/CLAUDE.md
Normal file
38
python/persian-tutor/CLAUDE.md
Normal file
@@ -0,0 +1,38 @@
|
||||
# Persian Language Tutor
|
||||
|
||||
## Overview
|
||||
Gradio-based Persian (Farsi) language learning app for English speakers, using GCSE Persian vocabulary (Pearson spec) as seed data.
|
||||
|
||||
## Tech Stack
|
||||
- **Frontend**: Gradio (browser handles RTL natively)
|
||||
- **Spaced repetition**: py-fsrs (same algorithm as Anki)
|
||||
- **AI**: Ollama (fast, local) + Claude CLI (smart, subprocess)
|
||||
- **STT**: faster-whisper via sttlib from tool-speechtotext
|
||||
- **Anki export**: genanki for .apkg generation
|
||||
- **Database**: SQLite (file-based, data/progress.db)
|
||||
- **Environment**: `whisper-ollama` conda env
|
||||
|
||||
## Running
|
||||
```bash
|
||||
mamba run -n whisper-ollama python app.py
|
||||
```
|
||||
|
||||
## Testing
|
||||
```bash
|
||||
mamba run -n whisper-ollama python -m pytest tests/
|
||||
```
|
||||
|
||||
## Key Paths
|
||||
- `data/vocabulary.json` — GCSE vocabulary data
|
||||
- `data/progress.db` — SQLite database (auto-created)
|
||||
- `app.py` — Gradio entry point
|
||||
- `db.py` — Database layer with FSRS integration
|
||||
- `ai.py` — Dual AI backend (Ollama + Claude)
|
||||
- `stt.py` — Persian speech-to-text wrapper
|
||||
- `modules/` — Feature modules (vocab, dashboard, essay, tutor, idioms)
|
||||
|
||||
## Architecture
|
||||
- Single-process Gradio app with shared SQLite connection
|
||||
- FSRS Card objects serialized as JSON in SQLite TEXT columns
|
||||
- Timestamps stored as ISO-8601 strings
|
||||
- sttlib imported via sys.path from tool-speechtotext project
|
||||
57
python/persian-tutor/README.md
Normal file
57
python/persian-tutor/README.md
Normal file
@@ -0,0 +1,57 @@
|
||||
# Persian Language Tutor
|
||||
|
||||
A Gradio-based Persian (Farsi) language learning app for English speakers, built around GCSE Persian vocabulary (Pearson specification).
|
||||
|
||||
## Features
|
||||
|
||||
- **Vocabulary Study** — Search, browse, and study 918 GCSE Persian words across 39 categories
|
||||
- **Flashcards with FSRS** — Spaced repetition scheduling (same algorithm as Anki)
|
||||
- **Idioms & Expressions** — 25 Persian social conventions with cultural context
|
||||
- **AI Tutor** — Conversational Persian lessons by GCSE theme (via Ollama)
|
||||
- **Essay Marking** — Write Persian essays, get AI feedback and grading (via Claude)
|
||||
- **Dashboard** — Track progress, streaks, and mastery
|
||||
- **Anki Export** — Generate .apkg decks for offline study
|
||||
- **Voice Input** — Speak Persian via microphone (Whisper STT) in the Tutor tab
|
||||
|
||||
## Prerequisites
|
||||
|
||||
- `whisper-ollama` conda environment with Python 3.10+
|
||||
- Ollama running locally with `qwen2.5:7b` (or another model)
|
||||
- Claude CLI installed (for essay marking / smart mode)
|
||||
|
||||
## Setup
|
||||
|
||||
```bash
|
||||
/home/ys/miniforge3/envs/whisper-ollama/bin/pip install gradio genanki fsrs
|
||||
```
|
||||
|
||||
## Running the app
|
||||
|
||||
```bash
|
||||
cd /home/ys/family-repo/Code/python/persian-tutor
|
||||
/home/ys/miniforge3/envs/whisper-ollama/bin/python app.py
|
||||
```
|
||||
|
||||
Then open http://localhost:7860 in your browser.
|
||||
|
||||
## Running tests
|
||||
|
||||
```bash
|
||||
cd /home/ys/family-repo/Code/python/persian-tutor
|
||||
/home/ys/miniforge3/envs/whisper-ollama/bin/python -m pytest tests/ -v
|
||||
```
|
||||
|
||||
41 tests covering db, vocab, ai, and anki_export modules.
|
||||
|
||||
## Expanding vocabulary
|
||||
|
||||
The vocabulary can be expanded by editing `data/vocabulary.json` directly or by updating `scripts/build_vocab.py` and re-running it:
|
||||
|
||||
```bash
|
||||
/home/ys/miniforge3/envs/whisper-ollama/bin/python scripts/build_vocab.py
|
||||
```
|
||||
|
||||
## TODO
|
||||
|
||||
- [ ] Voice-based vocabulary testing — answer flashcard prompts by speaking Persian
|
||||
- [ ] Improved UI theme and layout polish
|
||||
56
python/persian-tutor/ai.py
Normal file
56
python/persian-tutor/ai.py
Normal file
@@ -0,0 +1,56 @@
|
||||
"""Dual AI backend: Ollama (fast/local) and Claude CLI (smart)."""
|
||||
|
||||
import subprocess
|
||||
|
||||
import ollama
|
||||
|
||||
DEFAULT_OLLAMA_MODEL = "qwen2.5:7b"
|
||||
|
||||
_ollama_model = DEFAULT_OLLAMA_MODEL
|
||||
|
||||
|
||||
def set_ollama_model(model):
|
||||
"""Change the Ollama model used for fast queries."""
|
||||
global _ollama_model
|
||||
_ollama_model = model
|
||||
|
||||
|
||||
def ask_ollama(prompt, system=None, model=None):
|
||||
"""Query Ollama with an optional system prompt."""
|
||||
model = model or _ollama_model
|
||||
messages = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
messages.append({"role": "user", "content": prompt})
|
||||
response = ollama.chat(model=model, messages=messages)
|
||||
return response.message.content
|
||||
|
||||
|
||||
def ask_claude(prompt):
|
||||
"""Query Claude via the CLI subprocess."""
|
||||
result = subprocess.run(
|
||||
["claude", "-p", prompt],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Claude CLI failed (exit {result.returncode}): {result.stderr.strip()}")
|
||||
return result.stdout.strip()
|
||||
|
||||
|
||||
def ask(prompt, system=None, quality="fast"):
|
||||
"""Unified interface. quality='fast' uses Ollama, 'smart' uses Claude."""
|
||||
if quality == "smart":
|
||||
return ask_claude(prompt)
|
||||
return ask_ollama(prompt, system=system)
|
||||
|
||||
|
||||
def chat_ollama(messages, system=None, model=None):
|
||||
"""Multi-turn conversation with Ollama."""
|
||||
model = model or _ollama_model
|
||||
all_messages = []
|
||||
if system:
|
||||
all_messages.append({"role": "system", "content": system})
|
||||
all_messages.extend(messages)
|
||||
response = ollama.chat(model=model, messages=all_messages)
|
||||
return response.message.content
|
||||
75
python/persian-tutor/anki_export.py
Normal file
75
python/persian-tutor/anki_export.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""Generate Anki .apkg decks from vocabulary data."""
|
||||
|
||||
import genanki
|
||||
|
||||
# Stable model/deck IDs (generated once, kept constant)
|
||||
_MODEL_ID = 1607392319
|
||||
_DECK_ID = 2059400110
|
||||
|
||||
|
||||
def _make_model():
|
||||
"""Create an Anki note model with two card templates."""
|
||||
return genanki.Model(
|
||||
_MODEL_ID,
|
||||
"GCSE Persian",
|
||||
fields=[
|
||||
{"name": "English"},
|
||||
{"name": "Persian"},
|
||||
{"name": "Finglish"},
|
||||
{"name": "Category"},
|
||||
],
|
||||
templates=[
|
||||
{
|
||||
"name": "English → Persian",
|
||||
"qfmt": '<div style="font-size:1.5em">{{English}}</div>'
|
||||
'<br><small>{{Category}}</small>',
|
||||
"afmt": '{{FrontSide}}<hr id="answer">'
|
||||
'<div dir="rtl" style="font-size:2em">{{Persian}}</div>'
|
||||
"<br><div>{{Finglish}}</div>",
|
||||
},
|
||||
{
|
||||
"name": "Persian → English",
|
||||
"qfmt": '<div dir="rtl" style="font-size:2em">{{Persian}}</div>'
|
||||
'<br><small>{{Category}}</small>',
|
||||
"afmt": '{{FrontSide}}<hr id="answer">'
|
||||
'<div style="font-size:1.5em">{{English}}</div>'
|
||||
"<br><div>{{Finglish}}</div>",
|
||||
},
|
||||
],
|
||||
css=".card { font-family: arial; text-align: center; }",
|
||||
)
|
||||
|
||||
|
||||
def export_deck(vocab, categories=None, output_path="gcse-persian.apkg"):
|
||||
"""Generate an Anki .apkg deck from vocabulary entries.
|
||||
|
||||
Args:
|
||||
vocab: List of vocabulary entries (dicts with english, persian, finglish, category).
|
||||
categories: Optional list of categories to include. None = all.
|
||||
output_path: Where to save the .apkg file.
|
||||
|
||||
Returns:
|
||||
Path to the generated .apkg file.
|
||||
"""
|
||||
model = _make_model()
|
||||
deck = genanki.Deck(_DECK_ID, "GCSE Persian")
|
||||
|
||||
for entry in vocab:
|
||||
if categories and entry.get("category") not in categories:
|
||||
continue
|
||||
|
||||
note = genanki.Note(
|
||||
model=model,
|
||||
fields=[
|
||||
entry.get("english", ""),
|
||||
entry.get("persian", ""),
|
||||
entry.get("finglish", ""),
|
||||
entry.get("category", ""),
|
||||
],
|
||||
guid=genanki.guid_for(entry.get("id", entry["english"])),
|
||||
)
|
||||
deck.add_note(note)
|
||||
|
||||
package = genanki.Package(deck)
|
||||
package.write_to_file(output_path)
|
||||
return output_path
|
||||
525
python/persian-tutor/app.py
Normal file
525
python/persian-tutor/app.py
Normal file
@@ -0,0 +1,525 @@
|
||||
"""Persian Language Tutor — Gradio UI."""
|
||||
|
||||
import json
|
||||
import os
|
||||
import tempfile
|
||||
import time
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import ai
|
||||
import db
|
||||
from modules import vocab, dashboard, essay, tutor, idioms
|
||||
from modules.essay import GCSE_THEMES
|
||||
from modules.tutor import THEME_PROMPTS
|
||||
from anki_export import export_deck
|
||||
|
||||
|
||||
# ---------- Initialise ----------
|
||||
db.init_db()
|
||||
vocabulary = vocab.load_vocab()
|
||||
categories = ["All"] + vocab.get_categories()
|
||||
|
||||
|
||||
# ---------- Helper ----------
|
||||
def _rtl(text, size="2em"):
|
||||
return f'<div dir="rtl" style="font-size:{size}; text-align:center">{text}</div>'
|
||||
|
||||
|
||||
# ================================================================
|
||||
# TAB HANDLERS
|
||||
# ================================================================
|
||||
|
||||
# ---------- Dashboard ----------
|
||||
def refresh_dashboard():
|
||||
overview_md = dashboard.format_overview_markdown()
|
||||
cat_data = dashboard.get_category_breakdown()
|
||||
quiz_data = dashboard.get_recent_quizzes()
|
||||
return overview_md, cat_data, quiz_data
|
||||
|
||||
|
||||
# ---------- Vocabulary Search ----------
|
||||
def do_search(query, category):
|
||||
results = vocab.search(query)
|
||||
if category and category != "All":
|
||||
results = [r for r in results if r["category"] == category]
|
||||
if not results:
|
||||
return "No results found."
|
||||
lines = []
|
||||
for r in results:
|
||||
status = vocab.get_word_status(r["id"])
|
||||
icon = {"new": "⬜", "learning": "🟨", "mastered": "🟩"}.get(status, "⬜")
|
||||
lines.append(
|
||||
f'{icon} **{r["english"]}** — '
|
||||
f'<span dir="rtl">{r["persian"]}</span>'
|
||||
f' ({r.get("finglish", "")})'
|
||||
)
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
def do_random_word(category, transliteration):
|
||||
entry = vocab.get_random_word(category=category)
|
||||
if not entry:
|
||||
return "No words found."
|
||||
return vocab.format_word_card(entry, show_transliteration=transliteration)
|
||||
|
||||
|
||||
# ---------- Flashcards ----------
|
||||
def start_flashcards(category, direction):
|
||||
batch = vocab.get_flashcard_batch(count=10, category=category)
|
||||
if not batch:
|
||||
return "No words available.", [], 0, 0, "", gr.update(visible=False)
|
||||
|
||||
first = batch[0]
|
||||
if direction == "English → Persian":
|
||||
prompt = f'<div style="font-size:2em; text-align:center">{first["english"]}</div>'
|
||||
else:
|
||||
prompt = _rtl(first["persian"])
|
||||
|
||||
return (
|
||||
prompt, # card_display
|
||||
batch, # batch state
|
||||
0, # current index
|
||||
0, # score
|
||||
"", # answer_box cleared
|
||||
gr.update(visible=True), # answer_area visible
|
||||
)
|
||||
|
||||
|
||||
def submit_answer(user_answer, batch, index, score, direction, transliteration):
|
||||
if not batch or index >= len(batch):
|
||||
return "Session complete!", batch, index, score, "", gr.update(visible=False), ""
|
||||
|
||||
entry = batch[index]
|
||||
dir_key = "en_to_fa" if direction == "English → Persian" else "fa_to_en"
|
||||
is_correct, correct_answer, _ = vocab.check_answer(entry["id"], user_answer, direction=dir_key)
|
||||
|
||||
if is_correct:
|
||||
score += 1
|
||||
result = "✅ **Correct!**"
|
||||
else:
|
||||
result = f"❌ **Incorrect.** The answer is: "
|
||||
if dir_key == "en_to_fa":
|
||||
result += f'<span dir="rtl">{correct_answer}</span>'
|
||||
else:
|
||||
result += correct_answer
|
||||
|
||||
card_info = vocab.format_word_card(entry, show_transliteration=transliteration)
|
||||
feedback = f"{result}\n\n{card_info}\n\n---\n*Rate your recall to continue:*"
|
||||
|
||||
return feedback, batch, index, score, "", gr.update(visible=True), ""
|
||||
|
||||
|
||||
def rate_and_next(rating_str, batch, index, score, direction):
|
||||
if not batch or index >= len(batch):
|
||||
return "Session complete!", batch, index, score, gr.update(visible=False)
|
||||
|
||||
import fsrs as fsrs_mod
|
||||
rating_map = {
|
||||
"Again": fsrs_mod.Rating.Again,
|
||||
"Hard": fsrs_mod.Rating.Hard,
|
||||
"Good": fsrs_mod.Rating.Good,
|
||||
"Easy": fsrs_mod.Rating.Easy,
|
||||
}
|
||||
rating = rating_map.get(rating_str, fsrs_mod.Rating.Good)
|
||||
entry = batch[index]
|
||||
db.update_word_progress(entry["id"], rating)
|
||||
|
||||
index += 1
|
||||
if index >= len(batch):
|
||||
summary = f"## Session Complete!\n\n**Score:** {score}/{len(batch)}\n\n"
|
||||
summary += f"**Accuracy:** {score/len(batch)*100:.0f}%"
|
||||
return summary, batch, index, score, gr.update(visible=False)
|
||||
|
||||
next_entry = batch[index]
|
||||
if direction == "English → Persian":
|
||||
prompt = f'<div style="font-size:2em; text-align:center">{next_entry["english"]}</div>'
|
||||
else:
|
||||
prompt = _rtl(next_entry["persian"])
|
||||
|
||||
return prompt, batch, index, score, gr.update(visible=True)
|
||||
|
||||
|
||||
# ---------- Idioms ----------
|
||||
def show_random_idiom(transliteration):
|
||||
expr = idioms.get_random_expression()
|
||||
return idioms.format_expression(expr, show_transliteration=transliteration), expr
|
||||
|
||||
|
||||
def explain_idiom(expr_state):
|
||||
if not expr_state:
|
||||
return "Pick an idiom first."
|
||||
return idioms.explain_expression(expr_state)
|
||||
|
||||
|
||||
def browse_idioms(transliteration):
|
||||
exprs = idioms.get_all_expressions()
|
||||
lines = []
|
||||
for e in exprs:
|
||||
line = f'**<span dir="rtl">{e["persian"]}</span>** — {e["english"]}'
|
||||
if transliteration != "off":
|
||||
line += f' *({e["finglish"]})*'
|
||||
lines.append(line)
|
||||
return "\n\n".join(lines)
|
||||
|
||||
|
||||
# ---------- Tutor ----------
|
||||
def start_tutor_lesson(theme):
|
||||
response, messages, system = tutor.start_lesson(theme)
|
||||
chat_history = [{"role": "assistant", "content": response}]
|
||||
return chat_history, messages, system, time.time()
|
||||
|
||||
|
||||
def send_tutor_message(user_msg, chat_history, messages, system, audio_input):
|
||||
# Use STT if audio provided and no text
|
||||
if audio_input is not None and (not user_msg or not user_msg.strip()):
|
||||
try:
|
||||
from stt import transcribe_persian
|
||||
user_msg = transcribe_persian(audio_input)
|
||||
except Exception:
|
||||
user_msg = ""
|
||||
|
||||
if not user_msg or not user_msg.strip():
|
||||
return chat_history, messages, "", None
|
||||
|
||||
response, messages = tutor.process_response(user_msg, messages, system=system)
|
||||
chat_history.append({"role": "user", "content": user_msg})
|
||||
chat_history.append({"role": "assistant", "content": response})
|
||||
return chat_history, messages, "", None
|
||||
|
||||
|
||||
def save_tutor(theme, messages, start_time):
|
||||
if messages and len(messages) > 1:
|
||||
tutor.save_session(theme, messages, start_time)
|
||||
return "Session saved!"
|
||||
return "Nothing to save."
|
||||
|
||||
|
||||
# ---------- Essay ----------
|
||||
def submit_essay(text, theme):
|
||||
if not text or not text.strip():
|
||||
return "Please write an essay first."
|
||||
return essay.mark_essay(text, theme)
|
||||
|
||||
|
||||
def load_essay_history():
|
||||
return essay.get_essay_history()
|
||||
|
||||
|
||||
# ---------- Settings / Export ----------
|
||||
def do_anki_export(cats_selected):
|
||||
v = vocab.load_vocab()
|
||||
cats = cats_selected if cats_selected else None
|
||||
path = os.path.join(tempfile.gettempdir(), "gcse-persian.apkg")
|
||||
export_deck(v, categories=cats, output_path=path)
|
||||
return path
|
||||
|
||||
|
||||
def update_ollama_model(model):
|
||||
ai.set_ollama_model(model)
|
||||
|
||||
|
||||
def update_whisper_size(size):
|
||||
from stt import set_whisper_size
|
||||
set_whisper_size(size)
|
||||
|
||||
|
||||
def reset_progress():
|
||||
conn = db.get_connection()
|
||||
conn.execute("DELETE FROM word_progress")
|
||||
conn.execute("DELETE FROM quiz_sessions")
|
||||
conn.execute("DELETE FROM essays")
|
||||
conn.execute("DELETE FROM tutor_sessions")
|
||||
conn.commit()
|
||||
return "Progress reset."
|
||||
|
||||
|
||||
# ================================================================
|
||||
# GRADIO UI
|
||||
# ================================================================
|
||||
|
||||
with gr.Blocks(title="Persian Language Tutor") as app:
|
||||
gr.Markdown("# 🇮🇷 Persian Language Tutor\n*GCSE Persian vocabulary with spaced repetition*")
|
||||
|
||||
# Shared state
|
||||
transliteration_state = gr.State(value="Finglish")
|
||||
|
||||
with gr.Tabs():
|
||||
# ==================== DASHBOARD ====================
|
||||
with gr.Tab("📊 Dashboard"):
|
||||
overview_md = gr.Markdown("Loading...")
|
||||
with gr.Row():
|
||||
cat_table = gr.Dataframe(
|
||||
headers=["Category", "Total", "Seen", "Mastered", "Progress"],
|
||||
label="Category Breakdown",
|
||||
)
|
||||
quiz_table = gr.Dataframe(
|
||||
headers=["Date", "Category", "Score", "Duration"],
|
||||
label="Recent Quizzes",
|
||||
)
|
||||
refresh_btn = gr.Button("Refresh", variant="secondary")
|
||||
refresh_btn.click(
|
||||
fn=refresh_dashboard,
|
||||
outputs=[overview_md, cat_table, quiz_table],
|
||||
)
|
||||
|
||||
# ==================== VOCABULARY ====================
|
||||
with gr.Tab("📚 Vocabulary"):
|
||||
with gr.Row():
|
||||
search_box = gr.Textbox(
|
||||
label="Search (English or Persian)",
|
||||
placeholder="Type to search...",
|
||||
)
|
||||
vocab_cat = gr.Dropdown(
|
||||
choices=categories, value="All", label="Category"
|
||||
)
|
||||
search_btn = gr.Button("Search", variant="primary")
|
||||
random_btn = gr.Button("Random Word")
|
||||
search_results = gr.Markdown("Search for a word above.")
|
||||
|
||||
search_btn.click(
|
||||
fn=do_search,
|
||||
inputs=[search_box, vocab_cat],
|
||||
outputs=[search_results],
|
||||
)
|
||||
search_box.submit(
|
||||
fn=do_search,
|
||||
inputs=[search_box, vocab_cat],
|
||||
outputs=[search_results],
|
||||
)
|
||||
random_btn.click(
|
||||
fn=do_random_word,
|
||||
inputs=[vocab_cat, transliteration_state],
|
||||
outputs=[search_results],
|
||||
)
|
||||
|
||||
# ==================== FLASHCARDS ====================
|
||||
with gr.Tab("🃏 Flashcards"):
|
||||
with gr.Row():
|
||||
fc_category = gr.Dropdown(
|
||||
choices=categories, value="All", label="Category"
|
||||
)
|
||||
fc_direction = gr.Radio(
|
||||
["English → Persian", "Persian → English"],
|
||||
value="English → Persian",
|
||||
label="Direction",
|
||||
)
|
||||
start_fc_btn = gr.Button("Start Session", variant="primary")
|
||||
|
||||
card_display = gr.Markdown("Press 'Start Session' to begin.")
|
||||
|
||||
# Hidden states
|
||||
fc_batch = gr.State([])
|
||||
fc_index = gr.State(0)
|
||||
fc_score = gr.State(0)
|
||||
|
||||
with gr.Group(visible=False) as answer_area:
|
||||
answer_box = gr.Textbox(
|
||||
label="Your answer",
|
||||
placeholder="Type your answer...",
|
||||
rtl=True,
|
||||
)
|
||||
submit_ans_btn = gr.Button("Submit Answer", variant="primary")
|
||||
answer_feedback = gr.Markdown("")
|
||||
|
||||
with gr.Row():
|
||||
btn_again = gr.Button("Again", variant="stop")
|
||||
btn_hard = gr.Button("Hard", variant="secondary")
|
||||
btn_good = gr.Button("Good", variant="primary")
|
||||
btn_easy = gr.Button("Easy", variant="secondary")
|
||||
|
||||
start_fc_btn.click(
|
||||
fn=start_flashcards,
|
||||
inputs=[fc_category, fc_direction],
|
||||
outputs=[card_display, fc_batch, fc_index, fc_score, answer_box, answer_area],
|
||||
)
|
||||
|
||||
submit_ans_btn.click(
|
||||
fn=submit_answer,
|
||||
inputs=[answer_box, fc_batch, fc_index, fc_score, fc_direction, transliteration_state],
|
||||
outputs=[card_display, fc_batch, fc_index, fc_score, answer_box, answer_area, answer_feedback],
|
||||
)
|
||||
answer_box.submit(
|
||||
fn=submit_answer,
|
||||
inputs=[answer_box, fc_batch, fc_index, fc_score, fc_direction, transliteration_state],
|
||||
outputs=[card_display, fc_batch, fc_index, fc_score, answer_box, answer_area, answer_feedback],
|
||||
)
|
||||
|
||||
for btn, label in [(btn_again, "Again"), (btn_hard, "Hard"), (btn_good, "Good"), (btn_easy, "Easy")]:
|
||||
btn.click(
|
||||
fn=rate_and_next,
|
||||
inputs=[gr.State(label), fc_batch, fc_index, fc_score, fc_direction],
|
||||
outputs=[card_display, fc_batch, fc_index, fc_score, answer_area],
|
||||
)
|
||||
|
||||
# ==================== IDIOMS ====================
|
||||
with gr.Tab("💬 Idioms & Expressions"):
|
||||
idiom_display = gr.Markdown("Click 'Random Idiom' or browse below.")
|
||||
idiom_state = gr.State(None)
|
||||
|
||||
with gr.Row():
|
||||
random_idiom_btn = gr.Button("Random Idiom", variant="primary")
|
||||
explain_idiom_btn = gr.Button("Explain Usage")
|
||||
browse_idiom_btn = gr.Button("Browse All")
|
||||
|
||||
idiom_explanation = gr.Markdown("")
|
||||
|
||||
random_idiom_btn.click(
|
||||
fn=show_random_idiom,
|
||||
inputs=[transliteration_state],
|
||||
outputs=[idiom_display, idiom_state],
|
||||
)
|
||||
explain_idiom_btn.click(
|
||||
fn=explain_idiom,
|
||||
inputs=[idiom_state],
|
||||
outputs=[idiom_explanation],
|
||||
)
|
||||
browse_idiom_btn.click(
|
||||
fn=browse_idioms,
|
||||
inputs=[transliteration_state],
|
||||
outputs=[idiom_display],
|
||||
)
|
||||
|
||||
# ==================== TUTOR ====================
|
||||
with gr.Tab("🎓 Tutor"):
|
||||
tutor_theme = gr.Dropdown(
|
||||
choices=list(THEME_PROMPTS.keys()),
|
||||
value="Identity and culture",
|
||||
label="Theme",
|
||||
)
|
||||
start_lesson_btn = gr.Button("New Lesson", variant="primary")
|
||||
|
||||
chatbot = gr.Chatbot(label="Conversation")
|
||||
|
||||
# Tutor states
|
||||
tutor_messages = gr.State([])
|
||||
tutor_system = gr.State("")
|
||||
tutor_start_time = gr.State(0)
|
||||
|
||||
with gr.Row():
|
||||
tutor_input = gr.Textbox(
|
||||
label="Your message",
|
||||
placeholder="Type in English or Persian...",
|
||||
scale=3,
|
||||
)
|
||||
tutor_mic = gr.Audio(
|
||||
sources=["microphone"],
|
||||
type="numpy",
|
||||
label="Speak",
|
||||
scale=1,
|
||||
)
|
||||
send_btn = gr.Button("Send", variant="primary")
|
||||
save_btn = gr.Button("Save Session", variant="secondary")
|
||||
save_status = gr.Markdown("")
|
||||
|
||||
start_lesson_btn.click(
|
||||
fn=start_tutor_lesson,
|
||||
inputs=[tutor_theme],
|
||||
outputs=[chatbot, tutor_messages, tutor_system, tutor_start_time],
|
||||
)
|
||||
|
||||
send_btn.click(
|
||||
fn=send_tutor_message,
|
||||
inputs=[tutor_input, chatbot, tutor_messages, tutor_system, tutor_mic],
|
||||
outputs=[chatbot, tutor_messages, tutor_input, tutor_mic],
|
||||
)
|
||||
tutor_input.submit(
|
||||
fn=send_tutor_message,
|
||||
inputs=[tutor_input, chatbot, tutor_messages, tutor_system, tutor_mic],
|
||||
outputs=[chatbot, tutor_messages, tutor_input, tutor_mic],
|
||||
)
|
||||
|
||||
save_btn.click(
|
||||
fn=save_tutor,
|
||||
inputs=[tutor_theme, tutor_messages, tutor_start_time],
|
||||
outputs=[save_status],
|
||||
)
|
||||
|
||||
# ==================== ESSAY ====================
|
||||
with gr.Tab("✍️ Essay"):
|
||||
essay_theme = gr.Dropdown(
|
||||
choices=GCSE_THEMES,
|
||||
value="Identity and culture",
|
||||
label="Theme",
|
||||
)
|
||||
essay_input = gr.Textbox(
|
||||
label="Write your essay in Persian",
|
||||
lines=10,
|
||||
rtl=True,
|
||||
placeholder="اینجا بنویسید...",
|
||||
)
|
||||
submit_essay_btn = gr.Button("Submit for Marking", variant="primary")
|
||||
essay_feedback = gr.Markdown("Write an essay and submit for AI marking.")
|
||||
|
||||
gr.Markdown("### Essay History")
|
||||
essay_history_table = gr.Dataframe(
|
||||
headers=["Date", "Theme", "Grade", "Preview"],
|
||||
label="Past Essays",
|
||||
)
|
||||
refresh_essays_btn = gr.Button("Refresh History")
|
||||
|
||||
submit_essay_btn.click(
|
||||
fn=submit_essay,
|
||||
inputs=[essay_input, essay_theme],
|
||||
outputs=[essay_feedback],
|
||||
)
|
||||
refresh_essays_btn.click(
|
||||
fn=load_essay_history,
|
||||
outputs=[essay_history_table],
|
||||
)
|
||||
|
||||
# ==================== SETTINGS ====================
|
||||
with gr.Tab("⚙️ Settings"):
|
||||
gr.Markdown("## Settings")
|
||||
|
||||
transliteration_radio = gr.Radio(
|
||||
["off", "Finglish", "Academic"],
|
||||
value="Finglish",
|
||||
label="Transliteration",
|
||||
)
|
||||
|
||||
ollama_model = gr.Textbox(
|
||||
label="Ollama Model",
|
||||
value="qwen2.5:7b",
|
||||
info="Model used for fast AI responses",
|
||||
)
|
||||
|
||||
whisper_size = gr.Dropdown(
|
||||
choices=["tiny", "base", "small", "medium", "large-v3"],
|
||||
value="medium",
|
||||
label="Whisper Model Size",
|
||||
)
|
||||
|
||||
gr.Markdown("### Anki Export")
|
||||
export_cats = gr.Dropdown(
|
||||
choices=vocab.get_categories(),
|
||||
multiselect=True,
|
||||
label="Categories to export (empty = all)",
|
||||
)
|
||||
export_btn = gr.Button("Export to Anki (.apkg)", variant="primary")
|
||||
export_file = gr.File(label="Download")
|
||||
|
||||
export_btn.click(fn=do_anki_export, inputs=[export_cats], outputs=[export_file])
|
||||
|
||||
# Wire model settings
|
||||
ollama_model.change(fn=update_ollama_model, inputs=[ollama_model])
|
||||
whisper_size.change(fn=update_whisper_size, inputs=[whisper_size])
|
||||
|
||||
gr.Markdown("### Reset")
|
||||
reset_btn = gr.Button("Reset All Progress", variant="stop")
|
||||
reset_status = gr.Markdown("")
|
||||
reset_btn.click(fn=reset_progress, outputs=[reset_status])
|
||||
|
||||
# Wire transliteration state
|
||||
transliteration_radio.change(
|
||||
fn=lambda x: x,
|
||||
inputs=[transliteration_radio],
|
||||
outputs=[transliteration_state],
|
||||
)
|
||||
|
||||
# Load dashboard on app start
|
||||
app.load(fn=refresh_dashboard, outputs=[overview_md, cat_table, quiz_table])
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
app.launch(theme=gr.themes.Soft())
|
||||
7346
python/persian-tutor/data/vocabulary.json
Normal file
7346
python/persian-tutor/data/vocabulary.json
Normal file
File diff suppressed because it is too large
Load Diff
241
python/persian-tutor/db.py
Normal file
241
python/persian-tutor/db.py
Normal file
@@ -0,0 +1,241 @@
|
||||
"""SQLite database layer with FSRS spaced repetition integration."""
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import fsrs
|
||||
|
||||
DB_PATH = Path(__file__).parent / "data" / "progress.db"
|
||||
|
||||
_conn = None
|
||||
_scheduler = fsrs.Scheduler()
|
||||
|
||||
|
||||
def get_connection():
|
||||
"""Return the shared SQLite connection (singleton)."""
|
||||
global _conn
|
||||
if _conn is None:
|
||||
DB_PATH.parent.mkdir(parents=True, exist_ok=True)
|
||||
_conn = sqlite3.connect(str(DB_PATH), check_same_thread=False)
|
||||
_conn.row_factory = sqlite3.Row
|
||||
_conn.execute("PRAGMA journal_mode=WAL")
|
||||
return _conn
|
||||
|
||||
|
||||
def init_db():
|
||||
"""Create all tables if they don't exist. Called once at startup."""
|
||||
conn = get_connection()
|
||||
conn.executescript("""
|
||||
CREATE TABLE IF NOT EXISTS word_progress (
|
||||
word_id TEXT PRIMARY KEY,
|
||||
fsrs_state TEXT,
|
||||
due TIMESTAMP,
|
||||
stability REAL,
|
||||
difficulty REAL,
|
||||
reps INTEGER DEFAULT 0,
|
||||
lapses INTEGER DEFAULT 0,
|
||||
last_review TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS quiz_sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
category TEXT,
|
||||
total_questions INTEGER,
|
||||
correct INTEGER,
|
||||
duration_seconds INTEGER
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS essays (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
essay_text TEXT,
|
||||
grade TEXT,
|
||||
feedback TEXT,
|
||||
theme TEXT
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS tutor_sessions (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
|
||||
theme TEXT,
|
||||
messages TEXT,
|
||||
duration_seconds INTEGER
|
||||
);
|
||||
""")
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_word_progress(word_id):
|
||||
"""Return learning state for one word, or None if never reviewed."""
|
||||
conn = get_connection()
|
||||
row = conn.execute(
|
||||
"SELECT * FROM word_progress WHERE word_id = ?", (word_id,)
|
||||
).fetchone()
|
||||
return dict(row) if row else None
|
||||
|
||||
|
||||
def update_word_progress(word_id, rating):
|
||||
"""Run FSRS algorithm, update due date/stability/difficulty.
|
||||
|
||||
Args:
|
||||
word_id: Vocabulary entry ID.
|
||||
rating: fsrs.Rating value (Again=1, Hard=2, Good=3, Easy=4).
|
||||
"""
|
||||
conn = get_connection()
|
||||
existing = get_word_progress(word_id)
|
||||
|
||||
if existing and existing["fsrs_state"]:
|
||||
card = fsrs.Card.from_dict(json.loads(existing["fsrs_state"]))
|
||||
else:
|
||||
card = fsrs.Card()
|
||||
|
||||
card, review_log = _scheduler.review_card(card, rating)
|
||||
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
card_json = json.dumps(card.to_dict(), default=str)
|
||||
|
||||
conn.execute(
|
||||
"""INSERT OR REPLACE INTO word_progress
|
||||
(word_id, fsrs_state, due, stability, difficulty, reps, lapses, last_review)
|
||||
VALUES (?, ?, ?, ?, ?, ?, ?, ?)""",
|
||||
(
|
||||
word_id,
|
||||
card_json,
|
||||
card.due.isoformat(),
|
||||
card.stability,
|
||||
card.difficulty,
|
||||
(existing["reps"] + 1) if existing else 1,
|
||||
existing["lapses"] if existing else 0,
|
||||
now,
|
||||
),
|
||||
)
|
||||
conn.commit()
|
||||
return card
|
||||
|
||||
|
||||
def get_due_words(limit=20):
|
||||
"""Return word IDs where due <= now, ordered by due date."""
|
||||
conn = get_connection()
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
rows = conn.execute(
|
||||
"SELECT word_id FROM word_progress WHERE due <= ? ORDER BY due LIMIT ?",
|
||||
(now, limit),
|
||||
).fetchall()
|
||||
return [row["word_id"] for row in rows]
|
||||
|
||||
|
||||
def get_word_counts(total_vocab_size=0):
|
||||
"""Return dict with total/seen/mastered/due counts for dashboard."""
|
||||
conn = get_connection()
|
||||
now = datetime.now(timezone.utc).isoformat()
|
||||
|
||||
seen = conn.execute("SELECT COUNT(*) FROM word_progress").fetchone()[0]
|
||||
mastered = conn.execute(
|
||||
"SELECT COUNT(*) FROM word_progress WHERE stability > 10"
|
||||
).fetchone()[0]
|
||||
due = conn.execute(
|
||||
"SELECT COUNT(*) FROM word_progress WHERE due <= ?", (now,)
|
||||
).fetchone()[0]
|
||||
|
||||
return {
|
||||
"total": total_vocab_size,
|
||||
"seen": seen,
|
||||
"mastered": mastered,
|
||||
"due": due,
|
||||
}
|
||||
|
||||
|
||||
def get_all_word_progress():
|
||||
"""Return all word progress as a dict of word_id -> progress dict."""
|
||||
conn = get_connection()
|
||||
rows = conn.execute("SELECT * FROM word_progress").fetchall()
|
||||
return {row["word_id"]: dict(row) for row in rows}
|
||||
|
||||
|
||||
def record_quiz_session(category, total_questions, correct, duration_seconds):
|
||||
"""Log a completed flashcard session."""
|
||||
conn = get_connection()
|
||||
conn.execute(
|
||||
"INSERT INTO quiz_sessions (category, total_questions, correct, duration_seconds) VALUES (?, ?, ?, ?)",
|
||||
(category, total_questions, correct, duration_seconds),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def save_essay(essay_text, grade, feedback, theme):
|
||||
"""Save an essay + AI feedback."""
|
||||
conn = get_connection()
|
||||
conn.execute(
|
||||
"INSERT INTO essays (essay_text, grade, feedback, theme) VALUES (?, ?, ?, ?)",
|
||||
(essay_text, grade, feedback, theme),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def save_tutor_session(theme, messages, duration_seconds):
|
||||
"""Save a tutor conversation."""
|
||||
conn = get_connection()
|
||||
conn.execute(
|
||||
"INSERT INTO tutor_sessions (theme, messages, duration_seconds) VALUES (?, ?, ?)",
|
||||
(theme, json.dumps(messages, ensure_ascii=False), duration_seconds),
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
|
||||
def get_stats():
|
||||
"""Aggregate data for the dashboard."""
|
||||
conn = get_connection()
|
||||
|
||||
recent_quizzes = conn.execute(
|
||||
"SELECT * FROM quiz_sessions ORDER BY timestamp DESC LIMIT 10"
|
||||
).fetchall()
|
||||
|
||||
total_reviews = conn.execute(
|
||||
"SELECT COALESCE(SUM(reps), 0) FROM word_progress"
|
||||
).fetchone()[0]
|
||||
|
||||
total_quizzes = conn.execute(
|
||||
"SELECT COUNT(*) FROM quiz_sessions"
|
||||
).fetchone()[0]
|
||||
|
||||
# Streak: count consecutive days with activity
|
||||
days = conn.execute(
|
||||
"SELECT DISTINCT DATE(last_review) as d FROM word_progress WHERE last_review IS NOT NULL ORDER BY d DESC"
|
||||
).fetchall()
|
||||
|
||||
streak = 0
|
||||
today = datetime.now(timezone.utc).date()
|
||||
for i, row in enumerate(days):
|
||||
day = datetime.fromisoformat(row["d"]).date() if isinstance(row["d"], str) else row["d"]
|
||||
expected = today - timedelta(days=i)
|
||||
if day == expected:
|
||||
streak += 1
|
||||
else:
|
||||
break
|
||||
|
||||
return {
|
||||
"recent_quizzes": [dict(r) for r in recent_quizzes],
|
||||
"total_reviews": total_reviews,
|
||||
"total_quizzes": total_quizzes,
|
||||
"streak": streak,
|
||||
}
|
||||
|
||||
|
||||
def get_recent_essays(limit=10):
|
||||
"""Return recent essays for the essay history view."""
|
||||
conn = get_connection()
|
||||
rows = conn.execute(
|
||||
"SELECT * FROM essays ORDER BY timestamp DESC LIMIT ?", (limit,)
|
||||
).fetchall()
|
||||
return [dict(r) for r in rows]
|
||||
|
||||
|
||||
def close():
|
||||
"""Close the database connection."""
|
||||
global _conn
|
||||
if _conn:
|
||||
_conn.close()
|
||||
_conn = None
|
||||
0
python/persian-tutor/modules/__init__.py
Normal file
0
python/persian-tutor/modules/__init__.py
Normal file
84
python/persian-tutor/modules/dashboard.py
Normal file
84
python/persian-tutor/modules/dashboard.py
Normal file
@@ -0,0 +1,84 @@
|
||||
"""Dashboard: progress stats, charts, and overview."""
|
||||
|
||||
import db
|
||||
from modules.vocab import load_vocab, get_categories
|
||||
|
||||
|
||||
def get_overview():
|
||||
"""Return overview stats: total words, seen, mastered, due today."""
|
||||
vocab = load_vocab()
|
||||
counts = db.get_word_counts(total_vocab_size=len(vocab))
|
||||
stats = db.get_stats()
|
||||
counts["streak"] = stats["streak"]
|
||||
counts["total_reviews"] = stats["total_reviews"]
|
||||
counts["total_quizzes"] = stats["total_quizzes"]
|
||||
return counts
|
||||
|
||||
|
||||
def get_category_breakdown():
|
||||
"""Return progress per category as list of dicts."""
|
||||
vocab = load_vocab()
|
||||
categories = get_categories()
|
||||
all_progress = db.get_all_word_progress()
|
||||
|
||||
breakdown = []
|
||||
for cat in categories:
|
||||
cat_words = [e for e in vocab if e["category"] == cat]
|
||||
total = len(cat_words)
|
||||
|
||||
seen = 0
|
||||
mastered = 0
|
||||
for e in cat_words:
|
||||
progress = all_progress.get(e["id"])
|
||||
if progress:
|
||||
seen += 1
|
||||
if progress["stability"] and progress["stability"] > 10:
|
||||
mastered += 1
|
||||
|
||||
breakdown.append({
|
||||
"Category": cat,
|
||||
"Total": total,
|
||||
"Seen": seen,
|
||||
"Mastered": mastered,
|
||||
"Progress": f"{seen}/{total}" if total > 0 else "0/0",
|
||||
})
|
||||
|
||||
return breakdown
|
||||
|
||||
|
||||
def get_recent_quizzes(limit=10):
|
||||
"""Return recent quiz results as list of dicts for display."""
|
||||
stats = db.get_stats()
|
||||
quizzes = stats["recent_quizzes"][:limit]
|
||||
result = []
|
||||
for q in quizzes:
|
||||
result.append({
|
||||
"Date": q["timestamp"],
|
||||
"Category": q["category"] or "All",
|
||||
"Score": f"{q['correct']}/{q['total_questions']}",
|
||||
"Duration": f"{q['duration_seconds'] or 0}s",
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def format_overview_markdown():
|
||||
"""Format overview stats as a markdown string for display."""
|
||||
o = get_overview()
|
||||
pct = (o["seen"] / o["total"] * 100) if o["total"] > 0 else 0
|
||||
bar_filled = int(pct / 5)
|
||||
bar_empty = 20 - bar_filled
|
||||
progress_bar = "█" * bar_filled + "░" * bar_empty
|
||||
|
||||
lines = [
|
||||
"## Dashboard",
|
||||
"",
|
||||
f"**Words studied:** {o['seen']} / {o['total']} ({pct:.0f}%)",
|
||||
f"`{progress_bar}`",
|
||||
"",
|
||||
f"**Due today:** {o['due']}",
|
||||
f"**Mastered:** {o['mastered']}",
|
||||
f"**Daily streak:** {o['streak']} day{'s' if o['streak'] != 1 else ''}",
|
||||
f"**Total reviews:** {o['total_reviews']}",
|
||||
f"**Quiz sessions:** {o['total_quizzes']}",
|
||||
]
|
||||
return "\n".join(lines)
|
||||
78
python/persian-tutor/modules/essay.py
Normal file
78
python/persian-tutor/modules/essay.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""Essay writing and AI marking."""
|
||||
|
||||
import db
|
||||
from ai import ask
|
||||
|
||||
MARKING_SYSTEM_PROMPT = """You are an expert Persian (Farsi) language teacher marking a GCSE-level essay.
|
||||
You write in English but can read and correct Persian text.
|
||||
Always provide constructive, encouraging feedback suitable for a language learner."""
|
||||
|
||||
MARKING_PROMPT_TEMPLATE = """Please mark this Persian essay written by a GCSE student.
|
||||
|
||||
Theme: {theme}
|
||||
|
||||
Student's essay:
|
||||
{essay_text}
|
||||
|
||||
Please provide your response in this exact format:
|
||||
|
||||
**Grade:** [Give a grade from 1-9 matching GCSE grading, or a descriptive level like A2/B1]
|
||||
|
||||
**Summary:** [1-2 sentence overview of the essay quality]
|
||||
|
||||
**Corrections:**
|
||||
[List specific errors with corrections. For each error, show the original text and the corrected version in Persian, with an English explanation]
|
||||
|
||||
**Improved version:**
|
||||
[Rewrite the essay in corrected Persian]
|
||||
|
||||
**Tips for improvement:**
|
||||
[3-5 specific, actionable tips for the student]"""
|
||||
|
||||
|
||||
GCSE_THEMES = [
|
||||
"Identity and culture",
|
||||
"Local area and environment",
|
||||
"School and work",
|
||||
"Travel and tourism",
|
||||
"International and global dimension",
|
||||
]
|
||||
|
||||
|
||||
def mark_essay(essay_text, theme="General"):
|
||||
"""Send essay to AI for marking. Returns structured feedback."""
|
||||
if not essay_text or not essay_text.strip():
|
||||
return "Please write an essay first."
|
||||
|
||||
prompt = MARKING_PROMPT_TEMPLATE.format(
|
||||
theme=theme,
|
||||
essay_text=essay_text.strip(),
|
||||
)
|
||||
|
||||
feedback = ask(prompt, system=MARKING_SYSTEM_PROMPT, quality="smart")
|
||||
|
||||
# Extract grade from feedback (best-effort)
|
||||
grade = ""
|
||||
for line in feedback.split("\n"):
|
||||
if line.strip().startswith("**Grade:**"):
|
||||
grade = line.replace("**Grade:**", "").strip()
|
||||
break
|
||||
|
||||
# Save to database
|
||||
db.save_essay(essay_text.strip(), grade, feedback, theme)
|
||||
|
||||
return feedback
|
||||
|
||||
|
||||
def get_essay_history(limit=10):
|
||||
"""Return recent essays for the history view."""
|
||||
essays = db.get_recent_essays(limit)
|
||||
result = []
|
||||
for e in essays:
|
||||
result.append({
|
||||
"Date": e["timestamp"],
|
||||
"Theme": e["theme"] or "General",
|
||||
"Grade": e["grade"] or "-",
|
||||
"Preview": (e["essay_text"] or "")[:50] + "...",
|
||||
})
|
||||
return result
|
||||
200
python/persian-tutor/modules/idioms.py
Normal file
200
python/persian-tutor/modules/idioms.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""Persian idioms, expressions, and social conventions."""
|
||||
|
||||
from ai import ask
|
||||
|
||||
# Built-in collection of common Persian expressions and idioms
|
||||
EXPRESSIONS = [
|
||||
{
|
||||
"persian": "سلام علیکم",
|
||||
"finglish": "salâm aleykom",
|
||||
"english": "Peace be upon you (formal greeting)",
|
||||
"context": "Formal greeting, especially with elders",
|
||||
},
|
||||
{
|
||||
"persian": "خسته نباشید",
|
||||
"finglish": "khaste nabâshid",
|
||||
"english": "May you not be tired",
|
||||
"context": "Common greeting to someone who has been working. Used as 'hello' in shops, offices, etc.",
|
||||
},
|
||||
{
|
||||
"persian": "دستت درد نکنه",
|
||||
"finglish": "dastet dard nakone",
|
||||
"english": "May your hand not hurt",
|
||||
"context": "Thank you for your effort (after someone does something for you)",
|
||||
},
|
||||
{
|
||||
"persian": "قابلی نداره",
|
||||
"finglish": "ghâbeli nadâre",
|
||||
"english": "It's not worthy (of you)",
|
||||
"context": "You're welcome / Don't mention it — said when giving a gift or doing a favour",
|
||||
},
|
||||
{
|
||||
"persian": "تعارف نکن",
|
||||
"finglish": "ta'ârof nakon",
|
||||
"english": "Don't do ta'arof",
|
||||
"context": "Stop being politely modest — please accept! Part of Persian ta'arof culture.",
|
||||
},
|
||||
{
|
||||
"persian": "نوش جان",
|
||||
"finglish": "nush-e jân",
|
||||
"english": "May it nourish your soul",
|
||||
"context": "Said to someone eating — like 'bon appétit' or 'enjoy your meal'",
|
||||
},
|
||||
{
|
||||
"persian": "چشمت روز بد نبینه",
|
||||
"finglish": "cheshmet ruz-e bad nabine",
|
||||
"english": "May your eyes never see a bad day",
|
||||
"context": "A warm wish for someone's wellbeing",
|
||||
},
|
||||
{
|
||||
"persian": "قدمت روی چشم",
|
||||
"finglish": "ghadamet ru-ye cheshm",
|
||||
"english": "Your step is on my eye",
|
||||
"context": "Warm welcome — 'you're very welcome here'. Extremely hospitable expression.",
|
||||
},
|
||||
{
|
||||
"persian": "انشاءالله",
|
||||
"finglish": "inshâ'allâh",
|
||||
"english": "God willing",
|
||||
"context": "Used when talking about future plans. Very common in daily speech.",
|
||||
},
|
||||
{
|
||||
"persian": "ماشاءالله",
|
||||
"finglish": "mâshâ'allâh",
|
||||
"english": "What God has willed",
|
||||
"context": "Expression of admiration or praise, also used to ward off the evil eye.",
|
||||
},
|
||||
{
|
||||
"persian": "الهی شکر",
|
||||
"finglish": "elâhi shokr",
|
||||
"english": "Thank God",
|
||||
"context": "Expression of gratitude, similar to 'thankfully'",
|
||||
},
|
||||
{
|
||||
"persian": "به سلامتی",
|
||||
"finglish": "be salâmati",
|
||||
"english": "To your health / Cheers",
|
||||
"context": "A toast or general well-wishing expression",
|
||||
},
|
||||
{
|
||||
"persian": "عید مبارک",
|
||||
"finglish": "eyd mobârak",
|
||||
"english": "Happy holiday/celebration",
|
||||
"context": "Used for any celebration, especially Nowruz",
|
||||
},
|
||||
{
|
||||
"persian": "تسلیت میگم",
|
||||
"finglish": "tasliyat migam",
|
||||
"english": "I offer my condolences",
|
||||
"context": "Expressing sympathy when someone has lost a loved one",
|
||||
},
|
||||
{
|
||||
"persian": "خدا بیامرزه",
|
||||
"finglish": "khodâ biâmorzesh",
|
||||
"english": "May God forgive them (rest in peace)",
|
||||
"context": "Said about someone who has passed away",
|
||||
},
|
||||
{
|
||||
"persian": "زبونت رو گاز بگیر",
|
||||
"finglish": "zaboonet ro gâz begir",
|
||||
"english": "Bite your tongue",
|
||||
"context": "Don't say such things! (similar to English 'touch wood')",
|
||||
},
|
||||
{
|
||||
"persian": "دمت گرم",
|
||||
"finglish": "damet garm",
|
||||
"english": "May your breath be warm",
|
||||
"context": "Well done! / Good for you! (informal, friendly praise)",
|
||||
},
|
||||
{
|
||||
"persian": "چشم",
|
||||
"finglish": "cheshm",
|
||||
"english": "On my eye (I will do it)",
|
||||
"context": "Respectful way of saying 'yes, I'll do it' — shows obedience/respect",
|
||||
},
|
||||
{
|
||||
"persian": "بفرمایید",
|
||||
"finglish": "befarmâyid",
|
||||
"english": "Please (go ahead / help yourself / come in)",
|
||||
"context": "Very versatile polite expression: offering food, inviting someone in, or giving way",
|
||||
},
|
||||
{
|
||||
"persian": "ببخشید",
|
||||
"finglish": "bebakhshid",
|
||||
"english": "Excuse me / I'm sorry",
|
||||
"context": "Used for both apologies and getting someone's attention",
|
||||
},
|
||||
{
|
||||
"persian": "مخلصیم",
|
||||
"finglish": "mokhlesim",
|
||||
"english": "I'm your humble servant",
|
||||
"context": "Polite/humble way of saying goodbye or responding to a compliment (ta'arof)",
|
||||
},
|
||||
{
|
||||
"persian": "سرت سلامت باشه",
|
||||
"finglish": "saret salâmat bâshe",
|
||||
"english": "May your head be safe",
|
||||
"context": "Expression of condolence — 'I'm sorry for your loss'",
|
||||
},
|
||||
{
|
||||
"persian": "روی ما رو زمین ننداز",
|
||||
"finglish": "ru-ye mâ ro zamin nandâz",
|
||||
"english": "Don't throw our face on the ground",
|
||||
"context": "Please don't refuse/embarrass us — said when insisting on a request",
|
||||
},
|
||||
{
|
||||
"persian": "قربونت برم",
|
||||
"finglish": "ghorboonet beram",
|
||||
"english": "I'd sacrifice myself for you",
|
||||
"context": "Term of endearment — very common among family and close friends",
|
||||
},
|
||||
{
|
||||
"persian": "جون دل",
|
||||
"finglish": "jun-e del",
|
||||
"english": "Life of my heart",
|
||||
"context": "Affectionate term used with loved ones",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def get_all_expressions():
|
||||
"""Return all built-in expressions."""
|
||||
return EXPRESSIONS
|
||||
|
||||
|
||||
def get_random_expression():
|
||||
"""Pick a random expression."""
|
||||
import random
|
||||
return random.choice(EXPRESSIONS)
|
||||
|
||||
|
||||
def explain_expression(expression):
|
||||
"""Use AI to generate a detailed explanation with usage examples."""
|
||||
prompt = f"""Explain this Persian expression for an English-speaking student:
|
||||
|
||||
Persian: {expression['persian']}
|
||||
Transliteration: {expression['finglish']}
|
||||
Literal meaning: {expression['english']}
|
||||
Context: {expression['context']}
|
||||
|
||||
Please provide:
|
||||
1. A fuller explanation of when and how this is used
|
||||
2. The cultural context (ta'arof, hospitality, etc.)
|
||||
3. Two example dialogues showing it in use (in Persian with English translation)
|
||||
4. Any variations or related expressions
|
||||
|
||||
Keep it concise and student-friendly."""
|
||||
|
||||
return ask(prompt, quality="fast")
|
||||
|
||||
|
||||
def format_expression(expr, show_transliteration="off"):
|
||||
"""Format an expression for display."""
|
||||
parts = [
|
||||
f'<div dir="rtl" style="font-size:1.8em; text-align:center">{expr["persian"]}</div>',
|
||||
f'<div style="text-align:center; font-size:1.2em">{expr["english"]}</div>',
|
||||
]
|
||||
if show_transliteration != "off":
|
||||
parts.append(f'<div style="text-align:center; color:#666; font-style:italic">{expr["finglish"]}</div>')
|
||||
parts.append(f'<div style="text-align:center; color:#888; margin-top:0.5em">{expr["context"]}</div>')
|
||||
return "\n".join(parts)
|
||||
65
python/persian-tutor/modules/tutor.py
Normal file
65
python/persian-tutor/modules/tutor.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Conversational Persian lessons by GCSE theme."""
|
||||
|
||||
import time
|
||||
|
||||
import db
|
||||
from ai import chat_ollama
|
||||
|
||||
TUTOR_SYSTEM_PROMPT = """You are a friendly Persian (Farsi) language tutor teaching English-speaking GCSE students.
|
||||
|
||||
Rules:
|
||||
- Use a mix of English and Persian. Start mostly in English, gradually introducing more Persian.
|
||||
- When you write Persian, also provide the Finglish transliteration in parentheses.
|
||||
- Keep responses concise (2-4 sentences per turn).
|
||||
- Ask the student to practice: translate phrases, answer questions in Persian, or fill in blanks.
|
||||
- Correct mistakes gently and explain why.
|
||||
- Stay on the current theme/topic.
|
||||
- Use Iranian Persian (Farsi), not Dari or Tajik.
|
||||
- Adapt to the student's level based on their responses."""
|
||||
|
||||
THEME_PROMPTS = {
|
||||
"Identity and culture": "Let's practice talking about family, personality, daily routines, and Persian celebrations like Nowruz!",
|
||||
"Local area and environment": "Let's practice talking about your home, neighbourhood, shopping, and the environment!",
|
||||
"School and work": "Let's practice talking about school subjects, school life, jobs, and future plans!",
|
||||
"Travel and tourism": "Let's practice talking about transport, directions, holidays, hotels, and restaurants!",
|
||||
"International and global dimension": "Let's practice talking about health, global issues, technology, and social media!",
|
||||
"Free conversation": "Let's have a free conversation in Persian! I'll help you along the way.",
|
||||
}
|
||||
|
||||
|
||||
def start_lesson(theme):
|
||||
"""Generate the opening message for a new lesson.
|
||||
|
||||
Returns:
|
||||
(assistant_message, messages_list)
|
||||
"""
|
||||
intro = THEME_PROMPTS.get(theme, THEME_PROMPTS["Free conversation"])
|
||||
system = TUTOR_SYSTEM_PROMPT + f"\n\nCurrent topic: {theme}. {intro}"
|
||||
|
||||
messages = [{"role": "user", "content": f"I'd like to practice Persian. Today's theme is: {theme}"}]
|
||||
response = chat_ollama(messages, system=system)
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
return response, messages, system
|
||||
|
||||
|
||||
def process_response(user_input, messages, system=None):
|
||||
"""Add user input to conversation, get AI response.
|
||||
|
||||
Returns:
|
||||
(assistant_response, updated_messages)
|
||||
"""
|
||||
if not user_input or not user_input.strip():
|
||||
return "", messages
|
||||
|
||||
messages.append({"role": "user", "content": user_input.strip()})
|
||||
response = chat_ollama(messages, system=system)
|
||||
messages.append({"role": "assistant", "content": response})
|
||||
|
||||
return response, messages
|
||||
|
||||
|
||||
def save_session(theme, messages, start_time):
|
||||
"""Save the current tutor session to the database."""
|
||||
duration = int(time.time() - start_time)
|
||||
db.save_tutor_session(theme, messages, duration)
|
||||
153
python/persian-tutor/modules/vocab.py
Normal file
153
python/persian-tutor/modules/vocab.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Vocabulary search, flashcard logic, and FSRS-driven review."""
|
||||
|
||||
import json
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import fsrs
|
||||
|
||||
import db
|
||||
|
||||
|
||||
VOCAB_PATH = Path(__file__).parent.parent / "data" / "vocabulary.json"
|
||||
|
||||
_vocab_data = None
|
||||
|
||||
|
||||
def load_vocab():
|
||||
"""Load vocabulary data from JSON (cached)."""
|
||||
global _vocab_data
|
||||
if _vocab_data is None:
|
||||
with open(VOCAB_PATH, encoding="utf-8") as f:
|
||||
_vocab_data = json.load(f)
|
||||
return _vocab_data
|
||||
|
||||
|
||||
def get_categories():
|
||||
"""Return sorted list of unique categories."""
|
||||
vocab = load_vocab()
|
||||
return sorted({entry["category"] for entry in vocab})
|
||||
|
||||
|
||||
def get_sections():
|
||||
"""Return sorted list of unique sections."""
|
||||
vocab = load_vocab()
|
||||
return sorted({entry["section"] for entry in vocab})
|
||||
|
||||
|
||||
def search(query, vocab_data=None):
|
||||
"""Search vocabulary by English or Persian text. Returns matching entries."""
|
||||
if not query or not query.strip():
|
||||
return []
|
||||
vocab = vocab_data or load_vocab()
|
||||
query_lower = query.strip().lower()
|
||||
results = []
|
||||
for entry in vocab:
|
||||
if (
|
||||
query_lower in entry["english"].lower()
|
||||
or query_lower in entry["persian"]
|
||||
or (entry.get("finglish") and query_lower in entry["finglish"].lower())
|
||||
):
|
||||
results.append(entry)
|
||||
return results
|
||||
|
||||
|
||||
def get_random_word(vocab_data=None, category=None):
|
||||
"""Pick a random vocabulary entry, optionally filtered by category."""
|
||||
vocab = vocab_data or load_vocab()
|
||||
if category and category != "All":
|
||||
filtered = [e for e in vocab if e["category"] == category]
|
||||
else:
|
||||
filtered = vocab
|
||||
if not filtered:
|
||||
return None
|
||||
return random.choice(filtered)
|
||||
|
||||
|
||||
def get_flashcard_batch(count=10, category=None):
|
||||
"""Get a batch of words for flashcard study.
|
||||
|
||||
Prioritizes due words (FSRS), then fills with new/random words.
|
||||
"""
|
||||
vocab = load_vocab()
|
||||
|
||||
if category and category != "All":
|
||||
pool = [e for e in vocab if e["category"] == category]
|
||||
else:
|
||||
pool = vocab
|
||||
|
||||
# Get due words first
|
||||
due_ids = db.get_due_words(limit=count)
|
||||
due_entries = [e for e in pool if e["id"] in due_ids]
|
||||
|
||||
# Fill remaining with unseen or random words
|
||||
remaining = count - len(due_entries)
|
||||
if remaining > 0:
|
||||
seen_ids = {e["id"] for e in due_entries}
|
||||
all_progress = db.get_all_word_progress()
|
||||
# Prefer unseen words
|
||||
unseen = [e for e in pool if e["id"] not in seen_ids and e["id"] not in all_progress]
|
||||
if len(unseen) >= remaining:
|
||||
fill = random.sample(unseen, remaining)
|
||||
else:
|
||||
# Use all unseen + random from rest
|
||||
fill = unseen
|
||||
still_needed = remaining - len(fill)
|
||||
rest = [e for e in pool if e["id"] not in seen_ids and e not in fill]
|
||||
if rest:
|
||||
fill.extend(random.sample(rest, min(still_needed, len(rest))))
|
||||
due_entries.extend(fill)
|
||||
|
||||
random.shuffle(due_entries)
|
||||
return due_entries
|
||||
|
||||
|
||||
def check_answer(word_id, user_answer, direction="en_to_fa"):
|
||||
"""Check if user's answer matches the target word.
|
||||
|
||||
Args:
|
||||
word_id: Vocabulary entry ID.
|
||||
user_answer: What the user typed.
|
||||
direction: "en_to_fa" (user writes Persian) or "fa_to_en" (user writes English).
|
||||
|
||||
Returns:
|
||||
(is_correct, correct_answer, entry)
|
||||
"""
|
||||
vocab = load_vocab()
|
||||
entry = next((e for e in vocab if e["id"] == word_id), None)
|
||||
if not entry:
|
||||
return False, "", None
|
||||
|
||||
user_answer = user_answer.strip()
|
||||
|
||||
if direction == "en_to_fa":
|
||||
correct = entry["persian"].strip()
|
||||
is_correct = user_answer == correct
|
||||
else:
|
||||
correct = entry["english"].strip().lower()
|
||||
is_correct = user_answer.lower() == correct
|
||||
|
||||
return is_correct, correct if not is_correct else user_answer, entry
|
||||
|
||||
|
||||
def format_word_card(entry, show_transliteration="off"):
|
||||
"""Format a vocabulary entry for display as RTL-safe markdown."""
|
||||
parts = []
|
||||
parts.append(f'<div dir="rtl" style="font-size:2em; text-align:center">{entry["persian"]}</div>')
|
||||
parts.append(f'<div style="font-size:1.3em; text-align:center">{entry["english"]}</div>')
|
||||
|
||||
if show_transliteration != "off" and entry.get("finglish"):
|
||||
parts.append(f'<div style="text-align:center; color:#666; font-style:italic">{entry["finglish"]}</div>')
|
||||
|
||||
parts.append(f'<div style="text-align:center; color:#999; font-size:0.9em">{entry.get("category", "")}</div>')
|
||||
return "\n".join(parts)
|
||||
|
||||
|
||||
def get_word_status(word_id):
|
||||
"""Return status string for a word: new, learning, or mastered."""
|
||||
progress = db.get_word_progress(word_id)
|
||||
if not progress:
|
||||
return "new"
|
||||
if progress["stability"] and progress["stability"] > 10:
|
||||
return "mastered"
|
||||
return "learning"
|
||||
3
python/persian-tutor/requirements.txt
Normal file
3
python/persian-tutor/requirements.txt
Normal file
@@ -0,0 +1,3 @@
|
||||
gradio>=4.0
|
||||
genanki
|
||||
fsrs
|
||||
1100
python/persian-tutor/scripts/build_vocab.py
Normal file
1100
python/persian-tutor/scripts/build_vocab.py
Normal file
File diff suppressed because it is too large
Load Diff
81
python/persian-tutor/scripts/generate_vocab.py
Normal file
81
python/persian-tutor/scripts/generate_vocab.py
Normal file
@@ -0,0 +1,81 @@
|
||||
#!/usr/bin/env python3
|
||||
"""One-time script to generate/update vocabulary.json with AI-assisted transliterations.
|
||||
|
||||
Usage:
|
||||
python scripts/generate_vocab.py
|
||||
|
||||
This reads an existing vocabulary.json, finds entries missing finglish
|
||||
transliterations, and uses Ollama to generate them.
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
from ai import ask_ollama
|
||||
|
||||
VOCAB_PATH = Path(__file__).parent.parent / "data" / "vocabulary.json"
|
||||
|
||||
|
||||
def generate_transliterations(vocab):
|
||||
"""Fill in missing finglish transliterations using AI."""
|
||||
missing = [e for e in vocab if not e.get("finglish")]
|
||||
if not missing:
|
||||
print("All entries already have finglish transliterations.")
|
||||
return vocab
|
||||
|
||||
print(f"Generating transliterations for {len(missing)} entries...")
|
||||
|
||||
# Process in batches of 20
|
||||
batch_size = 20
|
||||
for i in range(0, len(missing), batch_size):
|
||||
batch = missing[i : i + batch_size]
|
||||
pairs = "\n".join(f"{e['persian']} = {e['english']}" for e in batch)
|
||||
|
||||
prompt = f"""For each Persian word below, provide the Finglish (romanized) transliteration.
|
||||
Use these conventions: â for آ, kh for خ, sh for ش, zh for ژ, gh for ق/غ, ch for چ.
|
||||
Reply with ONLY the transliterations, one per line, in the same order.
|
||||
|
||||
{pairs}"""
|
||||
|
||||
try:
|
||||
response = ask_ollama(prompt, model="qwen2.5:7b")
|
||||
lines = [l.strip() for l in response.strip().split("\n") if l.strip()]
|
||||
|
||||
for j, entry in enumerate(batch):
|
||||
if j < len(lines):
|
||||
# Clean up the response line
|
||||
line = lines[j]
|
||||
# Remove any numbering or equals signs
|
||||
for sep in ["=", ":", "-", "."]:
|
||||
if sep in line:
|
||||
line = line.split(sep)[-1].strip()
|
||||
entry["finglish"] = line
|
||||
|
||||
print(f" Processed {min(i + batch_size, len(missing))}/{len(missing)}")
|
||||
except Exception as e:
|
||||
print(f" Error processing batch: {e}")
|
||||
|
||||
return vocab
|
||||
|
||||
|
||||
def main():
|
||||
if not VOCAB_PATH.exists():
|
||||
print(f"No vocabulary file found at {VOCAB_PATH}")
|
||||
return
|
||||
|
||||
with open(VOCAB_PATH, encoding="utf-8") as f:
|
||||
vocab = json.load(f)
|
||||
|
||||
print(f"Loaded {len(vocab)} entries")
|
||||
vocab = generate_transliterations(vocab)
|
||||
|
||||
with open(VOCAB_PATH, "w", encoding="utf-8") as f:
|
||||
json.dump(vocab, f, ensure_ascii=False, indent=2)
|
||||
|
||||
print(f"Saved {len(vocab)} entries to {VOCAB_PATH}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
77
python/persian-tutor/stt.py
Normal file
77
python/persian-tutor/stt.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""Persian speech-to-text wrapper using sttlib."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
# sttlib lives in sibling project tool-speechtotext
|
||||
_sttlib_path = str(Path(__file__).resolve().parent.parent / "tool-speechtotext")
|
||||
sys.path.insert(0, _sttlib_path)
|
||||
from sttlib import load_whisper_model, transcribe, is_hallucination
|
||||
|
||||
_model = None
|
||||
_whisper_size = "medium"
|
||||
|
||||
# Common Whisper hallucinations in Persian/silence
|
||||
PERSIAN_HALLUCINATIONS = [
|
||||
"ممنون", # "thank you" hallucination
|
||||
"خداحافظ", # "goodbye" hallucination
|
||||
"تماشا کنید", # "watch" hallucination
|
||||
"لایک کنید", # "like" hallucination
|
||||
]
|
||||
|
||||
|
||||
def set_whisper_size(size):
|
||||
"""Change the Whisper model size. Reloads on next transcription."""
|
||||
global _whisper_size, _model
|
||||
if size != _whisper_size:
|
||||
_whisper_size = size
|
||||
_model = None
|
||||
|
||||
|
||||
def get_model():
|
||||
"""Load Whisper model (cached singleton)."""
|
||||
global _model
|
||||
if _model is None:
|
||||
_model = load_whisper_model(_whisper_size)
|
||||
return _model
|
||||
|
||||
|
||||
def transcribe_persian(audio_tuple):
|
||||
"""Transcribe Persian audio from Gradio audio component.
|
||||
|
||||
Args:
|
||||
audio_tuple: (sample_rate, numpy_array) from gr.Audio component.
|
||||
|
||||
Returns:
|
||||
Transcribed text string, or empty string on failure/hallucination.
|
||||
"""
|
||||
if audio_tuple is None:
|
||||
return ""
|
||||
|
||||
sr, audio = audio_tuple
|
||||
model = get_model()
|
||||
|
||||
# Convert to float32 normalized [-1, 1]
|
||||
if audio.dtype == np.int16:
|
||||
audio_float = audio.astype(np.float32) / 32768.0
|
||||
elif audio.dtype == np.float32:
|
||||
audio_float = audio
|
||||
else:
|
||||
audio_float = audio.astype(np.float32) / np.iinfo(audio.dtype).max
|
||||
|
||||
# Mono conversion if stereo
|
||||
if audio_float.ndim > 1:
|
||||
audio_float = audio_float.mean(axis=1)
|
||||
|
||||
# Use sttlib transcribe
|
||||
text = transcribe(model, audio_float)
|
||||
|
||||
# Filter hallucinations (English + Persian)
|
||||
if is_hallucination(text):
|
||||
return ""
|
||||
if text.strip() in PERSIAN_HALLUCINATIONS:
|
||||
return ""
|
||||
|
||||
return text
|
||||
0
python/persian-tutor/tests/__init__.py
Normal file
0
python/persian-tutor/tests/__init__.py
Normal file
89
python/persian-tutor/tests/test_ai.py
Normal file
89
python/persian-tutor/tests/test_ai.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Tests for ai.py — dual AI backend."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import ai
|
||||
|
||||
|
||||
def test_ask_ollama_calls_ollama_chat():
|
||||
"""ask_ollama should call ollama.chat with correct messages."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.content = "test response"
|
||||
|
||||
with patch("ai.ollama.chat", return_value=mock_response) as mock_chat:
|
||||
result = ai.ask_ollama("Hello", system="Be helpful")
|
||||
assert result == "test response"
|
||||
|
||||
call_args = mock_chat.call_args
|
||||
messages = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[1]["role"] == "user"
|
||||
assert messages[1]["content"] == "Hello"
|
||||
|
||||
|
||||
def test_ask_ollama_no_system():
|
||||
"""ask_ollama without system prompt should only send user message."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.content = "response"
|
||||
|
||||
with patch("ai.ollama.chat", return_value=mock_response) as mock_chat:
|
||||
ai.ask_ollama("Hi")
|
||||
call_args = mock_chat.call_args
|
||||
messages = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "user"
|
||||
|
||||
|
||||
def test_ask_claude_calls_subprocess():
|
||||
"""ask_claude should call claude CLI via subprocess."""
|
||||
with patch("ai.subprocess.run") as mock_run:
|
||||
mock_run.return_value = MagicMock(stdout="Claude says hi\n")
|
||||
result = ai.ask_claude("Hello")
|
||||
assert result == "Claude says hi"
|
||||
mock_run.assert_called_once()
|
||||
args = mock_run.call_args[0][0]
|
||||
assert args[0] == "claude"
|
||||
assert "-p" in args
|
||||
|
||||
|
||||
def test_ask_fast_uses_ollama():
|
||||
"""ask with quality='fast' should use Ollama."""
|
||||
with patch("ai.ask_ollama", return_value="ollama response") as mock:
|
||||
result = ai.ask("test", quality="fast")
|
||||
assert result == "ollama response"
|
||||
mock.assert_called_once()
|
||||
|
||||
|
||||
def test_ask_smart_uses_claude():
|
||||
"""ask with quality='smart' should use Claude."""
|
||||
with patch("ai.ask_claude", return_value="claude response") as mock:
|
||||
result = ai.ask("test", quality="smart")
|
||||
assert result == "claude response"
|
||||
mock.assert_called_once()
|
||||
|
||||
|
||||
def test_chat_ollama():
|
||||
"""chat_ollama should pass multi-turn messages."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.content = "continuation"
|
||||
|
||||
with patch("ai.ollama.chat", return_value=mock_response) as mock_chat:
|
||||
messages = [
|
||||
{"role": "user", "content": "Hi"},
|
||||
{"role": "assistant", "content": "Hello!"},
|
||||
{"role": "user", "content": "How are you?"},
|
||||
]
|
||||
result = ai.chat_ollama(messages, system="Be helpful")
|
||||
assert result == "continuation"
|
||||
|
||||
call_args = mock_chat.call_args
|
||||
all_msgs = call_args.kwargs.get("messages") or call_args[1].get("messages")
|
||||
# system + 3 conversation messages
|
||||
assert len(all_msgs) == 4
|
||||
86
python/persian-tutor/tests/test_anki_export.py
Normal file
86
python/persian-tutor/tests/test_anki_export.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Tests for anki_export.py — Anki .apkg generation."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import zipfile
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
from anki_export import export_deck
|
||||
|
||||
SAMPLE_VOCAB = [
|
||||
{
|
||||
"id": "verb_go",
|
||||
"section": "High-frequency language",
|
||||
"category": "Common verbs",
|
||||
"english": "to go",
|
||||
"persian": "رفتن",
|
||||
"finglish": "raftan",
|
||||
},
|
||||
{
|
||||
"id": "verb_eat",
|
||||
"section": "High-frequency language",
|
||||
"category": "Common verbs",
|
||||
"english": "to eat",
|
||||
"persian": "خوردن",
|
||||
"finglish": "khordan",
|
||||
},
|
||||
{
|
||||
"id": "colour_red",
|
||||
"section": "High-frequency language",
|
||||
"category": "Colours",
|
||||
"english": "red",
|
||||
"persian": "قرمز",
|
||||
"finglish": "ghermez",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
def test_export_deck_creates_file(tmp_path):
|
||||
"""export_deck should create a valid .apkg file."""
|
||||
output = str(tmp_path / "test.apkg")
|
||||
result = export_deck(SAMPLE_VOCAB, output_path=output)
|
||||
assert result == output
|
||||
assert os.path.exists(output)
|
||||
assert os.path.getsize(output) > 0
|
||||
|
||||
|
||||
def test_export_deck_is_valid_zip(tmp_path):
|
||||
"""An .apkg file is a zip archive containing an Anki SQLite database."""
|
||||
output = str(tmp_path / "test.apkg")
|
||||
export_deck(SAMPLE_VOCAB, output_path=output)
|
||||
assert zipfile.is_zipfile(output)
|
||||
|
||||
|
||||
def test_export_deck_with_category_filter(tmp_path):
|
||||
"""export_deck with category filter should only include matching entries."""
|
||||
output = str(tmp_path / "test.apkg")
|
||||
export_deck(SAMPLE_VOCAB, categories=["Colours"], output_path=output)
|
||||
# File should exist and be smaller than unfiltered
|
||||
assert os.path.exists(output)
|
||||
size_filtered = os.path.getsize(output)
|
||||
|
||||
output2 = str(tmp_path / "test_all.apkg")
|
||||
export_deck(SAMPLE_VOCAB, output_path=output2)
|
||||
size_all = os.path.getsize(output2)
|
||||
|
||||
# Filtered deck should be smaller (fewer cards)
|
||||
assert size_filtered <= size_all
|
||||
|
||||
|
||||
def test_export_deck_empty_vocab(tmp_path):
|
||||
"""export_deck with empty vocabulary should still create a valid file."""
|
||||
output = str(tmp_path / "test.apkg")
|
||||
export_deck([], output_path=output)
|
||||
assert os.path.exists(output)
|
||||
|
||||
|
||||
def test_export_deck_no_category_match(tmp_path):
|
||||
"""export_deck with non-matching category filter should create empty deck."""
|
||||
output = str(tmp_path / "test.apkg")
|
||||
export_deck(SAMPLE_VOCAB, categories=["Nonexistent"], output_path=output)
|
||||
assert os.path.exists(output)
|
||||
151
python/persian-tutor/tests/test_db.py
Normal file
151
python/persian-tutor/tests/test_db.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Tests for db.py — SQLite database layer with FSRS integration."""
|
||||
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
# Add project root to path
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import fsrs
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def temp_db(tmp_path):
|
||||
"""Use a temporary database for each test."""
|
||||
import db as db_mod
|
||||
|
||||
db_mod._conn = None
|
||||
db_mod.DB_PATH = tmp_path / "test.db"
|
||||
db_mod.init_db()
|
||||
yield db_mod
|
||||
db_mod.close()
|
||||
|
||||
|
||||
def test_init_db_creates_tables(temp_db):
|
||||
"""init_db should create all required tables."""
|
||||
conn = temp_db.get_connection()
|
||||
tables = conn.execute(
|
||||
"SELECT name FROM sqlite_master WHERE type='table'"
|
||||
).fetchall()
|
||||
table_names = {row["name"] for row in tables}
|
||||
assert "word_progress" in table_names
|
||||
assert "quiz_sessions" in table_names
|
||||
assert "essays" in table_names
|
||||
assert "tutor_sessions" in table_names
|
||||
|
||||
|
||||
def test_get_word_progress_nonexistent(temp_db):
|
||||
"""Should return None for a word that hasn't been reviewed."""
|
||||
assert temp_db.get_word_progress("nonexistent") is None
|
||||
|
||||
|
||||
def test_update_and_get_word_progress(temp_db):
|
||||
"""update_word_progress should create and update progress."""
|
||||
card = temp_db.update_word_progress("verb_go", fsrs.Rating.Good)
|
||||
assert card is not None
|
||||
assert card.stability is not None
|
||||
|
||||
progress = temp_db.get_word_progress("verb_go")
|
||||
assert progress is not None
|
||||
assert progress["word_id"] == "verb_go"
|
||||
assert progress["reps"] == 1
|
||||
assert progress["fsrs_state"] is not None
|
||||
|
||||
|
||||
def test_update_word_progress_increments_reps(temp_db):
|
||||
"""Reviewing the same word multiple times should increment reps."""
|
||||
temp_db.update_word_progress("verb_go", fsrs.Rating.Good)
|
||||
temp_db.update_word_progress("verb_go", fsrs.Rating.Easy)
|
||||
progress = temp_db.get_word_progress("verb_go")
|
||||
assert progress["reps"] == 2
|
||||
|
||||
|
||||
def test_get_due_words(temp_db):
|
||||
"""get_due_words should return words that are due for review."""
|
||||
# A newly reviewed word with Rating.Again should be due soon
|
||||
temp_db.update_word_progress("verb_go", fsrs.Rating.Again)
|
||||
# An easy word should have a later due date
|
||||
temp_db.update_word_progress("verb_eat", fsrs.Rating.Easy)
|
||||
|
||||
# Due words depend on timing; at minimum both should be in the system
|
||||
all_progress = temp_db.get_connection().execute(
|
||||
"SELECT word_id FROM word_progress"
|
||||
).fetchall()
|
||||
assert len(all_progress) == 2
|
||||
|
||||
|
||||
def test_get_word_counts(temp_db):
|
||||
"""get_word_counts should return correct counts."""
|
||||
counts = temp_db.get_word_counts(total_vocab_size=100)
|
||||
assert counts["total"] == 100
|
||||
assert counts["seen"] == 0
|
||||
assert counts["mastered"] == 0
|
||||
assert counts["due"] == 0
|
||||
|
||||
temp_db.update_word_progress("verb_go", fsrs.Rating.Good)
|
||||
counts = temp_db.get_word_counts(total_vocab_size=100)
|
||||
assert counts["seen"] == 1
|
||||
|
||||
|
||||
def test_record_quiz_session(temp_db):
|
||||
"""record_quiz_session should insert a quiz record."""
|
||||
temp_db.record_quiz_session("Common verbs", 10, 7, 120)
|
||||
rows = temp_db.get_connection().execute(
|
||||
"SELECT * FROM quiz_sessions"
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["correct"] == 7
|
||||
assert rows[0]["total_questions"] == 10
|
||||
|
||||
|
||||
def test_save_essay(temp_db):
|
||||
"""save_essay should store the essay and feedback."""
|
||||
temp_db.save_essay("متن آزمایشی", "B1", "Good effort!", "Identity and culture")
|
||||
essays = temp_db.get_recent_essays()
|
||||
assert len(essays) == 1
|
||||
assert essays[0]["grade"] == "B1"
|
||||
|
||||
|
||||
def test_save_tutor_session(temp_db):
|
||||
"""save_tutor_session should store the conversation."""
|
||||
messages = [
|
||||
{"role": "user", "content": "سلام"},
|
||||
{"role": "assistant", "content": "سلام! حالت چطوره؟"},
|
||||
]
|
||||
temp_db.save_tutor_session("Identity and culture", messages, 300)
|
||||
rows = temp_db.get_connection().execute(
|
||||
"SELECT * FROM tutor_sessions"
|
||||
).fetchall()
|
||||
assert len(rows) == 1
|
||||
assert rows[0]["theme"] == "Identity and culture"
|
||||
|
||||
|
||||
def test_get_stats(temp_db):
|
||||
"""get_stats should return aggregated stats."""
|
||||
stats = temp_db.get_stats()
|
||||
assert stats["total_reviews"] == 0
|
||||
assert stats["total_quizzes"] == 0
|
||||
assert stats["streak"] == 0
|
||||
assert isinstance(stats["recent_quizzes"], list)
|
||||
|
||||
|
||||
def test_close_and_reopen(temp_db):
|
||||
"""Closing and reopening should preserve data."""
|
||||
temp_db.update_word_progress("verb_go", fsrs.Rating.Good)
|
||||
db_path = temp_db.DB_PATH
|
||||
|
||||
temp_db.close()
|
||||
|
||||
# Reopen
|
||||
temp_db._conn = None
|
||||
temp_db.DB_PATH = db_path
|
||||
temp_db.init_db()
|
||||
|
||||
progress = temp_db.get_word_progress("verb_go")
|
||||
assert progress is not None
|
||||
assert progress["reps"] == 1
|
||||
204
python/persian-tutor/tests/test_vocab.py
Normal file
204
python/persian-tutor/tests/test_vocab.py
Normal file
@@ -0,0 +1,204 @@
|
||||
"""Tests for modules/vocab.py — vocabulary search and flashcard logic."""
|
||||
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
SAMPLE_VOCAB = [
|
||||
{
|
||||
"id": "verb_go",
|
||||
"section": "High-frequency language",
|
||||
"category": "Common verbs",
|
||||
"english": "to go",
|
||||
"persian": "رفتن",
|
||||
"finglish": "raftan",
|
||||
},
|
||||
{
|
||||
"id": "verb_eat",
|
||||
"section": "High-frequency language",
|
||||
"category": "Common verbs",
|
||||
"english": "to eat",
|
||||
"persian": "خوردن",
|
||||
"finglish": "khordan",
|
||||
},
|
||||
{
|
||||
"id": "adj_big",
|
||||
"section": "High-frequency language",
|
||||
"category": "Common adjectives",
|
||||
"english": "big",
|
||||
"persian": "بزرگ",
|
||||
"finglish": "bozorg",
|
||||
},
|
||||
{
|
||||
"id": "colour_red",
|
||||
"section": "High-frequency language",
|
||||
"category": "Colours",
|
||||
"english": "red",
|
||||
"persian": "قرمز",
|
||||
"finglish": "ghermez",
|
||||
},
|
||||
]
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def mock_vocab_and_db(tmp_path):
|
||||
"""Mock vocabulary loading and use temp DB."""
|
||||
import db as db_mod
|
||||
import modules.vocab as vocab_mod
|
||||
|
||||
# Temp DB
|
||||
db_mod._conn = None
|
||||
db_mod.DB_PATH = tmp_path / "test.db"
|
||||
db_mod.init_db()
|
||||
|
||||
# Mock vocab
|
||||
vocab_mod._vocab_data = SAMPLE_VOCAB
|
||||
|
||||
yield vocab_mod
|
||||
|
||||
db_mod.close()
|
||||
vocab_mod._vocab_data = None
|
||||
|
||||
|
||||
def test_load_vocab(mock_vocab_and_db):
|
||||
"""load_vocab should return the vocabulary data."""
|
||||
data = mock_vocab_and_db.load_vocab()
|
||||
assert len(data) == 4
|
||||
|
||||
|
||||
def test_get_categories(mock_vocab_and_db):
|
||||
"""get_categories should return unique sorted categories."""
|
||||
cats = mock_vocab_and_db.get_categories()
|
||||
assert "Colours" in cats
|
||||
assert "Common verbs" in cats
|
||||
assert "Common adjectives" in cats
|
||||
|
||||
|
||||
def test_search_english(mock_vocab_and_db):
|
||||
"""Search should find entries by English text."""
|
||||
results = mock_vocab_and_db.search("go")
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "verb_go"
|
||||
|
||||
|
||||
def test_search_persian(mock_vocab_and_db):
|
||||
"""Search should find entries by Persian text."""
|
||||
results = mock_vocab_and_db.search("رفتن")
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "verb_go"
|
||||
|
||||
|
||||
def test_search_finglish(mock_vocab_and_db):
|
||||
"""Search should find entries by Finglish text."""
|
||||
results = mock_vocab_and_db.search("raftan")
|
||||
assert len(results) == 1
|
||||
assert results[0]["id"] == "verb_go"
|
||||
|
||||
|
||||
def test_search_empty(mock_vocab_and_db):
|
||||
"""Empty search should return empty list."""
|
||||
assert mock_vocab_and_db.search("") == []
|
||||
assert mock_vocab_and_db.search(None) == []
|
||||
|
||||
|
||||
def test_search_no_match(mock_vocab_and_db):
|
||||
"""Search with no match should return empty list."""
|
||||
assert mock_vocab_and_db.search("zzzzz") == []
|
||||
|
||||
|
||||
def test_get_random_word(mock_vocab_and_db):
|
||||
"""get_random_word should return a valid entry."""
|
||||
word = mock_vocab_and_db.get_random_word()
|
||||
assert word is not None
|
||||
assert "id" in word
|
||||
assert "english" in word
|
||||
assert "persian" in word
|
||||
|
||||
|
||||
def test_get_random_word_with_category(mock_vocab_and_db):
|
||||
"""get_random_word with category filter should only return matching entries."""
|
||||
word = mock_vocab_and_db.get_random_word(category="Colours")
|
||||
assert word is not None
|
||||
assert word["category"] == "Colours"
|
||||
|
||||
|
||||
def test_get_random_word_nonexistent_category(mock_vocab_and_db):
|
||||
"""get_random_word with bad category should return None."""
|
||||
word = mock_vocab_and_db.get_random_word(category="Nonexistent")
|
||||
assert word is None
|
||||
|
||||
|
||||
def test_check_answer_correct_en_to_fa(mock_vocab_and_db):
|
||||
"""Correct Persian answer should be marked correct."""
|
||||
correct, answer, entry = mock_vocab_and_db.check_answer(
|
||||
"verb_go", "رفتن", direction="en_to_fa"
|
||||
)
|
||||
assert correct is True
|
||||
|
||||
|
||||
def test_check_answer_incorrect_en_to_fa(mock_vocab_and_db):
|
||||
"""Incorrect Persian answer should be marked incorrect with correct answer."""
|
||||
correct, answer, entry = mock_vocab_and_db.check_answer(
|
||||
"verb_go", "خوردن", direction="en_to_fa"
|
||||
)
|
||||
assert correct is False
|
||||
assert answer == "رفتن"
|
||||
|
||||
|
||||
def test_check_answer_fa_to_en(mock_vocab_and_db):
|
||||
"""Correct English answer (case-insensitive) should be marked correct."""
|
||||
correct, answer, entry = mock_vocab_and_db.check_answer(
|
||||
"verb_go", "To Go", direction="fa_to_en"
|
||||
)
|
||||
assert correct is True
|
||||
|
||||
|
||||
def test_check_answer_nonexistent_word(mock_vocab_and_db):
|
||||
"""Checking answer for nonexistent word should return False."""
|
||||
correct, answer, entry = mock_vocab_and_db.check_answer(
|
||||
"nonexistent", "test", direction="en_to_fa"
|
||||
)
|
||||
assert correct is False
|
||||
assert entry is None
|
||||
|
||||
|
||||
def test_format_word_card(mock_vocab_and_db):
|
||||
"""format_word_card should produce RTL HTML with correct content."""
|
||||
entry = SAMPLE_VOCAB[0]
|
||||
html = mock_vocab_and_db.format_word_card(entry, show_transliteration="Finglish")
|
||||
assert "رفتن" in html
|
||||
assert "to go" in html
|
||||
assert "raftan" in html
|
||||
|
||||
|
||||
def test_format_word_card_no_transliteration(mock_vocab_and_db):
|
||||
"""format_word_card with transliteration off should not show finglish."""
|
||||
entry = SAMPLE_VOCAB[0]
|
||||
html = mock_vocab_and_db.format_word_card(entry, show_transliteration="off")
|
||||
assert "raftan" not in html
|
||||
|
||||
|
||||
def test_get_flashcard_batch(mock_vocab_and_db):
|
||||
"""get_flashcard_batch should return a batch of entries."""
|
||||
batch = mock_vocab_and_db.get_flashcard_batch(count=2)
|
||||
assert len(batch) == 2
|
||||
assert all("id" in e for e in batch)
|
||||
|
||||
|
||||
def test_get_word_status_new(mock_vocab_and_db):
|
||||
"""Unreviewed word should have status 'new'."""
|
||||
assert mock_vocab_and_db.get_word_status("verb_go") == "new"
|
||||
|
||||
|
||||
def test_get_word_status_learning(mock_vocab_and_db):
|
||||
"""Recently reviewed word should have status 'learning'."""
|
||||
import db
|
||||
import fsrs
|
||||
|
||||
db.update_word_progress("verb_go", fsrs.Rating.Good)
|
||||
assert mock_vocab_and_db.get_word_status("verb_go") == "learning"
|
||||
@@ -1,5 +1,11 @@
|
||||
{
|
||||
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
||||
"python-envs.pythonProjects": []
|
||||
"python-envs.pythonProjects": [],
|
||||
"python.testing.pytestEnabled": true,
|
||||
"python.testing.unittestEnabled": false,
|
||||
"python.testing.pytestArgs": [
|
||||
"tests",
|
||||
"-v"
|
||||
]
|
||||
}
|
||||
@@ -12,11 +12,20 @@ Speech-to-text command line utilities leveraging local models (faster-whisper, O
|
||||
## Tools
|
||||
- `assistant.py` / `talk.sh` — transcribe speech, copy to clipboard, optionally send to Ollama
|
||||
- `voice_to_terminal.py` / `terminal.sh` — voice-controlled terminal via Ollama tool calling
|
||||
- `voice_to_xdotool.py` / `dotool.sh` — hands-free voice typing into any focused window (VAD + xdotool)
|
||||
- `voice_to_xdotool.py` / `xdotool.sh` — hands-free voice typing into any focused window (VAD + xdotool)
|
||||
|
||||
## Shared Library
|
||||
- `sttlib/` — shared package used by all scripts and importable by other projects
|
||||
- `whisper_loader.py` — model loading with GPU→CPU fallback
|
||||
- `audio.py` — press-enter recording, PCM conversion
|
||||
- `transcription.py` — Whisper transcribe wrapper, hallucination filter
|
||||
- `vad.py` — VADProcessor, audio callback, constants
|
||||
- Other projects import via: `sys.path.insert(0, "/path/to/tool-speechtotext")`
|
||||
|
||||
## Testing
|
||||
- To test scripts: `mamba run -n whisper-ollama python <script.py> --model-size base`
|
||||
- Run tests: `mamba run -n whisper-ollama python -m pytest tests/`
|
||||
- Use `--model-size base` for faster iteration during development
|
||||
- Tests mock hardware (Whisper model, VAD, mic) — no GPU/mic needed to run them
|
||||
- Audio device is available — live mic testing is possible
|
||||
- Test xdotool output by focusing a text editor window
|
||||
|
||||
@@ -24,12 +33,13 @@ Speech-to-text command line utilities leveraging local models (faster-whisper, O
|
||||
- Conda: faster-whisper, sounddevice, numpy, pyperclip, requests, ollama
|
||||
- Pip (in conda env): webrtcvad
|
||||
- System: libportaudio2, xdotool
|
||||
- Dev: pytest
|
||||
|
||||
## Conventions
|
||||
- Shell wrappers go in .sh files using `mamba run -n whisper-ollama`
|
||||
- All scripts set `CT2_CUDA_ALLOW_FP16=1`
|
||||
- Shared code lives in `sttlib/` — scripts are thin entry points that import from it
|
||||
- Whisper model loading always has GPU (cuda/float16) -> CPU (cpu/int8) fallback
|
||||
- Keep scripts self-contained (no shared module)
|
||||
- `CT2_CUDA_ALLOW_FP16=1` is set by `sttlib.whisper_loader` at import time
|
||||
- Don't print output for non-actionable events
|
||||
|
||||
## Preferences
|
||||
|
||||
@@ -1,57 +1,17 @@
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import argparse
|
||||
import pyperclip
|
||||
import requests
|
||||
import sys
|
||||
import argparse
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
import os
|
||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
||||
from sttlib import load_whisper_model, record_until_enter, transcribe
|
||||
|
||||
# --- Configuration ---
|
||||
MODEL_SIZE = "medium" # Options: "base", "small", "medium", "large-v3"
|
||||
OLLAMA_URL = "http://localhost:11434/api/generate" # Default is 11434
|
||||
OLLAMA_URL = "http://localhost:11434/api/generate"
|
||||
DEFAULT_OLLAMA_MODEL = "qwen3:latest"
|
||||
|
||||
# Load Whisper on GPU
|
||||
# float16 is faster and uses less VRAM on NVIDIA cards
|
||||
print("Loading Whisper model...")
|
||||
|
||||
|
||||
try:
|
||||
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="float16")
|
||||
except Exception as e:
|
||||
print(f"Error loading GPU: {e}")
|
||||
print("Falling back to CPU (Check your CUDA/cuDNN installation)")
|
||||
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="int16")
|
||||
|
||||
|
||||
def record_audio():
|
||||
fs = 16000
|
||||
print("\n[READY] Press Enter to START recording...")
|
||||
input()
|
||||
print("[RECORDING] Press Enter to STOP...")
|
||||
|
||||
recording = []
|
||||
|
||||
def callback(indata, frames, time, status):
|
||||
if status:
|
||||
print(status, file=sys.stderr)
|
||||
recording.append(indata.copy())
|
||||
|
||||
with sd.InputStream(samplerate=fs, channels=1, callback=callback):
|
||||
input()
|
||||
|
||||
return np.concatenate(recording, axis=0)
|
||||
|
||||
|
||||
def main():
|
||||
# 1. Setup Parser
|
||||
print(f"System active. Model: {DEFAULT_OLLAMA_MODEL}")
|
||||
parser = argparse.ArgumentParser(description="Whisper + Ollama CLI")
|
||||
|
||||
# Known Arguments (Hardcoded logic)
|
||||
parser.add_argument("--nollm", "-n", action='store_true',
|
||||
help="turn off llm")
|
||||
parser.add_argument("--system", "-s", default=None,
|
||||
@@ -65,30 +25,27 @@ def main():
|
||||
parser.add_argument(
|
||||
"--temp", default='0.7', help="temperature")
|
||||
|
||||
# 2. Capture "Unknown" arguments
|
||||
# args = known values, unknown = a list like ['--num_ctx', '4096', '--temp', '0.7']
|
||||
args, unknown = parser.parse_known_args()
|
||||
|
||||
# Convert unknown list to a dictionary for the Ollama 'options' field
|
||||
# This logic pairs ['--key', 'value'] into {key: value}
|
||||
extra_options = {}
|
||||
for i in range(0, len(unknown), 2):
|
||||
key = unknown[i].lstrip('-') # remove the '--'
|
||||
key = unknown[i].lstrip('-')
|
||||
val = unknown[i+1]
|
||||
# Try to convert numbers to actual ints/floats
|
||||
try:
|
||||
val = float(val) if '.' in val else int(val)
|
||||
except ValueError:
|
||||
pass
|
||||
extra_options[key] = val
|
||||
|
||||
model = load_whisper_model(args.model_size)
|
||||
|
||||
while True:
|
||||
try:
|
||||
audio_data = record_audio()
|
||||
audio_data = record_until_enter()
|
||||
|
||||
print("[TRANSCRIBING]...")
|
||||
segments, _ = model.transcribe(audio_data.flatten(), beam_size=5)
|
||||
text = "".join([segment.text for segment in segments]).strip()
|
||||
text = transcribe(model, audio_data.flatten())
|
||||
|
||||
if not text:
|
||||
print("No speech detected. Try again.")
|
||||
@@ -97,8 +54,7 @@ def main():
|
||||
print(f"You said: {text}")
|
||||
pyperclip.copy(text)
|
||||
|
||||
if (args.nollm == False):
|
||||
# Send to Ollama
|
||||
if not args.nollm:
|
||||
print(f"[OLLAMA] Thinking...")
|
||||
payload = {
|
||||
"model": args.ollama_model,
|
||||
@@ -108,9 +64,9 @@ def main():
|
||||
}
|
||||
|
||||
if args.system:
|
||||
payload["system"] = args
|
||||
response = requests.post(OLLAMA_URL, json=payload)
|
||||
payload["system"] = args.system
|
||||
|
||||
response = requests.post(OLLAMA_URL, json=payload)
|
||||
result = response.json().get("response", "")
|
||||
print(f"\nLLM Response:\n{result}\n")
|
||||
else:
|
||||
|
||||
7
python/tool-speechtotext/sttlib/__init__.py
Normal file
7
python/tool-speechtotext/sttlib/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from sttlib.whisper_loader import load_whisper_model
|
||||
from sttlib.audio import record_until_enter, pcm_bytes_to_float32
|
||||
from sttlib.transcription import transcribe, is_hallucination, HALLUCINATION_PATTERNS
|
||||
from sttlib.vad import (
|
||||
VADProcessor, audio_callback, audio_queue,
|
||||
SAMPLE_RATE, CHANNELS, FRAME_DURATION_MS, FRAME_SIZE, MIN_UTTERANCE_FRAMES,
|
||||
)
|
||||
28
python/tool-speechtotext/sttlib/audio.py
Normal file
28
python/tool-speechtotext/sttlib/audio.py
Normal file
@@ -0,0 +1,28 @@
|
||||
import sys
|
||||
import numpy as np
|
||||
import sounddevice as sd
|
||||
|
||||
|
||||
def record_until_enter(sample_rate=16000):
|
||||
"""Record audio until user presses Enter. Returns float32 numpy array."""
|
||||
print("\n[READY] Press Enter to START recording...")
|
||||
input()
|
||||
print("[RECORDING] Press Enter to STOP...")
|
||||
|
||||
recording = []
|
||||
|
||||
def callback(indata, frames, time, status):
|
||||
if status:
|
||||
print(status, file=sys.stderr)
|
||||
recording.append(indata.copy())
|
||||
|
||||
with sd.InputStream(samplerate=sample_rate, channels=1, callback=callback):
|
||||
input()
|
||||
|
||||
return np.concatenate(recording, axis=0)
|
||||
|
||||
|
||||
def pcm_bytes_to_float32(pcm_bytes):
|
||||
"""Convert raw 16-bit PCM bytes to float32 array normalized to [-1, 1]."""
|
||||
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
||||
return audio_int16.astype(np.float32) / 32768.0
|
||||
19
python/tool-speechtotext/sttlib/transcription.py
Normal file
19
python/tool-speechtotext/sttlib/transcription.py
Normal file
@@ -0,0 +1,19 @@
|
||||
HALLUCINATION_PATTERNS = [
|
||||
"thank you", "thanks for watching", "subscribe",
|
||||
"bye", "the end", "thank you for watching",
|
||||
"please subscribe", "like and subscribe",
|
||||
]
|
||||
|
||||
|
||||
def transcribe(model, audio_float32):
|
||||
"""Transcribe audio using Whisper. Returns stripped text."""
|
||||
segments, _ = model.transcribe(audio_float32, beam_size=5)
|
||||
return "".join(segment.text for segment in segments).strip()
|
||||
|
||||
|
||||
def is_hallucination(text):
|
||||
"""Return True if text looks like a Whisper hallucination."""
|
||||
lowered = text.lower().strip()
|
||||
if len(lowered) < 3:
|
||||
return True
|
||||
return any(p in lowered for p in HALLUCINATION_PATTERNS)
|
||||
58
python/tool-speechtotext/sttlib/vad.py
Normal file
58
python/tool-speechtotext/sttlib/vad.py
Normal file
@@ -0,0 +1,58 @@
|
||||
import sys
|
||||
import queue
|
||||
import collections
|
||||
import webrtcvad
|
||||
|
||||
SAMPLE_RATE = 16000
|
||||
CHANNELS = 1
|
||||
FRAME_DURATION_MS = 30
|
||||
FRAME_SIZE = int(SAMPLE_RATE * FRAME_DURATION_MS / 1000) # 480 samples
|
||||
MIN_UTTERANCE_FRAMES = 10 # ~300ms minimum to filter coughs/clicks
|
||||
|
||||
audio_queue = queue.Queue()
|
||||
|
||||
|
||||
def audio_callback(indata, frames, time_info, status):
|
||||
"""sounddevice callback that pushes raw bytes to the audio queue."""
|
||||
if status:
|
||||
print(status, file=sys.stderr)
|
||||
audio_queue.put(bytes(indata))
|
||||
|
||||
|
||||
class VADProcessor:
|
||||
def __init__(self, aggressiveness, silence_threshold):
|
||||
self.vad = webrtcvad.Vad(aggressiveness)
|
||||
self.silence_threshold = silence_threshold
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.triggered = False
|
||||
self.utterance_frames = []
|
||||
self.silence_duration = 0.0
|
||||
self.pre_buffer = collections.deque(maxlen=10) # ~300ms pre-roll
|
||||
|
||||
def process_frame(self, frame_bytes):
|
||||
"""Process one 30ms frame. Returns utterance bytes when complete, else None."""
|
||||
is_speech = self.vad.is_speech(frame_bytes, SAMPLE_RATE)
|
||||
|
||||
if not self.triggered:
|
||||
self.pre_buffer.append(frame_bytes)
|
||||
if is_speech:
|
||||
self.triggered = True
|
||||
self.silence_duration = 0.0
|
||||
self.utterance_frames = list(self.pre_buffer)
|
||||
self.utterance_frames.append(frame_bytes)
|
||||
else:
|
||||
self.utterance_frames.append(frame_bytes)
|
||||
if is_speech:
|
||||
self.silence_duration = 0.0
|
||||
else:
|
||||
self.silence_duration += FRAME_DURATION_MS / 1000.0
|
||||
if self.silence_duration >= self.silence_threshold:
|
||||
if len(self.utterance_frames) < MIN_UTTERANCE_FRAMES:
|
||||
self.reset()
|
||||
return None
|
||||
result = b"".join(self.utterance_frames)
|
||||
self.reset()
|
||||
return result
|
||||
return None
|
||||
15
python/tool-speechtotext/sttlib/whisper_loader.py
Normal file
15
python/tool-speechtotext/sttlib/whisper_loader.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import os
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
||||
|
||||
|
||||
def load_whisper_model(model_size):
|
||||
"""Load Whisper with GPU (cuda/float16) -> CPU (cpu/int8) fallback."""
|
||||
print(f"Loading Whisper model ({model_size})...")
|
||||
try:
|
||||
return WhisperModel(model_size, device="cuda", compute_type="float16")
|
||||
except Exception as e:
|
||||
print(f"GPU loading failed: {e}")
|
||||
print("Falling back to CPU (int8)")
|
||||
return WhisperModel(model_size, device="cpu", compute_type="int8")
|
||||
0
python/tool-speechtotext/tests/__init__.py
Normal file
0
python/tool-speechtotext/tests/__init__.py
Normal file
38
python/tool-speechtotext/tests/test_audio.py
Normal file
38
python/tool-speechtotext/tests/test_audio.py
Normal file
@@ -0,0 +1,38 @@
|
||||
import struct
|
||||
import numpy as np
|
||||
from sttlib.audio import pcm_bytes_to_float32
|
||||
|
||||
|
||||
def test_known_value():
|
||||
# 16384 in int16 -> 0.5 in float32
|
||||
pcm = struct.pack("<h", 16384)
|
||||
result = pcm_bytes_to_float32(pcm)
|
||||
assert abs(result[0] - 0.5) < 1e-5
|
||||
|
||||
|
||||
def test_silence():
|
||||
pcm = b"\x00\x00" * 10
|
||||
result = pcm_bytes_to_float32(pcm)
|
||||
assert np.all(result == 0.0)
|
||||
|
||||
|
||||
def test_full_scale():
|
||||
# max int16 = 32767 -> ~1.0
|
||||
pcm = struct.pack("<h", 32767)
|
||||
result = pcm_bytes_to_float32(pcm)
|
||||
assert abs(result[0] - (32767 / 32768.0)) < 1e-5
|
||||
|
||||
|
||||
def test_negative():
|
||||
# min int16 = -32768 -> -1.0
|
||||
pcm = struct.pack("<h", -32768)
|
||||
result = pcm_bytes_to_float32(pcm)
|
||||
assert result[0] == -1.0
|
||||
|
||||
|
||||
def test_round_trip_shape():
|
||||
# 100 samples worth of bytes
|
||||
pcm = b"\x00\x00" * 100
|
||||
result = pcm_bytes_to_float32(pcm)
|
||||
assert result.shape == (100,)
|
||||
assert result.dtype == np.float32
|
||||
78
python/tool-speechtotext/tests/test_transcription.py
Normal file
78
python/tool-speechtotext/tests/test_transcription.py
Normal file
@@ -0,0 +1,78 @@
|
||||
from unittest.mock import MagicMock
|
||||
from sttlib.transcription import transcribe, is_hallucination
|
||||
|
||||
|
||||
# --- is_hallucination tests ---
|
||||
|
||||
def test_known_hallucinations():
|
||||
assert is_hallucination("Thank you")
|
||||
assert is_hallucination("thanks for watching")
|
||||
assert is_hallucination("Subscribe")
|
||||
assert is_hallucination("the end")
|
||||
|
||||
|
||||
def test_short_text():
|
||||
assert is_hallucination("hi")
|
||||
assert is_hallucination("")
|
||||
assert is_hallucination("a")
|
||||
|
||||
|
||||
def test_normal_text():
|
||||
assert not is_hallucination("Hello how are you")
|
||||
assert not is_hallucination("Please open the terminal")
|
||||
|
||||
|
||||
def test_case_insensitivity():
|
||||
assert is_hallucination("THANK YOU")
|
||||
assert is_hallucination("Thank You For Watching")
|
||||
|
||||
|
||||
def test_substring_match():
|
||||
assert is_hallucination("I want to subscribe to your channel")
|
||||
|
||||
|
||||
def test_exactly_three_chars():
|
||||
assert not is_hallucination("hey")
|
||||
|
||||
|
||||
# --- transcribe tests ---
|
||||
|
||||
def _make_segment(text):
|
||||
seg = MagicMock()
|
||||
seg.text = text
|
||||
return seg
|
||||
|
||||
|
||||
def test_transcribe_joins_segments():
|
||||
model = MagicMock()
|
||||
model.transcribe.return_value = (
|
||||
[_make_segment("Hello "), _make_segment("world")],
|
||||
None,
|
||||
)
|
||||
result = transcribe(model, MagicMock())
|
||||
assert result == "Hello world"
|
||||
|
||||
|
||||
def test_transcribe_empty():
|
||||
model = MagicMock()
|
||||
model.transcribe.return_value = ([], None)
|
||||
result = transcribe(model, MagicMock())
|
||||
assert result == ""
|
||||
|
||||
|
||||
def test_transcribe_strips_whitespace():
|
||||
model = MagicMock()
|
||||
model.transcribe.return_value = (
|
||||
[_make_segment(" hello ")],
|
||||
None,
|
||||
)
|
||||
result = transcribe(model, MagicMock())
|
||||
assert result == "hello"
|
||||
|
||||
|
||||
def test_transcribe_passes_beam_size():
|
||||
model = MagicMock()
|
||||
model.transcribe.return_value = ([], None)
|
||||
audio = MagicMock()
|
||||
transcribe(model, audio)
|
||||
model.transcribe.assert_called_once_with(audio, beam_size=5)
|
||||
151
python/tool-speechtotext/tests/test_vad.py
Normal file
151
python/tool-speechtotext/tests/test_vad.py
Normal file
@@ -0,0 +1,151 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sttlib.vad import VADProcessor, FRAME_DURATION_MS, MIN_UTTERANCE_FRAMES
|
||||
|
||||
|
||||
def _make_vad_processor(aggressiveness=3, silence_threshold=0.8):
|
||||
"""Create VADProcessor with a mocked webrtcvad.Vad."""
|
||||
with patch("sttlib.vad.webrtcvad.Vad") as mock_vad_cls:
|
||||
mock_vad = MagicMock()
|
||||
mock_vad_cls.return_value = mock_vad
|
||||
proc = VADProcessor(aggressiveness, silence_threshold)
|
||||
return proc, mock_vad
|
||||
|
||||
|
||||
def _frame(label="x"):
|
||||
"""Return a fake 30ms frame (just needs to be distinct bytes)."""
|
||||
return label.encode() * 960 # 480 samples * 2 bytes
|
||||
|
||||
|
||||
def test_no_speech_returns_none():
|
||||
proc, mock_vad = _make_vad_processor()
|
||||
mock_vad.is_speech.return_value = False
|
||||
|
||||
for _ in range(100):
|
||||
assert proc.process_frame(_frame()) is None
|
||||
|
||||
|
||||
def test_speech_then_silence_triggers_utterance():
|
||||
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
|
||||
|
||||
# Feed enough speech frames
|
||||
speech_count = MIN_UTTERANCE_FRAMES + 5
|
||||
mock_vad.is_speech.return_value = True
|
||||
for _ in range(speech_count):
|
||||
result = proc.process_frame(_frame("s"))
|
||||
assert result is None # not done yet
|
||||
|
||||
# Feed silence frames until threshold (0.3s = 10 frames at 30ms)
|
||||
mock_vad.is_speech.return_value = False
|
||||
result = None
|
||||
for _ in range(20):
|
||||
result = proc.process_frame(_frame("q"))
|
||||
if result is not None:
|
||||
break
|
||||
|
||||
assert result is not None
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_short_utterance_filtered():
|
||||
# Use very short silence threshold so silence frames don't push total
|
||||
# past MIN_UTTERANCE_FRAMES. With threshold=0.09s (3 frames of silence):
|
||||
# 0 pre-buffer + 1 speech + 3 silence = 4 total < MIN_UTTERANCE_FRAMES (10)
|
||||
proc, mock_vad = _make_vad_processor(silence_threshold=0.09)
|
||||
|
||||
# Single speech frame triggers VAD
|
||||
mock_vad.is_speech.return_value = True
|
||||
proc.process_frame(_frame("s"))
|
||||
|
||||
# Immediately go silent — threshold reached in 3 frames
|
||||
mock_vad.is_speech.return_value = False
|
||||
result = None
|
||||
for _ in range(20):
|
||||
result = proc.process_frame(_frame("q"))
|
||||
if result is not None:
|
||||
break
|
||||
|
||||
# Should be filtered (too short — only 4 total frames)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_pre_buffer_included():
|
||||
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
|
||||
|
||||
# Fill pre-buffer with non-speech frames
|
||||
mock_vad.is_speech.return_value = False
|
||||
pre_frame = _frame("p")
|
||||
for _ in range(10):
|
||||
proc.process_frame(pre_frame)
|
||||
|
||||
# Speech starts
|
||||
mock_vad.is_speech.return_value = True
|
||||
speech_frame = _frame("s")
|
||||
for _ in range(MIN_UTTERANCE_FRAMES):
|
||||
proc.process_frame(speech_frame)
|
||||
|
||||
# Silence to trigger
|
||||
mock_vad.is_speech.return_value = False
|
||||
result = None
|
||||
for _ in range(20):
|
||||
result = proc.process_frame(_frame("q"))
|
||||
if result is not None:
|
||||
break
|
||||
|
||||
assert result is not None
|
||||
# Result should contain pre-buffer frames
|
||||
assert pre_frame in result
|
||||
|
||||
|
||||
def test_reset_after_utterance():
|
||||
proc, mock_vad = _make_vad_processor(silence_threshold=0.3)
|
||||
|
||||
# First utterance
|
||||
mock_vad.is_speech.return_value = True
|
||||
for _ in range(MIN_UTTERANCE_FRAMES + 5):
|
||||
proc.process_frame(_frame("s"))
|
||||
|
||||
mock_vad.is_speech.return_value = False
|
||||
for _ in range(20):
|
||||
result = proc.process_frame(_frame("q"))
|
||||
if result is not None:
|
||||
break
|
||||
assert result is not None
|
||||
|
||||
# After reset, should be able to collect a second utterance
|
||||
assert not proc.triggered
|
||||
assert proc.utterance_frames == []
|
||||
|
||||
mock_vad.is_speech.return_value = True
|
||||
for _ in range(MIN_UTTERANCE_FRAMES + 5):
|
||||
proc.process_frame(_frame("s"))
|
||||
|
||||
mock_vad.is_speech.return_value = False
|
||||
result2 = None
|
||||
for _ in range(20):
|
||||
result2 = proc.process_frame(_frame("q"))
|
||||
if result2 is not None:
|
||||
break
|
||||
assert result2 is not None
|
||||
|
||||
|
||||
def test_silence_threshold_boundary():
|
||||
# Use 0.3s threshold: 0.3 / 0.03 = exactly 10 frames needed
|
||||
threshold = 0.3
|
||||
proc, mock_vad = _make_vad_processor(silence_threshold=threshold)
|
||||
|
||||
# Start with speech
|
||||
mock_vad.is_speech.return_value = True
|
||||
for _ in range(MIN_UTTERANCE_FRAMES + 5):
|
||||
proc.process_frame(_frame("s"))
|
||||
|
||||
frames_needed = 10 # 0.3s / 0.03s per frame
|
||||
mock_vad.is_speech.return_value = False
|
||||
|
||||
# Feed one less than needed — should NOT trigger
|
||||
for i in range(frames_needed - 1):
|
||||
result = proc.process_frame(_frame("q"))
|
||||
assert result is None, f"Triggered too early at frame {i}"
|
||||
|
||||
# The 10th frame should trigger (silence_duration = 0.3 >= 0.3)
|
||||
result = proc.process_frame(_frame("q"))
|
||||
assert result is not None
|
||||
37
python/tool-speechtotext/tests/test_whisper_loader.py
Normal file
37
python/tool-speechtotext/tests/test_whisper_loader.py
Normal file
@@ -0,0 +1,37 @@
|
||||
from unittest.mock import patch, MagicMock
|
||||
from sttlib.whisper_loader import load_whisper_model
|
||||
|
||||
|
||||
@patch("sttlib.whisper_loader.WhisperModel")
|
||||
def test_gpu_success(mock_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_cls.return_value = mock_model
|
||||
|
||||
result = load_whisper_model("base")
|
||||
|
||||
assert result is mock_model
|
||||
mock_cls.assert_called_once_with("base", device="cuda", compute_type="float16")
|
||||
|
||||
|
||||
@patch("sttlib.whisper_loader.WhisperModel")
|
||||
def test_gpu_fails_cpu_fallback(mock_cls):
|
||||
mock_model = MagicMock()
|
||||
mock_cls.side_effect = [RuntimeError("no CUDA"), mock_model]
|
||||
|
||||
result = load_whisper_model("base")
|
||||
|
||||
assert result is mock_model
|
||||
assert mock_cls.call_count == 2
|
||||
_, kwargs = mock_cls.call_args
|
||||
assert kwargs == {"device": "cpu", "compute_type": "int8"}
|
||||
|
||||
|
||||
@patch("sttlib.whisper_loader.WhisperModel")
|
||||
def test_both_fail_propagates(mock_cls):
|
||||
mock_cls.side_effect = RuntimeError("no device")
|
||||
|
||||
try:
|
||||
load_whisper_model("base")
|
||||
assert False, "Should have raised"
|
||||
except RuntimeError:
|
||||
pass
|
||||
@@ -1,31 +1,16 @@
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import pyperclip
|
||||
import sys
|
||||
import argparse
|
||||
import os
|
||||
import subprocess
|
||||
import ollama
|
||||
import json
|
||||
from faster_whisper import WhisperModel
|
||||
import ollama
|
||||
from sttlib import load_whisper_model, record_until_enter, transcribe
|
||||
|
||||
# --- Configuration ---
|
||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
||||
MODEL_SIZE = "medium"
|
||||
OLLAMA_MODEL = "qwen2.5-coder:7b"
|
||||
CONFIRM_COMMANDS = True # Set to False to run commands instantly
|
||||
|
||||
# Load Whisper on GPU
|
||||
print("Loading Whisper model...")
|
||||
try:
|
||||
model = WhisperModel(MODEL_SIZE, device="cuda", compute_type="float16")
|
||||
except Exception as e:
|
||||
print(f"Error loading GPU: {e}, falling back to CPU")
|
||||
model = WhisperModel(MODEL_SIZE, device="cpu", compute_type="int8")
|
||||
|
||||
# --- Terminal Tool ---
|
||||
|
||||
|
||||
def run_terminal_command(command: str):
|
||||
"""
|
||||
Executes a bash command in the Linux terminal.
|
||||
@@ -33,8 +18,7 @@ def run_terminal_command(command: str):
|
||||
"""
|
||||
if CONFIRM_COMMANDS:
|
||||
print(f"\n{'='*40}")
|
||||
print(f"⚠️ AI SUGGESTED: \033[1;32m{command}\033[0m")
|
||||
# Allow user to provide feedback if they say 'n'
|
||||
print(f"\u26a0\ufe0f AI SUGGESTED: \033[1;32m{command}\033[0m")
|
||||
choice = input(" Confirm? [Y/n] or provide feedback: ").strip()
|
||||
|
||||
if choice.lower() == 'n':
|
||||
@@ -57,22 +41,15 @@ def run_terminal_command(command: str):
|
||||
return f"Execution Error: {str(e)}"
|
||||
|
||||
|
||||
def record_audio():
|
||||
fs, recording = 16000, []
|
||||
print("\n[READY] Press Enter to START...")
|
||||
input()
|
||||
print("[RECORDING] Press Enter to STOP...")
|
||||
def cb(indata, f, t, s): recording.append(indata.copy())
|
||||
with sd.InputStream(samplerate=fs, channels=1, callback=cb):
|
||||
input()
|
||||
return np.concatenate(recording, axis=0)
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--model", default=OLLAMA_MODEL)
|
||||
parser.add_argument("--model-size", default="medium",
|
||||
help="Whisper model size")
|
||||
args, _ = parser.parse_known_args()
|
||||
|
||||
whisper_model = load_whisper_model(args.model_size)
|
||||
|
||||
# Initial System Prompt
|
||||
messages = [{
|
||||
'role': 'system',
|
||||
@@ -88,9 +65,8 @@ def main():
|
||||
while True:
|
||||
try:
|
||||
# 1. Voice Capture
|
||||
audio_data = record_audio()
|
||||
segments, _ = model.transcribe(audio_data.flatten(), beam_size=5)
|
||||
user_text = "".join([s.text for s in segments]).strip()
|
||||
audio_data = record_until_enter()
|
||||
user_text = transcribe(whisper_model, audio_data.flatten())
|
||||
if not user_text:
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,91 +1,14 @@
|
||||
import sounddevice as sd
|
||||
import numpy as np
|
||||
import webrtcvad
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
import argparse
|
||||
import subprocess
|
||||
import threading
|
||||
import queue
|
||||
import collections
|
||||
import time
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
os.environ["CT2_CUDA_ALLOW_FP16"] = "1"
|
||||
|
||||
# --- Constants ---
|
||||
SAMPLE_RATE = 16000
|
||||
CHANNELS = 1
|
||||
FRAME_DURATION_MS = 30
|
||||
FRAME_SIZE = int(SAMPLE_RATE * FRAME_DURATION_MS / 1000) # 480 samples
|
||||
MIN_UTTERANCE_FRAMES = 10 # ~300ms minimum to filter coughs/clicks
|
||||
|
||||
HALLUCINATION_PATTERNS = [
|
||||
"thank you", "thanks for watching", "subscribe",
|
||||
"bye", "the end", "thank you for watching",
|
||||
"please subscribe", "like and subscribe",
|
||||
]
|
||||
|
||||
# --- Thread-safe audio queue ---
|
||||
audio_queue = queue.Queue()
|
||||
|
||||
|
||||
def audio_callback(indata, frames, time_info, status):
|
||||
if status:
|
||||
print(status, file=sys.stderr)
|
||||
audio_queue.put(bytes(indata))
|
||||
|
||||
|
||||
# --- Whisper model loading (reused pattern from assistant.py) ---
|
||||
def load_whisper_model(model_size):
|
||||
print(f"Loading Whisper model ({model_size})...")
|
||||
try:
|
||||
return WhisperModel(model_size, device="cuda", compute_type="float16")
|
||||
except Exception as e:
|
||||
print(f"GPU loading failed: {e}")
|
||||
print("Falling back to CPU (int8)")
|
||||
return WhisperModel(model_size, device="cpu", compute_type="int8")
|
||||
|
||||
|
||||
# --- VAD State Machine ---
|
||||
class VADProcessor:
|
||||
def __init__(self, aggressiveness, silence_threshold):
|
||||
self.vad = webrtcvad.Vad(aggressiveness)
|
||||
self.silence_threshold = silence_threshold
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
self.triggered = False
|
||||
self.utterance_frames = []
|
||||
self.silence_duration = 0.0
|
||||
self.pre_buffer = collections.deque(maxlen=10) # ~300ms pre-roll
|
||||
|
||||
def process_frame(self, frame_bytes):
|
||||
"""Process one 30ms frame. Returns utterance bytes when complete, else None."""
|
||||
is_speech = self.vad.is_speech(frame_bytes, SAMPLE_RATE)
|
||||
|
||||
if not self.triggered:
|
||||
self.pre_buffer.append(frame_bytes)
|
||||
if is_speech:
|
||||
self.triggered = True
|
||||
self.silence_duration = 0.0
|
||||
self.utterance_frames = list(self.pre_buffer)
|
||||
self.utterance_frames.append(frame_bytes)
|
||||
pass # silent until transcription confirms speech
|
||||
else:
|
||||
self.utterance_frames.append(frame_bytes)
|
||||
if is_speech:
|
||||
self.silence_duration = 0.0
|
||||
else:
|
||||
self.silence_duration += FRAME_DURATION_MS / 1000.0
|
||||
if self.silence_duration >= self.silence_threshold:
|
||||
if len(self.utterance_frames) < MIN_UTTERANCE_FRAMES:
|
||||
self.reset()
|
||||
return None
|
||||
result = b"".join(self.utterance_frames)
|
||||
self.reset()
|
||||
return result
|
||||
return None
|
||||
import sounddevice as sd
|
||||
from sttlib import (
|
||||
load_whisper_model, transcribe, is_hallucination, pcm_bytes_to_float32,
|
||||
VADProcessor, audio_callback, audio_queue,
|
||||
SAMPLE_RATE, CHANNELS, FRAME_SIZE,
|
||||
)
|
||||
|
||||
|
||||
# --- Typer Interface (xdotool) ---
|
||||
@@ -99,6 +22,7 @@ class Typer:
|
||||
except FileNotFoundError:
|
||||
print("ERROR: xdotool not found. Install it:")
|
||||
print(" sudo apt-get install xdotool")
|
||||
import sys
|
||||
sys.exit(1)
|
||||
|
||||
def type_text(self, text, submit_now=False):
|
||||
@@ -120,24 +44,6 @@ class Typer:
|
||||
pass
|
||||
|
||||
|
||||
# --- Helpers ---
|
||||
def pcm_bytes_to_float32(pcm_bytes):
|
||||
audio_int16 = np.frombuffer(pcm_bytes, dtype=np.int16)
|
||||
return audio_int16.astype(np.float32) / 32768.0
|
||||
|
||||
|
||||
def transcribe(model, audio_float32):
|
||||
segments, _ = model.transcribe(audio_float32, beam_size=5)
|
||||
return "".join(segment.text for segment in segments).strip()
|
||||
|
||||
|
||||
def is_hallucination(text):
|
||||
lowered = text.lower().strip()
|
||||
if len(lowered) < 3:
|
||||
return True
|
||||
return any(p in lowered for p in HALLUCINATION_PATTERNS)
|
||||
|
||||
|
||||
# --- CLI ---
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
|
||||
Reference in New Issue
Block a user