This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 2e2126f9e3 [Unity] Implement relax.Function.bind_symbolic_vars (#15509)
2e2126f9e3 is described below

commit 2e2126f9e342b1996f7b80ce0dd9e11095ef481c
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Aug 16 09:10:23 2023 -0500

    [Unity] Implement relax.Function.bind_symbolic_vars (#15509)
    
    * [Unity] Implement relax.Function.bind_symbolic_vars
    
    If a function has dynamic shape parameters, it can be useful to
    replace them with static parameters (e.g. when producing several
    models within the same family).  This commit introduces a utility
    function `relax.Function.bind_symbolic_vars`, which allows symbolic
    variables to be replaced with static values.
    
    This is a related to the parameter binding done in
    `relax.transform.BindParam`, but does not require the bound parameter
    to be fully static data array.
    
    * Updating ExprBinder to use tir::Substitute
    
    Previously, `ExprBinder` only checked whether a `PrimExpr` was a
    symbolic variable to be replaced, but did not handle cases where a
    `PrimExpr` contained a symbolic variable to be replaced.  As a result,
    when binding symbolic variables `{N: 16}`, a shape of `[N,2*N]` would be
    updated to `[16,2*N]` instead of `[16,32]`.  This commit updates
    `ExprBinder` to use `tir::Substitute` to ensure all occurrences of the
    symbolic variable are replaced.
    
    * Special case for updating symbolic vars in strided_slice attrs
    
    * Added IRModule pass to bind symbolic vars
    
    * Update unit test to include pytest
    
    * Co-authored-by: Sunghyun Park <[email protected]>
    
    * Correct match mode in kProvideDefinitions context
    
    * Clean up implementation with VisitMode as a bitflag
---
 include/tvm/relax/transform.h                      |  18 ++
 python/tvm/relax/expr.py                           |  32 ++-
 python/tvm/relax/transform/transform.py            |  29 +++
 src/relax/analysis/struct_info_analysis.cc         |  63 +++--
 src/relax/transform/bind_symbolic_vars.cc          | 177 ++++++++++++++
 src/relax/utils.cc                                 |  58 ++++-
 tests/python/relax/test_bind_symbolic_vars.py      | 205 ++++++++++++++++
 .../relax/test_transform_bind_symbolic_vars.py     | 270 +++++++++++++++++++++
 8 files changed, 827 insertions(+), 25 deletions(-)

diff --git a/include/tvm/relax/transform.h b/include/tvm/relax/transform.h
index 8d01262aab..05b26f0242 100644
--- a/include/tvm/relax/transform.h
+++ b/include/tvm/relax/transform.h
@@ -182,6 +182,24 @@ TVM_DLL Pass EliminateCommonSubexpr(bool call_only = 
false);
  */
 TVM_DLL Pass BindParams(String func_name, Map<String, runtime::NDArray> 
params);
 
+/*!
+ * \brief Bind symbolic vars to constant shape values.
+ *
+ * \param binding_map The dictionary of symbolic variables and their
+ *      constant shape values.  Dictionary keys may be either a
+ *      `tir.Var` or a string name of the variable.  If the variables
+ *      are referred to by name, the name must uniquely identify a
+ *      symbolic variable in each function where it is used.
+ *
+ * \param func_name The name of the function in which to bind shape
+ *      values.  If NullOpt, all functions in the module will be
+ *      updated.
+ *
+ * \return The Pass.
+ */
+TVM_DLL Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map,
+                              Optional<String> func_name = NullOpt);
+
 /*!
  * \brief Fold constant expressions.
  *
diff --git a/python/tvm/relax/expr.py b/python/tvm/relax/expr.py
index fb8ccf98d3..49b91ffb3d 100644
--- a/python/tvm/relax/expr.py
+++ b/python/tvm/relax/expr.py
@@ -19,7 +19,7 @@
 """The expression nodes of Relax."""
 import typing
 from numbers import Number
-from typing import Any, Callable, Dict, List, Optional, Union
+from typing import Any, Callable, Dict, List, Optional, Union, Mapping
 
 import numpy as _np  # type: ignore
 
@@ -627,6 +627,36 @@ class Function(BaseFunc, Scriptable):
         """
         return Call(self, args, None, None)
 
+    def bind_symbolic_vars(
+        self, binding_map: Mapping[Union[str, tvm.tir.Var], PrimExpr]
+    ) -> "Function":
+        """Return a new function with updated symbolic variable
+
+        Parameters
+        ----------
+        binding_map: Mapping[Union[str, tvm.tir.Var], PrimExpr]
+
+            The mapping of values to be replaced.  Keys may be either
+            a `tir.Var` or a string name of the variable.  If the
+            variables are referred to by name, the name must uniquely
+            identify a symbolic variable in the function.
+
+        Returns
+        -------
+        func: Function
+
+            The updated function
+        """
+
+        # Relax uses int64 for symbolic variables, but the FFI
+        # converts python integers into int32.
+        binding_map = {
+            key: tvm.tir.const(value, "int64") if isinstance(value, int) else 
value
+            for key, value in binding_map.items()
+        }
+
+        return _ffi_api.FunctionBindSymbolicVars(self, binding_map)  # type: 
ignore
+
 
 @tvm._ffi.register_object("relax.expr.ExternFunc")
 class ExternFunc(BaseFunc):
diff --git a/python/tvm/relax/transform/transform.py 
b/python/tvm/relax/transform/transform.py
index f512e42bf6..438a6d1213 100644
--- a/python/tvm/relax/transform/transform.py
+++ b/python/tvm/relax/transform/transform.py
@@ -416,6 +416,35 @@ def BindParams(
     return _ffi_api.BindParams(func_name, tvm_params)  # type: ignore
 
 
+def BindSymbolicVars(
+    binding_map: Mapping[Union[str, tvm.tir.Var], tvm.tir.PrimExpr],
+    func_name: Optional[str] = None,
+) -> tvm.ir.transform.Pass:
+    """Bind params of function of the module to constant tensors.
+    Parameters
+    ----------
+    binding_map : Mapping[Union[str, tvm.tir.Var], tvm.tir.PrimExpr]
+
+        The map from symbolic varname to integer.
+
+    func_name: Optional[str]
+
+        The function name to be bound.  If None (default), all
+        functions within the module will be updated.
+
+    Returns
+    -------
+    ret: tvm.ir.transform.Pass
+    """
+    # Relax uses int64 for symbolic variables, but the FFI
+    # converts python integers into int32.
+    binding_map = {
+        key: tvm.tir.const(value, "int64") if isinstance(value, int) else value
+        for key, value in binding_map.items()
+    }
+    return _ffi_api.BindSymbolicVars(binding_map, func_name)  # type: ignore
+
+
 def RunCodegen(
     target_options: Optional[dict] = None,
     entry_functions: Optional[List[str]] = None,
diff --git a/src/relax/analysis/struct_info_analysis.cc 
b/src/relax/analysis/struct_info_analysis.cc
index 22c2e9bbd4..9fae776279 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -978,17 +978,34 @@ class SymbolicVarCollector : public relax::ExprVisitor,
   using tir::ExprVisitor::VisitExpr;
   using tir::ExprVisitor::VisitExpr_;
 
-  // Possible mode of visitor
-  enum class VisitMode {
-    /*! \brief Check all vars are well-defined. */
-    kDefault,
-    /*! \brief Match define the vars on first occurrence. */
-    kMatchVarDef,
+  // Possible mode of visitor, used as bit-flags
+  enum VisitMode {
+    /*! \brief Do nothing on encountering a symbolic variable */
+    kNone = 0,
+
+    /*! \brief Provide a variable definition on first occurrence.
+     *
+     * If a symbolic variable occurs at a site where a definition can
+     * be provided, mark the variable as having a definition.
+     */
+    kProvideDefinition = 1,
+
+    /*! \brief Require a variable definition on occurrence.
+     *
+     * If a symbolic variable occurs, and has not previously been
+     * defined, mark the variable as being free/undefined.
+     */
+    kRequireDefinition = 2,
   };
 
   void VisitExpr_(const FunctionNode* op) final {
-    WithMode(VisitMode::kMatchVarDef, [&]() {
-      ICHECK(mode_ == VisitMode::kMatchVarDef);
+    WithMode(VisitMode::kProvideDefinition, [&]() {
+      for (Var param : op->params) {
+        relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param));
+      }
+    });
+
+    WithMode(VisitMode::kRequireDefinition, [&]() {
       for (Var param : op->params) {
         relax::StructInfoVisitor::VisitStructInfo(GetStructInfo(param));
       }
@@ -998,7 +1015,8 @@ class SymbolicVarCollector : public relax::ExprVisitor,
   }
 
   void VisitBinding_(const MatchCastNode* binding) final {
-    WithMode(VisitMode::kMatchVarDef, [&]() { 
this->VisitStructInfo(binding->struct_info); });
+    WithMode(VisitMode(VisitMode::kProvideDefinition | 
VisitMode::kRequireDefinition),
+             [&]() { this->VisitStructInfo(binding->struct_info); });
 
     relax::ExprVisitor::VisitBinding_(binding);
   }
@@ -1009,8 +1027,17 @@ class SymbolicVarCollector : public relax::ExprVisitor,
 
   void VisitStructInfo_(const FuncStructInfoNode* op) final {
     if (op->params.defined()) {
-      WithMode(VisitMode::kMatchVarDef, [&]() {
-        ICHECK(mode_ == VisitMode::kMatchVarDef);
+      // Visit the parameters once to collect bindings, and another
+      // time to collect usages.  Otherwise, a symbolic variable
+      // defined by a later parameter may be treated as undefined when
+      // used by an earlier parameter.
+      WithMode(VisitMode::kProvideDefinition, [&]() {
+        for (StructInfo param : op->params.value()) {
+          this->VisitStructInfo(param);
+        }
+      });
+
+      WithMode(VisitMode::kRequireDefinition, [&]() {
         for (StructInfo param : op->params.value()) {
           this->VisitStructInfo(param);
         }
@@ -1029,14 +1056,14 @@ class SymbolicVarCollector : public relax::ExprVisitor,
   }
 
   void VisitStructInfoExprField(const PrimExpr& expr) final {
-    if (mode_ == VisitMode::kMatchVarDef && expr->IsInstance<tir::VarNode>()) {
-      // populate symbolic var in first occurrence
-      const auto& var = Downcast<tir::Var>(expr);
-      if (defined_symbolic_var_.count(var) == 0) {
-        defined_symbolic_var_.insert(var);
+    if (mode_ & VisitMode::kProvideDefinition) {
+      if (auto var = expr.as<tir::Var>()) {
+        defined_symbolic_var_.insert(var.value());
       }
     }
-    tir::ExprVisitor::VisitExpr(expr);
+    if (mode_ & VisitMode::kRequireDefinition) {
+      tir::ExprVisitor::VisitExpr(expr);
+    }
   }
 
   void VisitExpr_(const tir::VarNode* op) final {
@@ -1056,7 +1083,7 @@ class SymbolicVarCollector : public relax::ExprVisitor,
   }
 
   /*! \brief The current visit mode. */
-  VisitMode mode_ = VisitMode::kDefault;
+  VisitMode mode_ = VisitMode::kRequireDefinition;
   /*! \brief The set of defined symbolic vars. */
   std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> 
defined_symbolic_var_;
   /*! \brief The set of free/undefined symbolic vars. */
diff --git a/src/relax/transform/bind_symbolic_vars.cc 
b/src/relax/transform/bind_symbolic_vars.cc
new file mode 100644
index 0000000000..2df9ed1f01
--- /dev/null
+++ b/src/relax/transform/bind_symbolic_vars.cc
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+#include <tvm/relax/analysis.h>
+#include <tvm/relax/expr.h>
+#include <tvm/relax/struct_info.h>
+#include <tvm/relax/transform.h>
+#include <tvm/relax/type.h>
+
+#include <string>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+namespace tvm {
+namespace relax {
+
+Function FunctionBindSymbolicVars(Function func, Map<ObjectRef, PrimExpr> 
obj_remap) {
+  // Early bail-out if no updates need to be made.
+  if (obj_remap.empty()) {
+    return func;
+  }
+
+  Array<tir::Var> old_symbolic_vars = DefinedSymbolicVars(func);
+
+  // Map from string to the variable(s) with that name.
+  std::unordered_map<std::string, Array<tir::Var>> string_lookup;
+  std::unordered_set<const tir::VarNode*> symbolic_var_set;
+  for (const auto& var : old_symbolic_vars) {
+    string_lookup[var->name_hint].push_back(var);
+    symbolic_var_set.insert(var.get());
+  }
+
+  // Replacement map to be used when rewriting the function.
+  Map<tir::Var, PrimExpr> var_remap;
+  for (const auto& [key, replacement] : obj_remap) {
+    if (auto opt = key.as<String>()) {
+      String string_key = opt.value();
+      auto it = string_lookup.find(string_key);
+      CHECK(it != string_lookup.end())
+          << "Function does not use symbolic var with name \"" << string_key 
<< "\".  "
+          << "Function has symbolic variables " << old_symbolic_vars;
+
+      CHECK_EQ(it->second.size(), 1)
+          << "Function contains multiple symbolic variables with name \"" << 
string_key << "\".  "
+          << "The TIR variables " << it->second << " are all named \"" << 
string_key << "\"";
+      auto var = it->second[0];
+
+      CHECK(!var_remap.count(var)) << "Remap of variable " << var << " was 
defined multiple times";
+      var_remap.Set(var, replacement);
+    } else if (auto opt = key.as<tir::Var>()) {
+      auto var = opt.value();
+
+      CHECK(!var_remap.count(var)) << "Remap of variable " << var << " was 
defined multiple times";
+      CHECK(symbolic_var_set.count(var.get()))
+          << "Function does not use variable " << var << " as a symbolic 
variable.  "
+          << "Function has symbolic variables " << old_symbolic_vars;
+      var_remap.Set(var, replacement);
+    } else {
+      LOG(FATAL) << "Expected symbolic variable to be a tir::Var or a string 
name, "
+                 << "but " << key << " was of type " << key->GetTypeKey();
+    }
+  }
+
+  auto new_func = Downcast<Function>(Bind(func, {}, var_remap));
+
+  auto free_symbolic_vars = FreeSymbolicVars(new_func);
+
+  CHECK(free_symbolic_vars.empty())
+      << "Resulting function should not have any undefined symbolic variables, 
"
+      << "but TIR variables " << free_symbolic_vars << " were undefined.";
+
+  return new_func;
+}
+
+namespace {
+IRModule ModuleBindSymbolicVars(IRModule mod, Map<ObjectRef, PrimExpr> 
binding_map) {
+  std::unordered_set<const Object*> used;
+  IRModule updates;
+  for (const auto& [gvar, base_func] : mod->functions) {
+    if (auto opt = base_func.as<Function>()) {
+      auto func = opt.value();
+
+      // Collect bindings that are used by this function.
+      auto func_binding_map = [&]() -> Map<ObjectRef, PrimExpr> {
+        std::unordered_set<std::string> var_names;
+        std::unordered_set<const tir::VarNode*> vars;
+        for (const auto& var : DefinedSymbolicVars(func)) {
+          var_names.insert(var->name_hint);
+          vars.insert(var.get());
+        }
+
+        Map<ObjectRef, PrimExpr> out;
+        for (const auto& [key, replacement] : binding_map) {
+          bool used_by_function = false;
+          if (auto opt = key.as<String>()) {
+            used_by_function = var_names.count(opt.value());
+          } else if (auto ptr = key.as<tir::VarNode>()) {
+            used_by_function = vars.count(ptr);
+          } else {
+            LOG(FATAL) << "Expected symbolic variable to be a tir::Var "
+                       << "or a string name, but " << key << " was of type " 
<< key->GetTypeKey();
+          }
+          if (used_by_function) {
+            used.insert(key.get());
+            out.Set(key, replacement);
+          }
+        }
+        return out;
+      }();
+      func = FunctionBindSymbolicVars(func, func_binding_map);
+
+      if (!func.same_as(base_func)) {
+        updates->Add(gvar, func);
+      }
+    }
+  }
+
+  Array<ObjectRef> unused;
+  for (const auto& [key, replacement] : binding_map) {
+    if (!used.count(key.get())) {
+      unused.push_back(key);
+    }
+  }
+  CHECK_EQ(unused.size(), 0) << "Binding map contains keys " << unused
+                             << ", which did not correspond to any symbolic 
variables "
+                             << "in the module.";
+
+  if (updates->functions.size()) {
+    mod.CopyOnWrite()->Update(updates);
+  }
+  return mod;
+}
+}  // namespace
+
+TVM_REGISTER_GLOBAL("relax.FunctionBindSymbolicVars").set_body_typed(FunctionBindSymbolicVars);
+
+namespace transform {
+
+Pass BindSymbolicVars(Map<ObjectRef, PrimExpr> binding_map, Optional<String> 
func_name) {
+  auto pass_func = [=](IRModule mod, PassContext context) -> IRModule {
+    if (func_name) {
+      auto gvar = mod->GetGlobalVar(func_name.value());
+      auto func = Downcast<Function>(mod->Lookup(gvar));
+      auto new_func = FunctionBindSymbolicVars(func, binding_map);
+      if (!func.same_as(new_func)) {
+        mod.CopyOnWrite()->Update(gvar, new_func);
+      }
+    } else {
+      mod = ModuleBindSymbolicVars(mod, binding_map);
+    }
+    return mod;
+  };
+
+  return tvm::transform::CreateModulePass(pass_func, 1, 
"relax.BindSymbolicVars", {});
+}
+
+TVM_REGISTER_GLOBAL("relax.transform.BindSymbolicVars").set_body_typed(BindSymbolicVars);
+
+}  // namespace transform
+}  // namespace relax
+}  // namespace tvm
diff --git a/src/relax/utils.cc b/src/relax/utils.cc
index b0816b0eda..ccb72805e3 100644
--- a/src/relax/utils.cc
+++ b/src/relax/utils.cc
@@ -20,7 +20,9 @@
 #include "transform/utils.h"
 
 #include <tvm/relax/analysis.h>
+#include <tvm/relax/attrs/index.h>
 #include <tvm/relax/expr_functor.h>
+#include <tvm/tir/stmt_functor.h>
 
 namespace tvm {
 namespace relax {
@@ -33,6 +35,8 @@ class ExprBinder : public ExprMutator {
       : args_map_(args_map), symbolic_var_map_(symbolic_var_map) {}
 
  private:
+  using ExprMutator::VisitExpr_;
+
   Expr VisitExpr_(const FunctionNode* op) final {
     tvm::Array<Var> params;
     bool all_params_unchanged = true;
@@ -61,6 +65,49 @@ class ExprBinder : public ExprMutator {
     }
   }
 
+  Expr VisitExpr_(const CallNode* op) final {
+    auto call_node = Downcast<Call>(ExprMutator::VisitExpr_(op));
+
+    // Special case for strided_slice
+    //
+    // The strided_slice operator currently stores the begins/ends in
+    // the CallNode::attrs.  Because the CallNode::attrs is only
+    // intended to store static information, any PrimExpr members in
+    // the attributes are not visited by `ExprMutator::VisitPrimExpr`.
+    // Therefore, these must be explicitly visited.
+    //
+    // When the strided_slice operator is updated to store begins/ends
+    // as a tuple of `relax::PrimValue` in the arguments, this special
+    // case can be removed.
+    static auto strided_slice_op = Op::Get("relax.strided_slice");
+    if (call_node->op.same_as(strided_slice_op)) {
+      auto attrs = call_node->attrs.as<StridedSliceAttrs>();
+
+      auto visit_prim_expr = [this](const auto& expr) { return 
VisitPrimExpr(expr); };
+
+      Array<PrimExpr> begin = attrs->begin.Map(visit_prim_expr);
+      Array<PrimExpr> end = attrs->end.Map(visit_prim_expr);
+      auto strides = attrs->strides;
+      if (strides.defined()) {
+        strides = strides.value().Map(visit_prim_expr);
+      }
+
+      bool all_same = begin.same_as(attrs->begin) && end.same_as(attrs->end) &&
+                      (!strides.defined() || strides.same_as(attrs->strides));
+      if (!all_same) {
+        ObjectPtr<StridedSliceAttrs> new_attrs = 
make_object<StridedSliceAttrs>();
+        new_attrs->axes = attrs->axes;
+        new_attrs->begin = std::move(begin);
+        new_attrs->end = std::move(end);
+        new_attrs->strides = std::move(strides);
+        new_attrs->assume_inbound = attrs->assume_inbound;
+        call_node.CopyOnWrite()->attrs = Attrs(new_attrs);
+      }
+    }
+
+    return std::move(call_node);
+  }
+
   Expr VisitExpr_(const VarNode* op) final {
     auto id = GetRef<Var>(op);
     auto it = args_map_.find(id);
@@ -72,13 +119,12 @@ class ExprBinder : public ExprMutator {
   }
 
   PrimExpr VisitPrimExpr(const PrimExpr& expr) final {
-    if (const tir::VarNode* var = expr.as<tir::VarNode>()) {
-      auto it = symbolic_var_map_.find(GetRef<tir::Var>(var));
-      if (it != symbolic_var_map_.end()) {
-        return (*it).second;
-      }
+    auto new_expr = tir::Substitute(expr, symbolic_var_map_);
+    if (!expr.same_as(new_expr)) {
+      arith::Analyzer analyzer;
+      new_expr = analyzer.Simplify(new_expr);
     }
-    return ExprMutator::VisitPrimExpr(expr);
+    return new_expr;
   }
 
  private:
diff --git a/tests/python/relax/test_bind_symbolic_vars.py 
b/tests/python/relax/test_bind_symbolic_vars.py
new file mode 100644
index 0000000000..1dc1189a67
--- /dev/null
+++ b/tests/python/relax/test_bind_symbolic_vars.py
@@ -0,0 +1,205 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm
+import tvm.testing
+from tvm.script import relax as R, tir as T
+
+replace_by_tir_var = tvm.testing.parameter(
+    by_dict={"replace-by-string": False, "replace-by-tir-var": True}
+)
+
+
+def test_bind_static_value(replace_by_tir_var):
+    """Symbolic vars may be replaced
+
+    The replaced variables may be given either as strings, or as TIR variables
+    """
+
+    @R.function(private=True)
+    def before(A: R.Tensor(("M", "K")), B: R.Tensor(("K", "N"))) -> 
R.Tensor(("M", "N")):
+        return R.matmul(A, B)
+
+    @R.function(private=True)
+    def expected(A: R.Tensor((128, 64)), B: R.Tensor((64, 32))) -> 
R.Tensor((128, 32)):
+        return R.matmul(A, B)
+
+    if replace_by_tir_var:
+        M, K = before.params[0].struct_info.shape
+        _, N = before.params[1].struct_info.shape
+        symbolic_var_map = {M: 128, K: 64, N: 32}
+    else:
+        symbolic_var_map = {"M": 128, "K": 64, "N": 32}
+
+    after = before.bind_symbolic_vars(symbolic_var_map)
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_error_with_duplicate_var_names():
+    """Duplicate variable names may not be replaced by string
+
+    Two TIR variables may have the same name.  If two symbolic
+    variables share the same name, the replacement map may not refer
+    to that variable by string.
+    """
+    N1 = tvm.tir.Var("N", "int64")
+    N2 = tvm.tir.Var("N", "int64")
+
+    @R.function(private=True)
+    def func(A: R.Tensor((N1, N1)), B: R.Tensor((N1, N2))) -> R.Tensor((N1, 
N2)):
+        out: R.Tensor((N1, N2)) = R.matmul(A, B)
+        return out
+
+    with pytest.raises(tvm.TVMError):
+        func.bind_symbolic_vars({"N": 64})
+
+
+def test_string_var_when_other_var_has_duplicate_var_names():
+    """Like test_error_with_duplicate_var_names, but replacing a different 
variable
+
+    If two TIR variables share the same name, the restriction against
+    replacing variables by name only applies to those duplicate names.
+    Other variables may still be replaced by name.
+    """
+    N1 = tvm.tir.Var("N", "int64")
+    N2 = tvm.tir.Var("N", "int64")
+    BatchSize = tvm.tir.Var("BatchSize", "int64")
+
+    @R.function(private=True)
+    def before(
+        A: R.Tensor((BatchSize, N1, N1)), B: R.Tensor((N1, N2))
+    ) -> R.Tensor((BatchSize, N1, N2)):
+        out: R.Tensor((BatchSize, N1, N2)) = R.matmul(A, B)
+        return out
+
+    @R.function(private=True)
+    def expected(A: R.Tensor((16, N1, N1)), B: R.Tensor((N1, N2))) -> 
R.Tensor((16, N1, N2)):
+        out: R.Tensor((16, N1, N2)) = R.matmul(A, B)
+        return out
+
+    after = before.bind_symbolic_vars({"BatchSize": 16})
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_error_with_nonexisting_var_name():
+    """A string name of a symbolic var must be used by the function"""
+
+    @R.function(private=True)
+    def func(A: R.Tensor(("M", "N"))):
+        return A
+
+    with pytest.raises(tvm.TVMError):
+        func.bind_symbolic_vars({"non_existing_symbolic_var": 64})
+
+
+def test_error_with_nonexisting_tir_var():
+    """A TIR symbolic var must be a symbolic var of the function"""
+
+    @R.function(private=True)
+    def func(A: R.Tensor(["M", "N"])):
+        return A
+
+    with pytest.raises(tvm.TVMError):
+        func.bind_symbolic_vars({tvm.tir.Var("M", "int64"): 64})
+
+
+def test_error_with_multiple_definitions():
+    """The string/TIR var syntaxes may not define the same variable"""
+
+    @R.function(private=True)
+    def func(A: R.Tensor(["M", "N"])):
+        return A
+
+    tir_var = func.params[0].struct_info.shape[0]
+    symbolic_var_map = {tir_var: 0, "M": 0}
+
+    with pytest.raises(tvm.TVMError):
+        func.bind_symbolic_vars(symbolic_var_map)
+
+
+def test_error_if_output_has_undefined():
+    """The replacements may not introduce undefined symbolic vars"""
+
+    @R.function(private=True)
+    def func(A: R.Tensor(["M", "N"])):
+        return A
+
+    outside_var = tvm.tir.Var("outside_var", "int64")
+
+    with pytest.raises(tvm.TVMError):
+        func.bind_symbolic_vars({"M": outside_var * 2})
+
+
+def test_replacements_may_produce_new_symbolic_vars():
+    """The output may introduce symbolic vars, but they must be bound"""
+
+    @R.function(private=True)
+    def before(A: R.Tensor(["M", "N"])):
+        return A
+
+    @R.function(private=True)
+    def expected(A: R.Tensor(["outside_var * 2", "outside_var"])):
+        return A
+
+    outside_var = tvm.tir.Var("outside_var", "int64")
+
+    after = before.bind_symbolic_vars({"M": outside_var * 2, "N": outside_var})
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_bind_symbolic_vars_in_shape():
+    """The bound variable should be replaced when appearing in struct info"""
+
+    @R.function(private=True)
+    def before(A: R.Tensor(["M", "N"])):
+        M = T.int64()
+        N = T.int64()
+        B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([2 * M * 
N]))
+        return B
+
+    @R.function(private=True)
+    def expected(A: R.Tensor(["M", 16])):
+        M = T.int64()
+        B = R.call_dps_packed("dummy_func", [A], out_sinfo=R.Tensor([M * 32]))
+        return B
+
+    after = before.bind_symbolic_vars({"N": 16})
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+def test_bind_strided_slice():
+    """relax.op.strided_slice stores PrimExpr attributes"""
+
+    @R.function(private=True)
+    def before(A: R.Tensor(["M", "N"])):
+        N = T.int64()
+        B = R.strided_slice(A, [1], [0], [N // 4])
+        return B
+
+    @R.function(private=True)
+    def expected(A: R.Tensor(["M", 32])):
+        B = R.strided_slice(A, [1], [0], [8])
+        return B
+
+    after = before.bind_symbolic_vars({"N": 32})
+    tvm.ir.assert_structural_equal(expected, after)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/relax/test_transform_bind_symbolic_vars.py 
b/tests/python/relax/test_transform_bind_symbolic_vars.py
new file mode 100644
index 0000000000..687945a650
--- /dev/null
+++ b/tests/python/relax/test_transform_bind_symbolic_vars.py
@@ -0,0 +1,270 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pytest
+
+import tvm
+import tvm.script
+import tvm.testing
+from tvm import relax
+from tvm.script import ir as I
+from tvm.script import relax as R
+from tvm.script import tir as T
+
+
+def test_bind_tensors():
+    """Symbolic variables may occur in Tensor shapes"""
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(("batch", "m"), dtype="float32"),
+            w0: R.Tensor(("m", "n"), dtype="float32"),
+            w1: R.Tensor(("k", 10), dtype="float32"),
+        ) -> R.Tensor(("batch", "k"), dtype="float32"):
+            batch = T.Var("batch", "int64")
+            n = T.Var("n", "int64")
+            k = T.Var("k", "int64")
+            with R.dataflow():
+                lv0 = R.call_dps_packed(
+                    "test0", (x, w0), out_sinfo=R.Tensor((batch, n), 
dtype="float32")
+                )
+                out = R.call_dps_packed(
+                    "test1", (lv0, w1), out_sinfo=R.Tensor((batch, k), 
dtype="float32")
+                )
+                R.output(out)
+            return out
+
+    symvar_map = {"batch": 1, "k": 3}
+    target_func_name = "main"
+    After = relax.transform.BindSymbolicVars(symvar_map, 
target_func_name)(Before)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, "m"), dtype="float32"),
+            w0: R.Tensor(("m", "n"), dtype="float32"),
+            w1: R.Tensor((3, 10), dtype="float32"),
+        ) -> R.Tensor((1, 3), dtype="float32"):
+            n = T.int64()
+            with R.dataflow():
+                lv0 = R.call_dps_packed(
+                    "test0", (x, w0), out_sinfo=R.Tensor((1, n), 
dtype="float32")
+                )
+                out = R.call_dps_packed(
+                    "test1", (lv0, w1), out_sinfo=R.Tensor((1, 3), 
dtype="float32")
+                )
+                R.output(out)
+            return out
+
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_bind_shape():
+    """Symbolic variables may occur in ShapeExpr"""
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Shape(("batch", "m")),
+            w0: R.Shape(("m", "n")),
+            w1: R.Shape(("k", 10)),
+        ) -> R.Shape(("batch", "k")):
+            batch = T.Var("batch", "int64")
+            n = T.Var("n", "int64")
+            k = T.Var("k", "int64")
+            with R.dataflow():
+                lv0 = R.call_dps_packed("test0", (x, w0), 
out_sinfo=R.Tensor((batch, n)))
+                out = R.call_dps_packed("test1", (lv0, w1), 
out_sinfo=R.Tensor((batch, k)))
+                R.output(out)
+            return out
+
+    symvar_map = {"batch": 1, "k": 3}
+    target_func_name = "main"
+    After = relax.transform.BindSymbolicVars(symvar_map, 
target_func_name)(Before)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Shape([1, "m"]), w0: R.Shape(["m", "n"]), w1: R.Shape([3, 10])
+        ) -> R.Shape([1, 3]):
+            n = T.int64()
+            with R.dataflow():
+                lv0 = R.call_dps_packed("test0", (x, w0), 
out_sinfo=R.Tensor((1, n)))
+                out = R.call_dps_packed("test1", (lv0, w1), 
out_sinfo=R.Tensor((1, 3)))
+                R.output(out)
+            return out
+
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_arith():
+    """Symbolic shapes may use TIR arithmetic expressions"""
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(("batch", "m-1"), dtype="float32"),
+            w0: R.Tensor(("m", "n"), dtype="float32"),
+            w1: R.Tensor(("k", 10), dtype="float32"),
+        ) -> R.Tensor(("batch", "k*m"), dtype="float32"):
+            batch = T.Var("batch", "int64")
+            m = T.Var("m", "int64")
+            n = T.Var("n", "int64")
+            k = T.Var("k", "int64")
+            with R.dataflow():
+                lv0 = R.call_dps_packed(
+                    "test0",
+                    (x, w0),
+                    out_sinfo=R.Tensor((batch, m + n), dtype="float32"),
+                )
+                out = R.call_dps_packed(
+                    "test1",
+                    (lv0, w1),
+                    out_sinfo=R.Tensor((batch, k + n), dtype="float32"),
+                )
+                R.output(out)
+            return out
+
+    symvar_map = {"batch": 1, "k": 2, "m": 3}
+    target_func_name = "main"
+    After = relax.transform.BindSymbolicVars(symvar_map, 
target_func_name)(Before)
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main(
+            x: R.Tensor((1, 2), dtype="float32"),
+            w0: R.Tensor((3, "n"), dtype="float32"),
+            w1: R.Tensor((2, 10), dtype="float32"),
+        ) -> R.Tensor((1, 6), dtype="float32"):
+            n = T.int64()
+            with R.dataflow():
+                lv0 = R.call_dps_packed(
+                    "test0", (x, w0), out_sinfo=R.Tensor((1, n + 3), 
dtype="float32")
+                )
+                out = R.call_dps_packed(
+                    "test1", (lv0, w1), out_sinfo=R.Tensor((1, n + 2), 
dtype="float32")
+                )
+                R.output(out)
+            return out
+
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_bind_multiple_variables_by_name():
+    """String names may be used to replace across multiple functions"""
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main_1(x: R.Tensor(("m", "n"), dtype="float32")):
+            return x
+
+        @R.function
+        def main_2(x: R.Tensor(("m", "n"), dtype="float32")):
+            return x
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main_1(x: R.Tensor(("m", 16), dtype="float32")):
+            return x
+
+        @R.function
+        def main_2(x: R.Tensor(("m", 16), dtype="float32")):
+            return x
+
+    After = relax.transform.BindSymbolicVars({"n": 16})(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_bind_single_variable_by_identity():
+    """TIR variables may be used to replace a specific var"""
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main_1(x: R.Tensor(("m", "n"), dtype="float32")):
+            return x
+
+        @R.function
+        def main_2(x: R.Tensor(("m", "n"), dtype="float32")):
+            return x
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main_1(x: R.Tensor(("m", 16), dtype="float32")):
+            return x
+
+        @R.function
+        def main_2(x: R.Tensor(("m", "n"), dtype="float32")):
+            return x
+
+    main_1_n = Before["main_1"].params[0].struct_info.shape[1]
+    After = relax.transform.BindSymbolicVars({main_1_n: 16})(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_bind_single_variable_by_function_name():
+    """Variable name and function name may be used to replace a specific var"""
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main_1(x: R.Tensor(("m", "n"), dtype="float32")):
+            return x
+
+        @R.function
+        def main_2(x: R.Tensor(("m", "n"), dtype="float32")):
+            return x
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main_1(x: R.Tensor(("m", 16), dtype="float32")):
+            return x
+
+        @R.function
+        def main_2(x: R.Tensor(("m", "n"), dtype="float32")):
+            return x
+
+    After = relax.transform.BindSymbolicVars({"n": 16}, "main_1")(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
+def test_error_for_unused_replacement():
+    """Each replacement must be used"""
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(x: R.Tensor(("m", "n"), dtype="float32")):
+            return x
+
+    with pytest.raises(tvm.TVMError):
+        relax.transform.BindSymbolicVars({"non_existing_var_name": 16})(Before)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()


Reply via email to