You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/03 16:59:30 UTC

[GitHub] [tvm] tkonolige commented on a change in pull request #7313: [AutoSchedule] Sparse dense tuning support with custom sketch rule

tkonolige commented on a change in pull request #7313:
URL: https://github.com/apache/tvm/pull/7313#discussion_r586586361



##########
File path: python/tvm/auto_scheduler/measure.py
##########
@@ -719,6 +720,45 @@ def local_builder_build(inputs, timeout, n_parallel, build_func="default", verbo
     return results
 
 
+def _prepare_input_map(args):
+    """This function deals with special task inputs.
+
+    Parameters
+    ----------
+    args : List[Tensor]
+        Input/output Tensor of a TVM subgraph.
+
+    Returns
+    -------
+    A Dict[Tensor, str] that maps the input Tensor to a buffer name.
+
+    Note
+    ----
+    The buffer name is specially designed, and these buffer should be provided in
+    `SearchTask(..., task_inputs={...})`.
+    """
+    # pylint: disable=import-outside-toplevel
+    from tvm import topi  # lazily import to avoid recursive dependency
+
+    # A dict that maps the input tensor arg to a buffer name
+    tensor_input_map = {}
+
+    # Case 0: Check placeholder name
+    for arg in args:
+        if isinstance(arg.op, tvm.te.PlaceholderOp):
+            if arg.op.name != "placeholder":
+                tensor_input_map[arg] = arg.op.name
+
+    # Case 1: Check sparse op
+    sparse_input_map = topi.nn.sparse.try_get_sparse_input(args)

Review comment:
       I think I asked this before, but can we have a more general mechanism than checking only for sparse. There are other use cases that require specific input (sorting, scatter).

##########
File path: tests/python/unittest/test_auto_scheduler_search_task.py
##########
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Test search policy"""
+
+import random
+import multiprocessing
+import numpy as np
+import tempfile
+
+import tvm
+import tvm.testing
+from tvm import auto_scheduler
+from tvm.auto_scheduler.utils import get_const_tuple
+
+from test_auto_scheduler_common import (
+    matmul_auto_scheduler_test,
+    zero_rank_compute_auto_scheduler_test,
+    zero_rank_reduce_auto_scheduler_test,
+)
+import multiprocessing

Review comment:
       Double unused import of multiprocessing.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org