You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/04 13:00:07 UTC

[GitHub] [incubator-tvm] MarisaKirisame commented on a change in pull request #5962: [Ansor][AutoTVM v2.0] Part 0: Ansor minimum system for auto schedule generating

MarisaKirisame commented on a change in pull request #5962:
URL: https://github.com/apache/incubator-tvm/pull/5962#discussion_r449759123



##########
File path: python/tvm/ansor/compute_dag.py
##########
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Computational graph and its analysis tools """
+
+import hashlib
+
+import tvm._ffi
+from tvm.runtime import Object
+from tvm.te import PlaceholderOp, ComputeOp
+
+from .loop_state import State, StateObject
+from .utils import get_const_tuple
+from .workload_registry import workload_key_to_tensors
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.ComputeDAG")
+class ComputeDAG(Object):
+    """
+    Computation declaration graph.
+
+    Parameters
+    ----------
+    compute : Union[List[Tensor], str]
+        `Tensor`s or workload key for a compute declaration.
+    """
+    def __init__(self, compute):
+        if isinstance(compute, str):
+            compute = workload_key_to_tensors(compute)
+        elif isinstance(compute, list):
+            for item in compute:
+                if not isinstance(item, tvm.te.Tensor):
+                    raise ValueError("The input of ComputeDAG should be a list of Tensor")
+        else:
+            raise ValueError("Invalid compute: " + compute + ". Expect a string or list of Tensor")
+        self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute)
+
+    def get_init_state(self):
+        """ Get init state of this ComputeDAG.
+
+        Returns
+        -------
+        state : State
+            The initial State without any transform steps.
+        """
+        return State(_ffi_api.ComputeDAGGetInitState(self), self)
+
+    def apply_steps_from_state(self, state):
+        """
+        Apply transform steps according to the history of a State.
+
+        Parameters
+        ----------
+        state : Union[State, StateObject]
+            The target state to be applied to TVM schedule.
+
+        Returns
+        -------
+            A `te.schedule` and the target `te.Tensor`s to be used in `tvm.lower` or `tvm.build`
+        """
+        state_obj = state if isinstance(state, StateObject) else state.state_object
+        return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj)
+
+    def print_python_code_from_state(self, state):

Review comment:
       I will just call this codegen.

##########
File path: python/tvm/ansor/compute_dag.py
##########
@@ -0,0 +1,153 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Computational graph and its analysis tools """
+
+import hashlib
+
+import tvm._ffi
+from tvm.runtime import Object
+from tvm.te import PlaceholderOp, ComputeOp
+
+from .loop_state import State, StateObject
+from .utils import get_const_tuple
+from .workload_registry import workload_key_to_tensors
+
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.ComputeDAG")
+class ComputeDAG(Object):
+    """
+    The Ansor computational graph and related program analyses.
+
+    We convert a compute declaration described by `tvm.compute` (could be a single operator or a
+    subgraph) to a ComputeDAG. It keeps the input/output tensors of the target compute declaration,
+    a list of all related operations in topo order as well as a set of analyses over each operation
+    stage (e.g. the total float operation count, consumer/producer relations of each operation
+    stage, whether a operation stage should be tiled/compute inlined ...). These analyses can
+    help the search policy to do some specific decisions during schedule search process.
+
+    ComputeDAG is also responsible for the interaction between Ansor LoopState and TVM schedule
+    (e.g. applying the LoopState transform steps to TVM schedule, providing LoopState with extra
+    information get from TVM schedule ...).
+
+    Parameters
+    ----------
+    compute : Union[List[Tensor], str]
+        `Tensor`s or workload key for a compute declaration.
+    """
+    def __init__(self, compute):
+        if isinstance(compute, str):
+            compute = workload_key_to_tensors(compute)
+        elif isinstance(compute, list):
+            for item in compute:
+                if not isinstance(item, tvm.te.Tensor):
+                    raise ValueError("The input of ComputeDAG should be a list of Tensor")
+        else:
+            raise ValueError("Invalid compute: " + compute +
+                             " . `ComputeDAG` expects a string or list of Tensor")
+        self.__init_handle_by_constructor__(_ffi_api.ComputeDAG, compute)
+
+    def get_init_state(self):
+        """ Get the init state of this ComputeDAG.
+
+        Returns
+        -------
+        state : State
+            The initial State without any transform steps.
+        """
+        return State(self.init_state, self)
+
+    def apply_steps_from_state(self, state):
+        """
+        Apply the history transform steps of a State to TVM schedule.
+
+        Parameters
+        ----------
+        state : Union[State, StateObject]
+            The target state to be applied to TVM schedule.
+
+        Returns
+        -------
+            A `te.schedule` and the target `te.Tensor`s to be used in `tvm.lower` or `tvm.build`
+        """
+        state_obj = state if isinstance(state, StateObject) else state.state_object
+        return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj)
+
+    def print_python_code_from_state(self, state):
+        """
+        Print transform steps in the history of a State as TVM's python schedule primitive.
+
+        Parameters
+        ----------
+        state : Union[State, StateObject]
+            The target state to be applied to TVM schedule.
+
+        Returns
+        -------
+        str : Str
+            The Python schedule code.

Review comment:
       this is weird. why do you return string and parse it again, instead of just returning tvm objects that represent the schedule?

##########
File path: python/tvm/ansor/loop_state.py
##########
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search.
+
+Each LoopState corresponds to a specific schedule for its target ComputeDAG.
+A LoopState consists of: 1. a current loop structure; 2. a history of transformations used to
+construct the loop structure.
+The loop structure keeps a preview of how the schedule will finally look like after lowering the
+current state (e.g. number of iterators, the extent of each iterator, the compute_at locations ...).
+During the schedule search process, the loop structure can provide search policy with necessary
+information on how to perform further operations with the current state.
+The transform history is a sequence of TransformStep which will finally be mapped to schedule
+primitives. The steps can also be used for serialization of a state.
+
+The LoopState can be seen as a lightweight loop structure IR specifically for schedule search.
+We don't use the existing TVM IR but to extend a new structure on it is because:
+1. We want fast incremental change to the loop structures, search policy needs to get the immediate
+loop structures update rather than after TVM lowering;
+2. We want serializable transform history for replay, backtracking, and mutation;
+3. We may create some macro schedule primitives that represent the combination of several
+TVM schedule primitives.
+
+When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives.
+Since we share a lot of common objects during search, the transformation is implemented in
+copy on write style. All objects are immutable, which is similar to TVM IR.
+"""
+
+import tvm._ffi
+from tvm.te.tensor import Operation, Tensor
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.Iterator")
+class Iterator(Object):
+    """ A loop iterator structure. """
+
+
+@tvm._ffi.register_object("ansor.Stage")
+class Stage(Object):
+    """A stage in the compute declaration. Similar to tvm.te.schedule.Stage"""
+
+
+@tvm._ffi.register_object("ansor.State")
+class StateObject(Object):
+    """ The internal State object """
+    def __eq__(self, other):
+        return _ffi_api.StateEqual(self, other)
+
+
+class State:
+    """
+    A state in the search process. It consists of the current loop structure
+    and a history of transformations used to construct it.
+
+    Each State corresponds to a specific schedule for its target ComputeDAG.
+
+    Parameters
+    ----------
+    state_object : StateObject
+        The target StateObject, corresponding to C++ internal State object.
+    dag : ComputeDAG
+        The original target ComputeDAG of this State.
+
+    Notes
+    -----
+    This is a wrapper class of StateObject to deal with copy-on-write property
+    """
+    def __init__(self, state_object, dag):
+        self.state_object = state_object
+        self.compute_dag = dag
+
+        self.stages_cache = None  # A list to cache all stages
+        self.stage_id_map = {}    # A dict maps operation to stage id
+        self._update_stage_id_map()
+
+    @property
+    def stages(self):
+        """
+        Returns
+        -------
+        stages : List[Stage]
+        """
+        if not self.stages_cache:
+            self.stages_cache = self.state_object.stages
+        return self.stages_cache
+
+    @property
+    def stage_ops(self):
+        """
+        Returns
+        -------
+        ops: List[Operation]
+        """
+        if not self.stages_cache:
+            self.stages_cache = self.state_object.stages
+        return [stage.op for stage in self.stages_cache]
+
+    def reorder(self, stage, order):
+        """ Schedule primitive corresponds to te.reorder.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The target Stage to be reordered, can be a Stage order index, Stage operation or stage
+            output tensor.
+        order : List[Iterator]
+            Iterators in the expected order
+        """
+        stage_id = self._resolve_stage_id(stage)
+
+        self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order)
+        self._clear_cache()
+
+    def split(self, stage, iterator, lengths, inner_to_outer=True):
+        """ Schedule primitive corresponds to te.split.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The target Stage to be split, can be a Stage order index, Stage operation or stage
+            output tensor.
+        iterator : Iterator
+            The iterator to split

Review comment:
       ```suggestion
               The iterator to split upon
   ```

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ The base class of measurement callback functions. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.ProgramBuilder")
+class ProgramBuilder(Object):
+    """ The base class of ProgramBuilders. """

Review comment:
       what program? in a compiler everything is a program. be more specific about what step this is.

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ The base class of measurement callback functions. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.ProgramBuilder")
+class ProgramBuilder(Object):
+    """ The base class of ProgramBuilders. """
+
+    def build(self, measure_inputs, verbose=1):
+        """ Build programs and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program building.
+
+        Returns
+        -------
+        res : List[BuildResult]
+        """
+        return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose)
+
+
+@tvm._ffi.register_object("ansor.ProgramRunner")
+class ProgramRunner(Object):
+    """ The base class of ProgramRunners. """
+
+    def run(self, measure_inputs, build_results, verbose=1):
+        """ Run measurement and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        build_results : List[BuildResult]
+            A List of BuildResult to be ran.
+        verbost : int = 1

