This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity-staging
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit c3dfa324f98ce0930e04647fe581ae7d6450f0f4
Author: Ruihang Lai <[email protected]>
AuthorDate: Tue Feb 14 15:07:33 2023 -0500

    [Unity] Relax op: linear algebra (#13988)
    
    This PR is about the high-level tensor computation operators in Relax.
    
    This PR includes the linear algebra operators.
    
    Co-authored-by: Siyuan Fneg <[email protected]>
---
 include/tvm/relax/attrs/linear_algebra.h           |  44 ++++
 python/tvm/relax/op/__init__.py                    |   1 +
 python/tvm/relax/op/linear_algebra.py              |  90 ++++++++
 python/tvm/relax/op/op_attrs.py                    |   5 +
 python/tvm/script/ir_builder/relax/ir.py           |   4 +
 src/relax/op/tensor/linear_algebra.cc              | 123 +++++++++++
 src/relax/op/tensor/linear_algebra.h               |  49 +++++
 tests/python/relax/test_op_linear_algebra.py       | 244 +++++++++++++++++++++
 .../test_tvmscript_parser_op_linear_algebra.py     |  80 +++++++
 9 files changed, 640 insertions(+)

diff --git a/include/tvm/relax/attrs/linear_algebra.h 
b/include/tvm/relax/attrs/linear_algebra.h
new file mode 100644
index 0000000000..4b0e04298c
--- /dev/null
+++ b/include/tvm/relax/attrs/linear_algebra.h
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/attrs/linear_algebra.h
+ * \brief Attributes for linear algebra operators.
+ */
+#ifndef TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_
+#define TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_
+
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief Attributes for matmul operator */
+struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
+  DataType out_dtype;
+
+  TVM_DECLARE_ATTRS(MatmulAttrs, "relax.attrs.MatmulAttrs") {
+    TVM_ATTR_FIELD(out_dtype).describe("The data type of the output tensor");
+  }
+};  // struct MatmulAttrs
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_
diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py
index 97d08c0946..4b2f990eaa 100644
--- a/python/tvm/relax/op/__init__.py
+++ b/python/tvm/relax/op/__init__.py
@@ -23,6 +23,7 @@ from .binary import *
 from .create import *
 from .datatype import *
 from .index import *
+from .linear_algebra import *
 from .manipulate import *
 from .op_attrs import *
 from .statistical import *
diff --git a/python/tvm/relax/op/linear_algebra.py 
b/python/tvm/relax/op/linear_algebra.py
new file mode 100644
index 0000000000..940861a972
--- /dev/null
+++ b/python/tvm/relax/op/linear_algebra.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Relax linear algebra operators"""
+from typing import Optional, Union
+
+from tvm import DataType
+
+from ..expr import Expr
+from . import _ffi_api
+from .manipulate import permute_dims
+
+
+def matmul(x1: Expr, x2: Expr, out_dtype: Optional[Union[str, DataType]] = 
None) -> Expr:
+    """General matrix multiplication of two tensors, with broadcasting on 
batched dimensions.
+
+    The semantics and output shape deduction rule is specified as
+    
https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html.
+
+    Parameters
+    ----------
+    x1 : relax.Expr
+        The first input tensor.
+
+    x2 : relax.Expr
+        The second input tensor.
+
+    out_dtype: Optional[Union[str, DataType]]
+        The data type of the matmul result.
+        When it is not specified, the output dtype will be the the same as 
input dtype.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+    return _ffi_api.matmul(x1, x2, out_dtype)  # type: ignore
+
+
+def linear(
+    data: Expr,
+    weight: Expr,
+    bias: Optional[Expr] = None,
+    out_dtype: Optional[Union[str, DataType]] = None,
+) -> Expr:
+    """Applies a linear transformation to the incoming data: y = xA^T + b
+
+    Parameters
+    ----------
+    data : relax.Expr
+        The input data.
+
+    weight : relax.Expr
+        The weight tensor.
+
+    bias : Optional[Expr]
+        The bias tensor.
+
+    out_dtype: Optional[Union[str, DataType]]
+        The data type of the matmul result.
+        When it is not specified, the output dtype will be the the same as 
input dtype.
+
+    Notes
+    -----
+    Relax does not regard the Linear Op as a primitive Op,
+    while combine the transpose, matmul and add op to implement it.
+
+    Returns
+    -------
+    result : relax.Expr
+        The computed result.
+    """
+
+    # Since weight can be 1D or 2D, we use `axes=None` to support both cases.
+    x = matmul(data, permute_dims(weight, axes=None), out_dtype=out_dtype)
+    return x + bias if bias is not None else x
diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py
index ac6714d940..3a7ed427f9 100644
--- a/python/tvm/relax/op/op_attrs.py
+++ b/python/tvm/relax/op/op_attrs.py
@@ -44,6 +44,11 @@ class StridedSliceAttrs(Attrs):
     """Attributes used in strided_slice operator"""
 
 
