You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by la...@apache.org on 2020/11/04 04:58:45 UTC

[incubator-tvm] branch main updated: Register shape functions for some image related ops (#6373)

This is an automated email from the ASF dual-hosted git repository.

laurawly pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 237744d  Register shape functions for some image related ops (#6373)
237744d is described below

commit 237744d52bb2afecdbc4404f4bd09ad1dbe73ccc
Author: Leyuan Wang <la...@gmail.com>
AuthorDate: Tue Nov 3 20:57:50 2020 -0800

    Register shape functions for some image related ops (#6373)
    
    * debugging
    
    * added three shape funcs
    
    * fix lint
    
    * address comment
    
    * resolve conflicts
    
    * resolve conflicts
    
    * resolve conflicts
    
    * resolve conflicts
    
    * resolve conflicts
---
 python/tvm/relay/op/image/_image.py | 76 ++++++++++++++++++++++++++++++++
 tests/python/relay/test_any.py      | 88 +++++++++++++++++++++++++++++++++++++
 2 files changed, 164 insertions(+)

diff --git a/python/tvm/relay/op/image/_image.py b/python/tvm/relay/op/image/_image.py
index c0cdf64..ee8a5b3 100644
--- a/python/tvm/relay/op/image/_image.py
+++ b/python/tvm/relay/op/image/_image.py
@@ -42,6 +42,45 @@ def compute_resize(attrs, inputs, out_type):
 reg.register_injective_schedule("image.resize")
 
 
+@script
+def _resize_shape_func(image_shape, size, batch_axis, height_axis, width_axis, channel_axis):
+    out = output_tensor((4,), "int64")
+    out[batch_axis] = int64(image_shape[0])
+    out[height_axis] = int64(size[0])
+    out[width_axis] = int64(size[1])
+    out[channel_axis] = image_shape[channel_axis]
+    return out
+
+
+@reg.register_shape_func("image.resize", False)
+def resize_shape_func(attrs, inputs, _):
+    """
+    Shape function for resize op.
+    """
+    layout = attrs.layout
+    height_axis = width_axis = channel_axis = 1
+    for i, letter in enumerate(layout):
+        if letter == "N":
+            batch_axis = i
+        if letter == "H":
+            height_axis = i
+        if letter == "W":
+            width_axis = i
+        if letter == "C":
+            channel_axis = i
+    size = get_const_tuple(attrs.size)
+    return [
+        _resize_shape_func(
+            inputs[0],
+            convert(size),
+            convert(batch_axis),
+            convert(height_axis),
+            convert(width_axis),
+            convert(channel_axis),
+        )
+    ]
+
+
 @reg.register_compute("image.resize3d")
 def compute_resize3d(attrs, inputs, out_type):
     size = attrs.size
@@ -134,6 +173,25 @@ def compute_affine_grid(attrs, inputs, out_dtype):
 reg.register_injective_schedule("image.affine_grid")
 
 
+@script
+def _affine_grid_func(data, target_shape):
+    out = output_tensor((4,), "int64")
+    out[0] = int64(data[0])
+    out[1] = int64(2)
+    out[2] = int64(target_shape[0])
+    out[3] = int64(target_shape[1])
+    return out
+
+
+@reg.register_shape_func("image.affine_grid", False)
+def affine_grid_func(attrs, inputs, _):
+    """
+    Shape function for affine_grid op.
+    """
+    target_shape = get_const_tuple(attrs.target_shape)
+    return [_affine_grid_func(inputs[0], convert(target_shape))]
+
+
 # grid_sample
 @reg.register_compute("image.grid_sample")
 def compute_grid_sample(attrs, inputs, out_dtype):
@@ -143,3 +201,21 @@ def compute_grid_sample(attrs, inputs, out_dtype):
 
 
 reg.register_injective_schedule("image.grid_sample")
+
+
+@script
+def _grid_sample_func(data, grid):
+    out = output_tensor((4,), "int64")
+    out[0] = int64(data[0])
+    out[1] = int64(data[1])
+    out[2] = int64(grid[2])
+    out[3] = int64(grid[3])
+    return out
+
+
+@reg.register_shape_func("image.grid_sample", False)
+def grid_sample_func(attrs, inputs, _):
+    """
+    Shape function for grid_sample op.
+    """
+    return [_grid_sample_func(inputs[0], inputs[1])]
diff --git a/tests/python/relay/test_any.py b/tests/python/relay/test_any.py
index 8784b97..5469737 100644
--- a/tests/python/relay/test_any.py
+++ b/tests/python/relay/test_any.py
@@ -1121,6 +1121,94 @@ def test_any_ndarray_size():
     verify_any_ndarray_size((1, 2, 3, 4))
 
 
+def verify_any_resize(data_shape, scale, layout, static_data_shape, ref_out_shape):
+    mod = tvm.IRModule()
+    dtype = "float32"
+    data = relay.var("data", shape=data_shape, dtype=dtype)
+    if layout == "NHWC":
+        size = (data_shape[1] * scale, data_shape[2] * scale)
+    else:
+        size = (data_shape[2] * scale, data_shape[3] * scale)
+    y = relay.image.resize(data, size, layout)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    check_result([data_np], mod, ref_out_shape, assert_shape=True)
+
+
+@tvm.testing.uses_gpu
+def test_any_resize():
+    verify_any_resize(
+        data_shape=(relay.Any(), 4, 4, 4),
+        scale=2,
+        layout="NHWC",
+        static_data_shape=(1, 4, 4, 4),
+        ref_out_shape=(1, 8, 8, 4),
+    )
+    verify_any_resize(
+        data_shape=(relay.Any(), 8, 17, 20),
+        scale=3,
+        layout="NCHW",
+        static_data_shape=(2, 8, 17, 20),
+        ref_out_shape=(2, 8, 51, 60),
+    )
+
+
+def verify_any_grid_sample(data_shape, grid_shape, static_data_shape, ref_out_shape):
+    mod = tvm.IRModule()
+    dtype = "float32"
+    data = relay.var("data", shape=data_shape, dtype=dtype)
+    grid = relay.var("grid", shape=grid_shape, dtype=dtype)
+    y = relay.image.grid_sample(data, grid)
+    mod["main"] = relay.Function([data, grid], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    grid_np = np.random.uniform(size=grid_shape).astype(dtype)
+    check_result([data_np, grid_np], mod, ref_out_shape, assert_shape=True)
+
+
+@tvm.testing.uses_gpu
+def test_any_grid_sample():
+    verify_any_grid_sample(
+        data_shape=(relay.Any(), 4, 16, 32),
+        grid_shape=(4, 2, 8, 8),
+        static_data_shape=(4, 4, 16, 32),
+        ref_out_shape=(4, 4, 8, 8),
+    )
+    verify_any_grid_sample(
+        data_shape=(relay.Any(), 4, 16, 32),
+        grid_shape=(4, 2, 32, 32),
+        static_data_shape=(4, 4, 16, 32),
+        ref_out_shape=(4, 4, 32, 32),
+    )
+
+
+def verify_any_affine_grid(num_batch, static_num_batch, target_shape, ref_out_shape):
+    mod = tvm.IRModule()
+    dtype = "float32"
+    data_shape = (num_batch, 2, 3)
+    static_data_shape = (static_num_batch, 2, 3)
+    data = relay.var("data", shape=data_shape, dtype=dtype)
+    y = relay.image.affine_grid(data, target_shape)
+    mod["main"] = relay.Function([data], y)
+    data_np = np.random.uniform(size=static_data_shape).astype(dtype)
+    check_result([data_np], mod, ref_out_shape, assert_shape=True)
+
+
+@tvm.testing.uses_gpu
+def test_any_affine_grid():
+    verify_any_affine_grid(
+        num_batch=relay.Any(),
+        static_num_batch=1,
+        target_shape=(16, 32),
+        ref_out_shape=(1, 2, 16, 32),
+    )
+    verify_any_affine_grid(
+        num_batch=relay.Any(),
+        static_num_batch=8,
+        target_shape=(32, 32),
+        ref_out_shape=(8, 2, 32, 32),
+    )
+
+
 def test_any_consecutive_broadcast():
     dtype = "float32"
     data0 = relay.var("data0", shape=any_dims(2), dtype=dtype)