This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new f4f525d  [AMP] Disallow fp16 conversion for summation-like ops (#8810)
f4f525d is described below

commit f4f525dab86af653636bce95ce3609288fbaa587
Author: masahi <[email protected]>
AuthorDate: Fri Aug 27 07:16:54 2021 +0900

    [AMP] Disallow fp16 conversion for summation-like ops (#8810)
    
    * [AMP] Disallow fp16 conversion for summation-like ops
    
    * test only structural equality
---
 python/tvm/relay/transform/mixed_precision.py | 15 ++++++-------
 tests/python/relay/test_to_mixed_precision.py | 31 +++++++++++++++++++--------
 2 files changed, 29 insertions(+), 17 deletions(-)

diff --git a/python/tvm/relay/transform/mixed_precision.py 
b/python/tvm/relay/transform/mixed_precision.py
index 1657f89..fb4d3fa 100644
--- a/python/tvm/relay/transform/mixed_precision.py
+++ b/python/tvm/relay/transform/mixed_precision.py
@@ -81,8 +81,6 @@ DEFAULT_FOLLOW_LIST = [
     "divide",
     "nn.bias_add",
     "nn.batch_norm",
-    "sum",
-    "mean",
     "sqrt",
     "shape_of",
     # Simple activations
@@ -107,15 +105,9 @@ DEFAULT_FOLLOW_LIST = [
     # "nn.global_max_pool1d", # does not exist yet
     "nn.global_max_pool2d",
     # "nn.global_max_pool3d", # does not exist yet
-    # "nn.global_avg_pool1d", # does not exist yet
-    "nn.global_avg_pool2d",
-    # "nn.global_avg_pool3d", # does not exist yet
     "nn.adaptive_max_pool1d",
     "nn.adaptive_max_pool2d",
     "nn.adaptive_max_pool3d",
-    "nn.adaptive_avg_pool1d",
-    "nn.adaptive_avg_pool2d",
-    "nn.adaptive_avg_pool3d",
 ]
 DEFAULT_NEVER_LIST = [
     # In general if |f(x)| >> |x| for expected inputs then put the op here.
@@ -131,6 +123,13 @@ DEFAULT_NEVER_LIST = [
     # Do not allow arange arguments (begin/end) to be fp16. "end" can be a big 
fp32 number
     # not representable in fp16.
     "arange",
+    # Ops that could involve a large summation are not allowed in fp16.
+    "nn.global_avg_pool2d",
+    "nn.adaptive_avg_pool1d",
+    "nn.adaptive_avg_pool2d",
+    "nn.adaptive_avg_pool3d",
+    "sum",
+    "mean",
 ]
 
 
diff --git a/tests/python/relay/test_to_mixed_precision.py 
b/tests/python/relay/test_to_mixed_precision.py
index 99078b7..472f987 100644
--- a/tests/python/relay/test_to_mixed_precision.py
+++ b/tests/python/relay/test_to_mixed_precision.py
@@ -221,12 +221,9 @@ def test_do_not_convert_softmax():
     b = relay.nn.softmax(a)
     mod = tvm.IRModule.from_expr(b)
     mod = tvm.relay.transform.InferType()(mod)
-
-    mod_params = {
-        "a": np.random.uniform(-1, 1, size=shape).astype("float32"),
-    }
-    output_mod = verify_mixed_precision_output_close(mod, mod_params, 
atol=0.0, rtol=0)
-    assert tvm.ir.structural_equal(mod, output_mod)
+    out_mod = ToMixedPrecision("float16")(mod)
+    orig_mod = tvm.relay.transform.InferType()(mod)
+    assert tvm.ir.structural_equal(orig_mod, out_mod)
 
 
 def test_do_not_convert_arange():
@@ -234,10 +231,26 @@ def test_do_not_convert_arange():
     dtype = "float32"
     arange = relay.arange(relay.const(1, dtype), relay.const(128, dtype))
     mod = tvm.IRModule.from_expr(arange)
-    mod = tvm.relay.transform.InferType()(mod)
+    out_mod = ToMixedPrecision("float16")(mod)
+    orig_mod = tvm.relay.transform.InferType()(mod)
+    assert tvm.ir.structural_equal(orig_mod, out_mod)
 
-    output_mod = verify_mixed_precision_output_close(mod, {}, atol=0.0, rtol=0)
-    assert tvm.ir.structural_equal(mod, output_mod)
+
+def test_do_not_convert_summation():
+    """Ops that could involve a large summation are not allowed in fp16."""
+    shape = [1, 3, 16, 16]
+    a = relay.var("a", shape=shape)
+    ops = [
+        relay.sum,
+        relay.mean,
+        relay.nn.global_avg_pool2d,
+        lambda inp: relay.nn.adaptive_avg_pool2d(inp, (1, 1)),
+    ]
+    for op in ops:
+        mod = tvm.IRModule.from_expr(op(a))
+        out_mod = ToMixedPrecision("float16")(mod)
+        orig_mod = tvm.relay.transform.InferType()(mod)
+        assert tvm.ir.structural_equal(orig_mod, out_mod)
 
 
 def test_green_gray_propagates_simple():

Reply via email to