You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/10/26 21:29:24 UTC

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6750: [AutoScheduler] New layout rewrite option: Weight pre-transpose

comaniac commented on a change in pull request #6750:
URL: https://github.com/apache/incubator-tvm/pull/6750#discussion_r512274881



##########
File path: python/tvm/auto_scheduler/compute_dag.py
##########
@@ -50,6 +50,11 @@ class ComputeDAG(Object):
     compute : Union[List[Tensor], str, Schedule]
         Input/output tensors or workload key for a compute declaration.
     """
+    LAYOUT_REWRITE_TABLE = {
+        "NoRewrite": 0,
+        "RewriteWithPlaceholder": 1,
+        "RewriteWithPreTranspose": 2,
+    }

Review comment:
       Is that possible to avoid duplicated layout rewrite option tables, although to be honest I have no idea how to do that...

##########
File path: src/auto_scheduler/compute_dag.cc
##########
@@ -902,28 +903,91 @@ void ComputeDAG::RewriteLayout(const Array<Step>& transform_steps) {
       if (!direct_consumer) {
         continue;
       }
+      handled_ops.insert(placeholder_op);
 
+      // Process original layout
       std::set<std::string> placeholder_axis_names;
-      GetOrigLayout(&placeholder_axis_names, op, placeholder);
+      std::string origin_layout = GetOrigLayout(&placeholder_axis_names, op, placeholder);
+      Array<PrimExpr> origin_shape;
+      std::vector<std::string> origin_axes;
+      ParseKernelLayout(origin_layout, &origin_shape, &origin_axes);
 
-      Array<PrimExpr> new_shape;
+      // Process new layout
       std::string new_layout =
-          GetNewLayout(&new_shape, state, stage_id, stage, op, placeholder, placeholder_axis_names);
-
-      handled_ops.insert(placeholder_op);
-
-      Array<te::Operation> old_ops = p_dag->ops;
-      ArrayNode* pops = p_dag->ops.CopyOnWrite();
-
-      // Create new placeholder
-      te::Operation new_placeholder_op;
-      new_placeholder_op = te::PlaceholderOp(placeholder_op->name, new_shape,
+          GetNewLayout(state, stage_id, stage, op, placeholder, placeholder_axis_names);
+      Array<PrimExpr> new_shape;
+      std::vector<std::string> new_axes;
+      ParseKernelLayout(new_layout, &new_shape, &new_axes);
+
+      // Process op updates
+      te::Operation new_op_to_update;
+      if (layout_rewrite == LayoutRewriteOption::RewriteWithPlaceholder) {
+        // Create new placeholder
+        new_op_to_update = te::PlaceholderOp(placeholder_op->name, new_shape,
                                              placeholder_op.as<te::PlaceholderOpNode>()->dtype);
+      } else if (layout_rewrite == LayoutRewriteOption::RewriteWithPreTranspose) {
+        // Process index strides
+        std::unordered_map<std::string, PrimExpr> axes_stride;
+        for (const auto& i : origin_axes) {
+          axes_stride[i] = Integer(1);
+        }
+        Array<PrimExpr> new_stride(new_shape.size(), PrimExpr());
+        PrimExpr temp = Integer(1);
+        for (int i = new_shape.size() - 1; i >= 0; i--) {
+          new_stride.Set(i, axes_stride[new_axes[i]]);
+          axes_stride[new_axes[i]] *= new_shape[i];
+        }
+
+        // Add extra layout transpose stage
+        const auto& layout_transform_tensor = te::compute(
+            new_shape,
+            [&new_stride, &placeholder_op, &origin_shape, &new_shape, &origin_axes,
+             &new_axes](const tvm::runtime::Array<tvm::tir::Var>& indices) -> tvm::PrimExpr {
+              Array<PrimExpr> access_indices;
+              for (size_t indice_index = 0; indice_index < origin_shape.size(); indice_index++) {
+                PrimExpr temp = Integer(0);
+                for (size_t i = 0; i < new_shape.size(); i++) {
+                  if (origin_axes[indice_index].compare(new_axes[i]) == 0) {
+                    temp += indices[i] * new_stride[i];
+                  }
+                }
+                access_indices.push_back(temp);
+              }
+              return placeholder_op.output(0)(access_indices);
+            },
+            "auto_schedule_layout_transpose");
+        new_op_to_update = layout_transform_tensor->op;
+
+        // Update the transform steps
+        for (size_t i = 0; i < transform_steps->size(); i++) {
+          Step step = (*transform_steps)[i];
+          if (step->stage_id >= static_cast<int>(stage_id)) {
+            step.CopyOnWrite()->stage_id++;
+          }
+          if (step->IsInstance<ComputeAtStepNode>()) {
+            auto compute_at_step = tvm::Downcast<ComputeAtStep>(step);
+            if (compute_at_step->target_stage_id >= static_cast<int>(stage_id)) {
+              dynamic_cast<ComputeAtStepNode*>(compute_at_step.CopyOnWrite())->target_stage_id++;
+            }
+            transform_steps->Set(i, std::move(compute_at_step));
+          } else {
+            transform_steps->Set(i, std::move(step));
+          }
+        }
+        Array<Integer> to_fuse;
+        for (size_t i = 0; i < new_shape.size() - 1; i++) {
+          to_fuse.push_back(i);
+        }
+        transform_steps->push_back(FuseStep(stage_id, to_fuse));
+        transform_steps->push_back(AnnotationStep(stage_id, 0, IteratorAnnotation::kParallel));
+      } else {
+        LOG(FATAL) << "Call ComputeDAG::RewriteLayout with NoRewrite.";

Review comment:
       Better to be an assertion as this is enforced in `ApplySteps`.

##########
File path: include/tvm/auto_scheduler/transform_step.h
##########
@@ -182,7 +182,23 @@ class StepNode : public Object {
  */
 class Step : public ObjectRef {
  public:
-  TVM_DEFINE_MUTABLE_OBJECT_REF_METHODS(Step, ObjectRef, StepNode);
+  /*!
+   * \brief CopyOnWrite function for Step.
+   * This works almost the same as a normal ObjectRef.CopyOnWrite(), but can dispatch to different
+   * steps.
+   * \return A base StepNode pointer, need to cast to its real StepNode type before doing any
+   * modifies.

Review comment:
       ```suggestion
      * modifications.
   ```

##########
File path: tests/python/unittest/test_auto_scheduler_layout_rewrite.py
##########
@@ -50,16 +55,18 @@ def test_layout_rewrite_correctness():
 
         search_policy = auto_scheduler.SketchPolicy(task)
 
+        measure_ctx = auto_scheduler.LocalRPCMeasureContext()

Review comment:
       Better to `del measure_ctx` explicitly. 

##########
File path: tests/python/unittest/test_auto_scheduler_layout_rewrite.py
##########
@@ -100,10 +107,56 @@ def test_layout_rewrite_correctness():
         func_ref(*args_ref)
         ctx.sync()
 
-        np.testing.assert_allclose(np_args[0], np_args_ref[0])
-        np.testing.assert_allclose(np_args[2], np_args_ref[2])
+        np.testing.assert_allclose(args[0].asnumpy(), args_ref[0].asnumpy())
+        np.testing.assert_allclose(args[2].asnumpy(), args_ref[2].asnumpy())
+
+
+def test_correctness_layout_rewrite_with_pre_transpose():
+    N = 128
+    target = tvm.target.Target("llvm")
+    task = auto_scheduler.create_task(matmul_auto_scheduler_test, (N, N, N), target)
+    dag = task.compute_dag
+
+    with tempfile.NamedTemporaryFile() as fp:
+        log_file = fp.name
+
+        search_policy = auto_scheduler.SketchPolicy(task)
+
+        measure_ctx = auto_scheduler.LocalRPCMeasureContext()

Review comment:
       ditto




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org