masahi commented on a change in pull request #4478: [TOPI] implement pool3d op
URL: https://github.com/apache/incubator-tvm/pull/4478#discussion_r356737854
 
 

 ##########
 File path: src/relay/op/nn/pooling.cc
 ##########
 @@ -720,5 +720,238 @@ RELAY_REGISTER_OP("nn.avg_pool2d_grad")
 .set_attr<FTVMCompute>("FTVMCompute", Pool2DGradCompute<AvgPool2DAttrs, 
topi::nn::kAvgPool>);
 
 
+// relay.nn.max_pool3d & relay.nn.avg_pool3d
+TVM_REGISTER_NODE_TYPE(MaxPool3DAttrs);
+TVM_REGISTER_NODE_TYPE(AvgPool3DAttrs);
+
+template <typename AttrType>
+bool Pool3DRel(const Array<Type>& types,
+               int num_inputs,
+               const Attrs& attrs,
+               const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+
+  if (data == nullptr) return false;
+
+  const auto dshape = data->shape;
+  CHECK_GE(dshape.size(), 3U)
+      << "Pool3D only support input >= 3-D: input must have depth, height and 
width";
+  const auto param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+
+  Layout layout(param->layout);
+  CHECK(layout.Contains(LayoutAxis::Get('D')) && 
layout.Contains(LayoutAxis::Get('H')) &&
+        layout.Contains(LayoutAxis::Get('W')) && 
!layout.Contains(LayoutAxis::Get('d')) &&
+        !layout.Contains(LayoutAxis::Get('h')) && 
!layout.Contains(LayoutAxis::Get('w')))
+    << "Invalid layout " << layout
+    << ". Pool3D layout must have D, H and W, which cannot be split";
+
+  const auto didx = layout.IndexOf(LayoutAxis::Get('D'));
+  const auto hidx = layout.IndexOf(LayoutAxis::Get('H'));
+  const auto widx = layout.IndexOf(LayoutAxis::Get('W'));
+
+  IndexExpr pad_d, pad_h, pad_w;
+  if (param->padding.size() == 1) {
+    pad_d = param->padding[0] * 2;
+    pad_h = param->padding[0] * 2;
+    pad_w = param->padding[0] * 2;
+  } else if (param->padding.size() == 3) {
+    // (front, top, left)
+    pad_d = param->padding[0] * 2;
+    pad_h = param->padding[1] * 2;
+    pad_w = param->padding[2] * 2;
+  } else if (param->padding.size() == 6) {
+    // (front, top, left, back, bottom, right)
+    pad_d = param->padding[0] + param->padding[3];
+    pad_h = param->padding[1] + param->padding[4];
+    pad_w = param->padding[2] + param->padding[5];
+  } else {
+    return false;
+  }
+
+  std::vector<IndexExpr> oshape;
+  for (const auto& e : dshape) {
+    oshape.push_back(e);
+  }
+
+  std::vector<int> idxes = {didx, hidx, widx};
+  for (int i = 0; i < 3; i++) {
+    int ii = idxes[i];
+    if (dshape[ii].as<ir::Any>()) {
+      oshape[ii] = dshape[ii];
+    } else {
+      if (param->ceil_mode) {
+        oshape[ii] = ((dshape[ii] + pad_d - param->pool_size[i] +
+                         param->strides[i] - 1) / param->strides[i]) + 1;
+      } else {
+        oshape[ii] = ((dshape[ii] + pad_d - param->pool_size[i]) / 
param->strides[i]) + 1;
+      }
+    }
+  }
+
+  // assign output type
+  reporter->Assign(types[1], TensorTypeNode::make(oshape, data->dtype));
+  return true;
+}
+
+// MaxPool3D
+Expr MakeMaxPool3D(Expr data,
+                   Array<IndexExpr> pool_size,
+                   Array<IndexExpr> strides,
+                   Array<IndexExpr> padding,
+                   std::string layout,
+                   bool ceil_mode) {
+  auto attrs = make_node<MaxPool3DAttrs>();
+  attrs->pool_size = std::move(pool_size);
+  attrs->strides = std::move(strides);
+  attrs->padding = std::move(padding);
+  attrs->layout = std::move(layout);
+  attrs->ceil_mode = ceil_mode;
+  static const Op& op = Op::Get("nn.max_pool3d");
+  return CallNode::make(op, {data}, Attrs(attrs), {});
+}
+
+template<typename AttrType, topi::nn::PoolType mode>
+Array<Tensor> Pool3DCompute(const Attrs& attrs,
+                            const Array<Tensor>& inputs,
+                            const Type& out_type,
+                            const Target& target) {
+  static const Layout kNCDHW("NCDHW");
+  const auto* param = attrs.as<AttrType>();
+  CHECK(param != nullptr);
+  auto pool_size = param->pool_size;
+  auto strides = param->strides;
+  auto padding = param->padding;
+  auto ceil_mode = param->ceil_mode;
+  Layout layout(param->layout);
+
+  CHECK(BijectiveLayoutNode::make(layout, kNCDHW).defined())
+      << "max_pool3d currently only supports layouts that are convertible from 
NCDHW";
+  CHECK_EQ(layout.IndexOf(LayoutAxis::Get('d')), -1)
+      << "max_pool3d does not support input split on depth";
+  CHECK_EQ(layout.IndexOf(LayoutAxis::Get('h')), -1)
+      << "max_pool3d does not support input split on height";
+  CHECK_EQ(layout.IndexOf(LayoutAxis::Get('w')), -1)
+      << "max_pool3d does not support input split on width";
+
+  CHECK(inputs[0].ndim() == 4U ||
+        inputs[0].ndim() == 5U ||
+        inputs[0].ndim() == 6U)
+      << "Pool3D only support 5-D input (e.g., NCDHW)"
+      << " or 6-D input (e.g. NCDHWc on for vector instructions)"
+      << " or 7-D input (e.g. NCDHWnc for tensor accelerators)";
+
+  if (param->padding.size() == 1) {
+    padding.push_back(padding[0]);
+    padding.push_back(padding[0]);
+    padding.push_back(padding[0]);
+  } else if (param->padding.size() == 3) {
+    padding.push_back(padding[0]);
+    padding.push_back(padding[1]);
+    padding.push_back(padding[2]);
+  }
+  if (mode == topi::nn::kAvgPool) {
+    bool count_include_pad = reinterpret_cast<const 
AvgPool3DAttrs*>(param)->count_include_pad;
+    return Array<Tensor>{
+      topi::nn::pool3d(inputs[0], pool_size, strides, padding,
+                       mode, ceil_mode, layout.name(), count_include_pad)};
+  } else {
+    return Array<Tensor>{
+      topi::nn::pool3d(inputs[0], pool_size, strides, padding,
+                       mode, ceil_mode, layout.name())};
+  }
+}
+
+TVM_REGISTER_API("relay.op.nn._make.max_pool3d")
+.set_body_typed(MakeMaxPool3D);
 
 Review comment:
   If you unify MakeMaxPool2D and MakeMaxPool3D as I said above, you need to 
pass a lambda here to pass an additional argument. It will be something like
   
   ```
    [](Expr data,
          Array<IndexExpr> pool_size,
          Array<IndexExpr> strides,
          Array<IndexExpr> padding,
          std::string layout,
          bool ceil_mode) {
       return MakeMaxPool<MaxPool3DAttrs>(data, pool_size, strides, padding, 
layout, ceil_mode, "nn.max_pool3d");
    }
   ```

----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
 
For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


With regards,
Apache Git Services

Reply via email to