This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/master by this push:
     new e55f9ff  [Relay, Topi][OP] Correlation (#5628)
e55f9ff is described below

commit e55f9ff115e2e9364e8ed8b3eeb44dbbc1894eb1
Author: Wuwei Lin <wu...@apache.org>
AuthorDate: Fri May 22 02:00:33 2020 -0400

    [Relay, Topi][OP] Correlation (#5628)
    
    * [Relay,Topi] Correlation
    
    * fix
    
    * move
    
    * typo
    
    * Update test_topi_correlation.py
---
 include/tvm/relay/attrs/nn.h                       |  30 ++++
 python/tvm/relay/frontend/mxnet.py                 |  14 ++
 python/tvm/relay/op/nn/_nn.py                      |   5 +
 python/tvm/relay/op/nn/nn.py                       |  83 ++++++++++
 python/tvm/relay/op/op_attrs.py                    |   5 +
 python/tvm/relay/op/strategy/cuda.py               |  12 ++
 python/tvm/relay/op/strategy/generic.py            |  27 ++++
 src/relay/op/nn/correlation.cc                     | 136 ++++++++++++++++
 tests/python/frontend/mxnet/test_forward.py        |  35 +++-
 tests/python/relay/test_op_level2.py               |  40 +++++
 topi/python/topi/cuda/__init__.py                  |   1 +
 topi/python/topi/cuda/correlation.py               | 176 +++++++++++++++++++++
 topi/python/topi/generic/nn.py                     |  17 ++
 topi/python/topi/nn/__init__.py                    |   1 +
 topi/python/topi/nn/correlation.py                 | 116 ++++++++++++++
 topi/python/topi/testing/__init__.py               |   1 +
 .../python/topi/testing/correlation_nchw_python.py | 103 ++++++++++++
 topi/tests/python/test_topi_correlation.py         |  93 +++++++++++
 18 files changed, 894 insertions(+), 1 deletion(-)

diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index a9c3059..dcb4cb6 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -1203,6 +1203,36 @@ struct SubPixelAttrs : public 
tvm::AttrsNode<SubPixelAttrs> {
   }
 };  // struct SubPixelAttrs
 
+/*! \brief Attributes used in correlation operators */
+struct CorrelationAttrs : public tvm::AttrsNode<CorrelationAttrs> {
+  int kernel_size;
+  int max_displacement;
+  int stride1;
+  int stride2;
+  Array<IndexExpr> padding;
+  bool is_multiply;
+  String layout;
+
+  TVM_DECLARE_ATTRS(CorrelationAttrs, "relay.attrs.CorrelationAttrs") {
+    TVM_ATTR_FIELD(kernel_size)
+        .describe("Kernel size for correlation, must be an odd number.")
+        .set_default(1);
+    TVM_ATTR_FIELD(max_displacement).describe("Max displacement of 
Correlation.").set_default(1);
+    TVM_ATTR_FIELD(stride1).describe("Stride for data1.").set_default(1);
+    TVM_ATTR_FIELD(stride2).describe("Stride for data2.").set_default(1);
+    TVM_ATTR_FIELD(padding)
+        .describe("Padding for data1 and data2.")
+        .set_default(Array<IndexExpr>{0, 0});
+    TVM_ATTR_FIELD(is_multiply)
+        .describe("Operation type is either multiplication or substraction.")
+        .set_default(true);
+    TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
+        "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+        "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+        "dimensions respectively.");
+  }
+};  // struct CorrelationAttrs
+
 }  // namespace relay
 }  // namespace tvm
 #endif  // TVM_RELAY_ATTRS_NN_H_
diff --git a/python/tvm/relay/frontend/mxnet.py 
b/python/tvm/relay/frontend/mxnet.py
index edf6680..9f97ee9 100644
--- a/python/tvm/relay/frontend/mxnet.py
+++ b/python/tvm/relay/frontend/mxnet.py
@@ -1133,6 +1133,19 @@ def _mx_space_to_depth(inputs, attrs):
     return _op.nn.space_to_depth(*inputs, **new_attrs)
 
 
+def _mx_correlation(inputs, attrs):
+    assert len(inputs) == 2
+    new_attrs = {}
+    new_attrs["kernel_size"] = attrs.get_int("kernel_size", 1)
+    new_attrs["max_displacement"] = attrs.get_int("max_displacement", 1)
+    new_attrs["stride1"] = attrs.get_int("stride1", 1)
+    new_attrs["stride2"] = attrs.get_int("stride2", 1)
+    new_attrs["padding"] = attrs.get_int("pad_size", 0)
+    new_attrs["is_multiply"] = attrs.get_bool("is_multiply", True)
+    new_attrs["layout"] = "NCHW"
+    return _op.nn.correlation(*inputs, **new_attrs)
+
+
 def _mx_contrib_fifo_buffer(inputs, attrs):
     new_attrs = {}
     new_attrs['axis'] = attrs.get_int('axis')
