You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/01/13 18:17:15 UTC

[GitHub] [tvm] comaniac commented on a change in pull request #7260: [AutoScheduler] Bug fix & Custom sketch support

comaniac commented on a change in pull request #7260:
URL: https://github.com/apache/tvm/pull/7260#discussion_r556711333



##########
File path: python/tvm/auto_scheduler/search_policy.py
##########
@@ -61,6 +62,71 @@ def __init__(self, filename):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allowing users to add

Review comment:
       ```suggestion
       A SearchCallback for SketchSearchPolicy that allows users to add
   ```

##########
File path: python/tvm/auto_scheduler/search_policy.py
##########
@@ -61,6 +62,71 @@ def __init__(self, filename):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allowing users to add
+    custom sketch rule.
+
+    Notes
+    -----
+    This is an advanced feature. Make sure you're clear how it
+    works and this should only be used in SketchSearchPolicy.
+
+    Parameters
+    ----------
+    meet_condition_func: Function
+        A function with `(policy, state, stage_id) -> int`
+    apply_func: Function

Review comment:
       ```suggestion
       apply_func: Callable
   ```

##########
File path: python/tvm/auto_scheduler/search_policy.py
##########
@@ -61,6 +62,71 @@ def __init__(self, filename):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allowing users to add
+    custom sketch rule.
+
+    Notes
+    -----
+    This is an advanced feature. Make sure you're clear how it
+    works and this should only be used in SketchSearchPolicy.
+
+    Parameters
+    ----------
+    meet_condition_func: Function
+        A function with `(policy, state, stage_id) -> int`
+    apply_func: Function
+        A function with `(policy, state, stage_id) -> [[State, int], ...]`
+    """
+
+    CONDITION_NUM = {"pass": 0, "apply": 1, "apply_and_skip_rest": 2}
+
+    def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"):

Review comment:
       From your test case I think we should not provide a default rule name. Otherwise it's easy to get the rule name conflict if users call `PreloadCustomSketchRule` twice.

##########
File path: src/auto_scheduler/search_policy/sketch_policy_rules.h
##########
@@ -131,6 +131,29 @@ DEFINE_SKETCH_GENERATION_RULE(RuleCrossThreadReduction);
  * location of the producers of compute ops that perform "fake reduction" with const tensors. */
 DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU);
 
+/*! \brief The rule that allows users to generate custom sketches. */
+class RuleCustomSketch : public SketchGenerationRule {

Review comment:
       Should this be named "CustomSketchRule"?

##########
File path: python/tvm/auto_scheduler/task_scheduler.py
##########
@@ -290,6 +296,9 @@ def tune(self, tune_option, search_policy="default", search_policy_params=None):
             "sketch.random" for SketchPolicy + RandomModel.
         search_policy_params : Optional[Dict[str, Any]]
             The parameters of the search policy
+        adapative_training : bool = False
+            Option used for XGBModel, which will reduce the model training frequency when there're

Review comment:
       ```suggestion
               Option used by XGBModel to reduce the model training frequency when there're
   ```

##########
File path: python/tvm/auto_scheduler/search_policy.py
##########
@@ -61,6 +62,71 @@ def __init__(self, filename):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allowing users to add
+    custom sketch rule.
+
+    Notes
+    -----
+    This is an advanced feature. Make sure you're clear how it
+    works and this should only be used in SketchSearchPolicy.
+
+    Parameters
+    ----------
+    meet_condition_func: Function

Review comment:
       ```suggestion
       meet_condition_func: Callable
   ```

##########
File path: python/tvm/auto_scheduler/search_policy.py
##########
@@ -61,6 +62,71 @@ def __init__(self, filename):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allowing users to add
+    custom sketch rule.
+
+    Notes
+    -----
+    This is an advanced feature. Make sure you're clear how it
+    works and this should only be used in SketchSearchPolicy.
+
+    Parameters
+    ----------
+    meet_condition_func: Function
+        A function with `(policy, state, stage_id) -> int`
+    apply_func: Function
+        A function with `(policy, state, stage_id) -> [[State, int], ...]`
+    """
+
+    CONDITION_NUM = {"pass": 0, "apply": 1, "apply_and_skip_rest": 2}
+
+    def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"):
+        self.__init_handle_by_constructor__(
+            _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name
+        )
+
+
+CUSTOM_SKETCH_REGISTRY = {}
+
+
+def register_custom_sketch_func(compute_name, func=None):
+    """Helper decorator to register custom sketch functions easily."""
+    global CUSTOM_SKETCH_REGISTRY
+
+    if callable(compute_name):
+        func = compute_name
+        compute_name = func.__name__
+
+    if not isinstance(compute_name, str):
+        raise ValueError("expect string function name")

Review comment:
       If this functions is used as a decorator, `assert` should be sufficient to check its type.

