You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/02 20:12:05 UTC

[GitHub] [incubator-tvm] jroesch commented on a change in pull request #5962: [Ansor][AutoTVM v2.0] Part 0: Ansor minimum system for auto schedule generating

jroesch commented on a change in pull request #5962:
URL: https://github.com/apache/incubator-tvm/pull/5962#discussion_r449198373



##########
File path: python/tvm/ansor/auto_schedule.py
##########
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+User interface for Ansor auto-scheduler.
+
+The basic schedule search process for Ansor is designed to be:
+`Program sampling` -> `Performance Tuning`.
+
+In `Program sampling`, we use some predefined or heuristic rules to generate several initial

Review comment:
       ```suggestion
   In `Program sampling`, we use predefined precise or heuristic rules to generate several initial
   ```
   I think it makes sense to clarify that prefined means exact or precise rules here.

##########
File path: python/tvm/ansor/auto_schedule.py
##########
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+User interface for Ansor auto-scheduler.
+
+The basic schedule search process for Ansor is designed to be:
+`Program sampling` -> `Performance Tuning`.
+
+In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+and evolutionary search to seek for schedules with the best performance. Candidate schedules will

Review comment:
       ```suggestion
   Candidate schedules are measured against the specific hardware target.
   ```

##########
File path: python/tvm/ansor/auto_schedule.py
##########
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+User interface for Ansor auto-scheduler.
+
+The basic schedule search process for Ansor is designed to be:
+`Program sampling` -> `Performance Tuning`.
+
+In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+and evolutionary search to seek for schedules with the best performance. Candidate schedules will
+be measured in the target hardware.
+"""
+
+import tvm._ffi
+from tvm.runtime import Object
+from .compute_dag import ComputeDAG
+from .measure import LocalBuilder, LocalRunner
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.HardwareParams")
+class HardwareParams(Object):
+    """ The parameters of target hardware, this is used to guide the search process of
+    SearchPolicy.
+
+    TODO(...): This is considering to merge with the new Target:

Review comment:
       We should mark this TODO with someone responsible for it imo. 

##########
File path: python/tvm/ansor/auto_schedule.py
##########
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+User interface for Ansor auto-scheduler.
+
+The basic schedule search process for Ansor is designed to be:
+`Program sampling` -> `Performance Tuning`.
+
+In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model

Review comment:
       ```suggestion
   schedules. Based on these initial starting points, we perform `Performance Tuning` which uses evolutionary search based on a cost model to select schedules with the best performance. 
   ```

##########
File path: python/tvm/ansor/auto_schedule.py
##########
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+User interface for Ansor auto-scheduler.
+
+The basic schedule search process for Ansor is design to be:
+`Program sampling` -> `Performance Tuning`.
+
+In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+and evolutionary search to seek for schedules with the best performance. Candidate schedules will
+be measured in the target hardware.
+"""
+
+import tvm._ffi
+from tvm.runtime import Object
+from .compute_dag import ComputeDAG
+from .measure import LocalBuilder, LocalRunner
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.HardwareParams")
+class HardwareParams(Object):
+    """ The parameters of target hardware, this is used to guide the search process of
+    SearchPolicy.
+
+    TODO(...): This is considering to merge with the new Target:
+    https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844
+
+    Parameters
+    ----------
+    num_cores : int
+        The number of device cores.
+    vector_unit_bytes : int
+        The width of vector units in bytes.
+    cache_line_bytes : int
+        The size of cache line in bytes.
+    max_unroll_vec : int
+        The max length of an axis to be unrolled or vectorized.
+    max_innermost_split_factor : int
+        The max split factor for the innermost tile.
+    """
+    def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes,
+                 max_unroll_vec, max_innermost_split_factor):
+        self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores,
+                                            vector_unit_bytes, cache_line_bytes,
+                                            max_unroll_vec, max_innermost_split_factor)
+
+
+@tvm._ffi.register_object("ansor.SearchTask")
+class SearchTask(Object):
+    """ The meta-information of a search task.
+
+    Parameters
+    ----------
+    dag : ComputeDAG
+        The ComputeDAG for target compute declaration.
+    workload_key : str
+        The workload key for target compute declaration.
+    target : tvm.target.Target
+        The target device of this search task.
+    target_host : Optional[tvm.target.Target]
+        The target host device of this search task.
+    hardware_params : Optional[HardwareParams]
+        Hardware parameters used in this search task.
+    """
+    def __init__(self, dag, workload_key, target, target_host=None,
+                 hardware_params=None):
+        self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag,
+                                            workload_key, target, target_host,
+                                            hardware_params)
+
+
+@tvm._ffi.register_object("ansor.SearchPolicy")
+class SearchPolicy(Object):
+    """ The base class for search policy  """
+
+
+@tvm._ffi.register_object("ansor.EmptyPolicy")
+class EmptyPolicy(SearchPolicy):
+    """ This is an example empty search policy which will always generate
+    the init state of target ComputeDAG.
+    """
+    def __init__(self):
+        self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy)
+
+
+@tvm._ffi.register_object("ansor.TuneOption")
+class TuneOption(Object):
+    """ This controls the options of performance tuning.
+
+    Parameters
+    ----------
+    num_measure_trials: int = 0
+      The number of total schedule measure trials.
+      Ansor takes `num_measure_trials` state for measuring in total, and finally gets the best
+      schedule among them.
+      With `num_measure_trials` == 0, Ansor will do the schedule search but don't involve
+      measurement, this can be used if we want to quickly get a runnable schedule without
+      performance tuning.
+    early_stopping: int = -1
+      Stops early the tuning if no improvement get after n measurements.
+    num_measures_per_round: int = 64
+      The number of programs to be measured at each search round.
+      The whole schedule search process is designed to have several rounds to try a total
+      `num_measure_trials` schedules.
+      We have: `num_search_rounds` = `num_measure_trials` // `num_measures_per_round`
+    verbose: int = 1
+      Verbosity level. 0 for silent, 1 to output information during schedule search.
+    builder: Union[Builder, str] = 'local'
+      Builder which builds the program.
+    runner: Union[Runner, str] = 'local'
+      Runner which runs the program and measures time costs.
+    measure_callbacks: Optional[List[MeasureCallback]]
+      Callback functions called after each measure.
+      Candidates:
+        - ansor.LogToFile
+    pre_search_callbacks: Optional[List[SearchCallback]]
+      Callback functions called before the search process.
+      Candidates:
+        - ansor.PreloadMeasuredStates
+        - ansor.PreloadCustomSketchRule
+        TODO(jcf94): Add these implementation in later PRs.
+    """
+    def __init__(self, num_measure_trials=0, early_stopping=-1, num_measures_per_round=64,

Review comment:
       yeah I agree, there are lots of fields here and its a bit hard to consume

##########
File path: python/tvm/ansor/auto_schedule.py
##########
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+User interface for Ansor auto-scheduler.
+
+The basic schedule search process for Ansor is designed to be:
+`Program sampling` -> `Performance Tuning`.
+
+In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+and evolutionary search to seek for schedules with the best performance. Candidate schedules will
+be measured in the target hardware.
+"""
+
+import tvm._ffi
+from tvm.runtime import Object
+from .compute_dag import ComputeDAG
+from .measure import LocalBuilder, LocalRunner
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.HardwareParams")
+class HardwareParams(Object):
+    """ The parameters of target hardware, this is used to guide the search process of
+    SearchPolicy.
+
+    TODO(...): This is considering to merge with the new Target:
+    https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844
+
+    Parameters
+    ----------
+    num_cores : int
+        The number of device cores.
+    vector_unit_bytes : int
+        The width of vector units in bytes.
+    cache_line_bytes : int
+        The size of cache line in bytes.
+    max_unroll_vec : int
+        The max length of an axis to be unrolled or vectorized.
+    max_innermost_split_factor : int
+        The max split factor for the innermost tile.
+    """
+    def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes,
+                 max_unroll_vec, max_innermost_split_factor):
+        self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores,
+                                            vector_unit_bytes, cache_line_bytes,
+                                            max_unroll_vec, max_innermost_split_factor)
+
+
+@tvm._ffi.register_object("ansor.SearchTask")
+class SearchTask(Object):
+    """ The meta-information of a search task.
+
+    Parameters
+    ----------
+    dag : ComputeDAG
+        The ComputeDAG for target compute declaration.
+    workload_key : str
+        The workload key for target compute declaration.
+    target : tvm.target.Target
+        The target device of this search task.
+    target_host : Optional[tvm.target.Target]
+        The target host device of this search task.
+    hardware_params : Optional[HardwareParams]
+        Hardware parameters used in this search task.
+    """
+    def __init__(self, dag, workload_key, target, target_host=None,
+                 hardware_params=None):
+        self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag,
+                                            workload_key, target, target_host,
+                                            hardware_params)
+
+
+@tvm._ffi.register_object("ansor.SearchPolicy")
+class SearchPolicy(Object):
+    """ The base class for search policy  """
+
+
+@tvm._ffi.register_object("ansor.EmptyPolicy")
+class EmptyPolicy(SearchPolicy):
+    """ This is an example empty search policy which will always generate
+    the init state of target ComputeDAG.
+    """
+    def __init__(self):
+        self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy)
+
+
+@tvm._ffi.register_object("ansor.TuneOption")
+class TuneOption(Object):

