You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2023/06/16 21:55:23 UTC
[tvm] branch main updated: [TIR] Handle DeclBuffer in LowerThreadAllreduce (#15078)
This is an automated email from the ASF dual-hosted git repository.
wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 3f2aa6817d [TIR] Handle DeclBuffer in LowerThreadAllreduce (#15078)
3f2aa6817d is described below
commit 3f2aa6817d1e0773b4b05fcaa981c170b3ed4a24
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Jun 16 17:55:16 2023 -0400
[TIR] Handle DeclBuffer in LowerThreadAllreduce (#15078)
Part of changes being split out from
https://github.com/apache/tvm/pull/14778 into independent portions.
This commit allows the `LowerThreadAllreduce` pass to handle
`DeclBuffer` nodes that occur within its input.
---
src/tir/transforms/lower_thread_allreduce.cc | 108 ++++------
src/tir/transforms/update_pointer_storage_scope.cc | 5 +
src/tir/transforms/update_pointer_storage_scope.h | 1 +
.../test_tir_transform_lower_thread_all_reduce.py | 239 +++++++++++++++++++++
4 files changed, 287 insertions(+), 66 deletions(-)
diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc
index 5c004fa5db..f6cda51f43 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -93,92 +93,70 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
}
}
Stmt VisitStmt_(const AllocateNode* op) final {
- Stmt stmt = StmtExprMutator::VisitStmt_(op);
- op = stmt.as<AllocateNode>();
- auto it = alloc_remap_.find(op->buffer_var.get());
- if (it != alloc_remap_.end()) {
+ auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
+
+ if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) {
const AllocateNode* repl = it->second.as<AllocateNode>();
if (warp_allocs_.count(repl)) {
new_storage_scopes_[repl->buffer_var.get()] = "local";
} else {
new_storage_scopes_[repl->buffer_var.get()] = "shared";
}
- return Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body);
- } else {
- return stmt;
+ auto write_ptr = node.CopyOnWrite();
+ write_ptr->buffer_var = repl->buffer_var;
+ write_ptr->dtype = repl->dtype;
+ write_ptr->extents = repl->extents;
+ write_ptr->condition = repl->condition;
}
+ return std::move(node);
}
- PrimExpr VisitExpr_(const BufferLoadNode* op) final {
- {
- auto it = load_remap_.find(op->buffer->data.get());
- if (it != load_remap_.end()) {
- for (const auto& index : op->indices) {
- ICHECK(is_zero(index));
- }
- return it->second;
- }
+ Optional<Buffer> GetRemappedBuffer(const Buffer& buf) {
+ if (auto it = buf_remap_.find(buf.get()); it != buf_remap_.end()) {
+ return it->second;
}
- BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
- op = load.get();
-
- {
- auto it = buf_remap_.find(op->buffer.get());
- if (it != buf_remap_.end()) {
- return BufferLoad(it->second, op->indices, op->span);
- }
+ if (auto it = var_remap_.find(buf->data.get()); it != var_remap_.end()) {
+ Buffer new_buf = buf;
+ new_buf.CopyOnWrite()->data = it->second;
+ buf_remap_[buf.get()] = new_buf;
+ return new_buf;
}
- {
- auto it = var_remap_.find(op->buffer->data.get());
- if (it != var_remap_.end()) {
- Buffer remapped_buffer(it->second, op->buffer->dtype, op->buffer->shape,
- op->buffer->strides, op->buffer->elem_offset, op->buffer->name,
- op->buffer->data_alignment, op->buffer->offset_factor,
- op->buffer->buffer_type, op->buffer->axis_separators,
- op->buffer->span);
- buf_remap_[op->buffer.get()] = remapped_buffer;
- return BufferLoad(remapped_buffer, op->indices, op->span);
- }
- }
- return StmtExprMutator::VisitExpr_(op);
+ return NullOpt;
}
- Stmt VisitStmt_(const BufferStoreNode* op) final {
- BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+ Stmt VisitStmt_(const DeclBufferNode* op) final {
+ auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+ if (auto buf = GetRemappedBuffer(node->buffer)) {
+ node.CopyOnWrite()->buffer = buf.value();
+ }
+ return std::move(node);
+ }
- auto it = store_remap_.find(store->buffer.get());
- if (it != store_remap_.end()) {
+ PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+ if (auto it = load_remap_.find(op->buffer->data.get()); it != load_remap_.end()) {
for (const auto& index : op->indices) {
ICHECK(is_zero(index));
}
-
- auto writer = store.CopyOnWrite();
- writer->buffer = it->second;
- return std::move(store);
+ return it->second;
}
- {
- auto it = buf_remap_.find(store->buffer.get());
- if (it != buf_remap_.end()) {
- return BufferStore(it->second, store->value, store->indices, store->span);
- }
- }
+ BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+ op = load.get();
- {
- auto it = var_remap_.find(store->buffer->data.get());
- if (it != var_remap_.end()) {
- Buffer remapped_buffer(it->second, store->buffer->dtype, store->buffer->shape,
- store->buffer->strides, store->buffer->elem_offset,
- store->buffer->name, store->buffer->data_alignment,
- store->buffer->offset_factor, store->buffer->buffer_type,
- store->buffer->axis_separators, store->buffer->span);
- buf_remap_[store->buffer.get()] = remapped_buffer;
- return BufferStore(remapped_buffer, store->value, store->indices, store->span);
- }
+ if (auto opt = GetRemappedBuffer(load->buffer)) {
+ load.CopyOnWrite()->buffer = opt.value();
}
+ return std::move(load);
+ }
+ Stmt VisitStmt_(const BufferStoreNode* op) final {
+ BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+
+ if (auto opt = GetRemappedBuffer(store->buffer)) {
+ store.CopyOnWrite()->buffer = opt.value();
+ }
return std::move(store);
}
@@ -446,11 +424,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
PrimExpr val = BufferLoad(buf, zero_indices);
ICHECK_EQ(val->dtype, types[i]);
load_remap_[buffers[i]->data.get()] = val;
- store_remap_[buffers[i].get()] = buf;
Array<PrimExpr> extents{PrimExpr(1)};
auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
alloc_remap_[buffers[i]->data.get()] = node;
var_remap_[buffers[i]->data.get()] = buf->data;
+ buf_remap_[buffers[i].get()] = buf;
warp_allocs_.insert(node.get());
}
} else {
@@ -489,7 +467,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
Allocate(shared_bufs[idx]->data, types[idx],
{PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
- store_remap_[buffers[idx].get()] = shared_bufs[idx];
+ buf_remap_[buffers[idx].get()] = shared_bufs[idx];
}
}
@@ -718,8 +696,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
std::vector<const CommReducerNode*> reduce_combiner_;
// The load remap
std::unordered_map<const VarNode*, PrimExpr> load_remap_;
- // The store remap
- std::unordered_map<const BufferNode*, Buffer> store_remap_;
// Allocate remap
std::unordered_map<const VarNode*, Stmt> alloc_remap_;
// BufferVar remap
diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc
index 157e81d77f..18950bc199 100644
--- a/src/tir/transforms/update_pointer_storage_scope.cc
+++ b/src/tir/transforms/update_pointer_storage_scope.cc
@@ -94,6 +94,11 @@ Buffer UpdatePointerStorageScope::GetUpdatedBuffer(Buffer buf) {
return buf;
}
+Stmt UpdatePointerStorageScope::VisitStmt_(const DeclBufferNode* op) {
+ auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+ return UpdateBufferAccess(node);
+}
+
PrimExpr UpdatePointerStorageScope::VisitExpr_(const BufferLoadNode* op) {
auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return UpdateBufferAccess(node);
diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h
index 5c082d24c4..1f1399fba7 100644
--- a/src/tir/transforms/update_pointer_storage_scope.h
+++ b/src/tir/transforms/update_pointer_storage_scope.h
@@ -41,6 +41,7 @@ class UpdatePointerStorageScope : public StmtExprMutator {
virtual PrimExpr VisitExpr_(const VarNode*);
virtual PrimExpr VisitExpr_(const BufferLoadNode*);
virtual Stmt VisitStmt_(const AllocateNode*);
+ virtual Stmt VisitStmt_(const DeclBufferNode*);
virtual Stmt VisitStmt_(const BufferStoreNode*);
private:
diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
new file mode 100644
index 0000000000..f20d11ffb4
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+ transform = tvm.tir.transform.LowerThreadAllreduce()
+
+
+class BaseFailure(BaseCompare):
+ expected = ValueError
+
+
+class TestBasic(BaseCompare):
+ def before(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")):
+ T.func_attr({"target": T.target("cuda", host="llvm")})
+ A_flat = T.Buffer(4096, data=A.data)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ reduce_data = T.allocate([1], "float32", "local")
+ reduce = T.Buffer(1, data=reduce_data, scope="local")
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ A_flat[0],
+ T.bool(True),
+ reduce[0],
+ threadIdx_x,
+ )
+ if threadIdx_x == 0:
+ B[i] = reduce[0]
+
+ def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")):
+ T.func_attr({"target": T.target("cuda", host="llvm")})
+ A_flat = T.Buffer(4096, data=A.data)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ reduce_data = T.allocate([1], "float32", "local")
+ reduce = T.Buffer(1, data=reduce_data, scope="local")
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ mask_data = T.allocate([1], "uint32", "local")
+ mask = T.Buffer(1, "uint32", data=mask_data, scope="local")
+
+ t0_data = T.allocate([1], "float32", "local")
+ t0 = T.Buffer(1, data=t0_data, scope="local")
+
+ reduce[0] = A_flat[0]
+ mask[0] = T.tvm_warp_activemask()
+
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 16, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 8, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 4, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 2, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 1, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ reduce[0] = T.tvm_warp_shuffle(mask[0], reduce[0], 0, 32, 32)
+ if threadIdx_x == 0:
+ B[i] = reduce[0]
+
+
+class TestBasicWithDeclBuffer(BaseCompare):
+ def before(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")):
+ T.func_attr({"target": T.target("cuda", host="llvm")})
+ A_flat = T.Buffer(4096, data=A.data)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ reduce = T.decl_buffer(1, dtype="float32", scope="local")
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ A_flat[0],
+ T.bool(True),
+ reduce[0],
+ threadIdx_x,
+ )
+ if threadIdx_x == 0:
+ B[i] = reduce[0]
+
+ def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")):
+ T.func_attr({"target": T.target("cuda", host="llvm")})
+ A_flat = T.Buffer(4096, data=A.data)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ reduce = T.decl_buffer(1, dtype="float32", scope="local")
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ mask_data = T.allocate([1], "uint32", "local")
+ mask = T.Buffer(1, "uint32", data=mask_data, scope="local")
+
+ t0_data = T.allocate([1], "float32", "local")
+ t0 = T.Buffer(1, data=t0_data, scope="local")
+
+ reduce[0] = A_flat[0]
+ mask[0] = T.tvm_warp_activemask()
+
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 16, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 8, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 4, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 2, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 1, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ reduce[0] = T.tvm_warp_shuffle(mask[0], reduce[0], 0, 32, 32)
+ if threadIdx_x == 0:
+ B[i] = reduce[0]
+
+
+class TestReduceSummation(BaseCompare):
+ def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")):
+ T.func_attr({"target": T.target("cuda", host="llvm")})
+ A_flat = T.Buffer((16384,), data=A.data)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ normal_reduce_data = T.allocate([1], "float32", "local")
+ normal_reduce = T.Buffer(1, data=normal_reduce_data, scope="local")
+
+ reduce_data = T.allocate([1], "float32", "local")
+ reduce = T.Buffer(1, data=reduce_data, scope="local")
+
+ normal_reduce[0] = T.float32(0)
+
+ for ko in range(4):
+ normal_reduce[0] = normal_reduce[0] + A_flat[i * 128 + ko * 32 + threadIdx_x]
+
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ T.tvm_thread_allreduce(
+ T.uint32(1),
+ normal_reduce[0],
+ T.bool(True),
+ reduce[0],
+ threadIdx_x,
+ )
+ if threadIdx_x == 0:
+ B[i] = reduce[0]
+
+ def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")):
+ T.func_attr({"target": T.target("cuda", host="llvm")})
+ A_flat = T.Buffer(16384, data=A.data)
+
+ for i in range(128):
+ threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+ normal_reduce_data = T.allocate([1], "float32", "local")
+ normal_reduce = T.Buffer(1, data=normal_reduce_data, scope="local")
+
+ reduce_data = T.allocate([1], "float32", "local")
+ reduce = T.Buffer(1, data=reduce_data, scope="local")
+
+ normal_reduce[0] = T.float32(0)
+ for ko in range(4):
+ normal_reduce[0] = normal_reduce[0] + A_flat[i * 128 + ko * 32 + threadIdx_x]
+ with T.attr(
+ T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+ "reduce_scope",
+ T.reinterpret("handle", T.uint64(0)),
+ ):
+ mask_data = T.allocate([1], "uint32", "local")
+ mask = T.Buffer(1, "uint32", data=mask_data, scope="local")
+
+ t0_data = T.allocate([1], "float32", "local")
+ t0 = T.Buffer(1, data=t0_data, scope="local")
+
+ reduce[0] = normal_reduce[0]
+ mask[0] = T.tvm_warp_activemask()
+
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 16, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 8, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 4, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 2, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 1, 32, 32)
+ reduce[0] = reduce[0] + t0[0]
+ reduce[0] = T.tvm_warp_shuffle(mask[0], reduce[0], 0, 32, 32)
+ if threadIdx_x == 0:
+ B[i] = reduce[0]
+
+
+if __name__ == "__main__":
+ tvm.testing.main()