You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/01/16 09:59:25 UTC

[GitHub] [tvm] jcf94 commented on a change in pull request #7260: [AutoScheduler] Bug fix & Custom sketch support

jcf94 commented on a change in pull request #7260:
URL: https://github.com/apache/tvm/pull/7260#discussion_r558834365



##########
File path: python/tvm/auto_scheduler/search_policy.py
##########
@@ -61,6 +62,71 @@ def __init__(self, filename):
         self.__init_handle_by_constructor__(_ffi_api.PreloadMeasuredStates, filename)
 
 
+@tvm._ffi.register_object("auto_scheduler.PreloadCustomSketchRule")
+class PreloadCustomSketchRule(SearchCallback):
+    """
+    A SearchCallback for SketchSearchPolicy that allowing users to add
+    custom sketch rule.
+
+    Notes
+    -----
+    This is an advanced feature. Make sure you're clear how it
+    works and this should only be used in SketchSearchPolicy.
+
+    Parameters
+    ----------
+    meet_condition_func: Function
+        A function with `(policy, state, stage_id) -> int`
+    apply_func: Function
+        A function with `(policy, state, stage_id) -> [[State, int], ...]`
+    """
+
+    CONDITION_NUM = {"pass": 0, "apply": 1, "apply_and_skip_rest": 2}
+
+    def __init__(self, meet_condition_func, apply_func, rule_name="CustomSketchRule"):
+        self.__init_handle_by_constructor__(
+            _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func, rule_name
+        )
+
+
+CUSTOM_SKETCH_REGISTRY = {}
+
+
+def register_custom_sketch_func(compute_name, func=None):
+    """Helper decorator to register custom sketch functions easily."""
+    global CUSTOM_SKETCH_REGISTRY
+
+    if callable(compute_name):
+        func = compute_name
+        compute_name = func.__name__
+
+    if not isinstance(compute_name, str):
+        raise ValueError("expect string function name")

Review comment:
       Yes, this is expected to be used as a decorator, and I was trying to follow the `tvm._ffi.register_func`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org