Compare commits

...

4 Commits

Author SHA1 Message Date
dl92
01d5532823 Add expression-evaluator: DAGs & state machines tutorial project
Educational calculator teaching FSMs (explicit transition table tokenizer)
and DAGs (recursive descent parser with AST evaluation). Includes CLI with
REPL, graphviz visualization, and 61 tests.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 18:09:42 +00:00
dl92
3a8705ece8 Fix bugs, N+1 queries, and wire settings in persian-tutor
- Replace inline __import__("datetime").timedelta hack with proper import
- Remove unused import random in anki_export.py
- Add error handling for Claude CLI subprocess failures in ai.py
- Fix hardcoded absolute path in stt.py with relative Path resolution
- Fix N+1 DB queries in vocab.get_flashcard_batch and dashboard.get_category_breakdown
  by adding db.get_all_word_progress() batch query
- Wire Ollama model and Whisper size settings to actually update config
  via ai.set_ollama_model() and stt.set_whisper_size()

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
2026-02-08 15:40:24 +00:00
local
8b5eb8797f working cva computation using quantlib 2026-02-08 13:40:35 +00:00
local
d484f9c236 working AMC algorithm tested against quantlib 2026-02-08 13:40:00 +00:00
28 changed files with 3253 additions and 12 deletions

View File

@@ -0,0 +1,5 @@
{
"python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda",
"python-envs.pythonProjects": []
}

View File

@@ -0,0 +1,19 @@
Current Progress Summary:
The Baseline: We fixed a standard Longstaff-Schwartz (LSM) American Put pricer, correcting the cash-flow propagation logic and regression targets.
The Evolution: We moved to Bermudan Swaptions using the Hull-White One-Factor Model.
The "Gold Standard": We implemented a 100% exact simulation from scratch. Instead of Euler discretization, we used Bivariate Normal sampling to jointly simulate the short rate rt and the stochastic integral ∫rsds. This accounts for the stochastic discount factor (the convexity adjustment) without approximation error.
The Current Frontier: We were debating the Risk-Neutral Measure (Q) vs. the Terminal Forward Measure (QT).
We concluded that while QT simplifies European options, it makes Bermudan LSM "messy" because it introduces a time-dependent drift shift: DriftQT=DriftQaσ2(1ea(Tt)).
Pending Topics:
Mathematical Proof: The derivation of the "Drift Shift" via Girsanovs Theorem.
Exercise Boundary Impact: How the choice of measure (and the resulting drift) visually shifts the optimal exercise boundary in simulation.
Beyond One-Factor: Potential move toward Two-Factor models or non-flat initial term structures.

145
python/american-mc/main.py Normal file
View File

@@ -0,0 +1,145 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
# -----------------------------
# Black-Scholes European Put
# -----------------------------
def european_put_black_scholes(S0, K, r, sigma, T):
d1 = (np.log(S0/K) + (r + 0.5*sigma**2)*T) / (sigma*np.sqrt(T))
d2 = d1 - sigma*np.sqrt(T)
return K*np.exp(-r*T)*norm.cdf(-d2) - S0*norm.cdf(-d1)
# -----------------------------
# GBM Simulation
# -----------------------------
def simulate_gbm(S0, r, sigma, T, M, N, seed=None):
if seed is not None:
np.random.seed(seed)
dt = T / N
# Vectorized simulation for speed
S = np.zeros((M, N + 1))
S[:, 0] = S0
for t in range(1, N + 1):
z = np.random.standard_normal(M)
S[:, t] = S[:, t-1] * np.exp((r - 0.5 * sigma**2) * dt + sigma * np.sqrt(dt) * z)
return S
# -----------------------------
# Longstaff-Schwartz Monte Carlo (Fixed)
# -----------------------------
def lsm_american_put(S0, K, r, sigma, T, M=100_000, N=50, basis_deg=2, seed=42, return_boundary=False):
S = simulate_gbm(S0, r, sigma, T, M, N, seed)
dt = T / N
df = np.exp(-r * dt)
# Immediate exercise value (payoff) at each step
payoff = np.maximum(K - S, 0)
# V stores the value of the option at each path
# Initialize with the payoff at maturity
V = payoff[:, -1]
boundary = []
# Backward induction
for t in range(N - 1, 0, -1):
# Identify In-The-Money paths
itm = payoff[:, t] > 0
# Default: option value at t is just the discounted value from t+1
V = V * df
if np.any(itm):
X = S[itm, t]
Y = V[itm] # These are already discounted from t+1
# Regression: Estimate Continuation Value
coeffs = np.polyfit(X, Y, basis_deg)
continuation_value = np.polyval(coeffs, X)
# Exercise if payoff > estimated continuation value
exercise = payoff[itm, t] > continuation_value
# Update V for paths where we exercise
# Get the indices of the ITM paths that should exercise
itm_indices = np.where(itm)[0]
exercise_indices = itm_indices[exercise]
V[exercise_indices] = payoff[exercise_indices, t]
# Boundary: The highest stock price at which we still exercise
if len(exercise_indices) > 0:
boundary.append((t * dt, np.max(S[exercise_indices, t])))
else:
boundary.append((t * dt, np.nan))
else:
boundary.append((t * dt, np.nan))
# Final price is the average of discounted values at t=1
price = np.mean(V * df)
if return_boundary:
return price, boundary[::-1]
return price
# -----------------------------
# QuantLib American Put (v1.18)
# -----------------------------
def quantlib_american_put(S0, K, r, sigma, T):
import QuantLib as ql
today = ql.Date().todaysDate()
ql.Settings.instance().evaluationDate = today
maturity = today + int(T*365)
payoff = ql.PlainVanillaPayoff(ql.Option.Put, K)
exercise = ql.AmericanExercise(today, maturity)
spot = ql.QuoteHandle(ql.SimpleQuote(S0))
dividend_curve = ql.YieldTermStructureHandle(ql.FlatForward(today, 0.0, ql.Actual365Fixed()))
riskfree_curve = ql.YieldTermStructureHandle(ql.FlatForward(today, r, ql.Actual365Fixed()))
vol_curve = ql.BlackVolTermStructureHandle(ql.BlackConstantVol(today, ql.NullCalendar(), sigma, ql.Actual365Fixed()))
process = ql.BlackScholesMertonProcess(spot, dividend_curve, riskfree_curve, vol_curve)
engine = ql.FdBlackScholesVanillaEngine(process, 200, 400)
option = ql.VanillaOption(payoff, exercise)
option.setPricingEngine(engine)
return option.NPV()
# -----------------------------
# Main script
# -----------------------------
if __name__ == "__main__":
# Parameters
S0, K, r, sigma, T = 100, 100, 0.05, 0.2, 1.0
M, N = 200_000, 50
# European Put
eur_bs = european_put_black_scholes(S0, K, r, sigma, T)
print(f"European Put (BS): {eur_bs:.4f}")
# LSM American Put
lsm_price, boundary = lsm_american_put(S0, K, r, sigma, T, M, N, return_boundary=True)
print(f"LSM American Put (M={M}): {lsm_price:.4f}")
# QuantLib American Put
ql_price = quantlib_american_put(S0, K, r, sigma, T)
print(f"QuantLib American Put: {ql_price:.4f}")
print(f"Lower bound (immediate exercise): {max(K-S0,0):.4f}")
# Plot Exercise Boundary
times = [b[0] for b in boundary]
boundaries = [b[1] for b in boundary]
plt.figure(figsize=(8,5))
plt.plot(times, boundaries, color='orange', label="LSM Exercise Boundary")
plt.axhline(K, color='red', linestyle='--', label="Strike Price")
plt.xlabel("Time to Maturity")
plt.ylabel("Stock Price for Exercise")
plt.title("American Put LSM Exercise Boundary")
plt.gca().invert_xaxis()
plt.legend()
plt.grid(True)
plt.show()

133
python/american-mc/main2.py Normal file
View File

