1
1
use super :: * ;
2
+ use defs2:: { Definitions , Id as DefId , TypesIn } ;
2
3
use kserd:: Number ;
3
4
use petgraph:: prelude:: * ;
4
5
use std:: ops:: Deref ;
@@ -23,6 +24,9 @@ pub enum AstNode {
23
24
Op {
24
25
op : Tag ,
25
26
blk : Tag ,
27
+
28
+ /// The invocation location of the op.
29
+ within : DefId ,
26
30
} ,
27
31
Intrinsic {
28
32
op : Tag ,
@@ -107,7 +111,7 @@ pub fn init(expr: ast::Expression, defs: &Definitions) -> Result<(AstGraph, Chgs
107
111
108
112
let expr_tag = expr. tag . clone ( ) ;
109
113
110
- graph. flatten_expr ( expr, & mut chgs, defs) ?;
114
+ graph. flatten_expr ( expr, & mut chgs, defs, defs2 :: ROOT . into ( ) ) ?;
111
115
112
116
let recursion_detector = & mut RecursionDetection :: default ( ) ;
113
117
@@ -130,6 +134,7 @@ impl AstGraph {
130
134
expr : ast:: Expression ,
131
135
chgs : & mut Chgs ,
132
136
defs : & Definitions ,
137
+ within : DefId ,
133
138
) -> Result < NodeIndex > {
134
139
use ast:: * ;
135
140
@@ -143,6 +148,7 @@ impl AstGraph {
143
148
} = expr;
144
149
145
150
let root = g. add_node ( AstNode :: Expr ( tag) ) ;
151
+ let tys = defs. types ( ) . within ( within) ;
146
152
147
153
// rather than building a recursive function, we'll use a queue of expressions to process,
148
154
// since expressions are the recursive element
@@ -157,7 +163,7 @@ impl AstGraph {
157
163
q. push_back ( Qi {
158
164
root,
159
165
blocks,
160
- out_ty : map_ty_tag ( out_ty, defs ) ?,
166
+ out_ty : map_ty_tag ( out_ty, tys ) ?,
161
167
} ) ;
162
168
163
169
// FIFO -- breadth-first
@@ -187,13 +193,17 @@ impl AstGraph {
187
193
}
188
194
} ;
189
195
190
- let op = g. add_node ( AstNode :: Op { op, blk : blk_tag } ) ;
196
+ let op = g. add_node ( AstNode :: Op {
197
+ op,
198
+ blk : blk_tag,
199
+ within,
200
+ } ) ;
191
201
g. add_edge ( root, op, Relation :: Normal ) ; // edge from the expression root to the op
192
202
193
- if let Some ( t) = map_ty_tag ( in_ty, defs ) ? {
203
+ if let Some ( t) = map_ty_tag ( in_ty, tys ) ? {
194
204
chgs. push ( Chg :: ObligeInput ( op, t) ) ;
195
205
}
196
- if let Some ( t) = map_ty_tag ( out_ty, defs ) ? {
206
+ if let Some ( t) = map_ty_tag ( out_ty, tys ) ? {
197
207
chgs. push ( Chg :: ObligeOutput ( op, t) ) ;
198
208
}
199
209
@@ -236,7 +246,7 @@ impl AstGraph {
236
246
q. push_back ( Qi {
237
247
root : term,
238
248
blocks,
239
- out_ty : map_ty_tag ( out_ty, defs ) ?,
249
+ out_ty : map_ty_tag ( out_ty, tys ) ?,
240
250
} ) ;
241
251
}
242
252
}
@@ -259,7 +269,7 @@ impl AstGraph {
259
269
let ops = self
260
270
. node_indices ( )
261
271
// is an Op
262
- . filter_map ( |n| self [ n] . op ( ) . map ( |_ | OpNode ( n) ) )
272
+ . filter_map ( |n| self [ n] . op ( ) . map ( |x | OpNode ( n) ) )
263
273
// have not already expanded it
264
274
. filter ( |& n| !self . op_expanded ( n) )
265
275
. collect :: < Vec < _ > > ( ) ;
@@ -283,20 +293,16 @@ impl AstGraph {
283
293
) -> Result < bool > {
284
294
let opnode_ = NodeIndex :: from ( opnode) ;
285
295
286
- let op = self [ opnode_]
287
- . op ( )
288
- . expect ( "opnode must be an Op variant" )
289
- . 0
290
- . clone ( ) ;
296
+ let ( op, _, within) = self [ opnode_] . op ( ) . expect ( "opnode must be an Op variant" ) ;
297
+ let op = op. clone ( ) ;
291
298
292
- let impls = defs. impls ( ) ;
299
+ let impls = defs. impls ( ) . within ( within ) ;
293
300
294
- if !impls. contains_op ( op. str ( ) ) {
295
- return Err ( Error :: op_not_found ( & op, None , false , impls) ) ;
301
+ let op_impls = impls. matches ( & op) ?;
302
+ if op_impls. is_empty ( ) {
303
+ return Err ( Error :: op_not_found2 ( & op, false , impls) ) ;
296
304
}
297
305
298
- let op_impls = impls. iter_op ( op. str ( ) ) ;
299
-
300
306
recursion_detector. clear_cache ( ) ;
301
307
302
308
let mut expanded = false ;
@@ -331,7 +337,9 @@ impl AstGraph {
331
337
}
332
338
333
339
// no recursion detected, this id will need to be added to the detector
334
- let tys = defs. types ( ) ;
340
+ // TODO: use something other than ROOT
341
+ todo ! ( ) ;
342
+ let tys = defs. types ( ) . within ( defs2:: ROOT ) ;
335
343
let params = def
336
344
. params
337
345
. iter ( )
@@ -342,7 +350,10 @@ impl AstGraph {
342
350
expr : def. expr . tag . clone ( ) ,
343
351
params,
344
352
} ) ;
345
- let expr = self . flatten_expr ( def. expr . clone ( ) , chgs, defs) ?;
353
+ // TODO: use something other than ROOT
354
+ // not sure what yet, need to consider how looking down into
355
+ // ast works
356
+ let expr = self . flatten_expr ( def. expr . clone ( ) , chgs, defs, todo ! ( ) ) ?;
346
357
// link cmd to expr
347
358
self . 0 . add_edge ( cmd, expr, Relation :: Normal ) ;
348
359
// add the id into the recursion detector
@@ -402,9 +413,8 @@ impl AstGraph {
402
413
}
403
414
}
404
415
405
- fn map_ty_tag ( tag : Option < Tag > , defs : & Definitions ) -> Result < Option < Type > > {
406
- tag. map ( |t| defs. types ( ) . get_using_tag ( & t) . map ( |x| x. clone ( ) ) )
407
- . transpose ( )
416
+ fn map_ty_tag ( tag : Option < Tag > , tys : TypesIn ) -> Result < Option < Type > > {
417
+ tag. map ( |t| tys. get ( & t) . map ( Clone :: clone) ) . transpose ( )
408
418
}
409
419
410
420
#[ derive( Default ) ]
@@ -709,9 +719,9 @@ impl AstGraph {
709
719
710
720
impl AstNode {
711
721
/// If this is an op node, returns the `(op, blk)` tags.
712
- pub fn op ( & self ) -> Option < ( & Tag , & Tag ) > {
722
+ pub fn op ( & self ) -> Option < ( & Tag , & Tag , DefId ) > {
713
723
match self {
714
- AstNode :: Op { op, blk } => Some ( ( op, blk) ) ,
724
+ AstNode :: Op { op, blk, within } => Some ( ( op, blk, * within ) ) ,
715
725
_ => None ,
716
726
}
717
727
}
@@ -752,7 +762,11 @@ impl AstNode {
752
762
use AstNode :: * ;
753
763
754
764
match self {
755
- Op { op, blk : _ } => op,
765
+ Op {
766
+ op,
767
+ blk : _,
768
+ within : _,
769
+ } => op,
756
770
Intrinsic { op, intrinsic : _ } => op,
757
771
Def { expr, params : _ } => expr,
758
772
Flag ( f) => f,
@@ -770,7 +784,7 @@ impl fmt::Display for AstNode {
770
784
use AstNode :: * ;
771
785
772
786
match self {
773
- Op { op, blk : _ } => write ! ( f, "Op({})" , op. str ( ) ) ,
787
+ Op { op, blk : _, within } => write ! ( f, "Op({})" , op. str ( ) ) ,
774
788
Intrinsic {
775
789
op : _,
776
790
intrinsic : _,
@@ -847,15 +861,15 @@ impl fmt::Display for Relation {
847
861
}
848
862
849
863
impl Parameter {
850
- fn from_ast ( param : & ast:: Parameter , tys : & types :: Types ) -> Result < Self > {
864
+ fn from_ast ( param : & ast:: Parameter , tys : TypesIn ) -> Result < Self > {
851
865
let ast:: Parameter { ident, ty } = param;
852
866
853
867
let name = ident. clone ( ) ;
854
868
let ty = ty. as_ref ( ) ;
855
869
let ty = if ty. map ( |t| t. str ( ) == "Expr" ) . unwrap_or ( false ) {
856
870
ParameterTy :: Expr
857
871
} else {
858
- ty. map ( |t| tys. get_using_tag ( t) )
872
+ ty. map ( |t| tys. get ( t) )
859
873
. transpose ( ) ?
860
874
. map ( Clone :: clone)
861
875
. map ( ParameterTy :: Specified )
0 commit comments