You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/04 00:15:08 UTC

[GitHub] [tvm] altanh commented on a change in pull request #7029: [Relay][Pass] Clean up DCE tests in preparation for refactoring.

altanh commented on a change in pull request #7029:
URL: https://github.com/apache/tvm/pull/7029#discussion_r535739423



##########
File path: tests/python/relay/test_pass_dead_code_elimination.py
##########
@@ -91,62 +138,107 @@ def use_f(func):
     return relay.Let(f, value, func(f))
 
 
-# make sure we dont infinite loop
-def test_recursion():
+def test_live_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
+    """
+
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
     """
-    Program:
-       let f(n: i32, data: f32) -> f32 = {
-          if (n == 0) {
-              return data;
-          } else {
-              return f(n - 1, log(data));
-          }
-       }
-       f(2, 10000);
+
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
+
+
+def test_dead_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        ()
+    }
     """
-    orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
-    dced = run_opt_pass(orig, transform.DeadCodeElimination())
-    orig = run_opt_pass(orig, transform.InferType())
-    tvm.ir.assert_structural_equal(dced, orig)
 
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        ()
+    }
+    """
 
-def test_recursion_dead():
-    x = relay.Let(e.a, e.one, e.three)
-    dced_f = lambda f: x
-    dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, e.three)
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
 
 
-def test_op_let():
-    dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, add(e.three, e.two))
+def test_dead_recursion():

Review comment:
       🤔 duplicated, maybe you meant `test_op_let`

##########
File path: tests/python/relay/test_pass_dead_code_elimination.py
##########
@@ -91,62 +138,107 @@ def use_f(func):
     return relay.Let(f, value, func(f))
 
 
-# make sure we dont infinite loop
-def test_recursion():
+def test_live_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
+    """
+
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
     """
-    Program:
-       let f(n: i32, data: f32) -> f32 = {
-          if (n == 0) {
-              return data;
-          } else {
-              return f(n - 1, log(data));
-          }
-       }
-       f(2, 10000);
+
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
+
+
+def test_dead_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        ()
+    }
     """
-    orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
-    dced = run_opt_pass(orig, transform.DeadCodeElimination())
-    orig = run_opt_pass(orig, transform.InferType())
-    tvm.ir.assert_structural_equal(dced, orig)
 
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        ()
+    }
+    """
 
-def test_recursion_dead():
-    x = relay.Let(e.a, e.one, e.three)
-    dced_f = lambda f: x
-    dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, e.three)
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
 
 
-def test_op_let():
-    dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, add(e.three, e.two))
+def test_dead_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        (let %a = 1; 3) + 2
+    }
+    """
 
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        3 + 2
+    }
+    """
 
-def test_tuple_get_item():
-    tt = relay.TupleType([e.float32, e.float32])
-    t = relay.Var("t", tt)
-    a = relay.Var("a")
-    g = relay.TupleGetItem(t, 0)
-    dced = run_opt_pass(g, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
-    orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
-    dced = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
 
 
-@pytest.mark.timeout(timeout=10, method="thread")
-def test_complexity():
-    g = inception_v3.get_net(1, 1000, (3, 299, 299), "float32")
-    run_opt_pass(g, transform.DeadCodeElimination())
+def test_dead_recursion():

Review comment:
       duplicated

##########
File path: tests/python/relay/test_pass_dead_code_elimination.py
##########
@@ -91,62 +138,107 @@ def use_f(func):
     return relay.Let(f, value, func(f))
 
 
-# make sure we dont infinite loop
-def test_recursion():
+def test_live_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
+    """
+
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
     """
-    Program:
-       let f(n: i32, data: f32) -> f32 = {
-          if (n == 0) {
-              return data;
-          } else {
-              return f(n - 1, log(data));
-          }
-       }
-       f(2, 10000);
+
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
+
+
+def test_dead_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        ()
+    }
     """
-    orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
-    dced = run_opt_pass(orig, transform.DeadCodeElimination())
-    orig = run_opt_pass(orig, transform.InferType())
-    tvm.ir.assert_structural_equal(dced, orig)
 
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        ()
+    }
+    """
 
-def test_recursion_dead():
-    x = relay.Let(e.a, e.one, e.three)
-    dced_f = lambda f: x
-    dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, e.three)
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
 
 
-def test_op_let():
-    dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, add(e.three, e.two))
+def test_dead_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        (let %a = 1; 3) + 2
+    }
+    """
 
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        3 + 2
+    }
+    """
 