@@ -0,0 +1,133 @@
import numpy as np
import matplotlib.pyplot as plt
# ---------------------------------------------------------
# 1. Analytical Hull-White Bond Pricing
# ---------------------------------------------------------
def bond_price(r_t, t, T, a, sigma, r0):
"""Zero-coupon bond price P(t, T) in Hull-White model for flat curve r0."""
B = (1 - np.exp(-a * (T - t))) / a
A = np.exp((B - (T - t)) * (a**2 * r0 - sigma**2/2) / a**2 - (sigma**2 * B**2 / (4 * a)))
return A * np.exp(-B * r_t)
def get_swap_npv(r_t, t, T_end, strike, a, sigma, r0):
"""NPV of a Payer Swap: Receives Floating, Pays Fixed strike."""
payment_dates = np.arange(t + 1, T_end + 1, 1.0)
if len(payment_dates) == 0:
return np.zeros_like(r_t)
# NPV = 1 - P(t, T_end) - strike * Sum[P(t, T_i)]
p_end = bond_price(r_t, t, T_end, a, sigma, r0)
fixed_leg = sum(bond_price(r_t, t, pd, a, sigma, r0) for pd in payment_dates)
return np.maximum(1 - p_end - strike * fixed_leg, 0)
# ---------------------------------------------------------
# 2. Joint Exact Simulator (Short Rate & Stochastic Integral)
# ---------------------------------------------------------
def simulate_hw_exact_joint(r0, a, sigma, exercise_dates, M):
"""Samples (r_t, Integral[r]) jointly to get exact discount factors."""
M = int(M)
num_dates = len(exercise_dates)
r_matrix = np.zeros((M, num_dates + 1))
d_matrix = np.zeros((M, num_dates)) # d[i] = exp(-integral from t_i to t_{i+1})
r_matrix[:, 0] = r0
t_steps = np.insert(exercise_dates, 0, 0.0)
for i in range(len(t_steps) - 1):
t, T = t_steps[i], t_steps[i+1]
dt = T - t
# Drift adjustments for flat initial curve r0
alpha_t = r0 + (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * t))**2
alpha_T = r0 + (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * T))**2
# Means
mean_r = r_matrix[:, i] * np.exp(-a * dt) + alpha_T - alpha_t * np.exp(-a * dt)
# The expected value of the integral is derived from the bond price: E[exp(-I)] = P(t,T)
mean_I = -np.log(bond_price(r_matrix[:, i], t, T, a, sigma, r0))
# Covariance Matrix Components
var_r = (sigma**2 / (2 * a)) * (1 - np.exp(-2 * a * dt))
B = (1 - np.exp(-a * dt)) / a
var_I = (sigma**2 / a**2) * (dt - 2*B + (1 - np.exp(-2*a*dt))/(2*a))
cov_rI = (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * dt))**2
cov_matrix = [[var_r, cov_rI], [cov_rI, var_I]]
# Sample joint normal innovations
Z = np.random.multivariate_normal([0, 0], cov_matrix, M)
r_matrix[:, i+1] = mean_r + Z[:, 0]
# Important: The variance of the integral affects the mean of the exponent (convexity)
# mean_I here is already the log of the bond price (risk-neutral expectation)
d_matrix[:, i] = np.exp(-(mean_I + Z[:, 1] - 0.5 * var_I))
return r_matrix[:, 1:], d_matrix
# ---------------------------------------------------------
# 3. LSM Pricing Logic
# ---------------------------------------------------------
def price_bermudan_swaption(r0, a, sigma, strike, exercise_dates, T_end, M):
# 1. Generate exact paths
r_at_ex, discounts = simulate_hw_exact_joint(r0, a, sigma, exercise_dates, M)
# 2. Final exercise date payoff
T_last = exercise_dates[-1]
cash_flows = get_swap_npv(r_at_ex[:, -1], T_last, T_end, strike, a, sigma, r0)
# 3. Backward Induction
# exercise_dates[:-1] because we already handled the last date
for i in reversed(range(len(exercise_dates) - 1)):
t_current = exercise_dates[i]
# Pull cash flows back to current time using stochastic discount
# Note: If path was exercised later, this is the discounted value of that exercise.
cash_flows = cash_flows * discounts[:, i]
# Current intrinsic value
X = r_at_ex[:, i]
immediate_payoff = get_swap_npv(X, t_current, T_end, strike, a, sigma, r0)
# Only regress In-The-Money paths
itm = immediate_payoff > 0
if np.any(itm):
# Regression: Basis functions [1, r, r^2]
A = np.column_stack([np.ones_like(X[itm]), X[itm], X[itm]**2])
coeffs = np.linalg.lstsq(A, cash_flows[itm], rcond=None)[0]
continuation_value = A @ coeffs
# Exercise decision
exercise = immediate_payoff[itm] > continuation_value
itm_indices = np.where(itm)[0]
exercise_indices = itm_indices[exercise]
# Update cash flows for paths where we exercise early
cash_flows[exercise_indices] = immediate_payoff[exercise_indices]
# Final discount to t=0
# The first discount factor in 'discounts' is from t1 to t0
# But wait, our 'discounts' matrix is (M, num_dates).
# Let's just use the analytical P(0, t1) for the very last step to t=0.
final_price = np.mean(cash_flows * bond_price(r0, 0, exercise_dates[0], a, sigma, r0))
return final_price
# ---------------------------------------------------------
# Execution
# ---------------------------------------------------------
if __name__ == "__main__":
# Params
r0_val, a_val, sigma_val = 0.05, 0.1, 0.01
strike_val = 0.05
ex_dates = np.array([1.0, 2.0, 3.0, 4.0])
maturity = 5.0
num_paths = 100_000
price = price_bermudan_swaption(r0_val, a_val, sigma_val, strike_val, ex_dates, maturity, num_paths)
print(f"--- 100% Exact LSM Bermudan Swaption ---")
print(f"Parameters: a={a_val}, sigma={sigma_val}, strike={strike_val}")
print(f"Exercise Dates: {ex_dates}")
print(f"Calculated Price: {price:.6f}")

View File

@@ -0,0 +1,5 @@
{
"python-envs.defaultEnvManager": "ms-python.python:conda",
"python-envs.defaultPackageManager": "ms-python.python:conda",
"python-envs.pythonProjects": []
}

File diff suppressed because one or more lines are too long

View File

@@ -0,0 +1,83 @@
import numpy as np
import Quantlib as ql
from hullwhite import (build_calibrated_hw,simulate_hw_state,swap_npv_from_state)
from instruments import (make_vanilla_swap,make_swaption_helpers)
today = ql.Date(15, 1, 2025)
ql.Settings.instance().evaluationDate = today
calendar = ql.TARGET()
day_count = ql.Actual365Fixed()
def expected_exposure(
swap,
hw,
curve_handle,
today,
time_grid,
x_paths
):
ee = np.zeros(len(time_grid))
for j, t in enumerate(time_grid):
values = []
for i in range(len(x_paths)):
v = swap_npv_from_state(
swap, hw, curve_handle, t, x_paths[i, j], today
)
values.append(max(v, 0.0))
ee[j] = np.mean(values)
return ee
def main():
# Time grid
time_grid = np.linspace(0, 10, 121)
flat_rate = 0.02
curve = ql.FlatForward(today, flat_rate, day_count)
curve_handle = ql.YieldTermStructureHandle(curve)
#creat and calibrate HW
a_init = 0.05
sigma_init = 0.01
hw = ql.HullWhite(curve_handle, a_init, sigma_init)
index = ql.Euribor6M(curve_handle)
helpers=make_swaption_helpers(curve_handle,index,hw)
optimizer = ql.LevenbergMarquardt()
end_criteria = ql.EndCriteria(1000, 500, 1e-8, 1e-8, 1e-8)
hw.calibrate(helpers, optimizer, end_criteria)
a, sigma = hw.params()
print(f"Calibrated a = {a:.4f}")
print(f"Calibrated sigma = {sigma:.4f}")
# Simulate model
x_paths = simulate_hw_state(
hw, curve_handle, time_grid, n_paths=10000
)
# Swap
swap = make_receiver_swap(
today, curve_handle, maturity_years=10, fixed_rate=0.025
)
# Exposure
ee = expected_exposure(
swap, hw, curve_handle, today, time_grid, x_paths
)
return 0
if __name__ =='__main__':
main()

View File

@@ -0,0 +1,66 @@
import QuantLib as ql
def build_calibrated_hw(
curve_handle,
swaption_helpers,
a_init=0.05,
sigma_init=0.01
):
hw = ql.HullWhite(curve_handle, a_init, sigma_init)
optimizer = ql.LevenbergMarquardt()
end_criteria = ql.EndCriteria(1000, 500, 1e-8, 1e-8, 1e-8)
hw.calibrate(swaption_helpers, optimizer, end_criteria)
return hw
import numpy as np
def simulate_hw_state(
hw: ql.HullWhite,
curve_handle,
time_grid,
n_paths,
seed=42
):
a, sigma = hw.params()
dt = np.diff(time_grid)
n_steps = len(time_grid)
np.random.seed(seed)
x = np.zeros((n_paths, n_steps))
for i in range(1, n_steps):
decay = np.exp(-a * dt[i-1])
vol = sigma * np.sqrt((1 - np.exp(-2 * a * dt[i-1])) / (2 * a))
z = np.random.normal(size=n_paths)
x[:, i] = x[:, i-1] * decay + vol * z
return x
def swap_npv_from_state(
swap,
hw,
curve_handle,
t,
x_t,
today
):
ql.Settings.instance().evaluationDate = today + int(t * 365)
hw.setState(x_t)
engine = ql.DiscountingSwapEngine(
curve_handle,
False,
hw
)
swap.setPricingEngine(engine)
return swap.NPV()

View File

@@ -0,0 +1,165 @@
import QuantLib as ql
# ============================================================
# Global defaults (override via parameters if needed)
# ============================================================
CALENDAR = ql.TARGET()
BUSINESS_CONVENTION = ql.ModifiedFollowing
DATE_GEN = ql.DateGeneration.Forward
FIXED_DAYCOUNT = ql.Thirty360(ql.Thirty360.European)
FLOAT_DAYCOUNT = ql.Actual360()
# ============================================================
# Yield curve factory
# ============================================================
def make_flat_curve(
evaluation_date: ql.Date,
rate: float,
day_count: ql.DayCounter = ql.Actual365Fixed()
) -> ql.YieldTermStructureHandle:
"""
Flat yield curve factory.
"""
ql.Settings.instance().evaluationDate = evaluation_date
curve = ql.FlatForward(evaluation_date, rate, day_count)
return ql.YieldTermStructureHandle(curve)
# ============================================================
# Index factory
# ============================================================
def make_euribor_6m(
curve_handle: ql.YieldTermStructureHandle
) -> ql.IborIndex:
"""
Euribor 6M index factory.
"""
return ql.Euribor6M(curve_handle)
# ============================================================
# Schedule factory
# ============================================================
def make_schedule(
start: ql.Date,
maturity: ql.Date,
tenor: ql.Period,
calendar: ql.Calendar = CALENDAR
) -> ql.Schedule:
"""
Generic schedule factory.
"""
return ql.Schedule(
start,
maturity,
tenor,
calendar,
BUSINESS_CONVENTION,
BUSINESS_CONVENTION,
DATE_GEN,
False
)
# ============================================================
# Swap factory
# ============================================================
def make_vanilla_swap(
evaluation_date: ql.Date,
curve_handle: ql.YieldTermStructureHandle,
notional: float,
fixed_rate: float,
maturity_years: int,
pay_fixed: bool = False,
fixed_leg_freq: ql.Period = ql.Annual,
float_leg_freq: ql.Period = ql.Semiannual
) -> ql.VanillaSwap:
"""
Vanilla fixed-for-float IRS factory.
pay_fixed = True -> payer swap
pay_fixed = False -> receiver swap
"""
ql.Settings.instance().evaluationDate = evaluation_date
index = make_euribor_6m(curve_handle)
start = CALENDAR.advance(evaluation_date, 2, ql.Days)
maturity = CALENDAR.advance(start, maturity_years, ql.Years)
fixed_schedule = make_schedule(
start, maturity, ql.Period(fixed_leg_freq)
)
float_schedule = make_schedule(
start, maturity, ql.Period(float_leg_freq)
)
swap_type = (
ql.VanillaSwap.Payer if pay_fixed else ql.VanillaSwap.Receiver
)
swap = ql.VanillaSwap(
swap_type,
notional,
fixed_schedule,
fixed_rate,
FIXED_DAYCOUNT,
float_schedule,
index,
0.0,
FLOAT_DAYCOUNT
)
swap.setPricingEngine(
ql.DiscountingSwapEngine(curve_handle)
)
return swap
# ============================================================
# Swaption helper factory (for HW calibration)
# ============================================================
def make_swaption_helpers(
swaption_data: list,
curve_handle: ql.YieldTermStructureHandle,
index: ql.IborIndex,
model: ql.HullWhite
) -> list:
"""
Create ATM swaption helpers.
swaption_data = [(expiry, tenor, vol), ...]
"""
helpers = []
for expiry, tenor, vol in swaption_data:
helper = ql.SwaptionHelper(
expiry,
tenor,
ql.QuoteHandle(ql.SimpleQuote(vol)),
index,
index.tenor(),
index.dayCounter(),
index.dayCounter(),
curve_handle
)
helper.setPricingEngine(
ql.JamshidianSwaptionEngine(model)
)
helpers.append(helper)
return helpers

