Skip to content

Commit f157ef8

Browse files
committed
update slice and tfield
1 parent 9e1322f commit f157ef8

File tree

2 files changed

+40
-25
lines changed

2 files changed

+40
-25
lines changed

MinkowskiEngine/MinkowskiSparseTensor.py

+26-24
Original file line numberDiff line numberDiff line change
@@ -512,15 +512,13 @@ def dense(self, shape=None, min_coordinate=None, contract_stride=True):
512512
tensor_stride = torch.IntTensor(self.tensor_stride)
513513
return dense_F, min_coordinate, tensor_stride
514514

515-
def slice(self, X, slicing_mode=0):
515+
def slice(self, X):
516516
r"""
517517
518518
Args:
519519
:attr:`X` (:attr:`MinkowskiEngine.SparseTensor`): a sparse tensor
520520
that discretized the original input.
521521
522-
:attr:`slicing_mode`: For future updates.
523-
524522
Returns:
525523
:attr:`tensor_field` (:attr:`MinkowskiEngine.TensorField`): the
526524
resulting tensor field contains features on the continuous
@@ -530,7 +528,7 @@ def slice(self, X, slicing_mode=0):
530528
531529
>>> # coords, feats from a data loader
532530
>>> print(len(coords)) # 227742
533-
>>> tfield = ME.TensorField(coords=coords, feats=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
531+
>>> tfield = ME.TensorField(coordinates=coords, features=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
534532
>>> print(len(tfield)) # 227742
535533
>>> sinput = tfield.sparse() # 161890 quantization results in fewer voxels
536534
>>> soutput = MinkUNet(sinput)
@@ -545,9 +543,7 @@ def slice(self, X, slicing_mode=0):
545543
SparseTensorQuantizationMode.RANDOM_SUBSAMPLE,
546544
SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
547545
], "slice only available for sparse tensors with quantization RANDOM_SUBSAMPLE or UNWEIGHTED_AVERAGE"
548-
assert (
549-
X.coordinate_map_key == self.coordinate_map_key
550-
), "Slice can only be applied on the same coordinates (coordinate_map_key)"
546+
551547
from MinkowskiTensorField import TensorField
552548

553549
if isinstance(X, TensorField):
@@ -557,23 +553,28 @@ def slice(self, X, slicing_mode=0):
557553
coordinate_manager=X.coordinate_manager,
558554
quantization_mode=X.quantization_mode,
559555
)
560-
else:
556+
elif isinstance(X, SparseTensor):
557+
assert (
558+
X.coordinate_map_key == self.coordinate_map_key
559+
), "Slice can only be applied on the same coordinates (coordinate_map_key)"
561560
return TensorField(
562561
self.F[X.inverse_mapping],
563562
coordinates=self.C[X.inverse_mapping],
564-
coordinate_manager=X.coordinate_manager,
565-
quantization_mode=X.quantization_mode,
563+
coordinate_manager=self.coordinate_manager,
564+
quantization_mode=self.quantization_mode,
565+
)
566+
else:
567+
raise ValueError(
568+
"Invalid input. The input must be an instance of TensorField or SparseTensor."
566569
)
567570

568-
def cat_slice(self, X, slicing_mode=0):
571+
def cat_slice(self, X):
569572
r"""
570573
571574
Args:
572575
:attr:`X` (:attr:`MinkowskiEngine.SparseTensor`): a sparse tensor
573576
that discretized the original input.
574577
575-
:attr:`slicing_mode`: For future updates.
576-
577578
Returns:
578579
:attr:`tensor_field` (:attr:`MinkowskiEngine.TensorField`): the
579580
resulting tensor field contains the concatenation of features on the
@@ -584,7 +585,7 @@ def cat_slice(self, X, slicing_mode=0):
584585
585586
>>> # coords, feats from a data loader
586587
>>> print(len(coords)) # 227742
587-
>>> sinput = ME.SparseTensor(coords=coords, feats=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
588+
>>> sinput = ME.SparseTensor(coordinates=coords, features=feats, quantization_mode=SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE)
588589
>>> print(len(sinput)) # 161890 quantization results in fewer voxels
589590
>>> soutput = network(sinput)
590591
>>> print(len(soutput)) # 161890 Output with the same resolution
@@ -596,29 +597,30 @@ def cat_slice(self, X, slicing_mode=0):
596597
SparseTensorQuantizationMode.RANDOM_SUBSAMPLE,
597598
SparseTensorQuantizationMode.UNWEIGHTED_AVERAGE,
598599
], "slice only available for sparse tensors with quantization RANDOM_SUBSAMPLE or UNWEIGHTED_AVERAGE"
599-
assert (
600-
X.coordinate_map_key == self.coordinate_map_key
601-
), "Slice can only be applied on the same coordinates (coordinate_map_key)"
600+
602601
from MinkowskiTensorField import TensorField
603602

604603
features = torch.cat((self.F[X.inverse_mapping], X.F), dim=1)
605604
if isinstance(X, TensorField):
606605
return TensorField(
607606
features,
608-
coordinate_map_key=X.coordinate_map_key,
609607
coordinate_field_map_key=X.coordinate_field_map_key,
610608
coordinate_manager=X.coordinate_manager,
611-
inverse_mapping=X.inverse_mapping,
612609
quantization_mode=X.quantization_mode,
613610
)
614-
else:
611+
elif isinstance(X, SparseTensor):
612+
assert (
613+
X.coordinate_map_key == self.coordinate_map_key
614+
), "Slice can only be applied on the same coordinates (coordinate_map_key)"
615615
return TensorField(
616616
features,
617617
coordinates=self.C[X.inverse_mapping],
618-
coordinate_map_key=X.coordinate_map_key,
619-
coordinate_manager=X.coordinate_manager,
620-
inverse_mapping=X.inverse_mapping,
621-
quantization_mode=X.quantization_mode,
618+
coordinate_manager=self.coordinate_manager,
619+
quantization_mode=self.quantization_mode,
620+
)
621+
else:
622+
raise ValueError(
623+
"Invalid input. The input must be an instance of TensorField or SparseTensor."
622624
)
623625

624626
def features_at_coordinates(self, query_coordinates: torch.Tensor):

MinkowskiEngine/MinkowskiTensorField.py

+14-1
Original file line numberDiff line numberDiff line change
@@ -234,13 +234,25 @@ def sparse(self, quantization_mode=None):
234234
if quantization_mode is None:
235235
quantization_mode = self.quantization_mode
236236

237-
return SparseTensor(
237+
sparse_tensor = SparseTensor(
238238
self._F,
239239
coordinates=self.coordinates,
240240
quantization_mode=quantization_mode,
241241
coordinate_manager=self.coordinate_manager,
242242
)
243243

244+
# Save the inverse mapping
245+
self._inverse_mapping = sparse_tensor.inverse_mapping
246+
return sparse_tensor
247+
248+
@property
249+
def inverse_mapping(self):
250+
if not hasattr(self, "_inverse_mapping"):
251+
raise ValueError(
252+
"Did you run SparseTensor.slice? The slice must take a tensor field that returned TensorField.space."
253+
)
254+
return self._inverse_mapping
255+
244256
def __repr__(self):
245257
return (
246258
self.__class__.__name__
@@ -269,5 +281,6 @@ def __repr__(self):
269281
"coordinate_field_map_key",
270282
"_manager",
271283
"quantization_mode",
284+
"_inverse_mapping",
272285
"_batch_rows",
273286
)

0 commit comments

Comments
 (0)