You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2020/04/21 21:23:26 UTC
[incubator-tvm] branch master updated: [PTYTHON] Migrate VTA TIR
passes to the new pass manager. (#5397)
This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new d327787 [PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)
d327787 is described below
commit d3277874a24e775d2476b0eb0ad89f3a46964a14
Author: Tianqi Chen <tq...@users.noreply.github.com>
AuthorDate: Tue Apr 21 14:23:18 2020 -0700
[PTYTHON] Migrate VTA TIR passes to the new pass manager. (#5397)
---
include/tvm/target/target.h | 5 +-
python/tvm/autotvm/measure/measure_methods.py | 8 +-
python/tvm/driver/build_module.py | 29 +-
python/tvm/tir/function.py | 16 +
src/target/target.cc | 4 +-
tests/python/relay/test_pass_fold_constant.py | 8 +-
tests/python/unittest/test_target_codegen_cuda.py | 10 +-
tests/python/unittest/test_target_codegen_llvm.py | 11 +-
.../unittest/test_tir_pass_verify_gpu_code.py | 8 +-
tutorials/dev/low_level_custom_pass.py | 11 +-
vta/python/vta/build_module.py | 56 +-
vta/python/vta/ir_pass.py | 995 ---------------------
vta/python/vta/transform.py | 962 ++++++++++++++++++++
13 files changed, 1050 insertions(+), 1073 deletions(-)
diff --git a/include/tvm/target/target.h b/include/tvm/target/target.h
index 59aa955..829de73 100644
--- a/include/tvm/target/target.h
+++ b/include/tvm/target/target.h
@@ -27,6 +27,7 @@
#include <tvm/support/with.h>
#include <tvm/node/container.h>
#include <tvm/ir/expr.h>
+#include <tvm/ir/transform.h>
#include <string>
#include <vector>
@@ -225,8 +226,8 @@ class BuildConfigNode : public Object {
/*! \brief Whether to partition const loop */
bool partition_const_loop = false;
- /*! \brief Whether to dump the IR of each pass (only when building from python) */
- std::vector< std::pair<int, runtime::PackedFunc> > add_lower_pass;
+ /*! \brief List of passes to be injected into the low-level pipeline. */
+ std::vector<std::pair<int, transform::Pass>> add_lower_pass;
/*! \brief Whether to dump the IR of each pass (only when building from python) */
bool dump_pass_ir = false;
diff --git a/python/tvm/autotvm/measure/measure_methods.py b/python/tvm/autotvm/measure/measure_methods.py
index 698ddbc..5ddc5df 100644
--- a/python/tvm/autotvm/measure/measure_methods.py
+++ b/python/tvm/autotvm/measure/measure_methods.py
@@ -615,9 +615,9 @@ def gpu_verify_pass(**kwargs):
"""Verify the validity of a gpu kernel.
This pass will check memory usage and number of threads per block.
"""
- def verify_pass(stmt):
- valid = ir_pass.VerifyGPUCode(stmt, kwargs)
+ def verify_pass(f, *_):
+ valid = ir_pass.VerifyGPUCode(f.body, kwargs)
if not valid:
raise InstantiationError("Skipped because of invalid gpu kernel")
- return stmt
- return verify_pass
+ return f
+ return tvm.tir.transform.prim_func_pass(verify_pass, opt_level=0)
diff --git a/python/tvm/driver/build_module.py b/python/tvm/driver/build_module.py
index 35700ba..dcd6d44 100644
--- a/python/tvm/driver/build_module.py
+++ b/python/tvm/driver/build_module.py
@@ -123,25 +123,6 @@ def form_irmodule(sch, args, name, binds):
return tvm.IRModule({name: func})
-def _wrap_as_prim_func_pass(flist, name):
- """Wrap flist as a function pass.
-
- This is an temporary adapter before we fully
- migrate to the new pass manager.
- """
- def _transform(func, *_):
- stmt = func.body
- for f in flist:
- stmt = f(stmt)
- # create a new function with updated body.
- return tvm.tir.PrimFunc(func.params,
- stmt,
- func.ret_type,
- func.buffer_map,
- func.attrs)
- return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name=name)
-
-
def lower(sch,
args,
name="main",
@@ -190,15 +171,15 @@ def lower(sch,
else:
mod = sch
+ pass_list = lower_phase0
# Phase 1
- pass_list = [
- _wrap_as_prim_func_pass(lower_phase0, "Custom-Phase0"),
+ pass_list += [
tvm.tir.transform.InjectPrefetch(),
tvm.tir.transform.StorageFlatten(64, cfg.instrument_bound_checkers),
tvm.tir.transform.NarrowDataType(32),
tvm.tir.transform.Simplify(),
- _wrap_as_prim_func_pass(lower_phase1, "Custom-Phase1"),
]
+ pass_list += lower_phase1
# Phase 2
if not simple_mode:
@@ -214,8 +195,8 @@ def lower(sch,
cfg.auto_unroll_max_depth,
cfg.auto_unroll_max_extent,
cfg.unroll_explicit),
- _wrap_as_prim_func_pass(lower_phase2, "Custom-Phase2"),
]
+ pass_list += lower_phase2
# Phase 3
pass_list += [
@@ -225,7 +206,7 @@ def lower(sch,
if not cfg.disable_select_rewriting:
pass_list += [tvm.tir.transform.RewriteUnsafeSelect()]
- pass_list += [_wrap_as_prim_func_pass(lower_phase3, "Custom-Phase3")]
+ pass_list += lower_phase3
# Instrument BoundCheckers
if cfg.instrument_bound_checkers:
diff --git a/python/tvm/tir/function.py b/python/tvm/tir/function.py
index 4ec1a71..47ad94f 100644
--- a/python/tvm/tir/function.py
+++ b/python/tvm/tir/function.py
@@ -67,3 +67,19 @@ class PrimFunc(BaseFunc):
self.__init_handle_by_constructor__(
_ffi_api.PrimFunc, param_list, body, ret_type, buffer_map, attrs)
+
+ def with_body(self, new_body):
+ """Create a new PrimFunc with the same set signatures but a new body.
+
+ Parameters
+ ----------
+ new_body : Stmt
+ The new body.
+
+ Returns
+ -------
+ new_func : PrimFunc
+ The created new function.
+ """
+ return PrimFunc(
+ self.params, new_body, self.ret_type, self.buffer_map, self.attrs)
diff --git a/src/target/target.cc b/src/target/target.cc
index 50856d6..a72ce1c 100644
--- a/src/target/target.cc
+++ b/src/target/target.cc
@@ -434,12 +434,12 @@ TVM_REGISTER_GLOBAL("target.ExitBuildConfigScope")
TVM_REGISTER_GLOBAL("target.BuildConfigSetAddLowerPass")
.set_body([](TVMArgs args, TVMRetValue* ret) {
BuildConfig cfg = args[0];
- std::vector< std::pair<int, PackedFunc> > add_lower_pass;
+ std::vector<std::pair<int, transform::Pass>> add_lower_pass;
CHECK_EQ(args.size() % 2, 1);
for (int i = 1; i < args.size(); i += 2) {
add_lower_pass.push_back(std::make_pair(
args[i].operator int(),
- args[i + 1].operator tvm::runtime::PackedFunc()));
+ args[i + 1].operator transform::Pass()));
}
cfg->add_lower_pass = add_lower_pass;
});
diff --git a/tests/python/relay/test_pass_fold_constant.py b/tests/python/relay/test_pass_fold_constant.py
index 4f44d2b..b212b26 100644
--- a/tests/python/relay/test_pass_fold_constant.py
+++ b/tests/python/relay/test_pass_fold_constant.py
@@ -51,11 +51,13 @@ def test_fold_const():
z = relay.add(y, relay.const(c_data))
return relay.Function([x], z)
- def fail(x):
- raise RuntimeError()
+ def FailPass():
+ def _transform(m, *args):
+ raise RuntimeError()
+ return tvm.transform.module_pass(_transform, opt_level=0)
# the fold constant should work on any context.
- with tvm.target.build_config(add_lower_pass=[(0, fail)]):
+ with tvm.target.build_config(add_lower_pass=[(0, FailPass())]):
with tvm.target.create("cuda"):
zz = run_opt_pass(before(), transform.FoldConstant())
zexpected = run_opt_pass(expected(), transform.InferType())
diff --git a/tests/python/unittest/test_target_codegen_cuda.py b/tests/python/unittest/test_target_codegen_cuda.py
index 739fc6f..4c2ec2e 100644
--- a/tests/python/unittest/test_target_codegen_cuda.py
+++ b/tests/python/unittest/test_target_codegen_cuda.py
@@ -182,7 +182,7 @@ def test_cuda_shuffle():
sch[c].bind(xo, thrx)
sch[c].vectorize(xi)
- def my_vectorize(stmt):
+ def MyVectorize():
def vectorizer(op):
if op.for_type == tvm.tir.For.Vectorized:
four = tvm.tir.const(4, 'int32')
@@ -198,9 +198,13 @@ def test_cuda_shuffle():
new_b = tvm.tir.Shuffle(bs, ids)
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
return None
- return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
- with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
+ def _transform(f, *_):
+ return f.with_body(
+ tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
+ return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="MyVectorize")
+
+ with tvm.target.build_config(add_lower_pass=[(1, MyVectorize())]):
module = tvm.build(sch, [a, b, c], target='cuda')
a_ = np.array(list(range(64)), dtype='int32')
b_ = np.array((list(range(4))[::-1]) * 16, dtype='int32')
diff --git a/tests/python/unittest/test_target_codegen_llvm.py b/tests/python/unittest/test_target_codegen_llvm.py
index 44b05c9..26f9347 100644
--- a/tests/python/unittest/test_target_codegen_llvm.py
+++ b/tests/python/unittest/test_target_codegen_llvm.py
@@ -671,8 +671,7 @@ def test_llvm_shuffle():
c = te.compute((8, ), lambda x: a[x] + b[7-x])
sch = te.create_schedule(c.op)
- def my_vectorize(stmt):
-
+ def my_vectorize():
def vectorizer(op):
store = op.body
idx = tvm.tir.Ramp(tvm.tir.const(0, 'int32'), tvm.tir.const(1, 'int32'), 8)
@@ -684,9 +683,13 @@ def test_llvm_shuffle():
value = new_a + new_b
return tvm.tir.Store(store.buffer_var, new_a + new_b, idx, all_ones)
- return tvm.tir.ir_pass.IRTransform(stmt, None, vectorizer, ['For'])
+ def _transform(f, *_):
+ return f.with_body(
+ tvm.tir.ir_pass.IRTransform(f.body, None, vectorizer, ['For']))
+
+ return tvm.tir.transform.prim_func_pass(_transform, opt_level=0, name="my_vectorize")
- with tvm.target.build_config(add_lower_pass=[(1, my_vectorize)]):
+ with tvm.target.build_config(add_lower_pass=[(1, my_vectorize())]):
ir = tvm.lower(sch, [a, b, c], simple_mode=True)
module = tvm.build(sch, [a, b, c])
a_ = tvm.nd.array(np.arange(1, 9, dtype='int32'))
diff --git a/tests/python/unittest/test_tir_pass_verify_gpu_code.py b/tests/python/unittest/test_tir_pass_verify_gpu_code.py
index 6e138a2..091a374 100644
--- a/tests/python/unittest/test_tir_pass_verify_gpu_code.py
+++ b/tests/python/unittest/test_tir_pass_verify_gpu_code.py
@@ -19,10 +19,10 @@ import tvm
from tvm import te
def get_verify_pass(valid, **kwargs):
- def verify_pass(stmt):
- valid[0] = tvm.tir.ir_pass.VerifyGPUCode(stmt, kwargs)
- return stmt
- return verify_pass
+ def _fverify(f, *_):
+ valid[0] = tvm.tir.ir_pass.VerifyGPUCode(f.body, kwargs)
+ return f
+ return tvm.tir.transform.prim_func_pass(_fverify, opt_level=0)
def test_shared_memory():
def check_shared_memory(dtype):
diff --git a/tutorials/dev/low_level_custom_pass.py b/tutorials/dev/low_level_custom_pass.py
index d35913b..49e86fd 100644
--- a/tutorials/dev/low_level_custom_pass.py
+++ b/tutorials/dev/low_level_custom_pass.py
@@ -117,19 +117,20 @@ def vectorize8(op):
return body
return None
-def vectorize(stmt):
+@tvm.tir.transform.prim_func_pass(opt_level=0)
+def vectorize(f, mod, ctx):
global loops
- tvm.tir.ir_pass.PostOrderVisit(stmt, find_width8)
+ tvm.tir.ir_pass.PostOrderVisit(f.body, find_width8)
if not loops:
- return stmt
+ return sf
# The last list arugment indicates what kinds of nodes will be transformed.
# Thus, in this case only `For` nodes will call `vectorize8`
- stmt = tvm.tir.ir_pass.IRTransform(stmt, None, vectorize8, ['For'])
+ return f.with_body(
+ tvm.tir.ir_pass.IRTransform(f.body, None, vectorize8, ['For']))
- return stmt
#####################################################################
# Glue to Lowering
diff --git a/vta/python/vta/build_module.py b/vta/python/vta/build_module.py
index 4c33d36..40bee86 100644
--- a/vta/python/vta/build_module.py
+++ b/vta/python/vta/build_module.py
@@ -14,25 +14,22 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-# pylint: disable=unused-argument
+# pylint: disable=unused-argument, invalid-name
"""VTA specific buildin for runtime."""
import tvm
-from . import ir_pass
+from . import transform
from .environment import get_env
-def lift_coproc_scope(x):
- """Lift coprocessings cope to the """
- x = ir_pass.lift_alloc_to_scope_begin(x)
- x = tvm.tir.ir_pass.LiftAttrScope(x, "coproc_scope", False)
- return x
-
-def early_rewrite(stmt):
+def EarlyRewrite():
"""Try to do storage rewrite in early pass."""
- try:
- return tvm.tir.ir_pass.StorageRewrite(stmt)
- except tvm.error.TVMError:
- return stmt
+ def _transform(mod, ctx):
+ try:
+ return tvm.tir.transform.StorageRewrite()(mod)
+ except tvm.error.TVMError:
+ return mod
+ return tvm.transform.module_pass(
+ _transform, opt_level=0, name="tir.vta.EarlyRewrite")
def build_config(debug_flag=0, **kwargs):
@@ -60,27 +57,32 @@ def build_config(debug_flag=0, **kwargs):
vta_module = tvm.build(s, ...)
"""
env = get_env()
- def add_debug(stmt):
+
+ @tvm.tir.transform.prim_func_pass(opt_level=0)
+ def add_debug(f, *_):
debug = tvm.tir.call_extern(
"int32", "VTASetDebugMode",
env.dev.command_handle,
debug_flag)
- return tvm.tir.stmt_seq(debug, stmt)
- pass_list = [(0, ir_pass.inject_conv2d_transpose_skip),
- (1, ir_pass.inject_dma_intrin),
- (1, ir_pass.inject_skip_copy),
- (1, ir_pass.annotate_alu_coproc_scope),
- (1, lambda x: tvm.tir.ir_pass.LiftAttrScope(x, "coproc_uop_scope", True)),
- (1, lift_coproc_scope),
- (1, ir_pass.inject_coproc_sync),
- (1, early_rewrite)]
+ return f.with_body(tvm.tir.stmt_seq(debug, f.body))
+
+
+ pass_list = [(0, transform.InjectConv2DTransposeSkip()),
+ (1, transform.InjectDMAIntrin()),
+ (1, transform.InjectSkipCopy()),
+ (1, transform.AnnotateALUCoProcScope()),
+ (1, tvm.tir.transform.LiftAttrScope("coproc_uop_scope")),
+ (1, transform.LiftAllocToScopeBegin()),
+ (1, tvm.tir.transform.LiftAttrScope("coproc_scope")),
+ (1, transform.InjectCoProcSync()),
+ (1, EarlyRewrite())]
if debug_flag:
pass_list.append((1, add_debug))
- pass_list.append((2, ir_pass.inject_alu_intrin))
- pass_list.append((3, tvm.tir.ir_pass.LowerStorageAccessInfo))
- pass_list.append((3, ir_pass.fold_uop_loop))
- pass_list.append((3, ir_pass.cpu_access_rewrite))
+ pass_list.append((2, transform.InjectALUIntrin()))
+ pass_list.append((3, tvm.tir.transform.LowerDeviceStorageAccessInfo()))
+ pass_list.append((3, transform.FoldUopLoop()))
+ pass_list.append((3, transform.CPUAccessRewrite()))
return tvm.target.build_config(add_lower_pass=pass_list, **kwargs)
diff --git a/vta/python/vta/ir_pass.py b/vta/python/vta/ir_pass.py
deleted file mode 100644
index 9836d13..0000000
--- a/vta/python/vta/ir_pass.py
+++ /dev/null
@@ -1,995 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-"""Additional IR Pass for VTA"""
-# pylint: disable=len-as-condition, no-else-return
-import tvm
-from tvm import te
-from topi import util
-
-from .environment import get_env
-
-
-def _match_pragma(stmt, key):
- """Internal helper to match stmt to pragma stmt.
-
- Parameters
- ----------
- stmt : Stmt
- The AttrStmt
-
- key : str
- The pragma key
- """
- return ((stmt.attr_key == "pragma_" + key) or
- (stmt.attr_key == "pragma_scope" and stmt.value.value == key))
-
-
-def fold_uop_loop(stmt_in):
- """Detect and fold uop loop.
-
- VTA support uop programming model
- that recognizes loop structure.
- This pass detect the loop structure
- and extract that into uop loop AST.
-
- Parameters
- ----------
- stmt_in : Stmt
- Input statement
-
- Returns
- -------
- stmt_out : Stmt
- Output statement.
- """
- env = get_env()
-
- def _fold_outermost_loop(body):
- stmt = body
- if not isinstance(stmt, tvm.tir.For):
- return None, body, None
-
- loop_var = stmt.loop_var
- gemm_offsets = [None, None, None]
- fail = [False]
-
- def _post_order(op):
- assert isinstance(op, tvm.tir.Call)
- base_args = 2
- if op.name == "VTAUopPush":
- args = []
- args += op.args[:base_args]
- for i in range(3):
- m = tvm.arith.detect_linear_equation(
- op.args[i + base_args], [loop_var])
- if not m:
- fail[0] = True
- return op
- if gemm_offsets[i] is not None:
- if not tvm.ir.structural_equal(m[0], gemm_offsets[i]):
- fail[0] = True
- return op
- args.append(m[1])
- else:
- gemm_offsets[i] = m[0]
- args.append(m[1])
- args += op.args[base_args+3:]
- return tvm.tir.call_extern("int32", "VTAUopPush", *args)
- if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
- raise RuntimeError("unexpected op %s" % op)
- return op
-
- ret = tvm.tir.ir_pass.IRTransform(
- stmt.body, None, _post_order, ["Call"])
-
- if not fail[0] and all(x is not None for x in gemm_offsets):
- def _visit(op):
- if op.same_as(loop_var):
- fail[0] = True
- tvm.tir.ir_pass.PostOrderVisit(ret, _visit)
- if not fail[0]:
- begin = tvm.tir.call_extern(
- "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
- end = tvm.tir.call_extern("int32", "VTAUopLoopEnd")
- return [begin, ret, end]
- raise ValueError("Failed to fold the GEMM instructions..")
-
- def _do_fold(stmt):
- if (stmt.attr_key == "coproc_uop_scope" and
- isinstance(stmt.value, tvm.tir.StringImm) and
- stmt.value.value == env.dev.vta_push_uop.value):
- body = stmt.body
- begins = []
- ends = []
- try:
- begin, body, end = _fold_outermost_loop(body)
- if begin is not None:
- begins.append(begin)
- if end is not None:
- ends.append(end)
- begin, body, end = _fold_outermost_loop(body)
- if begin is not None:
- begins.append(begin)
- if end is not None:
- ends.append(end)
- except ValueError:
- pass
- if body == stmt.body:
- return stmt
- ends = list(reversed(ends))
- body = tvm.tir.stmt_seq(*(begins + [body] + ends))
- return tvm.tir.AttrStmt(
- stmt.node, stmt.attr_key, stmt.value, body)
- return None
- out = tvm.tir.ir_pass.IRTransform(
- stmt_in, _do_fold, None, ["AttrStmt"])
- return out
-
-
-def cpu_access_rewrite(stmt_in):
- """Detect CPU access to VTA buffer and get address correctly.
-
- VTA's buffer is an opaque handle that do not
- correspond to address in CPU.
- This pass detect CPU access and rewrite to use pointer
- returned VTABufferCPUPtr for CPU access.
-
- Parameters
- ----------
- stmt_in : Stmt
- Input statement
-
- Returns
- -------
- stmt_out : Stmt
- Transformed statement
- """
- env = get_env()
- rw_info = {}
- def _post_order(op):
- if isinstance(op, tvm.tir.Allocate):
- buffer_var = op.buffer_var
- if not buffer_var in rw_info:
- return None
- new_var = rw_info[buffer_var]
- let_stmt = tvm.tir.LetStmt(
- new_var, tvm.tir.call_extern(
- "handle", "VTABufferCPUPtr",
- env.dev.command_handle,
- buffer_var), op.body)
- alloc = tvm.tir.Allocate(
- buffer_var, op.dtype, op.extents,
- op.condition, let_stmt)
- del rw_info[buffer_var]
- return alloc
- if isinstance(op, tvm.tir.Load):
- buffer_var = op.buffer_var
- if not buffer_var in rw_info:
- rw_info[buffer_var] = te.var(
- buffer_var.name + "_ptr", "handle")
- new_var = rw_info[buffer_var]
- return tvm.tir.Load(op.dtype, new_var, op.index)
- if isinstance(op, tvm.tir.Store):
- buffer_var = op.buffer_var
- if not buffer_var in rw_info:
- rw_info[buffer_var] = te.var(
- buffer_var.name + "_ptr", "handle")
- new_var = rw_info[buffer_var]
- return tvm.tir.Store(new_var, op.value, op.index)
- raise RuntimeError("not reached")
- stmt = tvm.tir.ir_pass.IRTransform(
- stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
- for buffer_var, new_var in rw_info.items():
- stmt = tvm.tir.LetStmt(
- new_var, tvm.tir.call_extern(
- "handle", "VTABufferCPUPtr",
- env.dev.command_handle,
- buffer_var), stmt)
- return stmt
-
-
-def lift_alloc_to_scope_begin(stmt_in):
- """Lift allocate to beginning of the current scope.
-
- Parameters
- ----------
- stmt_in : Stmt
- Input statement
-
- Returns
- -------
- stmt_out : Stmt
- Transformed statement
- """
- lift_stmt = [[]]
- def _merge_block(slist, body):
- for op in slist:
- if op.body == body:
- body = op
- elif isinstance(op, tvm.tir.Allocate):
- body = tvm.tir.Allocate(
- op.buffer_var, op.dtype,
- op.extents, op.condition, body)
- elif isinstance(op, tvm.tir.AttrStmt):
- body = tvm.tir.AttrStmt(
- op.node, op.attr_key, op.value, body)
- elif isinstance(op, tvm.tir.For):
- body = tvm.tir.For(
- op.loop_var, op.min, op.extent, op.for_type,
- op.device_api, body)
- else:
- raise RuntimeError("unexpected op")
- del slist[:]
- return body
-
- def _pre_order(op):
- if isinstance(op, tvm.tir.For):
- lift_stmt.append([])
- elif isinstance(op, tvm.tir.AttrStmt):
- if op.attr_key == "virtual_thread":
- lift_stmt.append([])
-
- def _post_order(op):
- if isinstance(op, tvm.tir.Allocate):
- lift_stmt[-1].append(op)
- return op.body
- if isinstance(op, tvm.tir.AttrStmt):
- if op.attr_key == "storage_scope":
- lift_stmt[-1].append(op)
- return op.body
- if op.attr_key == "virtual_thread":
- return _merge_block(lift_stmt.pop() + [op], op.body)
- return op
- if isinstance(op, tvm.tir.For):
- return _merge_block(lift_stmt.pop() + [op], op.body)
- raise RuntimeError("not reached")
- stmt = tvm.tir.ir_pass.IRTransform(
- stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
- assert len(lift_stmt) == 1
- return _merge_block(lift_stmt[0], stmt)
-
-
-def inject_skip_copy(stmt_in):
- """Pass to inject skip copy stmt, used for debug purpose.
-
- Parameters
- ----------
- stmt_in : Stmt
- Input statement
-
- Returns
- -------
- stmt_out : Stmt
- Transformed statement
- """
- def _do_fold(stmt):
- if _match_pragma(stmt, "skip_dma_copy"):
- return tvm.tir.Evaluate(0)
- return None
- return tvm.tir.ir_pass.IRTransform(
- stmt_in, _do_fold, None, ["AttrStmt"])
-
-
-def inject_coproc_sync(stmt_in):
- """Pass to inject skip copy stmt, used in debug.
-
- Parameters
- ----------
- stmt_in : Stmt
- Input statement
-
- Returns
- -------
- stmt_out : Stmt
- Transformed statement
- """
- success = [False]
- def _do_fold(stmt):
- if _match_pragma(stmt, "coproc_sync"):
- success[0] = True
- sync = tvm.tir.Call(
- "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0)
- return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
- if _match_pragma(stmt, "trim_loop"):
- op = stmt.body
- assert isinstance(op, tvm.tir.For)
- return tvm.tir.For(
- op.loop_var, op.min, 2, op.for_type,
- op.device_api, op.body)
- return None
- stmt = tvm.tir.ir_pass.IRTransform(
- stmt_in, None, _do_fold, ["AttrStmt"])
- stmt = tvm.tir.ir_pass.CoProcSync(stmt)
- return stmt
-
-
-def inject_dma_intrin(stmt_in):
- """Pass to inject DMA copy intrinsics.
-
- Parameters
- ----------
- stmt_in : Stmt
- Input statement
-
- Returns
- -------
- stmt_out : Stmt
- Transformed statement
- """
- env = get_env()
- idxd = tvm.tir.indexdiv
- idxm = tvm.tir.indexmod
-
- def _check_compact(buf):
- ndim = len(buf.shape)
- size = tvm.tir.const(1, buf.shape[0].dtype)
- for i in reversed(range(ndim)):
- if not util.equal_const_int(size - buf.strides[i], 0):
- raise RuntimeError(
- "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides))
- size = size * buf.shape[i]
-
- def _fold_buffer_dim(buf, scope, elem_block):
- ndim = len(buf.shape)
- x_size = 1
- base = 0
- for i in range(1, ndim + 1):
- if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0):
- raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block))
- x_size = x_size * buf.shape[ndim - i]
- if util.equal_const_int(x_size - elem_block, 0):
- base = i + 1
- break
- if base == 0:
- raise RuntimeError("scope %s need to have block=%d, shape=%s" % (
- scope, elem_block, buf.shape))
- shape = [elem_block]
- strides = [1]
-
- if base < ndim + 1 and not util.equal_const_int(buf.strides[ndim - base], elem_block):
- shape.append(1)
- strides.append(elem_block)
-
- analyzer = tvm.arith.Analyzer()
- while base < ndim + 1:
- x_size = 1
- x_stride = buf.strides[ndim - base]
- next_base = base
- if not util.equal_const_int(idxm(x_stride, elem_block), 0):
- raise RuntimeError(
- "scope %s need to have block=%d, shape=%s, strides=%s" % (
- scope, elem_block, buf.shape, buf.strides))
- for i in range(base, ndim + 1):
- k = ndim - i
- if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0):
- break
- x_size = x_size * buf.shape[k]
- next_base = i + 1
- shape.append(analyzer.simplify(x_size))
- strides.append(x_stride)
- assert next_base != base
- base = next_base
-
- strides = list(reversed(strides))
- shape = list(reversed(shape))
- return shape, strides
-
- def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
- elem_block = elem_bytes * 8 // elem_width
- if buf.dtype != dtype:
- raise RuntimeError("Expect buffer type to be %s instead of %s" %
- (dtype, buf.dtype))
- shape, strides = buf.shape, buf.strides
- if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
- raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
- if allow_fold:
- shape, strides = _fold_buffer_dim(buf, scope, elem_block)
- else:
- shape = list(x for x in shape)
- strides = list(x for x in strides)
-
- def raise_error():
- """Internal function to raise error """
- raise RuntimeError(
- ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" +
- " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides))
-
- ndim = len(shape)
-
- # Check if the inner-tensor is already flat
- flat = util.equal_const_int(shape[-1], elem_block)
-
- if flat:
- if not util.equal_const_int(strides[-1], 1):
- raise_error()
-
- if ndim == 1:
- x_size = 1
- x_stride = 1
- y_size = 1
- return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
- if not util.equal_const_int(strides[-2] - elem_block, 0):
- raise_error()
-
- if ndim == 2:
- x_size = shape[-2]
- x_stride = shape[-2]
- y_size = 1
- return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
- if not util.equal_const_int(idxm(strides[-3], elem_block), 0):
- raise_error()
-
- if ndim == 3:
- x_size = shape[-2]
- x_stride = idxd(strides[-3], elem_block)
- y_size = shape[-3]
- return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
-
- else:
- if not util.equal_const_int(strides[-1], 1):
- raise_error()
- if not util.equal_const_int(strides[-2] - shape[-1], 0):
- raise_error()
- if not util.equal_const_int(shape[-1] * shape[-2], elem_block):
- raise_error()
-
- if ndim == 2:
- x_size = 1
- x_stride = 1
- y_size = 1
- return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
- if not util.equal_const_int(strides[-3], elem_block):
- raise_error()
-
- if ndim == 3:
- x_size = shape[-3]
- x_stride = shape[-3]
- y_size = 1
- return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
- if not util.equal_const_int(idxm(strides[-4], elem_block), 0):
- raise_error()
-
- if ndim == 4:
- x_size = shape[-3]
- x_stride = idxd(strides[-4], elem_block)
- y_size = shape[-4]
- return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
-
- raise_error()
-
-
- def _inject_copy(src, dst, pad_before, pad_after, pad_value):
- # FIXME: pad_value is ignored...
- _ = pad_value
- if dst.scope == "global":
- # Store
- if pad_before or pad_after:
- raise RuntimeError("Do not support copy into DRAM with pad")
- if src.scope == env.acc_scope:
- elem_width = env.OUT_WIDTH
- elem_bytes = env.OUT_ELEM_BYTES
- mem_type = env.dev.MEM_ID_OUT
- data_type = "int%d" % env.OUT_WIDTH
- task_qid = env.dev.QID_STORE_OUT
- else:
- raise RuntimeError("Do not support copy %s->dram" % (src.scope))
- _check_compact(src)
- x_size, y_size, x_stride, offset = _get_2d_pattern(
- dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True)
- irb = tvm.tir.ir_builder.create()
- irb.scope_attr(env.dev.vta_axis, "coproc_scope",
- env.dev.get_task_qid(task_qid))
- irb.emit(tvm.tir.call_extern(
- "int32", "VTAStoreBuffer2D",
- env.dev.command_handle,
- src.access_ptr("r", "int32"),
- mem_type, dst.data, offset, x_size, y_size, x_stride))
- return irb.get()
- elif src.scope == "global":
- if dst.scope == env.acc_scope:
- elem_width = env.ACC_WIDTH
- elem_bytes = env.ACC_ELEM_BYTES
- mem_type = env.dev.MEM_ID_ACC
- data_type = "int%d" % env.ACC_WIDTH
- task_qid = env.dev.QID_LOAD_OUT
- elif dst.scope == env.inp_scope:
- elem_width = env.INP_WIDTH
- elem_bytes = env.INP_ELEM_BYTES
- mem_type = env.dev.MEM_ID_INP
- data_type = "int%d" % env.INP_WIDTH
- task_qid = env.dev.QID_LOAD_INP
- elif dst.scope == env.wgt_scope:
- elem_width = env.WGT_WIDTH
- elem_bytes = env.WGT_ELEM_BYTES
- mem_type = env.dev.MEM_ID_WGT
- data_type = "int%d" % env.WGT_WIDTH
- task_qid = env.dev.QID_LOAD_WGT
- else:
- raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
- # collect pad statistics
- if pad_before:
- assert pad_after
- ndim = len(pad_before)
- if ndim <= 2 or ndim > 5:
- raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim)
- if ndim == 5:
- # This case occurs when batch size N > 1
- y_pad_before = pad_before[1]
- x_pad_before = pad_before[2]
- y_pad_after = pad_after[1]
- x_pad_after = pad_after[2]
- for dim in range(3, ndim):
- if not util.equal_const_int(pad_before[dim], 0):
- raise ValueError("Do not support pad on the innermost block")
- if not util.equal_const_int(pad_after[dim], 0):
- raise ValueError("Do not support pad on the innermost block")
- else:
- y_pad_before = pad_before[0]
- x_pad_before = pad_before[1]
- y_pad_after = pad_after[0]
- x_pad_after = pad_after[1]
- for dim in range(2, ndim):
- if not util.equal_const_int(pad_before[dim], 0):
- raise ValueError("Do not support pad on the innermost block")
- if not util.equal_const_int(pad_after[dim], 0):
- raise ValueError("Do not support pad on the innermost block")
- allow_fold = False
- else:
- x_pad_before = 0
- y_pad_before = 0
- x_pad_after = 0
- y_pad_after = 0
- allow_fold = True
-
- _check_compact(dst)
- x_size, y_size, x_stride, offset = _get_2d_pattern(
- src, elem_width, elem_bytes, data_type,
- dst.scope, allow_fold=allow_fold)
-
- irb = tvm.tir.ir_builder.create()
- irb.scope_attr(env.dev.vta_axis, "coproc_scope",
- env.dev.get_task_qid(task_qid))
-
- irb.emit(tvm.tir.call_extern(
- "int32", "VTALoadBuffer2D",
- env.dev.command_handle,
- src.data, offset, x_size, y_size, x_stride,
- x_pad_before, y_pad_before,
- x_pad_after, y_pad_after,
- dst.access_ptr("r", "int32"), mem_type))
- return irb.get()
-
- else:
- raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
-
- return tvm.tir.ir_pass.InjectCopyIntrin(stmt_in, "dma_copy", _inject_copy)
-
-
-def _get_gemm_intrin_buffer():
- env = get_env()
- wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
- assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
- wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
- assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
- inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
- assert inp_lanes == env.BATCH * env.BLOCK_IN
- inp_shape = (env.BATCH, env.BLOCK_IN)
- assert inp_shape[0] * inp_shape[1] == inp_lanes
- out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
- assert out_lanes == env.BATCH * env.BLOCK_OUT
- out_shape = (env.BATCH, env.BLOCK_OUT)
- assert out_shape[0] * out_shape[1] == out_lanes
- wgt = te.placeholder((wgt_shape[0], wgt_shape[1]),
- dtype="int%d" % env.WGT_WIDTH,
- name=env.wgt_scope)
- inp = te.placeholder((inp_shape[0], inp_shape[1]),
- dtype="int%d" % env.INP_WIDTH,
- name=env.inp_scope)
- k = te.reduce_axis((0, wgt_shape[1]), name="k")
- out_dtype = "int%d" % env.ACC_WIDTH
- out = te.compute((out_shape[0], out_shape[1]),
- lambda i, j: te.sum(inp[i, k].astype(out_dtype) *
- wgt[j, k].astype(out_dtype),
- axis=[k]),
- name="out")
- wgt_layout = tvm.tir.decl_buffer(
- wgt.shape, wgt.dtype, env.wgt_scope,
- scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
- inp_layout = tvm.tir.decl_buffer(
- inp.shape, inp.dtype, env.inp_scope,
- scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
- out_layout = tvm.tir.decl_buffer(
- out.shape, out.dtype, env.acc_scope,
- scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
-
- return wgt_layout, inp_layout, out_layout
-
-
-def inject_conv2d_transpose_skip(stmt_in):
- """Pass to skip 0-weights in conv2d transpose with stride > 1.
-
- Parameters
- ----------
- stmt_in : Stmt
- Input statement
-
- Returns
- -------
- stmt_out : Stmt
- Transformed statement
- """
- env = get_env()
- dwgt, dinp, dout = _get_gemm_intrin_buffer()
-
- calls = []
- selects = []
-
- def _find_basics(op):
- if isinstance(op, tvm.tir.BufferLoad):
- calls.append(op)
- elif isinstance(op, tvm.tir.Select):
- selects.append(op)
-
- def _do_fold(op):
- if _match_pragma(op, "conv2d_transpose_gemm"):
- is_init = ".init" in str(op)
- tvm.tir.ir_pass.PostOrderVisit(op, _find_basics)
-
- if is_init:
- # create inner most block
- irb = tvm.tir.ir_builder.create()
- dev = env.dev
- irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
- irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
- irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
- 0, 1,
- dout.access_ptr("rw", "int32"),
- 0, 0,
- 0, 0, 0))
- inner = irb.get()
- # TODO(@tmoreau89): This is only a temporary fix, please take a look.
- body = op.body.body
- while isinstance(body, tvm.tir.IfThenElse):
- body = body.then_case
- args = body.indices
- res_buffer = body.buffer
- tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
- inner = tvm.tir.AttrStmt(
- [dout, res_buffer], 'buffer_bind_scope',
- tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
- return inner
- else:
- conv_call, data_call, kernel_call = calls[-3:]
- pad_data_tensor = data_call.buffer
- kernel_tensor = kernel_call.buffer
- res_tensor = conv_call.buffer
-
- if selects:
- condition = selects[0].condition
- else:
- condition = tvm.tir.const(1, 'int')
-
- # create inner most block
- irb = tvm.tir.ir_builder.create()
- with irb.if_scope(condition):
- dev = env.dev
- irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
- irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
- irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
- 0, 0,
- dout.access_ptr("rw", "int32"),
- dinp.access_ptr("r", "int32"),
- dwgt.access_ptr("r", "int32"),
- 0, 0, 0))
- inner = irb.get()
-
- args = conv_call.indices
- tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
- 1, 0, 1, 0, env.BLOCK_OUT)
- inner = tvm.tir.AttrStmt(
- [dout, res_tensor], 'buffer_bind_scope',
- tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
- args = kernel_call.indices
- tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
- 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
- inner = tvm.tir.AttrStmt(
- [dwgt, kernel_tensor], 'buffer_bind_scope',
- tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
- args = data_call.indices
- tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
- 1, 0, 1, 0, env.BLOCK_IN)
- inner = tvm.tir.AttrStmt(
- [dinp, pad_data_tensor], 'buffer_bind_scope',
- tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
- return inner
- return None
- ret = tvm.tir.ir_pass.IRTransform(
- stmt_in, _do_fold, None, ["AttrStmt"])
- return ret
-
-
-def annotate_alu_coproc_scope(stmt_in):
- """Pass to insert ALU instruction.
-
- Parameters
- ----------
- stmt_in : Stmt
- Input statement
-
- Returns
- -------
- stmt_out : Stmt
- Transformed statement
- """
- env = get_env()
- def _do_fold(stmt):
- if _match_pragma(stmt, "alu"):
- irb = tvm.tir.ir_builder.create()
- irb.scope_attr(env.dev.vta_axis, "coproc_scope",
- env.dev.get_task_qid(env.dev.QID_COMPUTE))
- irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
- tvm.tir.StringImm("VTAPushALUOp"))
- irb.emit(stmt)
- return irb.get()
- if _match_pragma(stmt, "skip_alu"):
- return tvm.tir.Evaluate(0)
- return stmt
-
- stmt_out = tvm.tir.ir_pass.IRTransform(
- stmt_in, None, _do_fold, ["AttrStmt"])
-
- return stmt_out
-
-
-def inject_alu_intrin(stmt_in):
- """Pass to inject ALU micro-ops.
-
- Parameters
- ----------
- stmt_in : Stmt
- Input statement
-
- Returns
- -------
- stmt_out : Stmt
- Transformed statement
- """
- env = get_env()
- idxm = tvm.tir.indexmod
- analyzer = tvm.arith.Analyzer()
-
- def _do_fold(stmt):
- def _equal(x, y):
- return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)
-
- def _flatten_loop(src_coeff, dst_coeff, extents):
- src_coeff = list(src_coeff)
- dst_coeff = list(dst_coeff)
- extents = list(extents)
- rev_src_coeff = [src_coeff.pop()]
- rev_dst_coeff = [dst_coeff.pop()]
- rev_extents = []
- assert src_coeff
- vsrc = src_coeff.pop()
- vdst = dst_coeff.pop()
- vext = extents.pop()
- while src_coeff:
- next_src = src_coeff.pop()
- next_dst = dst_coeff.pop()
- next_ext = extents.pop()
-
- if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
- vext = analyzer.simplify(vext * next_ext)
- else:
- rev_src_coeff.append(vsrc)
- rev_dst_coeff.append(vdst)
- rev_extents.append(vext)
- vsrc = next_src
- vdst = next_dst
- vext = next_ext
- rev_src_coeff.append(vsrc)
- rev_dst_coeff.append(vdst)
- rev_extents.append(vext)
- rev_src_coeff.reverse()
- rev_dst_coeff.reverse()
- rev_extents.reverse()
-
- return rev_src_coeff, rev_dst_coeff, rev_extents
-
- if _match_pragma(stmt, "alu"):
- # Get to the innermost loop body
- loop_body = stmt.body
- nest_size = 0
- while isinstance(loop_body, tvm.tir.For):
- loop_body = loop_body.body
- nest_size += 1
- # Get the src/dst arguments
- dst_var = loop_body.buffer_var
- dst_idx = loop_body.index
- # Derive loop variables and extents
- tmp_body = stmt.body
- indices = []
- extents = []
- for _ in range(nest_size):
- indices.append(tmp_body.loop_var)
- extents.append(tmp_body.extent)
- tmp_body = tmp_body.body
- # Derive opcode
- if isinstance(loop_body.value, tvm.tir.Add):
- alu_opcode = env.dev.ALU_OPCODE_ADD
- lhs = loop_body.value.a
- rhs = loop_body.value.b
- elif isinstance(loop_body.value, tvm.tir.Sub):
- alu_opcode = env.dev.ALU_OPCODE_SUB
- lhs = loop_body.value.a
- rhs = loop_body.value.b
- elif isinstance(loop_body.value, tvm.tir.Mul):
- alu_opcode = env.dev.ALU_OPCODE_MUL
- lhs = loop_body.value.a
- rhs = loop_body.value.b
- elif isinstance(loop_body.value, tvm.tir.Min):
- alu_opcode = env.dev.ALU_OPCODE_MIN
- lhs = loop_body.value.a
- rhs = loop_body.value.b
- elif isinstance(loop_body.value, tvm.tir.Max):
- alu_opcode = env.dev.ALU_OPCODE_MAX
- lhs = loop_body.value.a
- rhs = loop_body.value.b
- elif isinstance(loop_body.value, tvm.tir.Call):
- if loop_body.value.name == 'shift_left':
- alu_opcode = env.dev.ALU_OPCODE_SHR
- lhs = loop_body.value.args[0]
- rhs = analyzer.simplify(-loop_body.value.args[1])
- elif loop_body.value.name == 'shift_right':
- alu_opcode = env.dev.ALU_OPCODE_SHR
- lhs = loop_body.value.args[0]
- rhs = loop_body.value.args[1]
- else:
- raise RuntimeError(
- "Function call not recognized %s" % (loop_body.value.name))
- elif isinstance(loop_body.value, tvm.tir.Load):
- alu_opcode = env.dev.ALU_OPCODE_SHR
- lhs = loop_body.value
- rhs = tvm.tir.const(0, "int32")
- else:
- raise RuntimeError(
- "Expression not recognized %s, %s, %s" % (
- type(loop_body.value), str(loop_body.value), str(stmt)))
-
- # Derive array index coefficients
- dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
- # Check if lhs/rhs is immediate
- use_imm = False
- imm_val = None
- if isinstance(rhs, tvm.tir.IntImm):
- assert lhs.buffer_var.same_as(dst_var)
- src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
- use_imm = True
- imm_val = rhs
- if isinstance(lhs, tvm.tir.IntImm):
- assert rhs.buffer_var.same_as(dst_var)
- src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
- use_imm = True
- imm_val = lhs
- if imm_val is None:
- imm_val = 0
- assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
- src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
- src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
- # Determine which side has the same coefficients
- lhs_equal = True
- rhs_equal = True
- for i, coef in enumerate(dst_coeff):
- if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]):
- lhs_equal = False
- if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]):
- rhs_equal = False
- # Make sure at least one of the source is identical to the
- # destination (in-place computation)
- assert lhs_equal or rhs_equal
- # Assign the source coefficients
- if lhs_equal:
- src_coeff = src_rhs_coeff
- else:
- src_coeff = src_lhs_coeff
-
- # Ensure that we have the proper tensor dimensions in the
- # innermost loop (pattern match)
- src_coeff = list(src_coeff)
- dst_coeff = list(dst_coeff)
- extents = list(extents)
- assert len(src_coeff) > 1
- assert len(dst_coeff) > 1
- assert len(extents) != 0
- assert tvm.ir.structural_equal(
- analyzer.simplify(
- idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
- assert tvm.ir.structural_equal(
- analyzer.simplify(
- idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
- assert tvm.ir.structural_equal(src_coeff[-2], 1)
- assert tvm.ir.structural_equal(dst_coeff[-2], 1)
- if env.BATCH > 1:
- assert len(src_coeff) > 2
- assert len(dst_coeff) > 2
- assert len(extents) > 1
- assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT)
- assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT)
-
- # Apply tensorization of the loop coefficients
- src_offset = src_coeff[-1]
- dst_offset = dst_coeff[-1]
- if env.BATCH == 1:
- src_coeff = src_coeff[:-2]
- dst_coeff = dst_coeff[:-2]
- extents = extents[:-1]
- else:
- src_coeff = src_coeff[:-3]
- dst_coeff = dst_coeff[:-3]
- extents = extents[:-2]
- src_coeff.append(src_offset)
- dst_coeff.append(dst_offset)
- src_coeff = [
- analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
- dst_coeff = [
- analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
-
- # Flatten the outer loops
- if extents:
- src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
-
- # Insert ALU micro-ops
- irb = tvm.tir.ir_builder.create()
- for idx, extent in enumerate(extents):
- irb.emit(tvm.tir.call_extern(
- "int32", "VTAUopLoopBegin",
- extent, dst_coeff[idx], src_coeff[idx], 0))
- use_imm = int(use_imm)
- irb.emit(tvm.tir.call_extern(
- "int32", "VTAUopPush",
- 1, 0,
- dst_coeff[len(dst_coeff)-1],
- src_coeff[len(src_coeff)-1],
- 0,
- alu_opcode, use_imm, imm_val))
- for extent in extents:
- irb.emit(tvm.tir.call_extern(
- "int32", "VTAUopLoopEnd"))
- return irb.get()
- return stmt
-
- stmt_out = tvm.tir.ir_pass.IRTransform(
- stmt_in, None, _do_fold, ["AttrStmt"])
- return stmt_out
-
-
-def debug_print(stmt):
- """A debug pass that print the stmt
-
- Parameters
- ----------
- stmt : Stmt
- The input statement
-
- Returns
- -------
- stmt : Stmt
- The
- """
- # pylint: disable=superfluous-parens
- print(stmt)
- return stmt
diff --git a/vta/python/vta/transform.py b/vta/python/vta/transform.py
new file mode 100644
index 0000000..f930b3f
--- /dev/null
+++ b/vta/python/vta/transform.py
@@ -0,0 +1,962 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Additional Transformation Passes. for VTA"""
+# pylint: disable=len-as-condition, no-else-return, unused-argument, invalid-name
+import tvm
+from tvm import te
+from topi import util
+
+from .environment import get_env
+
+
+def _match_pragma(stmt, key):
+ """Internal helper to match stmt to pragma stmt.
+
+ Parameters
+ ----------
+ stmt : Stmt
+ The AttrStmt
+
+ key : str
+ The pragma key
+ """
+ return ((stmt.attr_key == "pragma_" + key) or
+ (stmt.attr_key == "pragma_scope" and stmt.value.value == key))
+
+
+def FoldUopLoop():
+ """Detect and fold uop loop.
+
+ VTA support uop programming model
+ that recognizes loop structure.
+ This pass detect the loop structure
+ and extract that into uop loop AST.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The pass
+ """
+ def _fold_outermost_loop(body):
+ stmt = body
+ if not isinstance(stmt, tvm.tir.For):
+ return None, body, None
+
+ loop_var = stmt.loop_var
+ gemm_offsets = [None, None, None]
+ fail = [False]
+
+ def _post_order(op):
+ assert isinstance(op, tvm.tir.Call)
+ base_args = 2
+ if op.name == "VTAUopPush":
+ args = []
+ args += op.args[:base_args]
+ for i in range(3):
+ m = tvm.arith.detect_linear_equation(
+ op.args[i + base_args], [loop_var])
+ if not m:
+ fail[0] = True
+ return op
+ if gemm_offsets[i] is not None:
+ if not tvm.ir.structural_equal(m[0], gemm_offsets[i]):
+ fail[0] = True
+ return op
+ args.append(m[1])
+ else:
+ gemm_offsets[i] = m[0]
+ args.append(m[1])
+ args += op.args[base_args+3:]
+ return tvm.tir.call_extern("int32", "VTAUopPush", *args)
+ if op.name not in ("VTATLSCommandHandle", "tvm_thread_context"):
+ raise RuntimeError("unexpected op %s" % op)
+ return op
+
+ ret = tvm.tir.ir_pass.IRTransform(
+ stmt.body, None, _post_order, ["Call"])
+
+ if not fail[0] and all(x is not None for x in gemm_offsets):
+ def _visit(op):
+ if op.same_as(loop_var):
+ fail[0] = True
+ tvm.tir.ir_pass.PostOrderVisit(ret, _visit)
+ if not fail[0]:
+ begin = tvm.tir.call_extern(
+ "int32", "VTAUopLoopBegin", stmt.extent, *gemm_offsets)
+ end = tvm.tir.call_extern("int32", "VTAUopLoopEnd")
+ return [begin, ret, end]
+ raise ValueError("Failed to fold the GEMM instructions..")
+
+ def _do_fold(stmt):
+ env = get_env()
+ if (stmt.attr_key == "coproc_uop_scope" and
+ isinstance(stmt.value, tvm.tir.StringImm) and
+ stmt.value.value == env.dev.vta_push_uop.value):
+ body = stmt.body
+ begins = []
+ ends = []
+ try:
+ begin, body, end = _fold_outermost_loop(body)
+ if begin is not None:
+ begins.append(begin)
+ if end is not None:
+ ends.append(end)
+ begin, body, end = _fold_outermost_loop(body)
+ if begin is not None:
+ begins.append(begin)
+ if end is not None:
+ ends.append(end)
+ except ValueError:
+ pass
+ if body == stmt.body:
+ return stmt
+ ends = list(reversed(ends))
+ body = tvm.tir.stmt_seq(*(begins + [body] + ends))
+ return tvm.tir.AttrStmt(
+ stmt.node, stmt.attr_key, stmt.value, body)
+ return None
+
+ def _ftransform(f, mod, ctx):
+ return f.with_body(tvm.tir.ir_pass.IRTransform(
+ f.body, _do_fold, None, ["AttrStmt"]))
+
+ return tvm.tir.transform.prim_func_pass(
+ _ftransform, opt_level=0, name="tir.vta.FoldUopLoop")
+
+
+def CPUAccessRewrite():
+ """Detect CPU access to VTA buffer and get address correctly.
+
+ VTA's buffer is an opaque handle that do not
+ correspond to address in CPU.
+ This pass detect CPU access and rewrite to use pointer
+ returned VTABufferCPUPtr for CPU access.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The pass
+ """
+ def _ftransform(f, mod, ctx):
+ rw_info = {}
+ env = get_env()
+ def _post_order(op):
+ if isinstance(op, tvm.tir.Allocate):
+ buffer_var = op.buffer_var
+ if not buffer_var in rw_info:
+ return None
+ new_var = rw_info[buffer_var]
+ let_stmt = tvm.tir.LetStmt(
+ new_var, tvm.tir.call_extern(
+ "handle", "VTABufferCPUPtr",
+ env.dev.command_handle,
+ buffer_var), op.body)
+ alloc = tvm.tir.Allocate(
+ buffer_var, op.dtype, op.extents,
+ op.condition, let_stmt)
+ del rw_info[buffer_var]
+ return alloc
+ if isinstance(op, tvm.tir.Load):
+ buffer_var = op.buffer_var
+ if not buffer_var in rw_info:
+ rw_info[buffer_var] = te.var(
+ buffer_var.name + "_ptr", "handle")
+ new_var = rw_info[buffer_var]
+ return tvm.tir.Load(op.dtype, new_var, op.index)
+ if isinstance(op, tvm.tir.Store):
+ buffer_var = op.buffer_var
+ if not buffer_var in rw_info:
+ rw_info[buffer_var] = te.var(
+ buffer_var.name + "_ptr", "handle")
+ new_var = rw_info[buffer_var]
+ return tvm.tir.Store(new_var, op.value, op.index)
+ raise RuntimeError("not reached")
+
+ stmt_in = f.body
+ stmt = tvm.tir.ir_pass.IRTransform(
+ stmt_in, None, _post_order, ["Allocate", "Load", "Store"])
+
+ for buffer_var, new_var in rw_info.items():
+ stmt = tvm.tir.LetStmt(
+ new_var, tvm.tir.call_extern(
+ "handle", "VTABufferCPUPtr",
+ env.dev.command_handle,
+ buffer_var), stmt)
+ return f.with_body(stmt)
+ return tvm.tir.transform.prim_func_pass(
+ _ftransform, opt_level=0, name="tir.vta.CPUAccessRewrite")
+
+
+def LiftAllocToScopeBegin():
+ """Lift allocate to beginning of the current scope.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The pass
+ """
+ def _ftransform(f, mod, ctx):
+ lift_stmt = [[]]
+ def _merge_block(slist, body):
+ for op in slist:
+ if op.body == body:
+ body = op
+ elif isinstance(op, tvm.tir.Allocate):
+ body = tvm.tir.Allocate(
+ op.buffer_var, op.dtype,
+ op.extents, op.condition, body)
+ elif isinstance(op, tvm.tir.AttrStmt):
+ body = tvm.tir.AttrStmt(
+ op.node, op.attr_key, op.value, body)
+ elif isinstance(op, tvm.tir.For):
+ body = tvm.tir.For(
+ op.loop_var, op.min, op.extent, op.for_type,
+ op.device_api, body)
+ else:
+ raise RuntimeError("unexpected op")
+ del slist[:]
+ return body
+
+ def _pre_order(op):
+ if isinstance(op, tvm.tir.For):
+ lift_stmt.append([])
+ elif isinstance(op, tvm.tir.AttrStmt):
+ if op.attr_key == "virtual_thread":
+ lift_stmt.append([])
+
+ def _post_order(op):
+ if isinstance(op, tvm.tir.Allocate):
+ lift_stmt[-1].append(op)
+ return op.body
+ if isinstance(op, tvm.tir.AttrStmt):
+ if op.attr_key == "storage_scope":
+ lift_stmt[-1].append(op)
+ return op.body
+ if op.attr_key == "virtual_thread":
+ return _merge_block(lift_stmt.pop() + [op], op.body)
+ return op
+ if isinstance(op, tvm.tir.For):
+ return _merge_block(lift_stmt.pop() + [op], op.body)
+ raise RuntimeError("not reached")
+ stmt_in = f.body
+ stmt = tvm.tir.ir_pass.IRTransform(
+ stmt_in, _pre_order, _post_order, ["Allocate", "AttrStmt", "For"])
+ assert len(lift_stmt) == 1
+ return f.with_body(_merge_block(lift_stmt[0], stmt))
+
+ return tvm.tir.transform.prim_func_pass(
+ _ftransform, opt_level=0, name="tir.vta.LiftAllocToScopeBegin")
+
+
+def InjectSkipCopy():
+ """Pass to inject skip copy stmt, used for debug purpose.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The pass
+ """
+ def _do_fold(stmt):
+ if _match_pragma(stmt, "skip_dma_copy"):
+ return tvm.tir.Evaluate(0)
+ return None
+
+ def _ftransform(f, mod, ctx):
+ return f.with_body(tvm.tir.ir_pass.IRTransform(
+ f.body, _do_fold, None, ["AttrStmt"]))
+
+ return tvm.tir.transform.prim_func_pass(
+ _ftransform, opt_level=0, name="tir.vta.InjectSkipCopy")
+
+
+def InjectCoProcSync():
+ """Pass inject coproc sync
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The pass
+ """
+ def _ftransform(f, *_):
+ success = [False]
+ def _do_fold(stmt):
+ if _match_pragma(stmt, "coproc_sync"):
+ success[0] = True
+ sync = tvm.tir.Call(
+ "int32", "vta.coproc_sync", [], tvm.tir.Call.Intrinsic, None, 0)
+ return tvm.tir.SeqStmt([stmt.body, tvm.tir.Evaluate(sync)])
+ if _match_pragma(stmt, "trim_loop"):
+ op = stmt.body
+ assert isinstance(op, tvm.tir.For)
+ return tvm.tir.For(
+ op.loop_var, op.min, 2, op.for_type,
+ op.device_api, op.body)
+ return None
+ return f.with_body(tvm.tir.ir_pass.IRTransform(
+ f.body, None, _do_fold, ["AttrStmt"]))
+ return tvm.transform.Sequential(
+ [tvm.tir.transform.prim_func_pass(_ftransform, 0, "tir.vta.InjectCoProcSync"),
+ tvm.tir.transform.CoProcSync()],
+ opt_level=0, name="tir.vta.InjectCoProcSync")
+
+
+def InjectDMAIntrin():
+ """Pass to inject DMA copy intrinsics.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The pass
+ """
+ idxd = tvm.tir.indexdiv
+ idxm = tvm.tir.indexmod
+
+ def _check_compact(buf):
+ ndim = len(buf.shape)
+ size = tvm.tir.const(1, buf.shape[0].dtype)
+ for i in reversed(range(ndim)):
+ if not util.equal_const_int(size - buf.strides[i], 0):
+ raise RuntimeError(
+ "Cannot prove compact: shape=%s, strides=%s" % (buf.shape, buf.strides))
+ size = size * buf.shape[i]
+
+ def _fold_buffer_dim(buf, scope, elem_block):
+ ndim = len(buf.shape)
+ x_size = 1
+ base = 0
+ for i in range(1, ndim + 1):
+ if not util.equal_const_int(buf.strides[ndim - i] - x_size, 0):
+ raise RuntimeError("scope %s needs to have block=%d" % (scope, elem_block))
+ x_size = x_size * buf.shape[ndim - i]
+ if util.equal_const_int(x_size - elem_block, 0):
+ base = i + 1
+ break
+ if base == 0:
+ raise RuntimeError("scope %s need to have block=%d, shape=%s" % (
+ scope, elem_block, buf.shape))
+ shape = [elem_block]
+ strides = [1]
+
+ if base < ndim + 1 and not util.equal_const_int(buf.strides[ndim - base], elem_block):
+ shape.append(1)
+ strides.append(elem_block)
+
+ analyzer = tvm.arith.Analyzer()
+ while base < ndim + 1:
+ x_size = 1
+ x_stride = buf.strides[ndim - base]
+ next_base = base
+ if not util.equal_const_int(idxm(x_stride, elem_block), 0):
+ raise RuntimeError(
+ "scope %s need to have block=%d, shape=%s, strides=%s" % (
+ scope, elem_block, buf.shape, buf.strides))
+ for i in range(base, ndim + 1):
+ k = ndim - i
+ if not util.equal_const_int(x_size * x_stride - buf.strides[k], 0):
+ break
+ x_size = x_size * buf.shape[k]
+ next_base = i + 1
+ shape.append(analyzer.simplify(x_size))
+ strides.append(x_stride)
+ assert next_base != base
+ base = next_base
+
+ strides = list(reversed(strides))
+ shape = list(reversed(shape))
+ return shape, strides
+
+ def _get_2d_pattern(buf, elem_width, elem_bytes, dtype, scope, allow_fold):
+ elem_block = elem_bytes * 8 // elem_width
+ if buf.dtype != dtype:
+ raise RuntimeError("Expect buffer type to be %s instead of %s" %
+ (dtype, buf.dtype))
+ shape, strides = buf.shape, buf.strides
+ if not util.equal_const_int(idxm(buf.elem_offset, elem_block), 0):
+ raise RuntimeError("scope %s need to have block=%d" % (scope, elem_block))
+ if allow_fold:
+ shape, strides = _fold_buffer_dim(buf, scope, elem_block)
+ else:
+ shape = list(x for x in shape)
+ strides = list(x for x in strides)
+
+ def raise_error():
+ """Internal function to raise error """
+ raise RuntimeError(
+ ("Scope[%s]: cannot detect 2d pattern with elem_block=%d:" +
+ " shape=%s, strides=%s") % (scope, elem_block, buf.shape, buf.strides))
+
+ ndim = len(shape)
+
+ # Check if the inner-tensor is already flat
+ flat = util.equal_const_int(shape[-1], elem_block)
+
+ if flat:
+ if not util.equal_const_int(strides[-1], 1):
+ raise_error()
+
+ if ndim == 1:
+ x_size = 1
+ x_stride = 1
+ y_size = 1
+ return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+ if not util.equal_const_int(strides[-2] - elem_block, 0):
+ raise_error()
+
+ if ndim == 2:
+ x_size = shape[-2]
+ x_stride = shape[-2]
+ y_size = 1
+ return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+ if not util.equal_const_int(idxm(strides[-3], elem_block), 0):
+ raise_error()
+
+ if ndim == 3:
+ x_size = shape[-2]
+ x_stride = idxd(strides[-3], elem_block)
+ y_size = shape[-3]
+ return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+
+ else:
+ if not util.equal_const_int(strides[-1], 1):
+ raise_error()
+ if not util.equal_const_int(strides[-2] - shape[-1], 0):
+ raise_error()
+ if not util.equal_const_int(shape[-1] * shape[-2], elem_block):
+ raise_error()
+
+ if ndim == 2:
+ x_size = 1
+ x_stride = 1
+ y_size = 1
+ return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+ if not util.equal_const_int(strides[-3], elem_block):
+ raise_error()
+
+ if ndim == 3:
+ x_size = shape[-3]
+ x_stride = shape[-3]
+ y_size = 1
+ return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+ if not util.equal_const_int(idxm(strides[-4], elem_block), 0):
+ raise_error()
+
+ if ndim == 4:
+ x_size = shape[-3]
+ x_stride = idxd(strides[-4], elem_block)
+ y_size = shape[-4]
+ return x_size, y_size, x_stride, idxd(buf.elem_offset, elem_block)
+
+ raise_error()
+
+
+ def _inject_copy(src, dst, pad_before, pad_after, pad_value):
+ # FIXME: pad_value is ignored...
+ env = get_env()
+ _ = pad_value
+ if dst.scope == "global":
+ # Store
+ if pad_before or pad_after:
+ raise RuntimeError("Do not support copy into DRAM with pad")
+ if src.scope == env.acc_scope:
+ elem_width = env.OUT_WIDTH
+ elem_bytes = env.OUT_ELEM_BYTES
+ mem_type = env.dev.MEM_ID_OUT
+ data_type = "int%d" % env.OUT_WIDTH
+ task_qid = env.dev.QID_STORE_OUT
+ else:
+ raise RuntimeError("Do not support copy %s->dram" % (src.scope))
+ _check_compact(src)
+ x_size, y_size, x_stride, offset = _get_2d_pattern(
+ dst, elem_width, elem_bytes, data_type, src.scope, allow_fold=True)
+ irb = tvm.tir.ir_builder.create()
+ irb.scope_attr(env.dev.vta_axis, "coproc_scope",
+ env.dev.get_task_qid(task_qid))
+ irb.emit(tvm.tir.call_extern(
+ "int32", "VTAStoreBuffer2D",
+ env.dev.command_handle,
+ src.access_ptr("r", "int32"),
+ mem_type, dst.data, offset, x_size, y_size, x_stride))
+ return irb.get()
+ elif src.scope == "global":
+ if dst.scope == env.acc_scope:
+ elem_width = env.ACC_WIDTH
+ elem_bytes = env.ACC_ELEM_BYTES
+ mem_type = env.dev.MEM_ID_ACC
+ data_type = "int%d" % env.ACC_WIDTH
+ task_qid = env.dev.QID_LOAD_OUT
+ elif dst.scope == env.inp_scope:
+ elem_width = env.INP_WIDTH
+ elem_bytes = env.INP_ELEM_BYTES
+ mem_type = env.dev.MEM_ID_INP
+ data_type = "int%d" % env.INP_WIDTH
+ task_qid = env.dev.QID_LOAD_INP
+ elif dst.scope == env.wgt_scope:
+ elem_width = env.WGT_WIDTH
+ elem_bytes = env.WGT_ELEM_BYTES
+ mem_type = env.dev.MEM_ID_WGT
+ data_type = "int%d" % env.WGT_WIDTH
+ task_qid = env.dev.QID_LOAD_WGT
+ else:
+ raise RuntimeError("Do not support copy dram->%s" % (dst.scope))
+ # collect pad statistics
+ if pad_before:
+ assert pad_after
+ ndim = len(pad_before)
+ if ndim <= 2 or ndim > 5:
+ raise ValueError("Limitation of 2D pad load forbid ndim=%d" % ndim)
+ if ndim == 5:
+ # This case occurs when batch size N > 1
+ y_pad_before = pad_before[1]
+ x_pad_before = pad_before[2]
+ y_pad_after = pad_after[1]
+ x_pad_after = pad_after[2]
+ for dim in range(3, ndim):
+ if not util.equal_const_int(pad_before[dim], 0):
+ raise ValueError("Do not support pad on the innermost block")
+ if not util.equal_const_int(pad_after[dim], 0):
+ raise ValueError("Do not support pad on the innermost block")
+ else:
+ y_pad_before = pad_before[0]
+ x_pad_before = pad_before[1]
+ y_pad_after = pad_after[0]
+ x_pad_after = pad_after[1]
+ for dim in range(2, ndim):
+ if not util.equal_const_int(pad_before[dim], 0):
+ raise ValueError("Do not support pad on the innermost block")
+ if not util.equal_const_int(pad_after[dim], 0):
+ raise ValueError("Do not support pad on the innermost block")
+ allow_fold = False
+ else:
+ x_pad_before = 0
+ y_pad_before = 0
+ x_pad_after = 0
+ y_pad_after = 0
+ allow_fold = True
+
+ _check_compact(dst)
+ x_size, y_size, x_stride, offset = _get_2d_pattern(
+ src, elem_width, elem_bytes, data_type,
+ dst.scope, allow_fold=allow_fold)
+
+ irb = tvm.tir.ir_builder.create()
+ irb.scope_attr(env.dev.vta_axis, "coproc_scope",
+ env.dev.get_task_qid(task_qid))
+
+ irb.emit(tvm.tir.call_extern(
+ "int32", "VTALoadBuffer2D",
+ env.dev.command_handle,
+ src.data, offset, x_size, y_size, x_stride,
+ x_pad_before, y_pad_before,
+ x_pad_after, y_pad_after,
+ dst.access_ptr("r", "int32"), mem_type))
+ return irb.get()
+
+ else:
+ raise RuntimeError("Do not support copy %s->%s" % (src.scope, dst.scope))
+
+ return tvm.tir.transform.InjectCopyIntrin("dma_copy", _inject_copy)
+
+
+def _get_gemm_intrin_buffer():
+ env = get_env()
+ wgt_lanes = env.WGT_ELEM_BITS // env.WGT_WIDTH
+ assert wgt_lanes == env.BLOCK_OUT * env.BLOCK_IN
+ wgt_shape = (env.BLOCK_OUT, env.BLOCK_IN)
+ assert wgt_shape[0] * wgt_shape[1] == wgt_lanes
+ inp_lanes = env.INP_ELEM_BITS // env.INP_WIDTH
+ assert inp_lanes == env.BATCH * env.BLOCK_IN
+ inp_shape = (env.BATCH, env.BLOCK_IN)
+ assert inp_shape[0] * inp_shape[1] == inp_lanes
+ out_lanes = env.ACC_ELEM_BITS // env.ACC_WIDTH
+ assert out_lanes == env.BATCH * env.BLOCK_OUT
+ out_shape = (env.BATCH, env.BLOCK_OUT)
+ assert out_shape[0] * out_shape[1] == out_lanes
+ wgt = te.placeholder((wgt_shape[0], wgt_shape[1]),
+ dtype="int%d" % env.WGT_WIDTH,
+ name=env.wgt_scope)
+ inp = te.placeholder((inp_shape[0], inp_shape[1]),
+ dtype="int%d" % env.INP_WIDTH,
+ name=env.inp_scope)
+ k = te.reduce_axis((0, wgt_shape[1]), name="k")
+ out_dtype = "int%d" % env.ACC_WIDTH
+ out = te.compute((out_shape[0], out_shape[1]),
+ lambda i, j: te.sum(inp[i, k].astype(out_dtype) *
+ wgt[j, k].astype(out_dtype),
+ axis=[k]),
+ name="out")
+ wgt_layout = tvm.tir.decl_buffer(
+ wgt.shape, wgt.dtype, env.wgt_scope,
+ scope=env.wgt_scope, offset_factor=wgt_lanes, data_alignment=wgt_lanes)
+ inp_layout = tvm.tir.decl_buffer(
+ inp.shape, inp.dtype, env.inp_scope,
+ scope=env.inp_scope, offset_factor=inp_lanes, data_alignment=inp_lanes)
+ out_layout = tvm.tir.decl_buffer(
+ out.shape, out.dtype, env.acc_scope,
+ scope=env.acc_scope, offset_factor=out_lanes, data_alignment=out_lanes)
+
+ return wgt_layout, inp_layout, out_layout
+
+
+def InjectConv2DTransposeSkip():
+ """Pass to skip 0-weights in conv2d transpose with stride > 1.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The pass
+ """
+ def _ftransform(func, mod, ctx):
+ env = get_env()
+ dwgt, dinp, dout = _get_gemm_intrin_buffer()
+
+ calls = []
+ selects = []
+
+ def _find_basics(op):
+ if isinstance(op, tvm.tir.BufferLoad):
+ calls.append(op)
+ elif isinstance(op, tvm.tir.Select):
+ selects.append(op)
+
+ def _do_fold(op):
+ if _match_pragma(op, "conv2d_transpose_gemm"):
+ is_init = ".init" in str(op)
+ tvm.tir.ir_pass.PostOrderVisit(op, _find_basics)
+
+ if is_init:
+ # create inner most block
+ irb = tvm.tir.ir_builder.create()
+ dev = env.dev
+ irb.scope_attr(dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
+ irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
+ irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
+ 0, 1,
+ dout.access_ptr("rw", "int32"),
+ 0, 0,
+ 0, 0, 0))
+ inner = irb.get()
+ # TODO(@tmoreau89): This is only a temporary fix, please take a look.
+ body = op.body.body
+ while isinstance(body, tvm.tir.IfThenElse):
+ body = body.then_case
+ args = body.indices
+ res_buffer = body.buffer
+ tpl = (args[0], 1, args[1], 1, args[2], 1, args[3], 1, 0, 1, 0, env.BLOCK_OUT)
+ inner = tvm.tir.AttrStmt(
+ [dout, res_buffer], 'buffer_bind_scope',
+ tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+ return inner
+ else:
+ conv_call, data_call, kernel_call = calls[-3:]
+ pad_data_tensor = data_call.buffer
+ kernel_tensor = kernel_call.buffer
+ res_tensor = conv_call.buffer
+
+ if selects:
+ condition = selects[0].condition
+ else:
+ condition = tvm.tir.const(1, 'int')
+
+ # create inner most block
+ irb = tvm.tir.ir_builder.create()
+ with irb.if_scope(condition):
+ dev = env.dev
+ irb.scope_attr(
+ dev.vta_axis, "coproc_scope", dev.get_task_qid(dev.QID_COMPUTE))
+ irb.scope_attr(dev.vta_axis, "coproc_uop_scope", dev.vta_push_uop)
+ irb.emit(tvm.tir.call_extern("int32", "VTAUopPush",
+ 0, 0,
+ dout.access_ptr("rw", "int32"),
+ dinp.access_ptr("r", "int32"),
+ dwgt.access_ptr("r", "int32"),
+ 0, 0, 0))
+ inner = irb.get()
+
+ args = conv_call.indices
+ tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
+ 1, 0, 1, 0, env.BLOCK_OUT)
+ inner = tvm.tir.AttrStmt(
+ [dout, res_tensor], 'buffer_bind_scope',
+ tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+ args = kernel_call.indices
+ tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
+ 1, 0, env.BLOCK_OUT, 0, env.BLOCK_IN)
+ inner = tvm.tir.AttrStmt(
+ [dwgt, kernel_tensor], 'buffer_bind_scope',
+ tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+ args = data_call.indices
+ tpl = (args[0], 1, args[1], 1, args[2], 1, args[3],
+ 1, 0, 1, 0, env.BLOCK_IN)
+ inner = tvm.tir.AttrStmt(
+ [dinp, pad_data_tensor], 'buffer_bind_scope',
+ tvm.tir.call_intrin('handle', 'tvm_tuple', *tpl), inner)
+ return inner
+ return None
+
+ return func.with_body(tvm.tir.ir_pass.IRTransform(
+ func.body, _do_fold, None, ["AttrStmt"]))
+ return tvm.tir.transform.prim_func_pass(
+ _ftransform, opt_level=0, name="tir.vta.InjectConv2DTrasnposeSkip")
+
+
+def AnnotateALUCoProcScope():
+ """Pass to insert ALU instruction.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The pass
+ """
+ def _ftransform(func, mod, ctx):
+ env = get_env()
+ def _do_fold(stmt):
+ if _match_pragma(stmt, "alu"):
+ irb = tvm.tir.ir_builder.create()
+ irb.scope_attr(env.dev.vta_axis, "coproc_scope",
+ env.dev.get_task_qid(env.dev.QID_COMPUTE))
+ irb.scope_attr(env.dev.vta_axis, "coproc_uop_scope",
+ tvm.tir.StringImm("VTAPushALUOp"))
+ irb.emit(stmt)
+ return irb.get()
+ if _match_pragma(stmt, "skip_alu"):
+ return tvm.tir.Evaluate(0)
+ return stmt
+
+ return func.with_body(tvm.tir.ir_pass.IRTransform(
+ func.body, None, _do_fold, ["AttrStmt"]))
+ return tvm.tir.transform.prim_func_pass(
+ _ftransform, opt_level=0, name="tir.vta.AnnotateALUCoProcScope")
+
+
+def InjectALUIntrin():
+ """Pass to inject ALU micro-ops.
+
+ Returns
+ -------
+ fpass : tvm.transform.Pass
+ The pass
+ """
+ def _ftransform(func, mod, ctx):
+ env = get_env()
+ idxm = tvm.tir.indexmod
+ analyzer = tvm.arith.Analyzer()
+
+ def _do_fold(stmt):
+ def _equal(x, y):
+ return tvm.ir.structural_equal(analyzer.simplify(x - y), 0)
+
+ def _flatten_loop(src_coeff, dst_coeff, extents):
+ src_coeff = list(src_coeff)
+ dst_coeff = list(dst_coeff)
+ extents = list(extents)
+ rev_src_coeff = [src_coeff.pop()]
+ rev_dst_coeff = [dst_coeff.pop()]
+ rev_extents = []
+ assert src_coeff
+ vsrc = src_coeff.pop()
+ vdst = dst_coeff.pop()
+ vext = extents.pop()
+ while src_coeff:
+ next_src = src_coeff.pop()
+ next_dst = dst_coeff.pop()
+ next_ext = extents.pop()
+
+ if _equal(next_src, vsrc * vext) and _equal(next_dst, vdst * vext):
+ vext = analyzer.simplify(vext * next_ext)
+ else:
+ rev_src_coeff.append(vsrc)
+ rev_dst_coeff.append(vdst)
+ rev_extents.append(vext)
+ vsrc = next_src
+ vdst = next_dst
+ vext = next_ext
+ rev_src_coeff.append(vsrc)
+ rev_dst_coeff.append(vdst)
+ rev_extents.append(vext)
+ rev_src_coeff.reverse()
+ rev_dst_coeff.reverse()
+ rev_extents.reverse()
+
+ return rev_src_coeff, rev_dst_coeff, rev_extents
+
+ if _match_pragma(stmt, "alu"):
+ # Get to the innermost loop body
+ loop_body = stmt.body
+ nest_size = 0
+ while isinstance(loop_body, tvm.tir.For):
+ loop_body = loop_body.body
+ nest_size += 1
+ # Get the src/dst arguments
+ dst_var = loop_body.buffer_var
+ dst_idx = loop_body.index
+ # Derive loop variables and extents
+ tmp_body = stmt.body
+ indices = []
+ extents = []
+ for _ in range(nest_size):
+ indices.append(tmp_body.loop_var)
+ extents.append(tmp_body.extent)
+ tmp_body = tmp_body.body
+ # Derive opcode
+ if isinstance(loop_body.value, tvm.tir.Add):
+ alu_opcode = env.dev.ALU_OPCODE_ADD
+ lhs = loop_body.value.a
+ rhs = loop_body.value.b
+ elif isinstance(loop_body.value, tvm.tir.Sub):
+ alu_opcode = env.dev.ALU_OPCODE_SUB
+ lhs = loop_body.value.a
+ rhs = loop_body.value.b
+ elif isinstance(loop_body.value, tvm.tir.Mul):
+ alu_opcode = env.dev.ALU_OPCODE_MUL
+ lhs = loop_body.value.a
+ rhs = loop_body.value.b
+ elif isinstance(loop_body.value, tvm.tir.Min):
+ alu_opcode = env.dev.ALU_OPCODE_MIN
+ lhs = loop_body.value.a
+ rhs = loop_body.value.b
+ elif isinstance(loop_body.value, tvm.tir.Max):
+ alu_opcode = env.dev.ALU_OPCODE_MAX
+ lhs = loop_body.value.a
+ rhs = loop_body.value.b
+ elif isinstance(loop_body.value, tvm.tir.Call):
+ if loop_body.value.name == 'shift_left':
+ alu_opcode = env.dev.ALU_OPCODE_SHR
+ lhs = loop_body.value.args[0]
+ rhs = analyzer.simplify(-loop_body.value.args[1])
+ elif loop_body.value.name == 'shift_right':
+ alu_opcode = env.dev.ALU_OPCODE_SHR
+ lhs = loop_body.value.args[0]
+ rhs = loop_body.value.args[1]
+ else:
+ raise RuntimeError(
+ "Function call not recognized %s" % (loop_body.value.name))
+ elif isinstance(loop_body.value, tvm.tir.Load):
+ alu_opcode = env.dev.ALU_OPCODE_SHR
+ lhs = loop_body.value
+ rhs = tvm.tir.const(0, "int32")
+ else:
+ raise RuntimeError(
+ "Expression not recognized %s, %s, %s" % (
+ type(loop_body.value), str(loop_body.value), str(stmt)))
+
+ # Derive array index coefficients
+ dst_coeff = tvm.arith.detect_linear_equation(dst_idx, indices)
+ # Check if lhs/rhs is immediate
+ use_imm = False
+ imm_val = None
+ if isinstance(rhs, tvm.tir.IntImm):
+ assert lhs.buffer_var.same_as(dst_var)
+ src_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
+ use_imm = True
+ imm_val = rhs
+ if isinstance(lhs, tvm.tir.IntImm):
+ assert rhs.buffer_var.same_as(dst_var)
+ src_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
+ use_imm = True
+ imm_val = lhs
+ if imm_val is None:
+ imm_val = 0
+ assert lhs.buffer_var.same_as(dst_var) and rhs.buffer_var.same_as(dst_var)
+ src_lhs_coeff = tvm.arith.detect_linear_equation(lhs.index, indices)
+ src_rhs_coeff = tvm.arith.detect_linear_equation(rhs.index, indices)
+ # Determine which side has the same coefficients
+ lhs_equal = True
+ rhs_equal = True
+ for i, coef in enumerate(dst_coeff):
+ if not tvm.ir.structural_equal(coef, src_lhs_coeff[i]):
+ lhs_equal = False
+ if not tvm.ir.structural_equal(coef, src_rhs_coeff[i]):
+ rhs_equal = False
+ # Make sure at least one of the source is identical to the
+ # destination (in-place computation)
+ assert lhs_equal or rhs_equal
+ # Assign the source coefficients
+ if lhs_equal:
+ src_coeff = src_rhs_coeff
+ else:
+ src_coeff = src_lhs_coeff
+
+ # Ensure that we have the proper tensor dimensions in the
+ # innermost loop (pattern match)
+ src_coeff = list(src_coeff)
+ dst_coeff = list(dst_coeff)
+ extents = list(extents)
+ assert len(src_coeff) > 1
+ assert len(dst_coeff) > 1
+ assert len(extents) != 0
+ assert tvm.ir.structural_equal(
+ analyzer.simplify(
+ idxm(src_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
+ assert tvm.ir.structural_equal(
+ analyzer.simplify(
+ idxm(dst_coeff[-1], env.BATCH * env.BLOCK_OUT)), 0)
+ assert tvm.ir.structural_equal(src_coeff[-2], 1)
+ assert tvm.ir.structural_equal(dst_coeff[-2], 1)
+ if env.BATCH > 1:
+ assert len(src_coeff) > 2
+ assert len(dst_coeff) > 2
+ assert len(extents) > 1
+ assert tvm.ir.structural_equal(src_coeff[-3], env.BLOCK_OUT)
+ assert tvm.ir.structural_equal(dst_coeff[-3], env.BLOCK_OUT)
+
+ # Apply tensorization of the loop coefficients
+ src_offset = src_coeff[-1]
+ dst_offset = dst_coeff[-1]
+ if env.BATCH == 1:
+ src_coeff = src_coeff[:-2]
+ dst_coeff = dst_coeff[:-2]
+ extents = extents[:-1]
+ else:
+ src_coeff = src_coeff[:-3]
+ dst_coeff = dst_coeff[:-3]
+ extents = extents[:-2]
+ src_coeff.append(src_offset)
+ dst_coeff.append(dst_offset)
+ src_coeff = [
+ analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in src_coeff]
+ dst_coeff = [
+ analyzer.simplify(c // (env.BATCH * env.BLOCK_OUT)) for c in dst_coeff]
+
+ # Flatten the outer loops
+ if extents:
+ src_coeff, dst_coeff, extents = _flatten_loop(src_coeff, dst_coeff, extents)
+
+ # Insert ALU micro-ops
+ irb = tvm.tir.ir_builder.create()
+ for idx, extent in enumerate(extents):
+ irb.emit(tvm.tir.call_extern(
+ "int32", "VTAUopLoopBegin",
+ extent, dst_coeff[idx], src_coeff[idx], 0))
+ use_imm = int(use_imm)
+ irb.emit(tvm.tir.call_extern(
+ "int32", "VTAUopPush",
+ 1, 0,
+ dst_coeff[len(dst_coeff)-1],
+ src_coeff[len(src_coeff)-1],
+ 0,
+ alu_opcode, use_imm, imm_val))
+ for extent in extents:
+ irb.emit(tvm.tir.call_extern(
+ "int32", "VTAUopLoopEnd"))
+ return irb.get()
+ return stmt
+
+ return func.with_body(tvm.tir.ir_pass.IRTransform(
+ func.body, None, _do_fold, ["AttrStmt"]))
+
+ return tvm.tir.transform.prim_func_pass(
+ _ftransform, opt_level=0, name="tir.vta.InjectALUIntrin")