You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/10/19 23:16:23 UTC

[GitHub] [tvm] adstraw commented on a change in pull request #9287: Adjust Hexagon conv2d schedule to split channel out (k) and move to outer loop

adstraw commented on a change in pull request #9287:
URL: https://github.com/apache/tvm/pull/9287#discussion_r732307896



##########
File path: tests/python/contrib/test_hexagon/README.md
##########
@@ -118,173 +128,220 @@ primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> ()
             for (wi.c: int32, 0, 8) {
               for (ki.c: int32, 0, 32) {
                 for (rc.inner: int32, 0, 32) {
-                  output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = 
+                  output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] = 
                   (
-                    (float32*)output.cache[(((((wo.c*4096) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + 
+                    (float32*)output.cache[((((wo.c*2048) + (hi.c*256)) + (wi.c*32)) + ki.c)] + 
                     (
                       (float32*)input.cache[(((((wo.c*4096) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] *
-                      (float32*)kernel_pointer[(((((ko.c*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))]
+                      (float32*)filter.cache[((((rc.outer*1024) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))]
                     )
                   )
                 }
               }
             }
           }
-        } // end rc.outer
-      } // end ko.c
-    } // end wo.c
+        }
+      } // end wo.c
 
-    // cache write
-    for (wo: int32, 0, 8) {
-      for (ko: int32, 0, 2) {
+      // cache write
+      for (wo: int32, 0, 8) {
         for (hi: int32, 0, 8) {
           for (wi: int32, 0, 8) {
             for (ki: int32, 0, 32) {
-              output_pointer[((((((ho.outer*32768) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = 
-                (float32*)output.cache[(((((wo*4096) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)]
+              output_pointer[((((((ho.outer*65536) + (wo*8192)) + (ko.outer*2048)) + (hi*256)) + (wi*32)) + ki)] = 
+                (float32*)output.cache[((((wo*2048) + (hi*256)) + (wi*32)) + ki)]
             }
           }
         }
       }
-    }
-  }
+    } // end ho.outer
+  } // end ko.outer
 }
 ```
 
-# Split on Height - "Full Output Slice"
+# Split on Channel Out and Height - "Full Output Slice"
 
-Adds a new parameter `h_split` which creates a loop split on the height `h` dimension.  The cache reads and writes are moved to the outer of the two loops created by that split - the loop over `ho.outer`.  This increases cache usage by a factor equivalent to `h_split`.  The compute is still "full width" and "full depth" in the channel-out dimension and now over multiple slices in the height `h` dimension.  
+Adds new parameters `k_split` and `h_split` which creates a loop split on the outer channel out `ko` and height `ho` loops creating `outer` and `inner` loops for each split.  The cache reads and writes are computed at `ho.outer` which means that cache allocation grow in proportion to `k_split` and `h_split` factors.
 
-The key changes in TIR versus the baseline are ...
+The key changes in TIR versus the above are...
 
 1) Increased cache allocations:
 
 ```
+  // input cache grows by factor of h_split = 2
   allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global;
+
+  // filter cache grows by factor of k_split = 2
+  allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global;
+
+  // output cache grows by factor of h_split * k_split = 4
   allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global;
 ```
 
-2) The loop split on the `h` dimension:
+2) Outer loop splits using k_split and h_split factors
 
 ```
-  for (ho.outer: int32, 0, 4) {
-    for (ho.inner: int32, 0, 2) {
+  // ko.outer = outer loop split on ko using k_split factor
+  for (ko.outer: int32, 0, 2) {
+    // ho.outer = outer loop split on ho using h_split factor
+    for (ho.outer: int32, 0, 4) {
+```
+
+3) Inner loop splits in both cache read / write and compute schedules.  This is taken from the compute schedule e.g.
+```
+      for (ko.c.inner: int32, 0, 2) {
+        for (ho.c.inner: int32, 0, 2) {
 ```
 
 ## Command
 
-pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-2-1-64-64-64-llvm]"
+pytest -sv "tests/python/contrib/test_hexagon/test_conv2d_blocked.py::TestConv2dPackedFilter::test_conv2d[conv2d_packed_filter-1-1-0-float32-2-2-1-64-64-128-llvm]"
 
 ## Parameters
 
 | Parameter | Value       |
 | --------- | ----------- |
 | Batch     | 1           |
-| Kernel    | 1x1         |
+| Filter    | 1x1         |
 | Spatial   | 64x64       |
 | Input Ch  | 64          |
-| Output Ch | 64          |
+| Output Ch | 128         |
 | Stride    | 1           |
 | Padding   | 0           |
 | Layout    | NHWC8h8w32c |
+| k_split   | 2           |
 | h_split   | 2           |
 
 ## Assumptions
 
-Same as baseline
+* n/a - With the loop splits on `ko` and `ho` the compute schedule is now over `ko.inner` `ho.inner` `wo` etc. This should fit the pattern matching for microkernels.
 
 ## To Do
 
-Same as baseline
+* n/a
 
 ## Annotated TIR
 
 ```
-primfn(input_handle: handle, kernel_handle: handle, output_handle: handle) -> ()
+primfn(input_handle: handle, filter_handle: handle, output_handle: handle) -> ()
   attr = {"from_legacy_te_schedule": True, "global_symbol": "default_function", "tir.noalias": True, "target": meta[Target][0]}
-  buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 2, 8, 8, 32], []),
-             kernel_buffer: Buffer(kernel_pointer: Pointer(float32), float32, [2, 2, 1, 1, 8, 32, 4], []),
-             input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])}
-  buffer_map = {input_handle: input_buffer, kernel_handle: kernel_buffer, output_handle: output_buffer} {
-  
-  // increased cache usage due to h_split parameter
+  buffers = {output_buffer: Buffer(output_pointer: Pointer(float32), float32, [1, 8, 8, 4, 8, 8, 32], []), // NHWC8h8w32c
+             filter_buffer: Buffer(filter_pointer: Pointer(float32), float32, [4, 2, 1, 1, 8, 32, 4], []), // OIHW8i32o4i
+             input_buffer: Buffer(input_pointer: Pointer(float32), float32, [1, 64, 64, 64], [])} // NHWC (pending RFC)
+  buffer_map = {input_handle: input_buffer, filter_handle: filter_buffer, output_handle: output_buffer} {
+
+  // input cache grows by factor of h_split = 2
   allocate(input.cache: Pointer(global float32), float32, [65536]), storage_scope = global;
+
+  // filter cache grows by factor of k_split = 2
+  allocate(filter.cache: Pointer(global float32), float32, [4096]), storage_scope = global;
+
+  // output cache grows by factor of h_split * k_split = 4
   allocate(output.cache: Pointer(global float32), float32, [65536]), storage_scope = global;
+  
+  // ko.outer = outer loop split on ko using k_split factor
+  for (ko.outer: int32, 0, 2) {
+    // ho.outer = outer loop split on ho using h_split factor
+    for (ho.outer: int32, 0, 4) {
+
+      // input cache read
+      // NHWC -> NHWC8h8w32c (pending RFC)
+      for (ho.inner: int32, 0, 2) {
+        for (wo: int32, 0, 8) {
+          for (co: int32, 0, 2) {
+            for (hi: int32, 0, 8) {
+              for (wi: int32, 0, 8) {
+                for (ci: int32, 0, 32) {
+                  input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = 
+                    (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)]
+                }
+              }
+            }
+          }
+        }
+      } // end ho.inner
 
-  // loop split ho.outer vs. ho.inner based on h_split parameter
-  for (ho.outer: int32, 0, 4) {
-    for (ho.inner: int32, 0, 2) {
-      for (wo: int32, 0, 8) {
+      // filter cache read
+      for (ko.inner: int32, 0, 2) {
         for (co: int32, 0, 2) {
-          for (hi: int32, 0, 8) {
-            for (wi: int32, 0, 8) {
-              for (ci: int32, 0, 32) {
-                input.cache[((((((ho.inner*32768) + (wo*4096)) + (co*2048)) + (hi*256)) + (wi*32)) + ci)] = 
-                  (float32*)input_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (hi*4096)) + (wo*512)) + (wi*64)) + (co*32)) + ci)]
+          for (ci8: int32, 0, 8) {
+            for (ki: int32, 0, 32) {
+              for (ci4: int32, 0, 4) {
+                filter.cache[(((((ko.inner*2048) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)] = 
+                  (float32*)filter_pointer[((((((ko.outer*4096) + (ko.inner*2048)) + (co*1024)) + (ci8*128)) + (ki*4)) + ci4)]
               }
             }
           }
         }
-      }
-    }
-    for (ho.c.inner: int32, 0, 2) {
-      for (wo.c: int32, 0, 8) {
-        for (ko.c: int32, 0, 2) {
-          for (hi.c.init: int32, 0, 8) {
-            for (wi.c.init: int32, 0, 8) {
-              for (ki.c.init: int32, 0, 32) {
-                output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32
+      } // end ko.inner
+
+      // compute
+      for (ko.c.inner: int32, 0, 2) {
+        for (ho.c.inner: int32, 0, 2) {
+          for (wo.c: int32, 0, 8) {
+
+            // init output cache
+            for (hi.c.init: int32, 0, 8) {
+              for (wi.c.init: int32, 0, 8) {
+                for (ki.c.init: int32, 0, 32) {
+                  output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c.init*256)) + (wi.c.init*32)) + ki.c.init)] = 0f32
+                }
               }
             }
-          }
-          for (rc.outer: int32, 0, 2) {
-            for (hi.c: int32, 0, 8) {
-              for (wi.c: int32, 0, 8) {
-                for (ki.c: int32, 0, 32) {
-                  for (rc.inner: int32, 0, 32) {
-                    output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = 
-                    (
-                      (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + 
+
+            // convolution
+            for (rc.outer: int32, 0, 2) {
+              for (hi.c: int32, 0, 8) {
+                for (wi.c: int32, 0, 8) {
+                  for (ki.c: int32, 0, 32) {
+                    for (rc.inner: int32, 0, 32) {
+                      output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] = 
                       (
-                        (float32*)input.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] *
-                        (float32*)kernel_pointer[(((((ko.c*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))]
+                        (float32*)output.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (ko.c.inner*2048)) + (hi.c*256)) + (wi.c*32)) + ki.c)] + 
+                        (
+                          (float32*)input.cache[((((((ho.c.inner*32768) + (wo.c*4096)) + (rc.outer*2048)) + (hi.c*256)) + (wi.c*32)) + rc.inner)] *
+                          (float32*)filter.cache[(((((ko.c.inner*2048) + (rc.outer*1024)) + (floordiv(rc.inner, 4)*128)) + (ki.c*4)) + floormod(rc.inner, 4))]
+                        )
                       )
-                    )
+                    }
                   }
                 }
               }
             }
-          }
-        }
-      }
-    }
-    for (ho.inner: int32, 0, 2) {
-      for (wo: int32, 0, 8) {
-        for (ko: int32, 0, 2) {
-          for (hi: int32, 0, 8) {
-            for (wi: int32, 0, 8) {
-              for (ki: int32, 0, 32) {
-                output_pointer[(((((((ho.outer*65536) + (ho.inner*32768)) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)] = 
-                  (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko*2048)) + (hi*256)) + (wi*32)) + ki)]
+          } // end wo.c
+        } // end ho.c.inner
+      } // end ko.c.inner
+
+      // cache write
+      for (ko.inner: int32, 0, 2) {
+        for (ho.inner: int32, 0, 2) {
+          for (wo: int32, 0, 8) {
+            for (hi: int32, 0, 8) {
+              for (wi: int32, 0, 8) {
+                for (ki: int32, 0, 32) {
+                  output_pointer[((((((((ho.outer*131072) + (ho.inner*65536)) + (wo*8192)) + (ko.outer*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)] = 
+                    (float32*)output.cache[((((((ho.inner*32768) + (wo*4096)) + (ko.inner*2048)) + (hi*256)) + (wi*32)) + ki)]
+                }
               }
             }
           }
-        }
-      }
-    }
-  }
+        } // end ho.inner
+      } // end ko.inner
+    } // end ho.outer
+  } // end ko.outer
 }
 ```
 
 # 3x3 conv2d (no padding)
 
-Change from a 1x1 kernel to a 3x3 kernel.  The implication of this change is that `h_split + 1` rather than just `h_split` "full width" slices of the input are required to compute the output.  This is due to the fact that the 3x3 kernel will "fall off the bottom" of the input and thus the vertically adjacent "full width" slice must be prefetched into the input cache.
+Change from a 1x1 filter to a 3x3 filter.  The implication of this change is that `h_split + 1` rather than just `h_split` "full width" slices of the input are required to compute the output.  This is due to the fact that the 3x3 filter will "fall off the bottom" of the input and thus the vertically adjacent "full width" slice must be prefetched into the input cache.
 
 The key changes in TIR versus the above are...
 
 1) Increased input cache size to hold the vertically adjacent slice
 
 ```
+  // input cache grows to hold vertically adjacent slice

Review comment:
       A full-width full-channel-in-depth slice.  The explanation for this is above.  Line 337.  You need `h_split + 1` vertical slices to calculate `h_split` output slices. 




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscribe@tvm.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org