@@ -1971,6 +1984,7 @@ _convert_map = {
     "one_hot"           : _mx_one_hot,
     "depth_to_space"    : _mx_depth_to_space,
     "space_to_depth"    : _mx_space_to_depth,
+    "Correlation"       : _mx_correlation,
     # vision
     "_contrib_BilinearResize2D" : _mx_resize,
     "_contrib_MultiBoxPrior" : _mx_multibox_prior,
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 9a9bfe0..0633451 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -563,6 +563,11 @@ reg.register_injective_schedule("nn.space_to_depth")
 reg.register_pattern("nn.space_to_depth", OpPattern.INJECTIVE)
 
 
+# correlation
+reg.register_strategy("nn.correlation", strategy.correlation_strategy)
+reg.register_pattern("nn.correlation", OpPattern.OUT_ELEMWISE_FUSABLE)
+
+
 #####################
 #  Shape functions  #
 #####################
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 96708c9..0f1f158 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -2761,3 +2761,86 @@ def global_avg_pool3d(data,
     """
     output_size = [1, 1, 1]
     return _make.adaptive_avg_pool3d(data, output_size, layout)
+
+
+def correlation(data1, data2, kernel_size, max_displacement, stride1, stride2, 
padding,
+                is_multiply, layout):
+    r"""Applies correlation to inputs.
+
+    The correlation layer performs multiplicative patch comparisons between 
two feature maps.
+    Given two multi-channel feature maps :math:`f_{1}, f_{2}`, with :math:`w`, 
:math:`h`, and
+    :math:`c` being their width, height, and number of channels, the 
correlation layer lets the
+    network compare each patch from :math:`f_{1}` with each patch from 
:math:`f_{2}`.
+
+    For now we consider only a single comparison of two patches. The 
'correlation' of two patches
+    centered at :math:`x_{1}` in the first map and :math:`x_{2}` in the second 
map is then defined
+    as:
+
+    .. math::
+
+        c(x_{1}, x_{2}) = \sum_{o \in [-k,k] \times [-k,k]} <f_{1}(x_{1} + o), 
f_{2}(x_{2} + o)>
+
+    for a square patch of size :math:`K:=2k+1`.
+
+    Note that the equation above is identical to one step of a convolution in 
neural networks, but
+    instead of convolving data with a filter, it convolves data with other    
data. For this
+    reason, it has no training weights.
+
+    Computing :math:`c(x_{1}, x_{2})` involves :math:`c * K^{2}` 
multiplications. Comparing all
+    patch combinations involves :math:`w^{2}*h^{2}` such computations.
+
+    Given a maximum displacement :math:`d`, for each location :math:`x_{1}` it 
computes
+    correlations :math:`c(x_{1}, x_{2})` only in a neighborhood of size 
:math:`D:=2d+1`,
+    by limiting the range of :math:`x_{2}`. We use strides :math:`s_{1}, 
s_{2}`, to quantize
+    :math:`x_{1}` globally and to quantize :math:`x_{2}` within the 
neighborhood
+    centered around :math:`x_{1}`.
+
+    The final output is defined by the following expression:
+
+    .. math::
+
+        out[n, q, i, j] = c(x_{i, j}, x_{q})
+
+    where :math:`i` and :math:`j` enumerate spatial locations in 
:math:`f_{1}`, and :math:`q`
+    denotes the :math:`q^{th}` neighborhood of :math:`x_{i,j}`.
+
+    Parameters
+    ----------
+    data1 : tvm.te.Tensor
+        4-D with shape [batch, channel, height, width]
+
+    data2 : tvm.te.Tensor
+        4-D with shape [batch, channel, height, width]
+
+    kernel_size: int
+        Kernel size for correlation, must be an odd number
+
+    max_displacement: int
+        Max displacement of Correlation
+
+    stride1: int
+        Stride for data1
+
+    stride2: int
+        Stride for data2 within the neightborhood centered around data1
+
+    padding : int or a list/tuple of 2 or 4 ints
+        Padding size, or
+        [pad_height, pad_width] for 2 ints, or
+        [pad_top, pad_left, pad_bottom, pad_right] for 4 ints
+
+    is_multiply: bool
+        operation type is either multiplication or substraction
+
+    layout: str
+        layout of data1, data2 and the output
+
+    Returns
+    -------
+    Output : tvm.te.Tensor
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    if isinstance(padding, int):
+        padding = (padding, padding)
+    return _make.correlation(data1, data2, kernel_size, max_displacement, 
stride1, stride2,
+                             padding, is_multiply, layout)
diff --git a/python/tvm/relay/op/op_attrs.py b/python/tvm/relay/op/op_attrs.py
index fee213c..0686125 100644
--- a/python/tvm/relay/op/op_attrs.py
+++ b/python/tvm/relay/op/op_attrs.py
@@ -357,3 +357,8 @@ class DilateAttrs(Attrs):
 @tvm._ffi.register_object("relay.attrs.SubPixelAttrs")
 class SubPixelAttrs(Attrs):
     """Attributes used in depth to space and space to depth operators"""
+
+
+@tvm._ffi.register_object("relay.attrs.CorrelationAttrs")
+class CorrelationAttrs(Attrs):
+    """Attributes used in correlation operators"""
diff --git a/python/tvm/relay/op/strategy/cuda.py 
b/python/tvm/relay/op/strategy/cuda.py
index 83e4e40..59d4ec9 100644
--- a/python/tvm/relay/op/strategy/cuda.py
+++ b/python/tvm/relay/op/strategy/cuda.py
@@ -590,3 +590,15 @@ def winograd_judge(N, H, W, KH, KW, CI, CO, padding, 
stride_h,
                               stride_h == 1 and stride_w == 1 and \
                               dilation_h == 1 and dilation_w == 1
     return judge_winograd_tensorcore, judge_winograd_shape
+
+@correlation_strategy.register(["cuda", "gpu"])
+def correlation_strategy_cuda(attrs, inputs, out_type, target):
+    """correlation cuda strategy"""
+    layout = attrs.layout
+    assert layout == "NCHW", "Only support NCHW layout"
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_correlation(topi.cuda.correlation_nchw),
+        wrap_topi_schedule(topi.cuda.schedule_correlation_nchw),
+        name="correlation.cuda")
+    return strategy
diff --git a/python/tvm/relay/op/strategy/generic.py 
b/python/tvm/relay/op/strategy/generic.py
index c3eadce..6db5b14 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -829,3 +829,30 @@ def bitserial_dense_strategy(attrs, inputs, out_type, 
target):
         wrap_topi_schedule(topi.generic.schedule_bitserial_dense),
         name="bitserial_dense.generic")
     return strategy
+
+# correlation
+def wrap_compute_correlation(topi_compute):
+    """wrap correlation topi compute"""
+    def _compute_correlation(attrs, inputs, out_type):
+        kernel_size = attrs.kernel_size
+        max_displacement = attrs.max_displacement
+        stride1 = attrs.stride1
+        stride2 = attrs.stride2
+        padding = get_const_tuple(attrs.padding)
+        is_multiply = attrs.is_multiply
+        return [topi_compute(inputs[0], inputs[1], kernel_size, 
max_displacement, stride1, stride2,
+                             padding, is_multiply)]
+    return _compute_correlation
+
+@override_native_generic_func("correlation_strategy")
+def correlation_strategy(attrs, inputs, out_type, target):
+    """correlation generic strategy"""
+    logger.warning("correlation is not optimized for this platform.")
+    layout = attrs.layout
+    assert layout == "NCHW", "Only support NCHW layout"
+    strategy = _op.OpStrategy()
+    strategy.add_implementation(
+        wrap_compute_correlation(topi.nn.correlation_nchw),
+        wrap_topi_schedule(topi.generic.schedule_correlation_nchw),
+        name="correlation.generic")
+    return strategy
diff --git a/src/relay/op/nn/correlation.cc b/src/relay/op/nn/correlation.cc
new file mode 100644
index 0000000..67f42b7
--- /dev/null
+++ b/src/relay/op/nn/correlation.cc
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file correlation.cc
+ * \brief Correlation operators
+ */
+#include <topi/nn.h>
+#include <tvm/relay/attrs/nn.h>
+#include <tvm/relay/op.h>
+#include <tvm/tir/data_layout.h>
+#include <tvm/tir/op.h>
+
+#include <vector>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relay {
+
+// relay.nn.correlation
+TVM_REGISTER_NODE_TYPE(CorrelationAttrs);
+
+Array<Array<Layout>> CorrelationInferCorrectLayout(const Attrs& attrs,
+                                                   const Array<Layout>& 
new_in_layouts,
+                                                   const Array<Layout>& 
old_in_layouts,
+                                                   const 
Array<tvm::relay::Type>& old_in_types) {
+  const auto* params = attrs.as<CorrelationAttrs>();
+  Layout layout{params->layout};
+  return Array<Array<Layout>>{{layout, layout}, {layout}};
+}
+
+// Positional relay function to create correlation operator
+// used by frontend FFI.
+Expr MakeCorrelation(Expr data1, Expr data2, int kernel_size, int 
max_displacement, int stride1,
+                     int stride2, Array<IndexExpr> padding, bool is_multiply, 
String layout) {
+  auto attrs = make_object<CorrelationAttrs>();
+  attrs->kernel_size = kernel_size;
+  attrs->max_displacement = max_displacement;
+  attrs->stride1 = stride1;
+  attrs->stride2 = stride2;
+  attrs->padding = std::move(padding);
+  attrs->is_multiply = is_multiply;
+  attrs->layout = std::move(layout);
+  static const Op& op = Op::Get("nn.correlation");
+  return Call(op, {data1, data2}, Attrs(attrs), {});
+}
+
+bool CorrelationRel(const Array<Type>& types, int num_inputs, const Attrs& 
attrs,
+                    const TypeReporter& reporter) {
+  CHECK_EQ(types.size(), 3);
+  const auto* data1 = types[0].as<TensorTypeNode>();
+  const auto* data2 = types[1].as<TensorTypeNode>();
+  if (data1 == nullptr || data2 == nullptr) return false;
+
+  const CorrelationAttrs* param = attrs.as<CorrelationAttrs>();
+  CHECK(param != nullptr);
+  CHECK_EQ(param->layout, "NCHW") << "layout not supported.";
+  IndexExpr pad_h, pad_w;
+  GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
+  IndexExpr padded_height = data1->shape[2] + pad_h;
+  IndexExpr padded_width = data2->shape[3] + pad_w;
+  int kernel_radius = (param->kernel_size - 1) / 2;
+  int border_size = param->max_displacement + kernel_radius;
+  int displacement_radius = param->max_displacement / param->stride2;
+  int displacement_size = 2 * displacement_radius + 1;
+  int out_channel = displacement_size * displacement_size;
+  IndexExpr out_height =
+      indexdiv((padded_height - 2 * border_size + param->stride1 - 1), 
param->stride1);
+  IndexExpr out_width =
+      indexdiv((padded_width - 2 * border_size + param->stride1 - 1), 
param->stride1);
+  Array<tvm::PrimExpr> oshape{data1->shape[0], out_channel, out_height, 
out_width};
+  // assign output type
+  reporter->Assign(types[2], TensorType(oshape, data1->dtype));
+  return true;
+}
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.correlation").set_body_typed(MakeCorrelation);
+
+RELAY_REGISTER_OP("nn.correlation")
+    .describe(R"code(Applies correlation to inputs.
+
+The correlation layer performs multiplicative patch comparisons between two 
feature maps.
+Given two multi-channel feature maps :math:`f_{1}, f_{2}`, with :math:`w`, 
:math:`h`, and :math:`c` being their width, height, and number of channels,
+the correlation layer lets the network compare each patch from :math:`f_{1}` 
with each patch from :math:`f_{2}`.
+
+For now we consider only a single comparison of two patches. The 'correlation' 
of two patches centered at :math:`x_{1}` in the first map and
+:math:`x_{2}` in the second map is then defined as:
+
+.. math::
+   c(x_{1}, x_{2}) = \sum_{o \in [-k,k] \times [-k,k]} <f_{1}(x_{1} + o), 
f_{2}(x_{2} + o)>
+
+for a square patch of size :math:`K:=2k+1`.
+
+Note that the equation above is identical to one step of a convolution in 
neural networks, but instead of convolving data with a filter, it convolves 
data with other
+data. For this reason, it has no training weights.
+
+Computing :math:`c(x_{1}, x_{2})` involves :math:`c * K^{2}` multiplications. 
Comparing all patch combinations involves :math:`w^{2}*h^{2}` such computations.
+
+Given a maximum displacement :math:`d`, for each location :math:`x_{1}` it 
computes correlations :math:`c(x_{1}, x_{2})` only in a neighborhood of size 
:math:`D:=2d+1`,
+by limiting the range of :math:`x_{2}`. We use strides :math:`s_{1}, s_{2}`, 
to quantize :math:`x_{1}` globally and to quantize :math:`x_{2}` within the 
neighborhood
+centered around :math:`x_{1}`.
+
+The final output is defined by the following expression:
+
+.. math::
+  out[n, q, i, j] = c(x_{i, j}, x_{q})
+
+where :math:`i` and :math:`j` enumerate spatial locations in :math:`f_{1}`, 
and :math:`q` denotes the :math:`q^{th}` neighborhood of :math:`x_{i,j}`.
+)code" TVM_ADD_FILELINE)
+    .set_attrs_type<CorrelationAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data1", "Tensor", "Input data1 to the correlation.")
+    .add_argument("data2", "Tensor", "Input data2 to the correlation.")
+    .set_support_level(2)
+    .set_attr<FInferCorrectLayout>("FInferCorrectLayout", 
CorrelationInferCorrectLayout)
+    .add_type_rel("Correlation", CorrelationRel);
+
+}  // namespace relay
+}  // namespace tvm
diff --git a/tests/python/frontend/mxnet/test_forward.py 
b/tests/python/frontend/mxnet/test_forward.py
index 6e8acde..99fc6c3 100644
--- a/tests/python/frontend/mxnet/test_forward.py
+++ b/tests/python/frontend/mxnet/test_forward.py
@@ -1114,6 +1114,38 @@ def test_forward_space_to_depth():
     verify((1, 1, 9, 9), 3)
 
 
