jroesch commented on a change in pull request #7029:
URL: https://github.com/apache/tvm/pull/7029#discussion_r535842427
##########
File path: tests/python/relay/test_pass_dead_code_elimination.py
##########
@@ -25,128 +25,179 @@
import pytest
-class env:
- def __init__(self):
- self.shape = tvm.runtime.convert([1, 2, 3])
- self.tt = relay.TensorType(self.shape, "float32")
- self.int32 = relay.TensorType([], "int32")
- self.float32 = relay.TensorType([], "float32")
- self.one = relay.const(1.0)
- self.two = relay.const(2.0)
- self.three = relay.const(3.0)
- self.a = relay.Var("a", self.float32)
- self.b = relay.Var("b", self.float32)
- self.c = relay.Var("c", self.float32)
- self.d = relay.Var("d", self.float32)
- self.e = relay.Var("e", self.float32)
- self.x = relay.Var("x", self.int32)
- self.y = relay.Var("y", self.int32)
- self.z = relay.Var("z", self.int32)
-
-
-e = env()
-
-
-def run_opt_pass(expr, opt_pass):
- assert isinstance(opt_pass, tvm.transform.Pass)
- mod = tvm.IRModule.from_expr(expr)
- mod = opt_pass(mod)
- entry = mod["main"]
- return entry if isinstance(expr, relay.Function) else entry.body
-
-
-def test_let():
- orig = relay.Let(e.x, e.y, e.z)
- orig = run_opt_pass(orig, transform.DeadCodeElimination())
- assert tvm.ir.structural_equal(Function(free_vars(orig), orig),
Function([e.z], e.z))
-
-
-def test_used_let():
- orig = relay.Let(e.c, e.one, e.c + e.c)
- orig = run_opt_pass(orig, transform.DeadCodeElimination())
- expected = relay.Let(e.c, e.one, e.c + e.c)
- assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
-
-
-def test_inline():
- orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
- orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
- tvm.ir.assert_structural_equal(Function(free_vars(orig), orig),
Function([e.d], e.d))
-
-
-def test_chain_unused_let():
- orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
- orig = run_opt_pass(orig, transform.DeadCodeElimination())
- assert tvm.ir.structural_equal(Function(free_vars(orig), orig),
Function([e.e], e.e))
-
-
-def use_f(func):
- f = relay.Var("f")
- n = relay.Var("n", e.int32)
- data = relay.Var("data", e.float32)
- funcbody = relay.If(
- equal(n, relay.const(0)), data, relay.Call(f, [subtract(n,
relay.const(1)), log(data)])
+def optimize_source(source, passes):
+ if not isinstance(passes, list):
+ passes = [passes]
+
+ optimize = tvm.transform.Sequential(passes)
+ module = tvm.parser.parse(source)
+ return optimize(module)
+
+
+def optimize_and_check(before_source, after_source, passes):
+ optimize_module = optimize_source(before_source, passes)
+ after_module = tvm.parser.parse(after_source)
+ print(optimize_module)
+ print(after_module)
+ assert tvm.ir.structural_equal(after_module, optimize_module)
+
+
+def test_dead_let():
+ before_program = """
+ #[version = "0.0.5"]
+ def @main(%z: int) {
+ let %x = 1;
+ %z
+ }
+ """
+ after_program = """
+ #[version = "0.0.5"]
+ def @main(%z: int) {
+ %z
+ }
+ """
+ optimize_and_check(before_program, after_program,
transform.DeadCodeElimination())
+
+
+def test_one_live_let():
+ before_program = """
+ #[version = "0.0.5"]
+ def @main(%z: int) {
+ let %x = 1;
+ let %y = 2;
+ %x + %x
+ }
+ """
+ after_program = """
+ #[version = "0.0.5"]
+ def @main(%z: int) {
+ let %x = 1;
+ %x + %x
+ }
+ """
+ optimize_and_check(before_program, after_program,
transform.DeadCodeElimination())
+
+
+def test_nested_let():
+ before_program = """
+ #[version = "0.0.5"]
+ def @main(%d: int, %b: int) {
+ let %a = %b;
+ let %c = %d;
+ %c
+ }
+ """
+ after_program = """
+ #[version = "0.0.5"]
+ def @main(%d: int, %b: int) {
+ let %c = %d;
+ %c
+ }
+ """
+ optimize_and_check(before_program, after_program,
transform.DeadCodeElimination())
Review comment:
I'll fix it in follow up.
----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
For queries about this service, please contact Infrastructure at:
[email protected]