You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2021/01/25 11:27:51 UTC

[tvm] branch main updated: [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing (#7317)

This is an automated email from the ASF dual-hosted git repository.

zhaowu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e6d5318  [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing (#7317)
e6d5318 is described below

commit e6d53185b96cc39f2aaec5e86ae11ca0ac675b8a
Author: Cody Yu <co...@gmail.com>
AuthorDate: Mon Jan 25 03:27:34 2021 -0800

    [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing (#7317)
    
    * [AutoScheduler] Separate shapes from DAG hash and enable schedule sharing
    
    * Update CI logs
    
    * lint
    
    * fix registry
    
    * add message; fix layout rewrite mismatch
    
    * update message
    
    * support other formats
---
 include/tvm/auto_scheduler/compute_dag.h           |   7 ++
 python/tvm/auto_scheduler/compute_dag.py           |  35 +++---
 python/tvm/auto_scheduler/measure_record.py        | 126 +++++++++++++++++++--
 python/tvm/auto_scheduler/relay_integration.py     |   6 +-
 python/tvm/auto_scheduler/search_task.py           |   8 +-
 python/tvm/auto_scheduler/utils.py                 |  27 +++++
 python/tvm/auto_scheduler/workload_registry.py     |  37 +++---
 src/auto_scheduler/compute_dag.cc                  | 109 ++++++++++--------
 .../python/unittest/test_auto_scheduler_measure.py |  33 ++++++
 .../ci_logs/resnet-18-NHWC-B1-cuda.json            |  48 ++++----
 .../ci_logs/resnet-50-NHWC-B1-llvm.json            |  55 +++++----
 11 files changed, 342 insertions(+), 149 deletions(-)

diff --git a/include/tvm/auto_scheduler/compute_dag.h b/include/tvm/auto_scheduler/compute_dag.h
index 1e3f097..a87563e 100755
--- a/include/tvm/auto_scheduler/compute_dag.h
+++ b/include/tvm/auto_scheduler/compute_dag.h
@@ -263,6 +263,13 @@ class ComputeDAG : public ObjectRef {
   String PrintStepsAsPython(const Array<Step>& transform_steps) const;
 
   /*!
+   * \brief Print the compute DAG to a string. This is also used to generate the ComputeDAG hash.
+   * \param simple_mode Simple mode will only include the op names and brief compute.
+   * \return The ComputeDAG in a string.
+   */
+  String PrintDAG(bool simple_mode = false) const;
+
+  /*!
    * \brief Fill the correct bound information for a given state by calling ir_pass::InferBound.
    * The states can lose complete bound information after some transform steps (e.g., compute_at).
    * We can call this function to infer and fill all the bound information.
diff --git a/python/tvm/auto_scheduler/compute_dag.py b/python/tvm/auto_scheduler/compute_dag.py
index a7f200a..948f277 100755
--- a/python/tvm/auto_scheduler/compute_dag.py
+++ b/python/tvm/auto_scheduler/compute_dag.py
@@ -19,11 +19,11 @@
 """ The auto-scheduler's computational graph and related program analyses. """
 
 import hashlib
+import json
 
 import tvm._ffi
 from tvm.runtime import Object
 from tvm.runtime._ffi_node_api import LoadJSON, SaveJSON
-from tvm.te import ComputeOp, PlaceholderOp
 
 from . import _ffi_api
 from .loop_state import State, StateObject
@@ -220,32 +220,23 @@ class ComputeDAG(Object):
         state_obj = state if isinstance(state, StateObject) else state.state_object
         return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state_obj)
 
-    def hash_key(self):
-        """Return the hash key of this compute DAG.
+    def workload_key(self):
+        """Return the workload key of this compute DAG.
+        The workload key is a JSON string from a tuple of (hash-key, tensor shapes...)
 
         Returns
         -------
         key: str
-            The hash key of this compute DAG
+            The workload key of this compute DAG
         """
-        # TODO(merrymercy): Implement this more carefully and move this to c++ as a member function
-        # of ComputeDAG
-        str_key = ""
-        for op in self.ops:
-            t = op.output(0)
-            if isinstance(op, PlaceholderOp):
-                str_key += "placeholder,"
-                str_key += str(get_const_tuple(t.shape)) + ","
-                str_key += t.dtype + ";"
-            elif isinstance(op, ComputeOp):
-                str_key += str(t.op.body) + ","
-                str_key += str(get_const_tuple(t.shape)) + ","
-                str_key += t.dtype + ";"
-            else:
-                raise ValueError("Invalid op: " + op)
-
-        str_key = str_key.encode(encoding="utf-8")
-        return hashlib.md5(str_key).hexdigest()
+        str_dag = _ffi_api.ComputeDAGPrintDAG(self, True)
+        str_dag = str_dag.encode(encoding="utf-8")
+        hash_key = hashlib.md5(str_dag).hexdigest()
+
+        io_shapes = []
+        for tensor in self.tensors:
+            io_shapes += get_const_tuple(tensor.shape)
+        return json.dumps([hash_key] + io_shapes)
 
     def __str__(self):
         # pretty print
diff --git a/python/tvm/auto_scheduler/measure_record.py b/python/tvm/auto_scheduler/measure_record.py
index 35e5e9b..9eaef18 100644
--- a/python/tvm/auto_scheduler/measure_record.py
+++ b/python/tvm/auto_scheduler/measure_record.py
@@ -27,6 +27,7 @@ import numpy as np
 import tvm._ffi
 from tvm.runtime import Object
 from .measure import MeasureErrorNo, MeasureCallback
+from .utils import decode_workload_key
 from . import _ffi_api
 
 logger = logging.getLogger("auto_scheduler")
@@ -59,8 +60,37 @@ class RecordReader(Object):
     """
 
     def __init__(self, filename):
+        # a set to prevent print duplicated message
+        self.messages = set()
+
         self.__init_handle_by_constructor__(_ffi_api.RecordReader, filename)
 
+    def check_workload_key(self, inputs):
+        """Check and throw warnings for records with old format workload key.
+
+        Parameters
+        ----------
+        inputs: List[MeasureInput]
+            The measure inputs to be checked.
+
+        Notes
+        -----
+        This checker could be deprecated in the future.
+        """
+        for inp in inputs:
+            _, args = decode_workload_key(inp.task.workload_key)
+            if args is None:
+                continue
+            if not args:
+                msg = (
+                    "MeasureInput with old format workload key %s should be updated "
+                    "using the script from https://github.com/apache/tvm/pull/7317."
+                    % inp.task.workload_key
+                )
+                if msg not in self.messages:
+                    self.messages.add(msg)
+                    logger.warning(msg)
+
     def read_lines(self, max_lines=None, skip_lines=0):
         """Read multiple lines from the log file.
 
@@ -88,6 +118,7 @@ class RecordReader(Object):
         inputs, results = _ffi_api.RecordReaderReadLines(
             self, max_lines if max_lines else -1, skip_lines
         )
+        self.check_workload_key(inputs)
         return inputs, results
 
     def __iter__(self):
@@ -95,9 +126,69 @@ class RecordReader(Object):
             ret = _ffi_api.RecordReaderReadNext(self)
             if not ret:
                 break
+            self.check_workload_key([ret[0]])
             yield ret[0], ret[1]  # (input, result)
 
 
+def calc_workload_dis_factor(target_workload_key, workload_key):
+    """Calculate the distance factor of the workload to the target workload.
+    If two workloads are not compatible at all (i.e., different compute DAG or function),
+    then the distance factor is "inf". Otherwise, we calculate the factor by traversing
+    the workload arguments, which are the arguments of the compute function,
+    or the output shapes for the ComputeDAG. The factor is calculated by the following rules:
+
+    1. For non-zero integer values: `product(target_arg / candidate_arg)`.
+    2. For non-integer or zero values: "inf" if not equal else 1.
+
+    As a result, factor=1 is the optimal when two workloads are identical.
+
+    Parameters
+    ----------
+    target_workload_key: str
+        The target workload key in JSON string.
+
+    workload_key: str
+        The candidate workload key in JSON string.
+
+    Returns
+    -------
+    dis_f: float
+        The distance factor.
+    """
+
+    def flatten_list(inp):
+        ret = []
+        for elt in inp:
+            if isinstance(elt, list):
+                ret += flatten_list(elt)
+            else:
+                ret.append(elt)
+        return ret
+
+    target_key, target_args = decode_workload_key(target_workload_key)
+    target_args = flatten_list(target_args) if target_args is not None else []
+    key, args = decode_workload_key(workload_key)
+    args = flatten_list(args) if args is not None else []
+
+    # Not even the same func/DAG.
+    if key != target_key or len(target_args) != len(args):
+        return float("inf")
+
+    dis_f = 1
+    for target_arg, arg in zip(target_args, args):
+        if isinstance(target_arg, int):
+            if target_arg == 0 or arg == 0:
+                if target_arg != arg:
+                    return float("inf")
+            elif target_arg % arg != 0:
+                return float("inf")
+            else:
+                dis_f *= target_arg / arg
+        elif target_arg != arg:
+            return float("inf")
+    return dis_f
+
+
 def load_record_from_string(record):
     """
     Load the measure record from string.
@@ -174,7 +265,7 @@ def save_records(filename, inputs, results):
     _ffi_api.SaveRecords(filename, inputs, results)
 
 
-def load_best_record(filename, workload_key=None, target=None):
+def load_best_record(filename, workload_key=None, target=None, include_compatible=False):
     """Return the best measurement pair form a log file. This may return none results if
     there is no legal measure pair with the specified workload_key/target found from the log file.
 
@@ -188,6 +279,8 @@ def load_best_record(filename, workload_key=None, target=None):
     target : Optional[tvm.target.Target]
         The target device.
         With `None`, this returns the best measure pair of all target devices.
+    include_compatible: bool
+        When set to True, all compatible records in the log file will be considered.
 
     Returns
     -------
@@ -204,13 +297,23 @@ def load_best_record(filename, workload_key=None, target=None):
     for inp, res in log_reader:
         if res.error_no != MeasureErrorNo.NO_ERROR:
             continue
-        if workload_key and inp.task.workload_key != workload_key:
-            continue
         if target and inp.task.target.kind.name != target.kind.name:
             continue
 
         costs = [v.value for v in res.costs]
         cost = np.mean(costs)
+
+        if workload_key is not None:
+            dis_f = calc_workload_dis_factor(workload_key, inp.task.workload_key)
+            if dis_f == float("inf"):
+                continue
+            if not include_compatible and dis_f != 1:
+                continue
+
+            # Since different workloads have different FLOPS, we multiply the factor to
+            # eliminate this difference, which is basically the concept of throughput.
+            cost *= dis_f
+
         if cost < best_cost:
             best_cost = cost
             best_inp = inp
@@ -267,12 +370,8 @@ def distill_record_file(in_file, out_file):
     logger.info("Extract %d best records from %s to %s", len(inputs), in_file, out_file)
 
 
-"""
-Usage:
-* Distill the best entries from a large log file
-e.g. python -m tvm.auto_scheduler.measure_record --mode distill --i input.json
-"""
-if __name__ == "__main__":
+def main():
+    """The main function for CLI."""
     parser = argparse.ArgumentParser()
     parser.add_argument("--mode", choices=["distill"], required=True)
     parser.add_argument("--i", type=str, help="input file")
@@ -285,3 +384,12 @@ if __name__ == "__main__":
     if args.mode == "distill":
         args.o = args.o or args.i + ".best.json"
         distill_record_file(args.i, args.o)
+
+
+"""
+Usage:
+* Distill the best entries from a large log file
+e.g. python -m tvm.auto_scheduler.measure_record --mode distill --i input.json
+"""
+if __name__ == "__main__":
+    main()
diff --git a/python/tvm/auto_scheduler/relay_integration.py b/python/tvm/auto_scheduler/relay_integration.py
index fb60da1..b39aba2 100644
--- a/python/tvm/auto_scheduler/relay_integration.py
+++ b/python/tvm/auto_scheduler/relay_integration.py
@@ -22,7 +22,6 @@ Integrate auto_scheduler into relay. It implements the following items:
 2. Provide auto-scheduling for all TOPI compute functions
 """
 
-import json
 import logging
 import threading
 
@@ -281,7 +280,7 @@ def auto_schedule_topi(outs):
         logger.info("Failed to create a ComputeDAG for auto_scheduler: %s", str(err))
         return None
 
-    key = register_workload_tensors(dag.hash_key(), io_tensors)
+    key = register_workload_tensors(dag.workload_key(), io_tensors)
     target = tvm.target.Target.current()
 
     env = TracingEnvironment.current
@@ -310,9 +309,8 @@ def auto_schedule_topi(outs):
                 return None
 
             # rewrite the layout and update the context for the new dag
-            dag = ComputeDAG(outs)
             new_dag = dag.rewrite_layout_from_state(state)
-            new_key = json.dumps((new_dag.hash_key(),))
+            new_key = new_dag.workload_key()
             if new_key != key:
                 dispatch_ctx.update(target, new_key, state)
     else:
diff --git a/python/tvm/auto_scheduler/search_task.py b/python/tvm/auto_scheduler/search_task.py
index d985ed1..83f665b 100644
--- a/python/tvm/auto_scheduler/search_task.py
+++ b/python/tvm/auto_scheduler/search_task.py
@@ -257,13 +257,15 @@ class SearchTask(Object):
 
         _ffi_api.AutoSchedule(search_policy, tuning_options)
 
-    def apply_best(self, log_file, layout_rewrite_option=None):
+    def apply_best(self, log_file, include_compatible=False, layout_rewrite_option=None):
         """Apply the history best from a log file and return the schedule.
 
         Parameters
         ----------
         log_file : str
            The name of the log file.
+        include_compatible: bool
+            When set to True, all compatible records in the log file will be considered.
         layout_rewrite_option : Optional[LayoutRewriteOption]
            The layout rewrite option.
 
@@ -272,7 +274,9 @@ class SearchTask(Object):
         -------
             A `te.Schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
         """
-        inp, _ = load_best_record(log_file, self.workload_key)
+        inp, _ = load_best_record(
+            log_file, self.workload_key, include_compatible=include_compatible
+        )
         if inp is None:
             raise RuntimeError(
                 "Cannot find any valid schedule for %s in file %s" % (self.workload_key, log_file)
diff --git a/python/tvm/auto_scheduler/utils.py b/python/tvm/auto_scheduler/utils.py
index 334acaf..fd25fdb 100644
--- a/python/tvm/auto_scheduler/utils.py
+++ b/python/tvm/auto_scheduler/utils.py
@@ -19,6 +19,7 @@
 """ Common utilities for auto_scheduler. """
 
 from typing import Hashable
+import json
 import multiprocessing
 import multiprocessing.pool
 import queue
@@ -42,6 +43,32 @@ from tvm.ir.transform import Sequential
 from ..te import Tensor, placeholder
 
 
+def decode_workload_key(workload_key):
+    """Decode the workload key from a string to the name and arguments. The wokrload key
+    is expected to be a list of "[func_name/hash, args ...]" in a JSON string. If not,
+    then simply return the workload key as the name without arguments.
+
+    Parameters
+    ----------
+    workload_key: str
+        The workload key in string. Format: "[func_name/hash, args ...]".
+
+    Returns
+    -------
+    name: str
+        The workload function name or the DAG hash.
+    args: Optional[List[Any]]
+        The arguments of the workload, or None if the workload key format is not decodeable.
+    """
+    try:
+        key_list = json.loads(workload_key)
+        if isinstance(key_list, list) and len(key_list) >= 1:
+            return key_list[0], key_list[1:]
+    except json.decoder.JSONDecodeError:
+        pass
+    return workload_key, None
+
+
 def get_func_name(func):
     """Get name of a function.
 
diff --git a/python/tvm/auto_scheduler/workload_registry.py b/python/tvm/auto_scheduler/workload_registry.py
index 9a7c15c..51ae64d 100644
--- a/python/tvm/auto_scheduler/workload_registry.py
+++ b/python/tvm/auto_scheduler/workload_registry.py
@@ -98,14 +98,14 @@ def register_workload(func_name, f=None, override=False):
     return register
 
 
-def register_workload_tensors(func_name, tensors, override=True):
+def register_workload_tensors(workload_key, tensors, override=True):
     """Register a workload by provding input/output tensors. Since this function is used
     when extracting/deserializing tasks, it expects duplicated registrations by default.
 
     Parameters
     ----------
-    func_name: str
-        The function name or the hash key of the compute DAG.
+    workload_key: str
+        The wokrload key of the compute DAG in JSON string.
     tensors: List[Tensor]
         The input/output tensors of a compute DAG
     override : boolean = True
@@ -113,11 +113,11 @@ def register_workload_tensors(func_name, tensors, override=True):
 
     Returns
     -------
-    key: str
-        The serialized JSON string as the workload key.
+    workload_key: str
+        The wokrload key of the compute DAG in JSON string.
     """
-    register_workload(func_name, override=override)(tensors)
-    return json.dumps((func_name,))
+    register_workload(workload_key, override=override)(tensors)
+    return workload_key
 
 
 def make_workload_key(func, args):
@@ -169,7 +169,8 @@ def workload_key_to_tensors(workload_key):
     Parameters
     ----------
     workload_key : str
-        The input workload key.
+        The input workload key in JSON string. The format is either (func_name, arguments...)
+        for compute functions, or (hash, shapes...) for ComputeDAG.
 
     Returns
     -------
@@ -178,16 +179,21 @@ def workload_key_to_tensors(workload_key):
     """
     global WORKLOAD_FUNC_REGISTRY
 
+    # We register ComputeDAG with both hash and argumetns, which are fixed in ComputeDAG,
+    # so we use an entire workload key to query the ComputeDAG.
+    if workload_key in WORKLOAD_FUNC_REGISTRY:
+        return WORKLOAD_FUNC_REGISTRY[workload_key]
+
+    # We register compute function with only the function name since
+    # it does not bind to specific arguments, so we use the function name to query
+    # the function and call the function with arguments to get the tensors.
     workload = json.loads(workload_key)
     name = workload[0]
     value = WORKLOAD_FUNC_REGISTRY[name]
+    assert callable(value)
 
-    # "value" can be either a function or a list of tensors
-    if callable(value):  # if it is a func
-        args = deserialize_args(workload[1:])
-        return value(*args)
-    # otherwise, it is a list of tensors
-    return value
+    args = deserialize_args(workload[1:])
+    return value(*args)
 
 
 def serialize_workload_registry_entry(workload_key):
@@ -209,6 +215,9 @@ def serialize_workload_registry_entry(workload_key):
     """
     global WORKLOAD_FUNC_REGISTRY
 
+    if workload_key in WORKLOAD_FUNC_REGISTRY:
+        return (workload_key, WORKLOAD_FUNC_REGISTRY[workload_key])
+
     workload = json.loads(workload_key)
     name = workload[0]
     value = WORKLOAD_FUNC_REGISTRY[name]
diff --git a/src/auto_scheduler/compute_dag.cc b/src/auto_scheduler/compute_dag.cc
old mode 100755
new mode 100644
index 735f044..4e7fb05
--- a/src/auto_scheduler/compute_dag.cc
+++ b/src/auto_scheduler/compute_dag.cc
@@ -1243,6 +1243,62 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
   return ss.str();
 }
 
+String ComputeDAG::PrintDAG(bool simple_mode) const {
+  std::stringstream ss;
+
+  for (const auto& op : operator->()->ops) {
+    if (op->IsInstance<te::PlaceholderOpNode>()) {
+      ss << op->name << " = PLACEHOLDER ";
+      if (!simple_mode) {
+        ss << op.output(0)->shape;
+      }
+      ss << "\n";
+    } else if (auto pop = op.as<te::ComputeOpNode>()) {
+      for (size_t k = 0; k < pop->body.size(); ++k) {
+        ss << op->name << "(";
+        for (size_t i = 0; i < pop->axis.size(); i++) {
+          ss << pop->axis[i]->var->name_hint;
+          if (i != pop->axis.size() - 1) {
+            ss << ", ";
+          }
+        }
+        ss << ")";
+        if (pop->body.size() > 1) {
+          ss << ".v" << k;
+        }
+        if (auto preduce = pop->body[k].as<ReduceNode>()) {
+          ICHECK_LT(k, preduce->combiner->result.size());
+          PrimExpr combiner = preduce->combiner->result[k];
+          if (combiner->IsInstance<AddNode>()) {
+            ss << " += " << preduce->source[0] << "\n";
+          } else if (combiner->IsInstance<MaxNode>()) {
+            ss << " max= " << preduce->source[0] << "\n";
+          } else if (combiner->IsInstance<MinNode>()) {
+            ss << " min= " << preduce->source[0] << "\n";
+          } else if (combiner->IsInstance<SelectNode>()) {
+            const auto& select = combiner.as<SelectNode>();
+            ss << " select(" << select->condition << ", " << select->true_value << ", "
+               << select->false_value << ")= " << '(' << preduce->source[0] << ','
+               << preduce->source[1] << ")\n";
+          } else {
+            ss << "reduce" << combiner << "\n";
+          }
+        } else {
+          auto call = pop->body[k].as<CallNode>();
+          if (simple_mode && call) {
+            ss << " = " << call->op << "\n";
+          } else {
+            ss << " = " << pop->body[k] << "\n";
+          }
+        }
+      }
+    } else {
+      LOG(FATAL) << "Invalid op";
+    }
+  }
+  return String(ss.str());
+}
+
 State ComputeDAG::InferBound(const State& state) const {
   ICHECK(state->concrete) << "Only concrete state can be processed to get bound info.";
 
@@ -1383,51 +1439,9 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<ComputeDAGNode>([](const ObjectRef& ref, ReprPrinter* p) {
       auto* node = static_cast<const ComputeDAGNode*>(ref.get());
-      std::stringstream ss;
-
-      for (const auto& op : node->ops) {
-        if (op->IsInstance<te::PlaceholderOpNode>()) {
-          ss << op->name << " = PLACEHOLDER " << op.output(0)->shape << "\n";
-        } else if (auto pop = op.as<te::ComputeOpNode>()) {
-          for (size_t k = 0; k < pop->body.size(); ++k) {
-            ss << op->name << "(";
-            for (size_t i = 0; i < pop->axis.size(); i++) {
-              ss << pop->axis[i]->var->name_hint;
-              if (i != pop->axis.size() - 1) {
-                ss << ", ";
-              }
-            }
-            ss << ")";
-            if (pop->body.size() > 1) {
-              ss << ".v" << k;
-            }
-            if (auto preduce = pop->body[k].as<ReduceNode>()) {
-              ICHECK_LT(k, preduce->combiner->result.size());
-              PrimExpr combiner = preduce->combiner->result[k];
-              if (combiner->IsInstance<AddNode>()) {
-                ss << " += " << preduce->source[0] << "\n";
-              } else if (combiner->IsInstance<MaxNode>()) {
-                ss << " max= " << preduce->source[0] << "\n";
-              } else if (combiner->IsInstance<MinNode>()) {
-                ss << " min= " << preduce->source[0] << "\n";
-              } else if (combiner->IsInstance<SelectNode>()) {
-                const auto& select = combiner.as<SelectNode>();
-                ss << " select(" << select->condition << ", " << select->true_value << ", "
-                   << select->false_value << ")= " << '(' << preduce->source[0] << ','
-                   << preduce->source[1] << ")\n";
-              } else {
-                ss << "reduce" << combiner << "\n";
-              }
-            } else {
-              ss << " = " << pop->body[k] << "\n";
-            }
-          }
-        } else {
-          LOG(FATAL) << "Invalid op";
-        }
-      }
-
-      p->stream << ss.str();
+      auto dag = GetRef<ComputeDAG>(node);
+      auto dag_str = dag.PrintDAG();
+      p->stream << dag_str;
     });
 
 Array<PrimExpr> GetShapeFromRewrittenLayout(String rewritten_layout, Array<String> axis_names) {
@@ -1476,6 +1490,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGPrintPythonCodeFromState")
       return dag.PrintStepsAsPython(state->transform_steps);
     });
 
+TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGPrintDAG")
+    .set_body_typed([](const ComputeDAG& dag, bool simple_mode) {
+      return dag.PrintDAG(simple_mode);
+    });
+
 TVM_REGISTER_GLOBAL("auto_scheduler.ComputeDAGInferBoundFromState")
     .set_body_typed([](const ComputeDAG& dag, const State& state) {
       return dag.InferBound(state);
diff --git a/tests/python/unittest/test_auto_scheduler_measure.py b/tests/python/unittest/test_auto_scheduler_measure.py
index e9f1fa4..3b074b2 100644
--- a/tests/python/unittest/test_auto_scheduler_measure.py
+++ b/tests/python/unittest/test_auto_scheduler_measure.py
@@ -16,6 +16,7 @@
 # under the License.
 
 """ Test measurement and log serialization. """
+import json
 
 import multiprocessing
 import tvm
@@ -200,6 +201,38 @@ def test_recover_measure_input():
         assert str(correct_inp.state) == str(inp.state)
 
 
+def test_workload_dis_factor():
+    calc = auto_scheduler.measure_record.calc_workload_dis_factor
+
+    # Identical
+    target_wkl_key = json.dumps(
+        ["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [1, 1], "float32"]
+    )
+    assert calc(target_wkl_key, target_wkl_key) == 1
+
+    # Compatible with a factor
+    wkl_key = json.dumps(["func1", [1, 3, 112, 112], [32, 3, 3, 3], [0, 0], [1, 1], "float32"])
+    assert calc(target_wkl_key, wkl_key) == 8 * 2 * 2
+
+    # Incompatible argument with zeros
+    wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [1, 1], [1, 1], "float32"])
+    assert calc(target_wkl_key, wkl_key) == float("inf")
+    wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [0, 0], "float32"])
+    assert calc(target_wkl_key, wkl_key) == float("inf")
+
+    # Incompatible non-integter argument
+    wkl_key = json.dumps(["func1", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [1, 1], "int8"])
+    assert calc(target_wkl_key, wkl_key) == float("inf")
+
+    # Incompatible function
+    wkl_key = json.dumps(["func2", [8, 3, 224, 224], [32, 3, 3, 3], [0, 0], [1, 1], "float32"])
+    assert calc(target_wkl_key, wkl_key) == float("inf")
+
+    # Incompatible due to non-dividable factor
+    wkl_key = json.dumps(["func1", [8, 3, 223, 223], [32, 3, 3, 3], [0, 0], [1, 1], "float32"])
+    assert calc(target_wkl_key, wkl_key) == float("inf")
+
+
 def test_measure_local_builder_runner():
     if not tvm.testing.device_enabled("llvm"):
         return
diff --git a/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1-cuda.json b/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1-cuda.json
index 8d0a6ae..7cb3a67 100644
--- a/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1-cuda.json
+++ b/tutorials/auto_scheduler/ci_logs/resnet-18-NHWC-B1-cuda.json
@@ -1,26 +1,26 @@
 # Provide valid schedules for resnet-18 on GPU.
 # This is used to run the tutorial on the documentation web server.
-{"i": [["[\"b32ed43fb351136894c322ee49097a1a\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["SP", 4, 1, 1000, [40], 1], ["AN", 4, 2, 6], ["FSP", 3, 1, 0, 1], ["AN", 3, 2, 6], ["CA", 3, 4, 0], ["CI", 2], ["FSP", 1, 1, 0, 1], ["AN", 1, 2, 6], ["CA", 1, 4, 0], ["AN", 4, 0, 5], ["PR", 1, 0, "auto_unroll_max_step$512"], ["PR", 3, 0, "auto_unroll_max_step$512"]]]], "r": [[4.87396e-06], 0, 1.30575, 1606984701], "v": "v0.3"}
-{"i": [["[\"d09dc1a6bb90d59c91b68989ad3492ff\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["SP", 2, 0, 1, [1, 1, 1, 1], 1], ["SP", 2, 5, 1000, [1, 50, 1, 1], 1], ["SP", 2, 10, 512, [1, 16], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 4, 0, 0, 3], ["FSP", 4, 4, 1, 3], ["RE", 4, [0, 4, 1, 5, 2, 6, 3, 7]], ["CA", 2, 4, 5], ["CHR", 1, "shared", [2]], ["CA", 2, 3, 6], ["CHR", 0, "shared", [3]], [" [...]
-{"i": [["[\"7de313da0ca29a8c63f647791692430d\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 512, [2], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["FU", 1, [0, 1, 2, 3]], ["SP", 1, 0, 512, [32], 1], ["AN", 1, 0, 5], ["AN", 1, 1, 6], ["PR", 1, 0, "auto_unroll_max_step$64"]]]], "r": [[3.91068e-06], 0, 1.63708, 1606984742], "v": "v0.3"}
-{"i": [["[\"8d5a93959138dc7b2ee1f1b3219dfa14\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 15], ["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [2], 1], ["SP", 8, 4, 512, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 16, [4,  [...]
-{"i": [["[\"ac6920940de3797cc3f9f9c260675e5d\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [4], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5, 4, [1, 1, 1, 2], 1], ["SP", 6, 10, 16, [4, 2, 2, 1], 1], ["SP", 6, 1 [...]
-{"i": [["[\"7e83a2ee5cd5d50282ed19310700046a\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [1], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 16, [2, 1, 1, 8], 1], ["SP", 6, 15, 512, [1,  [...]
-{"i": [["[\"424ba83160af31badc0b098136e1a3b0\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 2], 1], ["SP", 6, 5, 4, [1, 1, 1, 2], 1], ["SP", 6, 10, 49, [1, 1, 1, 7], 1] [...]
-{"i": [["[\"a169cd0053d3a7ca82998fcb62e42c58\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 1, 1, 1], 1], ["SP", 6, 10, 49, [1, 7, 7, 1], 1], ["SP", 6, 1 [...]
-{"i": [["[\"0141ffc4fbabc10cc5a94c954419055b\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [7], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 4, 1, 1], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 49, [1, 1, 7, 1], 1], ["SP", 6, 15, 256, [4,  [...]
-{"i": [["[\"81aae4b8e2c076a4014d403e8a2c70a1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [2, 7, 1, 1], 1], ["SP", 3, 10, 14, [1, 7, 2, 1], 1], ["SP", 3, 15, 256, [2, 2, 1, 4], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 128, [4, 1], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, [...]
-{"i": [["[\"c7a6b56bdc04b94c829fb2ef9874019e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 128, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 1], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 196, [1, 7, 1, 7],  [...]
-{"i": [["[\"c035cc8b0568a8e054d06bd7f4950550\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 4], 1], ["SP", 6, 5, 4, [1, 1, 1, 2], 1], ["SP", 6, 10, 196, [2, 49, 1, 1], 1], ["SP", 6 [...]
-{"i": [["[\"c5ee3e05edd9754492d0763aa41fd025\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4, [1, 4, 1, 1], 1], ["SP", 6, 10, 196, [1, 2, 7, 1], 1], ["SP", 6, 15, 128, [1 [...]
-{"i": [["[\"022ebb6b7c55c5ed030421380ec83a04\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 1, 2, 1], 1], ["SP", 3, 10, 28, [1, 7, 2, 2], 1], ["SP", 3, 15, 128, [1, 8, 8, 1], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 1], 1], ["SP", 3, 26, 64, [4, 2], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13,  [...]
-{"i": [["[\"de0df0893e01892cfe69f7bc2c24111f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 64, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 1, 1, 1], 1], ["SP", 6, 5, 6, [1, 1, 1, 2], 1], ["SP", 6, 10, 196, [2, 14, 1, 1], 1 [...]
-{"i": [["[\"f2e3c09a00e7d0a9897f70497e089f1e\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 64, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 3, 2, 1], 1], ["SP", 6, 5, 6, [1, 3, 1, 2], 1], ["SP", 6, 10, 196, [1, 1, 4, 1], 1], ["SP", 6, [...]
-{"i": [["[\"fa26946d7ac51126bfa859cb183f9ca1\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [49], 1], ["SP", 8, 4, 64, [2], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 2, 1, 3], 1], ["SP", 6, 5, 6, [1, 3, 1, 2], 1], ["SP", 6, 10, 196, [1, 1, 1, 4], 1], ["SP", 6, 15, 64, [1, [...]
-{"i": [["[\"ba2026d923536b75e9b4faed89287d5f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 4], ["CI", 1], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 200704, [64], 1], ["AN", 5, 0, 5], ["AN", 5, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 200704, [64], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["PR", 2, 0, "auto_unroll_max_step$16"]]]], "r": [[2.00968e-05], 0, 1.53065, 1606985193], "v": "v0.3"}
-{"i": [["[\"a0eb8d6048282a4a0986cc2ccf14eaa2\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 112, [1, 2, 7, 1], 1], ["SP", 3, 10, 112, [1, 7, 1, 1], 1], ["SP", 3, 15, 64, [1, 8, 4, 1], 1], ["SP", 3, 20, 7, [7, 1], 1], ["SP", 3, 23, 7, [1, 7], 1], ["SP", 3, 26, 3, [3, 1], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13,  [...]
-{"i": [["[\"bf78a7bf0209980f72953637dfd14a6f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 56, [1, 2, 2, 2], 1], ["SP", 3, 10, 56, [1, 7, 1, 2], 1], ["SP", 3, 15, 64, [1, 16, 1, 4], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [2, 8], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27,  [...]
-{"i": [["[\"6630936c26852f2b89dbfa2ff37fbb9c\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 7, 1, 1], 1], ["SP", 3, 10, 28, [1, 2, 1, 7], 1], ["SP", 3, 15, 128, [8, 8, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [2, 2], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27,  [...]
-{"i": [["[\"ba5f918733ccbbd4a1d7fd3724665a2f\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 1, 1], 1], ["SP", 3, 10, 14, [2, 1, 7, 1], 1], ["SP", 3, 15, 256, [2, 64, 1, 2], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 128, [1, 2], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27 [...]
-{"i": [["[\"21ad409d72953de188314010134e3acd\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 7, 1], 1], ["SP", 3, 10, 7, [1, 1, 1, 1], 1], ["SP", 3, 15, 512, [4, 128, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 256, [1, 16], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27 [...]
-{"i": [["[\"1f6cd3637ec856bf5cf5010a623eed05\"]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 65536, 1024, 8, 32]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [7, 1, 1, 1], 1], ["SP", 3, 10, 7, [1, 7, 1, 1], 1], ["SP", 3, 15, 512, [1, 4, 1, 1], 1], ["SP", 3, 20, 3, [1, 3], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 256, [8, 2], 1], ["RE", 3, [0, 5, 10, 15, 1, 6, 11, 16, 2, 7, 12, 17, 20, 23, 26, 21, 24, 27, 3, 8, 13, 1 [...]
+{"i": [["[\"d7b65649a4dd54becea0a52aabbc5af5\", 1, 1000, 1, 1000]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["SP", 4, 1, 1000, [40], 1], ["AN", 4, 2, 6], ["FSP", 3, 1, 0, 1], ["AN", 3, 2, 6], ["CA", 3, 4, 0], ["CI", 2], ["FSP", 1, 1, 0, 1], ["AN", 1, 2, 6], ["CA", 1, 4, 0], ["AN", 4, 0, 5], ["PR", 1, 0, "auto_unroll_max_step$512"], ["PR", 3, 0, "auto_unroll_max_step$512"]]]], "r": [[4.87396e-06], 0, 1.3 [...]
+{"i": [["[\"9847f8cc0b305137f49f2c5c0c8ab25d\", 1, 512, 1000, 512, 1000, 1, 1000]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["SP", 2, 0, 1, [1, 1, 1, 1], 1], ["SP", 2, 5, 1000, [1, 50, 1, 1], 1], ["SP", 2, 10, 512, [1, 16], 1], ["RE", 2, [0, 5, 1, 6, 2, 7, 10, 11, 3, 8, 12, 4, 9]], ["FSP", 4, 0, 0, 3], ["FSP", 4, 4, 1, 3], ["RE", 4, [0, 4, 1, 5, 2, 6, 3, 7]], ["CA", 2, 4, 5], ["CHR", 1, "shared", [2]],  [...]
+{"i": [["[\"69115f188984ae34ede37c3b8ca40b43\", 1, 7, 7, 512, 1, 1, 1, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 512, [2], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["FU", 1, [0, 1, 2, 3]], ["SP", 1, 0, 512, [32], 1], ["AN", 1, 0, 5], ["AN", 1, 1, 6], ["PR", 1, 0, "auto_unroll_max_step$64"]]]], "r": [[3.91068e-06], 0, 1.63708, 1606984742], "v": "v0.5"}
+{"i": [["[\"ad6cecbf5d85cb1cda3c2bb7af170211\", 1, 7, 7, 512, 4, 4, 512, 512, 1, 7, 7, 512, 1, 1, 1, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 15], ["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [2], 1], ["SP", 8, 4, 512, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "l [...]
+{"i": [["[\"3a69f9fbc63760d99e36b4c17b3bfc57\", 1, 7, 7, 512, 4, 4, 512, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [4], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5 [...]
+{"i": [["[\"d730bcd28f0920f6b97245e2a11bd8d6\", 1, 7, 7, 512, 4, 4, 512, 512, 1, 7, 7, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 16, [1], 1], ["SP", 8, 4, 512, [8], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 2], 1], ["SP", 6, 5, 4, [1, 1,  [...]
+{"i": [["[\"f3b6c10fcc6ce01ff01add933e4d21e9\", 1, 14, 14, 256, 4, 4, 256, 256, 1, 14, 14, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [16], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, [...]
+{"i": [["[\"b8b52b9be9df6102466a22a014c44c1f\", 1, 14, 14, 256, 4, 4, 256, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [1], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP",  [...]
+{"i": [["[\"d374e472bd9d8164892b9e28a0a8cb59\", 1, 14, 14, 256, 4, 4, 256, 256, 1, 14, 14, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 49, [7], 1], ["SP", 8, 4, 256, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 4, 1, 1], 1], ["SP", 6, 5, 4, [ [...]
+{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 28, 28, 128, 3, 3, 128, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [2, 7, 1, 1], 1], ["SP", 3, 10, 14, [1, 7, 2, 1], 1], ["SP", 3, 15, 256, [2, 2, 1, 4], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 128, [4, 1], 1], ["RE", 3, [0 [...]
+{"i": [["[\"c4500b4e2fd04e695c32d2f31bbdc14a\", 1, 28, 28, 128, 4, 4, 128, 128, 1, 28, 28, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [14], 1], ["SP", 8, 4, 128, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0 [...]
+{"i": [["[\"e4cdf917b876dbdd64488c3818d9c141\", 1, 28, 28, 128, 4, 4, 128, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 1, 1, 4], 1], ["SP", [...]
+{"i": [["[\"dac19035dd5fe9424ee8617421b9c817\", 1, 28, 28, 128, 4, 4, 128, 128, 1, 28, 28, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 128, [4], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 4, [1, 2, 1, 1], 1], ["SP", 6, 5, 4,  [...]
+{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 56, 56, 64, 3, 3, 64, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 1, 2, 1], 1], ["SP", 3, 10, 28, [1, 7, 2, 2], 1], ["SP", 3, 15, 128, [1, 8, 8, 1], 1], ["SP", 3, 20, 3, [3, 1], 1], ["SP", 3, 23, 3, [1, 1], 1], ["SP", 3, 26, 64, [4, 2], 1], ["RE", 3, [0, 5 [...]
+{"i": [["[\"1e3c4211ffd2f2db91078ae4d04b779d\", 1, 56, 56, 64, 6, 6, 64, 64, 1, 56, 56, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 13], ["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 64, [1], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, [...]
+{"i": [["[\"b818b53148cd450f86569dfc3e04cb8a\", 1, 56, 56, 64, 6, 6, 64, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 11], ["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [2], 1], ["SP", 8, 4, 64, [64], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 3, 2, 1], 1], ["SP", 6, 5 [...]
+{"i": [["[\"3ea73fb9b0364374730d09e068821f95\", 1, 56, 56, 64, 6, 6, 64, 64, 1, 56, 56, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 9], ["AN", 8, 0, 1], ["AN", 8, 1, 1], ["SP", 8, 2, 196, [49], 1], ["SP", 8, 4, 64, [2], 1], ["AN", 8, 6, 1], ["AN", 8, 7, 1], ["RE", 8, [2, 4, 3, 5, 0, 1, 6, 7]], ["CI", 7], ["CHW", 6, "local"], ["SP", 6, 0, 6, [1, 2, 1, 3], 1], ["SP", 6, 5, 6, [1, 3 [...]
+{"i": [["[\"a5612fdeb9db4d579a75ec225ea4c06a\", 1, 112, 112, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 4], ["CI", 1], ["FU", 5, [0, 1, 2, 3]], ["SP", 5, 0, 200704, [64], 1], ["AN", 5, 0, 5], ["AN", 5, 1, 6], ["FU", 2, [0, 1, 2, 3]], ["SP", 2, 0, 200704, [64], 1], ["AN", 2, 0, 5], ["AN", 2, 1, 6], ["PR", 2, 0, "auto_unroll_max_step$16"]]]], "r": [[2.00968e-05], 0, 1 [...]
+{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 224, 224, 3, 7, 7, 3, 64, 1, 1, 1, 64, 1, 112, 112, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 112, [1, 2, 7, 1], 1], ["SP", 3, 10, 112, [1, 7, 1, 1], 1], ["SP", 3, 15, 64, [1, 8, 4, 1], 1], ["SP", 3, 20, 7, [7, 1], 1], ["SP", 3, 23, 7, [1, 7], 1], ["SP", 3, 26, 3, [3, 1], 1], ["RE", 3, [0, 5, [...]
+{"i": [["[\"7006235cfc29b73be524cf390ed5a977\", 1, 56, 56, 64, 1, 1, 64, 64, 1, 56, 56, 64]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 56, [1, 2, 2, 2], 1], ["SP", 3, 10, 56, [1, 7, 1, 2], 1], ["SP", 3, 15, 64, [1, 16, 1, 4], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [2, 8], 1], ["RE", 3, [0, 5, 10,  [...]
+{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 56, 56, 64, 1, 1, 64, 128, 1, 28, 28, 128]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 28, [1, 7, 1, 1], 1], ["SP", 3, 10, 28, [1, 2, 1, 7], 1], ["SP", 3, 15, 128, [8, 8, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 64, [2, 2], 1], ["RE", 3, [0, 5, 10 [...]
+{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 28, 28, 128, 1, 1, 128, 256, 1, 14, 14, 256]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 14, [1, 1, 1, 1], 1], ["SP", 3, 10, 14, [2, 1, 7, 1], 1], ["SP", 3, 15, 256, [2, 64, 1, 2], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 128, [1, 2], 1], ["RE", 3, [0, 5 [...]
+{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 14, 14, 256, 1, 1, 256, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [1, 1, 7, 1], 1], ["SP", 3, 10, 7, [1, 1, 1, 1], 1], ["SP", 3, 15, 512, [4, 128, 1, 1], 1], ["SP", 3, 20, 1, [1, 1], 1], ["SP", 3, 23, 1, [1, 1], 1], ["SP", 3, 26, 256, [1, 16], 1], ["RE", 3, [0, 5,  [...]
+{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 14, 14, 256, 3, 3, 256, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "cuda -keys=cuda,gpu -max_num_threads=1024 -thread_warp_size=32", [-1, 16, 64, 49152, 2147483647, 1024, 8, 32], "", 0], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1, 1], 1], ["SP", 3, 5, 7, [7, 1, 1, 1], 1], ["SP", 3, 10, 7, [1, 7, 1, 1], 1], ["SP", 3, 15, 512, [1, 4, 1, 1], 1], ["SP", 3, 20, 3, [1, 3], 1], ["SP", 3, 23, 3, [1, 3], 1], ["SP", 3, 26, 256, [8, 2], 1], ["RE", 3, [0, 5, [...]
diff --git a/tutorials/auto_scheduler/ci_logs/resnet-50-NHWC-B1-llvm.json b/tutorials/auto_scheduler/ci_logs/resnet-50-NHWC-B1-llvm.json
index 611f776..3dd4541 100644
--- a/tutorials/auto_scheduler/ci_logs/resnet-50-NHWC-B1-llvm.json
+++ b/tutorials/auto_scheduler/ci_logs/resnet-50-NHWC-B1-llvm.json
@@ -1,31 +1,28 @@
 # Provide valid schedules for resnet-50 for CPU.
 # This is used to run the tutorial on the documentation web server.
-{"i": [["[\"b32ed43fb351136894c322ee49097a1a\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 3, 1, 1000, [50], 1], ["RF", 3, 2, 1], ["RE", 3, [0, 2, 1]], ["SP", 1, 1, 1000, [20], 1], ["RF", 1, 2, 1], ["RE", 1, [0, 2, 1]], ["CR", 6], ["CA", 5, 6, 1], ["CR", 4], ["CA", 2, 3, 1], ["AN", 1, 0, 3], ["FU", 3, [0, 1]], ["AN", 3, 0, 3], ["AN", 4, 0, 3], ["FU", 6, [0, 1]], ["AN", 6, 0, 3], ["PR", 1, 0, "auto_unroll_max_step$16"], ["PR", 2, 0, "auto [...]
-{"i": [["[\"6129df1a3d5f6326c8393a8d17160199\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 2, 0, 1, [1, 1, 1], 1], ["SP", 2, 4, 1000, [1, 1, 1], 1], ["SP", 2, 8, 16, [2, 2, 4], 1], ["SP", 2, 12, 128, [32], 1], ["RE", 2, [0, 4, 8, 1, 5, 9, 12, 2, 6, 10, 13, 3, 7, 11]], ["CR", 5], ["CA", 3, 5, 1], ["FU", 2, [0, 1]], ["AN", 2, 0, 3], ["FU", 5, [0, 1]], ["AN", 5, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$16"], ["PR", 3, 0, "auto_unroll_max_s [...]
-{"i": [["[\"36ee2798ed60bae3bcd1bb89a0285fe8\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CA", 1, 2, 3], ["FU", 2, [0, 1, 2, 3]], ["AN", 2, 0, 3], ["PR", 1, 0, "auto_unroll_max_step$16"]]]], "r": [[6.28e-06, 8.176e-06, 8.048e-06, 7.942e-06, 7.977e-06, 8.002e-06, 8.093e-06, 7.924e-06, 7.943e-06, 7.924e-06], 0, 0.130759, 1606960900], "v": "v0.3"}
-{"i": [["[\"dcf6fcf5f56fa614bf9aef0c82382caf\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 9], ["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 7, 1], 1], ["SP", 3, 8, 7, [1, 1, 1], 1], ["SP", 3, 12, 2048, [8, 2, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 10, 0, 3, 2], ["F [...]
-{"i": [["[\"7657f886f5e9d8b5f19a5fd2c5b90d8d\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 1, 7], 1], ["SP", 3, 8, 7, [1, 1, 1], 1], ["SP", 3, 12, 512, [32, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["CI", 1], ["FU", 3, [0, 1, 2, 3, 4, 5, 6, 7]] [...]
-{"i": [["[\"7e09b626cf077cd419190fee02091dd6\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 1, 2], 1], ["SP", 3, 8, 14, [2, 1, 1], 1], ["SP", 3, 12, 1024, [1, 1, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["CI", 1], ["FU", 3, [0, 1, 2, 3, [...]
-{"i": [["[\"1dce2c5e4269b8a12dfc50cd4dd23ff1\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [2, 1, 7], 1], ["SP", 3, 8, 14, [2, 1, 1], 1], ["SP", 3, 12, 256, [16, 4, 4], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [64], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["CI", 1], ["CR", 6], ["FU", 3, [0, 1, 2, 3, [...]
-{"i": [["[\"d3b36ce001dc24d693facfbdae1979b4\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [1, 1, 1], 1], ["SP", 3, 8, 28, [7, 1, 1], 1], ["SP", 3, 12, 512, [1, 2, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 128, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 8, 0, 2, 2], ["FSP", 8, 3, [...]
-{"i": [["[\"a085717fb3dcb046e5c4c2c04d3dc541\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [14, 1, 2], 1], ["SP", 3, 8, 28, [2, 1, 1], 1], ["SP", 3, 12, 128, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [16], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 1], ["FSP", 6, 2, 2, 1], [ [...]
-{"i": [["[\"8dd7d81db440763f622f03fdc99e6d46\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [14, 2, 2], 1], ["SP", 3, 8, 56, [2, 1, 2], 1], ["SP", 3, 12, 64, [1, 16, 4], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 1], ["FSP", 6, 2, 2, 1], ["FS [...]
-{"i": [["[\"ba2026d923536b75e9b4faed89287d5f\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 4], ["CA", 2, 5, 3], ["CR", 1], ["FU", 1, [0, 1, 2]], ["AN", 1, 0, 3], ["FU", 5, [0, 1, 2]], ["AN", 5, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$64"]]]], "r": [[2.9217e-05, 3.1065e-05, 3.188e-05, 3.0897e-05, 3.1295e-05, 3.1307e-05, 3.19e-05, 3.1038e-05, 3.1919e-05, 3.2077e-05], 0, 0.217184, 1606961266], "v": "v0.3"}
-{"i": [["[\"0fb1dfcdb5b755e2dab290ed0129dcf2\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [1, 1, 2], 1], ["SP", 3, 8, 28, [1, 1, 2], 1], ["SP", 3, 12, 128, [2, 2, 16], 1], ["SP", 3, 16, 3, [3], 1], ["SP", 3, 18, 3, [3], 1], ["SP", 3, 20, 128, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["F [...]
-{"i": [["[\"e043f834cc7f19597227e09dc7f59503\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [1, 1, 2], 1], ["SP", 3, 8, 14, [7, 2, 1], 1], ["SP", 3, 12, 256, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], [" [...]
-{"i": [["[\"a0eb8d6048282a4a0986cc2ccf14eaa2\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 112, [1, 1, 4], 1], ["SP", 3, 8, 112, [4, 2, 1], 1], ["SP", 3, 12, 64, [1, 1, 16], 1], ["SP", 3, 16, 7, [7], 1], ["SP", 3, 18, 7, [7], 1], ["SP", 3, 20, 3, [3], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FS [...]
-{"i": [["[\"03614e726dc588d11887eb0953a77e53\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 1, 1], 1], ["SP", 3, 8, 7, [1, 1, 7], 1], ["SP", 3, 12, 2048, [256, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 5, 0, 0, 2], ["FSP", 5, 3, 1, 2], ["FSP", 5, 6,  [...]
-{"i": [["[\"b51e06c1131d4cded40d1b215f722a4e\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [4, 1, 1], 1], ["SP", 3, 8, 56, [7, 4, 1], 1], ["SP", 3, 12, 64, [4, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FS [...]
-{"i": [["[\"a9e632e5167afb60fbe29e7aeef1d152\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [1, 1, 1], 1], ["SP", 3, 8, 56, [7, 1, 4], 1], ["SP", 3, 12, 64, [1, 1, 16], 1], ["SP", 3, 16, 3, [1], 1], ["SP", 3, 18, 3, [3], 1], ["SP", 3, 20, 64, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FSP [...]
-{"i": [["[\"e0a9eb3795b531085e0ebb772e7e800c\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [7, 1, 1], 1], ["SP", 3, 8, 7, [1, 7, 1], 1], ["SP", 3, 12, 512, [2, 2, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 2048, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FSP [...]
-{"i": [["[\"8fcee68a4342c38248a827f1c6c69177\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [4, 2, 1], 1], ["SP", 3, 8, 56, [1, 1, 1], 1], ["SP", 3, 12, 256, [2, 4, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 5, 0, 0, 2], ["FSP", 5, 3, 1, 2], ["FSP", 5, 6, 2, [...]
-{"i": [["[\"4d7e646d99bfa3cea8245bd7100369cb\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 2, 1], 1], ["SP", 3, 8, 14, [14, 1, 1], 1], ["SP", 3, 12, 1024, [2, 2, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 4, 0, 1, 2], ["FSP", 4, 3 [...]
-{"i": [["[\"b2010aa63c95dedf1f58f3fe8bc78634\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [2, 1, 2], 1], ["SP", 3, 8, 28, [1, 2, 1], 1], ["SP", 3, 12, 512, [16, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 4, 0, 1, 2], ["FSP", 4, 3, [...]
-{"i": [["[\"537c8642716948c33a6eaaabc86b159d\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [7, 1, 1], 1], ["SP", 3, 8, 7, [1, 7, 1], 1], ["SP", 3, 12, 2048, [128, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 4, 0, 1, 2], ["FSP", 4, 3 [...]
-{"i": [["[\"7e3f0cf5a6dd80d36dab1a3dad92674a\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 7, 1], 1], ["SP", 3, 8, 7, [7, 1, 1], 1], ["SP", 3, 12, 512, [4, 1, 8], 1], ["SP", 3, 16, 3, [3], 1], ["SP", 3, 18, 3, [1], 1], ["SP", 3, 20, 512, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FSP" [...]
-{"i": [["[\"cd7c4a374fb2bbc0d075c8cae638ad14\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 1, 2], 1], ["SP", 3, 8, 14, [7, 2, 1], 1], ["SP", 3, 12, 1024, [16, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 5, 0, 0, 2], ["FSP", 5, 3, 1, 2], ["FSP", 5, 6 [...]
-{"i": [["[\"45b4de07687dee43ee1cbde9f516b2bf\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [56, 1, 1], 1], ["SP", 3, 8, 56, [14, 1, 2], 1], ["SP", 3, 12, 256, [1, 2, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [64], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 4, 0, 1, 2], ["FSP", 4, 3 [...]
-{"i": [["[\"95bf49cc8cf7a351e974b2359702aac0\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [1, 2, 1], 1], ["SP", 3, 8, 14, [1, 7, 1], 1], ["SP", 3, 12, 256, [2, 1, 8], 1], ["SP", 3, 16, 3, [1], 1], ["SP", 3, 18, 3, [3], 1], ["SP", 3, 20, 256, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], ["FS [...]
-{"i": [["[\"5e3ceb6e23ae8c351d5a1770d5fc6c7c\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [1, 1, 2], 1], ["SP", 3, 8, 28, [1, 1, 1], 1], ["SP", 3, 12, 512, [4, 1, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 128, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 5, 0, 0, 2], ["FSP", 5, 3, 1, 2], ["FSP", 5, 6,  [...]
-{"i": [["[\"691feef049c8693bbe91bd5e7c9cdf34\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [7, 1, 4], 1], ["SP", 3, 8, 56, [4, 2, 1], 1], ["SP", 3, 12, 256, [32, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 8, 0, 2, 2], ["FSP", 8, 3,  [...]
-{"i": [["[\"45acfc473c772458684f36a34549d8aa\"]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [18, 64, 64, 0, 0, 0, 0, 0]], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [7, 1, 4], 1], ["SP", 3, 8, 28, [14, 1, 1], 1], ["SP", 3, 12, 128, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, 11, 15]], ["FSP", 6, 0, 1, 2], ["FSP", 6, 3, 2, 2], [" [...]
+{"i": [["[\"d7b65649a4dd54becea0a52aabbc5af5\", 1, 1000, 1, 1000]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["SP", 3, 1, 1000, [50], 1], ["RF", 3, 2, 1], ["RE", 3, [0, 2, 1]], ["SP", 1, 1, 1000, [20], 1], ["RF", 1, 2, 1], ["RE", 1, [0, 2, 1]], ["CR", 6], ["CA", 5, 6, 1], ["CR", 4], ["CA", 2, 3, 1], ["AN", 1, 0, 3], ["FU", 3, [0, 1]], ["AN", 3, 0, 3], ["AN", 4, 0, 3], ["FU", 6, [0, 1]], ["AN", 6, 0, 3], ["PR", 1, 0, "auto_unroll_max_step$ [...]
+{"i": [["[\"69115f188984ae34ede37c3b8ca40b43\", 1, 7, 7, 2048, 1, 1, 1, 2048]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CA", 1, 2, 3], ["FU", 2, [0, 1, 2, 3]], ["AN", 2, 0, 3], ["PR", 1, 0, "auto_unroll_max_step$16"]]]], "r": [[6.28e-06, 8.176e-06, 8.048e-06, 7.942e-06, 7.977e-06, 8.002e-06, 8.093e-06, 7.924e-06, 7.943e-06, 7.924e-06], 0, 0.130759, 1606960900], "v": "v0.5"}
+{"i": [["[\"875556d12d0be2269206a7775d5296a6\", 1, 7, 7, 512, 1, 1, 512, 2048, 1, 7, 7, 2048, 1, 1, 1, 2048, 1, 1, 1, 2048, 1, 7, 7, 2048]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 9], ["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 7, 1], 1], ["SP", 3, 8, 7, [1, 1, 1], 1], ["SP", 3, 12, 2048, [8, 2, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [2], 1], ["RE", 3, [0, 4, 8, [...]
+{"i": [["[\"de7d1695278cf52778b038e6573d7626\", 1, 14, 14, 1024, 1, 1, 1024, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 1, 7], 1], ["SP", 3, 8, 7, [1, 1, 1], 1], ["SP", 3, 12, 512, [32, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19 [...]
+{"i": [["[\"1b524af89dd867d26059e1f621cf987c\", 1, 14, 14, 256, 1, 1, 256, 1024, 1, 14, 14, 1024, 1, 1, 1, 1024, 1, 14, 14, 1024]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 1, 2], 1], ["SP", 3, 8, 14, [2, 1, 1], 1], ["SP", 3, 12, 1024, [1, 1, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, [...]
+{"i": [["[\"de7d1695278cf52778b038e6573d7626\", 1, 28, 28, 512, 1, 1, 512, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [2, 1, 7], 1], ["SP", 3, 8, 14, [2, 1, 1], 1], ["SP", 3, 12, 256, [16, 4, 4], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [64], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17,  [...]
+{"i": [["[\"1b524af89dd867d26059e1f621cf987c\", 1, 28, 28, 128, 1, 1, 128, 512, 1, 28, 28, 512, 1, 1, 1, 512, 1, 28, 28, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [1, 1, 1], 1], ["SP", 3, 8, 28, [7, 1, 1], 1], ["SP", 3, 12, 512, [1, 2, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 128, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16,  [...]
+{"i": [["[\"de7d1695278cf52778b038e6573d7626\", 1, 56, 56, 256, 1, 1, 256, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [14, 1, 2], 1], ["SP", 3, 8, 28, [2, 1, 1], 1], ["SP", 3, 12, 128, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [16], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, [...]
+{"i": [["[\"6b7583cf23c7c37d3212cad9d06e58c1\", 1, 56, 56, 64, 1, 1, 64, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [14, 2, 2], 1], ["SP", 3, 8, 56, [2, 1, 2], 1], ["SP", 3, 12, 64, [1, 16, 4], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, [...]
+{"i": [["[\"a5612fdeb9db4d579a75ec225ea4c06a\", 1, 112, 112, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 4], ["CA", 2, 5, 3], ["CR", 1], ["FU", 1, [0, 1, 2]], ["AN", 1, 0, 3], ["FU", 5, [0, 1, 2]], ["AN", 5, 0, 3], ["PR", 2, 0, "auto_unroll_max_step$64"]]]], "r": [[2.9217e-05, 3.1065e-05, 3.188e-05, 3.0897e-05, 3.1295e-05, 3.1307e-05, 3.19e-05, 3.1038e-05, 3.1919e-05, 3.2077e-05], 0, 0.217184, 1606961 [...]
+{"i": [["[\"6b7583cf23c7c37d3212cad9d06e58c1\", 1, 14, 14, 1024, 1, 1, 1024, 256, 1, 1, 1, 256, 1, 14, 14, 256]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [1, 1, 2], 1], ["SP", 3, 8, 14, [7, 2, 1], 1], ["SP", 3, 12, 256, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17 [...]
+{"i": [["[\"12b88bedece6984af589a28b43e0f3c4\", 1, 224, 224, 3, 7, 7, 3, 64, 1, 1, 1, 64, 1, 112, 112, 64]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 112, [1, 1, 4], 1], ["SP", 3, 8, 112, [4, 2, 1], 1], ["SP", 3, 12, 64, [1, 1, 16], 1], ["SP", 3, 16, 7, [7], 1], ["SP", 3, 18, 7, [7], 1], ["SP", 3, 20, 3, [3], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 2 [...]
+{"i": [["[\"1cc666833c122282e3fcf3595901b12b\", 1, 7, 7, 512, 1, 1, 512, 2048, 1, 7, 7, 2048, 1, 7, 7, 2048]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 1, 1], 1], ["SP", 3, 8, 7, [1, 1, 7], 1], ["SP", 3, 12, 2048, [256, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7, [...]
+{"i": [["[\"6b7583cf23c7c37d3212cad9d06e58c1\", 1, 56, 56, 256, 1, 1, 256, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [4, 1, 1], 1], ["SP", 3, 8, 56, [7, 4, 1], 1], ["SP", 3, 12, 64, [4, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 2 [...]
+{"i": [["[\"2350d19dc42a0665244368384c66b3a5\", 1, 56, 56, 64, 3, 3, 64, 64, 1, 1, 1, 64, 1, 56, 56, 64]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [1, 1, 1], 1], ["SP", 3, 8, 56, [7, 1, 4], 1], ["SP", 3, 12, 64, [1, 1, 16], 1], ["SP", 3, 16, 3, [1], 1], ["SP", 3, 18, 3, [3], 1], ["SP", 3, 20, 64, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21,  [...]
+{"i": [["[\"6b7583cf23c7c37d3212cad9d06e58c1\", 1, 7, 7, 2048, 1, 1, 2048, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [7, 1, 1], 1], ["SP", 3, 8, 7, [1, 7, 1], 1], ["SP", 3, 12, 512, [2, 2, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 2048, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 2 [...]
+{"i": [["[\"1cc666833c122282e3fcf3595901b12b\", 1, 56, 56, 64, 1, 1, 64, 256, 1, 56, 56, 256, 1, 56, 56, 256]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [4, 2, 1], 1], ["SP", 3, 8, 56, [1, 1, 1], 1], ["SP", 3, 12, 256, [2, 4, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, 7,  [...]
+{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 28, 28, 512, 1, 1, 512, 1024, 1, 14, 14, 1024]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 2, 1], 1], ["SP", 3, 8, 14, [14, 1, 1], 1], ["SP", 3, 12, 1024, [2, 2, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 1 [...]
+{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 56, 56, 256, 1, 1, 256, 512, 1, 28, 28, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [2, 1, 2], 1], ["SP", 3, 8, 28, [1, 2, 1], 1], ["SP", 3, 12, 512, [16, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [8], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19,  [...]
+{"i": [["[\"f4380bb1dc62422a69ad4a1a9771f927\", 1, 14, 14, 1024, 1, 1, 1024, 2048, 1, 7, 7, 2048]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [7, 1, 1], 1], ["SP", 3, 8, 7, [1, 7, 1], 1], ["SP", 3, 12, 2048, [128, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 1024, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 1 [...]
+{"i": [["[\"2350d19dc42a0665244368384c66b3a5\", 1, 7, 7, 512, 3, 3, 512, 512, 1, 1, 1, 512, 1, 7, 7, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 7, [1, 7, 1], 1], ["SP", 3, 8, 7, [7, 1, 1], 1], ["SP", 3, 12, 512, [4, 1, 8], 1], ["SP", 3, 16, 3, [3], 1], ["SP", 3, 18, 3, [1], 1], ["SP", 3, 20, 512, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21,  [...]
+{"i": [["[\"1cc666833c122282e3fcf3595901b12b\", 1, 14, 14, 256, 1, 1, 256, 1024, 1, 14, 14, 1024, 1, 14, 14, 1024]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 14, [7, 1, 2], 1], ["SP", 3, 8, 14, [7, 2, 1], 1], ["SP", 3, 12, 1024, [16, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 256, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 2 [...]
+{"i": [["[\"7006235cfc29b73be524cf390ed5a977\", 1, 56, 56, 64, 1, 1, 64, 256, 1, 56, 56, 256]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CHW", 3, "local"], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [56, 1, 1], 1], ["SP", 3, 8, 56, [14, 1, 2], 1], ["SP", 3, 12, 256, [1, 2, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [64], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 2 [...]
+{"i": [["[\"1cc666833c122282e3fcf3595901b12b\", 1, 28, 28, 128, 1, 1, 128, 512, 1, 28, 28, 512, 1, 28, 28, 512]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [1, 1, 2], 1], ["SP", 3, 8, 28, [1, 1, 1], 1], ["SP", 3, 12, 512, [4, 1, 32], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 128, [4], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17, 19, 21, 3, [...]
+{"i": [["[\"1b524af89dd867d26059e1f621cf987c\", 1, 56, 56, 64, 1, 1, 64, 256, 1, 56, 56, 256, 1, 1, 1, 256, 1, 56, 56, 256]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 7], ["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 56, [7, 1, 4], 1], ["SP", 3, 8, 56, [4, 2, 1], 1], ["SP", 3, 12, 256, [32, 1, 8], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 64, [2], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, [...]
+{"i": [["[\"6b7583cf23c7c37d3212cad9d06e58c1\", 1, 28, 28, 512, 1, 1, 512, 128, 1, 1, 1, 128, 1, 28, 28, 128]", "llvm -keys=cpu -link-params=0 -mcpu=core-avx2", [8, 64, 64, 0, 0, 0, 0, 0], "", 2], [[], [["CI", 5], ["SP", 3, 0, 1, [1, 1, 1], 1], ["SP", 3, 4, 28, [7, 1, 4], 1], ["SP", 3, 8, 28, [14, 1, 1], 1], ["SP", 3, 12, 128, [1, 1, 16], 1], ["SP", 3, 16, 1, [1], 1], ["SP", 3, 18, 1, [1], 1], ["SP", 3, 20, 512, [1], 1], ["RE", 3, [0, 4, 8, 12, 1, 5, 9, 13, 16, 18, 20, 2, 6, 10, 14, 17,  [...]