Review comment:
       why int instead of bool?

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ The base class of measurement callback functions. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.ProgramBuilder")
+class ProgramBuilder(Object):
+    """ The base class of ProgramBuilders. """
+
+    def build(self, measure_inputs, verbose=1):
+        """ Build programs and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program building.
+
+        Returns
+        -------
+        res : List[BuildResult]
+        """
+        return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose)
+
+
+@tvm._ffi.register_object("ansor.ProgramRunner")
+class ProgramRunner(Object):
+    """ The base class of ProgramRunners. """
+
+    def run(self, measure_inputs, build_results, verbose=1):
+        """ Run measurement and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        build_results : List[BuildResult]
+            A List of BuildResult to be ran.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program running.
+
+        Returns
+        -------
+        res : List[MeasureResult]
+        """
+        return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, verbose)
+
+
+@tvm._ffi.register_object("ansor.LocalBuilder")
+class LocalBuilder(ProgramBuilder):
+    """ LocalBuilder use local CPU cores to build programs in parallel.
+
+    Parameters
+    ----------
+    timeout : int = 15
+        The timeout limit for each build.

Review comment:
       15 seconds or miliseconds or minutes?

##########
File path: python/tvm/ansor/serialization.py
##########
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Serialization and other I/O support for tuning logs (measurement records)"""
+
+import numpy as np
+
+import tvm._ffi
+from tvm.runtime import Object
+from .measure import MeasureCallback, MeasureErrorNo
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.LogToFile")
+class LogToFile(MeasureCallback):
+    """
+    A measurement callback that writes measurement records into a file.
+
+    Parameters
+    ----------
+    filename : str
+        File name for this callback to write log to.
+    """
+    def __init__(self, filename="ansor_tuning.json"):
+        self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename)
+
+
+@tvm._ffi.register_object("ansor.LogReader")
+class LogReader(Object):
+    """
+    Reader of the json log file.
+
+    Parameters
+    ----------
+    filename : str = "ansor_tuning.json"
+        File name for this reader to load log from.
+    """
+    def __init__(self, filename="ansor_tuning.json"):
+        self.__init_handle_by_constructor__(_ffi_api.LogReader, filename)
+
+    def read_lines(self, max_lines=-1, skip_lines=0):
+        """ Read multiple lines from the log file.
+
+        Parameters
+        ----------
+        max_lines : int = -1
+            The maximum number of lines. -1 means to read all lines.

Review comment:
       Optional[int] with None to mean read all lines.
   

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ The base class of measurement callback functions. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.ProgramBuilder")
+class ProgramBuilder(Object):
+    """ The base class of ProgramBuilders. """
+
+    def build(self, measure_inputs, verbose=1):
+        """ Build programs and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program building.
+
+        Returns
+        -------
+        res : List[BuildResult]
+        """
+        return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose)
+
+
+@tvm._ffi.register_object("ansor.ProgramRunner")
+class ProgramRunner(Object):
+    """ The base class of ProgramRunners. """
+
+    def run(self, measure_inputs, build_results, verbose=1):
+        """ Run measurement and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        build_results : List[BuildResult]
+            A List of BuildResult to be ran.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program running.
+
+        Returns
+        -------
+        res : List[MeasureResult]
+        """
+        return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, verbose)
+
+
+@tvm._ffi.register_object("ansor.LocalBuilder")
+class LocalBuilder(ProgramBuilder):
+    """ LocalBuilder use local CPU cores to build programs in parallel.
+
+    Parameters
+    ----------
+    timeout : int = 15
+        The timeout limit for each build.
+    n_parallel : int = multiprocessing.cpu_count()
+        Number of threads used to build in parallel.
+    build_func : str = 'default'
+        The name of registered build function.
+    """
+
+    def __init__(self,
+                 timeout=15,
+                 n_parallel=multiprocessing.cpu_count(),
+                 build_func='default'):
+        self.__init_handle_by_constructor__(
+            _ffi_api.LocalBuilder, timeout, n_parallel, build_func)
+
+
+@tvm._ffi.register_object("ansor.LocalRunner")
+class LocalRunner(ProgramRunner):
+    """ LocalRunner that uses local CPU/GPU to measures the time cost of programs.
+
+    Parameters
+    ----------
+    timeout : int = 10
+        The timeout limit for each run.
+    number : int = 3
+        Number of measure times.
+    repeat : int = 1
+        Number of repeat times in each measure.
+    min_repeat_ms : int = 0
+        The minimum duration of one repeat in milliseconds.
+    cooldown_interval : float = 0.0
+        The cool down interval between two measurements.
+    """
+
+    def __init__(self,
+                 timeout=10,
+                 number=3,
+                 repeat=1,
+                 min_repeat_ms=0,
+                 cooldown_interval=0.0):
+        self.__init_handle_by_constructor__(
+            _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval)
+
+
+class MeasureErrorNo(object):
+    """ Error type for MeasureResult. """
+    NO_ERROR = 0              # No error
+    INSTANTIATION_ERROR = 1   # Errors happen when apply transform steps from init state
+                              # Errors happen when compiling code on host (e.g. tvm.build)
+    COMPILE_HOST = 2
+    COMPILE_DEVICE = 3        # Errors happen when compiling code on device
+                              # (e.g. OpenCL JIT on the device)
+    RUNTIME_DEVICE = 4        # Errors happen when run program on device
+    WRONG_ANSWER = 5          # Answer is wrong when compared to a reference output
+    BUILD_TIMEOUT = 6         # Timeout during compilation
+    RUN_TIMEOUT = 7           # Timeout during run
+    UNKNOWN_ERROR = 8         # Unknown error
+
+
+def make_error_msg():
+    """ Get the error message from traceback. """
+    error_msg = str(traceback.format_exc())
+    if len(error_msg) > MAX_ERROR_MSG_LEN:
+        error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \
+            "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN//2:]
+    return error_msg
+
+
+def local_build_worker(index):
+    """ Local builder function. """
+    # We use fork to copy arguments from a global variable.
+    # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
+    if not GLOBAL_BUILD_ARGUMENTS:
+        raise ValueError("GLOBAL_BUILD_ARGUMENTS not found")
+    measure_inputs, build_func, timeout, verbose = GLOBAL_BUILD_ARGUMENTS
+    assert isinstance(build_func, str)
+
+    if build_func == 'default':
+        build_func = tar.tar
+    elif build_func == 'ndk':
+        build_func = ndk.create_shared
+    else:
+        raise ValueError("Invalid build_func" + build_func)
+
+    def timed_func():
+        tic = time.time()
+        inp = measure_inputs[index]
+        task = inp.task
+
+        error_no = MeasureErrorNo.NO_ERROR
+        error_msg = None
+        args = []
+
+        try:
+            sch, args = task.compute_dag.apply_steps_from_state(
+                inp.state)
+        # pylint: disable=broad-except
+        except Exception:
+            error_no = MeasureErrorNo.INSTANTIATION_ERROR
+            error_msg = make_error_msg()
+
+        if error_no == 0:
+            dirname = tempfile.mkdtemp()
+            filename = os.path.join(
+                dirname, "tmp_func." + build_func.output_format)
+
+            try:
+                with transform.PassContext():  # todo(lmzheng): port the unroll pass
+                    func = build_module.build(
+                        sch, args, target=task.target, target_host=task.target_host)
+                func.export_library(filename, build_func)
+            # pylint: disable=broad-except
+            except Exception:
+                error_no = MeasureErrorNo.COMPILE_HOST
+                error_msg = make_error_msg()
+        else:
+            filename = ""
+
+        if verbose == 1:
+            if error_no == MeasureErrorNo.NO_ERROR:
+                print(".", end="")
+            else:
+                print(".E", end="")  # Build error
+        return filename, args, error_no, error_msg, time.time() - tic
+
+    res = call_func_with_timeout(timeout, timed_func)
+    if isinstance(res, TimeoutError):
+        if verbose == 1:
+            print(".T", end="")  # Build timeout
+        res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
+
+    return res
+
+
+@tvm._ffi.register_func("ansor.local_builder.build")
+def local_builder_build(inputs, timeout, n_parallel, build_func, verbose):
+    """ Local builder build function. """
+    # We use fork to copy arguments from a global variable.
+    # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
+    global GLOBAL_BUILD_ARGUMENTS
+    GLOBAL_BUILD_ARGUMENTS = (inputs, build_func, timeout, verbose)
+
+    pool = NoDaemonPool(n_parallel)
+    tuple_res = pool.map(local_build_worker, range(len(inputs)))
+    pool.terminate()
+    pool.join()
+    del pool
+
+    results = []
+    for res in tuple_res:
+        results.append(BuildResult(*res))
+
+    return results
+
+@tvm._ffi.register_func("ansor.local_runner.run")
+def local_run(inputs, build_results, timeout, number, repeat, min_repeat_ms, cooldown_interval,
+              verbose):
+    """ Local runner run function. """

Review comment:
       ```suggestion
       """ Execute Local runner. """
   ```

##########
File path: src/ansor/auto_schedule.cc
##########
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/auto_schedule.cc
+ * \brief The user interface of the Ansor auto-scheduler.
+ */
+
+#include "auto_schedule.h"
+
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace ansor {
+
+TVM_REGISTER_NODE_TYPE(TuningOptionsNode);
+
+TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round,
+                             int verbose, ProgramBuilder builder, ProgramRunner runner,
+                             Array<MeasureCallback> measure_callbacks,
+                             Array<SearchCallback> pre_search_callbacks) {
+  auto node = make_object<TuningOptionsNode>();
+  node->num_measure_trials = num_measure_trials;
+  node->early_stopping = early_stopping;
+  node->num_measures_per_round = num_measures_per_round;
+  node->verbose = verbose;
+  node->builder = std::move(builder);
+  node->runner = std::move(runner);
+  node->measure_callbacks = std::move(measure_callbacks);
+  node->pre_search_callbacks = std::move(pre_search_callbacks);
+  data_ = std::move(node);
+}
+
+std::pair<te::Schedule, Array<te::Tensor> > AutoSchedule(SearchTask task,
+                                                         SearchPolicy search_policy,
+                                                         TuningOptions tuning_options) {
+  // Create a ProgramMeasurer to handle the schedule build and performance measure
+  ProgramMeasurer measurer =
+      ProgramMeasurer(tuning_options->builder, tuning_options->runner,
+                      tuning_options->measure_callbacks, tuning_options->verbose);
+  // Search for the best schedule
+  State state = search_policy->Search(
+      task, tuning_options->num_measure_trials, tuning_options->early_stopping,

Review comment:
       why dont you just pass tuning_options around?

##########
File path: src/ansor/auto_schedule.cc
##########
@@ -0,0 +1,82 @@
+/*

Review comment:
       should this file be called auto_scheduler?

##########
File path: src/ansor/serialization.cc
##########
@@ -0,0 +1,423 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/serialization.cc
+ * \brief Json serialization format for dumping and loading tuning records.
+ */
+
+#include "serialization.h"
+
+#include <dmlc/json.h>
+#include <tvm/runtime/registry.h>
+
+#include <fstream>
+#include <sstream>
+#include <string>
+#include <utility>
+#include <vector>
+
+#include "loop_state.h"
+#include "transform_step.h"
+#include "utils.h"
+
+// Json serialization handler for MeasureInput, MeasureResult
+// (and recursively for SearchTask, State, Step, ...)
+namespace dmlc {
+namespace json {
+
+inline std::vector<int>& IntArrayToVector(std::vector<int>* out,
+                                          const ::tvm::Array<::tvm::Integer>& data) {
+  out->clear();
+  for (const auto& x : data) {
+    CHECK(x.defined());
+    out->push_back(x);
+  }
+  return *out;
+}
+
+template <>
+struct Handler<::tvm::Array<::tvm::ansor::Stage>> {
+  inline static void Write(dmlc::JSONWriter* writer,
+                           const ::tvm::Array<::tvm::ansor::Stage>& data) {
+    writer->BeginArray(false);
+    writer->EndArray();
+  }
+  inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::ansor::Stage>* data) {
+    bool s;
+    reader->BeginArray();
+    s = reader->NextArrayItem();
+    CHECK(!s);
+  }
+};
+
+template <>
+struct Handler<::tvm::Array<::tvm::ansor::Step>> {
+  inline static void Write(dmlc::JSONWriter* writer, const ::tvm::Array<::tvm::ansor::Step>& data) {
+    std::vector<int> tmp;
+    writer->BeginArray(false);
+    for (size_t i = 0; i < data.size(); ++i) {
+      writer->WriteArraySeperator();
+      writer->BeginArray(false);
+      if (auto ps = data[i].as<::tvm::ansor::ReorderStepNode>()) {
+        writer->WriteArrayItem(std::string("RE"));
+        writer->WriteArrayItem(ps->stage_id);
+        writer->WriteArrayItem(IntArrayToVector(&tmp, ps->after_ids));
+      } else if (auto ps = data[i].as<::tvm::ansor::SplitStepNode>()) {
+        writer->WriteArrayItem(std::string("SP"));
+        writer->WriteArrayItem(ps->stage_id);
+        writer->WriteArrayItem(ps->iter_id);
+        writer->WriteArrayItem(ps->extent.defined() ? ::tvm::ansor::GetIntImm(ps->extent) : 0);
+        writer->WriteArrayItem(IntArrayToVector(&tmp, ps->lengths));
+        writer->WriteArrayItem(static_cast<int>(ps->inner_to_outer));
+      } else if (auto ps = data[i].as<::tvm::ansor::FuseStepNode>()) {
+        writer->WriteArrayItem(std::string("FU"));
+        writer->WriteArrayItem(ps->stage_id);
+        writer->WriteArrayItem(IntArrayToVector(&tmp, ps->fused_ids));
+      } else {
+        LOG(FATAL) << "Invalid step: " << data[i];
+      }
+      writer->EndArray();
+    }
+    writer->EndArray();
+  }
+
+  inline static void Read(dmlc::JSONReader* reader, ::tvm::Array<::tvm::ansor::Step>* data) {
+    std::vector<int> int_list;
+    bool s, inner_to_outer;
+    std::string name, scope_name, pragma_type, ti_func_name;
+    int stage_id, iter_id, extent;
+
+    reader->BeginArray();
+    data->clear();
+    while (reader->NextArrayItem()) {
+      reader->BeginArray();
+      s = reader->NextArrayItem();
+      CHECK(s);
+      reader->Read(&name);
+      if (name == "RE") {
+        s = reader->NextArrayItem();
+        CHECK(s);
+        reader->Read(&stage_id);
+        s = reader->NextArrayItem();
+        CHECK(s);
+        reader->Read(&int_list);
+        ::tvm::Array<::tvm::Integer> after_ids;
+        for (const auto& i : int_list) {
+          after_ids.push_back(i);
+        }
+        data->push_back(::tvm::ansor::ReorderStep(stage_id, after_ids));
+      } else if (name == "SP") {
+        s = reader->NextArrayItem();
+        CHECK(s);
+        reader->Read(&stage_id);
+        s = reader->NextArrayItem();
+        CHECK(s);
+        reader->Read(&iter_id);
+        s = reader->NextArrayItem();
+        CHECK(s);
+        reader->Read(&extent);
+        s = reader->NextArrayItem();
+        CHECK(s);
+        reader->Read(&int_list);
+        s = reader->NextArrayItem();
+        CHECK(s);
+        reader->Read(&inner_to_outer);
+        ::tvm::Array<::tvm::Integer> lengths;
+        for (const auto& i : int_list) {
+          lengths.push_back(i);
+        }
+        data->push_back(::tvm::ansor::SplitStep(
+            stage_id, iter_id, extent == 0 ? ::tvm::PrimExpr() : extent, lengths, inner_to_outer));
+      } else if (name == "FU") {
+        s = reader->NextArrayItem();
+        CHECK(s);
+        reader->Read(&stage_id);
+        s = reader->NextArrayItem();
+        CHECK(s);
+        reader->Read(&int_list);
+        ::tvm::Array<::tvm::Integer> fused_ids;
+        for (const auto& i : int_list) {
+          fused_ids.push_back(i);
+        }
+        data->push_back(::tvm::ansor::FuseStep(stage_id, fused_ids));
+      } else {
+        LOG(FATAL) << "Invalid step format";
+      }
+      s = reader->NextArrayItem();
+      CHECK(!s);
+    }
+  }
+};
+
+template <>
+struct Handler<::tvm::ansor::StateNode> {
+  inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::StateNode& data) {
+    writer->BeginArray(false);
+    writer->WriteArrayItem(data.stages);
+    writer->WriteArrayItem(data.transform_steps);
+    writer->EndArray();
+  }
+  inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::StateNode* data) {
+    reader->BeginArray();
+    bool s;
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&data->stages);
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&data->transform_steps);
+    s = reader->NextArrayItem();
+    CHECK(!s);
+  }
+};
+
+template <>
+struct Handler<::tvm::ansor::SearchTaskNode> {
+  inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::SearchTaskNode& data) {
+    writer->BeginArray(false);
+    writer->WriteArrayItem(std::string(data.workload_key));
+    writer->WriteArrayItem(data.target->str());
+    writer->EndArray();
+  }
+  inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::SearchTaskNode* data) {
+    std::string target_str;
+    bool s;
+
+    reader->BeginArray();
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&target_str);
+    data->workload_key = std::move(target_str);
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&target_str);
+    data->target = ::tvm::Target::Create(target_str);
+    s = reader->NextArrayItem();
+    CHECK(!s);
+  }
+};
+
+template <>
+struct Handler<::tvm::ansor::MeasureInputNode> {
+  inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::MeasureInputNode& data) {
+    writer->BeginArray(false);
+    writer->WriteArrayItem(*data.task.operator->());
+    writer->WriteArrayItem(*data.state.operator->());
+    writer->EndArray();
+  }
+  inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::MeasureInputNode* data) {
+    bool s;
+    auto task_node = ::tvm::make_object<::tvm::ansor::SearchTaskNode>();
+    auto state_node = ::tvm::make_object<::tvm::ansor::StateNode>();
+    state_node->complete = true;
+
+    reader->BeginArray();
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(task_node.get());
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(state_node.get());
+    s = reader->NextArrayItem();
+    CHECK(!s);
+
+    data->task = ::tvm::ansor::SearchTask(task_node);
+    data->state = ::tvm::ansor::State(state_node);
+  }
+};
+
+template <>
+struct Handler<::tvm::ansor::MeasureResultNode> {
+  inline static void Write(dmlc::JSONWriter* writer, const ::tvm::ansor::MeasureResultNode& data) {
+    writer->BeginArray(false);
+    writer->WriteArraySeperator();
+    writer->BeginArray(false);
+    for (const auto& x : data.costs) {
+      auto pf = x.as<::tvm::tir::FloatImmNode>();
+      CHECK(pf != nullptr) << "Cost can only contain float values";
+      writer->WriteArrayItem(pf->value);
+    }
+    writer->EndArray();
+    writer->WriteArrayItem(data.error_no);
+    writer->WriteArrayItem(data.all_cost);
+    writer->WriteArrayItem(static_cast<int>((data.timestamp)));
+    writer->EndArray();
+  }
+  inline static void Read(dmlc::JSONReader* reader, ::tvm::ansor::MeasureResultNode* data) {
+    bool s;
+    std::vector<double> tmp;
+
+    reader->BeginArray();
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&tmp);
+    data->costs.clear();
+    for (const auto& i : tmp) {
+      data->costs.push_back(::tvm::FloatImm(::tvm::DataType::Float(64), i));
+    }
+    s = reader->NextArrayItem();
+    CHECK(s);
+    reader->Read(&data->error_no);
+    s = reader->NextArrayItem();

Review comment:
       refactor two call to NextArrayItem() and Read() into a single function?

##########
File path: src/ansor/loop_state.h
##########
@@ -0,0 +1,375 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/loop_state.h
+ * \brief The definition of the "state" in search.

Review comment:
       @merrymercy had comment about this file in python. please update here as well.

##########
File path: src/ansor/utils.h
##########
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/utils.h
+ * \brief Common utilities.
+ */
+
+#ifndef TVM_ANSOR_UTILS_H_
+#define TVM_ANSOR_UTILS_H_
+
+#include <dmlc/common.h>
+#include <tvm/tir/expr.h>
+
+#include <algorithm>
+#include <deque>
+#include <exception>
+#include <future>
+#include <string>
+#include <thread>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+namespace std {
+
+/*! \brief Hash function for std::pair */
+template <typename T1, typename T2>
+struct hash<std::pair<T1, T2>> {
+  std::size_t operator()(const std::pair<T1, T2>& k) const {
+    return ::dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second));
+  }
+};
+
+/*! \brief Hash function for std::tuple */
+template <typename T1, typename T2, typename T3>
+struct hash<std::tuple<T1, T2, T3>> {
+  std::size_t operator()(const std::tuple<T1, T2, T3>& k) const {
+    return ::dmlc::HashCombine(
+        ::dmlc::HashCombine(std::hash<T1>()(std::get<0>(k)), std::hash<T2>()(std::get<1>(k))),
+        std::hash<T3>()(std::get<2>(k)));
+  }
+};
+
+}  // namespace std
+
+namespace tvm {
+namespace ansor {
+
+/********** Utilities for Array, std::string **********/
+/*! \brief Get the first appearance index of elements in an Array */
+template <typename T>
+inline void GetIndices(const Array<T>& array, const Array<T>& to_locate, Array<Integer>* indices) {
+  for (const auto& v : to_locate) {
+    auto it = std::find(array.begin(), array.end(), v);
+    if (it != array.end()) {
+      indices->push_back(it - array.begin());
+    } else {
+      LOG(FATAL) << "Cannot find the item";
+    }
+  }
+}
+
+/*! \brief Get the first appearance index of an element in an Array */
+template <typename T>
+inline int GetIndex(const Array<T>& array, const T& to_locate) {
+  for (size_t i = 0; i < array.size(); ++i) {
+    if (array[i] == to_locate) {
+      return i;
+    }
+  }
+  LOG(FATAL) << "Cannot find the item";
+  return -1;
+}
+
+/*! \brief Replace a sub-string to another sub-string in a string */
+inline void StrReplace(std::string* base, const std::string& from, const std::string& to) {
+  auto pos = base->find(from);
+  while (pos != std::string::npos) {
+    base->replace(pos, from.size(), to);
+    pos = base->find(from, pos + to.size());
+  }
+}
+
+/********** Utilities for TVM Containers / ByteArray **********/
+/*! \brief Compute mean of a FloatImm array */
+inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
+  double sum = 0;
+  if (float_array.empty()) {
+    return 0.0;
+  }
+
+  for (const auto& x : float_array) {
+    auto floatimm = x.as<tir::FloatImmNode>();
+    CHECK(floatimm != nullptr);
+    sum += floatimm->value;
+  }
+  return sum / float_array.size();
+}
+
+/********** Other Utilities **********/
+/*! \brief  Get an int value from an Expr */
+inline int64_t GetIntImm(const PrimExpr& expr) {
+  auto pint = expr.as<IntImmNode>();
+  CHECK(pint != nullptr);
+  return pint->value;
+}
+
+/*! \brief  Compute the product of the lengths of axes */
+inline int64_t AxisLengthProd(const Array<tir::IterVar>& axes) {
+  int64_t ret = 1.0;
+  for (const auto& x : axes) {
+    if (const IntImmNode* imm = x->dom->extent.as<IntImmNode>()) {
+      ret *= imm->value;
+    } else {
+      return -1.0;

Review comment:
       optional or throw error.

##########
File path: src/ansor/transform_step.h
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/transform_step.h
+ * \brief Transformation steps. For each schedule primitive, there is a corresponding transform
+ * step. The implementation of each step consists of 2 parts:
+ * - transform_step.cc: How each step interact with TVM system
+ * - loop_state.cc:     How each step reflect on LoopState

Review comment:
       what do you mean

##########
File path: src/ansor/compute_dag.cc
##########
@@ -0,0 +1,505 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/compute_dag.cc
+ * \brief Compute declaration graph and its related analysis tools.
+ */
+
+#include "compute_dag.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "loop_state.h"
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+using namespace tvm::tir;
+
+TVM_REGISTER_NODE_TYPE(ComputeDAGNode);
+
+// Topo-sort ops from tensors according to their read-write relations.
+// Results are stored in ops
+void TopoSortOps(const Array<te::Tensor>& tensors, Array<te::Operation>* ops) {
+  std::unordered_map<const te::OperationNode*, int> degree;
+  std::unordered_map<const te::OperationNode*, std::vector<const te::OperationNode*> > edge_set;
+  std::unordered_map<const te::OperationNode*, int> priority;
+  std::unordered_set<const te::OperationNode*> visited;
+
+  // traverse to build edge_set and count degree
+  std::vector<const te::OperationNode*> stack;
+  stack.reserve(tensors.size());
+  for (const auto& x : tensors) {
+    stack.push_back(x->op.operator->());
+  }
+
+  int ct = 0;
+  while (!stack.empty()) {
+    const te::OperationNode* op = stack.back();
+    stack.pop_back();
+    if (visited.count(op)) {
+      continue;
+    }
+
+    priority[op] = ct;
+    ct++;
+    visited.insert(op);
+
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      degree[op] = 0;
+    } else if (auto cop = GetRef<te::Operation>(op).as<te::ComputeOpNode>()) {
+      const Array<te::Tensor>& input_tensors = cop->InputTensors();
+      degree[op] = input_tensors.size();
+      for (const auto& ten : input_tensors) {
+        edge_set[ten->op.operator->()].push_back(op);
+        stack.push_back(ten->op.operator->());
+      }
+    } else {
+      LOG(FATAL) << "Unsupported op " << GetRef<te::Operation>(op);
+    }
+  }
+
+  // topo sort
+  ops->clear();
+
+  using Item = std::pair<const te::OperationNode*, int>;
+  auto cmp = [](const Item& left, const Item& right) { return left.second < right.second; };
+  std::priority_queue<Item, std::vector<Item>, decltype(cmp)> queue(cmp);
+  for (const auto& iter : degree) {
+    if (iter.second == 0) {
+      queue.push(Item(iter.first, priority[iter.first]));
+    }
+  }
+
+  ops->reserve(degree.size());
+  while (!queue.empty()) {
+    Item item = queue.top();
+    queue.pop();
+    ops->push_back(GetRef<te::Operation>(item.first));
+    for (const auto& dst : edge_set[item.first]) {
+      degree[dst] -= 1;
+      if (degree[dst] == 0) {
+        queue.push(Item(dst, priority[dst]));
+      }
+    }
+  }
+}
+
+// Estimate number of float operations in an expression
+class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
+ public:
+  double EstimateFlop(const Array<te::Operation>& ops) {

Review comment:
       use option

##########
File path: python/tvm/ansor/utils.py
##########
@@ -0,0 +1,195 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Common utilities for ansor. """
+
+from typing import Hashable
+import multiprocessing
+import multiprocessing.pool
+import queue
+import signal
+
+try:
+    import psutil
+except ImportError:
+    raise ImportError("psutil not found, try `pip install psutil` to fix this")
+
+from tvm.tir import expr
+from tvm.tir.transform import Simplify
+from tvm.ir.transform import Sequential
+from ..te import Tensor, placeholder
+
+
+def get_func_name(func):
+    """Get name of a function.
+
+    Parameters
+    ----------
+    func: Function
+        The target function.
+
+    Returns
+    -------
+    name: str
+        The function name.
+    """
+    return func.func_name if hasattr(func, 'func_name') else func.__name__
+
+
+def get_const_int(exp):
+    """Verifies expr is integer and get the constant value.
+
+    Parameters
+    ----------
+    exp : tvm.Expr or int
+        The input expression.
+
+    Returns
+    -------
+    out_value : int
+        The output.
+    """
+    if isinstance(exp, int):
+        return exp
+    if not isinstance(exp, (expr.IntImm)):

