This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new c62dcfa89e [Unity] Propagate extra symbolic vars through 
LiftTransformParams (#15699)
c62dcfa89e is described below

commit c62dcfa89e8449cbe7dc53e6792b71e4a7f51818
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Oct 12 11:34:59 2023 -0500

    [Unity] Propagate extra symbolic vars through LiftTransformParams (#15699)
    
    * [Unity][Analysis] Implemented DefinableTIRVarsInStructInfo
    
    The existing utility `TIRVarsInStructInfo` returns all TIR variables,
    regardless of whether they are suitable for a variable definition, or
    are usage sites.  This utility walks over the struct info once,
    returning both the definable symbolic variables and the used symbolic
    variables.
    
    * [Unity][Analysis] Accept relax::Expr arg in Defined/FreeSymbolicVars
    
    Prior to this commit, this utility could only be used with a
    `relax::Function` argument.  This allows individual expressions to be
    inspected, even if they are not part of a complete function.
    
    * [Unity] Propagate symbolic variables in LiftTransformParams
    
    * Updated LiftTransformParams to use support::OrderedSet
    
    * Fixed import after rebase
---
 include/tvm/relax/analysis.h                       |  22 +++-
 python/tvm/relax/analysis/__init__.py              |   1 +
 python/tvm/relax/analysis/analysis.py              |  18 +++
 src/relax/analysis/struct_info_analysis.cc         |  95 +++++++++-----
 src/relax/transform/lift_transform_params.cc       | 146 +++++++++++++++++----
 src/support/ordered_set.h                          |  27 +++-
 .../relax/test_analysis_struct_info_analysis.py    |  31 ++++-
 .../relax/test_transform_lift_transform_params.py  | 100 ++++++++++++++
 8 files changed, 368 insertions(+), 72 deletions(-)

diff --git a/include/tvm/relax/analysis.h b/include/tvm/relax/analysis.h
index 82fb73b1bd..62fb5c686d 100644
--- a/include/tvm/relax/analysis.h
+++ b/include/tvm/relax/analysis.h
@@ -268,21 +268,35 @@ TVM_DLL StructInfo StructInfoLCA(const StructInfo& lhs, 
const StructInfo& rhs,
  */
 TVM_DLL Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo);
 
+/*!
+ * \brief Get the TIR variables that appear in the input struct info.
+ *
+ * Returns all symbolic variables that are definable based on, and
+ * used within, the StructInfo.
+ *
+ * \param sinfo The struct info object to be analyzed.
+ *
+ * \return A tuple of (definable,used) TIR variables.  Both lists are
+ *   deduplicated, each TIR variable will appear at most once, and in
+ *   order of occurrence.
+ */
+TVM_DLL Array<tir::Var> DefinableTIRVarsInStructInfo(const StructInfo& sinfo);
+
 /*!
  * \brief Get the TIR variables that defined in the input function.
  * The returned list is deduplicated - each TIR variable will appear at most 
once.
- * \param func The function object to be analyzed.
+ * \param expr The relax expression (e.g. a Function) to be analyzed.
  * \return The list of TIR variables that are defined in the input function.
  */
-TVM_DLL Array<tir::Var> DefinedSymbolicVars(const Function& func);
+TVM_DLL Array<tir::Var> DefinedSymbolicVars(const Expr& expr);
 
 /*!
  * \brief Get the TIR variables that are used but not defined in the input 
function.
  * The returned list is deduplicated - each TIR variable will appear at most 
once.
- * \param func The function object to be analyzed.
+ * \param expr The relax expression (e.g. a Function) to be analyzed.
  * \return The list of TIR variables that are used but not defined in the 
input function.
  */
-TVM_DLL Array<tir::Var> FreeSymbolicVars(const Function& func);
+TVM_DLL Array<tir::Var> FreeSymbolicVars(const Expr& expr);
 //-----------------------------------
 // General IR analysis
 //-----------------------------------
diff --git a/python/tvm/relax/analysis/__init__.py 
b/python/tvm/relax/analysis/__init__.py
index cc0a36622e..d8454a02cc 100644
--- a/python/tvm/relax/analysis/__init__.py
+++ b/python/tvm/relax/analysis/__init__.py
@@ -22,6 +22,7 @@ from .analysis import (
     all_vars,
     bound_vars,
     contains_impure_call,
+    definable_tir_vars_in_struct_info,
     defined_symbolic_vars,
     derive_call_ret_struct_info,
     detect_recursion,
diff --git a/python/tvm/relax/analysis/analysis.py 
b/python/tvm/relax/analysis/analysis.py
index 6ed319ed85..38f5ea2fea 100644
--- a/python/tvm/relax/analysis/analysis.py
+++ b/python/tvm/relax/analysis/analysis.py
@@ -184,6 +184,24 @@ def tir_vars_in_struct_info(sinfo: StructInfo) -> 
List[tir.Var]:
     return _ffi_api.TIRVarsInStructInfo(sinfo)  # type: ignore
 
 
+def definable_tir_vars_in_struct_info(sinfo: StructInfo) -> List[tir.Var]:
+    """Get the TIR variables that may be defined from input struct info.
+    The returned list is deduplicated - each TIR variable will appear at most 
once.
+
+    Parameters
+    ----------
+    sinfo : StructInfo
+        The struct info object to be analyzed.
+
+    Returns
+    -------
+    ret : List[tir.Var]
+
+        The list of TIR variables that can be defined from the StructInfo
+    """
+    return _ffi_api.DefinableTIRVarsInStructInfo(sinfo)  # type: ignore
+
+
 def defined_symbolic_vars(func: Function) -> List[Var]:
     """Get the TIR variables that defined in the input function.
     The returned list is deduplicated - each TIR variable will appear at most 
once.
diff --git a/src/relax/analysis/struct_info_analysis.cc 
b/src/relax/analysis/struct_info_analysis.cc
index 96e51eede8..18d21cb4d4 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -920,58 +920,89 @@ TVM_REGISTER_GLOBAL("relax.analysis.StructInfoLCA")
 // TIRVarsInStructInfo
 //--------------------------
 
-Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo) {
-  struct TIRVarsDetector : public StructInfoVisitor {
-    void VisitShape(Array<PrimExpr> shape) {
-      for (const PrimExpr& value : shape) {
-        Array<tir::Var> vars = tir::UndefinedVars(value);
-        for (const tir::Var& var : vars) {
-          auto insert_res = tir_vars_dedup_set.insert(var.get());
-          if (insert_res.second) {
-            tir_vars.push_back(var);
-          }
+class TIRVarsDetector : public StructInfoVisitor {
+ public:
+  enum class VarType {
+    Definition,
+    Usage,
+  };
+  TIRVarsDetector(VarType collection_type) : collection_type(collection_type) 
{}
+
+  Array<tir::Var> GetTIRVars() const { return tir_vars_; }
+
+ private:
+  void VisitShape(Array<PrimExpr> shape) {
+    for (const PrimExpr& value : shape) {
+      if (collection_type == VarType::Definition) {
+        if (auto opt = value.as<tir::Var>()) {
+          RecordTIRVar(opt.value());
         }
+      } else if (collection_type == VarType::Usage) {
+        for (const tir::Var& tir_var : tir::UndefinedVars(value)) {
+          RecordTIRVar(tir_var);
+        }
+      } else {
+        LOG(FATAL) << "Invalid value for VarType enum, " << 
static_cast<int>(collection_type);
       }
     }
+  }
 
-    void VisitStructInfo_(const ShapeStructInfoNode* shape_sinfo) final {
-      if (shape_sinfo->values.defined()) {
-        VisitShape(shape_sinfo->values.value());
-      }
+  void VisitStructInfo_(const ShapeStructInfoNode* shape_sinfo) final {
+    if (shape_sinfo->values.defined()) {
+      VisitShape(shape_sinfo->values.value());
     }
+  }
 
-    void VisitStructInfo_(const TensorStructInfoNode* tensor_sinfo) final {
-      if (tensor_sinfo->shape.defined()) {
-        VisitStructInfo(GetStructInfo(tensor_sinfo->shape.value()));
-      }
+  void VisitStructInfo_(const TensorStructInfoNode* tensor_sinfo) final {
+    if (tensor_sinfo->shape.defined()) {
+      VisitStructInfo(GetStructInfo(tensor_sinfo->shape.value()));
     }
+  }
 
-    Array<tir::Var> tir_vars;
-    std::unordered_set<const tir::VarNode*> tir_vars_dedup_set;
-  };
+  void RecordTIRVar(const tir::Var& tir_var) {
+    auto insert_res = used_tir_vars_dedup_.insert(tir_var.get());
+    if (insert_res.second) {
+      tir_vars_.push_back(tir_var);
+    }
+  }
+
+  Array<tir::Var> tir_vars_;
+  std::unordered_set<const tir::VarNode*> used_tir_vars_dedup_;
+
+  VarType collection_type;
+};
+
+Array<tir::Var> TIRVarsInStructInfo(const StructInfo& sinfo) {
+  TIRVarsDetector detector(TIRVarsDetector::VarType::Usage);
+  detector(sinfo);
+  return detector.GetTIRVars();
+}
 
-  TIRVarsDetector detector;
+Array<tir::Var> DefinableTIRVarsInStructInfo(const StructInfo& sinfo) {
+  TIRVarsDetector detector(TIRVarsDetector::VarType::Definition);
   detector(sinfo);
-  return detector.tir_vars;
+  return detector.GetTIRVars();
 }
 
-TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo")
-    .set_body_typed([](const StructInfo& sinfo) { return 
TIRVarsInStructInfo(sinfo); });
+TVM_REGISTER_GLOBAL("relax.analysis.TIRVarsInStructInfo").set_body_typed(TIRVarsInStructInfo);
+
+TVM_REGISTER_GLOBAL("relax.analysis.DefinableTIRVarsInStructInfo")
+    .set_body_typed(DefinableTIRVarsInStructInfo);
 
 class SymbolicVarCollector : public relax::ExprVisitor,
                              public relax::StructInfoVisitor,
                              public tir::ExprVisitor {
  public:
-  static Array<tir::Var> Free(const Function& func) {
+  static Array<tir::Var> Free(const Expr& expr) {
     SymbolicVarCollector collector;
-    collector.VisitExpr(func);
+    collector.VisitExpr(expr);
     Array<tir::Var> ret{collector.free_symbolic_var_.begin(), 
collector.free_symbolic_var_.end()};
     return ret;
   }
 
-  static Array<tir::Var> Defined(const Function& func) {
+  static Array<tir::Var> Defined(const Expr& expr) {
     SymbolicVarCollector collector;
-    collector.VisitExpr(func);
+    collector.VisitExpr(expr);
     Array<tir::Var> ret{collector.defined_symbolic_var_.begin(),
                         collector.defined_symbolic_var_.end()};
     return ret;
@@ -1098,10 +1129,10 @@ class SymbolicVarCollector : public relax::ExprVisitor,
   std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> 
free_symbolic_var_;
 };
 
-Array<tir::Var> DefinedSymbolicVars(const Function& func) {
-  return SymbolicVarCollector::Defined(func);
+Array<tir::Var> DefinedSymbolicVars(const Expr& expr) {
+  return SymbolicVarCollector::Defined(expr);
 }
-Array<tir::Var> FreeSymbolicVars(const Function& func) { return 
SymbolicVarCollector::Free(func); }
+Array<tir::Var> FreeSymbolicVars(const Expr& expr) { return 
SymbolicVarCollector::Free(expr); }
 
 
TVM_REGISTER_GLOBAL("relax.analysis.DefinedSymbolicVars").set_body_typed(DefinedSymbolicVars);
 
diff --git a/src/relax/transform/lift_transform_params.cc 
b/src/relax/transform/lift_transform_params.cc
index 7201786c37..cef19ff068 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -31,6 +31,8 @@
 #include <iostream>
 #include <vector>
 
+#include "../../support/ordered_set.h"
+
 namespace tvm {
 namespace relax {
 
@@ -49,13 +51,54 @@ class TransformParamsFuncBuilder : public ExprMutator {
   TransformParamsFuncBuilder() { builder_->BeginDataflowBlock(); }
 
   /*! \brief Add a input parameter. */
-  void AddInput(const Var& var) { inputs_.push_back(var); }
+  void AddInput(const Var& var) {
+    inputs_.push_back(var);
+    lifted_binding_lookup_.insert(var);
+  }
+
+  void UpdateBasedOnRuntimeInput(const Var& var) {
+    for (const auto& var : DefinableTIRVarsInStructInfo(GetStructInfo(var))) {
+      known_symbolic_var_during_inference_.insert(var);
+    }
+    for (const auto& var : TIRVarsInStructInfo(GetStructInfo(var))) {
+      required_symbolic_var_during_inference_.insert(var);
+    }
+  }
 
   /*! \brief Add a binding to lift. */
-  void AddBinding(const VarBinding& binding) { bindings_.push_back(binding); }
+  void AddInternalBinding(const VarBinding& binding) {
+    bindings_.push_back(binding);
+    lifted_binding_lookup_.insert(binding->var);
+  }
+
+  /*! \brief Update based on bindings not being lifted. */
+  void UpdateBasedOnRuntimeBinding(const VarBinding& binding) {
+    for (const auto& producer : FreeVars(binding->value)) {
+      // An external value that uses a lifted binding requires the
+      // lifted binding to be returned as output.
+      if (lifted_binding_lookup_.count(producer)) {
+        outputs_.insert(producer);
+
+        for (const auto& var : 
DefinableTIRVarsInStructInfo(GetStructInfo(producer))) {
+          known_symbolic_var_during_inference_.insert(var);
+        }
+      }
+    }
+
+    // All TIR variables used in the binding must be available at runtime.
+    for (const auto& var : FreeSymbolicVars(binding->value)) {
+      required_symbolic_var_during_inference_.insert(var);
+    }
+  }
 
-  /*! \brief Mark a variable as the output of the function. */
-  void MarkOutput(const Var& output) { outputs_.insert(output); }
+  bool UsesOnlyLiftableProducers(const Expr& expr) {
+    auto producers = FreeVars(expr);
+    bool uses_only_liftable_producers = [&]() {
+      return std::all_of(producers.begin(), producers.end(),
+                         [&](const auto& var) { return 
lifted_binding_lookup_.count(var); });
+    }();
+    return uses_only_liftable_producers;
+  }
 
   /*!
    * \brief Build the function that transforms the parameters
@@ -63,6 +106,13 @@ class TransformParamsFuncBuilder : public ExprMutator {
    * of the element of the output tuple
    */
   std::pair<Function, std::unordered_map<Var, int, ObjectPtrHash, 
ObjectPtrEqual>> Build() {
+    Array<PrimExpr> extra_symbolic_vars;
+    for (const auto& var : required_symbolic_var_during_inference_) {
+      if (!known_symbolic_var_during_inference_.count(var)) {
+        extra_symbolic_vars.push_back(var);
+      }
+    }
+
     Array<StructInfo> input_sinfo;
     Array<Expr> output_vars;
     std::unordered_map<Var, int, ObjectPtrHash, ObjectPtrEqual> 
output_to_index;
@@ -72,6 +122,10 @@ class TransformParamsFuncBuilder : public ExprMutator {
     }
     Var params("params", TupleStructInfo(input_sinfo));
 
+    if (extra_symbolic_vars.size()) {
+      output_vars.push_back(builder_->Emit(ShapeExpr(extra_symbolic_vars), 
"extra_symbolic_vars"));
+    }
+
     // Helper to add a variable to the output tuple
     // original_var: the binding variable in the original function
     // output_var: the variable, which is a binding in the transform_params 
function, that is added
@@ -107,7 +161,7 @@ class TransformParamsFuncBuilder : public ExprMutator {
     // Create the function.
     Expr transformed_params = builder_->EmitOutput(Tuple(output_vars));
     BindingBlock block = builder_->EndBlock();
-    Expr body = builder_->Normalize(SeqExpr({block}, transformed_params));
+    Expr body = VisitWithNewScope(SeqExpr({block}, transformed_params), 
Array<Var>{params});
     Function f_transform_params =
         Function(/*params=*/{params}, /*body=*/body, 
/*ret_struct_info=*/NullOpt);
     return {f_transform_params, output_to_index};
@@ -130,6 +184,39 @@ class TransformParamsFuncBuilder : public ExprMutator {
   Array<VarBinding> bindings_;
   // The variables that are marked as the output of the function.
   std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> outputs_;
+
+  // The bindings that are lifted
+  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> 
lifted_binding_lookup_;
+
+  /* Symbolic variables that are known during the transform_params execution.
+   *
+   * This set is populated based on the variables declared with
+   * AddInput, and contains variables that may appear inside the
+   * transformation function.  A binding that depends on a symbolic
+   * variable not contained in this set may not be lifted.
+   */
+  support::OrderedSet<tir::Var> known_symbolic_var_during_transform_;
+
+  /* Symbolic variables that are known during the runtime
+   *
+   * This set is populated based on the variables declared with
+   * UpdateBasedOnRuntimeInput, and contains variables that are
+   * defined at runtime.  A variable that present in
+   * required_symbolic_var_during_inference_, but not present in this
+   * set, causes the Build() function to output an additional
+   * R.ShapeExpr in order to propagate the symbolic variables.
+   */
+  support::OrderedSet<tir::Var> known_symbolic_var_during_inference_;
+
+  /* Symbolic variables that must be known at runtime
+   *
+   * This set is populated based on the variables used in external
+   * bindings.  A variable that is present here, but not present in
+   * known_symbolic_var_during_inference_, must be provided as an
+   * additional R.ShapeExpr parameter from the transform_params
+   * function.
+   */
+  support::OrderedSet<tir::Var> required_symbolic_var_during_inference_;
 };
 
 /*!
@@ -145,14 +232,17 @@ class TransformParamsFuncBuilder : public ExprMutator {
 class LiftTransformParamsPlanner : public ExprVisitor {
  public:
   LiftTransformParamsInfoPlan Plan(const Function& function, int num_inputs) {
-    for (int i = num_inputs; i < static_cast<int>(function->params.size()); 
++i) {
-      builder_.AddInput(function->params[i]);
-      lifted_bindings_.emplace(function->params[i]);
+    for (int i = 0; i < static_cast<int>(function->params.size()); ++i) {
+      if (i < num_inputs) {
+        builder_.UpdateBasedOnRuntimeInput(function->params[i]);
+      } else {
+        builder_.AddInput(function->params[i]);
+      }
     }
     VisitExpr(function->body);
 
     const auto& [f_transform_params, output_to_index] = builder_.Build();
-    return {f_transform_params, output_to_index, std::move(lifted_bindings_)};
+    return {f_transform_params, output_to_index, 
std::move(builder_.lifted_binding_lookup_)};
   }
 
  private:
@@ -163,7 +253,6 @@ class LiftTransformParamsPlanner : public ExprVisitor {
   }
 
   void VisitBinding_(const VarBindingNode* binding) final {
-    std::vector<const VarNode*> producers;
     bool can_lift = true;
 
     // Cond 1. Do not lift bindings outside dataflow blocks.
@@ -180,34 +269,29 @@ class LiftTransformParamsPlanner : public ExprVisitor {
     }
 
     // Cond 3. Do not lift when involving Vars that are not liftable.
-    PostOrderVisit(binding->value, [&](const ObjectRef& obj) {
-      if (const VarNode* var = obj.as<VarNode>()) {
-        producers.push_back(var);
-        if (!lifted_bindings_.count(GetRef<Var>(var))) {
-          can_lift = false;
-        }
-      }
-    });
+    auto producers = FreeVars(binding->value);
+    bool uses_only_liftable_producers = 
builder_.UsesOnlyLiftableProducers(binding->value);
+    if (!uses_only_liftable_producers) {
+      can_lift = false;
+    }
 
     // Cond 4. Do not lift when its struct info contains symbolic variables.
     if (!TIRVarsInStructInfo(GetStructInfo(binding->var)).empty()) {
       can_lift = false;
     }
 
+    // Cond 5. Do not lift declarations of external functions
+    if (binding->value.as<relax::ExternFuncNode>()) {
+      can_lift = false;
+    }
+
     if (can_lift) {
-      lifted_bindings_.insert(binding->var);
-      builder_.AddBinding(GetRef<VarBinding>(binding));
+      builder_.AddInternalBinding(GetRef<VarBinding>(binding));
     } else {
-      for (const VarNode* producer : producers) {
-        if (lifted_bindings_.count(GetRef<Var>(producer))) {
-          builder_.MarkOutput(GetRef<Var>(producer));
-        }
-      }
+      builder_.UpdateBasedOnRuntimeBinding(GetRef<VarBinding>(binding));
     }
   }
 
-  // The bindings that are lifted
-  std::unordered_set<Var, ObjectPtrHash, ObjectPtrEqual> lifted_bindings_;
   // The builder of the function that transforms the parameters
   TransformParamsFuncBuilder builder_;
   // Whether we are in a dataflow block
@@ -256,6 +340,7 @@ class TransformParamsLifter : ExprMutator {
     // Step 3.1: Update the function signature
     Array<StructInfo> param_fields =
         
Downcast<TupleStructInfo>(lift_plan_.f_transform_params->ret_struct_info)->fields;
+
     Array<Var> new_params(func->params.begin(), func->params.begin() + 
num_input);
     for (size_t i = 0; i < param_fields.size(); i++) {
       std::stringstream name;
@@ -320,12 +405,15 @@ Pass LiftTransformParams() {
         }
       }
     }
-    for (const auto& [gvar, transform_func] : 
mutator.GetTransformParamFunctions()) {
+    for (auto [gvar, transform_func] : mutator.GetTransformParamFunctions()) {
       String name = gvar->name_hint + "_transform_params";
       GlobalVar new_gvar(name);
       new_gvar->struct_info_ = transform_func->struct_info_;
 
-      updates->Add(new_gvar, WithAttr(transform_func, 
tvm::attr::kGlobalSymbol, name));
+      transform_func = CopyWithNewVars(transform_func);
+      transform_func = WithAttr(transform_func, tvm::attr::kGlobalSymbol, 
name);
+
+      updates->Add(new_gvar, transform_func);
     }
 
     if (updates->functions.size()) {
diff --git a/src/support/ordered_set.h b/src/support/ordered_set.h
index 8ba7708961..96ac45b769 100644
--- a/src/support/ordered_set.h
+++ b/src/support/ordered_set.h
@@ -24,15 +24,36 @@
 #ifndef TVM_SUPPORT_ORDERED_SET_H_
 #define TVM_SUPPORT_ORDERED_SET_H_
 
+#include <tvm/runtime/object.h>
+
 #include <list>
 #include <unordered_map>
 
 namespace tvm {
 namespace support {
 
+namespace detail {
+/* \brief Utility to allow use for standard and ObjectRef types
+ *
+ * \tparam T The type held by the OrderedSet
+ */
+template <typename T, typename = void>
+struct OrderedSetLookupType {
+  using MapType = std::unordered_map<T, typename std::list<T>::iterator>;
+};
+
+template <typename T>
+struct OrderedSetLookupType<T, 
std::enable_if_t<std::is_base_of_v<runtime::ObjectRef, T>>> {
+  using MapType = std::unordered_map<T, typename std::list<T>::iterator, 
runtime::ObjectPtrHash,
+                                     runtime::ObjectPtrEqual>;
+};
+}  // namespace detail
+
 template <typename T>
 class OrderedSet {
  public:
+  OrderedSet() = default;
+
   void push_back(const T& t) {
     if (!elem_to_iter_.count(t)) {
       elements_.push_back(t);
@@ -40,6 +61,8 @@ class OrderedSet {
     }
   }
 
+  void insert(const T& t) { push_back(t); }
+
   void erase(const T& t) {
     if (auto it = elem_to_iter_.find(t); it != elem_to_iter_.end()) {
       elements_.erase(it->second);
@@ -52,6 +75,8 @@ class OrderedSet {
     elem_to_iter_.clear();
   }
 
+  size_t count(const T& t) const { return elem_to_iter_.count(t); }
+
   auto begin() const { return elements_.begin(); }
   auto end() const { return elements_.end(); }
   auto size() const { return elements_.size(); }
@@ -59,7 +84,7 @@ class OrderedSet {
 
  private:
   std::list<T> elements_;
-  std::unordered_map<T, typename std::list<T>::iterator> elem_to_iter_;
+  typename detail::OrderedSetLookupType<T>::MapType elem_to_iter_;
 };
 
 }  // namespace support
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py 
b/tests/python/relax/test_analysis_struct_info_analysis.py
index 1b1ea2e53e..b28df7b224 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -619,21 +619,40 @@ def test_struct_info_lca():
     _check_lca(fopaque2(), fn_info_shape(1), fopaque2())
 
 
-def test_tir_vars_in_struct_info():
+def _generate_tir_var_test_cases():
     n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
     shape0 = rx.ShapeStructInfo([1, n, 3])
     shape1 = rx.ShapeStructInfo([1, 2 * n, n, m])
+    shape2 = rx.ShapeStructInfo([1, 2 * n, m])
     tensor0 = rx.TensorStructInfo([1, n, 3], "int32")
     tensor1 = rx.TensorStructInfo([1, 2 * n, n, m], "int32")
+    tensor2 = rx.TensorStructInfo([1, 2 * n, m], "int32")
     func = rx.FuncStructInfo(
         [rx.TensorStructInfo([1, 2 * n, n, m], "int32")], 
rx.TensorStructInfo([1, n, 3], "int32")
     )
 
-    
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(shape0), [n])
-    
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(shape1), [n, 
m])
-    
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(tensor0), 
[n])
-    
tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(tensor1), 
[n, m])
-    tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(func), 
[n, m])
+    yield shape0, [n], [n]
+    yield shape1, [n, m], [n, m]
+    yield shape2, [m], [n, m]
+    yield tensor0, [n], [n]
+    yield tensor1, [n, m], [n, m]
+    yield tensor2, [m], [n, m]
+    yield func, [n, m], [n, m]
+
+
+tir_var_test_case = tvm.testing.parameter(*_generate_tir_var_test_cases())
+
+
+def test_tir_vars_in_struct_info(tir_var_test_case):
+    sinfo, _vars_definable, vars_used = tir_var_test_case
+    tvm.ir.assert_structural_equal(rx.analysis.tir_vars_in_struct_info(sinfo), 
vars_used)
+
+
+def test_definable_tir_vars_in_struct_info(tir_var_test_case):
+    sinfo, vars_definable, _vars_used = tir_var_test_case
+    tvm.ir.assert_structural_equal(
+        rx.analysis.definable_tir_vars_in_struct_info(sinfo), vars_definable
+    )
 
 
 def test_collect_symbolic_var_from_tensor_shape():
diff --git a/tests/python/relax/test_transform_lift_transform_params.py 
b/tests/python/relax/test_transform_lift_transform_params.py
index c23efe655b..7389060bde 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -542,5 +542,105 @@ def test_symbolic_var_2():
     tvm.ir.assert_structural_equal(after, Expected)
 
 
+def test_symbolic_var_from_shape():
+    @I.ir_module
+    class Before:
+        @R.function
+        def main(
+            A: R.Tensor([16, 16], "int32"),
+            B: R.Tensor([16, 16], "int32"),
+            shape: R.Shape(["slice_index"]),
+        ) -> R.Tensor([16], "int32"):
+            R.func_attr({"num_input": 1})
+            slice_index = T.int64()
+            cls = Before
+            with R.dataflow():
+                B_slice = R.call_tir(
+                    cls.slice,
+                    [B],
+                    tir_vars=R.ShapeExpr([slice_index]),
+                    out_sinfo=R.Tensor([16], dtype="int32"),
+                )
+                A_slice = R.call_tir(
+                    cls.slice,
+                    [A],
+                    tir_vars=R.ShapeExpr([slice_index]),
+                    out_sinfo=R.Tensor([16], dtype="int32"),
+                )
+                A_scale = R.multiply(A_slice, B_slice)
+                R.output(A_scale)
+            return A_scale
+
+        @T.prim_func(private=True)
+        def slice(
+            Input_2d: T.Buffer(shape=[16, 16], dtype="int32"),
+            Output_Slice: T.Buffer(shape=[16], dtype="int32"),
+            slice_index: T.int64,
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for j in range(16):
+                with T.block("T_full"):
+                    vj = T.axis.remap("S", [j])
+                    Output_Slice[vj] = Input_2d[slice_index, vj]
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            A: R.Tensor([16, 16], "int32"),
+            shape: R.Shape(["slice_index"]),
+            B_slice: R.Tensor([16], "int32"),
+        ) -> R.Tensor([16], "int32"):
+            R.func_attr({"num_input": 1})
+            slice_index = T.int64()
+            cls = Expected
+            with R.dataflow():
+                A_slice = R.call_tir(
+                    cls.slice,
+                    [A],
+                    tir_vars=R.ShapeExpr([slice_index]),
+                    out_sinfo=R.Tensor([16], dtype="int32"),
+                )
+                B_slice = B_slice
+                A_scale = R.multiply(A_slice, B_slice)
+                R.output(A_scale)
+            return A_scale
+
+        @R.function
+        def main_transform_params(
+            params: R.Tuple(R.Tensor([16, 16], "int32"), 
R.Shape(["slice_index"]))
+        ):
+            slice_index = T.int64()
+            cls = Expected
+            with R.dataflow():
+                extra_symbolic_vars = R.ShapeExpr([slice_index])
+                B = params[0]
+                B_slice = R.call_tir(
+                    cls.slice,
+                    [B],
+                    tir_vars=R.ShapeExpr([slice_index]),
+                    out_sinfo=R.Tensor([16], dtype="int32"),
+                )
+                output = (extra_symbolic_vars, B_slice)
+                R.output(output)
+            return output
+
+        @T.prim_func(private=True)
+        def slice(
+            Input_2d: T.Buffer(shape=[16, 16], dtype="int32"),
+            Output_Slice: T.Buffer(shape=[16], dtype="int32"),
+            slice_index: T.int64,
+        ):
+            T.func_attr({"tir.noalias": T.bool(True)})
+            for j in range(16):
+                with T.block("T_full"):
+                    vj = T.axis.remap("S", [j])
+                    Output_Slice[vj] = Input_2d[slice_index, vj]
+
+    mod = Before
+    after = relax.transform.LiftTransformParams()(mod)
+    tvm.ir.assert_structural_equal(Expected, after)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to