You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2020/09/14 04:05:59 UTC

[GitHub] [incubator-tvm] wjliu1998 commented on a change in pull request #6468: add aten::pixel_shuffle implementation (#6328)

wjliu1998 commented on a change in pull request #6468:
URL: https://github.com/apache/incubator-tvm/pull/6468#discussion_r487640206



##########
File path: python/tvm/relay/frontend/pytorch.py
##########
@@ -1230,6 +1230,42 @@ def _impl(inputs, input_types):
 
     return _impl
 
+def _pixel_shuffle(prelude):
+    def _impl(inputs, input_types):
+        data = inputs[0]
+        upscale_factor = inputs[1]
+        upscale_squared = upscale_factor * upscale_factor
+        b, c, h, w = _infer_shape(data)
+        assert c % upscale_squared == 0, \
+            "input channel should be divisible by square of upscale_factor"
+
+        import torch
+        if isinstance(data, _expr.Expr):
+            ndims = len(_infer_shape(data, prelude.mod))
+        elif isinstance(data, list):
+            ndims = data
+        elif isinstance(data, (torch.Tensor, np.ndarray)):
+            ndims = _infer_shape(data)
+        else:
+            msg = "Data type %s could not be parsed in transpose op" % (type(data))
+            raise AssertionError(msg)
+
+        if isinstance(data, tvm.runtime.NDArray):
+            ndims = len(_infer_shape(data))
+        axes = list(range(ndims))
+        num_inputs = len(inputs)
+        oc = c // upscale_squared
+        oh = h * upscale_factor
+        ow = w * upscale_factor
+
+        new_shape = [b, oc, upscale_factor, upscale_factor, h, w]
+        out_shape = [b, oc, oh, ow]
+
+        data = _op.transform.reshape(data, new_shape)
+        axes = [0, 1, 4, 2, 5, 3]

Review comment:
       The original shape order is [b, oc, upscale_factor, upscale_factor, h, w], the hardcoded axes transpose the order to [b, oc, h, upscale_factor, w, upscale_factor] for further reshape




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org