From 963c9f3e9439fd741274ac004144d29dbd3bd420 Mon Sep 17 00:00:00 2001 From: imiel Date: Sun, 19 Apr 2026 15:03:40 +0000 Subject: [PATCH] Split neutrals from canonical values so stuck eliminators arent encoded via constructor overloading --- BidirTT/Check.lean | 8 +- BidirTT/Context.lean | 2 +- BidirTT/Eval.lean | 361 ++++++++++++++++++++++--------------------- BidirTT/Value.lean | 20 ++- Tests.lean | 61 +++++--- 5 files changed, 245 insertions(+), 207 deletions(-) diff --git a/BidirTT/Check.lean b/BidirTT/Check.lean index cb0e81f..cf7d820 100644 --- a/BidirTT/Check.lean +++ b/BidirTT/Check.lean @@ -14,7 +14,7 @@ private def showTy (l : Lvl) (v : Val) : String := mutual partial def check (cxt : Cxt) : Raw → Val → TCM Tm | .lam x t, .pi a c => do - let bodyTy ← cApp c (.var cxt.lvl) + let bodyTy ← cApp c (.neu (.var cxt.lvl)) let t' ← check (cxt.bind x a) t bodyTy pure (.lam t') | .pair t u, .sig a c => do @@ -62,8 +62,8 @@ mutual let zTy ← vApp vmotive .zero let z' ← check cxt z zTy let kCxt := cxt.bind k .nat - let ihTy ← vApp vmotive (.var cxt.lvl) - let stepTy ← vApp vmotive (.succ (.var cxt.lvl)) + let ihTy ← vApp vmotive (.neu (.var cxt.lvl)) + let stepTy ← vApp vmotive (.succ (.neu (.var cxt.lvl))) let stepBody' ← check (kCxt.bind ih ihTy) s stepTy let step' : Tm := .lam (.lam stepBody') let resultTy ← vApp vmotive vscrut @@ -109,7 +109,7 @@ mutual let vtarget ← eval cxt.env target' if !(← conv cxt.lvl vtarget rhs) then throw s!"idElim target {showTy cxt.lvl vtarget} does not match the equality endpoint {showTy cxt.lvl rhs}" - let eqVarTy : Val := .id a x (.var cxt.lvl) + let eqVarTy : Val := .id a x (.neu (.var cxt.lvl)) let (motiveBody', level) ← inferUniverse ((cxt.bind y a).bind p eqVarTy) motive let motive' : Tm := .lam (.lam motiveBody') let vmotive ← eval cxt.env motive' diff --git a/BidirTT/Context.lean b/BidirTT/Context.lean index abf70dc..c5865d3 100644 --- a/BidirTT/Context.lean +++ b/BidirTT/Context.lean @@ -11,7 +11,7 @@ structure Cxt where def Cxt.empty : Cxt := ⟨[], [], 0⟩ def Cxt.bind (cxt : Cxt) (x : Name) (a : Val) : Cxt := - { env := .var cxt.lvl :: cxt.env + { env := .neu (.var cxt.lvl) :: cxt.env , types := (x, a) :: cxt.types , lvl := cxt.lvl + 1 } diff --git a/BidirTT/Eval.lean b/BidirTT/Eval.lean index 9d38332..e9cd58b 100644 --- a/BidirTT/Eval.lean +++ b/BidirTT/Eval.lean @@ -73,16 +73,19 @@ mutual eval (vt :: env) u partial def vApp : Val → Val → EvalM Val - | .lam c, u => cApp c u - | t, u => pure (.app t u) + | .lam c, u => cApp c u + | .neu t, u => pure (.neu (.app t u)) + | _, _ => throw "bad application head during evaluation" partial def vFst : Val → EvalM Val | .pair a _ => pure a - | t => pure (.fst t) + | .neu t => pure (.neu (.fst t)) + | _ => throw "bad fst projection during evaluation" partial def vSnd : Val → EvalM Val | .pair _ b => pure b - | t => pure (.snd t) + | .neu t => pure (.neu (.snd t)) + | _ => throw "bad snd projection during evaluation" partial def vNatElim : Val → Val → Val → Val → EvalM Val | _, z, _, .zero => pure z @@ -90,93 +93,102 @@ mutual let ih ← vNatElim m z s n let step ← vApp s n vApp step ih - | m, z, s, n => pure (.natElim m z s n) + | m, z, s, .neu n => pure (.neu (.natElim m z s n)) + | _, _, _, _ => throw "bad Nat eliminand during evaluation" partial def vUnitElim : Val → Val → Val → EvalM Val - | _, t, .triv => pure t - | m, t, u => pure (.unitElim m t u) + | _, t, .triv => pure t + | m, t, .neu u => pure (.neu (.unitElim m t u)) + | _, _, _ => throw "bad Unit eliminand during evaluation" partial def vEmptyElim : Val → Val → EvalM Val - | m, e => pure (.emptyElim m e) + | m, .neu e => pure (.neu (.emptyElim m e)) + | _, _ => throw "bad Empty eliminand during evaluation" partial def vIdElim : Val → Val → Val → Val → EvalM Val | _, r, _, .refl => pure r - | m, r, y, p => pure (.idElim m r y p) + | m, r, y, .neu p => pure (.neu (.idElim m r y p)) + | _, _, _, _ => throw "bad Id eliminand during evaluation" partial def cApp : Closure → Val → EvalM Val | .mk env body, v => eval (v :: env) body end -partial def quote : Lvl → Val → EvalM Tm - | l, .var x => - if x < l then - pure (.var (l - x - 1)) - else - throw s!"bad level {x} while quoting at level {l}" - | l, .app t u => do - let qt ← quote l t - let qu ← quote l u - pure (.app qt qu) - | l, .fst t => do - let qt ← quote l t - pure (.fst qt) - | l, .snd t => do - let qt ← quote l t - pure (.snd qt) - | _, .nat => pure .nat - | _, .zero => pure .zero - | l, .succ t => do - let qt ← quote l t - pure (.succ qt) - | l, .natElim m z s n => do - let qm ← quote l m - let qz ← quote l z - let qs ← quote l s - let qn ← quote l n - pure (.natElim qm qz qs qn) - | _, .unit => pure .unit - | _, .triv => pure .triv - | l, .unitElim m t u => do - let qm ← quote l m - let qt ← quote l t - let qu ← quote l u - pure (.unitElim qm qt qu) - | _, .empty => pure .empty - | l, .emptyElim m e => do - let qm ← quote l m - let qe ← quote l e - pure (.emptyElim qm qe) - | l, .id a t u => do - let qa ← quote l a - let qt ← quote l t - let qu ← quote l u - pure (.id qa qt qu) - | _, .refl => pure .refl - | l, .idElim m r y p => do - let qm ← quote l m - let qr ← quote l r - let qy ← quote l y - let qp ← quote l p - pure (.idElim qm qr qy qp) - | l, .lam c => do - let body ← cApp c (.var l) - let qb ← quote (l + 1) body - pure (.lam qb) - | l, .pi a c => do - let qa ← quote l a - let body ← cApp c (.var l) - let qb ← quote (l + 1) body - pure (.pi qa qb) - | l, .sig a c => do - let qa ← quote l a - let body ← cApp c (.var l) - let qb ← quote (l + 1) body - pure (.sig qa qb) - | l, .pair a b => do - let qa ← quote l a - let qb ← quote l b - pure (.pair qa qb) - | _, .univ i => pure (.univ i) +mutual + partial def quoteNeutral : Lvl → Neutral → EvalM Tm + | l, .var x => + if x < l then + pure (.var (l - x - 1)) + else + throw s!"bad level {x} while quoting at level {l}" + | l, .app t u => do + let qt ← quoteNeutral l t + let qu ← quote l u + pure (.app qt qu) + | l, .fst t => do + let qt ← quoteNeutral l t + pure (.fst qt) + | l, .snd t => do + let qt ← quoteNeutral l t + pure (.snd qt) + | l, .natElim m z s n => do + let qm ← quote l m + let qz ← quote l z + let qs ← quote l s + let qn ← quoteNeutral l n + pure (.natElim qm qz qs qn) + | l, .unitElim m t u => do + let qm ← quote l m + let qt ← quote l t + let qu ← quoteNeutral l u + pure (.unitElim qm qt qu) + | l, .emptyElim m e => do + let qm ← quote l m + let qe ← quoteNeutral l e + pure (.emptyElim qm qe) + | l, .idElim m r y p => do + let qm ← quote l m + let qr ← quote l r + let qy ← quote l y + let qp ← quoteNeutral l p + pure (.idElim qm qr qy qp) + + partial def quote : Lvl → Val → EvalM Tm + | l, .neu n => quoteNeutral l n + | l, .lam c => do + let body ← cApp c (.neu (.var l)) + let qb ← quote (l + 1) body + pure (.lam qb) + | l, .pi a c => do + let qa ← quote l a + let body ← cApp c (.neu (.var l)) + let qb ← quote (l + 1) body + pure (.pi qa qb) + | l, .sig a c => do + let qa ← quote l a + let body ← cApp c (.neu (.var l)) + let qb ← quote (l + 1) body + pure (.sig qa qb) + | l, .pair a b => do + let qa ← quote l a + let qb ← quote l b + pure (.pair qa qb) + | _, .nat => pure .nat + | _, .zero => pure .zero + | l, .succ t => do + let qt ← quote l t + pure (.succ qt) + | _, .unit => pure .unit + | _, .triv => pure .triv + | _, .empty => pure .empty + | l, .id a t u => do + let qa ← quote l a + let qt ← quote l t + let qu ← quote l u + pure (.id qa qt qu) + | _, .refl => pure .refl + | _, .univ i => pure (.univ i) +end private def andThen (lhs : EvalM Bool) (rhs : Unit → EvalM Bool) : EvalM Bool := do if (← lhs) then @@ -184,114 +196,119 @@ private def andThen (lhs : EvalM Bool) (rhs : Unit → EvalM Bool) : EvalM Bool else pure false -partial def conv : Lvl → Val → Val → EvalM Bool - | _, .univ i, .univ j => pure (i == j) - | l, .pi a c, .pi a' c' => - andThen (conv l a a') fun _ => do - let b ← cApp c (.var l) - let b' ← cApp c' (.var l) - conv (l + 1) b b' - | l, .sig a c, .sig a' c' => - andThen (conv l a a') fun _ => do - let b ← cApp c (.var l) - let b' ← cApp c' (.var l) - conv (l + 1) b b' - | _, .nat, .nat => pure true - | _, .zero, .zero => pure true - | l, .succ n, .succ n' => conv l n n' - | l, .natElim m z s n, .natElim m' z' s' n' => - andThen (conv l m m') fun _ => do - let sameZ ← conv l z z' - if sameZ then - let sameS ← conv l s s' - if sameS then - conv l n n' +mutual + partial def convNeutral : Lvl → Neutral → Neutral → EvalM Bool + | _, .var x, .var y => pure (x == y) + | l, .app t u, .app t' u' => + andThen (convNeutral l t t') fun _ => conv l u u' + | l, .fst t, .fst t' => + convNeutral l t t' + | l, .snd t, .snd t' => + convNeutral l t t' + | l, .natElim m z s n, .natElim m' z' s' n' => + andThen (conv l m m') fun _ => do + let sameZ ← conv l z z' + if sameZ then + let sameS ← conv l s s' + if sameS then + convNeutral l n n' + else + pure false else pure false - else - pure false - | _, .unit, .unit => pure true - | _, .triv, .triv => pure true - | l, .unitElim m t u, .unitElim m' t' u' => - andThen (conv l m m') fun _ => do - let sameT ← conv l t t' - if sameT then - conv l u u' - else - pure false - | _, .empty, .empty => pure true - | l, .emptyElim m e, .emptyElim m' e' => - andThen (conv l m m') fun _ => conv l e e' - | l, .id a t u, .id a' t' u' => - andThen (conv l a a') fun _ => do - let sameT ← conv l t t' - if sameT then - conv l u u' - else - pure false - | _, .refl, .refl => pure true - | l, .idElim m r y p, .idElim m' r' y' p' => - andThen (conv l m m') fun _ => do - let sameR ← conv l r r' - if sameR then - let sameY ← conv l y y' - if sameY then - conv l p p' + | l, .unitElim m t u, .unitElim m' t' u' => + andThen (conv l m m') fun _ => do + let sameT ← conv l t t' + if sameT then + convNeutral l u u' else pure false - else - pure false - | l, .lam c, .lam c' => - do - let body ← cApp c (.var l) - let body' ← cApp c' (.var l) + | l, .emptyElim m e, .emptyElim m' e' => + andThen (conv l m m') fun _ => convNeutral l e e' + | l, .idElim m r y p, .idElim m' r' y' p' => + andThen (conv l m m') fun _ => do + let sameR ← conv l r r' + if sameR then + let sameY ← conv l y y' + if sameY then + convNeutral l p p' + else + pure false + else + pure false + | _, _, _ => pure false + + partial def conv : Lvl → Val → Val → EvalM Bool + | _, .univ i, .univ j => pure (i == j) + | l, .pi a c, .pi a' c' => + andThen (conv l a a') fun _ => do + let b ← cApp c (.neu (.var l)) + let b' ← cApp c' (.neu (.var l)) + conv (l + 1) b b' + | l, .sig a c, .sig a' c' => + andThen (conv l a a') fun _ => do + let b ← cApp c (.neu (.var l)) + let b' ← cApp c' (.neu (.var l)) + conv (l + 1) b b' + | _, .nat, .nat => pure true + | _, .zero, .zero => pure true + | l, .succ n, .succ n' => conv l n n' + | _, .unit, .unit => pure true + | _, .triv, .triv => pure true + | _, .empty, .empty => pure true + | l, .id a t u, .id a' t' u' => + andThen (conv l a a') fun _ => do + let sameT ← conv l t t' + if sameT then + conv l u u' + else + pure false + | _, .refl, .refl => pure true + | l, .lam c, .lam c' => do + let body ← cApp c (.neu (.var l)) + let body' ← cApp c' (.neu (.var l)) conv (l + 1) body body' - | l, .lam c, t => - do - let body ← cApp c (.var l) - let apped ← vApp t (.var l) + | l, .lam c, t => do + let body ← cApp c (.neu (.var l)) + let apped ← vApp t (.neu (.var l)) conv (l + 1) body apped - | l, t, .lam c => - do - let apped ← vApp t (.var l) - let body ← cApp c (.var l) + | l, t, .lam c => do + let apped ← vApp t (.neu (.var l)) + let body ← cApp c (.neu (.var l)) conv (l + 1) apped body - | l, .pair a b, .pair a' b' => - andThen (conv l a a') fun _ => conv l b b' - | l, .pair a b, p => - andThen - (do - let fstp ← vFst p - conv l a fstp) - fun _ => do - let sndp ← vSnd p - conv l b sndp - | l, p, .pair a b => - andThen - (do - let fstp ← vFst p - conv l fstp a) - fun _ => do - let sndp ← vSnd p - conv l sndp b - | _, .var x, .var y => pure (x == y) - | l, .app t u, .app t' u' => - andThen (conv l t t') fun _ => conv l u u' - | l, .fst t, .fst t' => conv l t t' - | l, .snd t, .snd t' => conv l t t' - | _, _, _ => pure false + | l, .pair a b, .pair a' b' => + andThen (conv l a a') fun _ => conv l b b' + | l, .pair a b, p => + andThen + (do + let fstp ← vFst p + conv l a fstp) + fun _ => do + let sndp ← vSnd p + conv l b sndp + | l, p, .pair a b => + andThen + (do + let fstp ← vFst p + conv l fstp a) + fun _ => do + let sndp ← vSnd p + conv l sndp b + | l, .neu n, .neu n' => convNeutral l n n' + | _, _, _ => pure false +end partial def sub : Lvl → Val → Val → EvalM Bool | _, .univ i, .univ j => pure (i <= j) | l, .pi a c, .pi a' c' => andThen (sub l a' a) fun _ => do - let b ← cApp c (.var l) - let b' ← cApp c' (.var l) + let b ← cApp c (.neu (.var l)) + let b' ← cApp c' (.neu (.var l)) sub (l + 1) b b' | l, .sig a c, .sig a' c' => andThen (sub l a a') fun _ => do - let b ← cApp c (.var l) - let b' ← cApp c' (.var l) + let b ← cApp c (.neu (.var l)) + let b' ← cApp c' (.neu (.var l)) sub (l + 1) b b' | l, t, t' => conv l t t' diff --git a/BidirTT/Value.lean b/BidirTT/Value.lean index 87f4515..220a01e 100644 --- a/BidirTT/Value.lean +++ b/BidirTT/Value.lean @@ -3,11 +3,18 @@ import BidirTT.Syntax namespace BidirTT mutual + inductive Neutral where + | var : Nat → Neutral + | app : Neutral → Val → Neutral + | fst : Neutral → Neutral + | snd : Neutral → Neutral + | natElim : Val → Val → Val → Neutral → Neutral + | unitElim : Val → Val → Neutral → Neutral + | emptyElim : Val → Neutral → Neutral + | idElim : Val → Val → Val → Neutral → Neutral + inductive Val where - | var : Nat → Val - | app : Val → Val → Val - | fst : Val → Val - | snd : Val → Val + | neu : Neutral → Val | lam : Closure → Val | pi : Val → Closure → Val | sig : Val → Closure → Val @@ -15,15 +22,11 @@ mutual | nat : Val | zero : Val | succ : Val → Val - | natElim : Val → Val → Val → Val → Val | unit : Val | triv : Val - | unitElim : Val → Val → Val → Val | empty : Val - | emptyElim : Val → Val → Val | id : Val → Val → Val → Val | refl : Val - | idElim : Val → Val → Val → Val → Val | univ : Nat → Val inductive Closure where @@ -33,6 +36,7 @@ end abbrev Env := List Val abbrev Lvl := Nat +instance : Inhabited Neutral := ⟨.var 0⟩ instance : Inhabited Val := ⟨.univ 0⟩ instance : Inhabited Closure := ⟨.mk [] (.univ 0)⟩ diff --git a/Tests.lean b/Tests.lean index 1ab4541..f45672b 100644 --- a/Tests.lean +++ b/Tests.lean @@ -84,52 +84,52 @@ def runCase (tc : TestCase) : IO Bool := do match renderType ty with | Except.ok actualTy => if actualTy == expectedTy then - IO.println s!"PASS {tc.name}" + IO.println s!"pass {tc.name}" pure true else - IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got {BidirTT.prettyTm actualTy})" + IO.println s!"fail {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got {BidirTT.prettyTm actualTy})" pure false | Except.error err => - IO.println s!"FAIL {tc.name} (could not quote type: {err})" + IO.println s!"fail {tc.name} (could not quote type: {err})" pure false | .okTyNorm expectedTy expectedNf, .ok (tm, ty) => match renderType ty, renderNormal tm with | Except.ok actualTy, Except.ok actualNf => if actualTy == expectedTy && actualNf == expectedNf then - IO.println s!"PASS {tc.name}" + IO.println s!"pass {tc.name}" pure true else if actualTy != expectedTy then - IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got {BidirTT.prettyTm actualTy})" + IO.println s!"fail {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got {BidirTT.prettyTm actualTy})" pure false else - IO.println s!"FAIL {tc.name} (expected nf {BidirTT.prettyTm expectedNf}, got {BidirTT.prettyTm actualNf})" + IO.println s!"fail {tc.name} (expected nf {BidirTT.prettyTm expectedNf}, got {BidirTT.prettyTm actualNf})" pure false | Except.error err, _ => - IO.println s!"FAIL {tc.name} (could not quote type: {err})" + IO.println s!"fail {tc.name} (could not quote type: {err})" pure false | _, Except.error err => - IO.println s!"FAIL {tc.name} (could not normalize term: {err})" + IO.println s!"fail {tc.name} (could not normalize term: {err})" pure false | .okTy expectedTy, .error err => - IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got error {err})" + IO.println s!"fail {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got error {err})" pure false | .okTyNorm expectedTy _, .error err => - IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got error {err})" + IO.println s!"fail {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got error {err})" pure false | .errContains needle, .error err => if containsText err needle then - IO.println s!"PASS {tc.name}" + IO.println s!"pass {tc.name}" pure true else - IO.println s!"FAIL {tc.name} (expected error containing {needle}, got {err})" + IO.println s!"fail {tc.name} (expected error containing {needle}, got {err})" pure false | .errContains needle, .ok (_, ty) => match renderType ty with | Except.ok actualTy => - IO.println s!"FAIL {tc.name} (expected error containing {needle}, got type {BidirTT.prettyTm actualTy})" + IO.println s!"fail {tc.name} (expected error containing {needle}, got type {BidirTT.prettyTm actualTy})" pure false | Except.error err => - IO.println s!"FAIL {tc.name} (expected error containing {needle}, got quote failure {err})" + IO.println s!"fail {tc.name} (expected error containing {needle}, got quote failure {err})" pure false def runInternalSafetyChecks : IO Bool := do @@ -138,14 +138,30 @@ def runInternalSafetyChecks : IO Bool := do | Except.error err => containsText err "bad de Bruijn index 0" | Except.ok _ => false let malformedQuoteOk := - match quote 0 (.var 0) with + match quote 0 (.neu (.var 0)) with | Except.error err => containsText err "bad level 0" | Except.ok _ => false if malformedEvalOk && malformedQuoteOk then - IO.println "PASS malformed core terms are rejected safely" + IO.println "pass malformed core terms are rejected safely" pure true else - IO.println "FAIL malformed core terms are rejected safely" + IO.println "fail malformed core terms are rejected safely" + pure false + +def runNeutralRepresentationChecks : IO Bool := do + let stuckAppOk := + match vApp (.neu (.var 0)) .zero with + | Except.ok (.neu (.app (.var 0) .zero)) => true + | _ => false + let stuckNatElimOk := + match vNatElim .nat .zero (.neu (.var 1)) (.neu (.var 0)) with + | Except.ok (.neu (.natElim .nat .zero (.neu (.var 1)) (.var 0))) => true + | _ => false + if stuckAppOk && stuckNatElimOk then + IO.println "pass stuck eliminators stay in the neutral fragment" + pure true + else + IO.println "fail stuck eliminators stay in the neutral fragment" pure false def runPrettyPrinterChecks : IO Bool := do @@ -160,23 +176,24 @@ def runPrettyPrinterChecks : IO Bool := do containsText typeText "Pi (x0 : U0)" && !containsText typeText "#" if ok then - IO.println "PASS pretty printer rehydrates binder names" + IO.println "pass pretty printer rehydrates binder names" pure true else - IO.println s!"FAIL pretty printer rehydrates binder names (term: {termText}, type: {typeText})" + IO.println s!"fail pretty printer rehydrates binder names (term: {termText}, type: {typeText})" pure false | Except.error err => - IO.println s!"FAIL pretty printer rehydrates binder names (could not quote type: {err})" + IO.println s!"fail pretty printer rehydrates binder names (could not quote type: {err})" pure false | .error err => - IO.println s!"FAIL pretty printer rehydrates binder names (could not elaborate fixture: {err})" + IO.println s!"fail pretty printer rehydrates binder names (could not elaborate fixture: {err})" pure false def main : IO UInt32 := do let results ← cases.mapM runCase let safetyOk ← runInternalSafetyChecks + let neutralOk ← runNeutralRepresentationChecks let prettyOk ← runPrettyPrinterChecks - let allResults := results ++ [safetyOk, prettyOk] + let allResults := results ++ [safetyOk, neutralOk, prettyOk] let failed := allResults.countP (· == false) if failed == 0 then IO.println s!"\n{allResults.length} passed, 0 failed"