Write Nat/Unit/Empty/Id Eliminators Through NbE and Bidir Elaboration

This commit is contained in:
2026-04-19 13:55:05 +00:00
parent a154e2b98c
commit 85be37b1d6
8 changed files with 374 additions and 2 deletions
+80 -1
View File
@@ -23,6 +23,11 @@ mutual
let bodyTy cApp c vt let bodyTy cApp c vt
let u' check cxt u bodyTy let u' check cxt u bodyTy
pure (.pair t' u') pure (.pair t' u')
| .refl, .id a t u => do
if ( conv cxt.lvl t u) then
pure .refl
else
throw s!"refl cannot inhabit {showTy cxt.lvl (.id a t u)} because the endpoints are not definitionally equal"
| .letE x a t u, ty => do | .letE x a t u, ty => do
let (a', _) inferUniverse cxt a let (a', _) inferUniverse cxt a
let va := eval cxt.env a' let va := eval cxt.env a'
@@ -33,7 +38,7 @@ mutual
pure (.letE a' t' u') pure (.letE a' t' u')
| r, ty => do | r, ty => do
let (t', ty') infer cxt r let (t', ty') infer cxt r
if ( conv cxt.lvl ty' ty) then if ( sub cxt.lvl ty' ty) then
pure t' pure t'
else else
throw s!"type mismatch: expected {showTy cxt.lvl ty}, got {showTy cxt.lvl ty'}" throw s!"type mismatch: expected {showTy cxt.lvl ty}, got {showTy cxt.lvl ty'}"
@@ -43,6 +48,80 @@ mutual
match cxt.lookup x with match cxt.lookup x with
| some (i, a) => pure (.var i, a) | some (i, a) => pure (.var i, a)
| none => throw s!"unknown variable {x}" | none => throw s!"unknown variable {x}"
| .nat => pure (.nat, .univ 0)
| .zero => pure (.zero, .nat)
| .succ t => do
let t' check cxt t .nat
pure (.succ t', .nat)
| .natElim n motive z k ih s scrut => do
let scrut' check cxt scrut .nat
let vscrut eval cxt.env scrut'
let (motiveBody', level) inferUniverse (cxt.bind n .nat) motive
let motive' : Tm := .lam motiveBody'
let vmotive eval cxt.env motive'
let zTy vApp vmotive .zero
let z' check cxt z zTy
let kCxt := cxt.bind k .nat
let ihTy vApp vmotive (.var cxt.lvl)
let stepTy vApp vmotive (.succ (.var cxt.lvl))
let stepBody' check (kCxt.bind ih ihTy) s stepTy
let step' : Tm := .lam (.lam stepBody')
let resultTy vApp vmotive vscrut
let _ := level
pure (.natElim motive' z' step' scrut', resultTy)
| .unit => pure (.unit, .univ 0)
| .triv => pure (.triv, .unit)
| .unitElim u motive t scrut => do
let scrut' check cxt scrut .unit
let vscrut eval cxt.env scrut'
let (motiveBody', level) inferUniverse (cxt.bind u .unit) motive
let motive' : Tm := .lam motiveBody'
let vmotive eval cxt.env motive'
let tTy vApp vmotive .triv
let t' check cxt t tTy
let resultTy vApp vmotive vscrut
let _ := level
pure (.unitElim motive' t' scrut', resultTy)
| .empty => pure (.empty, .univ 0)
| .emptyElim e motive scrut => do
let scrut' check cxt scrut .empty
let vscrut eval cxt.env scrut'
let (motiveBody', level) inferUniverse (cxt.bind e .empty) motive
let motive' : Tm := .lam motiveBody'
let vmotive eval cxt.env motive'
let resultTy vApp vmotive vscrut
let _ := level
pure (.emptyElim motive' scrut', resultTy)
| .id a t u => do
let (a', level) inferUniverse cxt a
let va eval cxt.env a'
let t' check cxt t va
let u' check cxt u va
pure (.id a' t' u', .univ level)
| .refl => throw "cannot infer type of refl, use an annotation"
| .idElim y p motive r target eq => do
let (eq', eqTy) infer cxt eq
match eqTy with
| .id a x rhs => do
let (target', targetTy) infer cxt target
if !( conv cxt.lvl targetTy a) then
throw s!"idElim target has type {showTy cxt.lvl targetTy}, but the equality lives over {showTy cxt.lvl a}"
let vtarget eval cxt.env target'
if !( conv cxt.lvl vtarget rhs) then
throw s!"idElim target {showTy cxt.lvl vtarget} does not match the equality endpoint {showTy cxt.lvl rhs}"
let eqVarTy : Val := .id a x (.var cxt.lvl)
let (motiveBody', level) inferUniverse ((cxt.bind y a).bind p eqVarTy) motive
let motive' : Tm := .lam (.lam motiveBody')
let vmotive eval cxt.env motive'
let reflTy vApp vmotive x
let reflTy vApp reflTy .refl
let r' check cxt r reflTy
let veq eval cxt.env eq'
let resultTy vApp vmotive vtarget
let resultTy vApp resultTy veq
let _ := level
pure (.idElim motive' r' target' eq', resultTy)
| _ => throw s!"expected Id type in idElim, got {showTy cxt.lvl eqTy}"
| .univ i => pure (.univ i, .univ (i + 1)) | .univ i => pure (.univ i, .univ (i + 1))
| .app t u => do | .app t u => do
let (t', tty) infer cxt t let (t', tty) infer cxt t
+148
View File
@@ -32,6 +32,41 @@ mutual
| env, .snd t => do | env, .snd t => do
let vt eval env t let vt eval env t
vSnd vt vSnd vt
| _, .nat => pure .nat
| _, .zero => pure .zero
| env, .succ t => do
let vt eval env t
pure (.succ vt)
| env, .natElim m z s n => do
let vm eval env m
let vz eval env z
let vs eval env s
let vn eval env n
vNatElim vm vz vs vn
| _, .unit => pure .unit
| _, .triv => pure .triv
| env, .unitElim m t u => do
let vm eval env m
let vt eval env t
let vu eval env u
vUnitElim vm vt vu
| _, .empty => pure .empty
| env, .emptyElim m e => do
let vm eval env m
let ve eval env e
vEmptyElim vm ve
| env, .id a t u => do
let va eval env a
let vt eval env t
let vu eval env u
pure (.id va vt vu)
| _, .refl => pure .refl
| env, .idElim m r y p => do
let vm eval env m
let vr eval env r
let vy eval env y
let vp eval env p
vIdElim vm vr vy vp
| _, .univ i => pure (.univ i) | _, .univ i => pure (.univ i)
| env, .letE _ t u => do | env, .letE _ t u => do
let vt eval env t let vt eval env t
@@ -49,6 +84,25 @@ mutual
| .pair _ b => pure b | .pair _ b => pure b
| t => pure (.snd t) | t => pure (.snd t)
partial def vNatElim : Val Val Val Val EvalM Val
| _, z, _, .zero => pure z
| m, z, s, .succ n => do
let ih vNatElim m z s n
let step vApp s n
vApp step ih
| m, z, s, n => pure (.natElim m z s n)
partial def vUnitElim : Val Val Val EvalM Val
| _, t, .triv => pure t
| m, t, u => pure (.unitElim m t u)
partial def vEmptyElim : Val Val EvalM Val
| m, e => pure (.emptyElim m e)
partial def vIdElim : Val Val Val Val EvalM Val
| _, r, _, .refl => pure r
| m, r, y, p => pure (.idElim m r y p)
partial def cApp : Closure Val EvalM Val partial def cApp : Closure Val EvalM Val
| .mk env body, v => eval (v :: env) body | .mk env body, v => eval (v :: env) body
end end
@@ -69,6 +123,41 @@ partial def quote : Lvl → Val → EvalM Tm
| l, .snd t => do | l, .snd t => do
let qt quote l t let qt quote l t
pure (.snd qt) pure (.snd qt)
| _, .nat => pure .nat
| _, .zero => pure .zero
| l, .succ t => do
let qt quote l t
pure (.succ qt)
| l, .natElim m z s n => do
let qm quote l m
let qz quote l z
let qs quote l s
let qn quote l n
pure (.natElim qm qz qs qn)
| _, .unit => pure .unit
| _, .triv => pure .triv
| l, .unitElim m t u => do
let qm quote l m
let qt quote l t
let qu quote l u
pure (.unitElim qm qt qu)
| _, .empty => pure .empty
| l, .emptyElim m e => do
let qm quote l m
let qe quote l e
pure (.emptyElim qm qe)
| l, .id a t u => do
let qa quote l a
let qt quote l t
let qu quote l u
pure (.id qa qt qu)
| _, .refl => pure .refl
| l, .idElim m r y p => do
let qm quote l m
let qr quote l r
let qy quote l y
let qp quote l p
pure (.idElim qm qr qy qp)
| l, .lam c => do | l, .lam c => do
let body cApp c (.var l) let body cApp c (.var l)
let qb quote (l + 1) body let qb quote (l + 1) body
@@ -107,6 +196,51 @@ partial def conv : Lvl → Val → Val → EvalM Bool
let b cApp c (.var l) let b cApp c (.var l)
let b' cApp c' (.var l) let b' cApp c' (.var l)
conv (l + 1) b b' conv (l + 1) b b'
| _, .nat, .nat => pure true
| _, .zero, .zero => pure true
| l, .succ n, .succ n' => conv l n n'
| l, .natElim m z s n, .natElim m' z' s' n' =>
andThen (conv l m m') fun _ => do
let sameZ conv l z z'
if sameZ then
let sameS conv l s s'
if sameS then
conv l n n'
else
pure false
else
pure false
| _, .unit, .unit => pure true
| _, .triv, .triv => pure true
| l, .unitElim m t u, .unitElim m' t' u' =>
andThen (conv l m m') fun _ => do
let sameT conv l t t'
if sameT then
conv l u u'
else
pure false
| _, .empty, .empty => pure true
| l, .emptyElim m e, .emptyElim m' e' =>
andThen (conv l m m') fun _ => conv l e e'
| l, .id a t u, .id a' t' u' =>
andThen (conv l a a') fun _ => do
let sameT conv l t t'
if sameT then
conv l u u'
else
pure false
| _, .refl, .refl => pure true
| l, .idElim m r y p, .idElim m' r' y' p' =>
andThen (conv l m m') fun _ => do
let sameR conv l r r'
if sameR then
let sameY conv l y y'
if sameY then
conv l p p'
else
pure false
else
pure false
| l, .lam c, .lam c' => | l, .lam c, .lam c' =>
do do
let body cApp c (.var l) let body cApp c (.var l)
@@ -147,4 +281,18 @@ partial def conv : Lvl → Val → Val → EvalM Bool
| l, .snd t, .snd t' => conv l t t' | l, .snd t, .snd t' => conv l t t'
| _, _, _ => pure false | _, _, _ => pure false
partial def sub : Lvl Val Val EvalM Bool
| _, .univ i, .univ j => pure (i <= j)
| l, .pi a c, .pi a' c' =>
andThen (sub l a' a) fun _ => do
let b cApp c (.var l)
let b' cApp c' (.var l)
sub (l + 1) b b'
| l, .sig a c, .sig a' c' =>
andThen (sub l a a') fun _ => do
let b cApp c (.var l)
let b' cApp c' (.var l)
sub (l + 1) b b'
| l, t, t' => conv l t t'
end BidirTT end BidirTT
+31
View File
@@ -46,6 +46,32 @@ def fstDepPair : Raw := .fst depPairAnn
def sndDepPair : Raw := .snd depPairAnn def sndDepPair : Raw := .snd depPairAnn
def natTwo : Raw :=
.succ (.succ .zero)
def natFoldId : Raw :=
.ann
(.natElim "n" .nat .zero "k" "ih" (.succ (.var "ih")) natTwo)
.nat
def unitToNat : Raw :=
.ann
(.unitElim "u" .nat natTwo .triv)
.nat
def absurdNat : Raw :=
.ann
(.lam "e" (.emptyElim "x" .nat (.var "e")))
(.pi "e" .empty .nat)
def reflZero : Raw :=
.ann .refl (.id .nat .zero .zero)
def idElimNat : Raw :=
.ann
(.idElim "y" "p" .nat .zero .zero reflZero)
.nat
def omegaTy : Raw := def omegaTy : Raw :=
.pi "A" (.univ 0) (.var "A") .pi "A" (.univ 0) (.var "A")
@@ -67,6 +93,11 @@ def letUniverse : Raw :=
(.letE "A" (.univ 1) (.pi "_" (.univ 0) (.univ 0)) (.var "A")) (.letE "A" (.univ 1) (.pi "_" (.univ 0) (.univ 0)) (.var "A"))
(.univ 1) (.univ 1)
def badSucc : Raw := .succ (.univ 0)
def badRefl : Raw :=
.ann .refl (.id .nat .zero (.succ .zero))
def univ0 : Raw := .univ 0 def univ0 : Raw := .univ 0
end BidirTT.Examples end BidirTT.Examples
+16
View File
@@ -12,6 +12,22 @@ mutual
| .pair t u => s!"({prettyTm t}, {prettyTm u})" | .pair t u => s!"({prettyTm t}, {prettyTm u})"
| .fst t => s!"({prettyTm t}.1)" | .fst t => s!"({prettyTm t}.1)"
| .snd t => s!"({prettyTm t}.2)" | .snd t => s!"({prettyTm t}.2)"
| .nat => "Nat"
| .zero => "zero"
| .succ t => s!"(succ {prettyTm t})"
| .natElim m z s n =>
s!"(natElim {prettyTm m} {prettyTm z} {prettyTm s} {prettyTm n})"
| .unit => "Unit"
| .triv => "tt"
| .unitElim m t u =>
s!"(unitElim {prettyTm m} {prettyTm t} {prettyTm u})"
| .empty => "Empty"
| .emptyElim m e =>
s!"(emptyElim {prettyTm m} {prettyTm e})"
| .id a t u => s!"(Id {prettyTm a} {prettyTm t} {prettyTm u})"
| .refl => "refl"
| .idElim m r y p =>
s!"(idElim {prettyTm m} {prettyTm r} {prettyTm y} {prettyTm p})"
| .univ i => s!"U{i}" | .univ i => s!"U{i}"
| .letE a t u => s!"(let : {prettyTm a} := {prettyTm t}; {prettyTm u})" | .letE a t u => s!"(let : {prettyTm a} := {prettyTm t}; {prettyTm u})"
end end
+24
View File
@@ -11,6 +11,18 @@ inductive Raw where
| pair : Raw Raw Raw | pair : Raw Raw Raw
| fst : Raw Raw | fst : Raw Raw
| snd : Raw Raw | snd : Raw Raw
| nat : Raw
| zero : Raw
| succ : Raw Raw
| natElim : Name Raw Raw Name Name Raw Raw Raw
| unit : Raw
| triv : Raw
| unitElim : Name Raw Raw Raw Raw
| empty : Raw
| emptyElim : Name Raw Raw Raw
| id : Raw Raw Raw Raw
| refl : Raw
| idElim : Name Name Raw Raw Raw Raw Raw
| univ : Nat Raw | univ : Nat Raw
| letE : Name Raw Raw Raw Raw | letE : Name Raw Raw Raw Raw
| ann : Raw Raw Raw | ann : Raw Raw Raw
@@ -25,6 +37,18 @@ inductive Tm where
| pair : Tm Tm Tm | pair : Tm Tm Tm
| fst : Tm Tm | fst : Tm Tm
| snd : Tm Tm | snd : Tm Tm
| nat : Tm
| zero : Tm
| succ : Tm Tm
| natElim : Tm Tm Tm Tm Tm
| unit : Tm
| triv : Tm
| unitElim : Tm Tm Tm Tm
| empty : Tm
| emptyElim : Tm Tm Tm
| id : Tm Tm Tm Tm
| refl : Tm
| idElim : Tm Tm Tm Tm Tm
| univ : Nat Tm | univ : Nat Tm
| letE : Tm Tm Tm Tm | letE : Tm Tm Tm Tm
deriving Repr, Inhabited, BEq, DecidableEq deriving Repr, Inhabited, BEq, DecidableEq
+12
View File
@@ -12,6 +12,18 @@ mutual
| pi : Val Closure Val | pi : Val Closure Val
| sig : Val Closure Val | sig : Val Closure Val
| pair : Val Val Val | pair : Val Val Val
| nat : Val
| zero : Val
| succ : Val Val
| natElim : Val Val Val Val Val
| unit : Val
| triv : Val
| unitElim : Val Val Val Val
| empty : Val
| emptyElim : Val Val Val
| id : Val Val Val Val
| refl : Val
| idElim : Val Val Val Val Val
| univ : Nat Val | univ : Nat Val
inductive Closure where inductive Closure where
+7
View File
@@ -23,8 +23,15 @@ def main : IO Unit := do
runOne "depPair" Examples.depPairAnn runOne "depPair" Examples.depPairAnn
runOne "depPair.1" Examples.fstDepPair runOne "depPair.1" Examples.fstDepPair
runOne "depPair.2" Examples.sndDepPair runOne "depPair.2" Examples.sndDepPair
runOne "nat fold id" Examples.natFoldId
runOne "unit elim" Examples.unitToNat
runOne "empty absurd" Examples.absurdNat
runOne "refl zero" Examples.reflZero
runOne "id elim" Examples.idElimNat
runOne "let universe" Examples.letUniverse runOne "let universe" Examples.letUniverse
runOne "omega (bad)" Examples.omegaAnn runOne "omega (bad)" Examples.omegaAnn
runOne "unknown var" Examples.unknownVar runOne "unknown var" Examples.unknownVar
runOne "pair mismatch" Examples.pairMismatch runOne "pair mismatch" Examples.pairMismatch
runOne "bad fst" Examples.badFst runOne "bad fst" Examples.badFst
runOne "bad succ" Examples.badSucc
runOne "bad refl" Examples.badRefl
+56 -1
View File
@@ -4,11 +4,16 @@ open BidirTT
inductive Expectation where inductive Expectation where
| okTy : Tm Expectation | okTy : Tm Expectation
| okTyNorm : Tm Tm Expectation
| errContains : String Expectation | errContains : String Expectation
private def renderType (ty : Val) : Except String Tm := private def renderType (ty : Val) : Except String Tm :=
BidirTT.quote 0 ty BidirTT.quote 0 ty
private def renderNormal (tm : Tm) : Except String Tm := do
let v eval [] tm
quote 0 v
private def containsText (haystack needle : String) : Bool := private def containsText (haystack needle : String) : Bool :=
needle.isEmpty || (haystack.splitOn needle).length > 1 needle.isEmpty || (haystack.splitOn needle).length > 1
@@ -20,20 +25,47 @@ structure TestCase where
def cases : List TestCase := [ def cases : List TestCase := [
"U0 is typed by U1", Examples.univ0, "U0 is typed by U1", Examples.univ0,
.okTy (.univ 1), .okTy (.univ 1),
"U0 subsumes into U2", .ann (.univ 0) (.univ 2),
.okTy (.univ 2),
"Nat is typed by U0", .nat,
.okTy (.univ 0),
"succ zero infers Nat", (.succ .zero),
.okTy .nat,
"id typechecks", Examples.idAnn, "id typechecks", Examples.idAnn,
.okTy (.pi (.univ 0) (.pi (.var 0) (.var 1))), .okTy (.pi (.univ 0) (.pi (.var 0) (.var 1))),
"Pi subsumption is contravariant in the domain", .ann
(.ann
(.lam "A" (.var "A"))
(.pi "A" (.univ 1) (.univ 1)))
(.pi "A" (.univ 0) (.univ 2)),
.okTy (.pi (.univ 0) (.univ 2)),
"const typechecks", Examples.constAnn, "const typechecks", Examples.constAnn,
.okTy (.pi (.univ 0) (.pi (.univ 0) (.pi (.var 1) (.pi (.var 1) (.var 3))))), .okTy (.pi (.univ 0) (.pi (.univ 0) (.pi (.var 1) (.pi (.var 1) (.var 3))))),
"swap typechecks", Examples.swapAnn, "swap typechecks", Examples.swapAnn,
.okTy (.pi (.univ 0) (.pi (.univ 0) (.pi (.sig (.var 1) (.var 1)) (.sig (.var 1) (.var 3))))), .okTy (.pi (.univ 0) (.pi (.univ 0) (.pi (.sig (.var 1) (.var 1)) (.sig (.var 1) (.var 3))))),
"dependent pair typechecks", Examples.depPairAnn, "dependent pair typechecks", Examples.depPairAnn,
.okTy (.sig (.univ 2) (.var 0)), .okTy (.sig (.univ 2) (.var 0)),
"dependent pair subsumes into a lifted Sigma", .ann Examples.depPairAnn
(.sig "A" (.univ 3) (.var "A")),
.okTy (.sig (.univ 3) (.var 0)),
"natElim computes on numerals", Examples.natFoldId,
.okTyNorm .nat (.succ (.succ .zero)),
"unitElim computes on tt", Examples.unitToNat,
.okTyNorm .nat (.succ (.succ .zero)),
"emptyElim builds absurd maps", Examples.absurdNat,
.okTy (.pi .empty .nat),
"refl inhabits reflexive identity", Examples.reflZero,
.okTyNorm (.id .nat .zero .zero) .refl,
"idElim computes on refl", Examples.idElimNat,
.okTyNorm .nat .zero,
"fst infers the first projection", Examples.fstDepPair, "fst infers the first projection", Examples.fstDepPair,
.okTy (.univ 2), .okTy (.univ 2),
"snd infers the dependent second projection", Examples.sndDepPair, "snd infers the dependent second projection", Examples.sndDepPair,
.okTy (.univ 1), .okTy (.univ 1),
"let infers through definitions", Examples.letUniverse, "let infers through definitions", Examples.letUniverse,
.okTy (.univ 1), .okTy (.univ 1),
"bad succ rejected", Examples.badSucc,
.errContains "type mismatch: expected Nat, got U1",
"self application rejected", Examples.omegaAnn, "self application rejected", Examples.omegaAnn,
.errContains "expected Pi type in application", .errContains "expected Pi type in application",
"unknown variable rejected", Examples.unknownVar, "unknown variable rejected", Examples.unknownVar,
@@ -41,7 +73,9 @@ def cases : List TestCase := [
"pair mismatch rejected at the Sigma body", Examples.pairMismatch, "pair mismatch rejected at the Sigma body", Examples.pairMismatch,
.errContains "type mismatch: expected U1, got U2", .errContains "type mismatch: expected U1, got U2",
"bad fst rejected", Examples.badFst, "bad fst rejected", Examples.badFst,
.errContains "expected Sigma type in .1, got U1" .errContains "expected Sigma type in .1, got U1",
"bad refl rejected", Examples.badRefl,
.errContains "refl cannot inhabit"
] ]
def runCase (tc : TestCase) : IO Bool := do def runCase (tc : TestCase) : IO Bool := do
@@ -58,9 +92,30 @@ def runCase (tc : TestCase) : IO Bool := do
| Except.error err => | Except.error err =>
IO.println s!"FAIL {tc.name} (could not quote type: {err})" IO.println s!"FAIL {tc.name} (could not quote type: {err})"
pure false pure false
| .okTyNorm expectedTy expectedNf, .ok (tm, ty) =>
match renderType ty, renderNormal tm with
| Except.ok actualTy, Except.ok actualNf =>
if actualTy == expectedTy && actualNf == expectedNf then
IO.println s!"PASS {tc.name}"
pure true
else if actualTy != expectedTy then
IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got {BidirTT.prettyTm actualTy})"
pure false
else
IO.println s!"FAIL {tc.name} (expected nf {BidirTT.prettyTm expectedNf}, got {BidirTT.prettyTm actualNf})"
pure false
| Except.error err, _ =>
IO.println s!"FAIL {tc.name} (could not quote type: {err})"
pure false
| _, Except.error err =>
IO.println s!"FAIL {tc.name} (could not normalize term: {err})"
pure false
| .okTy expectedTy, .error err => | .okTy expectedTy, .error err =>
IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got error {err})" IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got error {err})"
pure false pure false
| .okTyNorm expectedTy _, .error err =>
IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got error {err})"
pure false
| .errContains needle, .error err => | .errContains needle, .error err =>
if containsText err needle then if containsText err needle then
IO.println s!"PASS {tc.name}" IO.println s!"PASS {tc.name}"