##########
File path: python/tvm/auto_scheduler/search_policy.py
##########
@@ -61,6 +62,71 @@ def __init__(self, filename):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allowing users to add
+    custom sketch rule.
+
+    Notes
+    -----
+    This is an advanced feature. Make sure you're clear how it
+    works and this should only be used in SketchSearchPolicy.
+
+    Parameters
+    ----------
+    meet_condition_func: Function
+        A function with `(policy, state, stage_id) -> int`
+    apply_func: Function
+        A function with `(policy, state, stage_id) -> [[State, int], ...]`
+    """
+
+    CONDITION_NUM = {"pass": 0, "apply": 1, "apply_and_skip_rest": 2}
+
+    def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"):
+        self.__init_handle_by_constructor__(
+            _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name
+        )
+
+
+CUSTOM_SKETCH_REGISTRY = {}
+
+
+def register_custom_sketch_func(compute_name, func=None):
+    """Helper decorator to register custom sketch functions easily."""
+    global CUSTOM_SKETCH_REGISTRY
+
+    if callable(compute_name):
+        func = compute_name
+        compute_name = func.__name__
+
+    if not isinstance(compute_name, str):
+        raise ValueError("expect string function name")
+
+    def register(myf):
+        if compute_name in CUSTOM_SKETCH_REGISTRY:
+            raise RuntimeError(
+                "Custom Sketch for %s has been registered for compute already" % compute_name
+            )
+
+        def meet_condition_func(policy, state, stage_id):
+            state = State(state, policy.search_task.compute_dag)
+            if state.stages[stage_id].op.name == compute_name:
+                return PreloadCustomSketchRule.CONDITION_NUM["apply_and_skip_rest"]
+            return PreloadCustomSketchRule.CONDITION_NUM["pass"]
+
+        CUSTOM_SKETCH_REGISTRY[compute_name] = PreloadCustomSketchRule(meet_condition_func, myf)
+        return myf
+
+    if func:
+        return register(func)

Review comment:
       When will `func=None` at this point?

##########
File path: python/tvm/auto_scheduler/search_policy.py
##########
@@ -61,6 +62,71 @@ def __init__(self, filename):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allowing users to add
+    custom sketch rule.
+
+    Notes
+    -----
+    This is an advanced feature. Make sure you're clear how it
+    works and this should only be used in SketchSearchPolicy.
+
+    Parameters
+    ----------
+    meet_condition_func: Function
+        A function with `(policy, state, stage_id) -> int`
+    apply_func: Function
+        A function with `(policy, state, stage_id) -> [[State, int], ...]`
+    """
+
+    CONDITION_NUM = {"pass": 0, "apply": 1, "apply_and_skip_rest": 2}
+
+    def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"):
+        self.__init_handle_by_constructor__(
+            _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name
+        )
+
+
+CUSTOM_SKETCH_REGISTRY = {}
+
+
+def register_custom_sketch_func(compute_name, func=None):
+    """Helper decorator to register custom sketch functions easily."""

Review comment:
       This is an interface so please complete the docstring.

