You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/06/23 05:20:16 UTC
[incubator-tvm] branch master updated: [RFC] Improve quantized
convolution performance for armv8 architectures (#5754)
This is an automated email from the ASF dual-hosted git repository.
zhaowu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new b94e8b7 [RFC] Improve quantized convolution performance for armv8 architectures (#5754)
b94e8b7 is described below
commit b94e8b7290c5ced98728e730634ec73727c53c51
Author: Giuseppe Rossini <gi...@arm.com>
AuthorDate: Tue Jun 23 06:20:05 2020 +0100
[RFC] Improve quantized convolution performance for armv8 architectures (#5754)
* Improve quantized conv2d performance for armv8
Signed-off-by: Giuseppe Rossini <gi...@arm.com>
Change-Id: I3a3d29f5332dd9b3354e8e0dfb24677a521f9c8f
* Add ASF header to conv2d_gemm.py
Change-Id: I33853279e39c849ae1b555a9c91d7557985a0a35
* Run clang-format-10 on c++ files
Change-Id: Ieee22f032e595dabfc1616ab33466fcbf8d94365
* Fix pylint errors/warnings
Change-Id: I435d4d7bca7500db99547f4401fdc0d0995a1ff4
* Fix pylint errors/warnings in topi
Change-Id: I2fc1ad8453e9020072ab967c849df5390c2967b5
* Fix legalizations tests for aarch64
Change-Id: I0a67a49a7849f52ef7d57b9292ce9125bbb7cb2c
* Reintroduce conv2d_nhwc_spatial_pack.arm_cpu and int16 cast
Change-Id: I91b67fabd475e90a9b75f2dd5ecfee851265e0bb
* Switch type of legalization depending on the strategy used
Change-Id: I9a03040a8c40a6cd2658ed14c3751e05a8e19f2b
* Revert last commit
Change-Id: Ice34101e358e3ce8ebfb12c58f73e910ba5de8e8
* Fix the auto-tuner by registering the correct schedules
Change-Id: Id9273688b2620e1ea849ab01b4c46af8fbf37fd0
* Address review comments
Change-Id: Ia1755a0af7b6d159072d9f0c93c932c481101e48
* Improve usability and readability of conv2d_gemm_weight_transform
Change-Id: I3333186bbc2fe4054b58ce15d910e3be7b315482
* Change variable name to weight in Conv2DGemmWeightTransformRel
Change-Id: Ifb5f1f33af7512fe67c6b049b20a42a0bb2d26c9
* Fix clang-10 linting errors
Change-Id: I25ccc844d9cee23766096e1daddb6180abc413a6
* Trigger tests
Change-Id: Id37706fb7cf77a87a3cc817ecf8046297d9ca95a
---
include/tvm/relay/attrs/nn.h | 11 +
python/tvm/relay/op/nn/_nn.py | 17 ++
python/tvm/relay/op/nn/nn.py | 91 ++++++++
python/tvm/relay/op/strategy/arm_cpu.py | 42 ++++
python/tvm/relay/op/strategy/generic.py | 13 ++
python/tvm/relay/qnn/op/legalizations.py | 8 +-
src/relay/op/nn/convolution.cc | 82 +++++++
src/relay/op/nn/convolution.h | 131 +++++++++++
topi/python/topi/arm_cpu/conv2d_alter_op.py | 65 +++++-
topi/python/topi/arm_cpu/conv2d_gemm.py | 174 ++++++++++++++
topi/python/topi/arm_cpu/conv2d_int8.py | 38 +++-
topi/python/topi/arm_cpu/tensor_intrin.py | 339 ++++++++++++++++++++++++++++
topi/python/topi/generic/nn.py | 19 ++
topi/python/topi/nn/conv2d.py | 49 ++++
14 files changed, 1065 insertions(+), 14 deletions(-)
diff --git a/include/tvm/relay/attrs/nn.h b/include/tvm/relay/attrs/nn.h
index abe63e5..5f1ee2f 100644
--- a/include/tvm/relay/attrs/nn.h
+++ b/include/tvm/relay/attrs/nn.h
@@ -187,6 +187,17 @@ struct ConvWinogradWeightTransformAttrs : public tvm::AttrsNode<ConvWinogradWeig
}
};
+/*! \brief Attributes used in gemm weight transformation operators */
+struct ConvGemmWeightTransformAttrs : public tvm::AttrsNode<ConvGemmWeightTransformAttrs> {
+ int tile_rows;
+ int tile_cols;
+
+ TVM_DECLARE_ATTRS(ConvGemmWeightTransformAttrs, "relay.attrs.ConvGemmWeightTransformAttrs") {
+ TVM_ATTR_FIELD(tile_rows).describe("Tile rows of the weight transformation for ConvGemm.");
+ TVM_ATTR_FIELD(tile_cols).describe("Tile columns of the weight transformation for ConvGemm.");
+ }
+};
+
/*! \brief Attributes used in convolution operators with winograd algorithm */
struct Conv2DWinogradAttrs : public tvm::AttrsNode<Conv2DWinogradAttrs> {
int tile_size;
diff --git a/python/tvm/relay/op/nn/_nn.py b/python/tvm/relay/op/nn/_nn.py
index 1c76f57..564d6f7 100644
--- a/python/tvm/relay/op/nn/_nn.py
+++ b/python/tvm/relay/op/nn/_nn.py
@@ -446,6 +446,23 @@ reg.register_strategy("nn.contrib_conv2d_winograd_without_weight_transform",
reg.register_pattern("nn.contrib_conv2d_winograd_without_weight_transform",
OpPattern.OUT_ELEMWISE_FUSABLE)
+# conv2d_gemm related operators
+reg.register_strategy("nn.contrib_conv2d_gemm_without_weight_transform",
+ strategy.conv2d_gemm_without_weight_transform_strategy)
+reg.register_pattern("nn.contrib_conv2d_gemm_without_weight_transform",
+ OpPattern.OUT_ELEMWISE_FUSABLE)
+
+@reg.register_compute("nn.contrib_conv2d_gemm_weight_transform")
+def compute_contrib_conv2d_gemm_weight_transform(attrs, inputs, out_dtype):
+ """Compute definition of contrib_conv2d_gemm_weight_transform"""
+ out = topi.nn.conv2d_gemm_weight_transform(
+ inputs[0], attrs.tile_rows, attrs.tile_cols)
+ return [out]
+
+reg.register_schedule("nn.contrib_conv2d_gemm_weight_transform",
+ strategy.schedule_conv2d_gemm_weight_transform)
+reg.register_pattern("nn.contrib_conv2d_gemm_weight_transform",
+ OpPattern.OUT_ELEMWISE_FUSABLE)
@reg.register_compute("nn.contrib_conv2d_winograd_weight_transform")
def compute_contrib_conv2d_winograd_weight_transform(attrs, inputs, out_dtype):
diff --git a/python/tvm/relay/op/nn/nn.py b/python/tvm/relay/op/nn/nn.py
index 34d07dc..3c47cf7 100644
--- a/python/tvm/relay/op/nn/nn.py
+++ b/python/tvm/relay/op/nn/nn.py
@@ -2046,6 +2046,74 @@ def contrib_conv2d_winograd_without_weight_transform(data,
kernel_layout, out_layout, out_dtype)
+def contrib_conv2d_gemm_without_weight_transform(data,
+ weight,
+ strides=(1, 1),
+ padding=(0, 0),
+ dilation=(1, 1),
+ groups=1,
+ channels=None,
+ kernel_size=None,
+ data_layout="NCHW",
+ kernel_layout="OIHW",
+ out_layout="",
+ out_dtype=""):
+ r"""2D convolution with gemm algorithm.
+
+ The basic parameters are the same as the ones in vanilla conv2d.
+ It assumes the weight is pre-transformed by nn.contrib_conv2d_gemm_weight_transform
+
+ Parameters
+ ----------
+ data : tvm.relay.Expr
+ The input data to the operator.
+
+ weight : tvm.relay.Expr
+ The weight expressions.
+
+ strides : tuple of int, optional
+ The strides of convolution.
+
+ padding : tuple of int, optional
+ The padding of convolution on both sides of inputs before convolution.
+
+ dilation : tuple of int, optional
+ Specifies the dilation rate to be used for dilated convolution.
+
+ groups : int, optional
+ Number of groups for grouped convolution.
+
+ channels : int, optional
+ Number of output channels of this convolution.
+
+ kernel_size : tuple of int, optional
+ The spatial of the convolution kernel.
+
+ data_layout : str, optional
+ Layout of the input.
+
+ kernel_layout : str, optional
+ Layout of the weight.
+
+ out_layout : str, optional
+ Layout of the output, by default, out_layout is the same as data_layout
+
+ out_dtype : str, optional
+ Specifies the output data type for mixed precision conv2d.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+ # convert 2-way padding to 4-way padding
+ padding = get_pad_tuple2d(padding)
+ return _make.contrib_conv2d_gemm_without_weight_transform(
+ data, weight, strides, padding, dilation,
+ groups, channels, kernel_size, data_layout,
+ kernel_layout, out_layout, out_dtype)
+
+
def contrib_conv2d_nchwc(data,
kernel,
strides=(1, 1),
@@ -2204,6 +2272,29 @@ def contrib_conv2d_winograd_weight_transform(weight,
return _make.contrib_conv2d_winograd_weight_transform(weight, tile_size)
+def contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols):
+ r"""Weight Transformation part for 2D convolution with gemm algorithm.
+
+ We separate this as a single op to enable pre-compute for inference.
+ Use this together with nn.contrib_conv2d_gemm_without_weight_transform
+
+ Parameters
+ ----------
+ weights : tvm.relay.Expr
+ The weight expressions.
+ tile_rows: int
+ Tile rows of the weight transformation for ConvGemm.
+ tile_cols: int
+ Tile columns of the weight transformation for ConvGemm.
+
+ Returns
+ -------
+ result : tvm.relay.Expr
+ The computed result.
+ """
+ return _make.contrib_conv2d_gemm_weight_transform(weights, tile_rows, tile_cols)
+
+
def contrib_conv3d_winograd_weight_transform(weight,
tile_size):
r"""Weight Transformation part for 3D convolution with winograd algorithm.
diff --git a/python/tvm/relay/op/strategy/arm_cpu.py b/python/tvm/relay/op/strategy/arm_cpu.py
index 6bdec67..d682aad 100644
--- a/python/tvm/relay/op/strategy/arm_cpu.py
+++ b/python/tvm/relay/op/strategy/arm_cpu.py
@@ -112,6 +112,14 @@ def conv2d_strategy_arm_cpu(attrs, inputs, out_type, target):
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_direct_simd),
name='conv2d_direct_simd.micro_dev')
elif kernel_layout == "HWIO":
+ is_aarch64 = "aarch64" in str(isa.target)
+
+ if is_aarch64 and data.dtype in ["int8", "uint8"]:
+ strategy.add_implementation(
+ wrap_compute_conv2d(topi.arm_cpu.compute_conv2d_NHWC_quantized),
+ wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
+ name="conv2d_NHWC_quantized.arm_cpu")
+
strategy.add_implementation(
wrap_compute_conv2d(topi.arm_cpu.conv2d_nhwc_spatial_pack),
wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_nhwc_spatial_pack),
@@ -246,6 +254,40 @@ def conv2d_winograd_without_weight_transfrom_strategy_arm_cpu(attrs, inputs, out
format(layout))
return strategy
+def wrap_compute_conv2d_gemm(topi_compute):
+ """wrap topi compute for conv2d_gemm"""
+
+ def _compute_conv2d_gemm(attrs, inputs, out_type):
+ padding = attrs.get_int_tuple("padding")
+ strides = attrs.get_int_tuple("strides")
+ dilation = attrs.get_int_tuple("dilation")
+ out_dtype = attrs.get_str("out_dtype")
+ channels = attrs['channels']
+ kernel_size = attrs['kernel_size']
+ out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
+ return [topi_compute(inputs[0], inputs[1], strides, padding,
+ dilation, out_dtype, kernel_size, channels)]
+
+ return _compute_conv2d_gemm
+
+@conv2d_gemm_without_weight_transform_strategy.register("arm_cpu")
+def conv2d_gemm_without_weight_transform_strategy_arm_cpu(attrs, inputs, out_type, target):
+ """conv2d_winograd_without_weight_transfrom arm cpu strategy"""
+ layout = attrs.data_layout
+ data = inputs[0]
+ strategy = _op.OpStrategy()
+
+ if layout == "NHWC" and data.dtype in ['int8', 'uint8']:
+ strategy.add_implementation(
+ wrap_compute_conv2d_gemm(topi.arm_cpu.compute_conv2d_NHWC_quantized_without_transform),
+ wrap_topi_schedule(topi.arm_cpu.schedule_conv2d_NHWC_quantized),
+ name="conv2d_NHWC_quantized_without_transform.arm_cpu")
+ else:
+ raise RuntimeError(
+ "Unsupported conv2d_gemm_without_weight_transform layout {0} with datatype {1}".
+ format(layout, data.dtype))
+ return strategy
+
@conv2d_transpose_strategy.register(["arm_cpu", "micro_dev"])
def conv2d_transpose_strategy_arm_cpu(attrs, inputs, out_type, target):
"""conv2d_transpose arm cpu strategy"""
diff --git a/python/tvm/relay/op/strategy/generic.py b/python/tvm/relay/op/strategy/generic.py
index b1fb421..a0dd6bf 100644
--- a/python/tvm/relay/op/strategy/generic.py
+++ b/python/tvm/relay/op/strategy/generic.py
@@ -266,6 +266,12 @@ def conv2d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, t
"""conv2d_winograd_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv2d_winograd_without_weight_transform")
+# conv2d_gemm_without_weight_transform
+@override_native_generic_func("conv2d_gemm_without_weight_transform_strategy")
+def conv2d_gemm_without_weight_transform_strategy(attrs, inputs, out_type, target):
+ """conv2d_gemm_without_weight_transfrom generic strategy"""
+ raise ValueError("No generic implemenation for conv2d_gemm_without_weight_transform")
+
# conv2d_winograd_weight_transform
@generic_func
def schedule_conv2d_winograd_weight_transform(attrs, outs, target):
@@ -280,6 +286,13 @@ def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
+# conv2d_gemm_weight_transform
+@generic_func
+def schedule_conv2d_gemm_weight_transform(attrs, outs, target):
+ """Schedule conv2d_gemm_weight_transform"""
+ with target:
+ return topi.generic.schedule_conv2d_gemm_weight_transform(outs)
+
# deformable_conv2d
def wrap_compute_deformable_conv2d(topi_compute):
"""wrap deformable_conv2d topi compute"""
diff --git a/python/tvm/relay/qnn/op/legalizations.py b/python/tvm/relay/qnn/op/legalizations.py
index d3b0e44..7246214 100644
--- a/python/tvm/relay/qnn/op/legalizations.py
+++ b/python/tvm/relay/qnn/op/legalizations.py
@@ -237,6 +237,11 @@ def is_fast_int8_on_arm():
target = tvm.target.Target.current(allow_none=False)
return '+v8.2a,+dotprod' in ' '.join(target.options)
+def is_aarch64_arm():
+ """ Checks whether we are compiling for an AArch64 target. """
+ target = tvm.target.Target.current(allow_none=False)
+ return 'aarch64' in ' '.join(target.options)
+
########################
# ARM CPU legalizations.
########################
@@ -244,10 +249,11 @@ def is_fast_int8_on_arm():
@qnn_conv2d_legalize.register('arm_cpu')
def _qnn_conv2d_legalize_arm_cpu(attrs, inputs, types):
# ARM prefers the dtypes to be same.
- if is_fast_int8_on_arm():
+ if (is_aarch64_arm() and attrs["data_layout"] == "NHWC") or is_fast_int8_on_arm():
return helper_change_dtypes_to_be_same(attrs, inputs, types, relay.qnn.op.conv2d)
return helper_no_fast_int8_hw_legalization(attrs, inputs, types, relay.nn.conv2d)
+
@qnn_dense_legalize.register('arm_cpu')
def _qnn_dense_legalize_arm_cpu(attrs, inputs, types):
# ARM prefers the dtypes to be same.
diff --git a/src/relay/op/nn/convolution.cc b/src/relay/op/nn/convolution.cc
index 6c6eb1e..f63c489 100644
--- a/src/relay/op/nn/convolution.cc
+++ b/src/relay/op/nn/convolution.cc
@@ -77,6 +77,26 @@ Expr MakeConvWinograd(Expr data, Expr weight, int tile_size, Array<IndexExpr> st
return Call(op, {data, weight}, Attrs(attrs), {});
}
+template <typename T>
+Expr MakeConvGemm(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilation, int groups, IndexExpr channels,
+ Array<IndexExpr> kernel_size, std::string data_layout, std::string kernel_layout,
+ std::string out_layout, DataType out_dtype, std::string op_name) {
+ auto attrs = make_object<T>();
+ attrs->strides = std::move(strides);
+ attrs->padding = std::move(padding);
+ attrs->dilation = std::move(dilation);
+ attrs->groups = groups;
+ attrs->channels = std::move(channels);
+ attrs->kernel_size = std::move(kernel_size);
+ attrs->data_layout = std::move(data_layout);
+ attrs->kernel_layout = std::move(kernel_layout);
+ attrs->out_layout = std::move(out_layout);
+ attrs->out_dtype = std::move(out_dtype);
+ const Op& op = Op::Get(op_name);
+ return Call(op, {data, weight}, Attrs(attrs), {});
+}
+
Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_name) {
auto attrs = make_object<ConvWinogradWeightTransformAttrs>();
attrs->tile_size = tile_size;
@@ -84,6 +104,14 @@ Expr MakeConvWinogradWeightTransform(Expr weight, int tile_size, std::string op_
return Call(op, {weight}, Attrs(attrs), {});
}
+Expr MakeConvGemmWeightTransform(Expr weight, int tile_rows, int tile_cols, std::string op_name) {
+ auto attrs = make_object<ConvGemmWeightTransformAttrs>();
+ attrs->tile_rows = tile_rows;
+ attrs->tile_cols = tile_cols;
+ const Op& op = Op::Get(op_name);
+ return Call(op, {weight}, Attrs(attrs), {});
+}
+
template <typename T>
Expr MakeConvTranspose(Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
Array<IndexExpr> dilation, int groups, IndexExpr channels,
@@ -504,6 +532,60 @@ weight transformation in advance.
.set_support_level(10)
.add_type_rel("Conv2DWinogradNNPACKWeightTransform", Conv2DWinogradNNPACKWeightTransformRel);
+// relay.nn.contrib_conv2d_gemm_without_weight_transform
+TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_without_weight_transform")
+ .set_body_typed([](Expr data, Expr weight, Array<IndexExpr> strides, Array<IndexExpr> padding,
+ Array<IndexExpr> dilation, int groups, IndexExpr channels,
+ Array<IndexExpr> kernel_size, std::string data_layout,
+ std::string kernel_layout, std::string out_layout, DataType out_dtype) {
+ return MakeConvGemm<Conv2DAttrs>(
+ data, weight, strides, padding, dilation, groups, channels, kernel_size, data_layout,
+ kernel_layout, out_layout, out_dtype, "nn.contrib_conv2d_gemm_without_weight_transform");
+ });
+
+RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_without_weight_transform")
+ .describe(R"code(Compute conv2d with gemm algorithm. Only supports NHWC layout.
+ This operator assumes the weight tensor is already pre-transformed by
+ nn.contrib_conv2d_gemm_weight_transform.
+
+- **data**: Input is 4D array of shape (batch_size, height, width, in_channels)
+- **weight**: Any shape
+ We do not check the shape for this input tensor. Since different backend
+ has different layout strategy.
+
+- **out**: Output is 4D array of shape (batch_size, channels, out_height, out_width)
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<Conv2DAttrs>()
+ .set_num_inputs(2)
+ .add_argument("data", "Tensor", "The input tensor.")
+ .add_argument("weight", "Tensor", "The weight tensor.")
+ .set_support_level(10)
+ .add_type_rel("Conv2DGemm", Conv2DGemmRel<Conv2DAttrs>)
+ .set_attr<FInferCorrectLayout>("FInferCorrectLayout", ConvInferCorrectLayout<Conv2DAttrs>);
+
+// relay.nn.contrib_conv2d_gemm_weight_transform
+
+TVM_REGISTER_NODE_TYPE(ConvGemmWeightTransformAttrs);
+
+TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_gemm_weight_transform")
+ .set_body_typed([](Expr weights, int tile_rows, int tile_cols) {
+ return MakeConvGemmWeightTransform(weights, tile_rows, tile_cols,
+ "nn.contrib_conv2d_gemm_weight_transform");
+ });
+
+RELAY_REGISTER_OP("nn.contrib_conv2d_gemm_weight_transform")
+ .describe(R"code(Weight transformation of GEMM convolution algorithm.
+
+Separate this into another operator in order to enable Precompute Pass to compute the
+weight transformation in advance.
+
+)code" TVM_ADD_FILELINE)
+ .set_attrs_type<ConvGemmWeightTransformAttrs>()
+ .set_num_inputs(1)
+ .add_argument("weights", "Tensor", "The weights tensor.")
+ .set_support_level(10)
+ .add_type_rel("Conv2DGemmWeightTransform", Conv2DGemmWeightTransformRel);
+
// Positional relay function to create conv2d NCHWc operator
// used by frontend FFI.
TVM_REGISTER_GLOBAL("relay.op.nn._make.contrib_conv2d_NCHWc")
diff --git a/src/relay/op/nn/convolution.h b/src/relay/op/nn/convolution.h
index 0c5b20a..f53f4e0 100644
--- a/src/relay/op/nn/convolution.h
+++ b/src/relay/op/nn/convolution.h
@@ -383,6 +383,65 @@ inline bool Conv2DWinogradWeightTransformRel(const Array<Type>& types, int num_i
return true;
}
+// Gemm convolution shape relations
+// In order to run GEMM we need to block-transpose and interleave the K x N weights matrix W.
+// The high level idea is to subdivide W in tiles of tile_cols x tile_rows, and transpose and
+// interleave them. The final output is a [N//tile_rows, K//tile_cols, tile_rows, tile_cols]
+// matrix that we call W_interleaved_t.
+//
+// In the following picture, we show how the first [tile_cols,tile_rows] block of W is transformed
+// for tile_rows = 4 and tile_cols = 16
+//
+// W[0,0,:,:] W_interleaved_t[0,0,:,:]
+// +-------------------------------+ +----------------------------------- +
+// |W[0,0] W[0,1] W[0,2] W[0,3] | |W[0,0] W[1,0] W[2,0] ... W[15,0]|
+// |W[1,0] W[1,1] W[1,2] W[1,3] | --\ |W[0,1] W[1,1] W[2,1] ... W[15,1]|
+// |W[2,0] W[2,1] W[2,2] W[2,3] | --/ |W[0,2] W[1,2] W[2,2] ... W[15,2]|
+// | ... ... ... ... | |W[0,3] W[1,3] W[2,3] ... W[15,3]|
+// | ... ... ... ... | +------------------------------------+
+// |W[15,0] W[15,1] W[15,2] W[15,3]|
+// +-------------------------------+
+//
+// Tile columns is usually the direction of the reduction. So, if our target can reduce k elements
+// at the time, we should set tile_cols = k.
+// Tile rows is connected with the number of registers available for the given target.
+//
+inline bool Conv2DGemmWeightTransformRel(const Array<Type>& types, int num_inputs,
+ const Attrs& attrs, const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 2);
+ const auto* weight = types[0].as<TensorTypeNode>();
+ if (weight == nullptr) return false;
+
+ const ConvGemmWeightTransformAttrs* param = attrs.as<ConvGemmWeightTransformAttrs>();
+ CHECK(param != nullptr);
+ int n = param->tile_rows;
+ int k = param->tile_cols;
+
+ CHECK_EQ(weight->shape.size(), 4) << "Only support HWIO kernel layout";
+
+ const auto K = weight->shape[0] * weight->shape[1] * weight->shape[2];
+ const auto N = weight->shape[3];
+
+ auto K_mod_k = indexmod(K, k);
+ auto N_mod_n = indexmod(N, n);
+
+ auto pad_K = tvm::if_then_else(K_mod_k != 0, k - K_mod_k, tir::make_zero(DataType::Int(32)));
+ auto pad_N = tvm::if_then_else(N_mod_n != 0, n - N_mod_n, tir::make_zero(DataType::Int(32)));
+
+ const auto N_padded = N + pad_N;
+ const auto K_padded = K + pad_K;
+
+ Array<IndexExpr> oshape{
+ indexdiv(N_padded, n),
+ indexdiv(K_padded, k),
+ n,
+ k,
+ };
+
+ reporter->Assign(types[1], TensorType(oshape, weight->dtype));
+ return true;
+}
+
inline bool Conv3DWinogradWeightTransformRel(const Array<Type>& types, int num_inputs,
const Attrs& attrs, const TypeReporter& reporter) {
CHECK_EQ(types.size(), 2);
@@ -520,6 +579,78 @@ bool Conv2DWinogradRel(const Array<Type>& types, int num_inputs, const Attrs& at
}
template <typename AttrType>
+bool Conv2DGemmRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
+ const TypeReporter& reporter) {
+ CHECK_EQ(types.size(), 3);
+ const auto* data = types[0].as<TensorTypeNode>();
+ if (data == nullptr) return false;
+ static const Layout kNHWC("NHWC");
+ static const Layout kHWIO("HWIO");
+
+ const AttrType* param = attrs.as<AttrType>();
+ CHECK(param != nullptr);
+ const Layout in_layout(param->data_layout);
+ const Layout kernel_layout(param->kernel_layout);
+
+ const auto trans_in_layout = tir::BijectiveLayout(in_layout, kNHWC);
+ CHECK(trans_in_layout.defined())
+ << "Conv only support input layouts that are convertible from NHWC."
+ << " But got " << in_layout;
+
+ const auto trans_kernel_layout = tir::BijectiveLayout(kernel_layout, kHWIO);
+ CHECK(trans_kernel_layout.defined())
+ << "Conv only support kernel layouts that are convertible from HWIO."
+ << " But got " << kernel_layout;
+
+ Layout out_layout(param->out_layout == "" ? param->data_layout : param->out_layout);
+ const auto trans_out_layout = tir::BijectiveLayout(out_layout, kNHWC);
+ CHECK(trans_out_layout.defined())
+ << "Conv only support output layouts that are convertible from NHWC."
+ << " But got " << out_layout;
+
+ Array<IndexExpr> dshape_nhwc = trans_in_layout.ForwardShape(data->shape);
+
+ IndexExpr channels, dilated_ksize_y, dilated_ksize_x;
+
+ CHECK(param->kernel_size.defined() && param->channels.defined())
+ << "The kernel size and channels of a Conv must be set or inferred by previous pass";
+
+ CHECK_EQ(param->kernel_size.size(), 2);
+ CHECK_EQ(param->dilation.size(), 2);
+
+ channels = param->channels;
+ dilated_ksize_y = 1 + (param->kernel_size[0] - 1) * param->dilation[0];
+ dilated_ksize_x = 1 + (param->kernel_size[1] - 1) * param->dilation[1];
+
+ // NOTE: Do not check weight shape here!
+
+ // dilation
+ Array<IndexExpr> oshape({dshape_nhwc[0], 0, 0, channels});
+
+ IndexExpr pad_h, pad_w;
+ GetPaddingHeightWidth(param->padding, &pad_h, &pad_w);
+ if (!dshape_nhwc[2].as<tir::AnyNode>()) {
+ oshape.Set(1, (dshape_nhwc[1] + pad_h - dilated_ksize_y) / param->strides[0] + 1);
+ } else {
+ oshape.Set(1, dshape_nhwc[1]);
+ }
+ if (!dshape_nhwc[3].as<tir::AnyNode>()) {
+ oshape.Set(2, (dshape_nhwc[2] + pad_w - dilated_ksize_x) / param->strides[1] + 1);
+ } else {
+ oshape.Set(2, dshape_nhwc[2]);
+ }
+
+ DataType out_dtype = param->out_dtype;
+ if (out_dtype.bits() == 0) {
+ out_dtype = data->dtype;
+ }
+ oshape = trans_out_layout.BackwardShape(oshape);
+ // assign output type
+ reporter->Assign(types[2], TensorType(oshape, out_dtype));
+ return true;
+}
+
+template <typename AttrType>
bool Conv3DWinogradRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
const TypeReporter& reporter) {
CHECK_EQ(types.size(), 3);
diff --git a/topi/python/topi/arm_cpu/conv2d_alter_op.py b/topi/python/topi/arm_cpu/conv2d_alter_op.py
index 3206168..99fdf21 100644
--- a/topi/python/topi/arm_cpu/conv2d_alter_op.py
+++ b/topi/python/topi/arm_cpu/conv2d_alter_op.py
@@ -59,10 +59,6 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
data, kernel = tinfos
out_dtype = out_type.dtype
- # We only perform layout alteration for NCHW data layout.
- if data_layout == "NHWC":
- return None
-
# Extract data types
data_tensor, kernel_tensor = tinfos
data_dtype = data_tensor.dtype
@@ -70,6 +66,10 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
idxd = tvm.tir.indexdiv
+ # We don't perform layout alteration for NHWC layout with real data types
+ if data_layout == "NHWC" and data_dtype not in ['uint8', 'int8']:
+ return None
+
if topi_tmpl == "conv2d_nchw_spatial_pack.arm_cpu":
assert data_layout == "NCHW" and kernel_layout == "OIHW"
N, CI, H, W = get_const_tuple(data.shape)
@@ -88,21 +88,27 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
return relay.nn.conv2d(*inputs, **new_attrs)
if topi_tmpl == "conv2d_nhwc_spatial_pack.arm_cpu":
+ assert (data.dtype == 'int8' and kernel.dtype == 'int8' or
+ data.dtype == 'uint8' and kernel.dtype == 'uint8')
+
assert data_layout == "NHWC" and kernel_layout == "HWIO"
- N, H, W, CI = get_const_tuple(data.shape)
- KH, KW, _, CO = get_const_tuple(kernel.shape)
- VC = cfg['tile_co'].size[-1]
- new_attrs['kernel_layout'] = 'OHWI%do' % VC
+ data_expr, kernel_expr = inputs
+
+ data_int16 = relay.cast(data_expr, dtype='int16')
+ kernel_int16 = relay.cast(kernel_expr, dtype='int16')
+
+ new_attrs = {k : attrs[k] for k in attrs.keys()}
+
+ new_data = te.placeholder(data.shape, 'int16')
+ new_kernel = te.placeholder(kernel.shape, 'int16')
- new_data = data
- new_kernel = te.placeholder((idxd(CO, VC), KH, KW, CI, VC), dtype=kernel.dtype)
new_workload = autotvm.task.args_to_workload(
[new_data, new_kernel, strides, padding, dilation, out_dtype],
- "conv2d_nhwc_spatial_pack.arm_cpu")
+ 'conv2d_nhwc_spatial_pack.arm_cpu')
dispatch_ctx.update(target, new_workload, cfg)
- return relay.nn.conv2d(*inputs, **new_attrs)
+ return relay.nn.conv2d(data_int16, kernel_int16, **new_attrs)
if topi_tmpl == "conv2d_nchw_winograd.arm_cpu":
assert data_layout == "NCHW" and kernel_layout == "OIHW"
@@ -235,5 +241,40 @@ def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
new_attrs['out_layout'], out_dtype], topi_tmpl)
dispatch_ctx.update(target, new_workload, cfg)
return relay.nn.contrib_depthwise_conv2d_nchwc(*inputs, **new_attrs)
+ if topi_tmpl == "conv2d_NHWC_quantized.arm_cpu":
+ assert (data.dtype == 'int8' and kernel.dtype == 'int8' or
+ data.dtype == 'uint8' and kernel.dtype == 'uint8')
+ assert data_layout == "NHWC" and kernel_layout == "HWIO"
+ CO, IC, KH, KW = get_const_tuple(kernel.shape)
+ K = KH * KW * IC
+ N = CO
+
+ tile_rows = 4
+ tile_cols = 16
+ pad_K = 0
+ pad_N = 0
+
+ if N % tile_rows != 0:
+ pad_N = tile_rows - (N % tile_rows)
+ if K % tile_cols != 0:
+ pad_k = tile_cols - (K % tile_cols)
+
+ N_padded = N + pad_N
+ K_padded = K + pad_K
+ kernel_expr = relay.nn.contrib_conv2d_gemm_weight_transform(inputs[1], tile_rows, tile_cols)
+ new_kernel = te.placeholder((N_padded // tile_rows,
+ K_padded // tile_cols,
+ tile_rows,
+ tile_cols), kernel.dtype)
+
+ new_workload = autotvm.task.args_to_workload([data, new_kernel,
+ strides, padding, dilation,
+ out_dtype, (KH, KW), CO],
+ "conv2d_NHWC_int8_without_tranform.arm_cpu")
+ dispatch_ctx.update(target, new_workload, cfg)
+
+ return relay.nn.contrib_conv2d_gemm_without_weight_transform(inputs[0],
+ kernel_expr,
+ **new_attrs)
return None
diff --git a/topi/python/topi/arm_cpu/conv2d_gemm.py b/topi/python/topi/arm_cpu/conv2d_gemm.py
new file mode 100644
index 0000000..2b61229
--- /dev/null
+++ b/topi/python/topi/arm_cpu/conv2d_gemm.py
@@ -0,0 +1,174 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, unused-variable, too-many-locals
+# pylint: disable=unused-argument, redefined-builtin
+"""GEMM Convolution schedule on ARM"""
+import tvm
+from tvm import te
+from topi import nn
+from ..util import get_const_tuple
+from ..nn.util import get_pad_tuple
+from .tensor_intrin import gemv_quantized, gemv_quantized_impl
+
+
+# Compute function
+def compute_conv2d_gemm_without_weight_transform(cfg,
+ data, B_interleaved_t, strides, padding, dilation,
+ out_dtype, kernel_size, output_channels):
+ """Compute conv2d by transforming the input,
+ executing GEMM and transforming the output back"""
+ batches, IH, IW, IC = get_const_tuple(data.shape)
+
+ KH, KW = kernel_size
+ OC = output_channels
+
+ K_AREA = KH * KW
+
+ if isinstance(dilation, int):
+ dilation_h = dilation_w = dilation
+ else:
+ dilation_h, dilation_w = dilation
+
+ dilated_kernel_h = (KH - 1) * dilation_h + 1
+ dilated_kernel_w = (KW - 1) * dilation_w + 1
+
+ pad_top, pad_left, pad_down, pad_right = \
+ get_pad_tuple(padding, (dilated_kernel_h, dilated_kernel_w))
+ HSTR, WSTR = strides if isinstance(strides, (tuple, list)) else (strides, strides)
+
+ OH = (IH + pad_top + pad_down - dilated_kernel_h) // HSTR + 1
+ OW = (IW + pad_left + pad_right - dilated_kernel_w) // WSTR + 1
+ if pad_top or pad_left:
+ data_pad = nn.pad(data, [0, pad_top, pad_left, 0], [0, pad_down, pad_right, 0],
+ name="data_pad")
+ else:
+ data_pad = data
+
+ # --- Im2col
+ M = OH * OW
+ K = IC * K_AREA
+ N = OC
+
+ A_shape = (batches, M, K)
+ if K_AREA == 1:
+ A = te.compute(A_shape, lambda n, x, y: data_pad[n, HSTR * (x // OW), WSTR * (x % OW), y],
+ name='data_flatten')
+ else:
+ A = te.compute(A_shape, lambda n, x, y:
+ data_pad[n,
+ HSTR * (x // OW) + dilation_h * (y // IC) // KW,
+ WSTR * (x % OW) + dilation_w * (y // IC) % KW, y % IC],
+ name='data_im2col')
+ N_transformed = B_interleaved_t.shape[0]
+
+ # --- Pad if necessary
+ idxm = tvm.tir.indexmod
+
+ pad_m = 0
+ pad_k = 0
+
+ if M % 4 != 0:
+ pad_m = 4 - (M % 4)
+
+ if K % 16 != 0:
+ pad_k = 16 - (K % 16)
+
+ M_padded = M + pad_m
+ K_padded = K + pad_k
+
+ pad_before = (0, 0, 0)
+ pad_after = (0, pad_m, pad_k)
+
+ if pad_m != 0 or pad_k != 0:
+ A = nn.pad(A, pad_before=pad_before, pad_after=pad_after, name="A_padded")
+
+ # --- GEMM: A*B'
+ k = te.reduce_axis((0, K_padded), "k")
+
+ A_interleaved = te.compute((batches, M_padded // 4, K_padded // 16, 4, 16),
+ lambda b, x, y, z, w: A[b, z + 4 * x, w + 16 * y],
+ name='A_interleaved')
+
+ C_interleaved = te.compute((batches, M_padded // 4, N_transformed, 4, 4),
+ lambda b, x, y, w, z:
+ te.sum(A_interleaved[b, x, k//16, w, idxm(k, 16)].astype(out_dtype)*
+ B_interleaved_t[y, k//16, z, idxm(k, 16)].astype(out_dtype),
+ axis=k),
+ name='C_interleaved')
+
+ # --- Unpack C
+ C = te.compute((batches, M, N),
+ lambda b, x, y:
+ C_interleaved[b, x // 4, y // 4, idxm(x, 4), idxm(y, 4)],
+ name="C", tag='injective')
+
+ # --- Produce the conv output
+ out_shape = (batches, OH, OW, OC)
+ out = te.compute(out_shape, lambda b, x, y, z: C(b, y + OW * x, z),
+ name='conv2d_gemm_output')
+
+ return out
+
+# Schedules
+def schedule_conv2d_gemm(cfg, s, out):
+ """Create schedule for tensors"""
+ C = out.op.input_tensors[0]
+ C_interleaved = C.op.input_tensors[0]
+ A_interleaved = C_interleaved.op.input_tensors[0]
+
+ # Input transform
+ A_interleaved_input = A_interleaved.op.input_tensors[0]
+ if A_interleaved_input.op.name == "A_padded":
+ s[A_interleaved_input].compute_at(s[A_interleaved], A_interleaved.op.axis[3])
+ s[A_interleaved_input].vectorize(A_interleaved_input.op.axis[2])
+ s[A_interleaved_input].compute_inline()
+ data_im2col = A_interleaved_input.op.input_tensors[0]
+ else:
+ data_im2col = A_interleaved_input
+
+ b, m, n = data_im2col.op.axis
+ if data_im2col.op.name == "data_im2col":
+ n_outer, n_inner = s[data_im2col].split(n, 16)
+ s[data_im2col].unroll(n_outer)
+ s[data_im2col].vectorize(n_inner)
+ else:
+ s[data_im2col].compute_inline()
+
+ # Computation(through tensorize)
+ b, xo, yo, xi, yi = C_interleaved.op.axis
+ s[C_interleaved].reorder(xo, yo, yi, xi)
+ s[C_interleaved].parallel(xo)
+ s[A_interleaved].compute_at(s[C_interleaved], xo)
+ s[A_interleaved].vectorize(A_interleaved.op.axis[4])
+
+ in_type = A_interleaved.dtype
+ out_type = C.dtype
+ if out_type == 'int32':
+ K = A_interleaved_input.shape[2]
+ _, M, N = C.shape
+ assert in_type in ['int8', 'uint8'], "Only int8 and uint8 gemm are supported"
+
+ gem_v_dotprod = gemv_quantized(M, N, K, in_type, out_type)
+ s[C_interleaved].pragma(xo, "import_llvm", gemv_quantized_impl(M, N, in_type))
+ s[C_interleaved].tensorize(yi, gem_v_dotprod)
+
+ # Output transform
+ N, OH, OW, OC = out.shape
+ s[C].split(C.op.axis[1], OW)
+ s[C].compute_at(s[out], out.op.axis[3])
+
+ return s
diff --git a/topi/python/topi/arm_cpu/conv2d_int8.py b/topi/python/topi/arm_cpu/conv2d_int8.py
index 06412b6..5a895c0 100644
--- a/topi/python/topi/arm_cpu/conv2d_int8.py
+++ b/topi/python/topi/arm_cpu/conv2d_int8.py
@@ -19,11 +19,12 @@
from tvm import te
from tvm import autotvm
from .. import tag
-from ..util import get_const_tuple
+from ..util import traverse_inline, get_const_tuple
from ..generic import conv2d as conv2d_generic
from .. import nn
from ..nn.conv2d import _get_workload as _get_conv2d_workload
from .tensor_intrin import dot_int8_int8_int32
+from .conv2d_gemm import compute_conv2d_gemm_without_weight_transform, schedule_conv2d_gemm
def _get_default_config(cfg, data, kernel, strides, padding, out_dtype):
@@ -109,3 +110,38 @@ def schedule_conv2d_NCHWc_int8(cfg, outs):
traverse(outs[0].op)
return s
+
+
+@autotvm.register_topi_compute("conv2d_NHWC_quantized.arm_cpu")
+def compute_conv2d_NHWC_quantized(cfg, data, kernel, strides, padding, dilation, out_dtype):
+ N, IH, IW, IC = get_const_tuple(data.shape)
+ KH, KW, _, OC = get_const_tuple(kernel.shape)
+ tile_rows = 4
+ tile_cols = 16
+ kernel = nn.conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols)
+ return compute_conv2d_gemm_without_weight_transform(cfg,
+ data, kernel, strides, padding,
+ dilation, out_dtype, (KH, KW), OC)
+
+
+@autotvm.register_topi_compute("conv2d_NHWC_quantized_without_transform.arm_cpu")
+def compute_conv2d_NHWC_quantized_without_transform(cfg, data, B, strides, padding,
+ dilation, out_dtype, kernel_size=None,
+ output_channels=None):
+ return compute_conv2d_gemm_without_weight_transform(cfg, data, B, strides, padding,
+ dilation, out_dtype, kernel_size,
+ output_channels)
+
+
+@autotvm.register_topi_schedule("conv2d_NHWC_quantized.arm_cpu")
+def schedule_conv2d_NHWC_quantized(cfg, outs):
+ """Create schedule for tensors"""
+ s = te.create_schedule([x.op for x in outs])
+
+ def _callback(op):
+ """Traverse operators from computation graph"""
+ if op.name == "conv2d_gemm_output":
+ schedule_conv2d_gemm(cfg, s, op.output(0))
+
+ traverse_inline(s, outs[0].op, _callback)
+ return s
diff --git a/topi/python/topi/arm_cpu/tensor_intrin.py b/topi/python/topi/arm_cpu/tensor_intrin.py
index da9c71a..6ef2548 100644
--- a/topi/python/topi/arm_cpu/tensor_intrin.py
+++ b/topi/python/topi/arm_cpu/tensor_intrin.py
@@ -19,6 +19,345 @@
import tvm
from tvm import te
+from tvm.contrib import util, clang
+
+def gemv_quantized_impl(M, N, data_type='uint8'):
+ """ Assembly implementation of a blocked gemv. Given
+ a block a of shape (4, k) and a block b' of shape (4, k)
+ produces the output block c = a*b of shape (4,4) """
+
+ stepA = min(4, M)
+ stepB = min(4, N)
+ assert data_type in ['uint8', 'int8'], 'Only uint8/int8 supported for this implementation'
+
+ cc_code = """
+ extern "C" int gemv_{0}_{0}_int32_{1}_{2}(int *c_buffer,
+ unsigned char *a_buffer,
+ unsigned char *b_buffer,
+ int K, int m, int n)
+ """.format(data_type, stepA, stepB)
+
+ cc_code += """
+ {
+ unsigned char * a_ptr = a_buffer;
+ unsigned char * b_ptr = b_buffer;
+ int * c_ptr = c_buffer;
+
+ int k = K / 16;
+
+ __asm__ __volatile__ (
+ "movi v16.4s, #0\\n"
+ "movi v17.4s, #0\\n"
+ "movi v18.4s, #0\\n"
+ "movi v19.4s, #0\\n"
+ "movi v20.4s, #0\\n"
+ "movi v21.4s, #0\\n"
+ "movi v22.4s, #0\\n"
+ "movi v23.4s, #0\\n"
+ "movi v24.4s, #0\\n"
+ "movi v25.4s, #0\\n"
+ "movi v26.4s, #0\\n"
+ "movi v27.4s, #0\\n"
+ "movi v28.4s, #0\\n"
+ "movi v29.4s, #0\\n"
+ "movi v30.4s, #0\\n"
+ "movi v31.4s, #0\\n"
+ "1:"
+ """
+
+ cc_code += ' "ldr q0, [%[a_ptr]]\\n" '
+
+ if M > 1:
+ cc_code += ' "ldr q1, [%[a_ptr], #16]\\n" '
+ else:
+ cc_code += ' "movi v1.4s, #0\\n" '
+
+ if M > 2:
+ cc_code += ' "ldr q2, [%[a_ptr], #32]\\n" '
+ else:
+ cc_code += ' "movi v2.4s, #0\\n" '
+
+ if M > 3:
+ cc_code += ' "ldr q3, [%[a_ptr], #48]\\n" '
+ else:
+ cc_code += ' "movi v3.4s, #0\\n" '
+
+ cc_code += ' "ldr q4, [%[b_ptr]]\\n" '
+
+ if N > 1:
+ cc_code += ' "ldr q5, [%[b_ptr], #16]\\n" '
+
+ if N > 2:
+ cc_code += ' "ldr q6, [%[b_ptr], #32]\\n" '
+
+ if N > 3:
+ cc_code += ' "ldr q7, [%[b_ptr], #48]\\n" '
+
+ cc_code += """
+ // First half
+ // Higher part of a0 * {b0,b1,b2,b3}
+ "umull v8.8h, v0.8b, v4.8b\\n"
+ "umull v9.8h, v0.8b, v5.8b\\n"
+ "umull v10.8h, v0.8b, v6.8b\\n"
+ "umull v11.8h, v0.8b, v7.8b\\n"
+
+ // Higher part of a1 * {b0,b1,b2,b3}
+ "umull v12.8h, v1.8b, v4.8b\\n"
+ "umull v13.8h, v1.8b, v5.8b\\n"
+ "umull v14.8h, v1.8b, v6.8b\\n"
+ "umull v15.8h, v1.8b, v7.8b\\n"
+
+ // Accumulate
+ "uadalp v16.4s, v8.8h\\n"
+ "uadalp v17.4s, v9.8h\\n"
+ "uadalp v18.4s, v10.8h\\n"
+ "uadalp v19.4s, v11.8h\\n"
+ "uadalp v20.4s, v12.8h\\n"
+ "uadalp v21.4s, v13.8h\\n"
+ "uadalp v22.4s, v14.8h\\n"
+ "uadalp v23.4s, v15.8h\\n"
+
+ // Lower part of a0 * {b0,b1,b2,b3}
+ "umull2 v8.8h, v0.16b, v4.16b\\n"
+ "umull2 v9.8h, v0.16b, v5.16b\\n"
+ "umull2 v10.8h, v0.16b, v6.16b\\n"
+ "umull2 v11.8h, v0.16b, v7.16b\\n"
+
+ // Lower part of a1 * {b0,b1,b2,b3}
+ "umull2 v12.8h, v1.16b, v4.16b\\n"
+ "umull2 v13.8h, v1.16b, v5.16b\\n"
+ "umull2 v14.8h, v1.16b, v6.16b\\n"
+ "umull2 v15.8h, v1.16b, v7.16b\\n"
+
+ // Accumulate again
+ "uadalp v16.4s, v8.8h\\n"
+ "uadalp v17.4s, v9.8h\\n"
+ "uadalp v18.4s, v10.8h\\n"
+ "uadalp v19.4s, v11.8h\\n"
+ "uadalp v20.4s, v12.8h\\n"
+ "uadalp v21.4s, v13.8h\\n"
+ "uadalp v22.4s, v14.8h\\n"
+ "uadalp v23.4s, v15.8h\\n"
+
+ // Second half
+
+ // Lower part of a2 * {b0,b1,b2,b3}
+ "umull v8.8h, v2.8b, v4.8b\\n"
+ "umull v9.8h, v2.8b, v5.8b\\n"
+ "umull v10.8h, v2.8b, v6.8b\\n"
+ "umull v11.8h, v2.8b, v7.8b\\n"
+
+ // Lower part of a3 * {b0,b1,b2,b3}
+ "umull v12.8h, v3.8b, v4.8b\\n"
+ "umull v13.8h, v3.8b, v5.8b\\n"
+ "umull v14.8h, v3.8b, v6.8b\\n"
+ "umull v15.8h, v3.8b, v7.8b\\n"
+
+ // Accumulate
+ "uadalp v24.4s, v8.8h\\n"
+ "uadalp v25.4s, v9.8h\\n"
+ "uadalp v26.4s, v10.8h\\n"
+ "uadalp v27.4s, v11.8h\\n"
+ "uadalp v28.4s, v12.8h\\n"
+ "uadalp v29.4s, v13.8h\\n"
+ "uadalp v30.4s, v14.8h\\n"
+ "uadalp v31.4s, v15.8h\\n"
+
+ // Higher part of a2 * {b0,b1,b2,b3}
+ "umull2 v8.8h, v2.16b, v4.16b\\n"
+ "umull2 v9.8h, v2.16b, v5.16b\\n"
+ "umull2 v10.8h, v2.16b, v6.16b\\n"
+ "umull2 v11.8h, v2.16b, v7.16b\\n"
+
+ // Higher part of a3 * {b0,b1,b2,b3}
+ "umull2 v12.8h, v3.16b, v4.16b\\n"
+ "umull2 v13.8h, v3.16b, v5.16b\\n"
+ "umull2 v14.8h, v3.16b, v6.16b\\n"
+ "umull2 v15.8h, v3.16b, v7.16b\\n"
+
+ // Accumulate again
+ "uadalp v24.4s, v8.8h\\n"
+ "uadalp v25.4s, v9.8h\\n"
+ "uadalp v26.4s, v10.8h\\n"
+ "uadalp v27.4s, v11.8h\\n"
+ "uadalp v28.4s, v12.8h\\n"
+ "uadalp v29.4s, v13.8h\\n"
+ "uadalp v30.4s, v14.8h\\n"
+ "uadalp v31.4s, v15.8h\\n"
+ """
+ blockA = min(64, M * 16)
+ blockB = min(64, N * 16)
+
+ cc_code += """
+ // Increment pointers and decrement k
+ "add %[a_ptr], %[a_ptr], #{0}\\n"
+ "add %[b_ptr], %[b_ptr], #{1}\\n"
+ "subs %w[k], %w[k], #1\\n"
+ """.format(blockA, blockB)
+
+ stepC = min(4, N)
+
+ cc_code += """
+ "cbnz %w[k], 1b\\n"
+
+ // Final additions
+
+ // v16 contains the four partial sums of a[0, 0:K].*b[0,0:K], let's call them (a,b,c,d)
+ // v17 contains the four partial sums of a[0, 0:K].*b[1,0:K], let's call them (e,f,g,h)
+ // v18 contains the four partial sums of a[0, 0:K].*b[2,0:K], let's call them (i,j,k,l)
+ // v19 contains the four partial sums of a[0, 0:K].*b[3,0:K], let's call them (m,n,o,p)
+ "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b, c+d, e+f, g+h)
+ "addp v17.4s, v18.4s, v19.4s\\n" // v17 = (i+j, k+l, m+n, o+p)
+ "addp v16.4s, v16.4s, v17.4s\\n" // v16 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
+
+ // v20 contains the four partial sums of a[1, 0:K].*b[0,0:K], let's call them (a,b,c,d)
+ // v21 contains the four partial sums of a[1, 0:K].*b[1,0:K], let's call them (e,f,g,h)
+ // v22 contains the four partial sums of a[1, 0:K].*b[2,0:K], let's call them (i,j,k,l)
+ // v23 contains the four partial sums of a[1, 0:K].*b[3,0:K], let's call them (m,n,o,p)
+ "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b, c+d, e+f, g+h)
+ "addp v21.4s, v22.4s, v23.4s\\n" // v21 = (i+j, k+l, m+n, o+p)
+ "addp v20.4s, v20.4s, v21.4s\\n" // v20 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
+
+ // v24 contains the four partial sums of a[2, 0:K].*b[0,0:K], let's call them (a,b,c,d)
+ // v25 contains the four partial sums of a[2, 0:K].*b[1,0:K], let's call them (e,f,g,h)
+ // v26 contains the four partial sums of a[2, 0:K].*b[2,0:K], let's call them (i,j,k,l)
+ // v27 contains the four partial sums of a[2, 0:K].*b[3,0:K], let's call them (m,n,o,p)
+ "addp v24.4s, v24.4s, v25.4s\\n" // v24 = (a+b, c+d, e+f, g+h)
+ "addp v25.4s, v26.4s, v27.4s\\n" // v25 = (i+j, k+l, m+n, o+p)
+ "addp v24.4s, v24.4s, v25.4s\\n" // v24 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
+
+ // v28 contains the four partial sums of a[3, 0:K].*b[0,0:K], let's call them (a,b,c,d)
+ // v29 contains the four partial sums of a[3, 0:K].*b[1,0:K], let's call them (e,f,g,h)
+ // v30 contains the four partial sums of a[3, 0:K].*b[2,0:K], let's call them (i,j,k,l)
+ // v31 contains the four partial sums of a[3, 0:K].*b[3,0:K], let's call them (m,n,o,p)
+ "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b, c+d, e+f, g+h)
+ "addp v29.4s, v30.4s, v31.4s\\n" // v29 = (i+j, k+l, m+n, o+p)
+ "addp v28.4s, v28.4s, v29.4s\\n" // v28 = (a+b+c+d, e+f+g+h, i+j+k+l, m+n+o+p)
+
+ "str q16, [%[c_ptr]]\\n"
+ """
+
+ if M > 1:
+ cc_code += ' "str q20, [%[c_ptr], #{0}]\\n" '.format(stepC * 4)
+
+ if M > 2:
+ cc_code += ' "str q24, [%[c_ptr], #{0}]\\n" '.format(stepC * 8)
+
+ if M > 3:
+ cc_code += ' "str q28, [%[c_ptr], #{0}]\\n" '.format(stepC * 12)
+
+ cc_code += """
+ : [c_ptr] "+r" (c_ptr), [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [k] "+r" (k)
+ :
+ : "cc", "memory", "v0", "v1", "v2", "v3", "v4", "v5", "v6",
+ "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16",
+ "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26",
+ "v27", "v28", "v29", "v30", "v31"
+ );
+ return 0;
+ }
+ """
+
+ if data_type == 'int8':
+ cc_code = cc_code.replace('unsigned char', 'char')
+ cc_code = cc_code.replace('umull', 'smull')
+ cc_code = cc_code.replace('uadalp', 'sadalp')
+
+ temp = util.tempdir()
+ ll_path = temp.relpath("temp.ll")
+ # Create LLVM ir from c source code
+ ll_code = clang.create_llvm(cc_code,
+ options=["--target=aarch64-linux-gnu -mattr=+neon"],
+ output=ll_path)
+ return ll_code
+
+
+def gemv_quantized(M, N, K, in_type, out_type):
+ """
+ Use integer ARM v8 instructions in order to produce a block c of 4x4 elements
+ given two 4xK blocks a and b' (where b' is a Kx4 block transposed). The final
+ result is c = a*b (where '*' indicates the matrix product)
+
+ Every row of the matrix c is obtained (for uint8) by a sequence of
+
+ umull -> uadalp -> umull2 -> uadalp
+
+ The block size is constrained by the number of registers available in arvm8. This
+ function returns a TensorIntrin that can be used to tensorize
+ a schedule.
+
+ Parameters
+ ----------
+ M: int
+ rows of the matrix A
+ N: int
+ columns of the matrix B
+ K: int
+ columns of matrix A
+ in_type: str, {'uint8', 'int8'}
+ out_type: str, {'uint32', 'int32'}
+
+ Returns
+ -------
+ intrin : TensorIntrin
+ The ARM uint8/int8 TensorIntrin that can be used in tensorizing schedule
+ """
+ A = te.placeholder((K // 16, te.var("m"), 16), dtype=in_type, name='A')
+ B = te.placeholder((K // 16, te.var("n"), 16), dtype=in_type, name='B')
+
+ idxm = tvm.tir.indexmod
+
+ k = te.reduce_axis((0, K), "k")
+
+ C = te.compute((te.var("m"), te.var("n")),
+ lambda x, y: te.sum(A[k // 16, x, idxm(k, 16)].astype(out_type) *
+ B[k // 16, y, idxm(k, 16)].astype(out_type),
+ axis=k), name="C")
+
+ a_buffer = tvm.tir.decl_buffer(A.shape, dtype=in_type, name="a_buffer",
+ offset_factor=1, strides=[te.var('sa_1'), te.var('sa_2'), 1])
+
+ b_buffer = tvm.tir.decl_buffer(B.shape, dtype=in_type, name="b_buffer",
+ offset_factor=1, strides=[te.var('sb_1'), te.var('sb_2'), 1])
+
+ c_buffer = tvm.tir.decl_buffer(C.shape, dtype=out_type, name="c_buffer",
+ offset_factor=1, strides=[te.var('sc'), 1])
+
+ def _intrin_func(ins, outs):
+
+ def _instr():
+ ib = tvm.tir.ir_builder.create()
+ aa, bb = ins
+ cc = outs[0]
+ stepA = min(4, M)
+ stepB = min(4, N)
+
+ if in_type == 'int8':
+ ib.emit(tvm.tir.call_extern("int32",
+ "gemv_int8_int8_int32_{0}_{1}".format(stepA, stepB),
+ outs[0].access_ptr("w"),
+ a_buffer.access_ptr("r"),
+ b_buffer.access_ptr("r"),
+ K))
+ else:
+ ib.emit(tvm.tir.call_extern("int32",
+ "gemv_uint8_uint8_int32_{0}_{1}".format(stepA, stepB),
+ c_buffer.access_ptr("w"),
+ a_buffer.access_ptr("r"),
+ b_buffer.access_ptr("r"),
+ K,
+ C.shape[0], # m, very useful for debug
+ C.shape[1])) # n, very useful for debug
+ return ib.get()
+
+ # body, reset, update
+ return _instr()
+
+ buffer_params = {"offset_factor": 1}
+ return te.decl_tensor_intrin(C.op, _intrin_func,
+ binds={A:a_buffer, B:b_buffer, C:c_buffer},
+ default_buffer_params=buffer_params)
+
def dot_int8_int8_int32(int32_lanes, dtype='uint'):
"""
diff --git a/topi/python/topi/generic/nn.py b/topi/python/topi/generic/nn.py
index 767087b..7645588 100644
--- a/topi/python/topi/generic/nn.py
+++ b/topi/python/topi/generic/nn.py
@@ -187,6 +187,25 @@ def schedule_conv2d_winograd_weight_transform(outs):
return s
+def schedule_conv2d_gemm_weight_transform(outs):
+ """Schedule for weight transformation of gemm
+
+ Parameters
+ ----------
+ outs: Array of Tensor
+ The computation graph description of this operator
+ in the format of an array of tensors.
+
+ Returns
+ -------
+ sch: Schedule
+ The computation schedule for the op.
+ """
+ # Typically this is computed in PreCompute pass
+ s = te.create_schedule([x.op for x in outs])
+ return s
+
+
def schedule_conv3d_winograd_weight_transform(outs):
"""Schedule for weight transformation of 3D winograd
diff --git a/topi/python/topi/nn/conv2d.py b/topi/python/topi/nn/conv2d.py
index 4c7941b..5928889 100644
--- a/topi/python/topi/nn/conv2d.py
+++ b/topi/python/topi/nn/conv2d.py
@@ -590,6 +590,55 @@ def conv2d_NCHWc_int8(data, kernel, stride, padding, dilation, layout, out_layou
name='conv2d_NCHWc_int8', tag="conv2d_NCHWc_int8")
+def conv2d_gemm_weight_transform(kernel, tile_rows, tile_cols):
+ """Weight transformation for winograd
+
+ Parameters
+ ----------
+ kernel: Tensor
+ The raw kernel tensor with layout "NHWC".
+ tile_rows: int
+ Tile rows of the weight transformation for ConvGemm.
+ tile_cols: int
+ Tile columns of the weight transformation for ConvGemm.
+
+ Returns
+ -------
+ output : tvm.te.Tensor
+ 2-D with shape [CI*KH*KW,CO]
+ """
+ KH, KW, IC, OC = get_const_tuple(kernel.shape)
+ K = KH * KW * IC
+ N = OC
+
+ kernel_flat = te.compute((K, N), lambda x, y:
+ kernel[(x // IC) // KW, (x // IC) % KW, x % IC, y],
+ 'weight_flatten')
+
+ pad_K = 0
+ pad_N = 0
+
+ if N % tile_rows != 0:
+ pad_N = tile_rows - (N % tile_rows)
+
+ if K % tile_cols != 0:
+ pad_k = tile_cols - (K % tile_cols)
+
+ N_padded = N + pad_N
+ K_padded = K + pad_K
+
+ if pad_K != 0 or pad_N != 0:
+ kernel_flat = pad(kernel_flat, pad_before=(0, 0), pad_after=(pad_K, pad_N),
+ name='weight_padding')
+
+ return te.compute((N_padded // tile_rows,
+ K_padded // tile_cols,
+ tile_rows,
+ tile_cols), lambda x, y, z, w:
+ kernel_flat[w + tile_cols * y, z + tile_rows * x],
+ name='weight_block_reshape')
+
+
def conv2d_winograd_weight_transform(kernel, tile_size):
"""Weight transformation for winograd