You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/17 00:36:57 UTC

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6073: [Ansor][AutoTVM v2.0] Part 1: Add annotation/compute_at/compute_root/compute_inline steps

comaniac commented on a change in pull request #6073:
URL: https://github.com/apache/incubator-tvm/pull/6073#discussion_r456141719



##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -90,12 +90,69 @@ Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters,
   data_ = std::move(node);
 }
 
+/********** AttachMap **********/
+void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) {
+  AttachMapNode* pnode = CopyOnWrite();
+
+  // Delete the current entry of this stage
+  DeleteStageEntry(pnode, stage_id);
+
+  // Store the new relations to map
+  IterKey iter_key(target_stage_id, target_iter_id);
+  pnode->stage_to_attach_iter[stage_id] = iter_key;
+  pnode->iter_to_attached_stages[iter_key].push_back(stage_id);
+}
+
+void AttachMap::DeleteStage(int stage_id) {
+  AttachMapNode* pnode = CopyOnWrite();
+  // Delete the original stage entry
+  DeleteStageEntry(pnode, stage_id);
+}
+
+void AttachMap::UpdateIters(const std::vector<IterKey>& old_iters,
+                            const std::vector<IterKey>& new_iters) {
+  AttachMapNode* pnode = CopyOnWrite();
+
+  CHECK_EQ(old_iters.size(), new_iters.size());
+  for (size_t i = 0; i < old_iters.size(); ++i) {
+    auto entry = pnode->iter_to_attached_stages.find(old_iters[i]);
+    if (entry == pnode->iter_to_attached_stages.end()) {
+      continue;
+    }
+
+    // Replace iter in the value of `stage_to_attach_iter`
+    for (const auto& s : entry->second) {
+      pnode->stage_to_attach_iter[s] = new_iters[i];
+    }
+
+    // Replace iter in the key of `iter_to_attached_stages`

Review comment:
       Maybe "Remove the old iter from iter_to_attached_stages and add the new iter" might be better.

##########
File path: src/auto_scheduler/measure_record.cc
##########
@@ -169,6 +206,18 @@ struct Handler<::tvm::Array<::tvm::auto_scheduler::Step>> {
           fused_ids.push_back(i);
         }
         data->push_back(::tvm::auto_scheduler::FuseStep(stage_id, fused_ids));
+      } else if (name == "AN") {

Review comment:
       I feel the current approach of (de)serializing records is not scalable and hard to be maintained. Specifically, we put all deserialization rules to `Read` and all serialization rules to `Write`. The key to connect the serialization logic to the corresponding deserialization logic is a two character short name (e.g., AN). It seems to me that it would be better to define a short name and (de)serialization logic in each step. In this case, we can use `ps->serialize(writer)` for serialization, and build a step factory to deserialize them.
   

##########
File path: python/tvm/auto_scheduler/loop_state.py
##########
@@ -161,16 +202,116 @@ def fuse(self, stage, iters):
             The Stage to be fused, can be a Stage order index, Stage operation or stage
             output tensor.
         iters : List[Iterator]
-            The iterators to be fused
+            The iterators to be fused.
+
+        Returns
+        -------
+        res_it : Iterator
+            The fused Iterator.
+        """
+        self.state_object, res = _ffi_api.StateFuse(self.state_object,
+                                                    self._resolve_stage_id(stage), iters)
+        return res
+
+    def vectorize(self, stage, iterator):
+        """ Schedule primitive corresponds to te.vectorize.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be vectorized, can be a Stage order index, Stage operation or stage
+            output tensor.
+        iterator : Iterator
+            The iterator to be vectorized.
 
         Returns
         -------
         res_it : Iterator
-            The fused Iterator
+            The vectorized Iterator.
         """
-        stage_id = self._resolve_stage_id(stage)
+        self.state_object, res = _ffi_api.StateVectorize(self.state_object,
+                                                         self._resolve_stage_id(stage), iterator)
+        return res
+
+    def parallel(self, stage, iterator):
+        """ Schedule primitive corresponds to te.parallel.
 
-        self.state_object, res = _ffi_api.StateFuse(self.state_object, stage_id, iters)
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be paralleled, can be a Stage order index, Stage operation or stage
+            output tensor.
+        iterator : Iterator
+            The iterator to be paralleled.
+
+        Returns
+        -------
+        res_it : Iterator
+            The paralleled Iterator.
+        """
+        self.state_object, res = _ffi_api.StateParallel(self.state_object,
+                                                        self._resolve_stage_id(stage), iterator)
+        return res
+
+    def unroll(self, stage, iterator, max_unroll=None):
+        """ Schedule primitive corresponds to te.unroll.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be unrolled, can be a Stage order index, Stage operation or stage
+            output tensor.
+        iterator : Iterator
+            The iterator to be unrolled.
+        max_unroll : Optional[int]
+            The max unroll limit. Iterator with extent larger than this limit will be skipped.
+
+        Returns
+        -------
+        res_it : Iterator
+            The unrolled Iterator.
+        """
+        self.state_object, res = _ffi_api.StateUnroll(self.state_object,
+                                                      self._resolve_stage_id(stage), iterator,
+                                                      max_unroll if max_unroll else -1)
+        return res
+
+    def bind(self, stage, iterator, thread_name):
+        """ Schedule primitive corresponds to te.bind.
+
+        Parameters
+        ----------
+        stage : Union[int, Operation, Tensor]
+            The Stage to be binded, can be a Stage order index, Stage operation or stage
+            output tensor.
+        iterator : Iterator
+            The iterator to be binded.
+        thread_name : str
+            The thread type to be binded. Currently support:
+            - vthread
+            - blockIdx.x
+            - threadIdx.x
+            - blockIdx.y
+            - threadIdx.y
+
+        Returns
+        -------
+        res_it : Iterator
+            The binded Iterator.
+        """
+        trans_table = {
+            "vthread": 4,
+            "blockIdx.x": 5,
+            "threadIdx.x": 6,
+            "blockIdx.y": 7,
+            "threadIdx.y": 8,
+        }

Review comment:
       We should make this dict static.

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -143,12 +265,67 @@ void State::DoReorderStep(const ReorderStep& step) {
                      Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs));
 }
 
+void State::DoComputeAtStep(const ComputeAtStep& step) {
+  const Stage& stage = operator->()->stages[step->stage_id];
+
+  // After compute_at, we don't know the accurate length information any more
+  // If we do want to know the accurate lengths, we can call ComputeDAG::InferBound

Review comment:
       This comment should be in the `compute_at()` function definition in `loop_state.h` because this is more like a side effect of this function.
   
   In addition, I think it's fine to say we intentionally removed the loop sizes of that state when doing this step since it is not accurate anymore after compute_at. Run ComputeDAG::InferBound again to recover the loop sizes.

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -264,6 +462,38 @@ Iterator State::DoFuseStep(const FuseStep& step) {
   pstate->stages.Set(stage_id,
                      Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
 
+  // We have to update the iterator relations in attach map, these two vectors keep the replacement
+  // mapping

Review comment:
       ditto.

##########
File path: tests/python/unittest/test_auto_scheduler_loop_state.py
##########
@@ -61,5 +61,79 @@ def test_split_fuse_reorder():
     assert s1[C].iters[4].range.extent == 8
     assert s1[C].iters[5].range.extent == 2
 
+    s1.parallel(C, j1)
+    s1.unroll(C, j2)
+    s1.vectorize(C, j3)
+    s1.bind(C, i1, "blockIdx.x")
+    s1.bind(C, i2, "vthread")
+    s1.bind(C, i3, "threadIdx.y")
+
+
+def test_compute_at_root_inline():
+    dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3))
+    s0 = dag.get_init_state()
+
+    # data, padding, kernel = 0, 1, 2
+    conv = s0.stage_ops[3]
+    # bias = 4
+    bias_add = s0.stage_ops[5]
+    # bn_scale = 6
+    bn_mul = s0.stage_ops[7]
+    # bn_offset = 8
+    bn_add = s0.stage_ops[9]
+    relu = s0.stage_ops[10]
+
+    s0.compute_inline(bn_add)
+    s0.compute_inline(bn_mul)
+    s0.compute_inline(bias_add)
+    s0.compute_at(conv, relu, s0[relu].iters[2])
+    print(s0)

Review comment:
       Remove this line.

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -90,12 +90,69 @@ Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters,
   data_ = std::move(node);
 }
 
+/********** AttachMap **********/
+void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) {
+  AttachMapNode* pnode = CopyOnWrite();
+
+  // Delete the current entry of this stage
+  DeleteStageEntry(pnode, stage_id);
+
+  // Store the new relations to map
+  IterKey iter_key(target_stage_id, target_iter_id);
+  pnode->stage_to_attach_iter[stage_id] = iter_key;
+  pnode->iter_to_attached_stages[iter_key].push_back(stage_id);
+}
+
+void AttachMap::DeleteStage(int stage_id) {
+  AttachMapNode* pnode = CopyOnWrite();
+  // Delete the original stage entry
+  DeleteStageEntry(pnode, stage_id);
+}
+
+void AttachMap::UpdateIters(const std::vector<IterKey>& old_iters,
+                            const std::vector<IterKey>& new_iters) {
+  AttachMapNode* pnode = CopyOnWrite();
+
+  CHECK_EQ(old_iters.size(), new_iters.size());
+  for (size_t i = 0; i < old_iters.size(); ++i) {
+    auto entry = pnode->iter_to_attached_stages.find(old_iters[i]);
+    if (entry == pnode->iter_to_attached_stages.end()) {
+      continue;
+    }
+
+    // Replace iter in the value of `stage_to_attach_iter`

Review comment:
       Maybe "Update the attaching target of an old iter to the new iter" might be clearer.

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -90,12 +90,69 @@ Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters,
   data_ = std::move(node);
 }
 
+/********** AttachMap **********/
+void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) {
+  AttachMapNode* pnode = CopyOnWrite();
+
+  // Delete the current entry of this stage
+  DeleteStageEntry(pnode, stage_id);
+
+  // Store the new relations to map
+  IterKey iter_key(target_stage_id, target_iter_id);
+  pnode->stage_to_attach_iter[stage_id] = iter_key;
+  pnode->iter_to_attached_stages[iter_key].push_back(stage_id);
+}
+
+void AttachMap::DeleteStage(int stage_id) {
+  AttachMapNode* pnode = CopyOnWrite();
+  // Delete the original stage entry
+  DeleteStageEntry(pnode, stage_id);
+}
+
+void AttachMap::UpdateIters(const std::vector<IterKey>& old_iters,
+                            const std::vector<IterKey>& new_iters) {

Review comment:
       I guess you need to guarantee the length of both vectors are the same?

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -143,12 +265,67 @@ void State::DoReorderStep(const ReorderStep& step) {
                      Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs));
 }
 
+void State::DoComputeAtStep(const ComputeAtStep& step) {
+  const Stage& stage = operator->()->stages[step->stage_id];
+
+  // After compute_at, we don't know the accurate length information any more
+  // If we do want to know the accurate lengths, we can call ComputeDAG::InferBound
+  Array<Iterator> new_iters;
+  for (const Iterator& it : stage->iters) {
+    new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation));
+  }
+
+  StateNode* pstate = CopyOnWrite();
+  pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+                                           ComputeAtKind::kIter, stage->attrs));
+  // Update attach map
+  pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id);
+}
+
+void State::DoComputeRootStep(const ComputeRootStep& step) {
+  const Stage& stage = operator->()->stages[step->stage_id];
+
+  // After compute_at, we don't know the accurate length information any more
+  // If we do want to know the accurate lengths, we can call ComputeDAG::InferBound

Review comment:
       ditto.

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -90,12 +90,69 @@ Stage::Stage(te::Operation op, StageKind op_type, const Array<Iterator>& iters,
   data_ = std::move(node);
 }
 
+/********** AttachMap **********/
+void AttachMap::SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id) {
+  AttachMapNode* pnode = CopyOnWrite();
+
+  // Delete the current entry of this stage
+  DeleteStageEntry(pnode, stage_id);
+
+  // Store the new relations to map
+  IterKey iter_key(target_stage_id, target_iter_id);
+  pnode->stage_to_attach_iter[stage_id] = iter_key;
+  pnode->iter_to_attached_stages[iter_key].push_back(stage_id);
+}
+
+void AttachMap::DeleteStage(int stage_id) {
+  AttachMapNode* pnode = CopyOnWrite();
+  // Delete the original stage entry
+  DeleteStageEntry(pnode, stage_id);
+}
+
+void AttachMap::UpdateIters(const std::vector<IterKey>& old_iters,
+                            const std::vector<IterKey>& new_iters) {
+  AttachMapNode* pnode = CopyOnWrite();
+
+  CHECK_EQ(old_iters.size(), new_iters.size());
+  for (size_t i = 0; i < old_iters.size(); ++i) {
+    auto entry = pnode->iter_to_attached_stages.find(old_iters[i]);
+    if (entry == pnode->iter_to_attached_stages.end()) {
+      continue;
+    }

Review comment:
       Add a comment saying that we are skpping the old iterators that have no stage attached.

##########
File path: src/auto_scheduler/loop_state.h
##########
@@ -217,6 +196,68 @@ class Stage : public ObjectRef {
   TVM_DEFINE_OBJECT_REF_COW_METHOD(StageNode);
 };
 
+/*! \brief Use stage_id to represent a stage. */
+using StageKey = int;
+/*! \brief Use stage_id and iter_id to represent a iterator. */
+using IterKey = std::pair<int, int>;
+
+/*!
+ * \brief stores the compute_at relation between stages
+ * This stores a bi-directional mapping from stages and iter:
+ * 1. Stage to its attached iterator
+ * 2. Iterator to the stage attached to it
+ * You can use AttachMapNode::stage_to_attach_iter and AttachMapNode::iter_to_attached_stages
+ * to query the relations
+ */
+class AttachMapNode : public Object {
+ public:
+  /*! \brief A Map to store the mapping of stage to its attached iterator. */
+  std::unordered_map<StageKey, IterKey> stage_to_attach_iter;
+  /*! \brief A Map to store the mapping of iterator to the stage attached to it. */
+  std::unordered_map<IterKey, std::vector<StageKey>> iter_to_attached_stages;
+
+  static constexpr const char* _type_key = "auto_scheduler.AttachMap";
+  TVM_DECLARE_FINAL_OBJECT_INFO(AttachMapNode, Object);
+};
+
+/*!
+ * \brief Managed reference to AttachMapNode.
+ * \sa AttachMapNode
+ */
+class AttachMap : public ObjectRef {
+ public:
+  /*!
+   * \brief Process the stage/iterator mapping after compute at.
+   * \param stage_id The index of the stage to be compute at.
+   * \param target_stage_id The index of stage that this step will compute at to.
+   * \param target_iter_id The index of iterator in target stage that this step will compute at to.
+   */
+  void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);
+  /*!
+   * \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage.
+   * \param stage_id The index of the stage to be compute at.
+   */
+  void DeleteStage(int stage_id);
+  /*!
+   * \brief Update the iterator relations in AttachMap.

Review comment:
       To my understanding, this function attempts to update from (stage -> old_iter) to (stage -> new_iter). It needs a clearer description.

##########
File path: tests/python/unittest/test_auto_scheduler_loop_state.py
##########
@@ -61,5 +61,79 @@ def test_split_fuse_reorder():
     assert s1[C].iters[4].range.extent == 8
     assert s1[C].iters[5].range.extent == 2
 
+    s1.parallel(C, j1)
+    s1.unroll(C, j2)
+    s1.vectorize(C, j3)
+    s1.bind(C, i1, "blockIdx.x")
+    s1.bind(C, i2, "vthread")
+    s1.bind(C, i3, "threadIdx.y")

Review comment:
       Should check their annotations?

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -143,12 +265,67 @@ void State::DoReorderStep(const ReorderStep& step) {
                      Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs));
 }
 
+void State::DoComputeAtStep(const ComputeAtStep& step) {
+  const Stage& stage = operator->()->stages[step->stage_id];
+
+  // After compute_at, we don't know the accurate length information any more
+  // If we do want to know the accurate lengths, we can call ComputeDAG::InferBound
+  Array<Iterator> new_iters;
+  for (const Iterator& it : stage->iters) {
+    new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation));
+  }
+
+  StateNode* pstate = CopyOnWrite();
+  pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+                                           ComputeAtKind::kIter, stage->attrs));
+  // Update attach map
+  pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id);
+}
+
+void State::DoComputeRootStep(const ComputeRootStep& step) {
+  const Stage& stage = operator->()->stages[step->stage_id];
+
+  // After compute_at, we don't know the accurate length information any more
+  // If we do want to know the accurate lengths, we can call ComputeDAG::InferBound
+  Array<Iterator> new_iters;
+  for (const Iterator& it : stage->iters) {
+    new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation));
+  }
+
+  StateNode* pstate = CopyOnWrite();
+  pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+                                           ComputeAtKind::kRoot, stage->attrs));
+  // Update attach map
+  pstate->attach_map.DeleteStage(step->stage_id);
+}
+
+void State::DoComputeInlineStep(const ComputeInlineStep& step) {
+  const Stage& stage = operator->()->stages[step->stage_id];
+
+  // CHECK the validity of compute_inline
+  for (size_t i = 0; i < stage->iters.size(); ++i) {
+    CHECK_EQ(operator->()->attach_map->iter_to_attached_stages.count(
+                 std::make_pair(step->stage_id, i)),
+             0)
+        << "Invalid compute_inline: Because there are some other stages "

Review comment:
       Remove "Because"

##########
File path: tests/python/unittest/test_auto_scheduler_loop_state.py
##########
@@ -61,5 +61,79 @@ def test_split_fuse_reorder():
     assert s1[C].iters[4].range.extent == 8
     assert s1[C].iters[5].range.extent == 2
 
+    s1.parallel(C, j1)
+    s1.unroll(C, j2)
+    s1.vectorize(C, j3)
+    s1.bind(C, i1, "blockIdx.x")
+    s1.bind(C, i2, "vthread")
+    s1.bind(C, i3, "threadIdx.y")
+
+
+def test_compute_at_root_inline():
+    dag = auto_scheduler.ComputeDAG(conv2d_nchw_bn_relu(1, 224, 224, 3, 64, 7, 2, 3))
+    s0 = dag.get_init_state()
+
+    # data, padding, kernel = 0, 1, 2
+    conv = s0.stage_ops[3]
+    # bias = 4
+    bias_add = s0.stage_ops[5]
+    # bn_scale = 6
+    bn_mul = s0.stage_ops[7]
+    # bn_offset = 8
+    bn_add = s0.stage_ops[9]
+    relu = s0.stage_ops[10]
+
+    s0.compute_inline(bn_add)
+    s0.compute_inline(bn_mul)
+    s0.compute_inline(bias_add)
+    s0.compute_at(conv, relu, s0[relu].iters[2])
+    print(s0)
+    assert str(s0) == \

Review comment:
       I feel that it's not a good idea to check the correctness using printed strings...

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -143,12 +265,67 @@ void State::DoReorderStep(const ReorderStep& step) {
                      Stage(stage->op, stage->op_type, iters, stage->compute_at, stage->attrs));
 }
 
+void State::DoComputeAtStep(const ComputeAtStep& step) {
+  const Stage& stage = operator->()->stages[step->stage_id];
+
+  // After compute_at, we don't know the accurate length information any more
+  // If we do want to know the accurate lengths, we can call ComputeDAG::InferBound
+  Array<Iterator> new_iters;
+  for (const Iterator& it : stage->iters) {
+    new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation));
+  }
+
+  StateNode* pstate = CopyOnWrite();
+  pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+                                           ComputeAtKind::kIter, stage->attrs));
+  // Update attach map
+  pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id);
+}
+
+void State::DoComputeRootStep(const ComputeRootStep& step) {
+  const Stage& stage = operator->()->stages[step->stage_id];
+
+  // After compute_at, we don't know the accurate length information any more
+  // If we do want to know the accurate lengths, we can call ComputeDAG::InferBound
+  Array<Iterator> new_iters;
+  for (const Iterator& it : stage->iters) {
+    new_iters.push_back(Iterator(it->name, Range(), it->iter_kind, it->annotation));
+  }
+
+  StateNode* pstate = CopyOnWrite();
+  pstate->stages.Set(step->stage_id, Stage(stage->op, stage->op_type, std::move(new_iters),
+                                           ComputeAtKind::kRoot, stage->attrs));
+  // Update attach map
+  pstate->attach_map.DeleteStage(step->stage_id);
+}
+
+void State::DoComputeInlineStep(const ComputeInlineStep& step) {
+  const Stage& stage = operator->()->stages[step->stage_id];
+
+  // CHECK the validity of compute_inline

Review comment:
       s/CHECK/Check

##########
File path: src/auto_scheduler/loop_state.cc
##########
@@ -210,6 +387,16 @@ Array<Iterator> State::DoSplitStepCommon(int stage_id, int iter_id,
                      Stage(stage->op, stage->op_type, new_iters, stage->compute_at, stage->attrs));
   pstate->concrete &= concrete;
 
+  // We have to update the iterator relations in attach map, these two vectors keep the replacement
+  // mapping

Review comment:
       ```suggestion
     // Use two vectors to represent the iterator relation before and after the split
     // in order to update the attach_map
   ```
   
   We may also need to mention in the split function defintion in `loop_state.h` that if the attached target iterator is split, then the innermost split iterator would be the new attached target.

##########
File path: tests/python/unittest/test_auto_scheduler_measure.py
##########
@@ -18,14 +18,51 @@
 """ Test measurement and log serialization. """
 
 import tvm
-from tvm import auto_scheduler
+import topi
+from tvm import te, auto_scheduler
 import tempfile
 
 from test_auto_scheduler_common import get_tiled_matmul
 
 
 def test_record():
-    dag, s = get_tiled_matmul()
+    A = te.placeholder((512, 512), name='A')

Review comment:
       Move the checker in L67 to here. If LLVM is disabled, we don't have to even build this DAG.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org