Review comment:
       This name reads a bit awkward in english imo, perhaps we could clarify? `TuningConfig`, `TuningOptions`? I think plurality is important here given we have more then one option. 

##########
File path: python/tvm/ansor/compute_dag.py
##########
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Computational graph and its analysis tools """

Review comment:
       ```suggestion
   """ The Ansor computational graph and related program analyses. """
   ```

##########
File path: python/tvm/ansor/loop_state.py
##########
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search. A state consists a current loop structure
+and the transform history to reach its current loop structure.
+To enable flexible manipulation of the loop structures, we implemented a lightweight loop
+structure IR (Intermediate Representation) based on the original TVM IR but specifically
+for schedule search.
+
+We don't use the existing TVM IR but to extend a new Sketch IR on it is because:
+1. We want fast incremental change to the loop structures;
+2. We want serializable transform history for replay, backtracking, and mutation;
+3. We may create some macro schedule primitives that represent the combination of several
+TVM schedule primitives.
+
+After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives.
+Because we share a lot common objects during search, the transformation is implemented in
+copy on write style. All objects are immutable, which is similar to TVM IR.
+"""
+
+import tvm._ffi
+from tvm.te.tensor import Operation, Tensor
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.Iterator")
+class Iterator(Object):
+    """ A loop iterator structure. """
+
+
+@tvm._ffi.register_object("ansor.Stage")
+class Stage(Object):
+    """A stage in the compute declaration. Similar to tvm.te.schedule.Stage"""
+
+
+@tvm._ffi.register_object("ansor.State")

Review comment:
       Why have a split here? it seems like all the state should be in C++ cc @tqchen 

##########
File path: src/ansor/search_policy/search_policy.h
##########
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/search_policy/search_policy.h
+ * \brief The base class for search policy, including the abstract defination of search policy and
+ * some other supporting structures.

Review comment:
       ```suggestion
    *  other supporting data structures.
   ```

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ Base class for measurement callback function. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.Builder")
+class Builder(Object):
+    """ Base class of Builder. """
+
+    def build(self, measure_inputs, verbose=1):
+        """ Build programs and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program building.
+
+        Returns
+        -------
+        res : List[BuildResult]
+        """
+        return _ffi_api.BuilderBuild(self, measure_inputs, verbose)
+
+
+@tvm._ffi.register_object("ansor.Runner")
+class Runner(Object):
+    """ Base class of Runner """
+
+    def run(self, measure_inputs, build_results, verbose=1):
+        """ Run measurement and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        build_results : List[BuildResult]
+            A List of BuildResult to be ran.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program running.
+
+        Returns
+        -------
+        res : List[MeasureResult]
+        """
+        return _ffi_api.RunnerRun(self, measure_inputs, build_results, verbose)
+
+
+@tvm._ffi.register_object("ansor.LocalBuilder")
+class LocalBuilder(Builder):
+    """ LocalBuilder use local CPU cores to build programs in parallel.
+
+    Parameters
+    ----------
+    timeout : int = 15
+        The timeout limit for each build.
+    n_parallel : int = multiprocessing.cpu_count()
+        Number of threads used to build in parallel.
+    build_func : str = 'default'
+        The name of registered build function.
+    """
+
+    def __init__(self,
+                 timeout=15,
+                 n_parallel=multiprocessing.cpu_count(),
+                 build_func='default'):
+        self.__init_handle_by_constructor__(
+            _ffi_api.LocalBuilder, timeout, n_parallel, build_func)
+
+
+@tvm._ffi.register_object("ansor.LocalRunner")
+class LocalRunner(Runner):
+    """ LocalRunner that uses local CPU/GPU to measures the time cost of programs.
+
+    Parameters
+    ----------
+    timeout : int = 10
+        The timeout limit for each run.
+    number : int = 3
+        Number of measure times.
+    repeat : int = 1
+        Number of repeat times in each measure.
+    min_repeat_ms : int = 0
+        The minimum duration of one repeat in milliseconds.
+    cooldown_interval : float = 0.0
+        The cool down interval between two measurements.
+    """
+
+    def __init__(self,
+                 timeout=10,
+                 number=3,
+                 repeat=1,
+                 min_repeat_ms=0,
+                 cooldown_interval=0.0):
+        self.__init_handle_by_constructor__(
+            _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval)
+
+
+class MeasureErrorNo(object):
+    """ Error type for MeasureResult. """
+    NO_ERROR = 0              # No error
+    INSTANTIATION_ERROR = 1   # Errors happen when apply transform steps from init state
+                              # Errors happen when compiling code on host (e.g. tvm.build)
+    COMPILE_HOST = 2
+    COMPILE_DEVICE = 3        # Errors happen when compiling code on device
+                              # (e.g. OpenCL JIT on the device)
+    RUNTIME_DEVICE = 4        # Errors happen when run program on device
+    WRONG_ANSWER = 5          # Answer is wrong when compared to a reference output
+    BUILD_TIMEOUT = 6         # Timeout during compilation
+    RUN_TIMEOUT = 7           # Timeout during run
+    UNKNOWN_ERROR = 8         # Unknown error
+
+
+def make_error_msg():
+    """ Get the error message from traceback. """
+    error_msg = str(traceback.format_exc())
+    if len(error_msg) > MAX_ERROR_MSG_LEN:
+        error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \
+            "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN//2:]
+    return error_msg
+
+
+def local_build_worker(index):

Review comment:
       It is not clear what the purpose of this function is, we should also not pass arguments globally. 

##########
File path: src/ansor/search_policy/search_policy.h
##########
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/search_policy/search_policy.h
+ * \brief The base class for search policy, including the abstract defination of search policy and

Review comment:
       ```suggestion
    * \brief The base class for search policy, including the abstract definition of search policy and
   ```

