You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/06/10 19:38:28 UTC

[GitHub] [tvm] altanh commented on a change in pull request #8234: [Matmul] Add matmul op

altanh commented on a change in pull request #8234:
URL: https://github.com/apache/tvm/pull/8234#discussion_r649363398



##########
File path: python/tvm/relay/frontend/tensorflow.py
##########
@@ -1230,6 +1239,9 @@ def from_tensorflow(graph, layout="NHWC", shape=None, outputs=None):
     params : dict of str to tvm.nd.NDArray
         Dict of converted parameters stored in tvm.nd.NDArray format
     """
+    global _USE_DENSE_INSTEAD_OF_MATMUL

Review comment:
       is it possible to avoid using this global variable? I'm not familiar with the importer but would be nice if we could use an importer config dict or something

##########
File path: include/tvm/relay/attrs/nn.h
##########
@@ -961,19 +961,29 @@ struct AvgPool3DAttrs : public tvm::AttrsNode<AvgPool3DAttrs> {
   }
 };
 
-/*! \brief Attributes for dense operator */
-struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
+/*! \brief Attributes for matmul operator and dense operator */
+struct MatmulAttrs : public tvm::AttrsNode<MatmulAttrs> {
   IndexExpr units;
-  tvm::String auto_scheduler_rewritten_layout;  // The layout after auto-scheduler's layout rewrite
   DataType out_dtype;
+  bool data_transposed;

Review comment:
       nit: wonder if we should use `transpose_data` and `transpose_weight` to closer match the existing frameworks (e.g. BLAS libs use `transa`, Tensorflow with `transpose_a`)

##########
File path: python/tvm/relay/op/nn/_nn.py
##########
@@ -52,6 +52,32 @@
 reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
 
 
+@reg.register_legalize("nn.matmul")
+def leaglize_matmul(attrs, inputs, types):
+    """Legalize matmul op.

Review comment:
       could you summarize what the legalization does for this op?

##########
File path: python/tvm/relay/op/nn/nn.py
##########
@@ -1496,11 +1535,21 @@ def dense(data, weight, units=None, out_dtype=""):
         Specifies the output data type for mixed precision dense,
         of shape `(d_1, d_2, ..., d_n, units)`.
 
