From 28c9f2f9f85b4adbb867abccb1efc176c7ca7361 Mon Sep 17 00:00:00 2001 From: imiel Date: Sun, 19 Apr 2026 15:50:59 +0000 Subject: [PATCH] Replace hardcoded universe levels with a proper level language and constraint solving --- BidirTT/Check.lean | 8 +++---- BidirTT/Eval.lean | 8 +++---- BidirTT/Examples.lean | 2 ++ BidirTT/Level.lean | 52 +++++++++++++++++++++++++++++++++++++++++++ BidirTT/Pretty.lean | 2 +- BidirTT/Syntax.lean | 6 +++-- BidirTT/Value.lean | 2 +- Tests.lean | 23 ++++++++++++++++++- 8 files changed, 90 insertions(+), 13 deletions(-) create mode 100644 BidirTT/Level.lean diff --git a/BidirTT/Check.lean b/BidirTT/Check.lean index cf7d820..82a5880 100644 --- a/BidirTT/Check.lean +++ b/BidirTT/Check.lean @@ -122,7 +122,7 @@ mutual let _ := level pure (.idElim motive' r' target' eq', resultTy) | _ => throw s!"expected Id type in idElim, got {showTy cxt.lvl eqTy}" - | .univ i => pure (.univ i, .univ (i + 1)) + | .univ i => pure (.univ i.normalise, .univ i.succ') | .app t u => do let (t', tty) ← infer cxt t match tty with @@ -150,12 +150,12 @@ mutual let (a', i) ← inferUniverse cxt a let va ← eval cxt.env a' let (b', j) ← inferUniverse (cxt.bind x va) b - pure (.pi a' b', .univ (Nat.max i j)) + pure (.pi a' b', .univ (i.max' j)) | .sig x a b => do let (a', i) ← inferUniverse cxt a let va ← eval cxt.env a' let (b', j) ← inferUniverse (cxt.bind x va) b - pure (.sig a' b', .univ (Nat.max i j)) + pure (.sig a' b', .univ (i.max' j)) | .ann t a => do let (a', _) ← inferUniverse cxt a let va ← eval cxt.env a' @@ -171,7 +171,7 @@ mutual | .lam _ _ => throw "cannot infer type of lambda, use an annotation" | .pair _ _ => throw "cannot infer type of pair, use an annotation" - partial def inferUniverse (cxt : Cxt) (r : Raw) : TCM (Tm × Nat) := do + partial def inferUniverse (cxt : Cxt) (r : Raw) : TCM (Tm × Level) := do let (t, ty) ← infer cxt r match ty with | .univ level => pure (t, level) diff --git a/BidirTT/Eval.lean b/BidirTT/Eval.lean index e9cd58b..e15516b 100644 --- a/BidirTT/Eval.lean +++ b/BidirTT/Eval.lean @@ -67,7 +67,7 @@ mutual let vy ← eval env y let vp ← eval env p vIdElim vm vr vy vp - | _, .univ i => pure (.univ i) + | _, .univ i => pure (.univ i.normalise) | env, .letE _ t u => do let vt ← eval env t eval (vt :: env) u @@ -187,7 +187,7 @@ mutual let qu ← quote l u pure (.id qa qt qu) | _, .refl => pure .refl - | _, .univ i => pure (.univ i) + | _, .univ i => pure (.univ i.normalise) end private def andThen (lhs : EvalM Bool) (rhs : Unit → EvalM Bool) : EvalM Bool := do @@ -239,7 +239,7 @@ mutual | _, _, _ => pure false partial def conv : Lvl → Val → Val → EvalM Bool - | _, .univ i, .univ j => pure (i == j) + | _, .univ i, .univ j => Level.eqv i j | l, .pi a c, .pi a' c' => andThen (conv l a a') fun _ => do let b ← cApp c (.neu (.var l)) @@ -299,7 +299,7 @@ mutual end partial def sub : Lvl → Val → Val → EvalM Bool - | _, .univ i, .univ j => pure (i <= j) + | _, .univ i, .univ j => Level.leq i j | l, .pi a c, .pi a' c' => andThen (sub l a' a) fun _ => do let b ← cApp c (.neu (.var l)) diff --git a/BidirTT/Examples.lean b/BidirTT/Examples.lean index e78d998..78c09a1 100644 --- a/BidirTT/Examples.lean +++ b/BidirTT/Examples.lean @@ -46,6 +46,8 @@ def fstDepPair : Raw := .fst depPairAnn def sndDepPair : Raw := .snd depPairAnn +def univMax : Raw := .univ (.max 0 1) + def natTwo : Raw := .succ (.succ .zero) diff --git a/BidirTT/Level.lean b/BidirTT/Level.lean new file mode 100644 index 0000000..793ac2e --- /dev/null +++ b/BidirTT/Level.lean @@ -0,0 +1,52 @@ +namespace BidirTT + +inductive Level where + | zero : Level + | succ : Level → Level + | max : Level → Level → Level + deriving Repr, Inhabited, BEq, DecidableEq + +def Level.ofNat : Nat → Level + | 0 => .zero + | n + 1 => .succ (Level.ofNat n) + +instance (n : Nat) : OfNat Level n where + ofNat := Level.ofNat n + +partial def Level.eval : Level → Nat + | .zero => 0 + | .succ l => l.eval + 1 + | .max l r => Nat.max l.eval r.eval + +def Level.normalise (l : Level) : Level := + .ofNat l.eval + +def Level.succ' (l : Level) : Level := + .normalise (.succ l) + +def Level.max' (l r : Level) : Level := + .normalise (.max l r) + +def Level.pretty (l : Level) : String := + s!"{l.eval}" + +abbrev LevelConstraint := Level × Level + +def solveLevelConstraints (constraints : List LevelConstraint) : Except String Unit := do + match constraints.find? fun (lhs, rhs) => lhs.eval > rhs.eval with + | some (lhs, rhs) => + throw s!"unsatisfiable level constraint {lhs.pretty} <= {rhs.pretty}" + | none => + pure () + +def Level.leq (lhs rhs : Level) : Except String Bool := do + match solveLevelConstraints [(lhs, rhs)] with + | Except.ok _ => pure true + | Except.error _ => pure false + +def Level.eqv (lhs rhs : Level) : Except String Bool := do + match solveLevelConstraints [(lhs, rhs), (rhs, lhs)] with + | Except.ok _ => pure true + | Except.error _ => pure false + +end BidirTT diff --git a/BidirTT/Pretty.lean b/BidirTT/Pretty.lean index 74596ab..a33cd07 100644 --- a/BidirTT/Pretty.lean +++ b/BidirTT/Pretty.lean @@ -71,7 +71,7 @@ mutual let (py, nextFresh) := prettyTmWith names nextFresh y let (pp, nextFresh) := prettyTmWith names nextFresh p (s!"(idElim {pm} {pr} {py} {pp})", nextFresh) - | .univ i => (s!"U{i}", nextFresh) + | .univ i => (s!"U{i.pretty}", nextFresh) | .letE a t u => let x := s!"x{nextFresh}" let (pa, nextFresh) := prettyTmWith names (nextFresh + 1) a diff --git a/BidirTT/Syntax.lean b/BidirTT/Syntax.lean index f0b140a..9776abc 100644 --- a/BidirTT/Syntax.lean +++ b/BidirTT/Syntax.lean @@ -1,3 +1,5 @@ +import BidirTT.Level + namespace BidirTT abbrev Name := String @@ -23,7 +25,7 @@ inductive Raw where | id : Raw → Raw → Raw → Raw | refl : Raw | idElim : Name → Name → Raw → Raw → Raw → Raw → Raw - | univ : Nat → Raw + | univ : Level → Raw | letE : Name → Raw → Raw → Raw → Raw | ann : Raw → Raw → Raw deriving Repr, Inhabited, BEq, DecidableEq @@ -49,7 +51,7 @@ inductive Tm where | id : Tm → Tm → Tm → Tm | refl : Tm | idElim : Tm → Tm → Tm → Tm → Tm - | univ : Nat → Tm + | univ : Level → Tm | letE : Tm → Tm → Tm → Tm deriving Repr, Inhabited, BEq, DecidableEq diff --git a/BidirTT/Value.lean b/BidirTT/Value.lean index 220a01e..a3cfd43 100644 --- a/BidirTT/Value.lean +++ b/BidirTT/Value.lean @@ -27,7 +27,7 @@ mutual | empty : Val | id : Val → Val → Val → Val | refl : Val - | univ : Nat → Val + | univ : Level → Val inductive Closure where | mk : List Val → Tm → Closure diff --git a/Tests.lean b/Tests.lean index f45672b..7ef79e1 100644 --- a/Tests.lean +++ b/Tests.lean @@ -25,8 +25,12 @@ structure TestCase where def cases : List TestCase := [ ⟨"U0 is typed by U1", Examples.univ0, .okTy (.univ 1)⟩, + ⟨"U(max 0 1) is typed by U2", Examples.univMax, + .okTy (.univ 2)⟩, ⟨"U0 subsumes into U2", .ann (.univ 0) (.univ 2), .okTy (.univ 2)⟩, + ⟨"U0 subsumes into U(max 0 2)", .ann (.univ 0) (.univ (.max 0 2)), + .okTy (.univ 2)⟩, ⟨"Nat is typed by U0", .nat, .okTy (.univ 0)⟩, ⟨"succ zero infers Nat", (.succ .zero), @@ -164,6 +168,22 @@ def runNeutralRepresentationChecks : IO Bool := do IO.println "fail stuck eliminators stay in the neutral fragment" pure false +def runLevelSolverChecks : IO Bool := do + let satOk := + match solveLevelConstraints [(.max 0 1, 1), (1, .succ 1)] with + | Except.ok _ => true + | Except.error _ => false + let unsatOk := + match solveLevelConstraints [(.succ 1, 1)] with + | Except.error err => containsText err "unsatisfiable level constraint" + | Except.ok _ => false + if satOk && unsatOk then + IO.println "pass level constraints are solved consistently" + pure true + else + IO.println "fail level constraints are solved consistently" + pure false + def runPrettyPrinterChecks : IO Bool := do match checkTop Examples.idAnn with | .ok (tm, ty) => @@ -192,8 +212,9 @@ def main : IO UInt32 := do let results ← cases.mapM runCase let safetyOk ← runInternalSafetyChecks let neutralOk ← runNeutralRepresentationChecks + let levelOk ← runLevelSolverChecks let prettyOk ← runPrettyPrinterChecks - let allResults := results ++ [safetyOk, neutralOk, prettyOk] + let allResults := results ++ [safetyOk, neutralOk, levelOk, prettyOk] let failed := allResults.countP (· == false) if failed == 0 then IO.println s!"\n{allResults.length} passed, 0 failed"