Skip to content

Commit f8ae4ad

Browse files
committed
cat with tfield support
1 parent f157ef8 commit f8ae4ad

File tree

1 file changed

+64
-22
lines changed

1 file changed

+64
-22
lines changed

MinkowskiEngine/MinkowskiOps.py

+64-22
Original file line numberDiff line numberDiff line change
@@ -83,28 +83,70 @@ def cat(*sparse_tensors):
8383
>>> sout2 = ME.cat(sin, sin2, sout) # Can concatenate multiple sparse tensors
8484
8585
"""
86-
for s in sparse_tensors:
87-
assert isinstance(s, SparseTensor), "Inputs must be sparse tensors."
88-
coordinate_manager = sparse_tensors[0].coordinate_manager
89-
coordinate_map_key = sparse_tensors[0].coordinate_map_key
90-
for s in sparse_tensors:
91-
assert (
92-
coordinate_manager == s.coordinate_manager
93-
), COORDINATE_MANAGER_DIFFERENT_ERROR
94-
assert coordinate_map_key == s.coordinate_map_key, (
95-
COORDINATE_KEY_DIFFERENT_ERROR
96-
+ str(coordinate_map_key)
97-
+ " != "
98-
+ str(s.coordinate_map_key)
86+
assert (
87+
len(sparse_tensors) > 1
88+
), f"Invalid number of inputs. The input must be at least two len(sparse_tensors) > 1"
89+
90+
if isinstance(sparse_tensors[0], SparseTensor):
91+
device = sparse_tensors[0].device
92+
coordinate_manager = sparse_tensors[0].coordinate_manager
93+
coordinate_map_key = sparse_tensors[0].coordinate_map_key
94+
for s in sparse_tensors:
95+
assert isinstance(
96+
s, SparseTensor
97+
), "Inputs must be either SparseTensors or TensorFields."
98+
assert (
99+
device == s.device
100+
), f"Device must be the same. {device} != {s.device}"
101+
assert (
102+
coordinate_manager == s.coordinate_manager
103+
), COORDINATE_MANAGER_DIFFERENT_ERROR
104+
assert coordinate_map_key == s.coordinate_map_key, (
105+
COORDINATE_KEY_DIFFERENT_ERROR
106+
+ str(coordinate_map_key)
107+
+ " != "
108+
+ str(s.coordinate_map_key)
109+
)
110+
tens = []
111+
for s in sparse_tensors:
112+
tens.append(s.F)
113+
return SparseTensor(
114+
torch.cat(tens, dim=1),
115+
coordinate_map_key=coordinate_map_key,
116+
coordinate_manager=coordinate_manager,
117+
)
118+
elif isinstance(sparse_tensors[0], TensorField):
119+
device = sparse_tensors[0].device
120+
coordinate_manager = sparse_tensors[0].coordinate_manager
121+
coordinate_field_map_key = sparse_tensors[0].coordinate_field_map_key
122+
for s in sparse_tensors:
123+
assert isinstance(
124+
s, TensorField
125+
), "Inputs must be either SparseTensors or TensorFields."
126+
assert (
127+
device == s.device
128+
), f"Device must be the same. {device} != {s.device}"
129+
assert (
130+
coordinate_manager == s.coordinate_manager
131+
), COORDINATE_MANAGER_DIFFERENT_ERROR
132+
assert coordinate_field_map_key == s.coordinate_field_map_key, (
133+
COORDINATE_KEY_DIFFERENT_ERROR
134+
+ str(coordinate_field_map_key)
135+
+ " != "
136+
+ str(s.coordinate_field_map_key)
137+
)
138+
tens = []
139+
for s in sparse_tensors:
140+
tens.append(s.F)
141+
return TensorField(
142+
torch.cat(tens, dim=1),
143+
coordinate_field_map_key=coordinate_field_map_key,
144+
coordinate_manager=coordinate_manager,
145+
)
146+
else:
147+
raise ValueError(
148+
"Invalid data type. The input must be either a list of sparse tensors or a list of tensor fields."
99149
)
100-
tens = []
101-
for s in sparse_tensors:
102-
tens.append(s.F)
103-
return SparseTensor(
104-
torch.cat(tens, dim=1),
105-
coordinate_map_key=coordinate_map_key,
106-
coordinate_manager=coordinate_manager,
107-
)
108150

109151

110152
def dense_coordinates(shape: Union[list, torch.Size]):
@@ -131,7 +173,7 @@ def dense_coordinates(shape: Union[list, torch.Size]):
131173
for s in np.meshgrid(
132174
np.linspace(0, B - 1, B),
133175
*(np.linspace(0, s - 1, s) for s in size[2:]),
134-
indexing="ij"
176+
indexing="ij",
135177
)
136178
],
137179
1,

0 commit comments

Comments
 (0)