You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2023/04/19 21:17:31 UTC

[tvm] branch unity updated: [Unity] Update docs for operators (#14659)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch unity
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/unity by this push:
     new ec89242fbd [Unity] Update docs for operators (#14659)
ec89242fbd is described below

commit ec89242fbd074acea7f77925eec49cdc21bbcb79
Author: Yixin Dong <ub...@gmail.com>
AuthorDate: Thu Apr 20 05:17:24 2023 +0800

    [Unity] Update docs for operators (#14659)
    
    This PR:
    
    updates docs for several operators,
    add an 1 default value for the axis parameter of batch_norm,
    rename module tvm.relax.transform.legalize_ops.creation to tvm.relax.transform.legalize_ops.create, which is align with tvm.relax.op.create, and fixes a previous upstream error in tests/python/relax/test_transform_legalize_ops_grad.py
---
 python/tvm/relax/op/_op_gradient.py                |  6 ++--
 python/tvm/relax/op/grad/grad.py                   | 12 ++++----
 python/tvm/relax/op/nn/nn.py                       | 19 +++++++++++--
 .../tvm/relax/transform/legalize_ops/__init__.py   |  2 +-
 .../legalize_ops/{creation.py => create.py}        |  0
 .../relax/transform/legalize_ops/statistical.py    |  4 +++
 .../relax/test_transform_legalize_ops_grad.py      | 32 ++++++++++------------
 7 files changed, 44 insertions(+), 31 deletions(-)

diff --git a/python/tvm/relax/op/_op_gradient.py b/python/tvm/relax/op/_op_gradient.py
index 9de7370545..b0e37a9418 100644
--- a/python/tvm/relax/op/_op_gradient.py
+++ b/python/tvm/relax/op/_op_gradient.py
@@ -573,7 +573,7 @@ def mean_grad(
     Backward:
         Returns `[broadcast_to(y_output_grad, x.shape) / prod(x.shape[i] for i in axis)]`.
 
-        If `keepdims=False`, the meaned axis will be added back.
+        If `keepdims=False`, the mean axis will be added back.
     """
     axis = orig_call.attrs.axis
     keepdims = orig_call.attrs.keepdims
@@ -749,7 +749,7 @@ def cumsum_grad(
         `y = relax.cumsum(x, axis)`
 
     Backward:
-        The "reversed" cumsum along the same axis. Implement by some tricks now.
+        The "reversed" cumsum along the same axis. Implemented by some tricks now.
     """
 
     axis = orig_call.attrs["axis"]
@@ -786,7 +786,7 @@ def take_grad(
         `y = relax.take(x, indices, axis)`
 
     Backward:
-        Returns .
+        Returns [x_grad, no_grad].
 
         The second parameter, the indices, is not differentiable.
     """
diff --git a/python/tvm/relax/op/grad/grad.py b/python/tvm/relax/op/grad/grad.py
index b433dc9c60..e1f1591876 100644
--- a/python/tvm/relax/op/grad/grad.py
+++ b/python/tvm/relax/op/grad/grad.py
@@ -51,8 +51,8 @@ def nll_loss_backward(
     reduction: str = "mean",
     ignore_index: int = -100,
 ) -> Expr:
-    """Backward operator of relax.nll_loss. All parameters except output_grad is the same as
-    relax.nll_loss. Returns the gradient w.r.t. predictions.
+    """Backward operator of relax.nn.nll_loss. All parameters except output_grad is the same as
+    relax.nn.nll_loss. Returns the gradient w.r.t. predictions.
 
     Parameters
     ----------
@@ -80,8 +80,8 @@ def max_pool2d_backward(
     layout: str = "NCHW",
     out_layout: Optional[str] = None,
 ) -> Expr:
-    """Backward operator of relax.max_pool2d. All parameters except output_grad is the same as
-    relax.max_pool2d. Returns the gradient w.r.t. data.
+    """Backward operator of relax.nn.max_pool2d. All parameters except output_grad is the same as
+    relax.nn.max_pool2d. Returns the gradient w.r.t. data.
 
     Parameters
     ----------
@@ -109,8 +109,8 @@ def avg_pool2d_backward(
     layout: str = "NCHW",
     out_layout: Optional[str] = None,
 ) -> Expr:
-    """Backward operator of relax.avg_pool2d. All parameters except output_grad is the same as
-    relax.avg_pool2d. Returns the gradient w.r.t. data.
+    """Backward operator of relax.nn.avg_pool2d. All parameters except output_grad is the same as
+    relax.nn.avg_pool2d. Returns the gradient w.r.t. data.
 
     Parameters
     ----------
diff --git a/python/tvm/relax/op/nn/nn.py b/python/tvm/relax/op/nn/nn.py
index 083bca653a..5483e7c5ee 100644
--- a/python/tvm/relax/op/nn/nn.py
+++ b/python/tvm/relax/op/nn/nn.py
@@ -667,6 +667,7 @@ def batch_norm(
 ) -> Expr:
     r"""
     Batch normalization layer (Ioffe and Szegedy, 2014).
+
     Normalizes the input at each batch, i.e. applies a transformation
     that maintains the mean activation close to 0 and the activation
     standard deviation close to 1.
@@ -676,6 +677,8 @@ def batch_norm(
         data\_mean[i] = mean(data[:,i,:,...]) \\
         data\_var[i] = var(data[:,i,:,...])
 
+    Both *mean* and *var* returns a scalar by treating the input as a vector.
+
     Then compute the normalized output, which has the same shape as input, as following:
 
     .. math::
@@ -683,8 +686,6 @@ def batch_norm(
         out[:,i,:,...] = \frac{data[:,i,:,...] - data\_mean[i]}{\sqrt{data\_var[i]+\epsilon}}
             * gamma[i] + beta[i]
 
-    Both *mean* and *var* returns a scalar by treating the input as a vector.
-
     Assume the input has size *k* on axis 1, then both ``gamma`` and ``beta``
     have shape *(k,)*.
 
@@ -703,7 +704,19 @@ def batch_norm(
 
     .. note::
 
-        This operator can be optimized away for inference.
+        This operator has two modes:
+        - Training mode.
+            - Use the mean and var computed from THIS batch to normalize.
+            - Update and then return the running mean and running var.
+        - Inference mode.
+            - Use the running_mean and running_var parameters to normalize.
+            - Do not update the running mean and running var. Just return the original value.
+
+        In the legalization stage, this operator will be legalized to the training mode by default.
+
+        You can use tvm.relax.transform.DecomposeOpsForInference to decompose the operator, so it
+        executes the inference mode computation. Similarly, use
+        tvm.relax.transform.DecomposeOpsForTraining to execute the training mode computation.
 
     Parameters
     ----------
diff --git a/python/tvm/relax/transform/legalize_ops/__init__.py b/python/tvm/relax/transform/legalize_ops/__init__.py
index 8b668e5040..613bd8970f 100644
--- a/python/tvm/relax/transform/legalize_ops/__init__.py
+++ b/python/tvm/relax/transform/legalize_ops/__init__.py
@@ -16,7 +16,7 @@
 # under the License.
 """Legalize high-level operator calls in Relax functions to call_tir."""
 from . import binary
-from . import creation
+from . import create
 from . import datatype
 from . import grad
 from . import image
diff --git a/python/tvm/relax/transform/legalize_ops/creation.py b/python/tvm/relax/transform/legalize_ops/create.py
similarity index 100%
rename from python/tvm/relax/transform/legalize_ops/creation.py
rename to python/tvm/relax/transform/legalize_ops/create.py
diff --git a/python/tvm/relax/transform/legalize_ops/statistical.py b/python/tvm/relax/transform/legalize_ops/statistical.py
index 71cf1ef808..e1f273bda0 100644
--- a/python/tvm/relax/transform/legalize_ops/statistical.py
+++ b/python/tvm/relax/transform/legalize_ops/statistical.py
@@ -47,6 +47,10 @@ def _te_mean(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor:
 def _te_variance(x: te.Tensor, axis: List[tir.IntImm], keepdims: bool) -> te.Tensor:
     dev = x - _te_mean(x, axis, True)
     return _te_mean(dev * dev, axis, keepdims)
+    # This version has better memory locality and performance
+    # But may trigger some precision problems, so we will use the previous version now
+    # mean = _te_mean(x, axis, keepdims)
+    # return _te_mean(x * x, axis, keepdims) - mean * mean
 
 
 @register_legalize("relax.mean")
diff --git a/tests/python/relax/test_transform_legalize_ops_grad.py b/tests/python/relax/test_transform_legalize_ops_grad.py
index e8f75d83a9..67d0b9194b 100644
--- a/tests/python/relax/test_transform_legalize_ops_grad.py
+++ b/tests/python/relax/test_transform_legalize_ops_grad.py
@@ -14,13 +14,10 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-import pytest
-
 import tvm
 from tvm.relax.transform import LegalizeOps
 from tvm.script import relax as R, tir as T, ir as I
 import tvm.testing
-from tvm.tir.op import div
 
 
 def test_nll_loss_backward():
@@ -207,7 +204,6 @@ def test_nll_loss_backward_no_batch():
     tvm.ir.assert_structural_equal(mod, Expected)
 
 
-@pytest.mark.skip("Regression to be fixed in the generated after merge.")
 def test_max_pool2d_backward():
     # fmt: off
     @tvm.script.ir_module
@@ -219,15 +215,9 @@ def test_max_pool2d_backward():
 
     @I.ir_module
     class Expected:
-        @R.function
-        def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data: R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10), dtype="float32"):
-            cls = Expected
-            gv = R.call_tir(cls.max_pool2d_backward, (output_grad, data), out_sinfo=R.Tensor((3, 2, 10, 10), dtype="float32"))
-            return gv
-
         @T.prim_func
-        def max_pool2d_backward(rxplaceholder: T.Buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "float32"), rxplaceholder_1: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"), T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32")):
-            T.func_attr({"tir.noalias": True})
+        def max_pool2d_backward(A: T.Buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "float32"), B: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32"), T_pool_grad: T.Buffer((T.int64(3), T.int64(2), T.int64(10), T.int64(10)), "float32")):
+            T.func_attr({"tir.noalias": T.bool(True)})
             # with T.block("root"):
             pad_temp = T.alloc_buffer((T.int64(3), T.int64(2), T.int64(15), T.int64(13)))
             maxpool_grad_argmax_v0 = T.alloc_buffer((T.int64(3), T.int64(2), T.int64(6), T.int64(5)), "int64")
@@ -235,29 +225,35 @@ def test_max_pool2d_backward():
             for ax0, ax1, ax2, ax3 in T.grid(T.int64(3), T.int64(2), T.int64(15), T.int64(13)):
                 with T.block("pad_temp"):
                     v_ax0, v_ax1, v_ax2, v_ax3 = T.axis.remap("SSSS", [ax0, ax1, ax2, ax3])
-                    T.reads(rxplaceholder_1[v_ax0, v_ax1, v_ax2 - T.int64(2), v_ax3 - T.int64(1)])
+                    T.reads(B[v_ax0, v_ax1, v_ax2 - T.int64(2), v_ax3 - T.int64(1)])
                     T.writes(pad_temp[v_ax0, v_ax1, v_ax2, v_ax3])
-                    pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(2) <= v_ax2 and v_ax2 < T.int64(12) and T.int64(1) <= v_ax3 and v_ax3 < T.int64(11), rxplaceholder_1[v_ax0, v_ax1, v_ax2 - T.int64(2), v_ax3 - T.int64(1)], T.float32(-3.4028234663852886e+38))
+                    pad_temp[v_ax0, v_ax1, v_ax2, v_ax3] = T.if_then_else(T.int64(2) <= v_ax2 and v_ax2 < T.int64(12) and T.int64(1) <= v_ax3 and v_ax3 < T.int64(11), B[v_ax0, v_ax1, v_ax2 - T.int64(2), v_ax3 - T.int64(1)], T.float32(-3.4028234663852886e+38))
             for ax0, ax1, ax2, ax3, dh, dw in T.grid(T.int64(3), T.int64(2), T.int64(6), T.int64(5), T.int64(5), T.int64(5)):
                 with T.block("maxpool_grad_argmax"):
                     v_ax0, v_ax1, v_ax2, v_ax3, v_dh, v_dw = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, dh, dw])
                     T.reads(pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw])
                     T.writes(maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3], maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3])
                     with T.init():
