You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ju...@apache.org on 2022/04/01 16:31:21 UTC
[tvm] branch main updated: [Metaschedule] Add test case for multi-anchor subgraph (#10856)
This is an automated email from the ASF dual-hosted git repository.
junrushao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 93b255c [Metaschedule] Add test case for multi-anchor subgraph (#10856)
93b255c is described below
commit 93b255cb63514dc6e59560b44fb0a9a979bd8aac
Author: Masahiro Masuda <ma...@gmail.com>
AuthorDate: Sat Apr 2 01:30:41 2022 +0900
[Metaschedule] Add test case for multi-anchor subgraph (#10856)
This adds a demonstration of extracting, scheduling, and e2e-compiling relay subgraphs with multiple anchor ops. Since task extraction is not associated with TE scheduling anymore, extracting a subgraph with multiple anchor TE compute just works.
The test case manually creates a simple fused mod with two `relay.dense`. But in the future, an effort like https://github.com/apache/tvm/pull/9628 should make it easier to construct multi-anchor subgraphs.
The extracted TensorIR block corresponding to two TE `dense` compute looks like this:
```
@tvm.script.ir_module
class Module:
@T.prim_func
def main(placeholder: T.Buffer[(128, 128), "float32"], placeholder_1: T.Buffer[(128, 128), "float32"], placeholder_2: T.Buffer[(128, 128), "float32"], T_matmul_NT: T.Buffer[(128, 128), "float32"]) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
# body
# with T.block("root")
T_matmul_NT_1 = T.alloc_buffer([128, 128], dtype="float32")
for i0, i1, i2 in T.grid(128, 128, 128):
with T.block("T_matmul_NT"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads(placeholder[i, k], placeholder_1[j, k])
T.writes(T_matmul_NT_1[i, j])
T.block_attr({"layout_free_placeholders":[placeholder_1]})
with T.init():
T_matmul_NT_1[i, j] = T.float32(0)
T_matmul_NT_1[i, j] = T_matmul_NT_1[i, j] + placeholder[i, k] * placeholder_1[j, k]
for i0, i1, i2 in T.grid(128, 128, 128):
with T.block("T_matmul_NT_1"):
i, j, k = T.axis.remap("SSR", [i0, i1, i2])
T.reads(T_matmul_NT_1[i, k], placeholder_2[j, k])
T.writes(T_matmul_NT[i, j])
T.block_attr({"layout_free_placeholders":[placeholder_2]})
with T.init():
T_matmul_NT[i, j] = T.float32(0)
T_matmul_NT[i, j] = T_matmul_NT[i, j] + T_matmul_NT_1[i, k] * placeholder_2[j, k]
```
---
src/relay/backend/te_compiler_cache.cc | 6 +-
.../unittest/test_meta_schedule_multi_anchor.py | 131 +++++++++++++++++++++
2 files changed, 135 insertions(+), 2 deletions(-)
diff --git a/src/relay/backend/te_compiler_cache.cc b/src/relay/backend/te_compiler_cache.cc
index 3534697..cd3ce80 100644
--- a/src/relay/backend/te_compiler_cache.cc
+++ b/src/relay/backend/te_compiler_cache.cc
@@ -299,6 +299,7 @@ class ScheduleBuilder : public ExprVisitor {
explicit ScheduleBuilder(Target target) : target_(target) {
// Whether to use auto_scheduler schedule.
use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
+ use_meta_scheduler_ = backend::IsMetaScheduleEnabled();
}
CachedFunc Create(const Function& relay_func, std::function<std::string(std::string)> renamer) {
@@ -336,7 +337,7 @@ class ScheduleBuilder : public ExprVisitor {
schedule = Downcast<te::Schedule>(obj);
}
}
- if (backend::IsMetaScheduleEnabled()) {
+ if (use_meta_scheduler_) {
IRModule relay_mod({{prim_fn_var, relay_func}});
IRModule tir_mod({{prim_fn_var, tir::CreatePrimFunc(Concat(fn_inputs, tensor_outs))}});
Optional<IRModule> scheduled_mod = meta_schedule::MetaScheduleContext::QueryInsideWithScope(
@@ -377,7 +378,7 @@ class ScheduleBuilder : public ExprVisitor {
}
int op_pattern = fpattern[op];
- if (!use_auto_scheduler_ && op_pattern >= kCommReduce) {
+ if (!use_auto_scheduler_ && !use_meta_scheduler_ && op_pattern >= kCommReduce) {
ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
<< "Cannot apply TOPI schedule to a primitive function with two complicated ops"
<< " anchor=" << anchor_op_ << " current=" << op;
@@ -395,6 +396,7 @@ class ScheduleBuilder : public ExprVisitor {
Attrs anchor_attrs_;
int anchor_op_pattern_{0};
bool use_auto_scheduler_;
+ bool use_meta_scheduler_;
};
/*!
diff --git a/tests/python/unittest/test_meta_schedule_multi_anchor.py b/tests/python/unittest/test_meta_schedule_multi_anchor.py
new file mode 100644
index 0000000..e596391
--- /dev/null
+++ b/tests/python/unittest/test_meta_schedule_multi_anchor.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import os
+import tempfile
+
+import numpy as np
+
+import tvm
+import tvm.testing
+from tvm import relay
+from tvm.meta_schedule.tune import Parse, extract_task_from_relay
+from tvm.meta_schedule.database import TuningRecord, JSONDatabase
+from tvm.meta_schedule.integration import ApplyHistoryBest
+
+
+def get_dense_dense(data_shape, weight_shape):
+ def multi_dense():
+ p_data = relay.var("p_data", shape=data_shape, dtype="float32")
+ p_weight1 = relay.var("p_weight1", shape=weight_shape, dtype="float32")
+ p_weight2 = relay.var("p_weight2", shape=weight_shape, dtype="float32")
+
+ dense1 = relay.nn.dense(p_data, p_weight1)
+ dense2 = relay.nn.dense(dense1, p_weight2)
+
+ f = relay.Function([p_data, p_weight1, p_weight2], dense2)
+ f = f.with_attr("Primitive", tvm.tir.IntImm("int32", 1))
+ return f
+
+ data = relay.var("data", shape=data_shape, dtype="float32")
+ weight1 = relay.var("weight1", shape=weight_shape, dtype="float32")
+ weight2 = relay.var("weight2", shape=weight_shape, dtype="float32")
+
+ out = relay.Call(multi_dense(), [data, weight1, weight2])
+ return relay.Function([data, weight1, weight2], out)
+
+
+def get_ref(data_np, weight1_np, weight2_np):
+ dense1 = np.dot(data_np, np.transpose(weight1_np))
+ return np.dot(dense1, np.transpose(weight2_np))
+
+
+def schedule_dense_dense(sch):
+ dense1 = sch.get_block("T_matmul_NT")
+ dense2 = sch.get_block("T_matmul_NT_1")
+
+ y1, x1, k1 = sch.get_loops(dense1)
+ y2, x2, k2 = sch.get_loops(dense2)
+
+ # ...
+
+
+def test_dense_dense():
+ M, N, K = 128, 128, 128
+ data_shape = (M, K)
+ weight_shape = (N, K)
+
+ relay_mod = tvm.IRModule.from_expr(get_dense_dense(data_shape, weight_shape))
+
+ # print(relay.transform.InferType()(relay_mod))
+
+ target = "llvm"
+
+ data_np = np.random.randn(*data_shape).astype("float32")
+ weight1_np = np.random.randn(*weight_shape).astype("float32")
+ weight2_np = np.random.randn(*weight_shape).astype("float32")
+
+ params = {"weight1": weight1_np, "weight2": weight2_np}
+
+ extracted_tasks = extract_task_from_relay(relay_mod, target, params)
+
+ assert len(extracted_tasks) == 1
+
+ task = extracted_tasks[0]
+
+ mod = Parse._mod(task.dispatched[0])
+
+ with tempfile.TemporaryDirectory() as work_dir:
+ database = JSONDatabase(
+ path_workload=os.path.join(work_dir, "database_workload.json"),
+ path_tuning_record=os.path.join(work_dir, "database_tuning_record.json"),
+ )
+
+ workload = database.commit_workload(mod)
+
+ sch = tvm.tir.Schedule(mod)
+
+ schedule_dense_dense(sch)
+
+ # print(sch.mod.script())
+
+ tune_rec = TuningRecord(sch.trace, [0.0], workload, tvm.target.Target(target), [])
+
+ database.commit_tuning_record(tune_rec)
+
+ with ApplyHistoryBest(database):
+ with tvm.transform.PassContext(
+ opt_level=3,
+ config={"relay.backend.use_meta_schedule": True},
+ ):
+ lib = relay.build(relay_mod, target=target, params=params)
+
+ dev = tvm.device(target, 0)
+
+ runtime = tvm.contrib.graph_executor.GraphModule(lib["default"](dev))
+
+ runtime.set_input("data", data_np)
+ runtime.run()
+
+ out = runtime.get_output(0).numpy()
+
+ ref = get_ref(data_np, weight1_np, weight2_np)
+
+ tvm.testing.assert_allclose(out, ref, atol=1e-4, rtol=1e-4)
+
+
+if __name__ == "__main__":
+ test_dense_dense()