You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/12/04 00:07:45 UTC

[GitHub] [tvm] MarisaKirisame commented on a change in pull request #7029: [Relay][Pass] Clean up DCE tests in preparation for refactoring.

MarisaKirisame commented on a change in pull request #7029:
URL: https://github.com/apache/tvm/pull/7029#discussion_r535737856



##########
File path: tests/python/relay/test_pass_dead_code_elimination.py
##########
@@ -91,62 +138,107 @@ def use_f(func):
     return relay.Let(f, value, func(f))
 
 
-# make sure we dont infinite loop
-def test_recursion():
+def test_live_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
+    """
+
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
     """
-    Program:
-       let f(n: i32, data: f32) -> f32 = {
-          if (n == 0) {
-              return data;
-          } else {
-              return f(n - 1, log(data));
-          }
-       }
-       f(2, 10000);
+
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
+
+
+def test_dead_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        ()
+    }
     """
-    orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
-    dced = run_opt_pass(orig, transform.DeadCodeElimination())
-    orig = run_opt_pass(orig, transform.InferType())
-    tvm.ir.assert_structural_equal(dced, orig)
 
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        ()
+    }
+    """
 
-def test_recursion_dead():
-    x = relay.Let(e.a, e.one, e.three)
-    dced_f = lambda f: x
-    dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, e.three)
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
 
 
-def test_op_let():
-    dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, add(e.three, e.two))
+def test_dead_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        (let %a = 1; 3) + 2
+    }
+    """
 
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        3 + 2
+    }
+    """
 
-def test_tuple_get_item():
-    tt = relay.TupleType([e.float32, e.float32])
-    t = relay.Var("t", tt)
-    a = relay.Var("a")
-    g = relay.TupleGetItem(t, 0)
-    dced = run_opt_pass(g, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
-    orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
-    dced = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
 
 
-@pytest.mark.timeout(timeout=10, method="thread")
-def test_complexity():
-    g = inception_v3.get_net(1, 1000, (3, 299, 299), "float32")
-    run_opt_pass(g, transform.DeadCodeElimination())
+def test_dead_recursion():

Review comment:
       wrong name

##########
File path: tests/python/relay/test_pass_dead_code_elimination.py
##########
@@ -25,59 +25,106 @@
 import pytest
 
 
-class env:
-    def __init__(self):
-        self.shape = tvm.runtime.convert([1, 2, 3])
-        self.tt = relay.TensorType(self.shape, "float32")
-        self.int32 = relay.TensorType([], "int32")
-        self.float32 = relay.TensorType([], "float32")
-        self.one = relay.const(1.0)
-        self.two = relay.const(2.0)
-        self.three = relay.const(3.0)
-        self.a = relay.Var("a", self.float32)
-        self.b = relay.Var("b", self.float32)
-        self.c = relay.Var("c", self.float32)
-        self.d = relay.Var("d", self.float32)
-        self.e = relay.Var("e", self.float32)
-        self.x = relay.Var("x", self.int32)
-        self.y = relay.Var("y", self.int32)
-        self.z = relay.Var("z", self.int32)
-
-
-e = env()
-
-
-def run_opt_pass(expr, opt_pass):
-    assert isinstance(opt_pass, tvm.transform.Pass)
-    mod = tvm.IRModule.from_expr(expr)
-    mod = opt_pass(mod)
-    entry = mod["main"]
-    return entry if isinstance(expr, relay.Function) else entry.body
-
-
-def test_let():
-    orig = relay.Let(e.x, e.y, e.z)
-    orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.z], e.z))
-
-
-def test_used_let():
-    orig = relay.Let(e.c, e.one, e.c + e.c)
-    orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    expected = relay.Let(e.c, e.one, e.c + e.c)
-    assert tvm.ir.structural_equal(Function([], orig), Function([], expected))
+# class env:
+#     def __init__(self):
+#         self.shape = tvm.runtime.convert([1, 2, 3])
+#         self.tt = relay.TensorType(self.shape, "float32")
+#         self.int32 = relay.TensorType([], "int32")
+#         self.float32 = relay.TensorType([], "float32")
+#         self.one = relay.const(1.0)
+#         self.two = relay.const(2.0)
+#         self.three = relay.const(3.0)
+#         self.a = relay.Var("a", self.float32)
+#         self.b = relay.Var("b", self.float32)
+#         self.c = relay.Var("c", self.float32)
+#         self.d = relay.Var("d", self.float32)
+#         self.e = relay.Var("e", self.float32)
+#         self.x = relay.Var("x", self.int32)
+#         self.y = relay.Var("y", self.int32)
+#         self.z = relay.Var("z", self.int32)
+
+
+# e = env()
+
+
+# def run_opt_pass(expr, opt_pass):
+#     assert isinstance(opt_pass, tvm.transform.Pass)
+#     mod = tvm.IRModule.from_expr(expr)
+#     mod = opt_pass(mod)
+#     entry = mod["main"]
+#     return entry if isinstance(expr, relay.Function) else entry.body
+
+
+def optimize_source(source, passes):
+    if not isinstance(passes, list):
+        passes = [passes]
+
+    optimize = tvm.transform.Sequential(passes)
+    module = tvm.parser.parse(source)
+    return optimize(module)
+
+
+def optimize_and_check(before_source, after_source, passes):
+    optimize_module = optimize_source(before_source, passes)
+    after_module = tvm.parser.parse(after_source)
+    print(optimize_module)
+    print(after_module)
+    assert tvm.ir.structural_equal(after_module, optimize_module)
+
+
+def test_dead_let():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        let %x = 1;
+        %z
+    }
+    """
+    after_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        %z
+    }
+    """
+    optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
 
 
-def test_inline():
-    orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.c))
-    orig = run_opt_pass(orig, transform.DeadCodeElimination(True))
-    tvm.ir.assert_structural_equal(Function(free_vars(orig), orig), Function([e.d], e.d))
+def test_one_live_let():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        let %x = 1;
+        let %y = 2;
+        %x + %x
+    }
+    """
+    after_program = """
+    #[version = "0.0.5"]
+    def @main(%z: int) {
+        let %x = 1;
+        %x + %x
+    }
+    """
+    optimize_and_check(before_program, after_program, transform.DeadCodeElimination())
 
 
-def test_chain_unused_let():
-    orig = relay.Let(e.a, e.b, relay.Let(e.c, e.d, e.e))
-    orig = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(orig), orig), Function([e.e], e.e))
+def test_nested_let():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main(%d: int, %b: int) {
+        let %a = %b;
+        let %c = %d;
+        %c
+    }
+    """
+    after_program = """
+    #[version = "0.0.5"]
+    def @main(%d: int, %b: int) {
+        let %c = %d;

Review comment:
       are you sure? it should optimize to just %d (at least when I wrote it)

##########
File path: tests/python/relay/test_pass_dead_code_elimination.py
##########
@@ -91,62 +138,107 @@ def use_f(func):
     return relay.Let(f, value, func(f))
 
 
-# make sure we dont infinite loop
-def test_recursion():
+def test_live_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
+    """
+
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
     """
-    Program:
-       let f(n: i32, data: f32) -> f32 = {
-          if (n == 0) {
-              return data;
-          } else {
-              return f(n - 1, log(data));
-          }
-       }
-       f(2, 10000);
+
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
+
+
+def test_dead_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        ()
+    }
     """
