This is an automated email from the ASF dual-hosted git repository.
hongyij pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/unity by this push:
new 9100a8e8c9 [Unity] [LiftTransformParams] Treat symbolic var in weight
shape as constant (#16049)
9100a8e8c9 is described below
commit 9100a8e8c9906a2cd93730c1951f19b13de60de6
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed Nov 8 11:58:15 2023 -0800
[Unity] [LiftTransformParams] Treat symbolic var in weight shape as
constant (#16049)
* treat symbolic var in param shape as constant
* address comment
* add xfail test
---
src/relax/transform/lift_transform_params.cc | 19 ++-
.../relax/test_transform_lift_transform_params.py | 155 +++++++++++++++++++++
2 files changed, 171 insertions(+), 3 deletions(-)
diff --git a/src/relax/transform/lift_transform_params.cc
b/src/relax/transform/lift_transform_params.cc
index cef19ff068..b500a3c3a3 100644
--- a/src/relax/transform/lift_transform_params.cc
+++ b/src/relax/transform/lift_transform_params.cc
@@ -32,6 +32,7 @@
#include <vector>
#include "../../support/ordered_set.h"
+#include "utils.h"
namespace tvm {
namespace relax {
@@ -237,6 +238,13 @@ class LiftTransformParamsPlanner : public ExprVisitor {
builder_.UpdateBasedOnRuntimeInput(function->params[i]);
} else {
builder_.AddInput(function->params[i]);
+ if (function->params[i]->struct_info_.defined()) {
+ Array<tir::Var> symbolic_vars = DefinableTIRVarsInStructInfo(
+ Downcast<StructInfo>(function->params[i]->struct_info_.value()));
+ for (const auto& var : symbolic_vars) {
+ param_symbolic_vars_.insert(var);
+ }
+ }
}
}
VisitExpr(function->body);
@@ -275,9 +283,12 @@ class LiftTransformParamsPlanner : public ExprVisitor {
can_lift = false;
}
- // Cond 4. Do not lift when its struct info contains symbolic variables.
- if (!TIRVarsInStructInfo(GetStructInfo(binding->var)).empty()) {
- can_lift = false;
+ // Cond 4. Do not lift when its struct info contains symbolic variables
that do not appear in
+ // params.
+ for (const auto& var : TIRVarsInStructInfo(GetStructInfo(binding->var))) {
+ if (!param_symbolic_vars_.count(var)) {
+ can_lift = false;
+ }
}
// Cond 5. Do not lift declarations of external functions
@@ -296,6 +307,8 @@ class LiftTransformParamsPlanner : public ExprVisitor {
TransformParamsFuncBuilder builder_;
// Whether we are in a dataflow block
bool is_in_dataflow_block_{false};
+ // The symbolic variables in the parameters
+ std::unordered_set<tir::Var, ObjectPtrHash, ObjectPtrEqual>
param_symbolic_vars_;
};
/*!
diff --git a/tests/python/relax/test_transform_lift_transform_params.py
b/tests/python/relax/test_transform_lift_transform_params.py
index 7389060bde..5b24614469 100644
--- a/tests/python/relax/test_transform_lift_transform_params.py
+++ b/tests/python/relax/test_transform_lift_transform_params.py
@@ -15,6 +15,8 @@
# specific language governing permissions and limitations
# under the License.
+import pytest
+
import tvm
import tvm.testing
from tvm import relax
@@ -642,5 +644,158 @@ def test_symbolic_var_from_shape():
tvm.ir.assert_structural_equal(Expected, after)
+def test_symbolic_var_in_param_shape():
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor((1, 16, 224, "n"), "float32"),
+ w1: R.Tensor((16, "m", 3, 3), "float32"),
+ w2: R.Tensor((16, "m", 3, 3), "float32"),
+ ) -> R.Tensor((1, 16, 224, 224), "float32"):
+ m = T.int64()
+ n = T.int64()
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ zeros = R.zeros((n, n), "float32")
+ w1 = R.add(w1, R.const(1, "float32"))
+ conv1 = R.nn.conv2d(x, w1, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW")
+ conv2 = R.nn.conv2d(
+ conv1, w2, padding=(1, 1), data_layout="NCHW",
kernel_layout="OIHW"
+ )
+ R.output(conv2)
+ return conv2
+
+ @I.ir_module
+ class Expected:
+ @R.function
+ def main_transform_params(
+ params: R.Tuple(
+ R.Tensor((16, "m", 3, 3), dtype="float32"),
+ R.Tensor((16, "m", 3, 3), dtype="float32"),
+ )
+ ) -> R.Tuple(
+ R.Tensor((16, "m", 3, 3), dtype="float32"), R.Tensor((16, "m", 3,
3), dtype="float32")
+ ):
+ m = T.int64()
+ with R.dataflow():
+ lv: R.Tensor((16, m, 3, 3), dtype="float32") = params[1]
+ lv1: R.Tensor((16, m, 3, 3), dtype="float32") = params[0]
+ lv2: R.Tensor((16, m, 3, 3), dtype="float32") = R.add(lv1,
R.const(1, "float32"))
+ gv: R.Tuple(
+ R.Tensor((16, m, 3, 3), dtype="float32"),
+ R.Tensor((16, m, 3, 3), dtype="float32"),
+ ) = (lv, lv2)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor((1, 16, 224, "n"), dtype="float32"),
+ transformed_param_0: R.Tensor((16, "m", 3, 3), dtype="float32"),
+ transformed_param_1: R.Tensor((16, "m", 3, 3), dtype="float32"),
+ ) -> R.Tensor((1, 16, 224, 224), dtype="float32"):
+ n = T.int64()
+ m = T.int64()
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ zeros: R.Tensor((n, n), dtype="float32") = R.zeros(R.shape([n,
n]), dtype="float32")
+ lv: R.Tensor((16, m, 3, 3), dtype="float32") =
transformed_param_1
+ conv1: R.Tensor((1, 16, 224, n), dtype="float32") =
R.nn.conv2d(
+ x,
+ lv,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="void",
+ )
+ lv1: R.Tensor((16, m, 3, 3), dtype="float32") =
transformed_param_0
+ conv2: R.Tensor((1, 16, 224, n), dtype="float32") =
R.nn.conv2d(
+ conv1,
+ lv1,
+ strides=[1, 1],
+ padding=[1, 1, 1, 1],
+ dilation=[1, 1],
+ groups=1,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="NCHW",
+ out_dtype="void",
+ )
+ R.output(conv2)
+ return conv2
+
+ mod = Before
+ after = relax.transform.LiftTransformParams()(mod)
+ tvm.ir.assert_structural_equal(after, Expected)
+
+
+# not supported yet
[email protected]
+def test_symbolic_var_defined_in_params_but_used_in_weights():
+ """A symbolic variable's occurrence in the weights may not define it
+
+ In order to be a source of definition, a symbolic variable in the
+ parameters must occur as a distinct parameter, as a tensor shape
+ `R.Tensor(["var"])`, an explicit `R.Shape(["var"])`, or as a
+ `R.Prim(value="var")`. A variable that is part of a larger
+ expression, such as `R.Tensor(["m * n"])`, are variable usages,
+ not variable definitions.
+ """
+
+ @tvm.script.ir_module
+ class Before:
+ @R.function
+ def main(
+ x: R.Tensor(["m", "n"], "float32"),
+ weight: R.Tensor(["m * n"], "float32"),
+ ) -> R.Tensor(["m", "n"], "float32"):
+ m = T.int64()
+ n = T.int64()
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ weight = R.add(weight, R.const(1, "float32"))
+ weight = R.reshape(weight, [m, n])
+ output = R.multiply(x, weight)
+ R.output(output)
+ return output
+
+ @tvm.script.ir_module
+ class Expected:
+ @R.function
+ def main_transform_params(
+ params: R.Tuple(R.Tensor(("k",), dtype="float32"))
+ ) -> R.Tuple(R.Tensor(dtype="float32", ndim=1)):
+ k = T.int64()
+ with R.dataflow():
+ lv: R.Tensor((k,), dtype="float32") = params[0]
+ gv: R.Tuple(R.Tensor((k,), dtype="float32")) = (lv,)
+ R.output(gv)
+ return gv
+
+ @R.function
+ def main(
+ x: R.Tensor(("m", "n"), dtype="float32"),
+ transformed_param_0: R.Tensor(dtype="float32", ndim=1),
+ ) -> R.Tensor(("m", "n"), dtype="float32"):
+ m = T.int64()
+ n = T.int64()
+ R.func_attr({"num_input": 1})
+ with R.dataflow():
+ lv: R.Tensor(dtype="float32", ndim=1) = transformed_param_0
+ weight: R.Tensor(dtype="float32", ndim=1) = R.add(lv,
R.const(1, "float32"))
+ weight_1: R.Tensor((m, n), dtype="float32") =
R.reshape(weight, R.shape([m, n]))
+ output: R.Tensor((m, n), dtype="float32") = R.multiply(x,
weight_1)
+ R.output(output)
+ return output
+
+ After = relax.transform.LiftTransformParams()(Before)
+ tvm.ir.assert_structural_equal(Expected, After)
+
+
if __name__ == "__main__":
tvm.testing.main()