You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/04/12 15:27:25 UTC

[tvm] branch main updated: [Hexagon][LLVM] Enable/test tensorized Hexagon DMA on 2d transformed layout (#10905)

This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 11d22bdc1b [Hexagon][LLVM] Enable/test tensorized Hexagon DMA on 2d transformed layout (#10905)
11d22bdc1b is described below

commit 11d22bdc1bd45d952eb140684e64f01438b7f516
Author: Eric Lunderberg <Lu...@users.noreply.github.com>
AuthorDate: Tue Apr 12 10:27:20 2022 -0500

    [Hexagon][LLVM] Enable/test tensorized Hexagon DMA on 2d transformed layout (#10905)
    
    * [Hexagon][LLVM] Enable/test tensorized Hexagon DMA
    
    - In the `CodeGenLLVM::CreateIntrinsic` handler for
      `builtin::address_of()`, pass N-d indices to
      `CodeGenLLVM::CreateBufferPtr`.  The base class implementation still
      asserts that there is a flat memory space, while the
      `CodeGenHexagon::CreateBufferPtr` override allows 2-d memory.
    
    - Enable tensorization in `test_cache_read_write.py`, using
      `tir.address_of` to pass the lowered value.
    
    Co-authored-by: Adam Straw <as...@octoml.ai>
    
    * [TIR] Allow buffer_bind_scope of N-d buffers
    
    Previously, any `buffer_bind_scope` attribute that provides a view
    into a non-flat buffer would result in an error.  After this commit,
    `buffer_bind_scope` may be used for non-flat buffers, but use of
    `arg_buffer->elem_offset` within the body of the bind statement is
    still an error.
    
    The `BufferNode::elem_offset` field represents the offset between the
    pointer of the backing allocation and the first element of the buffer.
    This offset is only well-defined for flat memory spaces.
    
    * update test to tensorize cache_read `y` (works) and cache_write `z` (fails)
    
    * add `split` to allow for tensorization of cache_write of `z`
    
    * fix typo and cleanup comment
    
    * add back original 1d test_cache_read_write
    
    * update comments
    
    * format error
    
    Co-authored-by: Adam Straw <as...@octoml.ai>
---
 src/target/llvm/codegen_llvm.cc                    |  16 ++-
 src/tir/ir/buffer.cc                               |  17 ++-
 src/tir/transforms/arg_binder.cc                   |  33 +++---
 src/tir/transforms/storage_flatten.cc              |  10 ++
 .../contrib/test_hexagon/test_cache_read_write.py  | 125 +++++++++++++++------
 5 files changed, 143 insertions(+), 58 deletions(-)

diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index bacfbc9947..8cd8a5199d 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -1006,13 +1006,19 @@ llvm::Value* CodeGenLLVM::CreateIntrinsic(const CallNode* op) {
   } else if (op->op.same_as(builtin::address_of())) {
     const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
     ICHECK(op->args.size() == 1 && load);
-    ICHECK_EQ(load->indices.size(), 1) << "LLVM only supports flat memory allocations.";
-    PrimExpr index = load->indices[0];
-    if (const RampNode* r = index.as<RampNode>()) {
-      index = r->base;
+
+    Array<PrimExpr> indices = load->indices;
+    if (const RampNode* r = indices[indices.size() - 1].as<RampNode>()) {
+      indices.Set(indices.size() - 1, r->base);
+    }
+
+    std::vector<llvm::Value*> indices_val;
+    for (const auto& index : indices) {
+      indices_val.push_back(MakeValue(index));
     }
+
     TypedPointer buffer_ptr = CreateBufferPtr(MakeValue(load->buffer->data), load->buffer->dtype,
-                                              {MakeValue(index)}, load->dtype);
+                                              indices_val, load->dtype);
     unsigned addrspace =
         llvm::dyn_cast<llvm::PointerType>(buffer_ptr.addr->getType())->getAddressSpace();
     return builder_->CreatePointerCast(buffer_ptr.addr, t_char_->getPointerTo(addrspace));
diff --git a/src/tir/ir/buffer.cc b/src/tir/ir/buffer.cc
index ffeb4c0128..9cc92bd17e 100644
--- a/src/tir/ir/buffer.cc
+++ b/src/tir/ir/buffer.cc
@@ -460,7 +460,6 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
   begins = SimplifyArray(&ana, begins);
   Array<PrimExpr> elem_offset = n->ElemOffset(begins);
   elem_offset.MutateByApply([&](const PrimExpr& expr) { return ana.Simplify(expr); });
-  ICHECK_EQ(elem_offset.size(), 1) << "MakeSlice currently supports only flat 1-d memory.";
 
   Array<PrimExpr> strides = n->strides;
   if (strides.size() == 0) {
@@ -480,8 +479,20 @@ Buffer Buffer::MakeSlice(Array<PrimExpr> begins, Array<PrimExpr> extents) const
       return MakeStrideView().MakeSlice(begins, extents);
     }
   }
-  return Buffer(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice",
-                n->data_alignment, 0, n->buffer_type);
+  Buffer slice(n->data, n->dtype, extents, strides, elem_offset[0], n->name + "_slice",
+               n->data_alignment, 0, n->buffer_type);
+
+  // Buffer must be constructed with a singular element offset which means there is no
+  // support for n-dimensional buffers where n > 1.  Insert sentinel value for
+  // ArgBinder::BindBuffer to state that any usage of element offset is invalid
+  // in this case.  This allows for construction of a Buffer with multiple element offsets
+  // but disallows any usage of those element offsets.  See PR #10816 for discussion on
+  // supporting multiple element offsets in TIR Buffer.
+  // TODO(Lunderberg): Remove if/when TIR supports multiple element offsets in TIR Buffer
+  if (elem_offset.size() != 1) {
+    slice.CopyOnWrite()->elem_offset = PrimExpr();
+  }
+  return slice;
 }
 
 PrimExpr Buffer::access_ptr(int access_mask, DataType ptr_type, int content_lanes,
diff --git a/src/tir/transforms/arg_binder.cc b/src/tir/transforms/arg_binder.cc
index d7cd731a3d..2fc3bd2dca 100644
--- a/src/tir/transforms/arg_binder.cc
+++ b/src/tir/transforms/arg_binder.cc
@@ -96,22 +96,25 @@ void ArgBinder::BindBuffer(const Buffer& arg, const Buffer& value, const std::st
                  << " required_alignment=" << arg->data_alignment
                  << ", provided_alignment=" << value->data_alignment;
   }
-  // bind pointer and offset.
-  if (is_zero(arg->elem_offset)) {
-    ICHECK(is_zero(value->elem_offset))
-        << "Trying to bind a Buffer with offset into one without offset "
-        << " required elem_offset=" << arg->elem_offset
-        << ", provided elem_offset=" << value->elem_offset;
-  }
 
-  this->Bind(arg->data, value->data, arg_name + ".data");
-  if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
-    if (arg->offset_factor > 1) {
-      PrimExpr offset = value->elem_offset;
-      PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
-      PrimExpr zero = make_zero(offset.dtype());
-      BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
-                      &asserts_);
+  if (value->elem_offset.defined()) {
+    // bind pointer and offset.
+    if (is_zero(arg->elem_offset)) {
+      ICHECK(is_zero(value->elem_offset))
+          << "Trying to bind a Buffer with offset into one without offset "
+          << " required elem_offset=" << arg->elem_offset
+          << ", provided elem_offset=" << value->elem_offset;
+    }
+
+    this->Bind(arg->data, value->data, arg_name + ".data");
+    if (Bind_(arg->elem_offset, value->elem_offset, arg_name + ".elem_offset", false)) {
+      if (arg->offset_factor > 1) {
+        PrimExpr offset = value->elem_offset;
+        PrimExpr factor = make_const(offset.dtype(), arg->offset_factor);
+        PrimExpr zero = make_zero(offset.dtype());
+        BinderAddAssert(&analyzer_, truncmod(offset, factor) == zero, arg_name + ".elem_offset",
+                        &asserts_);
+      }
     }
   }
 
diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc
index 0923517634..f97f91a1e5 100644
--- a/src/tir/transforms/storage_flatten.cc
+++ b/src/tir/transforms/storage_flatten.cc
@@ -887,6 +887,9 @@ class BufferBindUnwrapper : public StmtExprMutator {
   }
 
   PrimExpr VisitExpr_(const VarNode* op) final {
+    ICHECK(!illegal_vars_.count(op)) << "Variable " << op->name_hint << " is not well defined.  "
+                                     << "(e.g. use of buffer.elem_offset for a non-flat buffer)";
+
     auto it = var_remap_.find(op);
     if (it != var_remap_.end()) {
       return it->second;
@@ -1110,6 +1113,11 @@ class BufferBindUnwrapper : public StmtExprMutator {
     // transformations should have been handled in
     // BufferShapeLegalize.
     binder.BindBuffer(source, view, source->name, false);
+    if (auto* elem_offset_var = source->elem_offset.as<VarNode>()) {
+      if (!view->elem_offset.defined()) {
+        illegal_vars_.insert(elem_offset_var);
+      }
+    }
 
     // Apply the remaps
     Stmt body = op->body;
@@ -1162,6 +1170,8 @@ class BufferBindUnwrapper : public StmtExprMutator {
   // The buffer assignment map
   // Variable remap
   std::unordered_map<const VarNode*, PrimExpr> var_remap_;
+  // Variables that may not occur within the body.
+  std::unordered_set<const VarNode*> illegal_vars_;
   // Buffer map
   std::unordered_map<const BufferNode*, BufferEntry> buf_map_;
   // Set of vars that have occurred in an AllocateNode, but haven't
diff --git a/tests/python/contrib/test_hexagon/test_cache_read_write.py b/tests/python/contrib/test_hexagon/test_cache_read_write.py
index 6bcd852424..e5595485a2 100644
--- a/tests/python/contrib/test_hexagon/test_cache_read_write.py
+++ b/tests/python/contrib/test_hexagon/test_cache_read_write.py
@@ -28,7 +28,6 @@ from .conftest import requires_hexagon_toolchain
 
 
 def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
-    assert len(shape) == 1
     src = te.placeholder(shape=shape, dtype=dtype, name="src")
     dst = te.compute(shape, lambda i: src[i], name="dst")
     size = shape[0] * np.dtype(dtype).itemsize
@@ -38,6 +37,7 @@ def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
         dtype,
         scope=src_scope,
         offset_factor=1,
+        name="mem_copy_src_buffer",
     )
 
     dst_buffer = tvm.tir.decl_buffer(
@@ -45,16 +45,27 @@ def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
         dtype,
         scope=dst_scope,
         offset_factor=1,
+        name="mem_copy_dst_buffer",
     )
 
+    zero_indices = [0 for _ in shape]
+
     def intrin_func(ins, outs):
         ib = tvm.tir.ir_builder.create()
 
         _src = ins[0]
         _dst = outs[0]
+
+        dst_handle = ib.buffer_ptr(dst_buffer)
+        src_handle = ib.buffer_ptr(src_buffer)
+
         ib.emit(
             tvm.tir.call_intrin(
-                "handle", "tir.mem_copy", _dst.access_ptr("w"), _src.access_ptr("r"), size
+                "handle",
+                "tir.mem_copy",
+                tvm.tir.call_intrin("handle", "tir.address_of", dst_handle[zero_indices]),
+                tvm.tir.call_intrin("handle", "tir.address_of", src_handle[zero_indices]),
+                size,
             )
         )
         return ib.get()
@@ -62,6 +73,36 @@ def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
     return te.decl_tensor_intrin(dst.op, intrin_func, binds={src: src_buffer, dst: dst_buffer})
 
 
+def verify(hexagon_session, s, x, y, z, size):
+    print(tvm.lower(s, [x, y, z]))
+
+    target_hexagon = tvm.target.hexagon("v68", link_params=True)
+    func = tvm.build(
+        s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy"
+    )
+
+    if hexagon_session is None:
+        pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")
+
+    mod = hexagon_session.load_module(func)
+    xt = tvm.nd.array(
+        np.random.randint(low=-128, high=127, size=size, dtype=x.dtype),
+        device=hexagon_session.device,
+    )
+    yt = tvm.nd.array(
+        np.random.randint(low=-128, high=127, size=size, dtype=y.dtype),
+        device=hexagon_session.device,
+    )
+    zt = tvm.nd.array(
+        np.random.randint(low=-128, high=127, size=size, dtype=z.dtype),
+        device=hexagon_session.device,
+    )
+    mod["dmacpy"](xt, yt, zt)
+
+    ref = xt.numpy() + yt.numpy()
+    np.testing.assert_equal(zt.numpy(), ref)
+
+
 @requires_hexagon_toolchain
 def test_cache_read_write(hexagon_session):
     size = 128
@@ -75,52 +116,66 @@ def test_cache_read_write(hexagon_session):
     z = te.compute(outer_shape, lambda i: x[i] + y[i], name="z")
     s = te.create_schedule(z.op)
 
-    x_global = s.cache_read(x, "global.vtcm", [z])
-    y_global = s.cache_read(y, "global.vtcm", [z])
-    z_global = s.cache_write(z, "global.vtcm")
+    x_vtcm = s.cache_read(x, "global.vtcm", [z])
+    y_vtcm = s.cache_read(y, "global.vtcm", [z])
+    z_vtcm = s.cache_write(z, "global.vtcm")
 
-    zouter, zinner = s[z_global].split(z_global.op.axis[0], factor=factor)
+    zouter, zinner = s[z_vtcm].split(z_vtcm.op.axis[0], factor=factor)
 
-    s[x_global].compute_at(s[z_global], zouter)
-    s[y_global].compute_at(s[z_global], zouter)
+    s[x_vtcm].compute_at(s[z_vtcm], zouter)
+    s[y_vtcm].compute_at(s[z_vtcm], zouter)
 
     mem_copy_read = intrin_mem_copy(inner_shape, dtype, "global.vtcm", "global")
 
-    (cache_read_x,) = s[x_global].op.axis
-    s[x_global].tensorize(cache_read_x, mem_copy_read)
+    (cache_read_x,) = s[x_vtcm].op.axis
+    s[x_vtcm].tensorize(cache_read_x, mem_copy_read)
 
-    (cache_read_y,) = s[y_global].op.axis
-    s[y_global].tensorize(cache_read_y, mem_copy_read)
+    (cache_read_y,) = s[y_vtcm].op.axis
+    s[y_vtcm].tensorize(cache_read_y, mem_copy_read)
 
     mem_copy_write = intrin_mem_copy(outer_shape, dtype, "global", "global.vtcm")
 
     (cache_write_z,) = s[z].op.axis
     s[z].tensorize(cache_write_z, mem_copy_write)
 
-    print(tvm.lower(s, [x, y, z]))
+    verify(hexagon_session, s, x, y, z, size)
 
-    target_hexagon = tvm.target.hexagon("v68", link_params=True)
-    func = tvm.build(
-        s, [x, y, z], tvm.target.Target(target_hexagon, host=target_hexagon), name="dmacpy"
-    )
 
-    if hexagon_session is None:
-        pytest.skip("Skip hardware test since ANDROID_SERIAL_NUMBER is not set.")
+def layout_transform_2d(n):
+    return [n // 16, te.AXIS_SEPARATOR, n % 16]
 
-    mod = hexagon_session.load_module(func)
-    xt = tvm.nd.array(
-        np.random.randint(low=-128, high=127, size=size, dtype=x.dtype),
-        device=hexagon_session.device,
-    )
-    yt = tvm.nd.array(
-        np.random.randint(low=-128, high=127, size=size, dtype=y.dtype),
-        device=hexagon_session.device,
-    )
-    zt = tvm.nd.array(
-        np.random.randint(low=-128, high=127, size=size, dtype=z.dtype),
-        device=hexagon_session.device,
-    )
-    mod["dmacpy"](xt, yt, zt)
 
-    ref = xt.numpy() + yt.numpy()
-    np.testing.assert_equal(zt.numpy(), ref)
+@requires_hexagon_toolchain
+def test_cache_read_write_2d(hexagon_session):
+    size = 128
+    outer_shape = (size,)
+    factor = 16
+    inner_shape = (factor,)
+    dtype = "int8"
+
+    x = te.placeholder(shape=outer_shape, dtype=dtype, name="x")
+    y = te.placeholder(shape=outer_shape, dtype=dtype, name="y")
+    z = te.compute(outer_shape, lambda i: x[i] + y[i], name="z")
+    s = te.create_schedule(z.op)
+
+    x_vtcm = s.cache_read(x, "global.vtcm", [z])
+    y_vtcm = s.cache_read(y, "global.vtcm", [z])
+    z_vtcm = s.cache_write(z, "global.vtcm")
+
+    layout_x_vtcm = s[x_vtcm].transform_layout(layout_transform_2d)
+    layout_y_vtcm = s[y_vtcm].transform_layout(layout_transform_2d)
+    layout_z_vtcm = s[z_vtcm].transform_layout(layout_transform_2d)
+
+    mem_copy_read = intrin_mem_copy(inner_shape, dtype, "global.vtcm", "global")
+    s[x_vtcm].tensorize(layout_x_vtcm[1], mem_copy_read)
+    s[y_vtcm].tensorize(layout_y_vtcm[1], mem_copy_read)
+
+    # The loop schedule over `z` is not modified when calling `transform_layout`
+    # on `z_vtcm` above therefore we must call `split` to modify the loop schedule
+    # over `z` to match the layout of `z_vtcm` such that we can accurately write
+    # `z_vtcm` back to `z` using memory copy intrinsic
+    zouter, zinner = s[z].split(z.op.axis[0], factor=factor)
+    mem_copy_write = intrin_mem_copy(inner_shape, dtype, "global", "global.vtcm")
+    s[z].tensorize(zinner, mem_copy_write)
+
+    verify(hexagon_session, s, x, y, z, size)