You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/04/01 16:31:21 UTC

[tvm] branch main updated: [Metaschedule] Add test case for multi-anchor subgraph (#10856)

This is an automated email from the ASF dual-hosted git repository.

junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 93b255c  [Metaschedule] Add test case for multi-anchor subgraph (#10856)
93b255c is described below

commit 93b255cb63514dc6e59560b44fb0a9a979bd8aac
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Sat Apr 2 01:30:41 2022 +0900

    [Metaschedule] Add test case for multi-anchor subgraph (#10856)
    
    This adds a demonstration of extracting, scheduling, and e2e-compiling relay subgraphs with multiple anchor ops. Since task extraction is not associated with TE scheduling anymore, extracting a subgraph with multiple anchor TE compute just works.
    
    The test case manually creates a simple fused mod with two `relay.dense`. But in the future, an effort like https://github.com/apache/tvm/pull/9628 should make it easier to construct multi-anchor subgraphs.
    
    The extracted TensorIR block corresponding to two TE `dense` compute looks like this:
    
    ```
    @tvm.script.ir_module
    class Module:
        @T.prim_func
        def main(placeholder: T.Buffer[(128, 128), "float32"], placeholder_1: T.Buffer[(128, 128), "float32"], placeholder_2: T.Buffer[(128, 128), "float32"], T_matmul_NT: T.Buffer[(128, 128), "float32"]) -> None:
            # function attr dict
            T.func_attr({"global_symbol": "main", "tir.noalias": True})
            # body
            # with T.block("root")
            T_matmul_NT_1 = T.alloc_buffer([128, 128], dtype="float32")
            for i0, i1, i2 in T.grid(128, 128, 128):
                with T.block("T_matmul_NT"):
                    i, j, k = T.axis.remap("SSR", [i0, i1, i2])
                    T.reads(placeholder[i, k], placeholder_1[j, k])
                    T.writes(T_matmul_NT_1[i, j])
                    T.block_attr({"layout_free_placeholders":[placeholder_1]})
                    with T.init():
                        T_matmul_NT_1[i, j] = T.float32(0)
                    T_matmul_NT_1[i, j] = T_matmul_NT_1[i, j] + placeholder[i, k] * placeholder_1[j, k]
            for i0, i1, i2 in T.grid(128, 128, 128):
                with T.block("T_matmul_NT_1"):
                    i, j, k = T.axis.remap("SSR", [i0, i1, i2])
                    T.reads(T_matmul_NT_1[i, k], placeholder_2[j, k])
                    T.writes(T_matmul_NT[i, j])
                    T.block_attr({"layout_free_placeholders":[placeholder_2]})
                    with T.init():
                        T_matmul_NT[i, j] = T.float32(0)
                    T_matmul_NT[i, j] = T_matmul_NT[i, j] + T_matmul_NT_1[i, k] * placeholder_2[j, k]
    
    ```
---
 src/relay/backend/te_compiler_cache.cc             |   6 +-
 .../unittest/test_meta_schedule_multi_anchor.py    | 131 +++++++++++++++++++++
 2 files changed, 135 insertions(+), 2 deletions(-)

diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc
index 3534697..cd3ce80 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -299,6 +299,7 @@ class ScheduleBuilder : public ExprVisitor {
   explicit ScheduleBuilder(Target target) : target_(target) {
     // Whether to use auto_scheduler schedule.
     use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+    use_meta_scheduler_ = backend::IsMetaScheduleEnabled();
   }
 
   CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
@@ -336,7 +337,7 @@ class ScheduleBuilder : public ExprVisitor {
           schedule = Downcast<te::Schedule>(obj);
         }
       }
-      if (backend::IsMetaScheduleEnabled()) {
+      if (use_meta_scheduler_) {
         IRModule relay_mod({{prim_fn_var, relay_func}});
         IRModule tir_mod({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}});
         Optional<IRModule> scheduled_mod = meta_schedule::MetaScheduleContext::QueryInsideWithScope(
@@ -377,7 +378,7 @@ class ScheduleBuilder : public ExprVisitor {
     }
 
     int op_pattern = fpattern[op];
-    if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
+    if (!use_auto_scheduler_ && !use_meta_scheduler_ && op_pattern >= kCommReduce) {
       ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
           << "Cannot apply TOPI schedule to a primitive function with two complicated ops"
           << " anchor=" << anchor_op_ << " current=" << op;
@@ -395,6 +396,7 @@ class ScheduleBuilder : public ExprVisitor {
   Attrs anchor_attrs_;
   int anchor_op_pattern_{0};
   bool use_auto_scheduler_;
+  bool use_meta_scheduler_;
 };
 
 /*!
diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py b/tests/python/unittest/test_meta_schedule_multi_anchor.py
new file mode 100644
index 0000000..e596391
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+import tempfile
+
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm import relay
+from tvm.meta_schedule.tune import Parse, extract_task_from_relay
+from tvm.meta_schedule.database import TuningRecord, JSONDatabase
+from tvm.meta_schedule.integration import ApplyHistoryBest
+
+
+def get_dense_dense(data_shape, weight_shape):
+    def multi_dense():
+        p_data = relay.var("p_data", shape=data_shape, dtype="float32")
+        p_weight1 = relay.var("p_weight1", shape=weight_shape, dtype="float32")
+        p_weight2 = relay.var("p_weight2", shape=weight_shape, dtype="float32")
+
+        dense1 = relay.nn.dense(p_data, p_weight1)
+        dense2 = relay.nn.dense(dense1, p_weight2)
+
+        f = relay.Function([p_data, p_weight1, p_weight2], dense2)
+        f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+        return f
+
+    data = relay.var("data", shape=data_shape, dtype="float32")
+    weight1 = relay.var("weight1", shape=weight_shape, dtype="float32")
+    weight2 = relay.var("weight2", shape=weight_shape, dtype="float32")
+
+    out = relay.Call(multi_dense(), [data, weight1, weight2])
+    return relay.Function([data, weight1, weight2], out)
+
+
+def get_ref(data_np, weight1_np, weight2_np):
+    dense1 = np.dot(data_np, np.transpose(weight1_np))
+    return np.dot(dense1, np.transpose(weight2_np))
+
+
+def schedule_dense_dense(sch):
+    dense1 = sch.get_block("T_matmul_NT")
+    dense2 = sch.get_block("T_matmul_NT_1")
+
+    y1, x1, k1 = sch.get_loops(dense1)
+    y2, x2, k2 = sch.get_loops(dense2)
+
+    # ...
+
+
+def test_dense_dense():
+    M, N, K = 128, 128, 128
+    data_shape = (M, K)
+    weight_shape = (N, K)
+
+    relay_mod = tvm.IRModule.from_expr(get_dense_dense(data_shape, weight_shape))
+
+    # print(relay.transform.InferType()(relay_mod))
+
+    target = "llvm"
+
+    data_np = np.random.randn(*data_shape).astype("float32")
+    weight1_np = np.random.randn(*weight_shape).astype("float32")
+    weight2_np = np.random.randn(*weight_shape).astype("float32")
+
+    params = {"weight1": weight1_np, "weight2": weight2_np}
+
+    extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+    assert len(extracted_tasks) == 1
+
+    task = extracted_tasks[0]
+
+    mod = Parse._mod(task.dispatched[0])
+
+    with tempfile.TemporaryDirectory() as work_dir:
+        database = JSONDatabase(
+            path_workload=os.path.join(work_dir, "database_workload.json"),
+            path_tuning_record=os.path.join(work_dir, "database_tuning_record.json"),
+        )
+
+        workload = database.commit_workload(mod)
+
+        sch = tvm.tir.Schedule(mod)
+
+        schedule_dense_dense(sch)
+
+        # print(sch.mod.script())
+
+        tune_rec = TuningRecord(sch.trace, [0.0], workload, tvm.target.Target(target), [])
+
+        database.commit_tuning_record(tune_rec)
+
+    with ApplyHistoryBest(database):
+        with tvm.transform.PassContext(
+            opt_level=3,
+            config={"relay.backend.use_meta_schedule": True},
+        ):
+            lib = relay.build(relay_mod, target=target, params=params)
+
+    dev = tvm.device(target, 0)
+
+    runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+    runtime.set_input("data", data_np)
+    runtime.run()
+
+    out = runtime.get_output(0).numpy()
+
+    ref = get_ref(data_np, weight1_np, weight2_np)
+
+    tvm.testing.assert_allclose(out, ref, atol=1e-4, rtol=1e-4)
+
+
+if __name__ == "__main__":
+    test_dense_dense()