@@ -16,8 +16,12 @@ def _cat_one(t):
16
16
17
17
class _CastBatchedDisjointBase (Layer ):
18
18
19
- def __init__ (self , reverse_indices : bool = False , dtype_batch : str = "int64" , dtype_index = None ,
20
- padded_disjoint : bool = False , uses_mask : bool = False ,
19
+ def __init__ (self ,
20
+ reverse_indices : bool = False ,
21
+ dtype_batch : str = "int64" ,
22
+ dtype_index = None ,
23
+ padded_disjoint : bool = False ,
24
+ uses_mask : bool = False ,
21
25
static_batched_node_output_shape : tuple = None ,
22
26
static_batched_edge_output_shape : tuple = None ,
23
27
remove_padded_disjoint_from_batched_output : bool = True ,
@@ -29,20 +33,26 @@ def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dt
29
33
dtype_batch (str): Dtype for batch ID tensor. Default is 'int64'.
30
34
dtype_index (str): Dtype for index tensor. Default is None.
31
35
padded_disjoint (bool): Whether to keep padding in disjoint representation. Default is False.
36
+ Not used for ragged arguments.
32
37
uses_mask (bool): Whether the padding is marked by a boolean mask or by a length tensor, counting the
33
38
non-padded nodes from index 0. Default is False.
39
+ Not used for ragged arguments.
34
40
static_batched_node_output_shape (tuple): Statical output shape of nodes. Default is None.
41
+ Not used for ragged arguments.
35
42
static_batched_edge_output_shape (tuple): Statical output shape of edges. Default is None.
43
+ Not used for ragged arguments.
36
44
remove_padded_disjoint_from_batched_output (bool): Whether to remove the first element on batched output
37
45
in case of padding.
46
+ Not used for ragged arguments.
38
47
"""
39
48
super (_CastBatchedDisjointBase , self ).__init__ (** kwargs )
40
49
self .reverse_indices = reverse_indices
41
50
self .dtype_index = dtype_index
42
51
self .dtype_batch = dtype_batch
43
52
self .uses_mask = uses_mask
44
53
self .padded_disjoint = padded_disjoint
45
- self .supports_jit = padded_disjoint
54
+ if padded_disjoint :
55
+ self .supports_jit = True
46
56
self .static_batched_node_output_shape = static_batched_node_output_shape
47
57
self .static_batched_edge_output_shape = static_batched_edge_output_shape
48
58
self .remove_padded_disjoint_from_batched_output = remove_padded_disjoint_from_batched_output
@@ -536,31 +546,7 @@ def call(self, inputs: list, **kwargs):
536
546
CastBatchedGraphStateToDisjoint .__init__ .__doc__ = _CastBatchedDisjointBase .__init__ .__doc__
537
547
538
548
539
- class _CastRaggedToDisjointBase (Layer ):
540
-
541
- def __init__ (self , reverse_indices : bool = False , dtype_batch : str = "int64" , dtype_index = None , ** kwargs ):
542
- r"""Initialize layer.
543
-
544
- Args:
545
- reverse_indices (bool): Whether to reverse index order. Default is False.
546
- dtype_batch (str): Dtype for batch ID tensor. Default is 'int64'.
547
- dtype_index (str): Dtype for index tensor. Default is None.
548
- """
549
- super (_CastRaggedToDisjointBase , self ).__init__ (** kwargs )
550
- self .reverse_indices = reverse_indices
551
- self .dtype_index = dtype_index
552
- self .dtype_batch = dtype_batch
553
- # self.supports_jit = False
554
-
555
- def get_config (self ):
556
- """Get config dictionary for this layer."""
557
- config = super (_CastRaggedToDisjointBase , self ).get_config ()
558
- config .update ({"reverse_indices" : self .reverse_indices , "dtype_batch" : self .dtype_batch ,
559
- "dtype_index" : self .dtype_index })
560
- return config
561
-
562
-
563
- class CastRaggedAttributesToDisjoint (_CastRaggedToDisjointBase ):
549
+ class CastRaggedAttributesToDisjoint (_CastBatchedDisjointBase ):
564
550
565
551
def __init__ (self , ** kwargs ):
566
552
super (CastRaggedAttributesToDisjoint , self ).__init__ (** kwargs )
@@ -598,10 +584,10 @@ def call(self, inputs, **kwargs):
598
584
return decompose_ragged_tensor (inputs , batch_dtype = self .dtype_batch )
599
585
600
586
601
- CastRaggedAttributesToDisjoint .__init__ .__doc__ = _CastRaggedToDisjointBase .__init__ .__doc__
587
+ CastRaggedAttributesToDisjoint .__init__ .__doc__ = _CastBatchedDisjointBase .__init__ .__doc__
602
588
603
589
604
- class CastRaggedIndicesToDisjoint (_CastRaggedToDisjointBase ):
590
+ class CastRaggedIndicesToDisjoint (_CastBatchedDisjointBase ):
605
591
606
592
def __init__ (self , ** kwargs ):
607
593
super (CastRaggedIndicesToDisjoint , self ).__init__ (** kwargs )
@@ -685,10 +671,10 @@ def call(self, inputs, **kwargs):
685
671
return [nodes_flatten , disjoint_indices , graph_id_node , graph_id_edge , node_id , edge_id , node_len , edge_len ]
686
672
687
673
688
- CastRaggedIndicesToDisjoint .__init__ .__doc__ = _CastRaggedToDisjointBase .__init__ .__doc__
674
+ CastRaggedIndicesToDisjoint .__init__ .__doc__ = _CastBatchedDisjointBase .__init__ .__doc__
689
675
690
676
691
- class CastDisjointToRaggedAttributes (_CastRaggedToDisjointBase ):
677
+ class CastDisjointToRaggedAttributes (_CastBatchedDisjointBase ):
692
678
693
679
def __init__ (self , ** kwargs ):
694
680
super (CastDisjointToRaggedAttributes , self ).__init__ (** kwargs )
@@ -713,4 +699,4 @@ def call(self, inputs, **kwargs):
713
699
raise NotImplementedError ()
714
700
715
701
716
- CastDisjointToRaggedAttributes .__init__ .__doc__ = CastDisjointToRaggedAttributes .__init__ .__doc__
702
+ CastDisjointToRaggedAttributes .__init__ .__doc__ = _CastBatchedDisjointBase .__init__ .__doc__
0 commit comments