This is an automated email from the ASF dual-hosted git repository.

echuraev pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 927df59662 [Relay] Disable exception for ADT in mixed precision pass 
(#15533)
927df59662 is described below

commit 927df5966237f10978319044716d93c90bf8843c
Author: Egor Churaev <[email protected]>
AuthorDate: Mon Aug 14 10:32:50 2023 +0300

    [Relay] Disable exception for ADT in mixed precision pass (#15533)
    
    If topology contains while loop and we want to transform it to mixed
    precision then we get an exception that "ADT are not supported for
    mixed precision pass". It happens, because while loop implemented as a
    lambda which is assigned to a VarNode.
    
    In this commit I changed the behavior of ToMixedPrecision pass and
    instead of generating exception, it just do nothing.
    
    Correspondent regression test is added.
---
 src/relay/transforms/to_mixed_precision.cc    |  9 ++++---
 tests/python/relay/test_to_mixed_precision.py | 35 ++++++++++++++++++++++++++-
 2 files changed, 39 insertions(+), 5 deletions(-)

diff --git a/src/relay/transforms/to_mixed_precision.cc 
b/src/relay/transforms/to_mixed_precision.cc
index 4638ee5477..5026b1bcba 100644
--- a/src/relay/transforms/to_mixed_precision.cc
+++ b/src/relay/transforms/to_mixed_precision.cc
@@ -350,10 +350,11 @@ class MixedPrecisionPass : public MixedModeMutator {
 
     // TODO(AndrewZhaoLuo): Support ADTs
     // Relay's algebraic data types are not supported yet.
-    ICHECK(!cur_op.as<GlobalVarNode>()       // used to declare functions for 
recursion
-           && !cur_op.as<ConstructorNode>()  // constructing ADT types
-           && !cur_op.as<VarNode>())         // used for calling recursive 
functions
-        << "Algebraic Data Types (ADT) are not supported yet for mixed 
precision pass.";
+    bool isADT = (cur_op.as<GlobalVarNode>()       // used to declare 
functions for recursion
+                  || cur_op.as<ConstructorNode>()  // constructing ADT types
+                  || cur_op.as<LetNode>()          // used for binding lambdas
+                  || cur_op.as<VarNode>());        // used for calling 
recursive functions
+    if (isADT) return post;
 
     // Get info on the operation being called:
     // conversion category (int), accumulation dtype (str), output dtype (str)
diff --git a/tests/python/relay/test_to_mixed_precision.py 
b/tests/python/relay/test_to_mixed_precision.py
index a802eee6d6..4c97642498 100644
--- a/tests/python/relay/test_to_mixed_precision.py
+++ b/tests/python/relay/test_to_mixed_precision.py
@@ -49,7 +49,6 @@ def verify_mixed_precision_output_close(
     atol: float = 0,
     keep_orig_output_dtype=False,
 ) -> tvm.runtime.Module:
-
     mod = InferType()(mod)
     result_fp32 = run_module(mod, mod_params)
 
@@ -586,5 +585,39 @@ def test_clip_with_pre_op(target_precision):
     assert tvm.ir.structural_equal(expected_mod, output_mod)
 
 
+def test_loop(target_precision):
+    i = relay.var("i", shape=(), dtype="int32")
+    st = relay.var("st", shape=(relay.Any(), 1), dtype="int32")
+
+    def int32(val):
+        return relay.const(val, "int32")
+
+    def _cond(i, st):
+        return relay.op.min(relay.op.less(i, int32(10)))
+
+    def _body(i, st):
+        i_vec = relay.op.reshape(i, (1, 1))
+        ret = relay.op.concatenate([st, i_vec], axis=0)
+        return i + int32(1), ret
+
+    loop = relay.loops.while_loop(_cond, [i, st], _body)
+    start = relay.var("start", shape=(), dtype="int32")
+    body = loop(start, relay.op.reshape(relay.const(0), newshape=(1, 1)))
+    func = relay.Function([start], relay.TupleGetItem(body, 1))
+    mod = tvm.IRModule()
+    mod["main"] = func
+
+    mod_params = {
+        "start": np.random.uniform(-1, 1, size=()).astype("int32"),
+    }
+    output_mod = verify_mixed_precision_output_close(
+        mod, mod_params, mixed_precision_dtype=target_precision, atol=0.01, 
rtol=0.01
+    )
+
+    # Create expected module
+    expected_mod = InferType()(mod)
+    assert tvm.ir.structural_equal(expected_mod, output_mod)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to