reject unsupported worker-wrapper source shapes

This commit is contained in:
2026-02-15 12:41:55 +00:00
parent ba72a797e5
commit 5bf83f0933
2 changed files with 26 additions and 10 deletions
+7 -9
View File
@@ -174,13 +174,9 @@ let rec lower strict_unroll = function
| Source.Case (s, (x, l), (y, r)) -> Target.Case (lower strict_unroll s, (x, lower strict_unroll l), (y, lower strict_unroll r))
let synthesize_worker_wrapper ty body =
match ty with
| TArrow (TPair (TInt, TBool), TPair (TInt, TBool)) ->
let worker_body =
match body with
| Source.Lam (_, _, _, Source.Var _) -> Target.Tuple [Target.Var "u0"; Target.Var "u1"]
| _ -> lower false body
in
match ty, body with
| TArrow (TPair (TInt, TBool), TPair (TInt, TBool)), Source.Lam (arg, _, _, Source.Var result)
when String.equal arg result ->
Target.WorkerWrapper
{
wrapper = "wrapper";
@@ -192,10 +188,12 @@ let synthesize_worker_wrapper ty body =
Target.Let
("p", Target.Unbox (Target.Var "boxed"),
Target.App (Target.Var "worker", [Target.Proj (0, Target.Var "p"); Target.Proj (1, Target.Var "p")]));
worker_body;
worker_body = Target.Tuple [Target.Var "u0"; Target.Var "u1"];
in_term = Target.Var "wrapper";
}
| _ -> lower false body
| TArrow (TPair (TInt, TBool), TPair (TInt, TBool)), _ ->
invalid_arg "unsupported worker-wrapper source shape"
| _ -> invalid_arg "unsupported worker-wrapper type"
let compile flags source_type source_program =
let obligations = ref [] in
+19 -1
View File
@@ -13,6 +13,13 @@ let assert_target_invalid msg term =
| Ok () -> failwith msg
| Error _ -> ()
let assert_invalid_arg msg f =
try
let _ = f () in
failwith msg
with
| Invalid_argument _ -> ()
let find_case name =
match List.find_opt (fun (case : Corpus.case) -> String.equal case.name name) Corpus.all with
| Some case -> case
@@ -72,4 +79,15 @@ let () =
wrap_body = Target.Tuple [Target.Int 0; Target.Bool true];
worker_body = Target.Tuple [Target.Var "u0"; Target.Bool true];
in_term = Target.Var "wrapper";
})
});
assert_invalid_arg
"expected unsupported worker-wrapper source shape to be rejected"
(fun () ->
Pipeline.compile
{ Pipeline.default_flags with unsafe_repr_eq = false; unsafe_strict_unroll = false }
(Types.TArrow (Types.TPair (Types.TInt, Types.TBool), Types.TPair (Types.TInt, Types.TBool)))
(Source.Lam
( "p",
Types.TPair (Types.TInt, Types.TBool),
Types.TPair (Types.TInt, Types.TBool),
Source.Pair (Source.Int 0, Source.Bool true) )))