-                        maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] = -1
+                        maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] = T.int64(-1)
                         maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(-3.4028234663852886e+38)
-                    v_maxpool_grad_argmax_v0: T.int64 = T.Select(maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] > pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] or maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] == pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] and  maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] < v_ax0 * T.int64(390) + v_ax1 * T.int64(195) + v_ax2 * T.int64(26) + v_dh * T.int64(13) + v_ax3 * T.int64( [...]
+                    v_maxpool_grad_argmax_v0: T.int64 = T.Select(maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] > pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] or maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] == pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw] and maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] < v_ax0 * T.int64(390) + v_ax1 * T.int64(195) + v_ax2 * T.int64(26) + v_dh * T.int64(13) + v_ax3 * T.int64(2 [...]
                     v_maxpool_grad_argmax_v1: T.float32 = T.Select(maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] > pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw], maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3], pad_temp[v_ax0, v_ax1, v_ax2 * T.int64(2) + v_dh, v_ax3 * T.int64(2) + v_dw])
                     maxpool_grad_argmax_v0[v_ax0, v_ax1, v_ax2, v_ax3] = v_maxpool_grad_argmax_v0
                     maxpool_grad_argmax_v1[v_ax0, v_ax1, v_ax2, v_ax3] = v_maxpool_grad_argmax_v1
             for ax0, ax1, ax2, ax3, wh, ww in T.grid(T.int64(3), T.int64(2), T.int64(10), T.int64(10), T.int64(3), T.int64(3)):
                 with T.block("T_pool_grad"):
                     v_ax0, v_ax1, v_ax2, v_ax3, v_wh, v_ww = T.axis.remap("SSSSRR", [ax0, ax1, ax2, ax3, wh, ww])
