Replace hardcoded universe levels with a proper level language and constraint solving

This commit is contained in:
2026-04-19 15:50:59 +00:00
parent 963c9f3e94
commit 28c9f2f9f8
8 changed files with 90 additions and 13 deletions
+4 -4
View File
@@ -122,7 +122,7 @@ mutual
let _ := level
pure (.idElim motive' r' target' eq', resultTy)
| _ => throw s!"expected Id type in idElim, got {showTy cxt.lvl eqTy}"
| .univ i => pure (.univ i, .univ (i + 1))
| .univ i => pure (.univ i.normalise, .univ i.succ')
| .app t u => do
let (t', tty) infer cxt t
match tty with
@@ -150,12 +150,12 @@ mutual
let (a', i) inferUniverse cxt a
let va eval cxt.env a'
let (b', j) inferUniverse (cxt.bind x va) b
pure (.pi a' b', .univ (Nat.max i j))
pure (.pi a' b', .univ (i.max' j))
| .sig x a b => do
let (a', i) inferUniverse cxt a
let va eval cxt.env a'
let (b', j) inferUniverse (cxt.bind x va) b
pure (.sig a' b', .univ (Nat.max i j))
pure (.sig a' b', .univ (i.max' j))
| .ann t a => do
let (a', _) inferUniverse cxt a
let va eval cxt.env a'
@@ -171,7 +171,7 @@ mutual
| .lam _ _ => throw "cannot infer type of lambda, use an annotation"
| .pair _ _ => throw "cannot infer type of pair, use an annotation"
partial def inferUniverse (cxt : Cxt) (r : Raw) : TCM (Tm × Nat) := do
partial def inferUniverse (cxt : Cxt) (r : Raw) : TCM (Tm × Level) := do
let (t, ty) infer cxt r
match ty with
| .univ level => pure (t, level)
+4 -4
View File
@@ -67,7 +67,7 @@ mutual
let vy eval env y
let vp eval env p
vIdElim vm vr vy vp
| _, .univ i => pure (.univ i)
| _, .univ i => pure (.univ i.normalise)
| env, .letE _ t u => do
let vt eval env t
eval (vt :: env) u
@@ -187,7 +187,7 @@ mutual
let qu quote l u
pure (.id qa qt qu)
| _, .refl => pure .refl
| _, .univ i => pure (.univ i)
| _, .univ i => pure (.univ i.normalise)
end
private def andThen (lhs : EvalM Bool) (rhs : Unit EvalM Bool) : EvalM Bool := do
@@ -239,7 +239,7 @@ mutual
| _, _, _ => pure false
partial def conv : Lvl Val Val EvalM Bool
| _, .univ i, .univ j => pure (i == j)
| _, .univ i, .univ j => Level.eqv i j
| l, .pi a c, .pi a' c' =>
andThen (conv l a a') fun _ => do
let b cApp c (.neu (.var l))
@@ -299,7 +299,7 @@ mutual
end
partial def sub : Lvl Val Val EvalM Bool
| _, .univ i, .univ j => pure (i <= j)
| _, .univ i, .univ j => Level.leq i j
| l, .pi a c, .pi a' c' =>
andThen (sub l a' a) fun _ => do
let b cApp c (.neu (.var l))
+2
View File
@@ -46,6 +46,8 @@ def fstDepPair : Raw := .fst depPairAnn
def sndDepPair : Raw := .snd depPairAnn
def univMax : Raw := .univ (.max 0 1)
def natTwo : Raw :=
.succ (.succ .zero)
+52
View File
@@ -0,0 +1,52 @@
namespace BidirTT
inductive Level where
| zero : Level
| succ : Level Level
| max : Level Level Level
deriving Repr, Inhabited, BEq, DecidableEq
def Level.ofNat : Nat Level
| 0 => .zero
| n + 1 => .succ (Level.ofNat n)
instance (n : Nat) : OfNat Level n where
ofNat := Level.ofNat n
partial def Level.eval : Level Nat
| .zero => 0
| .succ l => l.eval + 1
| .max l r => Nat.max l.eval r.eval
def Level.normalise (l : Level) : Level :=
.ofNat l.eval
def Level.succ' (l : Level) : Level :=
.normalise (.succ l)
def Level.max' (l r : Level) : Level :=
.normalise (.max l r)
def Level.pretty (l : Level) : String :=
s!"{l.eval}"
abbrev LevelConstraint := Level × Level
def solveLevelConstraints (constraints : List LevelConstraint) : Except String Unit := do
match constraints.find? fun (lhs, rhs) => lhs.eval > rhs.eval with
| some (lhs, rhs) =>
throw s!"unsatisfiable level constraint {lhs.pretty} <= {rhs.pretty}"
| none =>
pure ()
def Level.leq (lhs rhs : Level) : Except String Bool := do
match solveLevelConstraints [(lhs, rhs)] with
| Except.ok _ => pure true
| Except.error _ => pure false
def Level.eqv (lhs rhs : Level) : Except String Bool := do
match solveLevelConstraints [(lhs, rhs), (rhs, lhs)] with
| Except.ok _ => pure true
| Except.error _ => pure false
end BidirTT
+1 -1
View File
@@ -71,7 +71,7 @@ mutual
let (py, nextFresh) := prettyTmWith names nextFresh y
let (pp, nextFresh) := prettyTmWith names nextFresh p
(s!"(idElim {pm} {pr} {py} {pp})", nextFresh)
| .univ i => (s!"U{i}", nextFresh)
| .univ i => (s!"U{i.pretty}", nextFresh)
| .letE a t u =>
let x := s!"x{nextFresh}"
let (pa, nextFresh) := prettyTmWith names (nextFresh + 1) a
+4 -2
View File
@@ -1,3 +1,5 @@
import BidirTT.Level
namespace BidirTT
abbrev Name := String
@@ -23,7 +25,7 @@ inductive Raw where
| id : Raw Raw Raw Raw
| refl : Raw
| idElim : Name Name Raw Raw Raw Raw Raw
| univ : Nat Raw
| univ : Level Raw
| letE : Name Raw Raw Raw Raw
| ann : Raw Raw Raw
deriving Repr, Inhabited, BEq, DecidableEq
@@ -49,7 +51,7 @@ inductive Tm where
| id : Tm Tm Tm Tm
| refl : Tm
| idElim : Tm Tm Tm Tm Tm
| univ : Nat Tm
| univ : Level Tm
| letE : Tm Tm Tm Tm
deriving Repr, Inhabited, BEq, DecidableEq
+1 -1
View File
@@ -27,7 +27,7 @@ mutual
| empty : Val
| id : Val Val Val Val
| refl : Val
| univ : Nat Val
| univ : Level Val
inductive Closure where
| mk : List Val Tm Closure
+22 -1
View File
@@ -25,8 +25,12 @@ structure TestCase where
def cases : List TestCase := [
"U0 is typed by U1", Examples.univ0,
.okTy (.univ 1),
"U(max 0 1) is typed by U2", Examples.univMax,
.okTy (.univ 2),
"U0 subsumes into U2", .ann (.univ 0) (.univ 2),
.okTy (.univ 2),
"U0 subsumes into U(max 0 2)", .ann (.univ 0) (.univ (.max 0 2)),
.okTy (.univ 2),
"Nat is typed by U0", .nat,
.okTy (.univ 0),
"succ zero infers Nat", (.succ .zero),
@@ -164,6 +168,22 @@ def runNeutralRepresentationChecks : IO Bool := do
IO.println "fail stuck eliminators stay in the neutral fragment"
pure false
def runLevelSolverChecks : IO Bool := do
let satOk :=
match solveLevelConstraints [(.max 0 1, 1), (1, .succ 1)] with
| Except.ok _ => true
| Except.error _ => false
let unsatOk :=
match solveLevelConstraints [(.succ 1, 1)] with
| Except.error err => containsText err "unsatisfiable level constraint"
| Except.ok _ => false
if satOk && unsatOk then
IO.println "pass level constraints are solved consistently"
pure true
else
IO.println "fail level constraints are solved consistently"
pure false
def runPrettyPrinterChecks : IO Bool := do
match checkTop Examples.idAnn with
| .ok (tm, ty) =>
@@ -192,8 +212,9 @@ def main : IO UInt32 := do
let results cases.mapM runCase
let safetyOk runInternalSafetyChecks
let neutralOk runNeutralRepresentationChecks
let levelOk runLevelSolverChecks
let prettyOk runPrettyPrinterChecks
let allResults := results ++ [safetyOk, neutralOk, prettyOk]
let allResults := results ++ [safetyOk, neutralOk, levelOk, prettyOk]
let failed := allResults.countP (· == false)
if failed == 0 then
IO.println s!"\n{allResults.length} passed, 0 failed"