Skip to content

Commit 0825fd2

Browse files
author
Weiming Zhao
committed
[Bug] Fix axis attribute for reduction instrs
When the axis op is constant, it should update the axis attribute. (cherry picked from commit 4fe94fa)
1 parent bfa366c commit 0825fd2

21 files changed

+28
-22
lines changed

lib/transforms/type_legalizer.cc

+8-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <algorithm>
2121
#include <cmath>
2222
#include <limits>
23+
#include <type_traits>
2324
#include <unordered_set>
2425

2526
#include "halo/api/halo_data.h"
@@ -475,8 +476,8 @@ static void RunOnInstruction(Conv2DTransposeInst* inst) {
475476
}
476477
}
477478

478-
static void RunOnCommonReductionInstruction(Instruction* inst,
479-
std::vector<int32_t> axis,
479+
template <typename T>
480+
static void RunOnCommonReductionInstruction(T* inst, std::vector<int32_t> axis,
480481
bool keep_dims) {
481482
const auto& input_type = inst->GetOperand(0).GetType();
482483
if (!input_type.IsValid()) {
@@ -528,6 +529,11 @@ static void RunOnCommonReductionInstruction(Instruction* inst,
528529
dt = DataType::INT32;
529530
}
530531

532+
constexpr bool is_arg_inst =
533+
std::is_same<T, ArgmaxInst>() || std::is_same<T, ArgminInst>();
534+
if constexpr (!is_arg_inst) { // NOLINT
535+
inst->SetAxis(axis);
536+
}
531537
inst->GetResultsTypes()[0] = halo::Type{dt, ret_shape};
532538
}
533539

tests/unittests/lit_cases/test_dnnl/test_reduce_mean_negative_axes_keepdims_example_dnnl.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_reduce_mean_negative_axes_keepdims_example | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_mean_negative_axes_keepdims_example_dnnl.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_dnnl/test_reduce_mean_negative_axes_keepdims_random_dnnl.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_reduce_mean_negative_axes_keepdims_random | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_mean_negative_axes_keepdims_random_dnnl.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_dnnl/test_reduce_min_negative_axes_keepdims_example_dnnl.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_reduce_min_negative_axes_keepdims_example | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_min_negative_axes_keepdims_example_dnnl.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_dnnl/test_reduce_min_negative_axes_keepdims_random_dnnl.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_dnnl.exe 0.0001 0 dnnl %data_path/test_reduce_min_negative_axes_keepdims_random | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_min_negative_axes_keepdims_random_dnnl.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_popart/test_reduce_mean_negative_axes_keepdims_example_popart.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_popart.exe 0.0001 0 popart %data_path/test_reduce_mean_negative_axes_keepdims_example | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_mean_negative_axes_keepdims_example_popart.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_popart/test_reduce_mean_negative_axes_keepdims_random_popart.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_popart.exe 0.0001 0 popart %data_path/test_reduce_mean_negative_axes_keepdims_random | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_mean_negative_axes_keepdims_random_popart.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_popart/test_reduce_min_negative_axes_keepdims_example_popart.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_popart.exe 0.0001 0 popart %data_path/test_reduce_min_negative_axes_keepdims_example | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_min_negative_axes_keepdims_example_popart.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_popart/test_reduce_min_negative_axes_keepdims_random_popart.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_popart.exe 0.0001 0 popart %data_path/test_reduce_min_negative_axes_keepdims_random | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_min_negative_axes_keepdims_random_popart.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_l1_negative_axes_keep_dims_example_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_l1_negative_axes_keep_dims_example | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_l1_negative_axes_keep_dims_example_tensorrt.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_l1_negative_axes_keep_dims_random_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_l1_negative_axes_keep_dims_random | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_l1_negative_axes_keep_dims_random_tensorrt.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_l2_negative_axes_keep_dims_example_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_l2_negative_axes_keep_dims_example | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_l2_negative_axes_keep_dims_example_tensorrt.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_l2_negative_axes_keep_dims_random_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_l2_negative_axes_keep_dims_random | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_l2_negative_axes_keep_dims_random_tensorrt.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_log_sum_exp_negative_axes_keepdims_example_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_log_sum_exp_negative_axes_keepdims_example | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_log_sum_exp_negative_axes_keepdims_example_tensorrt.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_log_sum_exp_negative_axes_keepdims_random_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_log_sum_exp_negative_axes_keepdims_random | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_log_sum_exp_negative_axes_keepdims_random_tensorrt.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_log_sum_negative_axes_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_log_sum_negative_axes | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_log_sum_negative_axes_tensorrt.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_mean_negative_axes_keepdims_example_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_mean_negative_axes_keepdims_example | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_mean_negative_axes_keepdims_example_tensorrt.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_mean_negative_axes_keepdims_random_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_mean_negative_axes_keepdims_random | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_mean_negative_axes_keepdims_random_tensorrt.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_min_negative_axes_keepdims_example_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_min_negative_axes_keepdims_example | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_min_negative_axes_keepdims_example_tensorrt.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_min_negative_axes_keepdims_random_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_min_negative_axes_keepdims_random | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_min_negative_axes_keepdims_random_tensorrt.cc.tmp.main.cc.in"

tests/unittests/lit_cases/test_tensorrt/test_reduce_sum_square_negative_axes_keepdims_example_tensorrt.cc

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@
2525
// RUN: %t_tensorrt.exe 0.0001 0 tensorrt %data_path/test_reduce_sum_square_negative_axes_keepdims_example | FileCheck %s
2626
// CHECK: Result Pass
2727
// clang-format on
28-
// XFAIL: *
28+
2929
#include "test_reduce_sum_square_negative_axes_keepdims_example_tensorrt.cc.tmp.main.cc.in"

0 commit comments

Comments
 (0)