+    data_transposed : bool, optional
+        Whether the data tensor is in transposed format. Expected to be False.
+
+    weight_transposed : bool, optional
+        Whether the weight tensor is in transposed format. Expected to be True.
+
     Returns
     -------
     result : tvm.relay.Expr
         The computed result.
     """
+    # Add data_transposed & weight_transposed parameters for some API requires to apply

Review comment:
       could you explain why?

##########
File path: tests/python/relay/test_op_level1.py
##########
@@ -426,7 +486,7 @@ def test_dense():
     for dtype in ["float16", "float32"]:
         # Dense accuracy for float16 is poor
         if dtype == "float16":
-            return
+            continue

Review comment:
       did some PR affect this?

##########
File path: src/relay/op/nn/nn.h
##########
@@ -83,11 +90,12 @@ bool DenseRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
     } else {
       ICHECK(static_cast<int>(weight->shape.size()) == 2);
       if (!data->shape.back().as<tir::AnyNode>()) {
-        ICHECK(reporter->AssertEQ(data->shape[data->shape.size() - 1], weight->shape[1]))
-            << "DenseRel: input dimension doesn't match,"
+        ICHECK((param->weight_transposed && reporter->AssertEQ(reduce, weight->shape[1])) ||

Review comment:
       could you try using diagnostics for this?

##########
File path: python/tvm/topi/nn/dense.py
##########
@@ -51,37 +65,120 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo
         assert len(bias.shape) == 1
     if out_dtype is None:
         out_dtype = data.dtype
-    batch, in_dim = data.shape
+    if data_transposed:
+        in_dim, batch = data.shape
+    else:
+        batch, in_dim = data.shape
 
     if auto_scheduler_rewritten_layout:
         # Infer shape for the rewritten layout
         out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
-            auto_scheduler_rewritten_layout, ["j", "k"]
+            auto_scheduler_rewritten_layout, ["j", "k"] if weight_transposed else ["k", "j"]
         )
         auto_scheduler.remove_index_check(weight)
-    else:
+    elif weight_transposed:
         out_dim, red_dim = weight.shape
+    else:
+        red_dim, out_dim = weight.shape
     assert in_dim == red_dim
 
     k = te.reduce_axis((0, in_dim), name="k")
-    matmul = te.compute(
+    if data_transposed:
+        if weight_transposed:
+            compute_lambda = lambda i, j: te.sum(
+                data[k, i].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k
+            )
+            compute_name = "T_matmul_TT"
+        else:
+            compute_lambda = lambda i, j: te.sum(
+                data[k, i].astype(out_dtype) * weight[k, j].astype(out_dtype), axis=k
+            )
+            compute_name = "T_matmul_TN"
+        compute_tag = "matmul"
+    else:
+        if weight_transposed:
+            compute_lambda = lambda i, j: te.sum(
+                data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k
+            )
+            compute_name = "T_dense"
+            compute_tag = "dense"
+        else:
+            compute_lambda = lambda i, j: te.sum(
+                data[i, k].astype(out_dtype) * weight[k, j].astype(out_dtype), axis=k
+            )
+            compute_name = "T_matmul"

Review comment:
       `T_matmul_NN` for consistency

##########
File path: python/tvm/topi/nn/dense.py
##########
@@ -51,37 +65,120 @@ def dense(data, weight, bias=None, out_dtype=None, auto_scheduler_rewritten_layo
         assert len(bias.shape) == 1
     if out_dtype is None:
         out_dtype = data.dtype
-    batch, in_dim = data.shape
+    if data_transposed:
+        in_dim, batch = data.shape
+    else:
+        batch, in_dim = data.shape
 
     if auto_scheduler_rewritten_layout:
         # Infer shape for the rewritten layout
         out_dim, red_dim = auto_scheduler.get_shape_from_rewritten_layout(
-            auto_scheduler_rewritten_layout, ["j", "k"]
+            auto_scheduler_rewritten_layout, ["j", "k"] if weight_transposed else ["k", "j"]
         )
         auto_scheduler.remove_index_check(weight)
-    else:
+    elif weight_transposed:
         out_dim, red_dim = weight.shape
+    else:
+        red_dim, out_dim = weight.shape
     assert in_dim == red_dim
 
     k = te.reduce_axis((0, in_dim), name="k")
-    matmul = te.compute(
+    if data_transposed:
+        if weight_transposed:
+            compute_lambda = lambda i, j: te.sum(
+                data[k, i].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k
+            )
+            compute_name = "T_matmul_TT"
+        else:
+            compute_lambda = lambda i, j: te.sum(
+                data[k, i].astype(out_dtype) * weight[k, j].astype(out_dtype), axis=k
+            )
+            compute_name = "T_matmul_TN"
+        compute_tag = "matmul"
+    else:
+        if weight_transposed:
+            compute_lambda = lambda i, j: te.sum(
+                data[i, k].astype(out_dtype) * weight[j, k].astype(out_dtype), axis=k
+            )
+            compute_name = "T_dense"

Review comment:
       do we need to keep this as `dense` or can we unify it to be `T_matmul_NT`?

##########
File path: python/tvm/relay/op/nn/_nn.py
##########
@@ -52,6 +52,32 @@
 reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
 
 
+@reg.register_legalize("nn.matmul")
+def leaglize_matmul(attrs, inputs, types):
+    """Legalize matmul op.
+
+    Parameters
+    ----------
+    attrs : tvm.ir.Attrs
+        Attributes of current convolution

Review comment:
       ```suggestion
           Attributes of current matmul
   ```

##########
File path: python/tvm/relay/op/nn/_nn.py
##########
@@ -52,6 +52,32 @@
 reg.register_pattern("nn.log_softmax", OpPattern.OPAQUE)
 
 
+@reg.register_legalize("nn.matmul")
+def leaglize_matmul(attrs, inputs, types):

Review comment:
       typo, and can we standardize it with `topi.nn.matmul_legalize`? (choose one of `legalize_matmul` or `matmul_legalize`, I prefer the first)

##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -698,6 +698,26 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):
     return strategy
 
 
+@matmul_strategy.register(["cuda", "gpu"])
+def matmul_strategy_cuda(attrs, inputs, out_type, target):
+    """dense cuda strategy"""
+    strategy = _op.OpStrategy()
+    if target.kind.name == "cuda" and "cublas" in target.libs:
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.cuda.matmul_cublas),
+            wrap_topi_schedule(topi.cuda.schedule_matmul_cublas),
+            name="matmul_cublas.cuda",
+            plevel=25,
+        )
+    if is_auto_scheduler_enabled():
+        strategy.add_implementation(
+            wrap_compute_matmul(topi.nn.matmul),
+            naive_schedule,
+            name="matmul.cuda",
+        )

Review comment:
       is it possible to fallback to dense schedules when the layout is `NT` or are we going to try and unify it all to matmul in a later PR?

##########
File path: python/tvm/relay/op/strategy/cuda.py
##########
@@ -698,6 +698,26 @@ def conv1d_transpose_strategy_cuda(attrs, inputs, out_type, target):
     return strategy
 
 
+@matmul_strategy.register(["cuda", "gpu"])
+def matmul_strategy_cuda(attrs, inputs, out_type, target):
+    """dense cuda strategy"""

Review comment:
       ```suggestion
       """matmul cuda strategy"""
   ```




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org