This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new e369c5a  [Relay,Topi][OP] affine_grid and grid_sample (#5657)
e369c5a is described below

commit e369c5a9cbacb926ca7b95ebc4ae01a6de33c6cd
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Sat May 23 00:57:58 2020 -0400

    [Relay,Topi][OP] affine_grid and grid_sample (#5657)
    
    * [Relay,Topi][OP] affine_grid and grid_sample
    
    * lint
---
 include/tvm/relay/attrs/image.h                |  28 +++++
 python/tvm/relay/frontend/mxnet.py             |  22 ++++
 python/tvm/relay/op/image/_image.py            |  20 +++
 python/tvm/relay/op/image/image.py             |  64 ++++++++++
 src/relay/op/image/grid_sample.cc              | 168 +++++++++++++++++++++++++
 tests/python/frontend/mxnet/test_forward.py    |  38 ++++++
 tests/python/relay/test_op_level5.py           |  52 ++++++++
 topi/python/topi/image/__init__.py             |   1 +
 topi/python/topi/image/grid_sample.py          | 124 ++++++++++++++++++
 topi/python/topi/testing/__init__.py           |   1 +
 topi/python/topi/testing/grid_sample_python.py |  65 ++++++++++
 topi/tests/python/test_topi_image.py           |  83 ++++++++++++
 12 files changed, 666 insertions(+)

diff --git a/include/tvm/relay/attrs/image.h b/include/tvm/relay/attrs/image.h
index 58fd44b..cf5a6ef 100644
--- a/include/tvm/relay/attrs/image.h
+++ b/include/tvm/relay/attrs/image.h
@@ -167,6 +167,34 @@ struct Dilation2DAttrs : public 
tvm::AttrsNode<Dilation2DAttrs> {
   }
 };
 
+/*! \brief Attributes used in image affine_grid operator */
+struct AffineGridAttrs : public tvm::AttrsNode<AffineGridAttrs> {
+  Array<IndexExpr> target_shape;
+
+  TVM_DECLARE_ATTRS(AffineGridAttrs, "relay.attrs.AffineGridAttrs") {
+    TVM_ATTR_FIELD(target_shape).describe("Specifies the output shape (H, 
W).");
+  }
+};
+
+/*! \brief Attributes used in image grid_sample operator */
+struct GridSampleAttrs : public tvm::AttrsNode<GridSampleAttrs> {
+  String method;
+  String layout;
+
+  TVM_DECLARE_ATTRS(GridSampleAttrs, "relay.attrs.GridSampleAttrs") {
+    TVM_ATTR_FIELD(method)
+        .set_default("bilinear")
+        .describe(
+            "Specify the mode to use for scaling."
+            "bilinear - Bilinear Interpolation");
+    TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
+        "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+        "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+        "dimensions respectively. Resize is applied on the 'H' and"
+        "'W' dimensions.");
+  }
+};
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_IMAGE_H_
diff --git a/python/tvm/relay/frontend/mxnet.py 
b/python/tvm/relay/frontend/mxnet.py
index 9f97ee9..c75612d 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -757,6 +757,26 @@ def _mx_resize(inputs, attrs):
     return _op.image.resize(inputs[0], size,
                             coordinate_transformation_mode="align_corners")
 
+def _mx_grid_generator(inputs, attrs):
+    transform_type = attrs.get_str("transform_type")
+    if transform_type == 'affine':
+        target_shape = attrs.get_int_tuple("target_shape")
+        return _op.image.affine_grid(_op.reshape(inputs[0], (0, 2, 3)), 
target_shape)
+    if transform_type == 'warp':
+        checked_type = _infer_type(inputs[0]).checked_type
+        batch, _, height, width = get_const_tuple(checked_type.shape)
+        dtype = checked_type.dtype
+        identity_affine = relay.const(np.array([[[1.0, 0.0, 0.0], [0.0, 1.0, 
0.0]]], dtype=dtype))
+        identity_affine = _op.broadcast_to(identity_affine, (batch, 2, 3))
+        normalizer = (2.0 / np.array([width - 1, height - 1])).reshape(1, -1, 
1, 1).astype(dtype)
+        normalized_flow = inputs[0] * relay.const(normalizer)
+        grid = _op.image.affine_grid(identity_affine, (height, width))
+        return grid + normalized_flow
+    raise ValueError("unknown transform type" + transform_type)
+
+def _mx_bilinear_sampler(inputs, attrs):
+    return _op.image.grid_sample(inputs[0], inputs[1], 'bilinear', 'NCHW')
+
 def _mx_roi_pooling(inputs, attrs):
     new_attrs = {}
     new_attrs["pooled_size"] = attrs.get_int_tuple("pooled_size")