Review comment:
       ```suggestion
       if not isinstance(exp, expr.IntImm):
   ```

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ The base class of measurement callback functions. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.ProgramBuilder")
+class ProgramBuilder(Object):
+    """ The base class of ProgramBuilders. """
+
+    def build(self, measure_inputs, verbose=1):
+        """ Build programs and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program building.
+
+        Returns
+        -------
+        res : List[BuildResult]
+        """
+        return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose)
+
+
+@tvm._ffi.register_object("ansor.ProgramRunner")
+class ProgramRunner(Object):
+    """ The base class of ProgramRunners. """
+
+    def run(self, measure_inputs, build_results, verbose=1):
+        """ Run measurement and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        build_results : List[BuildResult]
+            A List of BuildResult to be ran.
+        verbost : int = 1

Review comment:
       also should turn this into a single config datastructure.

##########
File path: python/tvm/ansor/loop_state.py
##########
@@ -0,0 +1,221 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=unused-import
+
+"""
+The definition of the "state" in search.
+
+Each LoopState corresponds to a specific schedule for its target ComputeDAG.
+A LoopState consists of: 1. a current loop structure; 2. a history of transformations used to
+construct the loop structure.
+The loop structure keeps a preview of how the schedule will finally look like after lowering the
+current state (e.g. number of iterators, the extent of each iterator, the compute_at locations ...).
+During the schedule search process, the loop structure can provide search policy with necessary
+information on how to perform further operations with the current state.
+The transform history is a sequence of TransformStep which will finally be mapped to schedule
+primitives. The steps can also be used for serialization of a state.
+
+The LoopState can be seen as a lightweight loop structure IR specifically for schedule search.
+We don't use the existing TVM IR but to extend a new structure on it is because:
+1. We want fast incremental change to the loop structures, search policy needs to get the immediate
+loop structures update rather than after TVM lowering;
+2. We want serializable transform history for replay, backtracking, and mutation;
+3. We may create some macro schedule primitives that represent the combination of several
+TVM schedule primitives.
+
+When the search is complete, we will lower the state to TVM IR with TVM's schedule primitives.
+Since we share a lot of common objects during search, the transformation is implemented in
+copy on write style. All objects are immutable, which is similar to TVM IR.
+"""
+
+import tvm._ffi
+from tvm.te.tensor import Operation, Tensor
+from tvm.runtime import Object
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.Iterator")
+class Iterator(Object):
+    """ A loop iterator structure. """
+
+
+@tvm._ffi.register_object("ansor.Stage")
+class Stage(Object):
+    """A stage in the compute declaration. Similar to tvm.te.schedule.Stage"""
+
+
+@tvm._ffi.register_object("ansor.State")
+class StateObject(Object):
+    """ The internal State object """
+    def __eq__(self, other):
+        return _ffi_api.StateEqual(self, other)
+
+
+class State:
+    """
+    A state in the search process. It consists of the current loop structure
+    and a history of transformations used to construct it.
+
+    Each State corresponds to a specific schedule for its target ComputeDAG.
+
+    Parameters
+    ----------
+    state_object : StateObject
+        The target StateObject, corresponding to C++ internal State object.
+    dag : ComputeDAG
+        The original target ComputeDAG of this State.
+
+    Notes
+    -----
+    This is a wrapper class of StateObject to deal with copy-on-write property
+    """
+    def __init__(self, state_object, dag):
+        self.state_object = state_object
+        self.compute_dag = dag
+
+        self.stages_cache = None  # A list to cache all stages
+        self.stage_id_map = {}    # A dict maps operation to stage id
+        self._update_stage_id_map()
+
+    @property
+    def stages(self):
+        """
+        Returns
+        -------
+        stages : List[Stage]
+        """
+        if not self.stages_cache:
+            self.stages_cache = self.state_object.stages
+        return self.stages_cache
+
+    @property
+    def stage_ops(self):
+        """
+        Returns
+        -------
+        ops: List[Operation]
+        """
+        if not self.stages_cache:
+            self.stages_cache = self.state_object.stages
+        return [stage.op for stage in self.stages_cache]
+
+    def reorder(self, stage, order):
+        """ Schedule primitive corresponds to te.reorder.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The target Stage to be reordered, can be a Stage order index, Stage operation or stage
+            output tensor.
+        order : List[Iterator]
+            Iterators in the expected order
+        """
+        stage_id = self._resolve_stage_id(stage)
+
+        self.state_object = _ffi_api.StateReorder(self.state_object, stage_id, order)
+        self._clear_cache()
+
+    def split(self, stage, iterator, lengths, inner_to_outer=True):
+        """ Schedule primitive corresponds to te.split.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The target Stage to be split, can be a Stage order index, Stage operation or stage
+            output tensor.
+        iterator : Iterator
+            The iterator to split
+        lengths: List[int]
+            The split factors
+        inner_to_outer: bool = True
+            True to use `factor` to split from inner to outer,
+            False to use `nparts` to split from outer to inner

