@@ -83,28 +83,70 @@ def cat(*sparse_tensors):
83
83
>>> sout2 = ME.cat(sin, sin2, sout) # Can concatenate multiple sparse tensors
84
84
85
85
"""
86
- for s in sparse_tensors :
87
- assert isinstance (s , SparseTensor ), "Inputs must be sparse tensors."
88
- coordinate_manager = sparse_tensors [0 ].coordinate_manager
89
- coordinate_map_key = sparse_tensors [0 ].coordinate_map_key
90
- for s in sparse_tensors :
91
- assert (
92
- coordinate_manager == s .coordinate_manager
93
- ), COORDINATE_MANAGER_DIFFERENT_ERROR
94
- assert coordinate_map_key == s .coordinate_map_key , (
95
- COORDINATE_KEY_DIFFERENT_ERROR
96
- + str (coordinate_map_key )
97
- + " != "
98
- + str (s .coordinate_map_key )
86
+ assert (
87
+ len (sparse_tensors ) > 1
88
+ ), f"Invalid number of inputs. The input must be at least two len(sparse_tensors) > 1"
89
+
90
+ if isinstance (sparse_tensors [0 ], SparseTensor ):
91
+ device = sparse_tensors [0 ].device
92
+ coordinate_manager = sparse_tensors [0 ].coordinate_manager
93
+ coordinate_map_key = sparse_tensors [0 ].coordinate_map_key
94
+ for s in sparse_tensors :
95
+ assert isinstance (
96
+ s , SparseTensor
97
+ ), "Inputs must be either SparseTensors or TensorFields."
98
+ assert (
99
+ device == s .device
100
+ ), f"Device must be the same. { device } != { s .device } "
101
+ assert (
102
+ coordinate_manager == s .coordinate_manager
103
+ ), COORDINATE_MANAGER_DIFFERENT_ERROR
104
+ assert coordinate_map_key == s .coordinate_map_key , (
105
+ COORDINATE_KEY_DIFFERENT_ERROR
106
+ + str (coordinate_map_key )
107
+ + " != "
108
+ + str (s .coordinate_map_key )
109
+ )
110
+ tens = []
111
+ for s in sparse_tensors :
112
+ tens .append (s .F )
113
+ return SparseTensor (
114
+ torch .cat (tens , dim = 1 ),
115
+ coordinate_map_key = coordinate_map_key ,
116
+ coordinate_manager = coordinate_manager ,
117
+ )
118
+ elif isinstance (sparse_tensors [0 ], TensorField ):
119
+ device = sparse_tensors [0 ].device
120
+ coordinate_manager = sparse_tensors [0 ].coordinate_manager
121
+ coordinate_field_map_key = sparse_tensors [0 ].coordinate_field_map_key
122
+ for s in sparse_tensors :
123
+ assert isinstance (
124
+ s , TensorField
125
+ ), "Inputs must be either SparseTensors or TensorFields."
126
+ assert (
127
+ device == s .device
128
+ ), f"Device must be the same. { device } != { s .device } "
129
+ assert (
130
+ coordinate_manager == s .coordinate_manager
131
+ ), COORDINATE_MANAGER_DIFFERENT_ERROR
132
+ assert coordinate_field_map_key == s .coordinate_field_map_key , (
133
+ COORDINATE_KEY_DIFFERENT_ERROR
134
+ + str (coordinate_field_map_key )
135
+ + " != "
136
+ + str (s .coordinate_field_map_key )
137
+ )
138
+ tens = []
139
+ for s in sparse_tensors :
140
+ tens .append (s .F )
141
+ return TensorField (
142
+ torch .cat (tens , dim = 1 ),
143
+ coordinate_field_map_key = coordinate_field_map_key ,
144
+ coordinate_manager = coordinate_manager ,
145
+ )
146
+ else :
147
+ raise ValueError (
148
+ "Invalid data type. The input must be either a list of sparse tensors or a list of tensor fields."
99
149
)
100
- tens = []
101
- for s in sparse_tensors :
102
- tens .append (s .F )
103
- return SparseTensor (
104
- torch .cat (tens , dim = 1 ),
105
- coordinate_map_key = coordinate_map_key ,
106
- coordinate_manager = coordinate_manager ,
107
- )
108
150
109
151
110
152
def dense_coordinates (shape : Union [list , torch .Size ]):
@@ -131,7 +173,7 @@ def dense_coordinates(shape: Union[list, torch.Size]):
131
173
for s in np .meshgrid (
132
174
np .linspace (0 , B - 1 , B ),
133
175
* (np .linspace (0 , s - 1 , s ) for s in size [2 :]),
134
- indexing = "ij"
176
+ indexing = "ij" ,
135
177
)
136
178
],
137
179
1 ,
0 commit comments