@@ -1996,6 +2016,8 @@ _convert_map = {
     "_contrib_box_nms" : _mx_box_nms,
     "_contrib_DeformableConvolution" : _mx_deformable_convolution,
     "_contrib_AdaptiveAvgPooling2D" : _mx_adaptive_avg_pooling,
+    "GridGenerator"                 : _mx_grid_generator,
+    "BilinearSampler"               : _mx_bilinear_sampler,
     # NLP
     "RNN"               : _mx_rnn_layer,
     "_rnn_param_concat" : _mx_rnn_param_concat,
diff --git a/python/tvm/relay/op/image/_image.py 
b/python/tvm/relay/op/image/_image.py
index 290c0a2..bcb110f 100644
--- a/python/tvm/relay/op/image/_image.py
+++ b/python/tvm/relay/op/image/_image.py
@@ -19,6 +19,7 @@
 from __future__ import absolute_import
 
 import topi
+from topi.util import get_const_tuple
 from .. import op as reg
 from .. import strategy
 from ..op import OpPattern
@@ -67,3 +68,22 @@ reg.register_injective_schedule("image.crop_and_resize")
 # dilation2d
 reg.register_strategy("image.dilation2d", strategy.dilation2d_strategy)
 reg.register_pattern("image.dilation2d", OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
+# affine_grid
+@reg.register_compute("image.affine_grid")
+def compute_affine_grid(attrs, inputs, out_dtype):
+    target_shape = get_const_tuple(attrs.target_shape)
+    return [topi.image.affine_grid(inputs[0], target_shape)]
+
+reg.register_injective_schedule("image.affine_grid")
+
+
+# grid_sample
+@reg.register_compute("image.grid_sample")
+def compute_grid_sample(attrs, inputs, out_dtype):
+    method = attrs.method
+    layout = attrs.layout
+    return [topi.image.grid_sample(inputs[0], inputs[1], method, layout)]
+
+reg.register_injective_schedule("image.grid_sample")
diff --git a/python/tvm/relay/op/image/image.py 
b/python/tvm/relay/op/image/image.py
index 49b35d8..62889e0 100644
--- a/python/tvm/relay/op/image/image.py
+++ b/python/tvm/relay/op/image/image.py
@@ -215,3 +215,67 @@ def dilation2d(data,
 
     return _make.dilation2d(data, weight, strides, padding, dilations, 
data_layout,
                             kernel_layout, out_dtype)
+
+
+def affine_grid(data, target_shape=None):
+    """affine_grid operator that generates 2D sampling grid.
+
+    This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It 
generates a uniform
+    sampling grid within the target shape and normalizes it to [-1, 1]. The 
provided affine
+    transformation is then applied on the sampling grid.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        3-D with shape [batch, 2, 3]. The affine matrix.
+
+    target_shape: list/tuple of two int
+        Specifies the output shape (H, W).
+
+    Returns
+    -------
+    Output : tvm.Tensor
+        4-D with shape [batch, 2, target_height, target_width]
+    """
+    return _make.affine_grid(data, target_shape)
+
+def grid_sample(data, grid, method='bilinear', layout='NCHW'):
+    """Applies bilinear sampling to input feature map.
+
+    Given :math:`data` and :math:`grid`, then the output is computed by
+
+    .. math::
+
+        x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\
+        y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\
+        output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, 
y_{src}, x_{src})
+
+    :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in 
:math:`output`, and
+    :math:`G()` denotes the interpolation function.
+    The out-boundary points will be padded with zeros. The shape of the output 
will be
+    (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).
+
+    The operator assumes that :math:`grid` has been normalized to [-1, 1].
+
+    grid_sample often cooperates with affine_grid which generates sampling 
grids for grid_sample.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    grid : tvm.Tensor
+        4-D with shape [batch, 2, out_height, out_width]
+
+    method : str
+        The interpolation method. Only 'bilinear' is supported.
+
+    layout : str
+        The layout of input data and the output.
+
+    Returns
+    -------
+    Output : tvm.Tensor
+        4-D with shape [batch, 2, out_height, out_width]
+    """
+    return _make.grid_sample(data, grid, method, layout)
diff --git a/src/relay/op/image/grid_sample.cc 
b/src/relay/op/image/grid_sample.cc
new file mode 100644
index 0000000..bc69891
--- /dev/null
+++ b/src/relay/op/image/grid_sample.cc
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file grid_sample.cc
+ * \brief affine_grid and grid_sample operator
+ */
+#include <tvm/relay/attrs/image.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relay {
+
+// relay.image.affine_grid
+TVM_REGISTER_NODE_TYPE(AffineGridAttrs);
+
+bool AffineGridRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                   const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 2);
+  const auto* data = types[0].as<TensorTypeNode>();
+  if (data == nullptr) return false;
+  auto batch_size = data->shape[0];
+
+  const AffineGridAttrs* param = attrs.as<AffineGridAttrs>();
+  CHECK(param != nullptr);
+
+  Array<IndexExpr> oshape;
+
+  CHECK(data->shape.size() == 3U && reporter->AssertEQ(data->shape[1], 2) &&
+        reporter->AssertEQ(data->shape[2], 3))
+      << "data should be an"
+         "affine matrix with shape [batch_size, 2, 3]";
+  CHECK(param->target_shape.defined() && param->target_shape.size() == 2)
+      << "target_shape should be 2D";
+  oshape.push_back(batch_size);
+  oshape.push_back(2);
+  oshape.push_back(param->target_shape[0]);
+  oshape.push_back(param->target_shape[1]);
+
+  // assign output type
+  reporter->Assign(types[1], TensorType(oshape, data->dtype));
+  return true;
+}
+
+// Positional relay function to create affine_grid operator
+// used by frontend FFI.
+Expr MakeAffineGrid(Expr data, Array<IndexExpr> target_shape) {
+  auto attrs = make_object<AffineGridAttrs>();
+  attrs->target_shape = std::move(target_shape);
+  static const Op& op = Op::Get("image.affine_grid");
+  return Call(op, {data}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.image._make.affine_grid").set_body_typed(MakeAffineGrid);
+
+RELAY_REGISTER_OP("image.affine_grid")
+    .describe(R"code(affine_grid operator that generates 2D sampling grid.
+
+This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It 
generates a uniform
+sampling grid within the target shape and normalizes it to [-1, 1]. The 
provided affine
+transformation is then applied on the sampling grid.
+
+- **data**: data is 3D array of shape [batch, 2, 3], which defines an affine 
transformation.
+
+- **out**: out is 4D array of shape [batch, 2, height, width], where each 
vector
+           :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)`
+
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<AffineGridAttrs>()
+    .set_num_inputs(1)
+    .add_argument("data", "Tensor", "The affine matrix.")
+    .set_support_level(5)
+    .add_type_rel("AffineGrid", AffineGridRel)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);
+
+// relay.image.grid_sample
+TVM_REGISTER_NODE_TYPE(GridSampleAttrs);
+
+bool GridSampleRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                   const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data = types[0].as<TensorTypeNode>();
+  const auto* grid = types[1].as<TensorTypeNode>();
+  if (!data || !grid) return false;
+  const auto* param = attrs.as<GridSampleAttrs>();
+  CHECK(param);
+  static const Layout kNCHW("NCHW");
+  const Layout in_layout(param->layout);
+  auto layout_converter = tir::BijectiveLayout(in_layout, kNCHW);
+  auto oshape = layout_converter.ForwardShape(data->shape);
+  oshape.Set(2, grid->shape[2]);
+  oshape.Set(3, grid->shape[3]);
+  // assign output type
+  reporter->Assign(types[2], 
TensorType(layout_converter.BackwardShape(oshape), data->dtype));
+  return true;
+}
+
+// Positional relay function to create affine_grid operator
+// used by frontend FFI.
+Expr MakeGridSample(Expr data, Expr grid, String method, String layout) {
+  auto attrs = make_object<GridSampleAttrs>();
+  attrs->method = std::move(method);
+  attrs->layout = std::move(layout);
+  static const Op& op = Op::Get("image.grid_sample");
+  return Call(op, {data, grid}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relay.op.image._make.grid_sample").set_body_typed(MakeGridSample);
+
+RELAY_REGISTER_OP("image.grid_sample")
+    .describe(R"code(Applies grid sampling to input feature map.
+
+Given :math:`data` and :math:`grid`, then the output is computed by
+
+.. math::
+  x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\
+  y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\
+  output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, y_{src}, 
x_{src})
+
+:math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in 
:math:`output`, and
+:math:`G()` denotes the interpolation function.
+The out-boundary points will be padded with zeros. The shape of the output 
will be
+(data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).
+
+The operator assumes that :math:`data` has 'NCHW' layout and :math:`grid` has 
been normalized to [-1, 1].
+
+grid_sample often cooperates with affine_grid which generates sampling grids 
for grid_sample.
+
+- **data**: data is 4D array of shape
+            (batch_size, channels, in_height, in_width) for NCHW
+            (batch_size, in_height, in_width, channels) for NHWC
+
+- **grid**: out is 4D array of shape [batch, 2, out_height, out_width], where 
each vector
+           :math:`out[b, :, h, w]` represents the coordinate :math:`(x, y)`
+
+- **out**: out is 4D array of shape
+           (batch, in_channel, out_height, out_width) for NCHW
+           (batch_size, in_height, in_width, channels) for NHWC
+
+)code" TVM_ADD_FILELINE)
+    .set_num_inputs(2)
+    .set_attrs_type<GridSampleAttrs>()
+    .add_argument("data", "Tensor", "The input tensor.")
+    .set_support_level(5)
+    .add_type_rel("GridSample", GridSampleRel)
+    .set_attr<TOpPattern>("TOpPattern", kInjective);
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/frontend/mxnet/test_forward.py 
b/tests/python/frontend/mxnet/test_forward.py
index 99fc6c3..6d36ea3 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -639,6 +639,42 @@ def test_forward_bilinear_resize():
     mx_sym = mx.sym.contrib.BilinearResize2D(data, height=5, width=10)
     verify_mxnet_frontend_impl(mx_sym, (1, 2, 3, 4), (1, 2, 5, 10))
 
+def test_forward_grid_generator():
+    def verify(shape, transform_type, target_shape):
+        x = np.random.uniform(size=shape).astype("float32")
+        ref_res = mx.nd.GridGenerator(mx.nd.array(x), transform_type, 
target_shape)
+        mx_sym = mx.sym.GridGenerator(mx.sym.var("x"), transform_type, 
target_shape)
+        shape_dict = {"x": x.shape}
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(
+                    kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(x)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
+    verify((4, 6), 'affine', (16, 32))
+    verify((4, 2, 16, 16), 'warp', None)
+    verify((1, 2, 16, 16), 'warp', None)
+
+def test_forward_bilinear_sampler():
+    def verify(data_shape, grid_shape):
+        data = np.random.uniform(size=data_shape).astype("float32")
+        grid = np.random.uniform(low=-1.5, high=1.5, 
size=grid_shape).astype("float32")
+        ref_res = mx.nd.BilinearSampler(mx.nd.array(data), mx.nd.array(grid))
+        mx_sym = mx.sym.BilinearSampler(mx.sym.var("data"), mx.sym.var("grid"))
+        shape_dict = {"data": data.shape, "grid": grid.shape}
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(
+                    kind, mod=mod, ctx=ctx, target=target)
+                op_res = intrp.evaluate()(data, grid)
+                tvm.testing.assert_allclose(
+                    op_res.asnumpy(), ref_res.asnumpy(), rtol=1e-5, atol=1e-5)
+    verify((4, 4, 16, 32), (4, 2, 8, 8))
+    verify((4, 4, 16, 32), (4, 2, 32, 32))
+
 def test_forward_rnn_layer():
     def verify(mode, seq_len, input_size, hidden_size, num_layers,
                batch=1, init_states=True, bidirectional=False):
@@ -1211,3 +1247,5 @@ if __name__ == '__main__':
     test_forward_unravel_index()
     test_forward_swap_axis()
     test_forward_correlation()
+    test_forward_grid_generator()
+    test_forward_bilinear_sampler()
diff --git a/tests/python/relay/test_op_level5.py 
b/tests/python/relay/test_op_level5.py
index c9d7d42..c306752 100644
--- a/tests/python/relay/test_op_level5.py
+++ b/tests/python/relay/test_op_level5.py
@@ -823,6 +823,56 @@ def test_dilation2d_run():
                         data_layout='NHWC', kernel_layout='HWI')
 
 
+def test_affine_grid():
+    def verify_affine_grid(num_batch, target_shape):
+        dtype = 'float32'
+        data_shape = (num_batch, 2, 3)
+        data = relay.var("data", relay.ty.TensorType(data_shape, dtype))
+        y = relay.image.affine_grid(data, target_shape)
+        yy = run_infer_type(y)
+        assert yy.checked_type == relay.ty.TensorType((num_batch, 
len(target_shape), *target_shape), dtype)
+
+        func = relay.Function([data], y)
+        data_np = np.random.uniform(size=data_shape).astype(dtype)
+        ref_res = topi.testing.affine_grid_python(data_np, target_shape)
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp1 = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res1 = intrp1.evaluate(func)(data_np)
+                tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, 
rtol=1e-5, atol=1e-5)
+
+    verify_affine_grid(1, (16, 32))
+    verify_affine_grid(4, (16, 32))
+
+
+def test_grid_sample():
+    def verify_grid_sample(data_shape, grid_shape):
+        dtype = 'float32'
+        batch, channel, _, _ = data_shape
+        _, _, out_height, out_width = grid_shape
+        data = relay.var("data", relay.ty.TensorType(data_shape, dtype))
+        grid = relay.var("grid", relay.ty.TensorType(grid_shape, dtype))
+        y = relay.image.grid_sample(data, grid, method='bilinear', 
layout='NCHW')
+        yy = run_infer_type(y)
+        assert yy.checked_type == relay.TensorType((batch, channel, 
out_height, out_width), dtype)
+        func = relay.Function([data, grid], y)
+
+        data_np = np.random.uniform(size=data_shape).astype(dtype)
+        grid_np = np.random.uniform(size=grid_shape, low=-1.5, 
high=1.5).astype(dtype)
+        ref_res = topi.testing.grid_sample_nchw_python(data_np, grid_np, 
method='bilinear')
+
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp1 = relay.create_executor(kind, ctx=ctx, target=target)
+                op_res1 = intrp1.evaluate(func)(data_np, grid_np)
+                tvm.testing.assert_allclose(
+                    op_res1.asnumpy(), ref_res, rtol=1e-5, atol=1e-5)
+
+    verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8))
+    verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32))
+
+
 if __name__ == "__main__":
     test_resize_infer_type()
     test_resize()
@@ -843,3 +893,5 @@ if __name__ == "__main__":
     test_space_to_depth()
     test_dilation2d_infer_type()
     test_dilation2d_run()
+    test_affine_grid()
+    test_grid_sample()
diff --git a/topi/python/topi/image/__init__.py 
b/topi/python/topi/image/__init__.py
index 86b9825..914b02e 100644
--- a/topi/python/topi/image/__init__.py
+++ b/topi/python/topi/image/__init__.py
@@ -21,3 +21,4 @@ from __future__ import absolute_import as _abs
 
 from .resize import *
 from .dilation2d import *
+from .grid_sample import *
diff --git a/topi/python/topi/image/grid_sample.py 
b/topi/python/topi/image/grid_sample.py
new file mode 100644
index 0000000..32b6112
--- /dev/null
+++ b/topi/python/topi/image/grid_sample.py
@@ -0,0 +1,124 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""affine_grid and grid_sample operator"""
+from tvm import te, tir
+
+
+def affine_grid(data, target_shape):
+    """affine_grid operator that generates 2D sampling grid.
+
+    This operation is described in https://arxiv.org/pdf/1506.02025.pdf. It 
generates a uniform
+    sampling grid within the target shape and normalizes it to [-1, 1]. The 
provided affine
+    transformation is then applied on the sampling grid.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        3-D with shape [batch, 2, 3]. The affine matrix.
+
+    target_shape: list/tuple of two int
+        Specifies the output shape (H, W).
+
+    Returns
+    -------
+    Output : tvm.Tensor
+        4-D with shape [batch, 2, target_height, target_width]
+    """
+    assert target_shape is not None
+    assert len(target_shape) == 2
+    assert target_shape[0] > 1 and target_shape[1] > 1, \
+        "target height/width should be greater than 1"
+
+    dtype = data.dtype
+    y_step = tir.const((2.0 - 1e-7)/ (target_shape[0] - 1), dtype=dtype)
+    x_step = tir.const((2.0 - 1e-7)/ (target_shape[1] - 1), dtype=dtype)
+    start = tir.const(-1.0, dtype=dtype)
+
+    def _compute(n, dim, i, j):
+        y = start + i * y_step
+        x = start + j * x_step
+        return data[n, dim, 0] * x + data[n, dim, 1] * y + data[n, dim, 2]
+
+    oshape = (data.shape[0], len(target_shape), *target_shape)
+    return te.compute(oshape, _compute, tag='affine_grid')
+
+
+def grid_sample(data, grid, method='bilinear', layout='NCHW'):
+    """Applies bilinear sampling to input feature map.
+
+    Given :math:`data` and :math:`grid`, assuming NCHW layout, then the output 
is computed by
+
+    .. math::
+
+        x_{src} = grid[batch, 0, y_{dst}, x_{dst}] \\
+        y_{src} = grid[batch, 1, y_{dst}, x_{dst}] \\
+        output[batch, channel, y_{dst}, x_{dst}] = G(data[batch, channel, 
y_{src}, x_{src})
+
+    :math:`x_{dst}`, :math:`y_{dst}` enumerate all spatial locations in 
:math:`output`, and
+    :math:`G()` denotes the interpolation method.
+    The out-boundary points will be padded with zeros. The shape of the output 
will be
+    (data.shape[0], data.shape[1], grid.shape[2], grid.shape[3]).
+
+    The operator assumes that :math:`grid` has been normalized to [-1, 1].
+
+    grid_sample often cooperates with affine_grid which generates sampling 
grids for grid_sample.
+
+    Parameters
+    ----------
+    data : tvm.Tensor
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    grid : tvm.Tensor
+        4-D with shape [batch, 2, out_height, out_width]
+
+    method : str
+        The interpolation method. Only 'bilinear' is supported.
+
+    layout : str
+        The layout of input data and the output.
+
+    Returns
+    -------
+    Output : tvm.Tensor
+        4-D with shape [batch, 2, out_height, out_width]
+    """
+    batch, in_channel, in_height, in_width = data.shape
+    out_height, out_width = grid.shape[2:]
+    assert method == 'bilinear', "Only bilinear is supported"
+    assert layout == "NCHW", "Only NCHW is supported"
+
+    def _get_pixel_value(n, c, h, w):
+        return te.if_then_else(te.all(h >= 0, w >= 0, h < in_height, w < 
in_width),
+                               data[n, c, h, w], tir.const(0.0, 
dtype=data.dtype))
+
+    def _bilinear_sample(n, c, h, w):
+        x = grid[n, 0, h, w]
+        y = grid[n, 1, h, w]
+        y = (y + 1) * (in_height - 1) / 2
+        x = (x + 1) * (in_width - 1) / 2
+        x0 = te.floor(x).astype('int32')
+        y0 = te.floor(y).astype('int32')
+        x1 = x0 + tir.const(1, 'int32')
+        y1 = y0 + tir.const(1, 'int32')
+        return _get_pixel_value(n, c, y0, x0) * (1.0 - (y - y0)) * (1.0 - (x - 
x0)) \
+            + _get_pixel_value(n, c, y0, x1) * (1.0 - (y - y0)) * (x - x0) \
+            + _get_pixel_value(n, c, y1, x0) * (y - y0) * (1.0 - (x - x0)) \
+            + _get_pixel_value(n, c, y1, x1) * (y - y0) * (x - x0)
+
+    return te.compute((batch, in_channel, out_height, out_width), 
_bilinear_sample,
+                      tag='grid_sample')
diff --git a/topi/python/topi/testing/__init__.py 
b/topi/python/topi/testing/__init__.py
index 511fe16..e677a11 100644
--- a/topi/python/topi/testing/__init__.py
+++ b/topi/python/topi/testing/__init__.py
@@ -57,3 +57,4 @@ from .crop_and_resize_python import crop_and_resize_python
 from .common import get_injective_schedule, get_reduce_schedule, 
get_broadcast_schedule, \
     get_elemwise_schedule, get_conv2d_nchw_implement, dispatch
 from .adaptive_pool_python import adaptive_pool
+from .grid_sample_python import affine_grid_python, grid_sample_nchw_python
diff --git a/topi/python/topi/testing/grid_sample_python.py 
b/topi/python/topi/testing/grid_sample_python.py
new file mode 100644
index 0000000..964d8a2
--- /dev/null
+++ b/topi/python/topi/testing/grid_sample_python.py
@@ -0,0 +1,65 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
+"""affine_grid and grid_sample operators in python"""
+import math
+import numpy as np
+
+
+def affine_grid_python(data, target_shape):
+    yv, xv = np.meshgrid(
+        np.arange(target_shape[0]), np.arange(target_shape[1]))
+    yv = yv.T * 2 / (target_shape[0] - 1) - 1
+    xv = xv.T * 2 / (target_shape[1] - 1) - 1
+    ones = np.ones_like(xv)
+    grid = np.stack([xv, yv, ones]).reshape(3, -1)
+    return data.reshape(-1, 3).dot(grid).reshape(data.shape[0], 2, 
*target_shape)
+
+
+def _bilinear_sample_nchw_python(data, grid):
+    batch, in_channel, in_height, in_width = data.shape
+    _, _, out_height, out_width = grid.shape
+    out = np.zeros((batch, in_channel, out_height, out_width), 
dtype=data.dtype)
+
+    def _within_bound(y, x):
+        return 0 <= y < in_height and 0 <= x < in_width
+
+    for n in range(0, batch):
+        for h in range(0, out_height):
+            for w in range(0, out_width):
+                x, y = grid[n, :, h, w]
+                y = (y + 1) * (in_height - 1) / 2
+                x = (x + 1) * (in_width - 1) / 2
+                y0 = int(math.floor(y))
+                x0 = int(math.floor(x))
+                y1 = y0 + 1
+                x1 = x0 + 1
+                if _within_bound(y0, x0):
+                    out[n, :, h, w] += data[n, :, y0, x0] * (1.0 - (y - y0)) * 
(1.0 - (x - x0))
+                if _within_bound(y0, x1):
+                    out[n, :, h, w] += data[n, :, y0, x1] * (1.0 - (y - y0)) * 
(x - x0)
+                if _within_bound(y1, x0):
+                    out[n, :, h, w] += data[n, :, y1, x0] * (y - y0) * (1.0 - 
(x - x0))
+                if _within_bound(y1, x1):
+                    out[n, :, h, w] += data[n, :, y1, x1] * (y - y0) * (x - x0)
+    return out
+
+
+def grid_sample_nchw_python(data, grid, method='bilinear'):
+    if method == 'bilinear':
+        return _bilinear_sample_nchw_python(data, grid)
+    raise ValueError("invalid method")
diff --git a/topi/tests/python/test_topi_image.py 
b/topi/tests/python/test_topi_image.py
index 4eea75d..012ed42 100644
--- a/topi/tests/python/test_topi_image.py
+++ b/topi/tests/python/test_topi_image.py
@@ -20,6 +20,7 @@ import tvm
 from tvm import te
 import topi
 import topi.testing
+from tvm.contrib.pickle_memoize import memoize
 
 from common import get_all_backend
 
@@ -204,7 +205,89 @@ def test_crop_and_resize():
                            size_1, method='nearest_neighbor')
     verify_crop_and_resize((1, 3, 224, 224), boxes_1, indices_1, size_1, 
layout="NCHW")
 
+
+def test_affine_grid():
+    def verify_affine_grid(num_batch, target_shape):
+        dtype = "float32"
+        data_shape = (num_batch, 2, 3)
+        data = te.placeholder(data_shape, dtype=dtype)
+        out = topi.image.affine_grid(data, target_shape)
+
+        @memoize("topi.tests.test_affine_grid.verify_affine_grid")
+        def get_ref_data():
+            data_np = np.random.uniform(size=data_shape).astype(dtype)
+            out_np = topi.testing.affine_grid_python(data_np, target_shape)
+            return data_np, out_np
+
+        data_np, out_np = get_ref_data()
+
+        def check_device(device):
+            ctx = tvm.context(device, 0)
+            if not ctx.exist:
+                print("Skip because %s is not enabled" % device)
+                return
+            print("Running on target: %s" % device)
+            with tvm.target.create(device):
+                s = topi.testing.get_injective_schedule(device)(out)
+            tvm_data = tvm.nd.array(data_np, ctx)
+            tvm_out = tvm.nd.empty(out_np.shape, dtype, ctx)
+            f = tvm.build(s, [data, out], device)
+            f(tvm_data, tvm_out)
+
+            tvm.testing.assert_allclose(
+                tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5)
+
+        for device in get_all_backend():
+            check_device(device)
+
+    verify_affine_grid(1, (16, 32))
+    verify_affine_grid(4, (16, 32))
+
+
+def test_grid_sample():
+    def verify_grid_sample(data_shape, grid_shape):
+        dtype = "float32"
+        data = te.placeholder(data_shape, dtype=dtype)
+        grid = te.placeholder(grid_shape, dtype=dtype)
+        out = topi.image.grid_sample(data, grid, 'bilinear')
+
+        @memoize("topi.tests.test_grid_sample.verify_grid_sample")
+        def get_ref_data():
+            data_np = np.random.uniform(size=data_shape).astype(dtype)
+            # allow grid values to be out-of-bound
+            grid_np = np.random.uniform(size=grid_shape, low=-1.5, 
high=1.5).astype(dtype)
+            out_np = topi.testing.grid_sample_nchw_python(data_np, grid_np, 
'bilinear')
+            return data_np, grid_np, out_np
+
+        data_np, grid_np, out_np = get_ref_data()
+
+        def check_device(device):
+            ctx = tvm.context(device, 0)
+            if not ctx.exist:
+                print("Skip because %s is not enabled" % device)
+                return
+            print("Running on target: %s" % device)
+            with tvm.target.create(device):
+                s = topi.testing.get_injective_schedule(device)(out)
+            tvm_data = tvm.nd.array(data_np, ctx)
+            tvm_grid = tvm.nd.array(grid_np, ctx)
+            tvm_out = tvm.nd.empty(out_np.shape, dtype, ctx)
+            f = tvm.build(s, [data, grid, out], device)
+            f(tvm_data, tvm_grid, tvm_out)
+
+            tvm.testing.assert_allclose(
+                tvm_out.asnumpy(), out_np, rtol=1e-5, atol=1e-5)
+
+        for device in get_all_backend():
+            check_device(device)
+
+    verify_grid_sample((4, 4, 16, 32), (4, 2, 8, 8))
+    verify_grid_sample((4, 4, 16, 32), (4, 2, 32, 32))
+
+
 if __name__ == "__main__":
     test_resize()
     test_resize3d()
     test_crop_and_resize()
+    test_affine_grid()
+    test_grid_sample()

Reply via email to