This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 5618628586 [TVMScript][Relax] Preserve tir.SizeVar through TVMScript
round-trip (#17083)
5618628586 is described below
commit 561862858661aca27ecd6d0d14fb30b03ad9acab
Author: Eric Lunderberg <[email protected]>
AuthorDate: Thu Jun 13 06:50:20 2024 -0500
[TVMScript][Relax] Preserve tir.SizeVar through TVMScript round-trip
(#17083)
* [TVMScript][Relax] Preserve tir.SizeVar through TVMScript round-trip
Prior to this commit, all symbolic variables were printed identically,
regardless of whether the underlying variable was a `tir.Var` or
`tir.SizeVar`. As a result, numeric simplifications that rely on a
`tir.SizeVar` being non-negative may be skipped after a round-trip
through TVMScript.
This commit updates the TVMScript printing and parsing of Relax
functions to use `var = T.int64(is_size_var=True)` for `tir.SizeVar`,
matching how `tir.SizeVar` is parsed for TIR functions. As an added
benefit, this also allows Relax functions `R.Prim` arguments other
than `int64` to be benefit. This may be useful in the future, such as
to specify the fill value for `R.full`.
* Remove strict=True argument, not available until python 3.10
* lint fix
* Fix breakage in unit tests
---
python/tvm/script/parser/relax/parser.py | 46 +++++++++++++++++++---
src/script/printer/relax/tir.cc | 3 +-
tests/python/tvmscript/test_tvmscript_roundtrip.py | 28 +++++++++++++
3 files changed, 71 insertions(+), 6 deletions(-)
diff --git a/python/tvm/script/parser/relax/parser.py
b/python/tvm/script/parser/relax/parser.py
index 400c023aa7..08269ddeeb 100644
--- a/python/tvm/script/parser/relax/parser.py
+++ b/python/tvm/script/parser/relax/parser.py
@@ -68,7 +68,14 @@ def bind_assign_value(
"Expected the same dtype for TIR vars "
f"but got {value.dtype} vs {prev_value.dtype}",
)
- return prev_value
+ if not isinstance(value, type(prev_value)):
+ self.report_error(
+ node,
+ f"Expected the same IR type for TIR vars "
+ f"but existing value {type(value)} is mismatched "
+ f"to previous {type(prev_value)}",
+ )
+ value = prev_value
IRBuilder.name(var_name, value)
return value
@@ -144,18 +151,47 @@ def is_recursive(node: doc.FunctionDef) -> bool:
return False
+def collect_symbolic_var_from_prelude(
+ self: Parser, node: doc.FunctionDef, symbolic_vars: Dict[str, tir.Var]
+) -> Dict[str, tir.Var]:
+ prelude_vars = {}
+ for stmt in node.body:
+ if isinstance(stmt, doc.Assign) and all(
+ isinstance(target, doc.Name) and target.id in symbolic_vars for
target in stmt.targets
+ ):
+ values = self.eval_expr(stmt.value)
+
+ try:
+ iter(values)
+ except TypeError:
+ values = [values]
+
+ assert len(stmt.targets) == len(values)
+ for target, value in zip(stmt.targets, values):
+ name = target.id
+ prelude_vars[name] = value
+
+ return {**symbolic_vars, **prelude_vars}
+
+
def collect_symbolic_var_from_params(self: Parser, node: doc.FunctionDef) ->
None:
# Collect symbolic vars from parameters
- symbolic_vars = set()
+ symbolic_vars = {}
for arg in node.args.args:
if arg.annotation is None:
self.report_error(arg, "Type annotation is required for function
parameters.")
param_sinfo_proxy = eval_struct_info_proxy(self, arg.annotation)
- symbolic_vars.update(param_sinfo_proxy.get_symbolic_vars())
+
+ for var_name in param_sinfo_proxy.get_symbolic_vars():
+ if var_name not in symbolic_vars:
+ symbolic_vars[var_name] = tir.Var(var_name, "int64")
+
+ # Update symbolic vars based on
+ symbolic_vars = collect_symbolic_var_from_prelude(self, node,
symbolic_vars)
# Define symbolic vars to the current var_table frame
- for var_name in symbolic_vars:
- self.var_table.add(var_name, tir.Var(var_name, "int64"),
allow_shadowing=False)
+ for var_name, var in symbolic_vars.items():
+ self.var_table.add(var_name, var, allow_shadowing=False)
@dispatch.register(token="relax", type_name="FunctionDef")
diff --git a/src/script/printer/relax/tir.cc b/src/script/printer/relax/tir.cc
index 1a9c5d0546..6f9a8cbf89 100644
--- a/src/script/printer/relax/tir.cc
+++ b/src/script/printer/relax/tir.cc
@@ -18,6 +18,7 @@
*/
#include <tvm/ir/expr.h>
+#include "../tir/utils.h"
#include "./utils.h"
namespace tvm {
@@ -59,7 +60,7 @@ Doc PrintTIRVar(tir::Var n, ObjectPath n_p, IRDocsifier d) {
}
IdDoc var = d->Define(n, GetRef<Frame>(f), n->name_hint.empty() ? "v" :
n->name_hint);
var->source_paths.push_back(n_p);
- f->stmts.push_back(AssignDoc(var, TIR(d, DType2Str(n->dtype))->Call({}),
NullOpt));
+ f->stmts.push_back(AssignDoc(var, PrintVarCreation(n, n_p, d), NullOpt));
}
if (Optional<ExprDoc> doc = d->GetVarDoc(n)) {
return doc.value();
diff --git a/tests/python/tvmscript/test_tvmscript_roundtrip.py
b/tests/python/tvmscript/test_tvmscript_roundtrip.py
index ee404f08ef..f81a80de6d 100644
--- a/tests/python/tvmscript/test_tvmscript_roundtrip.py
+++ b/tests/python/tvmscript/test_tvmscript_roundtrip.py
@@ -4088,6 +4088,32 @@ def relax_match_cast_struct_info_proxy():
yield make_ir_generator(subclass)
+def relax_symbolic_size_var():
+ """Relax symbolic variables may be SizeVar"""
+ N = tvm.tir.SizeVar("N", "int64")
+
+ @R.function
+ def func(A: R.Tensor([N], "float16")):
+ B: R.Tensor([N], "float16") = A
+ return B
+
+ return func
+
+
+def relax_float_symbolic_var():
+ """Relax symbolic variables may hold any dtype"""
+
+ @R.function
+ def func(A: R.Tensor(["N"], "float16"), _: R.Prim(value="threshold")):
+ N = T.int64()
+ threshold = T.float16()
+
+ B = A >= R.prim_value(threshold / T.cast(N, "float16"))
+ return B
+
+ return func
+
+
ir_generator = tvm.testing.parameter(
launch_env_thread,
opt_gemm_normalize,
@@ -4174,6 +4200,8 @@ ir_generator = tvm.testing.parameter(
return_zero_private_with_attr,
*op_of_literal(),
*relax_match_cast_struct_info_proxy(),
+ relax_symbolic_size_var,
+ relax_float_symbolic_var,
)
relax_ir_generator = tvm.testing.parameter(