Lunderberg commented on code in PR #16049:
URL: https://github.com/apache/tvm/pull/16049#discussion_r1382198806


##########
tests/python/relax/test_transform_lift_transform_params.py:
##########
@@ -642,5 +642,95 @@ def slice(
     tvm.ir.assert_structural_equal(Expected, after)
 
 
+def test_symbolic_var_in_param_shape():
+    @tvm.script.ir_module
+    class Before:
+        @R.function
+        def main(
+            x: R.Tensor((1, 16, 224, "n"), "float32"),
+            w1: R.Tensor((16, "m", 3, 3), "float32"),
+            w2: R.Tensor((16, "m", 3, 3), "float32"),
+        ) -> R.Tensor((1, 16, 224, 224), "float32"):
+            m = T.int64()
+            n = T.int64()
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                zeros = R.zeros((n, n), "float32")
+                w1 = R.add(w1, R.const(1, "float32"))
+                conv1 = R.nn.conv2d(x, w1, padding=(1, 1), data_layout="NCHW", 
kernel_layout="OIHW")
+                conv2 = R.nn.conv2d(
+                    conv1, w2, padding=(1, 1), data_layout="NCHW", 
kernel_layout="OIHW"
+                )
+                R.output(conv2)
+            return conv2
+
+    @I.ir_module
+    class Expected:
+        @R.function
+        def main_transform_params(
+            params: R.Tuple(
+                R.Tensor((16, "m", 3, 3), dtype="float32"),
+                R.Tensor((16, "m", 3, 3), dtype="float32"),
+            )
+        ) -> R.Tuple(
+            R.Tensor((16, "m", 3, 3), dtype="float32"), R.Tensor((16, "m", 3, 
3), dtype="float32")
+        ):
+            m = T.int64()
+            with R.dataflow():
+                lv: R.Tensor((16, m, 3, 3), dtype="float32") = params[1]
+                lv1: R.Tensor((16, m, 3, 3), dtype="float32") = params[0]
+                lv2: R.Tensor((16, m, 3, 3), dtype="float32") = R.add(lv1, 
R.const(1, "float32"))
+                gv: R.Tuple(
+                    R.Tensor((16, m, 3, 3), dtype="float32"),
+                    R.Tensor((16, m, 3, 3), dtype="float32"),
+                ) = (lv, lv2)
+                R.output(gv)
+            return gv
+
+        @R.function
+        def main(
+            x: R.Tensor((1, 16, 224, "n"), dtype="float32"),
+            transformed_param_0: R.Tensor((16, "m", 3, 3), dtype="float32"),
+            transformed_param_1: R.Tensor((16, "m", 3, 3), dtype="float32"),
+        ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+            n = T.int64()
+            m = T.int64()
+            R.func_attr({"num_input": 1})
+            with R.dataflow():
+                zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n, 
n]), dtype="float32")
+                lv: R.Tensor((16, m, 3, 3), dtype="float32") = 
transformed_param_1
+                conv1: R.Tensor((1, 16, 224, n), dtype="float32") = 
R.nn.conv2d(
+                    x,
+                    lv,
+                    strides=[1, 1],
+                    padding=[1, 1, 1, 1],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="void",
+                )
+                lv1: R.Tensor((16, m, 3, 3), dtype="float32") = 
transformed_param_0
+                conv2: R.Tensor((1, 16, 224, n), dtype="float32") = 
R.nn.conv2d(
+                    conv1,
+                    lv1,
+                    strides=[1, 1],
+                    padding=[1, 1, 1, 1],
+                    dilation=[1, 1],
+                    groups=1,
+                    data_layout="NCHW",
+                    kernel_layout="OIHW",
+                    out_layout="NCHW",
+                    out_dtype="void",
+                )
+                R.output(conv2)
+            return conv2
+
+    mod = Before
+    after = relax.transform.LiftTransformParams()(mod)
+    tvm.ir.assert_structural_equal(after, Expected)
+

Review Comment:
   Can we add a unit test for the case where the weights require symbolic 
variables that are used, but not definable, from the shapes of the weights?  
This would require the `DefinableTIRVarsInStructInfo` usage described earlier.  
The following input would be able to trigger this case.
   
   ```python
   def test_symbolic_var_defined_in_params_but_used_in_weights():
       """A symbolic variable's occurrence in the weights may not define it
   
       In order to be a source of definition, a symbolic variable in the
       parameters must occur as a distinct parameter, as a tensor shape
       `R.Tensor(["var"])`, an explicit `R.Shape(["var"])`, or as a
       `R.Prim(value="var")`.  A variable that is part of a larger
       expression, such as `R.Tensor(["m * n"])`, are variable usages,
       not variable definitions.
       """
   
       @tvm.script.ir_module
       class Before:
           @R.function
           def main(
               x: R.Tensor(["m", "n"], "float32"),
               weight: R.Tensor(["m * n"], "float32"),
           ) -> R.Tensor(["m", "n"], "float32"):
               m = T.int64()
               n = T.int64()
               R.func_attr({"num_input": 1})
               with R.dataflow():
                   weight = R.add(weight, R.const(1, "float32"))
                   weight = R.reshape(weight, [m, n])
                   output = R.multiply(x, weight)
                   R.output(output)
               return output
   
       @tvm.script.ir_module
       class Expected:
           pass  # TODO
   
       After = relax.transform.LiftTransformParams()(Before)
       tvm.ir.assert_structural_equal(Expected, After)
   ```



##########
src/relax/transform/lift_transform_params.cc:
##########
@@ -237,6 +238,13 @@ class LiftTransformParamsPlanner : public ExprVisitor {
         builder_.UpdateBasedOnRuntimeInput(function->params[i]);
       } else {
         builder_.AddInput(function->params[i]);
+        if (function->params[i]->struct_info_.defined()) {
+          Array<tir::Var> symbolic_vars =
+              
TIRVarsInStructInfo(Downcast<StructInfo>(function->params[i]->struct_info_.value()));

Review Comment:
   Nit: This should be 
[`DefinableTIRVarsInStructInfo`](https://github.com/apache/tvm/blob/unity/include/tvm/relax/analysis.h#L271-L283),
 as not all symbolic variables that appear in the struct info can be used to 
define a symbolic variable.  For example, in `def func(A: R.Tensor(["m", "n"]), 
B: R.Tensor("m * n"))`, the struct info of `A` can be used to define `m` and 
`n`, but the struct info of `B` cannot.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to