diff --git a/BidirTT/Check.lean b/BidirTT/Check.lean index ce80f3b..cb0e81f 100644 --- a/BidirTT/Check.lean +++ b/BidirTT/Check.lean @@ -23,6 +23,11 @@ mutual let bodyTy ← cApp c vt let u' ← check cxt u bodyTy pure (.pair t' u') + | .refl, .id a t u => do + if (← conv cxt.lvl t u) then + pure .refl + else + throw s!"refl cannot inhabit {showTy cxt.lvl (.id a t u)} because the endpoints are not definitionally equal" | .letE x a t u, ty => do let (a', _) ← inferUniverse cxt a let va := eval cxt.env a' @@ -33,7 +38,7 @@ mutual pure (.letE a' t' u') | r, ty => do let (t', ty') ← infer cxt r - if (← conv cxt.lvl ty' ty) then + if (← sub cxt.lvl ty' ty) then pure t' else throw s!"type mismatch: expected {showTy cxt.lvl ty}, got {showTy cxt.lvl ty'}" @@ -43,6 +48,80 @@ mutual match cxt.lookup x with | some (i, a) => pure (.var i, a) | none => throw s!"unknown variable {x}" + | .nat => pure (.nat, .univ 0) + | .zero => pure (.zero, .nat) + | .succ t => do + let t' ← check cxt t .nat + pure (.succ t', .nat) + | .natElim n motive z k ih s scrut => do + let scrut' ← check cxt scrut .nat + let vscrut ← eval cxt.env scrut' + let (motiveBody', level) ← inferUniverse (cxt.bind n .nat) motive + let motive' : Tm := .lam motiveBody' + let vmotive ← eval cxt.env motive' + let zTy ← vApp vmotive .zero + let z' ← check cxt z zTy + let kCxt := cxt.bind k .nat + let ihTy ← vApp vmotive (.var cxt.lvl) + let stepTy ← vApp vmotive (.succ (.var cxt.lvl)) + let stepBody' ← check (kCxt.bind ih ihTy) s stepTy + let step' : Tm := .lam (.lam stepBody') + let resultTy ← vApp vmotive vscrut + let _ := level + pure (.natElim motive' z' step' scrut', resultTy) + | .unit => pure (.unit, .univ 0) + | .triv => pure (.triv, .unit) + | .unitElim u motive t scrut => do + let scrut' ← check cxt scrut .unit + let vscrut ← eval cxt.env scrut' + let (motiveBody', level) ← inferUniverse (cxt.bind u .unit) motive + let motive' : Tm := .lam motiveBody' + let vmotive ← eval cxt.env motive' + let tTy ← vApp vmotive .triv + let t' ← check cxt t tTy + let resultTy ← vApp vmotive vscrut + let _ := level + pure (.unitElim motive' t' scrut', resultTy) + | .empty => pure (.empty, .univ 0) + | .emptyElim e motive scrut => do + let scrut' ← check cxt scrut .empty + let vscrut ← eval cxt.env scrut' + let (motiveBody', level) ← inferUniverse (cxt.bind e .empty) motive + let motive' : Tm := .lam motiveBody' + let vmotive ← eval cxt.env motive' + let resultTy ← vApp vmotive vscrut + let _ := level + pure (.emptyElim motive' scrut', resultTy) + | .id a t u => do + let (a', level) ← inferUniverse cxt a + let va ← eval cxt.env a' + let t' ← check cxt t va + let u' ← check cxt u va + pure (.id a' t' u', .univ level) + | .refl => throw "cannot infer type of refl, use an annotation" + | .idElim y p motive r target eq => do + let (eq', eqTy) ← infer cxt eq + match eqTy with + | .id a x rhs => do + let (target', targetTy) ← infer cxt target + if !(← conv cxt.lvl targetTy a) then + throw s!"idElim target has type {showTy cxt.lvl targetTy}, but the equality lives over {showTy cxt.lvl a}" + let vtarget ← eval cxt.env target' + if !(← conv cxt.lvl vtarget rhs) then + throw s!"idElim target {showTy cxt.lvl vtarget} does not match the equality endpoint {showTy cxt.lvl rhs}" + let eqVarTy : Val := .id a x (.var cxt.lvl) + let (motiveBody', level) ← inferUniverse ((cxt.bind y a).bind p eqVarTy) motive + let motive' : Tm := .lam (.lam motiveBody') + let vmotive ← eval cxt.env motive' + let reflTy ← vApp vmotive x + let reflTy ← vApp reflTy .refl + let r' ← check cxt r reflTy + let veq ← eval cxt.env eq' + let resultTy ← vApp vmotive vtarget + let resultTy ← vApp resultTy veq + let _ := level + pure (.idElim motive' r' target' eq', resultTy) + | _ => throw s!"expected Id type in idElim, got {showTy cxt.lvl eqTy}" | .univ i => pure (.univ i, .univ (i + 1)) | .app t u => do let (t', tty) ← infer cxt t diff --git a/BidirTT/Eval.lean b/BidirTT/Eval.lean index 46e43aa..9d38332 100644 --- a/BidirTT/Eval.lean +++ b/BidirTT/Eval.lean @@ -32,6 +32,41 @@ mutual | env, .snd t => do let vt ← eval env t vSnd vt + | _, .nat => pure .nat + | _, .zero => pure .zero + | env, .succ t => do + let vt ← eval env t + pure (.succ vt) + | env, .natElim m z s n => do + let vm ← eval env m + let vz ← eval env z + let vs ← eval env s + let vn ← eval env n + vNatElim vm vz vs vn + | _, .unit => pure .unit + | _, .triv => pure .triv + | env, .unitElim m t u => do + let vm ← eval env m + let vt ← eval env t + let vu ← eval env u + vUnitElim vm vt vu + | _, .empty => pure .empty + | env, .emptyElim m e => do + let vm ← eval env m + let ve ← eval env e + vEmptyElim vm ve + | env, .id a t u => do + let va ← eval env a + let vt ← eval env t + let vu ← eval env u + pure (.id va vt vu) + | _, .refl => pure .refl + | env, .idElim m r y p => do + let vm ← eval env m + let vr ← eval env r + let vy ← eval env y + let vp ← eval env p + vIdElim vm vr vy vp | _, .univ i => pure (.univ i) | env, .letE _ t u => do let vt ← eval env t @@ -49,6 +84,25 @@ mutual | .pair _ b => pure b | t => pure (.snd t) + partial def vNatElim : Val → Val → Val → Val → EvalM Val + | _, z, _, .zero => pure z + | m, z, s, .succ n => do + let ih ← vNatElim m z s n + let step ← vApp s n + vApp step ih + | m, z, s, n => pure (.natElim m z s n) + + partial def vUnitElim : Val → Val → Val → EvalM Val + | _, t, .triv => pure t + | m, t, u => pure (.unitElim m t u) + + partial def vEmptyElim : Val → Val → EvalM Val + | m, e => pure (.emptyElim m e) + + partial def vIdElim : Val → Val → Val → Val → EvalM Val + | _, r, _, .refl => pure r + | m, r, y, p => pure (.idElim m r y p) + partial def cApp : Closure → Val → EvalM Val | .mk env body, v => eval (v :: env) body end @@ -69,6 +123,41 @@ partial def quote : Lvl → Val → EvalM Tm | l, .snd t => do let qt ← quote l t pure (.snd qt) + | _, .nat => pure .nat + | _, .zero => pure .zero + | l, .succ t => do + let qt ← quote l t + pure (.succ qt) + | l, .natElim m z s n => do + let qm ← quote l m + let qz ← quote l z + let qs ← quote l s + let qn ← quote l n + pure (.natElim qm qz qs qn) + | _, .unit => pure .unit + | _, .triv => pure .triv + | l, .unitElim m t u => do + let qm ← quote l m + let qt ← quote l t + let qu ← quote l u + pure (.unitElim qm qt qu) + | _, .empty => pure .empty + | l, .emptyElim m e => do + let qm ← quote l m + let qe ← quote l e + pure (.emptyElim qm qe) + | l, .id a t u => do + let qa ← quote l a + let qt ← quote l t + let qu ← quote l u + pure (.id qa qt qu) + | _, .refl => pure .refl + | l, .idElim m r y p => do + let qm ← quote l m + let qr ← quote l r + let qy ← quote l y + let qp ← quote l p + pure (.idElim qm qr qy qp) | l, .lam c => do let body ← cApp c (.var l) let qb ← quote (l + 1) body @@ -107,6 +196,51 @@ partial def conv : Lvl → Val → Val → EvalM Bool let b ← cApp c (.var l) let b' ← cApp c' (.var l) conv (l + 1) b b' + | _, .nat, .nat => pure true + | _, .zero, .zero => pure true + | l, .succ n, .succ n' => conv l n n' + | l, .natElim m z s n, .natElim m' z' s' n' => + andThen (conv l m m') fun _ => do + let sameZ ← conv l z z' + if sameZ then + let sameS ← conv l s s' + if sameS then + conv l n n' + else + pure false + else + pure false + | _, .unit, .unit => pure true + | _, .triv, .triv => pure true + | l, .unitElim m t u, .unitElim m' t' u' => + andThen (conv l m m') fun _ => do + let sameT ← conv l t t' + if sameT then + conv l u u' + else + pure false + | _, .empty, .empty => pure true + | l, .emptyElim m e, .emptyElim m' e' => + andThen (conv l m m') fun _ => conv l e e' + | l, .id a t u, .id a' t' u' => + andThen (conv l a a') fun _ => do + let sameT ← conv l t t' + if sameT then + conv l u u' + else + pure false + | _, .refl, .refl => pure true + | l, .idElim m r y p, .idElim m' r' y' p' => + andThen (conv l m m') fun _ => do + let sameR ← conv l r r' + if sameR then + let sameY ← conv l y y' + if sameY then + conv l p p' + else + pure false + else + pure false | l, .lam c, .lam c' => do let body ← cApp c (.var l) @@ -147,4 +281,18 @@ partial def conv : Lvl → Val → Val → EvalM Bool | l, .snd t, .snd t' => conv l t t' | _, _, _ => pure false +partial def sub : Lvl → Val → Val → EvalM Bool + | _, .univ i, .univ j => pure (i <= j) + | l, .pi a c, .pi a' c' => + andThen (sub l a' a) fun _ => do + let b ← cApp c (.var l) + let b' ← cApp c' (.var l) + sub (l + 1) b b' + | l, .sig a c, .sig a' c' => + andThen (sub l a a') fun _ => do + let b ← cApp c (.var l) + let b' ← cApp c' (.var l) + sub (l + 1) b b' + | l, t, t' => conv l t t' + end BidirTT diff --git a/BidirTT/Examples.lean b/BidirTT/Examples.lean index 6d7b52d..e78d998 100644 --- a/BidirTT/Examples.lean +++ b/BidirTT/Examples.lean @@ -46,6 +46,32 @@ def fstDepPair : Raw := .fst depPairAnn def sndDepPair : Raw := .snd depPairAnn +def natTwo : Raw := + .succ (.succ .zero) + +def natFoldId : Raw := + .ann + (.natElim "n" .nat .zero "k" "ih" (.succ (.var "ih")) natTwo) + .nat + +def unitToNat : Raw := + .ann + (.unitElim "u" .nat natTwo .triv) + .nat + +def absurdNat : Raw := + .ann + (.lam "e" (.emptyElim "x" .nat (.var "e"))) + (.pi "e" .empty .nat) + +def reflZero : Raw := + .ann .refl (.id .nat .zero .zero) + +def idElimNat : Raw := + .ann + (.idElim "y" "p" .nat .zero .zero reflZero) + .nat + def omegaTy : Raw := .pi "A" (.univ 0) (.var "A") @@ -67,6 +93,11 @@ def letUniverse : Raw := (.letE "A" (.univ 1) (.pi "_" (.univ 0) (.univ 0)) (.var "A")) (.univ 1) +def badSucc : Raw := .succ (.univ 0) + +def badRefl : Raw := + .ann .refl (.id .nat .zero (.succ .zero)) + def univ0 : Raw := .univ 0 end BidirTT.Examples diff --git a/BidirTT/Pretty.lean b/BidirTT/Pretty.lean index e396206..7425644 100644 --- a/BidirTT/Pretty.lean +++ b/BidirTT/Pretty.lean @@ -12,6 +12,22 @@ mutual | .pair t u => s!"({prettyTm t}, {prettyTm u})" | .fst t => s!"({prettyTm t}.1)" | .snd t => s!"({prettyTm t}.2)" + | .nat => "Nat" + | .zero => "zero" + | .succ t => s!"(succ {prettyTm t})" + | .natElim m z s n => + s!"(natElim {prettyTm m} {prettyTm z} {prettyTm s} {prettyTm n})" + | .unit => "Unit" + | .triv => "tt" + | .unitElim m t u => + s!"(unitElim {prettyTm m} {prettyTm t} {prettyTm u})" + | .empty => "Empty" + | .emptyElim m e => + s!"(emptyElim {prettyTm m} {prettyTm e})" + | .id a t u => s!"(Id {prettyTm a} {prettyTm t} {prettyTm u})" + | .refl => "refl" + | .idElim m r y p => + s!"(idElim {prettyTm m} {prettyTm r} {prettyTm y} {prettyTm p})" | .univ i => s!"U{i}" | .letE a t u => s!"(let : {prettyTm a} := {prettyTm t}; {prettyTm u})" end diff --git a/BidirTT/Syntax.lean b/BidirTT/Syntax.lean index 9bb2437..f0b140a 100644 --- a/BidirTT/Syntax.lean +++ b/BidirTT/Syntax.lean @@ -11,6 +11,18 @@ inductive Raw where | pair : Raw → Raw → Raw | fst : Raw → Raw | snd : Raw → Raw + | nat : Raw + | zero : Raw + | succ : Raw → Raw + | natElim : Name → Raw → Raw → Name → Name → Raw → Raw → Raw + | unit : Raw + | triv : Raw + | unitElim : Name → Raw → Raw → Raw → Raw + | empty : Raw + | emptyElim : Name → Raw → Raw → Raw + | id : Raw → Raw → Raw → Raw + | refl : Raw + | idElim : Name → Name → Raw → Raw → Raw → Raw → Raw | univ : Nat → Raw | letE : Name → Raw → Raw → Raw → Raw | ann : Raw → Raw → Raw @@ -25,6 +37,18 @@ inductive Tm where | pair : Tm → Tm → Tm | fst : Tm → Tm | snd : Tm → Tm + | nat : Tm + | zero : Tm + | succ : Tm → Tm + | natElim : Tm → Tm → Tm → Tm → Tm + | unit : Tm + | triv : Tm + | unitElim : Tm → Tm → Tm → Tm + | empty : Tm + | emptyElim : Tm → Tm → Tm + | id : Tm → Tm → Tm → Tm + | refl : Tm + | idElim : Tm → Tm → Tm → Tm → Tm | univ : Nat → Tm | letE : Tm → Tm → Tm → Tm deriving Repr, Inhabited, BEq, DecidableEq diff --git a/BidirTT/Value.lean b/BidirTT/Value.lean index 9267c64..87f4515 100644 --- a/BidirTT/Value.lean +++ b/BidirTT/Value.lean @@ -12,6 +12,18 @@ mutual | pi : Val → Closure → Val | sig : Val → Closure → Val | pair : Val → Val → Val + | nat : Val + | zero : Val + | succ : Val → Val + | natElim : Val → Val → Val → Val → Val + | unit : Val + | triv : Val + | unitElim : Val → Val → Val → Val + | empty : Val + | emptyElim : Val → Val → Val + | id : Val → Val → Val → Val + | refl : Val + | idElim : Val → Val → Val → Val → Val | univ : Nat → Val inductive Closure where diff --git a/Main.lean b/Main.lean index 2778915..e51a79b 100644 --- a/Main.lean +++ b/Main.lean @@ -23,8 +23,15 @@ def main : IO Unit := do runOne "depPair" Examples.depPairAnn runOne "depPair.1" Examples.fstDepPair runOne "depPair.2" Examples.sndDepPair + runOne "nat fold id" Examples.natFoldId + runOne "unit elim" Examples.unitToNat + runOne "empty absurd" Examples.absurdNat + runOne "refl zero" Examples.reflZero + runOne "id elim" Examples.idElimNat runOne "let universe" Examples.letUniverse runOne "omega (bad)" Examples.omegaAnn runOne "unknown var" Examples.unknownVar runOne "pair mismatch" Examples.pairMismatch runOne "bad fst" Examples.badFst + runOne "bad succ" Examples.badSucc + runOne "bad refl" Examples.badRefl diff --git a/Tests.lean b/Tests.lean index d5e087f..848b803 100644 --- a/Tests.lean +++ b/Tests.lean @@ -4,11 +4,16 @@ open BidirTT inductive Expectation where | okTy : Tm → Expectation + | okTyNorm : Tm → Tm → Expectation | errContains : String → Expectation private def renderType (ty : Val) : Except String Tm := BidirTT.quote 0 ty +private def renderNormal (tm : Tm) : Except String Tm := do + let v ← eval [] tm + quote 0 v + private def containsText (haystack needle : String) : Bool := needle.isEmpty || (haystack.splitOn needle).length > 1 @@ -20,20 +25,47 @@ structure TestCase where def cases : List TestCase := [ ⟨"U0 is typed by U1", Examples.univ0, .okTy (.univ 1)⟩, + ⟨"U0 subsumes into U2", .ann (.univ 0) (.univ 2), + .okTy (.univ 2)⟩, + ⟨"Nat is typed by U0", .nat, + .okTy (.univ 0)⟩, + ⟨"succ zero infers Nat", (.succ .zero), + .okTy .nat⟩, ⟨"id typechecks", Examples.idAnn, .okTy (.pi (.univ 0) (.pi (.var 0) (.var 1)))⟩, + ⟨"Pi subsumption is contravariant in the domain", .ann + (.ann + (.lam "A" (.var "A")) + (.pi "A" (.univ 1) (.univ 1))) + (.pi "A" (.univ 0) (.univ 2)), + .okTy (.pi (.univ 0) (.univ 2))⟩, ⟨"const typechecks", Examples.constAnn, .okTy (.pi (.univ 0) (.pi (.univ 0) (.pi (.var 1) (.pi (.var 1) (.var 3)))))⟩, ⟨"swap typechecks", Examples.swapAnn, .okTy (.pi (.univ 0) (.pi (.univ 0) (.pi (.sig (.var 1) (.var 1)) (.sig (.var 1) (.var 3)))))⟩, ⟨"dependent pair typechecks", Examples.depPairAnn, .okTy (.sig (.univ 2) (.var 0))⟩, + ⟨"dependent pair subsumes into a lifted Sigma", .ann Examples.depPairAnn + (.sig "A" (.univ 3) (.var "A")), + .okTy (.sig (.univ 3) (.var 0))⟩, + ⟨"natElim computes on numerals", Examples.natFoldId, + .okTyNorm .nat (.succ (.succ .zero))⟩, + ⟨"unitElim computes on tt", Examples.unitToNat, + .okTyNorm .nat (.succ (.succ .zero))⟩, + ⟨"emptyElim builds absurd maps", Examples.absurdNat, + .okTy (.pi .empty .nat)⟩, + ⟨"refl inhabits reflexive identity", Examples.reflZero, + .okTyNorm (.id .nat .zero .zero) .refl⟩, + ⟨"idElim computes on refl", Examples.idElimNat, + .okTyNorm .nat .zero⟩, ⟨"fst infers the first projection", Examples.fstDepPair, .okTy (.univ 2)⟩, ⟨"snd infers the dependent second projection", Examples.sndDepPair, .okTy (.univ 1)⟩, ⟨"let infers through definitions", Examples.letUniverse, .okTy (.univ 1)⟩, + ⟨"bad succ rejected", Examples.badSucc, + .errContains "type mismatch: expected Nat, got U1"⟩, ⟨"self application rejected", Examples.omegaAnn, .errContains "expected Pi type in application"⟩, ⟨"unknown variable rejected", Examples.unknownVar, @@ -41,7 +73,9 @@ def cases : List TestCase := [ ⟨"pair mismatch rejected at the Sigma body", Examples.pairMismatch, .errContains "type mismatch: expected U1, got U2"⟩, ⟨"bad fst rejected", Examples.badFst, - .errContains "expected Sigma type in .1, got U1"⟩ + .errContains "expected Sigma type in .1, got U1"⟩, + ⟨"bad refl rejected", Examples.badRefl, + .errContains "refl cannot inhabit"⟩ ] def runCase (tc : TestCase) : IO Bool := do @@ -58,9 +92,30 @@ def runCase (tc : TestCase) : IO Bool := do | Except.error err => IO.println s!"FAIL {tc.name} (could not quote type: {err})" pure false + | .okTyNorm expectedTy expectedNf, .ok (tm, ty) => + match renderType ty, renderNormal tm with + | Except.ok actualTy, Except.ok actualNf => + if actualTy == expectedTy && actualNf == expectedNf then + IO.println s!"PASS {tc.name}" + pure true + else if actualTy != expectedTy then + IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got {BidirTT.prettyTm actualTy})" + pure false + else + IO.println s!"FAIL {tc.name} (expected nf {BidirTT.prettyTm expectedNf}, got {BidirTT.prettyTm actualNf})" + pure false + | Except.error err, _ => + IO.println s!"FAIL {tc.name} (could not quote type: {err})" + pure false + | _, Except.error err => + IO.println s!"FAIL {tc.name} (could not normalize term: {err})" + pure false | .okTy expectedTy, .error err => IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got error {err})" pure false + | .okTyNorm expectedTy _, .error err => + IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got error {err})" + pure false | .errContains needle, .error err => if containsText err needle then IO.println s!"PASS {tc.name}"