-    orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
-    dced = run_opt_pass(orig, transform.DeadCodeElimination())
-    orig = run_opt_pass(orig, transform.InferType())
-    tvm.ir.assert_structural_equal(dced, orig)
 
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        ()
+    }
+    """
 
-def test_recursion_dead():
-    x = relay.Let(e.a, e.one, e.three)
-    dced_f = lambda f: x
-    dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, e.three)
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
 
 
-def test_op_let():
-    dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, add(e.three, e.two))
+def test_dead_recursion():

Review comment:
       wrong name

##########
File path: tests/python/relay/test_pass_dead_code_elimination.py
##########
@@ -91,62 +138,107 @@ def use_f(func):
     return relay.Let(f, value, func(f))
 
 
-# make sure we dont infinite loop
-def test_recursion():
+def test_live_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
+    """
+
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        %f(2, 10000)
+    }
     """
-    Program:
-       let f(n: i32, data: f32) -> f32 = {
-          if (n == 0) {
-              return data;
-          } else {
-              return f(n - 1, log(data));
-          }
-       }
-       f(2, 10000);
+
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
+
+
+def test_dead_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        let %f = fn (%n: int, %data: int) -> int {
+            if (%n == 0) {
+                %data
+            } else {
+                %f(%n - 1, log(%data))
+            }
+        };
+        ()
+    }
     """
-    orig = use_f(lambda f: relay.Call(f, [relay.const(2), relay.const(10000.0)]))
-    dced = run_opt_pass(orig, transform.DeadCodeElimination())
-    orig = run_opt_pass(orig, transform.InferType())
-    tvm.ir.assert_structural_equal(dced, orig)
 
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        ()
+    }
+    """
 
-def test_recursion_dead():
-    x = relay.Let(e.a, e.one, e.three)
-    dced_f = lambda f: x
-    dced = run_opt_pass(use_f(dced_f), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, e.three)
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
 
 
-def test_op_let():
-    dced = run_opt_pass(add(relay.Let(e.a, e.one, e.three), e.two), transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(dced, add(e.three, e.two))
+def test_dead_recursion():
+    before_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        (let %a = 1; 3) + 2
+    }
+    """
 
+    after_program = """
+    #[version = "0.0.5"]
+    def @main() {
+        3 + 2
+    }
+    """
 
-def test_tuple_get_item():
-    tt = relay.TupleType([e.float32, e.float32])
-    t = relay.Var("t", tt)
-    a = relay.Var("a")
-    g = relay.TupleGetItem(t, 0)
-    dced = run_opt_pass(g, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
-    orig = relay.TupleGetItem(relay.Let(a, e.one, t), 0)
-    dced = run_opt_pass(orig, transform.DeadCodeElimination())
-    assert tvm.ir.structural_equal(Function(free_vars(dced), dced), Function(free_vars(g), g))
+    optimize_and_check(
+        before_program, after_program, [transform.DeadCodeElimination(), transform.InferType()]
+    )
 
 
-@pytest.mark.timeout(timeout=10, method="thread")

Review comment:
       keep this?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org