This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 0be100160f23aed11d5547e91e97d9d474644ed7
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Feb 14 14:57:05 2023 -0500

    [Unity] Relax op: image (#13994)
    
    This PR is about the high-level tensor computation operators in Relax.
    
    This PR includes the image operators.
---
 include/tvm/relax/attrs/image.h                    |  81 +++++++
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/{ => image}/__init__.py        |  16 +-
 .../relax/op/{__init__.py => image/_ffi_api.py}    |  15 +-
 python/tvm/relax/op/image/image.py                 | 128 +++++++++++
 python/tvm/relax/op/op_attrs.py                    |   5 +
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 src/relax/op/image/resize.cc                       | 113 ++++++++++
 src/relax/op/image/resize.h                        |  43 ++++
 tests/python/relax/test_op_image.py                | 245 +++++++++++++++++++++
 .../python/relax/test_tvmscript_parser_op_image.py |  54 +++++
 11 files changed, 678 insertions(+), 25 deletions(-)

diff --git a/include/tvm/relax/attrs/image.h b/include/tvm/relax/attrs/image.h
new file mode 100644
index 0000000000..13463aaa48
--- /dev/null
+++ b/include/tvm/relax/attrs/image.h
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/attrs/image.h
+ * \brief Attributes for image operators.
+ */
+#ifndef TVM_RELAX_ATTRS_IMAGE_H_
+#define TVM_RELAX_ATTRS_IMAGE_H_
+
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Attributes used in image resize2d operator */
+struct Resize2DAttrs : public tvm::AttrsNode<Resize2DAttrs> {
+  Array<FloatImm> roi;
+  String layout;
+  String method;
+  String coordinate_transformation_mode;
+  String rounding_method;
+  double cubic_alpha;
+  int cubic_exclude;
+  double extrapolation_value;
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(Resize2DAttrs, "relax.attrs.Resize2DAttrs") {
+    TVM_ATTR_FIELD(roi).describe(
+        "Region of Interest for coordinate transformation mode 
'tf_crop_and_resize'");
+    TVM_ATTR_FIELD(layout).describe(
+        "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
+        "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
+        "dimensions respectively. Resize is applied on the 'H' and"
+        "'W' dimensions.");
+    TVM_ATTR_FIELD(method).describe(
+        "Specify the mode to use for scaling."
+        "nearest_neighbor -  Nearest Neighbor"
+        "linear - Bilinear Interpolation"
+        "cubic - Bicubic Interpolation");
+    TVM_ATTR_FIELD(coordinate_transformation_mode)
+        .describe(
+            "Describes how to transform the coordinate in the resized tensor"
+            "to the coordinate in the original tensor."
+            "Refer to the ONNX Resize operator specification for details"
+            "Available options are half_pixel, align_corners and asymmetric");
+    TVM_ATTR_FIELD(rounding_method)
+        .describe(
+            "indicates how to find the \"nearest\" pixel in nearest_neighbor 
method"
+            "Available options are round, floor, and ceil.");
+    TVM_ATTR_FIELD(cubic_alpha).describe("Spline Coefficient for Bicubic 
Interpolation");
+    TVM_ATTR_FIELD(cubic_exclude)
+        .describe("Flag to exclude exterior of the image during bicubic 
interpolation");
+    TVM_ATTR_FIELD(extrapolation_value)
+        .describe("Value to return when roi is outside of the image");
+    TVM_ATTR_FIELD(out_dtype).describe(
+        "The dtype of the output tensor. It it is not specified, the output 
will have the same "
+        "dtype as input if not specified.");
+  }
+};  // struct Resize2dAttrs
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_ATTRS_IMAGE_H_
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index da29c3715d..3857351269 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -26,4 +26,5 @@ from .manipulate import *
 from .op_attrs import *
 from .set import *
 from . import builtin
+from . import image
 from . import memory
diff --git a/python/tvm/relax/op/__init__.py 
b/python/tvm/relax/op/image/__init__.py
similarity index 72%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/image/__init__.py
index da29c3715d..f2552ad6ac 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/image/__init__.py
@@ -14,16 +14,6 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
-
-# Operators
-from .base import *
-from .binary import *
-from .datatype import *
-from .index import *
-from .manipulate import *
-from .op_attrs import *
-from .set import *
-from . import builtin
-from . import memory
+# pylint: disable=wildcard-import
+"""Image operators."""
+from .image import *
diff --git a/python/tvm/relax/op/__init__.py 
b/python/tvm/relax/op/image/_ffi_api.py
similarity index 72%
copy from python/tvm/relax/op/__init__.py
copy to python/tvm/relax/op/image/_ffi_api.py
index da29c3715d..e666203ae7 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/image/_ffi_api.py
@@ -14,16 +14,7 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-# pylint: disable=wildcard-import, redefined-builtin
-"""Relax core operators."""
+"""Constructor APIs"""
+import tvm._ffi
 
-# Operators
-from .base import *
-from .binary import *
-from .datatype import *
-from .index import *
-from .manipulate import *
-from .op_attrs import *
-from .set import *
-from . import builtin
-from . import memory
+tvm._ffi._init_api("relax.op.image", __name__)
diff --git a/python/tvm/relax/op/image/image.py 
b/python/tvm/relax/op/image/image.py
new file mode 100644
index 0000000000..562de5021d
--- /dev/null
+++ b/python/tvm/relax/op/image/image.py
@@ -0,0 +1,128 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Image operators."""
+from typing import Optional, Tuple, Union
+
+from tvm import DataType
+from tvm.ir.expr import PrimExpr
+
+from . import _ffi_api
+from ...expr import Expr, ShapeExpr
+
+
+PrimExprLike = Union[int, PrimExpr]
+
+
+def resize2d(
+    data: Expr,
+    size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]],
+    roi: Optional[Union[float, Tuple[float]]] = None,
+    layout: str = "NCHW",
+    method: str = "linear",
+    coordinate_transformation_mode: str = "half_pixel",
+    rounding_method: str = "round",
+    cubic_alpha: float = -0.5,
+    cubic_exclude: int = 0,
+    extrapolation_value: float = 0.0,
+    out_dtype: Optional[Union[str, DataType]] = None,
+) -> Expr:
+    """Image resize2d operator.
+
+    This operator takes data as input and does 2D scaling to the given scale 
factor.
+    In the default case, where the data_layout is `NCHW`
+    with data of shape (n, c, h, w)
+    out will have a shape (n, c, size[0], size[1])
+
+    method indicates the algorithm to be used while calculating the out value
+    and method can be one of ("linear", "nearest_neighbor", "cubic")
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data to the operator.
+
+    size: Union[Expr, PrimExprLike, Tuple[PrimExprLike]]
+        The out size to which the image will be resized.
+        If specified as a list, it is required to have length either 1 or 2.
+        If specified as an Expr, it is required to have ndim 2.
+
+    roi: Optional[Union[float, Tuple[float]]]
+        The region of interest for cropping the input image. Expected to be of
+        size 4, and format [start_h, start_w, end_h, end_w].
+        Only used if coordinate_transformation_mode is tf_crop_and_resize.
+
+    layout : str
+        Layout of the input.
+
+    method : str
+        Scale method to used [nearest_neighbor, linear, cubic].
+
+    coordinate_transformation_mode : str
+        Describes how to transform the coordinate in the resized tensor
+        to the coordinate in the original tensor. Definitions can be found
+        in topi/image/resize.py.
+        [half_pixel, align_corners, asymmetric, pytorch_half_pixel,
+        tf_half_pixel_for_nn, and tf_crop_and_resize].
+
+    rounding_method: str
+        indicates how to find the "nearest" pixel in nearest_neighbor method
+        [round, floor, ceil]
+
+    cubic_alpha: float
+        Spline Coefficient for bicubic interpolation
+
+    cubic_exclude: int
+        Flag to exclude exterior of the image during bicubic interpolation
+
+    extrapolation_value: float
+        Fill value to use when roi is outside of the image
+
+    out_dtype : Optional[Union[str, DataType]]
+        The dtype of the output tensor.
+        It it is not specified, the output will have the same dtype as input 
if not specified.
+
+    Returns
+    -------
+    result: relax.Expr
+        The resized result.
+    """
+    if roi is None:
+        roi = (0.0, 0.0, 0.0, 0.0)  # type: ignore
+    elif isinstance(roi, float):
+        roi = (roi, roi, roi, roi)  # type: ignore
+
+    if isinstance(size, (int, PrimExpr)):
+        size = (size, size)
+    if isinstance(size, tuple):
+        if len(size) == 1:
+            size = ShapeExpr([size[0], size[0]])
+        else:
+            size = ShapeExpr(size)
+
+    return _ffi_api.resize2d(  # type: ignore
+        data,
+        size,
+        roi,
+        layout,
+        method,
+        coordinate_transformation_mode,
+        rounding_method,
+        cubic_alpha,
+        cubic_exclude,
+        extrapolation_value,
+        out_dtype,
+    )
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 47c3b28798..fb64443b7e 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -34,6 +34,11 @@ class StridedSliceAttrs(Attrs):
     """Attributes used in strided_slice operator"""
 
 
+@tvm._ffi.register_object("relax.attrs.Resize2DAttrs")
+class Resize2DAttrs(Attrs):
+    """Attributes used in image resize2d operator"""
+
+
 @tvm._ffi.register_object("relax.attrs.UniqueAttrs")
 class UniqueAttrs(Attrs):
     """Attributes used for the unique operator"""
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 537adec615..22b85f6f40 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -35,6 +35,7 @@ from tvm.relax.op import (
     builtin,
     call_builtin_with_ctx,
     call_tir,
+    image,
     invoke_closure,
     make_closure,
     memory,
@@ -420,6 +421,7 @@ __all__ = [
     "func_ret_struct_info",
     "func_ret_value",
     "function",
+    "image",
     "invoke_closure",
     "make_closure",
     "memory",
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
new file mode 100644
index 0000000000..2711b3cc45
--- /dev/null
+++ b/src/relax/op/image/resize.cc
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file resize.cc
+ * \brief Image resize operators.
+ */
+
+#include "resize.h"
+
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+/* relax.resize2d */
+TVM_REGISTER_NODE_TYPE(Resize2DAttrs);
+
+Expr resize2d(Expr data, Expr size, Array<FloatImm> roi, String layout, String 
method,
+              String coordinate_transformation_mode, String rounding_method, 
double cubic_alpha,
+              int cubic_exclude, double extrapolation_value, DataType 
out_dtype) {
+  ObjectPtr<Resize2DAttrs> attrs = make_object<Resize2DAttrs>();
+  attrs->roi = std::move(roi);
+  attrs->layout = std::move(layout);
+  attrs->method = std::move(method);
+  attrs->coordinate_transformation_mode = 
std::move(coordinate_transformation_mode);
+  attrs->rounding_method = std::move(rounding_method);
+  attrs->cubic_alpha = cubic_alpha;
+  attrs->cubic_exclude = cubic_exclude;
+  attrs->extrapolation_value = extrapolation_value;
+  attrs->out_dtype = out_dtype;
+
+  static const Op& op = Op::Get("relax.image.resize2d");
+  return Call(op, {std::move(data), std::move(size)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.image.resize2d").set_body_typed(resize2d);
+
+StructInfo InferStructInfoResize2D(const Call& call, const BlockBuilder& ctx) {
+  if (call->args.size() != 1 && call->args.size() != 2) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "Resize2D expects either one or two arguments, while the given 
number of arguments is "
+        << call->args.size());
+  }
+
+  const auto* data_sinfo = 
GetStructInfoAs<TensorStructInfoNode>(call->args[0]);
+  const auto* size_sinfo = GetStructInfoAs<ShapeStructInfoNode>(call->args[1]);
+  const auto* size_value = call->args[1].as<ShapeExprNode>();
+  if (data_sinfo == nullptr) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Resize2D expects the input data to be a Tensor, while 
the given data is "
+                     << call->args[0]->GetTypeKey());
+  }
+  if (size_sinfo == nullptr) {
+    ctx->ReportFatal(
+        Diagnostic::Error(call)
+        << "Resize2D expects the given output image size to be a Shape, while 
the given one is "
+        << call->args[1]->GetTypeKey());
+  }
+  if (size_sinfo->ndim != 2) {
+    ctx->ReportFatal(Diagnostic::Error(call) << "Resize2D expects the given 
output image size to "
+                                                "be a 2-dim shape, while the 
given one has ndim "
+                                             << size_sinfo->ndim);
+  }
+
+  const auto* attrs = call->attrs.as<Resize2DAttrs>();
+  auto [data_layout, data2NCHW] = CheckTensorLayout(call, ctx, attrs->layout,  
//
+                                                    /*tgt_layout=*/"NCHW",     
//
+                                                    /*tensor_name=*/"data");
+
+  DataType out_dtype = attrs->out_dtype.is_void() ? data_sinfo->dtype : 
attrs->out_dtype;
+
+  Optional<ShapeExpr> data_shape =
+      CheckNdimPerLayoutAndGetShape(call, ctx, 
GetRef<TensorStructInfo>(data_sinfo), data_layout);
+  if (!data_shape.defined() || size_value == nullptr) {
+    return TensorStructInfo(out_dtype, data_layout.ndim());
+  }
+
+  Array<PrimExpr> data_NCHW_shape = 
data2NCHW.ForwardShape(data_shape.value()->values);
+  Array<PrimExpr> out_NCHW_shape(data_NCHW_shape);
+  out_NCHW_shape.Set(2, size_value->values[0]);
+  out_NCHW_shape.Set(3, size_value->values[1]);
+
+  Array<PrimExpr> out_shape = data2NCHW.BackwardShape(out_NCHW_shape);
+  return TensorStructInfo(ShapeExpr(out_shape), out_dtype);
+}
+
+TVM_REGISTER_OP("relax.image.resize2d")
+    .set_attrs_type<Resize2DAttrs>()
+    .set_num_inputs(2)
+    .add_argument("data", "Tensor", "The input tensor.")
+    .add_argument("size", "Shape", "The output image shape.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoResize2D);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/image/resize.h b/src/relax/op/image/resize.h
new file mode 100644
index 0000000000..085a1cbc5d
--- /dev/null
+++ b/src/relax/op/image/resize.h
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file resize.h
+ * \brief The functions to make Relax image resize operator calls.
+ */
+
+#ifndef TVM_RELAX_OP_IMAGE_RESIZE_H_
+#define TVM_RELAX_OP_IMAGE_RESIZE_H_
+
+#include <tvm/relax/attrs/image.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Image resize2d operator. */
+Expr resize2d(Expr data, Expr size, Array<FloatImm> roi, String layout, String 
method,
+              String coordinate_transformation_mode, String rounding_method, 
double cubic_alpha,
+              int cubic_exclude, double extrapolation_value, DataType 
out_dtype);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_IMAGE_RESIZE_H_
diff --git a/tests/python/relax/test_op_image.py 
b/tests/python/relax/test_op_image.py
new file mode 100644
index 0000000000..b06b51a2a1
--- /dev/null
+++ b/tests/python/relax/test_op_image.py
@@ -0,0 +1,245 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op
+from tvm.script import relax as R
+
+
+def test_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    assert relax.op.image.resize2d(x, (28, 28)).op == 
Op.get("relax.image.resize2d")
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_resize2d_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 32, 32, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor((2, 4, 32, 32, 16), "float32"))
+    x3 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x4 = relax.Var("x", R.Tensor("float32", ndim=5))
+    x5 = relax.Var("x", R.Tensor("float32"))
+    x6 = relax.Var("x", R.Tensor(ndim=4))
+    x7 = relax.Var("x", R.Tensor())
+
+    _check_inference(
+        bb, relax.op.image.resize2d(x0, (28, 28)), relax.TensorStructInfo((2, 
3, 28, 28), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x0, size=28),
+        relax.TensorStructInfo((2, 3, 28, 28), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x0, size=(28, 30)),
+        relax.TensorStructInfo((2, 3, 28, 30), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x1, size=28, layout="NHWC"),
+        relax.TensorStructInfo((2, 28, 28, 3), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x0, size=28, out_dtype="float16"),
+        relax.TensorStructInfo((2, 3, 28, 28), "float16"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x2, size=28, layout="NCHW16c"),
+        relax.TensorStructInfo((2, 4, 28, 28, 16), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.image.resize2d(x3, size=28), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x4, size=28, layout="NCHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+    _check_inference(
+        bb, relax.op.image.resize2d(x5, size=28), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb, relax.op.image.resize2d(x6, size=28), 
relax.TensorStructInfo(dtype="", ndim=4)
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x6, size=28, out_dtype="float32"),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb, relax.op.image.resize2d(x7, size=28), 
relax.TensorStructInfo(dtype="", ndim=4)
+    )
+
+
+def test_resize2d_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    n = tir.Var("n", "int64")
+    c = tir.Var("c", "int64")
+    ih = tir.Var("ih", "int64")
+    iw = tir.Var("iw", "int64")
+    oh = tir.Var("oh", "int64")
+    ow = tir.Var("ow", "int64")
+    x0 = relax.Var("x", R.Tensor((n, c, ih, iw), "float32"))
+    x1 = relax.Var("x", R.Tensor((n, c, ih, iw, 16), "float32"))
+
+    _check_inference(
+        bb, relax.op.image.resize2d(x0, size=oh), relax.TensorStructInfo((n, 
c, oh, oh), "float32")
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x0, size=(oh, ow)),
+        relax.TensorStructInfo((n, c, oh, ow), "float32"),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x1, size=(oh, ow), layout="NCHW16c"),
+        relax.TensorStructInfo((n, c, oh, ow, 16), "float32"),
+    )
+
+
+def test_resize2d_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo(ndim=4))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+    s2 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(
+        bb, relax.op.image.resize2d(x0, size=32), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x1, size=32, layout="NCHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x2, size=32, layout="NCHW16c"),
+        relax.TensorStructInfo(dtype="float32", ndim=5),
+    )
+
+
+def test_resize2d_infer_struct_info_pool_size_var():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    s0 = relax.Var("s", relax.ShapeStructInfo((30, 30)))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+
+    _check_inference(
+        bb,
+        relax.op.image.resize2d(x0, s0),
+        relax.TensorStructInfo(dtype="float32", ndim=4),
+    )
+    _check_inference(
+        bb, relax.op.image.resize2d(x0, s1), 
relax.TensorStructInfo(dtype="float32", ndim=4)
+    )
+
+
+def test_resize2d_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "int64"))
+    _check_inference(
+        bb, relax.op.image.resize2d(x0, size=28), relax.TensorStructInfo((2, 
3, 28, 28), "float16")
+    )
+    _check_inference(
+        bb, relax.op.image.resize2d(x1, size=28), relax.TensorStructInfo((2, 
3, 28, 28), "int8")
+    )
+    _check_inference(
+        bb, relax.op.image.resize2d(x2, size=28), relax.TensorStructInfo((2, 
3, 28, 28), "int64")
+    )
+
+
+def test_resize2d_infer_struct_info_wrong_layout_string():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 28, 28), "float32"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x, size=28, layout="OIHW"))
+
+
+def test_resize2d_wrong_input_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    x1 = relax.Var("x", R.Tensor((2, 3, 32, 32, 3), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32", ndim=3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x0, size=28, layout="NCHW16c"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x1, size=28, layout="NCHW"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x2, size=28))
+
+
+def test_resize2d_wrong_pool_size_ndim():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float16"))
+    s0 = relax.ShapeExpr((3,))
+    s1 = relax.Var("s", relax.ShapeStructInfo((30, 30, 30)))
+    s2 = relax.Var("s", relax.ShapeStructInfo(ndim=3))
+    s3 = relax.Var("s", relax.ShapeStructInfo(ndim=1))
+    s4 = relax.Var("s", relax.ShapeStructInfo(ndim=0))
+    s5 = relax.Var("s", relax.ShapeStructInfo())
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x0, (3, 3, 3)))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x0, s0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x0, s1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x0, s2))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x0, s3))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x0, s4))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x0, s5))
+
+
+def test_resize2d_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3, 28, 28)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3, 28, 28), 
"float32")))
+    x2 = relax.Var("x", R.Tensor((2, 3, 32, 32), "float32"))
+    s0 = relax.Var("s", R.Tensor((3, 3)))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x0, size=32))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x1, size=32))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.image.resize2d(x2, s0))
+    with pytest.raises(TVMError):
+        relax.op.image.resize2d(x2, [30, 30])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_image.py 
b/tests/python/relax/test_tvmscript_parser_op_image.py
new file mode 100644
index 0000000000..a90da37812
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_parser_op_image.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import IRModule, relax
+from tvm.script import relax as R
+
+
+def _check(
+    parsed: Union[relax.Function, IRModule],
+    expect: Optional[Union[relax.Function, IRModule]],
+):
+    test = parsed.script(show_meta=True)
+    roundtrip_mod = tvm.script.from_source(test)
+    tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
+    if expect:
+        tvm.ir.assert_structural_equal(parsed, expect)
+
+
+def test_resize2d():
+    @R.function
+    def foo(x: R.Tensor((2, 14, 14, 3), "float32")) -> R.Tensor((2, 28, 28, 
3), "float32"):
+        gv: R.Tensor((2, 28, 28, 3), "float32") = R.image.resize2d(x, 
size=(28, 28), layout="NHWC")
+        return gv
+
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 14, 14, 3), "float32"))
+    with bb.function("foo", [x]):
+        gv = bb.emit(relax.op.image.resize2d(x, (28, 28), layout="NHWC"))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to