You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/07/23 07:07:20 UTC

[GitHub] [incubator-tvm] zhanghaohit opened a new pull request #6126: Feature/intelfocl pr

zhanghaohit opened a new pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126


   This is related to #5840 and split from PR #5842 
   
   This PR only keeps the modification to TVM for opencl support 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on a change in pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on a change in pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#discussion_r477024332



##########
File path: src/tir/transforms/lower_tvm_builtin.cc
##########
@@ -88,16 +88,6 @@ class BuiltinLower : public StmtExprMutator {
     op = stmt.as<AllocateNode>();
     // Get constant allocation bound.
     int64_t nbytes = GetVectorBytes(op->dtype);

Review comment:
       @tqchen perhaps you'd have some input on why this code was needed in the first place?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on a change in pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on a change in pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#discussion_r460328729



##########
File path: vta/tutorials/autotvm/tune_alu_vta.py
##########
@@ -0,0 +1,317 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Auto-tuning a ALU fused op on VTA
+"""
+
+import os
+from mxnet.gluon.model_zoo import vision
+import numpy as np
+from PIL import Image
+
+import topi
+import tvm
+from tvm import te
+from tvm import rpc, autotvm, relay
+from tvm.contrib import graph_runtime, util, download
+from tvm.autotvm.measure.measure_methods import request_remote
+from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
+from tvm.autotvm import record
+
+import vta
+from vta.testing import simulator
+from vta.top import graph_pack
+import copy
+
+
+#################################################################
+# Compile network
+# ---------------
+# Perform vta-specific compilation with Relay from a Gluon model
+def compile_network(env, target, model, start_pack, stop_pack):
+
+    # Populate the shape and data type dictionary
+    dtype_dict = {"data": 'float32'}
+    shape_dict = {"data": (env.BATCH, 3, 224, 224)}
+
+    # Get off the shelf gluon model, and convert to relay
+    gluon_model = vision.get_model(model, pretrained=True)
+    mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)
+
+    # Update shape and type dictionary
+    shape_dict.update({k: v.shape for k, v in params.items()})
+    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
+
+    # Perform quantization in Relay
+    # Note: We set opt_level to 3 in order to fold batch norm
+    with relay.build_config(opt_level=3):
+        with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):
+            mod = relay.quantize.quantize(mod, params=params)
+
+    # Perform graph packing and constant folding for VTA target
+    if target.device_name == "vta":
+        assert env.BLOCK_IN == env.BLOCK_OUT
+        relay_prog = graph_pack(mod["main"],
+                                env.BATCH,
+                                env.BLOCK_OUT,
+                                env.WGT_WIDTH,
+                                start_name=start_pack,
+                                stop_name=stop_pack)
+
+    return relay_prog, params
+
+
+###########################################
+# Set Tuning Options
+# ------------------
+# Before tuning, we should apply some configurations.
+# Here we use an Pynq-Z1 board as an example.
+
+# Tracker host and port can be set by your environment
+tracker_host = os.environ.get("TVM_TRACKER_HOST", '0.0.0.0')
+tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190))
+
+# Load VTA parameters from the vta/config/vta_config.json file
+env = vta.get_env()
+
+# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
+# Set ``device=arm_cpu`` to run inference on the CPU
+# or ``device=vta`` to run inference on the FPGA.
+device = "vta"
+target = env.target if device == "vta" else env.target_vta_cpu
+
+# Name of Gluon model to compile
+# The ``start_pack`` and ``stop_pack`` labels indicate where
+# to start and end the graph packing relay pass: in other words
+# where to start and finish offloading to VTA.
+network = "resnet50_v2"
+start_pack = "nn.max_pool2d"
+stop_pack = "nn.global_avg_pool2d"
+
+# Tuning option
+log_file = "%s.alu.%s.log" % (device, network)
+tuning_option = {
+    'log_filename': log_file,
+
+    'tuner': 'random',
+    'n_trial': 1000,
+    'early_stopping': None,
+
+    'measure_option': autotvm.measure_option(
+        builder=autotvm.LocalBuilder(n_parallel=1),
+        runner=autotvm.RPCRunner(env.TARGET,
+                                 host=tracker_host,
+                                 port=tracker_port,
+                                 number=5,
+                                 timeout=60,
+                                 check_correctness=True),
+    ),
+}
+
+
+def log_to_file(file_out, protocol='json'):
+    """Log the tuning records into file.
+    The rows of the log are stored in the format of autotvm.record.encode.
+    for lhs == rhs, we add an extra rhs = [] record
+
+    Parameters
+    ----------
+    file_out : str
+        The file to log to.
+    protocol: str, optional
+        The log protocol. Can be 'json' or 'pickle'
+
+    Returns
+    -------
+    callback : callable
+        Callback function to do the logging.
+    """
+    def _callback(_, inputs, results):
+        with open(file_out, "a") as f:
+            for inp, result in zip(inputs, results):
+                f.write(record.encode(inp, result, protocol) + "\n")
+
+                # we only consider task with same lhs and rhs
+                if inp.task.args[0] == inp.task.args[1]:
+                    args = list(inp.task.args)
+                    args[1] = (args[0][0], (), args[0][2])
+                    inp_copy = copy.deepcopy(inp)
+                    inp_copy.task.args = tuple(args)
+                    f.write(record.encode(inp_copy, result, protocol) + "\n")
+
+    return _callback
+
+
+def tune_tasks(tasks,
+               measure_option,
+               tuner='xgb',
+               n_trial=10,
+               early_stopping=None,
+               log_filename='tuning.log',
+               use_transfer_learning=True):
+
+    # create tmp log file
+    tmp_log_file = log_filename + ".tmp"
+    if os.path.exists(tmp_log_file):
+        os.remove(tmp_log_file)
+
+    for i, tsk in enumerate(reversed(tasks)):
+        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
+
+        # create tuner
+        if tuner == 'xgb' or tuner == 'xgb-rank':
+            tuner_obj = XGBTuner(tsk, loss_type='rank')
+        elif tuner == 'xgb_knob':
+            tuner_obj = XGBTuner(tsk, loss_type='rank', feature_type='knob')
+        elif tuner == 'ga':
+            tuner_obj = GATuner(tsk, pop_size=50)
+        elif tuner == 'random':
+            tuner_obj = RandomTuner(tsk)
+        elif tuner == 'gridsearch':
+            tuner_obj = GridSearchTuner(tsk)
+        else:
+            raise ValueError("Invalid tuner: " + tuner)
+
+        if use_transfer_learning:
+            if os.path.isfile(tmp_log_file):
+                tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
+
+        # do tuning
+        tsk_trial = min(n_trial, len(tsk.config_space))
+        tuner_obj.tune(n_trial=tsk_trial,
+                       early_stopping=early_stopping,
+                       measure_option=measure_option,
+                       callbacks=[
+                           autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
+                           log_to_file(tmp_log_file)
+                       ])
+
+    # pick best records to a cache file
+    autotvm.record.pick_best(tmp_log_file, log_filename)
+    os.remove(tmp_log_file)
+
+
+########################################################################
+# Register VTA-specific tuning tasks
+def register_vta_tuning_tasks():
+    from tvm.autotvm.task import TaskExtractEnv
+
+    @tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
+    def my_clip(x, a_min, a_max):
+        """Unlike topi's current clip, put min and max into two stages."""
+        const_min = tvm.tir.const(a_min, x.dtype)
+        const_max = tvm.tir.const(a_max, x.dtype)
+        x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA")
+        x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
+        return x
+
+    # init autotvm env to register VTA operator
+    TaskExtractEnv()
+
+    @autotvm.template("add.vta")
+    def _topi_add(*args, **kwargs):
+        assert not kwargs, "Do not support kwargs in template function call"
+        A, B = args[:2]
+
+        with tvm.target.vta():
+            res = vta.top.op.add_packed(*args, **kwargs)
+            res = my_clip(res, 0, 127)
+            res = topi.cast(res, "int8")
+
+        if tvm.target.Target.current().device_name == 'vta':
+            s = vta.top.op.schedule_add_packed([res])
+        else:
+            s = te.create_schedule([res.op])
+        return s, [A, B, res]
+
+    @autotvm.template("multiply.vta")

Review comment:
       given that the other versions of VTA don't have a multiplier, we may want to add a check against e.g. intelfocl target. Or an alternative is to have an additional parameter in the VTA params file e.g. ALU_MULT = True.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on a change in pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on a change in pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#discussion_r477022969



##########
File path: python/tvm/autotvm/task/topi_integration.py
##########
@@ -227,17 +227,21 @@ def wrapper(outs, *args, **kwargs):
     return _decorate
 
 
-def get_workload(outs):
+def get_workload(outs, task_name=None):
     """Retrieve the workload from outputs"""
     def traverse(tensors):
         """traverse all ops to find attached workload"""
         for t in tensors:
             op = t.op
-            if 'workload' in op.attrs:
-                return args_to_workload(op.attrs['workload'])
             wkl = traverse(op.input_tensors)
             if wkl:
                 return wkl
+

Review comment:
       Thanks for clarifying. How do we guard against extracting add as a standalone op for other backends int his case?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#issuecomment-663772661


   I've also noticed that the issue with the CI is that the VTA submodule is dependent on changes that haven't landed. We'll first merge the changes you have made in incubator-tvm-vta, then this one here.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#issuecomment-680580847


   @liangfu would you mind doing a pass on this PR as well?


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhanghaohit commented on pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
zhanghaohit commented on pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#issuecomment-669646117


   > Thanks @zhanghaohit. This is converging nicely. I made some additional comments.
   > 
   > In addition, I'd like to request further partitioning given the large size of the PR.
   > 
   > (1) the following files will need to migrate to #6125:
   > 
   > * src/relay/op/annotation/annotation.cc
   > * python/tvm/relay/op/_tensor.py
   > 
   > (2) changes made for quantization should be isolated to an additional PR, this includes:
   > 
   > * src/relay/quantize/realize.cc
   > * python/tvm/relay/quantize/_partition.py
   > * python/tvm/relay/quantize/_annotate.py
   
   changes made for quantization are moved to #6191


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on a change in pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on a change in pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#discussion_r460328002



##########
File path: src/tir/transforms/lower_tvm_builtin.cc
##########
@@ -88,16 +88,6 @@ class BuiltinLower : public StmtExprMutator {
     op = stmt.as<AllocateNode>();
     // Get constant allocation bound.
     int64_t nbytes = GetVectorBytes(op->dtype);

Review comment:
       do you mind explaining the reasoning behind this deletion?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on a change in pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on a change in pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#discussion_r460327072



##########
File path: python/tvm/autotvm/task/topi_integration.py
##########
@@ -227,17 +227,21 @@ def wrapper(outs, *args, **kwargs):
     return _decorate
 
 
-def get_workload(outs):
+def get_workload(outs, task_name=None):
     """Retrieve the workload from outputs"""
     def traverse(tensors):
         """traverse all ops to find attached workload"""
         for t in tensors:
             op = t.op
-            if 'workload' in op.attrs:
-                return args_to_workload(op.attrs['workload'])
             wkl = traverse(op.input_tensors)
             if wkl:
                 return wkl
+

Review comment:
       do you mind explaining the changes made to this file? 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhanghaohit commented on a change in pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
zhanghaohit commented on a change in pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#discussion_r462007395



##########
File path: src/tir/transforms/lower_tvm_builtin.cc
##########
@@ -88,16 +88,6 @@ class BuiltinLower : public StmtExprMutator {
     op = stmt.as<AllocateNode>();
     // Get constant allocation bound.
     int64_t nbytes = GetVectorBytes(op->dtype);

Review comment:
       Original code will fail if there are multiple workloads in one schedule. For example, in `fused_nn_conv2d_add_add_right_shift_clip_cast_31`, the `conv2d` and `add` may both have `workload` attrs. We have to get the correct workload by comparing the `task_name`.
   
   Previously it works fine, as `add` is not a tunable op. But since we also want to put middle alu-only nodes (residual blocks) to VTA, such as `fused_cast_cast_add_nn_relu_clip_cast_3`. We create a vta schedule for `add` (see [add.alu](https://github.com/apache/incubator-tvm/blob/a1daa1c47d2a19e51ab96da5c20e187f3bbf3413/vta/python/vta/top/op.py#L166))

##########
File path: python/tvm/autotvm/task/topi_integration.py
##########
@@ -227,17 +227,21 @@ def wrapper(outs, *args, **kwargs):
     return _decorate
 
 
-def get_workload(outs):
+def get_workload(outs, task_name=None):
     """Retrieve the workload from outputs"""
     def traverse(tensors):
         """traverse all ops to find attached workload"""
         for t in tensors:
             op = t.op
-            if 'workload' in op.attrs:
-                return args_to_workload(op.attrs['workload'])
             wkl = traverse(op.input_tensors)
             if wkl:
                 return wkl
+

Review comment:
       Original code will fail if there are multiple workloads in one schedule. For example, in `fused_nn_conv2d_add_add_right_shift_clip_cast_31`, the `conv2d` and `add` may both have `workload` attrs. We have to get the correct workload by comparing the `task_name`.
   
   Previously it works fine, as `add` is not a tunable op. But since we also want to put middle alu-only nodes (residual blocks) to VTA, such as `fused_cast_cast_add_nn_relu_clip_cast_3`. We create a vta schedule for `add` (see [add.alu](https://github.com/apache/incubator-tvm/blob/a1daa1c47d2a19e51ab96da5c20e187f3bbf3413/vta/python/vta/top/op.py#L166))

##########
File path: src/relay/backend/compile_engine.cc
##########
@@ -230,7 +230,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
           << "Two complicated op in a primitive function "
           << " master=" << master_op_ << " current=" << op;
     }
-    if (op_pattern >= master_op_pattern_) {
+    if (op_pattern > master_op_pattern_) {

Review comment:
       By this change, the op visited first will have high priority to be the master_op, compared with the ops with the same op_pattern. I think it is true that the front ops are more important.
   
   This change is also due to the introduction of ALU-only node, e.g., `fused_cast_cast_add_left_shift_add_nn_relu_add_right_shift_cast_2`. In this case, we'll choose the first add as the master_op. The last `add` is actually an addition of a tensor and const. If we choose the last `add` as the master_op, the autotune config may cause ALU buffer overflow, as we actually need more ALU buffer for the whole fused op.

##########
File path: src/relay/backend/compile_engine.cc
##########
@@ -288,7 +288,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
   tvm::Target target_;
   Op master_op_;
   Attrs master_attrs_;
-  int master_op_pattern_{0};
+  int master_op_pattern_{-1};

Review comment:
       This is due to the change here:
   ```
       if (op_pattern >= master_op_pattern_) {
   ```
   to
   ```
       if (op_pattern > master_op_pattern_) {
   ```

##########
File path: src/tir/transforms/lower_tvm_builtin.cc
##########
@@ -88,16 +88,6 @@ class BuiltinLower : public StmtExprMutator {
     op = stmt.as<AllocateNode>();
     // Get constant allocation bound.
     int64_t nbytes = GetVectorBytes(op->dtype);

Review comment:
       This removes special handling for kDLCPU. Otherwise, it may cause LLVM parameters match error.
   
   ```bash
   Traceback (most recent call last):
     File "vta/tutorials/frontend/deploy_classification.py", line 210, in <module>
       params=params, target_host=env.target_host)
     File "/4pd/home/zhanghao/workspace/tvm-2/tvm/python/tvm/relay/build_module.py", line 251, in build
       graph_json, mod, params = bld_mod.build(mod, target, target_host, params)
     File "/4pd/home/zhanghao/workspace/tvm-2/tvm/python/tvm/relay/build_module.py", line 120, in build
       self._build(mod, target, target_host)
     File "tvm/_ffi/_cython/./packed_func.pxi", line 321, in tvm._ffi._cy3.core.PackedFuncBase.__call__
     File "tvm/_ffi/_cython/./packed_func.pxi", line 256, in tvm._ffi._cy3.core.FuncCall
     File "tvm/_ffi/_cython/./packed_func.pxi", line 245, in tvm._ffi._cy3.core.FuncCall3
     File "tvm/_ffi/_cython/./base.pxi", line 160, in tvm._ffi._cy3.core.CALL
   tvm._ffi.base.TVMError: Traceback (most recent call last):
     [bt] (8) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(TVMFuncCall+0x4c) [0x7f385ac9bc1c]
     [bt] (7) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(std::_Function_handler<void (tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*), tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}>::_M_invoke(std::_Any_data const&, tvm::runtime::TVMArgs&&, tvm::runtime::TVMRetValue*&&)+0x316) [0x7f385ab2a566]
     [bt] (6) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, std::unordered_map<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, tvm::runtime::NDArray, std::hash<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::equal_to<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > >, std::allocator<std::pair<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const, tvm::runtime::NDArray> > > const&)+0xe31) [0x7f385ab29c11]
     [bt] (5) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(tvm::build(tvm::Map<tvm::runtime::String, tvm::IRModule, void, void> const&, tvm::Target const&)+0x3c4) [0x7f385a4322d4]
     [bt] (4) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(tvm::build(tvm::Map<tvm::Target, tvm::IRModule, void, void> const&, tvm::Target const&)+0x326) [0x7f385a4318c6]
     [bt] (3) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(tvm::codegen::Build(tvm::IRModule, tvm::Target const&)+0x67a) [0x7f385a74f68a]
     [bt] (2) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(+0x1277ea1) [0x7f385ac7eea1]
     [bt] (1) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(tvm::codegen::LLVMModuleNode::Init(tvm::IRModule const&, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >)+0x1388) [0x7f385ac82c68]
     [bt] (0) /4pd/home/zhanghao/workspace/tvm-2/tvm/build/libtvm.so(+0x1276a57) [0x7f385ac7da57]
     File "/4pd/home/zhanghao/workspace/tvm-2/tvm/src/target/llvm/llvm_module.cc", line 230
   TVMError: LLVM module verification failed with the following errors: 
   Call parameter type does not match function signature!
     %.sub = getelementptr inbounds [4 x <8 x float>], [4 x <8 x float>]* %3, i64 0, i64 0
    i8*  %34 = call i8* @VTABufferCPUPtr(i8* %17, <8 x float>* nonnull %.sub)
   Call parameter type does not match function signature!
     %.sub = getelementptr inbounds [8 x float], [8 x float]* %3, i64 0, i64 0
    i8*  %31 = call i8* @VTABufferCPUPtr(i8* %14, float* nonnull %.sub)
   ```
   The raise error is due to the LLVM code here (lib/IR/Verifier.cpp):
   ```{.c++ filename="lib/IR/Verifier.cpp"}
   2598   // Verify that all arguments to the call match the function type.                                                                                                                                            
   2599   for (unsigned i = 0, e = FTy->getNumParams(); i != e; ++i)                                                                                                                                                   
   2600     Assert(CS.getArgument(i)->getType() == FTy->getParamType(i),                                                                                                                                               
   2601            "Call parameter type does not match function signature!",                                                                                                                                           
   2602            CS.getArgument(i), FTy->getParamType(i), I); 
   ```
   
   It will raise this error if the special handling for kDLCPU is there. I think it is because the signature for the AllocateNode is not consistent with the parameter? Any ideas about alternative fix?

##########
File path: vta/tutorials/autotvm/tune_alu_vta.py
##########
@@ -0,0 +1,317 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+Auto-tuning a ALU fused op on VTA
+"""
+
+import os
+from mxnet.gluon.model_zoo import vision
+import numpy as np
+from PIL import Image
+
+import topi
+import tvm
+from tvm import te
+from tvm import rpc, autotvm, relay
+from tvm.contrib import graph_runtime, util, download
+from tvm.autotvm.measure.measure_methods import request_remote
+from tvm.autotvm.tuner import XGBTuner, GATuner, RandomTuner, GridSearchTuner
+from tvm.autotvm import record
+
+import vta
+from vta.testing import simulator
+from vta.top import graph_pack
+import copy
+
+
+#################################################################
+# Compile network
+# ---------------
+# Perform vta-specific compilation with Relay from a Gluon model
+def compile_network(env, target, model, start_pack, stop_pack):
+
+    # Populate the shape and data type dictionary
+    dtype_dict = {"data": 'float32'}
+    shape_dict = {"data": (env.BATCH, 3, 224, 224)}
+
+    # Get off the shelf gluon model, and convert to relay
+    gluon_model = vision.get_model(model, pretrained=True)
+    mod, params = relay.frontend.from_mxnet(gluon_model, shape_dict)
+
+    # Update shape and type dictionary
+    shape_dict.update({k: v.shape for k, v in params.items()})
+    dtype_dict.update({k: str(v.dtype) for k, v in params.items()})
+
+    # Perform quantization in Relay
+    # Note: We set opt_level to 3 in order to fold batch norm
+    with relay.build_config(opt_level=3):
+        with relay.quantize.qconfig(global_scale=8.0, skip_conv_layers=[0]):
+            mod = relay.quantize.quantize(mod, params=params)
+
+    # Perform graph packing and constant folding for VTA target
+    if target.device_name == "vta":
+        assert env.BLOCK_IN == env.BLOCK_OUT
+        relay_prog = graph_pack(mod["main"],
+                                env.BATCH,
+                                env.BLOCK_OUT,
+                                env.WGT_WIDTH,
+                                start_name=start_pack,
+                                stop_name=stop_pack)
+
+    return relay_prog, params
+
+
+###########################################
+# Set Tuning Options
+# ------------------
+# Before tuning, we should apply some configurations.
+# Here we use an Pynq-Z1 board as an example.
+
+# Tracker host and port can be set by your environment
+tracker_host = os.environ.get("TVM_TRACKER_HOST", '0.0.0.0')
+tracker_port = int(os.environ.get("TVM_TRACKER_PORT", 9190))
+
+# Load VTA parameters from the vta/config/vta_config.json file
+env = vta.get_env()
+
+# This target is used for cross compilation. You can query it by :code:`gcc -v` on your device.
+# Set ``device=arm_cpu`` to run inference on the CPU
+# or ``device=vta`` to run inference on the FPGA.
+device = "vta"
+target = env.target if device == "vta" else env.target_vta_cpu
+
+# Name of Gluon model to compile
+# The ``start_pack`` and ``stop_pack`` labels indicate where
+# to start and end the graph packing relay pass: in other words
+# where to start and finish offloading to VTA.
+network = "resnet50_v2"
+start_pack = "nn.max_pool2d"
+stop_pack = "nn.global_avg_pool2d"
+
+# Tuning option
+log_file = "%s.alu.%s.log" % (device, network)
+tuning_option = {
+    'log_filename': log_file,
+
+    'tuner': 'random',
+    'n_trial': 1000,
+    'early_stopping': None,
+
+    'measure_option': autotvm.measure_option(
+        builder=autotvm.LocalBuilder(n_parallel=1),
+        runner=autotvm.RPCRunner(env.TARGET,
+                                 host=tracker_host,
+                                 port=tracker_port,
+                                 number=5,
+                                 timeout=60,
+                                 check_correctness=True),
+    ),
+}
+
+
+def log_to_file(file_out, protocol='json'):
+    """Log the tuning records into file.
+    The rows of the log are stored in the format of autotvm.record.encode.
+    for lhs == rhs, we add an extra rhs = [] record
+
+    Parameters
+    ----------
+    file_out : str
+        The file to log to.
+    protocol: str, optional
+        The log protocol. Can be 'json' or 'pickle'
+
+    Returns
+    -------
+    callback : callable
+        Callback function to do the logging.
+    """
+    def _callback(_, inputs, results):
+        with open(file_out, "a") as f:
+            for inp, result in zip(inputs, results):
+                f.write(record.encode(inp, result, protocol) + "\n")
+
+                # we only consider task with same lhs and rhs
+                if inp.task.args[0] == inp.task.args[1]:
+                    args = list(inp.task.args)
+                    args[1] = (args[0][0], (), args[0][2])
+                    inp_copy = copy.deepcopy(inp)
+                    inp_copy.task.args = tuple(args)
+                    f.write(record.encode(inp_copy, result, protocol) + "\n")
+
+    return _callback
+
+
+def tune_tasks(tasks,
+               measure_option,
+               tuner='xgb',
+               n_trial=10,
+               early_stopping=None,
+               log_filename='tuning.log',
+               use_transfer_learning=True):
+
+    # create tmp log file
+    tmp_log_file = log_filename + ".tmp"
+    if os.path.exists(tmp_log_file):
+        os.remove(tmp_log_file)
+
+    for i, tsk in enumerate(reversed(tasks)):
+        prefix = "[Task %2d/%2d] " % (i + 1, len(tasks))
+
+        # create tuner
+        if tuner == 'xgb' or tuner == 'xgb-rank':
+            tuner_obj = XGBTuner(tsk, loss_type='rank')
+        elif tuner == 'xgb_knob':
+            tuner_obj = XGBTuner(tsk, loss_type='rank', feature_type='knob')
+        elif tuner == 'ga':
+            tuner_obj = GATuner(tsk, pop_size=50)
+        elif tuner == 'random':
+            tuner_obj = RandomTuner(tsk)
+        elif tuner == 'gridsearch':
+            tuner_obj = GridSearchTuner(tsk)
+        else:
+            raise ValueError("Invalid tuner: " + tuner)
+
+        if use_transfer_learning:
+            if os.path.isfile(tmp_log_file):
+                tuner_obj.load_history(autotvm.record.load_from_file(tmp_log_file))
+
+        # do tuning
+        tsk_trial = min(n_trial, len(tsk.config_space))
+        tuner_obj.tune(n_trial=tsk_trial,
+                       early_stopping=early_stopping,
+                       measure_option=measure_option,
+                       callbacks=[
+                           autotvm.callback.progress_bar(tsk_trial, prefix=prefix),
+                           log_to_file(tmp_log_file)
+                       ])
+
+    # pick best records to a cache file
+    autotvm.record.pick_best(tmp_log_file, log_filename)
+    os.remove(tmp_log_file)
+
+
+########################################################################
+# Register VTA-specific tuning tasks
+def register_vta_tuning_tasks():
+    from tvm.autotvm.task import TaskExtractEnv
+
+    @tvm.te.tag_scope(tag=topi.tag.ELEMWISE)
+    def my_clip(x, a_min, a_max):
+        """Unlike topi's current clip, put min and max into two stages."""
+        const_min = tvm.tir.const(a_min, x.dtype)
+        const_max = tvm.tir.const(a_max, x.dtype)
+        x = te.compute(x.shape, lambda *i: tvm.te.min(x(*i), const_max), name="clipA")
+        x = te.compute(x.shape, lambda *i: tvm.te.max(x(*i), const_min), name="clipB")
+        return x
+
+    # init autotvm env to register VTA operator
+    TaskExtractEnv()
+
+    @autotvm.template("add.vta")
+    def _topi_add(*args, **kwargs):
+        assert not kwargs, "Do not support kwargs in template function call"
+        A, B = args[:2]
+
+        with tvm.target.vta():
+            res = vta.top.op.add_packed(*args, **kwargs)
+            res = my_clip(res, 0, 127)
+            res = topi.cast(res, "int8")
+
+        if tvm.target.Target.current().device_name == 'vta':
+            s = vta.top.op.schedule_add_packed([res])
+        else:
+            s = te.create_schedule([res.op])
+        return s, [A, B, res]
+
+    @autotvm.template("multiply.vta")

Review comment:
       Actually this tune_alu_vta.py tutorial is for intelfocl only. If target is not intelfocl, we'll show error message and return.
   
   In the main code, I add the check here [vta/top/op.py](https://github.com/apache/incubator-tvm/blob/a1daa1c47d2a19e51ab96da5c20e187f3bbf3413/vta/python/vta/top/op.py#L194)
   
   ```
   env = get_env()
   # other target does not support alu-only ops
   if env.TARGET in ["sim", "tsim", "intelfocl"]:
       reg.get("add").get_attr("FTVMStrategy").register(add_strategy_vta, "vta")
       reg.get("multiply").get_attr("FTVMStrategy").register(multiply_strategy_vta, "vta")
   ```




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhanghaohit commented on a change in pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
zhanghaohit commented on a change in pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#discussion_r478102076



##########
File path: vta/runtime/runtime.cc
##########
@@ -329,7 +442,7 @@ class BaseQueue {
   // End location of current SRAM write in FIFO mode
   uint32_t sram_end_{0};
   // The buffer in DRAM
-  std::vector<T> dram_buffer_;
+  std::vector<T, AlignmentAllocator<T, 64>> dram_buffer_;

Review comment:
       done




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#issuecomment-680579495


   Thanks overall the PR is in good shape. I think that we'll need to merge the incubator-tvm-vta PR first, before we merge this one in (due to the bump on the VTA submodule)


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on a change in pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on a change in pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#discussion_r460327463



##########
File path: src/relay/backend/compile_engine.cc
##########
@@ -230,7 +230,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
           << "Two complicated op in a primitive function "
           << " master=" << master_op_ << " current=" << op;
     }
-    if (op_pattern >= master_op_pattern_) {
+    if (op_pattern > master_op_pattern_) {

Review comment:
       what corner case is this fixing?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on a change in pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on a change in pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#discussion_r477024974



##########
File path: vta/runtime/runtime.cc
##########
@@ -329,7 +442,7 @@ class BaseQueue {
   // End location of current SRAM write in FIFO mode
   uint32_t sram_end_{0};
   // The buffer in DRAM
-  std::vector<T> dram_buffer_;
+  std::vector<T, AlignmentAllocator<T, 64>> dram_buffer_;

Review comment:
       I'm thinking we may want to make 64 a preprocessor constant




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on a change in pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on a change in pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#discussion_r477025169



##########
File path: vta/runtime/runtime.cc
##########
@@ -429,14 +542,24 @@ class UopQueue : public BaseQueue<VTAUop> {
       buff_size += cache_[i]->size() * kElemBytes;
     }
     CHECK(buff_size <= kMaxBytes);
-    // Move kernel contents to FPGA readable buffer
+
+    // merge all the cache entries and do CopyFromHost once
+    uint32_t total_size = 0;
+    for (uint32_t i = 0; i < cache_.size(); ++i) {
+      uint32_t ksize = cache_[i]->size() * kElemBytes;
+      total_size += ksize;
+    }
+
+    char *lbuf = (char*)memalign(64, total_size);

Review comment:
       Same here with the constant




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#issuecomment-663772343


   Thanks @zhanghaohit. This is converging nicely. I made some additional comments.
   
   In addition, I'd like to request further partitioning given the large size of the PR.
   
   (1) the following files will need to migrate to #6125:
   - src/relay/op/annotation/annotation.cc
   - python/tvm/relay/op/_tensor.py
   
   (2) changes made for quantization should be isolated to an additional PR, this includes:
   - src/relay/quantize/realize.cc
   - python/tvm/relay/quantize/_partition.py
   - python/tvm/relay/quantize/_annotate.py


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] tmoreau89 commented on a change in pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
tmoreau89 commented on a change in pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#discussion_r460327531



##########
File path: src/relay/backend/compile_engine.cc
##########
@@ -288,7 +288,7 @@ class ScheduleGetter : public backend::MemoizedExprTranslator<Array<te::Tensor>>
   tvm::Target target_;
   Op master_op_;
   Attrs master_attrs_;
-  int master_op_pattern_{0};
+  int master_op_pattern_{-1};

Review comment:
       could you comment on this change?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-tvm] zhanghaohit commented on pull request #6126: [VTA][OpenCL] intelfocl

Posted by GitBox <gi...@apache.org>.
zhanghaohit commented on pull request #6126:
URL: https://github.com/apache/incubator-tvm/pull/6126#issuecomment-669645896


   > Please address aforementioned changes, thank you
   
   Done. Thanks @tmoreau89 for the comments. 


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org