This is an automated email from the ASF dual-hosted git repository.
csullivan pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new ec4a8b3011 [Unity] Implement relax.Function.bind_params (#15626)
ec4a8b3011 is described below
commit ec4a8b301182ceba0cc6b9b07a74720287f343dc
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 6 12:23:30 2023 -0700
[Unity] Implement relax.Function.bind_params (#15626)
* [Unity] Implement relax.Function.bind_params
Similar to `relax.Function.bind_symbolic_vars`, implemented in
https://github.com/apache/tvm/pull/15509, this commit introduces
`relax.Function.bind_params` to allow Relax parameters to be
manipulated on a per-function basis. This utility function and the
existing `BindParams` transform both use the same underlying
implementation.
* Update relay_translator unit tests to avoid duplicate binding
* Updated unit test that attempted to bind non-existent parameter
---
include/tvm/relax/transform.h | 2 +-
include/tvm/relax/utils.h | 22 +++
python/tvm/relax/expr.py | 50 +++++++
python/tvm/relax/transform/transform.py | 11 +-
src/relax/transform/bind_params.cc | 116 ++++++++++-----
src/relax/utils.cc | 56 ++++++++
tests/python/relax/test_bind_params.py | 156 +++++++++++++++++++++
tests/python/relax/test_relay_translator.py | 6 +-
tests/python/relax/test_transform_bind_params.py | 52 +++++++
tests/python/relax/test_transform_fold_constant.py | 3 +-
10 files changed, 432 insertions(+), 42 deletions(-)
diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 0922126b78..45a31b0911 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -180,7 +180,7 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = false);
*
* \return The Pass.
*/
-TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray>
params);
+TVM_DLL Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params);
/*!
* \brief Bind symbolic vars to constant shape values.
diff --git a/include/tvm/relax/utils.h b/include/tvm/relax/utils.h
index 1a6d5d4a52..0e0249b863 100644
--- a/include/tvm/relax/utils.h
+++ b/include/tvm/relax/utils.h
@@ -24,7 +24,9 @@
#ifndef TVM_RELAX_UTILS_H_
#define TVM_RELAX_UTILS_H_
+#include <tvm/arith/analyzer.h>
#include <tvm/ir/module.h>
+#include <tvm/relax/expr.h>
#include <tvm/runtime/logging.h>
namespace tvm {
@@ -48,6 +50,26 @@ namespace relax {
TVM_DLL Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>& binds,
const tvm::Map<tir::Var, PrimExpr>& symbolic_var_map = {});
+/*!
+ * \brief Infer a binding map for symbolic variables
+ *
+ * If a set of relax variables are replaced within an expression, this
+ * may result in removal of the definition site of a symbolic
+ * variable. This utility function determines the symbolic variable
+ * replacements that can be inferred based on the replaced relax
+ * variables, and can be used alongside the `Bind` utility function to
+ * replace both the relax variables and the implied symbolic
+ * variables.
+ *
+ * \param binds A map of relax variables to relax expressions
+ *
+ * \param analyzer The analyzer to use for simplifications
+ *
+ * \return A map of TIR variables to TIR expressions
+ */
+TVM_DLL tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
+ const tvm::Map<relax::Var, relax::Expr>& binds, arith::Analyzer* analyzer);
+
/*!
* \brief Check if the given StructInfo is for a boolean scalar (tensor of
rank 0 with a boolean
* dtype).
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index 49b91ffb3d..cd5dfa2863 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -657,6 +657,56 @@ class Function(BaseFunc, Scriptable):
return _ffi_api.FunctionBindSymbolicVars(self, binding_map) # type:
ignore
+ def bind_params(
+ self,
+ binding_map: Mapping[
+ Union[str, Var],
+ Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray,
Expr],
+ ],
+ ) -> "Function":
+ """Return a new function with updated symbolic variable
+
+ Parameters
+ ----------
+ binding_map: Mapping[
+ Union[str, Var],
+ Union[int, float, PrimExpr, tvm.runtime.NDArray, _np.ndarray,
Expr],
+ ]
+
+ The mapping of values to be replaced.
+
+ Keys may be either a `relax.Var` or a string name of the
+ Relax variable. If the variables are referred to by name,
+ the name must uniquely identify a parameter in the
+ function.
+
+ Values must be a relax expression, or a value that is
+ convertible into a relax expression. The value must be
+ compatible with the variable being replaced.
+
+ Returns
+ -------
+ func: Function
+
+ The updated function
+ """
+
+ def _normalize_value(value):
+ # Conversions that must occur prior to the FFI
+ # conversions.
+ if isinstance(value, int):
+ # Relax uses int64 for symbolic variables, but the FFI
+ # converts python integers into int32.
+ return tvm.tir.const(value, "int64")
+ elif isinstance(value, (_np.ndarray, tvm.nd.NDArray)):
+ return tvm.relax.const(value)
+ else:
+ return value
+
+ binding_map = {key: _normalize_value(value) for key, value in
binding_map.items()}
+
+ return _ffi_api.FunctionBindParams(self, binding_map) # type: ignore
+
@tvm._ffi.register_object("relax.expr.ExternFunc")
class ExternFunc(BaseFunc):
diff --git a/python/tvm/relax/transform/transform.py
b/python/tvm/relax/transform/transform.py
index 5c9a2ae554..13874aa044 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -405,7 +405,7 @@ def AttachGlobalSymbol() -> tvm.ir.transform.Pass:
def BindParams(
func_name: str,
- params: Dict[str, Union[tvm.runtime.NDArray, np.ndarray]],
+ params: Dict[Union[str, Var], Union[tvm.runtime.NDArray, np.ndarray]],
) -> tvm.ir.transform.Pass:
"""Bind params of function of the module to constant tensors.
@@ -415,8 +415,13 @@ def BindParams(
func_name: str
The function name to be bound
- params : Dict[str, Union[tvm.runtime.NDArray, np.ndarray]]
- The map from param name to constant tensors.
+ params : Dict[
+ Union[str,relax.Var],
+ Union[tvm.runtime.NDArray, np.ndarray],
+ ]
+
+ The map from parameter or parameter name name to constant
+ tensors.
Returns
-------
diff --git a/src/relax/transform/bind_params.cc
b/src/relax/transform/bind_params.cc
index c444a84f44..27931b6017 100644
--- a/src/relax/transform/bind_params.cc
+++ b/src/relax/transform/bind_params.cc
@@ -25,6 +25,7 @@
#include <tvm/relax/type.h>
#include <tvm/tir/op.h>
+#include <tuple>
#include <utility>
namespace tvm {
@@ -81,45 +82,88 @@ void MatchSymbolicVar(const Expr& arg, const Expr& constant,
}
}
+std::tuple<Map<Var, Expr>, Map<tir::Var, PrimExpr>> NormalizeBindings(
+ const Function& func, const Map<ObjectRef, ObjectRef>& untyped_params) {
+ ICHECK(func.defined());
+ ICHECK(untyped_params.defined());
+
+ // Map from string to the variable(s) with that name.
+ std::unordered_map<std::string, Array<relax::Var>> string_lookup;
+ std::unordered_set<const relax::VarNode*> var_set;
+ for (const auto& param : func->params) {
+ string_lookup[param->name_hint()].push_back(param);
+ var_set.insert(param.get());
+ }
+
+ Map<relax::Var, relax::Expr> relax_var_remap;
+
+ auto normalize_key = [&](ObjectRef obj) -> relax::Var {
+ if (auto opt_str = obj.as<String>()) {
+ std::string str = opt_str.value();
+ auto it = string_lookup.find(str);
+ CHECK(it != string_lookup.end())
+ << "Function does not have parameter with name \"" << str << "\". "
+ << "Function parameters are named "
+ << func->params.Map([](const auto& param) { return
param->name_hint(); });
+ CHECK_EQ(it->second.size(), 1)
+ << "Function contains multiple parameters with name \"" << str <<
"\". "
+ << "The Relax variables " << it->second << " are all named \"" <<
str << "\"";
+ auto var = it->second[0];
+ CHECK(!relax_var_remap.count(var))
+ << "Remap of variable " << var << " was defined multiple times";
+
+ return var;
+ } else if (auto opt_var = obj.as<relax::Var>()) {
+ auto var = opt_var.value();
+ CHECK(!relax_var_remap.count(var))
+ << "Remap of variable " << var << " was defined multiple times";
+ CHECK(var_set.count(var.get()))
+ << "Function does not use Relax variable " << var << " as a
parameter. "
+ << "Function parameters are " << func->params;
+ return var;
+ } else {
+ LOG(FATAL)
+ << "Expected bound parameter to be a relax::Var, "
+ << " or a string that uniquely identifies a relax::Var param within
the function. "
+ << "However, received object " << obj << " of type " <<
obj->GetTypeKey();
+ }
+ };
+ auto normalize_value = [&](ObjectRef obj) -> relax::Expr {
+ if (auto opt = obj.as<relax::Expr>()) {
+ return opt.value();
+ } else if (auto opt = obj.as<runtime::NDArray>()) {
+ return Constant(opt.value());
+ } else {
+ LOG(FATAL) << "Cannot coerce object of type " << obj->GetTypeKey()
+ << " into relax expression";
+ }
+ };
+
+ for (const auto& [key, value] : untyped_params) {
+ relax_var_remap.Set(normalize_key(key), normalize_value(value));
+ }
+
+ arith::Analyzer analyzer;
+ Map<tir::Var, PrimExpr> symbolic_var_map =
InferSymbolicVarMap(relax_var_remap, &analyzer);
+
+ // for (const auto& [bind_param, bind_expr] : relax_var_remap) {
+ // MatchSymbolicVar(bind_param, bind_expr, &symbolic_var_map, &analyzer);
+ // }
+
+ return {relax_var_remap, symbolic_var_map};
+}
+
/*!
* \brief Bind params to function by using name
* \param func Relax function
* \param params params dict
* \return Function
*/
-inline Function BindParamsByName(Function func, const Map<String,
runtime::NDArray>& params) {
- std::unordered_map<std::string, Var> name_dict;
- std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> repeat_var;
- for (auto arg : func->params) {
- const auto& name = arg->name_hint();
- if (name_dict.count(name)) {
- repeat_var.insert(name_dict[name]);
- } else {
- name_dict[name] = arg;
- }
- }
+Function FunctionBindParams(Function func, const Map<ObjectRef, ObjectRef>&
untyped_params) {
+ auto [bind_dict, symbolic_var_map] = NormalizeBindings(func, untyped_params);
- arith::Analyzer analyzer;
- Map<Var, Expr> bind_dict;
- Map<tir::Var, PrimExpr> symbolic_var_map;
-
- for (auto& kv : params) {
- if (name_dict.count(kv.first) == 0) {
- continue;
- }
- const Var& arg = name_dict.at(kv.first);
- if (repeat_var.count(arg)) {
- LOG(FATAL) << "ValueError: Multiple args in the function have name " <<
kv.first;
- }
- Expr const_expr = Constant(kv.second);
- bind_dict.Set(arg, const_expr);
- MatchSymbolicVar(arg, const_expr, &symbolic_var_map, &analyzer);
- }
Expr bound_expr = Bind(func, bind_dict, symbolic_var_map);
- Function ret = Downcast<Function>(bound_expr);
- ICHECK(ret.defined()) << "The returning type is expected to be a Relax
Function."
- << "\n";
- return ret;
+ return Downcast<Function>(bound_expr);
}
/*!
@@ -129,7 +173,7 @@ inline Function BindParamsByName(Function func, const
Map<String, runtime::NDArr
* \param param The param dict
* \return The module after binding params.
*/
-IRModule BindParam(IRModule m, String func_name, Map<String, runtime::NDArray>
param) {
+IRModule BindParam(IRModule m, String func_name, Map<ObjectRef, ObjectRef>
bind_params) {
IRModuleNode* new_module = m.CopyOnWrite();
Map<GlobalVar, BaseFunc> functions = m->functions;
for (const auto& func_pr : functions) {
@@ -138,13 +182,13 @@ IRModule BindParam(IRModule m, String func_name,
Map<String, runtime::NDArray> p
// Use global_symbol if it's external linkage
Optional<String> gsymbol =
relax_f->GetAttr<String>(tvm::attr::kGlobalSymbol);
if (gsymbol.defined() && gsymbol.value() == func_name) {
- Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f),
param);
+ Function f_after_bind =
FunctionBindParams(GetRef<Function>(relax_f), bind_params);
new_module->Update(func_pr.first, f_after_bind);
}
} else {
// Use global var's name_hint if it's internal linkage
if (func_pr.first->name_hint == func_name) {
- Function f_after_bind = BindParamsByName(GetRef<Function>(relax_f),
param);
+ Function f_after_bind =
FunctionBindParams(GetRef<Function>(relax_f), bind_params);
new_module->Update(func_pr.first, f_after_bind);
}
}
@@ -153,9 +197,11 @@ IRModule BindParam(IRModule m, String func_name,
Map<String, runtime::NDArray> p
return GetRef<IRModule>(new_module);
}
+TVM_REGISTER_GLOBAL("relax.FunctionBindParams").set_body_typed(FunctionBindParams);
+
namespace transform {
-Pass BindParams(String func_name, Map<String, runtime::NDArray> params) {
+Pass BindParams(String func_name, Map<ObjectRef, ObjectRef> params) {
runtime::TypedPackedFunc<IRModule(IRModule, PassContext)> pass_func =
[=](IRModule mod, PassContext pc) { return BindParam(std::move(mod),
func_name, params); };
return CreateModulePass(pass_func, 0, "BindParams", {});
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index ccb72805e3..f8235def24 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -144,6 +144,62 @@ Expr Bind(const Expr& expr, const tvm::Map<Var, Expr>&
binds,
return ExprBinder(binds, symbolic_var_map).VisitExpr(expr);
}
+tvm::Map<tir::Var, PrimExpr> InferSymbolicVarMap(
+ const tvm::Map<relax::Var, relax::Expr>& relax_var_remap, arith::Analyzer*
analyzer) {
+ tvm::Map<tir::Var, PrimExpr> tir_var_remap;
+
+ auto bind_from_prim_expr = [&tir_var_remap](const PrimExpr& var_shape,
+ const PrimExpr& expr_shape) {
+ if (auto var = var_shape.as<tir::Var>()) {
+ tir_var_remap.Set(var.value(), expr_shape);
+ }
+ };
+
+ auto bind_from_shape = [&bind_from_prim_expr](const StructInfo& var, const
StructInfo& expr) {
+ auto var_shape = var.as<ShapeStructInfoNode>();
+ if (!var_shape) return;
+ if (!var_shape->values.defined()) return;
+
+ auto expr_shape = expr.as<ShapeStructInfoNode>();
+ CHECK(expr_shape) << "Cannot bind expression with struct type " << expr
+ << " to variable with struct type " << var;
+ if (!expr_shape->values.defined()) return;
+
+ auto var_shape_arr = var_shape->values.value();
+ auto expr_shape_arr = expr_shape->values.value();
+ CHECK_EQ(var_shape_arr.size(), expr_shape_arr.size())
+ << "Cannot bind shape " << expr_shape_arr << " of dimension " <<
expr_shape_arr.size()
+ << " to variable with shape " << var_shape_arr << " of dimension " <<
var_shape_arr.size();
+ for (size_t i = 0; i < var_shape_arr.size(); i++) {
+ bind_from_prim_expr(var_shape_arr[i], expr_shape_arr[i]);
+ }
+ };
+
+ auto bind_from_tensor = [&bind_from_shape](const StructInfo& var, const
StructInfo& expr) {
+ auto var_tensor = var.as<TensorStructInfoNode>();
+ if (!var_tensor) return;
+ if (!var_tensor->shape.defined()) return;
+
+ auto expr_tensor = expr.as<TensorStructInfoNode>();
+ CHECK(expr_tensor) << "Cannot bind expression with struct type " << expr
+ << " to variable with struct type " << var;
+ if (!expr_tensor->shape.defined()) return;
+
+ bind_from_shape(GetStructInfo(var_tensor->shape.value()),
+ GetStructInfo(expr_tensor->shape.value()));
+ };
+
+ for (const auto& [relax_var, relax_expr] : relax_var_remap) {
+ auto var_sinfo = GetStructInfo(relax_var);
+ auto expr_sinfo = GetStructInfo(relax_expr);
+
+ bind_from_tensor(var_sinfo, expr_sinfo);
+ bind_from_shape(var_sinfo, expr_sinfo);
+ }
+
+ return tir_var_remap;
+}
+
bool IsBoolStructInfo(const StructInfo& sinfo, bool permit_unknown_rank,
bool permit_unknown_dtype) {
const TensorStructInfoNode* tt = sinfo.as<TensorStructInfoNode>();
diff --git a/tests/python/relax/test_bind_params.py
b/tests/python/relax/test_bind_params.py
new file mode 100644
index 0000000000..a92e4fe8e5
--- /dev/null
+++ b/tests/python/relax/test_bind_params.py
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax, tir
+from tvm.script import relax as R
+
+import numpy as np
+import pytest
+
+param_specification = tvm.testing.parameter("by_string", "by_var")
+param_shape = tvm.testing.parameter("static_shape", "dynamic_shape", "ndim",
"arbitrary")
+tensor_param_dtype = tvm.testing.parameter("float32", None)
+
+
+def test_bind_tensor_param(param_specification, param_shape,
tensor_param_dtype):
+ if param_shape == "static_shape":
+ shape = [16]
+ ndim = -1
+ elif param_shape == "dynamic_shape":
+ shape = [tir.Var("N", "int64")]
+ ndim = -1
+ elif param_shape == "ndim":
+ shape = None
+ ndim = 1
+ elif param_shape == "arbitrary":
+ shape = None
+ ndim = -1
+ else:
+ raise ValueError(f"Unknown param_shape: {param_shape}")
+
+ @R.function
+ def before(A: R.Tensor(shape, ndim=ndim, dtype=tensor_param_dtype)):
+ R.func_attr({"global_symbol": "main"})
+ B: R.Tensor(shape=shape, ndim=ndim, dtype=tensor_param_dtype) = A
+ out = R.add(B, B)
+ return out
+
+ np_data = np.arange(16).astype("float32")
+ inlined_relax_const = relax.const(np_data)
+
+ @R.function
+ def expected() -> R.Tensor([16], "float32"):
+ R.func_attr({"global_symbol": "main"})
+ B = inlined_relax_const
+ out = R.add(B, B)
+ return out
+
+ if param_specification == "by_string":
+ var = "A"
+ elif param_specification == "by_var":
+ var = before.params[0]
+ else:
+ raise ValueError("Unknown param_specification: {param_specification}")
+
+ after = before.bind_params({var: np.arange(16).astype("float32")})
+
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_bind_shape_param(param_shape):
+ if param_shape == "static_shape":
+ shape = [16]
+ ndim = -1
+ elif param_shape == "dynamic_shape":
+ shape = [tir.Var("N", "int64")]
+ ndim = -1
+ elif param_shape == "ndim":
+ shape = None
+ ndim = 1
+ elif param_shape == "arbitrary":
+ shape = None
+ ndim = -1
+ else:
+ raise ValueError(f"Unknown param_shape: {param_shape}")
+
+ @R.function
+ def before(A: R.Shape(shape, ndim=ndim)):
+ R.func_attr({"global_symbol": "main"})
+ B: R.Shape(shape, ndim=ndim) = A
+ return B
+
+ @R.function
+ def expected() -> R.Shape([16]):
+ R.func_attr({"global_symbol": "main"})
+ B = R.ShapeExpr([16])
+ return B
+
+ after = before.bind_params({"A": relax.ShapeExpr([16])})
+
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+prim_value_dtype = tvm.testing.parameter("int64", "int32", "float32")
+
+
[email protected](reason="Depends on relax.PrimValue holding a tir.PrimExpr,
PR#15577")
+def test_bind_prim_value(prim_value_dtype):
+ @R.function
+ def before(A: R.Prim(value="N", dtype=prim_value_dtype)):
+ R.func_attr({"global_symbol": "main"})
+ B: R.Prim(value="N", dtype=prim_value_dtype) = A
+ return B
+
+ @R.function
+ def expected() -> R.Prim(value=16, dtype=prim_value_dtype):
+ R.func_attr({"global_symbol": "main"})
+ B = R.PrimValue(value=16, dtype=dtype)
+ return B
+
+ after = before.bind_params({"A": relax.PrimValue(tir.const(16,
prim_value_dtype))})
+
+ tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_error_on_unknown_var():
+ @R.function
+ def before(A: R.Tensor([16], dtype="float32")):
+ R.func_attr({"global_symbol": "main"})
+ return A
+
+ unknown_var = relax.Var("unknown_var")
+
+ with pytest.raises(tvm.TVMError):
+ before.bind_params({unknown_var: np.arange(16).astype("float32")})
+
+
+def test_error_on_unknown_var_name():
+ @R.function
+ def before(A: R.Tensor([16], dtype="float32")):
+ R.func_attr({"global_symbol": "main"})
+ return A
+
+ with pytest.raises(tvm.TVMError):
+ before.bind_params({"unknown_var_name":
np.arange(16).astype("float32")})
+
+
+if __name__ == "__main__":
+ tvm.testing.main()
diff --git a/tests/python/relax/test_relay_translator.py
b/tests/python/relax/test_relay_translator.py
index 6790ae851b..c32382dbb0 100644
--- a/tests/python/relax/test_relay_translator.py
+++ b/tests/python/relax/test_relay_translator.py
@@ -126,10 +126,14 @@ def test_verify_e2e_translation_gpu(layout, batch_size,
image_shape):
def verify_extracted_tasks(target_str, layout, batch_size, image_shape,
module_equality):
target = Target(target_str)
relay_mod, params = get_resnet(batch_size, "float32", layout, image_shape)
+ # Parameters can be bound either as part of the `from_relay`
+ # conversion, or as part of the `extract_tasks` method. However,
+ # they shouldn't be used in both locations, because
+ # `relax.BindParams` validates that there exists an unbound
+ # parameter of the specified name.
relax_mod = relay_translator.from_relay(
relay_mod["main"],
target,
- params,
pass_config={
"relay.backend.use_meta_schedule": True,
"relay.FuseOps.max_depth": 1, # Disable relay fusion
diff --git a/tests/python/relax/test_transform_bind_params.py
b/tests/python/relax/test_transform_bind_params.py
index 8e760b6fd7..9e212693f9 100644
--- a/tests/python/relax/test_transform_bind_params.py
+++ b/tests/python/relax/test_transform_bind_params.py
@@ -123,5 +123,57 @@ def test_bind_params_symbolic_vars():
)
+param_specification = tvm.testing.parameter("by_string", "by_var")
+
+
+def test_bind_params_by_var_obj(param_specification):
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(A: R.Tensor([16], "float32")):
+ return A
+
+ np_data = np.arange(16).astype("float32")
+ inlined_relax_const = relax.const(np_data)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main():
+ return inlined_relax_const
+
+ if param_specification == "by_string":
+ var = "A"
+ elif param_specification == "by_var":
+ var = Before["main"].params[0]
+ else:
+ raise ValueError("Unknown param_specification: {param_specification}")
+
+ After = relax.transform.BindParams("main", {var: np_data})(Before)
+
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_bind_params_by_var_name():
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(A: R.Tensor([16], "float32")):
+ return A
+
+ np_data = np.arange(16).astype("float32")
+ inlined_relax_const = relax.const(np_data)
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main():
+ return inlined_relax_const
+
+ After = relax.transform.BindParams("main", {"A": np_data})(Before)
+
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/python/relax/test_transform_fold_constant.py
b/tests/python/relax/test_transform_fold_constant.py
index b8ad5c4487..c2a3bd5092 100644
--- a/tests/python/relax/test_transform_fold_constant.py
+++ b/tests/python/relax/test_transform_fold_constant.py
@@ -378,8 +378,7 @@ def
test_fold_multiple_relax_ops_with_data_dependent_reshape():
before = gen_mod(Module, "before", {"c0": c0_np, "c1": c1_np})
assert relax.analysis.well_formed(before)
- c2_np = np.multiply(np.add(c0_np, c0_np), c1_np)
- expected = gen_mod(Module, "expected", {"c2": c2_np})
+ expected = gen_mod(Module, "expected", {})
after = relax.transform.FoldConstant()(before)
tvm.ir.assert_structural_equal(after, expected)