This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9bb30d7ea1 [Unity] Relax VM shape lowering pass (#13956)
9bb30d7ea1 is described below
commit 9bb30d7ea13df661e083b92774259c9cc7953c08
Author: Yuchen Jin <[email protected]>
AuthorDate: Fri Feb 10 23:33:46 2023 -0800
[Unity] Relax VM shape lowering pass (#13956)
This PR introduces Relax `FunctionPass` and `DataflowBlockPass` API, and
the `VMShapeLower` pass to lower the shape expression in Relax to TIR functions
and VM shape heap builtin functions.
Co-Authored-by: Ziheng Jiang <[email protected]>
Co-Authored-by: Lesheng Jin <[email protected]>
Co-Authored-by: Altan Haan <[email protected]>
Co-Authored-by: Junru Shao <[email protected]>
Co-Authored-by: Prakalp Srivastava <[email protected]>
Co-Authored-by: Ruihang Lai <[email protected]>
Co-Authored-by: Siyuan Feng <[email protected]>
Co-Authored-by: Steven S. <Lyubomirsky [email protected]>
Co-Authored-by: Sunghyun Park <[email protected]>
Co-Authored-by: Tianqi Chen <[email protected]>
Co-Authored-by: Yong Wu <[email protected]>
---
include/tvm/relax/backend.h | 44 ++
include/tvm/relax/transform.h | 72 ++
python/tvm/relax/__init__.py | 1 +
python/tvm/relax/transform/__init__.py | 20 +
python/tvm/relax/transform/_ffi_api.py | 19 +
python/tvm/relax/transform/transform.py | 345 ++++++++++
src/relax/backend/vm/vm_shape_lower.cc | 725 +++++++++++++++++++++
src/relax/ir/transform.cc | 413 ++++++++++++
.../relax/test_backend_transform_shape_lower.py | 429 ++++++++++++
9 files changed, 2068 insertions(+)
diff --git a/include/tvm/relax/backend.h b/include/tvm/relax/backend.h
new file mode 100644
index 0000000000..4ebeacac0f
--- /dev/null
+++ b/include/tvm/relax/backend.h
@@ -0,0 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/backend.h
+ * \brief Relax backend specific transformation passes.
+ */
+#ifndef TVM_RELAX_BACKEND_H_
+#define TVM_RELAX_BACKEND_H_
+
+#include <tvm/relax/transform.h>
+
+namespace tvm {
+namespace relax {
+namespace transform {
+
+/*!
+ * \brief Lower the shape expression in relax to VM shape heap and TIR
functions.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass VMShapeLower();
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_BACKEND_H_
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
new file mode 100644
index 0000000000..fa288a7f06
--- /dev/null
+++ b/include/tvm/relax/transform.h
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file tvm/relax/transform.h
+ * \brief Relax specific transformation passes.
+ */
+#ifndef TVM_RELAX_TRANSFORM_H_
+#define TVM_RELAX_TRANSFORM_H_
+
+#include <tvm/ir/transform.h>
+#include <tvm/relax/expr.h>
+
+namespace tvm {
+namespace relax {
+namespace transform {
+
+using Pass = tvm::transform::Pass;
+using PassInfo = tvm::transform::PassInfo;
+using PassContext = tvm::transform::PassContext;
+using Function = tvm::relax::Function;
+using DataflowBlock = tvm::relax::DataflowBlock;
+
+/*!
+ * \brief Create a function pass.
+ *
+ * \param pass_func The packed function that contains the optimization.
+ * \param opt_level The optimization level of the function pass.
+ * \param name The name of the function pass.
+ * \param required The list of the passes that the function pass is dependent
on.
+ *
+ * \return The created function pass.
+ */
+TVM_DLL Pass CreateFunctionPass(
+ const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>&
pass_func,
+ int opt_level, String name, tvm::Array<String> required);
+
+/*!
+ * \brief Create a dataflowblock pass.
+ *
+ * \param pass_func The packed function that contains the optimization.
+ * \param opt_level The optimization level of the dataflowblock pass.
+ * \param name The name of the dataflowblock pass.
+ * \param required The list of the passes that the dataflowblock pass is
dependent on.
+ *
+ * \return The created dataflowblock pass.
+ */
+TVM_DLL Pass CreateDataflowBlockPass(
+ const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)>& pass_func,
+ int opt_level, String name, tvm::Array<String> required);
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
+
+#endif // TVM_RELAX_TRANSFORM_H_
diff --git a/python/tvm/relax/__init__.py b/python/tvm/relax/__init__.py
index ce175354d0..a6306b788e 100644
--- a/python/tvm/relax/__init__.py
+++ b/python/tvm/relax/__init__.py
@@ -20,6 +20,7 @@ from . import exec_builder
from . import expr
from . import ty
from . import analysis
+from . import transform
from . import vm
from . import block_builder
from . import op
diff --git a/python/tvm/relax/transform/__init__.py
b/python/tvm/relax/transform/__init__.py
new file mode 100644
index 0000000000..eb4d5f710c
--- /dev/null
+++ b/python/tvm/relax/transform/__init__.py
@@ -0,0 +1,20 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=wildcard-import, redefined-builtin
+"""Relax transformations. """
+
+from .transform import *
diff --git a/python/tvm/relax/transform/_ffi_api.py
b/python/tvm/relax/transform/_ffi_api.py
new file mode 100644
index 0000000000..667aa62c2c
--- /dev/null
+++ b/python/tvm/relax/transform/_ffi_api.py
@@ -0,0 +1,19 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+"""FFI APIs for tvm.transform"""
+import tvm._ffi
+
+tvm._ffi._init_api("relax.transform", __name__)
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
new file mode 100644
index 0000000000..f20f06c522
--- /dev/null
+++ b/python/tvm/relax/transform/transform.py
@@ -0,0 +1,345 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Relax transformation passes."""
+import functools
+import inspect
+import types
+from typing import Callable, Union
+
+import tvm.ir
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("relax.FunctionPass")
+class FunctionPass(tvm.ir.transform.Pass):
+ """A pass that works on each tvm.relax.Function in a module. A function
+ pass class should be created through `function_pass`.
+ """
+
+
+@tvm._ffi.register_object("relax.DataflowBlockPass")
+class DataflowBlockPass(tvm.ir.transform.Pass):
+ """A pass that works on each tvm.relax.DataflowBlock in a module."""
+
+
+def VMShapeLower(*, emit_err_ctx: bool = True) -> tvm.ir.transform.Pass:
+ """Lower the symbolic shape and argument and match-cast structinfo
matching.
+
+ Parameters
+ ----------
+ emit_err_ctx: Optional[bool]
+ Whether emit err context string, can be turned off for testing
purposes.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ return _ffi_api.VMShapeLower(emit_err_ctx) # type: ignore
+
+
+def _wrap_class_function_pass(pass_cls, pass_info):
+ """Wrap a python class as function pass."""
+
+ class PyFunctionPass(FunctionPass):
+ """Internal wrapper class to create a class instance."""
+
+ def __init__(self, *args, **kwargs):
+ # initialize handle in case pass_cls creation failed.
+ self.handle = None
+ inst = pass_cls(*args, **kwargs)
+
+ # it is important not to capture self to
+ # avoid a cyclic dependency
+ def _pass_func(func, mod, ctx):
+ return inst.transform_function(func, mod, ctx)
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.MakeFunctionPass, _pass_func, pass_info # type:
ignore
+ )
+ self._inst = inst
+
+ def __getattr__(self, name):
+ # fall back to instance attribute if there is not any
+ return self._inst.__getattribute__(name)
+
+ functools.update_wrapper(PyFunctionPass.__init__, pass_cls.__init__)
+ PyFunctionPass.__name__ = pass_cls.__name__
+ PyFunctionPass.__doc__ = pass_cls.__doc__
+ PyFunctionPass.__module__ = pass_cls.__module__
+ return PyFunctionPass
+
+
+def function_pass(
+ pass_func=None,
+ opt_level=None,
+ name=None,
+ required=None,
+) -> Union[Callable, FunctionPass]:
+ """Decorate a function pass.
+
+ This function returns a callback when pass_func
+ is provided. Otherwise, it returns the created function pass using the
+ given optimization function.
+
+ Parameters
+ ----------
+ pass_func : Optional[Callable[(Function, Module, PassContext) -> Function]]
+ The transformation function or class.
+
+ opt_level : int
+ The optimization level of this function pass.
+
+ name : Optional[str]
+ The name of the function pass. The name could be empty. In this case,
the
+ name of the optimization function will be used as the pass name.
+
+ required : Optional[List[str]]
+ The list of passes that the function pass is dependent on.
+
+ Returns
+ -------
+ create_function_pass : Union[Callable, FunctionPass]
+
+ A decorator will be returned if pass_func is not provided,
+ otherwise return the decorated result.
+ The returned decorator has two behaviors depending on the input:
+ A new FunctionPass will be returned when we decorate a pass function.
+ A new FunctionPass class will be returned when we decorate a class
type.
+
+ Examples
+ --------
+ The following code block decorates a function pass class.
+
+ .. code-block:: python
+
+ @relax.transform.function_pass(opt_level=1)
+ class TestReplaceFunc:
+ def __init__(self, new_func):
+ self.new_func = new_func
+
+ def transform_function(self, func, mod, ctx):
+ # just for demo purposes
+ # transform func to new_func
+ return self.new_func
+
+ @R.function
+ def f1(x: Tensor[(m, n), "float32"]):
+ return x
+
+ @tvm.script.ir_module
+ class InputMod:
+ @R.function
+ def f2(x: Tensor[(m, n), "float32"]):
+ gv0 = relax.add(x, x)
+ return gv0
+ # fpass is now a special pass that replaces every
+ # function to f1
+ fpass = TestReplaceFunc(f1)
+ # now every function in InputMod is replaced by f1
+ updated_mod = fpass(InputMod)
+
+
+ The following code creates a function pass by decorating
+ a user defined transform function.
+
+ .. code-block:: python
+
+ @relax.transform.function_pass(opt_level=2)
+ def transform(func, mod, ctx):
+ # my transformations here.
+ return func
+
+ function_pass = transform
+ assert isinstance(function_pass, relax.transform.FunctionPass)
+ assert function_pass.info.opt_level == 2
+
+ # Given a module m, the optimization could be invoked as the follwoing:
+ updated_mod = function_pass(m)
+ # Now transform should have been applied to every function in
+ # the provided module m. And the updated module will be returned.
+ """
+
+ if opt_level is None:
+ raise ValueError("Please provide opt_level for the function pass.")
+
+ required = required if required else []
+ if not isinstance(required, (list, tuple)):
+ raise TypeError("Required is expected to be the type of " +
"list/tuple.")
+
+ def create_function_pass(pass_arg):
+ """Internal function that creates a function pass"""
+ fname = name if name else pass_arg.__name__
+ info = tvm.transform.PassInfo(opt_level, fname, required)
+ if inspect.isclass(pass_arg):
+ return _wrap_class_function_pass(pass_arg, info)
+ if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+ raise TypeError("pass_func must be a callable for Function pass")
+ return _ffi_api.MakeFunctionPass(pass_arg, info) # type: ignore
+
+ if pass_func:
+ return create_function_pass(pass_func)
+ return create_function_pass
+
+
+def _wrap_class_dataflowblock_pass(pass_cls, pass_info):
+ """Wrap a python class as dataflowblock pass"""
+
+ class PyDataflowBlockPass(DataflowBlockPass):
+ """Internal wrapper class to create a class instance."""
+
+ def __init__(self, *args, **kwargs):
+ # initialize handle in case pass_cls creation failed.
+ self.handle = None
+ inst = pass_cls(*args, **kwargs)
+
+ # it is important not to capture self to
+ # avoid a cyclic dependency
+ def _pass_func(func, mod, ctx):
+ return inst.transform_dataflowblock(func, mod, ctx)
+
+ self.__init_handle_by_constructor__(
+ _ffi_api.MakeDataflowBlockPass, _pass_func, pass_info # type:
ignore
+ )
+ self._inst = inst
+
+ def __getattr__(self, name):
+ # fall back to instance attribute if there is not any
+ return self._inst.__getattribute__(name)
+
+ functools.update_wrapper(PyDataflowBlockPass.__init__, pass_cls.__init__)
+ PyDataflowBlockPass.__name__ = pass_cls.__name__
+ PyDataflowBlockPass.__doc__ = pass_cls.__doc__
+ PyDataflowBlockPass.__module__ = pass_cls.__module__
+ return PyDataflowBlockPass
+
+
+def dataflowblock_pass(
+ pass_func=None, opt_level=None, name=None, required=None
+) -> Union[Callable, DataflowBlockPass]:
+ """Decorate a dataflowblock pass.
+
+ This function returns a callback when pass_func
+ is provided. Otherwise, it returns the created dataflowblock pass using the
+ given optimization function.
+
+ Parameters
+ ----------
+ pass_func : Optional[Callable[(DataflowBlock, Module, PassContext) ->
DataflowBlock]]
+ The transformation function or class.
+
+ opt_level : int
+ The optimization level of this dataflowblock pass.
+
+ name : Optional[str]
+ The name of the dataflowblock pass. The name could be empty. In this
case, the
+ name of the optimization function will be used as the pass name.
+
+ required : Optional[List[str]]
+ The list of passes that the dataflowblock pass is dependent on.
+
+ Returns
+ -------
+ create_dataflowblock_pass : Union[Callable, DataflowBlockPass]
+
+ A decorator will be returned if pass_func is not provided,
+ otherwise return the decorated result.
+ The returned decorator has two behaviors depending on the input:
+ A new DataflowBlockPass will be returned when we decorate a pass
function.
+ A new DataflowBlockPass class will be returned when we decorate a
class type.
+
+ Examples
+ --------
+ The following code block decorates a dataflowblock pass class.
+
+ .. code-block:: python
+
+ @relax.transform.dataflowblock_pass(opt_level=1)
+ class TestReplaceBinding:
+ # Simple test function to replace the first VarBinding to another.
+
+ def __init__(self):
+ # create a new VarBinding
+ m, n = tir.Var("m", "int64"), tir.Var("n", "int64")
+ lv0 = relax.Var("lv1", relax.TensorStructInfo([m, n],
"float32"))
+ val = relax.const(np.random.rand(24, 56))
+ self.new_binding = relax.VarBinding(lv0, val)
+
+ def transform_dataflowblock(self, block, mod, ctx):
+ # just for demo purposes
+ # Replace the first binding in the DataflowBlock
+ new_bindings = [self.new_binding, block.bindings[1]]
+ new_block = relax.expr.DataflowBlock(new_bindings, block.span)
+ return new_block
+
+ @tvm.script.ir_module
+ class InputMod:
+ @R.function
+ def f1(x: Tensor[(m, n), "float32"]):
+ with relax.dataflow():
+ lv0 = relax.multiply(x, x)
+ gv0 = relax.add(x, x)
+ relax.output(gv0)
+ return gv0
+ # block_pass is now a special pass that replaces every
+ # first binding to the constant value binding
+ block_pass = TestReplaceBinding()
+ # now every first binding in DataflowBlock of InputMod
+ # is replaced by new_binding
+ updated_mod = block_pass(InputMod)
+
+
+ The following code creates a dataflowblock pass by decorating
+ a user defined transform function.
+
+ .. code-block:: python
+
+ @relax.transform.dataflowblock_pass(opt_level=2)
+ def transform(block, mod, ctx):
+ # my transformations here.
+ return block
+
+ block_pass = transform
+ assert isinstance(block_pass, relax.transform.DataflowBlockPass)
+ assert block_pass.info.opt_level == 2
+
+ # Given a module m, the optimization could be invoked as the follwoing:
+ updated_mod = block_pass(m)
+ # Now transform should have been applied to every DataflowBlock in
+ # the provided module m. And the updated module will be returned.
+ """
+
+ if opt_level is None:
+ raise ValueError("Please provide opt_level for the dataflowblock
pass.")
+
+ required = required if required else []
+ if not isinstance(required, (list, tuple)):
+ raise TypeError("Required is expected to be the type of " +
"list/tuple.")
+
+ def create_dataflowblock_pass(pass_arg):
+ """Internal function that creates a dataflowblock pass"""
+ fname = name if name else pass_arg.__name__
+ info = tvm.transform.PassInfo(opt_level, fname, required)
+ if inspect.isclass(pass_arg):
+ return _wrap_class_dataflowblock_pass(pass_arg, info)
+ if not isinstance(pass_arg, (types.FunctionType, types.LambdaType)):
+ raise TypeError("pass_func must be a callable for DataflowBlock
pass")
+ return _ffi_api.MakeDataflowBlockPass(pass_arg, info) # type: ignore
+
+ if pass_func:
+ return create_dataflowblock_pass(pass_func)
+ return create_dataflowblock_pass
diff --git a/src/relax/backend/vm/vm_shape_lower.cc
b/src/relax/backend/vm/vm_shape_lower.cc
new file mode 100644
index 0000000000..090bcf01b5
--- /dev/null
+++ b/src/relax/backend/vm/vm_shape_lower.cc
@@ -0,0 +1,725 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+/*!
+ * \file src/relax/backend/vm/vm_shape_lower.cc
+ * \brief Lower the function boundary type checks and symbolic shape
computations.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/backend.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/struct_info_functor.h>
+#include <tvm/runtime/relax_vm/builtin.h>
+#include <tvm/tir/analysis.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+#include <tvm/tir/stmt_functor.h>
+
+namespace tvm {
+namespace relax {
+
+/*! \brief A slot used in PrimExpr lowering. */
+struct PrimExprSlot {
+ /*! \brief The existing */
+ PrimExpr expr;
+ /*! \brief The slot index */
+ int index;
+ // The following three members are auxiliary data
+ // to help shape rewriting.
+ /*!
+ * \brief List of slots whose PrimExpr uses this PrimExpr.
+ * \note Users won't be empty only if PrimExpr is a Var and it does not
include itself.
+ */
+ std::vector<PrimExprSlot*> user_slots;
+ /*!
+ * \brief Number of outstanding vars that are not defined in this PrimExpr.
+ * \note This is a helper counter used in analysis to perform computations.
+ */
+ int outstanding_defs = 0;
+ /*! \brief Whether we have computed the value. */
+ bool value_computed = false;
+};
+
+/*!
+ * \brief Helper dats structure to collect pairs of match shapes
+ * in a recursive matching process.
+ */
+struct MatchShapeTodoItem {
+ Expr input;
+ Array<PrimExpr> pattern;
+ String err_ctx;
+};
+
+/*! \brief Slot map used for shape lowering. */
+using PrimExprSlotMap =
+ std::unordered_map<PrimExpr, PrimExprSlot*, StructuralHash,
tir::ExprDeepEqual>;
+
+// Collector to collect PrimExprSlotMap
+class PrimExprSlotCollector : public ExprVisitor, public StructInfoVisitor {
+ public:
+ // collect the PrimExpr slot for a given function
+ static void Collect(Function func,
std::vector<std::unique_ptr<PrimExprSlot>>* slot_vec,
+ PrimExprSlotMap* slot_map) {
+ PrimExprSlotCollector collector;
+ collector.slot_vec_ = slot_vec;
+ collector.slot_map_ = slot_map;
+ // collect shape declaration in func params
+ for (auto param : func->params) {
+ collector.VisitStructInfo(GetStructInfo(param));
+ collector.VisitExpr(param);
+ }
+ collector.VisitExpr(func->body);
+ }
+
+ private:
+ void VisitPrimExpr(const PrimExpr& expr) final {
+ if (expr->IsInstance<IntImmNode>()) return;
+ if (slot_map_->count(expr) == 0) {
+ auto slot = std::make_unique<PrimExprSlot>();
+ slot->expr = expr;
+ slot->index = static_cast<int>(slot_vec_->size());
+ slot_map_->emplace(expr, slot.get());
+ slot_vec_->emplace_back(std::move(slot));
+ }
+ }
+
+ void VisitBinding_(const MatchCastNode* op) final {
+ // Visit the match cast struct info so we can define
+ // the symbolic variables here.
+ this->VisitStructInfo(op->struct_info);
+ }
+
+ void VisitExpr_(const FunctionNode* op) final {
+ // Do not recurse into function node as it is self-contained
+ }
+
+ void VisitStructInfo_(const FuncStructInfoNode* op) final {
+ // Do not recurse into function struct info as it is self-contained
+ }
+
+ void VisitStructInfoExprField(const PrimExpr& expr) final {
VisitPrimExpr(expr); }
+
+ void VisitStructInfoExprField(const Expr& expr) final {
ExprVisitor::VisitExpr(expr); }
+
+ std::vector<std::unique_ptr<PrimExprSlot>>* slot_vec_;
+ PrimExprSlotMap* slot_map_;
+};
+
+/*!
+ * \brief Main logic to transform the shape lowered functions
+ *
+ * Consider the following input:
+ *
+ * \code
+ *
+ * def f(x: R.Tuple(R.Tensor([m, n+1]), R.Tensor([n, 2])) -> R.Tensor:
+ * return x
+ *
+ * \endcode
+ *
+ * Overall flow of the algorithm:
+ * - Preprocess: PrimExprSlot collection, we scan the function and allocate
PrimExprSlot
+ * for each PrimExpr. In the above example, the result mapping from the slot
index
+ * to expr would be {0:m, 1: n+1: 2: n}. Note that "n+1" also get a slot.
+ * PrimExprSlot also comes with auxiliary fields that track whether its value
+ * can be readily computed.
+ *
+ * Steps at each matching point:
+ * - Step 0: We call CheckMatchCast,
+ * which will recursively unpack the StructInfo, and generate static
information checks.
+ * Note that this step only generates functions for checking types and ndim
info, but not
+ * the symbolic shape variables. The symbolic shape-matching results will be
returned as
+ * vector<MatchShapeTodoItem>. This is because symbolic shape matching may
not be completed
+ * in a single round. Importantly, CheckMatchCast also deals with tuple
unpacking.
+ *
+ * - Step 1: We then call RunMatch to generate the statements for matching
symbolic shapes.
+ * In the above example, the first round will store the value of m, n to
their corresponding
+ * slot. RunMatch may return outstanding items. In the above example
x.shape[1] == n+1 cannot
+ * be checked in the first round. RunMatch will populate new vars(this case
n, m), these vars
+ * are added to a ready queue (ready_vars_)
+ *
+ * - Step 2: We EmitOutstandingPrimExprCompute to check if ready_vars will
trigger new values
+ * to be computed. We eagerly compute all the outstanding values. The
trigger is done through
+ * a ref counter which decreases when each outstanding def is satisfied.
+ * This step can also generate additional TIR functions to carry out shape
computations.
+ *
+ * - Step 3: RunMatch again for given outstanding match todos. This time all
invariants
+ * should be checked.
+ *
+ * The above step would populate each slot(which is backed by an element in
shape_heap).
+ * Each time we find a symbolic shape tuple, we call MakeShape for given slot
indices
+ * in the shape_heap.
+ *
+ *
+ * Key functions in the flow:
+ * - PrimExprSlotCollector: preprocessing and collecting the slots
+ * - CheckMatchCast: recursively structinfo unpacking, generate checks and
match items.
+ * - RunMatch: generate symbolic shape matches
+ * - EmitOutstandingPrimExprCompute: tracks the variables to be computed and
emit shape computation
+ * - VisitExpr_(ShapeExprNode*): makes symbolic shape tuple.
+ *
+ * The checks and symbolic shape all maps to runtime builtin functions. Please
checkout
+ * runtime/relax_vm/builtin.cc for their definitions.
+ *
+ * Shape computation are lowered to host-side TIR functions that load var from
slot
+ * and store computed results into the slot. For a given slot map: {0:m, 1:
n+1: 2: n}
+ * It will create the shape_func below that loads data from H[2](n's slot) run
compute
+ * and store back to H[1](n+1's slot).
+ *
+ * \code
+ *
+ * @T.prim_func
+ * def shape_func(H: T.Buffer([3], "int64")):
+ * H[1] = H[2] + 1
+ *
+ * \endcode
+ *
+ * The current implementation will batch all shape computations at each match
point.
+ * For example, all the expressions that depend on n, m will be computed in a
single
+ * shape_func at the function boundary. If there are follow-up match_cast
points,
+ * that defines new variable, then we might we will generate new shape
functions
+ * to compute expressions that depend on these variables.
+ */
+class VMShapeLowerMutator
+ : public ExprMutator,
+ public StructInfoFunctor<void(const StructInfo&, Expr, bool, const
String&,
+ std::vector<MatchShapeTodoItem>*)> {
+ public:
+ static IRModule Lower(IRModule mod, bool emit_err_ctx) {
+ VMShapeLowerMutator mutator(mod, emit_err_ctx);
+
+ for (auto& kv : mod->functions) {
+ if (auto* func = kv.second.as<FunctionNode>()) {
+ Function updated_func = mutator.Rewrite(kv.first,
GetRef<Function>(func));
+ mutator.builder_->UpdateFunction(kv.first, updated_func);
+ }
+ }
+ return mutator.builder_->GetContextIRModule();
+ }
+
+ private:
+ explicit VMShapeLowerMutator(IRModule mod, bool emit_err_ctx)
+ : ExprMutator(mod), emit_err_ctx_(emit_err_ctx) {}
+
+ using ExprMutator::VisitExpr_;
+
+ // Unit rewrite function per function.
+ Function Rewrite(GlobalVar gvar, Function func) {
+ // prepare mapping and heap var
+ PrimExprSlotCollector::Collect(func, &slot_vec_, &slot_map_);
+ heap_size_ = IntImm(ShapeDType(), static_cast<int64_t>(slot_vec_.size()));
+ VarBinding shape_heap_binding = this->AllocShapeHeapBinding(heap_size_);
+ shape_heap_ = shape_heap_binding->var;
+
+ // prepare slot information
+ this->PopulateSlotInfo();
+
+ Array<BindingBlock> blocks;
+
+ builder_->BeginScope(func->params);
+
+ {
+ // Check the parameter section.
+ builder_->BeginBindingBlock();
+ this->builder_->EmitNormalized(shape_heap_binding);
+ std::vector<MatchShapeTodoItem> match_todos;
+ for (size_t i = 0; i < func->params.size(); ++i) {
+ StructInfo sinfo = GetStructInfo(func->params[i]);
+ std::ostringstream err_ctx;
+ err_ctx << "ErrorContext(fn=" << gvar->name_hint << ", loc=param[" << i
+ << "], param=" << func->params[i]->name_hint() << ",
annotation=" << sinfo << ") ";
+ this->CheckMatchCast(sinfo, func->params[i], true, err_ctx.str(),
&match_todos);
+ }
+ // insert heap generation logic.
+ match_todos = this->RunMatch(match_todos, false);
+ this->EmitOutstandingPrimExprCompute();
+ this->RunMatch(match_todos, true);
+
+ BindingBlock pre_block = builder_->EndBlock();
+ blocks.push_back(pre_block);
+ }
+
+ // new body.
+ auto body_seq = Downcast<SeqExpr>(this->VisitWithNewScope(func->body,
func->params));
+ blocks.insert(blocks.end(), body_seq->blocks.begin(),
body_seq->blocks.end());
+
+ {
+ // Insert the return value check
+ builder_->BeginBindingBlock();
+ std::ostringstream err_ctx;
+ err_ctx << "ErrorContext(fn=" << gvar->name_hint
+ << ", loc=return, annotation=" << func->ret_struct_info << ") ";
+ std::vector<MatchShapeTodoItem> match_todos;
+ // NOTE: the return value's shape computation must already be defined.
+ this->CheckMatchCast(func->ret_struct_info, body_seq->body, false,
err_ctx.str(),
+ &match_todos);
+ // NOTE: the return value's shape computation must already be defined.
+ this->RunMatch(match_todos, true);
+ BindingBlock post_block = builder_->EndBlock();
+ blocks.push_back(post_block);
+ }
+
+ auto new_body = builder_->Normalize(SeqExpr(blocks, body_seq->body));
+ // create a new function
+ return Function(func->params, new_body, func->ret_struct_info,
func->attrs);
+ }
+
+ //-------------------------------------------------------
+ // PrimExpr slot handling
+ //-------------------------------------------------------
+ static DataType ShapeDType() { return DataType::Int(64); }
+
+ /*! \brief populate additional information in the slot. */
+ void PopulateSlotInfo() {
+ for (auto& kv : slot_map_) {
+ auto* slot = kv.second;
+ if (!slot->expr.as<tir::VarNode>()) {
+ Array<tir::Var> dep_vars = tir::UndefinedVars(slot->expr);
+ for (auto var : dep_vars) {
+ auto it = slot_map_.find(var);
+ ICHECK(it != slot_map_.end())
+ << "Var " << var << "is not defined in the function but is
referenced by "
+ << slot->expr;
+ auto* var_slot = it->second;
+ // populate the use slot.
+ var_slot->user_slots.push_back(slot);
+ }
+ // set outstanding defs.
+ slot->outstanding_defs += static_cast<int>(dep_vars.size());
+ }
+ }
+ }
+ //-------------------------------------------------------
+ // Helper functions
+ //-------------------------------------------------------
+ StringImm GetErrContext(String err_ctx) const {
+ return emit_err_ctx_ ? StringImm(err_ctx) : StringImm("");
+ }
+
+ VarBinding AllocShapeHeapBinding(IntImm heap_size) {
+ if (heap_size->value > 0) {
+ TensorStructInfo heap_sinfo(ShapeDType(), 1);
+ Var var("shape_heap", heap_sinfo);
+ // set up the builtin func.
+ Call call(call_builtin_with_ctx_op_,
+ {builtin_alloc_shape_heap_, Tuple({PrimValue(heap_size)})},
Attrs(), {heap_sinfo});
+ UpdateStructInfo(call, heap_sinfo);
+ return VarBinding(var, call);
+ } else {
+ Var var("shape_heap", ObjectStructInfo());
+ Call call(null_value_op_, {});
+ UpdateStructInfo(call, ObjectStructInfo());
+ return VarBinding(var, call);
+ }
+ }
+
+ //-------------------------------------------------------
+ // Expr mutation overloading.
+ //-------------------------------------------------------
+ Expr VisitExpr_(const FunctionNode* op) final {
+ LOG(FATAL) << "VMShapeLower do not work for local functions, make sure "
+ << " to run it after LambdaLift";
+ return GetRef<Expr>(op);
+ }
+
+ Expr VisitExpr_(const ShapeExprNode* op) final {
+ using runtime::relax_vm::MakeShapeCode;
+ // Constant shape can be preserved.
+ bool is_const_shape = std::all_of(op->values.begin(), op->values.end(),
[](const PrimExpr& e) {
+ return e->IsInstance<IntImmNode>();
+ });
+ if (is_const_shape) {
+ return GetRef<Expr>(op);
+ }
+
+ Array<Expr> args = {shape_heap_,
PrimValue::Int64(static_cast<int64_t>(op->values.size()))};
+ for (PrimExpr expr : op->values) {
+ if (auto* int_expr = expr.as<IntImmNode>()) {
+
args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kUseImm)));
+ args.push_back(PrimValue::Int64(int_expr->value));
+ } else {
+ auto it = slot_map_.find(expr);
+ ICHECK(it != slot_map_.end());
+ auto* slot = it->second;
+ ICHECK(slot->value_computed) << "PrimExpr " << expr << " has not been
computed";
+
args.push_back(PrimValue::Int64(static_cast<int>(MakeShapeCode::kLoadShape)));
+ args.push_back(PrimValue::Int64(slot->index));
+ }
+ }
+
+ // make_shape(heap, n, c[0], r[0], c[1], r[1] ..., c[n], r[n])
+ Call call(builtin_make_shape_, args, Attrs(),
+ {ShapeStructInfo(static_cast<int>(op->values.size()))});
+ return call;
+ }
+
+ void VisitBinding_(const MatchCastNode* binding) final {
+ Expr value = ExprMutator::VisitExpr(binding->value);
+ std::vector<MatchShapeTodoItem> match_todos;
+ std::ostringstream err_ctx;
+ err_ctx << "ErrorContext(match_cast, struct_info=" << binding->struct_info
<< ") ";
+ // always_check=false
+ this->CheckMatchCast(binding->struct_info, value, false, err_ctx.str(),
&match_todos);
+
+ match_todos = this->RunMatch(match_todos, false);
+ this->EmitOutstandingPrimExprCompute();
+ this->RunMatch(match_todos, true);
+
+ // These checks are emitted as extra, in codegen
+ // match-cast is simply ignored and treated as a normal binding.
+ builder_->EmitNormalized(GetRef<MatchCast>(binding));
+ }
+
+ // Do not override shape in struct info fields
+ // We only override the shape that are already part of the normal function
values
+ // If future passes lift those values out into the values,
+ // then codegen may not be able to handle symbolic values.
+ // Place this pass as last pass before codegen.
+ StructInfo VisitExprDepStructInfoField(const StructInfo& sinfo) final {
return sinfo; }
+
+ //-------------------------------------------------------
+ // Shape computations.
+ //-------------------------------------------------------
+ /*!
+ * \brief Execute the match todo items.
+ *
+ * This function can populate vars in the match items when seeing it for the
first time.
+ * These new vars will be added to this->ready_vars_.
+ *
+ * If an item contains PrimExpr that are yet to be computed (but may be
computable through
+ * vars defined in this round), it will be returned to the caller.
+ *
+ * The caller should call EmitOutstandingPrimExprCompute, then call RunMatch
again.
+ *
+ * \param match_todos The list of match items to be executed.
+ * \param require_value_computed Whether we require all expr to be computed.
+ * \return List of outstanding items that contains value that are yet to be
computed.
+ */
+ std::vector<MatchShapeTodoItem> RunMatch(const
std::vector<MatchShapeTodoItem>& match_todos,
+ bool require_value_computed) {
+ std::vector<MatchShapeTodoItem> outstanding_todos;
+
+ using runtime::relax_vm::MatchShapeCode;
+ for (const MatchShapeTodoItem& item : match_todos) {
+ int64_t shape_len = static_cast<int64_t>(item.pattern.size());
+ bool all_nop = true;
+ int num_outstanding_exprs = 0;
+
+ Array<Expr> args = {item.input, shape_heap_,
PrimValue::Int64(shape_len)};
+
+ for (PrimExpr expr : item.pattern) {
+ MatchShapeCode code = MatchShapeCode::kNoOp;
+ int64_t rvalue = 0;
+ if (auto* int_expr = expr.as<IntImmNode>()) {
+ code = MatchShapeCode::kAssertEqualToImm;
+ rvalue = int_expr->value;
+ } else {
+ auto it = slot_map_.find(expr);
+ ICHECK(it != slot_map_.end());
+ auto* slot = it->second;
+ if (slot->value_computed) {
+ code = MatchShapeCode::kAssertEqualToLoad;
+ rvalue = slot->index;
+ } else {
+ // the value is not yet computed
+ ICHECK(!require_value_computed) << "PrimExpr " << expr << " is not
computed";
+ if (expr.as<tir::VarNode>()) {
+ // if it is a var, we will populate it in this round.
+ // otherwise, we skip and mark it as outstanding
+ code = MatchShapeCode::kStoreToHeap;
+ rvalue = slot->index;
+ slot->value_computed = true;
+ ready_vars_.push_back(slot);
+ } else {
+ code = MatchShapeCode::kNoOp;
+ rvalue = 0;
+ ++num_outstanding_exprs;
+ }
+ }
+ }
+ all_nop = all_nop && code == MatchShapeCode::kNoOp;
+ args.push_back(PrimValue::Int64(static_cast<int>(code)));
+ args.push_back(PrimValue::Int64(rvalue));
+ }
+ if (num_outstanding_exprs != 0) {
+ outstanding_todos.push_back(item);
+ }
+ args.push_back(GetErrContext(item.err_ctx));
+ if (!all_nop) {
+ Call call(builtin_match_shape_, args, Attrs(), {void_sinfo_});
+ builder_->Emit(call, "_");
+ }
+ }
+ return std::move(outstanding_todos);
+ }
+
+ /*!
+ * \brief Compute a list of prim expr that now be computed
+ * for given ready vars.
+ */
+ std::vector<PrimExprSlot*> GetReadyPrimExprSlots() {
+ std::vector<PrimExprSlot*> to_compute;
+ for (PrimExprSlot* slot : ready_vars_) {
+ for (PrimExprSlot* user : slot->user_slots) {
+ ICHECK_GT(user->outstanding_defs, 0);
+ user->outstanding_defs -= 1;
+ if (user->outstanding_defs == 0) {
+ to_compute.push_back(user);
+ }
+ }
+ }
+ ready_vars_.clear();
+ return to_compute;
+ }
+
+ /*!
+ * \brief Check the dependent expressions of ready_vars_,
+ *
+ * If there are outstanding PrimExpr that can now be computed
+ * we generate a PrimFunc that compute the extra shape values
+ *
+ * We will then clear the ready_vars.
+ *
+ * \return Number of PrimExpr computed.
+ */
+ size_t EmitOutstandingPrimExprCompute() {
+ std::vector<PrimExprSlot*> to_compute = GetReadyPrimExprSlots();
+ if (to_compute.size() == 0) return 0;
+ ICHECK_GT(heap_size_->value, 0);
+ // construct a PrimFunc that compute the shape.
+ tir::Var heap("heap", DataType::Handle());
+ Array<PrimExpr> buffer_shape{heap_size_};
+ tir::Buffer buffer = tir::decl_buffer(buffer_shape, ShapeDType(), "H",
"global");
+ Map<tir::Var, tir::Buffer> buffer_map;
+ buffer_map.Set(heap, buffer);
+
+ auto var_map = [&](const tir::Var& var) -> Optional<PrimExpr> {
+ auto it = slot_map_.find(var);
+ ICHECK(it != slot_map_.end());
+ return tir::BufferLoad(buffer, {IntImm(ShapeDType(),
it->second->index)});
+ };
+
+ Array<tir::Stmt> seq;
+ for (PrimExprSlot* slot : to_compute) {
+ ICHECK(!slot->value_computed);
+ slot->value_computed = true;
+ PrimExpr value = tir::Substitute(slot->expr, var_map);
+ seq.push_back(tir::BufferStore(buffer, value, {IntImm(ShapeDType(),
slot->index)}));
+ }
+
+ tir::Stmt body = tir::SeqStmt::Flatten(seq);
+ Array<tir::Var> params{heap};
+ Type ret_type = VoidType();
+
+ // TODO(relax-team): Consider attach the target attribute to
+ // the shape_func to indicate that this is a host function
+ // This could require us to attach target to the relax function here.
+ tir::PrimFunc shape_func(params, body, ret_type, buffer_map);
+ GlobalVar shape_func_var = builder_->AddFunction(shape_func, "shape_func");
+ builder_->Emit(Call(shape_func_var, {shape_heap_}), "_");
+ return to_compute.size();
+ }
+ //-------------------------------------------------------
+ // StructInfo value match logic
+ //
+ // CheckMatchCast is the only function needed by
+ // other code sections
+ //-------------------------------------------------------
+ /*!
+ * \brief Insert runtime check of the match cast condition(value,
struct_info).
+ *
+ * \param struct_info The struct info to be matched.
+ * \param value The input value.
+ * \param always_check Whether we insert runtime check even if we can prove
+ * that value's struct info already satisfies the condition.
+ * This option is necessary for argument checking per our calling
convention.
+ *
+ * \param err_ctx Extra error context to bring more informative error
reporting.
+ * \param match_todos List of match shape todo items collected when
recursively
+ * visit the match cast.
+ */
+ void CheckMatchCast(const StructInfo& struct_info, Expr value, bool
always_check,
+ const String& err_ctx, std::vector<MatchShapeTodoItem>*
match_todos) {
+ return this->VisitStructInfo(struct_info, value, always_check, err_ctx,
match_todos);
+ }
+
+ void VisitStructInfo(const StructInfo& struct_info, Expr value, bool
always_check,
+ const String& err_ctx, std::vector<MatchShapeTodoItem>*
match_todos) final {
+ // short-cut, if the struct info already satisfies the
+ // constraint during match cast, we can skip matching
+ if (!always_check && IsBaseOf(struct_info, GetStructInfo(value))) return;
+ return StructInfoFunctor::VisitStructInfo(struct_info, value,
always_check, err_ctx,
+ match_todos);
+ }
+
+ void VisitStructInfo_(const ObjectStructInfoNode* op, Expr value, bool
always_check,
+ const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
+ }
+
+ void VisitStructInfo_(const PrimStructInfoNode* op, Expr value, bool
always_check,
+ const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
+ // TODO(relax-team) add PrimValue checks later.
+ LOG(FATAL) << "MatchCast of PrimValue is not yet supported";
+ }
+
+ void VisitStructInfo_(const ShapeStructInfoNode* op, Expr value, bool
always_check,
+ const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
+ // emit runtime check of shape
+ if (always_check || !IsBaseOf(ShapeStructInfo(op->ndim),
GetStructInfo(value))) {
+ // check_shape_info(value, ndim, err_ctx)
+ Call call(builtin_check_shape_info_,
+ {value, PrimValue::Int64(op->ndim), GetErrContext(err_ctx)},
Attrs(),
+ {void_sinfo_});
+ builder_->Emit(call, "_");
+ }
+ if (op->values.defined()) {
+ MatchShapeTodoItem item;
+ item.input = value;
+ item.pattern = op->values.value();
+ item.err_ctx = err_ctx;
+ match_todos->push_back(item);
+ }
+ }
+
+ void VisitStructInfo_(const TensorStructInfoNode* op, Expr value, bool
always_check,
+ const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
+ // emit runtime check of shape
+ if (always_check || !IsBaseOf(TensorStructInfo(op->dtype, op->ndim),
GetStructInfo(value))) {
+ // check_tensor_info(value, ndim, dtype, err_ctx)
+ Call call(builtin_check_tensor_info_,
+ {value, PrimValue::Int64(op->ndim), DataTypeImm(op->dtype),
GetErrContext(err_ctx)},
+ Attrs(), {void_sinfo_});
+ builder_->Emit(call, "_");
+ }
+
+ if (auto* shape_expr = op->shape.as<ShapeExprNode>()) {
+ MatchShapeTodoItem item;
+ item.input = value;
+ item.pattern = shape_expr->values;
+ item.err_ctx = err_ctx;
+ match_todos->push_back(item);
+ } else if (op->shape.as<VarNode>()) {
+ // NOTE: This part of the logic is left empty for future support as it
is less common.
+ // Future implementors: we can emit a binding here and assert here.
+ LOG(FATAL) << "Cannot handle Tensor shape pattern where a var appears
multiple times";
+ } else {
+ ICHECK(!op->shape.defined()) << "Can only handle tensor shape pattern
var";
+ }
+ }
+
+ // Internal helper function to make tuple get item.
+ // This function will try to simplify constant tuples
+ // the return value **always** have struct info.
+ Expr MakeTupleGetItem(Expr value, int64_t index) {
+ if (auto* tuple_expr = value.as<TupleNode>()) {
+ return tuple_expr->fields[index];
+ } else if (auto* tuple_sinfo =
GetStructInfoAs<TupleStructInfoNode>(value)) {
+ // value is tuple type, it is OK to run tuple get item.
+ auto ret = TupleGetItem(value, index);
+ UpdateStructInfo(ret, tuple_sinfo->fields[index]);
+ return ret;
+ } else {
+ // call runtime tuple get item, and return a object.
+ Call call(builtin_tuple_getitem_, {value, PrimValue::Int64(index)},
Attrs(), {object_sinfo_});
+ UpdateStructInfo(call, ObjectStructInfo());
+ return call;
+ }
+ }
+
+ void VisitStructInfo_(const TupleStructInfoNode* op, Expr value, bool
always_check,
+ const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
+ auto* value_tinfo = GetStructInfoAs<TupleStructInfoNode>(value);
+ if (value_tinfo) {
+ CHECK_EQ(value_tinfo->fields.size(), op->fields.size())
+ << "TypeError: " << err_ctx << " during match-cast we find tuple
size mismatch";
+ }
+ if (always_check || !value_tinfo) {
+ // check_tuple_info(value, tuple_size)
+ Call call(builtin_check_tuple_info_,
+ {value,
PrimValue::Int64(static_cast<int64_t>(op->fields.size())),
+ GetErrContext(err_ctx)},
+ Attrs(), {void_sinfo_});
+ builder_->Emit(call, "_");
+ }
+ // recursively visit each sub-field and run matching
+ for (size_t i = 0; i < op->fields.size(); ++i) {
+ this->VisitStructInfo(op->fields[i], MakeTupleGetItem(value, i),
always_check, err_ctx,
+ match_todos);
+ }
+ }
+
+ void VisitStructInfo_(const FuncStructInfoNode* op, Expr value, bool
always_check,
+ const String& err_ctx,
std::vector<MatchShapeTodoItem>* match_todos) final {
+ // we only check function is callable.
+ if (!always_check && MatchStructInfo<FuncStructInfo>(value)) return;
+ // check_func_info(value, err_ctx)
+ Call call(builtin_check_func_info_, {value, GetErrContext(err_ctx)},
Attrs(), {void_sinfo_});
+ builder_->Emit(call, "_");
+ }
+
+ //-------------------------------------------------------
+ // Private member fields.
+ //-------------------------------------------------------
+ /*! \brief whether to emit error context, can be turned off for testing
purposes. */
+ bool emit_err_ctx_{true};
+ /*! \brief heap ptr to store the PrimExpr slots. */
+ Var shape_heap_;
+ /*! \brief heap size. */
+ IntImm heap_size_;
+ /*! \brief index => slot. */
+ std::vector<std::unique_ptr<PrimExprSlot>> slot_vec_;
+ /*! \brief Expr => slot. */
+ PrimExprSlotMap slot_map_;
+ /*!
+ * \brief List of vars that are being defined but
+ * have not go through outstanding shape compute check.
+ */
+ std::vector<PrimExprSlot*> ready_vars_;
+ // call builtin cop
+ const Op& call_builtin_with_ctx_op_ = Op::Get("relax.call_builtin_with_ctx");
+ const Op& null_value_op_ = Op::Get("relax.null_value");
+ // common struct info
+ const StructInfo object_sinfo_ = ObjectStructInfo();
+ const StructInfo void_sinfo_ = TupleStructInfo(Array<StructInfo>({}));
+ // check function
+ const ExternFunc builtin_alloc_shape_heap_{"vm.builtin.alloc_shape_heap"};
+ const ExternFunc builtin_match_shape_{"vm.builtin.match_shape"};
+ const ExternFunc builtin_make_shape_{"vm.builtin.make_shape"};
+ const ExternFunc builtin_check_shape_info_{"vm.builtin.check_shape_info"};
+ const ExternFunc builtin_check_tensor_info_{"vm.builtin.check_tensor_info"};
+ const ExternFunc builtin_check_tuple_info_{"vm.builtin.check_tuple_info"};
+ const ExternFunc builtin_check_func_info_{"vm.builtin.check_func_info"};
+ const ExternFunc builtin_tuple_getitem_{"vm.builtin.tuple_getitem"};
+};
+
+namespace transform {
+
+Pass VMShapeLower(bool emit_err_ctx) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule mod, PassContext pc) { return
VMShapeLowerMutator::Lower(mod, emit_err_ctx); };
+ return CreateModulePass(pass_func, 0, "VMShapeLower", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.VMShapeLower").set_body_typed([](bool
emit_err_ctx) {
+ return VMShapeLower(emit_err_ctx);
+});
+
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/ir/transform.cc b/src/relax/ir/transform.cc
new file mode 100644
index 0000000000..1b077d8b88
--- /dev/null
+++ b/src/relax/ir/transform.cc
@@ -0,0 +1,413 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file relax/ir/transform.cc
+ * \brief Relax specific transformation passes.
+ */
+#include <dmlc/thread_local.h>
+#include <tvm/node/repr_printer.h>
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/struct_info_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relay/function.h>
+#include <tvm/runtime/registry.h>
+namespace tvm {
+namespace relax {
+namespace transform {
+
+TVM_REGISTER_PASS_CONFIG_OPTION("relax.fallback_device_type", IntImm);
+
+// TODO(@yuchen): will need to dedup with FunctionPass in Relay when we
upstream
+class FunctionPass;
+
+/*!
+ * \brief Function-level passes are used to implement various global
+ * optimizations for a given Relax IRModule. It fetches one function at a time
+ * from the function list in the IRModule for optimization.
+ *
+ * Note that the scope of passes at this level is a Relax function. Therefore,
+ * we cannot add or delete a function through these passes as they are not
aware
+ * of the global information.
+ */
+class FunctionPassNode : public tvm::transform::PassNode {
+ public:
+ /* \brief The pass meta data.*/
+ PassInfo pass_info;
+
+ /*! \brief The packed pass function sketches the real optimization. For
+ * instance, we can implement a pass that works on a Relax function as a
+ * `pass_func` and let it run on a given IRModule. The same `pass_func` will
+ * then be applied on each function in the IRModule.
+ */
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func;
+
+ FunctionPassNode() = default;
+
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
+
+ /*!
+ * \brief Run a function pass on given pass context.
+ *
+ * \param mod The IRModule that an optimization pass is applied on.
+ * \param pass_ctx The context that an optimization pass executes on.
+ *
+ * \return Return the updated IRModule.
+ */
+ IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
+
+ /*!
+ * \brief Get the pass information/meta data.
+ */
+ PassInfo Info() const override { return pass_info; }
+
+ static constexpr const char* _type_key = "relax.FunctionPass";
+ TVM_DECLARE_FINAL_OBJECT_INFO(FunctionPassNode, PassNode);
+
+ private:
+ /*
+ * \brief Check if a function should be skipped for optimization.
+ *
+ * \param func The target function to be checked.
+ *
+ * \return Return true if the function will be skipped, otherwise false.
+ */
+ bool SkipFunction(const Function& func) const;
+};
+
+class FunctionPass : public Pass {
+ public:
+ /*!
+ * \brief The constructor
+ * \param pass_func The packed function which implements a pass.
+ * \param pass_info The pass info.
+ */
+ TVM_DLL FunctionPass(
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func,
+ PassInfo pass_info);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(FunctionPass, Pass, FunctionPassNode);
+};
+
+FunctionPass::FunctionPass(
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func,
+ PassInfo pass_info) {
+ auto n = make_object<FunctionPassNode>();
+ n->pass_func = std::move(pass_func);
+ n->pass_info = std::move(pass_info);
+ data_ = std::move(n);
+}
+
+// Perform IRModule -> IRModule optimizations at the Function level.
+IRModule FunctionPassNode::operator()(IRModule mod, const PassContext&
pass_ctx) const {
+ DiagnosticContext previous = DiagnosticContext::Default(mod);
+
+ if (pass_ctx->diag_ctx) {
+ DiagnosticContext tmp = pass_ctx->diag_ctx.value();
+ pass_ctx->diag_ctx = previous;
+ previous = tmp;
+ } else {
+ pass_ctx->diag_ctx = previous;
+ }
+
+ ICHECK(pass_ctx->diag_ctx)
+ << "The diagnostic context was set at the top of this block this is a
bug.";
+
+ const PassInfo& pass_info = Info();
+
+ ICHECK(mod.defined());
+
+ VLOG_CONTEXT << pass_info->name;
+ VLOG(0) << "Executing function pass with opt level: " <<
pass_info->opt_level;
+ VLOG(1) << "Input module:" << std::endl << mod;
+
+ IRModule updated_mod = mod->ShallowCopy();
+
+ std::vector<std::pair<GlobalVar, Function> > updates;
+ for (const auto& it : updated_mod->functions) {
+ // only picks up relax::Function
+ if (auto* n = it.second.as<FunctionNode>()) {
+ Function func = GetRef<Function>(n);
+ auto updated_func = SkipFunction(func) ? func : pass_func(func,
updated_mod, pass_ctx);
+ updates.push_back({it.first, updated_func});
+ }
+ }
+
+ for (const auto& pair : updates) {
+ updated_mod->Add(pair.first, pair.second, true);
+ }
+
+ ICHECK(pass_ctx->diag_ctx)
+ << "The diagnostic context was set at the top of this block, this is a
bug.";
+
+ pass_ctx->diag_ctx.value().Render();
+ pass_ctx->diag_ctx = previous;
+
+ VLOG(1) << "Output module:" << std::endl << updated_mod;
+
+ return updated_mod;
+}
+
+bool FunctionPassNode::SkipFunction(const Function& func) const {
+ // TODO(@yuchen): will need to revisit in the future
+ return (func->GetAttr<String>(relay::attr::kCompiler).defined()) ||
+ func->GetAttr<Integer>(relay::attr::kSkipOptimization, 0) != 0;
+}
+
+Pass CreateFunctionPass(
+ const runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>&
pass_func,
+ int opt_level, String name, tvm::Array<String> required) {
+ PassInfo pass_info = PassInfo(opt_level, name, required);
+ return FunctionPass(pass_func, pass_info);
+}
+
+TVM_REGISTER_NODE_TYPE(FunctionPassNode);
+
+TVM_REGISTER_GLOBAL("relax.transform.MakeFunctionPass")
+ .set_body_typed(
+ [](runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func,
+ PassInfo pass_info) { return FunctionPass(pass_func, pass_info); });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<FunctionPassNode>([](const ObjectRef& ref, ReprPrinter* p) {
+ auto* node = static_cast<const FunctionPassNode*>(ref.get());
+ const PassInfo info = node->Info();
+ p->stream << "Run Function pass: " << info->name << " at the
optimization level "
+ << info->opt_level;
+ });
+
+class DataflowBlockPass;
+
+/*!
+ * \brief DataflowBlock-level passes are used to implement various dataflow
block
+ * optimizations for a given Relax IRModule. It fetches one dataflow block at
a time
+ * from the functions in an IRModule, and yields a rewritten DataflowBlock.
+ *
+ * Note that the scope of passes at this level is a Relax DataflowBlock.
Therefore,
+ * we cannot modify the global scope Vars and symbolic shape Vars defined
inside the dataflow block.
+ */
+class DataflowBlockPassNode : public tvm::transform::PassNode {
+ public:
+ /* \brief The pass meta data.*/
+ PassInfo pass_info;
+
+ /*! \brief The packed pass function sketches the real optimization. For
+ * instance, we can implement a pass that works on a Relax DataflowBlock as a
+ * `pass_func` and let it run on a given IRModule. The same `pass_func` will
+ * then be applied on each DataflowBlock in the IRModule.
+ */
+ runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)> pass_func;
+
+ DataflowBlockPassNode() = default;
+
+ void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("pass_info", &pass_info); }
+
+ IRModule operator()(IRModule mod, const PassContext& pass_ctx) const final;
+
+ PassInfo Info() const override { return pass_info; }
+
+ static constexpr const char* _type_key = "relax.DataflowBlockPass";
+ TVM_DECLARE_FINAL_OBJECT_INFO(DataflowBlockPassNode, PassNode);
+};
+
+/*! \brief Helper to apply the passed function to dataflow blocks.*/
+class DataflowBlockMutator : public ExprMutator {
+ public:
+ DataflowBlockMutator(
+ runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)> pass_func,
+ IRModule mod, PassContext pass_ctx)
+ : pass_func_(pass_func), mod_(mod), pass_ctx_(pass_ctx) {}
+
+ /*!
+ * \brief Rewrite the DataflowBlockNode with pass_func_
+ *
+ * This function will check that there are no rewrites of the global scope
Vars
+ * and symbolic shape Vars defined inside the dataflow block.
+ */
+ BindingBlock VisitBindingBlock_(const DataflowBlockNode* n) final {
+ // collect Global Scope Vars and Symbolic Vars inside the DataflowBlock
+ Map<String, Var> global_scope_vars;
+ Map<String, tir::Var> symbolic_vars;
+ for (const Binding& binding : n->bindings) {
+ Var var = binding->var;
+ if (const auto* match_cast = binding.as<MatchCastNode>()) {
+ auto collected_vars =
SymbolicVarCollector::Collect(match_cast->struct_info);
+ for (const tir::VarNode* var : collected_vars) {
+ symbolic_vars.Set(var->name_hint, GetRef<tir::Var>(var));
+ }
+ }
+ if (!var.as<DataflowVarNode>()) {
+ global_scope_vars.Set(var->name_hint(), var);
+ }
+ }
+
+ // apply pass_func_ to the DataflowBlock
+ DataflowBlock block = GetRef<DataflowBlock>(n);
+ DataflowBlock updated_block = pass_func_(block, mod_, pass_ctx_);
+
+ // raise error if there are updates of recorded Global Scope Vars and
Symbolic Vars
+ for (const Binding& binding : updated_block->bindings) {
+ Var var = binding->var;
+ if (const auto* match_cast = binding.as<MatchCastNode>()) {
+ auto collected_vars =
SymbolicVarCollector::Collect(match_cast->struct_info);
+ for (const tir::VarNode* var : collected_vars) {
+ if (symbolic_vars.count(var->name_hint) > 0) {
+ tir::Var old_var = symbolic_vars[var->name_hint];
+ ICHECK(var == old_var.get())
+ << "Error: DataflowBlock Pass should not rewrite any Symbolic
Var.";
+ symbolic_vars.erase(var->name_hint);
+ }
+ }
+ }
+ if (!var.as<DataflowVarNode>() &&
global_scope_vars.count(var->name_hint()) > 0) {
+ ICHECK(var.same_as(global_scope_vars[var->name_hint()]))
+ << "Error: DataflowBlock Pass should not rewrite any GlobalScope
Var.";
+ global_scope_vars.erase(var->name_hint());
+ }
+ }
+ ICHECK(global_scope_vars.empty() && symbolic_vars.empty())
+ << "Error: DataflowBlock Pass should not delete any
GlobalScope/Symbolic Var.";
+
+ return std::move(updated_block);
+ }
+
+ private:
+ class SymbolicVarCollector : public StructInfoVisitor {
+ public:
+ static std::unordered_set<const tir::VarNode*> Collect(const StructInfo&
info) {
+ SymbolicVarCollector collector;
+ collector.VisitStructInfo(info);
+ return std::move(collector.symbolic_vars_);
+ }
+
+ private:
+ void VisitStructInfoExprField(const PrimExpr& expr) final {
+ if (const tir::VarNode* sym_var = expr.as<tir::VarNode>()) {
+ symbolic_vars_.insert(sym_var);
+ }
+ }
+
+ private:
+ std::unordered_set<const tir::VarNode*> symbolic_vars_;
+ };
+
+ runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)> pass_func_;
+ IRModule mod_;
+ PassContext pass_ctx_;
+};
+
+class DataflowBlockPass : public Pass {
+ public:
+ /*!
+ * \brief The constructor
+ * \param pass_func The packed function which implements a pass.
+ * \param pass_info The pass info.
+ */
+ TVM_DLL DataflowBlockPass(
+ runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)> pass_func,
+ PassInfo pass_info);
+
+ TVM_DEFINE_OBJECT_REF_METHODS(DataflowBlockPass, Pass,
DataflowBlockPassNode);
+};
+
+DataflowBlockPass::DataflowBlockPass(
+ runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)> pass_func,
+ PassInfo pass_info) {
+ auto n = make_object<DataflowBlockPassNode>();
+ n->pass_func = std::move(pass_func);
+ n->pass_info = std::move(pass_info);
+ data_ = std::move(n);
+}
+
+// Perform IRModule -> IRModule transformations at the DataflowBlock level.
+IRModule DataflowBlockPassNode::operator()(IRModule mod, const PassContext&
pass_ctx) const {
+ DiagnosticContext previous = DiagnosticContext::Default(mod);
+
+ if (pass_ctx->diag_ctx) {
+ DiagnosticContext tmp = pass_ctx->diag_ctx.value();
+ pass_ctx->diag_ctx = previous;
+ previous = tmp;
+ } else {
+ pass_ctx->diag_ctx = previous;
+ }
+
+ ICHECK(pass_ctx->diag_ctx)
+ << "The diagnostic context was set at the top of this block, this is a
bug.";
+
+ const PassInfo& pass_info = Info();
+
+ ICHECK(mod.defined());
+
+ VLOG_CONTEXT << pass_info->name;
+ VLOG(0) << "Executing DataflowBlock pass with opt level: " <<
pass_info->opt_level;
+ VLOG(1) << "Input module:" << std::endl << mod;
+
+ IRModule updated_mod = mod->ShallowCopy();
+
+ DataflowBlockMutator dataflow_block_mutator(pass_func, updated_mod,
pass_ctx);
+ std::vector<std::pair<GlobalVar, Function> > updates;
+ for (const auto& it : updated_mod->functions) {
+ // only picks up relax::Function
+ if (auto* n = it.second.as<FunctionNode>()) {
+ Function func = GetRef<Function>(n);
+ Function updated_func =
Downcast<Function>(dataflow_block_mutator.VisitExpr(func));
+ updates.push_back({it.first, updated_func});
+ }
+ }
+
+ for (const auto& pair : updates) {
+ updated_mod->Add(pair.first, pair.second, true);
+ }
+
+ ICHECK(pass_ctx->diag_ctx)
+ << "The diagnostic context was set at the top of this block this is a
bug.";
+
+ pass_ctx->diag_ctx.value().Render();
+ pass_ctx->diag_ctx = previous;
+
+ VLOG(1) << "Output module:" << std::endl << updated_mod;
+
+ return updated_mod;
+}
+
+Pass CreateDataflowBlockPass(
+ const runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)>& pass_func,
+ int opt_level, String name, tvm::Array<String> required) {
+ PassInfo pass_info = PassInfo(opt_level, name, required);
+ return DataflowBlockPass(pass_func, pass_info);
+}
+
+TVM_REGISTER_NODE_TYPE(DataflowBlockPassNode);
+
+TVM_REGISTER_GLOBAL("relax.transform.MakeDataflowBlockPass")
+ .set_body_typed(
+ [](runtime::TypedPackedFunc<DataflowBlock(DataflowBlock, IRModule,
PassContext)> pass_func,
+ PassInfo pass_info) { return DataflowBlockPass(pass_func,
pass_info); });
+
+TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
+ .set_dispatch<DataflowBlockPassNode>([](const ObjectRef& ref, ReprPrinter*
p) {
+ auto* node = static_cast<const DataflowBlockPassNode*>(ref.get());
+ const PassInfo info = node->Info();
+ p->stream << "Run DataflowBlock pass: " << info->name << " at the
optimization level "
+ << info->opt_level;
+ });
+} // namespace transform
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_backend_transform_shape_lower.py
b/tests/python/relax/test_backend_transform_shape_lower.py
new file mode 100644
index 0000000000..bf1bc61a6e
--- /dev/null
+++ b/tests/python/relax/test_backend_transform_shape_lower.py
@@ -0,0 +1,429 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.ir import assert_structural_equal
+from tvm.relax.testing.runtime_builtin import MakeShapeCode, MatchShapeCode
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def test_const_shape_arg():
+ MS = MatchShapeCode
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Shape([1, 2]), y: R.Shape):
+ return x
+
+ @T.prim_func
+ def extra_func(H: T.Buffer[T.int64(4), "int64"]):
+ """Extra function, checks if the pass preserves it."""
+ H[T.int64(1)] = H[T.int64(0)] + T.int64(1)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Shape([1, 2]), y: R.Shape):
+ shape_heap = R.null_value()
+ _ = R.call_packed("vm.builtin.check_shape_info", x, 2, "",
sinfo_args=[R.Tuple()])
+ _ = R.call_packed("vm.builtin.check_shape_info", y, -1, "",
sinfo_args=[R.Tuple()])
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ x,
+ shape_heap,
+ 2,
+ MS.ASSERT_EQUAL_TO_IMM,
+ 1,
+ MS.ASSERT_EQUAL_TO_IMM,
+ 2,
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ return x
+
+ @T.prim_func
+ def extra_func(H: T.Buffer[T.int64(4), "int64"]):
+ H[T.int64(1)] = H[T.int64(0)] + T.int64(1)
+
+ before = Before
+ expected = Expected
+ after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+ assert_structural_equal(after, expected)
+
+
+def test_static_fn_check():
+ """Check static shape and function."""
+ MS = MatchShapeCode
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])):
+ return y
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(f: R.Callable([R.Object], R.Object), y: R.Shape([1, 2])):
+ shape_heap = R.null_value()
+ _ = R.call_packed("vm.builtin.check_func_info", f, "",
sinfo_args=[R.Tuple()])
+ _ = R.call_packed("vm.builtin.check_shape_info", y, 2, "",
sinfo_args=[R.Tuple()])
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ y,
+ shape_heap,
+ 2,
+ MS.ASSERT_EQUAL_TO_IMM,
+ 1,
+ MS.ASSERT_EQUAL_TO_IMM,
+ 2,
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ return y
+
+ before = Before
+ expected = Expected
+ after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+ assert_structural_equal(after, expected)
+
+
+def test_simple_symbolic_shape():
+ MS = MatchShapeCode
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(x: R.Tensor(["n", 2, "m"], "float32")):
+ return x
+
+ sindex = {
+ "n": 0,
+ "m": 1,
+ }
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(x: R.Tensor(["n", 2, "m"], "float32")):
+ shape_heap = R.call_builtin_with_ctx(
+ "vm.builtin.alloc_shape_heap",
+ [R.prim_value(2)],
+ sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+ )
+ _ = R.call_packed(
+ "vm.builtin.check_tensor_info", x, 3, R.dtype("float32"), "",
sinfo_args=[R.Tuple()]
+ )
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ x,
+ shape_heap,
+ 3,
+ MS.STORE_TO_HEAP,
+ sindex["n"],
+ MS.ASSERT_EQUAL_TO_IMM,
+ 2,
+ MS.STORE_TO_HEAP,
+ sindex["m"],
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ return x
+
+ before = Before
+ expected = Expected
+ after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+ assert_structural_equal(after, expected)
+
+
+def test_symbolic_compute():
+ MS = MatchShapeCode
+ MK = MakeShapeCode
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None)
+ ) -> R.Shape(ndim=3):
+ n = T.Var("n", "int64")
+ k = T.Var("k", "int64")
+ z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None))
+ return (k + 1, m, 2)
+
+ # slot assignment:
+ # 0: n, 1: m, 2:k, 3: k+1
+ sindex = {"n": 0, "m": 1, "k": 2, "k+1": 3}
+
+ @tvm.script.ir_module
+ class Expected:
+ @T.prim_func
+ def shape_func(H: T.Buffer[T.int64(4), "int64"]):
+ # generated compute function
+ H[T.int64(sindex["k+1"])] = H[T.int64(sindex["k"])] + T.int64(1)
+
+ @R.function
+ def main(
+ x: R.Tensor(["n", "m"], "float32"), y: R.Tensor(ndim=3, dtype=None)
+ ) -> R.Shape(ndim=3):
+ n = T.Var("n", "int64")
+ k = T.Var("k", "int64")
+ shape_heap = R.call_builtin_with_ctx(
+ "vm.builtin.alloc_shape_heap",
+ [R.prim_value(4)],
+ sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+ )
+ _ = R.call_packed(
+ "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "",
sinfo_args=[R.Tuple()]
+ )
+ _ = R.call_packed(
+ "vm.builtin.check_tensor_info", y, 3, R.dtype(""), "",
sinfo_args=[R.Tuple()]
+ )
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ x,
+ shape_heap,
+ 2,
+ MS.STORE_TO_HEAP,
+ sindex["n"],
+ MS.STORE_TO_HEAP,
+ sindex["m"],
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ y,
+ shape_heap,
+ 3,
+ MS.STORE_TO_HEAP,
+ sindex["k"],
+ MS.ASSERT_EQUAL_TO_LOAD,
+ sindex["m"],
+ MS.NO_OP,
+ 0,
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ _ = shape_func(shape_heap)
+ # extra assertion on y's shape after shape computation
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ y,
+ shape_heap,
+ 3,
+ MS.ASSERT_EQUAL_TO_LOAD,
+ sindex["k"],
+ MS.ASSERT_EQUAL_TO_LOAD,
+ sindex["m"],
+ MS.ASSERT_EQUAL_TO_LOAD,
+ sindex["k+1"],
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ z = R.match_cast(y, R.Tensor([k, m, k + 1], dtype=None))
+ # construct shape value for return
+ s = R.call_packed(
+ "vm.builtin.make_shape",
+ shape_heap,
+ 3,
+ MK.LOAD_SHAPE,
+ sindex["k+1"],
+ MK.LOAD_SHAPE,
+ sindex["m"],
+ MK.USE_IMM,
+ 2,
+ sinfo_args=[R.Shape(ndim=3)],
+ )
+ return s
+
+ before = Before
+ expected = Expected
+ after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+ assert_structural_equal(after, expected)
+
+
+def test_tuple_handling():
+ MS = MatchShapeCode
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tuple(
+ R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape,
R.Tensor(["n", "k"], "int32"))
+ )
+ ):
+ return x
+
+ # slot assignment:
+ sindex = {"n": 0, "m": 1, "k": 2}
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tuple(
+ R.Tensor(["n", "m"], "float32"), R.Tuple(R.Shape,
R.Tensor(["n", "k"], "int32"))
+ )
+ ):
+ shape_heap = R.call_builtin_with_ctx(
+ "vm.builtin.alloc_shape_heap",
+ [R.prim_value(3)],
+ sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+ )
+ # recursively unpack tuple for static info check
+ _ = R.call_packed("vm.builtin.check_tuple_info", x, 2, "",
sinfo_args=[R.Tuple()])
+ t0 = x[0]
+ _ = R.call_packed(
+ "vm.builtin.check_tensor_info",
+ t0,
+ 2,
+ R.dtype("float32"),
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ t1 = x[1]
+ _ = R.call_packed("vm.builtin.check_tuple_info", t1, 2, "",
sinfo_args=[R.Tuple()])
+ t1x0 = t1[0]
+ _ = R.call_packed("vm.builtin.check_shape_info", t1x0, -1, "",
sinfo_args=[R.Tuple()])
+ t1x1 = t1[1]
+ _ = R.call_packed(
+ "vm.builtin.check_tensor_info",
+ t1x1,
+ 2,
+ R.dtype("int32"),
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ # match shape checks.
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ t0,
+ shape_heap,
+ 2,
+ MS.STORE_TO_HEAP,
+ sindex["n"],
+ MS.STORE_TO_HEAP,
+ sindex["m"],
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ t1x1,
+ shape_heap,
+ 2,
+ MS.ASSERT_EQUAL_TO_LOAD,
+ sindex["n"],
+ MS.STORE_TO_HEAP,
+ sindex["k"],
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ return x
+
+ before = Before
+ expected = Expected
+ after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+ assert_structural_equal(after, expected)
+
+
+def test_return_match_check():
+ """Test when return body is not same as ret_struct_info, runtime match
check needed."""
+ MS = MatchShapeCode
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(["n", "m"], "float32"), y: R.Object
+ ) -> R.Tuple(R.Tensor(["n", "m"], "float32")):
+ return y
+
+ # slot assignment:
+ sindex = {
+ "n": 0,
+ "m": 1,
+ }
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main(
+ x: R.Tensor(["n", "m"], "float32"), y: R.Object
+ ) -> R.Tuple(R.Tensor(["n", "m"], "float32")):
+ shape_heap = R.call_builtin_with_ctx(
+ "vm.builtin.alloc_shape_heap",
+ [R.prim_value(2)],
+ sinfo_args=[R.Tensor(ndim=1, dtype="int64")],
+ )
+ _ = R.call_packed(
+ "vm.builtin.check_tensor_info", x, 2, R.dtype("float32"), "",
sinfo_args=[R.Tuple()]
+ )
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ x,
+ shape_heap,
+ 2,
+ MS.STORE_TO_HEAP,
+ sindex["n"],
+ MS.STORE_TO_HEAP,
+ sindex["m"],
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ _ = R.call_packed("vm.builtin.check_tuple_info", y, 1, "",
sinfo_args=[R.Tuple()])
+ # emit runtime function call since y do not have the right type.
+ y1 = R.call_packed("vm.builtin.tuple_getitem", y, 0,
sinfo_args=[R.Object])
+ # run check
+ _ = R.call_packed(
+ "vm.builtin.check_tensor_info",
+ y1,
+ 2,
+ R.dtype("float32"),
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+ # shape check
+ _ = R.call_packed(
+ "vm.builtin.match_shape",
+ y1,
+ shape_heap,
+ 2,
+ MS.ASSERT_EQUAL_TO_LOAD,
+ sindex["n"],
+ MS.ASSERT_EQUAL_TO_LOAD,
+ sindex["m"],
+ "",
+ sinfo_args=[R.Tuple()],
+ )
+
+ return y
+
+ before = Before
+ expected = Expected
+ after = relax.transform.VMShapeLower(emit_err_ctx=False)(before)
+ assert_structural_equal(after, expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()