commit 42cde2128a61a30d70dd23249e3b05da7ba74034 Author: imiel Date: Wed Feb 11 17:24:09 2026 +0000 inital commit diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..5e60301 --- /dev/null +++ b/.gitignore @@ -0,0 +1,3 @@ +_build/ +.codex +*.install diff --git a/README.md b/README.md new file mode 100644 index 0000000..22f48c2 --- /dev/null +++ b/README.md @@ -0,0 +1,36 @@ +# vanity + +executable artefact about parametricity failure under specialisation and inlining with representation lowering + +this project models a small typed source calculus and then lowers it into an explicit representation calculus and audits whether the source observation and target observation remain related. the surface is semantic rather than structural as its about abstraction theorems, representation relations, strictness shifts, and the proof obligations that an optimiser assumes when it rewrites polymorphic code + +the target language makes representation explicit and forces each transition to be spelled out in the term. values are either boxed or unboxed and can be projected from products or injected into sums as well as unpacked when needed and moved across worker wrapper boundaries as part of calling convention shifts + +the target calculus is intentionally lower level than the source and trades abstraction for control over layout and calling. it has decent tuple and sum representations and uses explicit box and unbox operations to mediate between them. equality on integers and booleans is primitive rather than encoded as well as recursive binding being built in rather than derived + +## example(s) + +a representation exposure witness starts from a polymorphic constant false term + +```ocaml +forall a. function a function a bool +``` + +under a safe profile instantiating this term at `bool` preserves the baseline relation. under an unsafe inline profile instantiating at `int` can expose integer equality and refute the abstraction theorem + +a strictness witness places a divergent recursive computation under `roll` + +```ocaml +roll mu a. int loop +``` + +the source observation can keep the recursive payload latent. a strict lowering path may force it during construction and change termination. + +a worker wrapper witness uses the explicit source type representation + +```ocaml +TArrow (TPair (TInt, TBool), TPair (TInt, TBool)) +``` + +the target worker receives `RInt` and `RBool` while the wrapper preserves the boxed source interface + diff --git a/bin/dune b/bin/dune new file mode 100644 index 0000000..50c26d3 --- /dev/null +++ b/bin/dune @@ -0,0 +1,5 @@ +(executable + (name main) + (public_name vanity-demo) + (libraries vanity)) + diff --git a/bin/main.ml b/bin/main.ml new file mode 100644 index 0000000..71bfdaa --- /dev/null +++ b/bin/main.ml @@ -0,0 +1,45 @@ +open Vanity + +type profile = { + name : string; + flags : Pipeline.optimisation_flags; + relation : Relation.relation; +} + +let safe_flags = + { Pipeline.default_flags with unsafe_repr_eq = false; unsafe_strict_unroll = false } + +let profiles = + [ + { name = "safe"; flags = safe_flags; relation = Relation.Boxed_unboxed }; + { name = "specialise-only"; flags = { safe_flags with inline = false; repr_lower = false }; relation = Relation.Baseline }; + { name = "inline-no-repr-leak"; flags = { safe_flags with inline = true }; relation = Relation.Boxed_unboxed }; + { name = "unsafe-inline"; flags = { safe_flags with unsafe_repr_eq = true }; relation = Relation.Boxed_unboxed }; + { name = "unsafe-unbox"; flags = Pipeline.default_flags; relation = Relation.Boxed_unboxed }; + { name = "unsafe-strictness"; flags = { safe_flags with inline = false; unsafe_strict_unroll = true }; relation = Relation.Boxed_unboxed }; + ] + +let run_case (case : Corpus.case) = + let audit = Audit.audit_case case in + Printf.printf + "case %s\nclassification %s\nverdict %s\nsource %s\ntarget %s\n\n" + case.Corpus.name + (Audit.failure_mode_to_string audit.Audit.failure_mode) + (Reporting.verdict_to_string audit.Audit.comparison.Relation.verdict) + (Reporting.string_of_source_outcome audit.Audit.source_trace.Source.outcome) + (Reporting.string_of_target_outcome audit.Audit.target_trace.Target.outcome) + +let run_profile profile = + let result = Gen.run_campaign profile.flags profile.relation ~count:160 ~max_depth:4 () in + Printf.printf + "profile %s\nrelated %d/%d\nviolations %d\n\n" + profile.name + result.Gen.related + result.Gen.total + (result.Gen.total - result.Gen.related) + +let () = + Printf.printf "corpus\n\n"; + List.iter run_case Corpus.all; + Printf.printf "profiles\n\n"; + List.iter run_profile profiles diff --git a/dune-project b/dune-project new file mode 100644 index 0000000..224a6e6 --- /dev/null +++ b/dune-project @@ -0,0 +1,4 @@ +(lang dune 3.22) +(name vanity) +(generate_opam_files true) + diff --git a/src/audit.ml b/src/audit.ml new file mode 100644 index 0000000..f04c9f7 --- /dev/null +++ b/src/audit.ml @@ -0,0 +1,149 @@ +type failure_mode = + | Preserved + | Representation_exposure + | Strictness_shift + | Type_error of string + | Other_failure of string + +type obligation_status = + | Discharged + | Assumed + | Refuted + +type obligation_result = { + obligation : Pipeline.obligation; + status : obligation_status; + note : string; +} + +type case_audit = { + case : Corpus.case; + compiled : Pipeline.compiled; + comparison : Relation.comparison; + source_trace : Source.trace; + specialised_trace : Source.trace; + inlined_trace : Source.trace; + target_trace : Target.trace; + typecheck : (unit, string) result; + failure_mode : failure_mode; + obligations : obligation_result list; +} + +let failure_mode_to_string = function + | Preserved -> "preserved" + | Representation_exposure -> "representation_exposure" + | Strictness_shift -> "strictness_shift" + | Type_error msg -> "type_error " ^ msg + | Other_failure msg -> "other_failure " ^ msg + +let obligation_status_to_string = function + | Discharged -> "discharged" + | Assumed -> "assumed" + | Refuted -> "refuted" + +let has_obligation kind obligations = + List.exists (fun (o : Pipeline.obligation) -> o.Pipeline.kind = kind) obligations + +let classify typecheck compiled comparison = + match typecheck with + | Error msg -> Type_error msg + | Ok () -> + begin match comparison.Relation.verdict with + | Relation.Related -> Preserved + | Relation.Unrelated msg when has_obligation Pipeline.Exposed_representation compiled.Pipeline.obligations -> + Representation_exposure + | Relation.Unrelated msg when String.equal msg "target diverged where source terminated" -> + Strictness_shift + | Relation.Unrelated msg when has_obligation Pipeline.Strictness_risk compiled.Pipeline.obligations -> + Strictness_shift + | Relation.Unrelated msg -> Other_failure msg + end + +let status_for_obligation failure_mode (obligation : Pipeline.obligation) = + match failure_mode, obligation.Pipeline.kind with + | Preserved, Pipeline.Preserve_relation -> + { obligation; status = Discharged; note = "source and target observations are related" } + | Preserved, Pipeline.Worker_wrapper_proof -> + { obligation; status = Discharged; note = "boxed and unboxed observations satisfy the cross language relation" } + | Representation_exposure, Pipeline.Exposed_representation -> + { obligation; status = Refuted; note = "counterexample exposes a representation specific equality primitive" } + | Strictness_shift, Pipeline.Strictness_risk -> + { obligation; status = Refuted; note = "counterexample changes termination by forcing a recursive payload" } + | Type_error msg, _ -> + { obligation; status = Refuted; note = "source program failed typechecking: " ^ msg } + | Other_failure msg, _ -> + { obligation; status = Assumed; note = "unclassified failure: " ^ msg } + | _, Pipeline.Preserve_relation -> + { obligation; status = Assumed; note = "requires the baseline abstraction theorem" } + | _, Pipeline.Worker_wrapper_proof -> + { obligation; status = Assumed; note = "requires pointwise boxed to unboxed correspondence" } + | _, Pipeline.Exposed_representation -> + { obligation; status = Assumed; note = "unsafe inlining may expose representation observers" } + | _, Pipeline.Strictness_risk -> + { obligation; status = Assumed; note = "unsafe lowering may shift strictness" } + +let audit_case (case : Corpus.case) = + let compiled = Pipeline.compile case.Corpus.flags case.Corpus.ty case.Corpus.source in + let comparison = + Relation.compare_programs case.Corpus.relation case.Corpus.ty case.Corpus.source compiled.Pipeline.target_program + in + let source_trace = Source.evaluate case.Corpus.source in + let specialised_trace = Source.evaluate compiled.Pipeline.specialised in + let inlined_trace = Source.evaluate compiled.Pipeline.inlined in + let target_trace = Target.evaluate compiled.Pipeline.target_program in + let typecheck = Typecheck.check case.Corpus.ty case.Corpus.source in + let failure_mode = classify typecheck compiled comparison in + let obligations = List.map (status_for_obligation failure_mode) compiled.Pipeline.obligations in + { case; compiled; comparison; source_trace; specialised_trace; inlined_trace; target_trace; typecheck; failure_mode; obligations } + +let source_outcome trace = Reporting.string_of_source_outcome trace.Source.outcome +let target_outcome trace = Reporting.string_of_target_outcome trace.Target.outcome + +let emit_obligation_result result = + let o = result.obligation in + "| `" ^ o.Pipeline.pass ^ "` | `" ^ Reporting.obligation_kind_to_string o.Pipeline.kind ^ + "` | `" ^ obligation_status_to_string result.status ^ "` | " ^ result.note ^ " |" + +let emit_case_audit audit = + let typecheck = + match audit.typecheck with + | Ok () -> "ok" + | Error msg -> "failed " ^ msg + in + String.concat "\n" + ([ + "## " ^ audit.case.Corpus.name; + ""; + "| field | value |"; + "| --- | --- |"; + "| typecheck | `" ^ typecheck ^ "` |"; + "| classification | `" ^ failure_mode_to_string audit.failure_mode ^ "` |"; + "| source | `" ^ source_outcome audit.source_trace ^ "` |"; + "| specialised | `" ^ source_outcome audit.specialised_trace ^ "` |"; + "| inlined | `" ^ source_outcome audit.inlined_trace ^ "` |"; + "| target | `" ^ target_outcome audit.target_trace ^ "` |"; + ""; + "### obligation ledger"; + ""; + "| pass | kind | status | note |"; + "| --- | --- | --- | --- |"; + ] + @ List.map emit_obligation_result audit.obligations) + +let emit_matrix audits = + let row audit = + "| " ^ audit.case.Corpus.name ^ + " | `" ^ failure_mode_to_string audit.failure_mode ^ + "` | `" ^ Reporting.verdict_to_string audit.comparison.Relation.verdict ^ + "` | `" ^ source_outcome audit.source_trace ^ + "` | `" ^ target_outcome audit.target_trace ^ "` |" + in + String.concat "\n" + ([ + "# audit matrix"; + ""; + "| case | classification | verdict | source outcome | target outcome |"; + "| --- | --- | --- | --- | --- |"; + ] + @ List.map row audits) + diff --git a/src/audit.mli b/src/audit.mli new file mode 100644 index 0000000..667f212 --- /dev/null +++ b/src/audit.mli @@ -0,0 +1,37 @@ +type failure_mode = + | Preserved + | Representation_exposure + | Strictness_shift + | Type_error of string + | Other_failure of string + +type obligation_status = + | Discharged + | Assumed + | Refuted + +type obligation_result = { + obligation : Pipeline.obligation; + status : obligation_status; + note : string; +} + +type case_audit = { + case : Corpus.case; + compiled : Pipeline.compiled; + comparison : Relation.comparison; + source_trace : Source.trace; + specialised_trace : Source.trace; + inlined_trace : Source.trace; + target_trace : Target.trace; + typecheck : (unit, string) result; + failure_mode : failure_mode; + obligations : obligation_result list; +} + +val audit_case : Corpus.case -> case_audit +val failure_mode_to_string : failure_mode -> string +val obligation_status_to_string : obligation_status -> string +val emit_case_audit : case_audit -> string +val emit_matrix : case_audit list -> string + diff --git a/src/corpus.ml b/src/corpus.ml new file mode 100644 index 0000000..886566b --- /dev/null +++ b/src/corpus.ml @@ -0,0 +1,124 @@ +open Types + +type case = { + name : string; + summary : string; + claim : string; + ty : typ; + source : Source.term; + flags : Pipeline.optimisation_flags; + relation : Relation.relation; +} + +let default = Pipeline.default_flags + +let safer = + { default with unsafe_repr_eq = false; unsafe_strict_unroll = false } + +let strictness_flags = + { default with inline = false; unsafe_repr_eq = false; unsafe_strict_unroll = true } + +let poly_const_false = + Source.TyLam ("a", Source.Lam ("x", TVar "a", TArrow (TVar "a", TBool), Source.Lam ("y", TVar "a", TBool, Source.Bool false))) + +let worker_wrapper_id = + Source.Lam + ( "p", + TPair (TInt, TBool), + TPair (TInt, TBool), + Source.Var "p" ) + +let strictness_roll = + Source.Let + ( "x", + Source.Roll + ( TMu ("a", TInt), + Source.Fix ("loop", TInt, Source.Var "loop") ), + Source.Int 0 ) + +let pair_sum_roundtrip = + Source.Case + ( Source.Inl (TSum (TPair (TInt, TBool), TInt), Source.Pair (Source.Int 7, Source.Bool true)), + ("p", Source.LetPair ("n", "b", Source.Var "p", Source.If (Source.Var "b", Source.Var "n", Source.Int 0))), + ("k", Source.Var "k") ) + +let boxed_pair_projection = + Source.App + ( Source.Lam + ( "p", + TPair (TInt, TBool), + TInt, + Source.LetPair ("n", "b", Source.Var "p", Source.Var "n") ), + Source.Pair (Source.Int 9, Source.Bool false) ) + +let safe_poly_bool = + Source.Let + ( "poly_const_false", + poly_const_false, + Source.App + ( Source.App (Source.TyApp (Source.Var "poly_const_false", TBool), Source.Bool true), + Source.Bool false ) ) + +let all = + [ + { + name = "worker-wrapper-polymorphic-shape"; + summary = "synthesise a boxed wrapper and unboxed worker for a boxed pair function"; + claim = "worker-wrapper preserves the boxed to unboxed relation when argument and result projections are pointwise related"; + ty = TArrow (TPair (TInt, TBool), TPair (TInt, TBool)); + source = worker_wrapper_id; + flags = safer; + relation = Relation.Boxed_unboxed; + }; + { + name = "boxed-unboxed-pair-correspondence"; + summary = "project from a boxed source pair after lowering to an unboxed tuple"; + claim = "representation lowering is benign when no representation specific primitive crosses the abstraction boundary"; + ty = TInt; + source = boxed_pair_projection; + flags = safer; + relation = Relation.Boxed_unboxed; + }; + { + name = "sum-lowering-preserves-branching"; + summary = "preserve cbv case analysis across unboxed sum lowering"; + claim = "sum lowering preserves branch choice and payload relation when the scrutinee stays representation parametric"; + ty = TInt; + source = pair_sum_roundtrip; + flags = safer; + relation = Relation.Boxed_unboxed; + }; + { + name = "safe-polymorphic-instantiation"; + summary = "instantiate a polymorphic constant function at bool without exposing ground equality"; + claim = "specialisation alone neednt break a free theorem if inlining doesnt smuggle a representation test"; + ty = TBool; + source = safe_poly_bool; + flags = safer; + relation = Relation.Baseline; + }; + { + name = "free-theorem-fails-after-unsafe-inlining"; + summary = "inline a specialised polymorphic constant into `eq_int`"; + claim = "aggressive inlining breaks the abstraction theorem once it rewrites an abstract branch into a representation specific primitive"; + ty = TBool; + source = + Source.Let + ( "poly_const_false", + poly_const_false, + Source.App + ( Source.App (Source.TyApp (Source.Var "poly_const_false", TInt), Source.Int 1), + Source.Int 1 ) ); + flags = default; + relation = Relation.Baseline; + }; + { + name = "strictness-induced-termination-change"; + summary = "force a recursive payload earlier by lowering `roll` too aggressively"; + claim = "unboxing is not semantics preserving when a boxing boundary carried a non-strict observational effect"; + ty = TInt; + source = strictness_roll; + flags = strictness_flags; + relation = Relation.Boxed_unboxed; + }; + ] diff --git a/src/corpus.mli b/src/corpus.mli new file mode 100644 index 0000000..2848bb5 --- /dev/null +++ b/src/corpus.mli @@ -0,0 +1,11 @@ +type case = { + name : string; + summary : string; + claim : string; + ty : Types.typ; + source : Source.term; + flags : Pipeline.optimisation_flags; + relation : Relation.relation; +} + +val all : case list diff --git a/src/dune b/src/dune new file mode 100644 index 0000000..04c237f --- /dev/null +++ b/src/dune @@ -0,0 +1,4 @@ +(library + (name vanity) + (public_name vanity)) + diff --git a/src/gen.ml b/src/gen.ml new file mode 100644 index 0000000..a2c3e0f --- /dev/null +++ b/src/gen.ml @@ -0,0 +1,124 @@ +open Types + +type specimen = { + ty : typ; + term : Source.term; +} + +type result = { + total : int; + related : int; + failures : (specimen * Relation.comparison) list; +} + +let rng = Random.State.make [| 0x51; 0x17; 0x2b |] + +let fresh = + let c = ref 0 in + fun prefix -> + incr c; + prefix ^ string_of_int !c + +let pick xs = List.nth xs (Random.State.int rng (List.length xs)) + +let rec gen_typ depth = + if depth = 0 then pick [TInt; TBool] + else + pick + [ + TInt; + TBool; + TPair (gen_typ (depth - 1), gen_typ (depth - 1)); + TSum (gen_typ (depth - 1), gen_typ (depth - 1)); + ] + +let rec gen_closed depth ty = + if depth = 0 then + match ty with + | TInt -> Source.Int (Random.State.int rng 5) + | TBool -> Source.Bool (Random.State.bool rng) + | TPair (a, b) -> Source.Pair (gen_closed 0 a, gen_closed 0 b) + | TSum (a, b) -> + if Random.State.bool rng then Source.Inl (TSum (a, b), gen_closed 0 a) + else Source.Inr (TSum (a, b), gen_closed 0 b) + | _ -> Source.Int 0 + else + match ty with + | TInt -> + pick + [ + Source.Int (Random.State.int rng 7); + Source.If (Source.Bool (Random.State.bool rng), Source.Int 1, Source.Int 0); + ] + | TBool -> + pick + [ + Source.Bool (Random.State.bool rng); + Source.Eq (GBool, Source.Bool (Random.State.bool rng), Source.Bool (Random.State.bool rng)); + ] + | TPair (a, b) -> Source.Pair (gen_closed (depth - 1) a, gen_closed (depth - 1) b) + | TSum (a, b) -> + if Random.State.bool rng then Source.Inl (TSum (a, b), gen_closed (depth - 1) a) + else Source.Inr (TSum (a, b), gen_closed (depth - 1) b) + | _ -> gen_closed 0 ty + +let sample_terms ~count ~max_depth () = + let adversarial = + [ + { + ty = TBool; + term = + Source.Let + ( "poly_const_false", + Source.TyLam ("a", Source.Lam ("x", TVar "a", TArrow (TVar "a", TBool), Source.Lam ("y", TVar "a", TBool, Source.Bool false))), + Source.App + ( Source.App (Source.TyApp (Source.Var "poly_const_false", TInt), Source.Int 1), + Source.Int 1 ) ); + }; + { + ty = TInt; + term = + Source.Let + ( "x", + Source.Roll + ( TMu ("a", TInt), + Source.Fix ("loop", TInt, Source.Var "loop") ), + Source.Int 0 ); + }; + ] + in + let rec fill acc i attempts = + if i = count then List.rev acc + else if attempts > count * 32 then List.rev acc + else if i < min count 20 then + let specimen = List.nth adversarial (i mod List.length adversarial) in + if Typecheck.is_well_typed specimen.ty specimen.term then fill (specimen :: acc) (i + 1) attempts + else fill acc i (attempts + 1) + else + let ty = gen_typ max_depth in + let specimen = { ty; term = gen_closed max_depth ty } in + if Typecheck.is_well_typed specimen.ty specimen.term then fill (specimen :: acc) (i + 1) attempts + else fill acc i (attempts + 1) + in + fill [] 0 0 + +let shrink specimen = + match specimen.term, specimen.ty with + | Source.Pair (a, _), ty -> [{ term = a; ty }] + | Source.If (_, t, e), _ -> [{ specimen with term = t }; { specimen with term = e }] + | _ -> [] + +let run_campaign ?(fuel = 256) flags relation ~count ~max_depth () = + let specimens = sample_terms ~count ~max_depth () in + let failures, related = + List.fold_left + (fun (fails, ok) specimen -> + let compiled = Pipeline.compile flags specimen.ty specimen.term in + let cmp = Relation.compare_programs ~fuel relation specimen.ty specimen.term compiled.target_program in + match cmp.verdict with + | Relation.Related -> (fails, ok + 1) + | Relation.Unrelated _ -> ((specimen, cmp) :: fails, ok)) + ([], 0) + specimens + in + { total = List.length specimens; related; failures = List.rev failures } diff --git a/src/gen.mli b/src/gen.mli new file mode 100644 index 0000000..28e9c56 --- /dev/null +++ b/src/gen.mli @@ -0,0 +1,24 @@ +open Types + +type specimen = { + ty : typ; + term : Source.term; +} + +type result = { + total : int; + related : int; + failures : (specimen * Relation.comparison) list; +} + +val sample_terms : count:int -> max_depth:int -> unit -> specimen list +val shrink : specimen -> specimen list +val run_campaign : + ?fuel:int -> + Pipeline.optimisation_flags -> + Relation.relation -> + count:int -> + max_depth:int -> + unit -> + result + diff --git a/src/pipeline.ml b/src/pipeline.ml new file mode 100644 index 0000000..6dca8b6 --- /dev/null +++ b/src/pipeline.ml @@ -0,0 +1,219 @@ +open Types + +type optimisation_flags = { + specialise : bool; + inline : bool; + repr_lower : bool; + unsafe_repr_eq : bool; + unsafe_strict_unroll : bool; +} + +type obligation_kind = + | Preserve_relation + | Exposed_representation + | Strictness_risk + | Worker_wrapper_proof + +type obligation = { + pass : string; + kind : obligation_kind; + subject : string; + detail : string; +} + +type compiled = { + source_type : typ; + source_program : Source.term; + specialised : Source.term; + inlined : Source.term; + target_program : Target.term; + obligations : obligation list; +} + +let default_flags = + { + specialise = true; + inline = true; + repr_lower = true; + unsafe_repr_eq = true; + unsafe_strict_unroll = true; + } + +let obligation ?(kind = Preserve_relation) pass subject detail = + { pass; kind; subject; detail } + +let rec specialise_term obligations term = + match term with + | Source.TyApp (Source.TyLam (x, body), ty) -> + let body' = specialise_term obligations (Source.TyApp (Source.TyLam (x, body), ty)) in + begin match body' with + | Source.TyApp (Source.TyLam (x, body), ty) -> Source.(substitute_type x ty body) + | other -> other + end + | Source.Pair (a, b) -> Source.Pair (specialise_term obligations a, specialise_term obligations b) + | Source.Inl (ty, t) -> Source.Inl (ty, specialise_term obligations t) + | Source.Inr (ty, t) -> Source.Inr (ty, specialise_term obligations t) + | Source.Lam (x, a, b, body) -> Source.Lam (x, a, b, specialise_term obligations body) + | Source.App (f, a) -> Source.App (specialise_term obligations f, specialise_term obligations a) + | Source.TyLam (x, body) -> Source.TyLam (x, specialise_term obligations body) + | Source.TyApp (t, ty) -> Source.TyApp (specialise_term obligations t, ty) + | Source.Roll (ty, t) -> Source.Roll (ty, specialise_term obligations t) + | Source.Unroll t -> Source.Unroll (specialise_term obligations t) + | Source.Fix (f, ty, body) -> Source.Fix (f, ty, specialise_term obligations body) + | Source.Eq (g, a, b) -> Source.Eq (g, specialise_term obligations a, specialise_term obligations b) + | Source.If (c, t, e) -> Source.If (specialise_term obligations c, specialise_term obligations t, specialise_term obligations e) + | Source.Let (x, a, b) -> Source.Let (x, specialise_term obligations a, specialise_term obligations b) + | Source.LetPair (x, y, a, b) -> Source.LetPair (x, y, specialise_term obligations a, specialise_term obligations b) + | Source.Case (s, (x, l), (y, r)) -> + Source.Case (specialise_term obligations s, (x, specialise_term obligations l), (y, specialise_term obligations r)) + | other -> other + +let rec inline_cost = function + | Source.Int _ | Source.Bool _ | Source.Var _ -> 1 + | Source.Lam (_, _, _, body) | Source.TyLam (_, body) -> 1 + inline_cost body + | Source.App (f, a) -> 1 + inline_cost f + inline_cost a + | Source.TyApp (f, _) -> 1 + inline_cost f + | Source.Pair (a, b) -> 1 + inline_cost a + inline_cost b + | Source.Inl (_, t) | Source.Inr (_, t) | Source.Roll (_, t) | Source.Unroll t -> 1 + inline_cost t + | Source.Fix (_, _, body) -> 2 + inline_cost body + | Source.Eq (_, a, b) -> 1 + inline_cost a + inline_cost b + | Source.If (c, t, e) -> 1 + inline_cost c + inline_cost t + inline_cost e + | Source.Let (_, a, b) -> inline_cost a + inline_cost b + | Source.LetPair (_, _, a, b) -> 1 + inline_cost a + inline_cost b + | Source.Case (s, (_, l), (_, r)) -> 1 + inline_cost s + inline_cost l + inline_cost r + +let rec inline_term obligations unsafe_repr_eq term = + match term with + | Source.App (Source.Lam (x, _, _, body), arg) when inline_cost body <= 8 -> + Source.substitute x arg (inline_term obligations unsafe_repr_eq body) + | Source.TyApp (Source.TyLam (x, body), ty) -> + Source.substitute_type x ty (inline_term obligations unsafe_repr_eq body) + | Source.TyApp (Source.Var f, TInt) when unsafe_repr_eq && String.equal f "poly_const_false" -> + ignore obligations; + Source.Lam ("x", TInt, TArrow (TInt, TBool), Source.Lam ("y", TInt, TBool, Source.Eq (GInt, Source.Var "x", Source.Var "y"))) + | Source.Pair (a, b) -> Source.Pair (inline_term obligations unsafe_repr_eq a, inline_term obligations unsafe_repr_eq b) + | Source.Inl (ty, t) -> Source.Inl (ty, inline_term obligations unsafe_repr_eq t) + | Source.Inr (ty, t) -> Source.Inr (ty, inline_term obligations unsafe_repr_eq t) + | Source.Lam (x, a, b, body) -> Source.Lam (x, a, b, inline_term obligations unsafe_repr_eq body) + | Source.App (f, a) -> + let f' = inline_term obligations unsafe_repr_eq f in + let a' = inline_term obligations unsafe_repr_eq a in + begin match f' with + | Source.Lam (x, _, _, body) when inline_cost body <= 8 -> + inline_term obligations unsafe_repr_eq (Source.substitute x a' body) + | _ -> Source.App (f', a') + end + | Source.TyLam (x, body) -> Source.TyLam (x, inline_term obligations unsafe_repr_eq body) + | Source.TyApp (t, ty) -> + let t' = inline_term obligations unsafe_repr_eq t in + begin match t' with + | Source.TyLam (x, body) -> inline_term obligations unsafe_repr_eq (Source.substitute_type x ty body) + | _ -> Source.TyApp (t', ty) + end + | Source.Roll (ty, t) -> Source.Roll (ty, inline_term obligations unsafe_repr_eq t) + | Source.Unroll t -> Source.Unroll (inline_term obligations unsafe_repr_eq t) + | Source.Fix (f, ty, body) -> Source.Fix (f, ty, inline_term obligations unsafe_repr_eq body) + | Source.Eq (g, a, b) -> Source.Eq (g, inline_term obligations unsafe_repr_eq a, inline_term obligations unsafe_repr_eq b) + | Source.If (c, t, e) -> + Source.If (inline_term obligations unsafe_repr_eq c, inline_term obligations unsafe_repr_eq t, inline_term obligations unsafe_repr_eq e) + | Source.Let (x, a, b) -> + let a' = inline_term obligations unsafe_repr_eq a in + let b' = inline_term obligations unsafe_repr_eq b in + if inline_cost a' <= 4 then Source.substitute x a' b' else Source.Let (x, a', b') + | Source.LetPair (x, y, a, b) -> Source.LetPair (x, y, inline_term obligations unsafe_repr_eq a, inline_term obligations unsafe_repr_eq b) + | Source.Case (s, (x, l), (y, r)) -> + Source.Case (inline_term obligations unsafe_repr_eq s, (x, inline_term obligations unsafe_repr_eq l), (y, inline_term obligations unsafe_repr_eq r)) + | other -> other + +let rec repr_of_typ = function + | TInt -> Target.RInt + | TBool -> Target.RBool + | TPair (a, b) -> Target.RTuple [repr_of_typ a; repr_of_typ b] + | TSum (a, b) -> Target.RSum (repr_of_typ a, repr_of_typ b) + | TArrow (a, b) -> Target.RFun (Target.Boxed, [Target.RBox a], repr_of_typ b) + | TForall _ | TMu _ | TVar _ as ty -> Target.RBox ty + +let rec lower strict_unroll = function + | Source.Var x -> Target.Var x + | Source.Int n -> Target.Int n + | Source.Bool b -> Target.Bool b + | Source.Pair (a, b) -> Target.Tuple [lower strict_unroll a; lower strict_unroll b] + | Source.Inl (TSum (a, b), t) -> Target.Inl (repr_of_typ a, repr_of_typ b, lower strict_unroll t) + | Source.Inl (_, t) -> Target.Inl (Target.RBox TInt, Target.RBox TInt, lower strict_unroll t) + | Source.Inr (TSum (a, b), t) -> Target.Inr (repr_of_typ a, repr_of_typ b, lower strict_unroll t) + | Source.Inr (_, t) -> Target.Inr (Target.RBox TInt, Target.RBox TInt, lower strict_unroll t) + | Source.Lam (x, arg_ty, res_ty, body) -> + Target.Lam (Target.Boxed, [x, Target.RBox arg_ty], repr_of_typ res_ty, lower strict_unroll body) + | Source.App (f, a) -> Target.App (lower strict_unroll f, [Target.Box (TInt, lower strict_unroll a)]) + | Source.TyLam (_, body) -> lower strict_unroll body + | Source.TyApp (t, _) -> lower strict_unroll t + | Source.Roll (ty, body) -> + if strict_unroll then + Target.Let ("forced_roll_payload", lower strict_unroll body, Target.Roll (ty, Target.Var "forced_roll_payload")) + else Target.Box (ty, Target.Roll (ty, lower strict_unroll body)) + | Source.Unroll t -> + if strict_unroll then Target.Unroll (lower strict_unroll t) + else Target.Unroll (Target.Unbox (lower strict_unroll t)) + | Source.Fix (f, ty, body) -> + Target.LetRec (f, repr_of_typ ty, lower strict_unroll body, Target.Var f) + | Source.Eq (GInt, a, b) -> Target.EqInt (lower strict_unroll a, lower strict_unroll b) + | Source.Eq (GBool, a, b) -> Target.EqBool (lower strict_unroll a, lower strict_unroll b) + | Source.If (c, t, e) -> + Target.If (lower strict_unroll c, lower strict_unroll t, lower strict_unroll e) + | Source.Let (x, a, b) -> Target.Let (x, lower strict_unroll a, lower strict_unroll b) + | Source.LetPair (x, y, a, b) -> + Target.Let + ("tmp", lower strict_unroll a, + Target.Let (x, Target.Proj (0, Target.Var "tmp"), + Target.Let (y, Target.Proj (1, Target.Var "tmp"), lower strict_unroll b))) + | Source.Case (s, (x, l), (y, r)) -> Target.Case (lower strict_unroll s, (x, lower strict_unroll l), (y, lower strict_unroll r)) + +let synthesize_worker_wrapper ty body = + match ty with + | TArrow (TPair (TInt, TBool), TPair (TInt, TBool)) -> + let worker_body = lower false body in + Target.WorkerWrapper + { + wrapper = "wrapper"; + worker = "worker"; + boxed_arg = TPair (TInt, TBool); + unboxed_args = [Target.RInt; Target.RBool]; + result_repr = Target.RTuple [Target.RInt; Target.RBool]; + wrap_body = + Target.Let + ("p", Target.Unbox (Target.Var "boxed"), + Target.App (Target.Var "worker", [Target.Proj (0, Target.Var "p"); Target.Proj (1, Target.Var "p")])); + worker_body; + in_term = Target.Var "wrapper"; + } + | _ -> lower false body + +let compile flags source_type source_program = + let obligations = ref [] in + let note pass kind subject detail = + obligations := obligation ~kind pass subject detail :: !obligations + in + let specialised = + if flags.specialise then begin + note "specialise" Preserve_relation "type application cloning" "specialisation preserves source typing but requires instantiation closedness"; + specialise_term obligations source_program + end else source_program + in + let inlined = + if flags.inline then begin + if flags.unsafe_repr_eq then + note "inline" Exposed_representation "parametric body exposure" "inlining may replace an abstract constant relation with eq_int at specialised int instances"; + inline_term obligations flags.unsafe_repr_eq specialised + end else specialised + in + let target_program = + if flags.repr_lower then begin + note "repr_lower" Worker_wrapper_proof "worker-wrapper" "wrapper must refine the boxed to unboxed relation on arguments and results"; + if flags.unsafe_strict_unroll then + note "repr_lower" Strictness_risk "roll/unroll" "strict unroll lowering may force recursive payloads earlier than the source"; + match source_type with + | TArrow (TPair (TInt, TBool), TPair (TInt, TBool)) -> synthesize_worker_wrapper source_type inlined + | _ -> lower flags.unsafe_strict_unroll inlined + end else lower false inlined + in + { source_type; source_program; specialised; inlined; target_program; obligations = List.rev !obligations } diff --git a/src/pipeline.mli b/src/pipeline.mli new file mode 100644 index 0000000..92999e9 --- /dev/null +++ b/src/pipeline.mli @@ -0,0 +1,35 @@ +open Types + +type optimisation_flags = { + specialise : bool; + inline : bool; + repr_lower : bool; + unsafe_repr_eq : bool; + unsafe_strict_unroll : bool; +} + +type obligation_kind = + | Preserve_relation + | Exposed_representation + | Strictness_risk + | Worker_wrapper_proof + +type obligation = { + pass : string; + kind : obligation_kind; + subject : string; + detail : string; +} + +type compiled = { + source_type : typ; + source_program : Source.term; + specialised : Source.term; + inlined : Source.term; + target_program : Target.term; + obligations : obligation list; +} + +val default_flags : optimisation_flags +val compile : optimisation_flags -> typ -> Source.term -> compiled + diff --git a/src/project.ml b/src/project.ml new file mode 100644 index 0000000..8d2f8af --- /dev/null +++ b/src/project.ml @@ -0,0 +1,2 @@ +let version = "0.1.0" + diff --git a/src/project.mli b/src/project.mli new file mode 100644 index 0000000..b04d0d3 --- /dev/null +++ b/src/project.mli @@ -0,0 +1,2 @@ +val version : string + diff --git a/src/relation.ml b/src/relation.ml new file mode 100644 index 0000000..26375c7 --- /dev/null +++ b/src/relation.ml @@ -0,0 +1,70 @@ +open Types + +type relation = + | Baseline + | Boxed_unboxed + | Ground_only + +type verdict = + | Related + | Unrelated of string + +type comparison = { + source_outcome : Source.outcome; + target_outcome : Target.outcome; + verdict : verdict; +} + +let rec source_value_relation rel ty a b = + match rel, ty, a, b with + | _, TInt, Source.VInt x, Source.VInt y -> x = y + | _, TBool, Source.VBool x, Source.VBool y -> Bool.equal x y + | _, TPair (ta, tb), Source.VPair (xa, xb), Source.VPair (ya, yb) -> + source_value_relation rel ta xa ya && source_value_relation rel tb xb yb + | _, TSum (ta, _), Source.VInl (_, xa), Source.VInl (_, ya) -> + source_value_relation rel ta xa ya + | _, TSum (_, tb), Source.VInr (_, xa), Source.VInr (_, ya) -> + source_value_relation rel tb xa ya + | Baseline, TForall _, Source.VTyLam _, Source.VTyLam _ -> true + | Baseline, TArrow _, Source.VLam _, Source.VLam _ -> true + | Baseline, TMu _, Source.VRoll _, Source.VRoll _ -> true + | Ground_only, _, _, _ -> false + | Boxed_unboxed, TArrow _, Source.VLam _, Source.VLam _ -> true + | Boxed_unboxed, TMu _, Source.VRoll _, Source.VRoll _ -> true + | _ -> false + +let rec cross_value_relation rel ty sv tv = + match rel, ty, sv, tv with + | _, TInt, Source.VInt x, Target.VInt y -> x = y + | _, TBool, Source.VBool x, Target.VBool y -> Bool.equal x y + | Boxed_unboxed, TInt, Source.VInt x, Target.VBox (_, Target.VInt y) -> x = y + | Boxed_unboxed, TBool, Source.VBool x, Target.VBox (_, Target.VBool y) -> Bool.equal x y + | _, TPair (ta, tb), Source.VPair (a1, a2), Target.VTuple [b1; b2] -> + cross_value_relation rel ta a1 b1 && cross_value_relation rel tb a2 b2 + | Boxed_unboxed, TPair (ta, tb), Source.VPair (a1, a2), Target.VBox (_, Target.VTuple [b1; b2]) -> + cross_value_relation rel ta a1 b1 && cross_value_relation rel tb a2 b2 + | _, TSum (ta, _), Source.VInl (_, a), Target.VInl (_, _, b) -> cross_value_relation rel ta a b + | _, TSum (_, tb), Source.VInr (_, a), Target.VInr (_, _, b) -> cross_value_relation rel tb a b + | Baseline, TArrow _, Source.VLam _, Target.VLam _ -> true + | Boxed_unboxed, TArrow _, Source.VLam _, Target.VLam _ -> true + | Baseline, TForall _, Source.VTyLam _, _ -> true + | Baseline, TMu _, Source.VRoll _, Target.VRoll _ -> true + | Boxed_unboxed, TMu _, Source.VRoll _, Target.VBox (_, Target.VRoll _) -> true + | _ -> false + +let compare_programs ?(fuel = 256) rel ty src tgt = + let source_outcome = Source.observe ~fuel src in + let target_outcome = Target.observe ~fuel tgt in + let verdict = + match source_outcome, target_outcome with + | Source.Value sv, Target.Value tv -> + if cross_value_relation rel ty sv tv then Related + else Unrelated "value relation failed" + | Source.Diverged _, Target.Diverged _ -> Related + | Source.Diverged _, Target.Value _ -> Unrelated "target terminated where source diverged" + | Source.Value _, Target.Diverged _ -> Unrelated "target diverged where source terminated" + | Source.Stuck msg, _ -> Unrelated ("source stuck: " ^ msg) + | _, Target.Stuck msg -> Unrelated ("target stuck: " ^ msg) + in + { source_outcome; target_outcome; verdict } + diff --git a/src/relation.mli b/src/relation.mli new file mode 100644 index 0000000..e249d18 --- /dev/null +++ b/src/relation.mli @@ -0,0 +1,21 @@ +open Types + +type relation = + | Baseline + | Boxed_unboxed + | Ground_only + +type verdict = + | Related + | Unrelated of string + +type comparison = { + source_outcome : Source.outcome; + target_outcome : Target.outcome; + verdict : verdict; +} + +val source_value_relation : relation -> typ -> Source.value -> Source.value -> bool +val cross_value_relation : relation -> typ -> Source.value -> Target.value -> bool +val compare_programs : ?fuel:int -> relation -> typ -> Source.term -> Target.term -> comparison + diff --git a/src/reporting.ml b/src/reporting.ml new file mode 100644 index 0000000..f0d0f49 --- /dev/null +++ b/src/reporting.ml @@ -0,0 +1,81 @@ +let obligation_kind_to_string = function + | Pipeline.Preserve_relation -> "preserve_relation" + | Pipeline.Exposed_representation -> "exposed_representation" + | Pipeline.Strictness_risk -> "strictness_risk" + | Pipeline.Worker_wrapper_proof -> "worker_wrapper_proof" + +let string_of_source_outcome = function + | Source.Value v -> "value " ^ Source.string_of_value v + | Source.Stuck msg -> "stuck " ^ msg + | Source.Diverged n -> "diverged after " ^ string_of_int n ^ " steps" + +let string_of_target_outcome = function + | Target.Value v -> "value " ^ Target.string_of_value v + | Target.Stuck msg -> "stuck " ^ msg + | Target.Diverged n -> "diverged after " ^ string_of_int n ^ " steps" + +let string_of_relation = function + | Relation.Baseline -> "baseline" + | Relation.Boxed_unboxed -> "boxed_unboxed" + | Relation.Ground_only -> "ground_only" + +let verdict_to_string = function + | Relation.Related -> "related" + | Relation.Unrelated msg -> "unrelated: " ^ msg + +let take_steps limit steps = + let rec go acc n rest = + match rest with + | [] -> (List.rev acc, 0) + | _ when n = 0 -> (List.rev acc, List.length rest) + | x :: xs -> go (x :: acc) (n - 1) xs + in + go [] limit steps + +let emit_obligations (obligations : Pipeline.obligation list) = + if obligations = [] then "- none" + else + obligations + |> List.map (fun o -> + "- [" ^ obligation_kind_to_string o.Pipeline.kind ^ "] " ^ o.Pipeline.pass ^ + " / " ^ o.Pipeline.subject ^ ": " ^ o.Pipeline.detail) + |> String.concat "\n" + +let emit_case_header (case : Corpus.case) (comparison : Relation.comparison) (compiled : Pipeline.compiled) = + String.concat "\n" + [ + "## " ^ case.Corpus.name; + ""; + "| field | value |"; + "| --- | --- |"; + "| summary | " ^ case.Corpus.summary ^ " |"; + "| claim | " ^ case.Corpus.claim ^ " |"; + "| relation | `" ^ string_of_relation case.Corpus.relation ^ "` |"; + "| source type | `" ^ Types.string_of_typ case.Corpus.ty ^ "` |"; + "| verdict | `" ^ verdict_to_string comparison.Relation.verdict ^ "` |"; + "| obligations | " ^ string_of_int (List.length compiled.Pipeline.obligations) ^ " |"; + ""; + ] + +let emit_steps pp steps = + steps + |> List.mapi (fun i t -> string_of_int i ^ ": " ^ pp t) + |> String.concat "\n" + +let emit_counterexample title (src_trace : Source.trace) (tgt_trace : Target.trace) (obligations : Pipeline.obligation list) = + let _ = title in + let src_steps, src_hidden = take_steps 32 src_trace.Source.steps in + let tgt_steps, tgt_hidden = take_steps 48 tgt_trace.Target.steps in + let src_suffix = + if src_hidden = 0 then "" + else "\n... " ^ string_of_int src_hidden ^ " more source steps omitted" + in + let tgt_suffix = + if tgt_hidden = 0 then "" + else "\n... " ^ string_of_int tgt_hidden ^ " more target steps omitted" + in + "source trace:\n" ^ emit_steps Source.string_of_term src_steps ^ + src_suffix ^ "\n\nsource outcome: " ^ string_of_source_outcome src_trace.Source.outcome ^ + "\n\ntarget trace:\n" ^ emit_steps Target.string_of_term tgt_steps ^ + tgt_suffix ^ "\n\ntarget outcome: " ^ string_of_target_outcome tgt_trace.Target.outcome ^ + "\n\nobligations:\n" ^ emit_obligations obligations ^ "\n" diff --git a/src/reporting.mli b/src/reporting.mli new file mode 100644 index 0000000..7c5971e --- /dev/null +++ b/src/reporting.mli @@ -0,0 +1,9 @@ +val obligation_kind_to_string : Pipeline.obligation_kind -> string +val string_of_source_outcome : Source.outcome -> string +val string_of_target_outcome : Target.outcome -> string +val string_of_relation : Relation.relation -> string +val verdict_to_string : Relation.verdict -> string +val take_steps : int -> 'a list -> 'a list * int +val emit_obligations : Pipeline.obligation list -> string +val emit_case_header : Corpus.case -> Relation.comparison -> Pipeline.compiled -> string +val emit_counterexample : string -> Source.trace -> Target.trace -> Pipeline.obligation list -> string diff --git a/src/source.ml b/src/source.ml new file mode 100644 index 0000000..4132a66 --- /dev/null +++ b/src/source.ml @@ -0,0 +1,257 @@ +open Types + +type var = string + +type term = + | Var of var + | Int of int + | Bool of bool + | Pair of term * term + | Inl of typ * term + | Inr of typ * term + | Lam of var * typ * typ * term + | App of term * term + | TyLam of string * term + | TyApp of term * typ + | Roll of typ * term + | Unroll of term + | Fix of var * typ * term + | Eq of ground * term * term + | If of term * term * term + | Let of var * term * term + | LetPair of var * var * term * term + | Case of term * (var * term) * (var * term) + +type 'a expr = Expr : typ * term -> 'a expr +type packed_expr = Pack_expr : 'a expr -> packed_expr + +type value = + | VInt of int + | VBool of bool + | VPair of value * value + | VInl of typ * value + | VInr of typ * value + | VLam of var * typ * typ * term + | VTyLam of string * term + | VRoll of typ * term + +type frame = + | FAppL of term + | FAppR of value + | FPairL of term + | FPairR of value + | FInl of typ + | FInr of typ + | FIf of term * term + | FEqL of ground * term + | FEqR of ground * value + | FLet of var * term + | FLetPair of var * var * term + | FCase of (var * term) * (var * term) + | FTyApp of typ + | FUnroll + +type outcome = + | Value of value + | Stuck of string + | Diverged of int + +type trace = { + steps : term list; + outcome : outcome; +} + +let typ_of (Expr (ty, _)) = ty +let pack ty term = Pack_expr (Expr (ty, term)) + +let rec string_of_term = function + | Var x -> x + | Int n -> string_of_int n + | Bool b -> string_of_bool b + | Pair (a, b) -> "(" ^ string_of_term a ^ ", " ^ string_of_term b ^ ")" + | Inl (_, t) -> "(inl " ^ string_of_term t ^ ")" + | Inr (_, t) -> "(inr " ^ string_of_term t ^ ")" + | Lam (x, ty, _, body) -> + "(fun (" ^ x ^ " : " ^ string_of_typ ty ^ ") -> " ^ string_of_term body ^ ")" + | App (f, x) -> "(" ^ string_of_term f ^ " " ^ string_of_term x ^ ")" + | TyLam (a, body) -> "(/\\" ^ a ^ ". " ^ string_of_term body ^ ")" + | TyApp (t, ty) -> "(" ^ string_of_term t ^ " [" ^ string_of_typ ty ^ "])" + | Roll (ty, body) -> "(roll[" ^ string_of_typ ty ^ "] " ^ string_of_term body ^ ")" + | Unroll t -> "(unroll " ^ string_of_term t ^ ")" + | Fix (f, _, body) -> "(fix " ^ f ^ ". " ^ string_of_term body ^ ")" + | Eq (g, a, b) -> + "(eq_" ^ string_of_ground g ^ " " ^ string_of_term a ^ " " ^ string_of_term b ^ ")" + | If (c, t, e) -> + "(if " ^ string_of_term c ^ " then " ^ string_of_term t ^ " else " ^ string_of_term e ^ ")" + | Let (x, a, b) -> + "(let " ^ x ^ " = " ^ string_of_term a ^ " in " ^ string_of_term b ^ ")" + | LetPair (x, y, a, b) -> + "(let (" ^ x ^ ", " ^ y ^ ") = " ^ string_of_term a ^ " in " ^ string_of_term b ^ ")" + | Case (scrut, (x, l), (y, r)) -> + "(case " ^ string_of_term scrut ^ " of inl " ^ x ^ " -> " ^ string_of_term l ^ + " | inr " ^ y ^ " -> " ^ string_of_term r ^ ")" + +let rec string_of_value = function + | VInt n -> string_of_int n + | VBool b -> string_of_bool b + | VPair (a, b) -> "(" ^ string_of_value a ^ ", " ^ string_of_value b ^ ")" + | VInl (_, v) -> "(inl " ^ string_of_value v ^ ")" + | VInr (_, v) -> "(inr " ^ string_of_value v ^ ")" + | VLam _ -> "" + | VTyLam _ -> "" + | VRoll (_, _) -> "" + +let rec is_value = function + | Int _ | Bool _ | Lam _ | TyLam _ -> true + | Pair (a, b) -> is_value a && is_value b + | Inl (_, t) | Inr (_, t) -> is_value t + | Roll (_, _) -> true + | _ -> false + +let rec substitute x repl term = + let go = substitute x repl in + match term with + | Var y -> if String.equal x y then repl else term + | Int _ | Bool _ -> term + | Pair (a, b) -> Pair (go a, go b) + | Inl (ty, t) -> Inl (ty, go t) + | Inr (ty, t) -> Inr (ty, go t) + | Lam (y, a, b, body) -> + if String.equal x y then term else Lam (y, a, b, go body) + | App (f, a) -> App (go f, go a) + | TyLam _ -> term + | TyApp (t, ty) -> TyApp (go t, ty) + | Roll (ty, body) -> Roll (ty, go body) + | Unroll t -> Unroll (go t) + | Fix (f, ty, body) -> + if String.equal x f then term else Fix (f, ty, go body) + | Eq (g, a, b) -> Eq (g, go a, go b) + | If (c, t, e) -> If (go c, go t, go e) + | Let (y, a, b) -> + Let (y, go a, if String.equal x y then b else go b) + | LetPair (y, z, a, b) -> + LetPair (y, z, go a, if String.equal x y || String.equal x z then b else go b) + | Case (scrut, (y, l), (z, r)) -> + Case (go scrut, (y, if String.equal x y then l else go l), (z, if String.equal x z then r else go r)) + +let rec substitute_type x repl term = + let sub_ty = substitute_typ x repl in + let go = substitute_type x repl in + match term with + | Var _ | Int _ | Bool _ -> term + | Pair (a, b) -> Pair (go a, go b) + | Inl (ty, t) -> Inl (sub_ty ty, go t) + | Inr (ty, t) -> Inr (sub_ty ty, go t) + | Lam (y, a, b, body) -> Lam (y, sub_ty a, sub_ty b, go body) + | App (f, a) -> App (go f, go a) + | TyLam (y, body) -> if String.equal x y then term else TyLam (y, go body) + | TyApp (t, ty) -> TyApp (go t, sub_ty ty) + | Roll (ty, body) -> Roll (sub_ty ty, go body) + | Unroll t -> Unroll (go t) + | Fix (f, ty, body) -> Fix (f, sub_ty ty, go body) + | Eq (g, a, b) -> Eq (g, go a, go b) + | If (c, t, e) -> If (go c, go t, go e) + | Let (y, a, b) -> Let (y, go a, go b) + | LetPair (y, z, a, b) -> LetPair (y, z, go a, go b) + | Case (scrut, (y, l), (z, r)) -> Case (go scrut, (y, go l), (z, go r)) + +let rec value_to_term = function + | VInt n -> Int n + | VBool b -> Bool b + | VPair (a, b) -> Pair (value_to_term a, value_to_term b) + | VInl (ty, v) -> Inl (ty, value_to_term v) + | VInr (ty, v) -> Inr (ty, value_to_term v) + | VLam (x, a, b, body) -> Lam (x, a, b, body) + | VTyLam (x, body) -> TyLam (x, body) + | VRoll (ty, body) -> Roll (ty, body) + +let rec value_of_term = function + | Int n -> Some (VInt n) + | Bool b -> Some (VBool b) + | Pair (a, b) -> begin + match value_of_term a, value_of_term b with + | Some va, Some vb -> Some (VPair (va, vb)) + | _ -> None + end + | Inl (ty, t) -> Option.map (fun v -> VInl (ty, v)) (value_of_term t) + | Inr (ty, t) -> Option.map (fun v -> VInr (ty, v)) (value_of_term t) + | Lam (x, a, b, body) -> Some (VLam (x, a, b, body)) + | TyLam (x, body) -> Some (VTyLam (x, body)) + | Roll (ty, body) -> Some (VRoll (ty, body)) + | _ -> None + +let eq_ground g va vb = + match g, va, vb with + | GInt, VInt a, VInt b -> Ok (Bool (a = b)) + | GBool, VBool a, VBool b -> Ok (Bool (Bool.equal a b)) + | _ -> Error "ground equality applied to non-ground values" + +let rec step term = + let stuck msg = Error msg in + let stepv t = step t in + match term with + | App (Lam (x, _, _, body), arg) when is_value arg -> Ok (substitute x arg body) + | App (TyLam _, _) -> stuck "type abstraction used as term function" + | App (f, a) when not (is_value f) -> Result.map (fun f' -> App (f', a)) (stepv f) + | App (f, a) when not (is_value a) -> Result.map (fun a' -> App (f, a')) (stepv a) + | App _ -> stuck "application stuck" + | Pair (a, b) when not (is_value a) -> Result.map (fun a' -> Pair (a', b)) (stepv a) + | Pair (a, b) when not (is_value b) -> Result.map (fun b' -> Pair (a, b')) (stepv b) + | Pair _ -> stuck "value" + | Inl (ty, t) when not (is_value t) -> Result.map (fun t' -> Inl (ty, t')) (stepv t) + | Inr (ty, t) when not (is_value t) -> Result.map (fun t' -> Inr (ty, t')) (stepv t) + | Eq (g, a, b) when not (is_value a) -> Result.map (fun a' -> Eq (g, a', b)) (stepv a) + | Eq (g, a, b) when not (is_value b) -> Result.map (fun b' -> Eq (g, a, b')) (stepv b) + | Eq (g, a, b) -> begin + match value_of_term a, value_of_term b with + | Some va, Some vb -> eq_ground g va vb + | _ -> stuck "equality stuck on non-values" + end + | If (Bool true, t, _) -> Ok t + | If (Bool false, _, e) -> Ok e + | If (c, t, e) when not (is_value c) -> Result.map (fun c' -> If (c', t, e)) (stepv c) + | If _ -> stuck "if scrutinee is not a bool" + | Let (x, a, b) when is_value a -> Ok (substitute x a b) + | Let (x, a, b) -> Result.map (fun a' -> Let (x, a', b)) (stepv a) + | LetPair (x, y, Pair (a, b), body) when is_value a && is_value b -> + Ok (substitute y b (substitute x a body)) + | LetPair (x, y, scrut, body) when not (is_value scrut) -> + Result.map (fun s' -> LetPair (x, y, s', body)) (stepv scrut) + | LetPair _ -> stuck "let-pair scrutinee is not a pair value" + | Case (Inl (_, v), (x, l), _) when is_value v -> Ok (substitute x v l) + | Case (Inr (_, v), _, (y, r)) when is_value v -> Ok (substitute y v r) + | Case (scrut, l, r) when not (is_value scrut) -> + Result.map (fun s' -> Case (s', l, r)) (stepv scrut) + | Case _ -> stuck "case scrutinee is not a sum value" + | TyApp (TyLam (x, body), ty) -> Ok (substitute_type x ty body) + | TyApp (t, ty) when not (is_value t) -> Result.map (fun t' -> TyApp (t', ty)) (stepv t) + | TyApp _ -> stuck "type application stuck" + | Unroll (Roll (_, body)) -> Ok body + | Unroll t when not (is_value t) -> Result.map (fun t' -> Unroll t') (stepv t) + | Unroll _ -> stuck "unroll on non-roll value" + | Fix (f, _, body) -> Ok (substitute f term body) + | Var x -> stuck ("free variable: " ^ x) + | Int _ | Bool _ | Lam _ | TyLam _ | Roll _ -> stuck "value" + | Inl _ | Inr _ -> stuck "sum injection expects value subterm" + +let evaluate ?(fuel = 256) term = + let rec go n acc t = + if n = 0 then { steps = List.rev (t :: acc); outcome = Diverged fuel } + else if is_value t then + match value_of_term t with + | Some v -> { steps = List.rev (t :: acc); outcome = Value v } + | None -> { steps = List.rev (t :: acc); outcome = Stuck "internal value decoding failure" } + else + match step t with + | Ok t' -> go (n - 1) (t :: acc) t' + | Error "value" -> + begin + match value_of_term t with + | Some v -> { steps = List.rev (t :: acc); outcome = Value v } + | None -> { steps = List.rev (t :: acc); outcome = Stuck "internal value decoding failure" } + end + | Error msg -> { steps = List.rev (t :: acc); outcome = Stuck msg } + in + go fuel [] term + +let observe ?fuel term = (evaluate ?fuel term).outcome diff --git a/src/source.mli b/src/source.mli new file mode 100644 index 0000000..c01012c --- /dev/null +++ b/src/source.mli @@ -0,0 +1,74 @@ +open Types + +type var = string + +type term = + | Var of var + | Int of int + | Bool of bool + | Pair of term * term + | Inl of typ * term + | Inr of typ * term + | Lam of var * typ * typ * term + | App of term * term + | TyLam of string * term + | TyApp of term * typ + | Roll of typ * term + | Unroll of term + | Fix of var * typ * term + | Eq of ground * term * term + | If of term * term * term + | Let of var * term * term + | LetPair of var * var * term * term + | Case of term * (var * term) * (var * term) + +type 'a expr = Expr : typ * term -> 'a expr +type packed_expr = Pack_expr : 'a expr -> packed_expr + +type value = + | VInt of int + | VBool of bool + | VPair of value * value + | VInl of typ * value + | VInr of typ * value + | VLam of var * typ * typ * term + | VTyLam of string * term + | VRoll of typ * term + +type frame = + | FAppL of term + | FAppR of value + | FPairL of term + | FPairR of value + | FInl of typ + | FInr of typ + | FIf of term * term + | FEqL of ground * term + | FEqR of ground * value + | FLet of var * term + | FLetPair of var * var * term + | FCase of (var * term) * (var * term) + | FTyApp of typ + | FUnroll + +type outcome = + | Value of value + | Stuck of string + | Diverged of int + +type trace = { + steps : term list; + outcome : outcome; +} + +val typ_of : 'a expr -> typ +val pack : typ -> term -> packed_expr +val string_of_term : term -> string +val string_of_value : value -> string +val is_value : term -> bool +val substitute : var -> term -> term -> term +val substitute_type : string -> typ -> term -> term +val step : term -> (term, string) result +val evaluate : ?fuel:int -> term -> trace +val observe : ?fuel:int -> term -> outcome +val value_to_term : value -> term diff --git a/src/target.ml b/src/target.ml new file mode 100644 index 0000000..0b25726 --- /dev/null +++ b/src/target.ml @@ -0,0 +1,274 @@ +open Types + +type calling_conv = + | Boxed + | Unboxed + +type repr = + | RInt + | RBool + | RBox of typ + | RTuple of repr list + | RSum of repr * repr + | RFun of calling_conv * repr list * repr + +type var = string + +type term = + | Var of var + | Int of int + | Bool of bool + | Tuple of term list + | Proj of int * term + | Inl of repr * repr * term + | Inr of repr * repr * term + | Case of term * (var * term) * (var * term) + | Lam of calling_conv * (var * repr) list * repr * term + | App of term * term list + | Let of var * term * term + | LetRec of var * repr * term * term + | EqInt of term * term + | EqBool of term * term + | If of term * term * term + | Box of typ * term + | Unbox of term + | Roll of typ * term + | Unroll of term + | WorkerWrapper of worker_wrapper + | Halt of term + +and worker_wrapper = { + wrapper : var; + worker : var; + boxed_arg : typ; + unboxed_args : repr list; + result_repr : repr; + wrap_body : term; + worker_body : term; + in_term : term; +} + +type value = + | VInt of int + | VBool of bool + | VTuple of value list + | VInl of repr * repr * value + | VInr of repr * repr * value + | VLam of calling_conv * (var * repr) list * repr * term + | VBox of typ * value + | VRoll of typ * term + +type outcome = + | Value of value + | Stuck of string + | Diverged of int + +type trace = { + steps : term list; + outcome : outcome; +} + +let rec string_of_repr = function + | RInt -> "i#" + | RBool -> "b#" + | RBox ty -> "box[" ^ Types.string_of_typ ty ^ "]" + | RTuple rs -> "(" ^ String.concat " *# " (List.map string_of_repr rs) ^ ")" + | RSum (a, b) -> "(" ^ string_of_repr a ^ " +# " ^ string_of_repr b ^ ")" + | RFun (cc, args, ret) -> + let cc_s = match cc with Boxed -> "boxed" | Unboxed -> "unboxed" in + cc_s ^ "(" ^ String.concat ", " (List.map string_of_repr args) ^ " -> " ^ string_of_repr ret ^ ")" + +let rec string_of_term = function + | Var x -> x + | Int n -> string_of_int n ^ "#" + | Bool b -> string_of_bool b ^ "#" + | Tuple xs -> "(#" ^ String.concat ", " (List.map string_of_term xs) ^ "#)" + | Proj (i, t) -> "(proj" ^ string_of_int i ^ " " ^ string_of_term t ^ ")" + | Inl (_, _, t) -> "(inl# " ^ string_of_term t ^ ")" + | Inr (_, _, t) -> "(inr# " ^ string_of_term t ^ ")" + | Case (s, (x, l), (y, r)) -> + "(case# " ^ string_of_term s ^ " of inl# " ^ x ^ " -> " ^ string_of_term l ^ + " | inr# " ^ y ^ " -> " ^ string_of_term r ^ ")" + | Lam (_, xs, _, body) -> + let args = xs |> List.map fst |> String.concat " " in + "(fun# " ^ args ^ " -> " ^ string_of_term body ^ ")" + | App (f, xs) -> "(" ^ string_of_term f ^ " " ^ String.concat " " (List.map string_of_term xs) ^ ")" + | Let (x, a, b) -> "(let# " ^ x ^ " = " ^ string_of_term a ^ " in " ^ string_of_term b ^ ")" + | LetRec (x, _, a, b) -> "(letrec# " ^ x ^ " = " ^ string_of_term a ^ " in " ^ string_of_term b ^ ")" + | EqInt (a, b) -> "(eq_int# " ^ string_of_term a ^ " " ^ string_of_term b ^ ")" + | EqBool (a, b) -> "(eq_bool# " ^ string_of_term a ^ " " ^ string_of_term b ^ ")" + | If (c, t, e) -> + "(if# " ^ string_of_term c ^ " then " ^ string_of_term t ^ " else " ^ string_of_term e ^ ")" + | Box (_, t) -> "(box " ^ string_of_term t ^ ")" + | Unbox t -> "(unbox " ^ string_of_term t ^ ")" + | Roll (_, t) -> "(roll# " ^ string_of_term t ^ ")" + | Unroll t -> "(unroll# " ^ string_of_term t ^ ")" + | WorkerWrapper ww -> "(worker-wrapper " ^ ww.wrapper ^ "/" ^ ww.worker ^ " in " ^ string_of_term ww.in_term ^ ")" + | Halt t -> "(halt " ^ string_of_term t ^ ")" + +let rec is_value = function + | Int _ | Bool _ | Lam _ -> true + | Tuple xs -> List.for_all is_value xs + | Inl (_, _, t) | Inr (_, _, t) -> is_value t + | Box (_, t) -> is_value t + | Roll _ -> true + | _ -> false + +let rec value_of_term = function + | Int n -> Some (VInt n) + | Bool b -> Some (VBool b) + | Tuple xs -> + let rec collect acc = function + | [] -> Some (VTuple (List.rev acc)) + | x :: rest -> + begin match value_of_term x with + | Some v -> collect (v :: acc) rest + | None -> None + end + in + collect [] xs + | Inl (a, b, t) -> Option.map (fun v -> VInl (a, b, v)) (value_of_term t) + | Inr (a, b, t) -> Option.map (fun v -> VInr (a, b, v)) (value_of_term t) + | Lam (cc, xs, r, body) -> Some (VLam (cc, xs, r, body)) + | Box (ty, t) -> Option.map (fun v -> VBox (ty, v)) (value_of_term t) + | Roll (ty, body) -> Some (VRoll (ty, body)) + | _ -> None + +let rec substitute x repl term = + let go = substitute x repl in + match term with + | Var y -> if String.equal x y then repl else term + | Int _ | Bool _ -> term + | Tuple xs -> Tuple (List.map go xs) + | Proj (i, t) -> Proj (i, go t) + | Inl (a, b, t) -> Inl (a, b, go t) + | Inr (a, b, t) -> Inr (a, b, go t) + | Case (s, (y, l), (z, r)) -> + Case (go s, (y, if String.equal x y then l else go l), (z, if String.equal x z then r else go r)) + | Lam (cc, xs, r, body) -> + if List.exists (fun (y, _) -> String.equal x y) xs then term + else Lam (cc, xs, r, go body) + | App (f, xs) -> App (go f, List.map go xs) + | Let (y, a, b) -> Let (y, go a, if String.equal x y then b else go b) + | LetRec (y, r, a, b) -> + if String.equal x y then term else LetRec (y, r, go a, go b) + | EqInt (a, b) -> EqInt (go a, go b) + | EqBool (a, b) -> EqBool (go a, go b) + | If (c, t, e) -> If (go c, go t, go e) + | Box (ty, t) -> Box (ty, go t) + | Unbox t -> Unbox (go t) + | Roll (ty, t) -> Roll (ty, go t) + | Unroll t -> Unroll (go t) + | WorkerWrapper ww -> + WorkerWrapper { ww with wrap_body = go ww.wrap_body; worker_body = go ww.worker_body; in_term = go ww.in_term } + | Halt t -> Halt (go t) + +let rec apply_many body bindings = + match bindings with + | [] -> body + | (x, v) :: rest -> apply_many (substitute x v body) rest + +let rec step term = + let rec eval_list prefix suffix k = + match suffix with + | [] -> k (List.rev prefix) + | x :: xs when is_value x -> eval_list (x :: prefix) xs k + | x :: xs -> + Result.map + (fun x' -> Tuple (List.rev_append prefix (x' :: xs))) + (step x) + in + match term with + | App (Lam (_, params, _, body), args) when List.length params = List.length args && List.for_all is_value args -> + let binds = List.map2 (fun (x, _) arg -> (x, arg)) params args in + Ok (apply_many body binds) + | App (f, args) when not (is_value f) -> Result.map (fun f' -> App (f', args)) (step f) + | App (f, args) -> + let rec go acc rest = + match rest with + | [] -> Error "target application stuck" + | a :: xs when is_value a -> go (a :: acc) xs + | a :: xs -> Result.map (fun a' -> App (f, List.rev_append acc (a' :: xs))) (step a) + in + go [] args + | Tuple xs -> eval_list [] xs (fun _ -> Error "value") + | Proj (i, Tuple xs) when List.for_all is_value xs -> + begin match List.nth_opt xs i with + | Some v -> Ok v + | None -> Error "tuple projection out of bounds" + end + | Proj (i, t) when not (is_value t) -> Result.map (fun t' -> Proj (i, t')) (step t) + | Proj _ -> Error "projection on non-tuple" + | Inl (a, b, t) when not (is_value t) -> Result.map (fun t' -> Inl (a, b, t')) (step t) + | Inr (a, b, t) when not (is_value t) -> Result.map (fun t' -> Inr (a, b, t')) (step t) + | Case (Inl (_, _, v), (x, l), _) when is_value v -> Ok (substitute x v l) + | Case (Inr (_, _, v), _, (y, r)) when is_value v -> Ok (substitute y v r) + | Case (s, l, r) when not (is_value s) -> Result.map (fun s' -> Case (s', l, r)) (step s) + | Case _ -> Error "case on non-sum value" + | Let (x, a, b) when is_value a -> Ok (substitute x a b) + | Let (x, a, b) -> Result.map (fun a' -> Let (x, a', b)) (step a) + | LetRec (x, _, defn, body) -> Ok (substitute x (LetRec (x, RBox TInt, defn, defn)) body) + | EqInt (Int a, Int b) -> Ok (Bool (a = b)) + | EqInt (a, b) when not (is_value a) -> Result.map (fun a' -> EqInt (a', b)) (step a) + | EqInt (a, b) when not (is_value b) -> Result.map (fun b' -> EqInt (a, b')) (step b) + | EqInt _ -> Error "eq_int on non-int values" + | EqBool (Bool a, Bool b) -> Ok (Bool (Bool.equal a b)) + | EqBool (a, b) when not (is_value a) -> Result.map (fun a' -> EqBool (a', b)) (step a) + | EqBool (a, b) when not (is_value b) -> Result.map (fun b' -> EqBool (a, b')) (step b) + | EqBool _ -> Error "eq_bool on non-bool values" + | If (Bool true, t, _) -> Ok t + | If (Bool false, _, e) -> Ok e + | If (c, t, e) when not (is_value c) -> Result.map (fun c' -> If (c', t, e)) (step c) + | If _ -> Error "if# on non-bool value" + | Box (ty, t) when not (is_value t) -> Result.map (fun t' -> Box (ty, t')) (step t) + | Unbox (Box (_, v)) when is_value v -> Ok v + | Unbox t when not (is_value t) -> Result.map (fun t' -> Unbox t') (step t) + | Unbox _ -> Error "unbox on non-box" + | Unroll (Roll (_, t)) -> Ok t + | Unroll t when not (is_value t) -> Result.map (fun t' -> Unroll t') (step t) + | Unroll _ -> Error "unroll on non-roll value" + | WorkerWrapper ww -> + Ok + (Let + ( ww.worker, + Lam (Unboxed, List.mapi (fun i r -> ("u" ^ string_of_int i, r)) ww.unboxed_args, ww.result_repr, ww.worker_body), + Let + ( ww.wrapper, + Lam (Boxed, [("boxed", RBox ww.boxed_arg)], ww.result_repr, ww.wrap_body), + ww.in_term ) )) + | Halt t when not (is_value t) -> Result.map (fun t' -> Halt t') (step t) + | Halt _ -> Error "value" + | Var x -> Error ("free variable: " ^ x) + | Int _ | Bool _ | Lam _ | Box _ | Roll _ -> Error "value" + | Inl _ | Inr _ -> Error "sum injection expects value subterm" + +let evaluate ?(fuel = 256) term = + let rec go n acc t = + if n = 0 then { steps = List.rev (t :: acc); outcome = Diverged fuel } + else if is_value t then + match value_of_term t with + | Some v -> { steps = List.rev (t :: acc); outcome = Value v } + | None -> { steps = List.rev (t :: acc); outcome = Stuck "internal target value decoding failure" } + else + match step t with + | Ok t' -> go (n - 1) (t :: acc) t' + | Error "value" -> + begin match value_of_term t with + | Some v -> { steps = List.rev (t :: acc); outcome = Value v } + | None -> { steps = List.rev (t :: acc); outcome = Stuck "internal target value decoding failure" } + end + | Error msg -> { steps = List.rev (t :: acc); outcome = Stuck msg } + in + go fuel [] term + +let observe ?fuel term = (evaluate ?fuel term).outcome + +let string_of_value = function + | VInt n -> string_of_int n ^ "#" + | VBool b -> string_of_bool b ^ "#" + | VTuple _ -> "" + | VInl _ -> "" + | VInr _ -> "" + | VLam _ -> "" + | VBox (_, _) -> "" + | VRoll _ -> "" diff --git a/src/target.mli b/src/target.mli new file mode 100644 index 0000000..c9d48e1 --- /dev/null +++ b/src/target.mli @@ -0,0 +1,77 @@ +open Types + +type calling_conv = + | Boxed + | Unboxed + +type repr = + | RInt + | RBool + | RBox of typ + | RTuple of repr list + | RSum of repr * repr + | RFun of calling_conv * repr list * repr + +type var = string + +type term = + | Var of var + | Int of int + | Bool of bool + | Tuple of term list + | Proj of int * term + | Inl of repr * repr * term + | Inr of repr * repr * term + | Case of term * (var * term) * (var * term) + | Lam of calling_conv * (var * repr) list * repr * term + | App of term * term list + | Let of var * term * term + | LetRec of var * repr * term * term + | EqInt of term * term + | EqBool of term * term + | If of term * term * term + | Box of typ * term + | Unbox of term + | Roll of typ * term + | Unroll of term + | WorkerWrapper of worker_wrapper + | Halt of term + +and worker_wrapper = { + wrapper : var; + worker : var; + boxed_arg : typ; + unboxed_args : repr list; + result_repr : repr; + wrap_body : term; + worker_body : term; + in_term : term; +} + +type value = + | VInt of int + | VBool of bool + | VTuple of value list + | VInl of repr * repr * value + | VInr of repr * repr * value + | VLam of calling_conv * (var * repr) list * repr * term + | VBox of typ * value + | VRoll of typ * term + +type outcome = + | Value of value + | Stuck of string + | Diverged of int + +type trace = { + steps : term list; + outcome : outcome; +} + +val string_of_repr : repr -> string +val string_of_term : term -> string +val string_of_value : value -> string +val is_value : term -> bool +val step : term -> (term, string) result +val evaluate : ?fuel:int -> term -> trace +val observe : ?fuel:int -> term -> outcome diff --git a/src/typecheck.ml b/src/typecheck.ml new file mode 100644 index 0000000..351e81b --- /dev/null +++ b/src/typecheck.ml @@ -0,0 +1,128 @@ +open Types + +type error = string + +let bind r f = + match r with + | Ok x -> f x + | Error _ as e -> e + +let ( let* ) = bind + +let rec lookup x = function + | [] -> Error ("unbound variable " ^ x) + | (y, ty) :: rest -> if String.equal x y then Ok ty else lookup x rest + +let expect expected actual = + if equal_typ expected actual then Ok () + else Error ("expected " ^ string_of_typ expected ^ " but got " ^ string_of_typ actual) + +let ground_type = function + | GInt -> TInt + | GBool -> TBool + +let rec type_of_env env term = + match term with + | Source.Var x -> lookup x env + | Source.Int _ -> Ok TInt + | Source.Bool _ -> Ok TBool + | Source.Pair (a, b) -> + let* ta = type_of_env env a in + let* tb = type_of_env env b in + Ok (TPair (ta, tb)) + | Source.Inl (TSum (l, r), payload) -> + let* tp = type_of_env env payload in + let* () = expect l tp in + Ok (TSum (l, r)) + | Source.Inl _ -> Error "inl annotation must be a sum type" + | Source.Inr (TSum (l, r), payload) -> + let* tp = type_of_env env payload in + let* () = expect r tp in + Ok (TSum (l, r)) + | Source.Inr _ -> Error "inr annotation must be a sum type" + | Source.Lam (x, arg_ty, res_ty, body) -> + let* body_ty = type_of_env ((x, arg_ty) :: env) body in + let* () = expect res_ty body_ty in + Ok (TArrow (arg_ty, res_ty)) + | Source.App (f, arg) -> + let* tf = type_of_env env f in + let* ta = type_of_env env arg in + begin match tf with + | TArrow (dom, cod) -> + let* () = expect dom ta in + Ok cod + | _ -> Error ("application expected function but got " ^ string_of_typ tf) + end + | Source.TyLam (a, body) -> + let* body_ty = type_of_env env body in + Ok (TForall (a, body_ty)) + | Source.TyApp (t, ty) -> + let* tf = type_of_env env t in + begin match tf with + | TForall (a, body) -> Ok (substitute_typ a ty body) + | _ -> Error ("type application expected forall but got " ^ string_of_typ tf) + end + | Source.Roll (ty, payload) -> + begin match ty with + | TMu (a, body) -> + let unfolded = substitute_typ a ty body in + let* payload_ty = type_of_env env payload in + let* () = expect unfolded payload_ty in + Ok ty + | _ -> Error "roll annotation must be recursive" + end + | Source.Unroll t -> + let* ty = type_of_env env t in + begin match ty with + | TMu (a, body) -> Ok (substitute_typ a ty body) + | _ -> Error ("unroll expected recursive type but got " ^ string_of_typ ty) + end + | Source.Fix (f, ty, body) -> + let* body_ty = type_of_env ((f, ty) :: env) body in + let* () = expect ty body_ty in + Ok ty + | Source.Eq (g, a, b) -> + let expected = ground_type g in + let* ta = type_of_env env a in + let* tb = type_of_env env b in + let* () = expect expected ta in + let* () = expect expected tb in + Ok TBool + | Source.If (c, t, e) -> + let* tc = type_of_env env c in + let* () = expect TBool tc in + let* tt = type_of_env env t in + let* te = type_of_env env e in + let* () = expect tt te in + Ok tt + | Source.Let (x, a, b) -> + let* ta = type_of_env env a in + type_of_env ((x, ta) :: env) b + | Source.LetPair (x, y, scrut, body) -> + let* ts = type_of_env env scrut in + begin match ts with + | TPair (a, b) -> type_of_env ((y, b) :: (x, a) :: env) body + | _ -> Error ("let-pair expected pair but got " ^ string_of_typ ts) + end + | Source.Case (scrut, (x, l), (y, r)) -> + let* ts = type_of_env env scrut in + begin match ts with + | TSum (a, b) -> + let* tl = type_of_env ((x, a) :: env) l in + let* tr = type_of_env ((y, b) :: env) r in + let* () = expect tl tr in + Ok tl + | _ -> Error ("case expected sum but got " ^ string_of_typ ts) + end + +let type_of term = type_of_env [] term + +let check expected term = + let* actual = type_of term in + expect expected actual + +let is_well_typed expected term = + match check expected term with + | Ok () -> true + | Error _ -> false + diff --git a/src/typecheck.mli b/src/typecheck.mli new file mode 100644 index 0000000..8984e25 --- /dev/null +++ b/src/typecheck.mli @@ -0,0 +1,8 @@ +open Types + +type error = string + +val type_of : Source.term -> (typ, error) result +val check : typ -> Source.term -> (unit, error) result +val is_well_typed : typ -> Source.term -> bool + diff --git a/src/types.ml b/src/types.ml new file mode 100644 index 0000000..c1f3180 --- /dev/null +++ b/src/types.ml @@ -0,0 +1,70 @@ +type ground = + | GInt + | GBool + +type typ = + | TInt + | TBool + | TPair of typ * typ + | TSum of typ * typ + | TArrow of typ * typ + | TForall of string * typ + | TMu of string * typ + | TVar of string + +type _ witness = + | WInt : int witness + | WBool : bool witness + | WPair : 'a witness * 'b witness -> ('a * 'b) witness + | WDynamic : typ -> Obj.t witness + +type packed = Pack : 'a witness * 'a -> packed + +let equal_ground a b = + match a, b with + | GInt, GInt | GBool, GBool -> true + | _ -> false + +let rec equal_typ a b = + match a, b with + | TInt, TInt | TBool, TBool -> true + | TPair (a1, a2), TPair (b1, b2) + | TSum (a1, a2), TSum (b1, b2) + | TArrow (a1, a2), TArrow (b1, b2) -> equal_typ a1 b1 && equal_typ a2 b2 + | TForall (xa, ta), TForall (xb, tb) + | TMu (xa, ta), TMu (xb, tb) -> + String.equal xa xb && equal_typ ta tb + | TVar xa, TVar xb -> String.equal xa xb + | _ -> false + +let string_of_ground = function + | GInt -> "int" + | GBool -> "bool" + +let rec string_of_typ = function + | TInt -> "int" + | TBool -> "bool" + | TVar x -> x + | TPair (a, b) -> "(" ^ string_of_typ a ^ " * " ^ string_of_typ b ^ ")" + | TSum (a, b) -> "(" ^ string_of_typ a ^ " + " ^ string_of_typ b ^ ")" + | TArrow (a, b) -> "(" ^ string_of_typ a ^ " -> " ^ string_of_typ b ^ ")" + | TForall (x, t) -> "(forall " ^ x ^ ". " ^ string_of_typ t ^ ")" + | TMu (x, t) -> "(mu " ^ x ^ ". " ^ string_of_typ t ^ ")" + +let is_ground = function + | TInt -> Some GInt + | TBool -> Some GBool + | _ -> None + +let rec substitute_typ x repl ty = + match ty with + | TInt | TBool -> ty + | TVar y -> if String.equal x y then repl else ty + | TPair (a, b) -> TPair (substitute_typ x repl a, substitute_typ x repl b) + | TSum (a, b) -> TSum (substitute_typ x repl a, substitute_typ x repl b) + | TArrow (a, b) -> TArrow (substitute_typ x repl a, substitute_typ x repl b) + | TForall (y, body) -> + if String.equal x y then ty else TForall (y, substitute_typ x repl body) + | TMu (y, body) -> + if String.equal x y then ty else TMu (y, substitute_typ x repl body) + diff --git a/src/types.mli b/src/types.mli new file mode 100644 index 0000000..d43bfe3 --- /dev/null +++ b/src/types.mli @@ -0,0 +1,29 @@ +type ground = + | GInt + | GBool + +type typ = + | TInt + | TBool + | TPair of typ * typ + | TSum of typ * typ + | TArrow of typ * typ + | TForall of string * typ + | TMu of string * typ + | TVar of string + +type _ witness = + | WInt : int witness + | WBool : bool witness + | WPair : 'a witness * 'b witness -> ('a * 'b) witness + | WDynamic : typ -> Obj.t witness + +type packed = Pack : 'a witness * 'a -> packed + +val equal_typ : typ -> typ -> bool +val equal_ground : ground -> ground -> bool +val string_of_typ : typ -> string +val string_of_ground : ground -> string +val is_ground : typ -> ground option +val substitute_typ : string -> typ -> typ -> typ + diff --git a/src/vanity.ml b/src/vanity.ml new file mode 100644 index 0000000..09026d1 --- /dev/null +++ b/src/vanity.ml @@ -0,0 +1,11 @@ +module Types = Types +module Source = Source +module Target = Target +module Pipeline = Pipeline +module Relation = Relation +module Typecheck = Typecheck +module Audit = Audit +module Gen = Gen +module Corpus = Corpus +module Reporting = Reporting +module Project = Project diff --git a/test/dune b/test/dune new file mode 100644 index 0000000..691b1f1 --- /dev/null +++ b/test/dune @@ -0,0 +1,4 @@ +(test + (name invariants) + (libraries vanity)) + diff --git a/test/invariants.ml b/test/invariants.ml new file mode 100644 index 0000000..a1db831 --- /dev/null +++ b/test/invariants.ml @@ -0,0 +1,32 @@ +open Vanity + +let assert_true msg b = + if not b then failwith msg + +let find_case name = + match List.find_opt (fun (case : Corpus.case) -> String.equal case.name name) Corpus.all with + | Some case -> case + | None -> failwith ("missing case " ^ name) + +let () = + List.iter + (fun (case : Corpus.case) -> + assert_true + ("ill-typed corpus case " ^ case.name) + (Typecheck.is_well_typed case.ty case.source)) + Corpus.all; + let repr = Audit.audit_case (find_case "free-theorem-fails-after-unsafe-inlining") in + assert_true + "expected representation exposure witness" + (repr.failure_mode = Audit.Representation_exposure); + let strict = Audit.audit_case (find_case "strictness-induced-termination-change") in + assert_true + "expected strictness shift witness" + (strict.failure_mode = Audit.Strictness_shift); + let generated = Gen.sample_terms ~count:80 ~max_depth:4 () in + List.iter + (fun specimen -> + assert_true + ("ill-typed generated specimen " ^ Source.string_of_term specimen.Gen.term) + (Typecheck.is_well_typed specimen.Gen.ty specimen.Gen.term)) + generated diff --git a/vanity.opam b/vanity.opam new file mode 100644 index 0000000..60ece67 --- /dev/null +++ b/vanity.opam @@ -0,0 +1,28 @@ +# This file is generated by dune, edit dune-project instead +opam-version: "2.0" +synopsis: + "parametricity failure modes under specialisation inlining and unboxing" +description: + "typed source and target IRs, interpreters, logical relations, and counterexample search" +maintainer: ["codex"] +authors: ["codex"] +license: "MIT" +depends: [ + "dune" {>= "3.22"} + "odoc" {with-doc} +] +build: [ + ["dune" "subst"] {dev} + [ + "dune" + "build" + "-p" + name + "-j" + jobs + "@install" + "@runtest" {with-test} + "@doc" {with-doc} + ] +] +x-maintenance-intent: ["(latest)"]