Skip to content

Commit ab2a914

Browse files
committed
Simplify casting base layer. Inputs are ignored for ragged inputs.
1 parent 123194a commit ab2a914

File tree

1 file changed

+19
-33
lines changed

1 file changed

+19
-33
lines changed

kgcnn/layers/casting.py

Lines changed: 19 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,12 @@ def _cat_one(t):
1616

1717
class _CastBatchedDisjointBase(Layer):
1818

19-
def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dtype_index=None,
20-
padded_disjoint: bool = False, uses_mask: bool = False,
19+
def __init__(self,
20+
reverse_indices: bool = False,
21+
dtype_batch: str = "int64",
22+
dtype_index=None,
23+
padded_disjoint: bool = False,
24+
uses_mask: bool = False,
2125
static_batched_node_output_shape: tuple = None,
2226
static_batched_edge_output_shape: tuple = None,
2327
remove_padded_disjoint_from_batched_output: bool = True,
@@ -29,20 +33,26 @@ def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dt
2933
dtype_batch (str): Dtype for batch ID tensor. Default is 'int64'.
3034
dtype_index (str): Dtype for index tensor. Default is None.
3135
padded_disjoint (bool): Whether to keep padding in disjoint representation. Default is False.
36+
Not used for ragged arguments.
3237
uses_mask (bool): Whether the padding is marked by a boolean mask or by a length tensor, counting the
3338
non-padded nodes from index 0. Default is False.
39+
Not used for ragged arguments.
3440
static_batched_node_output_shape (tuple): Statical output shape of nodes. Default is None.
41+
Not used for ragged arguments.
3542
static_batched_edge_output_shape (tuple): Statical output shape of edges. Default is None.
43+
Not used for ragged arguments.
3644
remove_padded_disjoint_from_batched_output (bool): Whether to remove the first element on batched output
3745
in case of padding.
46+
Not used for ragged arguments.
3847
"""
3948
super(_CastBatchedDisjointBase, self).__init__(**kwargs)
4049
self.reverse_indices = reverse_indices
4150
self.dtype_index = dtype_index
4251
self.dtype_batch = dtype_batch
4352
self.uses_mask = uses_mask
4453
self.padded_disjoint = padded_disjoint
45-
self.supports_jit = padded_disjoint
54+
if padded_disjoint:
55+
self.supports_jit = True
4656
self.static_batched_node_output_shape = static_batched_node_output_shape
4757
self.static_batched_edge_output_shape = static_batched_edge_output_shape
4858
self.remove_padded_disjoint_from_batched_output = remove_padded_disjoint_from_batched_output
@@ -536,31 +546,7 @@ def call(self, inputs: list, **kwargs):
536546
CastBatchedGraphStateToDisjoint.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__
537547

538548

539-
class _CastRaggedToDisjointBase(Layer):
540-
541-
def __init__(self, reverse_indices: bool = False, dtype_batch: str = "int64", dtype_index=None, **kwargs):
542-
r"""Initialize layer.
543-
544-
Args:
545-
reverse_indices (bool): Whether to reverse index order. Default is False.
546-
dtype_batch (str): Dtype for batch ID tensor. Default is 'int64'.
547-
dtype_index (str): Dtype for index tensor. Default is None.
548-
"""
549-
super(_CastRaggedToDisjointBase, self).__init__(**kwargs)
550-
self.reverse_indices = reverse_indices
551-
self.dtype_index = dtype_index
552-
self.dtype_batch = dtype_batch
553-
# self.supports_jit = False
554-
555-
def get_config(self):
556-
"""Get config dictionary for this layer."""
557-
config = super(_CastRaggedToDisjointBase, self).get_config()
558-
config.update({"reverse_indices": self.reverse_indices, "dtype_batch": self.dtype_batch,
559-
"dtype_index": self.dtype_index})
560-
return config
561-
562-
563-
class CastRaggedAttributesToDisjoint(_CastRaggedToDisjointBase):
549+
class CastRaggedAttributesToDisjoint(_CastBatchedDisjointBase):
564550

565551
def __init__(self, **kwargs):
566552
super(CastRaggedAttributesToDisjoint, self).__init__(**kwargs)
@@ -598,10 +584,10 @@ def call(self, inputs, **kwargs):
598584
return decompose_ragged_tensor(inputs, batch_dtype=self.dtype_batch)
599585

600586

601-
CastRaggedAttributesToDisjoint.__init__.__doc__ = _CastRaggedToDisjointBase.__init__.__doc__
587+
CastRaggedAttributesToDisjoint.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__
602588

603589

604-
class CastRaggedIndicesToDisjoint(_CastRaggedToDisjointBase):
590+
class CastRaggedIndicesToDisjoint(_CastBatchedDisjointBase):
605591

606592
def __init__(self, **kwargs):
607593
super(CastRaggedIndicesToDisjoint, self).__init__(**kwargs)
@@ -685,10 +671,10 @@ def call(self, inputs, **kwargs):
685671
return [nodes_flatten, disjoint_indices, graph_id_node, graph_id_edge, node_id, edge_id, node_len, edge_len]
686672

687673

688-
CastRaggedIndicesToDisjoint.__init__.__doc__ = _CastRaggedToDisjointBase.__init__.__doc__
674+
CastRaggedIndicesToDisjoint.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__
689675

690676

691-
class CastDisjointToRaggedAttributes(_CastRaggedToDisjointBase):
677+
class CastDisjointToRaggedAttributes(_CastBatchedDisjointBase):
692678

693679
def __init__(self, **kwargs):
694680
super(CastDisjointToRaggedAttributes, self).__init__(**kwargs)
@@ -713,4 +699,4 @@ def call(self, inputs, **kwargs):
713699
raise NotImplementedError()
714700

715701

716-
CastDisjointToRaggedAttributes.__init__.__doc__ = CastDisjointToRaggedAttributes.__init__.__doc__
702+
CastDisjointToRaggedAttributes.__init__.__doc__ = _CastBatchedDisjointBase.__init__.__doc__

0 commit comments

Comments
 (0)