Skip to content

Commit

Permalink
Really fix #14
Browse files Browse the repository at this point in the history
  • Loading branch information
wies committed Feb 6, 2025
1 parent 2adfbb5 commit ddb3d1f
Show file tree
Hide file tree
Showing 4 changed files with 43 additions and 33 deletions.
2 changes: 1 addition & 1 deletion lib/ast/progUtils.ml
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ let is_ra_type (tp : AstDef.type_expr) : bool t =
let open Syntax in
let rec does_ident_implement_ra qual_ident =
let* symbol = find qual_ident in
Symbol.extract symbol ~f:(fun subst -> function
Symbol.extract symbol ~f:(fun _ subst -> function
| AstDef.Module.ModDef m -> return m.mod_decl.mod_decl_is_ra
| ModInst mod_inst ->
let* is_ra = does_ident_implement_ra mod_inst.mod_inst_type in
Expand Down
26 changes: 13 additions & 13 deletions lib/ast/rewriter.ml
Original file line number Diff line number Diff line change
Expand Up @@ -1292,22 +1292,22 @@ module Symbol = struct
i_l))
subst);*)

match subst with
match SymbolTbl.qid_subst subst with
| [] -> return symbol
| _ ->
let open Syntax in
let+ tbl = get_table in
let tbl_scope = SymbolTbl.goto (AstDef.Symbol.to_loc symbol) name tbl in
let symbol0 = match symbol, subst with
| ModDef mod_def, _ :: _ ->
let symbol0 = match symbol with
| ModDef mod_def when SymbolTbl.is_instance subst ->
let mod_decl = { mod_def.mod_decl with mod_decl_formals = [] } in
AstDef.Module.ModDef { mod_def with mod_decl }
| _ -> symbol
in
let _, symbol1 =
eval
(Module.rewrite_qual_idents_in_symbol
~f:(QualIdent.requalify subst)
~f:(subst |> SymbolTbl.qid_subst |> QualIdent.requalify)
symbol0)
tbl_scope
in
Expand All @@ -1324,13 +1324,13 @@ module Symbol = struct
| AstDef.Module.TypeDef { type_def_expr = None; _ } -> return None
| TypeDef { type_def_expr = Some tp_expr; _ } ->
let+ tp_expr =
Type.rewrite_qual_idents ~f:(QualIdent.requalify subst) tp_expr
Type.rewrite_qual_idents ~f:(subst |> SymbolTbl.qid_subst |> QualIdent.requalify) tp_expr
in
Some tp_expr
| ModDef { mod_decl = { mod_decl_rep = Some rep_id; _ }; _ } ->
let+ tp_expr =
AstDef.Type.mk_var (QualIdent.append name rep_id)
|> Type.rewrite_qual_idents ~f:(QualIdent.requalify subst)
|> Type.rewrite_qual_idents ~f:(subst |> SymbolTbl.qid_subst |> QualIdent.requalify)
in
Some tp_expr
| _ -> Error.error loc "Expected type identifier"
Expand All @@ -1342,24 +1342,24 @@ module Symbol = struct
| FieldDef field_def -> field_def.field_type
| _ -> Error.error loc "Expected expression identifier"
in
Type.rewrite_qual_idents ~f:(QualIdent.requalify subst) tp_expr
Type.rewrite_qual_idents ~f:(subst |> SymbolTbl.qid_subst |> QualIdent.requalify) tp_expr