+@tvm._ffi.register_object("relax.attrs.MatmulAttrs")
+class MatmulAttrs(Attrs):
+    """Attributes for matmul operator"""
+
+
 @tvm._ffi.register_object("relax.attrs.Conv2DAttrs")
 class Conv2DAttrs(Attrs):
     """Attributes for nn.conv2d"""
diff --git a/python/tvm/script/ir_builder/relax/ir.py 
b/python/tvm/script/ir_builder/relax/ir.py
index 118790372a..9f5fe03dec 100644
--- a/python/tvm/script/ir_builder/relax/ir.py
+++ b/python/tvm/script/ir_builder/relax/ir.py
@@ -63,8 +63,10 @@ from tvm.relax.op import (
     isnan,
     less,
     less_equal,
+    linear,
     log,
     make_closure,
+    matmul,
     max,
     mean,
     memory,
@@ -504,8 +506,10 @@ __all__ = [
     "isnan",
     "less",
     "less_equal",
+    "linear",
     "log",
     "make_closure",
+    "matmul",
     "max",
     "mean",
     "memory",
diff --git a/src/relax/op/tensor/linear_algebra.cc 
b/src/relax/op/tensor/linear_algebra.cc
new file mode 100644
index 0000000000..50b53d0c8e
--- /dev/null
+++ b/src/relax/op/tensor/linear_algebra.cc
@@ -0,0 +1,123 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file linear_algebra.cc
+ * \brief Linear algebra operators.
+ */
+
+#include "linear_algebra.h"
+
+#include <algorithm>
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+/* relax.matmul */
+TVM_REGISTER_NODE_TYPE(MatmulAttrs);
+
+Expr matmul(Expr x1, Expr x2, DataType out_dtype) {
+  ObjectPtr<MatmulAttrs> attrs = make_object<MatmulAttrs>();
+  attrs->out_dtype = out_dtype;
+
+  static const Op& op = Op::Get("relax.matmul");
+  return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {});
+}
+
+TVM_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul);
+
+StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) {
+  Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
+  TensorStructInfo x1_sinfo = input_sinfo[0];
+  TensorStructInfo x2_sinfo = input_sinfo[1];
+
+  const auto* attrs = call->attrs.as<MatmulAttrs>();
+  DataType out_dtype = attrs->out_dtype.is_void()
+                           ? InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, 
x2_sinfo)
+                           : attrs->out_dtype;
+
+  if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) {
+    return TensorStructInfo(out_dtype, kUnknownNDim);
+  }
+  int x1_ndim = x1_sinfo->ndim;
+  int x2_ndim = x2_sinfo->ndim;
+  if (x1_ndim == 0 || x2_ndim == 0) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Matmul requires both inputs to have at least 1 
dimension. However, "
+                     << (x1_ndim == 0 ? "x1" : "x2") << " is a 0-rank 
tensor.");
+  }
+
+  int x1_prepended = 0;
+  int x2_appended = 0;
+  if (x1_ndim == 1) {
+    x1_ndim = 2;
+    x1_prepended = 1;
+  }
+  if (x2_ndim == 1) {
+    x2_ndim = 2;
+    x2_appended = 1;
+  }
+  int output_ndim = std::max(x1_ndim, x2_ndim) - x1_prepended - x2_appended;
+
+  const auto* x1_shape = x1_sinfo->shape.as<ShapeExprNode>();
+  const auto* x2_shape = x2_sinfo->shape.as<ShapeExprNode>();
+  if (x1_shape == nullptr || x2_shape == nullptr) {
+    return TensorStructInfo(out_dtype, output_ndim);
+  }
+
+  Array<PrimExpr> x1_shape_prefix{x1_shape->values.begin(),
+                                  x1_shape->values.end() - 2 + x1_prepended};
+  Array<PrimExpr> x2_shape_prefix{x2_shape->values.begin(),
+                                  x2_shape->values.end() - 2 + x2_appended};
+  Optional<Array<PrimExpr>> output_shape_prefix =
+      InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix);
+  if (!output_shape_prefix.defined()) {
+    return TensorStructInfo(out_dtype, output_ndim);
+  }
+
+  arith::Analyzer* analyzer = ctx->GetAnalyzer();
+  PrimExpr x1_reduction_length = x1_shape->values[x1_sinfo->ndim - 1];
+  PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2];
+  if (analyzer->CanProve(x1_reduction_length != x2_reduction_length)) {
+    ctx->ReportFatal(Diagnostic::Error(call)
+                     << "Matmul requires the reduction length of x1 and x2 to 
be equal. However, "
+                        "the reduction lengths of x1 and x2 are "
+                     << x1_reduction_length << " and " << x2_reduction_length 
<< " respectively.");
+  }
+
+  Array<PrimExpr> output_shape = output_shape_prefix.value();
+  if (!x1_prepended) {
+    output_shape.push_back(x1_shape->values[x1_ndim - 2]);
+  }
+  if (!x2_appended) {
+    output_shape.push_back(x2_shape->values[x2_ndim - 1]);
+  }
+  ICHECK_EQ(static_cast<int>(output_shape.size()), output_ndim);
+  return TensorStructInfo(ShapeExpr(output_shape), out_dtype);
+}
+
+TVM_REGISTER_OP("relax.matmul")
+    .set_num_inputs(2)
+    .add_argument("x1", "Tensor", "The first input tensor.")
+    .add_argument("x2", "Tensor", "The second input tensor.")
+    .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMatmul);
+
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/op/tensor/linear_algebra.h 
b/src/relax/op/tensor/linear_algebra.h
new file mode 100644
index 0000000000..af614c1f30
--- /dev/null
+++ b/src/relax/op/tensor/linear_algebra.h
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file linear_algebra.h
+ * \brief The functions to make Relax linear algebra operator calls.
+ */
+#ifndef TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_
+#define TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_
+
+#include <tvm/relax/attrs/linear_algebra.h>
+
+#include "../op_common.h"
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief General matrix multiplication of two tensors.
+ * The semantics and output shape deduction rule is specified as
+ * 
https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html.
+ * \param x1 The first input tensor.
+ * \param x2 The second input tensor.
+ * \param out_dtype The data type of the matmul result.
+ * When it is not specified, the output dtype will be the the same as input 
dtype.
+ * \return The computed result.
+ */
+Expr matmul(Expr x1, Expr x2, DataType out_dtype);
+
+}  // namespace relax
+}  // namespace tvm
+
+#endif  // TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_
diff --git a/tests/python/relax/test_op_linear_algebra.py 
b/tests/python/relax/test_op_linear_algebra.py
new file mode 100644
index 0000000000..5eb19cf2b4
--- /dev/null
+++ b/tests/python/relax/test_op_linear_algebra.py
@@ -0,0 +1,244 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import pytest
+import tvm
+import tvm.testing
+from tvm import relax, tir
+from tvm import TVMError
+from tvm.ir import Op
+from tvm.script import relax as R
+
+
+def test_op_correctness():
+    x = relax.Var("x", R.Tensor((2, 3), "float32"))
+    y = relax.Var("y", R.Tensor((3, 4), "float32"))
+    assert relax.op.matmul(x, y).op == Op.get("relax.matmul")
+
+
+def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: 
relax.StructInfo):
+    ret = bb.normalize(call)
+    tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo)
+
+
+def test_matmul_infer_struct_info():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((3, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor((4,), "float32"))
+    x2 = relax.Var("x", R.Tensor((2, 3, 5, 4), "float32"))
+    x3 = relax.Var("x", R.Tensor((2, 1, 4, 5), "float32"))
+    x4 = relax.Var("x", R.Tensor((2, 1, 4, 5)))
+    x5 = relax.Var("x", R.Tensor("float32"))
+    x6 = relax.Var("x", R.Tensor((2, 1, 4, 5), "float16"))
+    y0 = relax.Var("y", R.Tensor((4, 5), "float32"))
+    y1 = relax.Var("y", R.Tensor((4,), "float32"))
+    y2 = relax.Var("y", R.Tensor((2, 3, 4, 5), "float32"))
+    y3 = relax.Var("y", R.Tensor((6, 1, 3, 5, 7), "float32"))
+    y4 = relax.Var("y", R.Tensor("float32", ndim=5))
+    y5 = relax.Var("y", R.Tensor())
+
+    _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 
5), "float32"))
+    _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), 
"float32"))
+    _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((2, 
3, 5), "float32"))
+    _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((2, 
3, 5), "float32"))
+    _check_inference(
+        bb, relax.op.matmul(x3, y3), relax.TensorStructInfo((6, 2, 3, 4, 7), 
"float32")
+    )
+    _check_inference(bb, relax.op.matmul(x4, y3), relax.TensorStructInfo((6, 
2, 3, 4, 7), ""))
+    _check_inference(bb, relax.op.matmul(x3, y4), 
relax.TensorStructInfo(dtype="float32", ndim=5))
+    _check_inference(bb, relax.op.matmul(x5, y3), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.matmul(x3, y5), 
relax.TensorStructInfo(dtype=""))
+    _check_inference(
+        bb,
+        relax.op.matmul(x3, y3, out_dtype="float16"),
+        relax.TensorStructInfo((6, 2, 3, 4, 7), "float16"),
+    )
+    _check_inference(
+        bb,
+        relax.op.matmul(x6, y3, out_dtype="float16"),
+        relax.TensorStructInfo((6, 2, 3, 4, 7), "float16"),
+    )
+
+
+def test_matmul_infer_struct_info_shape_symbolic():
+    bb = relax.BlockBuilder()
+    m = tir.Var("m", "int64")
+    n = tir.Var("n", "int64")
+    k0 = tir.Var("k0", "int64")
+    k1 = tir.Var("k1", "int64")
+    a = tir.Var("a", "int64")
+    b = tir.Var("b", "int64")
+    b1 = tir.Var("b", "int64")
+    c = tir.Var("c", "int64")
+    x0 = relax.Var("x", R.Tensor((m, k0), "float32"))
+    x1 = relax.Var("x", R.Tensor((k0,), "float32"))
+    x2 = relax.Var("x", R.Tensor((a, b, m, k0), "float32"))
+    x3 = relax.Var("x", R.Tensor((b, 1, m, k0), "float32"))
+    x4 = relax.Var("x", R.Tensor((b, 1, m, k1), "float32"))
+    y0 = relax.Var("y", R.Tensor((k0, n), "float32"))
+    y1 = relax.Var("y", R.Tensor((k0,), "float32"))
+    y2 = relax.Var("y", R.Tensor((a, b, k0, n), "float32"))
+    y3 = relax.Var("y", R.Tensor((a, 1, c, k0, n), "float32"))
+    y4 = relax.Var("y", R.Tensor((a, b1, c, k0, n), "float32"))
+
+    _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((m, 
n), "float32"))
+    _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), 
"float32"))
+    _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((a, 
b, n), "float32"))
+    _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((a, 
b, m), "float32"))
+    _check_inference(
+        bb, relax.op.matmul(x3, y3), relax.TensorStructInfo((a, b, c, m, n), 
"float32")
+    )
+    _check_inference(
+        bb, relax.op.matmul(x4, y3), relax.TensorStructInfo((a, b, c, m, n), 
"float32")
+    )
+    _check_inference(bb, relax.op.matmul(x3, y4), 
relax.TensorStructInfo(dtype="float32", ndim=5))
+
+
+def test_matmul_infer_struct_info_shape_var():
+    bb = relax.BlockBuilder()
+    s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4))
+    s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3))
+    s2 = relax.Var("s3", relax.ShapeStructInfo(ndim=1))
+    s3 = relax.Var("s4", relax.ShapeStructInfo(ndim=1))
+    s5 = relax.Var("s5", relax.ShapeStructInfo())
+    x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32"))
+    x1 = relax.Var("x", relax.TensorStructInfo(s2, "float32"))
+    x2 = relax.Var("x", relax.TensorStructInfo(s5, "float32"))
+    y0 = relax.Var("y", relax.TensorStructInfo(s1, "float32"))
+    y1 = relax.Var("y", relax.TensorStructInfo(s2, "float32"))
+    y2 = relax.Var("y", relax.TensorStructInfo(s3, "float32"))
+
+    _check_inference(bb, relax.op.matmul(x0, y0), 
relax.TensorStructInfo(dtype="float32", ndim=4))
+    _check_inference(bb, relax.op.matmul(x1, y0), 
relax.TensorStructInfo(dtype="float32", ndim=2))
+    _check_inference(bb, relax.op.matmul(x2, y0), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.matmul(x0, y1), 
relax.TensorStructInfo(dtype="float32", ndim=3))
+    _check_inference(bb, relax.op.matmul(x1, y1), 
relax.TensorStructInfo(dtype="float32", ndim=0))
+    _check_inference(bb, relax.op.matmul(x2, y1), 
relax.TensorStructInfo(dtype="float32"))
+    _check_inference(bb, relax.op.matmul(x1, y2), 
relax.TensorStructInfo(dtype="float32", ndim=0))
+
+
+def test_matmul_infer_struct_info_more_input_dtype():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((3, 4), "float16"))
+    y0 = relax.Var("y", R.Tensor((4, 5), "float16"))
+    x1 = relax.Var("x", R.Tensor((3, 4), "int8"))
+    y1 = relax.Var("y", R.Tensor((4, 5), "int8"))
+    x2 = relax.Var("x", R.Tensor((3, 4), "int64"))
+    y2 = relax.Var("y", R.Tensor((4, 5), "int64"))
+
+    _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 
5), "float16"))
+    _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((3, 
5), "int8"))
+    _check_inference(bb, relax.op.matmul(x2, y2), relax.TensorStructInfo((3, 
5), "int64"))
+
+
+def test_matmul_infer_struct_info_mixed_precision():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((3, 4), "float16"))
+    y0 = relax.Var("y", R.Tensor((4, 5), "float16"))
+    x1 = relax.Var("x", R.Tensor((3, 4), "int8"))
+    y1 = relax.Var("y", R.Tensor((4, 5), "int8"))
+    x2 = relax.Var("x", R.Tensor((3, 4)))
+    y2 = relax.Var("y", R.Tensor((4, 5)))
+
+    _check_inference(
+        bb,
+        relax.op.matmul(x0, y0, out_dtype="float32"),
+        relax.TensorStructInfo((3, 5), "float32"),
+    )
+    _check_inference(
+        bb, relax.op.matmul(x1, y1, out_dtype="int32"), 
relax.TensorStructInfo((3, 5), "int32")
+    )
+    _check_inference(
+        bb,
+        relax.op.matmul(x2, y2, out_dtype="float32"),
+        relax.TensorStructInfo((3, 5), "float32"),
+    )
+
+
+def test_matmul_infer_struct_info_zero_rank_input():
+    bb = relax.BlockBuilder()
+    x0 = relax.Var("x", R.Tensor((3, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor((), "float32"))
+    y0 = relax.Var("y", R.Tensor((4, 5), "float32"))
+    y1 = relax.Var("y", R.Tensor((), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.matmul(x0, y1))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.matmul(x1, y0))
+
+
+def test_matmul_infer_struct_info_not_broadcastable():
+    bb = relax.BlockBuilder()
+    x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    y = relax.Var("y", R.Tensor((2, 8, 3, 5, 6), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.matmul(x, y))
+
+
+def test_matmul_infer_struct_info_unequal_reduction_length():
+    bb = relax.BlockBuilder()
+    k = tir.Var("k", "int64")
+    x0 = relax.Var("x", R.Tensor((3, 4), "float32"))
+    x1 = relax.Var("x", R.Tensor((3, k), "float32"))
+    y0 = relax.Var("y", R.Tensor((6, 5), "float32"))
+    y1 = relax.Var("y", R.Tensor((k + 1, 5), "float32"))
+
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.matmul(x0, y0))
+    with pytest.raises(TVMError):
+        bb.normalize(relax.op.matmul(x1, y1))
+
+
+def test_linear():
+    # Since linear is only a sugar for transpose + matmul + add,
+    # we only have brief tests here.
+    bb = relax.BlockBuilder()
+    x1 = relax.Var("x", R.Tensor((2, 3, 4), "float32"))
+    x2 = relax.Var("x", R.Tensor("float32"))
+    w1 = relax.Var("w", R.Tensor((5, 4), "float32"))
+    w2 = relax.Var("w", R.Tensor((4,), "float32"))
+    w3 = relax.Var("w", R.Tensor("float32"))
+    b1 = relax.Var("b", R.Tensor((5,), "float32"))
+    b2 = relax.Var("b", R.Tensor((), "float32"))
+
+    # Need a scope to normalize non-leaf nodes
+    with bb.function("func", [x1]):
+        _check_inference(
+            bb, relax.op.linear(x1, w1, b1), relax.TensorStructInfo((2, 3, 5), 
"float32")
+        )
+        _check_inference(
+            bb, relax.op.linear(x1, w1, b2), relax.TensorStructInfo((2, 3, 5), 
"float32")
+        )
+        with pytest.raises(TVMError):
+            bb.normalize(relax.op.linear(x1, w2, b1))  # error on Add with 
shape (2, 3, 5) and (4,)
+        _check_inference(bb, relax.op.linear(x1, w2, b2), 
relax.TensorStructInfo((2, 3), "float32"))
+        _check_inference(bb, relax.op.linear(x1, w3, b1), 
relax.TensorStructInfo(dtype="float32"))
+        _check_inference(bb, relax.op.linear(x1, w3, b2), 
relax.TensorStructInfo(dtype="float32"))
+        _check_inference(bb, relax.op.linear(x2, w1, b1), 
relax.TensorStructInfo(dtype="float32"))
+        _check_inference(bb, relax.op.linear(x2, w1, b2), 
relax.TensorStructInfo(dtype="float32"))
+        _check_inference(bb, relax.op.linear(x2, w2, b1), 
relax.TensorStructInfo(dtype="float32"))
+        _check_inference(bb, relax.op.linear(x2, w2, b2), 
relax.TensorStructInfo(dtype="float32"))
+        _check_inference(bb, relax.op.linear(x2, w3, b1), 
relax.TensorStructInfo(dtype="float32"))
+        _check_inference(bb, relax.op.linear(x2, w3, b2), 
relax.TensorStructInfo(dtype="float32"))
+
+        # Fake output
+        gv = bb.emit_func_output(relax.Tuple([]))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py 
b/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py
new file mode 100644
index 0000000000..1ed7fa9b91
--- /dev/null
+++ b/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py
@@ -0,0 +1,80 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from typing import Optional, Union
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import IRModule, relax
+from tvm.script import relax as R
+
+
+def _check(
+    parsed: Union[relax.Function, IRModule],
+    expect: Optional[Union[relax.Function, IRModule]],
+):
+    test = parsed.script(show_meta=True)
+    roundtrip_mod = tvm.script.from_source(test)
+    tvm.ir.assert_structural_equal(parsed, roundtrip_mod)
+    if expect:
+        tvm.ir.assert_structural_equal(parsed, expect)
+
+
+def test_matmul():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 3, 4, 5), "float32"), y: R.Tensor((6, 2, 3, 5, 7), 
"float32")
+    ) -> R.Tensor((6, 2, 3, 4, 7), "float32"):
+        gv: R.Tensor((6, 2, 3, 4, 7), "float32") = R.matmul(x, y)
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    y = relax.Var("y", R.Tensor((6, 2, 3, 5, 7), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, y]):
+        gv = bb.emit(relax.op.matmul(x, y))
+        bb.emit_func_output(gv)
+
+    _check(foo, bb.get()["foo"])
+
+
+def test_linear():
+    @R.function
+    def foo(
+        x: R.Tensor((2, 3, 4, 5), "float32"),
+        w: R.Tensor((3, 5), "float32"),
+        bias: R.Tensor((3,), "float32"),
+    ):
+        gv = R.linear(x, w, bias)
+        return gv
+
+    x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32"))
+    w = relax.Var("y", R.Tensor((3, 5), "float32"))
+    bias = relax.Var("bias", R.Tensor((3,), "float32"))
+    bb = relax.BlockBuilder()
+    with bb.function("foo", [x, w, bias]):
+        w_T = bb.emit(relax.op.permute_dims(w, axes=None))
+        matmul = bb.emit(relax.op.matmul(x, w_T))
+        out = matmul + bias
+        bb.emit_func_output(out)
+
+    _check(foo, bb.get()["foo"])
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to