-def test_tuple_get_item():
-    tt = relay.TupleType([e.float32, e.float32])
-    t = relay.Var("t", tt)
-    a = relay.Var("a")
-    g = relay.TupleGetItem(t, 0)
-    dced = run_opt_pass(g, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
-    orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
-    dced = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
 
 
-@pytest.mark.timeout(timeout=10, method="thread")
-def test_complexity():

Review comment:
       do we want to keep something like this test? could be useful for perf regressions

##########
File path: tests/python/relay/test_pass_dead_code_elimination.py
##########
@@ -25,59 +25,106 @@
 import pytest
 
 
-class env:
-    def __init__(self):
-        self.shape = tvm.runtime.convert([1, 2, 3])
-        self.tt = relay.TensorType(self.shape, "float32")
-        self.int32 = relay.TensorType([], "int32")
-        self.float32 = relay.TensorType([], "float32")
-        self.one = relay.const(1.0)
-        self.two = relay.const(2.0)
-        self.three = relay.const(3.0)
-        self.a = relay.Var("a", self.float32)
-        self.b = relay.Var("b", self.float32)
-        self.c = relay.Var("c", self.float32)
-        self.d = relay.Var("d", self.float32)
-        self.e = relay.Var("e", self.float32)
-        self.x = relay.Var("x", self.int32)
-        self.y = relay.Var("y", self.int32)
-        self.z = relay.Var("z", self.int32)
-
-
-e = env()
-
-
-def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, tvm.transform.Pass)
-    mod = tvm.IRModule.from_expr(expr)
-    mod = opt_pass(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
-
-
-def test_let():
-    orig = relay.Let(e.x, e.y, e.z)
-    orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
-
-
-def test_used_let():
-    orig = relay.Let(e.c, e.one, e.c + e.c)
-    orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    expected = relay.Let(e.c, e.one, e.c + e.c)
-    assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
+# class env:
+#     def __init__(self):
+#         self.shape = tvm.runtime.convert([1, 2, 3])
+#         self.tt = relay.TensorType(self.shape, "float32")
+#         self.int32 = relay.TensorType([], "int32")
+#         self.float32 = relay.TensorType([], "float32")
+#         self.one = relay.const(1.0)
+#         self.two = relay.const(2.0)
+#         self.three = relay.const(3.0)
+#         self.a = relay.Var("a", self.float32)
+#         self.b = relay.Var("b", self.float32)
+#         self.c = relay.Var("c", self.float32)
+#         self.d = relay.Var("d", self.float32)
+#         self.e = relay.Var("e", self.float32)
+#         self.x = relay.Var("x", self.int32)
+#         self.y = relay.Var("y", self.int32)
+#         self.z = relay.Var("z", self.int32)
+
+
+# e = env()
+
+
+# def run_opt_pass(expr, opt_pass):
+#     assert isinstance(opt_pass, tvm.transform.Pass)
+#     mod = tvm.IRModule.from_expr(expr)
+#     mod = opt_pass(mod)
+#     entry = mod["main"]
+#     return entry if isinstance(expr, relay.Function) else entry.body
+
+
+def optimize_source(source, passes):
+    if not isinstance(passes, list):
+        passes = [passes]
+
+    optimize = tvm.transform.Sequential(passes)
+    module = tvm.parser.parse(source)
+    return optimize(module)
+
+
+def optimize_and_check(before_source, after_source, passes):
+    optimize_module = optimize_source(before_source, passes)
+    after_module = tvm.parser.parse(after_source)
+    print(optimize_module)
+    print(after_module)
+    assert tvm.ir.structural_equal(after_module, optimize_module)
+
+
+def test_dead_let():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        let %x = 1;
+        %z
+    }
+    """
+    after_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        %z
+    }
+    """
+    optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
 
 
-def test_inline():
-    orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
-    orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
-    tvm.ir.assert_structural_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
+def test_one_live_let():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        let %x = 1;
+        let %y = 2;
+        %x + %x
+    }
+    """
+    after_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        let %x = 1;
+        %x + %x
+    }
+    """
+    optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
 
 
-def test_chain_unused_let():
-    orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
-    orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
+def test_nested_let():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main(%d: int, %b: int) {
+        let %a = %b;
+        let %c = %d;
+        %c
+    }
+    """
+    after_program = """
+    #[version = "0.0.5"]
+    def @main(%d: int, %b: int) {
+        let %c = %d;
+        %c
+    }
+    """
+    optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
 
 
 def use_f(func):

Review comment:
       is this still used anywhere? 

##########
File path: tests/python/relay/test_pass_dead_code_elimination.py
##########
@@ -25,59 +25,106 @@
 import pytest
 
 
-class env:
-    def __init__(self):
-        self.shape = tvm.runtime.convert([1, 2, 3])
-        self.tt = relay.TensorType(self.shape, "float32")
-        self.int32 = relay.TensorType([], "int32")
-        self.float32 = relay.TensorType([], "float32")
-        self.one = relay.const(1.0)
-        self.two = relay.const(2.0)
-        self.three = relay.const(3.0)
-        self.a = relay.Var("a", self.float32)
-        self.b = relay.Var("b", self.float32)
-        self.c = relay.Var("c", self.float32)
-        self.d = relay.Var("d", self.float32)
-        self.e = relay.Var("e", self.float32)
-        self.x = relay.Var("x", self.int32)
-        self.y = relay.Var("y", self.int32)
-        self.z = relay.Var("z", self.int32)
-
-
-e = env()
-
-
-def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, tvm.transform.Pass)
-    mod = tvm.IRModule.from_expr(expr)
-    mod = opt_pass(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
-
-
-def test_let():
-    orig = relay.Let(e.x, e.y, e.z)
-    orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
-
-
-def test_used_let():
-    orig = relay.Let(e.c, e.one, e.c + e.c)
-    orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    expected = relay.Let(e.c, e.one, e.c + e.c)
-    assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
+# class env:

Review comment:
       remove




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org