From a154e2b98c0b1a74766eeee4a557751e6ad5de8d Mon Sep 17 00:00:00 2001 From: imiel Date: Sun, 19 Apr 2026 04:17:45 +0000 Subject: [PATCH] Initial --- .gitignore | 4 ++ BidirTT.lean | 7 ++ BidirTT/Check.lean | 105 +++++++++++++++++++++++++++++ BidirTT/Context.lean | 30 +++++++++ BidirTT/Eval.lean | 150 ++++++++++++++++++++++++++++++++++++++++++ BidirTT/Examples.lean | 72 ++++++++++++++++++++ BidirTT/Pretty.lean | 19 ++++++ BidirTT/Syntax.lean | 32 +++++++++ BidirTT/Value.lean | 27 ++++++++ Main.lean | 30 +++++++++ README.md | 17 +++++ Tests.lean | 106 +++++++++++++++++++++++++++++ lakefile.lean | 17 +++++ lean-toolchain | 1 + 14 files changed, 617 insertions(+) create mode 100644 .gitignore create mode 100644 BidirTT.lean create mode 100644 BidirTT/Check.lean create mode 100644 BidirTT/Context.lean create mode 100644 BidirTT/Eval.lean create mode 100644 BidirTT/Examples.lean create mode 100644 BidirTT/Pretty.lean create mode 100644 BidirTT/Syntax.lean create mode 100644 BidirTT/Value.lean create mode 100644 Main.lean create mode 100644 README.md create mode 100644 Tests.lean create mode 100644 lakefile.lean create mode 100644 lean-toolchain diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..a76ef6a --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +.lake/ +build/ +lakefile.olean +lake-manifest.json diff --git a/BidirTT.lean b/BidirTT.lean new file mode 100644 index 0000000..8d5226d --- /dev/null +++ b/BidirTT.lean @@ -0,0 +1,7 @@ +import BidirTT.Syntax +import BidirTT.Value +import BidirTT.Pretty +import BidirTT.Eval +import BidirTT.Context +import BidirTT.Check +import BidirTT.Examples diff --git a/BidirTT/Check.lean b/BidirTT/Check.lean new file mode 100644 index 0000000..ce80f3b --- /dev/null +++ b/BidirTT/Check.lean @@ -0,0 +1,105 @@ +import BidirTT.Context +import BidirTT.Eval +import BidirTT.Pretty + +namespace BidirTT + +abbrev TCM := Except String + +private def showTy (l : Lvl) (v : Val) : String := + match quote l v with + | Except.ok tm => prettyTm tm + | Except.error err => s!"" + +mutual + partial def check (cxt : Cxt) : Raw → Val → TCM Tm + | .lam x t, .pi a c => do + let bodyTy ← cApp c (.var cxt.lvl) + let t' ← check (cxt.bind x a) t bodyTy + pure (.lam t') + | .pair t u, .sig a c => do + let t' ← check cxt t a + let vt ← eval cxt.env t' + let bodyTy ← cApp c vt + let u' ← check cxt u bodyTy + pure (.pair t' u') + | .letE x a t u, ty => do + let (a', _) ← inferUniverse cxt a + let va := eval cxt.env a' + let va ← va + let t' ← check cxt t va + let vt ← eval cxt.env t' + let u' ← check (cxt.define x vt va) u ty + pure (.letE a' t' u') + | r, ty => do + let (t', ty') ← infer cxt r + if (← conv cxt.lvl ty' ty) then + pure t' + else + throw s!"type mismatch: expected {showTy cxt.lvl ty}, got {showTy cxt.lvl ty'}" + + partial def infer (cxt : Cxt) : Raw → TCM (Tm × Val) + | .var x => + match cxt.lookup x with + | some (i, a) => pure (.var i, a) + | none => throw s!"unknown variable {x}" + | .univ i => pure (.univ i, .univ (i + 1)) + | .app t u => do + let (t', tty) ← infer cxt t + match tty with + | .pi a c => do + let u' ← check cxt u a + let vu ← eval cxt.env u' + let bodyTy ← cApp c vu + pure (.app t' u', bodyTy) + | _ => throw s!"expected Pi type in application, got {showTy cxt.lvl tty}" + | .fst t => do + let (t', tty) ← infer cxt t + match tty with + | .sig a _ => pure (.fst t', a) + | _ => throw s!"expected Sigma type in .1, got {showTy cxt.lvl tty}" + | .snd t => do + let (t', tty) ← infer cxt t + match tty with + | .sig _ c => do + let vt ← eval cxt.env t' + let fstv ← vFst vt + let bodyTy ← cApp c fstv + pure (.snd t', bodyTy) + | _ => throw s!"expected Sigma type in .2, got {showTy cxt.lvl tty}" + | .pi x a b => do + let (a', i) ← inferUniverse cxt a + let va ← eval cxt.env a' + let (b', j) ← inferUniverse (cxt.bind x va) b + pure (.pi a' b', .univ (Nat.max i j)) + | .sig x a b => do + let (a', i) ← inferUniverse cxt a + let va ← eval cxt.env a' + let (b', j) ← inferUniverse (cxt.bind x va) b + pure (.sig a' b', .univ (Nat.max i j)) + | .ann t a => do + let (a', _) ← inferUniverse cxt a + let va ← eval cxt.env a' + let t' ← check cxt t va + pure (t', va) + | .letE x a t u => do + let (a', _) ← inferUniverse cxt a + let va ← eval cxt.env a' + let t' ← check cxt t va + let vt ← eval cxt.env t' + let (u', uty) ← infer (cxt.define x vt va) u + pure (.letE a' t' u', uty) + | .lam _ _ => throw "cannot infer type of lambda, use an annotation" + | .pair _ _ => throw "cannot infer type of pair, use an annotation" + + partial def inferUniverse (cxt : Cxt) (r : Raw) : TCM (Tm × Nat) := do + let (t, ty) ← infer cxt r + match ty with + | .univ level => pure (t, level) + | _ => throw s!"expected a universe, got {showTy cxt.lvl ty}" +end + +def checkTop (r : Raw) : TCM (Tm × Val) := + infer Cxt.empty r + +end BidirTT diff --git a/BidirTT/Context.lean b/BidirTT/Context.lean new file mode 100644 index 0000000..abf70dc --- /dev/null +++ b/BidirTT/Context.lean @@ -0,0 +1,30 @@ +import BidirTT.Value + +namespace BidirTT + +structure Cxt where + env : Env + types : List (Name × Val) + lvl : Lvl + deriving Inhabited + +def Cxt.empty : Cxt := ⟨[], [], 0⟩ + +def Cxt.bind (cxt : Cxt) (x : Name) (a : Val) : Cxt := + { env := .var cxt.lvl :: cxt.env + , types := (x, a) :: cxt.types + , lvl := cxt.lvl + 1 } + +def Cxt.define (cxt : Cxt) (x : Name) (v a : Val) : Cxt := + { env := v :: cxt.env + , types := (x, a) :: cxt.types + , lvl := cxt.lvl + 1 } + +private def lookupGo : Name → List (Name × Val) → Nat → Option (Nat × Val) + | _, [], _ => none + | x, (y, a) :: rest, i => if x == y then some (i, a) else lookupGo x rest (i+1) + +def Cxt.lookup (cxt : Cxt) (x : Name) : Option (Nat × Val) := + lookupGo x cxt.types 0 + +end BidirTT diff --git a/BidirTT/Eval.lean b/BidirTT/Eval.lean new file mode 100644 index 0000000..46e43aa --- /dev/null +++ b/BidirTT/Eval.lean @@ -0,0 +1,150 @@ +import BidirTT.Value + +namespace BidirTT + +abbrev EvalM := Except String + +mutual + partial def eval : Env → Tm → EvalM Val + | env, .var i => + match env[i]? with + | some v => pure v + | none => + throw s!"bad de Bruijn index {i} in environment of size {env.length}" + | env, .lam t => pure (.lam (.mk env t)) + | env, .app t u => do + let vt ← eval env t + let vu ← eval env u + vApp vt vu + | env, .pi a b => do + let va ← eval env a + pure (.pi va (.mk env b)) + | env, .sig a b => do + let va ← eval env a + pure (.sig va (.mk env b)) + | env, .pair t u => do + let vt ← eval env t + let vu ← eval env u + pure (.pair vt vu) + | env, .fst t => do + let vt ← eval env t + vFst vt + | env, .snd t => do + let vt ← eval env t + vSnd vt + | _, .univ i => pure (.univ i) + | env, .letE _ t u => do + let vt ← eval env t + eval (vt :: env) u + + partial def vApp : Val → Val → EvalM Val + | .lam c, u => cApp c u + | t, u => pure (.app t u) + + partial def vFst : Val → EvalM Val + | .pair a _ => pure a + | t => pure (.fst t) + + partial def vSnd : Val → EvalM Val + | .pair _ b => pure b + | t => pure (.snd t) + + partial def cApp : Closure → Val → EvalM Val + | .mk env body, v => eval (v :: env) body +end + +partial def quote : Lvl → Val → EvalM Tm + | l, .var x => + if x < l then + pure (.var (l - x - 1)) + else + throw s!"bad level {x} while quoting at level {l}" + | l, .app t u => do + let qt ← quote l t + let qu ← quote l u + pure (.app qt qu) + | l, .fst t => do + let qt ← quote l t + pure (.fst qt) + | l, .snd t => do + let qt ← quote l t + pure (.snd qt) + | l, .lam c => do + let body ← cApp c (.var l) + let qb ← quote (l + 1) body + pure (.lam qb) + | l, .pi a c => do + let qa ← quote l a + let body ← cApp c (.var l) + let qb ← quote (l + 1) body + pure (.pi qa qb) + | l, .sig a c => do + let qa ← quote l a + let body ← cApp c (.var l) + let qb ← quote (l + 1) body + pure (.sig qa qb) + | l, .pair a b => do + let qa ← quote l a + let qb ← quote l b + pure (.pair qa qb) + | _, .univ i => pure (.univ i) + +private def andThen (lhs : EvalM Bool) (rhs : Unit → EvalM Bool) : EvalM Bool := do + if (← lhs) then + rhs () + else + pure false + +partial def conv : Lvl → Val → Val → EvalM Bool + | _, .univ i, .univ j => pure (i == j) + | l, .pi a c, .pi a' c' => + andThen (conv l a a') fun _ => do + let b ← cApp c (.var l) + let b' ← cApp c' (.var l) + conv (l + 1) b b' + | l, .sig a c, .sig a' c' => + andThen (conv l a a') fun _ => do + let b ← cApp c (.var l) + let b' ← cApp c' (.var l) + conv (l + 1) b b' + | l, .lam c, .lam c' => + do + let body ← cApp c (.var l) + let body' ← cApp c' (.var l) + conv (l + 1) body body' + | l, .lam c, t => + do + let body ← cApp c (.var l) + let apped ← vApp t (.var l) + conv (l + 1) body apped + | l, t, .lam c => + do + let apped ← vApp t (.var l) + let body ← cApp c (.var l) + conv (l + 1) apped body + | l, .pair a b, .pair a' b' => + andThen (conv l a a') fun _ => conv l b b' + | l, .pair a b, p => + andThen + (do + let fstp ← vFst p + conv l a fstp) + fun _ => do + let sndp ← vSnd p + conv l b sndp + | l, p, .pair a b => + andThen + (do + let fstp ← vFst p + conv l fstp a) + fun _ => do + let sndp ← vSnd p + conv l sndp b + | _, .var x, .var y => pure (x == y) + | l, .app t u, .app t' u' => + andThen (conv l t t') fun _ => conv l u u' + | l, .fst t, .fst t' => conv l t t' + | l, .snd t, .snd t' => conv l t t' + | _, _, _ => pure false + +end BidirTT diff --git a/BidirTT/Examples.lean b/BidirTT/Examples.lean new file mode 100644 index 0000000..6d7b52d --- /dev/null +++ b/BidirTT/Examples.lean @@ -0,0 +1,72 @@ +import BidirTT.Syntax + +namespace BidirTT.Examples + +open BidirTT + +def idTy : Raw := + .pi "A" (.univ 0) (.pi "_" (.var "A") (.var "A")) + +def idTm : Raw := + .lam "A" (.lam "x" (.var "x")) + +def idAnn : Raw := .ann idTm idTy + +def constTy : Raw := + .pi "A" (.univ 0) (.pi "B" (.univ 0) + (.pi "_" (.var "A") (.pi "_" (.var "B") (.var "A")))) + +def constTm : Raw := + .lam "A" (.lam "B" (.lam "x" (.lam "_" (.var "x")))) + +def constAnn : Raw := .ann constTm constTy + +def swapTy : Raw := + .pi "A" (.univ 0) (.pi "B" (.univ 0) + (.pi "_" (.sig "_" (.var "A") (.var "B")) + (.sig "_" (.var "B") (.var "A")))) + +def swapTm : Raw := + .lam "A" (.lam "B" (.lam "p" + (.pair (.snd (.var "p")) (.fst (.var "p"))))) + +def swapAnn : Raw := .ann swapTm swapTy + +def depPairTy : Raw := + .sig "A" (.univ 2) (.var "A") + +def depPairTm : Raw := + .pair + (.univ 1) + (.pi "_" (.univ 0) (.univ 0)) + +def depPairAnn : Raw := .ann depPairTm depPairTy + +def fstDepPair : Raw := .fst depPairAnn + +def sndDepPair : Raw := .snd depPairAnn + +def omegaTy : Raw := + .pi "A" (.univ 0) (.var "A") + +def omegaTm : Raw := + .lam "x" (.app (.var "x") (.var "x")) + +def omegaAnn : Raw := .ann omegaTm omegaTy + +def unknownVar : Raw := .var "nope" + +def pairMismatch : Raw := + .ann (.pair (.univ 1) (.univ 1)) + (.sig "A" (.univ 2) (.var "A")) + +def badFst : Raw := .fst (.univ 0) + +def letUniverse : Raw := + .ann + (.letE "A" (.univ 1) (.pi "_" (.univ 0) (.univ 0)) (.var "A")) + (.univ 1) + +def univ0 : Raw := .univ 0 + +end BidirTT.Examples diff --git a/BidirTT/Pretty.lean b/BidirTT/Pretty.lean new file mode 100644 index 0000000..e396206 --- /dev/null +++ b/BidirTT/Pretty.lean @@ -0,0 +1,19 @@ +import BidirTT.Syntax + +namespace BidirTT + +mutual + partial def prettyTm : Tm → String + | .var i => s!"#{i}" + | .lam t => s!"(fun => {prettyTm t})" + | .app t u => s!"({prettyTm t} {prettyTm u})" + | .pi a b => s!"(Pi {prettyTm a} -> {prettyTm b})" + | .sig a b => s!"(Sigma {prettyTm a} * {prettyTm b})" + | .pair t u => s!"({prettyTm t}, {prettyTm u})" + | .fst t => s!"({prettyTm t}.1)" + | .snd t => s!"({prettyTm t}.2)" + | .univ i => s!"U{i}" + | .letE a t u => s!"(let : {prettyTm a} := {prettyTm t}; {prettyTm u})" +end + +end BidirTT diff --git a/BidirTT/Syntax.lean b/BidirTT/Syntax.lean new file mode 100644 index 0000000..9bb2437 --- /dev/null +++ b/BidirTT/Syntax.lean @@ -0,0 +1,32 @@ +namespace BidirTT + +abbrev Name := String + +inductive Raw where + | var : Name → Raw + | lam : Name → Raw → Raw + | app : Raw → Raw → Raw + | pi : Name → Raw → Raw → Raw + | sig : Name → Raw → Raw → Raw + | pair : Raw → Raw → Raw + | fst : Raw → Raw + | snd : Raw → Raw + | univ : Nat → Raw + | letE : Name → Raw → Raw → Raw → Raw + | ann : Raw → Raw → Raw + deriving Repr, Inhabited, BEq, DecidableEq + +inductive Tm where + | var : Nat → Tm + | lam : Tm → Tm + | app : Tm → Tm → Tm + | pi : Tm → Tm → Tm + | sig : Tm → Tm → Tm + | pair : Tm → Tm → Tm + | fst : Tm → Tm + | snd : Tm → Tm + | univ : Nat → Tm + | letE : Tm → Tm → Tm → Tm + deriving Repr, Inhabited, BEq, DecidableEq + +end BidirTT diff --git a/BidirTT/Value.lean b/BidirTT/Value.lean new file mode 100644 index 0000000..9267c64 --- /dev/null +++ b/BidirTT/Value.lean @@ -0,0 +1,27 @@ +import BidirTT.Syntax + +namespace BidirTT + +mutual + inductive Val where + | var : Nat → Val + | app : Val → Val → Val + | fst : Val → Val + | snd : Val → Val + | lam : Closure → Val + | pi : Val → Closure → Val + | sig : Val → Closure → Val + | pair : Val → Val → Val + | univ : Nat → Val + + inductive Closure where + | mk : List Val → Tm → Closure +end + +abbrev Env := List Val +abbrev Lvl := Nat + +instance : Inhabited Val := ⟨.univ 0⟩ +instance : Inhabited Closure := ⟨.mk [] (.univ 0)⟩ + +end BidirTT diff --git a/Main.lean b/Main.lean new file mode 100644 index 0000000..2778915 --- /dev/null +++ b/Main.lean @@ -0,0 +1,30 @@ +import BidirTT + +open BidirTT + +def runOne (label : String) (r : Raw) : IO Unit := do + match checkTop r with + | .ok (t, ty) => + IO.println s!"[ok] {label}" + IO.println s!" term : {BidirTT.prettyTm t}" + match quote 0 ty with + | Except.ok qt => + IO.println s!" type : {BidirTT.prettyTm qt}" + | Except.error err => + IO.println s!" type : " + | .error e => + IO.println s!"[err] {label}: {e}" + +def main : IO Unit := do + runOne "U0" Examples.univ0 + runOne "id" Examples.idAnn + runOne "const" Examples.constAnn + runOne "swap" Examples.swapAnn + runOne "depPair" Examples.depPairAnn + runOne "depPair.1" Examples.fstDepPair + runOne "depPair.2" Examples.sndDepPair + runOne "let universe" Examples.letUniverse + runOne "omega (bad)" Examples.omegaAnn + runOne "unknown var" Examples.unknownVar + runOne "pair mismatch" Examples.pairMismatch + runOne "bad fst" Examples.badFst diff --git a/README.md b/README.md new file mode 100644 index 0000000..6019e69 --- /dev/null +++ b/README.md @@ -0,0 +1,17 @@ +# iris + +Small (toy) dependently typed core w. a bidirectional typechecker. It takes a named raw syntax, checks and elaborates it into a de Bruijn core and evaluates terms into semantic values, quotes them back for diagnostics and uses conversion checking to compare normal forms. The current kernel has explicit universe levels `U0`, `U1`, `U2`, dependent function and pair types, projections, annotations, and `let` + +One of the checked examples is: + +```lean +def depPairTy : Raw := + .sig "A" (.univ 2) (.var "A") + +def depPairTm : Raw := + .pair + (.univ 1) + (.pi "_" (.univ 0) (.univ 0)) +``` + +which elaborates to a pair whose first component is the type `U1` and whose second component inhabits it diff --git a/Tests.lean b/Tests.lean new file mode 100644 index 0000000..d5e087f --- /dev/null +++ b/Tests.lean @@ -0,0 +1,106 @@ +import BidirTT + +open BidirTT + +inductive Expectation where + | okTy : Tm → Expectation + | errContains : String → Expectation + +private def renderType (ty : Val) : Except String Tm := + BidirTT.quote 0 ty + +private def containsText (haystack needle : String) : Bool := + needle.isEmpty || (haystack.splitOn needle).length > 1 + +structure TestCase where + name : String + input : Raw + expect : Expectation + +def cases : List TestCase := [ + ⟨"U0 is typed by U1", Examples.univ0, + .okTy (.univ 1)⟩, + ⟨"id typechecks", Examples.idAnn, + .okTy (.pi (.univ 0) (.pi (.var 0) (.var 1)))⟩, + ⟨"const typechecks", Examples.constAnn, + .okTy (.pi (.univ 0) (.pi (.univ 0) (.pi (.var 1) (.pi (.var 1) (.var 3)))))⟩, + ⟨"swap typechecks", Examples.swapAnn, + .okTy (.pi (.univ 0) (.pi (.univ 0) (.pi (.sig (.var 1) (.var 1)) (.sig (.var 1) (.var 3)))))⟩, + ⟨"dependent pair typechecks", Examples.depPairAnn, + .okTy (.sig (.univ 2) (.var 0))⟩, + ⟨"fst infers the first projection", Examples.fstDepPair, + .okTy (.univ 2)⟩, + ⟨"snd infers the dependent second projection", Examples.sndDepPair, + .okTy (.univ 1)⟩, + ⟨"let infers through definitions", Examples.letUniverse, + .okTy (.univ 1)⟩, + ⟨"self application rejected", Examples.omegaAnn, + .errContains "expected Pi type in application"⟩, + ⟨"unknown variable rejected", Examples.unknownVar, + .errContains "unknown variable nope"⟩, + ⟨"pair mismatch rejected at the Sigma body", Examples.pairMismatch, + .errContains "type mismatch: expected U1, got U2"⟩, + ⟨"bad fst rejected", Examples.badFst, + .errContains "expected Sigma type in .1, got U1"⟩ +] + +def runCase (tc : TestCase) : IO Bool := do + match tc.expect, checkTop tc.input with + | .okTy expectedTy, .ok (_, ty) => + match renderType ty with + | Except.ok actualTy => + if actualTy == expectedTy then + IO.println s!"PASS {tc.name}" + pure true + else + IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got {BidirTT.prettyTm actualTy})" + pure false + | Except.error err => + IO.println s!"FAIL {tc.name} (could not quote type: {err})" + pure false + | .okTy expectedTy, .error err => + IO.println s!"FAIL {tc.name} (expected type {BidirTT.prettyTm expectedTy}, got error {err})" + pure false + | .errContains needle, .error err => + if containsText err needle then + IO.println s!"PASS {tc.name}" + pure true + else + IO.println s!"FAIL {tc.name} (expected error containing {needle}, got {err})" + pure false + | .errContains needle, .ok (_, ty) => + match renderType ty with + | Except.ok actualTy => + IO.println s!"FAIL {tc.name} (expected error containing {needle}, got type {BidirTT.prettyTm actualTy})" + pure false + | Except.error err => + IO.println s!"FAIL {tc.name} (expected error containing {needle}, got quote failure {err})" + pure false + +def runInternalSafetyChecks : IO Bool := do + let malformedEvalOk := + match eval [] (.var 0) with + | Except.error err => containsText err "bad de Bruijn index 0" + | Except.ok _ => false + let malformedQuoteOk := + match quote 0 (.var 0) with + | Except.error err => containsText err "bad level 0" + | Except.ok _ => false + if malformedEvalOk && malformedQuoteOk then + IO.println "PASS malformed core terms are rejected safely" + pure true + else + IO.println "FAIL malformed core terms are rejected safely" + pure false + +def main : IO UInt32 := do + let results ← cases.mapM runCase + let safetyOk ← runInternalSafetyChecks + let allResults := results ++ [safetyOk] + let failed := allResults.countP (· == false) + if failed == 0 then + IO.println s!"\n{allResults.length} passed, 0 failed" + pure 0 + else + IO.println s!"\n{allResults.length - failed} passed, {failed} failed" + pure 1 diff --git a/lakefile.lean b/lakefile.lean new file mode 100644 index 0000000..59c9e53 --- /dev/null +++ b/lakefile.lean @@ -0,0 +1,17 @@ +import Lake +open Lake DSL + +package bidirtt where + leanOptions := #[ + ⟨`pp.unicode.fun, true⟩, + ⟨`autoImplicit, false⟩ + ] + +@[default_target] +lean_lib BidirTT where + +lean_exe bidirtt where + root := `Main + +lean_exe tests where + root := `Tests diff --git a/lean-toolchain b/lean-toolchain new file mode 100644 index 0000000..d0eb99f --- /dev/null +++ b/lean-toolchain @@ -0,0 +1 @@ +leanprover/lean4:v4.15.0