Skip to content

Commit 0e35c30

Browse files
authored
Arm backend: Avoid subtraction error on boolean tensors in error diff logging (#11249)
This patch add a dtype check before diff calculation to avoid runtime errors since boolean tensors do not support arithmetic operations like subtraction. Reports the mismatch count for boolean tensors. for example: ``` ############### ERROR DIFFERENCE ############### BATCH 0 (BOOLEAN tensor) 0 / 8 elements differ (0.00%) BATCH 1 (BOOLEAN tensor) 8 / 8 elements differ (100.00%) BATCH 2 (BOOLEAN tensor) 8 / 8 elements differ (100.00%) ################################################ ... ... AssertionError: Output 0 (bool tensor) does not match reference output. ``` Signed-off-by: Fang-Ching <Fang-Ching.Chen@arm.com>
1 parent bca2cf5 commit 0e35c30

File tree

2 files changed

+41
-26
lines changed

2 files changed

+41
-26
lines changed

backends/arm/test/tester/analyze_output_utils.py

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,13 @@ def print_error_diffs(
154154
output_str += f"BATCH {n}\n"
155155
result_batch = result[n, :, :, :]
156156
reference_batch = reference[n, :, :, :]
157+
158+
if reference_batch.dtype == torch.bool or result_batch.dtype == torch.bool:
159+
mismatches = (reference_batch != result_batch).sum().item()
160+
total = reference_batch.numel()
161+
output_str += f"(BOOLEAN tensor) {mismatches} / {total} elements differ ({mismatches / total:.2%})\n"
162+
continue
163+
157164
is_close = torch.allclose(result_batch, reference_batch, rtol, atol)
158165
if is_close:
159166
output_str += ".\n"
@@ -180,14 +187,15 @@ def print_error_diffs(
180187
output_str += _print_elements(
181188
result[n, :, :, :], reference[n, :, :, :], C, H, W, rtol, atol
182189
)
183-
184-
reference_range = torch.max(reference) - torch.min(reference)
185-
diff = torch.abs(reference - result).flatten()
186-
diff = diff[diff.nonzero()]
187-
if not len(diff) == 0:
188-
diff_percent = diff / reference_range
189-
output_str += "\nMEAN MEDIAN MAX MIN (error as % of reference output range)\n"
190-
output_str += f"{torch.mean(diff_percent):<8.2%} {torch.median(diff_percent):<8.2%} {torch.max(diff_percent):<8.2%} {torch.min(diff_percent):<8.2%}\n"
190+
# Only compute numeric error metrics if tensor is not boolean
191+
if reference.dtype != torch.bool and result.dtype != torch.bool:
192+
reference_range = torch.max(reference) - torch.min(reference)
193+
diff = torch.abs(reference - result).flatten()
194+
diff = diff[diff.nonzero()]
195+
if not len(diff) == 0:
196+
diff_percent = diff / reference_range
197+
output_str += "\nMEAN MEDIAN MAX MIN (error as % of reference output range)\n"
198+
output_str += f"{torch.mean(diff_percent):<8.2%} {torch.median(diff_percent):<8.2%} {torch.max(diff_percent):<8.2%} {torch.min(diff_percent):<8.2%}\n"
191199

192200
# Over-engineer separators to match output width
193201
lines = output_str.split("\n")

backends/xnnpack/test/tester/tester.py

Lines changed: 25 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -714,23 +714,30 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
714714
assert (
715715
ref.shape == model.shape
716716
), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
717-
assert torch.allclose(
718-
model,
719-
ref,
720-
atol=atol,
721-
rtol=rtol,
722-
), (
723-
f"Output {i} does not match reference output.\n"
724-
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
725-
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
726-
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n"
727-
f"\t-- Model vs. Reference --\n"
728-
f"\t Numel: {model.numel()}, {ref.numel()}\n"
729-
f"\tMedian: {model.median()}, {ref.median()}\n"
730-
f"\t Mean: {model.mean()}, {ref.mean()}\n"
731-
f"\t Max: {model.max()}, {ref.max()}\n"
732-
f"\t Min: {model.min()}, {ref.min()}\n"
733-
)
717+
if model.dtype == torch.bool:
718+
assert torch.equal(model, ref), (
719+
f"Output {i} (bool tensor) does not match reference output.\n"
720+
f"\tShape: {model.shape}\n"
721+
f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n"
722+
)
723+
else:
724+
assert torch.allclose(
725+
model,
726+
ref,
727+
atol=atol,
728+
rtol=rtol,
729+
), (
730+
f"Output {i} does not match reference output.\n"
731+
f"\tGiven atol: {atol}, rtol: {rtol}.\n"
732+
f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n"
733+
f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n"
734+
f"\t-- Model vs. Reference --\n"
735+
f"\t Numel: {model.numel()}, {ref.numel()}\n"
736+
f"\tMedian: {model.median()}, {ref.median()}\n"
737+
f"\t Mean: {model.mean()}, {ref.mean()}\n"
738+
f"\t Max: {model.max()}, {ref.max()}\n"
739+
f"\t Min: {model.min()}, {ref.min()}\n"
740+
)
734741

735742
@staticmethod
736743
def _compare_outputs(

0 commit comments

Comments
 (0)