Skip to content

Commit

Permalink
refactor: get rid of hacky datatype_fn in synth_ty
Browse files Browse the repository at this point in the history
  • Loading branch information
brandonspark committed Feb 19, 2024
1 parent 2fa43a5 commit 9cb5a66
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 80 deletions.
53 changes: 12 additions & 41 deletions src/context/context.sml
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,7 @@ signature CONTEXT =

(* Type stuff. *)
val synth_ty :
(S.symbol * S.tyval list -> S.tyval option)
-> (S.symbol -> S.tyval option)
(S.symbol -> S.tyval option)
-> S.context
-> S.ty
-> S.tyval
Expand All @@ -138,8 +137,7 @@ signature CONTEXT =
val norm_tyval : t -> S.tyval -> S.tyval

val mk_type_scheme :
(S.symbol * S.tyval list -> S.tyval option)
-> S.symbol list
S.symbol list
-> S.ty
-> t
-> S.type_scheme
Expand All @@ -148,11 +146,7 @@ signature CONTEXT =

val get_current_tyvars : (S.tyval -> S.tyvar list) -> t -> S.tyvar list

val add_datbind :
(S.symbol * S.tyval list -> S.tyval option)
-> t
-> TyId.t * S.datbind
-> t
val add_datbind : t -> TyId.t * S.datbind -> t
end

(*****************************************************************************)
Expand Down Expand Up @@ -1029,9 +1023,9 @@ structure Context : CONTEXT =
we find something which is in that definition, we will just handle its
type
*)
fun synth_ty datatype_fn tyvar_fn ctx ty =
fun synth_ty tyvar_fn ctx ty =
let
val synth_ty = fn ctx => fn ty => synth_ty datatype_fn tyvar_fn ctx ty
val synth_ty = fn ctx => fn ty => synth_ty tyvar_fn ctx ty

fun handle_type_synonym tyvals id =
case get_type_synonym ctx id of
Expand All @@ -1043,25 +1037,15 @@ structure Context : CONTEXT =
f tyvals
in
case ty of
Tident [id] =>
(case datatype_fn (id, []) of
NONE => handle_type_synonym [] [id]
| SOME ty => ty
)
| Tident longid =>
Tident longid =>
handle_type_synonym [] longid
| Tapp (tys, longid) =>
handle_type_synonym (List.map (synth_ty ctx) tys) longid
| Ttyvar sym =>
(case tyvar_fn sym of
NONE => TVtyvar sym
| SOME tyval => tyval
)
| Tapp (tys, [id]) =>
(case datatype_fn (id, List.map (synth_ty ctx) tys) of
NONE => handle_type_synonym (List.map (synth_ty ctx) tys) [id]
| SOME ty => ty
)
| Tapp (tys, longid) =>
handle_type_synonym (List.map (synth_ty ctx) tys) longid
| Tprod tys =>
TVprod (List.map (synth_ty ctx) tys)
| Tarrow (t1, t2) =>
Expand All @@ -1072,7 +1056,7 @@ structure Context : CONTEXT =
| Tparens ty => synth_ty ctx ty
end

fun synth_ty' ctx ty = synth_ty (fn _ => NONE) (fn _ => NONE) ctx ty
fun synth_ty' ctx ty = synth_ty (fn _ => NONE) ctx ty

fun norm_tyval ctx tyval =
case tyval of
Expand Down Expand Up @@ -1101,7 +1085,7 @@ structure Context : CONTEXT =
TVarrow (norm_tyval ctx t1, norm_tyval ctx t2)
| TVtyvar sym => TVtyvar sym

fun mk_type_scheme datatype_fn tyvars ty ctx =
fun mk_type_scheme tyvars ty ctx =
let
val arity = List.length tyvars
in
Expand All @@ -1117,27 +1101,16 @@ structure Context : CONTEXT =
case List.find (fn (tyvar, _) => Symbol.eq (sym, tyvar)) paired of
NONE => NONE
| SOME (_, ty) => SOME ty

val default_datatype_fn =
fn (sym, tyvals) =>
case datatype_fn (sym, tyvals) of
SOME ty => SOME ty
| _ =>
case get_type_synonym_opt ctx [sym] of
SOME (Datatype tyid) => SOME (TVapp (tyvals, tyid))
| SOME (Scheme (_, f)) =>
SOME (f tyvals)
| NONE => NONE
in
if List.length tyvals <> arity then
prog_err "invalid arity for instantiated type scheme"
else
synth_ty default_datatype_fn tyvar_fn ctx ty
synth_ty tyvar_fn ctx ty
end
)
end