Review comment:
       ```suggestion
               Whether the factor go from inner to outer, or from outer to inner
   ```
   

##########
File path: src/ansor/compute_dag.cc
##########
@@ -0,0 +1,505 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/compute_dag.cc
+ * \brief Compute declaration graph and its related analysis tools.
+ */
+
+#include "compute_dag.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "loop_state.h"
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+using namespace tvm::tir;
+
+TVM_REGISTER_NODE_TYPE(ComputeDAGNode);
+
+// Topo-sort ops from tensors according to their read-write relations.
+// Results are stored in ops
+void TopoSortOps(const Array<te::Tensor>& tensors, Array<te::Operation>* ops) {

Review comment:
       why not just return?

##########
File path: python/tvm/ansor/serialization.py
##########
@@ -0,0 +1,156 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Serialization and other I/O support for tuning logs (measurement records)"""
+
+import numpy as np
+
+import tvm._ffi
+from tvm.runtime import Object
+from .measure import MeasureCallback, MeasureErrorNo
+from . import _ffi_api
+
+
+@tvm._ffi.register_object("ansor.LogToFile")
+class LogToFile(MeasureCallback):
+    """
+    A measurement callback that writes measurement records into a file.
+
+    Parameters
+    ----------
+    filename : str
+        File name for this callback to write log to.
+    """
+    def __init__(self, filename="ansor_tuning.json"):
+        self.__init_handle_by_constructor__(_ffi_api.LogToFile, filename)
+
+
+@tvm._ffi.register_object("ansor.LogReader")
+class LogReader(Object):
+    """
+    Reader of the json log file.
+
+    Parameters
+    ----------
+    filename : str = "ansor_tuning.json"
+        File name for this reader to load log from.
+    """
+    def __init__(self, filename="ansor_tuning.json"):
+        self.__init_handle_by_constructor__(_ffi_api.LogReader, filename)
+
+    def read_lines(self, max_lines=-1, skip_lines=0):
+        """ Read multiple lines from the log file.
+
+        Parameters
+        ----------
+        max_lines : int = -1
+            The maximum number of lines. -1 means to read all lines.
+        skip_lines : int = 0
+            Skip the first n lines.
+
+        Returns
+        -------
+        inputs : List[MeasureInput]
+            The MeasureInputs loaded from the log file.
+        results : List[MeasureResult]
+            The MeasureResults loaded from the log file.
+        """
+        inputs, results = _ffi_api.LogReaderReadLines(self, max_lines, skip_lines)
+        return inputs, results
+
+    def __iter__(self):
+        while True:
+            ret = _ffi_api.LogReaderReadNext(self)
+            if not ret:
+                break
+            yield ret[0], ret[1]  # (input, result)
+
+
+def load_from_file(filename):
+    """
+    Load measurement records from a file.
+
+    Parameters
+    ----------
+    filename : str
+        File name to load log from.
+
+    Returns
+    -------
+    logs : List[MeasureInput, MeasureResult]
+    """
+    return zip(*LogReader(filename).read_lines())
+
+
+def append_measure_records_to_file(filename, inputs, results):
+    """
+    Aappend measure records to file.

Review comment:
       ```suggestion
       Append measure records to file.
   ```

##########
File path: src/ansor/auto_schedule.h
##########
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/auto_schedule.h
+ * \brief The user interface of the Ansor auto-scheduler. This is the entry structure to get
+ * schedule search requirements from upper level (Python API), and returns a high performance
+ * schedule after search process.
+ */
+
+#ifndef TVM_ANSOR_AUTO_SCHEDULE_H_
+#define TVM_ANSOR_AUTO_SCHEDULE_H_
+
+#include <utility>
+
+#include "measure.h"
+#include "search_policy/search_policy.h"
+
+namespace tvm {
+namespace ansor {
+
+/*! \brief Tuning and measurement options. */
+class TuningOptionsNode : public Object {
+ public:
+  /*! \brief Number of total measurement trials. */
+  int num_measure_trials;
+  /*! \brief Stops early the tuning if no improvement after n measurements. */
+  int early_stopping;
+  /*! \brief The number of programs to be measured at each search round. */
+  int num_measures_per_round;
+  /*! \brief Verbosity level. 0 for silent, 1 to output information during schedule searching. */
+  int verbose;
+  /*! \brief ProgramBuilder which builds the program */
+  ProgramBuilder builder;
+  /*! \brief ProgramRunner which runs the program and measure time costs */
+  ProgramRunner runner;
+  /*! \brief MeasureCallback functions to be called after each measure batch */
+  Array<MeasureCallback> measure_callbacks;
+  /*! \brief SearchCallback functions to be called before schedule search */
+  Array<SearchCallback> pre_search_callbacks;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("num_measure_trials", &num_measure_trials);
+    v->Visit("early_stopping", &early_stopping);
+    v->Visit("num_measures_per_round", &num_measures_per_round);
+    v->Visit("verbose", &verbose);
+    v->Visit("builder", &builder);
+    v->Visit("runner", &runner);
+    v->Visit("measure_callbacks", &measure_callbacks);
+    v->Visit("pre_search_callbacks", &pre_search_callbacks);
+  }
+
+  static constexpr const char* _type_key = "ansor.TuningOptions";
+  TVM_DECLARE_FINAL_OBJECT_INFO(TuningOptionsNode, Object);
+};
+
+/*!
+ * \brief Managed reference to TuningOptionsNode.
+ * \sa TuningOptionsNode
+ */
+class TuningOptions : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor
+   * \param num_measure_trials Number of total measurement trials.
+   * \param early_stopping Stops early the tuning if no improvement after n measurements.
+   * \param num_measures_per_round The number of programs to be measured at each search round.
+   * \param verbose Verbosity level. 0 for silent, 1 to output information during schedule
+   * search.
+   * \param builder ProgramBuilder which builds the program.
+   * \param runner ProgramRunner which runs the program and measure time costs.
+   * \param measure_callbacks MeasureCallback functions to be called after each measure batch.
+   * \param pre_search_callbacks SearchCallback functions to be called before schedule search.
+   */
+  TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round, int verbose,
+                ProgramBuilder builder, ProgramRunner runner,
+                Array<MeasureCallback> measure_callbacks,
+                Array<SearchCallback> pre_search_callbacks);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(TuningOptions, ObjectRef, TuningOptionsNode);
+};
+
+/*!
+ * \brief Auto schedule search for a given compute declaration, by SearchTask.
+ * \param task The target search task.
+ * \param search_policy The search policy to be used for schedule search.
+ * \param tuning_options Tuning and measurement options.
+ * \return A `te::Schedule` and the target `te::Tensor` to be used in `tvm.lower` or `tvm.build`.
+ */
+std::pair<te::Schedule, Array<te::Tensor> > AutoSchedule(SearchTask task,

Review comment:
       we are C++11 or newer. just >>.

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ The base class of measurement callback functions. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.ProgramBuilder")
+class ProgramBuilder(Object):
+    """ The base class of ProgramBuilders. """
+
+    def build(self, measure_inputs, verbose=1):
+        """ Build programs and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program building.
+
+        Returns
+        -------
+        res : List[BuildResult]
+        """
+        return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose)
+
+
+@tvm._ffi.register_object("ansor.ProgramRunner")
+class ProgramRunner(Object):
+    """ The base class of ProgramRunners. """
+
+    def run(self, measure_inputs, build_results, verbose=1):
+        """ Run measurement and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        build_results : List[BuildResult]
+            A List of BuildResult to be ran.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program running.
+
+        Returns
+        -------
+        res : List[MeasureResult]
+        """
+        return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, verbose)
+
+
+@tvm._ffi.register_object("ansor.LocalBuilder")
+class LocalBuilder(ProgramBuilder):
+    """ LocalBuilder use local CPU cores to build programs in parallel.
+
+    Parameters
+    ----------
+    timeout : int = 15
+        The timeout limit for each build.
+    n_parallel : int = multiprocessing.cpu_count()
+        Number of threads used to build in parallel.
+    build_func : str = 'default'
+        The name of registered build function.
+    """
+
+    def __init__(self,
+                 timeout=15,
+                 n_parallel=multiprocessing.cpu_count(),
+                 build_func='default'):
+        self.__init_handle_by_constructor__(
+            _ffi_api.LocalBuilder, timeout, n_parallel, build_func)
+
+
+@tvm._ffi.register_object("ansor.LocalRunner")
+class LocalRunner(ProgramRunner):
+    """ LocalRunner that uses local CPU/GPU to measures the time cost of programs.
+
+    Parameters
+    ----------
+    timeout : int = 10
+        The timeout limit for each run.
+    number : int = 3
+        Number of measure times.
+    repeat : int = 1
+        Number of repeat times in each measure.
+    min_repeat_ms : int = 0
+        The minimum duration of one repeat in milliseconds.
+    cooldown_interval : float = 0.0
+        The cool down interval between two measurements.
+    """
+
+    def __init__(self,
+                 timeout=10,
+                 number=3,
+                 repeat=1,
+                 min_repeat_ms=0,
+                 cooldown_interval=0.0):
+        self.__init_handle_by_constructor__(
+            _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval)
+
+
+class MeasureErrorNo(object):
+    """ Error type for MeasureResult. """
+    NO_ERROR = 0              # No error
+    INSTANTIATION_ERROR = 1   # Errors happen when apply transform steps from init state
+                              # Errors happen when compiling code on host (e.g. tvm.build)
+    COMPILE_HOST = 2
+    COMPILE_DEVICE = 3        # Errors happen when compiling code on device
+                              # (e.g. OpenCL JIT on the device)
+    RUNTIME_DEVICE = 4        # Errors happen when run program on device
+    WRONG_ANSWER = 5          # Answer is wrong when compared to a reference output
+    BUILD_TIMEOUT = 6         # Timeout during compilation
+    RUN_TIMEOUT = 7           # Timeout during run
+    UNKNOWN_ERROR = 8         # Unknown error
+
+
+def make_error_msg():
+    """ Get the error message from traceback. """
+    error_msg = str(traceback.format_exc())
+    if len(error_msg) > MAX_ERROR_MSG_LEN:
+        error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \
+            "\n...\n" + error_msg[-MAX_ERROR_MSG_LEN//2:]
+    return error_msg
+
+
+def local_build_worker(index):
+    """ Local builder function. """
+    # We use fork to copy arguments from a global variable.
+    # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
+    if not GLOBAL_BUILD_ARGUMENTS:
+        raise ValueError("GLOBAL_BUILD_ARGUMENTS not found")
+    measure_inputs, build_func, timeout, verbose = GLOBAL_BUILD_ARGUMENTS
+    assert isinstance(build_func, str)
+
+    if build_func == 'default':
+        build_func = tar.tar
+    elif build_func == 'ndk':
+        build_func = ndk.create_shared
+    else:
+        raise ValueError("Invalid build_func" + build_func)
+
+    def timed_func():
+        tic = time.time()
+        inp = measure_inputs[index]
+        task = inp.task
+
+        error_no = MeasureErrorNo.NO_ERROR
+        error_msg = None
+        args = []
+
+        try:
+            sch, args = task.compute_dag.apply_steps_from_state(
+                inp.state)
+        # pylint: disable=broad-except
+        except Exception:
+            error_no = MeasureErrorNo.INSTANTIATION_ERROR
+            error_msg = make_error_msg()
+
+        if error_no == 0:
+            dirname = tempfile.mkdtemp()
+            filename = os.path.join(
+                dirname, "tmp_func." + build_func.output_format)
+
+            try:
+                with transform.PassContext():  # todo(lmzheng): port the unroll pass
+                    func = build_module.build(
+                        sch, args, target=task.target, target_host=task.target_host)
+                func.export_library(filename, build_func)
+            # pylint: disable=broad-except
+            except Exception:
+                error_no = MeasureErrorNo.COMPILE_HOST
+                error_msg = make_error_msg()
+        else:
+            filename = ""
+
+        if verbose == 1:
+            if error_no == MeasureErrorNo.NO_ERROR:
+                print(".", end="")
+            else:
+                print(".E", end="")  # Build error
+        return filename, args, error_no, error_msg, time.time() - tic
+
+    res = call_func_with_timeout(timeout, timed_func)
+    if isinstance(res, TimeoutError):
+        if verbose == 1:
+            print(".T", end="")  # Build timeout
+        res = None, [], MeasureErrorNo.BUILD_TIMEOUT, None, timeout
+
+    return res
+
+
+@tvm._ffi.register_func("ansor.local_builder.build")
+def local_builder_build(inputs, timeout, n_parallel, build_func, verbose):
+    """ Local builder build function. """
+    # We use fork to copy arguments from a global variable.
+    # This can avoid expensive serialization of TVM IR when using multiprocessing.Pool
+    global GLOBAL_BUILD_ARGUMENTS
+    GLOBAL_BUILD_ARGUMENTS = (inputs, build_func, timeout, verbose)

Review comment:
       this seems like bad program structure. can you make this local and pass it around? you still dont have to serialize stuff as everything pass is by reference.
   what do you mean by multiprocessing tool?

##########
File path: python/tvm/ansor/utils.py
##########
@@ -0,0 +1,195 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Common utilities for ansor. """
+
+from typing import Hashable
+import multiprocessing
+import multiprocessing.pool
+import queue
+import signal
+
+try:
+    import psutil
+except ImportError:
+    raise ImportError("psutil not found, try `pip install psutil` to fix this")
+
+from tvm.tir import expr
+from tvm.tir.transform import Simplify
+from tvm.ir.transform import Sequential
+from ..te import Tensor, placeholder
+
+
+def get_func_name(func):
+    """Get name of a function.
+
+    Parameters
+    ----------
+    func: Function
+        The target function.
+
+    Returns
+    -------
+    name: str
+        The function name.
+    """
+    return func.func_name if hasattr(func, 'func_name') else func.__name__
+
+
+def get_const_int(exp):
+    """Verifies expr is integer and get the constant value.
+
+    Parameters
+    ----------
+    exp : tvm.Expr or int
+        The input expression.
+
+    Returns
+    -------
+    out_value : int
+        The output.
+    """
+    if isinstance(exp, int):
+        return exp
+    if not isinstance(exp, (expr.IntImm)):
+        opt = Sequential([Simplify()])
+        exp = opt(exp)
+    if not isinstance(exp, (expr.IntImm)):
+        raise ValueError("Expect value to be constant int")
+    return exp.value
+
+
+def get_const_tuple(in_tuple):
+    """Verifies input tuple is IntImm, returns tuple of int.
+
+    Parameters
+    ----------
+    in_tuple : tuple of Expr

Review comment:
       Tuple[Expr]

##########
File path: src/ansor/utils.h
##########
@@ -0,0 +1,292 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/utils.h
+ * \brief Common utilities.
+ */
+
+#ifndef TVM_ANSOR_UTILS_H_
+#define TVM_ANSOR_UTILS_H_
+
+#include <dmlc/common.h>
+#include <tvm/tir/expr.h>
+
+#include <algorithm>
+#include <deque>
+#include <exception>
+#include <future>
+#include <string>
+#include <thread>
+#include <tuple>
+#include <utility>
+#include <vector>
+
+namespace std {
+
+/*! \brief Hash function for std::pair */
+template <typename T1, typename T2>
+struct hash<std::pair<T1, T2>> {
+  std::size_t operator()(const std::pair<T1, T2>& k) const {
+    return ::dmlc::HashCombine(std::hash<T1>()(k.first), std::hash<T2>()(k.second));
+  }
+};
+
+/*! \brief Hash function for std::tuple */
+template <typename T1, typename T2, typename T3>
+struct hash<std::tuple<T1, T2, T3>> {
+  std::size_t operator()(const std::tuple<T1, T2, T3>& k) const {
+    return ::dmlc::HashCombine(
+        ::dmlc::HashCombine(std::hash<T1>()(std::get<0>(k)), std::hash<T2>()(std::get<1>(k))),
+        std::hash<T3>()(std::get<2>(k)));
+  }
+};
+
+}  // namespace std
+
+namespace tvm {
+namespace ansor {
+
+/********** Utilities for Array, std::string **********/
+/*! \brief Get the first appearance index of elements in an Array */
+template <typename T>
+inline void GetIndices(const Array<T>& array, const Array<T>& to_locate, Array<Integer>* indices) {
+  for (const auto& v : to_locate) {
+    auto it = std::find(array.begin(), array.end(), v);
+    if (it != array.end()) {
+      indices->push_back(it - array.begin());
+    } else {
+      LOG(FATAL) << "Cannot find the item";
+    }
+  }
+}
+
+/*! \brief Get the first appearance index of an element in an Array */
+template <typename T>
+inline int GetIndex(const Array<T>& array, const T& to_locate) {
+  for (size_t i = 0; i < array.size(); ++i) {
+    if (array[i] == to_locate) {
+      return i;
+    }
+  }
+  LOG(FATAL) << "Cannot find the item";
+  return -1;
+}
+
+/*! \brief Replace a sub-string to another sub-string in a string */
+inline void StrReplace(std::string* base, const std::string& from, const std::string& to) {
+  auto pos = base->find(from);
+  while (pos != std::string::npos) {
+    base->replace(pos, from.size(), to);
+    pos = base->find(from, pos + to.size());
+  }
+}
+
+/********** Utilities for TVM Containers / ByteArray **********/
+/*! \brief Compute mean of a FloatImm array */
+inline double FloatArrayMean(const Array<PrimExpr>& float_array) {
+  double sum = 0;
+  if (float_array.empty()) {
+    return 0.0;
+  }
+
+  for (const auto& x : float_array) {
+    auto floatimm = x.as<tir::FloatImmNode>();
+    CHECK(floatimm != nullptr);
+    sum += floatimm->value;
+  }
+  return sum / float_array.size();
+}
+
+/********** Other Utilities **********/
+/*! \brief  Get an int value from an Expr */
+inline int64_t GetIntImm(const PrimExpr& expr) {
+  auto pint = expr.as<IntImmNode>();
+  CHECK(pint != nullptr);
+  return pint->value;
+}
+
+/*! \brief  Compute the product of the lengths of axes */

Review comment:
       ```suggestion
   /*! \brief Compute the product of the lengths of axes */
   ```

##########
File path: python/tvm/ansor/workload_registry.py
##########
@@ -0,0 +1,170 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Workload registration and serialization.
+
+We use a json string to represent a workload (a compute dag).

Review comment:
       ```suggestion
   We use a json string to represent a workload (a computation graph).
   ```

##########
File path: python/tvm/ansor/measure.py
##########
@@ -0,0 +1,386 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Distributed measurement infrastructure to measure the runtime costs of tensor programs
+
+These functions are responsible for building the tvm module, uploading it to
+remote devices, recording the running time costs, and checking the correctness of the output.
+
+We implement these in python to utilize python's multiprocessing and error handling
+"""
+
+import os
+import time
+import shutil
+import traceback
+import tempfile
+import multiprocessing
+
+import tvm._ffi
+from tvm.runtime import Object, module, ndarray
+from tvm.driver import build_module
+from tvm.ir import transform
+from tvm.contrib import tar, ndk
+
+from . import _ffi_api
+from .utils import get_const_tuple, NoDaemonPool, call_func_with_timeout
+
+# The maximum length of error message
+MAX_ERROR_MSG_LEN = 512
+
+# Global variables used in build function
+GLOBAL_BUILD_ARGUMENTS = None
+
+@tvm._ffi.register_object("ansor.MeasureCallback")
+class MeasureCallback(Object):
+    """ The base class of measurement callback functions. """
+
+
+@tvm._ffi.register_object("ansor.MeasureInput")
+class MeasureInput(Object):
+    """ Store the input of a measurement.
+
+    Parameters
+    ----------
+    task : SearchTask
+        The target SearchTask.
+    state : State
+        The current State to be measured.
+    """
+    def __init__(self, task, state):
+        self.__init_handle_by_constructor__(_ffi_api.MeasureInput, task, state.state_object)
+
+
+@tvm._ffi.register_object("ansor.BuildResult")
+class BuildResult(Object):
+    """ Store the result of a build.
+
+    Parameters
+    ----------
+    filename : Optional[str]
+        The filename of built binary file.
+    args : List[Tensor]
+        The arguments.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    time_cost : float
+        The time cost of build.
+    """
+    def __init__(self, filename, args, error_no, error_msg, time_cost):
+        filename = filename if filename else ""
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.BuildResult, filename, args, error_no, error_msg, time_cost)
+
+
+@tvm._ffi.register_object("ansor.MeasureResult")
+class MeasureResult(Object):
+    """ Store the results of a measurement.
+
+    Parameters
+    ----------
+    costs : List[float]
+        The time costs of execution.
+    error_no : int
+        The error code.
+    error_msg : Optional[str]
+        The error message if there is any error.
+    all_cost : float
+        The time cost of build and run.
+    timestamp : float
+        The time stamps of this measurement.
+    """
+    def __init__(self, costs, error_no, error_msg, all_cost, timestamp):
+        error_msg = error_msg if error_msg else ""
+
+        self.__init_handle_by_constructor__(
+            _ffi_api.MeasureResult, costs, error_no,
+            error_msg, all_cost, timestamp)
+
+
+@tvm._ffi.register_object("ansor.ProgramBuilder")
+class ProgramBuilder(Object):
+    """ The base class of ProgramBuilders. """
+
+    def build(self, measure_inputs, verbose=1):
+        """ Build programs and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program building.
+
+        Returns
+        -------
+        res : List[BuildResult]
+        """
+        return _ffi_api.ProgramBuilderBuild(self, measure_inputs, verbose)
+
+
+@tvm._ffi.register_object("ansor.ProgramRunner")
+class ProgramRunner(Object):
+    """ The base class of ProgramRunners. """
+
+    def run(self, measure_inputs, build_results, verbose=1):
+        """ Run measurement and return results.
+
+        Parameters
+        ----------
+        measure_inputs : List[MeasureInput]
+            A List of MeasureInput.
+        build_results : List[BuildResult]
+            A List of BuildResult to be ran.
+        verbost : int = 1
+            Verbosity level. 0 for silent, 1 to output information during program running.
+
+        Returns
+        -------
+        res : List[MeasureResult]
+        """
+        return _ffi_api.ProgramRunnerRun(self, measure_inputs, build_results, verbose)
+
+
+@tvm._ffi.register_object("ansor.LocalBuilder")
+class LocalBuilder(ProgramBuilder):
+    """ LocalBuilder use local CPU cores to build programs in parallel.
+
+    Parameters
+    ----------
+    timeout : int = 15
+        The timeout limit for each build.
+    n_parallel : int = multiprocessing.cpu_count()
+        Number of threads used to build in parallel.
+    build_func : str = 'default'
+        The name of registered build function.
+    """
+
+    def __init__(self,
+                 timeout=15,
+                 n_parallel=multiprocessing.cpu_count(),
+                 build_func='default'):
+        self.__init_handle_by_constructor__(
+            _ffi_api.LocalBuilder, timeout, n_parallel, build_func)
+
+
+@tvm._ffi.register_object("ansor.LocalRunner")
+class LocalRunner(ProgramRunner):
+    """ LocalRunner that uses local CPU/GPU to measures the time cost of programs.
+
+    Parameters
+    ----------
+    timeout : int = 10
+        The timeout limit for each run.

Review comment:
       get the unit back.

##########
File path: src/ansor/auto_schedule.cc
##########
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/auto_schedule.cc
+ * \brief The user interface of the Ansor auto-scheduler.
+ */
+
+#include "auto_schedule.h"
+
+#include <tvm/runtime/registry.h>
+
+namespace tvm {
+namespace ansor {
+
+TVM_REGISTER_NODE_TYPE(TuningOptionsNode);
+
+TuningOptions::TuningOptions(int num_measure_trials, int early_stopping, int num_measures_per_round,
+                             int verbose, ProgramBuilder builder, ProgramRunner runner,
+                             Array<MeasureCallback> measure_callbacks,
+                             Array<SearchCallback> pre_search_callbacks) {
+  auto node = make_object<TuningOptionsNode>();
+  node->num_measure_trials = num_measure_trials;
+  node->early_stopping = early_stopping;
+  node->num_measures_per_round = num_measures_per_round;
+  node->verbose = verbose;
+  node->builder = std::move(builder);
+  node->runner = std::move(runner);
+  node->measure_callbacks = std::move(measure_callbacks);
+  node->pre_search_callbacks = std::move(pre_search_callbacks);
+  data_ = std::move(node);
+}
+
+std::pair<te::Schedule, Array<te::Tensor> > AutoSchedule(SearchTask task,
+                                                         SearchPolicy search_policy,
+                                                         TuningOptions tuning_options) {
+  // Create a ProgramMeasurer to handle the schedule build and performance measure
+  ProgramMeasurer measurer =
+      ProgramMeasurer(tuning_options->builder, tuning_options->runner,
+                      tuning_options->measure_callbacks, tuning_options->verbose);
+  // Search for the best schedule
+  State state = search_policy->Search(
+      task, tuning_options->num_measure_trials, tuning_options->early_stopping,

Review comment:
       it will be more extensible.

##########
File path: src/ansor/loop_state.cc
##########
@@ -0,0 +1,447 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/loop_state.cc
+ * \brief An lightweight IR (intermediate representation) for loop structures.
+ * see ansor/loop_state.h for more explanation.
+ */
+
+#include "loop_state.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+
+#include <utility>
+
+#include "transform_step.h"
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+TVM_REGISTER_OBJECT_TYPE(StepNode);
+TVM_REGISTER_NODE_TYPE(StageNode);
+TVM_REGISTER_NODE_TYPE(StateNode);
+TVM_REGISTER_NODE_TYPE(IteratorNode);
+
+/********** Iterator **********/
+Iterator::Iterator(String name, Range range, IteratorType iter_type,
+                   IteratorAnnotation annotation) {
+  auto node = make_object<IteratorNode>();
+  node->name = std::move(name);
+  node->range = std::move(range);
+  node->iter_type = iter_type;
+  node->annotation = annotation;
+  data_ = std::move(node);
+}
+
+/********** Stage **********/
+Stage::Stage(te::Operation op) {
+  auto node = make_object<StageNode>();
+  if (op->IsInstance<te::ComputeOpNode>()) {
+    node->op_type = kCompute;
+    auto* pop = op.as<te::ComputeOpNode>();
+    for (const auto& axis : pop->axis) {
+      node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kSpace, kNone));
+    }
+    for (const auto& axis : pop->reduce_axis) {
+      node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kReduce, kNone));
+    }
+  } else if (op->IsInstance<te::PlaceholderOpNode>()) {
+    node->op_type = kPlaceholder;
+  } else {
+    LOG(FATAL) << "Unsupported operator type" << op->_type_key;
+  }
+
+  node->compute_at = kRoot;
+  node->op = std::move(op);
+  node->attrs.auto_unroll_max_step = 0;
+  node->attrs.storage_offset = 0;
+  data_ = std::move(node);
+}
+
+Stage::Stage(te::Operation op, StageType op_type, const Array<Iterator>& iters,
+             ComputeAtType compute_at, StageAttributes attrs) {
+  auto node = make_object<StageNode>();
+  node->op = std::move(op);
+  node->op_type = op_type;
+  node->iters = iters;
+  node->compute_at = compute_at;
+  node->attrs = attrs;
+  data_ = std::move(node);
+}
+
+Stage::Stage(te::Operation op, StageType op_type, Array<Iterator>&& iters, ComputeAtType compute_at,
+             StageAttributes attrs) {
+  auto node = make_object<StageNode>();
+  node->op = std::move(op);
+  node->op_type = op_type;
+  node->iters = std::move(iters);
+  node->compute_at = compute_at;
+  node->attrs = attrs;
+  data_ = std::move(node);
+}
+
+/********** State **********/
+State::State(const Array<te::Operation>& ops) {
+  auto node = make_object<StateNode>();
+  for (const auto& op : ops) {
+    node->stages.push_back(Stage(op));
+  }
+  node->complete = true;
+  data_ = std::move(node);
+}
+
+/********** Schedule primitives apis for state **********/
+void State::reorder(int stage_id, const Array<Iterator>& order) {
+  const Stage& stage = operator->()->stages[stage_id];
+  CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
+                                              << "should be specified";
+  Array<Integer> after_ids;
+  GetIndices(stage->iters, order, &after_ids);
+  ReorderStep step = ReorderStep(stage_id, after_ids);
+  CopyOnWrite()->transform_steps.push_back(step);
+  DoReorderStep(step);
+}
+
+Array<Iterator> State::split(int stage_id, const Iterator& it, const Array<Integer>& lengths,
+                             bool inner_to_outer) {
+  const Stage& stage = operator->()->stages[stage_id];
+  SplitStep step =
+      SplitStep(stage_id, GetIndex(stage->iters, it),
+                it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return DoSplitStep(step);
+}
+
+Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
+  const Stage& stage = operator->()->stages[stage_id];
+  Array<Integer> indices;
+  GetIndices(stage->iters, iters, &indices);
+  FuseStep step = FuseStep(stage_id, indices);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return DoFuseStep(step);
+}
+
+/********** Step implementations for state **********/
+void State::DoReorderStep(const ReorderStep& step) {
+  const Stage& stage = operator->()->stages[step->stage_id];
+  Array<Iterator> iters;
+  for (auto x : step->after_ids) {
+    iters.push_back(stage->iters[x]);
+  }
+  StateNode* pstate = CopyOnWrite();
+  pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(iters),
+                                           stage->compute_at, stage->attrs));
+}
+
+// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep
+Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id, const Array<Integer>& lengths,
+                                         bool inner_to_outer) {
+  const Stage& stage = operator->()->stages[stage_id];
+  const Iterator& it = stage->iters[iter_id];
+
+  PrimExpr tosplit_min, tosplit_extent;
+  if (it->range.defined()) {
+    tosplit_min = it->range->min;
+    tosplit_extent = it->range->extent;
+  } else {
+    tosplit_min = tosplit_extent = PrimExpr();
+  }
+
+  Array<Iterator> outs;
+  for (size_t i = 0; i < lengths.size(); ++i) {
+    PrimExpr l;
+    String name;
+    if (inner_to_outer) {
+      l = lengths[lengths.size() - i - 1];
+      name = it->name + "." + std::to_string(lengths.size() - i);
+    } else {
+      l = lengths[i];
+      name = it->name + "." + std::to_string(i);
+    }
+    Iterator res;
+    if (l.defined() && tosplit_min.defined() && tosplit_extent.defined()) {
+      res = Iterator(name, Range::FromMinExtent(tosplit_min, l), it->iter_type, kNone);
+      tosplit_min = 0;
+      tosplit_extent = indexdiv(tosplit_extent + l - 1, l);
+    } else {
+      res = Iterator(name, Range(), it->iter_type, kNone);
+      tosplit_min = tosplit_extent = PrimExpr();
+    }
+    outs.push_back(std::move(res));
+  }
+
+  Range range;
+  if (tosplit_min.defined() && tosplit_extent.defined()) {
+    range = Range::FromMinExtent(tosplit_min, tosplit_extent);
+  }
+  if (inner_to_outer) {
+    outs.push_back(Iterator(it->name + ".0", range, it->iter_type, kNone));
+    // Reverse the Iterator array
+    Array<Iterator> temp(outs.rbegin(), outs.rend());
+    outs = std::move(temp);
+  } else {
+    outs.push_back(
+        Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_type, kNone));
+  }
+
+  Array<Iterator> new_iters;
+  new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id);
+  new_iters.insert(new_iters.end(), outs.begin(), outs.end());
+  new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end());
+
+  StateNode* pstate = CopyOnWrite();
+  pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+                                     stage->compute_at, stage->attrs));
+
+  return outs;
+}
+
+Array<Iterator> State::DoSplitStep(const SplitStep& step) {
+  return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, step->inner_to_outer);
+}
+
+Iterator State::DoFuseStep(const FuseStep& step) {
+  int stage_id = step->stage_id;
+  const Stage& stage = operator->()->stages[stage_id];
+
+  String new_name;
+  PrimExpr new_extent = 1;
+  IteratorType new_iter_type = kSpecial;
+
+  for (size_t i = 0; i < step->fused_ids.size(); ++i) {
+    if (i > 0) {
+      CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1);
+    }
+
+    const Iterator& it = stage->iters[step->fused_ids[i]];
+    new_name = new_name + it->name + "@";
+
+    if (it->range.defined() && new_extent.defined()) {
+      new_extent = new_extent * it->range->extent;
+    } else {
+      new_extent = PrimExpr();
+    }
+
+    if (i == 0) {
+      new_iter_type = it->iter_type;
+    } else {
+      if (new_iter_type != it->iter_type) {
+        new_iter_type = kMixed;
+      }
+    }
+  }
+
+  Range range;
+  if (new_extent.defined()) {
+    range = Range::FromMinExtent(0, new_extent);
+  }
+  Iterator new_it = Iterator(new_name, range, new_iter_type, kNone);
+  Array<Iterator> new_iters;
+  new_iters.insert(new_iters.end(), stage->iters.begin(),
+                   stage->iters.begin() + step->fused_ids.front());
+  new_iters.push_back(new_it);
+  new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1,
+                   stage->iters.end());
+
+  StateNode* pstate = CopyOnWrite();
+  pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+                                     stage->compute_at, stage->attrs));
+
+  return new_it;
+}
+
+void State::DoSteps(const ComputeDAG& dag) {
+  CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
+
+  // Use complete rate for the study in the paper
+  const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE");
+  double complete_rate = -1.0;
+  if (complete_rate_str) {
+    complete_rate = std::stod(complete_rate_str);
+  }
+  size_t ct = 0;
+  for (const auto& step : operator->()->transform_steps) {
+    if (complete_rate >= 0 && ct++ > operator->()->transform_steps.size() * complete_rate) {
+      break;
+    }
+    if (auto ps = step.as<ReorderStepNode>()) {
+      DoReorderStep(GetRef<ReorderStep>(ps));
+    } else if (auto ps = step.as<SplitStepNode>()) {
+      DoSplitStep(GetRef<SplitStep>(ps));
+    } else if (auto ps = step.as<FuseStepNode>()) {
+      DoFuseStep(GetRef<FuseStep>(ps));
+    } else {
+      LOG(FATAL) << "Invalid step: " << step;
+    }
+  }
+}
+
+// Print stage to ostream
+void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent,

