This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new c62dcfa89e [Unity] Propagate extra symbolic vars through
LiftTransformParams (#15699)
c62dcfa89e is described below
commit c62dcfa89e8449cbe7dc53e6792b71e4a7f51818
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Oct 12 11:34:59 2023 -0500
[Unity] Propagate extra symbolic vars through LiftTransformParams (#15699)
* [Unity][Analysis] Implemented DefinableTIRVarsInStructInfo
The existing utility `TIRVarsInStructInfo` returns all TIR variables,
regardless of whether they are suitable for a variable definition, or
are usage sites. This utility walks over the struct info once,
returning both the definable symbolic variables and the used symbolic
variables.
* [Unity][Analysis] Accept relax::Expr arg in Defined/FreeSymbolicVars
Prior to this commit, this utility could only be used with a
`relax::Function` argument. This allows individual expressions to be
inspected, even if they are not part of a complete function.
* [Unity] Propagate symbolic variables in LiftTransformParams
* Updated LiftTransformParams to use support::OrderedSet
* Fixed import after rebase
---
include/tvm/relax/analysis.h | 22 +++-
python/tvm/relax/analysis/__init__.py | 1 +
python/tvm/relax/analysis/analysis.py | 18 +++
src/relax/analysis/struct_info_analysis.cc | 95 +++++++++-----
src/relax/transform/lift_transform_params.cc | 146 +++++++++++++++++----
src/support/ordered_set.h | 27 +++-
.../relax/test_analysis_struct_info_analysis.py | 31 ++++-
.../relax/test_transform_lift_transform_params.py | 100 ++++++++++++++
8 files changed, 368 insertions(+), 72 deletions(-)
diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index 82fb73b1bd..62fb5c686d 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -268,21 +268,35 @@ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs,
const StructInfo& rhs,
*/
TVM_DLL Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo);
+/*!
+ * \brief Get the TIR variables that appear in the input struct info.
+ *
+ * Returns all symbolic variables that are definable based on, and
+ * used within, the StructInfo.
+ *
+ * \param sinfo The struct info object to be analyzed.
+ *
+ * \return A tuple of (definable,used) TIR variables. Both lists are
+ * deduplicated, each TIR variable will appear at most once, and in
+ * order of occurrence.
+ */
+TVM_DLL Array<tir::Var> DefinableTIRVarsInStructInfo(const StructInfo& sinfo);
+
/*!
* \brief Get the TIR variables that defined in the input function.
* The returned list is deduplicated - each TIR variable will appear at most
once.
- * \param func The function object to be analyzed.
+ * \param expr The relax expression (e.g. a Function) to be analyzed.
* \return The list of TIR variables that are defined in the input function.
*/
-TVM_DLL Array<tir::Var> DefinedSymbolicVars(const Function& func);
+TVM_DLL Array<tir::Var> DefinedSymbolicVars(const Expr& expr);
/*!
* \brief Get the TIR variables that are used but not defined in the input
function.
* The returned list is deduplicated - each TIR variable will appear at most
once.
- * \param func The function object to be analyzed.
+ * \param expr The relax expression (e.g. a Function) to be analyzed.
* \return The list of TIR variables that are used but not defined in the
input function.
*/
-TVM_DLL Array<tir::Var> FreeSymbolicVars(const Function& func);
+TVM_DLL Array<tir::Var> FreeSymbolicVars(const Expr& expr);
//-----------------------------------
// General IR analysis
//-----------------------------------
diff --git a/python/tvm/relax/analysis/__init__.py
b/python/tvm/relax/analysis/__init__.py
index cc0a36622e..d8454a02cc 100644
--- a/python/tvm/relax/analysis/__init__.py
+++ b/python/tvm/relax/analysis/__init__.py
@@ -22,6 +22,7 @@ from .analysis import (
all_vars,
bound_vars,
contains_impure_call,
+ definable_tir_vars_in_struct_info,
defined_symbolic_vars,
derive_call_ret_struct_info,
detect_recursion,
diff --git a/python/tvm/relax/analysis/analysis.py
b/python/tvm/relax/analysis/analysis.py
index 6ed319ed85..38f5ea2fea 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -184,6 +184,24 @@ def tir_vars_in_struct_info(sinfo: StructInfo) ->
List[tir.Var]:
return _ffi_api.TIRVarsInStructInfo(sinfo) # type: ignore
+def definable_tir_vars_in_struct_info(sinfo: StructInfo) -> List[tir.Var]:
+ """Get the TIR variables that may be defined from input struct info.
+ The returned list is deduplicated - each TIR variable will appear at most
once.
+
+ Parameters
+ ----------
+ sinfo : StructInfo
+ The struct info object to be analyzed.
+
+ Returns
+ -------
+ ret : List[tir.Var]
+
+ The list of TIR variables that can be defined from the StructInfo
+ """
+ return _ffi_api.DefinableTIRVarsInStructInfo(sinfo) # type: ignore
+
+
def defined_symbolic_vars(func: Function) -> List[Var]:
"""Get the TIR variables that defined in the input function.
The returned list is deduplicated - each TIR variable will appear at most
once.
diff --git a/src/relax/analysis/struct_info_analysis.cc
b/src/relax/analysis/struct_info_analysis.cc
index 96e51eede8..18d21cb4d4 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -920,58 +920,89 @@ TVM_REGISTER_GLOBAL("relax.analysis.StructInfoLCA")
// TIRVarsInStructInfo
//--------------------------
-Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo) {
- struct TIRVarsDetector : public StructInfoVisitor {
- void VisitShape(Array<PrimExpr> shape) {
- for (const PrimExpr& value : shape) {
- Array<tir::Var> vars = tir::UndefinedVars(value);
- for (const tir::Var& var : vars) {
- auto insert_res = tir_vars_dedup_set.insert(var.get());
- if (insert_res.second) {
- tir_vars.push_back(var);
- }
+class TIRVarsDetector : public StructInfoVisitor {
+ public:
+ enum class VarType {
+ Definition,
+ Usage,
+ };
+ TIRVarsDetector(VarType collection_type) : collection_type(collection_type)
{}
+
+ Array<tir::Var> GetTIRVars() const { return tir_vars_; }
+
+ private:
+ void VisitShape(Array<PrimExpr> shape) {
+ for (const PrimExpr& value : shape) {
+ if (collection_type == VarType::Definition) {
+ if (auto opt = value.as<tir::Var>()) {
+ RecordTIRVar(opt.value());
}
+ } else if (collection_type == VarType::Usage) {
+ for (const tir::Var& tir_var : tir::UndefinedVars(value)) {
+ RecordTIRVar(tir_var);
+ }
+ } else {
+ LOG(FATAL) << "Invalid value for VarType enum, " <<
static_cast<int>(collection_type);
}
}
+ }
- void VisitStructInfo_(const ShapeStructInfoNode* shape_sinfo) final {
- if (shape_sinfo->values.defined()) {
- VisitShape(shape_sinfo->values.value());
- }
+ void VisitStructInfo_(const ShapeStructInfoNode* shape_sinfo) final {
+ if (shape_sinfo->values.defined()) {
+ VisitShape(shape_sinfo->values.value());
}
+ }
- void VisitStructInfo_(const TensorStructInfoNode* tensor_sinfo) final {
- if (tensor_sinfo->shape.defined()) {
- VisitStructInfo(GetStructInfo(tensor_sinfo->shape.value()));
- }
+ void VisitStructInfo_(const TensorStructInfoNode* tensor_sinfo) final {
+ if (tensor_sinfo->shape.defined()) {
+ VisitStructInfo(GetStructInfo(tensor_sinfo->shape.value()));
}
+ }
- Array<tir::Var> tir_vars;
- std::unordered_set<const tir::VarNode*> tir_vars_dedup_set;
- };
+ void RecordTIRVar(const tir::Var& tir_var) {
+ auto insert_res = used_tir_vars_dedup_.insert(tir_var.get());
+ if (insert_res.second) {
+ tir_vars_.push_back(tir_var);
+ }
+ }
+
+ Array<tir::Var> tir_vars_;
+ std::unordered_set<const tir::VarNode*> used_tir_vars_dedup_;
+
+ VarType collection_type;
+};
+
+Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo) {
+ TIRVarsDetector detector(TIRVarsDetector::VarType::Usage);
+ detector(sinfo);
+ return detector.GetTIRVars();
+}
- TIRVarsDetector detector;
+Array<tir::Var> DefinableTIRVarsInStructInfo(const StructInfo& sinfo) {
+ TIRVarsDetector detector(TIRVarsDetector::VarType::Definition);
detector(sinfo);
- return detector.tir_vars;
+ return detector.GetTIRVars();
}
-TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo")
- .set_body_typed([](const StructInfo& sinfo) { return
TIRVarsInStructInfo(sinfo); });
+TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo").set_body_typed(TIRVarsInStructInfo);
+
+TVM_REGISTER_GLOBAL("relax.analysis.DefinableTIRVarsInStructInfo")
+ .set_body_typed(DefinableTIRVarsInStructInfo);
class SymbolicVarCollector : public relax::ExprVisitor,
public relax::StructInfoVisitor,
public tir::ExprVisitor {
public:
- static Array<tir::Var> Free(const Function& func) {
+ static Array<tir::Var> Free(const Expr& expr) {
SymbolicVarCollector collector;
- collector.VisitExpr(func);
+ collector.VisitExpr(expr);
Array<tir::Var> ret{collector.free_symbolic_var_.begin(),
collector.free_symbolic_var_.end()};
return ret;
}
- static Array<tir::Var> Defined(const Function& func) {
+ static Array<tir::Var> Defined(const Expr& expr) {
SymbolicVarCollector collector;
- collector.VisitExpr(func);
+ collector.VisitExpr(expr);
Array<tir::Var> ret{collector.defined_symbolic_var_.begin(),
collector.defined_symbolic_var_.end()};
return ret;
@@ -1098,10 +1129,10 @@ class SymbolicVarCollector : public relax::ExprVisitor,
std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>
free_symbolic_var_;
};
-Array<tir::Var> DefinedSymbolicVars(const Function& func) {
- return SymbolicVarCollector::Defined(func);
+Array<tir::Var> DefinedSymbolicVars(const Expr& expr) {
+ return SymbolicVarCollector::Defined(expr);
}
-Array<tir::Var> FreeSymbolicVars(const Function& func) { return
SymbolicVarCollector::Free(func); }
+Array<tir::Var> FreeSymbolicVars(const Expr& expr) { return
SymbolicVarCollector::Free(expr); }
TVM_REGISTER_GLOBAL("relax.analysis.DefinedSymbolicVars").set_body_typed(DefinedSymbolicVars);
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index 7201786c37..cef19ff068 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -31,6 +31,8 @@
#include <iostream>
#include <vector>
+#include "../../support/ordered_set.h"
+
namespace tvm {
namespace relax {
@@ -49,13 +51,54 @@ class TransformParamsFuncBuilder : public ExprMutator {
TransformParamsFuncBuilder() { builder_->BeginDataflowBlock(); }
/*! \brief Add a input parameter. */
- void AddInput(const Var& var) { inputs_.push_back(var); }
+ void AddInput(const Var& var) {
+ inputs_.push_back(var);
+ lifted_binding_lookup_.insert(var);
+ }
+
+ void UpdateBasedOnRuntimeInput(const Var& var) {
+ for (const auto& var : DefinableTIRVarsInStructInfo(GetStructInfo(var))) {
+ known_symbolic_var_during_inference_.insert(var);
+ }
+ for (const auto& var : TIRVarsInStructInfo(GetStructInfo(var))) {
+ required_symbolic_var_during_inference_.insert(var);
+ }
+ }
/*! \brief Add a binding to lift. */
- void AddBinding(const VarBinding& binding) { bindings_.push_back(binding); }
+ void AddInternalBinding(const VarBinding& binding) {
+ bindings_.push_back(binding);
+ lifted_binding_lookup_.insert(binding->var);
+ }
+
+ /*! \brief Update based on bindings not being lifted. */
+ void UpdateBasedOnRuntimeBinding(const VarBinding& binding) {
+ for (const auto& producer : FreeVars(binding->value)) {
+ // An external value that uses a lifted binding requires the
+ // lifted binding to be returned as output.
+ if (lifted_binding_lookup_.count(producer)) {
+ outputs_.insert(producer);
+
+ for (const auto& var :
DefinableTIRVarsInStructInfo(GetStructInfo(producer))) {
+ known_symbolic_var_during_inference_.insert(var);
+ }
+ }
+ }
+
+ // All TIR variables used in the binding must be available at runtime.
+ for (const auto& var : FreeSymbolicVars(binding->value)) {
+ required_symbolic_var_during_inference_.insert(var);
+ }
+ }
- /*! \brief Mark a variable as the output of the function. */
- void MarkOutput(const Var& output) { outputs_.insert(output); }
+ bool UsesOnlyLiftableProducers(const Expr& expr) {
+ auto producers = FreeVars(expr);
+ bool uses_only_liftable_producers = [&]() {
+ return std::all_of(producers.begin(), producers.end(),
+ [&](const auto& var) { return
lifted_binding_lookup_.count(var); });
+ }();
+ return uses_only_liftable_producers;
+ }
/*!
* \brief Build the function that transforms the parameters
@@ -63,6 +106,13 @@ class TransformParamsFuncBuilder : public ExprMutator {
* of the element of the output tuple
*/
std::pair<Function, std::unordered_map<Var, int, ObjectPtrHash,
ObjectPtrEqual>> Build() {
+ Array<PrimExpr> extra_symbolic_vars;
+ for (const auto& var : required_symbolic_var_during_inference_) {
+ if (!known_symbolic_var_during_inference_.count(var)) {
+ extra_symbolic_vars.push_back(var);
+ }
+ }
+
Array<StructInfo> input_sinfo;
Array<Expr> output_vars;
std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual>
output_to_index;
@@ -72,6 +122,10 @@ class TransformParamsFuncBuilder : public ExprMutator {
}
Var params("params", TupleStructInfo(input_sinfo));
+ if (extra_symbolic_vars.size()) {
+ output_vars.push_back(builder_->Emit(ShapeExpr(extra_symbolic_vars),
"extra_symbolic_vars"));
+ }
+
// Helper to add a variable to the output tuple
// original_var: the binding variable in the original function
// output_var: the variable, which is a binding in the transform_params
function, that is added
@@ -107,7 +161,7 @@ class TransformParamsFuncBuilder : public ExprMutator {
// Create the function.
Expr transformed_params = builder_->EmitOutput(Tuple(output_vars));
BindingBlock block = builder_->EndBlock();
- Expr body = builder_->Normalize(SeqExpr({block}, transformed_params));
+ Expr body = VisitWithNewScope(SeqExpr({block}, transformed_params),
Array<Var>{params});
Function f_transform_params =
Function(/*params=*/{params}, /*body=*/body,
/*ret_struct_info=*/NullOpt);
return {f_transform_params, output_to_index};
@@ -130,6 +184,39 @@ class TransformParamsFuncBuilder : public ExprMutator {
Array<VarBinding> bindings_;
// The variables that are marked as the output of the function.
std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> outputs_;
+
+ // The bindings that are lifted
+ std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual>
lifted_binding_lookup_;
+
+ /* Symbolic variables that are known during the transform_params execution.
+ *
+ * This set is populated based on the variables declared with
+ * AddInput, and contains variables that may appear inside the
+ * transformation function. A binding that depends on a symbolic
+ * variable not contained in this set may not be lifted.
+ */
+ support::OrderedSet<tir::Var> known_symbolic_var_during_transform_;
+
+ /* Symbolic variables that are known during the runtime
+ *
+ * This set is populated based on the variables declared with
+ * UpdateBasedOnRuntimeInput, and contains variables that are
+ * defined at runtime. A variable that present in
+ * required_symbolic_var_during_inference_, but not present in this
+ * set, causes the Build() function to output an additional
+ * R.ShapeExpr in order to propagate the symbolic variables.
+ */
+ support::OrderedSet<tir::Var> known_symbolic_var_during_inference_;
+
+ /* Symbolic variables that must be known at runtime
+ *
+ * This set is populated based on the variables used in external
+ * bindings. A variable that is present here, but not present in
+ * known_symbolic_var_during_inference_, must be provided as an
+ * additional R.ShapeExpr parameter from the transform_params
+ * function.
+ */
+ support::OrderedSet<tir::Var> required_symbolic_var_during_inference_;
};
/*!
@@ -145,14 +232,17 @@ class TransformParamsFuncBuilder : public ExprMutator {
class LiftTransformParamsPlanner : public ExprVisitor {
public:
LiftTransformParamsInfoPlan Plan(const Function& function, int num_inputs) {
- for (int i = num_inputs; i < static_cast<int>(function->params.size());
++i) {
- builder_.AddInput(function->params[i]);
- lifted_bindings_.emplace(function->params[i]);
+ for (int i = 0; i < static_cast<int>(function->params.size()); ++i) {
+ if (i < num_inputs) {
+ builder_.UpdateBasedOnRuntimeInput(function->params[i]);
+ } else {
+ builder_.AddInput(function->params[i]);
+ }
}
VisitExpr(function->body);
const auto& [f_transform_params, output_to_index] = builder_.Build();
- return {f_transform_params, output_to_index, std::move(lifted_bindings_)};
+ return {f_transform_params, output_to_index,
std::move(builder_.lifted_binding_lookup_)};
}
private:
@@ -163,7 +253,6 @@ class LiftTransformParamsPlanner : public ExprVisitor {
}
void VisitBinding_(const VarBindingNode* binding) final {
- std::vector<const VarNode*> producers;
bool can_lift = true;
// Cond 1. Do not lift bindings outside dataflow blocks.
@@ -180,34 +269,29 @@ class LiftTransformParamsPlanner : public ExprVisitor {
}
// Cond 3. Do not lift when involving Vars that are not liftable.
- PostOrderVisit(binding->value, [&](const ObjectRef& obj) {
- if (const VarNode* var = obj.as<VarNode>()) {
- producers.push_back(var);
- if (!lifted_bindings_.count(GetRef<Var>(var))) {
- can_lift = false;
- }
- }
- });
+ auto producers = FreeVars(binding->value);
+ bool uses_only_liftable_producers =
builder_.UsesOnlyLiftableProducers(binding->value);
+ if (!uses_only_liftable_producers) {
+ can_lift = false;
+ }
// Cond 4. Do not lift when its struct info contains symbolic variables.
if (!TIRVarsInStructInfo(GetStructInfo(binding->var)).empty()) {
can_lift = false;
}
+ // Cond 5. Do not lift declarations of external functions
+ if (binding->value.as<relax::ExternFuncNode>()) {
+ can_lift = false;
+ }
+
if (can_lift) {
- lifted_bindings_.insert(binding->var);
- builder_.AddBinding(GetRef<VarBinding>(binding));
+ builder_.AddInternalBinding(GetRef<VarBinding>(binding));
} else {
- for (const VarNode* producer : producers) {
- if (lifted_bindings_.count(GetRef<Var>(producer))) {
- builder_.MarkOutput(GetRef<Var>(producer));
- }
- }
+ builder_.UpdateBasedOnRuntimeBinding(GetRef<VarBinding>(binding));
}
}
- // The bindings that are lifted
- std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> lifted_bindings_;
// The builder of the function that transforms the parameters
TransformParamsFuncBuilder builder_;
// Whether we are in a dataflow block
@@ -256,6 +340,7 @@ class TransformParamsLifter : ExprMutator {
// Step 3.1: Update the function signature
Array<StructInfo> param_fields =
Downcast<TupleStructInfo>(lift_plan_.f_transform_params->ret_struct_info)->fields;
+
Array<Var> new_params(func->params.begin(), func->params.begin() +
num_input);
for (size_t i = 0; i < param_fields.size(); i++) {
std::stringstream name;
@@ -320,12 +405,15 @@ Pass LiftTransformParams() {
}
}
}
- for (const auto& [gvar, transform_func] :
mutator.GetTransformParamFunctions()) {
+ for (auto [gvar, transform_func] : mutator.GetTransformParamFunctions()) {
String name = gvar->name_hint + "_transform_params";
GlobalVar new_gvar(name);
new_gvar->struct_info_ = transform_func->struct_info_;
- updates->Add(new_gvar, WithAttr(transform_func,
tvm::attr::kGlobalSymbol, name));
+ transform_func = CopyWithNewVars(transform_func);
+ transform_func = WithAttr(transform_func, tvm::attr::kGlobalSymbol,
name);
+
+ updates->Add(new_gvar, transform_func);
}
if (updates->functions.size()) {
diff --git a/src/support/ordered_set.h b/src/support/ordered_set.h
index 8ba7708961..96ac45b769 100644
--- a/src/support/ordered_set.h
+++ b/src/support/ordered_set.h
@@ -24,15 +24,36 @@
#ifndef TVM_SUPPORT_ORDERED_SET_H_
#define TVM_SUPPORT_ORDERED_SET_H_
+#include <tvm/runtime/object.h>
+
#include <list>
#include <unordered_map>
namespace tvm {
namespace support {
+namespace detail {
+/* \brief Utility to allow use for standard and ObjectRef types
+ *
+ * \tparam T The type held by the OrderedSet
+ */
+template <typename T, typename = void>
+struct OrderedSetLookupType {
+ using MapType = std::unordered_map<T, typename std::list<T>::iterator>;
+};
+
+template <typename T>
+struct OrderedSetLookupType<T,
std::enable_if_t<std::is_base_of_v<runtime::ObjectRef, T>>> {
+ using MapType = std::unordered_map<T, typename std::list<T>::iterator,
runtime::ObjectPtrHash,
+ runtime::ObjectPtrEqual>;
+};
+} // namespace detail
+
template <typename T>
class OrderedSet {
public:
+ OrderedSet() = default;
+
void push_back(const T& t) {
if (!elem_to_iter_.count(t)) {
elements_.push_back(t);
@@ -40,6 +61,8 @@ class OrderedSet {
}
}
+ void insert(const T& t) { push_back(t); }
+
void erase(const T& t) {
if (auto it = elem_to_iter_.find(t); it != elem_to_iter_.end()) {
elements_.erase(it->second);
@@ -52,6 +75,8 @@ class OrderedSet {
elem_to_iter_.clear();
}
+ size_t count(const T& t) const { return elem_to_iter_.count(t); }
+
auto begin() const { return elements_.begin(); }
auto end() const { return elements_.end(); }
auto size() const { return elements_.size(); }
@@ -59,7 +84,7 @@ class OrderedSet {
private:
std::list<T> elements_;
- std::unordered_map<T, typename std::list<T>::iterator> elem_to_iter_;
+ typename detail::OrderedSetLookupType<T>::MapType elem_to_iter_;
};
} // namespace support
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py
b/tests/python/relax/test_analysis_struct_info_analysis.py
index 1b1ea2e53e..b28df7b224 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -619,21 +619,40 @@ def test_struct_info_lca():
_check_lca(fopaque2(), fn_info_shape(1), fopaque2())
-def test_tir_vars_in_struct_info():
+def _generate_tir_var_test_cases():
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
shape0 = rx.ShapeStructInfo([1, n, 3])
shape1 = rx.ShapeStructInfo([1, 2 * n, n, m])
+ shape2 = rx.ShapeStructInfo([1, 2 * n, m])
tensor0 = rx.TensorStructInfo([1, n, 3], "int32")
tensor1 = rx.TensorStructInfo([1, 2 * n, n, m], "int32")
+ tensor2 = rx.TensorStructInfo([1, 2 * n, m], "int32")
func = rx.FuncStructInfo(
[rx.TensorStructInfo([1, 2 * n, n, m], "int32")],
rx.TensorStructInfo([1, n, 3], "int32")
)
-
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(shape0), [n])
-
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(shape1), [n,
m])
-
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(tensor0),
[n])
-
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(tensor1),
[n, m])
- tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(func),
[n, m])
+ yield shape0, [n], [n]
+ yield shape1, [n, m], [n, m]
+ yield shape2, [m], [n, m]
+ yield tensor0, [n], [n]
+ yield tensor1, [n, m], [n, m]
+ yield tensor2, [m], [n, m]
+ yield func, [n, m], [n, m]
+
+
+tir_var_test_case = tvm.testing.parameter(*_generate_tir_var_test_cases())
+
+
+def test_tir_vars_in_struct_info(tir_var_test_case):
+ sinfo, _vars_definable, vars_used = tir_var_test_case
+ tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(sinfo),
vars_used)
+
+
+def test_definable_tir_vars_in_struct_info(tir_var_test_case):
+ sinfo, vars_definable, _vars_used = tir_var_test_case
+ tvm.ir.assert_structural_equal(
+ rx.analysis.definable_tir_vars_in_struct_info(sinfo), vars_definable
+ )
def test_collect_symbolic_var_from_tensor_shape():
diff --git a/tests/python/relax/test_transform_lift_transform_params.py
b/tests/python/relax/test_transform_lift_transform_params.py
index c23efe655b..7389060bde 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -542,5 +542,105 @@ def test_symbolic_var_2():
tvm.ir.assert_structural_equal(after, Expected)
+def test_symbolic_var_from_shape():
+ @I.ir_module
+ class Before:
+ @R.function
+ def main(
+ A: R.Tensor([16, 16], "int32"),
+ B: R.Tensor([16, 16], "int32"),
+ shape: R.Shape(["slice_index"]),
+ ) -> R.Tensor([16], "int32"):
+ R.func_attr({"num_input": 1})
+ slice_index = T.int64()
+ cls = Before
+ with R.dataflow():
+ B_slice = R.call_tir(
+ cls.slice,
+ [B],
+ tir_vars=R.ShapeExpr([slice_index]),
+ out_sinfo=R.Tensor([16], dtype="int32"),
+ )
+ A_slice = R.call_tir(
+ cls.slice,
+ [A],
+ tir_vars=R.ShapeExpr([slice_index]),
+ out_sinfo=R.Tensor([16], dtype="int32"),
+ )
+ A_scale = R.multiply(A_slice, B_slice)
+ R.output(A_scale)
+ return A_scale
+
+ @T.prim_func(private=True)
+ def slice(
+ Input_2d: T.Buffer(shape=[16, 16], dtype="int32"),
+ Output_Slice: T.Buffer(shape=[16], dtype="int32"),
+ slice_index: T.int64,
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for j in range(16):
+ with T.block("T_full"):
+ vj = T.axis.remap("S", [j])
+ Output_Slice[vj] = Input_2d[slice_index, vj]
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main(
+ A: R.Tensor([16, 16], "int32"),
+ shape: R.Shape(["slice_index"]),
+ B_slice: R.Tensor([16], "int32"),
+ ) -> R.Tensor([16], "int32"):
+ R.func_attr({"num_input": 1})
+ slice_index = T.int64()
+ cls = Expected
+ with R.dataflow():
+ A_slice = R.call_tir(
+ cls.slice,
+ [A],
+ tir_vars=R.ShapeExpr([slice_index]),
+ out_sinfo=R.Tensor([16], dtype="int32"),
+ )
+ B_slice = B_slice
+ A_scale = R.multiply(A_slice, B_slice)
+ R.output(A_scale)
+ return A_scale
+
+ @R.function
+ def main_transform_params(
+ params: R.Tuple(R.Tensor([16, 16], "int32"),
R.Shape(["slice_index"]))
+ ):
+ slice_index = T.int64()
+ cls = Expected
+ with R.dataflow():
+ extra_symbolic_vars = R.ShapeExpr([slice_index])
+ B = params[0]
+ B_slice = R.call_tir(
+ cls.slice,
+ [B],
+ tir_vars=R.ShapeExpr([slice_index]),
+ out_sinfo=R.Tensor([16], dtype="int32"),
+ )
+ output = (extra_symbolic_vars, B_slice)
+ R.output(output)
+ return output
+
+ @T.prim_func(private=True)
+ def slice(
+ Input_2d: T.Buffer(shape=[16, 16], dtype="int32"),
+ Output_Slice: T.Buffer(shape=[16], dtype="int32"),
+ slice_index: T.int64,
+ ):
+ T.func_attr({"tir.noalias": T.bool(True)})
+ for j in range(16):
+ with T.block("T_full"):
+ vj = T.axis.remap("S", [j])
+ Output_Slice[vj] = Input_2d[slice_index, vj]
+
+ mod = Before
+ after = relax.transform.LiftTransformParams()(mod)
+ tvm.ir.assert_structural_equal(Expected, after)
+
+
if __name__ == "__main__":
tvm.testing.main()