You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by wu...@apache.org on 2023/06/16 21:55:23 UTC

[tvm] branch main updated: [TIR] Handle DeclBuffer in LowerThreadAllreduce (#15078)

This is an automated email from the ASF dual-hosted git repository.

wuwei pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3f2aa6817d [TIR] Handle DeclBuffer in LowerThreadAllreduce (#15078)
3f2aa6817d is described below

commit 3f2aa6817d1e0773b4b05fcaa981c170b3ed4a24
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Fri Jun 16 17:55:16 2023 -0400

    [TIR] Handle DeclBuffer in LowerThreadAllreduce (#15078)
    
    Part of changes being split out from
    https://github.com/apache/tvm/pull/14778 into independent portions.
    This commit allows the `LowerThreadAllreduce` pass to handle
    `DeclBuffer` nodes that occur within its input.
---
 src/tir/transforms/lower_thread_allreduce.cc       | 108 ++++------
 src/tir/transforms/update_pointer_storage_scope.cc |   5 +
 src/tir/transforms/update_pointer_storage_scope.h  |   1 +
 .../test_tir_transform_lower_thread_all_reduce.py  | 239 +++++++++++++++++++++
 4 files changed, 287 insertions(+), 66 deletions(-)

diff --git a/src/tir/transforms/lower_thread_allreduce.cc b/src/tir/transforms/lower_thread_allreduce.cc
index 5c004fa5db..f6cda51f43 100644
--- a/src/tir/transforms/lower_thread_allreduce.cc
+++ b/src/tir/transforms/lower_thread_allreduce.cc
@@ -93,92 +93,70 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
     }
   }
   Stmt VisitStmt_(const AllocateNode* op) final {
-    Stmt stmt = StmtExprMutator::VisitStmt_(op);
-    op = stmt.as<AllocateNode>();
-    auto it = alloc_remap_.find(op->buffer_var.get());
-    if (it != alloc_remap_.end()) {
+    auto node = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
+
+    if (auto it = alloc_remap_.find(node->buffer_var.get()); it != alloc_remap_.end()) {
       const AllocateNode* repl = it->second.as<AllocateNode>();
       if (warp_allocs_.count(repl)) {
         new_storage_scopes_[repl->buffer_var.get()] = "local";
       } else {
         new_storage_scopes_[repl->buffer_var.get()] = "shared";
       }
-      return Allocate(repl->buffer_var, repl->dtype, repl->extents, repl->condition, op->body);
-    } else {
-      return stmt;
+      auto write_ptr = node.CopyOnWrite();
+      write_ptr->buffer_var = repl->buffer_var;
+      write_ptr->dtype = repl->dtype;
+      write_ptr->extents = repl->extents;
+      write_ptr->condition = repl->condition;
     }
+    return std::move(node);
   }
 
-  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
-    {
-      auto it = load_remap_.find(op->buffer->data.get());
-      if (it != load_remap_.end()) {
-        for (const auto& index : op->indices) {
-          ICHECK(is_zero(index));
-        }
-        return it->second;
-      }
+  Optional<Buffer> GetRemappedBuffer(const Buffer& buf) {
+    if (auto it = buf_remap_.find(buf.get()); it != buf_remap_.end()) {
+      return it->second;
     }
 
-    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
-    op = load.get();
-
-    {
-      auto it = buf_remap_.find(op->buffer.get());
-      if (it != buf_remap_.end()) {
-        return BufferLoad(it->second, op->indices, op->span);
-      }
+    if (auto it = var_remap_.find(buf->data.get()); it != var_remap_.end()) {
+      Buffer new_buf = buf;
+      new_buf.CopyOnWrite()->data = it->second;
+      buf_remap_[buf.get()] = new_buf;
+      return new_buf;
     }
 
-    {
-      auto it = var_remap_.find(op->buffer->data.get());
-      if (it != var_remap_.end()) {
-        Buffer remapped_buffer(it->second, op->buffer->dtype, op->buffer->shape,
-                               op->buffer->strides, op->buffer->elem_offset, op->buffer->name,
-                               op->buffer->data_alignment, op->buffer->offset_factor,
-                               op->buffer->buffer_type, op->buffer->axis_separators,
-                               op->buffer->span);
-        buf_remap_[op->buffer.get()] = remapped_buffer;
-        return BufferLoad(remapped_buffer, op->indices, op->span);
-      }
-    }
-    return StmtExprMutator::VisitExpr_(op);
+    return NullOpt;
   }
 
-  Stmt VisitStmt_(const BufferStoreNode* op) final {
-    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+  Stmt VisitStmt_(const DeclBufferNode* op) final {
+    auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+    if (auto buf = GetRemappedBuffer(node->buffer)) {
+      node.CopyOnWrite()->buffer = buf.value();
+    }
+    return std::move(node);
+  }
 
-    auto it = store_remap_.find(store->buffer.get());
-    if (it != store_remap_.end()) {
+  PrimExpr VisitExpr_(const BufferLoadNode* op) final {
+    if (auto it = load_remap_.find(op->buffer->data.get()); it != load_remap_.end()) {
       for (const auto& index : op->indices) {
         ICHECK(is_zero(index));
       }
-
-      auto writer = store.CopyOnWrite();
-      writer->buffer = it->second;
-      return std::move(store);
+      return it->second;
     }
 
-    {
-      auto it = buf_remap_.find(store->buffer.get());
-      if (it != buf_remap_.end()) {
-        return BufferStore(it->second, store->value, store->indices, store->span);
-      }
-    }
+    BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
+    op = load.get();
 
-    {
-      auto it = var_remap_.find(store->buffer->data.get());
-      if (it != var_remap_.end()) {
-        Buffer remapped_buffer(it->second, store->buffer->dtype, store->buffer->shape,
-                               store->buffer->strides, store->buffer->elem_offset,
-                               store->buffer->name, store->buffer->data_alignment,
-                               store->buffer->offset_factor, store->buffer->buffer_type,
-                               store->buffer->axis_separators, store->buffer->span);
-        buf_remap_[store->buffer.get()] = remapped_buffer;
-        return BufferStore(remapped_buffer, store->value, store->indices, store->span);
-      }
+    if (auto opt = GetRemappedBuffer(load->buffer)) {
+      load.CopyOnWrite()->buffer = opt.value();
     }
+    return std::move(load);
+  }
 
+  Stmt VisitStmt_(const BufferStoreNode* op) final {
+    BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
+
+    if (auto opt = GetRemappedBuffer(store->buffer)) {
+      store.CopyOnWrite()->buffer = opt.value();
+    }
     return std::move(store);
   }
 
@@ -446,11 +424,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
         PrimExpr val = BufferLoad(buf, zero_indices);
         ICHECK_EQ(val->dtype, types[i]);
         load_remap_[buffers[i]->data.get()] = val;
-        store_remap_[buffers[i].get()] = buf;
         Array<PrimExpr> extents{PrimExpr(1)};
         auto node = Allocate(buf->data, types[i], extents, pred, Evaluate(0));
         alloc_remap_[buffers[i]->data.get()] = node;
         var_remap_[buffers[i]->data.get()] = buf->data;
+        buf_remap_[buffers[i].get()] = buf;
         warp_allocs_.insert(node.get());
       }
     } else {
@@ -489,7 +467,7 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
             Allocate(shared_bufs[idx]->data, types[idx],
                      {PrimExpr(group_extent), PrimExpr(reduce_extent)}, pred, Evaluate(0));
         var_remap_[buffers[idx]->data.get()] = shared_bufs[idx]->data;
-        store_remap_[buffers[idx].get()] = shared_bufs[idx];
+        buf_remap_[buffers[idx].get()] = shared_bufs[idx];
       }
     }
 
@@ -718,8 +696,6 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
   std::vector<const CommReducerNode*> reduce_combiner_;
   // The load remap
   std::unordered_map<const VarNode*, PrimExpr> load_remap_;
-  // The store remap
-  std::unordered_map<const BufferNode*, Buffer> store_remap_;
   // Allocate remap
   std::unordered_map<const VarNode*, Stmt> alloc_remap_;
   // BufferVar remap
diff --git a/src/tir/transforms/update_pointer_storage_scope.cc b/src/tir/transforms/update_pointer_storage_scope.cc
index 157e81d77f..18950bc199 100644
--- a/src/tir/transforms/update_pointer_storage_scope.cc
+++ b/src/tir/transforms/update_pointer_storage_scope.cc
@@ -94,6 +94,11 @@ Buffer UpdatePointerStorageScope::GetUpdatedBuffer(Buffer buf) {
   return buf;
 }
 
+Stmt UpdatePointerStorageScope::VisitStmt_(const DeclBufferNode* op) {
+  auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));
+  return UpdateBufferAccess(node);
+}
+
 PrimExpr UpdatePointerStorageScope::VisitExpr_(const BufferLoadNode* op) {
   auto node = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
   return UpdateBufferAccess(node);
diff --git a/src/tir/transforms/update_pointer_storage_scope.h b/src/tir/transforms/update_pointer_storage_scope.h
index 5c082d24c4..1f1399fba7 100644
--- a/src/tir/transforms/update_pointer_storage_scope.h
+++ b/src/tir/transforms/update_pointer_storage_scope.h
@@ -41,6 +41,7 @@ class UpdatePointerStorageScope : public StmtExprMutator {
   virtual PrimExpr VisitExpr_(const VarNode*);
   virtual PrimExpr VisitExpr_(const BufferLoadNode*);
   virtual Stmt VisitStmt_(const AllocateNode*);
+  virtual Stmt VisitStmt_(const DeclBufferNode*);
   virtual Stmt VisitStmt_(const BufferStoreNode*);
 
  private:
diff --git a/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
new file mode 100644
index 0000000000..f20d11ffb4
--- /dev/null
+++ b/tests/python/unittest/test_tir_transform_lower_thread_all_reduce.py
@@ -0,0 +1,239 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import tvm.testing
+from tvm.script import tir as T
+
+
+class BaseCompare(tvm.testing.CompareBeforeAfter):
+    transform = tvm.tir.transform.LowerThreadAllreduce()
+
+
+class BaseFailure(BaseCompare):
+    expected = ValueError
+
+
+class TestBasic(BaseCompare):
+    def before(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        A_flat = T.Buffer(4096, data=A.data)
+
+        for i in range(128):
+            threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+            reduce_data = T.allocate([1], "float32", "local")
+            reduce = T.Buffer(1, data=reduce_data, scope="local")
+
+            with T.attr(
+                T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                "reduce_scope",
+                T.reinterpret("handle", T.uint64(0)),
+            ):
+                T.tvm_thread_allreduce(
+                    T.uint32(1),
+                    A_flat[0],
+                    T.bool(True),
+                    reduce[0],
+                    threadIdx_x,
+                )
+            if threadIdx_x == 0:
+                B[i] = reduce[0]
+
+    def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        A_flat = T.Buffer(4096, data=A.data)
+
+        for i in range(128):
+            threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+            reduce_data = T.allocate([1], "float32", "local")
+            reduce = T.Buffer(1, data=reduce_data, scope="local")
+
+            with T.attr(
+                T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                "reduce_scope",
+                T.reinterpret("handle", T.uint64(0)),
+            ):
+                mask_data = T.allocate([1], "uint32", "local")
+                mask = T.Buffer(1, "uint32", data=mask_data, scope="local")
+
+                t0_data = T.allocate([1], "float32", "local")
+                t0 = T.Buffer(1, data=t0_data, scope="local")
+
+                reduce[0] = A_flat[0]
+                mask[0] = T.tvm_warp_activemask()
+
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 16, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 8, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 4, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 2, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 1, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                reduce[0] = T.tvm_warp_shuffle(mask[0], reduce[0], 0, 32, 32)
+            if threadIdx_x == 0:
+                B[i] = reduce[0]
+
+
+class TestBasicWithDeclBuffer(BaseCompare):
+    def before(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        A_flat = T.Buffer(4096, data=A.data)
+
+        for i in range(128):
+            threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+            reduce = T.decl_buffer(1, dtype="float32", scope="local")
+
+            with T.attr(
+                T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                "reduce_scope",
+                T.reinterpret("handle", T.uint64(0)),
+            ):
+                T.tvm_thread_allreduce(
+                    T.uint32(1),
+                    A_flat[0],
+                    T.bool(True),
+                    reduce[0],
+                    threadIdx_x,
+                )
+            if threadIdx_x == 0:
+                B[i] = reduce[0]
+
+    def expected(A: T.Buffer((128, 32), "float32"), B: T.Buffer(128, "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        A_flat = T.Buffer(4096, data=A.data)
+
+        for i in range(128):
+            threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+            reduce = T.decl_buffer(1, dtype="float32", scope="local")
+
+            with T.attr(
+                T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                "reduce_scope",
+                T.reinterpret("handle", T.uint64(0)),
+            ):
+                mask_data = T.allocate([1], "uint32", "local")
+                mask = T.Buffer(1, "uint32", data=mask_data, scope="local")
+
+                t0_data = T.allocate([1], "float32", "local")
+                t0 = T.Buffer(1, data=t0_data, scope="local")
+
+                reduce[0] = A_flat[0]
+                mask[0] = T.tvm_warp_activemask()
+
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 16, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 8, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 4, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 2, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 1, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                reduce[0] = T.tvm_warp_shuffle(mask[0], reduce[0], 0, 32, 32)
+            if threadIdx_x == 0:
+                B[i] = reduce[0]
+
+
+class TestReduceSummation(BaseCompare):
+    def before(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        A_flat = T.Buffer((16384,), data=A.data)
+
+        for i in range(128):
+            threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+            normal_reduce_data = T.allocate([1], "float32", "local")
+            normal_reduce = T.Buffer(1, data=normal_reduce_data, scope="local")
+
+            reduce_data = T.allocate([1], "float32", "local")
+            reduce = T.Buffer(1, data=reduce_data, scope="local")
+
+            normal_reduce[0] = T.float32(0)
+
+            for ko in range(4):
+                normal_reduce[0] = normal_reduce[0] + A_flat[i * 128 + ko * 32 + threadIdx_x]
+
+            with T.attr(
+                T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                "reduce_scope",
+                T.reinterpret("handle", T.uint64(0)),
+            ):
+                T.tvm_thread_allreduce(
+                    T.uint32(1),
+                    normal_reduce[0],
+                    T.bool(True),
+                    reduce[0],
+                    threadIdx_x,
+                )
+            if threadIdx_x == 0:
+                B[i] = reduce[0]
+
+    def expected(A: T.Buffer((128, 128), "float32"), B: T.Buffer(128, "float32")):
+        T.func_attr({"target": T.target("cuda", host="llvm")})
+        A_flat = T.Buffer(16384, data=A.data)
+
+        for i in range(128):
+            threadIdx_x = T.launch_thread("threadIdx.x", 32)
+
+            normal_reduce_data = T.allocate([1], "float32", "local")
+            normal_reduce = T.Buffer(1, data=normal_reduce_data, scope="local")
+
+            reduce_data = T.allocate([1], "float32", "local")
+            reduce = T.Buffer(1, data=reduce_data, scope="local")
+
+            normal_reduce[0] = T.float32(0)
+            for ko in range(4):
+                normal_reduce[0] = normal_reduce[0] + A_flat[i * 128 + ko * 32 + threadIdx_x]
+            with T.attr(
+                T.comm_reducer(lambda x, y: x + y, [T.float32(0)]),
+                "reduce_scope",
+                T.reinterpret("handle", T.uint64(0)),
+            ):
+                mask_data = T.allocate([1], "uint32", "local")
+                mask = T.Buffer(1, "uint32", data=mask_data, scope="local")
+
+                t0_data = T.allocate([1], "float32", "local")
+                t0 = T.Buffer(1, data=t0_data, scope="local")
+
+                reduce[0] = normal_reduce[0]
+                mask[0] = T.tvm_warp_activemask()
+
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 16, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 8, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 4, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 2, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                t0[0] = T.tvm_warp_shuffle_down(mask[0], reduce[0], 1, 32, 32)
+                reduce[0] = reduce[0] + t0[0]
+                reduce[0] = T.tvm_warp_shuffle(mask[0], reduce[0], 0, 32, 32)
+            if threadIdx_x == 0:
+                B[i] = reduce[0]
+
+
+if __name__ == "__main__":
+    tvm.testing.main()