109
python/cvatesting/main.py Normal file
View File

@@ -0,0 +1,109 @@
import QuantLib as ql
import numpy as np
import matplotlib.pyplot as plt
# --- UNIT 1: Instrument Setup ---
def create_30y_swap(today, rate_quote):
ql.Settings.instance().evaluationDate = today
# We use a Relinkable handle to swap curves per path/time-step
yield_handle = ql.RelinkableYieldTermStructureHandle()
yield_handle.linkTo(ql.FlatForward(today, rate_quote, ql.Actual365Fixed()))
calendar = ql.TARGET()
settle_date = calendar.advance(today, 2, ql.Days)
maturity_date = calendar.advance(settle_date, 30, ql.Years)
index = ql.Euribor6M(yield_handle)
fixed_schedule = ql.Schedule(settle_date, maturity_date, ql.Period(ql.Annual),
calendar, ql.ModifiedFollowing, ql.ModifiedFollowing,
ql.DateGeneration.Forward, False)
float_schedule = ql.Schedule(settle_date, maturity_date, ql.Period(ql.Semiannual),
calendar, ql.ModifiedFollowing, ql.ModifiedFollowing,
ql.DateGeneration.Forward, False)
swap = ql.VanillaSwap(ql.VanillaSwap.Payer, 1e6, fixed_schedule, rate_quote,
ql.Thirty360(ql.Thirty360.BondBasis), float_schedule,
index, 0.0, ql.Actual360())
# Pre-calculate all fixing dates needed for the life of the swap
fixing_dates = [index.fixingDate(d) for d in float_schedule]
return swap, yield_handle, index, fixing_dates
# --- UNIT 2: The Simulation Loop ---
def run_simulation(n_paths=50):
today = ql.Date(27, 1, 2026)
swap, yield_handle, index, fixing_dates = create_30y_swap(today, 0.03)
# HW Parameters: a=mean reversion, sigma=vol
model = ql.HullWhite(yield_handle, 0.01, 0.01)
process = ql.HullWhiteProcess(yield_handle, 0.01, 0.01)
times = np.arange(0, 31, 1.0) # Annual buckets
grid = ql.TimeGrid(times)
rng = ql.GaussianRandomSequenceGenerator(ql.UniformRandomSequenceGenerator(
len(grid)-1, ql.UniformRandomGenerator()))
seq = ql.GaussianMultiPathGenerator(process, grid, rng, False)
npv_matrix = np.zeros((n_paths, len(times)))
for i in range(n_paths):
path = seq.next().value()[0]
# 1. Clear previous path's fixings to avoid data pollution
ql.IndexManager.instance().clearHistories()
for j, t in enumerate(times):
if t >= 30: continue
eval_date = ql.TARGET().advance(today, int(t), ql.Years)
ql.Settings.instance().evaluationDate = eval_date
# 2. Update the curve with simulated short rate rt
rt = path[j]
# Use pillars up to 35Y to avoid extrapolation crashes
tenors = [0, 1, 2, 5, 10, 20, 30, 35]
dates = [ql.TARGET().advance(eval_date, y, ql.Years) for y in tenors]
discounts = [model.discountBond(t, t + y, rt) for y in tenors]
sim_curve = ql.DiscountCurve(dates, discounts, ql.Actual365Fixed())
sim_curve.enableExtrapolation()
yield_handle.linkTo(sim_curve)
# 3. MANUAL FIXING INJECTION
# For every fixing date that has passed or is 'today'
for fd in fixing_dates:
if fd <= eval_date:
# Direct manual calculation to avoid index.fixing() error
# We calculate the forward rate for the 6M period starting at fd
start_date = fd
end_date = ql.TARGET().advance(start_date, 6, ql.Months)
# Manual Euribor rate formula: (P1/P2 - 1) / dt
p1 = sim_curve.discount(start_date)
p2 = sim_curve.discount(end_date)
dt = ql.Actual360().yearFraction(start_date, end_date)
fwd_rate = (p1 / p2 - 1.0) / dt
index.addFixing(fd, fwd_rate, True) # Force overwrite
# 4. Valuation
swap.setPricingEngine(ql.DiscountingSwapEngine(yield_handle))
npv_matrix[i, j] = max(swap.NPV(), 0)
ql.Settings.instance().evaluationDate = today
return times, npv_matrix
# --- UNIT 3: Visualization ---
times, npv_matrix = run_simulation(n_paths=100)
ee = np.mean(npv_matrix, axis=0)
plt.figure(figsize=(10, 5))
plt.plot(times, ee, lw=2, label="Expected Exposure (EE)")
plt.fill_between(times, ee, alpha=0.3)
plt.title("30Y Swap EE Profile - Hull White")
plt.xlabel("Years"); plt.ylabel("Exposure"); plt.grid(True); plt.legend()
plt.show()

View File

@@ -0,0 +1,42 @@
# Expression Evaluator
## Overview
Educational project teaching DAGs and state machines through a calculator.
Pure Python, no external dependencies.
## Running
```bash
python main.py "3 + 4 * 2" # single expression
python main.py # REPL mode
python main.py --show-tokens --show-ast --trace "expr" # show internals
python main.py --dot "3+4*2" | dot -Tpng -o ast.png # AST diagram
python main.py --dot-fsm | dot -Tpng -o fsm.png # FSM diagram
```
## Testing
```bash
python -m pytest tests/ -v
```
## Architecture
- `tokenizer.py` -- Explicit finite state machine (Mealy machine) tokenizer
- `parser.py` -- Recursive descent parser building an AST (DAG)
- `evaluator.py` -- Post-order tree walker (topological sort evaluation)
- `visualize.py` -- Graphviz dot generation for AST and FSM diagrams
- `main.py` -- CLI entry point with argparse, REPL mode
## Key Design Decisions
- State machine uses an explicit transition table (dict), not implicit if/else
- Unary minus resolved by examining previous token context
- Power operator (`^`) is right-associative (grammar uses right-recursion)
- AST nodes are dataclasses; evaluation uses structural pattern matching
- Graphviz output is raw dot strings (no graphviz Python package needed)
## Grammar
```
expression ::= term ((PLUS | MINUS) term)*
term ::= unary ((MULTIPLY | DIVIDE) unary)*
unary ::= UNARY_MINUS unary | power
power ::= atom (POWER power)?
atom ::= NUMBER | LPAREN expression RPAREN
```

View File

