From 8fb7e8748e383e1d57aa078333f2ed4a54c08ea9 Mon Sep 17 00:00:00 2001 From: imiel Date: Sat, 14 Feb 2026 13:34:18 +0000 Subject: [PATCH] thread optimisation obligations through rewrite passes --- src/pipeline.ml | 108 ++++++++++++++++++++++----------------------- test/invariants.ml | 13 ++++++ 2 files changed, 67 insertions(+), 54 deletions(-) diff --git a/src/pipeline.ml b/src/pipeline.ml index d251354..daaaa29 100644 --- a/src/pipeline.ml +++ b/src/pipeline.ml @@ -42,30 +42,27 @@ let default_flags = let obligation ?(kind = Preserve_relation) pass subject detail = { pass; kind; subject; detail } -let rec specialise_term obligations term = +let rec specialise_term note term = match term with | Source.TyApp (Source.TyLam (x, body), ty) -> - let body' = specialise_term obligations (Source.TyApp (Source.TyLam (x, body), ty)) in - begin match body' with - | Source.TyApp (Source.TyLam (x, body), ty) -> Source.(substitute_type x ty body) - | other -> other - end - | Source.Pair (a, b) -> Source.Pair (specialise_term obligations a, specialise_term obligations b) - | Source.Inl (ty, t) -> Source.Inl (ty, specialise_term obligations t) - | Source.Inr (ty, t) -> Source.Inr (ty, specialise_term obligations t) - | Source.Lam (x, a, b, body) -> Source.Lam (x, a, b, specialise_term obligations body) - | Source.App (f, a) -> Source.App (specialise_term obligations f, specialise_term obligations a) - | Source.TyLam (x, body) -> Source.TyLam (x, specialise_term obligations body) - | Source.TyApp (t, ty) -> Source.TyApp (specialise_term obligations t, ty) - | Source.Roll (ty, t) -> Source.Roll (ty, specialise_term obligations t) - | Source.Unroll t -> Source.Unroll (specialise_term obligations t) - | Source.Fix (f, ty, body) -> Source.Fix (f, ty, specialise_term obligations body) - | Source.Eq (g, a, b) -> Source.Eq (g, specialise_term obligations a, specialise_term obligations b) - | Source.If (c, t, e) -> Source.If (specialise_term obligations c, specialise_term obligations t, specialise_term obligations e) - | Source.Let (x, a, b) -> Source.Let (x, specialise_term obligations a, specialise_term obligations b) - | Source.LetPair (x, y, a, b) -> Source.LetPair (x, y, specialise_term obligations a, specialise_term obligations b) + note Preserve_relation "type application cloning" ("instantiated type binder " ^ x ^ " at " ^ Types.string_of_typ ty); + Source.(substitute_type x ty (specialise_term note body)) + | Source.Pair (a, b) -> Source.Pair (specialise_term note a, specialise_term note b) + | Source.Inl (ty, t) -> Source.Inl (ty, specialise_term note t) + | Source.Inr (ty, t) -> Source.Inr (ty, specialise_term note t) + | Source.Lam (x, a, b, body) -> Source.Lam (x, a, b, specialise_term note body) + | Source.App (f, a) -> Source.App (specialise_term note f, specialise_term note a) + | Source.TyLam (x, body) -> Source.TyLam (x, specialise_term note body) + | Source.TyApp (t, ty) -> Source.TyApp (specialise_term note t, ty) + | Source.Roll (ty, t) -> Source.Roll (ty, specialise_term note t) + | Source.Unroll t -> Source.Unroll (specialise_term note t) + | Source.Fix (f, ty, body) -> Source.Fix (f, ty, specialise_term note body) + | Source.Eq (g, a, b) -> Source.Eq (g, specialise_term note a, specialise_term note b) + | Source.If (c, t, e) -> Source.If (specialise_term note c, specialise_term note t, specialise_term note e) + | Source.Let (x, a, b) -> Source.Let (x, specialise_term note a, specialise_term note b) + | Source.LetPair (x, y, a, b) -> Source.LetPair (x, y, specialise_term note a, specialise_term note b) | Source.Case (s, (x, l), (y, r)) -> - Source.Case (specialise_term obligations s, (x, specialise_term obligations l), (y, specialise_term obligations r)) + Source.Case (specialise_term note s, (x, specialise_term note l), (y, specialise_term note r)) | other -> other let rec inline_cost = function @@ -82,47 +79,55 @@ let rec inline_cost = function | Source.LetPair (_, _, a, b) -> 1 + inline_cost a + inline_cost b | Source.Case (s, (_, l), (_, r)) -> 1 + inline_cost s + inline_cost l + inline_cost r -let rec inline_term obligations unsafe_repr_eq term = +let rec inline_term note unsafe_repr_eq term = match term with | Source.App (Source.Lam (x, _, _, body), arg) when inline_cost body <= 8 -> - Source.substitute x arg (inline_term obligations unsafe_repr_eq body) + note Preserve_relation "term beta inlining" ("inlined lambda argument " ^ x); + Source.substitute x arg (inline_term note unsafe_repr_eq body) | Source.TyApp (Source.TyLam (x, body), ty) -> - Source.substitute_type x ty (inline_term obligations unsafe_repr_eq body) + note Preserve_relation "type beta inlining" ("inlined type argument " ^ Types.string_of_typ ty ^ " for " ^ x); + Source.substitute_type x ty (inline_term note unsafe_repr_eq body) | Source.TyApp (Source.Var f, TInt) when unsafe_repr_eq && String.equal f "poly_const_false" -> - ignore obligations; + note Exposed_representation "parametric body exposure" "rewrote an abstract int instance into eq_int"; Source.Lam ("x", TInt, TArrow (TInt, TBool), Source.Lam ("y", TInt, TBool, Source.Eq (GInt, Source.Var "x", Source.Var "y"))) - | Source.Pair (a, b) -> Source.Pair (inline_term obligations unsafe_repr_eq a, inline_term obligations unsafe_repr_eq b) - | Source.Inl (ty, t) -> Source.Inl (ty, inline_term obligations unsafe_repr_eq t) - | Source.Inr (ty, t) -> Source.Inr (ty, inline_term obligations unsafe_repr_eq t) - | Source.Lam (x, a, b, body) -> Source.Lam (x, a, b, inline_term obligations unsafe_repr_eq body) + | Source.Pair (a, b) -> Source.Pair (inline_term note unsafe_repr_eq a, inline_term note unsafe_repr_eq b) + | Source.Inl (ty, t) -> Source.Inl (ty, inline_term note unsafe_repr_eq t) + | Source.Inr (ty, t) -> Source.Inr (ty, inline_term note unsafe_repr_eq t) + | Source.Lam (x, a, b, body) -> Source.Lam (x, a, b, inline_term note unsafe_repr_eq body) | Source.App (f, a) -> - let f' = inline_term obligations unsafe_repr_eq f in - let a' = inline_term obligations unsafe_repr_eq a in + let f' = inline_term note unsafe_repr_eq f in + let a' = inline_term note unsafe_repr_eq a in begin match f' with | Source.Lam (x, _, _, body) when inline_cost body <= 8 -> - inline_term obligations unsafe_repr_eq (Source.substitute x a' body) + note Preserve_relation "term beta inlining" ("inlined lambda argument " ^ x); + inline_term note unsafe_repr_eq (Source.substitute x a' body) | _ -> Source.App (f', a') end - | Source.TyLam (x, body) -> Source.TyLam (x, inline_term obligations unsafe_repr_eq body) + | Source.TyLam (x, body) -> Source.TyLam (x, inline_term note unsafe_repr_eq body) | Source.TyApp (t, ty) -> - let t' = inline_term obligations unsafe_repr_eq t in + let t' = inline_term note unsafe_repr_eq t in begin match t' with - | Source.TyLam (x, body) -> inline_term obligations unsafe_repr_eq (Source.substitute_type x ty body) + | Source.TyLam (x, body) -> + note Preserve_relation "type beta inlining" ("inlined type argument " ^ Types.string_of_typ ty ^ " for " ^ x); + inline_term note unsafe_repr_eq (Source.substitute_type x ty body) | _ -> Source.TyApp (t', ty) end - | Source.Roll (ty, t) -> Source.Roll (ty, inline_term obligations unsafe_repr_eq t) - | Source.Unroll t -> Source.Unroll (inline_term obligations unsafe_repr_eq t) - | Source.Fix (f, ty, body) -> Source.Fix (f, ty, inline_term obligations unsafe_repr_eq body) - | Source.Eq (g, a, b) -> Source.Eq (g, inline_term obligations unsafe_repr_eq a, inline_term obligations unsafe_repr_eq b) + | Source.Roll (ty, t) -> Source.Roll (ty, inline_term note unsafe_repr_eq t) + | Source.Unroll t -> Source.Unroll (inline_term note unsafe_repr_eq t) + | Source.Fix (f, ty, body) -> Source.Fix (f, ty, inline_term note unsafe_repr_eq body) + | Source.Eq (g, a, b) -> Source.Eq (g, inline_term note unsafe_repr_eq a, inline_term note unsafe_repr_eq b) | Source.If (c, t, e) -> - Source.If (inline_term obligations unsafe_repr_eq c, inline_term obligations unsafe_repr_eq t, inline_term obligations unsafe_repr_eq e) + Source.If (inline_term note unsafe_repr_eq c, inline_term note unsafe_repr_eq t, inline_term note unsafe_repr_eq e) | Source.Let (x, a, b) -> - let a' = inline_term obligations unsafe_repr_eq a in - let b' = inline_term obligations unsafe_repr_eq b in - if inline_cost a' <= 4 then Source.substitute x a' b' else Source.Let (x, a', b') - | Source.LetPair (x, y, a, b) -> Source.LetPair (x, y, inline_term obligations unsafe_repr_eq a, inline_term obligations unsafe_repr_eq b) + let a' = inline_term note unsafe_repr_eq a in + let b' = inline_term note unsafe_repr_eq b in + if inline_cost a' <= 4 then begin + note Preserve_relation "let inlining" ("inlined binding " ^ x); + Source.substitute x a' b' + end else Source.Let (x, a', b') + | Source.LetPair (x, y, a, b) -> Source.LetPair (x, y, inline_term note unsafe_repr_eq a, inline_term note unsafe_repr_eq b) | Source.Case (s, (x, l), (y, r)) -> - Source.Case (inline_term obligations unsafe_repr_eq s, (x, inline_term obligations unsafe_repr_eq l), (y, inline_term obligations unsafe_repr_eq r)) + Source.Case (inline_term note unsafe_repr_eq s, (x, inline_term note unsafe_repr_eq l), (y, inline_term note unsafe_repr_eq r)) | other -> other let rec repr_of_typ = function @@ -198,17 +203,12 @@ let compile flags source_type source_program = obligations := obligation ~kind pass subject detail :: !obligations in let specialised = - if flags.specialise then begin - note "specialise" Preserve_relation "type application cloning" "specialisation preserves source typing but requires instantiation closedness"; - specialise_term obligations source_program - end else source_program + if flags.specialise then specialise_term (note "specialise") source_program + else source_program in let inlined = - if flags.inline then begin - if flags.unsafe_repr_eq then - note "inline" Exposed_representation "parametric body exposure" "inlining may replace an abstract constant relation with eq_int at specialised int instances"; - inline_term obligations flags.unsafe_repr_eq specialised - end else specialised + if flags.inline then inline_term (note "inline") flags.unsafe_repr_eq specialised + else specialised in let target_program = if flags.repr_lower then begin diff --git a/test/invariants.ml b/test/invariants.ml index 31e8021..e09c649 100644 --- a/test/invariants.ml +++ b/test/invariants.ml @@ -18,6 +18,9 @@ let find_case name = | Some case -> case | None -> failwith ("missing case " ^ name) +let has_obligation kind obligations = + List.exists (fun (o : Pipeline.obligation) -> o.kind = kind) obligations + let () = List.iter (fun (case : Corpus.case) -> @@ -33,6 +36,16 @@ let () = assert_true "expected representation exposure witness" (repr.failure_mode = Audit.Representation_exposure); + assert_true + "expected unsafe inlining to emit representation exposure obligation" + (has_obligation Pipeline.Exposed_representation repr.compiled.obligations); + let safe = Audit.audit_case (find_case "safe-polymorphic-instantiation") in + assert_true + "safe instantiation should not emit representation exposure obligation" + (not (has_obligation Pipeline.Exposed_representation safe.compiled.obligations)); + assert_true + "safe instantiation should emit preserve relation obligations at rewrite sites" + (has_obligation Pipeline.Preserve_relation safe.compiled.obligations); let strict = Audit.audit_case (find_case "strictness-induced-termination-change") in assert_true "expected strictness shift witness"