Review comment:
       same for const StateNode* state. why not reference?

##########
File path: python/tvm/ansor/utils.py
##########
@@ -0,0 +1,195 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Common utilities for ansor. """
+
+from typing import Hashable
+import multiprocessing
+import multiprocessing.pool
+import queue
+import signal
+
+try:
+    import psutil
+except ImportError:
+    raise ImportError("psutil not found, try `pip install psutil` to fix this")
+
+from tvm.tir import expr
+from tvm.tir.transform import Simplify
+from tvm.ir.transform import Sequential
+from ..te import Tensor, placeholder
+
+
+def get_func_name(func):
+    """Get name of a function.
+
+    Parameters
+    ----------
+    func: Function
+        The target function.
+
+    Returns
+    -------
+    name: str
+        The function name.
+    """
+    return func.func_name if hasattr(func, 'func_name') else func.__name__
+
+
+def get_const_int(exp):
+    """Verifies expr is integer and get the constant value.
+
+    Parameters
+    ----------
+    exp : tvm.Expr or int
+        The input expression.
+
+    Returns
+    -------
+    out_value : int
+        The output.
+    """
+    if isinstance(exp, int):
+        return exp
+    if not isinstance(exp, (expr.IntImm)):
+        opt = Sequential([Simplify()])
+        exp = opt(exp)
+    if not isinstance(exp, (expr.IntImm)):
+        raise ValueError("Expect value to be constant int")
+    return exp.value
+
+
+def get_const_tuple(in_tuple):
+    """Verifies input tuple is IntImm, returns tuple of int.
+
+    Parameters
+    ----------
+    in_tuple : tuple of Expr
+        The input.
+
+    Returns
+    -------
+    out_tuple : tuple of int
+        The output.
+    """
+    return tuple(get_const_int(x) for x in in_tuple)
+
+
+
+def list_to_tuple(x):
+    """ Convert a list to a tuple recursively. """
+    assert isinstance(x, list)
+    return tuple(list_to_tuple(y) if isinstance(y, list) else y for y in x)
+
+
+def serialize_args(args):
+    """
+    Serialize arguments of a function to a hashable and jsonable tuple.
+    Currently this is mainly used for tvm.tensor.Tensor
+    """
+    ret = []
+    for t in args:
+        if isinstance(t, Tensor):
+            t = ('TENSOR', get_const_tuple(t.shape), t.dtype)
+        elif isinstance(t, list):
+            t = list_to_tuple(t)
+
+        assert isinstance(t, Hashable), str(t) + " is not hashable"
+        ret.append(t)
+
+    return tuple(ret)
+
+
+def deserialize_args(args):
+    """The inverse function of :code:`serialize_args`"""
+    ret = []
+    for t in args:
+        if isinstance(t, (tuple, list)) and t[0] == 'TENSOR':
+            ret.append(placeholder(shape=t[1], dtype=t[2]))
+        else:
+            ret.append(t)
+    return ret
+
+
+class NoDaemonProcess(multiprocessing.Process):
+    @property
+    def daemon(self):
+        return False
+
+    @daemon.setter
+    def daemon(self, value):
+        pass
+
+
+class NoDaemonContext(type(multiprocessing.get_context())):
+    Process = NoDaemonProcess
+
+
+class NoDaemonPool(multiprocessing.pool.Pool):
+    """A no daemon pool version of multiprocessing.Pool.
+    This allows us to start new processings inside the worker function"""
+
+    def __init__(self, *args, **kwargs):
+        kwargs['context'] = NoDaemonContext()
+        super().__init__(*args, **kwargs)
+
+    def __reduce__(self):
+        pass
+
+
+def kill_child_processes(parent_pid, sig=signal.SIGTERM):
+    """kill all child processes recursively"""
+    try:
+        parent = psutil.Process(parent_pid)
+    except psutil.NoSuchProcess:
+        return
+    children = parent.children(recursive=True)
+    for process in children:
+        try:
+            process.send_signal(sig)
+        except psutil.NoSuchProcess:
+            return
+
+
+def call_func_with_timeout(timeout, func, args=(), kwargs=None):
+    """Call a function with timeout"""
+    def func_wrapper(que):
+        if kwargs:
+            que.put(func(*args, **kwargs))
+        else:
+            que.put(func(*args))
+
+    que = multiprocessing.Queue(2)
+    process = multiprocessing.Process(target=func_wrapper, args=(que,))
+    process.start()
+    process.join(timeout)
+
+    try:
+        res = que.get(block=False)
+    except queue.Empty:
+        res = TimeoutError()
+
+    # clean queue and process
+    kill_child_processes(process.pid)
+    process.terminate()

Review comment:
       is this code exception safe? rewrite in raii style or put in finalizer if otherwise.

##########
File path: python/tvm/ansor/utils.py
##########
@@ -0,0 +1,195 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+""" Common utilities for ansor. """
+
+from typing import Hashable
+import multiprocessing
+import multiprocessing.pool
+import queue
+import signal
+
+try:
+    import psutil
+except ImportError:
+    raise ImportError("psutil not found, try `pip install psutil` to fix this")
+
+from tvm.tir import expr
+from tvm.tir.transform import Simplify
+from tvm.ir.transform import Sequential
+from ..te import Tensor, placeholder
+
+
+def get_func_name(func):
+    """Get name of a function.
+
+    Parameters
+    ----------
+    func: Function
+        The target function.
+
+    Returns
+    -------
+    name: str
+        The function name.
+    """
+    return func.func_name if hasattr(func, 'func_name') else func.__name__
+
+
+def get_const_int(exp):
+    """Verifies expr is integer and get the constant value.
+
+    Parameters
+    ----------
+    exp : tvm.Expr or int
+        The input expression.
+
+    Returns
+    -------
+    out_value : int
+        The output.
+    """
+    if isinstance(exp, int):
+        return exp
+    if not isinstance(exp, (expr.IntImm)):
+        opt = Sequential([Simplify()])
+        exp = opt(exp)
+    if not isinstance(exp, (expr.IntImm)):

Review comment:
       ```suggestion
       if not isinstance(exp, expr.IntImm):
   ```

##########
File path: src/ansor/transform_step.h
##########
@@ -0,0 +1,225 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/transform_step.h
+ * \brief Transformation steps. For each schedule primitive, there is a corresponding transform
+ * step. The implementation of each step consists of 2 parts:
+ * - transform_step.cc: How each step interact with TVM system
+ * - loop_state.cc:     How each step reflect on LoopState
+ *
+ * \note Adding a new transform step.

Review comment:
       ```suggestion
    * \note To add a new transform step:
   ```

##########
File path: src/ansor/measure.h
##########
@@ -0,0 +1,430 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/measure.h
+ * \brief Distributed measurement infrastructure to measure the runtime costs of tensor programs.
+ * MeasureInput -> BuildeResult -> MeasureResult
+ */
+
+#ifndef TVM_ANSOR_MEASURE_H_
+#define TVM_ANSOR_MEASURE_H_
+
+#include <unordered_map>
+#include <utility>
+
+#include "loop_state.h"
+#include "search_task.h"
+
+namespace tvm {
+namespace ansor {
+
+class SearchPolicy;
+class MeasureInput;
+class MeasureResult;
+
+/*! \brief The error code of one measurement */
+enum MeasureErrorNO {
+  /*! \brief No error. */
+  kNoError = 0,
+  /*! \brief Errors happen when apply transform steps from init state. */
+  kInstantiationError = 1,
+  /*! \brief Errors happen when compiling code on host. (when build module) */
+  kCompileHostError = 2,
+  /*! \brief Errors happen when compiling code on device. (when load module) */
+  kCompileDeviceError = 3,
+  /*! \brief Errors happen when run program on device. */
+  kRuntimeDeviceError = 4,
+  /*! \brief Answer is wrong when compared to a reference output. */
+  kWrongAnswerError = 5,
+  /*! \brief Timeout during compilation. */
+  kBuildTimeoutError = 6,
+  /*! \brief Timeout during run. */
+  kRunTimeoutError = 7,
+  /*! \brief Unknown error. */
+  kUnknonwError = 8,
+};
+
+// Inputs and results of one measurement
+
+/*! \brief Store the input of a measurement */
+class MeasureInputNode : public Object {
+ public:
+  /*! \brief The search task. */
+  SearchTask task;
+  /*! \brief The program state to be measured. */
+  State state;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("task", &task);
+    v->Visit("state", &state);
+  }
+
+  /*! \brief Do deep copy. */
+  MeasureInput copy() const;
+
+  static constexpr const char* _type_key = "ansor.MeasureInput";
+  TVM_DECLARE_FINAL_OBJECT_INFO(MeasureInputNode, Object);
+};
+
+/*!
+ * \brief Managed reference to MeasureInputNode.
+ * \sa MeasureInputNode
+ */
+class MeasureInput : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param task The target SearchTeask.
+   * \param state The target State.
+   */
+  MeasureInput(SearchTask task, State state);
+
+  TVM_DEFINE_OBJECT_REF_METHODS(MeasureInput, ObjectRef, MeasureInputNode);
+};
+
+/*! \brief Store the result of a build. */
+class BuildResultNode : public Object {
+ public:
+  /*! \brief The filename of built binary file. */
+  String filename;
+  /*! \brief The arguments. */
+  Array<te::Tensor> args;
+  /*! \brief The error code. (0 means no error, see MeasureErrorNO) */
+  int error_no;
+  /*! \brief The error message if there is any error. */
+  String error_msg;
+  /*! \brief The time cost of build. */
+  double time_cost;
+
+  void VisitAttrs(tvm::AttrVisitor* v) {
+    v->Visit("filename", &filename);
+    v->Visit("args", &args);
+    v->Visit("error_no", &error_no);
+    v->Visit("error_msg", &error_msg);
+    v->Visit("time_cost", &time_cost);
+  }
+
+  static constexpr const char* _type_key = "ansor.BuildResult";
+  TVM_DECLARE_FINAL_OBJECT_INFO(BuildResultNode, Object);
+};
+
+/*!
+ * \brief Managed reference to BuildResultNode.
+ * \sa BuildResultNode
+ */
+class BuildResult : public ObjectRef {
+ public:
+  /*!
+   * \brief The constructor.
+   * \param filename The filename of built binary file.
+   * \param args The arguments.
+   * \param error_no The error code.
+   * \param error_msg The error message if there is any error.
+   * \param time_cost The time cost of build.
+   */
+  BuildResult(String filename, Array<te::Tensor> args, int error_no, String error_msg,
+              double time_cost);
+  TVM_DEFINE_OBJECT_REF_METHODS(BuildResult, ObjectRef, BuildResultNode);
+};
+
+/*! \brief Store the results of a measurement. */
+class MeasureResultNode : public Object {

Review comment:
       MeasurmentNode

##########
File path: src/ansor/compute_dag.cc
##########
@@ -0,0 +1,505 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/compute_dag.cc
+ * \brief Compute declaration graph and its related analysis tools.
+ */
+
+#include "compute_dag.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "loop_state.h"
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+using namespace tvm::tir;
+
+TVM_REGISTER_NODE_TYPE(ComputeDAGNode);
+
+// Topo-sort ops from tensors according to their read-write relations.
+// Results are stored in ops
+void TopoSortOps(const Array<te::Tensor>& tensors, Array<te::Operation>* ops) {
+  std::unordered_map<const te::OperationNode*, int> degree;
+  std::unordered_map<const te::OperationNode*, std::vector<const te::OperationNode*> > edge_set;
+  std::unordered_map<const te::OperationNode*, int> priority;
+  std::unordered_set<const te::OperationNode*> visited;
+
+  // traverse to build edge_set and count degree
+  std::vector<const te::OperationNode*> stack;
+  stack.reserve(tensors.size());
+  for (const auto& x : tensors) {
+    stack.push_back(x->op.operator->());
+  }
+
+  int ct = 0;
+  while (!stack.empty()) {
+    const te::OperationNode* op = stack.back();
+    stack.pop_back();
+    if (visited.count(op)) {
+      continue;
+    }
+
+    priority[op] = ct;
+    ct++;
+    visited.insert(op);
+
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      degree[op] = 0;
+    } else if (auto cop = GetRef<te::Operation>(op).as<te::ComputeOpNode>()) {
+      const Array<te::Tensor>& input_tensors = cop->InputTensors();
+      degree[op] = input_tensors.size();
+      for (const auto& ten : input_tensors) {
+        edge_set[ten->op.operator->()].push_back(op);
+        stack.push_back(ten->op.operator->());
+      }
+    } else {
+      LOG(FATAL) << "Unsupported op " << GetRef<te::Operation>(op);
+    }
+  }
+
+  // topo sort
+  ops->clear();
+
+  using Item = std::pair<const te::OperationNode*, int>;
+  auto cmp = [](const Item& left, const Item& right) { return left.second < right.second; };
+  std::priority_queue<Item, std::vector<Item>, decltype(cmp)> queue(cmp);
+  for (const auto& iter : degree) {
+    if (iter.second == 0) {
+      queue.push(Item(iter.first, priority[iter.first]));
+    }
+  }
+
+  ops->reserve(degree.size());
+  while (!queue.empty()) {
+    Item item = queue.top();
+    queue.pop();
+    ops->push_back(GetRef<te::Operation>(item.first));
+    for (const auto& dst : edge_set[item.first]) {
+      degree[dst] -= 1;
+      if (degree[dst] == 0) {
+        queue.push(Item(dst, priority[dst]));
+      }
+    }
+  }
+}
+
+// Estimate number of float operations in an expression
+class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
+ public:
+  double EstimateFlop(const Array<te::Operation>& ops) {
+    double ret = 0;
+    for (const auto& op : ops) {
+      if (auto pop = op.as<te::ComputeOpNode>()) {
+        double num_element = AxisLengthProd(pop->axis);
+        if (num_element == -1) {
+          fail = true;
+          break;
+        }
+        double op_per_element = 0;
+        for (const auto& x : pop->body) {
+          op_per_element += VisitExpr(x);
+        }
+        ret += num_element * op_per_element;
+      } else if (op->IsInstance<te::PlaceholderOpNode>()) {
+        {}  // do nothing
+      } else {
+        LOG(FATAL) << "Invalid op type " << op;
+      }
+    }
+
+    return fail ? -1 : ret;
+  }
+
+  double VisitExpr_(const ReduceNode* op) final {
+    uint64_t num_iter = 1;
+    for (const auto& x : op->axis) {
+      if (auto imm = x->dom->extent.as<IntImmNode>()) {
+        num_iter *= imm->value;
+      } else {
+        fail = true;
+        num_iter = -1;
+      }
+    }
+    double body_flop = 0;
+    for (size_t i = 0; i < op->combiner->result.size(); ++i) {
+      body_flop += VisitExpr(op->combiner->result[i]);
+      body_flop += VisitExpr(op->source[i]);
+    }
+    return num_iter * body_flop;
+  }
+
+  double VisitExpr_(const FloatImmNode* op) final { return 0.0; }

Review comment:
       accessing stuff take some time. why 0 instead of 1?

##########
File path: src/ansor/compute_dag.cc
##########
@@ -0,0 +1,505 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/compute_dag.cc
+ * \brief Compute declaration graph and its related analysis tools.
+ */
+
+#include "compute_dag.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "loop_state.h"
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+using namespace tvm::tir;
+
+TVM_REGISTER_NODE_TYPE(ComputeDAGNode);
+
+// Topo-sort ops from tensors according to their read-write relations.
+// Results are stored in ops
+void TopoSortOps(const Array<te::Tensor>& tensors, Array<te::Operation>* ops) {
+  std::unordered_map<const te::OperationNode*, int> degree;
+  std::unordered_map<const te::OperationNode*, std::vector<const te::OperationNode*> > edge_set;
+  std::unordered_map<const te::OperationNode*, int> priority;
+  std::unordered_set<const te::OperationNode*> visited;
+
+  // traverse to build edge_set and count degree
+  std::vector<const te::OperationNode*> stack;
+  stack.reserve(tensors.size());
+  for (const auto& x : tensors) {
+    stack.push_back(x->op.operator->());
+  }
+
+  int ct = 0;
+  while (!stack.empty()) {
+    const te::OperationNode* op = stack.back();
+    stack.pop_back();
+    if (visited.count(op)) {
+      continue;
+    }
+
+    priority[op] = ct;
+    ct++;
+    visited.insert(op);
+
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      degree[op] = 0;
+    } else if (auto cop = GetRef<te::Operation>(op).as<te::ComputeOpNode>()) {
+      const Array<te::Tensor>& input_tensors = cop->InputTensors();
+      degree[op] = input_tensors.size();
+      for (const auto& ten : input_tensors) {
+        edge_set[ten->op.operator->()].push_back(op);
+        stack.push_back(ten->op.operator->());
+      }
+    } else {
+      LOG(FATAL) << "Unsupported op " << GetRef<te::Operation>(op);
+    }
+  }
+
+  // topo sort
+  ops->clear();
+
+  using Item = std::pair<const te::OperationNode*, int>;
+  auto cmp = [](const Item& left, const Item& right) { return left.second < right.second; };
+  std::priority_queue<Item, std::vector<Item>, decltype(cmp)> queue(cmp);
+  for (const auto& iter : degree) {
+    if (iter.second == 0) {
+      queue.push(Item(iter.first, priority[iter.first]));
+    }
+  }
+
+  ops->reserve(degree.size());
+  while (!queue.empty()) {
+    Item item = queue.top();
+    queue.pop();
+    ops->push_back(GetRef<te::Operation>(item.first));
+    for (const auto& dst : edge_set[item.first]) {
+      degree[dst] -= 1;
+      if (degree[dst] == 0) {
+        queue.push(Item(dst, priority[dst]));
+      }
+    }
+  }
+}
+
+// Estimate number of float operations in an expression
+class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
+ public:
+  double EstimateFlop(const Array<te::Operation>& ops) {
+    double ret = 0;
+    for (const auto& op : ops) {
+      if (auto pop = op.as<te::ComputeOpNode>()) {
+        double num_element = AxisLengthProd(pop->axis);
+        if (num_element == -1) {
+          fail = true;
+          break;
+        }
+        double op_per_element = 0;
+        for (const auto& x : pop->body) {
+          op_per_element += VisitExpr(x);
+        }
+        ret += num_element * op_per_element;
+      } else if (op->IsInstance<te::PlaceholderOpNode>()) {
+        {}  // do nothing
+      } else {
+        LOG(FATAL) << "Invalid op type " << op;
+      }
+    }
+
+    return fail ? -1 : ret;
+  }
+
+  double VisitExpr_(const ReduceNode* op) final {
+    uint64_t num_iter = 1;
+    for (const auto& x : op->axis) {
+      if (auto imm = x->dom->extent.as<IntImmNode>()) {
+        num_iter *= imm->value;
+      } else {
+        fail = true;
+        num_iter = -1;
+      }
+    }
+    double body_flop = 0;
+    for (size_t i = 0; i < op->combiner->result.size(); ++i) {
+      body_flop += VisitExpr(op->combiner->result[i]);
+      body_flop += VisitExpr(op->source[i]);
+    }
+    return num_iter * body_flop;
+  }
+
+  double VisitExpr_(const FloatImmNode* op) final { return 0.0; }
+  double VisitExpr_(const IntImmNode* op) final { return 0.0; }
+  double VisitExpr_(const ProducerLoadNode* op) final { return 0.0; }
+
+  double VisitExpr_(const CastNode* op) final { return VisitExpr(op->value); }
+  double VisitExpr_(const VarNode* op) final { return 0.0; }
+
+  double VisitExpr_(const SelectNode* op) final {
+    return VisitExpr(op->condition) +
+           std::max(VisitExpr(op->true_value), VisitExpr(op->false_value));
+  }
+
+#define VisitBinary(Node) \
+  double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a) + VisitExpr(op->b); }
+#define VisitUnary(Node) \
+  double VisitExpr_(const Node* op) final { return 1.0 + VisitExpr(op->a); }
+
+  VisitBinary(AddNode);
+  VisitBinary(SubNode);
+  VisitBinary(MulNode);
+  VisitBinary(DivNode);
+  VisitBinary(ModNode);
+  VisitBinary(FloorDivNode);
+  VisitBinary(FloorModNode);
+  VisitBinary(MaxNode);
+  VisitBinary(MinNode);
+  VisitBinary(EQNode);
+  VisitBinary(NENode);
+  VisitBinary(LTNode);
+  VisitBinary(LENode);
+  VisitBinary(GTNode);
+  VisitBinary(GENode);
+  VisitBinary(AndNode);
+  VisitBinary(OrNode);
+  VisitUnary(NotNode);
+
+  double VisitExpr_(const CallNode* op) final {
+    double ret = 0.0;
+    for (const auto& x : op->args) {
+      ret += VisitExpr(x);
+    }
+    return ret;

Review comment:
       what about the function call?

##########
File path: src/ansor/loop_state.cc
##########
@@ -0,0 +1,447 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/loop_state.cc
+ * \brief An lightweight IR (intermediate representation) for loop structures.
+ * see ansor/loop_state.h for more explanation.
+ */
+
+#include "loop_state.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+
+#include <utility>
+
+#include "transform_step.h"
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+TVM_REGISTER_OBJECT_TYPE(StepNode);
+TVM_REGISTER_NODE_TYPE(StageNode);
+TVM_REGISTER_NODE_TYPE(StateNode);
+TVM_REGISTER_NODE_TYPE(IteratorNode);
+
+/********** Iterator **********/
+Iterator::Iterator(String name, Range range, IteratorType iter_type,
+                   IteratorAnnotation annotation) {
+  auto node = make_object<IteratorNode>();
+  node->name = std::move(name);
+  node->range = std::move(range);
+  node->iter_type = iter_type;
+  node->annotation = annotation;
+  data_ = std::move(node);
+}
+
+/********** Stage **********/
+Stage::Stage(te::Operation op) {
+  auto node = make_object<StageNode>();
+  if (op->IsInstance<te::ComputeOpNode>()) {
+    node->op_type = kCompute;
+    auto* pop = op.as<te::ComputeOpNode>();
+    for (const auto& axis : pop->axis) {
+      node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kSpace, kNone));
+    }
+    for (const auto& axis : pop->reduce_axis) {
+      node->iters.push_back(Iterator(CleanName(axis->var->name_hint), axis->dom, kReduce, kNone));
+    }
+  } else if (op->IsInstance<te::PlaceholderOpNode>()) {
+    node->op_type = kPlaceholder;
+  } else {
+    LOG(FATAL) << "Unsupported operator type" << op->_type_key;
+  }
+
+  node->compute_at = kRoot;
+  node->op = std::move(op);
+  node->attrs.auto_unroll_max_step = 0;
+  node->attrs.storage_offset = 0;
+  data_ = std::move(node);
+}
+
+Stage::Stage(te::Operation op, StageType op_type, const Array<Iterator>& iters,
+             ComputeAtType compute_at, StageAttributes attrs) {
+  auto node = make_object<StageNode>();
+  node->op = std::move(op);
+  node->op_type = op_type;
+  node->iters = iters;
+  node->compute_at = compute_at;
+  node->attrs = attrs;
+  data_ = std::move(node);
+}
+
+Stage::Stage(te::Operation op, StageType op_type, Array<Iterator>&& iters, ComputeAtType compute_at,
+             StageAttributes attrs) {
+  auto node = make_object<StageNode>();
+  node->op = std::move(op);
+  node->op_type = op_type;
+  node->iters = std::move(iters);
+  node->compute_at = compute_at;
+  node->attrs = attrs;
+  data_ = std::move(node);
+}
+
+/********** State **********/
+State::State(const Array<te::Operation>& ops) {
+  auto node = make_object<StateNode>();
+  for (const auto& op : ops) {
+    node->stages.push_back(Stage(op));
+  }
+  node->complete = true;
+  data_ = std::move(node);
+}
+
+/********** Schedule primitives apis for state **********/
+void State::reorder(int stage_id, const Array<Iterator>& order) {
+  const Stage& stage = operator->()->stages[stage_id];
+  CHECK_EQ(order.size(), stage->iters.size()) << "The order of all iterators "
+                                              << "should be specified";
+  Array<Integer> after_ids;
+  GetIndices(stage->iters, order, &after_ids);
+  ReorderStep step = ReorderStep(stage_id, after_ids);
+  CopyOnWrite()->transform_steps.push_back(step);
+  DoReorderStep(step);
+}
+
+Array<Iterator> State::split(int stage_id, const Iterator& it, const Array<Integer>& lengths,
+                             bool inner_to_outer) {
+  const Stage& stage = operator->()->stages[stage_id];
+  SplitStep step =
+      SplitStep(stage_id, GetIndex(stage->iters, it),
+                it->range.defined() ? it->range->extent : PrimExpr(), lengths, inner_to_outer);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return DoSplitStep(step);
+}
+
+Iterator State::fuse(int stage_id, const Array<Iterator>& iters) {
+  const Stage& stage = operator->()->stages[stage_id];
+  Array<Integer> indices;
+  GetIndices(stage->iters, iters, &indices);
+  FuseStep step = FuseStep(stage_id, indices);
+  CopyOnWrite()->transform_steps.push_back(step);
+  return DoFuseStep(step);
+}
+
+/********** Step implementations for state **********/
+void State::DoReorderStep(const ReorderStep& step) {
+  const Stage& stage = operator->()->stages[step->stage_id];
+  Array<Iterator> iters;
+  for (auto x : step->after_ids) {
+    iters.push_back(stage->iters[x]);
+  }
+  StateNode* pstate = CopyOnWrite();
+  pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(iters),
+                                           stage->compute_at, stage->attrs));
+}
+
+// common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep
+Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id, const Array<Integer>& lengths,
+                                         bool inner_to_outer) {
+  const Stage& stage = operator->()->stages[stage_id];
+  const Iterator& it = stage->iters[iter_id];
+
+  PrimExpr tosplit_min, tosplit_extent;
+  if (it->range.defined()) {
+    tosplit_min = it->range->min;
+    tosplit_extent = it->range->extent;
+  } else {
+    tosplit_min = tosplit_extent = PrimExpr();
+  }
+
+  Array<Iterator> outs;
+  for (size_t i = 0; i < lengths.size(); ++i) {
+    PrimExpr l;
+    String name;
+    if (inner_to_outer) {
+      l = lengths[lengths.size() - i - 1];
+      name = it->name + "." + std::to_string(lengths.size() - i);
+    } else {
+      l = lengths[i];
+      name = it->name + "." + std::to_string(i);
+    }
+    Iterator res;
+    if (l.defined() && tosplit_min.defined() && tosplit_extent.defined()) {
+      res = Iterator(name, Range::FromMinExtent(tosplit_min, l), it->iter_type, kNone);
+      tosplit_min = 0;
+      tosplit_extent = indexdiv(tosplit_extent + l - 1, l);
+    } else {
+      res = Iterator(name, Range(), it->iter_type, kNone);
+      tosplit_min = tosplit_extent = PrimExpr();
+    }
+    outs.push_back(std::move(res));
+  }
+
+  Range range;
+  if (tosplit_min.defined() && tosplit_extent.defined()) {
+    range = Range::FromMinExtent(tosplit_min, tosplit_extent);
+  }
+  if (inner_to_outer) {
+    outs.push_back(Iterator(it->name + ".0", range, it->iter_type, kNone));
+    // Reverse the Iterator array
+    Array<Iterator> temp(outs.rbegin(), outs.rend());
+    outs = std::move(temp);
+  } else {
+    outs.push_back(
+        Iterator(it->name + "." + std::to_string(lengths.size()), range, it->iter_type, kNone));
+  }
+
+  Array<Iterator> new_iters;
+  new_iters.insert(new_iters.end(), stage->iters.begin(), stage->iters.begin() + iter_id);
+  new_iters.insert(new_iters.end(), outs.begin(), outs.end());
+  new_iters.insert(new_iters.end(), stage->iters.begin() + iter_id + 1, stage->iters.end());
+
+  StateNode* pstate = CopyOnWrite();
+  pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+                                     stage->compute_at, stage->attrs));
+
+  return outs;
+}
+
+Array<Iterator> State::DoSplitStep(const SplitStep& step) {
+  return DoSplitStepCommon(step->stage_id, step->iter_id, step->lengths, step->inner_to_outer);
+}
+
+Iterator State::DoFuseStep(const FuseStep& step) {
+  int stage_id = step->stage_id;
+  const Stage& stage = operator->()->stages[stage_id];
+
+  String new_name;
+  PrimExpr new_extent = 1;
+  IteratorType new_iter_type = kSpecial;
+
+  for (size_t i = 0; i < step->fused_ids.size(); ++i) {
+    if (i > 0) {
+      CHECK_EQ(step->fused_ids[i]->value, step->fused_ids[i - 1]->value + 1);
+    }
+
+    const Iterator& it = stage->iters[step->fused_ids[i]];
+    new_name = new_name + it->name + "@";
+
+    if (it->range.defined() && new_extent.defined()) {
+      new_extent = new_extent * it->range->extent;
+    } else {
+      new_extent = PrimExpr();
+    }
+
+    if (i == 0) {
+      new_iter_type = it->iter_type;
+    } else {
+      if (new_iter_type != it->iter_type) {
+        new_iter_type = kMixed;
+      }
+    }
+  }
+
+  Range range;
+  if (new_extent.defined()) {
+    range = Range::FromMinExtent(0, new_extent);
+  }
+  Iterator new_it = Iterator(new_name, range, new_iter_type, kNone);
+  Array<Iterator> new_iters;
+  new_iters.insert(new_iters.end(), stage->iters.begin(),
+                   stage->iters.begin() + step->fused_ids.front());
+  new_iters.push_back(new_it);
+  new_iters.insert(new_iters.end(), stage->iters.begin() + step->fused_ids.back() + 1,
+                   stage->iters.end());
+
+  StateNode* pstate = CopyOnWrite();
+  pstate->stages.Set(stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+                                     stage->compute_at, stage->attrs));
+
+  return new_it;
+}
+
+void State::DoSteps(const ComputeDAG& dag) {
+  CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";
+
+  // Use complete rate for the study in the paper
+  const char* complete_rate_str = getenv("ANSOR_PROGRAM_COMPLETE_RATE");
+  double complete_rate = -1.0;
+  if (complete_rate_str) {
+    complete_rate = std::stod(complete_rate_str);
+  }
+  size_t ct = 0;
+  for (const auto& step : operator->()->transform_steps) {
+    if (complete_rate >= 0 && ct++ > operator->()->transform_steps.size() * complete_rate) {
+      break;
+    }
+    if (auto ps = step.as<ReorderStepNode>()) {
+      DoReorderStep(GetRef<ReorderStep>(ps));
+    } else if (auto ps = step.as<SplitStepNode>()) {
+      DoSplitStep(GetRef<SplitStep>(ps));
+    } else if (auto ps = step.as<FuseStepNode>()) {
+      DoFuseStep(GetRef<FuseStep>(ps));
+    } else {
+      LOG(FATAL) << "Invalid step: " << step;
+    }
+  }
+}
+
+// Print stage to ostream
+void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent,

Review comment:
       why not &? other ppl use ostream& rather then ostream*.

##########
File path: src/ansor/compute_dag.cc
##########
@@ -0,0 +1,505 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file ansor/compute_dag.cc
+ * \brief Compute declaration graph and its related analysis tools.
+ */
+
+#include "compute_dag.h"
+
+#include <tvm/runtime/registry.h>
+#include <tvm/te/operation.h>
+#include <tvm/te/schedule.h>
+#include <tvm/te/schedule_pass.h>
+#include <tvm/tir/stmt_functor.h>
+
+#include <algorithm>
+#include <queue>
+#include <unordered_map>
+#include <unordered_set>
+#include <vector>
+
+#include "loop_state.h"
+#include "utils.h"
+
+namespace tvm {
+namespace ansor {
+
+using namespace tvm::tir;
+
+TVM_REGISTER_NODE_TYPE(ComputeDAGNode);
+
+// Topo-sort ops from tensors according to their read-write relations.
+// Results are stored in ops
+void TopoSortOps(const Array<te::Tensor>& tensors, Array<te::Operation>* ops) {
+  std::unordered_map<const te::OperationNode*, int> degree;
+  std::unordered_map<const te::OperationNode*, std::vector<const te::OperationNode*> > edge_set;
+  std::unordered_map<const te::OperationNode*, int> priority;
+  std::unordered_set<const te::OperationNode*> visited;
+
+  // traverse to build edge_set and count degree
+  std::vector<const te::OperationNode*> stack;
+  stack.reserve(tensors.size());
+  for (const auto& x : tensors) {
+    stack.push_back(x->op.operator->());
+  }
+
+  int ct = 0;
+  while (!stack.empty()) {
+    const te::OperationNode* op = stack.back();
+    stack.pop_back();
+    if (visited.count(op)) {
+      continue;
+    }
+
+    priority[op] = ct;
+    ct++;
+    visited.insert(op);
+
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      degree[op] = 0;
+    } else if (auto cop = GetRef<te::Operation>(op).as<te::ComputeOpNode>()) {
+      const Array<te::Tensor>& input_tensors = cop->InputTensors();
+      degree[op] = input_tensors.size();
+      for (const auto& ten : input_tensors) {
+        edge_set[ten->op.operator->()].push_back(op);
+        stack.push_back(ten->op.operator->());
+      }
+    } else {
+      LOG(FATAL) << "Unsupported op " << GetRef<te::Operation>(op);
+    }
+  }
+
+  // topo sort
+  ops->clear();
+
+  using Item = std::pair<const te::OperationNode*, int>;
+  auto cmp = [](const Item& left, const Item& right) { return left.second < right.second; };
+  std::priority_queue<Item, std::vector<Item>, decltype(cmp)> queue(cmp);
+  for (const auto& iter : degree) {
+    if (iter.second == 0) {
+      queue.push(Item(iter.first, priority[iter.first]));
+    }
+  }
+
+  ops->reserve(degree.size());
+  while (!queue.empty()) {
+    Item item = queue.top();
+    queue.pop();
+    ops->push_back(GetRef<te::Operation>(item.first));
+    for (const auto& dst : edge_set[item.first]) {
+      degree[dst] -= 1;
+      if (degree[dst] == 0) {
+        queue.push(Item(dst, priority[dst]));
+      }
+    }
+  }
+}
+
+// Estimate number of float operations in an expression
+class FlopEstimator : public ExprFunctor<double(const PrimExpr& n)> {
+ public:
+  double EstimateFlop(const Array<te::Operation>& ops) {
+    double ret = 0;
+    for (const auto& op : ops) {
+      if (auto pop = op.as<te::ComputeOpNode>()) {
+        double num_element = AxisLengthProd(pop->axis);

Review comment:
       use option




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org