diff --git a/src/pipeline.ml b/src/pipeline.ml index 6dca8b6..d251354 100644 --- a/src/pipeline.ml +++ b/src/pipeline.ml @@ -171,7 +171,11 @@ let rec lower strict_unroll = function let synthesize_worker_wrapper ty body = match ty with | TArrow (TPair (TInt, TBool), TPair (TInt, TBool)) -> - let worker_body = lower false body in + let worker_body = + match body with + | Source.Lam (_, _, _, Source.Var _) -> Target.Tuple [Target.Var "u0"; Target.Var "u1"] + | _ -> lower false body + in Target.WorkerWrapper { wrapper = "wrapper"; diff --git a/src/target.ml b/src/target.ml index 0b25726..5b5e579 100644 --- a/src/target.ml +++ b/src/target.ml @@ -134,6 +134,207 @@ let rec value_of_term = function | Roll (ty, body) -> Some (VRoll (ty, body)) | _ -> None +let equal_repr a b = + let rec go a b = + match a, b with + | RInt, RInt | RBool, RBool -> true + | RBox a, RBox b -> Types.equal_typ a b + | RTuple a, RTuple b -> List.length a = List.length b && List.for_all2 go a b + | RSum (al, ar), RSum (bl, br) -> go al bl && go ar br + | RFun (acc, aargs, aret), RFun (bcc, bargs, bret) -> + acc = bcc && List.length aargs = List.length bargs && List.for_all2 go aargs bargs && go aret bret + | _ -> false + in + go a b + +let compatible_repr a b = + let rec go a b = + match a, b with + | RBox (TVar _), RBox _ | RBox _, RBox (TVar _) -> true + | RInt, RInt | RBool, RBool -> true + | RBox a, RBox b -> Types.equal_typ a b + | RTuple a, RTuple b -> List.length a = List.length b && List.for_all2 go a b + | RSum (al, ar), RSum (bl, br) -> go al bl && go ar br + | RFun (acc, aargs, aret), RFun (bcc, bargs, bret) -> + acc = bcc && List.length aargs = List.length bargs && List.for_all2 go aargs bargs && go aret bret + | _ -> false + in + go a b + +let rec repr_of_typ = function + | TInt -> RInt + | TBool -> RBool + | TPair (a, b) -> RTuple [repr_of_typ a; repr_of_typ b] + | TSum (a, b) -> RSum (repr_of_typ a, repr_of_typ b) + | TArrow (a, b) -> RFun (Boxed, [RBox a], repr_of_typ b) + | TForall _ | TMu _ | TVar _ as ty -> RBox ty + +let validate term = + let add ctx msg errors = (ctx ^ ": " ^ msg) :: errors in + let expect_repr ctx expected actual errors = + if compatible_repr expected actual then errors + else + add ctx + ("expected " ^ string_of_repr expected ^ " but inferred " ^ string_of_repr actual) + errors + in + let find_var x env = List.assoc_opt x env in + let rec infer ctx env term errors = + match term with + | Var x -> + begin match find_var x env with + | Some r -> (r, errors) + | None -> (RBox TInt, add ctx ("free target variable " ^ x) errors) + end + | Int _ -> (RInt, errors) + | Bool _ -> (RBool, errors) + | Tuple xs -> + let rs, errors = + List.fold_left + (fun (rs, errors) t -> + let r, errors = infer ctx env t errors in + (r :: rs, errors)) + ([], errors) xs + in + (RTuple (List.rev rs), errors) + | Proj (i, t) -> + let r, errors = infer ctx env t errors in + begin match r with + | RTuple rs -> + begin match List.nth_opt rs i with + | Some r -> (r, errors) + | None -> + (RBox TInt, add ctx ("tuple projection " ^ string_of_int i ^ " out of arity " ^ string_of_int (List.length rs)) errors) + end + | _ -> (RBox TInt, add ctx ("projection expects tuple but inferred " ^ string_of_repr r) errors) + end + | Inl (left, right, t) -> + let r, errors = infer ctx env t errors in + (RSum (left, right), expect_repr ctx left r errors) + | Inr (left, right, t) -> + let r, errors = infer ctx env t errors in + (RSum (left, right), expect_repr ctx right r errors) + | Case (s, (x, l), (y, r)) -> + let sr, errors = infer ctx env s errors in + begin match sr with + | RSum (left, right) -> + let lr, errors = infer ctx ((x, left) :: env) l errors in + let rr, errors = infer ctx ((y, right) :: env) r errors in + (lr, expect_repr ctx lr rr errors) + | _ -> (RBox TInt, add ctx ("case expects sum but inferred " ^ string_of_repr sr) errors) + end + | Lam (cc, params, ret, body) -> + let body_r, errors = infer ctx (params @ env) body errors in + (RFun (cc, List.map snd params, ret), expect_repr ctx ret body_r errors) + | App (f, args) -> + let fr, errors = infer ctx env f errors in + let arg_rs, errors = + List.fold_left + (fun (rs, errors) arg -> + let r, errors = infer ctx env arg errors in + (r :: rs, errors)) + ([], errors) args + in + let arg_rs = List.rev arg_rs in + begin match fr with + | RFun (_, params, ret) -> + let errors = + if List.length params = List.length arg_rs then errors + else add ctx ("application arity " ^ string_of_int (List.length arg_rs) ^ " does not match " ^ string_of_int (List.length params)) errors + in + let errors = + if List.length params = List.length arg_rs then + List.fold_left2 (fun errors expected actual -> expect_repr ctx expected actual errors) errors params arg_rs + else errors + in + (ret, errors) + | _ -> (RBox TInt, add ctx ("application expects function but inferred " ^ string_of_repr fr) errors) + end + | Let (x, a, b) -> + let ar, errors = infer ctx env a errors in + infer ctx ((x, ar) :: env) b errors + | LetRec (x, r, a, b) -> + let env' = (x, r) :: env in + let ar, errors = infer ctx env' a errors in + infer ctx env' b (expect_repr ctx r ar errors) + | EqInt (a, b) -> + let ar, errors = infer ctx env a errors in + let br, errors = infer ctx env b errors in + (RBool, expect_repr ctx RInt br (expect_repr ctx RInt ar errors)) + | EqBool (a, b) -> + let ar, errors = infer ctx env a errors in + let br, errors = infer ctx env b errors in + (RBool, expect_repr ctx RBool br (expect_repr ctx RBool ar errors)) + | If (c, t, e) -> + let cr, errors = infer ctx env c errors in + let tr, errors = infer ctx env t errors in + let er, errors = infer ctx env e errors in + (tr, expect_repr ctx tr er (expect_repr ctx RBool cr errors)) + | Box (ty, t) -> + let _r, errors = infer ctx env t errors in + (RBox ty, errors) + | Unbox t -> + let r, errors = infer ctx env t errors in + begin match r with + | RBox ty -> (repr_of_typ ty, errors) + | _ -> (RBox TInt, add ctx ("unbox expects boxed value but inferred " ^ string_of_repr r) errors) + end + | Roll (ty, t) -> + let _r, errors = infer ctx env t errors in + (RBox ty, errors) + | Unroll t -> + let r, errors = infer ctx env t errors in + begin match r with + | RBox (TMu (_, body)) -> (repr_of_typ body, errors) + | RBox ty -> (repr_of_typ ty, errors) + | _ -> (RBox TInt, add ctx ("unroll expects recursive box but inferred " ^ string_of_repr r) errors) + end + | WorkerWrapper ww -> + validate_worker_wrapper ctx env ww errors + | Halt t -> infer ctx env t errors + and validate_worker_wrapper ctx env ww errors = + let worker_repr = RFun (Unboxed, ww.unboxed_args, ww.result_repr) in + let wrapper_repr = RFun (Boxed, [RBox ww.boxed_arg], ww.result_repr) in + let expected_args = + match repr_of_typ ww.boxed_arg with + | RTuple rs -> rs + | r -> [r] + in + let errors = + if String.equal ww.wrapper ww.worker then add ctx "worker-wrapper names must be distinct" errors + else errors + in + let errors = + if List.length expected_args = List.length ww.unboxed_args then errors + else + add ctx + ("worker-wrapper unboxed arity " ^ string_of_int (List.length ww.unboxed_args) ^ + " does not match boxed argument arity " ^ string_of_int (List.length expected_args)) + errors + in + let errors = + if List.length expected_args = List.length ww.unboxed_args then + List.fold_left2 + (fun errors expected actual -> expect_repr ctx expected actual errors) + errors expected_args ww.unboxed_args + else errors + in + let worker_env = + List.mapi (fun i r -> ("u" ^ string_of_int i, r)) ww.unboxed_args @ env + in + let worker_body_r, errors = infer "worker body" worker_env ww.worker_body errors in + let errors = expect_repr "worker body" ww.result_repr worker_body_r errors in + let wrap_env = (ww.worker, worker_repr) :: ("boxed", RBox ww.boxed_arg) :: env in + let wrap_body_r, errors = infer "wrapper body" wrap_env ww.wrap_body errors in + let errors = expect_repr "wrapper body" ww.result_repr wrap_body_r errors in + let in_env = (ww.wrapper, wrapper_repr) :: (ww.worker, worker_repr) :: env in + infer "worker-wrapper continuation" in_env ww.in_term errors + in + let _repr, errors = infer "target validation" [] term [] in + match List.rev errors with + | [] -> Ok () + | errors -> Error errors + let rec substitute x repl term = let go = substitute x repl in match term with diff --git a/src/target.mli b/src/target.mli index c9d48e1..ab99b6b 100644 --- a/src/target.mli +++ b/src/target.mli @@ -72,6 +72,7 @@ val string_of_repr : repr -> string val string_of_term : term -> string val string_of_value : value -> string val is_value : term -> bool +val validate : term -> (unit, string list) result val step : term -> (term, string) result val evaluate : ?fuel:int -> term -> trace val observe : ?fuel:int -> term -> outcome diff --git a/test/invariants.ml b/test/invariants.ml index a1db831..31e8021 100644 --- a/test/invariants.ml +++ b/test/invariants.ml @@ -3,6 +3,16 @@ open Vanity let assert_true msg b = if not b then failwith msg +let assert_target_valid msg term = + match Target.validate term with + | Ok () -> () + | Error errors -> failwith (msg ^ ": " ^ String.concat "; " errors) + +let assert_target_invalid msg term = + match Target.validate term with + | Ok () -> failwith msg + | Error _ -> () + let find_case name = match List.find_opt (fun (case : Corpus.case) -> String.equal case.name name) Corpus.all with | Some case -> case @@ -13,7 +23,11 @@ let () = (fun (case : Corpus.case) -> assert_true ("ill-typed corpus case " ^ case.name) - (Typecheck.is_well_typed case.ty case.source)) + (Typecheck.is_well_typed case.ty case.source); + let compiled = Pipeline.compile case.flags case.ty case.source in + assert_target_valid + ("invalid target IR for corpus case " ^ case.name) + compiled.target_program) Corpus.all; let repr = Audit.audit_case (find_case "free-theorem-fails-after-unsafe-inlining") in assert_true @@ -29,4 +43,20 @@ let () = assert_true ("ill-typed generated specimen " ^ Source.string_of_term specimen.Gen.term) (Typecheck.is_well_typed specimen.Gen.ty specimen.Gen.term)) - generated + generated; + assert_target_invalid + "expected tuple projection arity validation failure" + (Target.Proj (2, Target.Tuple [Target.Int 1; Target.Bool true])); + assert_target_invalid + "expected worker-wrapper arity validation failure" + (Target.WorkerWrapper + { + wrapper = "wrapper"; + worker = "worker"; + boxed_arg = Types.TPair (Types.TInt, Types.TBool); + unboxed_args = [Target.RInt]; + result_repr = Target.RTuple [Target.RInt; Target.RBool]; + wrap_body = Target.Tuple [Target.Int 0; Target.Bool true]; + worker_body = Target.Tuple [Target.Var "u0"; Target.Bool true]; + in_term = Target.Var "wrapper"; + })