-                    T.reads(maxpool_grad_argmax_v0[v_ax0, v_ax1, div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh, div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww], rxplaceholder[v_ax0, v_ax1, div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh, div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww])
+                    T.reads(maxpool_grad_argmax_v0[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww], A[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh, T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww])
                     T.writes(T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3])
                     with T.init():
                         T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T.float32(0)
-                    T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < T.int64(3), T.int64(0), div((v_ax2 - T.int64(3)), T.int64(2)) + T.int64(1)) <= div((v_ax2 + T.int64(2)), T.int64(2)) - v_wh and T.Select(v_ax3 < T.int64(4), T.int64(0), div((v_ax3 - T.int64(4)), T.int64(2)) + T.int64(1)) <= div((v_ax3 + T.int64(1)), T.int64(2)) - v_ww and T.Cast("int64", maxpool_grad_argmax_v0[v_ax0, v_ax1, div((v_ax2 + T.int64(2)), T.in [...]
+                    T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] = T_pool_grad[v_ax0, v_ax1, v_ax2, v_ax3] + T.if_then_else(T.Select(v_ax2 < T.int64(3), T.int64(0), T.Div(v_ax2 - T.int64(3), T.int64(2)) + T.int64(1)) <= T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh and T.Select(v_ax3 < T.int64(4), T.int64(0), T.Div(v_ax3 - T.int64(4), T.int64(2)) + T.int64(1)) <= T.Div(v_ax3 + T.int64(1), T.int64(2)) - v_ww and maxpool_grad_argmax_v0[v_ax0, v_ax1, T.Div(v_ax2 + T.int64(2), T.int64(2)) - v_wh,  [...]
+
+        @R.function
+        def main(output_grad: R.Tensor((3, 2, 6, 5), dtype="float32"), data: R.Tensor((3, 2, 10, 10), dtype="float32")) -> R.Tensor((3, 2, 10, 10), dtype="float32"):
+            cls = Expected
+            gv = R.call_tir(cls.max_pool2d_backward, (output_grad, data), out_sinfo=R.Tensor((3, 2, 10, 10), dtype="float32"))
+            return gv
     # fmt: on
 
     mod = LegalizeOps()(MaxPool2DBackward)