+def test_forward_correlation():
+    def verify(data_shape, kernel_size, max_displacement, stride1, stride2, 
pad_size,
+               is_multiply):
+        data1 = np.random.uniform(size=data_shape).astype("float32")
+        data2 = np.random.uniform(size=data_shape).astype("float32")
+        ref_res = mx.nd.Correlation(data1=mx.nd.array(data1), 
data2=mx.nd.array(data2),
+                                    kernel_size=kernel_size, 
max_displacement=max_displacement,
+                                    stride1=stride1, stride2=stride2, 
pad_size=pad_size,
+                                    is_multiply=is_multiply)
+        mx_sym = mx.sym.Correlation(data1=mx.sym.var('data1'), 
data2=mx.sym.var('data2'),
+                                    kernel_size=kernel_size, 
max_displacement=max_displacement,
+                                    stride1=stride1, stride2=stride2, 
pad_size=pad_size,
+                                    is_multiply=is_multiply)
+        shape_dict = {"data1": data1.shape, "data2": data2.shape}
+        mod, _ = relay.frontend.from_mxnet(mx_sym, shape_dict)
+        for target, ctx in ctx_list():
+            for kind in ["graph", "debug"]:
+                intrp = relay.create_executor(kind, mod=mod, ctx=ctx, 
target=target)
+                op_res = intrp.evaluate()(data1, data2)
+                tvm.testing.assert_allclose(op_res.asnumpy(), 
ref_res.asnumpy(), rtol=1e-3, atol=1e-5)
+
+    verify((1, 3, 10, 10), kernel_size = 1, max_displacement = 4, stride1 = 1, 
stride2 = 1, pad_size = 4, is_multiply = False)
+    verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 5, stride1 = 1, 
stride2 = 1, pad_size = 5, is_multiply = False)
+    verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 5, stride1 = 1, 
stride2 = 1, pad_size = 5, is_multiply = True)
+    verify((5, 1, 15, 15), kernel_size = 1, max_displacement = 10, stride1 = 
1, stride2 = 2, pad_size = 10, is_multiply = True)
+    verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 1, 
stride2 = 1, pad_size = 2, is_multiply = True)
+    verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, 
stride2 = 1, pad_size = 2, is_multiply = True)
+    verify((5, 1, 4, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, 
stride2 = 1, pad_size = 2, is_multiply = False)
+    verify((5, 1, 6, 4), kernel_size = 3, max_displacement = 1, stride1 = 2, 
stride2 = 1, pad_size = 2, is_multiply = False)
+    verify((5, 1, 11, 11), kernel_size = 5, max_displacement = 1, stride1 = 1, 
stride2 = 1, pad_size = 2, is_multiply = False)
+
+
 if __name__ == '__main__':
     test_forward_mlp()
     test_forward_vgg()
@@ -1177,4 +1209,5 @@ if __name__ == '__main__':
     test_forward_cond()
     test_forward_make_loss()
     test_forward_unravel_index()
-    test_forward_swap_axis()
\ No newline at end of file
+    test_forward_swap_axis()
+    test_forward_correlation()
diff --git a/tests/python/relay/test_op_level2.py 
b/tests/python/relay/test_op_level2.py
index cf9d2d4..68eced3 100644
--- a/tests/python/relay/test_op_level2.py
+++ b/tests/python/relay/test_op_level2.py
@@ -1342,6 +1342,45 @@ def test_bitpack_infer_type():
 # TODO(@jwfromm): Need to add bitserial_conv2d & bitpack run test cases
 
 
+def test_correlation():
+    def _test_correlation(data_shape, kernel_size, max_displacement, stride1, 
stride2, padding, is_multiply, dtype='float32'):
+        data1 = relay.var("data1", relay.ty.TensorType(data_shape, dtype))
+        data2 = relay.var("data2", relay.ty.TensorType(data_shape, dtype))
+        y = relay.nn.correlation(data1, data2, kernel_size, max_displacement, 
stride1, stride2,
+                                 padding, is_multiply, "NCHW")
+        yy = run_infer_type(y)
+        padded_height = data_shape[2] + 2 * padding
+        padded_width = data_shape[3] + 2 * padding
+        border_size = (kernel_size - 1) // 2 + max_displacement
+        displacement_radius = max_displacement // stride2
+        out_channel = ((2 * displacement_radius) + 1) ** 2
+        out_height = (padded_height - 2 * border_size + stride1 - 1) // stride1
+        out_width = (padded_width - 2 * border_size + stride1 - 1) // stride1
+        assert yy.checked_type == relay.TensorType(
+            (data_shape[0], out_channel, out_height, out_width), dtype
+        )
+        func = relay.Function([data1, data2], y)
+        data1_np = np.random.uniform(size=data_shape).astype(dtype)
+        data2_np = np.random.uniform(size=data_shape).astype(dtype)
+        ref_res = topi.testing.correlation_nchw_python(data1_np, data2_np, 
kernel_size, max_displacement, stride1, stride2, padding, is_multiply)
+
+        for target, ctx in ctx_list():
+            intrp1 = relay.create_executor("graph", ctx=ctx, target=target)
+            op_res1 = intrp1.evaluate(func)(data1_np, data2_np)
+            tvm.testing.assert_allclose(op_res1.asnumpy(), ref_res, rtol=1e-5, 
atol=1e-5)
+
+    _test_correlation((1, 3, 10, 10), kernel_size=1, max_displacement=4,
+                      stride1=1, stride2=1, padding=4, is_multiply=True)
+    _test_correlation((1, 3, 10, 10), kernel_size=1, max_displacement=5,
+                      stride1=1, stride2=1, padding=5, is_multiply=True)
+    _test_correlation((5, 1, 4, 4), kernel_size=3, max_displacement=1,
+                      stride1=2, stride2=1, padding=2, is_multiply=True)
+    _test_correlation((5, 1, 6, 4), kernel_size=3, max_displacement=1,
+                      stride1=2, stride2=2, padding=2, is_multiply=False)
+    _test_correlation((5, 1, 11, 11), kernel_size=5, max_displacement=1,
+                      stride1=1, stride2=1, padding=2, is_multiply=False)
+
+
 if __name__ == "__main__":
     test_pool1d()
     test_pool2d()
@@ -1374,3 +1413,4 @@ if __name__ == "__main__":
     test_upsampling3d()
     test_conv2d_int8_intrinsics()
     test_depthwise_conv2d_int8()
+    test_correlation()
diff --git a/topi/python/topi/cuda/__init__.py 
b/topi/python/topi/cuda/__init__.py
index 8ccd80f..ba5c54b 100644
--- a/topi/python/topi/cuda/__init__.py
+++ b/topi/python/topi/cuda/__init__.py
@@ -49,3 +49,4 @@ from .sort import *
 from .conv2d_nhwc_tensorcore import *
 from .conv3d_ndhwc_tensorcore import *
 from .dense_tensorcore import *
+from .correlation import *
diff --git a/topi/python/topi/cuda/correlation.py 
b/topi/python/topi/cuda/correlation.py
new file mode 100644
index 0000000..a383e4e
--- /dev/null
+++ b/topi/python/topi/cuda/correlation.py
@@ -0,0 +1,176 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Correlation operators on CUDA"""
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from .. import nn
+from ..util import traverse_inline
+
+
+@autotvm.register_topi_compute("correlation_nchw.cuda")
+def correlation_nchw(cfg, data1, data2, kernel_size, max_displacement, 
stride1, stride2, padding,
+                     is_multiply):
+    """Correlation operator in NCHW layout.
+
+    Parameters
+    ----------
+    data1 : tvm.te.Tensor
+        4-D with shape [batch, channel, height, width]
+
+    data2 : tvm.te.Tensor
+        4-D with shape [batch, channel, height, width]
+
+    kernel_size: int
+        Kernel size for correlation, must be an odd number
+
+    max_displacement: int
+        Max displacement of Correlation
+
+    stride1: int
+        Stride for data1
+
+    stride2: int
+        Stride for data2 within the neightborhood centered around data1
+
+    padding : int or a list/tuple of 2 or 4 ints
+        Padding size, or
+        [pad_height, pad_width] for 2 ints, or
+        [pad_top, pad_left, pad_bottom, pad_right] for 4 ints
+
+    is_multiply: bocorrelation
+        operation type is either multiplication or substraction
+
+    Returns
+    -------
+    Output : tvm.te.Tensor
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    # pylint: disable=unused-argument
+    return nn.correlation_nchw(data1, data2, kernel_size, max_displacement, 
stride1, stride2,
+                               padding, is_multiply)
+
+
+def _schedule_correlation_nchw(cfg, s, correlation):
+    """Schedule correlation_nchw direct template"""
+    # pylint: disable=invalid-name
+    ##### space definition begin #####
+    n, f, y, x = s[correlation].op.axis
+    rc, ry, rx = s[correlation].op.reduce_axis
+    cfg.define_split("tile_f", f, num_outputs=4)
+    cfg.define_split("tile_y", y, num_outputs=4)
+    cfg.define_split("tile_x", x, num_outputs=4)
+    cfg.define_split("tile_rc", rc, num_outputs=2)
+    cfg.define_split("tile_ry", ry, num_outputs=2)
+    cfg.define_split("tile_rx", rx, num_outputs=2)
+    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+
+    target = tvm.target.Target.current()
+    if target.target_name in ['nvptx', 'rocm']:
+        cfg.define_knob("unroll_explicit", [1])
+    else:
+        cfg.define_knob("unroll_explicit", [0, 1])
+
+    ##### space definition end #####
+
+    padded_data1, padded_data2 = s[correlation].op.input_tensors
+    s[padded_data1].compute_inline()
+    s[padded_data2].compute_inline()
+
+    # create cache stage
+    s[correlation].set_scope('local')
+    AA = s.cache_read(padded_data1, 'shared', [correlation])
+    BB = s.cache_read(padded_data2, 'shared', [correlation])
+
+    output = s.outputs[0].output(0)
+
+    # tile and bind spatial axes
+    n, f, y, x = s[output].op.axis
+    kernel_scope, n = s[output].split(n, nparts=1)
+
+    bf, vf, tf, fi = cfg["tile_f"].apply(s, output, f)
+    by, vy, ty, yi = cfg["tile_y"].apply(s, output, y)
+    bx, vx, tx, xi = cfg["tile_x"].apply(s, output, x)
+
+    bf = s[output].fuse(n, bf)
+    s[output].bind(bf, te.thread_axis("blockIdx.z"))
+    s[output].bind(by, te.thread_axis("blockIdx.y"))
+    s[output].bind(bx, te.thread_axis("blockIdx.x"))
+    s[output].bind(vf, te.thread_axis("vthread"))
+    s[output].bind(vy, te.thread_axis("vthread"))
+    s[output].bind(vx, te.thread_axis("vthread"))
+    s[output].bind(tf, te.thread_axis("threadIdx.z"))
+    s[output].bind(ty, te.thread_axis("threadIdx.y"))
+    s[output].bind(tx, te.thread_axis("threadIdx.x"))
+    s[output].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fi, yi, xi)
+    s[correlation].compute_at(s[output], tx)
+
+    # tile reduction axes
+    n, f, y, x = s[correlation].op.axis
+    rc, ry, rx = s[correlation].op.reduce_axis
+    rco, rci = cfg['tile_rc'].apply(s, correlation, rc)
+    ryo, ryi = cfg['tile_ry'].apply(s, correlation, ry)
+    rxo, rxi = cfg['tile_rx'].apply(s, correlation, rx)
+    s[correlation].reorder(rco, ryo, rxo, rci, ryi, rxi, n, f, y, x)
+
+    s[AA].compute_at(s[correlation], rxo)
+    s[BB].compute_at(s[correlation], rxo)
+
+    # cooperative fetching
+    for load in [AA, BB]:
+        n, f, y, x = s[load].op.axis
+        fused = s[load].fuse(n, f, y, x)
+        tz, fused = s[load].split(fused, nparts=cfg["tile_f"].size[2])
+        ty, fused = s[load].split(fused, nparts=cfg["tile_y"].size[2])
+        tx, fused = s[load].split(fused, nparts=cfg["tile_x"].size[2])
+        s[load].bind(tz, te.thread_axis("threadIdx.z"))
+        s[load].bind(ty, te.thread_axis("threadIdx.y"))
+        s[load].bind(tx, te.thread_axis("threadIdx.x"))
+
+    # unroll
+    s[output].pragma(kernel_scope, 'auto_unroll_max_step', 
cfg['auto_unroll_max_step'].val)
+    s[output].pragma(kernel_scope, 'unroll_explicit', 
cfg['unroll_explicit'].val)
+
+
+@autotvm.register_topi_schedule("correlation_nchw.cuda")
+def schedule_correlation_nchw(cfg, outs):
+    """schedule of correlation_nchw for cuda gpu
+
+    Parameters
+    ----------
+    cfg: ConfigEntity
+        The config for this template
+
+    outs: Array of Tensor
+        The computation graph description of correlation
+        in the format of an array of tensors.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for correlation.
+    """
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == 'correlation_nchw':
+            _schedule_correlation_nchw(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py
index 2be4bbb..d0c165d 100644
--- a/topi/python/topi/generic/nn.py
+++ b/topi/python/topi/generic/nn.py
@@ -672,3 +672,20 @@ def schedule_batch_matmul(outs):
         The computation schedule for the op.
     """
     return _default_schedule(outs, False)
+
+
+def schedule_correlation_nchw(outs):
+    """Schedule for correlation_nchw
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+          The computation graph description of correlation_nchw
+          in the format of an array of tensors.
+
+    Returns
+    -------
+    sch: Schedule
+        The computation schedule for the op.
+    """
+    return _default_schedule(outs, False)
diff --git a/topi/python/topi/nn/__init__.py b/topi/python/topi/nn/__init__.py
index bd806b9..3830bd0 100644
--- a/topi/python/topi/nn/__init__.py
+++ b/topi/python/topi/nn/__init__.py
@@ -22,6 +22,7 @@ from __future__ import absolute_import as _abs
 from .conv1d import *
 from .conv2d import *
 from .conv3d import *
+from .correlation import *
 from .deformable_conv2d import *
 from .depthwise_conv2d import *
 from .elemwise import *
diff --git a/topi/python/topi/nn/correlation.py 
b/topi/python/topi/nn/correlation.py
new file mode 100644
index 0000000..94aea55
--- /dev/null
+++ b/topi/python/topi/nn/correlation.py
@@ -0,0 +1,116 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Correlation operators"""
+from tvm import te
+
+from .pad import pad
+from ..util import get_const_tuple
+
+
+def correlation_nchw(data1, data2, kernel_size, max_displacement, stride1, 
stride2, padding,
+                     is_multiply):
+    """Correlation operator in NCHW layout.
+
+    Parameters
+    ----------
+    data1 : tvm.te.Tensor
+        4-D with shape [batch, channel, height, width]
+
+    data2 : tvm.te.Tensor
+        4-D with shape [batch, channel, height, width]
+
+    kernel_size: int
+        Kernel size for correlation, must be an odd number
+
+    max_displacement: int
+        Max displacement of Correlation
+
+    stride1: int
+        Stride for data1
+
+    stride2: int
+        Stride for data2 within the neightborhood centered around data1
+
+    padding : int or a list/tuple of 2 or 4 ints
+        Padding size, or
+        [pad_height, pad_width] for 2 ints, or
+        [pad_top, pad_left, pad_bottom, pad_right] for 4 ints
+
+    is_multiply: bool
+        operation type is either multiplication or substraction
+
+    Returns
+    -------
+    Output : tvm.te.Tensor
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    # pylint: disable=unnecessary-lambda, invalid-name
+    data_shape = get_const_tuple(data1.shape)
+    assert get_const_tuple(data2.shape) == data_shape, "data1 and data2 should 
have the same shape"
+    assert kernel_size > 0 and kernel_size % 2, "kernel_size should be 
non-negative odd number"
+    if isinstance(padding, (tuple, list)):
+        if len(padding) == 2:
+            pad_before_h = pad_after_h = padding[0]
+            pad_before_w = pad_after_w = padding[1]
+        elif len(padding) == 4:
+            pad_before_h, pad_before_w, pad_after_h, pad_after_w = padding
+        else:
+            raise ValueError("invalid padding")
+    elif isinstance(padding, int):
+        pad_before_h = pad_after_h = pad_before_w = pad_after_w = padding
+    else:
+        raise ValueError("invalid padding")
+    pad_before = [0, 0, pad_before_h, pad_before_w]
+    pad_after = [0, 0, pad_after_h, pad_after_w]
+    padded_data1 = pad(data1, pad_before, pad_after)
+    padded_data2 = pad(data2, pad_before, pad_after)
+
+    batch, channel, height, width = data_shape
+
+    kernel_radius = (kernel_size - 1) // 2
+    border_size = max_displacement + kernel_radius
+    displacement_radius = max_displacement // stride2
+    displacement_size = 2 * displacement_radius + 1
+
+    padded_width = width + pad_before_w + pad_after_w
+    padded_height = height + pad_before_h + pad_after_h
+    out_channel = displacement_size * displacement_size
+    out_height = (padded_height - 2 * border_size + stride1 - 1) // stride1
+    out_width = (padded_width - 2 * border_size + stride1 - 1) // stride1
+
+    rc = te.reduce_axis((0, channel), name='rc')
+    ry = te.reduce_axis((0, kernel_size), name='ry')
+    rx = te.reduce_axis((0, kernel_size), name='rx')
+
+    if is_multiply:
+        corr_func = lambda x, y: x * y
+    else:
+        corr_func = lambda x, y: te.abs(x - y)
+
+    def _compute_correlation(n, q, i, j):
+        # location in data1
+        y1 = i * stride1 + max_displacement
+        x1 = j * stride1 + max_displacement
+        # location in data2
+        y2 = y1 + (te.indexdiv(q, displacement_size) - displacement_radius) * 
stride2
+        x2 = x1 + (te.indexmod(q, displacement_size) - displacement_radius) * 
stride2
+        return te.sum(corr_func(padded_data1[n, rc, y1 + ry, x1 + rx],
+                                padded_data2[n, rc, y2 + ry, x2 + rx]), 
axis=[rc, ry, rx])
+
+    correlation = te.compute((batch, out_channel, out_height, out_width), 
lambda n, q, i, j:
+                             _compute_correlation(n, q, i, j), 
tag="correlation_nchw")
+    return correlation / (kernel_size * kernel_size * channel)
diff --git a/topi/python/topi/testing/__init__.py 
b/topi/python/topi/testing/__init__.py
index 36c460e..511fe16 100644
--- a/topi/python/topi/testing/__init__.py
+++ b/topi/python/topi/testing/__init__.py
@@ -29,6 +29,7 @@ from .conv3d_ncdhw_python import conv3d_ncdhw_python
 from .conv3d_ndhwc_python import conv3d_ndhwc_python
 from .conv2d_transpose_python import conv2d_transpose_nchw_python, 
conv2d_transpose_nhwc_python
 from .conv1d_transpose_ncw_python import conv1d_transpose_ncw_python
+from .correlation_nchw_python import correlation_nchw_python
 from .deformable_conv2d_nchw_python import deformable_conv2d_nchw_python
 from .depthwise_conv2d_python import depthwise_conv2d_python_nchw, 
depthwise_conv2d_python_nhwc
 from .dilate_python import dilate_python
diff --git a/topi/python/topi/testing/correlation_nchw_python.py 
b/topi/python/topi/testing/correlation_nchw_python.py
new file mode 100644
index 0000000..f053656
--- /dev/null
+++ b/topi/python/topi/testing/correlation_nchw_python.py
@@ -0,0 +1,103 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, line-too-long, unused-variable, too-many-locals
+"""Convolution 3D in python"""
+import numpy as np
+
+
+def correlation_nchw_python(data1, data2, kernel_size, max_displacement, 
stride1, stride2, padding, is_multiply):
+    """Correlationn operator in NCHW layout.
+
+    Parameters
+    ----------
+    data1_np : numpy.ndarray
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    data2_np : numpy.ndarray
+        4-D with shape [batch, in_channel, in_height, in_width]
+
+    kernel_size: int
+        Kernel size for correlation, must be an odd number
+
+    max_displacement: int
+        Max displacement of Correlation
+
+    stride1: int
+        Stride for data1
+
+    stride2: int
+        Stride for data2 within the neightborhood centered around data1
+
+    padding: int
+        Padding for correlation
+
+    is_multiply: bool
+        operation type is either multiplication or substraction
+
+    Returns
+    -------
+    c_np : np.ndarray
+        4-D with shape [batch, out_channel, out_height, out_width]
+    """
+    # compute output's dimension
+    pad_data_height = data1.shape[2] + 2 * padding
+    pad_data_width = data1.shape[3] + 2 * padding
+    kernel_radius = (kernel_size - 1) // 2
+    border_size = max_displacement + kernel_radius
+    out_width = (pad_data_width - border_size * 2) // stride1
+    out_height = (pad_data_height - border_size * 2) // stride1
+    neighborhood_grid_radius = max_displacement // stride2
+    neighborhood_grid_width = neighborhood_grid_radius * 2 + 1
+    out_channel = neighborhood_grid_width * neighborhood_grid_width
+
+    out = np.zeros((data1.shape[0], out_channel, out_height, out_width))
+    pad_data1 = np.zeros((data1.shape[0], data1.shape[1],
+                          pad_data_height, pad_data_width))
+    pad_data2 = np.zeros((data1.shape[0], data1.shape[1],
+                          pad_data_height, pad_data_width))
+
+    pad_data1[:, :, padding:padding + data1.shape[2],
+              padding:padding + data1.shape[3]] = data1[:, :, :, :]
+    pad_data2[:, :, padding:padding + data2.shape[2],
+              padding:padding + data2.shape[3]] = data2[:, :, :, :]
+
+    if is_multiply:
+        corr_func = lambda x, y: x * y
+    else:
+        corr_func = lambda x, y: abs(x - y)
+
+    # pylint: disable=too-many-nested-blocks
+    for i in range(out_height):
+        for j in range(out_width):
+            for nbatch in range(data1.shape[0]):
+                # x1,y1 is the location in data1 , i,j is the location in 
output
+                x1 = j * stride1 + max_displacement
+                y1 = i * stride1 + max_displacement
+
+                for q in range(out_channel):
+                    # location in data2
+                    x2 = x1 + (q % neighborhood_grid_width - 
neighborhood_grid_radius) * stride2
+                    y2 = y1 + (q // neighborhood_grid_width - 
neighborhood_grid_radius) * stride2
+
+                    for h in range(kernel_size):
+                        for w in range(kernel_size):
+                            for channel in range(data1.shape[1]):
+                                out[nbatch, q, i, j] += 
corr_func(pad_data1[nbatch, channel, y1 + h, x1 + w],
+                                                                  
pad_data2[nbatch, channel, y2 + h, x2 + w])
+
+    out /= float(kernel_size** 2 *data1.shape[1])
+    return out
diff --git a/topi/tests/python/test_topi_correlation.py 
b/topi/tests/python/test_topi_correlation.py
new file mode 100644
index 0000000..663564f
--- /dev/null
+++ b/topi/tests/python/test_topi_correlation.py
@@ -0,0 +1,93 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License
+"""test of correlation operator in NCHW layout"""
+import numpy as np
+import tvm
+from tvm import te
+from tvm import autotvm
+import topi
+import topi.testing
+from tvm.contrib.pickle_memoize import memoize
+from topi.util import get_const_tuple
+
+from common import get_all_backend
+
+
+_correlation_implement = {
+    "generic": (topi.nn.correlation_nchw, 
topi.generic.schedule_correlation_nchw),
+    "cuda": (topi.cuda.correlation_nchw, topi.cuda.schedule_correlation_nchw),
+}
+
+
+def verify_correlation_nchw(data_shape, kernel_size, max_displacement, 
stride1, stride2, pad_size,
+                            is_multiply):
+    print("Workload: (%d, %d, %d, %d, %d, %d, %d, %d, %d, %d)" % 
(data_shape[0], data_shape[1], data_shape[2], data_shape[3],
+                                                                  kernel_size, 
max_displacement, stride1, stride2, pad_size,
+                                                                  is_multiply))
+
+    A = te.placeholder(data_shape, name='data1')
+    B = te.placeholder(data_shape, name='data2')
+    dtype = A.dtype
+
+    @memoize("topi.tests.test_topi_correlation_nchw.verify_correlation_nchw")
+    def get_ref_data():
+        a_np = np.random.uniform(size=data_shape).astype(dtype)
+        b_np = np.random.uniform(size=data_shape).astype(dtype)
+        c_np = topi.testing.correlation_nchw_python(a_np, b_np, kernel_size, 
max_displacement, stride1, stride2, pad_size, is_multiply)
+        return a_np, b_np, c_np
+
+    a_np, b_np, c_np = get_ref_data()
+
+    def check_device(device):
+        ctx = tvm.context(device, 0)
+        if not ctx.exist:
+            print("Skip because %s is not enabled" % device)
+            return
+        print("Running on target: %s" % device)
+        fcompute, fschedule = topi.testing.dispatch(
+            device, _correlation_implement)
+        with tvm.target.create(device):
+            C = fcompute(A, B, kernel_size, max_displacement, stride1, 
stride2, pad_size, is_multiply)
+            s = fschedule([C])
+
+            a = tvm.nd.array(a_np, ctx)
+            b = tvm.nd.array(b_np, ctx)
+            c = tvm.nd.empty(c_np.shape, dtype=dtype, ctx=ctx)
+
+            func = tvm.build(s, [A, B, C], device)
+            func(a, b, c)
+            tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
+
+    for device in get_all_backend():
+        check_device(device)
+
+
+def test_correlation_nchw():
+    verify_correlation_nchw((1, 3, 10, 10), kernel_size=1, max_displacement=4,
+                        stride1=1, stride2=1, pad_size=4, is_multiply=True)
+    verify_correlation_nchw((1, 3, 10, 10), kernel_size=1, max_displacement=5,
+                            stride1=1, stride2=1, pad_size=5, is_multiply=True)
+    verify_correlation_nchw((5, 1, 4, 4), kernel_size=3, max_displacement=1,
+                            stride1=2, stride2=1, pad_size=2, is_multiply=True)
+    verify_correlation_nchw((5, 1, 6, 4), kernel_size=3, max_displacement=1,
+                            stride1=2, stride2=2, pad_size=2, 
is_multiply=False)
+    verify_correlation_nchw((5, 1, 11, 11), kernel_size=5, max_displacement=1,
+                            stride1=1, stride2=1, pad_size=2, 
is_multiply=False)
+
+
+if __name__ == "__main__":
+    test_correlation_nchw()

Reply via email to