diff --git a/src/pipeline.ml b/src/pipeline.ml index daaaa29..fa0de0f 100644 --- a/src/pipeline.ml +++ b/src/pipeline.ml @@ -174,13 +174,9 @@ let rec lower strict_unroll = function | Source.Case (s, (x, l), (y, r)) -> Target.Case (lower strict_unroll s, (x, lower strict_unroll l), (y, lower strict_unroll r)) let synthesize_worker_wrapper ty body = - match ty with - | TArrow (TPair (TInt, TBool), TPair (TInt, TBool)) -> - let worker_body = - match body with - | Source.Lam (_, _, _, Source.Var _) -> Target.Tuple [Target.Var "u0"; Target.Var "u1"] - | _ -> lower false body - in + match ty, body with + | TArrow (TPair (TInt, TBool), TPair (TInt, TBool)), Source.Lam (arg, _, _, Source.Var result) + when String.equal arg result -> Target.WorkerWrapper { wrapper = "wrapper"; @@ -192,10 +188,12 @@ let synthesize_worker_wrapper ty body = Target.Let ("p", Target.Unbox (Target.Var "boxed"), Target.App (Target.Var "worker", [Target.Proj (0, Target.Var "p"); Target.Proj (1, Target.Var "p")])); - worker_body; + worker_body = Target.Tuple [Target.Var "u0"; Target.Var "u1"]; in_term = Target.Var "wrapper"; } - | _ -> lower false body + | TArrow (TPair (TInt, TBool), TPair (TInt, TBool)), _ -> + invalid_arg "unsupported worker-wrapper source shape" + | _ -> invalid_arg "unsupported worker-wrapper type" let compile flags source_type source_program = let obligations = ref [] in diff --git a/test/invariants.ml b/test/invariants.ml index e09c649..5c643b8 100644 --- a/test/invariants.ml +++ b/test/invariants.ml @@ -13,6 +13,13 @@ let assert_target_invalid msg term = | Ok () -> failwith msg | Error _ -> () +let assert_invalid_arg msg f = + try + let _ = f () in + failwith msg + with + | Invalid_argument _ -> () + let find_case name = match List.find_opt (fun (case : Corpus.case) -> String.equal case.name name) Corpus.all with | Some case -> case @@ -72,4 +79,15 @@ let () = wrap_body = Target.Tuple [Target.Int 0; Target.Bool true]; worker_body = Target.Tuple [Target.Var "u0"; Target.Bool true]; in_term = Target.Var "wrapper"; - }) + }); + assert_invalid_arg + "expected unsupported worker-wrapper source shape to be rejected" + (fun () -> + Pipeline.compile + { Pipeline.default_flags with unsafe_repr_eq = false; unsafe_strict_unroll = false } + (Types.TArrow (Types.TPair (Types.TInt, Types.TBool), Types.TPair (Types.TInt, Types.TBool))) + (Source.Lam + ( "p", + Types.TPair (Types.TInt, Types.TBool), + Types.TPair (Types.TInt, Types.TBool), + Source.Pair (Source.Int 0, Source.Bool true) )))