106 lines
3.5 KiB
Lean4
106 lines
3.5 KiB
Lean4
import BidirTT.Context
|
||
import BidirTT.Eval
|
||
import BidirTT.Pretty
|
||
|
||
namespace BidirTT
|
||
|
||
abbrev TCM := Except String
|
||
|
||
private def showTy (l : Lvl) (v : Val) : String :=
|
||
match quote l v with
|
||
| Except.ok tm => prettyTm tm
|
||
| Except.error err => s!"<ill-scoped value: {err}>"
|
||
|
||
mutual
|
||
partial def check (cxt : Cxt) : Raw → Val → TCM Tm
|
||
| .lam x t, .pi a c => do
|
||
let bodyTy ← cApp c (.var cxt.lvl)
|
||
let t' ← check (cxt.bind x a) t bodyTy
|
||
pure (.lam t')
|
||
| .pair t u, .sig a c => do
|
||
let t' ← check cxt t a
|
||
let vt ← eval cxt.env t'
|
||
let bodyTy ← cApp c vt
|
||
let u' ← check cxt u bodyTy
|
||
pure (.pair t' u')
|
||
| .letE x a t u, ty => do
|
||
let (a', _) ← inferUniverse cxt a
|
||
let va := eval cxt.env a'
|
||
let va ← va
|
||
let t' ← check cxt t va
|
||
let vt ← eval cxt.env t'
|
||
let u' ← check (cxt.define x vt va) u ty
|
||
pure (.letE a' t' u')
|
||
| r, ty => do
|
||
let (t', ty') ← infer cxt r
|
||
if (← conv cxt.lvl ty' ty) then
|
||
pure t'
|
||
else
|
||
throw s!"type mismatch: expected {showTy cxt.lvl ty}, got {showTy cxt.lvl ty'}"
|
||
|
||
partial def infer (cxt : Cxt) : Raw → TCM (Tm × Val)
|
||
| .var x =>
|
||
match cxt.lookup x with
|
||
| some (i, a) => pure (.var i, a)
|
||
| none => throw s!"unknown variable {x}"
|
||
| .univ i => pure (.univ i, .univ (i + 1))
|
||
| .app t u => do
|
||
let (t', tty) ← infer cxt t
|
||
match tty with
|
||
| .pi a c => do
|
||
let u' ← check cxt u a
|
||
let vu ← eval cxt.env u'
|
||
let bodyTy ← cApp c vu
|
||
pure (.app t' u', bodyTy)
|
||
| _ => throw s!"expected Pi type in application, got {showTy cxt.lvl tty}"
|
||
| .fst t => do
|
||
let (t', tty) ← infer cxt t
|
||
match tty with
|
||
| .sig a _ => pure (.fst t', a)
|
||
| _ => throw s!"expected Sigma type in .1, got {showTy cxt.lvl tty}"
|
||
| .snd t => do
|
||
let (t', tty) ← infer cxt t
|
||
match tty with
|
||
| .sig _ c => do
|
||
let vt ← eval cxt.env t'
|
||
let fstv ← vFst vt
|
||
let bodyTy ← cApp c fstv
|
||
pure (.snd t', bodyTy)
|
||
| _ => throw s!"expected Sigma type in .2, got {showTy cxt.lvl tty}"
|
||
| .pi x a b => do
|
||
let (a', i) ← inferUniverse cxt a
|
||
let va ← eval cxt.env a'
|
||
let (b', j) ← inferUniverse (cxt.bind x va) b
|
||
pure (.pi a' b', .univ (Nat.max i j))
|
||
| .sig x a b => do
|
||
let (a', i) ← inferUniverse cxt a
|
||
let va ← eval cxt.env a'
|
||
let (b', j) ← inferUniverse (cxt.bind x va) b
|
||
pure (.sig a' b', .univ (Nat.max i j))
|
||
| .ann t a => do
|
||
let (a', _) ← inferUniverse cxt a
|
||
let va ← eval cxt.env a'
|
||
let t' ← check cxt t va
|
||
pure (t', va)
|
||
| .letE x a t u => do
|
||
let (a', _) ← inferUniverse cxt a
|
||
let va ← eval cxt.env a'
|
||
let t' ← check cxt t va
|
||
let vt ← eval cxt.env t'
|
||
let (u', uty) ← infer (cxt.define x vt va) u
|
||
pure (.letE a' t' u', uty)
|
||
| .lam _ _ => throw "cannot infer type of lambda, use an annotation"
|
||
| .pair _ _ => throw "cannot infer type of pair, use an annotation"
|
||
|
||
partial def inferUniverse (cxt : Cxt) (r : Raw) : TCM (Tm × Nat) := do
|
||
let (t, ty) ← infer cxt r
|
||
match ty with
|
||
| .univ level => pure (t, level)
|
||
| _ => throw s!"expected a universe, got {showTy cxt.lvl ty}"
|
||
end
|
||
|
||
def checkTop (r : Raw) : TCM (Tm × Val) :=
|
||
infer Cxt.empty r
|
||
|
||
end BidirTT
|