This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ba475010db [Unity][Pass] BindParams pass, FoldConstant pass (#14016)
ba475010db is described below
commit ba475010db82cfe21f13324cc9e8e4ae52cdcc22
Author: Sunghyun Park <[email protected]>
AuthorDate: Thu Feb 16 19:35:03 2023 -0800
[Unity][Pass] BindParams pass, FoldConstant pass (#14016)
This PR introduces FoldConstant/BindParam passes.
---
include/tvm/ir/function.h | 133 ++++++----
include/tvm/relax/transform.h | 15 ++
python/tvm/relax/transform/transform.py | 62 ++++-
src/relax/transform/bind_params.cc | 113 +++++++++
src/relax/transform/fold_constant.cc | 230 +++++++++++++++++
tests/python/relax/test_transform_bind_params.py | 75 ++++++
tests/python/relax/test_transform_fold_constant.py | 280 +++++++++++++++++++++
7 files changed, 861 insertions(+), 47 deletions(-)
diff --git a/include/tvm/ir/function.h b/include/tvm/ir/function.h
index 1493544e73..381ea6b8d6 100644
--- a/include/tvm/ir/function.h
+++ b/include/tvm/ir/function.h
@@ -65,6 +65,68 @@ enum class CallingConv : int {
kDeviceKernelLaunch = 2,
};
+/*!
+ * \brief Supported linkage types.
+ */
+enum class LinkageType : int {
+ /*!
+ * \brief Internal linkage.
+ */
+ kInternal = 0,
+ /*!
+ * \brief External linkage.
+ - Function with external linkage should have a global symbol attached to it.
+ */
+ kExternal = 1
+};
+
+/*!
+ * \brief Generic attribute names that can be attached to any function.
+ *
+ * \sa tvm::tir::attr, tvm::relay::attr
+ */
+namespace attr {
+/*!
+ * \brief Indicates the special calling convention.
+ *
+ * Type: Integer
+ *
+ * \sa tvm::CallingConv
+ */
+constexpr const char* kCallingConv = "calling_conv";
+
+/*!
+ * \brief Compilation target of the function.
+ *
+ * Type: Target
+ *
+ * \sa tvm::Target
+ */
+constexpr const char* kTarget = "target";
+
+/*!
+ * \brief Global linker symbol of the function in generated code.
+ *
+ * This option forces the code generator to name the
+ * function with the given.
+ *
+ * For example, we could set a global_symbol of a function
+ * early to make sure that we can always refer to it by
+ * the symbol name in the generated DLL.
+ *
+ * We should not set the attribute for local functions,
+ * so that the compiler can freely rename them.
+ *
+ * A unique global symbol will be automatically assigned
+ * to each function in the module before the target code
+ * generation phase.
+ *
+ * Type: String
+ */
+constexpr const char* kGlobalSymbol = "global_symbol";
+
+} // namespace attr
+
/*!
* \brief Base node of all functions.
*
@@ -130,6 +192,31 @@ class BaseFuncNode : public RelayExprNode {
* \endcode
*/
bool HasNonzeroAttr(const std::string& attr_key) const { return
attrs.HasNonzeroAttr(attr_key); }
+ /*!
+ * \brief Get the type of the linkage.
+ *
+ * Currently, we only consider external/internal linkage.
+ * This can be extended in the future when necessary.
+ *
+ * \return Linkage type.
+ *
+ * \code
+ *
+ * void Example(const BaseFunc& f) {
+ * if (f->GetLinkageType() == tvm::LinkageType::kExternal) {
+ * // Do not remove a function with external linkage
+ * }
+ * }
+ *
+ * \endcode
+ */
+
+ LinkageType GetLinkageType() const {
+ if (GetAttr<String>(attr::kGlobalSymbol))
+ return LinkageType::kExternal;
+ else
+ return LinkageType::kInternal;
+ }
static constexpr const char* _type_key = "BaseFunc";
static constexpr const uint32_t _type_child_slots = 2;
@@ -145,51 +232,5 @@ class BaseFunc : public RelayExpr {
TVM_DEFINE_OBJECT_REF_METHODS(BaseFunc, RelayExpr, BaseFuncNode);
};
-/*!
- * \brief Generic attribute names that can be attached to any function.
- *
- * \sa tvm::tir::attr, tvm::relay::attr
- */
-namespace attr {
-/*!
- * \brief Indicates the special calling convention.
- *
- * Type: Integer
- *
- * \sa tvm::CallingConv
- */
-constexpr const char* kCallingConv = "calling_conv";
-
-/*!
- * \brief Compilation target of the function.
- *
- * Type: Target
- *
- * \sa tvm::Target
- */
-constexpr const char* kTarget = "target";
-
-/*!
- * \brief Global linker symbol of the function in generated code.
- *
- * This option forces the code generator to name the
- * function with the given.
- *
- * For example, we could set a global_symbol of a function
- * early to make sure that we can always refer to it by
- * the symbol name in the generated DLL.
- *
- * We should not set the attribute for local functions,
- * so that the compiler can freely rename them.
- *
- * A unique global symbol will be automatically assigned
- * to each function in the module before the target code
- * generation phase.
- *
- * Type: String
- */
-constexpr const char* kGlobalSymbol = "global_symbol";
-
-} // namespace attr
} // namespace tvm
#endif // TVM_IR_FUNCTION_H_
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index ff98b16d25..dab062588a 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -99,7 +99,22 @@ TVM_DLL Pass RewriteDataflowReshape();
* \return The Pass.
*/
TVM_DLL Pass AttachGlobalSymbol();
+/*!
+ * \brief Bind params of function of the module to constant tensors.
+ *
+ * \param func_name The name of the function to bind parameters.
+ * \param params The parameters to bind.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray>
params);
+/*!
+ * \brief Fold constant expressions.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass FoldConstant();
} // namespace transform
} // namespace relax
} // namespace tvm
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 1a525431dd..745a26a4da 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -19,7 +19,8 @@
import functools
import inspect
import types
-from typing import Callable, Union
+from typing import Callable, Dict, Union, Optional, List
+import numpy as np # type: ignore
import tvm.ir
from . import _ffi_api
@@ -115,6 +116,65 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
return _ffi_api.AttachGlobalSymbol() # type: ignore
+def BindParams(
+ func_name: str,
+ params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]],
+) -> tvm.ir.transform.Pass:
+ """Bind params of function of the module to constant tensors.
+
+ Parameters
+ ----------
+
+ func_name: str
+ The function name to be bound
+
+ params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]]
+ The map from param name to constant tensors.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ tvm_params = {}
+ for k, v in params.items():
+ if isinstance(v, np.ndarray):
+ v = tvm.nd.array(v)
+ assert isinstance(
+ v, tvm.runtime.NDArray
+ ), f"param values are expected to be TVM.NDArray or numpy.ndarray, but
got {type(v)}"
+ tvm_params[k] = v
+
+ return _ffi_api.BindParams(func_name, tvm_params) # type: ignore
+
+
+def RemoveUnusedFunctions(entry_functions: Optional[List[str]] = None) ->
tvm.ir.transform.Pass:
+ """Remove unused relax/prim functions without external linkage in a
IRModule.
+
+ Parameters
+ ----------
+ entry_functions: Optional[List[str]]
+ The set of entry functions to start from.
+
+ Returns
+ -------
+ ret : tvm.transform.Pass
+ The registered pass to remove unused functions.
+ """
+ if entry_functions is None:
+ entry_functions = ["main"]
+ return _ffi_api.RemoveUnusedFunctions(entry_functions) # type: ignore
+
+
+def FoldConstant() -> tvm.ir.transform.Pass:
+ """Fold constant expressions.
+
+ Returns
+ -------
+ ret: tvm.ir.transform.Pass
+ """
+ return _ffi_api.FoldConstant() # type: ignore
+
+
def AnnotateTIROpPattern() -> tvm.ir.transform.Pass:
"""Annotate Op Pattern Kind for TIR functions
diff --git a/src/relax/transform/bind_params.cc
b/src/relax/transform/bind_params.cc
new file mode 100644
index 0000000000..1de8d94461
--- /dev/null
+++ b/src/relax/transform/bind_params.cc
@@ -0,0 +1,113 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/function.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+#include <tvm/tir/op.h>
+
+#include <utility>
+
+namespace tvm {
+namespace relax {
+
+/*!
+ * \brief Bind params to function by using name
+ * \param func Relax function
+ * \param params params dict
+ * \return Function
+ */
+inline Function BindParamsByName(Function func, const Map<String,
runtime::NDArray>& params) {
+ std::unordered_map<std::string, Var> name_dict;
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
+ for (auto arg : func->params) {
+ const auto& name = arg->name_hint();
+ if (name_dict.count(name)) {
+ repeat_var.insert(name_dict[name]);
+ } else {
+ name_dict[name] = arg;
+ }
+ }
+
+ std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> bind_dict;
+ for (auto& kv : params) {
+ if (name_dict.count(kv.first) == 0) {
+ continue;
+ }
+ auto arg = name_dict.at(kv.first);
+ if (repeat_var.count(arg)) {
+ LOG(FATAL) << "ValueError: Multiple args in the function have name " <<
kv.first;
+ }
+ bind_dict[arg] = Constant(kv.second);
+ }
+ Expr bound_expr = Bind(func, bind_dict);
+ Function ret = Downcast<Function>(bound_expr);
+ ICHECK(ret.defined()) << "The returning type is expected to be a Relax
Function."
+ << "\n";
+ return ret;
+}
+
+/*!
+ * \brief Bind params to a specific function in a module
+ * \param m The module
+ * \param func_name The name of the specific function
+ * \param param The param dict
+ * \return The module after binding params.
+ */
+IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray>
param) {
+ IRModuleNode* new_module = m.CopyOnWrite();
+ Map<GlobalVar, BaseFunc> functions = m->functions;
+ for (const auto& func_pr : functions) {
+ if (const auto* relax_f = func_pr.second.as<FunctionNode>()) {
+ if (relax_f->GetLinkageType() == LinkageType::kExternal) {
+ // Use global_symbol if it's external linkage
+ Optional<String> gsymbol =
relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol);
+ if (gsymbol.defined() && gsymbol.value() == func_name) {
+ Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f),
param);
+ new_module->Update(func_pr.first, f_after_bind);
+ }
+ } else {
+ // Use global var's name_hint if it's internal linkage
+ if (func_pr.first->name_hint == func_name) {
+ Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f),
param);
+ new_module->Update(func_pr.first, f_after_bind);
+ }
+ }
+ }
+ }
+ return GetRef<IRModule>(new_module);
+}
+
+namespace transform {
+
+Pass BindParams(String func_name, Map<String, runtime::NDArray> params) {
+ runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
+ [=](IRModule mod, PassContext pc) { return BindParam(std::move(mod),
func_name, params); };
+ return CreateModulePass(pass_func, 0, "BindParams", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.BindParams").set_body_typed(BindParams);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/src/relax/transform/fold_constant.cc
b/src/relax/transform/fold_constant.cc
new file mode 100644
index 0000000000..aa55ee7f7e
--- /dev/null
+++ b/src/relax/transform/fold_constant.cc
@@ -0,0 +1,230 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#include <tvm/driver/driver_api.h>
+#include <tvm/ir/function.h>
+#include <tvm/relax/expr_functor.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+#include <tvm/tir/function.h>
+#include <tvm/tir/op.h>
+
+namespace tvm {
+namespace relax {
+
+class ConstantFolder : public ExprMutator {
+ public:
+ explicit ConstantFolder(IRModule ctx_module) : ctx_module_(ctx_module) {}
+
+ private:
+ /*!
+ * \brief Pattern match the shape inside the given struct info to a
+ * constant shape and get runtime shape tuple from it.
+ * \param struct_info The given struct info whose shape inside is to be
casted.
+ * \return The runtime shape tuple, or nullopt if it is not a constant shape.
+ * \note Only TensorStructInfo is supported at this moment. Return NullOpt
+ * if the input struct info is not TensorStructInfo.
+ */
+ static Optional<runtime::ShapeTuple> MatchConstShape(const StructInfo&
struct_info) {
+ // Only support single output for call_tir at this moment.
+ const auto* tensor_sinfo = struct_info.as<TensorStructInfoNode>();
+ if (tensor_sinfo == nullptr) {
+ return NullOpt;
+ }
+
+ const auto* shape = tensor_sinfo->shape.as<ShapeExprNode>();
+ ICHECK(shape != nullptr) << "struct info given by call_tir should have
ShapeExpr shape";
+
+ std::vector<int64_t> shape_values;
+ for (const auto v : shape->values) {
+ auto* ptr = v.as<IntImmNode>();
+ if (!ptr) return NullOpt;
+ shape_values.push_back(ptr->value);
+ }
+ return runtime::ShapeTuple(shape_values.begin(), shape_values.end());
+ }
+
+ /*!
+ * \brief Pattern match op to constant array arguments.
+ * \return The constant array arguments, or nullopt if match fails.
+ */
+ static Optional<Array<runtime::NDArray>> MatchConstArrayArgs(const
Array<Expr>& args) {
+ Array<runtime::NDArray> res;
+ for (auto arg : args) {
+ auto* ptr = arg.as<relax::ConstantNode>();
+ if (!ptr) return NullOpt;
+ res.push_back(ptr->data);
+ }
+ return res;
+ }
+
+ /*!
+ * \brief Pattern match op to a TIR function and look it up.
+ * \return The TIR function, or nullopt if pattern match fails.
+ */
+ Optional<tir::PrimFunc> MatchPrimFunc(const Expr& op) {
+ if (auto* ptr = op.as<GlobalVarNode>()) {
+ // NOTE: as check works for nullptr(returns null)
+ Optional<BaseFunc> base_func =
ctx_module_->functions.Get(GetRef<GlobalVar>(ptr));
+ if (auto* pfunc = base_func.as<tir::PrimFuncNode>()) {
+ return GetRef<tir::PrimFunc>(pfunc);
+ }
+ }
+ return NullOpt;
+ }
+
+ /*!
+ * \brief Get a cached build version of func
+ * \return The cached func, nullopt if func cannot be built.
+ */
+ Optional<PackedFunc> GetCachedBuild(tir::PrimFunc func) {
+ // TODO(tvm-team): consider another way of bulk extract and build PrimFunc
once
+ // would be helpful for future cases where PrimFunc recursively call into
each other
+ Target eval_cpu_target{"llvm"};
+
+ auto it = func_build_cache_.find(func);
+ if (it != func_build_cache_.end()) {
+ return it->second;
+ }
+ Optional<PackedFunc> build_func = NullOpt;
+
+ try {
+ // Not all the primfunc can be directly built via llvm, for example, if
a function is
+ // already scheduled to only work on GPU, we will need to skip this in
the const folder for
+ // now
+ // TODO(Hongyi): further check and narrow the scope of foldable function
+ runtime::Module rt_module =
+ build(LowerPrimFunc(func, "tir_function"), eval_cpu_target,
eval_cpu_target);
+ build_func = rt_module.GetFunction("tir_function");
+ } catch (const tvm::Error& err) {
+ // build failure may happen in which case we skip
+ DLOG(WARNING) << "Build failure for function " << func << ", Error
message: " << err.what();
+ }
+ func_build_cache_[func] = build_func;
+ return build_func;
+ }
+
+ // Try constant evaluate the function call
+ // if failed return NullOpt
+ Optional<Expr> ConstEvaluateCallTIR(tir::PrimFunc tir_func,
Array<runtime::NDArray> arr_args,
+ runtime::ShapeTuple shape, DataType
ret_type) {
+ // obtain function from the cache.
+ Optional<PackedFunc> func = GetCachedBuild(tir_func);
+ if (!func) return NullOpt;
+
+ // here the vector size has an additional + 1 because we need to put
ret_tensor at the end
+ std::vector<TVMValue> values(arr_args.size() + 1);
+ std::vector<int> type_codes(arr_args.size() + 1);
+
+ DLDevice cpu_dev = {DLDeviceType::kDLCPU, 0};
+ runtime::NDArray ret_tensor = runtime::NDArray::Empty(shape, ret_type,
cpu_dev);
+
+ // avoid set rvalue ref which get de-allocated later, store args in a
vector
+ // where temp_args[i] are lvalue ref that is stable
+ std::vector<runtime::NDArray> temp_args(arr_args.begin(), arr_args.end());
+
+ size_t arg_offset = 0;
+ for (; arg_offset < arr_args.size(); ++arg_offset) {
+ runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset,
temp_args[arg_offset]);
+ }
+ // set return value
+ runtime::TVMArgsSetter(values.data(), type_codes.data())(arg_offset++,
ret_tensor);
+
+ TVMRetValue ret;
+ // invoke
+ func.value().CallPacked(TVMArgs(values.data(), type_codes.data(),
values.size()), &ret);
+ return Constant(ret_tensor);
+ }
+
+ Expr VisitCallTIR(Call call) {
+ // call_tir needs to have at least three arguments
+ ICHECK_GE(call->args.size(), 2);
+ Optional<tir::PrimFunc> func = MatchPrimFunc(call->args[0]);
+ ICHECK(call->args[1].as<TupleNode>()) << "call_tir.args[1] must be Tuple";
+ Optional<Array<runtime::NDArray>> arr_args =
+ MatchConstArrayArgs(call->args[1].as<TupleNode>()->fields);
+ ICHECK_EQ(call->sinfo_args.size(), 1) << "call_tir should have exactly one
sinfo arg";
+ Optional<runtime::ShapeTuple> shape = MatchConstShape(call->sinfo_args[0]);
+ bool output_not_tuple = call->sinfo_args.size() == 1;
+ // Pattern 0: call constant function, const argument with const shape.
+ if (func && arr_args && shape && output_not_tuple) {
+ DynTensorType ret_type = Downcast<DynTensorType>(call->checked_type());
+ // value_or will return value if it is not null, otherwise return or
+ return ConstEvaluateCallTIR(func.value(), arr_args.value(),
shape.value(), ret_type->dtype)
+ .value_or(call);
+ }
+ // TODO(hongyi): support const-fold tuple outputs
+ return std::move(call);
+ }
+
+ using ExprMutator::VisitExpr_;
+
+ Expr VisitExpr_(const CallNode* call) final {
+ // post-order mutation
+ Call post_call = Downcast<Call>(VisitExprPostOrder_(call));
+ static const Op& call_tir_op = Op::Get("relax.call_tir");
+
+ if (call->op.same_as(call_tir_op)) {
+ return VisitCallTIR(post_call);
+ }
+ return std::move(post_call);
+ }
+
+ Expr VisitExpr_(const DataflowVarNode* op) final {
+ Optional<Expr> opt = LookupBinding(GetRef<Var>(op));
+ // `as` check checks if opt is not null and is instance of constant
+ if (opt.as<relax::ConstantNode>()) {
+ return opt.value();
+ }
+ return ExprMutator::VisitExpr_(op);
+ }
+
+ Expr VisitExpr_(const VarNode* op) final {
+ Optional<Expr> opt = LookupBinding(GetRef<Var>(op));
+ // `as` check checks if opt is not null and is instance of constant
+ if (opt.as<relax::ConstantNode>()) {
+ return opt.value();
+ }
+ return ExprMutator::VisitExpr_(op);
+ }
+
+ // the context module to lookup functions
+ IRModule ctx_module_;
+ // cache for function build, via structural equality
+ std::unordered_map<tir::PrimFunc, Optional<runtime::PackedFunc>,
StructuralHash, StructuralEqual>
+ func_build_cache_;
+};
+
+namespace transform {
+
+Pass FoldConstant() {
+ runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)>
pass_func =
+ [=](Function f, IRModule m, PassContext pc) {
+ ConstantFolder folder(m);
+ return Downcast<Function>(folder(f));
+ };
+ return CreateFunctionPass(pass_func, 0, "FoldConstant", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.FoldConstant").set_body_typed(FoldConstant);
+
+} // namespace transform
+
+} // namespace relax
+} // namespace tvm
diff --git a/tests/python/relax/test_transform_bind_params.py
b/tests/python/relax/test_transform_bind_params.py
new file mode 100644
index 0000000000..b96fb89e6c
--- /dev/null
+++ b/tests/python/relax/test_transform_bind_params.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+use_np_array = tvm.testing.parameter(False, True)
+
+
+def test_bind_params(use_np_array):
+ @tvm.script.ir_module
+ class InputModule:
+ @T.prim_func
+ def tir_matmul(x: T.handle, y: T.handle, z: T.handle) -> None:
+ T.func_attr({"global_symbol": "tir_matmul"})
+ A = T.match_buffer(x, (16, 16))
+ B = T.match_buffer(y, (16, 16))
+ C = T.match_buffer(z, (16, 16))
+ for i0, j, k0, i1, k1 in T.grid(4, 16, 4, 4, 4):
+ with T.block("matmul"):
+ vi = T.axis.S(16, i0 * 4 + i1)
+ vj = T.axis.S(16, j)
+ vk = T.axis.R(16, k0 * 4 + k1)
+ with T.init():
+ C[vi, vj] = T.float32(0)
+ C[vi, vj] = C[vi, vj] + A[vi, vk] * B[vk, vj]
+
+ @R.function
+ def main(
+ x: R.Tensor((16, 16), "float32"), w: R.Tensor((16, 16), "float32")
+ ) -> R.Tensor((16, 16), "float32"):
+ gv0 = R.call_tir(tir_matmul, (x, w), R.Tensor((16, 16),
dtype="float32"))
+ return gv0
+
+ x_np = np.random.rand(16, 16).astype(np.float32)
+ w_np = np.random.rand(16, 16).astype(np.float32)
+ x_tvm = tvm.nd.array(x_np)
+ w_tvm = tvm.nd.array(w_np)
+ params_dict = {"w": w_np if use_np_array else w_tvm}
+ mod = relax.transform.BindParams("main", params_dict)(InputModule)
+ assert len(mod["main"].params) == 1
+
+ target = tvm.target.Target("llvm")
+ ex_after = relax.vm.build(mod, target)
+ vm_after = relax.VirtualMachine(ex_after, tvm.cpu())
+ res_after = vm_after["main"](x_tvm)
+
+ ex_before = relax.vm.build(InputModule, target)
+ vm_before = relax.VirtualMachine(ex_before, tvm.cpu())
+ res_before = vm_before["main"](x_tvm, w_tvm)
+
+ tvm.testing.assert_allclose(res_before.numpy(), res_after.numpy())
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fold_constant.py
b/tests/python/relax/test_transform_fold_constant.py
new file mode 100644
index 0000000000..32ee3e7000
--- /dev/null
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -0,0 +1,280 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import tvm
+import tvm.testing
+from tvm import relax
+import numpy as np
+
+import tvm.script
+from tvm.script import tir as T, relax as R
+
+
+def gen_mod(mod, name, binding):
+ """Select relax function with name, rename to main and and bind constant.
+
+ Parameters
+ ----------
+ mod: IRModule
+ The input module
+
+ name: str
+ The name of relax function to preserve and rename to main
+
+ binding: Dict[str, array]
+ The const parameter bindings
+ """
+ funcs = {}
+ binding = {k: tvm.nd.array(v) for k, v in binding.items()}
+
+ for k, v in mod.functions.items():
+ if isinstance(v, tvm.relax.Function):
+ if k.name_hint == name:
+ # rename to main
+ gv = tvm.ir.GlobalVar("main")
+ funcs[gv] = tvm.relax.Function(v.params, v.body,
v.ret_struct_info).with_attr(
+ "global_symbol", "main"
+ )
+ else:
+ funcs[k] = v
+ mod = tvm.IRModule(funcs)
+ return relax.transform.BindParams("main", binding)(mod)
+
+
+def test_one_fold_addone():
+ # put before after in a single module
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def addone(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16),
"float32"]) -> None:
+ for i, j in T.grid(16, 16):
+ with T.block("addone"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] + T.float32(1)
+
+ @R.function
+ def before(c0: R.Tensor((16, 16), "float32")):
+ lv0 = relax.call_tir(addone, (c0,), R.Tensor((16, 16),
dtype="float32"))
+ return lv0
+
+ @R.function
+ def expected(c1: R.Tensor((16, 16), "float32")):
+ lv0 = c1
+ return c1
+
+ c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+ c1_np = c0_np + 1
+ before = gen_mod(Module, "before", {"c0": c0_np})
+ expected = gen_mod(Module, "expected", {"c1": c1_np})
+
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_one_fold_transpose():
+ # put before after in a single module
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def func(A: T.Buffer[(2, 3), "float32"], B: T.Buffer[(3, 2),
"float32"]) -> None:
+ for i, j in T.grid(3, 2):
+ with T.block("transpose"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vj, vi]
+
+ @R.function
+ def before(c0: R.Tensor((2, 3), "float32")):
+ lv0 = relax.call_tir(func, (c0,), R.Tensor((3, 2),
dtype="float32"))
+ return lv0
+
+ @R.function
+ def expected(c1: R.Tensor((3, 2), "float32")):
+ lv0 = c1
+ return c1
+
+ c0_np = np.arange(2 * 3).astype("float32").reshape(2, 3)
+ c1_np = c0_np.T
+ before = gen_mod(Module, "before", {"c0": c0_np})
+ expected = gen_mod(Module, "expected", {"c1": c1_np})
+
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_two_hop_addone():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def addone(A: T.Buffer[(2, 2), "float32"], B: T.Buffer[(2, 2),
"float32"]) -> None:
+ for i, j in T.grid(2, 2):
+ with T.block("addone"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] + T.float32(1)
+
+ @R.function
+ def before(c0: R.Tensor((2, 2), "float32")):
+ lv0 = relax.call_tir(addone, (c0,), R.Tensor((2, 2),
dtype="float32"))
+ lv1 = relax.call_tir(addone, (lv0,), R.Tensor((2, 2),
dtype="float32"))
+ return lv1
+
+ @R.function
+ def expected(c1: R.Tensor((2, 2), "float32"), c2: R.Tensor((2, 2),
"float32")):
+ lv0 = c1
+ lv1 = c2
+ return c2
+
+ c0_np = np.arange((2 * 2)).astype("float32").reshape(2, 2)
+ c1_np = c0_np + 1
+ c2_np = c1_np + 1
+ before = gen_mod(Module, "before", {"c0": c0_np})
+ expected = gen_mod(Module, "expected", {"c1": c1_np, "c2": c2_np})
+
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_dataflow_fold():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def identity(A: T.Buffer[(16, 16), "float32"], B: T.Buffer[(16, 16),
"float32"]) -> None:
+ for i, j in T.grid(16, 16):
+ with T.block("identity"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj]
+
+ @R.function
+ def before(c0: R.Tensor((16, 16), "float32")):
+ with R.dataflow():
+ gv0 = relax.call_tir(identity, (c0,), R.Tensor((16, 16),
dtype="float32"))
+ R.output(gv0)
+ return gv0
+
+ @R.function
+ def expected(c1: R.Tensor((16, 16), "float32")):
+ with R.dataflow():
+ gv0 = c1
+ R.output(gv0)
+ return c1
+
+ c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+ c1_np = c0_np
+ before = gen_mod(Module, "before", {"c0": c0_np})
+ expected = gen_mod(Module, "expected", {"c1": c1_np})
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_fold_mixed_case():
+ @tvm.script.ir_module
+ class Module:
+ # TIR function can handle different cases.
+ @T.prim_func
+ def addone(a: T.handle, b: T.handle) -> None:
+ n = T.var("int32")
+ m = T.var("int32")
+ A = T.match_buffer(a, (n, m))
+ B = T.match_buffer(b, (n, m))
+ for i, j in T.grid(n, m):
+ with T.block("addone"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] + T.float32(1)
+
+ @T.prim_func
+ def sub(
+ A: T.Buffer[(16, 16), "float32"],
+ B: T.Buffer[(16, 16), "float32"],
+ C: T.Buffer[(16, 16), "float32"],
+ ) -> None:
+ for i, j in T.grid(16, 16):
+ with T.block("sub"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ C[vi, vj] = A[vi, vj] - B[vi, vj]
+
+ @R.function
+ def before(c0: R.Tensor((16, 16), "float32"), x: R.Tensor("float32",
ndim=2)):
+ n, m = T.var("int64"), T.var("int64")
+ x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
+ # this line cannot be folded because n is unknown
+ lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16),
dtype="float32"))
+ # this line can be folded
+ lv1 = relax.call_tir(addone, (c0,), R.Tensor((16, 16),
dtype="float32"))
+ # this line can be folded because all inputs are const
+ lv2 = relax.call_tir(sub, (c0, lv1), R.Tensor((16, 16),
dtype="float32"))
+ # this line can not be folded because x's shape is unknown
+ lv3 = relax.call_tir(sub, (lv2, x), R.Tensor((16, 16),
dtype="float32"))
+ return lv3
+
+ @R.function
+ def expected(
+ c0: R.Tensor((16, 16), "float32"),
+ c1: R.Tensor((16, 16), "float32"),
+ c2: R.Tensor((16, 16), "float32"),
+ x: R.Tensor("float32", ndim=2),
+ ) -> R.Tensor:
+ n, m = T.var("int64"), T.var("int64")
+ x0 = R.match_cast(x, R.Tensor((n, m), "float32"))
+ # this line cannot be folded because n is unknown
+ lv0 = relax.call_tir(addone, (c0,), R.Tensor((n, 16),
dtype="float32"))
+ # this line can be folded
+ lv1 = c1
+ # this line can be folded because all inputs are const
+ lv2 = c2
+ # this line can not be folded because x's shape is unknown
+ lv3 = relax.call_tir(sub, (c2, x), R.Tensor((16, 16),
dtype="float32"))
+ return lv3
+
+ c0_np = np.arange((16 * 16)).astype("float32").reshape(16, 16)
+ c1_np = c0_np + 1
+ c2_np = c0_np - c1_np
+
+ before = gen_mod(Module, "before", {"c0": c0_np})
+ expected = gen_mod(Module, "expected", {"c0": c0_np, "c1": c1_np, "c2":
c2_np})
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+def test_int32_fold():
+ @tvm.script.ir_module
+ class Module:
+ @T.prim_func
+ def addone(A: T.Buffer[(16, 16), "int32"], B: T.Buffer[(16, 16),
"int32"]) -> None:
+ for i, j in T.grid(16, 16):
+ with T.block("addone"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] + T.int32(1)
+
+ @R.function
+ def before(c0: R.Tensor((16, 16), "int32")):
+ lv0 = relax.call_tir(addone, (c0,), R.Tensor((16, 16),
dtype="int32"))
+ return lv0
+
+ @R.function
+ def expected(c1: R.Tensor((16, 16), "int32")):
+ lv0 = c1
+ return c1
+
+ c0_np = np.arange((16 * 16)).astype("int32").reshape(16, 16)
+ c1_np = c0_np + 1
+ before = gen_mod(Module, "before", {"c0": c0_np})
+ expected = gen_mod(Module, "expected", {"c1": c1_np})
+
+ after = relax.transform.FoldConstant()(before)
+ tvm.ir.assert_structural_equal(after, expected)
+
+
+if __name__ == "__main__":
+ tvm.testing.main()