106 lines
3.5 KiB
Lean4
106 lines
3.5 KiB
Lean4
|
|
import BidirTT.Context
|
|||
|
|
import BidirTT.Eval
|
|||
|
|
import BidirTT.Pretty
|
|||
|
|
|
|||
|
|
namespace BidirTT
|
|||
|
|
|
|||
|
|
abbrev TCM := Except String
|
|||
|
|
|
|||
|
|
private def showTy (l : Lvl) (v : Val) : String :=
|
|||
|
|
match quote l v with
|
|||
|
|
| Except.ok tm => prettyTm tm
|
|||
|
|
| Except.error err => s!"<ill-scoped value: {err}>"
|
|||
|
|
|
|||
|
|
mutual
|
|||
|
|
partial def check (cxt : Cxt) : Raw → Val → TCM Tm
|
|||
|
|
| .lam x t, .pi a c => do
|
|||
|
|
let bodyTy ← cApp c (.var cxt.lvl)
|
|||
|
|
let t' ← check (cxt.bind x a) t bodyTy
|
|||
|
|
pure (.lam t')
|
|||
|
|
| .pair t u, .sig a c => do
|
|||
|
|
let t' ← check cxt t a
|
|||
|
|
let vt ← eval cxt.env t'
|
|||
|
|
let bodyTy ← cApp c vt
|
|||
|
|
let u' ← check cxt u bodyTy
|
|||
|
|
pure (.pair t' u')
|
|||
|
|
| .letE x a t u, ty => do
|
|||
|
|
let (a', _) ← inferUniverse cxt a
|
|||
|
|
let va := eval cxt.env a'
|
|||
|
|
let va ← va
|
|||
|
|
let t' ← check cxt t va
|
|||
|
|
let vt ← eval cxt.env t'
|
|||
|
|
let u' ← check (cxt.define x vt va) u ty
|
|||
|
|
pure (.letE a' t' u')
|
|||
|
|
| r, ty => do
|
|||
|
|
let (t', ty') ← infer cxt r
|
|||
|
|
if (← conv cxt.lvl ty' ty) then
|
|||
|
|
pure t'
|
|||
|
|
else
|
|||
|
|
throw s!"type mismatch: expected {showTy cxt.lvl ty}, got {showTy cxt.lvl ty'}"
|
|||
|
|
|
|||
|
|
partial def infer (cxt : Cxt) : Raw → TCM (Tm × Val)
|
|||
|
|
| .var x =>
|
|||
|
|
match cxt.lookup x with
|
|||
|
|
| some (i, a) => pure (.var i, a)
|
|||
|
|
| none => throw s!"unknown variable {x}"
|
|||
|
|
| .univ i => pure (.univ i, .univ (i + 1))
|
|||
|
|
| .app t u => do
|
|||
|
|
let (t', tty) ← infer cxt t
|
|||
|
|
match tty with
|
|||
|
|
| .pi a c => do
|
|||
|
|
let u' ← check cxt u a
|
|||
|
|
let vu ← eval cxt.env u'
|
|||
|
|
let bodyTy ← cApp c vu
|
|||
|
|
pure (.app t' u', bodyTy)
|
|||
|
|
| _ => throw s!"expected Pi type in application, got {showTy cxt.lvl tty}"
|
|||
|
|
| .fst t => do
|
|||
|
|
let (t', tty) ← infer cxt t
|
|||
|
|
match tty with
|
|||
|
|
| .sig a _ => pure (.fst t', a)
|
|||
|
|
| _ => throw s!"expected Sigma type in .1, got {showTy cxt.lvl tty}"
|
|||
|
|
| .snd t => do
|
|||
|
|
let (t', tty) ← infer cxt t
|
|||
|
|
match tty with
|
|||
|
|
| .sig _ c => do
|
|||
|
|
let vt ← eval cxt.env t'
|
|||
|
|
let fstv ← vFst vt
|
|||
|
|
let bodyTy ← cApp c fstv
|
|||
|
|
pure (.snd t', bodyTy)
|
|||
|
|
| _ => throw s!"expected Sigma type in .2, got {showTy cxt.lvl tty}"
|
|||
|
|
| .pi x a b => do
|
|||
|
|
let (a', i) ← inferUniverse cxt a
|
|||
|
|
let va ← eval cxt.env a'
|
|||
|
|
let (b', j) ← inferUniverse (cxt.bind x va) b
|
|||
|
|
pure (.pi a' b', .univ (Nat.max i j))
|
|||
|
|
| .sig x a b => do
|
|||
|
|
let (a', i) ← inferUniverse cxt a
|
|||
|
|
let va ← eval cxt.env a'
|
|||
|
|
let (b', j) ← inferUniverse (cxt.bind x va) b
|
|||
|
|
pure (.sig a' b', .univ (Nat.max i j))
|
|||
|
|
| .ann t a => do
|
|||
|
|
let (a', _) ← inferUniverse cxt a
|
|||
|
|
let va ← eval cxt.env a'
|
|||
|
|
let t' ← check cxt t va
|
|||
|
|
pure (t', va)
|
|||
|
|
| .letE x a t u => do
|
|||
|
|
let (a', _) ← inferUniverse cxt a
|
|||
|
|
let va ← eval cxt.env a'
|
|||
|
|
let t' ← check cxt t va
|
|||
|
|
let vt ← eval cxt.env t'
|
|||
|
|
let (u', uty) ← infer (cxt.define x vt va) u
|
|||
|
|
pure (.letE a' t' u', uty)
|
|||
|
|
| .lam _ _ => throw "cannot infer type of lambda, use an annotation"
|
|||
|
|
| .pair _ _ => throw "cannot infer type of pair, use an annotation"
|
|||
|
|
|
|||
|
|
partial def inferUniverse (cxt : Cxt) (r : Raw) : TCM (Tm × Nat) := do
|
|||
|
|
let (t, ty) ← infer cxt r
|
|||
|
|
match ty with
|
|||
|
|
| .univ level => pure (t, level)
|
|||
|
|
| _ => throw s!"expected a universe, got {showTy cxt.lvl ty}"
|
|||
|
|
end
|
|||
|
|
|
|||
|
|
def checkTop (r : Raw) : TCM (Tm × Val) :=
|
|||
|
|
infer Cxt.empty r
|
|||
|
|
|
|||
|
|
end BidirTT
|