masahi commented on a change in pull request #4478: [TOPI] implement pool3d op URL: https://github.com/apache/incubator-tvm/pull/4478#discussion_r356737854
########## File path: src/relay/op/nn/pooling.cc ########## @@ -720,5 +720,238 @@ RELAY_REGISTER_OP("nn.avg_pool2d_grad") .set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<AvgPool2DAttrs, topi::nn::kAvgPool>); +// relay.nn.max_pool3d & relay.nn.avg_pool3d +TVM_REGISTER_NODE_TYPE(MaxPool3DAttrs); +TVM_REGISTER_NODE_TYPE(AvgPool3DAttrs); + +template <typename AttrType> +bool Pool3DRel(const Array<Type>& types, + int num_inputs, + const Attrs& attrs, + const TypeReporter& reporter) { + CHECK_EQ(types.size(), 2); + const auto* data = types[0].as<TensorTypeNode>(); + + if (data == nullptr) return false; + + const auto dshape = data->shape; + CHECK_GE(dshape.size(), 3U) + << "Pool3D only support input >= 3-D: input must have depth, height and width"; + const auto param = attrs.as<AttrType>(); + CHECK(param != nullptr); + + Layout layout(param->layout); + CHECK(layout.Contains(LayoutAxis::Get('D')) && layout.Contains(LayoutAxis::Get('H')) && + layout.Contains(LayoutAxis::Get('W')) && !layout.Contains(LayoutAxis::Get('d')) && + !layout.Contains(LayoutAxis::Get('h')) && !layout.Contains(LayoutAxis::Get('w'))) + << "Invalid layout " << layout + << ". Pool3D layout must have D, H and W, which cannot be split"; + + const auto didx = layout.IndexOf(LayoutAxis::Get('D')); + const auto hidx = layout.IndexOf(LayoutAxis::Get('H')); + const auto widx = layout.IndexOf(LayoutAxis::Get('W')); + + IndexExpr pad_d, pad_h, pad_w; + if (param->padding.size() == 1) { + pad_d = param->padding[0] * 2; + pad_h = param->padding[0] * 2; + pad_w = param->padding[0] * 2; + } else if (param->padding.size() == 3) { + // (front, top, left) + pad_d = param->padding[0] * 2; + pad_h = param->padding[1] * 2; + pad_w = param->padding[2] * 2; + } else if (param->padding.size() == 6) { + // (front, top, left, back, bottom, right) + pad_d = param->padding[0] + param->padding[3]; + pad_h = param->padding[1] + param->padding[4]; + pad_w = param->padding[2] + param->padding[5]; + } else { + return false; + } + + std::vector<IndexExpr> oshape; + for (const auto& e : dshape) { + oshape.push_back(e); + } + + std::vector<int> idxes = {didx, hidx, widx}; + for (int i = 0; i < 3; i++) { + int ii = idxes[i]; + if (dshape[ii].as<ir::Any>()) { + oshape[ii] = dshape[ii]; + } else { + if (param->ceil_mode) { + oshape[ii] = ((dshape[ii] + pad_d - param->pool_size[i] + + param->strides[i] - 1) / param->strides[i]) + 1; + } else { + oshape[ii] = ((dshape[ii] + pad_d - param->pool_size[i]) / param->strides[i]) + 1; + } + } + } + + // assign output type + reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype)); + return true; +} + +// MaxPool3D +Expr MakeMaxPool3D(Expr data, + Array<IndexExpr> pool_size, + Array<IndexExpr> strides, + Array<IndexExpr> padding, + std::string layout, + bool ceil_mode) { + auto attrs = make_node<MaxPool3DAttrs>(); + attrs->pool_size = std::move(pool_size); + attrs->strides = std::move(strides); + attrs->padding = std::move(padding); + attrs->layout = std::move(layout); + attrs->ceil_mode = ceil_mode; + static const Op& op = Op::Get("nn.max_pool3d"); + return CallNode::make(op, {data}, Attrs(attrs), {}); +} + +template<typename AttrType, topi::nn::PoolType mode> +Array<Tensor> Pool3DCompute(const Attrs& attrs, + const Array<Tensor>& inputs, + const Type& out_type, + const Target& target) { + static const Layout kNCDHW("NCDHW"); + const auto* param = attrs.as<AttrType>(); + CHECK(param != nullptr); + auto pool_size = param->pool_size; + auto strides = param->strides; + auto padding = param->padding; + auto ceil_mode = param->ceil_mode; + Layout layout(param->layout); + + CHECK(BijectiveLayoutNode::make(layout, kNCDHW).defined()) + << "max_pool3d currently only supports layouts that are convertible from NCDHW"; + CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1) + << "max_pool3d does not support input split on depth"; + CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1) + << "max_pool3d does not support input split on height"; + CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1) + << "max_pool3d does not support input split on width"; + + CHECK(inputs[0].ndim() == 4U || + inputs[0].ndim() == 5U || + inputs[0].ndim() == 6U) + << "Pool3D only support 5-D input (e.g., NCDHW)" + << " or 6-D input (e.g. NCDHWc on for vector instructions)" + << " or 7-D input (e.g. NCDHWnc for tensor accelerators)"; + + if (param->padding.size() == 1) { + padding.push_back(padding[0]); + padding.push_back(padding[0]); + padding.push_back(padding[0]); + } else if (param->padding.size() == 3) { + padding.push_back(padding[0]); + padding.push_back(padding[1]); + padding.push_back(padding[2]); + } + if (mode == topi::nn::kAvgPool) { + bool count_include_pad = reinterpret_cast<const AvgPool3DAttrs*>(param)->count_include_pad; + return Array<Tensor>{ + topi::nn::pool3d(inputs[0], pool_size, strides, padding, + mode, ceil_mode, layout.name(), count_include_pad)}; + } else { + return Array<Tensor>{ + topi::nn::pool3d(inputs[0], pool_size, strides, padding, + mode, ceil_mode, layout.name())}; + } +} + +TVM_REGISTER_API("relay.op.nn._make.max_pool3d") +.set_body_typed(MakeMaxPool3D); Review comment: If you unify MakeMaxPool2D and MakeMaxPool3D as I said above, you need to pass a lambda here to pass an additional argument. It will be something like ``` [](Expr data, Array<IndexExpr> pool_size, Array<IndexExpr> strides, Array<IndexExpr> padding, std::string layout, bool ceil_mode) { return MakeMaxPool<MaxPool3DAttrs>(data, pool_size, strides, padding, layout, ceil_mode, "nn.max_pool3d"); } ``` ---------------------------------------------------------------- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. For queries about this service, please contact Infrastructure at: us...@infra.apache.org With regards, Apache Git Services