This is an automated email from the ASF dual-hosted git repository.
lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 72b75fe5b2 [Relax] Validate StructInfo of variable bindings (#17332)
72b75fe5b2 is described below
commit 72b75fe5b2f34765892b6ae3ba8709bad318b7bd
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 11 08:34:17 2024 -0500
[Relax] Validate StructInfo of variable bindings (#17332)
* [Relax] Validate StructInfo of variable bindings
In Relax, both the variable and the expression in a `VarBinding` may
contain `StructInfo` annotations. Prior to this commit, these
`StructInfo` annotations could be inconsistent, assigning an
expression to a variable of incompatible type.
This commit updates the Relax well-formed checker to verify that the
`StructInfo` of Relax variables accurately describes their contents.
* Fix unit tests
* [Relax][Bugfix] LCA of PrimStructInfo must check known values
The `StructInfoLCA` determines the lowest common ancestor between two
`StructInfo` annotations. This is primarily used in Relax to
determine the appropriate `StructInfo` annotation for a `relax::If`
node, given the `StructInfo` of each branch. Prior to this commit,
when determining the LCA of two `PrimStructInfo` annotations, the
`StructInfoLCA` function only inspected the datatype of
`PrimStructInfo` annotations, and did not check for known values. For
example, the LCA of `R.Prim(value=T.int64(128))` and
`R.Prim(value=T.int64(64))` is `R.Prim("int64")`, but was incorrectly
determined as `R.Prim(value=T.int64(128))` by the `StructInfoLCA`
function.
This commit updates `StructInfoLCA` to inspect the known values of a
`PrimStructInfo`, as well as the datatype.
---
src/relax/analysis/struct_info_analysis.cc | 23 +++++-
src/relax/analysis/well_formed.cc | 12 +++
src/relax/transform/normalize.cc | 6 +-
.../relax/test_analysis_struct_info_analysis.py | 94 +++++++++++++++++++++-
tests/python/relax/test_analysis_well_formed.py | 87 ++++++++++++++++++++
5 files changed, 216 insertions(+), 6 deletions(-)
diff --git a/src/relax/analysis/struct_info_analysis.cc
b/src/relax/analysis/struct_info_analysis.cc
index a7e5404c20..6fe8f36020 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -982,10 +982,25 @@ class StructInfoLCAFinder
StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo&
other) final {
auto* rhs = other.as<PrimStructInfoNode>();
if (rhs == nullptr) return ObjectStructInfo(lhs->span);
- if (lhs->dtype == rhs->dtype) return GetRef<StructInfo>(lhs);
- // PrimType will be treated as their boxed(object) values
- // as a result we can unify to object.
- return ObjectStructInfo(lhs->span);
+ if (lhs->dtype != rhs->dtype) {
+ // PrimType will be treated as their boxed(object) values
+ // as a result we can unify to object.
+ return ObjectStructInfo(lhs->span);
+ }
+ if (!lhs->value.defined() || !rhs->value.defined() ||
+ !analyzer_->CanProveEqual(lhs->value.value(), rhs->value.value())) {
+ // The two values are known to contain the same dtype, but may
+ // contain different values.
+ if (!lhs->value.defined()) {
+ // If the mismatch was due to extra information in the RHS,
+ // prefer to avoid constructing a new object.
+ return GetRef<StructInfo>(lhs);
+ } else {
+ return PrimStructInfo(lhs->dtype, lhs->span);
+ }
+ }
+
+ return GetRef<StructInfo>(lhs);
}
StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const
StructInfo& other) final {
diff --git a/src/relax/analysis/well_formed.cc
b/src/relax/analysis/well_formed.cc
index 235059ece2..7688c4a642 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -429,6 +429,18 @@ class WellFormedChecker : public relax::ExprVisitor,
}
this->VisitVarDef(binding->var);
+
+ if (check_struct_info_ && binding->var->struct_info_.defined() &&
+ binding->value->struct_info_.defined()) {
+ auto expr_sinfo = GetStructInfo(binding->value);
+ auto var_sinfo = GetStructInfo(binding->var);
+ if (!IsBaseOf(var_sinfo, expr_sinfo)) {
+ Malformed(Diagnostic::Error(binding->var)
+ << "Expression of type " << expr_sinfo
+ << " cannot be assigned to a variable of type " <<
var_sinfo);
+ }
+ }
+
if (is_lambda) {
recur_vars_.erase(binding->var);
}
diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc
index 89080ebc3e..5493b44f82 100644
--- a/src/relax/transform/normalize.cc
+++ b/src/relax/transform/normalize.cc
@@ -65,7 +65,11 @@ class NormalizeMutator : public ExprMutatorBase {
Expr VisitWithNewScope(const Expr& expr, Optional<Array<Var>> params =
NullOpt) {
builder_->BeginBindingBlock();
- builder_->BeginScope(params);
+ if (params.defined()) {
+ builder_->BeginScope(params);
+ } else {
+ builder_->BeginInnerScope();
+ }
Expr ret = this->VisitExpr(expr);
BindingBlock prologue = builder_->EndBlock();
if (!prologue->bindings.empty()) {
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py
b/tests/python/relax/test_analysis_struct_info_analysis.py
index 83b1ddd4fc..b2931549e9 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -24,7 +24,7 @@ import tvm.testing
from tvm import TVMError
from tvm import relax as rx
from tvm import tir, ir
-from tvm.script import relax as R
+from tvm.script import relax as R, tir as T
def test_get_static_type_basic():
@@ -620,6 +620,98 @@ def test_struct_info_lca():
_check_lca(fopaque2(), fn_info_shape(1), fopaque2())
+def _generate_prim_test_cases():
+ dtypes = [
+ "bool",
+ "int8",
+ "uint8",
+ "int16",
+ "uint16",
+ "int32",
+ "uint32",
+ "int64",
+ "uint64",
+ "float16",
+ "float32",
+ "float64",
+ ]
+
+ for dtype in dtypes:
+ # LCA of a PrimStructInfo with itself yields itself
+ yield (R.Prim(dtype), R.Prim(dtype), R.Prim(dtype))
+
+ # The LCA of two values, each statically known to be the same
+ # value, is known to have that value.
+ yield (
+ R.Prim(value=tir.const(0, dtype)),
+ R.Prim(value=tir.const(0, dtype)),
+ R.Prim(value=tir.const(0, dtype)),
+ )
+
+ # The LCA of two values, each of which is statically known to
+ # have a different value, no longer knows the contained value.
+ yield (
+ R.Prim(value=tir.const(0, dtype)),
+ R.Prim(value=tir.const(1, dtype)),
+ R.Prim(dtype=dtype),
+ )
+
+ # LCA of a known variable with itself yields itself
+ var_N = tir.Var("N", dtype)
+ yield (R.Prim(value=var_N), R.Prim(value=var_N), R.Prim(value=var_N))
+
+ # LCA of a known variable with a known static value is no
+ # longer known to have a specific value.
+ yield (R.Prim(value=var_N), R.Prim(value=tir.const(0, dtype)),
R.Prim(dtype=dtype))
+ yield (R.Prim(value=tir.const(0, dtype)), R.Prim(value=var_N),
R.Prim(dtype=dtype))
+
+ var_M = tir.Var("M", dtype)
+ yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Prim(dtype=dtype))
+
+ for dtype_a in dtypes:
+ for dtype_b in dtypes:
+ if dtype_a != dtype_b:
+ # Unlike R.Tensor, R.Prim does not currently support a
+ # value with an unknown datatype. If the dtype
+ # differs between the two annotations, the next wider
+ # category is R.Object.
+ yield (R.Prim(dtype_a), R.Prim(dtype_b), R.Object)
+
+ # Because the dtypes are different, even `R.Prim` containing
+ # the same value in different representations (e.g.
+ # `T.float32(0)` vs `T.float16(0)`) fall back to `R.Object`.
+ yield (
+ R.Prim(value=tir.const(0, dtype_a)),
+ R.Prim(value=tir.const(0, dtype_b)),
+ R.Object,
+ )
+
+ # And the same is true for known variable values
+ var_N = tir.Var("N", dtype_a)
+ var_M = tir.Var("M", dtype_b)
+ yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Object)
+
+
[email protected]("test_case", list(_generate_prim_test_cases()))
+def test_prim_struct_info_lca(test_case):
+ def _normalize_sinfo(sinfo):
+ if isinstance(sinfo, tvm.relax.StructInfo):
+ return sinfo
+ elif isinstance(sinfo, tvm.script.parser.relax.entry.StructInfoProxy):
+ return sinfo.as_struct_info()
+ elif callable(sinfo):
+ return sinfo()
+ else:
+ raise TypeError(f"Cannot normalize {type(sinfo)} to StructInfo")
+
+ lhs, rhs, expected = map(_normalize_sinfo, test_case)
+
+ lca = rx.analysis.struct_info_lca(lhs, rhs)
+ assert tvm.ir.structural_equal(
+ lca, expected
+ ), f"Expected {lhs} and {rhs} to have LCA of {expected}, but instead found
{lca}"
+
+
def _generate_tir_var_test_cases():
n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
shape0 = rx.ShapeStructInfo([1, n, 3])
diff --git a/tests/python/relax/test_analysis_well_formed.py
b/tests/python/relax/test_analysis_well_formed.py
index c0b962c3f3..3db3efee1a 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -1208,5 +1208,92 @@ def test_call_tir_inplace_with_some_allocated_outputs():
assert rx.analysis.well_formed(Module)
+def test_var_binding_must_have_compatible_struct_info():
+ """Variables must accurately describe their contents
+
+ To be well-formed, the inferred struct info must not conflict with
+ the StructInfo annotations.
+
+ """
+
+ # The function is equivalent to the TVMScript below. However,
+ # TVMScript applies additional checks that would catch this error
+ # while parsing. In order to validate the well-formed checker
+ # itself, this test directly constructs the function withoutusing
+ # TVMScript, skipping the TVMScript-specific checks.
+ #
+ # @R.function
+ # def main(
+ # A: R.Tensor(shape=[128, 32], dtype="float32"),
+ # ):
+ # B: R.Tensor(shape=[128, 32], dtype="int32") = A
+ # return B
+
+ param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32"))
+ var = tvm.relax.Var("B", R.Tensor(shape=[128, 32], dtype="int32"))
+ binding = tvm.relax.VarBinding(var, param)
+ body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var)
+ tvm.relax.expr._update_struct_info(body, var.struct_info)
+ main = tvm.relax.Function([param], body)
+
+ assert not rx.analysis.well_formed(main)
+
+
+def test_var_binding_may_have_less_constrained_struct_info():
+ """StructInfo of variable may be less specific than expression
+
+ The StructInfo annotation of a variable is not required to be an
+ exact match to the expression's StructInfo, and may provide less
+ specific information than the inference would provide.
+
+ """
+
+ @I.ir_module
+ class Module:
+ @R.function
+ def main(
+ A: R.Tensor(shape=[128, 32], dtype="float32"),
+ ):
+ B: R.Object = R.add(A, A)
+ return B
+
+ assert isinstance(
+ Module["main"].body.blocks[0].bindings[0].var.struct_info,
tvm.relax.ObjectStructInfo
+ ), "Validity of this test requires a variable with R.Object struct info"
+
+ assert rx.analysis.well_formed(Module)
+
+
+def test_var_binding_with_incomplete_struct_info_must_be_consistent():
+ """StructInfo of variable must be accurate
+
+ Even though StructInfo annotation may be less specific, the
+ information that they do contain must be correct.
+
+ """
+
+ # The function is equivalent to the TVMScript below. However,
+ # TVMScript applies additional checks that would catch this error
+ # while parsing. In order to validate the well-formed checker
+ # itself, this test directly constructs the function withoutusing
+ # TVMScript, skipping the TVMScript-specific checks.
+ #
+ # @R.function
+ # def main(
+ # A: R.Tensor(shape=[128, 32], dtype="float32"),
+ # ):
+ # B: R.Tensor(ndim=3) = A
+ # return B
+
+ param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32"))
+ var = tvm.relax.Var("B", R.Tensor(ndim=3, dtype="int32"))
+ binding = tvm.relax.VarBinding(var, param)
+ body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var)
+ tvm.relax.expr._update_struct_info(body, var.struct_info)
+ main = tvm.relax.Function([param], body)
+
+ assert not rx.analysis.well_formed(main)
+
+
if __name__ == "__main__":
tvm.testing.main()