This is an automated email from the ASF dual-hosted git repository.

lunderberg pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 72b75fe5b2 [Relax] Validate StructInfo of variable bindings (#17332)
72b75fe5b2 is described below

commit 72b75fe5b2f34765892b6ae3ba8709bad318b7bd
Author: Eric Lunderberg <[email protected]>
AuthorDate: Wed Sep 11 08:34:17 2024 -0500

    [Relax] Validate StructInfo of variable bindings (#17332)
    
    * [Relax] Validate StructInfo of variable bindings
    
    In Relax, both the variable and the expression in a `VarBinding` may
    contain `StructInfo` annotations.  Prior to this commit, these
    `StructInfo` annotations could be inconsistent, assigning an
    expression to a variable of incompatible type.
    
    This commit updates the Relax well-formed checker to verify that the
    `StructInfo` of Relax variables accurately describes their contents.
    
    * Fix unit tests
    
    * [Relax][Bugfix] LCA of PrimStructInfo must check known values
    
    The `StructInfoLCA` determines the lowest common ancestor between two
    `StructInfo` annotations.  This is primarily used in Relax to
    determine the appropriate `StructInfo` annotation for a `relax::If`
    node, given the `StructInfo` of each branch.  Prior to this commit,
    when determining the LCA of two `PrimStructInfo` annotations, the
    `StructInfoLCA` function only inspected the datatype of
    `PrimStructInfo` annotations, and did not check for known values.  For
    example, the LCA of `R.Prim(value=T.int64(128))` and
    `R.Prim(value=T.int64(64))` is `R.Prim("int64")`, but was incorrectly
    determined as `R.Prim(value=T.int64(128))` by the `StructInfoLCA`
    function.
    
    This commit updates `StructInfoLCA` to inspect the known values of a
    `PrimStructInfo`, as well as the datatype.
---
 src/relax/analysis/struct_info_analysis.cc         | 23 +++++-
 src/relax/analysis/well_formed.cc                  | 12 +++
 src/relax/transform/normalize.cc                   |  6 +-
 .../relax/test_analysis_struct_info_analysis.py    | 94 +++++++++++++++++++++-
 tests/python/relax/test_analysis_well_formed.py    | 87 ++++++++++++++++++++
 5 files changed, 216 insertions(+), 6 deletions(-)

diff --git a/src/relax/analysis/struct_info_analysis.cc 
b/src/relax/analysis/struct_info_analysis.cc
index a7e5404c20..6fe8f36020 100644
--- a/src/relax/analysis/struct_info_analysis.cc
+++ b/src/relax/analysis/struct_info_analysis.cc
@@ -982,10 +982,25 @@ class StructInfoLCAFinder
   StructInfo VisitStructInfo_(const PrimStructInfoNode* lhs, const StructInfo& 
other) final {
     auto* rhs = other.as<PrimStructInfoNode>();
     if (rhs == nullptr) return ObjectStructInfo(lhs->span);
-    if (lhs->dtype == rhs->dtype) return GetRef<StructInfo>(lhs);
-    // PrimType will be treated as their boxed(object) values
-    // as a result we can unify to object.
-    return ObjectStructInfo(lhs->span);
+    if (lhs->dtype != rhs->dtype) {
+      // PrimType will be treated as their boxed(object) values
+      // as a result we can unify to object.
+      return ObjectStructInfo(lhs->span);
+    }
+    if (!lhs->value.defined() || !rhs->value.defined() ||
+        !analyzer_->CanProveEqual(lhs->value.value(), rhs->value.value())) {
+      // The two values are known to contain the same dtype, but may
+      // contain different values.
+      if (!lhs->value.defined()) {
+        // If the mismatch was due to extra information in the RHS,
+        // prefer to avoid constructing a new object.
+        return GetRef<StructInfo>(lhs);
+      } else {
+        return PrimStructInfo(lhs->dtype, lhs->span);
+      }
+    }
+
+    return GetRef<StructInfo>(lhs);
   }
 
   StructInfo VisitStructInfo_(const ShapeStructInfoNode* lhs, const 
StructInfo& other) final {
diff --git a/src/relax/analysis/well_formed.cc 
b/src/relax/analysis/well_formed.cc
index 235059ece2..7688c4a642 100644
--- a/src/relax/analysis/well_formed.cc
+++ b/src/relax/analysis/well_formed.cc
@@ -429,6 +429,18 @@ class WellFormedChecker : public relax::ExprVisitor,
     }
 
     this->VisitVarDef(binding->var);
+
+    if (check_struct_info_ && binding->var->struct_info_.defined() &&
+        binding->value->struct_info_.defined()) {
+      auto expr_sinfo = GetStructInfo(binding->value);
+      auto var_sinfo = GetStructInfo(binding->var);
+      if (!IsBaseOf(var_sinfo, expr_sinfo)) {
+        Malformed(Diagnostic::Error(binding->var)
+                  << "Expression of type " << expr_sinfo
+                  << " cannot be assigned to a variable of type " << 
var_sinfo);
+      }
+    }
+
     if (is_lambda) {
       recur_vars_.erase(binding->var);
     }
diff --git a/src/relax/transform/normalize.cc b/src/relax/transform/normalize.cc
index 89080ebc3e..5493b44f82 100644
--- a/src/relax/transform/normalize.cc
+++ b/src/relax/transform/normalize.cc
@@ -65,7 +65,11 @@ class NormalizeMutator : public ExprMutatorBase {
 
   Expr VisitWithNewScope(const Expr& expr, Optional<Array<Var>> params = 
NullOpt) {
     builder_->BeginBindingBlock();
-    builder_->BeginScope(params);
+    if (params.defined()) {
+      builder_->BeginScope(params);
+    } else {
+      builder_->BeginInnerScope();
+    }
     Expr ret = this->VisitExpr(expr);
     BindingBlock prologue = builder_->EndBlock();
     if (!prologue->bindings.empty()) {
diff --git a/tests/python/relax/test_analysis_struct_info_analysis.py 
b/tests/python/relax/test_analysis_struct_info_analysis.py
index 83b1ddd4fc..b2931549e9 100644
--- a/tests/python/relax/test_analysis_struct_info_analysis.py
+++ b/tests/python/relax/test_analysis_struct_info_analysis.py
@@ -24,7 +24,7 @@ import tvm.testing
 from tvm import TVMError
 from tvm import relax as rx
 from tvm import tir, ir
-from tvm.script import relax as R
+from tvm.script import relax as R, tir as T
 
 
 def test_get_static_type_basic():
@@ -620,6 +620,98 @@ def test_struct_info_lca():
     _check_lca(fopaque2(), fn_info_shape(1), fopaque2())
 
 
+def _generate_prim_test_cases():
+    dtypes = [
+        "bool",
+        "int8",
+        "uint8",
+        "int16",
+        "uint16",
+        "int32",
+        "uint32",
+        "int64",
+        "uint64",
+        "float16",
+        "float32",
+        "float64",
+    ]
+
+    for dtype in dtypes:
+        # LCA of a PrimStructInfo with itself yields itself
+        yield (R.Prim(dtype), R.Prim(dtype), R.Prim(dtype))
+
+        # The LCA of two values, each statically known to be the same
+        # value, is known to have that value.
+        yield (
+            R.Prim(value=tir.const(0, dtype)),
+            R.Prim(value=tir.const(0, dtype)),
+            R.Prim(value=tir.const(0, dtype)),
+        )
+
+        # The LCA of two values, each of which is statically known to
+        # have a different value, no longer knows the contained value.
+        yield (
+            R.Prim(value=tir.const(0, dtype)),
+            R.Prim(value=tir.const(1, dtype)),
+            R.Prim(dtype=dtype),
+        )
+
+        # LCA of a known variable with itself yields itself
+        var_N = tir.Var("N", dtype)
+        yield (R.Prim(value=var_N), R.Prim(value=var_N), R.Prim(value=var_N))
+
+        # LCA of a known variable with a known static value is no
+        # longer known to have a specific value.
+        yield (R.Prim(value=var_N), R.Prim(value=tir.const(0, dtype)), 
R.Prim(dtype=dtype))
+        yield (R.Prim(value=tir.const(0, dtype)), R.Prim(value=var_N), 
R.Prim(dtype=dtype))
+
+        var_M = tir.Var("M", dtype)
+        yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Prim(dtype=dtype))
+
+    for dtype_a in dtypes:
+        for dtype_b in dtypes:
+            if dtype_a != dtype_b:
+                # Unlike R.Tensor, R.Prim does not currently support a
+                # value with an unknown datatype.  If the dtype
+                # differs between the two annotations, the next wider
+                # category is R.Object.
+                yield (R.Prim(dtype_a), R.Prim(dtype_b), R.Object)
+
+                # Because the dtypes are different, even `R.Prim` containing
+                # the same value in different representations (e.g.
+                # `T.float32(0)` vs `T.float16(0)`) fall back to `R.Object`.
+                yield (
+                    R.Prim(value=tir.const(0, dtype_a)),
+                    R.Prim(value=tir.const(0, dtype_b)),
+                    R.Object,
+                )
+
+                # And the same is true for known variable values
+                var_N = tir.Var("N", dtype_a)
+                var_M = tir.Var("M", dtype_b)
+                yield (R.Prim(value=var_N), R.Prim(value=var_M), R.Object)
+
+
[email protected]("test_case", list(_generate_prim_test_cases()))
+def test_prim_struct_info_lca(test_case):
+    def _normalize_sinfo(sinfo):
+        if isinstance(sinfo, tvm.relax.StructInfo):
+            return sinfo
+        elif isinstance(sinfo, tvm.script.parser.relax.entry.StructInfoProxy):
+            return sinfo.as_struct_info()
+        elif callable(sinfo):
+            return sinfo()
+        else:
+            raise TypeError(f"Cannot normalize {type(sinfo)} to StructInfo")
+
+    lhs, rhs, expected = map(_normalize_sinfo, test_case)
+
+    lca = rx.analysis.struct_info_lca(lhs, rhs)
+    assert tvm.ir.structural_equal(
+        lca, expected
+    ), f"Expected {lhs} and {rhs} to have LCA of {expected}, but instead found 
{lca}"
+
+
 def _generate_tir_var_test_cases():
     n, m = tir.Var("n", "int64"), tir.Var("m", "int64")
     shape0 = rx.ShapeStructInfo([1, n, 3])
diff --git a/tests/python/relax/test_analysis_well_formed.py 
b/tests/python/relax/test_analysis_well_formed.py
index c0b962c3f3..3db3efee1a 100644
--- a/tests/python/relax/test_analysis_well_formed.py
+++ b/tests/python/relax/test_analysis_well_formed.py
@@ -1208,5 +1208,92 @@ def test_call_tir_inplace_with_some_allocated_outputs():
     assert rx.analysis.well_formed(Module)
 
 
+def test_var_binding_must_have_compatible_struct_info():
+    """Variables must accurately describe their contents
+
+    To be well-formed, the inferred struct info must not conflict with
+    the StructInfo annotations.
+
+    """
+
+    # The function is equivalent to the TVMScript below.  However,
+    # TVMScript applies additional checks that would catch this error
+    # while parsing.  In order to validate the well-formed checker
+    # itself, this test directly constructs the function withoutusing
+    # TVMScript, skipping the TVMScript-specific checks.
+    #
+    # @R.function
+    # def main(
+    #     A: R.Tensor(shape=[128, 32], dtype="float32"),
+    # ):
+    #     B: R.Tensor(shape=[128, 32], dtype="int32") = A
+    #     return B
+
+    param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32"))
+    var = tvm.relax.Var("B", R.Tensor(shape=[128, 32], dtype="int32"))
+    binding = tvm.relax.VarBinding(var, param)
+    body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var)
+    tvm.relax.expr._update_struct_info(body, var.struct_info)
+    main = tvm.relax.Function([param], body)
+
+    assert not rx.analysis.well_formed(main)
+
+
+def test_var_binding_may_have_less_constrained_struct_info():
+    """StructInfo of variable may be less specific than expression
+
+    The StructInfo annotation of a variable is not required to be an
+    exact match to the expression's StructInfo, and may provide less
+    specific information than the inference would provide.
+
+    """
+
+    @I.ir_module
+    class Module:
+        @R.function
+        def main(
+            A: R.Tensor(shape=[128, 32], dtype="float32"),
+        ):
+            B: R.Object = R.add(A, A)
+            return B
+
+    assert isinstance(
+        Module["main"].body.blocks[0].bindings[0].var.struct_info, 
tvm.relax.ObjectStructInfo
+    ), "Validity of this test requires a variable with R.Object struct info"
+
+    assert rx.analysis.well_formed(Module)
+
+
+def test_var_binding_with_incomplete_struct_info_must_be_consistent():
+    """StructInfo of variable must be accurate
+
+    Even though StructInfo annotation may be less specific, the
+    information that they do contain must be correct.
+
+    """
+
+    # The function is equivalent to the TVMScript below.  However,
+    # TVMScript applies additional checks that would catch this error
+    # while parsing.  In order to validate the well-formed checker
+    # itself, this test directly constructs the function withoutusing
+    # TVMScript, skipping the TVMScript-specific checks.
+    #
+    #   @R.function
+    #   def main(
+    #       A: R.Tensor(shape=[128, 32], dtype="float32"),
+    #   ):
+    #       B: R.Tensor(ndim=3) = A
+    #       return B
+
+    param = tvm.relax.Var("A", R.Tensor(shape=[128, 32], dtype="float32"))
+    var = tvm.relax.Var("B", R.Tensor(ndim=3, dtype="int32"))
+    binding = tvm.relax.VarBinding(var, param)
+    body = tvm.relax.SeqExpr([tvm.relax.BindingBlock([binding])], var)
+    tvm.relax.expr._update_struct_info(body, var.struct_info)
+    main = tvm.relax.Function([param], body)
+
+    assert not rx.analysis.well_formed(main)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to