This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 3e4af0dc9b742d7c8686d14f5680fd9ce862abcf
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Feb 14 14:55:09 2023 -0500

    [Unity] Relax op: datatype (#13986)
---
 include/tvm/relax/attrs/datatype.h                 |  44 +++++++++
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/{op_attrs.py => datatype.py}   |  31 ++++--
 python/tvm/relax/op/op_attrs.py                    |   5 +
 python/tvm/script/ir_builder/relax/ir.py           |   2 +
 src/relax/op/tensor/datatype.cc                    |  60 ++++++++++++
 src/relax/op/tensor/datatype.h                     |  45 +++++++++
 tests/python/relax/test_op_datatype.py             | 105 +++++++++++++++++++++
 .../relax/test_tvmscript_parser_op_datatype.py     |  54 +++++++++++
 9 files changed, 338 insertions(+), 9 deletions(-)

diff --git a/include/tvm/relax/attrs/datatype.h 
b/include/tvm/relax/attrs/datatype.h
new file mode 100644
index 0000000000..79cb345688
--- /dev/null
+++ b/include/tvm/relax/attrs/datatype.h
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/attrs/datatype.h
+ * \brief Attributes for datatype operators.
+ */
+#ifndef TVM_RELAX_ATTRS_DATATYPE_H_
+#define TVM_RELAX_ATTRS_DATATYPE_H_
+
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Attributes used in astype operator */
+struct AstypeAttrs : public tvm::AttrsNode<AstypeAttrs> {
+  DataType dtype;
+
+  TVM_DECLARE_ATTRS(AstypeAttrs, "relax.attrs.AstypeAttrs") {
+    TVM_ATTR_FIELD(dtype).describe("Target data type");
+  }
+};  // struct AstypeAttrs.
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_ATTRS_DATATYPE_H_
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 3393a5dcae..f3ab9085b8 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -20,6 +20,7 @@
 # Operators
 from .base import *
 from .binary import *
+from .datatype import *
 from .index import *
 from .manipulate import *
 from .op_attrs import *
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/datatype.py
similarity index 60%
copy from python/tvm/relax/op/op_attrs.py
copy to python/tvm/relax/op/datatype.py
index 44cb2cf3a5..5c02776dd7 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/datatype.py
@@ -14,16 +14,29 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""The attributes node used for Relax operators"""
-from tvm.ir import Attrs
-import tvm._ffi
+"""Datatype operators."""
+from typing import Union
 
+from tvm import DataType
 
-@tvm._ffi.register_object("relax.attrs.TakeAttrs")
-class TakeAttrs(Attrs):
-    """Attributes used in take operator"""
+from . import _ffi_api
+from ..expr import Expr
 
 
-@tvm._ffi.register_object("relax.attrs.StridedSliceAttrs")
-class StridedSliceAttrs(Attrs):
-    """Attributes used in strided_slice operator"""
+def astype(x: Expr, dtype: Union[str, DataType]) -> Expr:
+    """Cast input tensor to the given data type.
+
+    Parameters
+    ----------
+    x : relax.Expr
+        The input data to the operator.
+
+    dtype: Union[str, DataType]
+        The target data type
+
+    Returns
+    -------
+    result : relax.Expr
+        The casted result.
+    """
+    return _ffi_api.astype(x, dtype)  # type: ignore
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index 44cb2cf3a5..cb33363944 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -19,6 +19,11 @@ from tvm.ir import Attrs
 import tvm._ffi
 
 
+@tvm._ffi.register_object("relax.attrs.AstypeAttrs")
+class AstypeAttrs(Attrs):
+    """Attributes used in astype operator"""
+
+
 @tvm._ffi.register_object("relax.attrs.TakeAttrs")
 class TakeAttrs(Attrs):
     """Attributes used in take operator"""
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 75a00ea049..aaee0f4e2f 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -31,6 +31,7 @@ from tvm.relax import Call, Expr, ExternFunc, TupleGetItem, 
Var, const
 from tvm.relax.op import (
     add,
     assert_op,
+    astype,
     builtin,
     call_builtin_with_ctx,
     call_tir,
@@ -403,6 +404,7 @@ __all__ = [
     "add",
     "arg",
     "assert_op",
+    "astype",
     "builtin",
     "call_packed",
     "call_tir",
diff --git a/src/relax/op/tensor/datatype.cc b/src/relax/op/tensor/datatype.cc
new file mode 100644
index 0000000000..0c647aa866
--- /dev/null
+++ b/src/relax/op/tensor/datatype.cc
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file datatype.cc
+ * \brief Datatype operators.
+ */
+
+#include "datatype.h"
+
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+/* relax.astype */
+TVM_REGISTER_NODE_TYPE(AstypeAttrs);
+
+Expr astype(Expr x, DataType dtype) {
+  ObjectPtr<AstypeAttrs> attrs = make_object<AstypeAttrs>();
+  attrs->dtype = dtype;
+
+  static const Op& op = Op::Get("relax.astype");
+  return Call(op, {std::move(x)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.astype").set_body_typed(astype);
+
+StructInfo InferStructInfoAstype(const Call& call, const BlockBuilder& ctx) {
+  TensorStructInfo sinfo = GetUnaryInputTensorStructInfo(call, ctx);
+  const auto* attrs = call->attrs.as<AstypeAttrs>();
+  ObjectPtr<TensorStructInfoNode> new_sinfo = 
make_object<TensorStructInfoNode>(*sinfo.get());
+  new_sinfo->dtype = attrs->dtype;
+  return TensorStructInfo(new_sinfo);
+}
+
+TVM_REGISTER_OP("relax.astype")
+    .set_attrs_type<AstypeAttrs>()
+    .set_num_inputs(1)
+    .add_argument("x", "Tensor", "The input tensor")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoAstype);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/tensor/datatype.h b/src/relax/op/tensor/datatype.h
new file mode 100644
index 0000000000..6afa7a50d4
--- /dev/null
+++ b/src/relax/op/tensor/datatype.h
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file datatype.h
+ * \brief The functions to make Relax datatype operator calls.
+ */
+#ifndef TVM_RELAX_OP_TENSOR_DATATYPE_H_
+#define TVM_RELAX_OP_TENSOR_DATATYPE_H_
+
+#include <tvm/relax/attrs/datatype.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Cast input tensor to the given data type.
+ * \param x The input data to the operator.
+ * \param dtype The target data type
+ * \return The casted result.
+ */
+Expr astype(Expr x, DataType dtype);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_TENSOR_DATATYPE_H_
diff --git a/tests/python/relax/test_op_datatype.py 
b/tests/python/relax/test_op_datatype.py
new file mode 100644
index 0000000000..56bbe464cf
--- /dev/null
+++ b/tests/python/relax/test_op_datatype.py
@@ -0,0 +1,105 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op
+from tvm.script import relax as R
+
+
+def test_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    assert relax.op.astype(x, "float16").op == Op.get("relax.astype")
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_astype_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float32"))
+    x1 = relax.Var("x", R.Tensor("float32", ndim=2))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    x3 = relax.Var("x", R.Tensor((2, 3)))
+    x4 = relax.Var("x", R.Tensor(ndim=2))
+    x5 = relax.Var("x", R.Tensor())
+
+    _check_inference(bb, relax.op.astype(x0, "float16"), 
relax.TensorStructInfo((2, 3), "float16"))
+    _check_inference(
+        bb, relax.op.astype(x1, "float16"), 
relax.TensorStructInfo(dtype="float16", ndim=2)
+    )
+    _check_inference(bb, relax.op.astype(x2, "float16"), 
relax.TensorStructInfo(dtype="float16"))
+    _check_inference(bb, relax.op.astype(x3, "float16"), 
relax.TensorStructInfo((2, 3), "float16"))
+    _check_inference(
+        bb, relax.op.astype(x4, "float16"), 
relax.TensorStructInfo(dtype="float16", ndim=2)
+    )
+    _check_inference(bb, relax.op.astype(x5, "float16"), 
relax.TensorStructInfo(dtype="float16"))
+
+
+def test_astype_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    x0 = relax.Var("x", R.Tensor((m, n), "float32"))
+    x1 = relax.Var("x", R.Tensor((m, n)))
+
+    _check_inference(bb, relax.op.astype(x0, "float16"), 
relax.TensorStructInfo((m, n), "float16"))
+    _check_inference(bb, relax.op.astype(x1, "float16"), 
relax.TensorStructInfo((m, n), "float16"))
+
+
+def test_astype_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s", relax.ShapeStructInfo((2, 3)))
+    s1 = relax.Var("s", relax.ShapeStructInfo(ndim=2))
+    s2 = relax.Var("s", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s1, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+
+    _check_inference(bb, relax.op.astype(x0, "float16"), 
relax.TensorStructInfo(s0, "float16"))
+    _check_inference(bb, relax.op.astype(x1, "float16"), 
relax.TensorStructInfo(s1, "float16"))
+    _check_inference(bb, relax.op.astype(x2, "float16"), 
relax.TensorStructInfo(s2, "float16"))
+
+
+def test_astype_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((2, 3), "float16"))
+    x1 = relax.Var("x", R.Tensor((2, 3), "int8"))
+    x2 = relax.Var("x", R.Tensor((2, 3), "int32"))
+
+    _check_inference(bb, relax.op.astype(x0, "float32"), 
relax.TensorStructInfo((2, 3), "float32"))
+    _check_inference(bb, relax.op.astype(x1, "int32"), 
relax.TensorStructInfo((2, 3), "int32"))
+    _check_inference(bb, relax.op.astype(x2, "int8"), 
relax.TensorStructInfo((2, 3), "int8"))
+
+
+def test_astype_infer_struct_info_wrong_input_type():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", relax.ShapeStructInfo((2, 3)))
+    x1 = relax.Var("x", relax.FuncStructInfo([], R.Tensor((2, 3), "float32")))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.astype(x0, "float16"))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.astype(x1, "float16"))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_datatype.py 
b/tests/python/relax/test_tvmscript_parser_op_datatype.py
new file mode 100644
index 0000000000..ec71e868d4
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_parser_op_datatype.py
@@ -0,0 +1,54 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import IRModule, relax
+from tvm.script import relax as R
+
+
+def _check(
+    parsed: Union[relax.Function, IRModule],
+    expect: Optional[Union[relax.Function, IRModule]],
+):
+    test = parsed.script(show_meta=True)
+    roundtrip_mod = tvm.script.from_source(test)
+    tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
+    if expect:
+        tvm.ir.assert_structural_equal(parsed, expect)
+
+
+def test_astype():
+    @R.function
+    def expected(x: R.Tensor((2, 3, 4), "float32")) -> R.Tensor((2, 3, 4), 
"float16"):
+        gv: R.Tensor((2, 3, 4), "float16") = R.astype(x, "float16")
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("main", [x]):
+        gv = bb.emit(relax.op.astype(x, "float16"))
+        bb.emit_func_output(gv)
+
+    _check(expected, bb.get()["main"])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to