You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/05/27 02:40:29 UTC

[GitHub] [tvm] euntaik opened a new issue #8145: Auto-scheduler tuned OpenCL kernel produces Invalid Workgroup Size Error

euntaik opened a new issue #8145:
URL: https://github.com/apache/tvm/issues/8145


   TVM Version: 0.8
   OS: Ubuntu 20.04
   
   Auto-scheduler tuned OpenCL kernel produces Invalid Workgroup Size Error.
   
   Below is the script that I used to reproduce the issue. The script basically creates a tflite model, tunes it and run until an error occurs. It tunes 10 times and runs to check for errors.
   I've been able to reproduce the issue in about 100 iterations or less.
   
   ```
   import numpy as np
   import tensorflow as tf
   import tvm
   from tvm import relay, auto_scheduler
   from tensorflow.keras import Model, Input, layers, models, initializers
   import tflite
   from tvm.contrib import ndk
   import tarfile
   from tvm.driver.tvmc.runner import run_module
   from tvm.auto_scheduler.utils import request_remote
   import os
   import time
   
   ###############################################################################
   # RPC configuration
   #
   RPC_KEY = 'android'
   RPC_PORT = 9190
   RPC_TRACKER = '0.0.0.0'
   
   ###############################################################################
   # Tune config
   #
   SKIP_TUNE = False
   NUM_TRIALS = 20000
   MIN_TRIALS = 10
   EPOCH_CNT = NUM_TRIALS / MIN_TRIALS
   
   def generate_tflite_file(N, H, W, CO, CI, KH, KW):
       def test_model(N, H, W, CO, CI, KH, KW):
           input = Input(shape=(CI, H, W), batch_size=N, name='input')
           x = layers.Conv2D(CO, (KH, KW),
                       kernel_initializer=initializers.RandomUniform(),
                       bias_initializer=initializers.RandomUniform(),
                       activation='relu')(input)
           x = layers.Conv2DTranspose(CO, (KH, KW),
                       kernel_initializer=initializers.RandomUniform(),
                       bias_initializer=initializers.RandomUniform(),
                       activation='relu')(x)
           x = layers.Conv2D(CO, (KH, KW),
                       kernel_initializer=initializers.RandomUniform(),
                       bias_initializer=initializers.RandomUniform(),
                       activation='relu')(x)
           output = layers.Conv2DTranspose(CO, (KH, KW),
                       kernel_initializer=initializers.RandomUniform(),
                       bias_initializer=initializers.RandomUniform(),
                       activation='relu')(x)
           model = Model(input, output)
           return model
   
       model = test_model(N,H,W,CO,CI,KH,KW)
       # save as tflite
       converter = tf.lite.TFLiteConverter.from_keras_model(model)
       tflite_model = converter.convert()
       params = '_'.join(str(x) for x in (N,H,W,CO,CI,KH,KW))
       tflite_file = f'invalid_ws_test_{params}.tflite'
       with open(tflite_file, 'wb') as f:
           f.write(tflite_model)
   
       return tflite_file
   
   
   ###############################################################################
   # Generate a tflite file
   #
   tflite_file = generate_tflite_file(1,140,108,64,64,2,2)
   
   
   
   ###############################################################################
   # tune and run until Invalid workgroup size error is generated
   #
   target_host = tvm.target.Target('llvm -mtriple=arm64-linux-android')
   target = tvm.target.Target('opencl')
   
   record_filename = f'{tflite_file}.records'
   
   print('Request remote...')
   remote = request_remote(RPC_KEY, RPC_TRACKER, RPC_PORT, timeout=3)
   dev = remote.cl()
   
   print(dev.max_clock_rate)
   print(dev.max_shared_memory_per_block)
   print(dev.max_thread_dimensions)
   print(dev.max_threads_per_block)
   print(dev.multi_processor_count)
   print(dev.warp_size)
   
   INT32_MAX = 2147483647
   hardware_params = auto_scheduler.HardwareParams(
       dev.multi_processor_count,
       16,
       64,
       dev.max_shared_memory_per_block,
       INT32_MAX,
       dev.max_threads_per_block,
       int(dev.warp_size / 4) if int(dev.warp_size / 4) > 1 else dev.warp_size,
       dev.warp_size,
   )
   
   # wait for remote session to timeout
   time.sleep(3)
   
   epoch_cnt = 0
   while epoch_cnt < EPOCH_CNT:
       print('===============================================================================')
       print(f'Starting Epoch #{epoch_cnt}')
       print('===============================================================================')
   
       tf_model_buf = open(tflite_file, "rb").read()
       tflite_model = tflite.Model.GetRootAsModel(tf_model_buf, 0)
       mod, params = relay.frontend.from_tflite(tflite_model, shape_dict=None, dtype_dict=None)
   
       if not SKIP_TUNE:
           tasks, task_weights = auto_scheduler.extract_tasks(
                               mod["main"], params, target,
                               target_host=target_host,
                               hardware_params=hardware_params)
   
           runner = auto_scheduler.RPCRunner(key=RPC_KEY, host=RPC_TRACKER, port=RPC_PORT, repeat=1, timeout=50, n_parallel=1)
   
           builder = auto_scheduler.LocalBuilder(build_func="ndk")
   
           tuner = auto_scheduler.TaskScheduler(tasks, task_weights, load_log_file=record_filename)
           tune_option = auto_scheduler.TuningOptions(
               num_measure_trials=MIN_TRIALS,
               builder=builder,
               runner=runner,
               measure_callbacks=[auto_scheduler.RecordToFile(record_filename)],
               verbose = 1
           )
   
           print("Tuning...")
           tuner.tune(tune_option)
           # save best records
           auto_scheduler.measure_record.distill_record_file(record_filename, record_filename+'.best')
   
       print("Compiling...")
       if os.path.isfile(record_filename):
           config = {'relay.backend.use_auto_scheduler':True}
           with auto_scheduler.ApplyHistoryBest(record_filename):
               with tvm.transform.PassContext(opt_level=3, config=config, disabled_pass=None):
                   graph, lib, params = relay.build(mod, target=target, target_host=target_host, params=params)
       else:
           graph, lib, params = relay.build(mod, target=target, target_host=target_host, params=params)
   
       lib_file = tflite_file + '.so'
       graph_file = tflite_file + '.json'
       param_file = tflite_file + '.params'
   
       lib.export_library(lib_file, ndk.create_shared)
   
       with open(graph_file, 'w') as f:
           f.write(graph)
       with open(param_file, 'wb') as f:
           f.write(relay.save_param_dict(params))
   
       tvm_model_file = tflite_file + '.tar'
       with tarfile.open(tvm_model_file, 'w') as tar:
           tar.add(lib_file, arcname='mod.so')
           tar.add(param_file, arcname='mod.params')
           tar.add(graph_file, arcname='mod.json')
   
       # Run on target
       print(f'Running {tvm_model_file} on target device...')
       outputs, time = run_module(tvm_model_file, hostname=RPC_TRACKER, port=RPC_PORT, rpc_key=RPC_KEY,
           device='cl', inputs=None, fill_mode='random', repeat=1, profile=False)
   
       print(time)
       #print(outputs)
       epoch_cnt += 1
   
   ```
   
   
   With little bit of more effort I was able to capture what auto-scheduler has ran and what the final kernel was.
   
   Kernel compiled during tuning was:
   ```
   __kernel void default_function_kernel0(__global float* restrict placeholder, __global float* restrict placeholder1, __global float* restrict compute) {
   float compute_local[4];
   __local float data_pad_shared[3408];
   __local float placeholder_shared[1024];
   for (int c_c_inner_init = 0; c_c_inner_init < 2; ++c_c_inner_init) {
   compute_local[(c_c_inner_init)] = 0.000000e+00f;
   compute_local[((c_c_inner_init + 2))] = 0.000000e+00f;
   }
   for (int dc_outer_outer = 0; dc_outer_outer < 4; ++dc_outer_outer) {
   barrier(CLK_LOCAL_MEM_FENCE);
   for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer < 7; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer) {
   if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) < 3408) {
   data_pad_shared[(((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))))] = (((((1 <= (((((int)get_group_id(0)) & 1) * 70) + (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 71))) && ((((((int)get_group_id(0)) & 1) * 70) + (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 71)) < 140)) && (1 <= ((((((int)get_group_id(0)) & 63) >> 1) * 2) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 213) / 71)))) && (((((((int)get_group_id(0)) & 63) >> 1) * 2) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 213) / 71)) < 64)) ? placeholder[((((((((dc_outer_outer * 140112) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) / 213) * 8757)) + (((((int)get_group_id(0)) & 63) >> 1) * 278)) + (((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 213) / 71) * 139)) + (
 (((int)get_group_id(0)) & 1) * 70)) + (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) % 71)) - 140))] : 0.000000e+00f);
   }
   }
   for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 < 2; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1) {
   if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 140) + (((int)get_local_id(0)) >> 2)) < 256) {
   if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 280) + (((int)get_local_id(0)) >> 1)) < 512) {
   if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 560) + ((int)get_local_id(0))) < 1024) {
   placeholder_shared[(((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 560) + ((int)get_local_id(0))))] = placeholder1[((((((dc_outer_outer * 4096) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 140) + (((int)get_local_id(0)) >> 2)) >> 4) * 256)) + ((((int)get_group_id(0)) >> 6) * 64)) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 140) + (((int)get_local_id(0)) >> 2)) & 15) * 4)) + (((int)get_local_id(0)) & 3)))];
   }
   }
   }
   }
   barrier(CLK_LOCAL_MEM_FENCE);
   for (int dc_outer_inner = 0; dc_outer_inner < 8; ++dc_outer_inner) {
   for (int dh_outer_inner = 0; dh_outer_inner < 2; ++dh_outer_inner) {
   for (int dc_inner = 0; dc_inner < 2; ++dc_inner) {
   for (int dw_inner = 0; dw_inner < 2; ++dw_inner) {
   for (int c_c_inner = 0; c_c_inner < 2; ++c_c_inner) {
   compute_local[(c_c_inner)] = (compute_local[(c_c_inner)] + (data_pad_shared[(((((((dc_outer_inner * 426) + (dc_inner * 213)) + (((((int)get_local_id(0)) % 140) / 70) * 71)) + (dh_outer_inner * 71)) + dw_inner) + (((int)get_local_id(0)) % 70)))] * placeholder_shared[((((((((dc_outer_inner * 128) + (dc_inner * 64)) + ((((int)get_local_id(0)) / 140) * 8)) + (c_c_inner * 4)) + 3) - dw_inner) - (dh_outer_inner * 2)))]));
   compute_local[((c_c_inner + 2))] = (compute_local[((c_c_inner + 2))] + (data_pad_shared[(((((((dc_outer_inner * 426) + (dc_inner * 213)) + (((((int)get_local_id(0)) % 140) / 70) * 71)) + (dh_outer_inner * 71)) + dw_inner) + (((int)get_local_id(0)) % 70)))] * placeholder_shared[((((((((dc_outer_inner * 128) + (dc_inner * 64)) + ((((int)get_local_id(0)) / 140) * 8)) + (c_c_inner * 4)) + 35) - dw_inner) - (dh_outer_inner * 2)))]));
   }
   }
   }
   }
   }
   }
   for (int c_inner = 0; c_inner < 2; ++c_inner) {
   compute[(((((((((((int)get_group_id(0)) >> 6) * 143360) + ((((int)get_local_id(0)) / 140) * 17920)) + (c_inner * 8960)) + (((((int)get_group_id(0)) & 63) >> 1) * 280)) + (((((int)get_local_id(0)) % 140) / 70) * 140)) + ((((int)get_group_id(0)) & 1) * 70)) + (((int)get_local_id(0)) % 70)))] = compute_local[(c_inner)];
   compute[((((((((((((int)get_group_id(0)) >> 6) * 143360) + ((((int)get_local_id(0)) / 140) * 17920)) + (c_inner * 8960)) + (((((int)get_group_id(0)) & 63) >> 1) * 280)) + (((((int)get_local_id(0)) % 140) / 70) * 140)) + ((((int)get_group_id(0)) & 1) * 70)) + (((int)get_local_id(0)) % 70)) + 71680))] = compute_local[((c_inner + 2))];
   }
   }
   ```
   
   
   Kernel compiled using the record written with above tuning results.
   ```
   __kernel void fused_nn_conv2d_transpose_83_kernel0(__global float* restrict placeholder, __global float* restrict placeholder1, __global float* restrict compute) {
   
   float compute_local[4];
   __local float data_pad_shared[3408];
   __local float placeholder_shared[1024];
   for (int c_c_outer_inner_init = 0; c_c_outer_inner_init < 2; ++c_c_outer_inner_init) {
   compute_local[(c_c_outer_inner_init)] = 0.000000e+00f;
   compute_local[((c_c_outer_inner_init + 2))] = 0.000000e+00f;
   }
   for (int dc_outer_outer = 0; dc_outer_outer < 4; ++dc_outer_outer) {
   barrier(CLK_LOCAL_MEM_FENCE);
   for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer < 7; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer) {
   if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))) < 3408) {
   data_pad_shared[(((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer * 560) + ((int)get_local_id(0))))] = (((((1 <= (((((int)get_group_id(0)) & 1) * 70) + (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_o
   }
   }
   for (int ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 = 0; ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 < 2; ++ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1) {
   if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 140) + (((int)get_local_id(0)) >> 2)) < 256) {
   if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 280) + (((int)get_local_id(0)) >> 1)) < 512) {
   if (((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 560) + ((int)get_local_id(0))) < 1024) {
   placeholder_shared[(((ax0_ax1_fused_ax2_fused_ax3_fused_outer_outer1 * 560) + ((int)get_local_id(0))))] = placeholder1[((((((dc_outer_outer * 4096) + ((((ax0_ax1_fused_ax2_fused_ax3_fused_out
   }
   }
   }
   }
   barrier(CLK_LOCAL_MEM_FENCE);
   for (int dc_outer_inner = 0; dc_outer_inner < 8; ++dc_outer_inner) {
   for (int dh_outer_inner = 0; dh_outer_inner < 2; ++dh_outer_inner) {
   for (int c_c_outer_inner = 0; c_c_outer_inner < 2; ++c_c_outer_inner) {
   for (int dc_inner = 0; dc_inner < 2; ++dc_inner) {
   for (int dw_inner = 0; dw_inner < 2; ++dw_inner) {
   compute_local[(c_c_outer_inner)] = (compute_local[(c_c_outer_inner)] + (data_pad_shared[(((((((dc_outer_inner * 426) + (dc_inner * 213)) + (((((int)get_local_id(0)) % 140) / 70) * 71)) + (d
   compute_local[((c_c_outer_inner + 2))] = (compute_local[((c_c_outer_inner + 2))] + (data_pad_shared[(((((((dc_outer_inner * 426) + (dc_inner * 213)) + (((((int)get_local_id(0)) % 140) / 70)
   }
   }
   }
   }
   }
   }
   for (int c_inner = 0; c_inner < 2; ++c_inner) {
   compute[(((((((((((int)get_group_id(0)) >> 6) * 143360) + ((((int)get_local_id(0)) / 140) * 17920)) + (c_inner * 8960)) + (((((int)get_group_id(0)) & 63) >> 1) * 280)) + (((((int)get_local_id(0)) % 1
   compute[((((((((((((int)get_group_id(0)) >> 6) * 143360) + ((((int)get_local_id(0)) / 140) * 17920)) + (c_inner * 8960)) + (((((int)get_group_id(0)) & 63) >> 1) * 280)) + (((((int)get_local_id(0)) %
   }
   }
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 closed issue #8145: Auto-scheduler tuned OpenCL kernel produces Invalid Workgroup Size Error

Posted by GitBox <gi...@apache.org>.
vinx13 closed issue #8145:
URL: https://github.com/apache/tvm/issues/8145


   


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [tvm] vinx13 commented on issue #8145: Auto-scheduler tuned OpenCL kernel produces Invalid Workgroup Size Error

Posted by GitBox <gi...@apache.org>.
vinx13 commented on issue #8145:
URL: https://github.com/apache/tvm/issues/8145#issuecomment-855507092


   Thanks for asking the question, please open a new thread on https://discuss.tvm.apache.org/ as we use the forum for related discussions. For this case, could you if check hardware params are correct?


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org