@@ -33,7 +33,7 @@ class AugmentationSequential(TransformMatrixMinIn, ImageSequential):
33
33
*args: a list of kornia augmentation modules.
34
34
35
35
data_keys: the input type sequential for applying augmentations. Accepts "input", "image", "mask",
36
- "bbox", "bbox_xyxy", "bbox_xywh", "keypoints".
36
+ "bbox", "bbox_xyxy", "bbox_xywh", "keypoints", "class", "label" .
37
37
38
38
same_on_batch: apply the same transformation across the batch. If None, it will not overwrite the function-wise
39
39
settings.
@@ -234,7 +234,7 @@ def __init__(
234
234
self .data_keys = data_keys
235
235
236
236
if self .data_keys :
237
- if not all (in_type in DataKey for in_type in self .data_keys ):
237
+ if any (in_type not in DataKey for in_type in self .data_keys ):
238
238
raise AssertionError (f"`data_keys` must be in { DataKey } . Got { self .data_keys } ." )
239
239
240
240
if self .data_keys [0 ] != DataKey .INPUT :
@@ -446,9 +446,9 @@ def _preproc_dict_data(
446
446
if self .data_keys is not None :
447
447
raise ValueError ("If you are using a dictionary as input, the data_keys should be None." )
448
448
449
- data_keys = self ._read_datakeys_from_dict (tuple (data .keys ()))
450
449
keys = tuple (data .keys ())
451
- data_unpacked = tuple (v for v in data .values ())
450
+ data_keys = self ._read_datakeys_from_dict (keys )
451
+ data_unpacked = tuple (data .values ())
452
452
453
453
return keys , data_keys , data_unpacked
454
454
@@ -467,7 +467,7 @@ def retrieve_key(key: str) -> DataKey:
467
467
468
468
allowed_dk = " | " .join (f"`{ d .name } `" for d in DataKey )
469
469
raise ValueError (
470
- f"You input data dictionary keys should starts with some of datakey values: { allowed_dk } . Got `{ key } `"
470
+ f"Your input data dictionary keys should start with some of datakey values: { allowed_dk } . Got `{ key } `"
471
471
)
472
472
473
473
return [DataKey .get (retrieve_key (k )) for k in keys ]
0 commit comments