let reify_field_type loc (_name, symbol, subst) : (AstDef.Type.t, 'a) t_ext =
let tp_expr =
match symbol with
| AstDef.Module.FieldDef { field_type = App (Fld, [ tp ], _); _ } -> tp
| _ -> Error.error loc "Expected field identifier"
in
Type.rewrite_qual_idents ~f:(QualIdent.requalify subst) tp_expr
Type.rewrite_qual_idents ~f:(subst |> SymbolTbl.qid_subst |> QualIdent.requalify) tp_expr

let orig_symbol (_name, symbol, _subst) = symbol
let orig_qid (name, _symbol, _subst) = name
let subst (_name, _symbol, subst) = subst
let extract (_name, symbol, subst) ~f = f (QualIdent.requalify subst) symbol
let add_subst s (name, symbol, subst) = (name, symbol, s :: subst)
let is_derived (_, _, subst) = not @@ Base.List.is_empty subst
let subst (_name, _symbol, subst) = SymbolTbl.qid_subst subst
let extract (_name, symbol, subst) ~f = f (SymbolTbl.is_instance subst) (subst |> SymbolTbl.qid_subst |> QualIdent.requalify) symbol
let extend_subst s (name, symbol, subst) = (name, symbol, SymbolTbl.extend_subst s subst)
let is_instance (_, _, subst) = SymbolTbl.is_instance subst

type t = QualIdent.t * AstDef.Module.symbol * QualIdent.subst
type t = QualIdent.t * AstDef.Module.symbol * SymbolTbl.subst

let pr ppf (name, symbol, subst) =
Stdlib.Format.fprintf ppf "%a -> %a [%a]" QualIdent.pr name AstDef.Symbol.pr
Expand Down
35 changes: 22 additions & 13 deletions lib/ast/symbolTbl.ml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,15 @@ open Util
let unknown_ident_error loc id =
Error.error (QualIdent.to_loc id) ("Unknown identifier " ^ QualIdent.to_string id)

(* Substitutions for reifying aliased symbols.
The Boolean indicates whether the derived symbol is the result of a functor instantiation.
*)
type subst = bool * QualIdent.subst

let is_instance (b, _) = b
let qid_subst (_, ss) = ss
let extend_subst s (b, ss) = (b, s :: ss)

type entry =
| Symbol of QualIdent.t
| Alias of bool * QualIdent.t * QualIdent.subst
Expand All @@ -24,7 +33,7 @@ type scope = {
(* Symbols defined in this scope *)
scope_entries : entry IdentHashtbl.t; [@hash.ignore]
(* Cache that maps names (partially) qualified names to fully qualified names, relative to this scope. *)
scope_cache : (QualIdent.t * QualIdent.t * QualIdent.subst) QualIdentHashtbl.t;
scope_cache : (QualIdent.t * QualIdent.t * subst) QualIdentHashtbl.t;
[@hash.ignore]
}
[@@deriving hash]
Expand Down Expand Up @@ -145,7 +154,7 @@ let fully_qualify ident scope tbl : QualIdent.t =
if QualIdent.(scope_ident = root_ident tbl) then QualIdent.from_ident ident
else QualIdent.append scope_ident ident

let pr_subst ppf subst =
let pr_subst ppf (_, subst) =
let open Stdlib.Format in
fprintf ppf "[ @[%a@] ]"
(Print.pr_list_sep ",\n" (fun ppf (a, b) ->
Expand All @@ -161,7 +170,7 @@ let is_parent scope tbl =

(** Resolve [name] to its fully qualified identifier relative to the current scope in [tbl]. *)
let resolve name (tbl : t) :
(QualIdent.t * QualIdent.t * QualIdent.subst) option =
(QualIdent.t * QualIdent.t * subst) option =
let open Option.Syntax in
let rec go_forward inst_scopes scope subst ids =
(* Logs.debug (fun m -> m "SymbolTbl.resolve.go_forward: scope: %a" QualIdent.pr (get_scope_id scope)); *)
Expand Down Expand Up @@ -192,7 +201,7 @@ let resolve name (tbl : t) :
) subst1) ; *)
let subst1 =
List.map subst1 ~f:(fun (s, t) ->
(QualIdent.requalify subst s, t))
(QualIdent.requalify (snd subst) s, t))
in

(* if the first argument is abstract, then it needs to be requalified. The second arg doesn't because this is taken care of by the order in which elements are added to the subst list. QualIdent.requalify will make sure the renaming on the second argument by existing substitutions happens *)
Expand All @@ -209,9 +218,9 @@ let resolve name (tbl : t) :
| _ ->
( target_qual_ident,
fully_qualify first_id scope tbl |> QualIdent.to_list )
:: subst
:: snd subst
in
let subst = subst1 @ target_subst in
let subst = fst subst || not @@ List.is_empty subst1, subst1 @ target_subst in
let new_inst_scopes =
if is_abstract then inst_scopes
else Set.add inst_scopes target_qual_ident
Expand Down Expand Up @@ -250,11 +259,11 @@ let resolve name (tbl : t) :
let+ alias_qual_ident, subst, is_local =
go_forward
(Set.empty (module QualIdent))
curr_scope [] (QualIdent.to_list name)
curr_scope (false, []) (QualIdent.to_list name)
in
(* Compute resolved identifier *)
let orig_qual_ident =
alias_qual_ident |> QualIdent.requalify subst
alias_qual_ident |> QualIdent.requalify (snd subst)
in
(* Don't qualify orig_qual_ident if it identifies a symbol in a local scope *)
let orig_qual_ident =
Expand Down Expand Up @@ -293,7 +302,7 @@ let resolve_exn name tbl =
- a substitution map that maps the symbol definition to the scope where it is used.
*)
let resolve_and_find name tbl :
(QualIdent.t * QualIdent.t * Module.symbol * QualIdent.subst) option =
(QualIdent.t * QualIdent.t * Module.symbol * subst) option =
let open Option.Syntax in
(* Logs.debug (fun m -> m "SymbolTbl.resolve_and_find: %a" QualIdent.pr name); *)
let* alias_qual_ident, orig_qual_ident, subst = resolve name tbl in
Expand All @@ -319,7 +328,7 @@ let resolve_and_find_exn name (tbl : t) =
unknown_ident_error (QualIdent.to_loc name) name)