@@ -0,0 +1,87 @@
# Expression Evaluator -- DAGs & State Machines Tutorial
A calculator that teaches two fundamental CS patterns by building them from scratch:
1. **Finite State Machine** -- the tokenizer processes input character-by-character using an explicit transition table
2. **Directed Acyclic Graph (DAG)** -- the parser builds an expression tree, evaluated bottom-up in topological order
## What You'll Learn
| File | CS Concept | What it does |
|------|-----------|-------------|
| `tokenizer.py` | **State Machine** (Mealy machine) | Converts `"3 + 4 * 2"` into tokens using a transition table |
| `parser.py` | **DAG construction** | Builds an expression tree with operator precedence |
| `evaluator.py` | **Topological evaluation** | Walks the tree bottom-up (leaves before parents) |
| `visualize.py` | **Visualization** | Generates graphviz diagrams of both the FSM and AST |
## Quick Start
```bash
# Evaluate an expression
python main.py "3 + 4 * 2"
# => 11
# Interactive REPL
python main.py
# See how the state machine tokenizes
python main.py --show-tokens "(2 + 3) * -4"
# See the expression tree (DAG)
python main.py --show-ast "(2 + 3) * 4"
# *
# +-- +
# | +-- 2
# | `-- 3
# `-- 4
# Watch evaluation in topological order
python main.py --trace "(2 + 3) * 4"
# Step 1: 2 => 2
# Step 2: 3 => 3
# Step 3: 2 + 3 => 5
# Step 4: 4 => 4
# Step 5: 5 * 4 => 20
# Generate graphviz diagrams
python main.py --dot "(2 + 3) * 4" | dot -Tpng -o ast.png
python main.py --dot-fsm | dot -Tpng -o fsm.png
```
## Features
- Arithmetic: `+`, `-`, `*`, `/`, `^` (power)
- Parentheses: `(2 + 3) * 4`
- Unary minus: `-3`, `-(2 + 1)`, `2 * -3`
- Decimals: `3.14`, `.5`
- Standard precedence: parens > `^` > `*`/`/` > `+`/`-`
- Right-associative power: `2^3^4` = `2^(3^4)`
- Correct unary minus: `-3^2` = `-(3^2)` = `-9`
## Running Tests
```bash
python -m pytest tests/ -v
```
## How the State Machine Works
The tokenizer in `tokenizer.py` uses an **explicit transition table** -- a dictionary mapping `(current_state, character_class)` to `(next_state, action)`. This is the same pattern used in network protocol parsers, regex engines, and compiler lexers.
The three states are:
- `START` -- between tokens, dispatching based on the next character
- `INTEGER` -- accumulating digits (e.g., `"12"` so far)
- `DECIMAL` -- accumulating digits after a decimal point (e.g., `"12.3"`)
Use `--dot-fsm` to generate a visual diagram of the state machine.
## How the DAG Works
The parser in `parser.py` builds an **expression tree** (AST) where:
- **Leaf nodes** are numbers (no dependencies)
- **Interior nodes** are operators with edges to their operands
- **Edges** represent "depends on" relationships
Evaluation in `evaluator.py` walks this tree **bottom-up** -- children before parents. This is exactly a **topological sort** of the DAG: you can only compute a node after all its dependencies are resolved.
Use `--show-ast` to see the tree structure, or `--dot` to generate a graphviz diagram.

View File

@@ -0,0 +1,147 @@
"""
Part 3: DAG Evaluation -- Tree Walker
=======================================
Evaluating the AST bottom-up is equivalent to topological-sort
evaluation of a DAG. We must evaluate a node's children before
the node itself -- just like in any dependency graph.
For a tree, post-order traversal gives a topological ordering.
The recursive evaluate() function naturally does this:
1. Recursively evaluate all children (dependencies)
2. Combine the results (compute this node's value)
3. Return the result (make it available to the parent)
This is the same pattern as:
- make: build dependencies before the target
- pip/npm install: install dependencies before the package
- Spreadsheet recalculation: compute referenced cells first
"""
from parser import NumberNode, BinOpNode, UnaryOpNode, Node
from tokenizer import TokenType
# ---------- Errors ----------
class EvalError(Exception):
pass
# ---------- Evaluator ----------
OP_SYMBOLS = {
TokenType.PLUS: '+',
TokenType.MINUS: '-',
TokenType.MULTIPLY: '*',
TokenType.DIVIDE: '/',
TokenType.POWER: '^',
TokenType.UNARY_MINUS: 'neg',
}
def evaluate(node):
"""
Evaluate an AST by walking it bottom-up (post-order traversal).
This is a recursive function that mirrors the DAG structure:
each recursive call follows a DAG edge to a child node.
Children are evaluated before parents -- topological order.
"""
match node:
case NumberNode(value=v):
return v
case UnaryOpNode(op=TokenType.UNARY_MINUS, operand=child):
return -evaluate(child)
case BinOpNode(op=op, left=left, right=right):
left_val = evaluate(left)
right_val = evaluate(right)
match op:
case TokenType.PLUS:
return left_val + right_val
case TokenType.MINUS:
return left_val - right_val
case TokenType.MULTIPLY:
return left_val * right_val
case TokenType.DIVIDE:
if right_val == 0:
raise EvalError("division by zero")
return left_val / right_val
case TokenType.POWER:
return left_val ** right_val
raise EvalError(f"unknown node type: {type(node)}")
def evaluate_traced(node):
"""
Like evaluate(), but records each step for educational display.
Returns (result, list_of_trace_lines).
The trace shows the topological evaluation order -- how the DAG
is evaluated from leaves to root. Each step shows a node being
evaluated after all its dependencies are resolved.
"""
steps = []
counter = [0] # mutable counter for step numbering
def _walk(node, depth):
indent = " " * depth
counter[0] += 1
step = counter[0]
match node:
case NumberNode(value=v):
result = v
display = _format_number(v)
steps.append(f"{indent}Step {step}: {display} => {_format_number(result)}")
return result
case UnaryOpNode(op=TokenType.UNARY_MINUS, operand=child):
child_val = _walk(child, depth + 1)
result = -child_val
counter[0] += 1
step = counter[0]
steps.append(
f"{indent}Step {step}: neg({_format_number(child_val)}) "
f"=> {_format_number(result)}"
)
return result
case BinOpNode(op=op, left=left, right=right):
left_val = _walk(left, depth + 1)
right_val = _walk(right, depth + 1)
sym = OP_SYMBOLS[op]
match op:
case TokenType.PLUS:
result = left_val + right_val
case TokenType.MINUS:
result = left_val - right_val
case TokenType.MULTIPLY:
result = left_val * right_val
case TokenType.DIVIDE:
if right_val == 0:
raise EvalError("division by zero")
result = left_val / right_val
case TokenType.POWER:
result = left_val ** right_val
counter[0] += 1
step = counter[0]
steps.append(
f"{indent}Step {step}: {_format_number(left_val)} {sym} "
f"{_format_number(right_val)} => {_format_number(result)}"
)
return result
raise EvalError(f"unknown node type: {type(node)}")
result = _walk(node, 0)
return result, steps
def _format_number(v):
"""Display a number as integer when possible."""
if isinstance(v, float) and v == int(v):
return str(int(v))
return str(v)

View File

@@ -0,0 +1,163 @@
"""
Expression Evaluator -- Learn DAGs & State Machines
====================================================
CLI entry point and interactive REPL.
Usage:
python main.py "3 + 4 * 2" # evaluate
python main.py # REPL mode
python main.py --show-tokens --show-ast --trace "expr" # show internals
python main.py --dot "3 + 4 * 2" | dot -Tpng -o ast.png
python main.py --dot-fsm | dot -Tpng -o fsm.png
"""
import argparse
import sys
from tokenizer import tokenize, TokenError
from parser import Parser, ParseError
from evaluator import evaluate, evaluate_traced, EvalError
from visualize import ast_to_dot, fsm_to_dot, ast_to_text
def process_expression(expr, args):
"""Tokenize, parse, and evaluate a single expression."""
try:
tokens = tokenize(expr)
except TokenError as e:
_print_error(expr, e)
return
if args.show_tokens:
print("\nTokens:")
for tok in tokens:
print(f" {tok}")
try:
ast = Parser(tokens).parse()
except ParseError as e:
_print_error(expr, e)
return
if args.show_ast:
print("\nAST (text tree):")
print(ast_to_text(ast))
if args.dot:
print(ast_to_dot(ast))
return # dot output goes to stdout, skip numeric result
if args.trace:
try:
result, steps = evaluate_traced(ast)
except EvalError as e:
print(f"Eval error: {e}")
return
print("\nEvaluation trace (topological order):")
for step in steps:
print(step)
print(f"\nResult: {_format_result(result)}")
else:
try:
result = evaluate(ast)
except EvalError as e:
print(f"Eval error: {e}")
return
print(_format_result(result))
def repl(args):
"""Interactive read-eval-print loop."""
print("Expression Evaluator REPL")
print("Type an expression, or 'quit' to exit.")
flags = []
if args.show_tokens:
flags.append("--show-tokens")
if args.show_ast:
flags.append("--show-ast")
if args.trace:
flags.append("--trace")
if flags:
print(f"Active flags: {' '.join(flags)}")
print()
while True:
try:
line = input(">>> ").strip()
except (EOFError, KeyboardInterrupt):
print()
break
if line.lower() in ("quit", "exit", "q"):
break
if not line:
continue
process_expression(line, args)
print()
def _print_error(expr, error):
"""Print an error with a caret pointing to the position."""
print(f"Error: {error}")
if hasattr(error, 'position') and error.position is not None:
print(f" {expr}")
print(f" {' ' * error.position}^")
def _format_result(v):
"""Format a numeric result: show as int when possible."""
if isinstance(v, float) and v == int(v) and abs(v) < 1e15:
return str(int(v))
return str(v)
def main():
arg_parser = argparse.ArgumentParser(
description="Expression Evaluator -- learn DAGs and state machines",
epilog="Examples:\n"
" python main.py '3 + 4 * 2'\n"
" python main.py --show-tokens --trace '-(3 + 4) ^ 2'\n"
" python main.py --dot '(2+3)*4' | dot -Tpng -o ast.png\n"
" python main.py --dot-fsm | dot -Tpng -o fsm.png",
formatter_class=argparse.RawDescriptionHelpFormatter,
)
arg_parser.add_argument(
"expression", nargs="?",
help="Expression to evaluate (omit for REPL mode)",
)
arg_parser.add_argument(
"--show-tokens", action="store_true",
help="Display tokenizer output",
)
arg_parser.add_argument(
"--show-ast", action="store_true",
help="Display AST as indented text tree",
)
arg_parser.add_argument(
"--trace", action="store_true",
help="Show step-by-step evaluation trace",
)
arg_parser.add_argument(
"--dot", action="store_true",
help="Output AST as graphviz dot (pipe to: dot -Tpng -o ast.png)",
)
arg_parser.add_argument(
"--dot-fsm", action="store_true",
help="Output tokenizer FSM as graphviz dot",
)
args = arg_parser.parse_args()
# Special mode: just print the FSM diagram and exit
if args.dot_fsm:
print(fsm_to_dot())
return
# REPL mode if no expression given
if args.expression is None:
repl(args)
else:
process_expression(args.expression, args)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,217 @@
"""
Part 2: DAG Construction -- Recursive Descent Parser
=====================================================
A parser converts a flat list of tokens into a tree structure (AST).
The AST is a DAG (Directed Acyclic Graph) where:
- Nodes are operations (BinOpNode) or values (NumberNode)
- Edges point from parent operations to their operands
- The graph is acyclic because an operation's inputs are always
"simpler" sub-expressions (no circular dependencies)
- It is a tree (a special case of DAG) because no node is shared
This is the same structure as:
- Spreadsheet dependency graphs (cell A1 depends on B1, B2...)
- Build systems (Makefile targets depend on other targets)
- Task scheduling (some tasks must finish before others start)
- Neural network computation graphs (forward pass is a DAG)
Key DAG concepts demonstrated:
- Nodes: operations and values
- Directed edges: from operation to its inputs (dependencies)
- Acyclic: no circular dependencies
- Topological ordering: natural evaluation order (leaves first)
Grammar (BNF) -- precedence is encoded by nesting depth:
expression ::= term ((PLUS | MINUS) term)* # lowest precedence
term ::= unary ((MULTIPLY | DIVIDE) unary)*
unary ::= UNARY_MINUS unary | power
power ::= atom (POWER power)? # right-associative
atom ::= NUMBER | LPAREN expression RPAREN # highest precedence
Call chain: expression -> term -> unary -> power -> atom
This means: +/- binds loosest, then *//, then unary -, then ^, then parens.
So -3^2 = -(3^2) = -9, matching standard math convention.
"""
from dataclasses import dataclass
from tokenizer import Token, TokenType
# ---------- AST node types ----------
# These are the nodes of our DAG. Each node is either a leaf (NumberNode)
# or an interior node with edges pointing to its children (operands).
@dataclass
class NumberNode:
"""Leaf node: a numeric literal. In DAG terms, a node with no outgoing edges."""
value: float
def __repr__(self):
if self.value == int(self.value):
return f"NumberNode({int(self.value)})"
return f"NumberNode({self.value})"
@dataclass
class BinOpNode:
"""
Interior node: a binary operation with two children.
DAG edges: this node -> left, this node -> right
The edges represent "depends on": to compute this node's value,
we must first compute left and right.
"""
op: TokenType
left: 'NumberNode | BinOpNode | UnaryOpNode'
right: 'NumberNode | BinOpNode | UnaryOpNode'
def __repr__(self):
return f"BinOpNode({self.op.name}, {self.left}, {self.right})"
@dataclass
class UnaryOpNode:
"""Interior node: a unary operation (negation) with one child."""
op: TokenType
operand: 'NumberNode | BinOpNode | UnaryOpNode'
def __repr__(self):
return f"UnaryOpNode({self.op.name}, {self.operand})"
# Union type for any AST node
Node = NumberNode | BinOpNode | UnaryOpNode
# ---------- Errors ----------
class ParseError(Exception):
def __init__(self, message, position=None):
self.position = position
pos_info = f" at position {position}" if position is not None else ""
super().__init__(f"Parse error{pos_info}: {message}")
# ---------- Recursive descent parser ----------
class Parser:
"""
Converts a list of tokens into an AST (expression tree / DAG).
Each grammar rule becomes a method. The call tree mirrors the shape
of the AST being built. When a deeper method returns a node, it
becomes a child of the node built by the caller -- this is how
the DAG edges form.
Precedence is encoded by nesting: lower-precedence operators are
parsed at higher (outer) levels, so they become closer to the root
of the tree and are evaluated last.
"""
def __init__(self, tokens):
self.tokens = tokens
self.pos = 0
def peek(self):
"""Look at the current token without consuming it."""
return self.tokens[self.pos]
def consume(self, expected=None):
"""Consume and return the current token, optionally asserting its type."""
token = self.tokens[self.pos]
if expected is not None and token.type != expected:
raise ParseError(
f"expected {expected.name}, got {token.type.name}",
token.position,
)
self.pos += 1
return token
def parse(self):
"""Entry point: parse the full expression and verify we consumed everything."""
if self.peek().type == TokenType.EOF:
raise ParseError("empty expression")
node = self.expression()
self.consume(TokenType.EOF)
return node
# --- Grammar rules ---
# Each method corresponds to one production in the grammar.
# The nesting encodes operator precedence.
def expression(self):
"""expression ::= term ((PLUS | MINUS) term)*"""
node = self.term()
while self.peek().type in (TokenType.PLUS, TokenType.MINUS):
op_token = self.consume()
right = self.term()
# Build a new BinOpNode -- this creates a DAG edge from
# the new node to both 'node' (left) and 'right'
node = BinOpNode(op_token.type, node, right)
return node
def term(self):
"""term ::= unary ((MULTIPLY | DIVIDE) unary)*"""
node = self.unary()
while self.peek().type in (TokenType.MULTIPLY, TokenType.DIVIDE):
op_token = self.consume()
right = self.unary()
node = BinOpNode(op_token.type, node, right)
return node
def unary(self):
"""
unary ::= UNARY_MINUS unary | power
Unary minus is parsed here, between term and power, so it binds
looser than ^ but tighter than * and /. This gives the standard
math behavior: -3^2 = -(3^2) = -9.
The recursion (unary calls itself) handles double negation: --3 = 3.
"""
if self.peek().type == TokenType.UNARY_MINUS:
op_token = self.consume()
operand = self.unary()
return UnaryOpNode(op_token.type, operand)
return self.power()
def power(self):
"""
power ::= atom (POWER power)?
Right-recursive for right-associativity: 2^3^4 = 2^(3^4) = 2^81.
Compare with term() which uses a while loop for LEFT-associativity.
"""
node = self.atom()
if self.peek().type == TokenType.POWER:
op_token = self.consume()
right = self.power() # recurse (not loop) for right-associativity
node = BinOpNode(op_token.type, node, right)
return node
def atom(self):
"""
atom ::= NUMBER | LPAREN expression RPAREN
The base case: either a literal number or a parenthesized
sub-expression. Parentheses work by recursing back to
expression(), which restarts precedence parsing from the top.
"""
token = self.peek()
if token.type == TokenType.NUMBER:
self.consume()
return NumberNode(float(token.value))
if token.type == TokenType.LPAREN:
self.consume()
node = self.expression()
self.consume(TokenType.RPAREN)
return node
raise ParseError(
f"expected number or '(', got {token.type.name}",
token.position,
)

View File

@@ -0,0 +1,120 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import pytest
from tokenizer import tokenize
from parser import Parser
from evaluator import evaluate, evaluate_traced, EvalError
def eval_expr(expr):
"""Helper: tokenize -> parse -> evaluate in one step."""
tokens = tokenize(expr)
ast = Parser(tokens).parse()
return evaluate(ast)
# ---------- Basic arithmetic ----------
def test_addition():
assert eval_expr("3 + 4") == 7.0
def test_subtraction():
assert eval_expr("10 - 3") == 7.0
def test_multiplication():
assert eval_expr("3 * 4") == 12.0
def test_division():
assert eval_expr("10 / 4") == 2.5
def test_power():
assert eval_expr("2 ^ 10") == 1024.0
# ---------- Precedence ----------
def test_standard_precedence():
assert eval_expr("3 + 4 * 2") == 11.0
def test_parentheses():
assert eval_expr("(3 + 4) * 2") == 14.0
def test_power_precedence():
assert eval_expr("2 * 3 ^ 2") == 18.0
def test_right_associative_power():
# 2^(2^3) = 2^8 = 256
assert eval_expr("2 ^ 2 ^ 3") == 256.0
# ---------- Unary minus ----------
def test_negation():
assert eval_expr("-5") == -5.0
def test_double_negation():
assert eval_expr("--5") == 5.0
def test_negation_with_power():
# -(3^2) = -9, not (-3)^2 = 9
assert eval_expr("-3 ^ 2") == -9.0
def test_negation_in_parens():
assert eval_expr("(-3) ^ 2") == 9.0
# ---------- Decimals ----------
def test_decimal_addition():
assert eval_expr("0.1 + 0.2") == pytest.approx(0.3)
def test_leading_dot():
assert eval_expr(".5 + .5") == 1.0
# ---------- Edge cases ----------
def test_nested_parens():
assert eval_expr("((((3))))") == 3.0
def test_complex_expression():
assert eval_expr("(2 + 3) * (7 - 2) / 5 ^ 1") == 5.0
def test_long_chain():
assert eval_expr("1 + 2 + 3 + 4 + 5") == 15.0
def test_mixed_operations():
assert eval_expr("2 + 3 * 4 - 6 / 2") == 11.0
# ---------- Division by zero ----------
def test_division_by_zero():
with pytest.raises(EvalError):
eval_expr("1 / 0")
def test_division_by_zero_in_expression():
with pytest.raises(EvalError):
eval_expr("5 + 3 / (2 - 2)")
# ---------- Traced evaluation ----------
def test_traced_returns_correct_result():
tokens = tokenize("3 + 4 * 2")
ast = Parser(tokens).parse()
result, steps = evaluate_traced(ast)
assert result == 11.0
assert len(steps) > 0
def test_traced_step_count():
"""A simple binary op has 3 evaluation events: left, right, combine."""
tokens = tokenize("3 + 4")
ast = Parser(tokens).parse()
result, steps = evaluate_traced(ast)
assert result == 7.0
# NumberNode(3), NumberNode(4), BinOp(+)
assert len(steps) == 3

View File

@@ -0,0 +1,136 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import pytest
from tokenizer import tokenize, TokenType
from parser import Parser, ParseError, NumberNode, BinOpNode, UnaryOpNode
def parse(expr):
"""Helper: tokenize and parse in one step."""
return Parser(tokenize(expr)).parse()
# ---------- Basic parsing ----------
def test_parse_number():
ast = parse("42")
assert isinstance(ast, NumberNode)
assert ast.value == 42.0
def test_parse_decimal():
ast = parse("3.14")
assert isinstance(ast, NumberNode)
assert ast.value == 3.14
def test_parse_addition():
ast = parse("3 + 4")
assert isinstance(ast, BinOpNode)
assert ast.op == TokenType.PLUS
assert isinstance(ast.left, NumberNode)
assert isinstance(ast.right, NumberNode)
# ---------- Precedence ----------
def test_multiply_before_add():
"""3 + 4 * 2 should parse as 3 + (4 * 2)."""
ast = parse("3 + 4 * 2")
assert ast.op == TokenType.PLUS
assert isinstance(ast.right, BinOpNode)
assert ast.right.op == TokenType.MULTIPLY
def test_power_before_multiply():
"""2 * 3 ^ 4 should parse as 2 * (3 ^ 4)."""
ast = parse("2 * 3 ^ 4")
assert ast.op == TokenType.MULTIPLY
assert isinstance(ast.right, BinOpNode)
assert ast.right.op == TokenType.POWER
def test_parentheses_override_precedence():
"""(3 + 4) * 2 should parse as (3 + 4) * 2."""
ast = parse("(3 + 4) * 2")
assert ast.op == TokenType.MULTIPLY
assert isinstance(ast.left, BinOpNode)
assert ast.left.op == TokenType.PLUS
# ---------- Associativity ----------
def test_left_associative_subtraction():
"""10 - 3 - 2 should parse as (10 - 3) - 2."""
ast = parse("10 - 3 - 2")
assert ast.op == TokenType.MINUS
assert isinstance(ast.left, BinOpNode)
assert ast.left.op == TokenType.MINUS
assert isinstance(ast.right, NumberNode)
def test_power_right_associative():
"""2 ^ 3 ^ 4 should parse as 2 ^ (3 ^ 4)."""
ast = parse("2 ^ 3 ^ 4")
assert ast.op == TokenType.POWER
assert isinstance(ast.left, NumberNode)
assert isinstance(ast.right, BinOpNode)
assert ast.right.op == TokenType.POWER
# ---------- Unary minus ----------
def test_unary_minus():
ast = parse("-3")
assert isinstance(ast, UnaryOpNode)
assert ast.operand.value == 3.0
def test_double_negation():
ast = parse("--3")
assert isinstance(ast, UnaryOpNode)
assert isinstance(ast.operand, UnaryOpNode)
assert ast.operand.operand.value == 3.0
def test_unary_minus_precedence():
"""-3^2 should parse as -(3^2), not (-3)^2."""
ast = parse("-3 ^ 2")
assert isinstance(ast, UnaryOpNode)
assert isinstance(ast.operand, BinOpNode)
assert ast.operand.op == TokenType.POWER
def test_unary_minus_in_expression():
"""2 * -3 should parse as 2 * (-(3))."""
ast = parse("2 * -3")
assert ast.op == TokenType.MULTIPLY
assert isinstance(ast.right, UnaryOpNode)
# ---------- Nested parentheses ----------
def test_nested_parens():
ast = parse("((3))")
assert isinstance(ast, NumberNode)
assert ast.value == 3.0
def test_complex_nesting():
"""((2 + 3) * (7 - 2))"""
ast = parse("((2 + 3) * (7 - 2))")
assert isinstance(ast, BinOpNode)
assert ast.op == TokenType.MULTIPLY
# ---------- Errors ----------
def test_missing_rparen():
with pytest.raises(ParseError):
parse("(3 + 4")
def test_empty_expression():
with pytest.raises(ParseError):
parse("")
def test_trailing_operator():
with pytest.raises(ParseError):
parse("3 +")
def test_empty_parens():
with pytest.raises(ParseError):
parse("()")

View File

@@ -0,0 +1,139 @@
import sys
from pathlib import Path
sys.path.insert(0, str(Path(__file__).parent.parent))
import pytest
from tokenizer import tokenize, TokenType, Token, TokenError
# ---------- Basic tokens ----------
def test_single_integer():
tokens = tokenize("42")
assert tokens[0].type == TokenType.NUMBER
assert tokens[0].value == "42"
def test_decimal_number():
tokens = tokenize("3.14")
assert tokens[0].type == TokenType.NUMBER
assert tokens[0].value == "3.14"
def test_leading_dot():
tokens = tokenize(".5")
assert tokens[0].type == TokenType.NUMBER
assert tokens[0].value == ".5"
def test_all_operators():
"""Operators between numbers are all binary."""
tokens = tokenize("1 + 1 - 1 * 1 / 1 ^ 1")
ops = [t.type for t in tokens if t.type not in (TokenType.NUMBER, TokenType.EOF)]
assert ops == [
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
TokenType.DIVIDE, TokenType.POWER,
]
def test_operators_between_numbers():
tokens = tokenize("1 + 2 - 3 * 4 / 5 ^ 6")
ops = [t.type for t in tokens if t.type not in (TokenType.NUMBER, TokenType.EOF)]
assert ops == [
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
TokenType.DIVIDE, TokenType.POWER,
]
def test_parentheses():
tokens = tokenize("()")
assert tokens[0].type == TokenType.LPAREN
assert tokens[1].type == TokenType.RPAREN
# ---------- Unary minus ----------
def test_unary_minus_at_start():
tokens = tokenize("-3")
assert tokens[0].type == TokenType.UNARY_MINUS
assert tokens[1].type == TokenType.NUMBER
def test_unary_minus_after_lparen():
tokens = tokenize("(-3)")
assert tokens[1].type == TokenType.UNARY_MINUS
def test_unary_minus_after_operator():
tokens = tokenize("2 * -3")
assert tokens[2].type == TokenType.UNARY_MINUS
def test_binary_minus():
tokens = tokenize("5 - 3")
assert tokens[1].type == TokenType.MINUS
def test_double_unary_minus():
tokens = tokenize("--3")
assert tokens[0].type == TokenType.UNARY_MINUS
assert tokens[1].type == TokenType.UNARY_MINUS
assert tokens[2].type == TokenType.NUMBER
# ---------- Whitespace handling ----------
def test_no_spaces():
tokens = tokenize("3+4")
non_eof = [t for t in tokens if t.type != TokenType.EOF]
assert len(non_eof) == 3
def test_extra_spaces():
tokens = tokenize(" 3 + 4 ")
non_eof = [t for t in tokens if t.type != TokenType.EOF]
assert len(non_eof) == 3
# ---------- Position tracking ----------
def test_positions():
tokens = tokenize("3 + 4")
assert tokens[0].position == 0 # '3'
assert tokens[1].position == 2 # '+'
assert tokens[2].position == 4 # '4'
# ---------- Errors ----------
def test_invalid_character():
with pytest.raises(TokenError):
tokenize("3 & 4")
def test_double_dot():
with pytest.raises(TokenError):
tokenize("3.14.15")
# ---------- EOF token ----------
def test_eof_always_present():
tokens = tokenize("42")
assert tokens[-1].type == TokenType.EOF
def test_empty_input():
tokens = tokenize("")
assert len(tokens) == 1
assert tokens[0].type == TokenType.EOF
# ---------- Complex expressions ----------
def test_complex_expression():
tokens = tokenize("(3 + 4.5) * -2 ^ 3")
types = [t.type for t in tokens if t.type != TokenType.EOF]
assert types == [
TokenType.LPAREN, TokenType.NUMBER, TokenType.PLUS,
TokenType.NUMBER, TokenType.RPAREN, TokenType.MULTIPLY,
TokenType.UNARY_MINUS, TokenType.NUMBER, TokenType.POWER,
TokenType.NUMBER,
]
def test_adjacent_parens():
tokens = tokenize("(3)(4)")
types = [t.type for t in tokens if t.type != TokenType.EOF]
assert types == [
TokenType.LPAREN, TokenType.NUMBER, TokenType.RPAREN,
TokenType.LPAREN, TokenType.NUMBER, TokenType.RPAREN,
]

View File

@@ -0,0 +1,306 @@
"""
Part 1: State Machine Tokenizer
================================
A tokenizer (lexer) converts raw text into a stream of tokens.
This implementation uses an EXPLICIT finite state machine (FSM):
- States are named values (an enum), not implicit control flow
- A transition table maps (current_state, input_class) -> (next_state, action)
- The main loop reads one character at a time and consults the table
This is the same pattern used in:
- Network protocol parsers (HTTP, TCP state machines)
- Regular expression engines
- Compiler front-ends (lexers for C, Python, etc.)
- Game AI (enemy behavior states)
Key FSM concepts demonstrated:
- States: the "memory" of what we're currently building
- Transitions: rules for moving between states based on input
- Actions: side effects (emit a token, accumulate a character)
- Mealy machine: outputs depend on both state AND input
"""
from dataclasses import dataclass
from enum import Enum
# ---------- Token types ----------
class TokenType(Enum):
NUMBER = "NUMBER"
PLUS = "PLUS"
MINUS = "MINUS"
MULTIPLY = "MULTIPLY"
DIVIDE = "DIVIDE"
POWER = "POWER"
LPAREN = "LPAREN"
RPAREN = "RPAREN"
UNARY_MINUS = "UNARY_MINUS"
EOF = "EOF"
@dataclass
class Token:
type: TokenType
value: str # raw text: "42", "+", "(", etc.
position: int # character offset in original expression
def __repr__(self):
return f"Token({self.type.name}, {self.value!r}, pos={self.position})"
OPERATOR_MAP = {
'+': TokenType.PLUS,
'-': TokenType.MINUS,
'*': TokenType.MULTIPLY,
'/': TokenType.DIVIDE,
'^': TokenType.POWER,
}
# ---------- FSM state definitions ----------
class State(Enum):
"""
The tokenizer's finite set of states.
START -- idle / between tokens, deciding what comes next
INTEGER -- accumulating digits of an integer (e.g. "12" so far)
DECIMAL -- accumulating digits after a decimal point (e.g. "12.3" so far)
"""
START = "START"
INTEGER = "INTEGER"
DECIMAL = "DECIMAL"
class CharClass(Enum):
"""
Character classification -- groups raw characters into categories
so the transition table stays small and readable.
"""
DIGIT = "DIGIT"
DOT = "DOT"
OPERATOR = "OPERATOR"
LPAREN = "LPAREN"
RPAREN = "RPAREN"
SPACE = "SPACE"
EOF = "EOF"
UNKNOWN = "UNKNOWN"
class Action(Enum):
"""
What the FSM does on a transition. In a Mealy machine, the output
(action) depends on both the current state AND the input.
"""
ACCUMULATE = "ACCUMULATE"
EMIT_NUMBER = "EMIT_NUMBER"
EMIT_OPERATOR = "EMIT_OPERATOR"
EMIT_LPAREN = "EMIT_LPAREN"
EMIT_RPAREN = "EMIT_RPAREN"
EMIT_NUMBER_THEN_OP = "EMIT_NUMBER_THEN_OP"
EMIT_NUMBER_THEN_LPAREN = "EMIT_NUMBER_THEN_LPAREN"
EMIT_NUMBER_THEN_RPAREN = "EMIT_NUMBER_THEN_RPAREN"
EMIT_NUMBER_THEN_DONE = "EMIT_NUMBER_THEN_DONE"
SKIP = "SKIP"
DONE = "DONE"
ERROR = "ERROR"
@dataclass(frozen=True)
class Transition:
next_state: State
action: Action
# ---------- Transition table ----------
# This is the heart of the state machine. Every (state, char_class) pair
# maps to exactly one transition: a next state and an action to perform.
# Making this a data structure (not nested if/else) means we can:
# 1. Inspect it programmatically (e.g. to generate a diagram)
# 2. Verify completeness (every combination is covered)
# 3. Understand the FSM at a glance
TRANSITIONS = {
# --- START: between tokens, dispatch based on character class ---
(State.START, CharClass.DIGIT): Transition(State.INTEGER, Action.ACCUMULATE),
(State.START, CharClass.DOT): Transition(State.DECIMAL, Action.ACCUMULATE),
(State.START, CharClass.OPERATOR): Transition(State.START, Action.EMIT_OPERATOR),
(State.START, CharClass.LPAREN): Transition(State.START, Action.EMIT_LPAREN),
(State.START, CharClass.RPAREN): Transition(State.START, Action.EMIT_RPAREN),
(State.START, CharClass.SPACE): Transition(State.START, Action.SKIP),
(State.START, CharClass.EOF): Transition(State.START, Action.DONE),
# --- INTEGER: accumulating digits like "123" ---
(State.INTEGER, CharClass.DIGIT): Transition(State.INTEGER, Action.ACCUMULATE),
(State.INTEGER, CharClass.DOT): Transition(State.DECIMAL, Action.ACCUMULATE),
(State.INTEGER, CharClass.OPERATOR): Transition(State.START, Action.EMIT_NUMBER_THEN_OP),
(State.INTEGER, CharClass.LPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_LPAREN),
(State.INTEGER, CharClass.RPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_RPAREN),
(State.INTEGER, CharClass.SPACE): Transition(State.START, Action.EMIT_NUMBER),
(State.INTEGER, CharClass.EOF): Transition(State.START, Action.EMIT_NUMBER_THEN_DONE),
# --- DECIMAL: accumulating digits after "." like "123.45" ---
(State.DECIMAL, CharClass.DIGIT): Transition(State.DECIMAL, Action.ACCUMULATE),
(State.DECIMAL, CharClass.DOT): Transition(State.START, Action.ERROR),
(State.DECIMAL, CharClass.OPERATOR): Transition(State.START, Action.EMIT_NUMBER_THEN_OP),
(State.DECIMAL, CharClass.LPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_LPAREN),
(State.DECIMAL, CharClass.RPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_RPAREN),
(State.DECIMAL, CharClass.SPACE): Transition(State.START, Action.EMIT_NUMBER),
(State.DECIMAL, CharClass.EOF): Transition(State.START, Action.EMIT_NUMBER_THEN_DONE),
}
# ---------- Errors ----------
class TokenError(Exception):
def __init__(self, message, position):
self.position = position
super().__init__(f"Token error at position {position}: {message}")
# ---------- Character classification ----------
def classify(ch):
"""Map a single character to its CharClass."""
if ch.isdigit():
return CharClass.DIGIT
if ch == '.':
return CharClass.DOT
if ch in OPERATOR_MAP:
return CharClass.OPERATOR
if ch == '(':
return CharClass.LPAREN
if ch == ')':
return CharClass.RPAREN
if ch.isspace():
return CharClass.SPACE
return CharClass.UNKNOWN
# ---------- Main tokenize function ----------
def tokenize(expression):
"""
Process an expression string through the state machine, producing tokens.
The main loop:
1. Classify the current character
2. Look up (state, char_class) in the transition table
3. Execute the action (accumulate, emit, skip, etc.)
4. Move to the next state
5. Advance to the next character
After all tokens are emitted, a post-processing step resolves
unary minus: if a MINUS token appears at the start, after an operator,
or after LPAREN, it is re-classified as UNARY_MINUS.
"""
state = State.START
buffer = [] # characters accumulated for the current token
buffer_start = 0 # position where the current buffer started
tokens = []
pos = 0
# Append a sentinel so EOF is handled uniformly in the loop
chars = expression + '\0'
while pos <= len(expression):
ch = chars[pos]
char_class = CharClass.EOF if pos == len(expression) else classify(ch)
if char_class == CharClass.UNKNOWN:
raise TokenError(f"unexpected character {ch!r}", pos)
# Look up the transition
key = (state, char_class)
transition = TRANSITIONS.get(key)
if transition is None:
raise TokenError(f"no transition for state={state.name}, input={char_class.name}", pos)
action = transition.action
next_state = transition.next_state
# --- Execute the action ---
if action == Action.ACCUMULATE:
if not buffer:
buffer_start = pos
buffer.append(ch)
elif action == Action.EMIT_NUMBER:
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
buffer.clear()
elif action == Action.EMIT_OPERATOR:
tokens.append(Token(OPERATOR_MAP[ch], ch, pos))
elif action == Action.EMIT_LPAREN:
tokens.append(Token(TokenType.LPAREN, ch, pos))
elif action == Action.EMIT_RPAREN:
tokens.append(Token(TokenType.RPAREN, ch, pos))
elif action == Action.EMIT_NUMBER_THEN_OP:
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
buffer.clear()
tokens.append(Token(OPERATOR_MAP[ch], ch, pos))
elif action == Action.EMIT_NUMBER_THEN_LPAREN:
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
buffer.clear()
tokens.append(Token(TokenType.LPAREN, ch, pos))
elif action == Action.EMIT_NUMBER_THEN_RPAREN:
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
buffer.clear()
tokens.append(Token(TokenType.RPAREN, ch, pos))
elif action == Action.EMIT_NUMBER_THEN_DONE:
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
buffer.clear()
elif action == Action.SKIP:
pass
elif action == Action.DONE:
pass
elif action == Action.ERROR:
raise TokenError(f"unexpected {ch!r} in state {state.name}", pos)
state = next_state
pos += 1
# --- Post-processing: resolve unary minus ---
# A MINUS is unary if it appears:
# - at the very start of the token stream
# - immediately after an operator (+, -, *, /, ^) or LPAREN
# This context-sensitivity cannot be captured by the FSM alone --
# it requires looking at previously emitted tokens.
_resolve_unary_minus(tokens)
tokens.append(Token(TokenType.EOF, '', len(expression)))
return tokens
def _resolve_unary_minus(tokens):
"""
Convert binary MINUS tokens to UNARY_MINUS where appropriate.
Why this isn't in the FSM: the FSM processes characters one at a time
and only tracks what kind of token it's currently building (its state).
But whether '-' is unary or binary depends on the PREVIOUS TOKEN --
information the FSM doesn't track. This is a common real-world pattern:
the lexer handles most work, then a lightweight post-pass adds context.
"""
unary_predecessor = {
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
TokenType.DIVIDE, TokenType.POWER, TokenType.LPAREN,
TokenType.UNARY_MINUS,
}
for i, token in enumerate(tokens):
if token.type != TokenType.MINUS:
continue
if i == 0 or tokens[i - 1].type in unary_predecessor:
tokens[i] = Token(TokenType.UNARY_MINUS, token.value, token.position)

View File

@@ -0,0 +1,200 @@
"""
Part 4: Visualization -- Graphviz Dot Output
==============================================
Generate graphviz dot-format strings for:
1. The tokenizer's finite state machine (FSM)
2. Any expression's AST (DAG)
3. Text-based tree rendering for the terminal
No external dependencies -- outputs raw dot strings that can be piped
to the 'dot' command: python main.py --dot "3+4*2" | dot -Tpng -o ast.png
"""
from parser import NumberNode, BinOpNode, UnaryOpNode, Node
from tokenizer import TRANSITIONS, State, CharClass, Action, TokenType
# ---------- FSM diagram ----------
# Human-readable labels for character classes
_CHAR_LABELS = {
CharClass.DIGIT: "digit",
CharClass.DOT: "'.'",
CharClass.OPERATOR: "op",
CharClass.LPAREN: "'('",
CharClass.RPAREN: "')'",
CharClass.SPACE: "space",
CharClass.EOF: "EOF",
}
# Short labels for actions
_ACTION_LABELS = {
Action.ACCUMULATE: "accum",
Action.EMIT_NUMBER: "emit num",
Action.EMIT_OPERATOR: "emit op",
Action.EMIT_LPAREN: "emit '('",
Action.EMIT_RPAREN: "emit ')'",
Action.EMIT_NUMBER_THEN_OP: "emit num+op",
Action.EMIT_NUMBER_THEN_LPAREN: "emit num+'('",
Action.EMIT_NUMBER_THEN_RPAREN: "emit num+')'",
Action.EMIT_NUMBER_THEN_DONE: "emit num, done",
Action.SKIP: "skip",
Action.DONE: "done",
Action.ERROR: "ERROR",
}
def fsm_to_dot():
"""
Generate a graphviz dot diagram of the tokenizer's state machine.
Reads the TRANSITIONS table directly -- because the FSM is data (a dict),
we can programmatically inspect and visualize it. This is a key advantage
of explicit state machines over implicit if/else control flow.
"""
lines = [
'digraph FSM {',
' rankdir=LR;',
' node [shape=circle, fontname="Helvetica"];',
' edge [fontname="Helvetica", fontsize=10];',
'',
' // Start indicator',
' __start__ [shape=point, width=0.2];',
' __start__ -> START;',
'',
]
# Collect edges grouped by (src, dst) to merge labels
edge_labels = {}
for (state, char_class), transition in TRANSITIONS.items():
src = state.name
dst = transition.next_state.name
char_label = _CHAR_LABELS.get(char_class, char_class.name)
action_label = _ACTION_LABELS.get(transition.action, transition.action.name)
label = f"{char_label} / {action_label}"
edge_labels.setdefault((src, dst), []).append(label)
# Emit edges
for (src, dst), labels in sorted(edge_labels.items()):
combined = "\\n".join(labels)
lines.append(f' {src} -> {dst} [label="{combined}"];')
lines.append('}')
return '\n'.join(lines)
# ---------- AST diagram ----------
_OP_LABELS = {
TokenType.PLUS: '+',
TokenType.MINUS: '-',
TokenType.MULTIPLY: '*',
TokenType.DIVIDE: '/',
TokenType.POWER: '^',
TokenType.UNARY_MINUS: 'neg',
}
def ast_to_dot(node):
"""
Generate a graphviz dot diagram of an AST (expression tree / DAG).
Each node gets a unique ID. Edges go from parent to children,
showing the directed acyclic structure. Leaves are boxed,
operators are ellipses.
"""
lines = [
'digraph AST {',
' node [fontname="Helvetica"];',
' edge [fontname="Helvetica"];',
'',
]
counter = [0]
def _visit(node):
nid = f"n{counter[0]}"
counter[0] += 1
match node:
case NumberNode(value=v):
label = _format_number(v)
lines.append(f' {nid} [label="{label}", shape=box, style=rounded];')
return nid
case UnaryOpNode(op=op, operand=child):
label = _OP_LABELS.get(op, op.name)
lines.append(f' {nid} [label="{label}", shape=ellipse];')
child_id = _visit(child)
lines.append(f' {nid} -> {child_id};')
return nid
case BinOpNode(op=op, left=left, right=right):
label = _OP_LABELS.get(op, op.name)
lines.append(f' {nid} [label="{label}", shape=ellipse];')
left_id = _visit(left)
right_id = _visit(right)
lines.append(f' {nid} -> {left_id} [label="L"];')
lines.append(f' {nid} -> {right_id} [label="R"];')
return nid
_visit(node)
lines.append('}')
return '\n'.join(lines)
# ---------- Text-based tree ----------
def ast_to_text(node, prefix="", connector=""):
"""
Render the AST as an indented text tree for terminal display.
Example output for (2 + 3) * 4:
*
+-- +
| +-- 2
| +-- 3
+-- 4
"""
match node:
case NumberNode(value=v):
label = _format_number(v)
case UnaryOpNode(op=op):
label = _OP_LABELS.get(op, op.name)
case BinOpNode(op=op):
label = _OP_LABELS.get(op, op.name)
lines = [f"{prefix}{connector}{label}"]
children = _get_children(node)
for i, child in enumerate(children):
is_last_child = (i == len(children) - 1)
if connector:
# Extend the prefix: if we used "+-- " then next children
# see "| " (continuing) or " " (last child)
child_prefix = prefix + ("| " if connector == "+-- " else " ")
else:
child_prefix = prefix
child_connector = "+-- " if is_last_child else "+-- "
# Use a different lead for non-last: the vertical bar continues
child_connector = "`-- " if is_last_child else "+-- "
child_lines = ast_to_text(child, child_prefix, child_connector)
lines.append(child_lines)
return '\n'.join(lines)
def _get_children(node):
match node:
case NumberNode():
return []
case UnaryOpNode(operand=child):
return [child]
case BinOpNode(left=left, right=right):
return [left, right]
return []
def _format_number(v):
if isinstance(v, float) and v == int(v):
return str(int(v))
return str(v)

