|
1 | 1 | # Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
| -# Copyright 2024-2025 Arm Limited and/or its affiliates. |
3 | 2 | # All rights reserved.
|
| 3 | +# Copyright 2024-2025 Arm Limited and/or its affiliates. |
4 | 4 | #
|
5 | 5 | # This source code is licensed under the BSD-style license found in the
|
6 | 6 | # LICENSE file in the root directory of this source tree.
|
@@ -714,23 +714,30 @@ def _assert_outputs_equal(model_output, ref_output, atol=1e-03, rtol=1e-03):
|
714 | 714 | assert (
|
715 | 715 | ref.shape == model.shape
|
716 | 716 | ), f"Output {i} shape {model.shape} does not match reference output shape {ref.shape}"
|
717 |
| - assert torch.allclose( |
718 |
| - model, |
719 |
| - ref, |
720 |
| - atol=atol, |
721 |
| - rtol=rtol, |
722 |
| - ), ( |
723 |
| - f"Output {i} does not match reference output.\n" |
724 |
| - f"\tGiven atol: {atol}, rtol: {rtol}.\n" |
725 |
| - f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" |
726 |
| - f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n" |
727 |
| - f"\t-- Model vs. Reference --\n" |
728 |
| - f"\t Numel: {model.numel()}, {ref.numel()}\n" |
729 |
| - f"\tMedian: {model.median()}, {ref.median()}\n" |
730 |
| - f"\t Mean: {model.mean()}, {ref.mean()}\n" |
731 |
| - f"\t Max: {model.max()}, {ref.max()}\n" |
732 |
| - f"\t Min: {model.min()}, {ref.min()}\n" |
733 |
| - ) |
| 717 | + if model.dtype == torch.bool: |
| 718 | + assert torch.equal(model, ref), ( |
| 719 | + f"Output {i} (bool tensor) does not match reference output.\n" |
| 720 | + f"\tShape: {model.shape}\n" |
| 721 | + f"\tMismatched count: {(model != ref).sum().item()} / {model.numel()}\n" |
| 722 | + ) |
| 723 | + else: |
| 724 | + assert torch.allclose( |
| 725 | + model, |
| 726 | + ref, |
| 727 | + atol=atol, |
| 728 | + rtol=rtol, |
| 729 | + ), ( |
| 730 | + f"Output {i} does not match reference output.\n" |
| 731 | + f"\tGiven atol: {atol}, rtol: {rtol}.\n" |
| 732 | + f"\tOutput tensor shape: {model.shape}, dtype: {model.dtype}\n" |
| 733 | + f"\tDifference: max: {torch.max(model-ref)}, abs: {torch.max(torch.abs(model-ref))}, mean abs error: {torch.mean(torch.abs(model-ref))}.\n" |
| 734 | + f"\t-- Model vs. Reference --\n" |
| 735 | + f"\t Numel: {model.numel()}, {ref.numel()}\n" |
| 736 | + f"\tMedian: {model.median()}, {ref.median()}\n" |
| 737 | + f"\t Mean: {model.mean()}, {ref.mean()}\n" |
| 738 | + f"\t Max: {model.max()}, {ref.max()}\n" |
| 739 | + f"\t Min: {model.min()}, {ref.min()}\n" |
| 740 | + ) |
734 | 741 |
|
735 | 742 | @staticmethod
|
736 | 743 | def _compare_outputs(
|
|
0 commit comments