##########
File path: python/tvm/auto_scheduler/search_policy.py
##########
@@ -61,6 +62,71 @@ def __init__(self, filename):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allowing users to add
+    custom sketch rule.
+
+    Notes
+    -----
+    This is an advanced feature. Make sure you're clear how it
+    works and this should only be used in SketchSearchPolicy.
+
+    Parameters
+    ----------
+    meet_condition_func: Function
+        A function with `(policy, state, stage_id) -> int`
+    apply_func: Function
+        A function with `(policy, state, stage_id) -> [[State, int], ...]`
+    """
+
+    CONDITION_NUM = {"pass": 0, "apply": 1, "apply_and_skip_rest": 2}
+
+    def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"):
+        self.__init_handle_by_constructor__(
+            _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name
+        )
+
+
+CUSTOM_SKETCH_REGISTRY = {}
+
+
+def register_custom_sketch_func(compute_name, func=None):
+    """Helper decorator to register custom sketch functions easily."""
+    global CUSTOM_SKETCH_REGISTRY
+
+    if callable(compute_name):
+        func = compute_name
+        compute_name = func.__name__
+
+    if not isinstance(compute_name, str):
+        raise ValueError("expect string function name")
+
+    def register(myf):
+        if compute_name in CUSTOM_SKETCH_REGISTRY:
+            raise RuntimeError(
+                "Custom Sketch for %s has been registered for compute already" % compute_name
+            )
+
+        def meet_condition_func(policy, state, stage_id):
+            state = State(state, policy.search_task.compute_dag)
+            if state.stages[stage_id].op.name == compute_name:
+                return PreloadCustomSketchRule.CONDITION_NUM["apply_and_skip_rest"]
+            return PreloadCustomSketchRule.CONDITION_NUM["pass"]
+
+        CUSTOM_SKETCH_REGISTRY[compute_name] = PreloadCustomSketchRule(meet_condition_func, myf)

Review comment:
       Should the `compute_name` be in the 3rd argument?

##########
File path: src/auto_scheduler/search_policy/sketch_policy.cc
##########
@@ -671,6 +671,26 @@ Array<MeasureInput> SketchPolicyNode::PickStatesWithEpsGreedy(const Array<State>
   return inputs;
 }
 
+/********** PreloadCustomSketchRule **********/
+TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode);
+
+PreloadCustomSketchRule::PreloadCustomSketchRule(PackedFunc meet_condition_func,
+                                                 PackedFunc apply_func, String rule_name) {
+  auto node = make_object<PreloadCustomSketchRuleNode>();
+  node->meet_condition_func = std::move(meet_condition_func);
+  node->apply_func = std::move(apply_func);
+  node->rule_name = std::move(rule_name);
+  data_ = std::move(node);
+}
+
+void PreloadCustomSketchRuleNode::Callback(SearchPolicyNode* policy) {
+  CHECK(policy->IsInstance<SketchPolicyNode>());
+  auto sketch_policy = dynamic_cast<SketchPolicyNode*>(policy);
+  sketch_policy->sketch_rules.push_back(
+      new RuleCustomSketch(meet_condition_func, apply_func, rule_name));
+  StdCout(policy->verbose) << "Custom sketch rule added." << std::endl;

Review comment:
       Better to print the rule name in this message.

##########
File path: python/tvm/auto_scheduler/search_policy.py
##########
@@ -61,6 +62,71 @@ def __init__(self, filename):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allowing users to add
+    custom sketch rule.
+
+    Notes
+    -----
+    This is an advanced feature. Make sure you're clear how it
+    works and this should only be used in SketchSearchPolicy.
+
+    Parameters
+    ----------
+    meet_condition_func: Function
+        A function with `(policy, state, stage_id) -> int`
+    apply_func: Function
+        A function with `(policy, state, stage_id) -> [[State, int], ...]`
+    """
+
+    CONDITION_NUM = {"pass": 0, "apply": 1, "apply_and_skip_rest": 2}

Review comment:
       Since this will be used by users who create custom rules, comments are need to explain how to use.

##########
File path: src/auto_scheduler/search_policy/sketch_policy_rules.cc
##########
@@ -461,6 +461,33 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(
   return {std::make_pair(std::move(tmp_s), stage_id - 1)};
 }
 
+/********** RuleCustomSketch **********/
+
+SketchGenerationRule::ConditionKind RuleCustomSketch::MeetCondition(const SketchPolicyNode& policy,
+                                                                    const State& state,
+                                                                    int stage_id) const {
+  auto ret = meet_condition_func_(tvm::runtime::GetRef<SketchPolicy>(&policy), state, stage_id);
+  if (ret.type_code() == 0) {
+    return ConditionKind(static_cast<int>(ret));
+  } else {
+    LOG(WARNING) << "Wrong value returned from custom sketch, try apply and skip the rest";

Review comment:
       ```suggestion
       LOG(WARNING) << "Wrong rule condition value. Apply the rule and skip the rest";
   ```

##########
File path: python/tvm/auto_scheduler/search_policy.py
##########
@@ -61,6 +62,71 @@ def __init__(self, filename):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allowing users to add
+    custom sketch rule.
+
+    Notes
+    -----
+    This is an advanced feature. Make sure you're clear how it
+    works and this should only be used in SketchSearchPolicy.
+
+    Parameters
+    ----------
+    meet_condition_func: Function
+        A function with `(policy, state, stage_id) -> int`
+    apply_func: Function
+        A function with `(policy, state, stage_id) -> [[State, int], ...]`
+    """
+
+    CONDITION_NUM = {"pass": 0, "apply": 1, "apply_and_skip_rest": 2}
+
+    def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"):
+        self.__init_handle_by_constructor__(
+            _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name
+        )
+
+
+CUSTOM_SKETCH_REGISTRY = {}
+
+
+def register_custom_sketch_func(compute_name, func=None):
+    """Helper decorator to register custom sketch functions easily."""
+    global CUSTOM_SKETCH_REGISTRY
+
+    if callable(compute_name):
+        func = compute_name
+        compute_name = func.__name__
+
+    if not isinstance(compute_name, str):
+        raise ValueError("expect string function name")
+
+    def register(myf):
+        if compute_name in CUSTOM_SKETCH_REGISTRY:
+            raise RuntimeError(
+                "Custom Sketch for %s has been registered for compute already" % compute_name
+            )
+
+        def meet_condition_func(policy, state, stage_id):
+            state = State(state, policy.search_task.compute_dag)
+            if state.stages[stage_id].op.name == compute_name:
+                return PreloadCustomSketchRule.CONDITION_NUM["apply_and_skip_rest"]
+            return PreloadCustomSketchRule.CONDITION_NUM["pass"]

Review comment:
       Does this mean if users use this decorator to register the custom rule, they cannot customize the condition function?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org