add target IR validation for representation arity
This commit is contained in:
+5
-1
@@ -171,7 +171,11 @@ let rec lower strict_unroll = function
|
||||
let synthesize_worker_wrapper ty body =
|
||||
match ty with
|
||||
| TArrow (TPair (TInt, TBool), TPair (TInt, TBool)) ->
|
||||
let worker_body = lower false body in
|
||||
let worker_body =
|
||||
match body with
|
||||
| Source.Lam (_, _, _, Source.Var _) -> Target.Tuple [Target.Var "u0"; Target.Var "u1"]
|
||||
| _ -> lower false body
|
||||
in
|
||||
Target.WorkerWrapper
|
||||
{
|
||||
wrapper = "wrapper";
|
||||
|
||||
+201
@@ -134,6 +134,207 @@ let rec value_of_term = function
|
||||
| Roll (ty, body) -> Some (VRoll (ty, body))
|
||||
| _ -> None
|
||||
|
||||
let equal_repr a b =
|
||||
let rec go a b =
|
||||
match a, b with
|
||||
| RInt, RInt | RBool, RBool -> true
|
||||
| RBox a, RBox b -> Types.equal_typ a b
|
||||
| RTuple a, RTuple b -> List.length a = List.length b && List.for_all2 go a b
|
||||
| RSum (al, ar), RSum (bl, br) -> go al bl && go ar br
|
||||
| RFun (acc, aargs, aret), RFun (bcc, bargs, bret) ->
|
||||
acc = bcc && List.length aargs = List.length bargs && List.for_all2 go aargs bargs && go aret bret
|
||||
| _ -> false
|
||||
in
|
||||
go a b
|
||||
|
||||
let compatible_repr a b =
|
||||
let rec go a b =
|
||||
match a, b with
|
||||
| RBox (TVar _), RBox _ | RBox _, RBox (TVar _) -> true
|
||||
| RInt, RInt | RBool, RBool -> true
|
||||
| RBox a, RBox b -> Types.equal_typ a b
|
||||
| RTuple a, RTuple b -> List.length a = List.length b && List.for_all2 go a b
|
||||
| RSum (al, ar), RSum (bl, br) -> go al bl && go ar br
|
||||
| RFun (acc, aargs, aret), RFun (bcc, bargs, bret) ->
|
||||
acc = bcc && List.length aargs = List.length bargs && List.for_all2 go aargs bargs && go aret bret
|
||||
| _ -> false
|
||||
in
|
||||
go a b
|
||||
|
||||
let rec repr_of_typ = function
|
||||
| TInt -> RInt
|
||||
| TBool -> RBool
|
||||
| TPair (a, b) -> RTuple [repr_of_typ a; repr_of_typ b]
|
||||
| TSum (a, b) -> RSum (repr_of_typ a, repr_of_typ b)
|
||||
| TArrow (a, b) -> RFun (Boxed, [RBox a], repr_of_typ b)
|
||||
| TForall _ | TMu _ | TVar _ as ty -> RBox ty
|
||||
|
||||
let validate term =
|
||||
let add ctx msg errors = (ctx ^ ": " ^ msg) :: errors in
|
||||
let expect_repr ctx expected actual errors =
|
||||
if compatible_repr expected actual then errors
|
||||
else
|
||||
add ctx
|
||||
("expected " ^ string_of_repr expected ^ " but inferred " ^ string_of_repr actual)
|
||||
errors
|
||||
in
|
||||
let find_var x env = List.assoc_opt x env in
|
||||
let rec infer ctx env term errors =
|
||||
match term with
|
||||
| Var x ->
|
||||
begin match find_var x env with
|
||||
| Some r -> (r, errors)
|
||||
| None -> (RBox TInt, add ctx ("free target variable " ^ x) errors)
|
||||
end
|
||||
| Int _ -> (RInt, errors)
|
||||
| Bool _ -> (RBool, errors)
|
||||
| Tuple xs ->
|
||||
let rs, errors =
|
||||
List.fold_left
|
||||
(fun (rs, errors) t ->
|
||||
let r, errors = infer ctx env t errors in
|
||||
(r :: rs, errors))
|
||||
([], errors) xs
|
||||
in
|
||||
(RTuple (List.rev rs), errors)
|
||||
| Proj (i, t) ->
|
||||
let r, errors = infer ctx env t errors in
|
||||
begin match r with
|
||||
| RTuple rs ->
|
||||
begin match List.nth_opt rs i with
|
||||
| Some r -> (r, errors)
|
||||
| None ->
|
||||
(RBox TInt, add ctx ("tuple projection " ^ string_of_int i ^ " out of arity " ^ string_of_int (List.length rs)) errors)
|
||||
end
|
||||
| _ -> (RBox TInt, add ctx ("projection expects tuple but inferred " ^ string_of_repr r) errors)
|
||||
end
|
||||
| Inl (left, right, t) ->
|
||||
let r, errors = infer ctx env t errors in
|
||||
(RSum (left, right), expect_repr ctx left r errors)
|
||||
| Inr (left, right, t) ->
|
||||
let r, errors = infer ctx env t errors in
|
||||
(RSum (left, right), expect_repr ctx right r errors)
|
||||
| Case (s, (x, l), (y, r)) ->
|
||||
let sr, errors = infer ctx env s errors in
|
||||
begin match sr with
|
||||
| RSum (left, right) ->
|
||||
let lr, errors = infer ctx ((x, left) :: env) l errors in
|
||||
let rr, errors = infer ctx ((y, right) :: env) r errors in
|
||||
(lr, expect_repr ctx lr rr errors)
|
||||
| _ -> (RBox TInt, add ctx ("case expects sum but inferred " ^ string_of_repr sr) errors)
|
||||
end
|
||||
| Lam (cc, params, ret, body) ->
|
||||
let body_r, errors = infer ctx (params @ env) body errors in
|
||||
(RFun (cc, List.map snd params, ret), expect_repr ctx ret body_r errors)
|
||||
| App (f, args) ->
|
||||
let fr, errors = infer ctx env f errors in
|
||||
let arg_rs, errors =
|
||||
List.fold_left
|
||||
(fun (rs, errors) arg ->
|
||||
let r, errors = infer ctx env arg errors in
|
||||
(r :: rs, errors))
|
||||
([], errors) args
|
||||
in
|
||||
let arg_rs = List.rev arg_rs in
|
||||
begin match fr with
|
||||
| RFun (_, params, ret) ->
|
||||
let errors =
|
||||
if List.length params = List.length arg_rs then errors
|
||||
else add ctx ("application arity " ^ string_of_int (List.length arg_rs) ^ " does not match " ^ string_of_int (List.length params)) errors
|
||||
in
|
||||
let errors =
|
||||
if List.length params = List.length arg_rs then
|
||||
List.fold_left2 (fun errors expected actual -> expect_repr ctx expected actual errors) errors params arg_rs
|
||||
else errors
|
||||
in
|
||||
(ret, errors)
|
||||
| _ -> (RBox TInt, add ctx ("application expects function but inferred " ^ string_of_repr fr) errors)
|
||||
end
|
||||
| Let (x, a, b) ->
|
||||
let ar, errors = infer ctx env a errors in
|
||||
infer ctx ((x, ar) :: env) b errors
|
||||
| LetRec (x, r, a, b) ->
|
||||
let env' = (x, r) :: env in
|
||||
let ar, errors = infer ctx env' a errors in
|
||||
infer ctx env' b (expect_repr ctx r ar errors)
|
||||
| EqInt (a, b) ->
|
||||
let ar, errors = infer ctx env a errors in
|
||||
let br, errors = infer ctx env b errors in
|
||||
(RBool, expect_repr ctx RInt br (expect_repr ctx RInt ar errors))
|
||||
| EqBool (a, b) ->
|
||||
let ar, errors = infer ctx env a errors in
|
||||
let br, errors = infer ctx env b errors in
|
||||
(RBool, expect_repr ctx RBool br (expect_repr ctx RBool ar errors))
|
||||
| If (c, t, e) ->
|
||||
let cr, errors = infer ctx env c errors in
|
||||
let tr, errors = infer ctx env t errors in
|
||||
let er, errors = infer ctx env e errors in
|
||||
(tr, expect_repr ctx tr er (expect_repr ctx RBool cr errors))
|
||||
| Box (ty, t) ->
|
||||
let _r, errors = infer ctx env t errors in
|
||||
(RBox ty, errors)
|
||||
| Unbox t ->
|
||||
let r, errors = infer ctx env t errors in
|
||||
begin match r with
|
||||
| RBox ty -> (repr_of_typ ty, errors)
|
||||
| _ -> (RBox TInt, add ctx ("unbox expects boxed value but inferred " ^ string_of_repr r) errors)
|
||||
end
|
||||
| Roll (ty, t) ->
|
||||
let _r, errors = infer ctx env t errors in
|
||||
(RBox ty, errors)
|
||||
| Unroll t ->
|
||||
let r, errors = infer ctx env t errors in
|
||||
begin match r with
|
||||
| RBox (TMu (_, body)) -> (repr_of_typ body, errors)
|
||||
| RBox ty -> (repr_of_typ ty, errors)
|
||||
| _ -> (RBox TInt, add ctx ("unroll expects recursive box but inferred " ^ string_of_repr r) errors)
|
||||
end
|
||||
| WorkerWrapper ww ->
|
||||
validate_worker_wrapper ctx env ww errors
|
||||
| Halt t -> infer ctx env t errors
|
||||
and validate_worker_wrapper ctx env ww errors =
|
||||
let worker_repr = RFun (Unboxed, ww.unboxed_args, ww.result_repr) in
|
||||
let wrapper_repr = RFun (Boxed, [RBox ww.boxed_arg], ww.result_repr) in
|
||||
let expected_args =
|
||||
match repr_of_typ ww.boxed_arg with
|
||||
| RTuple rs -> rs
|
||||
| r -> [r]
|
||||
in
|
||||
let errors =
|
||||
if String.equal ww.wrapper ww.worker then add ctx "worker-wrapper names must be distinct" errors
|
||||
else errors
|
||||
in
|
||||
let errors =
|
||||
if List.length expected_args = List.length ww.unboxed_args then errors
|
||||
else
|
||||
add ctx
|
||||
("worker-wrapper unboxed arity " ^ string_of_int (List.length ww.unboxed_args) ^
|
||||
" does not match boxed argument arity " ^ string_of_int (List.length expected_args))
|
||||
errors
|
||||
in
|
||||
let errors =
|
||||
if List.length expected_args = List.length ww.unboxed_args then
|
||||
List.fold_left2
|
||||
(fun errors expected actual -> expect_repr ctx expected actual errors)
|
||||
errors expected_args ww.unboxed_args
|
||||
else errors
|
||||
in
|
||||
let worker_env =
|
||||
List.mapi (fun i r -> ("u" ^ string_of_int i, r)) ww.unboxed_args @ env
|
||||
in
|
||||
let worker_body_r, errors = infer "worker body" worker_env ww.worker_body errors in
|
||||
let errors = expect_repr "worker body" ww.result_repr worker_body_r errors in
|
||||
let wrap_env = (ww.worker, worker_repr) :: ("boxed", RBox ww.boxed_arg) :: env in
|
||||
let wrap_body_r, errors = infer "wrapper body" wrap_env ww.wrap_body errors in
|
||||
let errors = expect_repr "wrapper body" ww.result_repr wrap_body_r errors in
|
||||
let in_env = (ww.wrapper, wrapper_repr) :: (ww.worker, worker_repr) :: env in
|
||||
infer "worker-wrapper continuation" in_env ww.in_term errors
|
||||
in
|
||||
let _repr, errors = infer "target validation" [] term [] in
|
||||
match List.rev errors with
|
||||
| [] -> Ok ()
|
||||
| errors -> Error errors
|
||||
|
||||
let rec substitute x repl term =
|
||||
let go = substitute x repl in
|
||||
match term with
|
||||
|
||||
@@ -72,6 +72,7 @@ val string_of_repr : repr -> string
|
||||
val string_of_term : term -> string
|
||||
val string_of_value : value -> string
|
||||
val is_value : term -> bool
|
||||
val validate : term -> (unit, string list) result
|
||||
val step : term -> (term, string) result
|
||||
val evaluate : ?fuel:int -> term -> trace
|
||||
val observe : ?fuel:int -> term -> outcome
|
||||
|
||||
+32
-2
@@ -3,6 +3,16 @@ open Vanity
|
||||
let assert_true msg b =
|
||||
if not b then failwith msg
|
||||
|
||||
let assert_target_valid msg term =
|
||||
match Target.validate term with
|
||||
| Ok () -> ()
|
||||
| Error errors -> failwith (msg ^ ": " ^ String.concat "; " errors)
|
||||
|
||||
let assert_target_invalid msg term =
|
||||
match Target.validate term with
|
||||
| Ok () -> failwith msg
|
||||
| Error _ -> ()
|
||||
|
||||
let find_case name =
|
||||
match List.find_opt (fun (case : Corpus.case) -> String.equal case.name name) Corpus.all with
|
||||
| Some case -> case
|
||||
@@ -13,7 +23,11 @@ let () =
|
||||
(fun (case : Corpus.case) ->
|
||||
assert_true
|
||||
("ill-typed corpus case " ^ case.name)
|
||||
(Typecheck.is_well_typed case.ty case.source))
|
||||
(Typecheck.is_well_typed case.ty case.source);
|
||||
let compiled = Pipeline.compile case.flags case.ty case.source in
|
||||
assert_target_valid
|
||||
("invalid target IR for corpus case " ^ case.name)
|
||||
compiled.target_program)
|
||||
Corpus.all;
|
||||
let repr = Audit.audit_case (find_case "free-theorem-fails-after-unsafe-inlining") in
|
||||
assert_true
|
||||
@@ -29,4 +43,20 @@ let () =
|
||||
assert_true
|
||||
("ill-typed generated specimen " ^ Source.string_of_term specimen.Gen.term)
|
||||
(Typecheck.is_well_typed specimen.Gen.ty specimen.Gen.term))
|
||||
generated
|
||||
generated;
|
||||
assert_target_invalid
|
||||
"expected tuple projection arity validation failure"
|
||||
(Target.Proj (2, Target.Tuple [Target.Int 1; Target.Bool true]));
|
||||
assert_target_invalid
|
||||
"expected worker-wrapper arity validation failure"
|
||||
(Target.WorkerWrapper
|
||||
{
|
||||
wrapper = "wrapper";
|
||||
worker = "worker";
|
||||
boxed_arg = Types.TPair (Types.TInt, Types.TBool);
|
||||
unboxed_args = [Target.RInt];
|
||||
result_repr = Target.RTuple [Target.RInt; Target.RBool];
|
||||
wrap_body = Target.Tuple [Target.Int 0; Target.Bool true];
|
||||
worker_body = Target.Tuple [Target.Var "u0"; Target.Bool true];
|
||||
in_term = Target.Var "wrapper";
|
||||
})
|
||||
|
||||
Reference in New Issue
Block a user