Skip to content
This repository was archived by the owner on Jan 3, 2023. It is now read-only.

Sarkars/unify pooling translate functions #420

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
91 changes: 18 additions & 73 deletions src/ngraph_builder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -607,9 +607,16 @@ static Status TranslateArgMinMaxOp(
return Status::OK();
}

static Status TranslateAvgPoolOp(
template <typename T>
static Status TranslatePoolOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
bool is_avgpool = std::is_same<T, ng::op::AvgPool>::value;
bool is_maxpool = std::is_same<T, ng::op::MaxPool>::value;
if (!(is_avgpool || is_maxpool)) {
return errors::InvalidArgument(
"Expected pooling type node to be average or maxpool type");
}
shared_ptr<ng::Node> ng_input;
TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, &ng_input));

Expand All @@ -624,7 +631,7 @@ static Status TranslateAvgPoolOp(

if (tf_data_format != "NHWC" && tf_data_format != "NCHW") {
return errors::InvalidArgument(
"AvgPool data format is neither NHWC nor NCHW");
"Pooling data format is neither NHWC nor NCHW");
}

bool is_nhwc = (tf_data_format == "NHWC");
Expand All @@ -646,7 +653,7 @@ static Status TranslateAvgPoolOp(
NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape);

// TODO: change this once nGraph supports negative padding
// (CoordinateDiff) for AvgPool
// (CoordinateDiff) for Pooling
// ng::CoordinateDiff ng_padding_below{0,0};
// ng::CoordinateDiff ng_padding_above{0,0};
ng::Shape ng_padding_below{0, 0};
Expand All @@ -655,15 +662,15 @@ static Status TranslateAvgPoolOp(
Builder::MakePadding(tf_padding_type, ng_image_shape, ng_kernel_shape,
ng_strides, ng_padding_below, ng_padding_above);

std::shared_ptr<ng::Node> ng_avgpool = ConstructNgNode<ng::op::AvgPool>(
op->name(), ng_input, ng_kernel_shape, ng_strides, ng_padding_below,
ng_padding_above, false);
std::shared_ptr<ng::Node> ng_pool =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Question:
Does TF AvgPool considers padding in average calculation? Just want to confirm here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Avg pool has this extra param whose default value is false:
https://github.com/NervanaSystems/ngraph/blob/066037c25901478ac79ead90aa6da721ab27c0c3/src/ngraph/op/avg_pool.hpp#L51

Earlier we were explicitly passing the false value, but even without it, the default value is going to be false

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! Sounds good!

ConstructNgNode<T>(op->name(), ng_input, ng_kernel_shape, ng_strides,
ng_padding_below, ng_padding_above);

BatchToTensorflow(is_nhwc, ng_avgpool);
NGRAPH_VLOG(3) << "avgpool outshape: {" << ng::join(ng_avgpool->get_shape())
BatchToTensorflow(is_nhwc, ng_pool);
NGRAPH_VLOG(3) << "pooling outshape: {" << ng::join(ng_pool->get_shape())
<< "}";

SaveNgOp(ng_op_map, op->name(), ng_avgpool);
SaveNgOp(ng_op_map, op->name(), ng_pool);
return Status::OK();
}

Expand Down Expand Up @@ -2191,68 +2198,6 @@ static Status TranslateMatMulOp(
return Status::OK();
}

static Status TranslateMaxPoolOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
shared_ptr<ng::Node> ng_input;
TF_RETURN_IF_ERROR(GetInputNodes(ng_op_map, op, &ng_input));

std::vector<int32> tf_strides;
std::vector<int32> tf_ksize;
std::string tf_padding_type;
std::string tf_data_format;
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "strides", &tf_strides));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "ksize", &tf_ksize));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "padding", &tf_padding_type));
TF_RETURN_IF_ERROR(GetNodeAttr(op->attrs(), "data_format", &tf_data_format));

if (tf_data_format != "NHWC" && tf_data_format != "NCHW") {
return errors::InvalidArgument(
"MaxPool data format is neither NHWC nor NCHW");
}

bool is_nhwc = (tf_data_format == "NHWC");

NGRAPH_VLOG(3) << ng::join(tf_strides);
NGRAPH_VLOG(3) << ng::join(tf_ksize);
NGRAPH_VLOG(3) << tf_padding_type;
NGRAPH_VLOG(3) << tf_data_format;

ng::Strides ng_strides(2);
ng::Shape ng_image_shape(2);
ng::Shape ng_kernel_shape(2);

BatchedOpParamToNGraph(is_nhwc, tf_strides, ng_strides);
BatchedOpParamToNGraph(is_nhwc, ng_input->get_shape(), ng_image_shape);
BatchedOpParamToNGraph(is_nhwc, tf_ksize, ng_kernel_shape);
BatchToNGraph(is_nhwc, ng_input);
NGRAPH_VLOG(3) << "ng_strides: " << ng::join(ng_strides);
NGRAPH_VLOG(3) << "ng_image_shape: " << ng::join(ng_image_shape);
NGRAPH_VLOG(3) << "ng_kernel_shape: " << ng::join(ng_kernel_shape);

// TODO: change this once nGraph supports negative padding
// (CoordinateDiff) for MaxPool
// ng::CoordinateDiff ng_padding_below{0,0};
// ng::CoordinateDiff ng_padding_above{0,0};
ng::Shape ng_padding_below{0, 0};
ng::Shape ng_padding_above{0, 0};

Builder::MakePadding(tf_padding_type, ng_image_shape, ng_kernel_shape,
ng_strides, ng_padding_below, ng_padding_above);

std::shared_ptr<ng::Node> ng_maxpool = ConstructNgNode<ng::op::MaxPool>(
op->name(), ng_input, ng_kernel_shape, ng_strides, ng_padding_below,
ng_padding_above);

BatchToTensorflow(is_nhwc, ng_maxpool);

NGRAPH_VLOG(3) << "maxpool outshape: {" << ng::join(ng_maxpool->get_shape())
<< "}";

SaveNgOp(ng_op_map, op->name(), ng_maxpool);
return Status::OK();
}

static Status TranslateMaxPool3DOp(
const Node* op, const std::vector<const Tensor*>& static_input_map,
Builder::OpMap& ng_op_map) {
Expand Down Expand Up @@ -4395,7 +4340,7 @@ const static std::map<
{"All", TranslateDirectReduceOp<ng::op::All>},
{"ArgMax", TranslateArgMinMaxOp<ng::op::ArgMax>},
{"ArgMin", TranslateArgMinMaxOp<ng::op::ArgMin>},
{"AvgPool", TranslateAvgPoolOp},
{"AvgPool", TranslatePoolOp<ng::op::AvgPool>},
{"AvgPoolGrad", TranslateAvgPoolGradOp},
{"BatchMatMul", TranslateBatchMatMulOp},
{"BiasAdd", TranslateBiasAddOp},
Expand Down Expand Up @@ -4436,7 +4381,7 @@ const static std::map<
{"MatMul", TranslateMatMulOp},
{"Max", TranslateDirectReduceOp<ng::op::Max>},
{"Maximum", TranslateBinaryOp<ngraph::op::Maximum>},
{"MaxPool", TranslateMaxPoolOp},
{"MaxPool", TranslatePoolOp<ng::op::MaxPool>},
{"MaxPool3D", TranslateMaxPool3DOp},
{"MaxPoolGrad", TranslateMaxPoolGradOp},
{"NonMaxSuppressionV4", TranslateNonMaxSuppressionV4Op},
Expand Down