You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/27 03:17:25 UTC

[GitHub] [incubator-tvm] comaniac commented on a change in pull request #6568: [AutoScheduler] Improve the rule of mutating parallel granularity

comaniac commented on a change in pull request #6568:
URL: https://github.com/apache/incubator-tvm/pull/6568#discussion_r495521431



##########
File path: tests/python/unittest/test_auto_scheduler_evolutionary_search.py
##########
@@ -22,56 +22,102 @@
 from tvm.auto_scheduler.cost_model.cost_model import PythonBasedModel
 
 
-class MockCostModel(PythonBasedModel):
-    """A mock cost model that rates 1 only for the states with tile_k=2."""
+def test_mutate_tile_size():
+    """
+    The test case initializes evo search with a batch of "bad" states and check whether
+    the search algorithm can find "good" states by mutating the "bad" states.
+
+    This unit test has been tested with 1,000 runs with no failures, meaning that
+    the failure rate is less than 0.1%.
+    """
 
-    def predict(self, task, states):
-        scores = []
-        found = False
-        for state in states:
+    class MockCostModel(PythonBasedModel):
+        """A mock cost model that rates 1 only for the states with tile_k=2."""
+
+        @staticmethod
+        def is_good_state(state):
             for line in str(state).split("\n"):
                 if line.find("k.1") != -1 and line.find("(0,2)") != -1:
-                    found = True
-                    break
-            scores.append(1 if found else 0)
-        return scores
+                    return True
+            return False
 
+        def predict(self, task, states):
+            scores = []
+            found = False
+            for state in states:
+                scores.append(1 if self.is_good_state(state) else 0)
+            return scores
 
-def test_evo_search():
-    """Test evolutionary search. Since we cannot mock random number generator,
-    we mocked the cost model to manually guide the evo search. If evo search works
-    as expected, it should find the target state after a sufficient number of iterations.
-    This unit test has been tested with 1,000 runs with no failures, meaning that
-    the failure rate is less than 0.1%.
-    """
     workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (10, 10, 4))
     dag = auto_scheduler.ComputeDAG(workload_key)
     task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target("llvm"))
     policy = auto_scheduler.SketchPolicy(task, program_cost_model=MockCostModel(), verbose=0)
     states = policy.sample_initial_population(50)
-    pruned_states = []
+
+    bad_states = []
     for state in states:
-        found = False
-        for line in str(state).split("\n"):
-            # Remove all tile_k=2 states and expect evo search will fine them.
-            if line.find("k.1") != -1 and line.find("(0,2)") != -1:
-                found = True
-                break
-        if not found:
-            pruned_states.append(state)
+        if not MockCostModel.is_good_state(state):
+            bad_states.append(state)
 
-    new_states = policy.evolutionary_search(pruned_states, 50)
+    new_states = policy.evolutionary_search(bad_states, 50)
     found = False
     for state in new_states:
-        for line in str(state).split("\n"):
-            # Check if evo search found at least one state with tile_k=2.
-            if line.find("k.1") != -1 and line.find("(0,2)") != -1:
+        if MockCostModel.is_good_state(state):
+            found = True
+            break
+    assert found
+
+
+def test_mutate_parallel():
+    """
+    The test case initializes evo search with a batch of "bad" states and check whether
+    the search algorithm can find "good" states by mutating the "bad" states.
+
+    This unit test has been tested with 1,000 runs with no failures, meaning that
+    the failure rate is less than 0.1%.
+    """
+
+    class MockCostModel(PythonBasedModel):
+        @staticmethod
+        def is_good_state(state):
+            for line in str(state).split("\n"):
+                if (
+                    line.find("parallel i.0@ (0") != -1
+                    or line.find("parallel i.0@j.0@ (0") != -1
+                    or line.find("parallel i.0@j.0@i.1@ (0") != -1
+                ):
+                    return True
+            return False
+
+        def predict(self, task, states):
+            scores = []
+            found = False
+            for state in states:
+                scores.append(1 if self.is_good_state(state) else 0)
+            return scores
+
+    workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (1024, 1024, 1024))
+    dag = auto_scheduler.ComputeDAG(workload_key)
+    task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.Target("llvm"))
+    policy = auto_scheduler.SketchPolicy(task, program_cost_model=MockCostModel(), verbose=0)
+    states = policy.sample_initial_population(100)
+
+    bad_states = []
+    for state in states:
+        if not MockCostModel.is_good_state(state):
+            bad_states.append(state)
+
+    found = False
+    retry_ct = 0

Review comment:
       Where does this get updated?

##########
File path: src/auto_scheduler/search_policy/sketch_policy_rules.cc
##########
@@ -1081,52 +1085,43 @@ PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* pol
 
   // Randomly pick one parallel step.
   size_t step_id = parallel_steps[(*rand_gen)() % parallel_steps.size()];
-  auto ps = (*state)->transform_steps[step_id].as<AnnotationStepNode>();
-  CHECK(ps);
-  size_t stage_id = ps->stage_id;
-  size_t iter_id = ps->iter_id;
-  const Stage& stage = (*state)->stages[stage_id];
-  const Iterator& it = stage->iters[iter_id];
 
   // Replay a new state until the picked fuse step.
   State tmp_s = policy->search_task->compute_dag->init_state;
   for (size_t s = 0; s < step_id - 1; ++s) {
-    auto step = (*state)->transform_steps[s];
+    const auto& step = (*state)->transform_steps[s];
     tmp_s.CopyOnWrite()->transform_steps.push_back(step);
     StepApplyToState(step, &tmp_s, policy->search_task->compute_dag);
   }
 
-  // Determine the fusion mutation direction.
-  // 0: fuse less; 1: fuse more.
+  // Compute all possible fusion granularities
   auto fuse_step = (*state)->transform_steps[step_id - 1].as<FuseStepNode>();
-  auto fused_ids = fuse_step->fused_ids;
-  std::vector<double> fuse_dir = {0.5, 1.0};
-
-  // The case that we can only fuse more. This may happen after multiple mutations.
-  if (fused_ids.size() == 1) {
-    fuse_dir[0] = 0.0;
-  }
+  int stage_id = fuse_step->stage_id;
+  const Stage& stage = tmp_s->stages[stage_id];
+  size_t iter_id;

Review comment:
       Maybe call it `max_fusable_iter_id` or something similar to make it clearer?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org