This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch unity-staging in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 15800607f50221e4cff9d7d564db2b11da587a35 Author: Ruihang Lai <[email protected]> AuthorDate: Tue Feb 14 15:07:33 2023 -0500 [Unity] Relax op: linear algebra (#13988) This PR is about the high-level tensor computation operators in Relax. This PR includes the linear algebra operators. Co-authored-by: Siyuan Fneg <[email protected]> --- include/tvm/relax/attrs/linear_algebra.h | 44 ++++ python/tvm/relax/op/__init__.py | 1 + python/tvm/relax/op/linear_algebra.py | 90 ++++++++ python/tvm/relax/op/op_attrs.py | 5 + python/tvm/script/ir_builder/relax/ir.py | 4 + src/relax/op/tensor/linear_algebra.cc | 123 +++++++++++ src/relax/op/tensor/linear_algebra.h | 49 +++++ tests/python/relax/test_op_linear_algebra.py | 244 +++++++++++++++++++++ .../test_tvmscript_parser_op_linear_algebra.py | 80 +++++++ 9 files changed, 640 insertions(+) diff --git a/include/tvm/relax/attrs/linear_algebra.h b/include/tvm/relax/attrs/linear_algebra.h new file mode 100644 index 0000000000..4b0e04298c --- /dev/null +++ b/include/tvm/relax/attrs/linear_algebra.h @@ -0,0 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file tvm/relax/attrs/linear_algebra.h + * \brief Attributes for linear algebra operators. + */ +#ifndef TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_ +#define TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_ + +#include <tvm/relax/expr.h> + +namespace tvm { +namespace relax { + +/*! \brief Attributes for matmul operator */ +struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> { + DataType out_dtype; + + TVM_DECLARE_ATTRS(MatmulAttrs, "relax.attrs.MatmulAttrs") { + TVM_ATTR_FIELD(out_dtype).describe("The data type of the output tensor"); + } +}; // struct MatmulAttrs + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_ATTRS_LINEAR_ALGEBRA_H_ diff --git a/python/tvm/relax/op/__init__.py b/python/tvm/relax/op/__init__.py index 97d08c0946..4b2f990eaa 100644 --- a/python/tvm/relax/op/__init__.py +++ b/python/tvm/relax/op/__init__.py @@ -23,6 +23,7 @@ from .binary import * from .create import * from .datatype import * from .index import * +from .linear_algebra import * from .manipulate import * from .op_attrs import * from .statistical import * diff --git a/python/tvm/relax/op/linear_algebra.py b/python/tvm/relax/op/linear_algebra.py new file mode 100644 index 0000000000..940861a972 --- /dev/null +++ b/python/tvm/relax/op/linear_algebra.py @@ -0,0 +1,90 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=invalid-name +"""Relax linear algebra operators""" +from typing import Optional, Union + +from tvm import DataType + +from ..expr import Expr +from . import _ffi_api +from .manipulate import permute_dims + + +def matmul(x1: Expr, x2: Expr, out_dtype: Optional[Union[str, DataType]] = None) -> Expr: + """General matrix multiplication of two tensors, with broadcasting on batched dimensions. + + The semantics and output shape deduction rule is specified as + https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html. + + Parameters + ---------- + x1 : relax.Expr + The first input tensor. + + x2 : relax.Expr + The second input tensor. + + out_dtype: Optional[Union[str, DataType]] + The data type of the matmul result. + When it is not specified, the output dtype will be the the same as input dtype. + + Returns + ------- + result : relax.Expr + The computed result. + """ + return _ffi_api.matmul(x1, x2, out_dtype) # type: ignore + + +def linear( + data: Expr, + weight: Expr, + bias: Optional[Expr] = None, + out_dtype: Optional[Union[str, DataType]] = None, +) -> Expr: + """Applies a linear transformation to the incoming data: y = xA^T + b + + Parameters + ---------- + data : relax.Expr + The input data. + + weight : relax.Expr + The weight tensor. + + bias : Optional[Expr] + The bias tensor. + + out_dtype: Optional[Union[str, DataType]] + The data type of the matmul result. + When it is not specified, the output dtype will be the the same as input dtype. + + Notes + ----- + Relax does not regard the Linear Op as a primitive Op, + while combine the transpose, matmul and add op to implement it. + + Returns + ------- + result : relax.Expr + The computed result. + """ + + # Since weight can be 1D or 2D, we use `axes=None` to support both cases. + x = matmul(data, permute_dims(weight, axes=None), out_dtype=out_dtype) + return x + bias if bias is not None else x diff --git a/python/tvm/relax/op/op_attrs.py b/python/tvm/relax/op/op_attrs.py index ac6714d940..3a7ed427f9 100644 --- a/python/tvm/relax/op/op_attrs.py +++ b/python/tvm/relax/op/op_attrs.py @@ -44,6 +44,11 @@ class StridedSliceAttrs(Attrs): """Attributes used in strided_slice operator""" +@tvm._ffi.register_object("relax.attrs.MatmulAttrs") +class MatmulAttrs(Attrs): + """Attributes for matmul operator""" + + @tvm._ffi.register_object("relax.attrs.Conv2DAttrs") class Conv2DAttrs(Attrs): """Attributes for nn.conv2d""" diff --git a/python/tvm/script/ir_builder/relax/ir.py b/python/tvm/script/ir_builder/relax/ir.py index 118790372a..9f5fe03dec 100644 --- a/python/tvm/script/ir_builder/relax/ir.py +++ b/python/tvm/script/ir_builder/relax/ir.py @@ -63,8 +63,10 @@ from tvm.relax.op import ( isnan, less, less_equal, + linear, log, make_closure, + matmul, max, mean, memory, @@ -504,8 +506,10 @@ __all__ = [ "isnan", "less", "less_equal", + "linear", "log", "make_closure", + "matmul", "max", "mean", "memory", diff --git a/src/relax/op/tensor/linear_algebra.cc b/src/relax/op/tensor/linear_algebra.cc new file mode 100644 index 0000000000..50b53d0c8e --- /dev/null +++ b/src/relax/op/tensor/linear_algebra.cc @@ -0,0 +1,123 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file linear_algebra.cc + * \brief Linear algebra operators. + */ + +#include "linear_algebra.h" + +#include <algorithm> +#include <utility> + +namespace tvm { +namespace relax { + +/* relax.matmul */ +TVM_REGISTER_NODE_TYPE(MatmulAttrs); + +Expr matmul(Expr x1, Expr x2, DataType out_dtype) { + ObjectPtr<MatmulAttrs> attrs = make_object<MatmulAttrs>(); + attrs->out_dtype = out_dtype; + + static const Op& op = Op::Get("relax.matmul"); + return Call(op, {std::move(x1), std::move(x2)}, Attrs(attrs), {}); +} + +TVM_REGISTER_GLOBAL("relax.op.matmul").set_body_typed(matmul); + +StructInfo InferStructInfoMatmul(const Call& call, const BlockBuilder& ctx) { + Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx); + TensorStructInfo x1_sinfo = input_sinfo[0]; + TensorStructInfo x2_sinfo = input_sinfo[1]; + + const auto* attrs = call->attrs.as<MatmulAttrs>(); + DataType out_dtype = attrs->out_dtype.is_void() + ? InferBinaryArithOpOutDtype(call, ctx, x1_sinfo, x2_sinfo) + : attrs->out_dtype; + + if (x1_sinfo->IsUnknownNdim() || x2_sinfo->IsUnknownNdim()) { + return TensorStructInfo(out_dtype, kUnknownNDim); + } + int x1_ndim = x1_sinfo->ndim; + int x2_ndim = x2_sinfo->ndim; + if (x1_ndim == 0 || x2_ndim == 0) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Matmul requires both inputs to have at least 1 dimension. However, " + << (x1_ndim == 0 ? "x1" : "x2") << " is a 0-rank tensor."); + } + + int x1_prepended = 0; + int x2_appended = 0; + if (x1_ndim == 1) { + x1_ndim = 2; + x1_prepended = 1; + } + if (x2_ndim == 1) { + x2_ndim = 2; + x2_appended = 1; + } + int output_ndim = std::max(x1_ndim, x2_ndim) - x1_prepended - x2_appended; + + const auto* x1_shape = x1_sinfo->shape.as<ShapeExprNode>(); + const auto* x2_shape = x2_sinfo->shape.as<ShapeExprNode>(); + if (x1_shape == nullptr || x2_shape == nullptr) { + return TensorStructInfo(out_dtype, output_ndim); + } + + Array<PrimExpr> x1_shape_prefix{x1_shape->values.begin(), + x1_shape->values.end() - 2 + x1_prepended}; + Array<PrimExpr> x2_shape_prefix{x2_shape->values.begin(), + x2_shape->values.end() - 2 + x2_appended}; + Optional<Array<PrimExpr>> output_shape_prefix = + InferBinaryBroadcastShape(call, ctx, x1_shape_prefix, x2_shape_prefix); + if (!output_shape_prefix.defined()) { + return TensorStructInfo(out_dtype, output_ndim); + } + + arith::Analyzer* analyzer = ctx->GetAnalyzer(); + PrimExpr x1_reduction_length = x1_shape->values[x1_sinfo->ndim - 1]; + PrimExpr x2_reduction_length = x2_shape->values[x2_ndim - 2]; + if (analyzer->CanProve(x1_reduction_length != x2_reduction_length)) { + ctx->ReportFatal(Diagnostic::Error(call) + << "Matmul requires the reduction length of x1 and x2 to be equal. However, " + "the reduction lengths of x1 and x2 are " + << x1_reduction_length << " and " << x2_reduction_length << " respectively."); + } + + Array<PrimExpr> output_shape = output_shape_prefix.value(); + if (!x1_prepended) { + output_shape.push_back(x1_shape->values[x1_ndim - 2]); + } + if (!x2_appended) { + output_shape.push_back(x2_shape->values[x2_ndim - 1]); + } + ICHECK_EQ(static_cast<int>(output_shape.size()), output_ndim); + return TensorStructInfo(ShapeExpr(output_shape), out_dtype); +} + +TVM_REGISTER_OP("relax.matmul") + .set_num_inputs(2) + .add_argument("x1", "Tensor", "The first input tensor.") + .add_argument("x2", "Tensor", "The second input tensor.") + .set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoMatmul); + +} // namespace relax +} // namespace tvm diff --git a/src/relax/op/tensor/linear_algebra.h b/src/relax/op/tensor/linear_algebra.h new file mode 100644 index 0000000000..af614c1f30 --- /dev/null +++ b/src/relax/op/tensor/linear_algebra.h @@ -0,0 +1,49 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file linear_algebra.h + * \brief The functions to make Relax linear algebra operator calls. + */ +#ifndef TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_ +#define TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_ + +#include <tvm/relax/attrs/linear_algebra.h> + +#include "../op_common.h" + +namespace tvm { +namespace relax { + +/*! + * \brief General matrix multiplication of two tensors. + * The semantics and output shape deduction rule is specified as + * https://data-apis.org/array-api/latest/API_specification/generated/array_api.matmul.html. + * \param x1 The first input tensor. + * \param x2 The second input tensor. + * \param out_dtype The data type of the matmul result. + * When it is not specified, the output dtype will be the the same as input dtype. + * \return The computed result. + */ +Expr matmul(Expr x1, Expr x2, DataType out_dtype); + +} // namespace relax +} // namespace tvm + +#endif // TVM_RELAX_OP_TENSOR_LINEAR_ALGEBRA_H_ diff --git a/tests/python/relax/test_op_linear_algebra.py b/tests/python/relax/test_op_linear_algebra.py new file mode 100644 index 0000000000..5eb19cf2b4 --- /dev/null +++ b/tests/python/relax/test_op_linear_algebra.py @@ -0,0 +1,244 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import pytest +import tvm +import tvm.testing +from tvm import relax, tir +from tvm import TVMError +from tvm.ir import Op +from tvm.script import relax as R + + +def test_op_correctness(): + x = relax.Var("x", R.Tensor((2, 3), "float32")) + y = relax.Var("y", R.Tensor((3, 4), "float32")) + assert relax.op.matmul(x, y).op == Op.get("relax.matmul") + + +def _check_inference(bb: relax.BlockBuilder, call: relax.Call, expected_sinfo: relax.StructInfo): + ret = bb.normalize(call) + tvm.ir.assert_structural_equal(ret.struct_info, expected_sinfo) + + +def test_matmul_infer_struct_info(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((4,), "float32")) + x2 = relax.Var("x", R.Tensor((2, 3, 5, 4), "float32")) + x3 = relax.Var("x", R.Tensor((2, 1, 4, 5), "float32")) + x4 = relax.Var("x", R.Tensor((2, 1, 4, 5))) + x5 = relax.Var("x", R.Tensor("float32")) + x6 = relax.Var("x", R.Tensor((2, 1, 4, 5), "float16")) + y0 = relax.Var("y", R.Tensor((4, 5), "float32")) + y1 = relax.Var("y", R.Tensor((4,), "float32")) + y2 = relax.Var("y", R.Tensor((2, 3, 4, 5), "float32")) + y3 = relax.Var("y", R.Tensor((6, 1, 3, 5, 7), "float32")) + y4 = relax.Var("y", R.Tensor("float32", ndim=5)) + y5 = relax.Var("y", R.Tensor()) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 5), "float32")) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((2, 3, 5), "float32")) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((2, 3, 5), "float32")) + _check_inference( + bb, relax.op.matmul(x3, y3), relax.TensorStructInfo((6, 2, 3, 4, 7), "float32") + ) + _check_inference(bb, relax.op.matmul(x4, y3), relax.TensorStructInfo((6, 2, 3, 4, 7), "")) + _check_inference(bb, relax.op.matmul(x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5)) + _check_inference(bb, relax.op.matmul(x5, y3), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.matmul(x3, y5), relax.TensorStructInfo(dtype="")) + _check_inference( + bb, + relax.op.matmul(x3, y3, out_dtype="float16"), + relax.TensorStructInfo((6, 2, 3, 4, 7), "float16"), + ) + _check_inference( + bb, + relax.op.matmul(x6, y3, out_dtype="float16"), + relax.TensorStructInfo((6, 2, 3, 4, 7), "float16"), + ) + + +def test_matmul_infer_struct_info_shape_symbolic(): + bb = relax.BlockBuilder() + m = tir.Var("m", "int64") + n = tir.Var("n", "int64") + k0 = tir.Var("k0", "int64") + k1 = tir.Var("k1", "int64") + a = tir.Var("a", "int64") + b = tir.Var("b", "int64") + b1 = tir.Var("b", "int64") + c = tir.Var("c", "int64") + x0 = relax.Var("x", R.Tensor((m, k0), "float32")) + x1 = relax.Var("x", R.Tensor((k0,), "float32")) + x2 = relax.Var("x", R.Tensor((a, b, m, k0), "float32")) + x3 = relax.Var("x", R.Tensor((b, 1, m, k0), "float32")) + x4 = relax.Var("x", R.Tensor((b, 1, m, k1), "float32")) + y0 = relax.Var("y", R.Tensor((k0, n), "float32")) + y1 = relax.Var("y", R.Tensor((k0,), "float32")) + y2 = relax.Var("y", R.Tensor((a, b, k0, n), "float32")) + y3 = relax.Var("y", R.Tensor((a, 1, c, k0, n), "float32")) + y4 = relax.Var("y", R.Tensor((a, b1, c, k0, n), "float32")) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((m, n), "float32")) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((), "float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo((a, b, n), "float32")) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo((a, b, m), "float32")) + _check_inference( + bb, relax.op.matmul(x3, y3), relax.TensorStructInfo((a, b, c, m, n), "float32") + ) + _check_inference( + bb, relax.op.matmul(x4, y3), relax.TensorStructInfo((a, b, c, m, n), "float32") + ) + _check_inference(bb, relax.op.matmul(x3, y4), relax.TensorStructInfo(dtype="float32", ndim=5)) + + +def test_matmul_infer_struct_info_shape_var(): + bb = relax.BlockBuilder() + s0 = relax.Var("s0", relax.ShapeStructInfo(ndim=4)) + s1 = relax.Var("s1", relax.ShapeStructInfo(ndim=3)) + s2 = relax.Var("s3", relax.ShapeStructInfo(ndim=1)) + s3 = relax.Var("s4", relax.ShapeStructInfo(ndim=1)) + s5 = relax.Var("s5", relax.ShapeStructInfo()) + x0 = relax.Var("x", relax.TensorStructInfo(s0, "float32")) + x1 = relax.Var("x", relax.TensorStructInfo(s2, "float32")) + x2 = relax.Var("x", relax.TensorStructInfo(s5, "float32")) + y0 = relax.Var("y", relax.TensorStructInfo(s1, "float32")) + y1 = relax.Var("y", relax.TensorStructInfo(s2, "float32")) + y2 = relax.Var("y", relax.TensorStructInfo(s3, "float32")) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo(dtype="float32", ndim=4)) + _check_inference(bb, relax.op.matmul(x1, y0), relax.TensorStructInfo(dtype="float32", ndim=2)) + _check_inference(bb, relax.op.matmul(x2, y0), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.matmul(x0, y1), relax.TensorStructInfo(dtype="float32", ndim=3)) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo(dtype="float32", ndim=0)) + _check_inference(bb, relax.op.matmul(x2, y1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.matmul(x1, y2), relax.TensorStructInfo(dtype="float32", ndim=0)) + + +def test_matmul_infer_struct_info_more_input_dtype(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float16")) + y0 = relax.Var("y", R.Tensor((4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((3, 4), "int8")) + y1 = relax.Var("y", R.Tensor((4, 5), "int8")) + x2 = relax.Var("x", R.Tensor((3, 4), "int64")) + y2 = relax.Var("y", R.Tensor((4, 5), "int64")) + + _check_inference(bb, relax.op.matmul(x0, y0), relax.TensorStructInfo((3, 5), "float16")) + _check_inference(bb, relax.op.matmul(x1, y1), relax.TensorStructInfo((3, 5), "int8")) + _check_inference(bb, relax.op.matmul(x2, y2), relax.TensorStructInfo((3, 5), "int64")) + + +def test_matmul_infer_struct_info_mixed_precision(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float16")) + y0 = relax.Var("y", R.Tensor((4, 5), "float16")) + x1 = relax.Var("x", R.Tensor((3, 4), "int8")) + y1 = relax.Var("y", R.Tensor((4, 5), "int8")) + x2 = relax.Var("x", R.Tensor((3, 4))) + y2 = relax.Var("y", R.Tensor((4, 5))) + + _check_inference( + bb, + relax.op.matmul(x0, y0, out_dtype="float32"), + relax.TensorStructInfo((3, 5), "float32"), + ) + _check_inference( + bb, relax.op.matmul(x1, y1, out_dtype="int32"), relax.TensorStructInfo((3, 5), "int32") + ) + _check_inference( + bb, + relax.op.matmul(x2, y2, out_dtype="float32"), + relax.TensorStructInfo((3, 5), "float32"), + ) + + +def test_matmul_infer_struct_info_zero_rank_input(): + bb = relax.BlockBuilder() + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((), "float32")) + y0 = relax.Var("y", R.Tensor((4, 5), "float32")) + y1 = relax.Var("y", R.Tensor((), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x0, y1)) + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x1, y0)) + + +def test_matmul_infer_struct_info_not_broadcastable(): + bb = relax.BlockBuilder() + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((2, 8, 3, 5, 6), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x, y)) + + +def test_matmul_infer_struct_info_unequal_reduction_length(): + bb = relax.BlockBuilder() + k = tir.Var("k", "int64") + x0 = relax.Var("x", R.Tensor((3, 4), "float32")) + x1 = relax.Var("x", R.Tensor((3, k), "float32")) + y0 = relax.Var("y", R.Tensor((6, 5), "float32")) + y1 = relax.Var("y", R.Tensor((k + 1, 5), "float32")) + + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x0, y0)) + with pytest.raises(TVMError): + bb.normalize(relax.op.matmul(x1, y1)) + + +def test_linear(): + # Since linear is only a sugar for transpose + matmul + add, + # we only have brief tests here. + bb = relax.BlockBuilder() + x1 = relax.Var("x", R.Tensor((2, 3, 4), "float32")) + x2 = relax.Var("x", R.Tensor("float32")) + w1 = relax.Var("w", R.Tensor((5, 4), "float32")) + w2 = relax.Var("w", R.Tensor((4,), "float32")) + w3 = relax.Var("w", R.Tensor("float32")) + b1 = relax.Var("b", R.Tensor((5,), "float32")) + b2 = relax.Var("b", R.Tensor((), "float32")) + + # Need a scope to normalize non-leaf nodes + with bb.function("func", [x1]): + _check_inference( + bb, relax.op.linear(x1, w1, b1), relax.TensorStructInfo((2, 3, 5), "float32") + ) + _check_inference( + bb, relax.op.linear(x1, w1, b2), relax.TensorStructInfo((2, 3, 5), "float32") + ) + with pytest.raises(TVMError): + bb.normalize(relax.op.linear(x1, w2, b1)) # error on Add with shape (2, 3, 5) and (4,) + _check_inference(bb, relax.op.linear(x1, w2, b2), relax.TensorStructInfo((2, 3), "float32")) + _check_inference(bb, relax.op.linear(x1, w3, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x1, w3, b2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w1, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w1, b2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w2, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w2, b2), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w3, b1), relax.TensorStructInfo(dtype="float32")) + _check_inference(bb, relax.op.linear(x2, w3, b2), relax.TensorStructInfo(dtype="float32")) + + # Fake output + gv = bb.emit_func_output(relax.Tuple([])) + + +if __name__ == "__main__": + tvm.testing.main() diff --git a/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py b/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py new file mode 100644 index 0000000000..1ed7fa9b91 --- /dev/null +++ b/tests/python/relax/test_tvmscript_parser_op_linear_algebra.py @@ -0,0 +1,80 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from typing import Optional, Union + +import tvm +import tvm.script +import tvm.testing +from tvm import IRModule, relax +from tvm.script import relax as R + + +def _check( + parsed: Union[relax.Function, IRModule], + expect: Optional[Union[relax.Function, IRModule]], +): + test = parsed.script(show_meta=True) + roundtrip_mod = tvm.script.from_source(test) + tvm.ir.assert_structural_equal(parsed, roundtrip_mod) + if expect: + tvm.ir.assert_structural_equal(parsed, expect) + + +def test_matmul(): + @R.function + def foo( + x: R.Tensor((2, 3, 4, 5), "float32"), y: R.Tensor((6, 2, 3, 5, 7), "float32") + ) -> R.Tensor((6, 2, 3, 4, 7), "float32"): + gv: R.Tensor((6, 2, 3, 4, 7), "float32") = R.matmul(x, y) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + y = relax.Var("y", R.Tensor((6, 2, 3, 5, 7), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, y]): + gv = bb.emit(relax.op.matmul(x, y)) + bb.emit_func_output(gv) + + _check(foo, bb.get()["foo"]) + + +def test_linear(): + @R.function + def foo( + x: R.Tensor((2, 3, 4, 5), "float32"), + w: R.Tensor((3, 5), "float32"), + bias: R.Tensor((3,), "float32"), + ): + gv = R.linear(x, w, bias) + return gv + + x = relax.Var("x", R.Tensor((2, 3, 4, 5), "float32")) + w = relax.Var("y", R.Tensor((3, 5), "float32")) + bias = relax.Var("bias", R.Tensor((3,), "float32")) + bb = relax.BlockBuilder() + with bb.function("foo", [x, w, bias]): + w_T = bb.emit(relax.op.permute_dims(w, axes=None)) + matmul = bb.emit(relax.op.matmul(x, w_T)) + out = matmul + bias + bb.emit_func_output(out) + + _check(foo, bb.get()["foo"]) + + +if __name__ == "__main__": + tvm.testing.main()
