This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit aaac5c44bcd5d7c7dcc462d51fe6c454d3217ddf
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Feb 14 15:09:10 2023 -0500

    [Unity] Relax op: search (#13992)
    
    This PR is about the high-level tensor computation operators in Relax.
    
    This PR includes the search operators.
---
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/search.py                      |  50 ++++
 python/tvm/script/ir_builder/relax/ir.py           |   4 +-
 src/relax/op/tensor/search.cc                      |  99 ++++++++
 src/relax/op/tensor/search.h                       |  41 +++
 tests/python/relax/test_op_search.py               | 278 +++++++++++++++++++++
 .../relax/test_tvmscript_parser_op_search.py       |  60 +++++
 7 files changed, 532 insertions(+), 1 deletion(-)

diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 4b2f990eaa..39a645ffea 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -27,6 +27,7 @@ from .linear_algebra import *
 from .manipulate import *
 from .op_attrs import *
 from .statistical import *
+from .search import *
 from .set import *
 from .ternary import *
 from .unary import *
diff --git a/python/tvm/relax/op/search.py b/python/tvm/relax/op/search.py
new file mode 100644
index 0000000000..8252b0e1d8
--- /dev/null
+++ b/python/tvm/relax/op/search.py
@@ -0,0 +1,50 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Search operators."""
+from . import _ffi_api
+from ..expr import Expr
+
+
+def where(condition: Expr, x1: Expr, x2: Expr) -> Expr:
+    """Selecting elements from either the input tensors depending on the value 
of the
+    condition.
+
+    For a given position, return the corresponding value in `x1` if 
`condition` is True,
+    and return the corresponding value in `x2` otherwise.
+
+    Parameters
+    ----------
+    condition : relax.Expr
+        When True, yield `x1`; otherwise, yield `x2`.
+        Must be broadcasting compatible with `x1` and `x2`.
+        Must have boolean dtype.
+
+    x1 : relax.Expr
+        The first input tensor.
+        Must be broadcasting compatible with `condition` and `x2`.
+
+    x2 : relax.Expr
+        The second input tensor.
+        Must be broadcasting compatible with `condition` and `x1`.
+
+    Returns
+    -------
+    result : relax.Expr
+        The result tensor.
+    """
+    return _ffi_api.where(condition, x1, x2)  # type: ignore
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 9f5fe03dec..b779bdac9c 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -101,6 +101,7 @@ from tvm.relax.op import (
     tril,
     triu,
     unique,
+    where,
     zeros,
     zeros_like,
     nn,
@@ -547,8 +548,9 @@ __all__ = [
     "tril",
     "triu",
     "tuple",
-    "variance",
     "unique",
+    "variance",
+    "where",
     "zeros",
     "zeros_like",
     "nn",    
diff --git a/src/relax/op/tensor/search.cc b/src/relax/op/tensor/search.cc
new file mode 100644
index 0000000000..5191017ea1
--- /dev/null
+++ b/src/relax/op/tensor/search.cc
@@ -0,0 +1,99 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file search.cc
+ * \brief Searching operators.
+ */
+
+#include "search.h"
+
+#include <algorithm>
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+/* relax.where */
+Expr where(Expr condition, Expr x1, Expr x2) {
+  static const Op& op = Op::Get("relax.where");
+  return Call(op, {std::move(condition), std::move(x1), std::move(x2)}, 
Attrs(), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.where").set_body_typed(where);
+
+StructInfo InferStructInfoWhere(const Call& call, const BlockBuilder& ctx) {
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+  TensorStructInfo cond_sinfo = input_sinfo[0];
+  TensorStructInfo x1_sinfo = input_sinfo[1];
+  TensorStructInfo x2_sinfo = input_sinfo[2];
+
+  if (!cond_sinfo->dtype.is_bool()) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Where requires the input condition tensor to have 
boolean dtype. However, "
+                        "the given condition dtype is "
+                     << cond_sinfo->dtype);
+  }
+  DataType output_dtype = InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, 
x2_sinfo);
+
+  int output_ndim;
+  if (cond_sinfo->IsUnknownNdim() || x1_sinfo->IsUnknownNdim() || 
x2_sinfo->IsUnknownNdim()) {
+    output_ndim = kUnknownNDim;
+  } else {
+    output_ndim = std::max(cond_sinfo->ndim, std::max(x1_sinfo->ndim, 
x2_sinfo->ndim));
+  }
+
+  const auto* cond_shape = cond_sinfo->shape.as<ShapeExprNode>();
+  const auto* x1_shape = x1_sinfo->shape.as<ShapeExprNode>();
+  const auto* x2_shape = x2_sinfo->shape.as<ShapeExprNode>();
+  if (cond_shape && x1_shape && x2_shape) {
+    // Step 1. Compute the broadcasted shape of x1's and x2's
+    Optional<Array<PrimExpr>> broadcasted_shape =
+        InferBinaryBroadcastShape(call, ctx, x1_shape->values, 
x2_shape->values);
+    if (!broadcasted_shape.defined()) {
+      return TensorStructInfo(output_dtype, output_ndim);
+    }
+    // Step 2. Compute the broadcasted shape of cond's and the previous 
broadcasted shape.
+    broadcasted_shape =
+        InferBinaryBroadcastShape(call, ctx, cond_shape->values, 
broadcasted_shape.value());
+    if (!broadcasted_shape.defined()) {
+      return TensorStructInfo(output_dtype, output_ndim);
+    }
+    ICHECK_EQ(static_cast<int>(broadcasted_shape.value().size()), output_ndim);
+    return TensorStructInfo(ShapeExpr(broadcasted_shape.value()), 
output_dtype);
+  } else if (cond_sinfo->shape.defined() &&                 //
+             x1_sinfo->shape.defined() &&                   //
+             x2_sinfo->shape.defined() &&                   //
+             cond_sinfo->shape.same_as(x1_sinfo->shape) &&  //
+             cond_sinfo->shape.same_as(x2_sinfo->shape)) {
+    return TensorStructInfo(cond_sinfo->shape.value(), output_dtype);
+  } else {
+    return TensorStructInfo(output_dtype, output_ndim);
+  }
+}
+
+TVM_REGISTER_OP("relax.where")
+    .set_num_inputs(3)
+    .add_argument("condition", "Tensor", "When True, yield `x1`; otherwise, 
yield `x2`.")
+    .add_argument("x1", "Tensor", "The first input tensor.")
+    .add_argument("x2", "Tensor", "The second input tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoWhere);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/tensor/search.h b/src/relax/op/tensor/search.h
new file mode 100644
index 0000000000..aeae4a7157
--- /dev/null
+++ b/src/relax/op/tensor/search.h
@@ -0,0 +1,41 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file search.h
+ * \brief The functions to make Relax searching operator calls.
+ */
+#ifndef TVM_RELAX_OP_TENSOR_SEARCH_H_
+#define TVM_RELAX_OP_TENSOR_SEARCH_H_
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Selecting elements from either the input tensors depending on the 
value of the
+ * condition.
+ */
+Expr where(Expr condition, Expr x1, Expr x2);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_TENSOR_SEARCH_H_
diff --git a/tests/python/relax/test_op_search.py 
b/tests/python/relax/test_op_search.py
new file mode 100644
index 0000000000..a2f271671b
--- /dev/null
+++ b/tests/python/relax/test_op_search.py
@@ -0,0 +1,278 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op
+from tvm.script import relax as R
+
+
+def test_op_correctness():
+    cond = relax.Var("cond", R.Tensor((2, 3), "bool"))
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    y = relax.Var("x", R.Tensor((2, 3), "float32"))
+    assert relax.op.where(cond, x, y).op == Op.get("relax.where")
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_where_infer_struct_info():
+    bb = relax.BlockBuilder()
+    cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool"))
+    cond1 = relax.Var("cond", R.Tensor("bool", ndim=5))
+    cond2 = relax.Var("cond", R.Tensor("bool"))
+    x0 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=4))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((5, 1, 3, 2)))
+    x4 = relax.Var("x", R.Tensor(ndim=4))
+    x5 = relax.Var("x", R.Tensor())
+    y0 = relax.Var("y", R.Tensor((4, 3, 1), "float32"))
+    y1 = relax.Var("y", R.Tensor("float32", ndim=3))
+    y2 = relax.Var("y", R.Tensor("float32"))
+    y3 = relax.Var("y", R.Tensor((4, 3, 1)))
+    y4 = relax.Var("y", R.Tensor(ndim=3))
+    y5 = relax.Var("y", R.Tensor())
+
+    _check_inference(
+        bb, relax.op.where(cond0, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 
2), "float32")
+    )
+    _check_inference(
+        bb, relax.op.where(cond0, x1, y0), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(bb, relax.op.where(cond0, x2, y0), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(
+        bb, relax.op.where(cond0, x3, y0), relax.TensorStructInfo((6, 5, 4, 3, 
2), dtype="")
+    )
+    _check_inference(bb, relax.op.where(cond0, x4, y0), 
relax.TensorStructInfo(dtype="", ndim=5))
+    _check_inference(bb, relax.op.where(cond0, x5, y0), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(
+        bb, relax.op.where(cond0, x1, y1), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(bb, relax.op.where(cond0, x2, y1), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.where(cond0, x3, y1), 
relax.TensorStructInfo(dtype="", ndim=5))
+    _check_inference(bb, relax.op.where(cond0, x4, y1), 
relax.TensorStructInfo(dtype="", ndim=5))
+    _check_inference(bb, relax.op.where(cond0, x5, y1), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(bb, relax.op.where(cond0, x2, y2), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.where(cond0, x3, y2), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(bb, relax.op.where(cond0, x4, y2), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(bb, relax.op.where(cond0, x5, y2), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(
+        bb, relax.op.where(cond0, x3, y3), relax.TensorStructInfo((6, 5, 4, 3, 
2), dtype="")
+    )
+    _check_inference(bb, relax.op.where(cond0, x4, y3), 
relax.TensorStructInfo(dtype="", ndim=5))
+    _check_inference(bb, relax.op.where(cond0, x5, y3), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(bb, relax.op.where(cond0, x4, y4), 
relax.TensorStructInfo(dtype="", ndim=5))
+    _check_inference(bb, relax.op.where(cond0, x5, y4), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(bb, relax.op.where(cond0, x5, y5), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(
+        bb, relax.op.where(cond1, x0, y0), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(bb, relax.op.where(cond1, x2, y0), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.where(cond2, x0, y0), 
relax.TensorStructInfo(dtype="float32"))
+
+
+def test_where_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    c = tir.Var("c", "int64")
+    d0 = tir.Var("d", "int64")
+    d1 = tir.Var("d", "int64")
+    e = tir.Var("e", "int64")
+    cond = relax.Var("cond", R.Tensor((a, b, 1, d0, 1), "bool"))
+    x0 = relax.Var("x", R.Tensor((b, 1, d0, e), "float32"))
+    x1 = relax.Var("x", R.Tensor((b, 1, d1, e), "float32"))
+    x2 = relax.Var("x", R.Tensor((b, 1, d0, e)))
+    y0 = relax.Var("y", R.Tensor((c, d0, 1), "float32"))
+    y1 = relax.Var("y", R.Tensor((c, d0, 1)))
+
+    _check_inference(
+        bb, relax.op.where(cond, x0, y0), relax.TensorStructInfo((a, b, c, d0, 
e), "float32")
+    )
+    _check_inference(
+        bb, relax.op.where(cond, x1, y0), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(
+        bb, relax.op.where(cond, x2, y0), relax.TensorStructInfo((a, b, c, d0, 
e), dtype="")
+    )
+    _check_inference(
+        bb, relax.op.where(cond, x0, y1), relax.TensorStructInfo((a, b, c, d0, 
e), dtype="")
+    )
+    _check_inference(bb, relax.op.where(cond, x1, y1), 
relax.TensorStructInfo(dtype="", ndim=5))
+    _check_inference(
+        bb, relax.op.where(cond, x2, y1), relax.TensorStructInfo((a, b, c, d0, 
e), dtype="")
+    )
+
+
+def test_where_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    scond0 = relax.Var("scond", relax.ShapeStructInfo((6, 5, 1, 3, 1)))
+    scond1 = relax.Var("scond", relax.ShapeStructInfo(ndim=5))
+    scond2 = relax.Var("scond", relax.ShapeStructInfo())
+    sx0 = relax.Var("sx", relax.ShapeStructInfo((5, 1, 3, 2)))
+    sx1 = relax.Var("sx", relax.ShapeStructInfo(ndim=4))
+    sx2 = relax.Var("sx", relax.ShapeStructInfo())
+    sy0 = relax.Var("sy", relax.ShapeStructInfo((4, 3, 1)))
+    sy1 = relax.Var("sy", relax.ShapeStructInfo(ndim=3))
+    sy2 = relax.Var("sy", relax.ShapeStructInfo())
+    s0 = relax.Var("s", relax.ShapeStructInfo((6, 5, 4, 3, 2)))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=5))
+    s2 = relax.Var("s", relax.ShapeStructInfo())
+    cond0 = relax.Var("cond", relax.TensorStructInfo(scond0, "bool"))
+    cond1 = relax.Var("cond", relax.TensorStructInfo(scond1, "bool"))
+    cond2 = relax.Var("cond", relax.TensorStructInfo(scond2, "bool"))
+    cond3 = relax.Var("cond", relax.TensorStructInfo(s0, "bool"))
+    cond4 = relax.Var("cond", relax.TensorStructInfo(s1, "bool"))
+    cond5 = relax.Var("cond", relax.TensorStructInfo(s2, "bool"))
+    x0 = relax.Var("x", relax.TensorStructInfo(sx0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(sx1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(sx2, "float32"))
+    x3 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x4 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x5 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+    y0 = relax.Var("y", relax.TensorStructInfo(sy0, "float32"))
+    y1 = relax.Var("y", relax.TensorStructInfo(sy1, "float32"))
+    y2 = relax.Var("y", relax.TensorStructInfo(sy2, "float32"))
+    y3 = relax.Var("y", relax.TensorStructInfo(s0, "float32"))
+    y4 = relax.Var("y", relax.TensorStructInfo(s1, "float32"))
+    y5 = relax.Var("y", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(
+        bb, relax.op.where(cond0, x0, y0), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(
+        bb, relax.op.where(cond0, x0, y1), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(bb, relax.op.where(cond0, x0, y2), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(
+        bb, relax.op.where(cond0, x1, y1), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(bb, relax.op.where(cond0, x1, y2), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.where(cond0, x2, y2), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(
+        bb, relax.op.where(cond1, x1, y1), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(bb, relax.op.where(cond1, x1, y2), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.where(cond1, x2, y2), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.where(cond2, x2, y2), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.where(cond3, x3, y3), 
relax.TensorStructInfo(s0, "float32"))
+    _check_inference(
+        bb, relax.op.where(cond3, x3, y4), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(
+        bb, relax.op.where(cond3, x4, y3), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(
+        bb, relax.op.where(cond4, x3, y3), 
relax.TensorStructInfo(dtype="float32", ndim=5)
+    )
+    _check_inference(bb, relax.op.where(cond4, x4, y4), 
relax.TensorStructInfo(s1, "float32"))
+    _check_inference(bb, relax.op.where(cond4, x4, y5), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.where(cond4, x5, y4), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.where(cond5, x4, y4), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.where(cond5, x5, y5), 
relax.TensorStructInfo(s2, "float32"))
+
+
+def test_where_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    cond = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool"))
+    x0 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float16"))
+    y0 = relax.Var("y", R.Tensor((4, 3, 1), "float16"))
+    x1 = relax.Var("x", R.Tensor((5, 1, 3, 2), "int8"))
+    y1 = relax.Var("y", R.Tensor((4, 3, 1), "int8"))
+    x2 = relax.Var("x", R.Tensor((5, 1, 3, 2), "int32"))
+    y2 = relax.Var("y", R.Tensor((4, 3, 1), "int32"))
+
+    _check_inference(
+        bb, relax.op.where(cond, x0, y0), relax.TensorStructInfo((6, 5, 4, 3, 
2), "float16")
+    )
+    _check_inference(
+        bb, relax.op.where(cond, x1, y1), relax.TensorStructInfo((6, 5, 4, 3, 
2), "int8")
+    )
+    _check_inference(
+        bb, relax.op.where(cond, x2, y2), relax.TensorStructInfo((6, 5, 4, 3, 
2), "int32")
+    )
+
+
+def test_where_infer_struct_info_cond_not_boolean():
+    bb = relax.BlockBuilder()
+    cond0 = relax.Var("cond", R.Tensor((2, 3), "float32"))
+    cond1 = relax.Var("cond", R.Tensor((2, 3)))
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    y = relax.Var("y", R.Tensor((2, 3), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.where(cond0, x, y))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.where(cond1, x, y))
+
+
+def test_where_infer_struct_info_shape_unequal_const_int():
+    bb = relax.BlockBuilder()
+    cond0 = relax.Var("cond", R.Tensor((6, 5, 1, 4, 1), "bool"))
+    cond1 = relax.Var("cond", R.Tensor((6, 5, 1, 3, 1), "bool"))
+    x0 = relax.Var("x", R.Tensor((5, 1, 4, 2), "float32"))
+    x1 = relax.Var("x", R.Tensor((5, 1, 3, 2), "float32"))
+    y0 = relax.Var("y", R.Tensor((4, 4, 1), "float32"))
+    y1 = relax.Var("y", R.Tensor((4, 3, 1), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.where(cond0, x1, y1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.where(cond1, x0, y1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.where(cond1, x1, y0))
+
+
+def test_where_infer_struct_info_dtype_mismatch():
+    bb = relax.BlockBuilder()
+    cond = relax.Var("cond", R.Tensor((2, 3), "bool"))
+    x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
+    y0 = relax.Var("y", R.Tensor((2, 3), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
+    y1 = relax.Var("y", R.Tensor((2, 3), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.where(cond, x0, y0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.where(cond, x1, y1))
+
+
+def test_where_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    cond0 = relax.Var("cond", relax.ShapeStructInfo((2, 3)))
+    cond1 = relax.Var("cond", R.Tensor((2, 3), "bool"))
+    x0 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32")))
+    x1 = relax.Var("x", R.Tensor((2, 3), "float32"))
+    y0 = relax.Var("y", relax.TupleStructInfo([R.Tensor((2, 3), "float32")]))
+    y1 = relax.Var("y", R.Tensor((2, 3), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.where(cond0, x1, y1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.where(cond1, x0, y1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.where(cond1, x1, y0))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_search.py 
b/tests/python/relax/test_tvmscript_parser_op_search.py
new file mode 100644
index 0000000000..a8eaa814aa
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_parser_op_search.py
@@ -0,0 +1,60 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import IRModule, relax
+from tvm.script import relax as R
+
+
+def _check(
+    parsed: Union[relax.Function, IRModule],
+    expect: Optional[Union[relax.Function, IRModule]],
+):
+    test = parsed.script(show_meta=True)
+    roundtrip_mod = tvm.script.from_source(test)
+    tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
+    if expect:
+        tvm.ir.assert_structural_equal(parsed, expect)
+
+
+def test_where():
+    @R.function
+    def foo(
+        condition: R.Tensor((2, 1), "bool"),
+        x: R.Tensor((2, 3), "float32"),
+        y: R.Tensor((1, 3), "float32"),
+    ) -> R.Tensor((2, 3), "float32"):
+        gv: R.Tensor((2, 3), "float32") = R.where(condition, x, y)
+        return gv
+
+    bb = relax.BlockBuilder()
+    condition = relax.Var("condition", R.Tensor((2, 1), "bool"))
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    y = relax.Var("y", R.Tensor((1, 3), "float32"))
+    with bb.function("foo", [condition, x, y]):
+        gv = bb.emit(relax.op.where(condition, x, y))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to