##########
File path: src/ansor/search_policy/search_policy.h
##########
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/search_policy/search_policy.h
+ * \brief The base class for search policy, including the abstract defination of search policy and
+ * some other supporting structures.
+ *
+ * The basic schedule search process for Ansor is design to be:
+ * `Program sampling` -> `Performance Tuning`.
+ *
+ * In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+ * schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+ * and evolutionary search to seek for schedules with the best performance. Candidate schedules
+ * will be measured in the target hardware.
+ *
+ * \note Adding a new search policy.
+ * In design, there's no need for users to implement their own search policy, our formal search
+ * policy(will be brought later) should be enough to cover auto schedule generation for different
+ * ops/subgraphs, and in the meantime, a custom rule mechanism will be provided to enable
+ * user-defined template search. (which should play a same role as the current AutoTVM template)
+ * This guide is to help understand it better and incase some advanced users have special
+ * requirements.
+ * 1. The only funcion that must be implemented is Search(), the design principe for it is to be
+ * the entry of starting a schedule search and returns the best schedule get.
+ * 2. Imformations about the target ops/subgraphs can be acquired from SearchTask, this structure
+ * also contains HardwareParams which can be used to limit the search space. (For exp. limit the
+ * max vectorize size depending on the vector unit weight of a specific device)
+ * 3. SearchCallback provides more flexibility to do extra affairs during the search process.
+ * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states get
+ * during the search process.
+ */
+
+#ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_
+#define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_
+
+#include <tvm/node/node.h>
+
+#include <unordered_set>
+#include <vector>
+
+#include "../search_task.h"
+
+namespace tvm {
+namespace ansor {
+
+class ProgramMeasurer;
+class SearchPolicyNode;
+
+/*!
+ * \brief Callback function to be called by the search process.
+ * This interface allows to do extra initializations before schedule search or extra
+ * check during/after the schedule search.
+ */
+class SearchCallbackNode : public Object {
+ public:
+  /*!
+   * \brief Run the registered callback function.
+   * \param policy A pointer to SearchPolicyNode.

Review comment:
       ```suggestion
      * \param policy A pointer to a SearchPolicyNode.
   ```

##########
File path: src/ansor/search_policy/search_policy.h
##########
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/search_policy/search_policy.h
+ * \brief The base class for search policy, including the abstract defination of search policy and
+ * some other supporting structures.
+ *
+ * The basic schedule search process for Ansor is design to be:
+ * `Program sampling` -> `Performance Tuning`.
+ *
+ * In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+ * schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+ * and evolutionary search to seek for schedules with the best performance. Candidate schedules
+ * will be measured in the target hardware.
+ *
+ * \note Adding a new search policy.

Review comment:
       The writing here could be improved/clarified. 

##########
File path: python/tvm/ansor/auto_schedule.py
##########
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+User interface for Ansor auto-scheduler.
+
+The basic schedule search process for Ansor is designed to be:
+`Program sampling` -> `Performance Tuning`.
+
+In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+and evolutionary search to seek for schedules with the best performance. Candidate schedules will
+be measured in the target hardware.
+"""
+
+import tvm._ffi
+from tvm.runtime import Object
+from .compute_dag import ComputeDAG
+from .measure import LocalBuilder, LocalRunner
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.HardwareParams")
+class HardwareParams(Object):
+    """ The parameters of target hardware, this is used to guide the search process of
+    SearchPolicy.
+
+    TODO(...): This is considering to merge with the new Target:
+    https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844
+
+    Parameters
+    ----------
+    num_cores : int
+        The number of device cores.
+    vector_unit_bytes : int
+        The width of vector units in bytes.
+    cache_line_bytes : int
+        The size of cache line in bytes.
+    max_unroll_vec : int
+        The max length of an axis to be unrolled or vectorized.
+    max_innermost_split_factor : int
+        The max split factor for the innermost tile.
+    """
+    def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes,
+                 max_unroll_vec, max_innermost_split_factor):
+        self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores,
+                                            vector_unit_bytes, cache_line_bytes,
+                                            max_unroll_vec, max_innermost_split_factor)
+
+
+@tvm._ffi.register_object("ansor.SearchTask")
+class SearchTask(Object):
+    """ The meta-information of a search task.

Review comment:
       Can you clarify here, meta-information is vague and non-informative to me 

##########
File path: python/tvm/ansor/auto_schedule.py
##########
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+User interface for Ansor auto-scheduler.
+
+The basic schedule search process for Ansor is designed to be:
+`Program sampling` -> `Performance Tuning`.
+
+In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+and evolutionary search to seek for schedules with the best performance. Candidate schedules will
+be measured in the target hardware.
+"""
+
+import tvm._ffi
+from tvm.runtime import Object
+from .compute_dag import ComputeDAG
+from .measure import LocalBuilder, LocalRunner
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.HardwareParams")
+class HardwareParams(Object):
+    """ The parameters of target hardware, this is used to guide the search process of

Review comment:
       ```suggestion
       """ The parameters of target hardware used to guide the search process of
   ```

##########
File path: python/tvm/ansor/compute_dag.py
##########
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Computational graph and its analysis tools """
+
+import hashlib
+
+import tvm._ffi
+from tvm.runtime import Object
+from tvm.te import PlaceholderOp, ComputeOp
+
+from .loop_state import State, StateObject
+from .utils import get_const_tuple
+from .workload_registry import workload_key_to_tensors
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.ComputeDAG")
+class ComputeDAG(Object):
+    """
+    Computation declaration graph.

Review comment:
       Can we maybe add some explanation of how this is different then the many other DAGs that exist in TVM today. I think a common user failure mode is confusion between the various IRs and representations. 

##########
File path: python/tvm/ansor/compute_dag.py
##########
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Computational graph and its analysis tools """
+
+import hashlib
+
+import tvm._ffi
+from tvm.runtime import Object
+from tvm.te import PlaceholderOp, ComputeOp
+
+from .loop_state import State, StateObject
+from .utils import get_const_tuple
+from .workload_registry import workload_key_to_tensors
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.ComputeDAG")
+class ComputeDAG(Object):
+    """
+    Computation declaration graph.
+
+    Parameters
+    ----------
+    compute : Union[List[Tensor], str]
+        `Tensor`s or workload key for a compute declaration.
+    """
+    def __init__(self, compute):
+        if isinstance(compute, str):
+            compute = workload_key_to_tensors(compute)
+        elif isinstance(compute, list):
+            for item in compute:
+                if not isinstance(item, tvm.te.Tensor):
+                    raise ValueError("The input of ComputeDAG should be a list of Tensor")
+        else:
+            raise ValueError("Invalid compute: " + compute + ". Expect a string or list of Tensor")
+        self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute)
+
+    def get_init_state(self):
+        """ Get init state of this ComputeDAG.
+
+        Returns
+        -------
+        state : State
+            The initial State without any transform steps.
+        """
+        return State(_ffi_api.ComputeDAGGetInitState(self), self)
+
+    def apply_steps_from_state(self, state):

Review comment:
       @zhiics can you clarify? 

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function

Review comment:
       I think we should avoid any kind of global configuration, and instead unify configuration into context options like we use universally throughout the system cc @tqchen 

##########
File path: python/tvm/ansor/auto_schedule.py
##########
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+User interface for Ansor auto-scheduler.
+
+The basic schedule search process for Ansor is designed to be:
+`Program sampling` -> `Performance Tuning`.
+
+In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+and evolutionary search to seek for schedules with the best performance. Candidate schedules will
+be measured in the target hardware.
+"""
+
+import tvm._ffi
+from tvm.runtime import Object
+from .compute_dag import ComputeDAG
+from .measure import LocalBuilder, LocalRunner
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.HardwareParams")
+class HardwareParams(Object):
+    """ The parameters of target hardware, this is used to guide the search process of
+    SearchPolicy.
+
+    TODO(...): This is considering to merge with the new Target:
+    https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844
+
+    Parameters
+    ----------
+    num_cores : int
+        The number of device cores.
+    vector_unit_bytes : int
+        The width of vector units in bytes.
+    cache_line_bytes : int
+        The size of cache line in bytes.
+    max_unroll_vec : int
+        The max length of an axis to be unrolled or vectorized.
+    max_innermost_split_factor : int
+        The max split factor for the innermost tile.
+    """
+    def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes,
+                 max_unroll_vec, max_innermost_split_factor):
+        self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores,
+                                            vector_unit_bytes, cache_line_bytes,
+                                            max_unroll_vec, max_innermost_split_factor)
+
+
+@tvm._ffi.register_object("ansor.SearchTask")
+class SearchTask(Object):
+    """ The meta-information of a search task.
+
+    Parameters
+    ----------
+    dag : ComputeDAG
+        The ComputeDAG for target compute declaration.
+    workload_key : str
+        The workload key for target compute declaration.
+    target : tvm.target.Target
+        The target device of this search task.
+    target_host : Optional[tvm.target.Target]
+        The target host device of this search task.
+    hardware_params : Optional[HardwareParams]
+        Hardware parameters used in this search task.
+    """
+    def __init__(self, dag, workload_key, target, target_host=None,
+                 hardware_params=None):
+        self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag,
+                                            workload_key, target, target_host,
+                                            hardware_params)
+
+
+@tvm._ffi.register_object("ansor.SearchPolicy")
+class SearchPolicy(Object):
+    """ The base class for search policy  """

Review comment:
       ```suggestion
       """ The base class of search policies. """
   ```

##########
File path: python/tvm/ansor/loop_state.py
##########
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search. A state consists a current loop structure
+and the transform history to reach its current loop structure.
+To enable flexible manipulation of the loop structures, we implemented a lightweight loop
+structure IR (Intermediate Representation) based on the original TVM IR but specifically
+for schedule search.
+
+We don't use the existing TVM IR but to extend a new Sketch IR on it is because:
+1. We want fast incremental change to the loop structures;
+2. We want serializable transform history for replay, backtracking, and mutation;
+3. We may create some macro schedule primitives that represent the combination of several
+TVM schedule primitives.
+
+After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives.

Review comment:
       ```suggestion
   When the search is complete, we will lower this IR to TVM IR with TVM's schedule primitives.
   ```

##########
File path: python/tvm/ansor/auto_schedule.py
##########
@@ -0,0 +1,206 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+User interface for Ansor auto-scheduler.
+
+The basic schedule search process for Ansor is designed to be:
+`Program sampling` -> `Performance Tuning`.
+
+In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+and evolutionary search to seek for schedules with the best performance. Candidate schedules will
+be measured in the target hardware.
+"""
+
+import tvm._ffi
+from tvm.runtime import Object
+from .compute_dag import ComputeDAG
+from .measure import LocalBuilder, LocalRunner
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.HardwareParams")
+class HardwareParams(Object):
+    """ The parameters of target hardware, this is used to guide the search process of
+    SearchPolicy.
+
+    TODO(...): This is considering to merge with the new Target:
+    https://discuss.tvm.ai/t/rfc-tvm-target-specification/6844
+
+    Parameters
+    ----------
+    num_cores : int
+        The number of device cores.
+    vector_unit_bytes : int
+        The width of vector units in bytes.
+    cache_line_bytes : int
+        The size of cache line in bytes.
+    max_unroll_vec : int
+        The max length of an axis to be unrolled or vectorized.
+    max_innermost_split_factor : int
+        The max split factor for the innermost tile.
+    """
+    def __init__(self, num_cores, vector_unit_bytes, cache_line_bytes,
+                 max_unroll_vec, max_innermost_split_factor):
+        self.__init_handle_by_constructor__(_ffi_api.HardwareParams, num_cores,
+                                            vector_unit_bytes, cache_line_bytes,
+                                            max_unroll_vec, max_innermost_split_factor)
+
+
+@tvm._ffi.register_object("ansor.SearchTask")
+class SearchTask(Object):
+    """ The meta-information of a search task.
+
+    Parameters
+    ----------
+    dag : ComputeDAG
+        The ComputeDAG for target compute declaration.
+    workload_key : str
+        The workload key for target compute declaration.
+    target : tvm.target.Target
+        The target device of this search task.
+    target_host : Optional[tvm.target.Target]
+        The target host device of this search task.
+    hardware_params : Optional[HardwareParams]
+        Hardware parameters used in this search task.
+    """
+    def __init__(self, dag, workload_key, target, target_host=None,
+                 hardware_params=None):
+        self.__init_handle_by_constructor__(_ffi_api.SearchTask, dag,
+                                            workload_key, target, target_host,
+                                            hardware_params)
+
+
+@tvm._ffi.register_object("ansor.SearchPolicy")
+class SearchPolicy(Object):
+    """ The base class for search policy  """
+
+
+@tvm._ffi.register_object("ansor.EmptyPolicy")
+class EmptyPolicy(SearchPolicy):
+    """ This is an example empty search policy which will always generate
+    the init state of target ComputeDAG.
+    """
+    def __init__(self):
+        self.__init_handle_by_constructor__(_ffi_api.EmptyPolicy)
+
+
+@tvm._ffi.register_object("ansor.TuneOption")
+class TuneOption(Object):
+    """ This controls the options of performance tuning.
+
+    Parameters
+    ----------
+    num_measure_trials: int = 0
+      The number of total schedule measure trials.
+      Ansor takes `num_measure_trials` state for measuring in total, and finally gets the best
+      schedule among them.
+      With `num_measure_trials` == 0, Ansor will do the schedule search but don't involve
+      measurement, this can be used if we want to quickly get a runnable schedule without
+      performance tuning.
+    early_stopping: int = -1
+      Stops early the tuning if no improvement get after n measurements.
+    num_measures_per_round: int = 64
+      The number of programs to be measured at each search round.
+      The whole schedule search process is designed to have several rounds to try a total
+      `num_measure_trials` schedules.
+      We have: `num_search_rounds` = `num_measure_trials` // `num_measures_per_round`
+    verbose: int = 1
+      Verbosity level. 0 for silent, 1 to output information during schedule search.
+    builder: Union[Builder, str] = 'local'
+      Builder which builds the program.
+    runner: Union[Runner, str] = 'local'
+      Runner which runs the program and measures time costs.
+    measure_callbacks: Optional[List[MeasureCallback]]
+      Callback functions called after each measure.
+      Candidates:
+        - ansor.LogToFile
+    pre_search_callbacks: Optional[List[SearchCallback]]
+      Callback functions called before the search process.
+      Candidates:
+        - ansor.PreloadMeasuredStates
+        - ansor.PreloadCustomSketchRule
+        TODO(jcf94): Add these implementation in later PRs.
+    """
+    def __init__(self, num_measure_trials=0, early_stopping=-1, num_measures_per_round=64,
+                 verbose=1, builder='local', runner='local', measure_callbacks=None,
+                 pre_search_callbacks=None):
+        if isinstance(builder, str):
+            if builder == 'local':
+                builder = LocalBuilder()
+            else:
+                raise ValueError("Invalid builder: " + builder)
+
+        if isinstance(runner, str):
+            if runner == 'local':
+                runner = LocalRunner()
+            else:
+                raise ValueError("Invalid runner: " + runner)
+
+        measure_callbacks = [] if measure_callbacks is None else measure_callbacks
+        pre_search_callbacks = [] if pre_search_callbacks is None else pre_search_callbacks
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.TuneOption, num_measure_trials, early_stopping, num_measures_per_round,
+            verbose, builder, runner, measure_callbacks, pre_search_callbacks)
+
+
+def auto_schedule(task, target, target_host=None, search_policy='default',
+                  hardware_params=None, tune_option=None):
+    """ Do auto scheduling for a computation declaration.
+
+    The task parameter can be a `string` as workload_key, or directly
+    passing a `SearchTask` as input.
+
+    Parameters
+    ----------
+    task : Union[SearchTask, str]
+        The target search task or workload key.
+    target : tvm.target.Target
+        The target device of this schedule search.
+    target_host : Optional[tvm.target.Target]
+        The target host device of this schedule search.
+    search_policy : Union[SearchPolicy, str] = 'default'
+        The search policy to be used for schedule search.
+    hardware_params : Optional[HardwareParams]
+        The hardware parameters of this schedule search.
+    tune_option : Optional[TuneOption]
+        Tuning and measurement options.
+
+    Returns
+    -------
+        A `te.schedule` and the target `te.Tensor`s to be used in `tvm.lower` or `tvm.build`
+    """
+    if isinstance(search_policy, str):
+        if search_policy == 'default':
+            # TODO(jcf94): This is an example policy for minimum system, will be upgrated to
+            # formal search policy later.
+            search_policy = EmptyPolicy()
+        else:
+            raise ValueError("Invalid search policy: " + search_policy)
+
+    tune_option = tune_option if tune_option else TuneOption()
+
+    if isinstance(task, str):
+        dag = ComputeDAG(task)
+        task = SearchTask(dag, task, target, target_host, hardware_params)
+    elif not isinstance(task, SearchTask):
+        raise ValueError("Invalid task: " + task + ". Expect a string or SearchTask")

Review comment:
       ```suggestion
           raise ValueError("Invalid task: " + task + " . `ansor.auto_schedule` expects a `str` or `SearchTask`.")
   ```

##########
File path: python/tvm/ansor/loop_state.py
##########
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search. A state consists a current loop structure
+and the transform history to reach its current loop structure.
+To enable flexible manipulation of the loop structures, we implemented a lightweight loop
+structure IR (Intermediate Representation) based on the original TVM IR but specifically
+for schedule search.
+
+We don't use the existing TVM IR but to extend a new Sketch IR on it is because:

Review comment:
       It might be better to combine these two paragraphs into one, that more clearly states "We have designed a new loop IR specifically for schedule search. Our loop IR is similar to TVM's TIR but importantly has three important additions: ...."

##########
File path: python/tvm/ansor/loop_state.py
##########
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search. A state consists a current loop structure
+and the transform history to reach its current loop structure.
+To enable flexible manipulation of the loop structures, we implemented a lightweight loop

Review comment:
       `flexible` doesn't mean anything in a technical setting imo, it would be good to clarify why we need a new IR here.

##########
File path: python/tvm/ansor/compute_dag.py
##########
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Computational graph and its analysis tools """
+
+import hashlib
+
+import tvm._ffi
+from tvm.runtime import Object
+from tvm.te import PlaceholderOp, ComputeOp
+
+from .loop_state import State, StateObject
+from .utils import get_const_tuple
+from .workload_registry import workload_key_to_tensors
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.ComputeDAG")
+class ComputeDAG(Object):
+    """
+    Computation declaration graph.
+
+    Parameters
+    ----------
+    compute : Union[List[Tensor], str]
+        `Tensor`s or workload key for a compute declaration.
+    """
+    def __init__(self, compute):
+        if isinstance(compute, str):
+            compute = workload_key_to_tensors(compute)
+        elif isinstance(compute, list):
+            for item in compute:
+                if not isinstance(item, tvm.te.Tensor):
+                    raise ValueError("The input of ComputeDAG should be a list of Tensor")
+        else:
+            raise ValueError("Invalid compute: " + compute + ". Expect a string or list of Tensor")
+        self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute)
+
+    def get_init_state(self):
+        """ Get init state of this ComputeDAG.
+
+        Returns
+        -------
+        state : State
+            The initial State without any transform steps.
+        """
+        return State(self.init_state, self)
+
+    def apply_steps_from_state(self, state):
+        """
+        Apply transform steps according to the history of a State.
+
+        Parameters
+        ----------
+        state : Union[State, StateObject]
+            The target state to be applied to TVM schedule.
+
+        Returns
+        -------
+            A `te.schedule` and the target `te.Tensor`s to be used in `tvm.lower` or `tvm.build`
+        """
+        state_obj = state if isinstance(state, StateObject) else state.state_object
+        return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj)
+
+    def print_python_code_from_state(self, state):
+        """
+        Print transform steps in the history of a State as TVM's python schedule primitive.
+
+        Parameters
+        ----------
+        state : Union[State, StateObject]
+            The target state to be applied to TVM schedule.
+
+        Returns
+        -------
+        str : Str
+            The Python schedule code.
+        """
+        state_obj = state if isinstance(state, StateObject) else state.state_object
+        return _ffi_api.ComputeDAGPrintPythonCodeFromState(self, state_obj)
+
+    def infer_bound_from_state(self, state):
+        """
+        Infer bound for a state using TVM schedule.

Review comment:
       I believe I know what you are talking about, but `infer bound` is relatively vague and is worth explaining more clearly to users what this is doing. 

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ Base class for measurement callback function. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):

Review comment:
       yeah I agree, I also personally think `build` is nonsensical in a compiler and we should strive to remove it from TVM period as it carries no technical meaning in the compiler world, and is vague enough to everyone else it doesn't actually let us know what operation is occurring. 

##########
File path: python/tvm/ansor/loop_state.py
##########
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search. A state consists a current loop structure
+and the transform history to reach its current loop structure.

Review comment:
       ```suggestion
   and a history of transformations used to construct it. 
   ```

##########
File path: python/tvm/ansor/workload_registry.py
##########
@@ -0,0 +1,170 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Workload registration and serialization.
+
+We use a json string to represent a workload (a compute dag).
+The format of the string is `[func_name, [args...]]`.
+The dag should be the return value of this `func_name(*args)`.
+
+Rationale: The workload is actually a compute dag defined by tvm dsl. But serializing compute dags
+and matching them efficiently is not easy. Therefore, we use the above string to encode a compute
+dag.
+These strings are efficient for serialization/matching and wont' be too long.
+When we need the dag, we decode the string and call the function, which will return the dag.
+"""
+
+import pickle
+import json
+
+import tvm._ffi
+from .utils import serialize_args, deserialize_args
+
+WORKLOAD_FUNC_REGISTRY = {}
+
+
+def register_workload_by_func(func):
+    """ Register a workload by generation function.

Review comment:
       This doesn't make sense to me, can you explain? 

##########
File path: python/tvm/ansor/utils.py
##########
@@ -0,0 +1,195 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Common utilities for ansor"""

Review comment:
       ```suggestion
   """Common utilities for ansor."""
   ```

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ Base class for measurement callback function. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.Builder")
+class Builder(Object):
+    """ Base class of Builder. """
+
+    def build(self, measure_inputs, verbose=1):

Review comment:
       See above comment. 

##########
File path: python/tvm/ansor/loop_state.py
##########
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search. A state consists a current loop structure
+and the transform history to reach its current loop structure.
+To enable flexible manipulation of the loop structures, we implemented a lightweight loop
+structure IR (Intermediate Representation) based on the original TVM IR but specifically
+for schedule search.
+
+We don't use the existing TVM IR but to extend a new Sketch IR on it is because:
+1. We want fast incremental change to the loop structures;
+2. We want serializable transform history for replay, backtracking, and mutation;
+3. We may create some macro schedule primitives that represent the combination of several
+TVM schedule primitives.
+
+After the search is done, we will lower this IR to TVM IR with TVM's schedule primitives.
+Because we share a lot common objects during search, the transformation is implemented in
+copy on write style. All objects are immutable, which is similar to TVM IR.
+"""
+
+import tvm._ffi
+from tvm.te.tensor import Operation, Tensor
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.Iterator")
+class Iterator(Object):
+    """ A loop iterator structure. """
+
+
+@tvm._ffi.register_object("ansor.Stage")
+class Stage(Object):
+    """A stage in the compute declaration. Similar to tvm.te.schedule.Stage"""
+
+
+@tvm._ffi.register_object("ansor.State")
+class StateObject(Object):
+    """ The internal State object """
+    def __eq__(self, other):
+        return _ffi_api.StateEqual(self, other)
+
+
+class State:
+    """
+    A state in the search process. It consists of the current loop structure
+    and the history steps to reach this state.
+
+    Each State corresponds to a specific schedule for the target ComputeDAG.
+
+    Parameters
+    ----------
+    state_object : StateObject
+        The target StateObject, corresponding to C++ internal State object.
+    dag : ComputeDAG
+        The original target ComputeDAG of this State.
+
+    Notes
+    -----
+    This is a wrapper class of StateObject to deal with copy-on-write property
+    """
+    def __init__(self, state_object, dag):
+        self.state_object = state_object
+        self.compute_dag = dag
+
+        self.stages_cache = None  # A list to cache all stages

Review comment:
       its not clear to me that the cache should be in the Python interface, it seems like almost all operations on this object are just interacting with cache then C++.

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ Base class for measurement callback function. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.Builder")
+class Builder(Object):

Review comment:
       ```suggestion
   class ProgramBuilder(Object):
   ```
   We should be more specific here

##########
File path: python/tvm/ansor/workload_registry.py
##########
@@ -0,0 +1,170 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Workload registration and serialization.
+
+We use a json string to represent a workload (a compute dag).
+The format of the string is `[func_name, [args...]]`.

Review comment:
       If we are using a serialization format we should version and track it, unversioned serialization is a bug waiting to happen imo. 

##########
File path: python/tvm/ansor/serialization.py
##########
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Serialization and other I/O support for tuning logs (measurement records)"""
+
+import numpy as np
+
+import tvm._ffi
+from tvm.runtime import Object
+from .measure import MeasureCallback, MeasureErrorNo
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.LogToFile")
+class LogToFile(MeasureCallback):
+    """
+    A measurement callback that writes measurement records into a file.
+
+    Parameters
+    ----------
+    filename : str
+        File name for this callback to write log to.
+    """
+    def __init__(self, filename="ansor_tuning.json"):
+        self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename)
+
+
+@tvm._ffi.register_object("ansor.LogReader")
+class LogReader(Object):
+    """
+    Reader of the json log file.
+
+    Parameters
+    ----------
+    filename : str = "ansor_tuning.json"
+        File name for this reader to load log from.
+    """
+    def __init__(self, filename="ansor_tuning.json"):
+        self.__init_handle_by_constructor__(_ffi_api.LogReader, filename)
+
+    def read_lines(self, max_lines=-1, skip_lines=0):
+        """ Read multiple lines from the log file.
+
+        Parameters
+        ----------
+        max_lines : int = -1
+            The maximum number of lines. -1 means to read all lines.
+        skip_lines : int = 0
+            Skip the first n lines.
+
+        Returns
+        -------
+        inputs : List[MeasureInput]
+            The MeasureInputs loaded from the log file.
+        results : List[MeasureResult]
+            The MeasureResults loaded from the log file.
+        """
+        inputs, results = _ffi_api.LogReaderReadLines(self, max_lines, skip_lines)
+        return inputs, results
+
+    def __iter__(self):
+        while True:
+            ret = _ffi_api.LogReaderReadNext(self)
+            if not ret:
+                break
+            yield ret[0], ret[1]  # (input, result)
+
+
+def load_from_file(filename):
+    """
+    Load measurement records from a file.
+
+    Parameters
+    ----------
+    filename : str
+        File name to load log from.
+
+    Returns
+    -------
+    logs : List[MeasureInput, MeasureResult]
+    """
+    return zip(*LogReader(filename).read_lines())
+
+
+def append_measure_records_to_file(filename, inputs, results):

Review comment:
       👍 

##########
File path: python/tvm/ansor/serialization.py
##########
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Serialization and other I/O support for tuning logs (measurement records)"""
+
+import numpy as np
+
+import tvm._ffi
+from tvm.runtime import Object
+from .measure import MeasureCallback, MeasureErrorNo
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.LogToFile")
+class LogToFile(MeasureCallback):

Review comment:
       A better question is do we need classes for this at all? could we not just use reader/writer functions? cc @tqchen 

##########
File path: src/ansor/serialization.h
##########
@@ -0,0 +1,136 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/serialization.h
+ * \brief Json serialization format for dumping and loading tuning records.
+ */
+
+#ifndef TVM_ANSOR_SERIALIZATION_H_
+#define TVM_ANSOR_SERIALIZATION_H_
+
+#include <fstream>
+#include <string>
+#include <utility>
+
+#include "measure.h"
+
+namespace tvm {
+namespace ansor {
+
+/*! \brief Callback for logging the input and results of measurements to file */

Review comment:
       RE: my Python comment it is probably more flexible to just have reader/writer callbacks. 

##########
File path: python/tvm/ansor/utils.py
##########
@@ -0,0 +1,195 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Common utilities for ansor"""
+
+from typing import Hashable
+import multiprocessing
+import multiprocessing.pool
+import queue
+import signal
+
+try:
+    import psutil
+except ImportError:
+    raise ImportError("psutil not found, try `pip install psutil` to fix this")
+
+from tvm.tir import expr
+from tvm.tir.transform import Simplify
+from tvm.ir.transform import Sequential
+from ..te import Tensor, placeholder
+
+
+def get_func_name(func):
+    """Get name of a function.
+
+    Parameters
+    ----------
+    func: Function
+        The target function.
+
+    Returns
+    -------
+    name: str
+        The function name.
+    """
+    return func.func_name if hasattr(func, 'func_name') else func.__name__
+
+
+def get_const_int(exp):
+    """Verifies expr is integer and get the constant value.
+
+    Parameters
+    ----------
+    exp : tvm.Expr or int
+        The input expression.
+
+    Returns
+    -------
+    out_value : int
+        The output.
+    """
+    if isinstance(exp, int):
+        return exp
+    if not isinstance(exp, (expr.IntImm)):
+        opt = Sequential([Simplify()])
+        exp = opt(exp)
+    if not isinstance(exp, (expr.IntImm)):
+        raise ValueError("Expect value to be constant int")
+    return exp.value
+
+
+def get_const_tuple(in_tuple):
+    """Verifies input tuple is IntImm, returns tuple of int.
+
+    Parameters
+    ----------
+    in_tuple : tuple of Expr
+        The input.
+
+    Returns
+    -------
+    out_tuple : tuple of int
+        The output.
+    """
+    return tuple(get_const_int(x) for x in in_tuple)
+
+
+
+def list_to_tuple(x):
+    """ Convert a list to a tuple recursively. """
+    assert isinstance(x, list)
+    return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x)
+
+
+def serialize_args(args):
+    """
+    Serialize arguments of a function to a hashable and jsonable tuple.
+    Currently this is mainly used for tvm.tensor.Tensor
+    """
+    ret = []
+    for t in args:
+        if isinstance(t, Tensor):
+            t = ('TENSOR', get_const_tuple(t.shape), t.dtype)
+        elif isinstance(t, list):
+            t = list_to_tuple(t)
+
+        assert isinstance(t, Hashable), str(t) + " is not hashable"
+        ret.append(t)
+
+    return tuple(ret)
+
+
+def deserialize_args(args):
+    """The inverse function of :code:`serialize_args`"""
+    ret = []
+    for t in args:
+        if isinstance(t, (tuple, list)) and t[0] == 'TENSOR':
+            ret.append(placeholder(shape=t[1], dtype=t[2]))
+        else:
+            ret.append(t)
+    return ret
+
+
+class NoDaemonProcess(multiprocessing.Process):
+    @property
+    def daemon(self):
+        return False
+
+    @daemon.setter
+    def daemon(self, value):
+        pass
+
+
+class NoDaemonContext(type(multiprocessing.get_context())):
+    Process = NoDaemonProcess
+
+
+class NoDaemonPool(multiprocessing.pool.Pool):
+    """A no daemon pool version of multiprocessing.Pool.

Review comment:
       Can we improve this comment? 

##########
File path: src/ansor/compute_dag.cc
##########
@@ -0,0 +1,507 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/compute_dag.cc
+ * \brief Compute declaration graph and its related analysis tools.
+ */
+
+#include "compute_dag.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "loop_state.h"
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+using namespace tvm::tir;

Review comment:
       Improve comments in the file, vague comments are more likely to bitrot or mislead the reader of the code. 

##########
File path: src/ansor/utils.cc
##########
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/utils.cc
+ * \brief Common utilities.
+ */
+
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+NullStream& NullStream::Global() {
+  static NullStream stream;
+  return stream;
+}
+
+ThreadPool& ThreadPool::Global() {
+  static ThreadPool* pool = new ThreadPool();
+  static int ct = 0;
+
+  ct = (ct + 1) % ThreadPool::REFRESH_EVERY;

Review comment:
       Yes that would be good. cc @tqchen having multiple thread pools scares me a bit thoughts? 

##########
File path: src/ansor/auto_schedule.h
##########
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/auto_schedule.h
+ * \brief The user interface of the Ansor auto-scheduler. This is the entry structure to get
+ * schedule search requirements from upper level (Python API), and returns a high performance
+ * schedule after search process.
+ */
+
+#ifndef TVM_ANSOR_AUTO_SCHEDULE_H_
+#define TVM_ANSOR_AUTO_SCHEDULE_H_
+
+#include <utility>
+
+#include "measure.h"
+#include "search_policy/search_policy.h"
+
+namespace tvm {
+namespace ansor {
+
+/*! \brief Tuning and measurement options. */

Review comment:
       Same comment as above I placed on the Python class name. 

##########
File path: python/tvm/ansor/serialization.py
##########
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Serialization and other I/O support for tuning logs (measurement records)"""

Review comment:
       I like using `record` or `configuration` or something that is different. To me we should use this as a way to move away from `log file` terminology, its incredibly confusing when interacting with people outside of TVM as they assume you are talking about traditional service logging. It would be good to have a distinct name for the results of AutoTVM/Ansor imo. 

##########
File path: src/ansor/search_task.h
##########
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/search_task.h
+ * \brief Meta information and hardware parameters for a search task.
+ */
+
+#ifndef TVM_ANSOR_SEARCH_TASK_H_
+#define TVM_ANSOR_SEARCH_TASK_H_
+
+#include <tvm/target/target.h>
+
+#include "compute_dag.h"
+
+namespace tvm {
+namespace ansor {
+
+class HardwareParams;
+
+/*! \brief Hardware related parameters */
+class HardwareParamsNode : public Object {
+ public:
+  /*! \brief The number of cores. */
+  int num_cores;
+  /*! \brief The width of vector units in bytes. */
+  int vector_unit_bytes;
+  /*! \brief The size of cache line in bytes. */
+  int cache_line_bytes;
+  /*! \brief The max length of an axis to be unrolled or vectorized. */
+  int max_unroll_vec;
+  /*! \brief The max split factor for the innermost tile. */
+  int max_innermost_split_factor;
+
+  // Limitation params for GPU
+
+  /*! \brief The max shared memory per block. */
+  int max_shared_memory_per_block{INT32_MAX};
+  /*! \brief The max register memory per block. */
+  int max_registers_per_block{INT32_MAX};
+  /*! \brief The max threads per block. */
+  int max_threads_per_block{INT32_MAX};
+  /*! \brief The max vthread extent. */
+  int max_vthread_extent{INT32_MAX};
+  /*! \brief The thread numbers of a warp. */
+  int warp_size{INT32_MAX};
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("num_cores", &num_cores);
+    v->Visit("vector_unit_bytes", &vector_unit_bytes);
+    v->Visit("cache_line_bytes", &cache_line_bytes);
+    v->Visit("max_unroll_vec", &max_unroll_vec);
+    v->Visit("max_innermost_split_factor", &max_innermost_split_factor);
+    v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block);
+    v->Visit("max_registers_per_block", &max_registers_per_block);
+    v->Visit("max_threads_per_block", &max_threads_per_block);
+    v->Visit("max_vthread_extent", &max_vthread_extent);
+    v->Visit("warp_size", &warp_size);
+  }
+
+  /*!
+   * \brief Get the default hardware params.
+   * \param target A `tvm.target`.
+   * \param target_host A `tvm.target` for host device.
+   * \return A HardwareParams object.
+   */
+  static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host);
+
+  static constexpr const char* _type_key = "ansor.HardwareParams";
+  TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object);
+};
+
+/*!
+ * \brief Managed reference to HardwareParamsNode.
+ * \sa HardwareParamsNode
+ */
+class HardwareParams : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param num_cores The number of cores.
+   * \param vector_unit_bytes The width of vector units in bytes.
+   * \param cache_line_bytes The size of cache line in bytes.
+   * \param max_unroll_vec The max length of an axis to be unrolled or vectorized.
+   * \param max_innermost_split_factor The max split factor for the innermost tile.
+   */
+  HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, int max_unroll_vec,
+                 int max_innermost_split_factor);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode);
+};
+
+/*! \brief Meta-info for a search task */
+class SearchTaskNode : public Object {
+ public:
+  /*! \brief The ComputeDAG for target compute declaration. */
+  ComputeDAG compute_dag;
+  /*! \brief The workload key for target compute declaration. */
+  String workload_key;
+  /*! \brief The target device of this search task. */
+  Target target;
+  /*! \brief The target host device of this search task. */
+  Target target_host;
+  /*! \brief Hardware parameters used in this search task. */
+  HardwareParams hardware_params;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("compute_dag", &compute_dag);
+    v->Visit("workload_key", &workload_key);
+    v->Visit("target", &target);
+    v->Visit("target_host", &target_host);
+    v->Visit("hardware_params", &hardware_params);
+  }
+
+  static constexpr const char* _type_key = "ansor.SearchTask";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object);
+};
+
+/*!
+ * \brief Managed reference to SearchTaskNode.
+ * \sa SearchTaskNode
+ */
+class SearchTask : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param compute_dag The ComputeDAG for target compute declaration.

Review comment:
       This comment isn't clear "target compute declaration"

##########
File path: src/ansor/search_policy/search_policy.h
##########
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/search_policy/search_policy.h
+ * \brief The base class for search policy, including the abstract defination of search policy and
+ * some other supporting structures.
+ *
+ * The basic schedule search process for Ansor is design to be:
+ * `Program sampling` -> `Performance Tuning`.
+ *
+ * In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+ * schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+ * and evolutionary search to seek for schedules with the best performance. Candidate schedules
+ * will be measured in the target hardware.
+ *
+ * \note Adding a new search policy.
+ * In design, there's no need for users to implement their own search policy, our formal search
+ * policy(will be brought later) should be enough to cover auto schedule generation for different
+ * ops/subgraphs, and in the meantime, a custom rule mechanism will be provided to enable
+ * user-defined template search. (which should play a same role as the current AutoTVM template)
+ * This guide is to help understand it better and incase some advanced users have special
+ * requirements.
+ * 1. The only funcion that must be implemented is Search(), the design principe for it is to be
+ * the entry of starting a schedule search and returns the best schedule get.
+ * 2. Imformations about the target ops/subgraphs can be acquired from SearchTask, this structure
+ * also contains HardwareParams which can be used to limit the search space. (For exp. limit the
+ * max vectorize size depending on the vector unit weight of a specific device)
+ * 3. SearchCallback provides more flexibility to do extra affairs during the search process.
+ * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states get
+ * during the search process.
+ */
+
+#ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_
+#define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_
+
+#include <tvm/node/node.h>
+
+#include <unordered_set>
+#include <vector>
+
+#include "../search_task.h"
+
+namespace tvm {
+namespace ansor {
+
+class ProgramMeasurer;
+class SearchPolicyNode;
+
+/*!
+ * \brief Callback function to be called by the search process.
+ * This interface allows to do extra initializations before schedule search or extra
+ * check during/after the schedule search.
+ */
+class SearchCallbackNode : public Object {
+ public:
+  /*!
+   * \brief Run the registered callback function.
+   * \param policy A pointer to SearchPolicyNode.
+   */
+  virtual void Callback(SearchPolicyNode* policy) = 0;
+
+  static constexpr const char* _type_key = "ansor.SearchCallback";
+  TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object);
+};
+
+/*!
+ * \brief Managed reference to SearchCallbackNode.
+ * \sa SearchCallbackNode
+ */
+class SearchCallback : public ObjectRef {
+ public:
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode);
+};
+
+/*!
+ * \brief The base class for search policy.
+ */
+class SearchPolicyNode : public Object {
+ public:
+  /*! \brief The current search task. */
+  SearchTask cur_task;
+  /*!
+   * \brief Verbose level to control the screen output during schedule search.
+   * 0 for silent, 1 to output information.
+   */
+  int verbose;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("cur_task", &cur_task);
+    v->Visit("verbose", &verbose);
+  }
+
+  /*!
+   * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state
+   * get during the search process.
+   * \param task The target search task.
+   * \param num_measure_trials Total schedules to be tried during this search.
+   * \param early_stopping Early stop if no better schedule is found.
+   * \param num_measures_per_round Max measure batch in one search round.
+   * \param verbose Verbose level. 0 for silent, 1 to output information during schedule search.
+   * \param measurer A ProgramMeasurer which packs Builder & Runner inside.
+   * \param pre_search_callbacks SearchCallback to be called before schedule search.
+   * \return The best state get.
+   */
+  virtual State Search(SearchTask task, int num_measure_trials, int early_stopping,
+                       int num_measures_per_round, int verbose, ProgramMeasurer measurer,
+                       Array<SearchCallback> pre_search_callbacks) = 0;
+
+  /*!
+   * \brief Call SearchCallback with the current SearchPolicyNode
+   * \param callbacks SearchCallback to be called.
+   */
+  void RunCallbacks(const Array<SearchCallback>& callbacks);
+
+  static constexpr const char* _type_key = "ansor.SearchPolicy";
+  TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object);
+
+ protected:
+  /*!
+   * \brief The set of already measured states.
+   * We store the string format for redundancy check.

Review comment:
       More details please. 

##########
File path: src/ansor/compute_dag.h
##########
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/compute_dag.h
+ * \brief Compute declaration graph and its related analysis tools.
+ * ComputeDAG is also responsible for the interaction with the original TVM schedule system, to
+ * apply state to a runable TVM schedule or provide the schedule Python code.
+ */
+
+#ifndef TVM_ANSOR_COMPUTE_DAG_H_
+#define TVM_ANSOR_COMPUTE_DAG_H_
+
+#include <tvm/te/schedule.h>
+
+#include <utility>
+
+#include "loop_state.h"
+
+namespace tvm {
+namespace ansor {
+
+/*! \brief Computation declaration graph. */
+class ComputeDAGNode : public Object {
+ public:
+  /*! \brief Input and output tensors. */
+  Array<te::Tensor> tensors;
+  /*! \brief All related operations in topo order. */
+  Array<te::Operation> ops;
+  /*! \brief Number of total float operations for this ComputeDAG. */
+  double flop_ct;
+  /*! \brief The initial state without any transform steps. */
+  State init_state;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("tensors", &tensors);
+    v->Visit("ops", &ops);
+    v->Visit("flop_ct", &flop_ct);
+    v->Visit("init_state", &init_state);
+  }
+
+  static constexpr const char* _type_key = "ansor.ComputeDAG";
+  TVM_DECLARE_FINAL_OBJECT_INFO(ComputeDAGNode, Object);
+};
+
+/*!
+ * \brief Managed reference to ComputeDAGNode.
+ * \sa ComputeDAGNode
+ */
+class ComputeDAG : public ObjectRef {
+ public:
+  /*! \brief The constructor.
+   * \param tensors `te::Tensor`s for a compute declaration.
+   */
+  explicit ComputeDAG(Array<te::Tensor> tensors);
+
+  /*!
+   * \brief Apply transform steps to the init state of this DAG, and get the
+   * equivalent `tvm::schedule`.
+   * \param transform_steps Transform steps of the target state.
+   * \return The return values can be used as arguments to `tvm.build` or `tvm.lower`.
+   */
+  std::pair<te::Schedule, Array<te::Tensor> > ApplySteps(const Array<Step>& transform_steps) const;
+  /*!
+   * \brief Print transform steps as equivalent python schedule API.
+   * \param transform_steps Transform steps of the target state.
+   * \return Python schedule code.
+   */
+  String PrintStepsAsPython(const Array<Step>& transform_steps) const;
+
+  /*!
+   * \brief Replay the transform steps and call ir_pass::InferBound to fill correct bound
+   * information.
+   * State api supports to define a split step with its split factor to be a blank placeholder,
+   * so sometimes we may get a State will incomplete iterator extent information.
+   * And another situation is after some steps (for exp. compute_at), it may be hard to track the
+   * extent change of all iterators.
+   * We perform infer bound using TVM schedule and fill the State with those informations. After
+   * applying this methods, the State is guaranteed to have complete interator extent information.
+   * \param transform_steps Transform steps of the target state.
+   * \return The State after inferbound.
+   */
+  State ReplayAndInferBound(const Array<Step>& transform_steps) const;
+  /*!
+   * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound.
+   * \param state The target state.
+   * \return The State after inferbound.
+   */
+  State InferBound(const State& state) const;
+  /*!
+   * \brief Fill the correct bound information for a list of given states.
+   * Return the new states inplace.
+   * \param states A pointer to a State vector, States are updated inplace.
+   */
+  void InferBound(Array<State>* states) const;
+
+  TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
+
+ private:
+  /*!
+   * \brief Internal common parts for replaying steps. This is the key method to apply steps to

Review comment:
       Can you clarify why this is split out, vs ReplayAndInferBounds, does it need to be interface?

##########
File path: src/ansor/search_task.h
##########
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/search_task.h
+ * \brief Meta information and hardware parameters for a search task.
+ */
+
+#ifndef TVM_ANSOR_SEARCH_TASK_H_
+#define TVM_ANSOR_SEARCH_TASK_H_
+
+#include <tvm/target/target.h>
+
+#include "compute_dag.h"
+
+namespace tvm {
+namespace ansor {
+
+class HardwareParams;
+
+/*! \brief Hardware related parameters */
+class HardwareParamsNode : public Object {
+ public:
+  /*! \brief The number of cores. */
+  int num_cores;
+  /*! \brief The width of vector units in bytes. */
+  int vector_unit_bytes;
+  /*! \brief The size of cache line in bytes. */
+  int cache_line_bytes;
+  /*! \brief The max length of an axis to be unrolled or vectorized. */
+  int max_unroll_vec;
+  /*! \brief The max split factor for the innermost tile. */
+  int max_innermost_split_factor;
+
+  // Limitation params for GPU
+
+  /*! \brief The max shared memory per block. */
+  int max_shared_memory_per_block{INT32_MAX};
+  /*! \brief The max register memory per block. */
+  int max_registers_per_block{INT32_MAX};
+  /*! \brief The max threads per block. */
+  int max_threads_per_block{INT32_MAX};
+  /*! \brief The max vthread extent. */
+  int max_vthread_extent{INT32_MAX};
+  /*! \brief The thread numbers of a warp. */
+  int warp_size{INT32_MAX};
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("num_cores", &num_cores);
+    v->Visit("vector_unit_bytes", &vector_unit_bytes);
+    v->Visit("cache_line_bytes", &cache_line_bytes);
+    v->Visit("max_unroll_vec", &max_unroll_vec);
+    v->Visit("max_innermost_split_factor", &max_innermost_split_factor);
+    v->Visit("max_shared_memory_per_block", &max_shared_memory_per_block);
+    v->Visit("max_registers_per_block", &max_registers_per_block);
+    v->Visit("max_threads_per_block", &max_threads_per_block);
+    v->Visit("max_vthread_extent", &max_vthread_extent);
+    v->Visit("warp_size", &warp_size);
+  }
+
+  /*!
+   * \brief Get the default hardware params.
+   * \param target A `tvm.target`.
+   * \param target_host A `tvm.target` for host device.
+   * \return A HardwareParams object.
+   */
+  static HardwareParams GetDefaultHardwareParams(const Target& target, const Target& target_host);
+
+  static constexpr const char* _type_key = "ansor.HardwareParams";
+  TVM_DECLARE_FINAL_OBJECT_INFO(HardwareParamsNode, Object);
+};
+
+/*!
+ * \brief Managed reference to HardwareParamsNode.
+ * \sa HardwareParamsNode
+ */
+class HardwareParams : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param num_cores The number of cores.
+   * \param vector_unit_bytes The width of vector units in bytes.
+   * \param cache_line_bytes The size of cache line in bytes.
+   * \param max_unroll_vec The max length of an axis to be unrolled or vectorized.
+   * \param max_innermost_split_factor The max split factor for the innermost tile.
+   */
+  HardwareParams(int num_cores, int vector_unit_bytes, int cache_line_bytes, int max_unroll_vec,
+                 int max_innermost_split_factor);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(HardwareParams, ObjectRef, HardwareParamsNode);
+  TVM_DEFINE_OBJECT_REF_COW_METHOD(HardwareParamsNode);
+};
+
+/*! \brief Meta-info for a search task */
+class SearchTaskNode : public Object {
+ public:
+  /*! \brief The ComputeDAG for target compute declaration. */
+  ComputeDAG compute_dag;
+  /*! \brief The workload key for target compute declaration. */
+  String workload_key;
+  /*! \brief The target device of this search task. */
+  Target target;
+  /*! \brief The target host device of this search task. */
+  Target target_host;
+  /*! \brief Hardware parameters used in this search task. */
+  HardwareParams hardware_params;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("compute_dag", &compute_dag);
+    v->Visit("workload_key", &workload_key);
+    v->Visit("target", &target);
+    v->Visit("target_host", &target_host);
+    v->Visit("hardware_params", &hardware_params);
+  }
+
+  static constexpr const char* _type_key = "ansor.SearchTask";
+  TVM_DECLARE_FINAL_OBJECT_INFO(SearchTaskNode, Object);
+};
+
+/*!
+ * \brief Managed reference to SearchTaskNode.

Review comment:
       ```suggestion
    * \brief Managed reference to a SearchTaskNode.
   ```

##########
File path: src/ansor/search_policy/search_policy.h
##########
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/search_policy/search_policy.h
+ * \brief The base class for search policy, including the abstract defination of search policy and
+ * some other supporting structures.
+ *
+ * The basic schedule search process for Ansor is design to be:
+ * `Program sampling` -> `Performance Tuning`.
+ *
+ * In `Program sampling`, we use some predefined or heuristic rules to generate several initial
+ * schedules. Based on these initial start points, we have `Performance Tuning` to apply cost model
+ * and evolutionary search to seek for schedules with the best performance. Candidate schedules
+ * will be measured in the target hardware.
+ *
+ * \note Adding a new search policy.
+ * In design, there's no need for users to implement their own search policy, our formal search
+ * policy(will be brought later) should be enough to cover auto schedule generation for different
+ * ops/subgraphs, and in the meantime, a custom rule mechanism will be provided to enable
+ * user-defined template search. (which should play a same role as the current AutoTVM template)
+ * This guide is to help understand it better and incase some advanced users have special
+ * requirements.
+ * 1. The only funcion that must be implemented is Search(), the design principe for it is to be
+ * the entry of starting a schedule search and returns the best schedule get.
+ * 2. Imformations about the target ops/subgraphs can be acquired from SearchTask, this structure
+ * also contains HardwareParams which can be used to limit the search space. (For exp. limit the
+ * max vectorize size depending on the vector unit weight of a specific device)
+ * 3. SearchCallback provides more flexibility to do extra affairs during the search process.
+ * 4. ProgramMeasurer provides a simple but useful api to help check the performance of states get
+ * during the search process.
+ */
+
+#ifndef TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_
+#define TVM_ANSOR_SEARCH_POLICY_SEARCH_POLICY_H_
+
+#include <tvm/node/node.h>
+
+#include <unordered_set>
+#include <vector>
+
+#include "../search_task.h"
+
+namespace tvm {
+namespace ansor {
+
+class ProgramMeasurer;
+class SearchPolicyNode;
+
+/*!
+ * \brief Callback function to be called by the search process.
+ * This interface allows to do extra initializations before schedule search or extra
+ * check during/after the schedule search.
+ */
+class SearchCallbackNode : public Object {
+ public:
+  /*!
+   * \brief Run the registered callback function.
+   * \param policy A pointer to SearchPolicyNode.
+   */
+  virtual void Callback(SearchPolicyNode* policy) = 0;
+
+  static constexpr const char* _type_key = "ansor.SearchCallback";
+  TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object);
+};
+
+/*!
+ * \brief Managed reference to SearchCallbackNode.
+ * \sa SearchCallbackNode
+ */
+class SearchCallback : public ObjectRef {
+ public:
+  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(SearchCallback, ObjectRef, SearchCallbackNode);
+};
+
+/*!
+ * \brief The base class for search policy.
+ */
+class SearchPolicyNode : public Object {
+ public:
+  /*! \brief The current search task. */
+  SearchTask cur_task;
+  /*!
+   * \brief Verbose level to control the screen output during schedule search.
+   * 0 for silent, 1 to output information.
+   */
+  int verbose;
+
+  void VisitAttrs(AttrVisitor* v) {
+    v->Visit("cur_task", &cur_task);
+    v->Visit("verbose", &verbose);
+  }
+
+  /*!
+   * \brief Do schedule search for a task. Takes the SearchTask as input and returns the best state
+   * get during the search process.
+   * \param task The target search task.
+   * \param num_measure_trials Total schedules to be tried during this search.
+   * \param early_stopping Early stop if no better schedule is found.
+   * \param num_measures_per_round Max measure batch in one search round.
+   * \param verbose Verbose level. 0 for silent, 1 to output information during schedule search.
+   * \param measurer A ProgramMeasurer which packs Builder & Runner inside.
+   * \param pre_search_callbacks SearchCallback to be called before schedule search.
+   * \return The best state get.
+   */
+  virtual State Search(SearchTask task, int num_measure_trials, int early_stopping,
+                       int num_measures_per_round, int verbose, ProgramMeasurer measurer,
+                       Array<SearchCallback> pre_search_callbacks) = 0;
+
+  /*!
+   * \brief Call SearchCallback with the current SearchPolicyNode
+   * \param callbacks SearchCallback to be called.
+   */
+  void RunCallbacks(const Array<SearchCallback>& callbacks);
+
+  static constexpr const char* _type_key = "ansor.SearchPolicy";
+  TVM_DECLARE_BASE_OBJECT_INFO(SearchPolicyNode, Object);
+
+ protected:
+  /*!
+   * \brief The set of already measured states.
+   * We store the string format for redundancy check.
+   */
+  std::unordered_set<String> measured_states_set_;
+  /*! \brief The array of already measured states. */
+  std::vector<State> measured_states_vector_;
+  /*! \brief The throughputs of already measured states */
+  std::vector<float> measured_states_throughputs_;
+};
+
+/*!
+ * \brief Managed reference to SearchPolicyNode.

Review comment:
       ```suggestion
    * \brief Managed reference to a SearchPolicyNode.
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org