thread optimisation obligations through rewrite passes

This commit is contained in:
2026-02-14 13:34:18 +00:00
parent 6164dcb7d9
commit 8fb7e8748e
2 changed files with 67 additions and 54 deletions
+54 -54
View File
@@ -42,30 +42,27 @@ let default_flags =
let obligation ?(kind = Preserve_relation) pass subject detail = let obligation ?(kind = Preserve_relation) pass subject detail =
{ pass; kind; subject; detail } { pass; kind; subject; detail }
let rec specialise_term obligations term = let rec specialise_term note term =
match term with match term with
| Source.TyApp (Source.TyLam (x, body), ty) -> | Source.TyApp (Source.TyLam (x, body), ty) ->
let body' = specialise_term obligations (Source.TyApp (Source.TyLam (x, body), ty)) in note Preserve_relation "type application cloning" ("instantiated type binder " ^ x ^ " at " ^ Types.string_of_typ ty);
begin match body' with Source.(substitute_type x ty (specialise_term note body))
| Source.TyApp (Source.TyLam (x, body), ty) -> Source.(substitute_type x ty body) | Source.Pair (a, b) -> Source.Pair (specialise_term note a, specialise_term note b)
| other -> other | Source.Inl (ty, t) -> Source.Inl (ty, specialise_term note t)
end | Source.Inr (ty, t) -> Source.Inr (ty, specialise_term note t)
| Source.Pair (a, b) -> Source.Pair (specialise_term obligations a, specialise_term obligations b) | Source.Lam (x, a, b, body) -> Source.Lam (x, a, b, specialise_term note body)
| Source.Inl (ty, t) -> Source.Inl (ty, specialise_term obligations t) | Source.App (f, a) -> Source.App (specialise_term note f, specialise_term note a)
| Source.Inr (ty, t) -> Source.Inr (ty, specialise_term obligations t) | Source.TyLam (x, body) -> Source.TyLam (x, specialise_term note body)
| Source.Lam (x, a, b, body) -> Source.Lam (x, a, b, specialise_term obligations body) | Source.TyApp (t, ty) -> Source.TyApp (specialise_term note t, ty)
| Source.App (f, a) -> Source.App (specialise_term obligations f, specialise_term obligations a) | Source.Roll (ty, t) -> Source.Roll (ty, specialise_term note t)
| Source.TyLam (x, body) -> Source.TyLam (x, specialise_term obligations body) | Source.Unroll t -> Source.Unroll (specialise_term note t)
| Source.TyApp (t, ty) -> Source.TyApp (specialise_term obligations t, ty) | Source.Fix (f, ty, body) -> Source.Fix (f, ty, specialise_term note body)
| Source.Roll (ty, t) -> Source.Roll (ty, specialise_term obligations t) | Source.Eq (g, a, b) -> Source.Eq (g, specialise_term note a, specialise_term note b)
| Source.Unroll t -> Source.Unroll (specialise_term obligations t) | Source.If (c, t, e) -> Source.If (specialise_term note c, specialise_term note t, specialise_term note e)
| Source.Fix (f, ty, body) -> Source.Fix (f, ty, specialise_term obligations body) | Source.Let (x, a, b) -> Source.Let (x, specialise_term note a, specialise_term note b)
| Source.Eq (g, a, b) -> Source.Eq (g, specialise_term obligations a, specialise_term obligations b) | Source.LetPair (x, y, a, b) -> Source.LetPair (x, y, specialise_term note a, specialise_term note b)
| Source.If (c, t, e) -> Source.If (specialise_term obligations c, specialise_term obligations t, specialise_term obligations e)
| Source.Let (x, a, b) -> Source.Let (x, specialise_term obligations a, specialise_term obligations b)
| Source.LetPair (x, y, a, b) -> Source.LetPair (x, y, specialise_term obligations a, specialise_term obligations b)
| Source.Case (s, (x, l), (y, r)) -> | Source.Case (s, (x, l), (y, r)) ->
Source.Case (specialise_term obligations s, (x, specialise_term obligations l), (y, specialise_term obligations r)) Source.Case (specialise_term note s, (x, specialise_term note l), (y, specialise_term note r))
| other -> other | other -> other
let rec inline_cost = function let rec inline_cost = function
@@ -82,47 +79,55 @@ let rec inline_cost = function
| Source.LetPair (_, _, a, b) -> 1 + inline_cost a + inline_cost b | Source.LetPair (_, _, a, b) -> 1 + inline_cost a + inline_cost b
| Source.Case (s, (_, l), (_, r)) -> 1 + inline_cost s + inline_cost l + inline_cost r | Source.Case (s, (_, l), (_, r)) -> 1 + inline_cost s + inline_cost l + inline_cost r
let rec inline_term obligations unsafe_repr_eq term = let rec inline_term note unsafe_repr_eq term =
match term with match term with
| Source.App (Source.Lam (x, _, _, body), arg) when inline_cost body <= 8 -> | Source.App (Source.Lam (x, _, _, body), arg) when inline_cost body <= 8 ->
Source.substitute x arg (inline_term obligations unsafe_repr_eq body) note Preserve_relation "term beta inlining" ("inlined lambda argument " ^ x);
Source.substitute x arg (inline_term note unsafe_repr_eq body)
| Source.TyApp (Source.TyLam (x, body), ty) -> | Source.TyApp (Source.TyLam (x, body), ty) ->
Source.substitute_type x ty (inline_term obligations unsafe_repr_eq body) note Preserve_relation "type beta inlining" ("inlined type argument " ^ Types.string_of_typ ty ^ " for " ^ x);
Source.substitute_type x ty (inline_term note unsafe_repr_eq body)
| Source.TyApp (Source.Var f, TInt) when unsafe_repr_eq && String.equal f "poly_const_false" -> | Source.TyApp (Source.Var f, TInt) when unsafe_repr_eq && String.equal f "poly_const_false" ->
ignore obligations; note Exposed_representation "parametric body exposure" "rewrote an abstract int instance into eq_int";
Source.Lam ("x", TInt, TArrow (TInt, TBool), Source.Lam ("y", TInt, TBool, Source.Eq (GInt, Source.Var "x", Source.Var "y"))) Source.Lam ("x", TInt, TArrow (TInt, TBool), Source.Lam ("y", TInt, TBool, Source.Eq (GInt, Source.Var "x", Source.Var "y")))
| Source.Pair (a, b) -> Source.Pair (inline_term obligations unsafe_repr_eq a, inline_term obligations unsafe_repr_eq b) | Source.Pair (a, b) -> Source.Pair (inline_term note unsafe_repr_eq a, inline_term note unsafe_repr_eq b)
| Source.Inl (ty, t) -> Source.Inl (ty, inline_term obligations unsafe_repr_eq t) | Source.Inl (ty, t) -> Source.Inl (ty, inline_term note unsafe_repr_eq t)
| Source.Inr (ty, t) -> Source.Inr (ty, inline_term obligations unsafe_repr_eq t) | Source.Inr (ty, t) -> Source.Inr (ty, inline_term note unsafe_repr_eq t)
| Source.Lam (x, a, b, body) -> Source.Lam (x, a, b, inline_term obligations unsafe_repr_eq body) | Source.Lam (x, a, b, body) -> Source.Lam (x, a, b, inline_term note unsafe_repr_eq body)
| Source.App (f, a) -> | Source.App (f, a) ->
let f' = inline_term obligations unsafe_repr_eq f in let f' = inline_term note unsafe_repr_eq f in
let a' = inline_term obligations unsafe_repr_eq a in let a' = inline_term note unsafe_repr_eq a in
begin match f' with begin match f' with
| Source.Lam (x, _, _, body) when inline_cost body <= 8 -> | Source.Lam (x, _, _, body) when inline_cost body <= 8 ->
inline_term obligations unsafe_repr_eq (Source.substitute x a' body) note Preserve_relation "term beta inlining" ("inlined lambda argument " ^ x);
inline_term note unsafe_repr_eq (Source.substitute x a' body)
| _ -> Source.App (f', a') | _ -> Source.App (f', a')
end end
| Source.TyLam (x, body) -> Source.TyLam (x, inline_term obligations unsafe_repr_eq body) | Source.TyLam (x, body) -> Source.TyLam (x, inline_term note unsafe_repr_eq body)
| Source.TyApp (t, ty) -> | Source.TyApp (t, ty) ->
let t' = inline_term obligations unsafe_repr_eq t in let t' = inline_term note unsafe_repr_eq t in
begin match t' with begin match t' with
| Source.TyLam (x, body) -> inline_term obligations unsafe_repr_eq (Source.substitute_type x ty body) | Source.TyLam (x, body) ->
note Preserve_relation "type beta inlining" ("inlined type argument " ^ Types.string_of_typ ty ^ " for " ^ x);
inline_term note unsafe_repr_eq (Source.substitute_type x ty body)
| _ -> Source.TyApp (t', ty) | _ -> Source.TyApp (t', ty)
end end
| Source.Roll (ty, t) -> Source.Roll (ty, inline_term obligations unsafe_repr_eq t) | Source.Roll (ty, t) -> Source.Roll (ty, inline_term note unsafe_repr_eq t)
| Source.Unroll t -> Source.Unroll (inline_term obligations unsafe_repr_eq t) | Source.Unroll t -> Source.Unroll (inline_term note unsafe_repr_eq t)
| Source.Fix (f, ty, body) -> Source.Fix (f, ty, inline_term obligations unsafe_repr_eq body) | Source.Fix (f, ty, body) -> Source.Fix (f, ty, inline_term note unsafe_repr_eq body)
| Source.Eq (g, a, b) -> Source.Eq (g, inline_term obligations unsafe_repr_eq a, inline_term obligations unsafe_repr_eq b) | Source.Eq (g, a, b) -> Source.Eq (g, inline_term note unsafe_repr_eq a, inline_term note unsafe_repr_eq b)
| Source.If (c, t, e) -> | Source.If (c, t, e) ->
Source.If (inline_term obligations unsafe_repr_eq c, inline_term obligations unsafe_repr_eq t, inline_term obligations unsafe_repr_eq e) Source.If (inline_term note unsafe_repr_eq c, inline_term note unsafe_repr_eq t, inline_term note unsafe_repr_eq e)
| Source.Let (x, a, b) -> | Source.Let (x, a, b) ->
let a' = inline_term obligations unsafe_repr_eq a in let a' = inline_term note unsafe_repr_eq a in
let b' = inline_term obligations unsafe_repr_eq b in let b' = inline_term note unsafe_repr_eq b in
if inline_cost a' <= 4 then Source.substitute x a' b' else Source.Let (x, a', b') if inline_cost a' <= 4 then begin
| Source.LetPair (x, y, a, b) -> Source.LetPair (x, y, inline_term obligations unsafe_repr_eq a, inline_term obligations unsafe_repr_eq b) note Preserve_relation "let inlining" ("inlined binding " ^ x);
Source.substitute x a' b'
end else Source.Let (x, a', b')
| Source.LetPair (x, y, a, b) -> Source.LetPair (x, y, inline_term note unsafe_repr_eq a, inline_term note unsafe_repr_eq b)
| Source.Case (s, (x, l), (y, r)) -> | Source.Case (s, (x, l), (y, r)) ->
Source.Case (inline_term obligations unsafe_repr_eq s, (x, inline_term obligations unsafe_repr_eq l), (y, inline_term obligations unsafe_repr_eq r)) Source.Case (inline_term note unsafe_repr_eq s, (x, inline_term note unsafe_repr_eq l), (y, inline_term note unsafe_repr_eq r))
| other -> other | other -> other
let rec repr_of_typ = function let rec repr_of_typ = function
@@ -198,17 +203,12 @@ let compile flags source_type source_program =
obligations := obligation ~kind pass subject detail :: !obligations obligations := obligation ~kind pass subject detail :: !obligations
in in
let specialised = let specialised =
if flags.specialise then begin if flags.specialise then specialise_term (note "specialise") source_program
note "specialise" Preserve_relation "type application cloning" "specialisation preserves source typing but requires instantiation closedness"; else source_program
specialise_term obligations source_program
end else source_program
in in
let inlined = let inlined =
if flags.inline then begin if flags.inline then inline_term (note "inline") flags.unsafe_repr_eq specialised
if flags.unsafe_repr_eq then else specialised
note "inline" Exposed_representation "parametric body exposure" "inlining may replace an abstract constant relation with eq_int at specialised int instances";
inline_term obligations flags.unsafe_repr_eq specialised
end else specialised
in in
let target_program = let target_program =
if flags.repr_lower then begin if flags.repr_lower then begin
+13
View File
@@ -18,6 +18,9 @@ let find_case name =
| Some case -> case | Some case -> case
| None -> failwith ("missing case " ^ name) | None -> failwith ("missing case " ^ name)
let has_obligation kind obligations =
List.exists (fun (o : Pipeline.obligation) -> o.kind = kind) obligations
let () = let () =
List.iter List.iter
(fun (case : Corpus.case) -> (fun (case : Corpus.case) ->
@@ -33,6 +36,16 @@ let () =
assert_true assert_true
"expected representation exposure witness" "expected representation exposure witness"
(repr.failure_mode = Audit.Representation_exposure); (repr.failure_mode = Audit.Representation_exposure);
assert_true
"expected unsafe inlining to emit representation exposure obligation"
(has_obligation Pipeline.Exposed_representation repr.compiled.obligations);
let safe = Audit.audit_case (find_case "safe-polymorphic-instantiation") in
assert_true
"safe instantiation should not emit representation exposure obligation"
(not (has_obligation Pipeline.Exposed_representation safe.compiled.obligations));
assert_true
"safe instantiation should emit preserve relation obligations at rewrite sites"
(has_obligation Pipeline.Preserve_relation safe.compiled.obligations);
let strict = Audit.audit_case (find_case "strictness-induced-termination-change") in let strict = Audit.audit_case (find_case "strictness-induced-termination-change") in
assert_true assert_true
"expected strictness shift witness" "expected strictness shift witness"