This is an automated email from the ASF dual-hosted git repository.

hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new 9100a8e8c9 [Unity] [LiftTransformParams] Treat symbolic var in weight 
shape as constant (#16049)
9100a8e8c9 is described below

commit 9100a8e8c9906a2cd93730c1951f19b13de60de6
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed Nov 8 11:58:15 2023 -0800

    [Unity] [LiftTransformParams] Treat symbolic var in weight shape as 
constant (#16049)
    
    * treat symbolic var in param shape as constant
    
    * address comment
    
    * add xfail test
---
 src/relax/transform/lift_transform_params.cc       |  19 ++-
 .../relax/test_transform_lift_transform_params.py  | 155 +++++++++++++++++++++
 2 files changed, 171 insertions(+), 3 deletions(-)

diff --git a/src/relax/transform/lift_transform_params.cc 
b/src/relax/transform/lift_transform_params.cc
index cef19ff068..b500a3c3a3 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -32,6 +32,7 @@
 #include <vector>
 
 #include "../../support/ordered_set.h"
+#include "utils.h"
 
 namespace tvm {
 namespace relax {
@@ -237,6 +238,13 @@ class LiftTransformParamsPlanner : public ExprVisitor {
         builder_.UpdateBasedOnRuntimeInput(function->params[i]);
       } else {
         builder_.AddInput(function->params[i]);
+        if (function->params[i]->struct_info_.defined()) {
+          Array<tir::Var> symbolic_vars = DefinableTIRVarsInStructInfo(
+              Downcast<StructInfo>(function->params[i]->struct_info_.value()));
+          for (const auto& var : symbolic_vars) {
+            param_symbolic_vars_.insert(var);
+          }
+        }
       }
     }
     VisitExpr(function->body);
@@ -275,9 +283,12 @@ class LiftTransformParamsPlanner : public ExprVisitor {
       can_lift = false;
     }
 
-    // Cond 4. Do not lift when its struct info contains symbolic variables.
-    if (!TIRVarsInStructInfo(GetStructInfo(binding->var)).empty()) {
-      can_lift = false;
+    // Cond 4. Do not lift when its struct info contains symbolic variables 
that do not appear in
+    // params.
+    for (const auto& var : TIRVarsInStructInfo(GetStructInfo(binding->var))) {
+      if (!param_symbolic_vars_.count(var)) {
+        can_lift = false;
+      }
     }
 
     // Cond 5. Do not lift declarations of external functions
@@ -296,6 +307,8 @@ class LiftTransformParamsPlanner : public ExprVisitor {
   TransformParamsFuncBuilder builder_;
   // Whether we are in a dataflow block
   bool is_in_dataflow_block_{false};
+  // The symbolic variables in the parameters
+  std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual> 
param_symbolic_vars_;
 };
 
 /*!
diff --git a/tests/python/relax/test_transform_lift_transform_params.py 
b/tests/python/relax/test_transform_lift_transform_params.py
index 7389060bde..5b24614469 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -15,6 +15,8 @@
 # specific language governing permissions and limitations
 # under the License.
 
+import pytest
+
 import tvm
 import tvm.testing
 from tvm import relax
@@ -642,5 +644,158 @@ def test_symbolic_var_from_shape():
     tvm.ir.assert_structural_equal(Expected, after)
 
 
+def test_symbolic_var_in_param_shape():
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((1, 16, 224, "n"), "float32"),
+            w1: R.Tensor((16, "m", 3, 3), "float32"),
+            w2: R.Tensor((16, "m", 3, 3), "float32"),
+        ) -> R.Tensor((1, 16, 224, 224), "float32"):
+            m = T.int64()
+            n = T.int64()
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                zeros = R.zeros((n, n), "float32")
+                w1 = R.add(w1, R.const(1, "float32"))
+                conv1 = R.nn.conv2d(x, w1, padding=(1, 1), data_layout="NCHW", 
kernel_layout="OIHW")
+                conv2 = R.nn.conv2d(
+                    conv1, w2, padding=(1, 1), data_layout="NCHW", 
kernel_layout="OIHW"
+                )
+                R.output(conv2)
+            return conv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main_transform_params(
+            params: R.Tuple(
+                R.Tensor((16, "m", 3, 3), dtype="float32"),
+                R.Tensor((16, "m", 3, 3), dtype="float32"),
+            )
+        ) -> R.Tuple(
+            R.Tensor((16, "m", 3, 3), dtype="float32"), R.Tensor((16, "m", 3, 
3), dtype="float32")
+        ):
+            m = T.int64()
+            with R.dataflow():
+                lv: R.Tensor((16, m, 3, 3), dtype="float32") = params[1]
+                lv1: R.Tensor((16, m, 3, 3), dtype="float32") = params[0]
+                lv2: R.Tensor((16, m, 3, 3), dtype="float32") = R.add(lv1, 
R.const(1, "float32"))
+                gv: R.Tuple(
+                    R.Tensor((16, m, 3, 3), dtype="float32"),
+                    R.Tensor((16, m, 3, 3), dtype="float32"),
+                ) = (lv, lv2)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((1, 16, 224, "n"), dtype="float32"),
+            transformed_param_0: R.Tensor((16, "m", 3, 3), dtype="float32"),
+            transformed_param_1: R.Tensor((16, "m", 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+            n = T.int64()
+            m = T.int64()
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n, 
n]), dtype="float32")
+                lv: R.Tensor((16, m, 3, 3), dtype="float32") = 
transformed_param_1
+                conv1: R.Tensor((1, 16, 224, n), dtype="float32") = 
R.nn.conv2d(
+                    x,
+                    lv,
+                    strides=[1, 1],
+                    padding=[1, 1, 1, 1],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="void",
+                )
+                lv1: R.Tensor((16, m, 3, 3), dtype="float32") = 
transformed_param_0
+                conv2: R.Tensor((1, 16, 224, n), dtype="float32") = 
R.nn.conv2d(
+                    conv1,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[1, 1, 1, 1],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="void",
+                )
+                R.output(conv2)
+            return conv2
+
+    mod = Before
+    after = relax.transform.LiftTransformParams()(mod)
+    tvm.ir.assert_structural_equal(after, Expected)
+
+
+# not supported yet
[email protected]
+def test_symbolic_var_defined_in_params_but_used_in_weights():
+    """A symbolic variable's occurrence in the weights may not define it
+
+    In order to be a source of definition, a symbolic variable in the
+    parameters must occur as a distinct parameter, as a tensor shape
+    `R.Tensor(["var"])`, an explicit `R.Shape(["var"])`, or as a
+    `R.Prim(value="var")`.  A variable that is part of a larger
+    expression, such as `R.Tensor(["m * n"])`, are variable usages,
+    not variable definitions.
+    """
+
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor(["m", "n"], "float32"),
+            weight: R.Tensor(["m * n"], "float32"),
+        ) -> R.Tensor(["m", "n"], "float32"):
+            m = T.int64()
+            n = T.int64()
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                weight = R.add(weight, R.const(1, "float32"))
+                weight = R.reshape(weight, [m, n])
+                output = R.multiply(x, weight)
+                R.output(output)
+            return output
+
+    @tvm.script.ir_module
+    class Expected:
+        @R.function
+        def main_transform_params(
+            params: R.Tuple(R.Tensor(("k",), dtype="float32"))
+        ) -> R.Tuple(R.Tensor(dtype="float32", ndim=1)):
+            k = T.int64()
+            with R.dataflow():
+                lv: R.Tensor((k,), dtype="float32") = params[0]
+                gv: R.Tuple(R.Tensor((k,), dtype="float32")) = (lv,)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor(("m", "n"), dtype="float32"),
+            transformed_param_0: R.Tensor(dtype="float32", ndim=1),
+        ) -> R.Tensor(("m", "n"), dtype="float32"):
+            m = T.int64()
+            n = T.int64()
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                lv: R.Tensor(dtype="float32", ndim=1) = transformed_param_0
+                weight: R.Tensor(dtype="float32", ndim=1) = R.add(lv, 
R.const(1, "float32"))
+                weight_1: R.Tensor((m, n), dtype="float32") = 
R.reshape(weight, R.shape([m, n]))
+                output: R.Tensor((m, n), dtype="float32") = R.multiply(x, 
weight_1)
+                R.output(output)
+            return output
+
+    After = relax.transform.LiftTransformParams()(Before)
+    tvm.ir.assert_structural_equal(Expected, After)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to