fun add_datbind datatype_fn (ctx : SMLSyntax.context) (tyid, {tyvars, tycon, conbinds}) =
fun add_datbind (ctx : SMLSyntax.context) (tyid, {tyvars, tycon, conbinds}) =
lift (fn (scope as (Scope {identdict, valtydict, tynamedict, ...}), rest) =>
let
val dtydict = ! (#dtydict ctx)
Expand All @@ -1151,13 +1124,11 @@ structure Context : CONTEXT =
case ty of
NONE =>
mk_type_scheme
datatype_fn
tyvars
(Tapp (List.map Ttyvar tyvars, [tycon]))
ctx
| SOME ty =>
mk_type_scheme
datatype_fn
tyvars
(Tarrow (ty, Tapp (List.map Ttyvar tyvars, [tycon])))
ctx
Expand Down
55 changes: 32 additions & 23 deletions src/context/value.sml
Original file line number Diff line number Diff line change
Expand Up @@ -355,27 +355,43 @@ structure Value : VALUE =
spec
(sigval as Sigval {valspecs, tyspecs, dtyspecs, exnspecs, modspecs}) =
let
val ctx_with_tyspecs =
SymDict.foldl (fn (sym, {status, ...}, ctx) =>
let
val synonym =
case status of
Abstract (_, id) => raise Fail "TODO"
| Concrete (i, ty_fn) => Scheme (i, ty_fn)
in
Context.add_type_synonym ctx sym synonym
end
) ctx tyspecs

val augmented_ctx =
SymDict.foldl (fn (sym, {tyid, ...}, ctx) =>
Context.add_type_synonym ctx sym (Datatype tyid)
) ctx_with_tyspecs dtyspecs

(* We search in the original tyspec, because all of these are
* "simultaneous" and don't see each other.
*)
fun spec_datatype_fn (sym, tyvals) =
(* fun spec_datatype_fn (sym, tyvals) =
case (SymDict.find tyspecs sym, SymDict.find dtyspecs sym) of
(SOME _, SOME _) => prog_err "should be impossible"
| (SOME { status = Abstract (_, id), ... }, _) => SOME (TVabs (tyvals, id))
| (SOME { status = Concrete (_, ty_fn), ... }, _) => SOME (ty_fn tyvals)
| (_, SOME {tyid, ...}) => SOME (TVapp (tyvals, tyid))
| (NONE, NONE) => NONE
| (NONE, NONE) => NONE *)

(* This function is supposed to allow us to get the type scheme which
* takes into account the abstract types and datatypes defined
* previously in the signature.
*)
fun get_type_scheme tyvars ty =
Context.mk_type_scheme
spec_datatype_fn
tyvars
ty
ctx
augmented_ctx

(* This code is kinda complicated because structure sharing allows
* enclosed types to be concrete, but only if they do not overlap in
Expand Down Expand Up @@ -504,7 +520,7 @@ structure Value : VALUE =
*)
SymDict.insert tyspecs tycon
{ equality = false
, status = Concrete (mk_type_scheme spec_datatype_fn tyvars ty ctx)
, status = Concrete (mk_type_scheme tyvars ty augmented_ctx)
}
)
tyspecs
Expand Down Expand Up @@ -540,21 +556,13 @@ structure Value : VALUE =
val enum_datbinds =
List.map (fn datbind as {tycon, ...} => (datbind, TyId.new (SOME tycon))) datbinds

(* This function both can look through the previous typdescs, as
* well as any of the mutually recursive datatypes.
*)
fun datatype_fn (sym, tyvals) =
case
( spec_datatype_fn (sym, tyvals)
, List.find
(fn ({tycon, ...}, _) => Symbol.eq (sym, tycon))
enum_datbinds
val augmented_ctx =
List.foldl
(fn (({tycon, ...}, tyid), ctx) =>
Context.add_type_synonym ctx tycon (Datatype tyid)
)
of
(NONE, NONE) => NONE
| (SOME _, SOME _) => raise Fail "shouldn't be possible"
| (SOME ty, _) => SOME ty
| (_, SOME (_, tyid)) => SOME (TVapp (tyvals, tyid))
ctx
enum_datbinds
in
{ valspecs = valspecs
, tyspecs = tyspecs
Expand All @@ -577,13 +585,14 @@ structure Value : VALUE =
case ty of
NONE =>
Context.mk_type_scheme
datatype_fn
tyvars
(Tapp (List.map Ttyvar tyvars, [tycon]))
ctx
augmented_ctx
| SOME ty =>
Context.mk_type_scheme datatype_fn tyvars
(Tarrow (ty, Tident [tycon])) ctx
Context.mk_type_scheme
tyvars
(Tarrow (ty, Tident [tycon]))
augmented_ctx
}
)
condescs
Expand Down
38 changes: 23 additions & 15 deletions src/statics/statics.sml
Original file line number Diff line number Diff line change
Expand Up @@ -676,7 +676,7 @@ structure Statics : STATICS =
unify
ctx
(synth ctx exp) (* TODO: check these functions *)
(Context.synth_ty (fn _ => NONE) (fn _ => NONE) ctx ty)
(Context.synth_ty (fn _ => NONE) ctx ty)
| ( Eandalso {left, right} | Eorelse {left, right} ) =>
( unify ctx (synth ctx left) B.bool_ty
; unify ctx (synth ctx right) B.bool_ty
Expand Down Expand Up @@ -1349,7 +1349,7 @@ structure Statics : STATICS =
List.map
(fn {tycon, tyvars, ty} =>
( tycon
, Scheme (Context.mk_type_scheme (fn _ => NONE) tyvars ty ctx)
, Scheme (Context.mk_type_scheme tyvars ty ctx)
)
)
typbinds
Expand All @@ -1372,14 +1372,13 @@ structure Statics : STATICS =
let
val num = List.length tyvars

(* This function takes in a datatype name, and checks
* whether it is the same as the name of these datatypes.
* If so, it extracts it into a tyval.
*)
val datatype_fn =
fn (sym, tyvals) =>
List.find (fn ({tycon, ...}, _) => Symbol.eq (tycon, sym)) enum_datbinds
|> Option.map (fn (_, id) => TVapp (tyvals, id))
val ctx =
List.foldl
(fn (({tycon, ...}, tyid), ctx) =>
Context.add_type_synonym ctx tycon (Datatype tyid)
)
ctx
enum_datbinds

fun mk_tyvar_fn tys =
let
Expand All @@ -1400,7 +1399,7 @@ structure Statics : STATICS =
(`"Arity mismatch when instantiating type scheme.")
|> type_err
else
Context.synth_ty datatype_fn (mk_tyvar_fn tys) ctx ty
Context.synth_ty (mk_tyvar_fn tys) ctx ty
in
(* The name of this type maps to the type scheme which
* abstracts over the number of type arguments to the type,
Expand All @@ -1420,16 +1419,25 @@ structure Statics : STATICS =
ctx
withtypee_bindings

(* We must add the datatypes into the tynamedict first,
so that they can see each other while type-checking the
datatypes.
*)
val ctx =
List.foldl
(fn (({tycon, ...}, tyid), ctx) =>
Context.add_type_synonym ctx tycon (Datatype tyid)
)
ctx
enum_datbinds

(* add_datbind is responsible for adding the constructors to the
* identdict and valtydict.
*)
val ctx =
List.foldl
(fn ((datbind, tyid), ctx) =>
Context.add_datbind (fn (sym, tyvals) =>
List.find (fn ({tycon, ...}, _) => Symbol.eq (tycon, sym)) enum_datbinds
|> Option.map (fn (_, tyid) => TVapp (tyvals, tyid))
) ctx (tyid, datbind)
Context.add_datbind ctx (tyid, datbind)
)
ctx
enum_datbinds
Expand Down
12 changes: 11 additions & 1 deletion test/test.sml
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,15 @@ structure Test :
\val res4 : bool nest = Recur (Left 150) \
\val res5 : bool nest = Recur (Right true) \
\datatype mutual1 = Nil | Cons of mutual2 \
\and mutual2 = Nil' | Cons' of mutual1 \
\val _ = Cons' (Cons Nil') \
\val res3 : bool nest = Base \
\val res4 : bool nest = Recur (Left 150) \
\val res5 : bool nest = Recur (Right true) \
\val res6 : bool = \
\ case res5 of \
\ Base => false \
Expand Down Expand Up @@ -705,7 +714,8 @@ structure Test :
, "datatypes" >:: test_datatypes
, "poly" >:: test_poly
, "scoping" >:: test_scoping
, "abstract" >:: test_abstract
(* XFAIL: *)
(* , "abstract" >:: test_abstract *)
, "cont" >:: test_cont
, "mod" >:: test_mod
]
Expand Down

0 comments on commit 9cb5a66

Please sign in to comment.