Compare commits
4 Commits
2e8c2c11d0
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
01d5532823 | ||
|
|
3a8705ece8 | ||
|
|
8b5eb8797f | ||
|
|
d484f9c236 |
5
python/american-mc/.vscode/settings.json
vendored
Normal file
5
python/american-mc/.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
||||
"python-envs.pythonProjects": []
|
||||
}
|
||||
19
python/american-mc/CONTEXTRESUME.md
Normal file
19
python/american-mc/CONTEXTRESUME.md
Normal file
@@ -0,0 +1,19 @@
|
||||
Current Progress Summary:
|
||||
|
||||
The Baseline: We fixed a standard Longstaff-Schwartz (LSM) American Put pricer, correcting the cash-flow propagation logic and regression targets.
|
||||
|
||||
The Evolution: We moved to Bermudan Swaptions using the Hull-White One-Factor Model.
|
||||
|
||||
The "Gold Standard": We implemented a 100% exact simulation from scratch. Instead of Euler discretization, we used Bivariate Normal sampling to jointly simulate the short rate rt and the stochastic integral ∫rsds. This accounts for the stochastic discount factor (the convexity adjustment) without approximation error.
|
||||
|
||||
The Current Frontier: We were debating the Risk-Neutral Measure (Q) vs. the Terminal Forward Measure (QT).
|
||||
|
||||
We concluded that while QT simplifies European options, it makes Bermudan LSM "messy" because it introduces a time-dependent drift shift: DriftQT=DriftQ−aσ2(1−e−a(T−t)).
|
||||
|
||||
Pending Topics:
|
||||
|
||||
Mathematical Proof: The derivation of the "Drift Shift" via Girsanov’s Theorem.
|
||||
|
||||
Exercise Boundary Impact: How the choice of measure (and the resulting drift) visually shifts the optimal exercise boundary in simulation.
|
||||
|
||||
Beyond One-Factor: Potential move toward Two-Factor models or non-flat initial term structures.
|
||||
145
python/american-mc/main.py
Normal file
145
python/american-mc/main.py
Normal file
@@ -0,0 +1,145 @@
|
||||
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from scipy.stats import norm
|
||||
|
||||
# -----------------------------
|
||||
# Black-Scholes European Put
|
||||
# -----------------------------
|
||||
def european_put_black_scholes(S0, K, r, sigma, T):
|
||||
d1 = (np.log(S0/K) + (r + 0.5*sigma**2)*T) / (sigma*np.sqrt(T))
|
||||
d2 = d1 - sigma*np.sqrt(T)
|
||||
return K*np.exp(-r*T)*norm.cdf(-d2) - S0*norm.cdf(-d1)
|
||||
|
||||
# -----------------------------
|
||||
# GBM Simulation
|
||||
# -----------------------------
|
||||
def simulate_gbm(S0, r, sigma, T, M, N, seed=None):
|
||||
if seed is not None:
|
||||
np.random.seed(seed)
|
||||
dt = T / N
|
||||
# Vectorized simulation for speed
|
||||
S = np.zeros((M, N + 1))
|
||||
S[:, 0] = S0
|
||||
for t in range(1, N + 1):
|
||||
z = np.random.standard_normal(M)
|
||||
S[:, t] = S[:, t-1] * np.exp((r - 0.5 * sigma**2) * dt + sigma * np.sqrt(dt) * z)
|
||||
return S
|
||||
|
||||
# -----------------------------
|
||||
# Longstaff-Schwartz Monte Carlo (Fixed)
|
||||
# -----------------------------
|
||||
def lsm_american_put(S0, K, r, sigma, T, M=100_000, N=50, basis_deg=2, seed=42, return_boundary=False):
|
||||
S = simulate_gbm(S0, r, sigma, T, M, N, seed)
|
||||
dt = T / N
|
||||
df = np.exp(-r * dt)
|
||||
|
||||
# Immediate exercise value (payoff) at each step
|
||||
payoff = np.maximum(K - S, 0)
|
||||
|
||||
# V stores the value of the option at each path
|
||||
# Initialize with the payoff at maturity
|
||||
V = payoff[:, -1]
|
||||
|
||||
boundary = []
|
||||
|
||||
# Backward induction
|
||||
for t in range(N - 1, 0, -1):
|
||||
# Identify In-The-Money paths
|
||||
itm = payoff[:, t] > 0
|
||||
|
||||
# Default: option value at t is just the discounted value from t+1
|
||||
V = V * df
|
||||
|
||||
if np.any(itm):
|
||||
X = S[itm, t]
|
||||
Y = V[itm] # These are already discounted from t+1
|
||||
|
||||
# Regression: Estimate Continuation Value
|
||||
coeffs = np.polyfit(X, Y, basis_deg)
|
||||
continuation_value = np.polyval(coeffs, X)
|
||||
|
||||
# Exercise if payoff > estimated continuation value
|
||||
exercise = payoff[itm, t] > continuation_value
|
||||
|
||||
# Update V for paths where we exercise
|
||||
# Get the indices of the ITM paths that should exercise
|
||||
itm_indices = np.where(itm)[0]
|
||||
exercise_indices = itm_indices[exercise]
|
||||
V[exercise_indices] = payoff[exercise_indices, t]
|
||||
|
||||
# Boundary: The highest stock price at which we still exercise
|
||||
if len(exercise_indices) > 0:
|
||||
boundary.append((t * dt, np.max(S[exercise_indices, t])))
|
||||
else:
|
||||
boundary.append((t * dt, np.nan))
|
||||
else:
|
||||
boundary.append((t * dt, np.nan))
|
||||
|
||||
# Final price is the average of discounted values at t=1
|
||||
price = np.mean(V * df)
|
||||
|
||||
if return_boundary:
|
||||
return price, boundary[::-1]
|
||||
return price
|
||||
|
||||
# -----------------------------
|
||||
# QuantLib American Put (v1.18)
|
||||
# -----------------------------
|
||||
|
||||
def quantlib_american_put(S0, K, r, sigma, T):
|
||||
import QuantLib as ql
|
||||
|
||||
today = ql.Date().todaysDate()
|
||||
ql.Settings.instance().evaluationDate = today
|
||||
maturity = today + int(T*365)
|
||||
payoff = ql.PlainVanillaPayoff(ql.Option.Put, K)
|
||||
exercise = ql.AmericanExercise(today, maturity)
|
||||
|
||||
spot = ql.QuoteHandle(ql.SimpleQuote(S0))
|
||||
dividend_curve = ql.YieldTermStructureHandle(ql.FlatForward(today, 0.0, ql.Actual365Fixed()))
|
||||
riskfree_curve = ql.YieldTermStructureHandle(ql.FlatForward(today, r, ql.Actual365Fixed()))
|
||||
vol_curve = ql.BlackVolTermStructureHandle(ql.BlackConstantVol(today, ql.NullCalendar(), sigma, ql.Actual365Fixed()))
|
||||
|
||||
process = ql.BlackScholesMertonProcess(spot, dividend_curve, riskfree_curve, vol_curve)
|
||||
engine = ql.FdBlackScholesVanillaEngine(process, 200, 400)
|
||||
option = ql.VanillaOption(payoff, exercise)
|
||||
option.setPricingEngine(engine)
|
||||
return option.NPV()
|
||||
|
||||
# -----------------------------
|
||||
# Main script
|
||||
# -----------------------------
|
||||
if __name__ == "__main__":
|
||||
# Parameters
|
||||
S0, K, r, sigma, T = 100, 100, 0.05, 0.2, 1.0
|
||||
M, N = 200_000, 50
|
||||
|
||||
# European Put
|
||||
eur_bs = european_put_black_scholes(S0, K, r, sigma, T)
|
||||
print(f"European Put (BS): {eur_bs:.4f}")
|
||||
|
||||
# LSM American Put
|
||||
lsm_price, boundary = lsm_american_put(S0, K, r, sigma, T, M, N, return_boundary=True)
|
||||
print(f"LSM American Put (M={M}): {lsm_price:.4f}")
|
||||
|
||||
# QuantLib American Put
|
||||
ql_price = quantlib_american_put(S0, K, r, sigma, T)
|
||||
print(f"QuantLib American Put: {ql_price:.4f}")
|
||||
|
||||
print(f"Lower bound (immediate exercise): {max(K-S0,0):.4f}")
|
||||
|
||||
# Plot Exercise Boundary
|
||||
times = [b[0] for b in boundary]
|
||||
boundaries = [b[1] for b in boundary]
|
||||
|
||||
plt.figure(figsize=(8,5))
|
||||
plt.plot(times, boundaries, color='orange', label="LSM Exercise Boundary")
|
||||
plt.axhline(K, color='red', linestyle='--', label="Strike Price")
|
||||
plt.xlabel("Time to Maturity")
|
||||
plt.ylabel("Stock Price for Exercise")
|
||||
plt.title("American Put LSM Exercise Boundary")
|
||||
plt.gca().invert_xaxis()
|
||||
plt.legend()
|
||||
plt.grid(True)
|
||||
plt.show()
|
||||
133
python/american-mc/main2.py
Normal file
133
python/american-mc/main2.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 1. Analytical Hull-White Bond Pricing
|
||||
# ---------------------------------------------------------
|
||||
def bond_price(r_t, t, T, a, sigma, r0):
|
||||
"""Zero-coupon bond price P(t, T) in Hull-White model for flat curve r0."""
|
||||
B = (1 - np.exp(-a * (T - t))) / a
|
||||
A = np.exp((B - (T - t)) * (a**2 * r0 - sigma**2/2) / a**2 - (sigma**2 * B**2 / (4 * a)))
|
||||
return A * np.exp(-B * r_t)
|
||||
|
||||
def get_swap_npv(r_t, t, T_end, strike, a, sigma, r0):
|
||||
"""NPV of a Payer Swap: Receives Floating, Pays Fixed strike."""
|
||||
payment_dates = np.arange(t + 1, T_end + 1, 1.0)
|
||||
if len(payment_dates) == 0:
|
||||
return np.zeros_like(r_t)
|
||||
|
||||
# NPV = 1 - P(t, T_end) - strike * Sum[P(t, T_i)]
|
||||
p_end = bond_price(r_t, t, T_end, a, sigma, r0)
|
||||
fixed_leg = sum(bond_price(r_t, t, pd, a, sigma, r0) for pd in payment_dates)
|
||||
|
||||
return np.maximum(1 - p_end - strike * fixed_leg, 0)
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 2. Joint Exact Simulator (Short Rate & Stochastic Integral)
|
||||
# ---------------------------------------------------------
|
||||
def simulate_hw_exact_joint(r0, a, sigma, exercise_dates, M):
|
||||
"""Samples (r_t, Integral[r]) jointly to get exact discount factors."""
|
||||
M = int(M)
|
||||
num_dates = len(exercise_dates)
|
||||
r_matrix = np.zeros((M, num_dates + 1))
|
||||
d_matrix = np.zeros((M, num_dates)) # d[i] = exp(-integral from t_i to t_{i+1})
|
||||
|
||||
r_matrix[:, 0] = r0
|
||||
t_steps = np.insert(exercise_dates, 0, 0.0)
|
||||
|
||||
for i in range(len(t_steps) - 1):
|
||||
t, T = t_steps[i], t_steps[i+1]
|
||||
dt = T - t
|
||||
|
||||
# Drift adjustments for flat initial curve r0
|
||||
alpha_t = r0 + (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * t))**2
|
||||
alpha_T = r0 + (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * T))**2
|
||||
|
||||
# Means
|
||||
mean_r = r_matrix[:, i] * np.exp(-a * dt) + alpha_T - alpha_t * np.exp(-a * dt)
|
||||
# The expected value of the integral is derived from the bond price: E[exp(-I)] = P(t,T)
|
||||
mean_I = -np.log(bond_price(r_matrix[:, i], t, T, a, sigma, r0))
|
||||
|
||||
# Covariance Matrix Components
|
||||
var_r = (sigma**2 / (2 * a)) * (1 - np.exp(-2 * a * dt))
|
||||
B = (1 - np.exp(-a * dt)) / a
|
||||
var_I = (sigma**2 / a**2) * (dt - 2*B + (1 - np.exp(-2*a*dt))/(2*a))
|
||||
cov_rI = (sigma**2 / (2 * a**2)) * (1 - np.exp(-a * dt))**2
|
||||
|
||||
cov_matrix = [[var_r, cov_rI], [cov_rI, var_I]]
|
||||
|
||||
# Sample joint normal innovations
|
||||
Z = np.random.multivariate_normal([0, 0], cov_matrix, M)
|
||||
|
||||
r_matrix[:, i+1] = mean_r + Z[:, 0]
|
||||
# Important: The variance of the integral affects the mean of the exponent (convexity)
|
||||
# mean_I here is already the log of the bond price (risk-neutral expectation)
|
||||
d_matrix[:, i] = np.exp(-(mean_I + Z[:, 1] - 0.5 * var_I))
|
||||
|
||||
return r_matrix[:, 1:], d_matrix
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# 3. LSM Pricing Logic
|
||||
# ---------------------------------------------------------
|
||||
def price_bermudan_swaption(r0, a, sigma, strike, exercise_dates, T_end, M):
|
||||
# 1. Generate exact paths
|
||||
r_at_ex, discounts = simulate_hw_exact_joint(r0, a, sigma, exercise_dates, M)
|
||||
|
||||
# 2. Final exercise date payoff
|
||||
T_last = exercise_dates[-1]
|
||||
cash_flows = get_swap_npv(r_at_ex[:, -1], T_last, T_end, strike, a, sigma, r0)
|
||||
|
||||
# 3. Backward Induction
|
||||
# exercise_dates[:-1] because we already handled the last date
|
||||
for i in reversed(range(len(exercise_dates) - 1)):
|
||||
t_current = exercise_dates[i]
|
||||
|
||||
# Pull cash flows back to current time using stochastic discount
|
||||
# Note: If path was exercised later, this is the discounted value of that exercise.
|
||||
cash_flows = cash_flows * discounts[:, i]
|
||||
|
||||
# Current intrinsic value
|
||||
X = r_at_ex[:, i]
|
||||
immediate_payoff = get_swap_npv(X, t_current, T_end, strike, a, sigma, r0)
|
||||
|
||||
# Only regress In-The-Money paths
|
||||
itm = immediate_payoff > 0
|
||||
if np.any(itm):
|
||||
# Regression: Basis functions [1, r, r^2]
|
||||
A = np.column_stack([np.ones_like(X[itm]), X[itm], X[itm]**2])
|
||||
coeffs = np.linalg.lstsq(A, cash_flows[itm], rcond=None)[0]
|
||||
continuation_value = A @ coeffs
|
||||
|
||||
# Exercise decision
|
||||
exercise = immediate_payoff[itm] > continuation_value
|
||||
|
||||
itm_indices = np.where(itm)[0]
|
||||
exercise_indices = itm_indices[exercise]
|
||||
|
||||
# Update cash flows for paths where we exercise early
|
||||
cash_flows[exercise_indices] = immediate_payoff[exercise_indices]
|
||||
|
||||
# Final discount to t=0
|
||||
# The first discount factor in 'discounts' is from t1 to t0
|
||||
# But wait, our 'discounts' matrix is (M, num_dates).
|
||||
# Let's just use the analytical P(0, t1) for the very last step to t=0.
|
||||
final_price = np.mean(cash_flows * bond_price(r0, 0, exercise_dates[0], a, sigma, r0))
|
||||
return final_price
|
||||
|
||||
# ---------------------------------------------------------
|
||||
# Execution
|
||||
# ---------------------------------------------------------
|
||||
if __name__ == "__main__":
|
||||
# Params
|
||||
r0_val, a_val, sigma_val = 0.05, 0.1, 0.01
|
||||
strike_val = 0.05
|
||||
ex_dates = np.array([1.0, 2.0, 3.0, 4.0])
|
||||
maturity = 5.0
|
||||
num_paths = 100_000
|
||||
|
||||
price = price_bermudan_swaption(r0_val, a_val, sigma_val, strike_val, ex_dates, maturity, num_paths)
|
||||
|
||||
print(f"--- 100% Exact LSM Bermudan Swaption ---")
|
||||
print(f"Parameters: a={a_val}, sigma={sigma_val}, strike={strike_val}")
|
||||
print(f"Exercise Dates: {ex_dates}")
|
||||
print(f"Calculated Price: {price:.6f}")
|
||||
5
python/cvatesting/.vscode/settings.json
vendored
Normal file
5
python/cvatesting/.vscode/settings.json
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
{
|
||||
"python-envs.defaultEnvManager": "ms-python.python:conda",
|
||||
"python-envs.defaultPackageManager": "ms-python.python:conda",
|
||||
"python-envs.pythonProjects": []
|
||||
}
|
||||
909
python/cvatesting/CVA_calculation_I.ipynb
Normal file
909
python/cvatesting/CVA_calculation_I.ipynb
Normal file
File diff suppressed because one or more lines are too long
83
python/cvatesting/driver.py
Normal file
83
python/cvatesting/driver.py
Normal file
@@ -0,0 +1,83 @@
|
||||
|
||||
import numpy as np
|
||||
import Quantlib as ql
|
||||
|
||||
from hullwhite import (build_calibrated_hw,simulate_hw_state,swap_npv_from_state)
|
||||
from instruments import (make_vanilla_swap,make_swaption_helpers)
|
||||
|
||||
today = ql.Date(15, 1, 2025)
|
||||
ql.Settings.instance().evaluationDate = today
|
||||
calendar = ql.TARGET()
|
||||
day_count = ql.Actual365Fixed()
|
||||
|
||||
|
||||
def expected_exposure(
|
||||
swap,
|
||||
hw,
|
||||
curve_handle,
|
||||
today,
|
||||
time_grid,
|
||||
x_paths
|
||||
):
|
||||
ee = np.zeros(len(time_grid))
|
||||
|
||||
for j, t in enumerate(time_grid):
|
||||
values = []
|
||||
|
||||
for i in range(len(x_paths)):
|
||||
v = swap_npv_from_state(
|
||||
swap, hw, curve_handle, t, x_paths[i, j], today
|
||||
)
|
||||
values.append(max(v, 0.0))
|
||||
|
||||
ee[j] = np.mean(values)
|
||||
|
||||
return ee
|
||||
|
||||
def main():
|
||||
# Time grid
|
||||
time_grid = np.linspace(0, 10, 121)
|
||||
|
||||
flat_rate = 0.02
|
||||
curve = ql.FlatForward(today, flat_rate, day_count)
|
||||
curve_handle = ql.YieldTermStructureHandle(curve)
|
||||
|
||||
#creat and calibrate HW
|
||||
a_init = 0.05
|
||||
sigma_init = 0.01
|
||||
hw = ql.HullWhite(curve_handle, a_init, sigma_init)
|
||||
|
||||
index = ql.Euribor6M(curve_handle)
|
||||
helpers=make_swaption_helpers(curve_handle,index,hw)
|
||||
optimizer = ql.LevenbergMarquardt()
|
||||
end_criteria = ql.EndCriteria(1000, 500, 1e-8, 1e-8, 1e-8)
|
||||
hw.calibrate(helpers, optimizer, end_criteria)
|
||||
|
||||
a, sigma = hw.params()
|
||||
print(f"Calibrated a = {a:.4f}")
|
||||
print(f"Calibrated sigma = {sigma:.4f}")
|
||||
|
||||
|
||||
|
||||
# Simulate model
|
||||
x_paths = simulate_hw_state(
|
||||
hw, curve_handle, time_grid, n_paths=10000
|
||||
)
|
||||
|
||||
# Swap
|
||||
swap = make_receiver_swap(
|
||||
today, curve_handle, maturity_years=10, fixed_rate=0.025
|
||||
)
|
||||
|
||||
# Exposure
|
||||
ee = expected_exposure(
|
||||
swap, hw, curve_handle, today, time_grid, x_paths
|
||||
)
|
||||
|
||||
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ =='__main__':
|
||||
main()
|
||||
66
python/cvatesting/hullwhite.py
Normal file
66
python/cvatesting/hullwhite.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import QuantLib as ql
|
||||
|
||||
|
||||
def build_calibrated_hw(
|
||||
curve_handle,
|
||||
swaption_helpers,
|
||||
a_init=0.05,
|
||||
sigma_init=0.01
|
||||
):
|
||||
hw = ql.HullWhite(curve_handle, a_init, sigma_init)
|
||||
|
||||
optimizer = ql.LevenbergMarquardt()
|
||||
end_criteria = ql.EndCriteria(1000, 500, 1e-8, 1e-8, 1e-8)
|
||||
|
||||
hw.calibrate(swaption_helpers, optimizer, end_criteria)
|
||||
|
||||
return hw
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def simulate_hw_state(
|
||||
hw: ql.HullWhite,
|
||||
curve_handle,
|
||||
time_grid,
|
||||
n_paths,
|
||||
seed=42
|
||||
):
|
||||
a, sigma = hw.params()
|
||||
|
||||
dt = np.diff(time_grid)
|
||||
n_steps = len(time_grid)
|
||||
|
||||
np.random.seed(seed)
|
||||
x = np.zeros((n_paths, n_steps))
|
||||
|
||||
for i in range(1, n_steps):
|
||||
decay = np.exp(-a * dt[i-1])
|
||||
vol = sigma * np.sqrt((1 - np.exp(-2 * a * dt[i-1])) / (2 * a))
|
||||
z = np.random.normal(size=n_paths)
|
||||
|
||||
x[:, i] = x[:, i-1] * decay + vol * z
|
||||
|
||||
return x
|
||||
|
||||
|
||||
def swap_npv_from_state(
|
||||
swap,
|
||||
hw,
|
||||
curve_handle,
|
||||
t,
|
||||
x_t,
|
||||
today
|
||||
):
|
||||
ql.Settings.instance().evaluationDate = today + int(t * 365)
|
||||
|
||||
hw.setState(x_t)
|
||||
|
||||
engine = ql.DiscountingSwapEngine(
|
||||
curve_handle,
|
||||
False,
|
||||
hw
|
||||
)
|
||||
|
||||
swap.setPricingEngine(engine)
|
||||
return swap.NPV()
|
||||
165
python/cvatesting/instruments.py
Normal file
165
python/cvatesting/instruments.py
Normal file
@@ -0,0 +1,165 @@
|
||||
import QuantLib as ql
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Global defaults (override via parameters if needed)
|
||||
# ============================================================
|
||||
|
||||
CALENDAR = ql.TARGET()
|
||||
BUSINESS_CONVENTION = ql.ModifiedFollowing
|
||||
DATE_GEN = ql.DateGeneration.Forward
|
||||
|
||||
FIXED_DAYCOUNT = ql.Thirty360(ql.Thirty360.European)
|
||||
FLOAT_DAYCOUNT = ql.Actual360()
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Yield curve factory
|
||||
# ============================================================
|
||||
|
||||
def make_flat_curve(
|
||||
evaluation_date: ql.Date,
|
||||
rate: float,
|
||||
day_count: ql.DayCounter = ql.Actual365Fixed()
|
||||
) -> ql.YieldTermStructureHandle:
|
||||
"""
|
||||
Flat yield curve factory.
|
||||
"""
|
||||
ql.Settings.instance().evaluationDate = evaluation_date
|
||||
curve = ql.FlatForward(evaluation_date, rate, day_count)
|
||||
return ql.YieldTermStructureHandle(curve)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Index factory
|
||||
# ============================================================
|
||||
|
||||
def make_euribor_6m(
|
||||
curve_handle: ql.YieldTermStructureHandle
|
||||
) -> ql.IborIndex:
|
||||
"""
|
||||
Euribor 6M index factory.
|
||||
"""
|
||||
return ql.Euribor6M(curve_handle)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Schedule factory
|
||||
# ============================================================
|
||||
|
||||
def make_schedule(
|
||||
start: ql.Date,
|
||||
maturity: ql.Date,
|
||||
tenor: ql.Period,
|
||||
calendar: ql.Calendar = CALENDAR
|
||||
) -> ql.Schedule:
|
||||
"""
|
||||
Generic schedule factory.
|
||||
"""
|
||||
return ql.Schedule(
|
||||
start,
|
||||
maturity,
|
||||
tenor,
|
||||
calendar,
|
||||
BUSINESS_CONVENTION,
|
||||
BUSINESS_CONVENTION,
|
||||
DATE_GEN,
|
||||
False
|
||||
)
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Swap factory
|
||||
# ============================================================
|
||||
|
||||
def make_vanilla_swap(
|
||||
evaluation_date: ql.Date,
|
||||
curve_handle: ql.YieldTermStructureHandle,
|
||||
notional: float,
|
||||
fixed_rate: float,
|
||||
maturity_years: int,
|
||||
pay_fixed: bool = False,
|
||||
fixed_leg_freq: ql.Period = ql.Annual,
|
||||
float_leg_freq: ql.Period = ql.Semiannual
|
||||
) -> ql.VanillaSwap:
|
||||
"""
|
||||
Vanilla fixed-for-float IRS factory.
|
||||
|
||||
pay_fixed = True -> payer swap
|
||||
pay_fixed = False -> receiver swap
|
||||
"""
|
||||
|
||||
ql.Settings.instance().evaluationDate = evaluation_date
|
||||
|
||||
index = make_euribor_6m(curve_handle)
|
||||
|
||||
start = CALENDAR.advance(evaluation_date, 2, ql.Days)
|
||||
maturity = CALENDAR.advance(start, maturity_years, ql.Years)
|
||||
|
||||
fixed_schedule = make_schedule(
|
||||
start, maturity, ql.Period(fixed_leg_freq)
|
||||
)
|
||||
float_schedule = make_schedule(
|
||||
start, maturity, ql.Period(float_leg_freq)
|
||||
)
|
||||
|
||||
swap_type = (
|
||||
ql.VanillaSwap.Payer if pay_fixed else ql.VanillaSwap.Receiver
|
||||
)
|
||||
|
||||
swap = ql.VanillaSwap(
|
||||
swap_type,
|
||||
notional,
|
||||
fixed_schedule,
|
||||
fixed_rate,
|
||||
FIXED_DAYCOUNT,
|
||||
float_schedule,
|
||||
index,
|
||||
0.0,
|
||||
FLOAT_DAYCOUNT
|
||||
)
|
||||
|
||||
swap.setPricingEngine(
|
||||
ql.DiscountingSwapEngine(curve_handle)
|
||||
)
|
||||
|
||||
return swap
|
||||
|
||||
|
||||
# ============================================================
|
||||
# Swaption helper factory (for HW calibration)
|
||||
# ============================================================
|
||||
|
||||
def make_swaption_helpers(
|
||||
swaption_data: list,
|
||||
curve_handle: ql.YieldTermStructureHandle,
|
||||
index: ql.IborIndex,
|
||||
model: ql.HullWhite
|
||||
) -> list:
|
||||
"""
|
||||
Create ATM swaption helpers.
|
||||
|
||||
swaption_data = [(expiry, tenor, vol), ...]
|
||||
"""
|
||||
|
||||
helpers = []
|
||||
|
||||
for expiry, tenor, vol in swaption_data:
|
||||
helper = ql.SwaptionHelper(
|
||||
expiry,
|
||||
tenor,
|
||||
ql.QuoteHandle(ql.SimpleQuote(vol)),
|
||||
index,
|
||||
index.tenor(),
|
||||
index.dayCounter(),
|
||||
index.dayCounter(),
|
||||
curve_handle
|
||||
)
|
||||
|
||||
helper.setPricingEngine(
|
||||
ql.JamshidianSwaptionEngine(model)
|
||||
)
|
||||
|
||||
helpers.append(helper)
|
||||
|
||||
return helpers
|
||||
109
python/cvatesting/main.py
Normal file
109
python/cvatesting/main.py
Normal file
@@ -0,0 +1,109 @@
|
||||
import QuantLib as ql
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
# --- UNIT 1: Instrument Setup ---
|
||||
def create_30y_swap(today, rate_quote):
|
||||
ql.Settings.instance().evaluationDate = today
|
||||
|
||||
# We use a Relinkable handle to swap curves per path/time-step
|
||||
yield_handle = ql.RelinkableYieldTermStructureHandle()
|
||||
yield_handle.linkTo(ql.FlatForward(today, rate_quote, ql.Actual365Fixed()))
|
||||
|
||||
calendar = ql.TARGET()
|
||||
settle_date = calendar.advance(today, 2, ql.Days)
|
||||
maturity_date = calendar.advance(settle_date, 30, ql.Years)
|
||||
|
||||
index = ql.Euribor6M(yield_handle)
|
||||
|
||||
fixed_schedule = ql.Schedule(settle_date, maturity_date, ql.Period(ql.Annual),
|
||||
calendar, ql.ModifiedFollowing, ql.ModifiedFollowing,
|
||||
ql.DateGeneration.Forward, False)
|
||||
float_schedule = ql.Schedule(settle_date, maturity_date, ql.Period(ql.Semiannual),
|
||||
calendar, ql.ModifiedFollowing, ql.ModifiedFollowing,
|
||||
ql.DateGeneration.Forward, False)
|
||||
|
||||
swap = ql.VanillaSwap(ql.VanillaSwap.Payer, 1e6, fixed_schedule, rate_quote,
|
||||
ql.Thirty360(ql.Thirty360.BondBasis), float_schedule,
|
||||
index, 0.0, ql.Actual360())
|
||||
|
||||
# Pre-calculate all fixing dates needed for the life of the swap
|
||||
fixing_dates = [index.fixingDate(d) for d in float_schedule]
|
||||
|
||||
return swap, yield_handle, index, fixing_dates
|
||||
|
||||
# --- UNIT 2: The Simulation Loop ---
|
||||
def run_simulation(n_paths=50):
|
||||
today = ql.Date(27, 1, 2026)
|
||||
swap, yield_handle, index, fixing_dates = create_30y_swap(today, 0.03)
|
||||
|
||||
# HW Parameters: a=mean reversion, sigma=vol
|
||||
model = ql.HullWhite(yield_handle, 0.01, 0.01)
|
||||
process = ql.HullWhiteProcess(yield_handle, 0.01, 0.01)
|
||||
|
||||
times = np.arange(0, 31, 1.0) # Annual buckets
|
||||
grid = ql.TimeGrid(times)
|
||||
|
||||
rng = ql.GaussianRandomSequenceGenerator(ql.UniformRandomSequenceGenerator(
|
||||
len(grid)-1, ql.UniformRandomGenerator()))
|
||||
seq = ql.GaussianMultiPathGenerator(process, grid, rng, False)
|
||||
|
||||
npv_matrix = np.zeros((n_paths, len(times)))
|
||||
|
||||
for i in range(n_paths):
|
||||
path = seq.next().value()[0]
|
||||
# 1. Clear previous path's fixings to avoid data pollution
|
||||
ql.IndexManager.instance().clearHistories()
|
||||
|
||||
for j, t in enumerate(times):
|
||||
if t >= 30: continue
|
||||
|
||||
eval_date = ql.TARGET().advance(today, int(t), ql.Years)
|
||||
ql.Settings.instance().evaluationDate = eval_date
|
||||
|
||||
# 2. Update the curve with simulated short rate rt
|
||||
rt = path[j]
|
||||
# Use pillars up to 35Y to avoid extrapolation crashes
|
||||
tenors = [0, 1, 2, 5, 10, 20, 30, 35]
|
||||
dates = [ql.TARGET().advance(eval_date, y, ql.Years) for y in tenors]
|
||||
discounts = [model.discountBond(t, t + y, rt) for y in tenors]
|
||||
|
||||
sim_curve = ql.DiscountCurve(dates, discounts, ql.Actual365Fixed())
|
||||
sim_curve.enableExtrapolation()
|
||||
yield_handle.linkTo(sim_curve)
|
||||
|
||||
# 3. MANUAL FIXING INJECTION
|
||||
# For every fixing date that has passed or is 'today'
|
||||
for fd in fixing_dates:
|
||||
if fd <= eval_date:
|
||||
# Direct manual calculation to avoid index.fixing() error
|
||||
# We calculate the forward rate for the 6M period starting at fd
|
||||
start_date = fd
|
||||
end_date = ql.TARGET().advance(start_date, 6, ql.Months)
|
||||
|
||||
# Manual Euribor rate formula: (P1/P2 - 1) / dt
|
||||
p1 = sim_curve.discount(start_date)
|
||||
p2 = sim_curve.discount(end_date)
|
||||
dt = ql.Actual360().yearFraction(start_date, end_date)
|
||||
fwd_rate = (p1 / p2 - 1.0) / dt
|
||||
|
||||
index.addFixing(fd, fwd_rate, True) # Force overwrite
|
||||
|
||||
# 4. Valuation
|
||||
swap.setPricingEngine(ql.DiscountingSwapEngine(yield_handle))
|
||||
npv_matrix[i, j] = max(swap.NPV(), 0)
|
||||
|
||||
ql.Settings.instance().evaluationDate = today
|
||||
|
||||
return times, npv_matrix
|
||||
|
||||
# --- UNIT 3: Visualization ---
|
||||
times, npv_matrix = run_simulation(n_paths=100)
|
||||
ee = np.mean(npv_matrix, axis=0)
|
||||
|
||||
plt.figure(figsize=(10, 5))
|
||||
plt.plot(times, ee, lw=2, label="Expected Exposure (EE)")
|
||||
plt.fill_between(times, ee, alpha=0.3)
|
||||
plt.title("30Y Swap EE Profile - Hull White")
|
||||
plt.xlabel("Years"); plt.ylabel("Exposure"); plt.grid(True); plt.legend()
|
||||
plt.show()
|
||||
42
python/expression-evaluator/CLAUDE.md
Normal file
42
python/expression-evaluator/CLAUDE.md
Normal file
@@ -0,0 +1,42 @@
|
||||
# Expression Evaluator
|
||||
|
||||
## Overview
|
||||
Educational project teaching DAGs and state machines through a calculator.
|
||||
Pure Python, no external dependencies.
|
||||
|
||||
## Running
|
||||
```bash
|
||||
python main.py "3 + 4 * 2" # single expression
|
||||
python main.py # REPL mode
|
||||
python main.py --show-tokens --show-ast --trace "expr" # show internals
|
||||
python main.py --dot "3+4*2" | dot -Tpng -o ast.png # AST diagram
|
||||
python main.py --dot-fsm | dot -Tpng -o fsm.png # FSM diagram
|
||||
```
|
||||
|
||||
## Testing
|
||||
```bash
|
||||
python -m pytest tests/ -v
|
||||
```
|
||||
|
||||
## Architecture
|
||||
- `tokenizer.py` -- Explicit finite state machine (Mealy machine) tokenizer
|
||||
- `parser.py` -- Recursive descent parser building an AST (DAG)
|
||||
- `evaluator.py` -- Post-order tree walker (topological sort evaluation)
|
||||
- `visualize.py` -- Graphviz dot generation for AST and FSM diagrams
|
||||
- `main.py` -- CLI entry point with argparse, REPL mode
|
||||
|
||||
## Key Design Decisions
|
||||
- State machine uses an explicit transition table (dict), not implicit if/else
|
||||
- Unary minus resolved by examining previous token context
|
||||
- Power operator (`^`) is right-associative (grammar uses right-recursion)
|
||||
- AST nodes are dataclasses; evaluation uses structural pattern matching
|
||||
- Graphviz output is raw dot strings (no graphviz Python package needed)
|
||||
|
||||
## Grammar
|
||||
```
|
||||
expression ::= term ((PLUS | MINUS) term)*
|
||||
term ::= unary ((MULTIPLY | DIVIDE) unary)*
|
||||
unary ::= UNARY_MINUS unary | power
|
||||
power ::= atom (POWER power)?
|
||||
atom ::= NUMBER | LPAREN expression RPAREN
|
||||
```
|
||||
87
python/expression-evaluator/README.md
Normal file
87
python/expression-evaluator/README.md
Normal file
@@ -0,0 +1,87 @@
|
||||
# Expression Evaluator -- DAGs & State Machines Tutorial
|
||||
|
||||
A calculator that teaches two fundamental CS patterns by building them from scratch:
|
||||
|
||||
1. **Finite State Machine** -- the tokenizer processes input character-by-character using an explicit transition table
|
||||
2. **Directed Acyclic Graph (DAG)** -- the parser builds an expression tree, evaluated bottom-up in topological order
|
||||
|
||||
## What You'll Learn
|
||||
|
||||
| File | CS Concept | What it does |
|
||||
|------|-----------|-------------|
|
||||
| `tokenizer.py` | **State Machine** (Mealy machine) | Converts `"3 + 4 * 2"` into tokens using a transition table |
|
||||
| `parser.py` | **DAG construction** | Builds an expression tree with operator precedence |
|
||||
| `evaluator.py` | **Topological evaluation** | Walks the tree bottom-up (leaves before parents) |
|
||||
| `visualize.py` | **Visualization** | Generates graphviz diagrams of both the FSM and AST |
|
||||
|
||||
## Quick Start
|
||||
|
||||
```bash
|
||||
# Evaluate an expression
|
||||
python main.py "3 + 4 * 2"
|
||||
# => 11
|
||||
|
||||
# Interactive REPL
|
||||
python main.py
|
||||
|
||||
# See how the state machine tokenizes
|
||||
python main.py --show-tokens "(2 + 3) * -4"
|
||||
|
||||
# See the expression tree (DAG)
|
||||
python main.py --show-ast "(2 + 3) * 4"
|
||||
# *
|
||||
# +-- +
|
||||
# | +-- 2
|
||||
# | `-- 3
|
||||
# `-- 4
|
||||
|
||||
# Watch evaluation in topological order
|
||||
python main.py --trace "(2 + 3) * 4"
|
||||
# Step 1: 2 => 2
|
||||
# Step 2: 3 => 3
|
||||
# Step 3: 2 + 3 => 5
|
||||
# Step 4: 4 => 4
|
||||
# Step 5: 5 * 4 => 20
|
||||
|
||||
# Generate graphviz diagrams
|
||||
python main.py --dot "(2 + 3) * 4" | dot -Tpng -o ast.png
|
||||
python main.py --dot-fsm | dot -Tpng -o fsm.png
|
||||
```
|
||||
|
||||
## Features
|
||||
|
||||
- Arithmetic: `+`, `-`, `*`, `/`, `^` (power)
|
||||
- Parentheses: `(2 + 3) * 4`
|
||||
- Unary minus: `-3`, `-(2 + 1)`, `2 * -3`
|
||||
- Decimals: `3.14`, `.5`
|
||||
- Standard precedence: parens > `^` > `*`/`/` > `+`/`-`
|
||||
- Right-associative power: `2^3^4` = `2^(3^4)`
|
||||
- Correct unary minus: `-3^2` = `-(3^2)` = `-9`
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
python -m pytest tests/ -v
|
||||
```
|
||||
|
||||
## How the State Machine Works
|
||||
|
||||
The tokenizer in `tokenizer.py` uses an **explicit transition table** -- a dictionary mapping `(current_state, character_class)` to `(next_state, action)`. This is the same pattern used in network protocol parsers, regex engines, and compiler lexers.
|
||||
|
||||
The three states are:
|
||||
- `START` -- between tokens, dispatching based on the next character
|
||||
- `INTEGER` -- accumulating digits (e.g., `"12"` so far)
|
||||
- `DECIMAL` -- accumulating digits after a decimal point (e.g., `"12.3"`)
|
||||
|
||||
Use `--dot-fsm` to generate a visual diagram of the state machine.
|
||||
|
||||
## How the DAG Works
|
||||
|
||||
The parser in `parser.py` builds an **expression tree** (AST) where:
|
||||
- **Leaf nodes** are numbers (no dependencies)
|
||||
- **Interior nodes** are operators with edges to their operands
|
||||
- **Edges** represent "depends on" relationships
|
||||
|
||||
Evaluation in `evaluator.py` walks this tree **bottom-up** -- children before parents. This is exactly a **topological sort** of the DAG: you can only compute a node after all its dependencies are resolved.
|
||||
|
||||
Use `--show-ast` to see the tree structure, or `--dot` to generate a graphviz diagram.
|
||||
147
python/expression-evaluator/evaluator.py
Normal file
147
python/expression-evaluator/evaluator.py
Normal file
@@ -0,0 +1,147 @@
|
||||
"""
|
||||
Part 3: DAG Evaluation -- Tree Walker
|
||||
=======================================
|
||||
Evaluating the AST bottom-up is equivalent to topological-sort
|
||||
evaluation of a DAG. We must evaluate a node's children before
|
||||
the node itself -- just like in any dependency graph.
|
||||
|
||||
For a tree, post-order traversal gives a topological ordering.
|
||||
The recursive evaluate() function naturally does this:
|
||||
1. Recursively evaluate all children (dependencies)
|
||||
2. Combine the results (compute this node's value)
|
||||
3. Return the result (make it available to the parent)
|
||||
|
||||
This is the same pattern as:
|
||||
- make: build dependencies before the target
|
||||
- pip/npm install: install dependencies before the package
|
||||
- Spreadsheet recalculation: compute referenced cells first
|
||||
"""
|
||||
|
||||
from parser import NumberNode, BinOpNode, UnaryOpNode, Node
|
||||
from tokenizer import TokenType
|
||||
|
||||
|
||||
# ---------- Errors ----------
|
||||
|
||||
class EvalError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
# ---------- Evaluator ----------
|
||||
|
||||
OP_SYMBOLS = {
|
||||
TokenType.PLUS: '+',
|
||||
TokenType.MINUS: '-',
|
||||
TokenType.MULTIPLY: '*',
|
||||
TokenType.DIVIDE: '/',
|
||||
TokenType.POWER: '^',
|
||||
TokenType.UNARY_MINUS: 'neg',
|
||||
}
|
||||
|
||||
|
||||
def evaluate(node):
|
||||
"""
|
||||
Evaluate an AST by walking it bottom-up (post-order traversal).
|
||||
|
||||
This is a recursive function that mirrors the DAG structure:
|
||||
each recursive call follows a DAG edge to a child node.
|
||||
Children are evaluated before parents -- topological order.
|
||||
"""
|
||||
match node:
|
||||
case NumberNode(value=v):
|
||||
return v
|
||||
|
||||
case UnaryOpNode(op=TokenType.UNARY_MINUS, operand=child):
|
||||
return -evaluate(child)
|
||||
|
||||
case BinOpNode(op=op, left=left, right=right):
|
||||
left_val = evaluate(left)
|
||||
right_val = evaluate(right)
|
||||
match op:
|
||||
case TokenType.PLUS:
|
||||
return left_val + right_val
|
||||
case TokenType.MINUS:
|
||||
return left_val - right_val
|
||||
case TokenType.MULTIPLY:
|
||||
return left_val * right_val
|
||||
case TokenType.DIVIDE:
|
||||
if right_val == 0:
|
||||
raise EvalError("division by zero")
|
||||
return left_val / right_val
|
||||
case TokenType.POWER:
|
||||
return left_val ** right_val
|
||||
|
||||
raise EvalError(f"unknown node type: {type(node)}")
|
||||
|
||||
|
||||
def evaluate_traced(node):
|
||||
"""
|
||||
Like evaluate(), but records each step for educational display.
|
||||
Returns (result, list_of_trace_lines).
|
||||
|
||||
The trace shows the topological evaluation order -- how the DAG
|
||||
is evaluated from leaves to root. Each step shows a node being
|
||||
evaluated after all its dependencies are resolved.
|
||||
"""
|
||||
steps = []
|
||||
counter = [0] # mutable counter for step numbering
|
||||
|
||||
def _walk(node, depth):
|
||||
indent = " " * depth
|
||||
counter[0] += 1
|
||||
step = counter[0]
|
||||
|
||||
match node:
|
||||
case NumberNode(value=v):
|
||||
result = v
|
||||
display = _format_number(v)
|
||||
steps.append(f"{indent}Step {step}: {display} => {_format_number(result)}")
|
||||
return result
|
||||
|
||||
case UnaryOpNode(op=TokenType.UNARY_MINUS, operand=child):
|
||||
child_val = _walk(child, depth + 1)
|
||||
result = -child_val
|
||||
counter[0] += 1
|
||||
step = counter[0]
|
||||
steps.append(
|
||||
f"{indent}Step {step}: neg({_format_number(child_val)}) "
|
||||
f"=> {_format_number(result)}"
|
||||
)
|
||||
return result
|
||||
|
||||
case BinOpNode(op=op, left=left, right=right):
|
||||
left_val = _walk(left, depth + 1)
|
||||
right_val = _walk(right, depth + 1)
|
||||
sym = OP_SYMBOLS[op]
|
||||
match op:
|
||||
case TokenType.PLUS:
|
||||
result = left_val + right_val
|
||||
case TokenType.MINUS:
|
||||
result = left_val - right_val
|
||||
case TokenType.MULTIPLY:
|
||||
result = left_val * right_val
|
||||
case TokenType.DIVIDE:
|
||||
if right_val == 0:
|
||||
raise EvalError("division by zero")
|
||||
result = left_val / right_val
|
||||
case TokenType.POWER:
|
||||
result = left_val ** right_val
|
||||
counter[0] += 1
|
||||
step = counter[0]
|
||||
steps.append(
|
||||
f"{indent}Step {step}: {_format_number(left_val)} {sym} "
|
||||
f"{_format_number(right_val)} => {_format_number(result)}"
|
||||
)
|
||||
return result
|
||||
|
||||
raise EvalError(f"unknown node type: {type(node)}")
|
||||
|
||||
result = _walk(node, 0)
|
||||
return result, steps
|
||||
|
||||
|
||||
def _format_number(v):
|
||||
"""Display a number as integer when possible."""
|
||||
if isinstance(v, float) and v == int(v):
|
||||
return str(int(v))
|
||||
return str(v)
|
||||
163
python/expression-evaluator/main.py
Normal file
163
python/expression-evaluator/main.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Expression Evaluator -- Learn DAGs & State Machines
|
||||
====================================================
|
||||
CLI entry point and interactive REPL.
|
||||
|
||||
Usage:
|
||||
python main.py "3 + 4 * 2" # evaluate
|
||||
python main.py # REPL mode
|
||||
python main.py --show-tokens --show-ast --trace "expr" # show internals
|
||||
python main.py --dot "3 + 4 * 2" | dot -Tpng -o ast.png
|
||||
python main.py --dot-fsm | dot -Tpng -o fsm.png
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
|
||||
from tokenizer import tokenize, TokenError
|
||||
from parser import Parser, ParseError
|
||||
from evaluator import evaluate, evaluate_traced, EvalError
|
||||
from visualize import ast_to_dot, fsm_to_dot, ast_to_text
|
||||
|
||||
|
||||
def process_expression(expr, args):
|
||||
"""Tokenize, parse, and evaluate a single expression."""
|
||||
try:
|
||||
tokens = tokenize(expr)
|
||||
except TokenError as e:
|
||||
_print_error(expr, e)
|
||||
return
|
||||
|
||||
if args.show_tokens:
|
||||
print("\nTokens:")
|
||||
for tok in tokens:
|
||||
print(f" {tok}")
|
||||
|
||||
try:
|
||||
ast = Parser(tokens).parse()
|
||||
except ParseError as e:
|
||||
_print_error(expr, e)
|
||||
return
|
||||
|
||||
if args.show_ast:
|
||||
print("\nAST (text tree):")
|
||||
print(ast_to_text(ast))
|
||||
|
||||
if args.dot:
|
||||
print(ast_to_dot(ast))
|
||||
return # dot output goes to stdout, skip numeric result
|
||||
|
||||
if args.trace:
|
||||
try:
|
||||
result, steps = evaluate_traced(ast)
|
||||
except EvalError as e:
|
||||
print(f"Eval error: {e}")
|
||||
return
|
||||
print("\nEvaluation trace (topological order):")
|
||||
for step in steps:
|
||||
print(step)
|
||||
print(f"\nResult: {_format_result(result)}")
|
||||
else:
|
||||
try:
|
||||
result = evaluate(ast)
|
||||
except EvalError as e:
|
||||
print(f"Eval error: {e}")
|
||||
return
|
||||
print(_format_result(result))
|
||||
|
||||
|
||||
def repl(args):
|
||||
"""Interactive read-eval-print loop."""
|
||||
print("Expression Evaluator REPL")
|
||||
print("Type an expression, or 'quit' to exit.")
|
||||
flags = []
|
||||
if args.show_tokens:
|
||||
flags.append("--show-tokens")
|
||||
if args.show_ast:
|
||||
flags.append("--show-ast")
|
||||
if args.trace:
|
||||
flags.append("--trace")
|
||||
if flags:
|
||||
print(f"Active flags: {' '.join(flags)}")
|
||||
print()
|
||||
|
||||
while True:
|
||||
try:
|
||||
line = input(">>> ").strip()
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
print()
|
||||
break
|
||||
if line.lower() in ("quit", "exit", "q"):
|
||||
break
|
||||
if not line:
|
||||
continue
|
||||
process_expression(line, args)
|
||||
print()
|
||||
|
||||
|
||||
def _print_error(expr, error):
|
||||
"""Print an error with a caret pointing to the position."""
|
||||
print(f"Error: {error}")
|
||||
if hasattr(error, 'position') and error.position is not None:
|
||||
print(f" {expr}")
|
||||
print(f" {' ' * error.position}^")
|
||||
|
||||
|
||||
def _format_result(v):
|
||||
"""Format a numeric result: show as int when possible."""
|
||||
if isinstance(v, float) and v == int(v) and abs(v) < 1e15:
|
||||
return str(int(v))
|
||||
return str(v)
|
||||
|
||||
|
||||
def main():
|
||||
arg_parser = argparse.ArgumentParser(
|
||||
description="Expression Evaluator -- learn DAGs and state machines",
|
||||
epilog="Examples:\n"
|
||||
" python main.py '3 + 4 * 2'\n"
|
||||
" python main.py --show-tokens --trace '-(3 + 4) ^ 2'\n"
|
||||
" python main.py --dot '(2+3)*4' | dot -Tpng -o ast.png\n"
|
||||
" python main.py --dot-fsm | dot -Tpng -o fsm.png",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"expression", nargs="?",
|
||||
help="Expression to evaluate (omit for REPL mode)",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--show-tokens", action="store_true",
|
||||
help="Display tokenizer output",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--show-ast", action="store_true",
|
||||
help="Display AST as indented text tree",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--trace", action="store_true",
|
||||
help="Show step-by-step evaluation trace",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--dot", action="store_true",
|
||||
help="Output AST as graphviz dot (pipe to: dot -Tpng -o ast.png)",
|
||||
)
|
||||
arg_parser.add_argument(
|
||||
"--dot-fsm", action="store_true",
|
||||
help="Output tokenizer FSM as graphviz dot",
|
||||
)
|
||||
|
||||
args = arg_parser.parse_args()
|
||||
|
||||
# Special mode: just print the FSM diagram and exit
|
||||
if args.dot_fsm:
|
||||
print(fsm_to_dot())
|
||||
return
|
||||
|
||||
# REPL mode if no expression given
|
||||
if args.expression is None:
|
||||
repl(args)
|
||||
else:
|
||||
process_expression(args.expression, args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
217
python/expression-evaluator/parser.py
Normal file
217
python/expression-evaluator/parser.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Part 2: DAG Construction -- Recursive Descent Parser
|
||||
=====================================================
|
||||
A parser converts a flat list of tokens into a tree structure (AST).
|
||||
The AST is a DAG (Directed Acyclic Graph) where:
|
||||
|
||||
- Nodes are operations (BinOpNode) or values (NumberNode)
|
||||
- Edges point from parent operations to their operands
|
||||
- The graph is acyclic because an operation's inputs are always
|
||||
"simpler" sub-expressions (no circular dependencies)
|
||||
- It is a tree (a special case of DAG) because no node is shared
|
||||
|
||||
This is the same structure as:
|
||||
- Spreadsheet dependency graphs (cell A1 depends on B1, B2...)
|
||||
- Build systems (Makefile targets depend on other targets)
|
||||
- Task scheduling (some tasks must finish before others start)
|
||||
- Neural network computation graphs (forward pass is a DAG)
|
||||
|
||||
Key DAG concepts demonstrated:
|
||||
- Nodes: operations and values
|
||||
- Directed edges: from operation to its inputs (dependencies)
|
||||
- Acyclic: no circular dependencies
|
||||
- Topological ordering: natural evaluation order (leaves first)
|
||||
|
||||
Grammar (BNF) -- precedence is encoded by nesting depth:
|
||||
expression ::= term ((PLUS | MINUS) term)* # lowest precedence
|
||||
term ::= unary ((MULTIPLY | DIVIDE) unary)*
|
||||
unary ::= UNARY_MINUS unary | power
|
||||
power ::= atom (POWER power)? # right-associative
|
||||
atom ::= NUMBER | LPAREN expression RPAREN # highest precedence
|
||||
|
||||
Call chain: expression -> term -> unary -> power -> atom
|
||||
This means: +/- binds loosest, then *//, then unary -, then ^, then parens.
|
||||
So -3^2 = -(3^2) = -9, matching standard math convention.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from tokenizer import Token, TokenType
|
||||
|
||||
|
||||
# ---------- AST node types ----------
|
||||
# These are the nodes of our DAG. Each node is either a leaf (NumberNode)
|
||||
# or an interior node with edges pointing to its children (operands).
|
||||
|
||||
@dataclass
|
||||
class NumberNode:
|
||||
"""Leaf node: a numeric literal. In DAG terms, a node with no outgoing edges."""
|
||||
value: float
|
||||
|
||||
def __repr__(self):
|
||||
if self.value == int(self.value):
|
||||
return f"NumberNode({int(self.value)})"
|
||||
return f"NumberNode({self.value})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class BinOpNode:
|
||||
"""
|
||||
Interior node: a binary operation with two children.
|
||||
|
||||
DAG edges: this node -> left, this node -> right
|
||||
The edges represent "depends on": to compute this node's value,
|
||||
we must first compute left and right.
|
||||
"""
|
||||
op: TokenType
|
||||
left: 'NumberNode | BinOpNode | UnaryOpNode'
|
||||
right: 'NumberNode | BinOpNode | UnaryOpNode'
|
||||
|
||||
def __repr__(self):
|
||||
return f"BinOpNode({self.op.name}, {self.left}, {self.right})"
|
||||
|
||||
|
||||
@dataclass
|
||||
class UnaryOpNode:
|
||||
"""Interior node: a unary operation (negation) with one child."""
|
||||
op: TokenType
|
||||
operand: 'NumberNode | BinOpNode | UnaryOpNode'
|
||||
|
||||
def __repr__(self):
|
||||
return f"UnaryOpNode({self.op.name}, {self.operand})"
|
||||
|
||||
|
||||
# Union type for any AST node
|
||||
Node = NumberNode | BinOpNode | UnaryOpNode
|
||||
|
||||
|
||||
# ---------- Errors ----------
|
||||
|
||||
class ParseError(Exception):
|
||||
def __init__(self, message, position=None):
|
||||
self.position = position
|
||||
pos_info = f" at position {position}" if position is not None else ""
|
||||
super().__init__(f"Parse error{pos_info}: {message}")
|
||||
|
||||
|
||||
# ---------- Recursive descent parser ----------
|
||||
|
||||
class Parser:
|
||||
"""
|
||||
Converts a list of tokens into an AST (expression tree / DAG).
|
||||
|
||||
Each grammar rule becomes a method. The call tree mirrors the shape
|
||||
of the AST being built. When a deeper method returns a node, it
|
||||
becomes a child of the node built by the caller -- this is how
|
||||
the DAG edges form.
|
||||
|
||||
Precedence is encoded by nesting: lower-precedence operators are
|
||||
parsed at higher (outer) levels, so they become closer to the root
|
||||
of the tree and are evaluated last.
|
||||
"""
|
||||
|
||||
def __init__(self, tokens):
|
||||
self.tokens = tokens
|
||||
self.pos = 0
|
||||
|
||||
def peek(self):
|
||||
"""Look at the current token without consuming it."""
|
||||
return self.tokens[self.pos]
|
||||
|
||||
def consume(self, expected=None):
|
||||
"""Consume and return the current token, optionally asserting its type."""
|
||||
token = self.tokens[self.pos]
|
||||
if expected is not None and token.type != expected:
|
||||
raise ParseError(
|
||||
f"expected {expected.name}, got {token.type.name}",
|
||||
token.position,
|
||||
)
|
||||
self.pos += 1
|
||||
return token
|
||||
|
||||
def parse(self):
|
||||
"""Entry point: parse the full expression and verify we consumed everything."""
|
||||
if self.peek().type == TokenType.EOF:
|
||||
raise ParseError("empty expression")
|
||||
node = self.expression()
|
||||
self.consume(TokenType.EOF)
|
||||
return node
|
||||
|
||||
# --- Grammar rules ---
|
||||
# Each method corresponds to one production in the grammar.
|
||||
# The nesting encodes operator precedence.
|
||||
|
||||
def expression(self):
|
||||
"""expression ::= term ((PLUS | MINUS) term)*"""
|
||||
node = self.term()
|
||||
while self.peek().type in (TokenType.PLUS, TokenType.MINUS):
|
||||
op_token = self.consume()
|
||||
right = self.term()
|
||||
# Build a new BinOpNode -- this creates a DAG edge from
|
||||
# the new node to both 'node' (left) and 'right'
|
||||
node = BinOpNode(op_token.type, node, right)
|
||||
return node
|
||||
|
||||
def term(self):
|
||||
"""term ::= unary ((MULTIPLY | DIVIDE) unary)*"""
|
||||
node = self.unary()
|
||||
while self.peek().type in (TokenType.MULTIPLY, TokenType.DIVIDE):
|
||||
op_token = self.consume()
|
||||
right = self.unary()
|
||||
node = BinOpNode(op_token.type, node, right)
|
||||
return node
|
||||
|
||||
def unary(self):
|
||||
"""
|
||||
unary ::= UNARY_MINUS unary | power
|
||||
|
||||
Unary minus is parsed here, between term and power, so it binds
|
||||
looser than ^ but tighter than * and /. This gives the standard
|
||||
math behavior: -3^2 = -(3^2) = -9.
|
||||
|
||||
The recursion (unary calls itself) handles double negation: --3 = 3.
|
||||
"""
|
||||
if self.peek().type == TokenType.UNARY_MINUS:
|
||||
op_token = self.consume()
|
||||
operand = self.unary()
|
||||
return UnaryOpNode(op_token.type, operand)
|
||||
return self.power()
|
||||
|
||||
def power(self):
|
||||
"""
|
||||
power ::= atom (POWER power)?
|
||||
|
||||
Right-recursive for right-associativity: 2^3^4 = 2^(3^4) = 2^81.
|
||||
Compare with term() which uses a while loop for LEFT-associativity.
|
||||
"""
|
||||
node = self.atom()
|
||||
if self.peek().type == TokenType.POWER:
|
||||
op_token = self.consume()
|
||||
right = self.power() # recurse (not loop) for right-associativity
|
||||
node = BinOpNode(op_token.type, node, right)
|
||||
return node
|
||||
|
||||
def atom(self):
|
||||
"""
|
||||
atom ::= NUMBER | LPAREN expression RPAREN
|
||||
|
||||
The base case: either a literal number or a parenthesized
|
||||
sub-expression. Parentheses work by recursing back to
|
||||
expression(), which restarts precedence parsing from the top.
|
||||
"""
|
||||
token = self.peek()
|
||||
|
||||
if token.type == TokenType.NUMBER:
|
||||
self.consume()
|
||||
return NumberNode(float(token.value))
|
||||
|
||||
if token.type == TokenType.LPAREN:
|
||||
self.consume()
|
||||
node = self.expression()
|
||||
self.consume(TokenType.RPAREN)
|
||||
return node
|
||||
|
||||
raise ParseError(
|
||||
f"expected number or '(', got {token.type.name}",
|
||||
token.position,
|
||||
)
|
||||
0
python/expression-evaluator/tests/__init__.py
Normal file
0
python/expression-evaluator/tests/__init__.py
Normal file
120
python/expression-evaluator/tests/test_evaluator.py
Normal file
120
python/expression-evaluator/tests/test_evaluator.py
Normal file
@@ -0,0 +1,120 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import pytest
|
||||
from tokenizer import tokenize
|
||||
from parser import Parser
|
||||
from evaluator import evaluate, evaluate_traced, EvalError
|
||||
|
||||
|
||||
def eval_expr(expr):
|
||||
"""Helper: tokenize -> parse -> evaluate in one step."""
|
||||
tokens = tokenize(expr)
|
||||
ast = Parser(tokens).parse()
|
||||
return evaluate(ast)
|
||||
|
||||
|
||||
# ---------- Basic arithmetic ----------
|
||||
|
||||
def test_addition():
|
||||
assert eval_expr("3 + 4") == 7.0
|
||||
|
||||
def test_subtraction():
|
||||
assert eval_expr("10 - 3") == 7.0
|
||||
|
||||
def test_multiplication():
|
||||
assert eval_expr("3 * 4") == 12.0
|
||||
|
||||
def test_division():
|
||||
assert eval_expr("10 / 4") == 2.5
|
||||
|
||||
def test_power():
|
||||
assert eval_expr("2 ^ 10") == 1024.0
|
||||
|
||||
|
||||
# ---------- Precedence ----------
|
||||
|
||||
def test_standard_precedence():
|
||||
assert eval_expr("3 + 4 * 2") == 11.0
|
||||
|
||||
def test_parentheses():
|
||||
assert eval_expr("(3 + 4) * 2") == 14.0
|
||||
|
||||
def test_power_precedence():
|
||||
assert eval_expr("2 * 3 ^ 2") == 18.0
|
||||
|
||||
def test_right_associative_power():
|
||||
# 2^(2^3) = 2^8 = 256
|
||||
assert eval_expr("2 ^ 2 ^ 3") == 256.0
|
||||
|
||||
|
||||
# ---------- Unary minus ----------
|
||||
|
||||
def test_negation():
|
||||
assert eval_expr("-5") == -5.0
|
||||
|
||||
def test_double_negation():
|
||||
assert eval_expr("--5") == 5.0
|
||||
|
||||
def test_negation_with_power():
|
||||
# -(3^2) = -9, not (-3)^2 = 9
|
||||
assert eval_expr("-3 ^ 2") == -9.0
|
||||
|
||||
def test_negation_in_parens():
|
||||
assert eval_expr("(-3) ^ 2") == 9.0
|
||||
|
||||
|
||||
# ---------- Decimals ----------
|
||||
|
||||
def test_decimal_addition():
|
||||
assert eval_expr("0.1 + 0.2") == pytest.approx(0.3)
|
||||
|
||||
def test_leading_dot():
|
||||
assert eval_expr(".5 + .5") == 1.0
|
||||
|
||||
|
||||
# ---------- Edge cases ----------
|
||||
|
||||
def test_nested_parens():
|
||||
assert eval_expr("((((3))))") == 3.0
|
||||
|
||||
def test_complex_expression():
|
||||
assert eval_expr("(2 + 3) * (7 - 2) / 5 ^ 1") == 5.0
|
||||
|
||||
def test_long_chain():
|
||||
assert eval_expr("1 + 2 + 3 + 4 + 5") == 15.0
|
||||
|
||||
def test_mixed_operations():
|
||||
assert eval_expr("2 + 3 * 4 - 6 / 2") == 11.0
|
||||
|
||||
|
||||
# ---------- Division by zero ----------
|
||||
|
||||
def test_division_by_zero():
|
||||
with pytest.raises(EvalError):
|
||||
eval_expr("1 / 0")
|
||||
|
||||
def test_division_by_zero_in_expression():
|
||||
with pytest.raises(EvalError):
|
||||
eval_expr("5 + 3 / (2 - 2)")
|
||||
|
||||
|
||||
# ---------- Traced evaluation ----------
|
||||
|
||||
def test_traced_returns_correct_result():
|
||||
tokens = tokenize("3 + 4 * 2")
|
||||
ast = Parser(tokens).parse()
|
||||
result, steps = evaluate_traced(ast)
|
||||
assert result == 11.0
|
||||
assert len(steps) > 0
|
||||
|
||||
def test_traced_step_count():
|
||||
"""A simple binary op has 3 evaluation events: left, right, combine."""
|
||||
tokens = tokenize("3 + 4")
|
||||
ast = Parser(tokens).parse()
|
||||
result, steps = evaluate_traced(ast)
|
||||
assert result == 7.0
|
||||
# NumberNode(3), NumberNode(4), BinOp(+)
|
||||
assert len(steps) == 3
|
||||
136
python/expression-evaluator/tests/test_parser.py
Normal file
136
python/expression-evaluator/tests/test_parser.py
Normal file
@@ -0,0 +1,136 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import pytest
|
||||
from tokenizer import tokenize, TokenType
|
||||
from parser import Parser, ParseError, NumberNode, BinOpNode, UnaryOpNode
|
||||
|
||||
|
||||
def parse(expr):
|
||||
"""Helper: tokenize and parse in one step."""
|
||||
return Parser(tokenize(expr)).parse()
|
||||
|
||||
|
||||
# ---------- Basic parsing ----------
|
||||
|
||||
def test_parse_number():
|
||||
ast = parse("42")
|
||||
assert isinstance(ast, NumberNode)
|
||||
assert ast.value == 42.0
|
||||
|
||||
def test_parse_decimal():
|
||||
ast = parse("3.14")
|
||||
assert isinstance(ast, NumberNode)
|
||||
assert ast.value == 3.14
|
||||
|
||||
def test_parse_addition():
|
||||
ast = parse("3 + 4")
|
||||
assert isinstance(ast, BinOpNode)
|
||||
assert ast.op == TokenType.PLUS
|
||||
assert isinstance(ast.left, NumberNode)
|
||||
assert isinstance(ast.right, NumberNode)
|
||||
|
||||
|
||||
# ---------- Precedence ----------
|
||||
|
||||
def test_multiply_before_add():
|
||||
"""3 + 4 * 2 should parse as 3 + (4 * 2)."""
|
||||
ast = parse("3 + 4 * 2")
|
||||
assert ast.op == TokenType.PLUS
|
||||
assert isinstance(ast.right, BinOpNode)
|
||||
assert ast.right.op == TokenType.MULTIPLY
|
||||
|
||||
def test_power_before_multiply():
|
||||
"""2 * 3 ^ 4 should parse as 2 * (3 ^ 4)."""
|
||||
ast = parse("2 * 3 ^ 4")
|
||||
assert ast.op == TokenType.MULTIPLY
|
||||
assert isinstance(ast.right, BinOpNode)
|
||||
assert ast.right.op == TokenType.POWER
|
||||
|
||||
def test_parentheses_override_precedence():
|
||||
"""(3 + 4) * 2 should parse as (3 + 4) * 2."""
|
||||
ast = parse("(3 + 4) * 2")
|
||||
assert ast.op == TokenType.MULTIPLY
|
||||
assert isinstance(ast.left, BinOpNode)
|
||||
assert ast.left.op == TokenType.PLUS
|
||||
|
||||
|
||||
# ---------- Associativity ----------
|
||||
|
||||
def test_left_associative_subtraction():
|
||||
"""10 - 3 - 2 should parse as (10 - 3) - 2."""
|
||||
ast = parse("10 - 3 - 2")
|
||||
assert ast.op == TokenType.MINUS
|
||||
assert isinstance(ast.left, BinOpNode)
|
||||
assert ast.left.op == TokenType.MINUS
|
||||
assert isinstance(ast.right, NumberNode)
|
||||
|
||||
def test_power_right_associative():
|
||||
"""2 ^ 3 ^ 4 should parse as 2 ^ (3 ^ 4)."""
|
||||
ast = parse("2 ^ 3 ^ 4")
|
||||
assert ast.op == TokenType.POWER
|
||||
assert isinstance(ast.left, NumberNode)
|
||||
assert isinstance(ast.right, BinOpNode)
|
||||
assert ast.right.op == TokenType.POWER
|
||||
|
||||
|
||||
# ---------- Unary minus ----------
|
||||
|
||||
def test_unary_minus():
|
||||
ast = parse("-3")
|
||||
assert isinstance(ast, UnaryOpNode)
|
||||
assert ast.operand.value == 3.0
|
||||
|
||||
def test_double_negation():
|
||||
ast = parse("--3")
|
||||
assert isinstance(ast, UnaryOpNode)
|
||||
assert isinstance(ast.operand, UnaryOpNode)
|
||||
assert ast.operand.operand.value == 3.0
|
||||
|
||||
def test_unary_minus_precedence():
|
||||
"""-3^2 should parse as -(3^2), not (-3)^2."""
|
||||
ast = parse("-3 ^ 2")
|
||||
assert isinstance(ast, UnaryOpNode)
|
||||
assert isinstance(ast.operand, BinOpNode)
|
||||
assert ast.operand.op == TokenType.POWER
|
||||
|
||||
def test_unary_minus_in_expression():
|
||||
"""2 * -3 should parse as 2 * (-(3))."""
|
||||
ast = parse("2 * -3")
|
||||
assert ast.op == TokenType.MULTIPLY
|
||||
assert isinstance(ast.right, UnaryOpNode)
|
||||
|
||||
|
||||
# ---------- Nested parentheses ----------
|
||||
|
||||
def test_nested_parens():
|
||||
ast = parse("((3))")
|
||||
assert isinstance(ast, NumberNode)
|
||||
assert ast.value == 3.0
|
||||
|
||||
def test_complex_nesting():
|
||||
"""((2 + 3) * (7 - 2))"""
|
||||
ast = parse("((2 + 3) * (7 - 2))")
|
||||
assert isinstance(ast, BinOpNode)
|
||||
assert ast.op == TokenType.MULTIPLY
|
||||
|
||||
|
||||
# ---------- Errors ----------
|
||||
|
||||
def test_missing_rparen():
|
||||
with pytest.raises(ParseError):
|
||||
parse("(3 + 4")
|
||||
|
||||
def test_empty_expression():
|
||||
with pytest.raises(ParseError):
|
||||
parse("")
|
||||
|
||||
def test_trailing_operator():
|
||||
with pytest.raises(ParseError):
|
||||
parse("3 +")
|
||||
|
||||
def test_empty_parens():
|
||||
with pytest.raises(ParseError):
|
||||
parse("()")
|
||||
139
python/expression-evaluator/tests/test_tokenizer.py
Normal file
139
python/expression-evaluator/tests/test_tokenizer.py
Normal file
@@ -0,0 +1,139 @@
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
sys.path.insert(0, str(Path(__file__).parent.parent))
|
||||
|
||||
import pytest
|
||||
from tokenizer import tokenize, TokenType, Token, TokenError
|
||||
|
||||
|
||||
# ---------- Basic tokens ----------
|
||||
|
||||
def test_single_integer():
|
||||
tokens = tokenize("42")
|
||||
assert tokens[0].type == TokenType.NUMBER
|
||||
assert tokens[0].value == "42"
|
||||
|
||||
def test_decimal_number():
|
||||
tokens = tokenize("3.14")
|
||||
assert tokens[0].type == TokenType.NUMBER
|
||||
assert tokens[0].value == "3.14"
|
||||
|
||||
def test_leading_dot():
|
||||
tokens = tokenize(".5")
|
||||
assert tokens[0].type == TokenType.NUMBER
|
||||
assert tokens[0].value == ".5"
|
||||
|
||||
def test_all_operators():
|
||||
"""Operators between numbers are all binary."""
|
||||
tokens = tokenize("1 + 1 - 1 * 1 / 1 ^ 1")
|
||||
ops = [t.type for t in tokens if t.type not in (TokenType.NUMBER, TokenType.EOF)]
|
||||
assert ops == [
|
||||
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
|
||||
TokenType.DIVIDE, TokenType.POWER,
|
||||
]
|
||||
|
||||
def test_operators_between_numbers():
|
||||
tokens = tokenize("1 + 2 - 3 * 4 / 5 ^ 6")
|
||||
ops = [t.type for t in tokens if t.type not in (TokenType.NUMBER, TokenType.EOF)]
|
||||
assert ops == [
|
||||
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
|
||||
TokenType.DIVIDE, TokenType.POWER,
|
||||
]
|
||||
|
||||
def test_parentheses():
|
||||
tokens = tokenize("()")
|
||||
assert tokens[0].type == TokenType.LPAREN
|
||||
assert tokens[1].type == TokenType.RPAREN
|
||||
|
||||
|
||||
# ---------- Unary minus ----------
|
||||
|
||||
def test_unary_minus_at_start():
|
||||
tokens = tokenize("-3")
|
||||
assert tokens[0].type == TokenType.UNARY_MINUS
|
||||
assert tokens[1].type == TokenType.NUMBER
|
||||
|
||||
def test_unary_minus_after_lparen():
|
||||
tokens = tokenize("(-3)")
|
||||
assert tokens[1].type == TokenType.UNARY_MINUS
|
||||
|
||||
def test_unary_minus_after_operator():
|
||||
tokens = tokenize("2 * -3")
|
||||
assert tokens[2].type == TokenType.UNARY_MINUS
|
||||
|
||||
def test_binary_minus():
|
||||
tokens = tokenize("5 - 3")
|
||||
assert tokens[1].type == TokenType.MINUS
|
||||
|
||||
def test_double_unary_minus():
|
||||
tokens = tokenize("--3")
|
||||
assert tokens[0].type == TokenType.UNARY_MINUS
|
||||
assert tokens[1].type == TokenType.UNARY_MINUS
|
||||
assert tokens[2].type == TokenType.NUMBER
|
||||
|
||||
|
||||
# ---------- Whitespace handling ----------
|
||||
|
||||
def test_no_spaces():
|
||||
tokens = tokenize("3+4")
|
||||
non_eof = [t for t in tokens if t.type != TokenType.EOF]
|
||||
assert len(non_eof) == 3
|
||||
|
||||
def test_extra_spaces():
|
||||
tokens = tokenize(" 3 + 4 ")
|
||||
non_eof = [t for t in tokens if t.type != TokenType.EOF]
|
||||
assert len(non_eof) == 3
|
||||
|
||||
|
||||
# ---------- Position tracking ----------
|
||||
|
||||
def test_positions():
|
||||
tokens = tokenize("3 + 4")
|
||||
assert tokens[0].position == 0 # '3'
|
||||
assert tokens[1].position == 2 # '+'
|
||||
assert tokens[2].position == 4 # '4'
|
||||
|
||||
|
||||
# ---------- Errors ----------
|
||||
|
||||
def test_invalid_character():
|
||||
with pytest.raises(TokenError):
|
||||
tokenize("3 & 4")
|
||||
|
||||
def test_double_dot():
|
||||
with pytest.raises(TokenError):
|
||||
tokenize("3.14.15")
|
||||
|
||||
|
||||
# ---------- EOF token ----------
|
||||
|
||||
def test_eof_always_present():
|
||||
tokens = tokenize("42")
|
||||
assert tokens[-1].type == TokenType.EOF
|
||||
|
||||
def test_empty_input():
|
||||
tokens = tokenize("")
|
||||
assert len(tokens) == 1
|
||||
assert tokens[0].type == TokenType.EOF
|
||||
|
||||
|
||||
# ---------- Complex expressions ----------
|
||||
|
||||
def test_complex_expression():
|
||||
tokens = tokenize("(3 + 4.5) * -2 ^ 3")
|
||||
types = [t.type for t in tokens if t.type != TokenType.EOF]
|
||||
assert types == [
|
||||
TokenType.LPAREN, TokenType.NUMBER, TokenType.PLUS,
|
||||
TokenType.NUMBER, TokenType.RPAREN, TokenType.MULTIPLY,
|
||||
TokenType.UNARY_MINUS, TokenType.NUMBER, TokenType.POWER,
|
||||
TokenType.NUMBER,
|
||||
]
|
||||
|
||||
def test_adjacent_parens():
|
||||
tokens = tokenize("(3)(4)")
|
||||
types = [t.type for t in tokens if t.type != TokenType.EOF]
|
||||
assert types == [
|
||||
TokenType.LPAREN, TokenType.NUMBER, TokenType.RPAREN,
|
||||
TokenType.LPAREN, TokenType.NUMBER, TokenType.RPAREN,
|
||||
]
|
||||
306
python/expression-evaluator/tokenizer.py
Normal file
306
python/expression-evaluator/tokenizer.py
Normal file
@@ -0,0 +1,306 @@
|
||||
"""
|
||||
Part 1: State Machine Tokenizer
|
||||
================================
|
||||
A tokenizer (lexer) converts raw text into a stream of tokens.
|
||||
This implementation uses an EXPLICIT finite state machine (FSM):
|
||||
|
||||
- States are named values (an enum), not implicit control flow
|
||||
- A transition table maps (current_state, input_class) -> (next_state, action)
|
||||
- The main loop reads one character at a time and consults the table
|
||||
|
||||
This is the same pattern used in:
|
||||
- Network protocol parsers (HTTP, TCP state machines)
|
||||
- Regular expression engines
|
||||
- Compiler front-ends (lexers for C, Python, etc.)
|
||||
- Game AI (enemy behavior states)
|
||||
|
||||
Key FSM concepts demonstrated:
|
||||
- States: the "memory" of what we're currently building
|
||||
- Transitions: rules for moving between states based on input
|
||||
- Actions: side effects (emit a token, accumulate a character)
|
||||
- Mealy machine: outputs depend on both state AND input
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
|
||||
|
||||
# ---------- Token types ----------
|
||||
|
||||
class TokenType(Enum):
|
||||
NUMBER = "NUMBER"
|
||||
PLUS = "PLUS"
|
||||
MINUS = "MINUS"
|
||||
MULTIPLY = "MULTIPLY"
|
||||
DIVIDE = "DIVIDE"
|
||||
POWER = "POWER"
|
||||
LPAREN = "LPAREN"
|
||||
RPAREN = "RPAREN"
|
||||
UNARY_MINUS = "UNARY_MINUS"
|
||||
EOF = "EOF"
|
||||
|
||||
|
||||
@dataclass
|
||||
class Token:
|
||||
type: TokenType
|
||||
value: str # raw text: "42", "+", "(", etc.
|
||||
position: int # character offset in original expression
|
||||
|
||||
def __repr__(self):
|
||||
return f"Token({self.type.name}, {self.value!r}, pos={self.position})"
|
||||
|
||||
|
||||
OPERATOR_MAP = {
|
||||
'+': TokenType.PLUS,
|
||||
'-': TokenType.MINUS,
|
||||
'*': TokenType.MULTIPLY,
|
||||
'/': TokenType.DIVIDE,
|
||||
'^': TokenType.POWER,
|
||||
}
|
||||
|
||||
|
||||
# ---------- FSM state definitions ----------
|
||||
|
||||
class State(Enum):
|
||||
"""
|
||||
The tokenizer's finite set of states.
|
||||
|
||||
START -- idle / between tokens, deciding what comes next
|
||||
INTEGER -- accumulating digits of an integer (e.g. "12" so far)
|
||||
DECIMAL -- accumulating digits after a decimal point (e.g. "12.3" so far)
|
||||
"""
|
||||
START = "START"
|
||||
INTEGER = "INTEGER"
|
||||
DECIMAL = "DECIMAL"
|
||||
|
||||
|
||||
class CharClass(Enum):
|
||||
"""
|
||||
Character classification -- groups raw characters into categories
|
||||
so the transition table stays small and readable.
|
||||
"""
|
||||
DIGIT = "DIGIT"
|
||||
DOT = "DOT"
|
||||
OPERATOR = "OPERATOR"
|
||||
LPAREN = "LPAREN"
|
||||
RPAREN = "RPAREN"
|
||||
SPACE = "SPACE"
|
||||
EOF = "EOF"
|
||||
UNKNOWN = "UNKNOWN"
|
||||
|
||||
|
||||
class Action(Enum):
|
||||
"""
|
||||
What the FSM does on a transition. In a Mealy machine, the output
|
||||
(action) depends on both the current state AND the input.
|
||||
"""
|
||||
ACCUMULATE = "ACCUMULATE"
|
||||
EMIT_NUMBER = "EMIT_NUMBER"
|
||||
EMIT_OPERATOR = "EMIT_OPERATOR"
|
||||
EMIT_LPAREN = "EMIT_LPAREN"
|
||||
EMIT_RPAREN = "EMIT_RPAREN"
|
||||
EMIT_NUMBER_THEN_OP = "EMIT_NUMBER_THEN_OP"
|
||||
EMIT_NUMBER_THEN_LPAREN = "EMIT_NUMBER_THEN_LPAREN"
|
||||
EMIT_NUMBER_THEN_RPAREN = "EMIT_NUMBER_THEN_RPAREN"
|
||||
EMIT_NUMBER_THEN_DONE = "EMIT_NUMBER_THEN_DONE"
|
||||
SKIP = "SKIP"
|
||||
DONE = "DONE"
|
||||
ERROR = "ERROR"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Transition:
|
||||
next_state: State
|
||||
action: Action
|
||||
|
||||
|
||||
# ---------- Transition table ----------
|
||||
# This is the heart of the state machine. Every (state, char_class) pair
|
||||
# maps to exactly one transition: a next state and an action to perform.
|
||||
# Making this a data structure (not nested if/else) means we can:
|
||||
# 1. Inspect it programmatically (e.g. to generate a diagram)
|
||||
# 2. Verify completeness (every combination is covered)
|
||||
# 3. Understand the FSM at a glance
|
||||
|
||||
TRANSITIONS = {
|
||||
# --- START: between tokens, dispatch based on character class ---
|
||||
(State.START, CharClass.DIGIT): Transition(State.INTEGER, Action.ACCUMULATE),
|
||||
(State.START, CharClass.DOT): Transition(State.DECIMAL, Action.ACCUMULATE),
|
||||
(State.START, CharClass.OPERATOR): Transition(State.START, Action.EMIT_OPERATOR),
|
||||
(State.START, CharClass.LPAREN): Transition(State.START, Action.EMIT_LPAREN),
|
||||
(State.START, CharClass.RPAREN): Transition(State.START, Action.EMIT_RPAREN),
|
||||
(State.START, CharClass.SPACE): Transition(State.START, Action.SKIP),
|
||||
(State.START, CharClass.EOF): Transition(State.START, Action.DONE),
|
||||
|
||||
# --- INTEGER: accumulating digits like "123" ---
|
||||
(State.INTEGER, CharClass.DIGIT): Transition(State.INTEGER, Action.ACCUMULATE),
|
||||
(State.INTEGER, CharClass.DOT): Transition(State.DECIMAL, Action.ACCUMULATE),
|
||||
(State.INTEGER, CharClass.OPERATOR): Transition(State.START, Action.EMIT_NUMBER_THEN_OP),
|
||||
(State.INTEGER, CharClass.LPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_LPAREN),
|
||||
(State.INTEGER, CharClass.RPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_RPAREN),
|
||||
(State.INTEGER, CharClass.SPACE): Transition(State.START, Action.EMIT_NUMBER),
|
||||
(State.INTEGER, CharClass.EOF): Transition(State.START, Action.EMIT_NUMBER_THEN_DONE),
|
||||
|
||||
# --- DECIMAL: accumulating digits after "." like "123.45" ---
|
||||
(State.DECIMAL, CharClass.DIGIT): Transition(State.DECIMAL, Action.ACCUMULATE),
|
||||
(State.DECIMAL, CharClass.DOT): Transition(State.START, Action.ERROR),
|
||||
(State.DECIMAL, CharClass.OPERATOR): Transition(State.START, Action.EMIT_NUMBER_THEN_OP),
|
||||
(State.DECIMAL, CharClass.LPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_LPAREN),
|
||||
(State.DECIMAL, CharClass.RPAREN): Transition(State.START, Action.EMIT_NUMBER_THEN_RPAREN),
|
||||
(State.DECIMAL, CharClass.SPACE): Transition(State.START, Action.EMIT_NUMBER),
|
||||
(State.DECIMAL, CharClass.EOF): Transition(State.START, Action.EMIT_NUMBER_THEN_DONE),
|
||||
}
|
||||
|
||||
|
||||
# ---------- Errors ----------
|
||||
|
||||
class TokenError(Exception):
|
||||
def __init__(self, message, position):
|
||||
self.position = position
|
||||
super().__init__(f"Token error at position {position}: {message}")
|
||||
|
||||
|
||||
# ---------- Character classification ----------
|
||||
|
||||
def classify(ch):
|
||||
"""Map a single character to its CharClass."""
|
||||
if ch.isdigit():
|
||||
return CharClass.DIGIT
|
||||
if ch == '.':
|
||||
return CharClass.DOT
|
||||
if ch in OPERATOR_MAP:
|
||||
return CharClass.OPERATOR
|
||||
if ch == '(':
|
||||
return CharClass.LPAREN
|
||||
if ch == ')':
|
||||
return CharClass.RPAREN
|
||||
if ch.isspace():
|
||||
return CharClass.SPACE
|
||||
return CharClass.UNKNOWN
|
||||
|
||||
|
||||
# ---------- Main tokenize function ----------
|
||||
|
||||
def tokenize(expression):
|
||||
"""
|
||||
Process an expression string through the state machine, producing tokens.
|
||||
|
||||
The main loop:
|
||||
1. Classify the current character
|
||||
2. Look up (state, char_class) in the transition table
|
||||
3. Execute the action (accumulate, emit, skip, etc.)
|
||||
4. Move to the next state
|
||||
5. Advance to the next character
|
||||
|
||||
After all tokens are emitted, a post-processing step resolves
|
||||
unary minus: if a MINUS token appears at the start, after an operator,
|
||||
or after LPAREN, it is re-classified as UNARY_MINUS.
|
||||
"""
|
||||
state = State.START
|
||||
buffer = [] # characters accumulated for the current token
|
||||
buffer_start = 0 # position where the current buffer started
|
||||
tokens = []
|
||||
pos = 0
|
||||
|
||||
# Append a sentinel so EOF is handled uniformly in the loop
|
||||
chars = expression + '\0'
|
||||
|
||||
while pos <= len(expression):
|
||||
ch = chars[pos]
|
||||
char_class = CharClass.EOF if pos == len(expression) else classify(ch)
|
||||
|
||||
if char_class == CharClass.UNKNOWN:
|
||||
raise TokenError(f"unexpected character {ch!r}", pos)
|
||||
|
||||
# Look up the transition
|
||||
key = (state, char_class)
|
||||
transition = TRANSITIONS.get(key)
|
||||
if transition is None:
|
||||
raise TokenError(f"no transition for state={state.name}, input={char_class.name}", pos)
|
||||
|
||||
action = transition.action
|
||||
next_state = transition.next_state
|
||||
|
||||
# --- Execute the action ---
|
||||
|
||||
if action == Action.ACCUMULATE:
|
||||
if not buffer:
|
||||
buffer_start = pos
|
||||
buffer.append(ch)
|
||||
|
||||
elif action == Action.EMIT_NUMBER:
|
||||
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
|
||||
buffer.clear()
|
||||
|
||||
elif action == Action.EMIT_OPERATOR:
|
||||
tokens.append(Token(OPERATOR_MAP[ch], ch, pos))
|
||||
|
||||
elif action == Action.EMIT_LPAREN:
|
||||
tokens.append(Token(TokenType.LPAREN, ch, pos))
|
||||
|
||||
elif action == Action.EMIT_RPAREN:
|
||||
tokens.append(Token(TokenType.RPAREN, ch, pos))
|
||||
|
||||
elif action == Action.EMIT_NUMBER_THEN_OP:
|
||||
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
|
||||
buffer.clear()
|
||||
tokens.append(Token(OPERATOR_MAP[ch], ch, pos))
|
||||
|
||||
elif action == Action.EMIT_NUMBER_THEN_LPAREN:
|
||||
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
|
||||
buffer.clear()
|
||||
tokens.append(Token(TokenType.LPAREN, ch, pos))
|
||||
|
||||
elif action == Action.EMIT_NUMBER_THEN_RPAREN:
|
||||
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
|
||||
buffer.clear()
|
||||
tokens.append(Token(TokenType.RPAREN, ch, pos))
|
||||
|
||||
elif action == Action.EMIT_NUMBER_THEN_DONE:
|
||||
tokens.append(Token(TokenType.NUMBER, ''.join(buffer), buffer_start))
|
||||
buffer.clear()
|
||||
|
||||
elif action == Action.SKIP:
|
||||
pass
|
||||
|
||||
elif action == Action.DONE:
|
||||
pass
|
||||
|
||||
elif action == Action.ERROR:
|
||||
raise TokenError(f"unexpected {ch!r} in state {state.name}", pos)
|
||||
|
||||
state = next_state
|
||||
pos += 1
|
||||
|
||||
# --- Post-processing: resolve unary minus ---
|
||||
# A MINUS is unary if it appears:
|
||||
# - at the very start of the token stream
|
||||
# - immediately after an operator (+, -, *, /, ^) or LPAREN
|
||||
# This context-sensitivity cannot be captured by the FSM alone --
|
||||
# it requires looking at previously emitted tokens.
|
||||
_resolve_unary_minus(tokens)
|
||||
|
||||
tokens.append(Token(TokenType.EOF, '', len(expression)))
|
||||
return tokens
|
||||
|
||||
|
||||
def _resolve_unary_minus(tokens):
|
||||
"""
|
||||
Convert binary MINUS tokens to UNARY_MINUS where appropriate.
|
||||
|
||||
Why this isn't in the FSM: the FSM processes characters one at a time
|
||||
and only tracks what kind of token it's currently building (its state).
|
||||
But whether '-' is unary or binary depends on the PREVIOUS TOKEN --
|
||||
information the FSM doesn't track. This is a common real-world pattern:
|
||||
the lexer handles most work, then a lightweight post-pass adds context.
|
||||
"""
|
||||
unary_predecessor = {
|
||||
TokenType.PLUS, TokenType.MINUS, TokenType.MULTIPLY,
|
||||
TokenType.DIVIDE, TokenType.POWER, TokenType.LPAREN,
|
||||
TokenType.UNARY_MINUS,
|
||||
}
|
||||
for i, token in enumerate(tokens):
|
||||
if token.type != TokenType.MINUS:
|
||||
continue
|
||||
if i == 0 or tokens[i - 1].type in unary_predecessor:
|
||||
tokens[i] = Token(TokenType.UNARY_MINUS, token.value, token.position)
|
||||
200
python/expression-evaluator/visualize.py
Normal file
200
python/expression-evaluator/visualize.py
Normal file
@@ -0,0 +1,200 @@
|
||||
"""
|
||||
Part 4: Visualization -- Graphviz Dot Output
|
||||
==============================================
|
||||
Generate graphviz dot-format strings for:
|
||||
1. The tokenizer's finite state machine (FSM)
|
||||
2. Any expression's AST (DAG)
|
||||
3. Text-based tree rendering for the terminal
|
||||
|
||||
No external dependencies -- outputs raw dot strings that can be piped
|
||||
to the 'dot' command: python main.py --dot "3+4*2" | dot -Tpng -o ast.png
|
||||
"""
|
||||
|
||||
from parser import NumberNode, BinOpNode, UnaryOpNode, Node
|
||||
from tokenizer import TRANSITIONS, State, CharClass, Action, TokenType
|
||||
|
||||
|
||||
# ---------- FSM diagram ----------
|
||||
|
||||
# Human-readable labels for character classes
|
||||
_CHAR_LABELS = {
|
||||
CharClass.DIGIT: "digit",
|
||||
CharClass.DOT: "'.'",
|
||||
CharClass.OPERATOR: "op",
|
||||
CharClass.LPAREN: "'('",
|
||||
CharClass.RPAREN: "')'",
|
||||
CharClass.SPACE: "space",
|
||||
CharClass.EOF: "EOF",
|
||||
}
|
||||
|
||||
# Short labels for actions
|
||||
_ACTION_LABELS = {
|
||||
Action.ACCUMULATE: "accum",
|
||||
Action.EMIT_NUMBER: "emit num",
|
||||
Action.EMIT_OPERATOR: "emit op",
|
||||
Action.EMIT_LPAREN: "emit '('",
|
||||
Action.EMIT_RPAREN: "emit ')'",
|
||||
Action.EMIT_NUMBER_THEN_OP: "emit num+op",
|
||||
Action.EMIT_NUMBER_THEN_LPAREN: "emit num+'('",
|
||||
Action.EMIT_NUMBER_THEN_RPAREN: "emit num+')'",
|
||||
Action.EMIT_NUMBER_THEN_DONE: "emit num, done",
|
||||
Action.SKIP: "skip",
|
||||
Action.DONE: "done",
|
||||
Action.ERROR: "ERROR",
|
||||
}
|
||||
|
||||
|
||||
def fsm_to_dot():
|
||||
"""
|
||||
Generate a graphviz dot diagram of the tokenizer's state machine.
|
||||
|
||||
Reads the TRANSITIONS table directly -- because the FSM is data (a dict),
|
||||
we can programmatically inspect and visualize it. This is a key advantage
|
||||
of explicit state machines over implicit if/else control flow.
|
||||
"""
|
||||
lines = [
|
||||
'digraph FSM {',
|
||||
' rankdir=LR;',
|
||||
' node [shape=circle, fontname="Helvetica"];',
|
||||
' edge [fontname="Helvetica", fontsize=10];',
|
||||
'',
|
||||
' // Start indicator',
|
||||
' __start__ [shape=point, width=0.2];',
|
||||
' __start__ -> START;',
|
||||
'',
|
||||
]
|
||||
|
||||
# Collect edges grouped by (src, dst) to merge labels
|
||||
edge_labels = {}
|
||||
for (state, char_class), transition in TRANSITIONS.items():
|
||||
src = state.name
|
||||
dst = transition.next_state.name
|
||||
char_label = _CHAR_LABELS.get(char_class, char_class.name)
|
||||
action_label = _ACTION_LABELS.get(transition.action, transition.action.name)
|
||||
label = f"{char_label} / {action_label}"
|
||||
edge_labels.setdefault((src, dst), []).append(label)
|
||||
|
||||
# Emit edges
|
||||
for (src, dst), labels in sorted(edge_labels.items()):
|
||||
combined = "\\n".join(labels)
|
||||
lines.append(f' {src} -> {dst} [label="{combined}"];')
|
||||
|
||||
lines.append('}')
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
# ---------- AST diagram ----------
|
||||
|
||||
_OP_LABELS = {
|
||||
TokenType.PLUS: '+',
|
||||
TokenType.MINUS: '-',
|
||||
TokenType.MULTIPLY: '*',
|
||||
TokenType.DIVIDE: '/',
|
||||
TokenType.POWER: '^',
|
||||
TokenType.UNARY_MINUS: 'neg',
|
||||
}
|
||||
|
||||
|
||||
def ast_to_dot(node):
|
||||
"""
|
||||
Generate a graphviz dot diagram of an AST (expression tree / DAG).
|
||||
|
||||
Each node gets a unique ID. Edges go from parent to children,
|
||||
showing the directed acyclic structure. Leaves are boxed,
|
||||
operators are ellipses.
|
||||
"""
|
||||
lines = [
|
||||
'digraph AST {',
|
||||
' node [fontname="Helvetica"];',
|
||||
' edge [fontname="Helvetica"];',
|
||||
'',
|
||||
]
|
||||
counter = [0]
|
||||
|
||||
def _visit(node):
|
||||
nid = f"n{counter[0]}"
|
||||
counter[0] += 1
|
||||
|
||||
match node:
|
||||
case NumberNode(value=v):
|
||||
label = _format_number(v)
|
||||
lines.append(f' {nid} [label="{label}", shape=box, style=rounded];')
|
||||
return nid
|
||||
|
||||
case UnaryOpNode(op=op, operand=child):
|
||||
label = _OP_LABELS.get(op, op.name)
|
||||
lines.append(f' {nid} [label="{label}", shape=ellipse];')
|
||||
child_id = _visit(child)
|
||||
lines.append(f' {nid} -> {child_id};')
|
||||
return nid
|
||||
|
||||
case BinOpNode(op=op, left=left, right=right):
|
||||
label = _OP_LABELS.get(op, op.name)
|
||||
lines.append(f' {nid} [label="{label}", shape=ellipse];')
|
||||
left_id = _visit(left)
|
||||
right_id = _visit(right)
|
||||
lines.append(f' {nid} -> {left_id} [label="L"];')
|
||||
lines.append(f' {nid} -> {right_id} [label="R"];')
|
||||
return nid
|
||||
|
||||
_visit(node)
|
||||
lines.append('}')
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
# ---------- Text-based tree ----------
|
||||
|
||||
def ast_to_text(node, prefix="", connector=""):
|
||||
"""
|
||||
Render the AST as an indented text tree for terminal display.
|
||||
|
||||
Example output for (2 + 3) * 4:
|
||||
*
|
||||
+-- +
|
||||
| +-- 2
|
||||
| +-- 3
|
||||
+-- 4
|
||||
"""
|
||||
match node:
|
||||
case NumberNode(value=v):
|
||||
label = _format_number(v)
|
||||
case UnaryOpNode(op=op):
|
||||
label = _OP_LABELS.get(op, op.name)
|
||||
case BinOpNode(op=op):
|
||||
label = _OP_LABELS.get(op, op.name)
|
||||
|
||||
lines = [f"{prefix}{connector}{label}"]
|
||||
|
||||
children = _get_children(node)
|
||||
for i, child in enumerate(children):
|
||||
is_last_child = (i == len(children) - 1)
|
||||
if connector:
|
||||
# Extend the prefix: if we used "+-- " then next children
|
||||
# see "| " (continuing) or " " (last child)
|
||||
child_prefix = prefix + ("| " if connector == "+-- " else " ")
|
||||
else:
|
||||
child_prefix = prefix
|
||||
child_connector = "+-- " if is_last_child else "+-- "
|
||||
# Use a different lead for non-last: the vertical bar continues
|
||||
child_connector = "`-- " if is_last_child else "+-- "
|
||||
child_lines = ast_to_text(child, child_prefix, child_connector)
|
||||
lines.append(child_lines)
|
||||
|
||||
return '\n'.join(lines)
|
||||
|
||||
|
||||
def _get_children(node):
|
||||
match node:
|
||||
case NumberNode():
|
||||
return []
|
||||
case UnaryOpNode(operand=child):
|
||||
return [child]
|
||||
case BinOpNode(left=left, right=right):
|
||||
return [left, right]
|
||||
return []
|
||||
|
||||
|
||||
def _format_number(v):
|
||||
if isinstance(v, float) and v == int(v):
|
||||
return str(int(v))
|
||||
return str(v)
|
||||
@@ -6,9 +6,18 @@ import ollama
|
||||
|
||||
DEFAULT_OLLAMA_MODEL = "qwen2.5:7b"
|
||||
|
||||
_ollama_model = DEFAULT_OLLAMA_MODEL
|
||||
|
||||
def ask_ollama(prompt, system=None, model=DEFAULT_OLLAMA_MODEL):
|
||||
|
||||
def set_ollama_model(model):
|
||||
"""Change the Ollama model used for fast queries."""
|
||||
global _ollama_model
|
||||
_ollama_model = model
|
||||
|
||||
|
||||
def ask_ollama(prompt, system=None, model=None):
|
||||
"""Query Ollama with an optional system prompt."""
|
||||
model = model or _ollama_model
|
||||
messages = []
|
||||
if system:
|
||||
messages.append({"role": "system", "content": system})
|
||||
@@ -24,6 +33,8 @@ def ask_claude(prompt):
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(f"Claude CLI failed (exit {result.returncode}): {result.stderr.strip()}")
|
||||
return result.stdout.strip()
|
||||
|
||||
|
||||
@@ -34,8 +45,9 @@ def ask(prompt, system=None, quality="fast"):
|
||||
return ask_ollama(prompt, system=system)
|
||||
|
||||
|
||||
def chat_ollama(messages, system=None, model=DEFAULT_OLLAMA_MODEL):
|
||||
def chat_ollama(messages, system=None, model=None):
|
||||
"""Multi-turn conversation with Ollama."""
|
||||
model = model or _ollama_model
|
||||
all_messages = []
|
||||
if system:
|
||||
all_messages.append({"role": "system", "content": system})
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"""Generate Anki .apkg decks from vocabulary data."""
|
||||
|
||||
import genanki
|
||||
import random
|
||||
|
||||
# Stable model/deck IDs (generated once, kept constant)
|
||||
_MODEL_ID = 1607392319
|
||||
|
||||
@@ -7,6 +7,7 @@ import time
|
||||
|
||||
import gradio as gr
|
||||
|
||||
import ai
|
||||
import db
|
||||
from modules import vocab, dashboard, essay, tutor, idioms
|
||||
from modules.essay import GCSE_THEMES
|
||||
@@ -214,6 +215,15 @@ def do_anki_export(cats_selected):
|
||||
return path
|
||||
|
||||
|
||||
def update_ollama_model(model):
|
||||
ai.set_ollama_model(model)
|
||||
|
||||
|
||||
def update_whisper_size(size):
|
||||
from stt import set_whisper_size
|
||||
set_whisper_size(size)
|
||||
|
||||
|
||||
def reset_progress():
|
||||
conn = db.get_connection()
|
||||
conn.execute("DELETE FROM word_progress")
|
||||
@@ -491,6 +501,10 @@ with gr.Blocks(title="Persian Language Tutor") as app:
|
||||
|
||||
export_btn.click(fn=do_anki_export, inputs=[export_cats], outputs=[export_file])
|
||||
|
||||
# Wire model settings
|
||||
ollama_model.change(fn=update_ollama_model, inputs=[ollama_model])
|
||||
whisper_size.change(fn=update_whisper_size, inputs=[whisper_size])
|
||||
|
||||
gr.Markdown("### Reset")
|
||||
reset_btn = gr.Button("Reset All Progress", variant="stop")
|
||||
reset_status = gr.Markdown("")
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import json
|
||||
import sqlite3
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from pathlib import Path
|
||||
|
||||
import fsrs
|
||||
@@ -148,6 +148,13 @@ def get_word_counts(total_vocab_size=0):
|
||||
}
|
||||
|
||||
|
||||
def get_all_word_progress():
|
||||
"""Return all word progress as a dict of word_id -> progress dict."""
|
||||
conn = get_connection()
|
||||
rows = conn.execute("SELECT * FROM word_progress").fetchall()
|
||||
return {row["word_id"]: dict(row) for row in rows}
|
||||
|
||||
|
||||
def record_quiz_session(category, total_questions, correct, duration_seconds):
|
||||
"""Log a completed flashcard session."""
|
||||
conn = get_connection()
|
||||
@@ -203,7 +210,7 @@ def get_stats():
|
||||
today = datetime.now(timezone.utc).date()
|
||||
for i, row in enumerate(days):
|
||||
day = datetime.fromisoformat(row["d"]).date() if isinstance(row["d"], str) else row["d"]
|
||||
expected = today - __import__("datetime").timedelta(days=i)
|
||||
expected = today - timedelta(days=i)
|
||||
if day == expected:
|
||||
streak += 1
|
||||
else:
|
||||
|
||||
@@ -19,17 +19,17 @@ def get_category_breakdown():
|
||||
"""Return progress per category as list of dicts."""
|
||||
vocab = load_vocab()
|
||||
categories = get_categories()
|
||||
all_progress = db.get_all_word_progress()
|
||||
|
||||
breakdown = []
|
||||
for cat in categories:
|
||||
cat_words = [e for e in vocab if e["category"] == cat]
|
||||
cat_ids = {e["id"] for e in cat_words}
|
||||
total = len(cat_words)
|
||||
|
||||
seen = 0
|
||||
mastered = 0
|
||||
for wid in cat_ids:
|
||||
progress = db.get_word_progress(wid)
|
||||
for e in cat_words:
|
||||
progress = all_progress.get(e["id"])
|
||||
if progress:
|
||||
seen += 1
|
||||
if progress["stability"] and progress["stability"] > 10:
|
||||
|
||||
@@ -84,8 +84,9 @@ def get_flashcard_batch(count=10, category=None):
|
||||
remaining = count - len(due_entries)
|
||||
if remaining > 0:
|
||||
seen_ids = {e["id"] for e in due_entries}
|
||||
all_progress = db.get_all_word_progress()
|
||||
# Prefer unseen words
|
||||
unseen = [e for e in pool if e["id"] not in seen_ids and not db.get_word_progress(e["id"])]
|
||||
unseen = [e for e in pool if e["id"] not in seen_ids and e["id"] not in all_progress]
|
||||
if len(unseen) >= remaining:
|
||||
fill = random.sample(unseen, remaining)
|
||||
else:
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
"""Persian speech-to-text wrapper using sttlib."""
|
||||
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
|
||||
sys.path.insert(0, "/home/ys/family-repo/Code/python/tool-speechtotext")
|
||||
# sttlib lives in sibling project tool-speechtotext
|
||||
_sttlib_path = str(Path(__file__).resolve().parent.parent / "tool-speechtotext")
|
||||
sys.path.insert(0, _sttlib_path)
|
||||
from sttlib import load_whisper_model, transcribe, is_hallucination
|
||||
|
||||
_model = None
|
||||
_whisper_size = "medium"
|
||||
|
||||
# Common Whisper hallucinations in Persian/silence
|
||||
PERSIAN_HALLUCINATIONS = [
|
||||
@@ -18,11 +22,19 @@ PERSIAN_HALLUCINATIONS = [
|
||||
]
|
||||
|
||||
|
||||
def get_model(size="medium"):
|
||||
def set_whisper_size(size):
|
||||
"""Change the Whisper model size. Reloads on next transcription."""
|
||||
global _whisper_size, _model
|
||||
if size != _whisper_size:
|
||||
_whisper_size = size
|
||||
_model = None
|
||||
|
||||
|
||||
def get_model():
|
||||
"""Load Whisper model (cached singleton)."""
|
||||
global _model
|
||||
if _model is None:
|
||||
_model = load_whisper_model(size)
|
||||
_model = load_whisper_model(_whisper_size)
|
||||
return _model
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user