You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/02/08 03:28:28 UTC

[GitHub] [tvm] ilovetvm opened a new issue #7420: Parallel on Reduction Axis in CPU needs rigorous checking

ilovetvm opened a new issue #7420:
URL: https://github.com/apache/tvm/issues/7420


   Using the primitive `parallel` on a reduction axis in CPU (`target="llvm`) can lead to silent errors. In the following `GEMM` example, TVM compiles successfully, and computes as if the computation result is correct. However, the comparison with `Pytorch` shows that the final result is wrong.
   It would be better if a more rigorous checking for this could be added to TVM. At least some warning messages for parallel reduction should be properly produced to give an alarm to the user who is learning how to write correct schedules.
   ```
   import tvm
   import torch
   import numpy as np
   
   M = 2
   N = 2
   K = 4
   
   def gemm(inputs, weight):
     assert inputs.shape[1].value == weight.shape[0].value
     M, K = inputs.shape[0], inputs.shape[1]
     N = weight.shape[1].value
     k = tvm.te.reduce_axis((0, K))
     return tvm.te.compute((M, N), lambda i, j: tvm.te.sum(inputs[i, k] * weight[k, j], axis=k))
   
   def test(parallel):
     A_np = np.random.random([M, K]).astype(np.float32) * 100
     B_np = np.random.random([K, N]).astype(np.float32) * 100
   
     A_torch = torch.tensor(A_np)
     B_torch = torch.tensor(B_np)
     C_torch = A_torch @ B_torch
   
     tvm_ctx = tvm.context("llvm", 0)
     A_tvm = tvm.nd.array(A_np, tvm_ctx)
     B_tvm = tvm.nd.array(B_np, tvm_ctx)
   
     C_tvm = tvm.nd.array(np.zeros(C_torch.shape).astype(np.float32), tvm_ctx)
     A_t = tvm.te.placeholder(A_np.shape, dtype="float32")
     B_t = tvm.te.placeholder(B_np.shape, dtype="float32")
   
     C = gemm(A_t, B_t)
   
     s = tvm.te.create_schedule(C.op)
   
     if parallel == True:
       k_axis, = s[C].op.reduce_axis
       s[C].parallel(k_axis)
   
     print(tvm.lower(s, [A_t, B_t, C], simple_mode=True))
   
     func = tvm.build(s, [A_t, B_t, C], "llvm")
     func(A_tvm, B_tvm, C_tvm)
   
     np.testing.assert_allclose(C_tvm.asnumpy(), C_torch.numpy(), rtol=1e-5)
   
   test(False)
   print("===========================================================")
   test(True)
   ```
   The result of running the above code is
   ```
   primfn(placeholder_2: handle, placeholder_3: handle, compute_1: handle) -> ()
     attr = {"global_symbol": "main", "tir.noalias": True}
     buffers = {compute: Buffer(compute_2: Pointer(float32), float32, [2, 2], []),
                placeholder: Buffer(placeholder_4: Pointer(float32), float32, [2, 4], []),
                placeholder_1: Buffer(placeholder_5: Pointer(float32), float32, [4, 2], [])}
     buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1, compute_1: compute} {
     for (i: int32, 0, 2) {
       for (j: int32, 0, 2) {
         compute_2[((i*2) + j)] = 0f32
         for (rv: int32, 0, 4) {
           compute_2[((i*2) + j)] = ((float32*)compute_2[((i*2) + j)] + ((float32*)placeholder_4[((i*4) + rv)]*(float32*)placeholder_5[((rv*2) + j)]))
         }
       }
     }
   }
   
   
   ===========================================================
   primfn(placeholder_2: handle, placeholder_3: handle, compute_1: handle) -> ()
     attr = {"global_symbol": "main", "tir.noalias": True}
     buffers = {compute: Buffer(compute_2: Pointer(float32), float32, [2, 2], []),
                placeholder: Buffer(placeholder_4: Pointer(float32), float32, [2, 4], []),
                placeholder_1: Buffer(placeholder_5: Pointer(float32), float32, [4, 2], [])}
     buffer_map = {placeholder_2: placeholder, placeholder_3: placeholder_1, compute_1: compute} {
     for (i: int32, 0, 2) {
       for (j: int32, 0, 2) {
         compute_2[((i*2) + j)] = 0f32
         for (rv: int32, 0, 4) "parallel" {
           compute_2[((i*2) + j)] = ((float32*)compute_2[((i*2) + j)] + ((float32*)placeholder_4[((i*4) + rv)]*(float32*)placeholder_5[((rv*2) + j)]))
         }
       }
     }
   }
   
   
   Traceback (most recent call last):
   ...
       np.testing.assert_allclose(C_tvm.asnumpy(), C_torch.numpy(), rtol=1e-5)
   ...
   AssertionError: 
   Not equal to tolerance rtol=1e-05, atol=0
   
   Mismatched elements: 3 / 4 (75%)
   Max absolute difference: 1769.9526
   Max relative difference: 0.21025768
    x: array([[ 6648.064,  9747.168],
          [ 5256.175, 11761.918]], dtype=float32)
    y: array([[ 8418.017 , 10098.07  ],
          [ 5256.1753, 13276.437 ]], dtype=float32)
   ```
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org