You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2022/11/22 20:33:57 UTC

[GitHub] [tvm] Lunderberg commented on a diff in pull request #13463: [TIR] Fix buffer shape and IndexMap indices dtype mismatch

Lunderberg commented on code in PR #13463:
URL: https://github.com/apache/tvm/pull/13463#discussion_r1029680983


##########
tests/python/unittest/test_tir_schedule_transform_layout.py:
##########
@@ -173,6 +173,35 @@ def two_elementwise_unit_dim(A: T.Buffer[(1, 128), "float32"], C: T.Buffer[(1, 1
             vi, vj = T.axis.remap("SS", [i, j])
             C[vi, vj] = B[vi, vj] + 1.0
 
+
+
+@tvm.script.ir_module

Review Comment:
   Is this entire definition required for the test case?  As a reader, it's hard to tell which parts of this PrimFunc are needed to trigger the bug.



##########
src/tir/schedule/primitive/layout_transformation.cc:
##########
@@ -1055,13 +1055,43 @@ class TransformationIntroducesPaddingError : public ScheduleError {
   PrimExpr padding_predicate_;
 };
 
+// Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid
+// dtype-mismatch issues later.
+IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) {
+  auto initial_indices_orig = index_map->initial_indices;
+  ICHECK(buf->shape.size() == initial_indices_orig.size());
+
+  Array<Var> initial_indices;
+  Map<Var, PrimExpr> var_map;
+
+  for (size_t i = 0; i < buf->shape.size(); ++i) {
+    if (buf->shape[i]->dtype != initial_indices_orig[i].dtype()) {

Review Comment:
   I think this would have an error if only some of the index dtypes have a mismatch.  In that case, `initial_indices` would only be filled with variables that have a mismatched dtype, when it should have the same size as `initial_indices_orig`.
   
   ```c++
   if (buf->shape[i]->dtype == initial_indices_orig[i].dtype()) {
       initial_indices.push_back(initial_indices_orig[i]);   
   } else {
       auto new_idx = ...
   }
   ```



##########
tests/python/unittest/test_tir_schedule_transform_layout.py:
##########
@@ -925,5 +954,24 @@ def expected(a: T.handle):
                 A[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 42, dtype="int32")
 
 
+def test_index_map_dtype_legalize():
+    """Test dtype legalization of the index map indices."""
+
+    def index_map_nchw32c_nchw8h8w32c(n_batch, channel, height, width, channel_32):
+        return [n_batch, channel, height // 8, width // 8, height % 8, width % 8, channel_32]
+
+    sch = tir.Schedule(Conv2dNCHW32c)
+
+    conv2d_block = sch.get_block("conv2d_NCHWc_int8")
+    sch.cache_read(conv2d_block, 0, "global.vtcm")
+
+    # The following error is raised from the IterVar constructor without the dtype legalization.
+    # TVMError: Check failed: dom->extent.dtype() == var.dtype() (int64 vs. int32) :
+    # The dtype of the extent of an IterVar (int64) must match its associated Var's dtype (int32)
+    sch.transform_layout(
+        conv2d_block, ("read", 0), index_map=index_map_nchw32c_nchw8h8w32c, pad_value=0

Review Comment:
   Instead of `("read", 0)`, the buffer can be specified by name, `buffer = "data_pad_global_vtcm"`.



##########
src/tir/schedule/primitive/layout_transformation.cc:
##########
@@ -1055,13 +1055,43 @@ class TransformationIntroducesPaddingError : public ScheduleError {
   PrimExpr padding_predicate_;
 };
 
+// Make the dtypes of indices in IndexMap be the same as the dtype of the buffer shape, to avoid
+// dtype-mismatch issues later.
+IndexMap LegalizeIndexMapDType(const IndexMap& index_map, const Buffer& buf) {
+  auto initial_indices_orig = index_map->initial_indices;

Review Comment:
   Nit: `const auto&` instead of `auto`.



##########
tests/python/unittest/test_tir_schedule_transform_layout.py:
##########
@@ -173,6 +173,35 @@ def two_elementwise_unit_dim(A: T.Buffer[(1, 128), "float32"], C: T.Buffer[(1, 1
             vi, vj = T.axis.remap("SS", [i, j])
             C[vi, vj] = B[vi, vj] + 1.0
 
+
+
+@tvm.script.ir_module
+class Conv2dNCHW32c:
+    @T.prim_func

Review Comment:
   This location for `@T.prim_func` would run while the test is being collected.  This can make it difficult to trouble-shoot, since a failure when parsing/constructing would prevent any unit test from running, not just the unit tests that make use of it.  It would be better to have the `@T.prim_func` be in the unit test itself, so a failure to construct the primfunc would only cause a failure in a single test.



##########
tests/python/unittest/test_tir_schedule_transform_layout.py:
##########
@@ -925,5 +954,24 @@ def expected(a: T.handle):
                 A[i, j] = T.if_then_else(i == 3 and 2 <= j, 0, 42, dtype="int32")
 
 
+def test_index_map_dtype_legalize():
+    """Test dtype legalization of the index map indices."""
+
+    def index_map_nchw32c_nchw8h8w32c(n_batch, channel, height, width, channel_32):
+        return [n_batch, channel, height // 8, width // 8, height % 8, width % 8, channel_32]
+
+    sch = tir.Schedule(Conv2dNCHW32c)
+
+    conv2d_block = sch.get_block("conv2d_NCHWc_int8")
+    sch.cache_read(conv2d_block, 0, "global.vtcm")

Review Comment:
   Why does the test case call `cache_read`, instead of starting with the input provided to `transform_layout`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org