You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by lm...@apache.org on 2020/11/07 19:10:38 UTC
[incubator-tvm] branch main updated: [AutoScheduler] Make
SearchTask and ComputeDAG serializable (#6842)
This is an automated email from the ASF dual-hosted git repository.
lmzheng pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 12582fd [AutoScheduler] Make SearchTask and ComputeDAG serializable (#6842)
12582fd is described below
commit 12582fdcbf2b67e7eb3064d13d06bbf6c71e62f0
Author: Cody Yu <co...@gmail.com>
AuthorDate: Sat Nov 7 11:10:25 2020 -0800
[AutoScheduler] Make SearchTask and ComputeDAG serializable (#6842)
* serialize task and dag
* fix test
* more tests
* format
* format
* format
* trigger ci
---
python/tvm/auto_scheduler/compute_dag.py | 23 +++++++++++++----
python/tvm/auto_scheduler/search_task.py | 29 +++++++++++++++++++++
.../python/unittest/test_auto_scheduler_common.py | 4 +--
.../unittest/test_auto_scheduler_compute_dag.py | 30 +++++++++++++++++++++-
4 files changed, 77 insertions(+), 9 deletions(-)
diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py
index d50ff39..9390a9c 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -21,14 +21,14 @@ import hashlib
import tvm._ffi
from tvm.runtime import Object
-from tvm.te import PlaceholderOp, ComputeOp
+from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON
+from tvm.te import ComputeOp, PlaceholderOp
+from . import _ffi_api
from .loop_state import State, StateObject
from .utils import get_const_tuple
from .workload_registry import workload_key_to_tensors
-from . import _ffi_api
-
@tvm._ffi.register_object("auto_scheduler.ComputeDAG")
class ComputeDAG(Object):
@@ -63,7 +63,10 @@ class ComputeDAG(Object):
elif isinstance(compute_or_sche, list):
for item in compute_or_sche:
if not isinstance(item, tvm.te.Tensor):
- raise ValueError("The input of ComputeDAG should be a list of Tensor")
+ raise ValueError(
+ "The input of ComputeDAG should be a list of Tensor, but got %s"
+ % type(item)
+ )
compute = compute_or_sche
sche = None
elif isinstance(compute_or_sche, tvm.te.Schedule):
@@ -72,8 +75,10 @@ class ComputeDAG(Object):
else:
raise ValueError(
"Invalid compute type: %s. ComputeDAG expects string, list of Tensor, or Schedule"
- % type(compute)
+ % type(compute_or_sche)
)
+ self.compute = compute
+ self.sche = sche
self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute, sche)
def get_init_state(self):
@@ -182,3 +187,11 @@ class ComputeDAG(Object):
str_key = str_key.encode(encoding="utf-8")
return hashlib.md5(str_key).hexdigest()
+
+ def __getstate__(self):
+ return {"compute": SaveJSON(self.compute), "sche": SaveJSON(self.sche)}
+
+ def __setstate__(self, state):
+ self.compute = LoadJSON(state["compute"]) # pylint: disable=assignment-from-no-return
+ self.sche = LoadJSON(state["sche"]) # pylint: disable=assignment-from-no-return
+ self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, self.compute, self.sche)
diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py
index 92c4f48..7c5021b 100644
--- a/python/tvm/auto_scheduler/search_task.py
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -42,6 +42,35 @@ class SearchTask(Object):
"""
def __init__(self, dag, workload_key, target, target_host=None, hardware_params=None):
+ self.dag = dag
+ self.workload_key = workload_key
+ self.target = target
+ self.target_host = target_host
+ self.hardware_params = hardware_params
self.__init_handle_by_constructor__(
_ffi_api.SearchTask, dag, workload_key, target, target_host, hardware_params
)
+
+ def __getstate__(self):
+ return {
+ "dag": self.dag,
+ "workload_key": self.workload_key,
+ "target": self.target,
+ "target_host": self.target_host,
+ "hardware_params": self.hardware_params,
+ }
+
+ def __setstate__(self, state):
+ self.dag = state["dag"]
+ self.workload_key = state["workload_key"]
+ self.target = state["target"]
+ self.target_host = state["target_host"]
+ self.hardware_params = state["hardware_params"]
+ self.__init_handle_by_constructor__(
+ _ffi_api.SearchTask,
+ self.dag,
+ self.workload_key,
+ self.target,
+ self.target_host,
+ self.hardware_params,
+ )
diff --git a/tests/python/unittest/test_auto_scheduler_common.py b/tests/python/unittest/test_auto_scheduler_common.py
index 6a3fe4e..5b7add9 100644
--- a/tests/python/unittest/test_auto_scheduler_common.py
+++ b/tests/python/unittest/test_auto_scheduler_common.py
@@ -161,14 +161,12 @@ def conv2d_winograd_nhwc_auto_scheduler_test(
r = KW
m = tile_size
alpha = m + r - 1
- A, B, G = winograd_transform_matrices(m, r, "float32")
+ A, B, _ = winograd_transform_matrices(m, r, "float32")
H = (H + 2 * HPAD - KH) // HSTR + 1
W = (W + 2 * WPAD - KW) // WSTR + 1
nH, nW = (H + m - 1) // m, (W + m - 1) // m
P = N * nH * nW
- r_kh = te.reduce_axis((0, KH), name="r_kh")
- r_kw = te.reduce_axis((0, KW), name="r_kw")
kshape = (alpha, alpha, CI, CO)
kernel_pack = te.placeholder(kshape, inputs.dtype, name="weight")
diff --git a/tests/python/unittest/test_auto_scheduler_compute_dag.py b/tests/python/unittest/test_auto_scheduler_compute_dag.py
index 2ccedef..e777475 100644
--- a/tests/python/unittest/test_auto_scheduler_compute_dag.py
+++ b/tests/python/unittest/test_auto_scheduler_compute_dag.py
@@ -16,6 +16,7 @@
# under the License.
"""Test ComputeDAG (replay, infer bound)"""
+import pickle
import tvm
from tvm import topi
@@ -32,7 +33,7 @@ def test_apply_steps():
dag, s = get_tiled_matmul()
dag.print_python_code_from_state(s)
sch, tensors = dag.apply_steps_from_state(s)
- stmt = tvm.lower(sch, tensors, simple_mode=True)
+ tvm.lower(sch, tensors, simple_mode=True)
def test_infer_bound():
@@ -61,6 +62,7 @@ def test_estimate_flop():
def test_stage_order():
+ """Test if the stage order is preserved when recovering a DAG."""
N = 512
A, B, C, D, E = parallel_matmul_auto_scheduler_test(N)
sch = te.create_schedule([D.op, E.op])
@@ -87,6 +89,11 @@ def test_stage_order():
elif op.name in ["B", "C"]:
assert stage_ops_1[idx + 1].name == "%s.shared" % op.name
+ # Serialize and deserialize the ComputeDAG constructed by a schedule.
+ loaded_dag = pickle.loads(pickle.dumps(dag))
+ assert str(loaded_dag.get_init_state()) == str(dag.get_init_state())
+ assert len(loaded_dag.get_init_state().stage_ops) == len(dag.get_init_state().stage_ops)
+
# Apply the same schedule to Ansor state and it should have the same stage order
dag = auto_scheduler.ComputeDAG([A, B, C, D, E])
state = dag.get_init_state()
@@ -105,6 +112,27 @@ def test_stage_order():
for op1, op2 in zip(stage_ops_1, stage_ops_2):
assert op1.name == op2.name
+ # Serialize and deserialize the ComputeDAG constructed by a list of tensor ops.
+ loaded_dag = pickle.loads(pickle.dumps(dag))
+ assert str(loaded_dag.get_init_state()) == str(dag.get_init_state())
+ assert len(loaded_dag.get_init_state().stage_ops) == len(dag.get_init_state().stage_ops)
+
+ # Serialize and deserialize the search task.
+ task = auto_scheduler.SearchTask(
+ dag,
+ "test1",
+ tvm.target.Target("llvm"),
+ hardware_params=auto_scheduler.HardwareParams(100000, 16, 64),
+ )
+ task2 = pickle.loads(pickle.dumps(task))
+ assert str(task.dag.get_init_state()) == str(task2.dag.get_init_state())
+ assert len(task.dag.get_init_state().stage_ops) == len(task2.dag.get_init_state().stage_ops)
+ assert task.workload_key == task2.workload_key
+ assert str(task.target) == str(task2.target)
+ assert task.hardware_params.num_cores == task2.hardware_params.num_cores
+ assert task.hardware_params.vector_unit_bytes == task2.hardware_params.vector_unit_bytes
+ assert task.hardware_params.cache_line_bytes == task2.hardware_params.cache_line_bytes
+
if __name__ == "__main__":
test_apply_steps()