chunit-quic commented on code in PR #13402:
URL: https://github.com/apache/tvm/pull/13402#discussion_r1030068464


##########
python/tvm/testing/utils.py:
##########
@@ -2081,3 +2081,28 @@ def pprint(name, obj):
                 f"or an instance of `tvm.tir.PrimFunc`.  "
                 f"Instead, received {type(expected)}."
             )
+
+
+class _control_span_filling:
+    def __init__(self, on=True):
+        self._old_state = os.environ["TVM_SPANFILLING"] if "TVM_SPANFILLING" 
in os.environ else None

Review Comment:
   Because we can command like the following way to get rid of span without 
modifying the code. It would be a bit more convenient. 
   ```bash
   TVM_SPANFILLING=0 python ${your_program.py}
   ````
   



##########
src/relay/ir/expr.cc:
##########
@@ -72,8 +72,8 @@ Constant::Constant(runtime::NDArray data, Span span) {
 
 TVM_REGISTER_NODE_TYPE(ConstantNode);
 
-TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray 
data) {
-  return Constant(data);
+TVM_REGISTER_GLOBAL("relay.ir.Constant").set_body_typed([](runtime::NDArray 
data, Span span) {

Review Comment:
    Thank you for point out this concern. :D
   Just a gentle reminder that we do have the default value of span for these 
types. Because currently the constructors in header 
([include/tvm/relay/expr.h](https://github.com/apache/tvm/blob/b2058f4dd2e0ae1fc5ab51ac9f84b372a389a65a/include/tvm/relay/expr.h#L107))
 have already been assigned with Span() to its attribute. We simply make span 
attribute be modifiable here.
   
   



##########
tests/python/frontend/test_common.py:
##########
@@ -27,6 +32,203 @@ def test_key_is_not_present():
     assert not attrs.has_attr("b")
 
 
+def test_set_span():
+    def _verify_env_var_switch():
+        def _res(should_fill):
+            if should_fill:
+                with testing.enable_span_filling():
+                    return set_span(relay.var("x", shape=(1, 64, 56, 56)), 
"x_var")
+            else:
+                with testing.disable_span_filling():
+                    return set_span(relay.var("x", shape=(1, 64, 56, 56)), 
"x_var")
+
+        disable = relay.var("x", shape=(1, 64, 56, 56))
+        enable = relay.var("x", shape=(1, 64, 56, 56), 
span=_create_span("x_var"))
+
+        assert _verify_structural_equal_with_span(_res(False), disable)
+        assert _verify_structural_equal_with_span(_res(True), enable)
+
+    # Should tag all exprs without span, and stop when expr is span-tagged
+    def _verify_builtin_tuple():
+        def _res():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
+            return set_span(tuple([a, b]), "tuple")
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64", 
span=_create_span("tuple"))
+            return tuple([a, b])
+
+        res_tuple, golden_tuple = _res(), _golden()
+        assert len(res_tuple) == len(golden_tuple)
+        for i in range(len(res_tuple)):
+            assert _verify_structural_equal_with_span(res_tuple[i], 
golden_tuple[i])
+
+    def _verify_builtin_list():
+        def _res():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64")
+            t = relay.Tuple([a, b])
+            t_a = relay.TupleGetItem(t, 0)
+            t_b = relay.TupleGetItem(t, 1)
+            return set_span([t_a, t_b], "list")
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.zeros([1, 1, 1]), dtype="int64", 
span=_create_span("list"))
+            t = relay.Tuple([a, b], span=_create_span("list"))
+            t_a = relay.TupleGetItem(t, 0, span=_create_span("list"))
+            t_b = relay.TupleGetItem(t, 1, span=_create_span("list"))
+            return [t_a, t_b]
+
+        res_list, golden_list = _res(), _golden()
+        assert len(res_list) == len(golden_list)
+        for i in range(len(res_list)):
+            assert _verify_structural_equal_with_span(res_list[i], 
golden_list[i])
+
+    def _verify_var():
+        x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+        x_expected = relay.var("x", shape=(1, 64, 56, 56), 
span=_create_span("x_var"))
+        assert _verify_structural_equal_with_span(x, x_expected)
+
+    def _verify_constant():
+        c = set_span(relay.const(np.ones([64, 64, 3, 3]), dtype="int64"), 
"const_c")
+        c_expected = relay.const(
+            np.ones([64, 64, 3, 3]), dtype="int64", 
span=_create_span("const_c")
+        )
+        assert _verify_structural_equal_with_span(c, c_expected)
+
+    def _verify_call():
+        def _res():
+            x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
+            y = set_span(
+                relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), 
padding=(1, 1)), "conv2d"
+            )
+            return relay.Function([x], y)
+
+        def _golden():
+            x = relay.var("x", shape=(1, 64, 56, 56), 
span=_create_span("x_var"))
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", 
span=_create_span("conv2d"))
+            y = _set_span(
+                relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), 
padding=(1, 1)), "conv2d"
+            )
+            return relay.Function([x], y)
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    def _verify_tuple():
+        def _res():
+            a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64")
+            t = set_span(relay.Tuple([a, b]), "t")
+            return relay.Function([], t)
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("t"))
+            t = relay.Tuple([a, b], span=_create_span("t"))
+            return relay.Function([], t)
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    def _verify_tuple_getitem():
+        def _res():
+            a = set_span(relay.const(np.ones([1, 1, 1]), dtype="int64"), "a")
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64")
+            t = relay.Tuple([a, b])
+            i = set_span(relay.TupleGetItem(t, 0), "i")
+            return relay.Function([], i)
+
+        def _golden():
+            a = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("a"))
+            b = relay.const(np.ones([1, 1, 1]), dtype="int64", 
span=_create_span("i"))
+            t = relay.Tuple([a, b], span=_create_span("i"))
+            i = relay.TupleGetItem(t, 0, span=_create_span("i"))
+            return relay.Function([], i)
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    def _verify_let():
+        def _res():
+            x = set_span(relay.Var("x"), "x_var")
+            c_1 = relay.const(np.ones(10))
+            add = relay.add(x, x)
+            body = set_span(relay.Let(x, c_1, add), "let")
+
+            c_2 = set_span(relay.const(np.zeros(10)), "zeros")
+            y = set_span(relay.add(body, c_2), "add_2")
+            return relay.Function([x], y)
+
+        def _golden():
+            x = relay.Var("x", span=_create_span("x_var"))
+            c_1 = relay.const(np.ones(10), span=_create_span("let"))
+            add = _set_span(relay.add(x, x), "let")
+            body = relay.Let(x, c_1, add, span=_create_span("let"))
+
+            c_2 = relay.const(np.zeros(10), span=_create_span("zeros"))
+            y = _set_span(relay.add(body, c_2), "add_2")
+            return relay.Function([x], y)
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    def _verify_if():
+        def _res():
+            x = set_span(relay.var("x", shape=[], dtype="float32"), "x_var")
+            y = set_span(relay.var("y", shape=[], dtype="float32"), "y_var")
+            eq = relay.equal(x, y)
+
+            true_branch = set_span(relay.add(x, y), "true_branch")
+            false_branch = relay.subtract(x, y)
+            ife = set_span(relay.If(eq, true_branch, false_branch), "if")
+            return relay.Function([x, y], ife)
+
+        def _golden():
+            x = relay.var("x", shape=[], dtype="float32", 
span=_create_span("x_var"))
+            y = relay.var("y", shape=[], dtype="float32", 
span=_create_span("y_var"))
+            eq = _set_span(relay.equal(x, y), "if")
+
+            true_branch = _set_span(relay.add(x, y), "true_branch")
+            false_branch = _set_span(relay.subtract(x, y), "if")
+            ife = relay.If(eq, true_branch, false_branch, 
span=_create_span("if"))
+            return relay.Function([x, y], ife)
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    def _verify_fn():
+        def _res():
+            x = set_span(relay.var("x", shape=(1, 64, 56, 56)), "x_var")
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64")
+            y = relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), 
padding=(1, 1))
+            f = set_span(relay.Function([x], y), "func")
+            return f
+
+        def _golden():
+            x = relay.var("x", shape=(1, 64, 56, 56), 
span=_create_span("x_var"))
+            w = relay.const(np.ones([64, 64, 3, 3]), dtype="int64", 
span=_create_span("func"))
+            y = _set_span(
+                relay.nn.conv2d(x, w, channels=64, kernel_size=(3, 3), 
padding=(1, 1)), "func"
+            )
+            f = relay.Function([x], y, span=_create_span("func"))
+            return f
+
+        assert _verify_structural_equal_with_span(_res(), _golden())
+
+    _verify_env_var_switch()
+    _verify_builtin_tuple()
+    _verify_builtin_list()
+    _verify_var()
+    _verify_constant()
+    _verify_call()
+    _verify_tuple()
+    _verify_tuple_getitem()
+    _verify_let()
+    _verify_if()
+    _verify_fn()
+
+
 if __name__ == "__main__":
     test_key_is_present()

Review Comment:
   Will do.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to