You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/11/07 19:10:38 UTC

[incubator-tvm] branch main updated: [AutoScheduler] Make SearchTask and ComputeDAG serializable (#6842)

This is an automated email from the ASF dual-hosted git repository.

lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 12582fd  [AutoScheduler] Make SearchTask and ComputeDAG serializable (#6842)
12582fd is described below

commit 12582fdcbf2b67e7eb3064d13d06bbf6c71e62f0
Author: Cody Yu <co...@gmail.com>
AuthorDate: Sat Nov 7 11:10:25 2020 -0800

    [AutoScheduler] Make SearchTask and ComputeDAG serializable (#6842)
    
    * serialize task and dag
    
    * fix test
    
    * more tests
    
    * format
    
    * format
    
    * format
    
    * trigger ci
---
 python/tvm/auto_scheduler/compute_dag.py           | 23 +++++++++++++----
 python/tvm/auto_scheduler/search_task.py           | 29 +++++++++++++++++++++
 .../python/unittest/test_auto_scheduler_common.py  |  4 +--
 .../unittest/test_auto_scheduler_compute_dag.py    | 30 +++++++++++++++++++++-
 4 files changed, 77 insertions(+), 9 deletions(-)

diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py
index d50ff39..9390a9c 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -21,14 +21,14 @@ import hashlib
 
 import tvm._ffi
 from tvm.runtime import Object
-from tvm.te import PlaceholderOp, ComputeOp
+from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON
+from tvm.te import ComputeOp, PlaceholderOp
 
+from . import _ffi_api
 from .loop_state import State, StateObject
 from .utils import get_const_tuple
 from .workload_registry import workload_key_to_tensors
 
-from . import _ffi_api
-
 
 @tvm._ffi.register_object("auto_scheduler.ComputeDAG")
 class ComputeDAG(Object):
@@ -63,7 +63,10 @@ class ComputeDAG(Object):
         elif isinstance(compute_or_sche, list):
             for item in compute_or_sche:
                 if not isinstance(item, tvm.te.Tensor):
-                    raise ValueError("The input of ComputeDAG should be a list of Tensor")
+                    raise ValueError(
+                        "The input of ComputeDAG should be a list of Tensor, but got %s"
+                        % type(item)
+                    )
             compute = compute_or_sche
             sche = None
         elif isinstance(compute_or_sche, tvm.te.Schedule):
@@ -72,8 +75,10 @@ class ComputeDAG(Object):
         else:
             raise ValueError(
                 "Invalid compute type: %s. ComputeDAG expects string, list of Tensor, or Schedule"
-                % type(compute)
+                % type(compute_or_sche)
             )
+        self.compute = compute
+        self.sche = sche
         self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute, sche)
 
     def get_init_state(self):
@@ -182,3 +187,11 @@ class ComputeDAG(Object):
 
         str_key = str_key.encode(encoding="utf-8")
         return hashlib.md5(str_key).hexdigest()
+
+    def __getstate__(self):
+        return {"compute": SaveJSON(self.compute), "sche": SaveJSON(self.sche)}
+
+    def __setstate__(self, state):
+        self.compute = LoadJSON(state["compute"])  # pylint: disable=assignment-from-no-return
+        self.sche = LoadJSON(state["sche"])  # pylint: disable=assignment-from-no-return
+        self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, self.compute, self.sche)
diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py
index 92c4f48..7c5021b 100644
--- a/python/tvm/auto_scheduler/search_task.py
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -42,6 +42,35 @@ class SearchTask(Object):
     """
 
     def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
+        self.dag = dag
+        self.workload_key = workload_key
+        self.target = target
+        self.target_host = target_host
+        self.hardware_params = hardware_params
         self.__init_handle_by_constructor__(
             _ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
         )
+
+    def __getstate__(self):
+        return {
+            "dag": self.dag,
+            "workload_key": self.workload_key,
+            "target": self.target,
+            "target_host": self.target_host,
+            "hardware_params": self.hardware_params,
+        }
+
+    def __setstate__(self, state):
+        self.dag = state["dag"]
+        self.workload_key = state["workload_key"]
+        self.target = state["target"]
+        self.target_host = state["target_host"]
+        self.hardware_params = state["hardware_params"]
+        self.__init_handle_by_constructor__(
+            _ffi_api.SearchTask,
+            self.dag,
+            self.workload_key,
+            self.target,
+            self.target_host,
+            self.hardware_params,
+        )
diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py
index 6a3fe4e..5b7add9 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -161,14 +161,12 @@ def conv2d_winograd_nhwc_auto_scheduler_test(
     r = KW
     m = tile_size
     alpha = m + r - 1
-    A, B, G = winograd_transform_matrices(m, r, "float32")
+    A, B, _ = winograd_transform_matrices(m, r, "float32")
 
     H = (H + 2 * HPAD - KH) // HSTR + 1
     W = (W + 2 * WPAD - KW) // WSTR + 1
     nH, nW = (H + m - 1) // m, (W + m - 1) // m
     P = N * nH * nW
-    r_kh = te.reduce_axis((0, KH), name="r_kh")
-    r_kw = te.reduce_axis((0, KW), name="r_kw")
     kshape = (alpha, alpha, CI, CO)
     kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight")
 
diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py
index 2ccedef..e777475 100644
--- a/tests/python/unittest/test_auto_scheduler_compute_dag.py
+++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py
@@ -16,6 +16,7 @@
 # under the License.
 
 """Test ComputeDAG (replay, infer bound)"""
+import pickle
 
 import tvm
 from tvm import topi
@@ -32,7 +33,7 @@ def test_apply_steps():
     dag, s = get_tiled_matmul()
     dag.print_python_code_from_state(s)
     sch, tensors = dag.apply_steps_from_state(s)
-    stmt = tvm.lower(sch, tensors, simple_mode=True)
+    tvm.lower(sch, tensors, simple_mode=True)
 
 
 def test_infer_bound():
@@ -61,6 +62,7 @@ def test_estimate_flop():
 
 
 def test_stage_order():
+    """Test if the stage order is preserved when recovering a DAG."""
     N = 512
     A, B, C, D, E = parallel_matmul_auto_scheduler_test(N)
     sch = te.create_schedule([D.op, E.op])
@@ -87,6 +89,11 @@ def test_stage_order():
         elif op.name in ["B", "C"]:
             assert stage_ops_1[idx + 1].name == "%s.shared" % op.name
 
+    # Serialize and deserialize the ComputeDAG constructed by a schedule.
+    loaded_dag = pickle.loads(pickle.dumps(dag))
+    assert str(loaded_dag.get_init_state()) == str(dag.get_init_state())
+    assert len(loaded_dag.get_init_state().stage_ops) == len(dag.get_init_state().stage_ops)
+
     # Apply the same schedule to Ansor state and it should have the same stage order
     dag = auto_scheduler.ComputeDAG([A, B, C, D, E])
     state = dag.get_init_state()
@@ -105,6 +112,27 @@ def test_stage_order():
     for op1, op2 in zip(stage_ops_1, stage_ops_2):
         assert op1.name == op2.name
 
+    # Serialize and deserialize the ComputeDAG constructed by a list of tensor ops.
+    loaded_dag = pickle.loads(pickle.dumps(dag))
+    assert str(loaded_dag.get_init_state()) == str(dag.get_init_state())
+    assert len(loaded_dag.get_init_state().stage_ops) == len(dag.get_init_state().stage_ops)
+
+    # Serialize and deserialize the search task.
+    task = auto_scheduler.SearchTask(
+        dag,
+        "test1",
+        tvm.target.Target("llvm"),
+        hardware_params=auto_scheduler.HardwareParams(100000, 16, 64),
+    )
+    task2 = pickle.loads(pickle.dumps(task))
+    assert str(task.dag.get_init_state()) == str(task2.dag.get_init_state())
+    assert len(task.dag.get_init_state().stage_ops) == len(task2.dag.get_init_state().stage_ops)
+    assert task.workload_key == task2.workload_key
+    assert str(task.target) == str(task2.target)
+    assert task.hardware_params.num_cores == task2.hardware_params.num_cores
+    assert task.hardware_params.vector_unit_bytes == task2.hardware_params.vector_unit_bytes
+    assert task.hardware_params.cache_line_bytes == task2.hardware_params.cache_line_bytes
+
 
 if __name__ == "__main__":
     test_apply_steps()