View File

@@ -6,9 +6,18 @@ import ollama
DEFAULT_OLLAMA_MODEL = "qwen2.5:7b"
_ollama_model = DEFAULT_OLLAMA_MODEL
def ask_ollama(prompt, system=None, model=DEFAULT_OLLAMA_MODEL):
def set_ollama_model(model):
"""Change the Ollama model used for fast queries."""
global _ollama_model
_ollama_model = model
def ask_ollama(prompt, system=None, model=None):
"""Query Ollama with an optional system prompt."""
model = model or _ollama_model
messages = []
if system:
messages.append({"role": "system", "content": system})
@@ -24,6 +33,8 @@ def ask_claude(prompt):
capture_output=True,
text=True,
)
if result.returncode != 0:
raise RuntimeError(f"Claude CLI failed (exit {result.returncode}): {result.stderr.strip()}")
return result.stdout.strip()
@@ -34,8 +45,9 @@ def ask(prompt, system=None, quality="fast"):
return ask_ollama(prompt, system=system)
def chat_ollama(messages, system=None, model=DEFAULT_OLLAMA_MODEL):
def chat_ollama(messages, system=None, model=None):
"""Multi-turn conversation with Ollama."""
model = model or _ollama_model
all_messages = []
if system:
all_messages.append({"role": "system", "content": system})