(** Find the symbol associated with [name] relative to the current scope in [tbl]. *)
let find name tbl : (Module.symbol * QualIdent.subst) option =
let find name tbl : (Module.symbol * subst) option =
let open Option.Syntax in
let* alias_qual_ident, _, subst = resolve name tbl in
let+ symbol = Map.find tbl.tbl_symbols alias_qual_ident in
Expand Down Expand Up @@ -450,10 +459,10 @@ let add_symbol ?(scope : scope option = None) symbol tbl =
let formals =
match mod_inst_symbol with
| Module.ModDef mdef ->
if List.is_empty subst1
if not @@ is_instance subst1
then mdef.mod_decl.mod_decl_formals
else []
| _ -> Error.type_error symbol_loc "Expected module identifier"
| _ -> Error.type_error (QualIdent.to_loc mod_inst_func) "Expected module identifier"
in
let res =
List.map2 formals mod_inst_args ~f:(fun formal arg ->
Expand All @@ -469,7 +478,7 @@ let add_symbol ?(scope : scope option = None) symbol tbl =
| Ok subst ->
(mod_inst_func, subst)
| Unequal_lengths ->
Error.type_error symbol_loc
Error.type_error (QualIdent.to_loc mod_inst_func)
(Printf.sprintf
!"Module %{QualIdent} expects %d arguments"
mod_inst_func (List.length formals)))
Expand Down
13 changes: 7 additions & 6 deletions lib/frontend/typing.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2241,7 +2241,7 @@ module ProcessModule = struct
Rewriter.resolve_and_find int_ident
in
let interfaces =
Rewriter.Symbol.extract mod_symbol ~f:(fun _subst -> function
Rewriter.Symbol.extract mod_symbol ~f:(fun _ _subst -> function
| Ast.Module.ModDef mod_def ->
(*Set.map (module QualIdent) mod_def.mod_decl.mod_decl_interfaces ~f:subst*)
mod_def.mod_decl.mod_decl_interfaces
Expand Down Expand Up @@ -2310,9 +2310,10 @@ module ProcessModule = struct
mod_inst_func
in
let formals =
Rewriter.Symbol.extract functor_symbol ~f:(fun subst ->
Rewriter.Symbol.extract functor_symbol ~f:(fun is_instance subst ->
function
| Ast.Module.ModDef mod_def when not @@ Rewriter.Symbol.is_derived functor_symbol ->
| Ast.Module.ModDef mod_def when not is_instance ->
Logs.info (fun m -> m !"%{QualIdent}" mod_inst_func);
List.map mod_def.mod_decl.mod_decl_formals
~f:(fun mod_inst ->
subst mod_inst.mod_inst_type)
Expand All @@ -2322,7 +2323,7 @@ module ProcessModule = struct
match List.zip mod_inst_args formals with
| Ok res -> res
| Unequal_lengths ->
Error.type_error mod_inst.mod_inst_loc
Error.type_error (*mod_inst.mod_inst_loc*) (QualIdent.to_loc mod_inst_func)
(Printf.sprintf
!"Module %{QualIdent} expects %d arguments"
mod_inst_func (List.length formals))
Expand Down Expand Up @@ -2513,7 +2514,7 @@ module ProcessModule = struct
Rewriter.resolve_and_find mid
in
let interface_symbol =
Rewriter.Symbol.add_subst
Rewriter.Symbol.extend_subst
(qual_interface_ident, QualIdent.to_list mod_qual_ident)
interface_symbol
in
Expand Down Expand Up @@ -2574,7 +2575,7 @@ module ProcessModule = struct
Rewriter.resolve_and_find
interface_ident
in
Rewriter.Symbol.extract interface_symbol ~f:(fun _ -> function
Rewriter.Symbol.extract interface_symbol ~f:(fun _ _ -> function
| Module.ModDef mod_def -> mod_def.mod_decl.mod_decl_is_ra
| _ -> false))
in
Expand Down

0 comments on commit ddb3d1f

Please sign in to comment.