View File

@@ -1,7 +1,6 @@
"""Generate Anki .apkg decks from vocabulary data."""
import genanki
import random
# Stable model/deck IDs (generated once, kept constant)
_MODEL_ID = 1607392319

View File

@@ -7,6 +7,7 @@ import time
import gradio as gr
import ai
import db
from modules import vocab, dashboard, essay, tutor, idioms
from modules.essay import GCSE_THEMES
@@ -214,6 +215,15 @@ def do_anki_export(cats_selected):
return path
def update_ollama_model(model):
ai.set_ollama_model(model)
def update_whisper_size(size):
from stt import set_whisper_size
set_whisper_size(size)
def reset_progress():
conn = db.get_connection()
conn.execute("DELETE FROM word_progress")
@@ -491,6 +501,10 @@ with gr.Blocks(title="Persian Language Tutor") as app:
export_btn.click(fn=do_anki_export, inputs=[export_cats], outputs=[export_file])
# Wire model settings
ollama_model.change(fn=update_ollama_model, inputs=[ollama_model])
whisper_size.change(fn=update_whisper_size, inputs=[whisper_size])
gr.Markdown("### Reset")
reset_btn = gr.Button("Reset All Progress", variant="stop")
reset_status = gr.Markdown("")

View File

@@ -2,7 +2,7 @@
import json
import sqlite3
from datetime import datetime, timezone
from datetime import datetime, timedelta, timezone
from pathlib import Path
import fsrs
@@ -148,6 +148,13 @@ def get_word_counts(total_vocab_size=0):
}
def get_all_word_progress():
"""Return all word progress as a dict of word_id -> progress dict."""
conn = get_connection()
rows = conn.execute("SELECT * FROM word_progress").fetchall()
return {row["word_id"]: dict(row) for row in rows}
def record_quiz_session(category, total_questions, correct, duration_seconds):
"""Log a completed flashcard session."""
conn = get_connection()
@@ -203,7 +210,7 @@ def get_stats():
today = datetime.now(timezone.utc).date()
for i, row in enumerate(days):
day = datetime.fromisoformat(row["d"]).date() if isinstance(row["d"], str) else row["d"]
expected = today - __import__("datetime").timedelta(days=i)
expected = today - timedelta(days=i)
if day == expected:
streak += 1
else:

View File

@@ -19,17 +19,17 @@ def get_category_breakdown():
"""Return progress per category as list of dicts."""
vocab = load_vocab()
categories = get_categories()
all_progress = db.get_all_word_progress()
breakdown = []
for cat in categories:
cat_words = [e for e in vocab if e["category"] == cat]
cat_ids = {e["id"] for e in cat_words}
total = len(cat_words)
seen = 0
mastered = 0
for wid in cat_ids:
progress = db.get_word_progress(wid)
for e in cat_words:
progress = all_progress.get(e["id"])
if progress:
seen += 1
if progress["stability"] and progress["stability"] > 10:

View File

@@ -84,8 +84,9 @@ def get_flashcard_batch(count=10, category=None):
remaining = count - len(due_entries)
if remaining > 0:
seen_ids = {e["id"] for e in due_entries}
all_progress = db.get_all_word_progress()
# Prefer unseen words
unseen = [e for e in pool if e["id"] not in seen_ids and not db.get_word_progress(e["id"])]
unseen = [e for e in pool if e["id"] not in seen_ids and e["id"] not in all_progress]
if len(unseen) >= remaining:
fill = random.sample(unseen, remaining)
else:

View File

@@ -1,13 +1,17 @@
"""Persian speech-to-text wrapper using sttlib."""
import sys
from pathlib import Path
import numpy as np
sys.path.insert(0, "/home/ys/family-repo/Code/python/tool-speechtotext")
# sttlib lives in sibling project tool-speechtotext
_sttlib_path = str(Path(__file__).resolve().parent.parent / "tool-speechtotext")
sys.path.insert(0, _sttlib_path)
from sttlib import load_whisper_model, transcribe, is_hallucination
_model = None
_whisper_size = "medium"
# Common Whisper hallucinations in Persian/silence
PERSIAN_HALLUCINATIONS = [
@@ -18,11 +22,19 @@ PERSIAN_HALLUCINATIONS = [
]
def get_model(size="medium"):
def set_whisper_size(size):
"""Change the Whisper model size. Reloads on next transcription."""
global _whisper_size, _model
if size != _whisper_size:
_whisper_size = size
_model = None
def get_model():
"""Load Whisper model (cached singleton)."""
global _model
if _model is None:
_model = load_whisper_model(size)
_model = load_whisper_model(_whisper_size)
return _model