You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by cs...@apache.org on 2022/05/13 15:43:17 UTC

[tvm] branch main updated: Add Adreno GPU target and topi supporting textures with dynamically allocated textures (#11161)

This is an automated email from the ASF dual-hosted git repository.

csullivan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c2d1905779 Add Adreno GPU target and topi supporting textures with dynamically allocated textures (#11161)
c2d1905779 is described below

commit c2d190577928b0d81747947fdd5a2c2145adae62
Author: Andrey Malyshev <el...@gmail.com>
AuthorDate: Fri May 13 18:43:06 2022 +0300

    Add Adreno GPU target and topi supporting textures with dynamically allocated textures (#11161)
    
    * Add Adreno GPU target and topi supporting textures
    
    - There are 5 compute/schedules: conv2d for NCHW/NHWC, depthwise_conv2d
      for NCHW/NHWC, average pooling
    - Fix of dynamically allocated textures caching
    - Add texture-nhwc scope
    - Fix issue with codegen of vars having non acceptable symbols
    
    Co-authored-by: Chris Sullivan <cs...@octoml.ai>
    Co-authored-by: Egor Churaev <eg...@gmail.com>
    
    * Address comments
    
    * Add vectorization into some adreno pool flow
    
    Co-authored-by: Li <qu...@quicinc.com>
    
    * Fix adreno tests for running on the opencl host platform
    
    * remove unnecessary kDriverVersion in DeviceAttrKind
    
    * Move utils adreno functinos to separate shared file
    
    * fix black hits
    
    Co-authored-by: Chris Sullivan <cs...@octoml.ai>
    Co-authored-by: Egor Churaev <eg...@gmail.com>
    Co-authored-by: Li <qu...@quicinc.com>
---
 python/tvm/_ffi/runtime_ctypes.py                  |  11 +
 python/tvm/relay/op/strategy/__init__.py           |   1 +
 python/tvm/relay/op/strategy/adreno.py             | 162 ++++++
 python/tvm/target/target.py                        |  14 +
 python/tvm/topi/__init__.py                        |   1 +
 .../{relay/op/strategy => topi/adreno}/__init__.py |  22 +-
 python/tvm/topi/adreno/conv2d_alter_op.py          | 211 ++++++++
 python/tvm/topi/adreno/conv2d_nchw.py              | 344 +++++++++++++
 python/tvm/topi/adreno/conv2d_nhwc.py              | 339 +++++++++++++
 python/tvm/topi/adreno/depthwise_conv2d_nchw.py    | 316 ++++++++++++
 python/tvm/topi/adreno/depthwise_conv2d_nhwc.py    | 311 ++++++++++++
 python/tvm/topi/adreno/pooling.py                  |  89 ++++
 python/tvm/topi/adreno/utils.py                    | 549 ++++++++++++++++++++
 src/runtime/opencl/opencl_common.h                 |  11 +-
 src/runtime/opencl/opencl_device_api.cc            |   6 +
 src/runtime/texture.h                              |   6 +
 src/runtime/thread_storage_scope.h                 |   7 +
 src/target/source/codegen_source_base.cc           |   5 +
 src/target/target_kind.cc                          |   1 +
 tests/python/relay/test_conv2d_nchw_texture.py     | 394 +++++++++++++++
 tests/python/relay/test_conv2d_nhwc_texture.py     | 556 +++++++++++++++++++++
 .../relay/test_depthwise_conv2d_nchw_texture.py    | 194 +++++++
 .../relay/test_depthwise_conv2d_nhwc_texture.py    | 233 +++++++++
 tests/python/relay/utils/adreno_utils.py           | 118 +++++
 24 files changed, 3886 insertions(+), 15 deletions(-)

diff --git a/python/tvm/_ffi/runtime_ctypes.py b/python/tvm/_ffi/runtime_ctypes.py
index 03a68e9f97..5dc3fe0938 100644
--- a/python/tvm/_ffi/runtime_ctypes.py
+++ b/python/tvm/_ffi/runtime_ctypes.py
@@ -432,6 +432,17 @@ class Device(ctypes.Structure):
         """
         return self._GetDeviceAttr(self.device_type, self.device_id, 12)
 
+    def texture_spatial_limit(self):
+        """Returns limits for textures by spatial dimensions
+
+        Returns
+        -------
+        limit : int or None
+            Maximum size of the texture by spatial dimensions
+
+        """
+        return self._GetDeviceAttr(self.device_type, self.device_id, 12)
+
     def create_raw_stream(self):
         """Create a new runtime stream at the context.
 
diff --git a/python/tvm/relay/op/strategy/__init__.py b/python/tvm/relay/op/strategy/__init__.py
index cf915777ed..1be5425e70 100644
--- a/python/tvm/relay/op/strategy/__init__.py
+++ b/python/tvm/relay/op/strategy/__init__.py
@@ -29,3 +29,4 @@ from . import bifrost
 from . import rocm
 from . import intel_graphics
 from . import hexagon
+from . import adreno
diff --git a/python/tvm/relay/op/strategy/adreno.py b/python/tvm/relay/op/strategy/adreno.py
new file mode 100644
index 0000000000..a783440bb3
--- /dev/null
+++ b/python/tvm/relay/op/strategy/adreno.py
@@ -0,0 +1,162 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Definition of adreno operator strategy."""
+# pylint: disable=invalid-name,unused-argument,wildcard-import,unused-wildcard-import
+from tvm import topi
+from .generic import *
+from .. import op as _op
+
+
+@conv2d_NCHWc_strategy.register("adreno")
+@conv2d_strategy.register("adreno")
+def conv2d_strategy_adreno(attrs, inputs, out_type, target):
+    """conv2d adreno strategy"""
+    strategy = _op.OpStrategy()
+    data, kernel = inputs
+    dilation_h, dilation_w = attrs.get_int_tuple("dilation")
+    groups = attrs.groups
+    data_layout = attrs.data_layout
+    kernel_layout = attrs.kernel_layout
+    if dilation_h < 1 or dilation_w < 1:
+        raise ValueError("dilation should be positive value")
+
+    if groups == 1:
+        if (data_layout == "NCHW" and kernel_layout == "OIHW") or (
+            data_layout == "NCHW4c" and kernel_layout == "OIHW4o"
+        ):
+            if out_type.dtype == "float16":
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.adreno.conv2d_nchwc),
+                    wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc),
+                    name="conv2d_nchwc.image2d",
+                    plevel=10,
+                )
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.adreno.conv2d_nchwc_acc32),
+                wrap_topi_schedule(topi.adreno.schedule_conv2d_nchwc_acc32),
+                name="conv2d_nchwc_tpack.image2d",
+                plevel=20,
+            )
+        elif (data_layout == "NHWC" and kernel_layout == "HWIO") or (
+            data_layout == "NHWC4c" and kernel_layout == "HWIO4o"
+        ):
+            if out_type.dtype == "float16":
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.adreno.conv2d_nhwc),
+                    wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc),
+                    name="conv2d_nhwc.image2d",
+                    plevel=10,
+                )
+            strategy.add_implementation(
+                wrap_compute_conv2d(topi.adreno.conv2d_nhwc_acc32),
+                wrap_topi_schedule(topi.adreno.schedule_conv2d_nhwc_acc32),
+                name="conv2d_nhwc_acc32.image2d",
+                plevel=20,
+            )
+        else:
+            raise RuntimeError(
+                "Layout not supported: ("
+                + data_layout
+                + ", "
+                + kernel_layout
+                + ") - only support NCHW4c / OIHW4o and NHWC / HWOI layouts for conv2d"
+            )
+    else:
+        # cannot use is_depthwise_conv2d because it does not know about NHWC4c/HWOI4o layouts
+        if data_layout == "NCHW":
+            ic = data.shape[1]
+        elif data_layout == "NCHW4c":
+            ic = data.shape[1] * data.shape[4]
+        elif data_layout == "NHWC":
+            ic = data.shape[3]
+        elif data_layout == "NHWC4c":
+            ic = data.shape[3] * data.shape[4]
+        else:
+            raise RuntimeError("Unsupported depthwise_conv2d data layout {}".format(data_layout))
+        if kernel_layout == "OIHW":
+            oc = kernel.shape[0]
+        elif kernel_layout == "OIHW4o":
+            oc = kernel.shape[0] * kernel.shape[4]
+        elif kernel_layout == "HWOI":
+            oc = kernel.shape[2]
+        elif kernel_layout == "HWOI4o":
+            oc = kernel.shape[2] * kernel.shape[4]
+        else:
+            raise RuntimeError(
+                "Unsupported depthwise_conv2d kernel layout {}".format(kernel_layout)
+            )
+
+        if ic == oc == groups:
+            if (data_layout == "NCHW" and kernel_layout == "OIHW") or (
+                data_layout == "NCHW4c" and kernel_layout == "OIHW4o"
+            ):
+                if out_type.dtype == "float16":
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc),
+                        wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc),
+                        name="depthwise_conv2d_nchwc.image2d",
+                        plevel=10,
+                    )
+                strategy.add_implementation(
+                    wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nchwc_acc32),
+                    wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nchwc_acc32),
+                    name="depthwise_conv2d_nchwc_acc32.image2d",
+                    plevel=20,
+                )
+            elif (data_layout == "NHWC" and kernel_layout == "HWOI") or (
+                data_layout == "NHWC4c" and kernel_layout == "HWOI4o"
+            ):
+                if data.shape[-1] >= 4:
+                    if out_type.dtype == "float16":
+                        strategy.add_implementation(
+                            wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nhwc),
+                            wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nhwc),
+                            name="depthwise_conv2d_nhwc.image2d",
+                            plevel=10,
+                        )
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.adreno.depthwise_conv2d_nhwc_acc32),
+                        wrap_topi_schedule(topi.adreno.schedule_depthwise_conv2d_nhwc_acc32),
+                        name="depthwise_conv2d_nhwc_acc32.image2d",
+                        plevel=20,
+                    )
+                else:
+                    strategy.add_implementation(
+                        wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
+                        wrap_topi_schedule(topi.cuda.schedule_depthwise_conv2d_nhwc),
+                        name="depthwise_conv2d_nhwc.cuda",
+                    )
+            else:
+                raise RuntimeError(
+                    "Layout not supported: ("
+                    + data_layout
+                    + ", "
+                    + kernel_layout
+                    + ") - only support NCHW4c / OIHW4o and NHWC / HWOI layouts for conv2d"
+                )
+        else:
+            raise RuntimeError("General group convolution is not currently supported")
+    return strategy
+
+
+@schedule_pool.register("adreno")
+def schedule_pool_adreno(attrs, outs, target):
+    """schedule pooling ops for adreno"""
+    with target:
+        if attrs.layout == "NCHW4c":
+            return topi.adreno.schedule_pool(outs, attrs.layout)
+        return topi.cuda.schedule_pool(outs, attrs.layout)
diff --git a/python/tvm/target/target.py b/python/tvm/target/target.py
index 03115612c5..4752095d37 100644
--- a/python/tvm/target/target.py
+++ b/python/tvm/target/target.py
@@ -814,6 +814,20 @@ def stm32(series="unknown", options=None):
     return Target(" ".join(["c"] + opts))
 
 
+def adreno(model="unknown", options=None):
+    """Returns a Qualcomm GPU target.
+    Parameters
+    ----------
+    model: str
+        The model of this device
+    options : str or list of str
+        Additional options
+    """
+    opts = ["-device=adreno", "-model=%s" % model]
+    opts = _merge_opts(opts, options)
+    return Target(" ".join(["opencl"] + opts))
+
+
 def create(target):
     """Deprecated. Use the constructor of :py:mod:`tvm.target.Target` directly."""
     warnings.warn("tvm.target.create() is being deprecated. Please use tvm.target.Target() instead")
diff --git a/python/tvm/topi/__init__.py b/python/tvm/topi/__init__.py
index cc6c8fcc31..785ba395d2 100644
--- a/python/tvm/topi/__init__.py
+++ b/python/tvm/topi/__init__.py
@@ -64,6 +64,7 @@ from . import sparse
 from . import hls
 from . import random
 from . import hexagon
+from . import adreno
 
 # error reporting
 from .utils import InvalidShapeError
diff --git a/python/tvm/relay/op/strategy/__init__.py b/python/tvm/topi/adreno/__init__.py
similarity index 71%
copy from python/tvm/relay/op/strategy/__init__.py
copy to python/tvm/topi/adreno/__init__.py
index cf915777ed..6c9b7463c1 100644
--- a/python/tvm/relay/op/strategy/__init__.py
+++ b/python/tvm/topi/adreno/__init__.py
@@ -15,17 +15,11 @@
 # specific language governing permissions and limitations
 # under the License.
 
-# pylint: disable=wildcard-import
-"""Relay op strategies."""
-from __future__ import absolute_import as _abs
-
-from .generic import *
-from . import x86
-from . import arm_cpu
-from . import cuda
-from . import hls
-from . import mali
-from . import bifrost
-from . import rocm
-from . import intel_graphics
-from . import hexagon
+# pylint: disable=redefined-builtin, wildcard-import
+"""Qualcomm Adreno GPU specific declaration and schedules."""
+from .conv2d_nchw import *
+from .depthwise_conv2d_nchw import *
+from .conv2d_nhwc import *
+from .depthwise_conv2d_nhwc import *
+from .pooling import *
+from .conv2d_alter_op import *
diff --git a/python/tvm/topi/adreno/conv2d_alter_op.py b/python/tvm/topi/adreno/conv2d_alter_op.py
new file mode 100644
index 0000000000..e8944093c0
--- /dev/null
+++ b/python/tvm/topi/adreno/conv2d_alter_op.py
@@ -0,0 +1,211 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-member
+"""Conv2D alter op for Qualcomm Adreno GPU"""
+
+import logging
+
+import re
+import tvm
+from tvm import te
+from tvm import relay
+from tvm import autotvm
+from ..utils import get_const_tuple
+from ..nn import conv2d_alter_layout
+
+logger = logging.getLogger("topi")
+
+# Number of wildcards for matching of supported layouts to be transformed
+_NCHWc_matcher = re.compile("^NCHW[0-9]+c$")
+_OIHWo_matcher = re.compile("^OIHW[0-9]+o$")
+_NHWCc_matcher = re.compile("^NHWC[0-9]+c$")
+_HWIOo_matcher = re.compile("^HWIO[0-9]+o$")
+_HWOIo_matcher = re.compile("^HWOI[0-9]+o$")
+
+
+@conv2d_alter_layout.register("adreno")
+def _alter_conv2d_layout(attrs, inputs, tinfos, out_type):
+    """
+    Prepare of the new conv2d with proper target blocked layout attributes
+    OpenCL Textures supports 1d/2d/3d/4d tetures but read happens always only for 4 elements
+    in a line. Thus way we are supporting for now only 4d conversions on the end
+    NCHW -> NCHW4c & OIHW ->OIHW4o
+    NHWC -> NHWC4c & HWIO -> HWIO4o & HWOI -> HWOI4o
+    """
+    target = tvm.target.Target.current(allow_none=False)
+    dispatch_ctx = autotvm.task.DispatchContext.current
+    new_attrs = {k: attrs[k] for k in attrs.keys()}
+
+    # Parse the attributes.
+    padding = attrs.get_int_tuple("padding")
+    strides = attrs.get_int_tuple("strides")
+    dilation = attrs.get_int_tuple("dilation")
+    data_layout = attrs["data_layout"]
+    kernel_layout = attrs["kernel_layout"]
+    data_tensor, kernel_tensor = tinfos
+    data_dtype = data_tensor.dtype
+    kernel_dtype = kernel_tensor.dtype
+    out_dtype = out_type.dtype
+
+    if isinstance(dispatch_ctx, autotvm.task.ApplyGraphBest):
+        cfg = dispatch_ctx.query(target, None)
+        workload = cfg.workload
+    else:
+        impl, outs = relay.backend.te_compiler.select_implementation(
+            relay.op.get("nn.conv2d"), attrs, tinfos, out_type, target
+        )
+        workload = autotvm.task.get_workload(outs)
+        if workload is None:
+            return None
+
+        cfg = dispatch_ctx.query(target, workload)
+
+    topi_tmpl = workload[0]
+
+    if "conv2d_nchwc" in topi_tmpl:  # covers both conv2d_nchwc and depthwise_conv2d_nchwc
+        if data_layout == "NCHW" and kernel_layout == "OIHW":
+            batch, in_channels, in_height, in_width = data_tensor.shape
+            out_channles, _, kernel_h, kernel_w = kernel_tensor.shape
+            in_channel_block = in_channels % 4
+            if in_channel_block == 0:
+                in_channel_block = 4
+            num_filter_block = out_channles % 4
+            if num_filter_block == 0:
+                num_filter_block = 4
+
+            # no support yet for tensors that cannot be divisible by factor 4
+            if in_channel_block != 4 or num_filter_block != 4:
+                return None
+
+            batch_size, in_channel, height, width = get_const_tuple(data_tensor.shape)
+            out_channel, in_filter_channel, kh, kw = get_const_tuple(kernel_tensor.shape)
+
+            # update new attrs
+            new_attrs["channels"] = out_channel
+            new_attrs["data_layout"] = "NCHW%dc" % in_channel_block
+            # (oc, ic, h, w) -> (OC, ic, h, w, oc)
+            new_attrs["kernel_layout"] = "OIHW%do" % num_filter_block
+            new_attrs["out_layout"] = "NCHW%dc" % num_filter_block
+
+            # Store altered operator's config for applying of tuned AutoTVM statistics
+            new_data = te.placeholder(
+                (batch_size, in_channel // in_channel_block, height, width, in_channel_block),
+                dtype=data_dtype,
+            )
+            new_kernel = te.placeholder(
+                (out_channel // num_filter_block, in_filter_channel, kh, kw, num_filter_block),
+                dtype=kernel_tensor.dtype,
+            )
+            new_workload = autotvm.task.args_to_workload(
+                [
+                    new_data,
+                    new_kernel,
+                    strides,
+                    padding,
+                    dilation,
+                    out_dtype,
+                ],
+                topi_tmpl,  # "conv2d_nchwc.image2d",
+            )
+            dispatch_ctx.update(target, new_workload, cfg)
+        else:
+            assert _NCHWc_matcher.match(data_layout)
+            assert _OIHWo_matcher.match(kernel_layout)
+        return relay.nn.conv2d(*inputs, **new_attrs)
+
+    if "conv2d_nhwc" in topi_tmpl:  # covers both conv2d_nhwcc and depthwise_conv2d_nhwcc
+        if (data_layout == "NHWC" and kernel_layout == "HWIO") or (
+            data_layout == "NHWC" and kernel_layout == "HWOI"
+        ):
+            if kernel_layout == "HWIO":
+                batch_size, in_height, in_width, in_channels = data_tensor.shape
+                kernel_h, kernel_w, in_filter_channel, out_channles = kernel_tensor.shape
+            else:
+                batch_size, in_height, in_width, in_channels = data_tensor.shape
+                kernel_h, kernel_w, out_channles, in_filter_channel = kernel_tensor.shape
+            in_channel_block = in_channels % 4
+            if in_channel_block == 0:
+                in_channel_block = 4
+            num_filter_block = out_channles % 4
+            if num_filter_block == 0:
+                num_filter_block = 4
+
+            # no support yet for tensors cannot be divisible by factor 4
+            if in_channel_block != 4 or num_filter_block != 4:
+                return None
+
+            # update new attrs
+            new_attrs["channels"] = out_channles
+            new_attrs["data_layout"] = "NHWC%dc" % in_channel_block
+            # (h, w, ic, oc) -> (h, w, ic, OC, oc)
+            if kernel_layout == "HWIO":
+                new_attrs["kernel_layout"] = "HWIO%do" % num_filter_block
+            else:
+                new_attrs["kernel_layout"] = "HWOI%do" % num_filter_block
+            new_attrs["out_layout"] = "NHWC%dc" % num_filter_block
+
+            # Store altered operator's config for applying of tuned AutoTVM statistics
+            new_data = te.placeholder(
+                (
+                    batch_size,
+                    in_height,
+                    in_width,
+                    in_channels // in_channel_block,
+                    in_channel_block,
+                ),
+                dtype=data_dtype,
+            )
+            if kernel_layout == "HWIO":
+                new_kernel = te.placeholder(
+                    (
+                        kernel_h,
+                        kernel_w,
+                        in_filter_channel,
+                        out_channles // num_filter_block,
+                        num_filter_block,
+                    ),
+                    dtype=kernel_tensor.dtype,
+                )
+            else:
+                new_kernel = te.placeholder(
+                    (
+                        kernel_h,
+                        kernel_w,
+                        out_channles // num_filter_block,
+                        in_filter_channel,
+                        num_filter_block,
+                    ),
+                    dtype=kernel_tensor.dtype,
+                )
+            new_workload = autotvm.task.args_to_workload(
+                [
+                    new_data,
+                    new_kernel,
+                    strides,
+                    padding,
+                    dilation,
+                    out_dtype,
+                ],
+                topi_tmpl,
+            )
+            dispatch_ctx.update(target, new_workload, cfg)
+        else:
+            assert _NHWCc_matcher.match(data_layout)
+            assert _HWIOo_matcher.match(kernel_layout) or _HWOIo_matcher.match(kernel_layout)
+        return relay.nn.conv2d(*inputs, **new_attrs)
+
+    return None
diff --git a/python/tvm/topi/adreno/conv2d_nchw.py b/python/tvm/topi/adreno/conv2d_nchw.py
new file mode 100644
index 0000000000..96368b3e57
--- /dev/null
+++ b/python/tvm/topi/adreno/conv2d_nchw.py
@@ -0,0 +1,344 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
+"""conv2d nchw schedule on Qualcomm Adreno GPU"""
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from ..utils import get_const_tuple, traverse_inline
+from .utils import (
+    split_to_chunks,
+    pack_input,
+    pack_filter,
+    expand_spatial_dimensions,
+    add_pad,
+    bind_data_copy,
+)
+
+
+@autotvm.register_topi_compute("conv2d_nchwc.image2d")
+def conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"):
+    """Compute conv2d with NCHWc layout"""
+    args = {"shared": False, "accumulator": "float16"}
+    return compute_conv2d_NCHWc_KCRSk(
+        data, kernel, strides, padding, dilation, out_dtype, args=args
+    )
+
+
+@autotvm.register_topi_compute("conv2d_nchwc_acc32.image2d")
+def conv2d_nchwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"):
+    """Compute conv2d with NCHWc layout"""
+    args = {"shared": False, "accumulator": "float32"}
+    return compute_conv2d_NCHWc_KCRSk(
+        data, kernel, strides, padding, dilation, out_dtype, args=args
+    )
+
+
+@autotvm.register_topi_schedule("conv2d_nchwc.image2d")
+def schedule_conv2d_nchwc(cfg, outs):
+    return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc16")
+
+
+@autotvm.register_topi_schedule("conv2d_nchwc_acc32.image2d")
+def schedule_conv2d_nchwc_acc32(cfg, outs):
+    return schedule_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc32")
+
+
+def schedule_conv2d_nchwc_impl(cfg, outs, tag):
+    """Create the schedule for conv2d_nchw"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == tag:
+            schedule_conv2d_NCHWc_KCRSk(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def compute_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype, args):
+    """
+    Convolution operator in NCHWc layout.
+    Algo:
+      1. Convert into blocked format if we have 4d original tensor.
+         In case of AutoTVM we override the convert by just tensors since such conversion
+         will be absent for real blocked convolution, no sense to include into tuning
+      2. Expand spatial dimensions to have width and height be dividable by factor 4
+         This leads to slightly bigger amount of compute but allow utilize GPU much better
+      3. Add paddings. This happens even if we do not need pad originaly. This is useful
+         due to work arounding of the gaps of texture annotation between Primary Functions
+         and limited support of textures in schedules. Later on this pad will be executed
+         separately and will produce texture
+      4. 5d Convolution compute with accumulating into out_dtype
+      5. Cast to the origin output data type
+      6. For case of 4d convolution: convert of output from 5d to 4d
+    """
+
+    if out_dtype is None:
+        out_dtype = Input.dtype
+    assert isinstance(stride, int) or len(stride) == 2
+    assert isinstance(dilation, int) or len(dilation) == 2
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    convert_from4d = False
+    if len(Input.shape) == 4:
+        batch, in_channels, in_height, in_width = Input.shape
+        out_channles, in_filter_channels, kernel_h, kernel_w = Filter.shape
+
+        in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(in_channels, 4)
+        out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channles, 4)
+
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            dshape = (batch, in_channel_chunks, in_height, in_width, in_channel_block)
+            Input = tvm.te.placeholder(dshape, Input.dtype, name="data_placeholder")
+            kshape = (out_channel_chunks, in_filter_channels, kernel_h, kernel_w, out_channel_block)
+            Filter = tvm.te.placeholder(kshape, Filter.dtype, name="kernel_placeholder")
+        else:
+            convert_from4d = True
+            Input = pack_input(
+                Input,
+                "NCHW",
+                batch,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                in_height,
+                in_width,
+            )
+            Filter = pack_filter(
+                Filter,
+                "OIHW",
+                out_channel_chunks,
+                out_channel_block,
+                out_channel_tail,
+                in_filter_channels,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                kernel_h,
+                kernel_w,
+            )
+
+    else:
+        batch, in_channel_chunks, in_height, in_width, in_channel_block = Input.shape
+        out_channel_chunks, in_filter_channels, kernel_h, kernel_w, out_channel_block = Filter.shape
+
+    out_height_orig, out_height, out_width_orig, out_width = expand_spatial_dimensions(
+        in_height, in_width, kernel_h, kernel_w, dilation_h, dilation_w, padding, stride_h, stride_w
+    )
+
+    temp = add_pad(
+        Input,
+        "NCHW",
+        out_height_orig,
+        out_width_orig,
+        kernel_h,
+        kernel_w,
+        dilation_h,
+        dilation_w,
+        padding,
+        stride_h,
+        stride_w,
+    )
+
+    rcc = te.reduce_axis((0, in_channel_chunks), name="rc")
+    rcb = te.reduce_axis((0, in_channel_block), name="rc")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
+
+    conv = te.compute(
+        (batch, out_channel_chunks, out_height, out_width, out_channel_block),
+        lambda nn, ffc, yy, xx, ffb: te.sum(
+            (
+                temp[nn, rcc, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcb]
+                * Filter[ffc, rcc * in_channel_block + rcb, ry, rx, ffb]
+            ).astype(args["accumulator"]),
+            axis=[rcc, rcb, ry, rx],
+        ),
+        tag="conv2d_nchwc",
+    )
+
+    if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning:
+        dummy_cast = te.compute(
+            (batch, out_channel_chunks, out_height_orig, out_width_orig, out_channel_block),
+            lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype(out_dtype),
+            tag="dummy_cast",
+        )
+        return te.compute(
+            (batch, out_channles, out_height_orig, out_width_orig),
+            lambda n, c, y, x: dummy_cast[n, c // out_channel_block, y, x, c % out_channel_block],
+            tag="cast_from_acc" + args["accumulator"][-2:],
+        )
+    else:
+        return te.compute(
+            (batch, out_channel_chunks, out_height_orig, out_width_orig, out_channel_block),
+            lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype(out_dtype),
+            tag="cast_from_acc" + args["accumulator"][-2:],
+        )
+
+
+def schedule_conv2d_NCHWc_KCRSk(cfg, s, output):
+    """
+    schedule optimized for batch size = 1
+
+    Algo:
+    1. Split output axis to three parts: global work size, vthread, local worksize.
+       The limitations for tuning includes heuristics from some tuned networks to limit
+       search space and not pay much time for useles configurations.
+    2. In case of 4d convolution schedule copying of the input (and filter) into
+      5d tensors
+    4. pad should be scheduled separately to create independent opencl kernel. If pad is
+       inlined into convolution, this gives 1.5x performance drop
+    5. We are using cache_read to produce texture and guarantee the best performance
+       on the next stage.
+    6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize
+       for textures
+       For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion
+       of data type
+    7. In case of 4d conv we need to schedule postops as well
+    """
+    latest = s.outputs[0].output(0)
+    if len(latest.op.axis) == 4:
+        latest_blocked = dummy = output.op.input_tensors[0]
+        conv = dummy.op.input_tensors[0]
+    else:
+        conv = output.op.input_tensors[0]
+        latest_blocked = latest
+
+    ##### space definition begin #####
+    n, fc, y, x, fb = s[conv].op.axis
+    rcc, rcb, ry, rx = s[conv].op.reduce_axis
+
+    if conv.shape[1] % 2 == 0:
+        min_threads_div = 2
+    else:
+        min_threads_div = 1
+    cfg.define_split(
+        "tile_fc",
+        fc,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8
+        and entity.size[2] >= min_threads_div
+        and entity.size[2] < 256,
+    )
+    cfg.define_split(
+        "tile_y",
+        y,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16,
+    )
+    cfg.define_split(
+        "tile_x",
+        x,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16,
+    )
+
+    cfg.define_split("tile_rcc", rcc, num_outputs=2)
+    cfg.define_split("tile_ry", ry, num_outputs=2)
+    cfg.define_split("tile_rx", rx, num_outputs=2)
+    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+    cfg.define_knob("unroll_explicit", [0, 1])
+
+    ##### space definition end #####
+
+    pad_data, kernel = s[conv].op.input_tensors
+    if (
+        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag
+    ):  # len(latest.op.axis) == 4:
+        # manage scheduling of datacopy
+        pad_data, kernel = s[conv].op.input_tensors
+        pack_data = pad_data.op.input_tensors[0]
+        bind_data_copy(s[pack_data])
+        bind_data_copy(s[kernel])
+
+    pad_data, kernel = s[conv].op.input_tensors
+
+    s[pad_data].compute_inline()
+
+    s[conv].set_scope("local")
+    if latest_blocked == latest and output != latest:
+        s[output].compute_inline()
+
+    # create cache stage
+    AT = s.cache_read(pad_data, "global.texture", [conv])
+    bind_data_copy(s[AT])
+    WT = s.cache_read(kernel, "global.texture-weight", [conv])
+    bind_data_copy(s[WT])
+
+    # tile and bind spatial axes
+    n, fc, y, x, fb = s[latest_blocked].op.axis
+
+    kernel_scope, n = s[latest_blocked].split(n, nparts=1)
+
+    bf, vf, tf = cfg["tile_fc"].apply(s, latest_blocked, fc)
+    by, vy, ty = cfg["tile_y"].apply(s, latest_blocked, y)
+    bx, vx, tx = cfg["tile_x"].apply(s, latest_blocked, x)
+
+    bf = s[latest_blocked].fuse(n, bf)
+    s[latest_blocked].bind(bf, te.thread_axis("blockIdx.z"))
+    s[latest_blocked].bind(by, te.thread_axis("blockIdx.y"))
+    s[latest_blocked].bind(bx, te.thread_axis("blockIdx.x"))
+    s[latest_blocked].bind(vf, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vy, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vx, te.thread_axis("vthread"))
+    s[latest_blocked].bind(tf, te.thread_axis("threadIdx.z"))
+    s[latest_blocked].bind(ty, te.thread_axis("threadIdx.y"))
+    s[latest_blocked].bind(tx, te.thread_axis("threadIdx.x"))
+    s[latest_blocked].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fb)
+    s[latest_blocked].vectorize(fb)
+
+    s[conv].compute_at(s[latest_blocked], tx)
+
+    # tile reduction axes
+    n, fc, y, x, fb = s[conv].op.axis
+
+    rcc, rcb, ry, rx = s[conv].op.reduce_axis
+    rco, rci = cfg["tile_rcc"].apply(s, conv, rcc)
+    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
+
+    s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb)
+    s[conv].vectorize(fb)
+    s[conv].unroll(rcb)
+
+    # unroll
+    s[latest_blocked].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[latest_blocked].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
+
+    if latest_blocked != latest:
+        s[latest].compute_root()
+        bind_data_copy(s[latest], 1)
+        if latest != output:
+            s[output].compute_inline()
+
+    N, OCC, OH, OW, OCB = get_const_tuple(latest_blocked.shape)
+    _, IC, KH, KW, _ = get_const_tuple(kernel.shape)
+    ICKHKW = IC * KH * KW
+
+    if isinstance(N, int):
+        cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW)
diff --git a/python/tvm/topi/adreno/conv2d_nhwc.py b/python/tvm/topi/adreno/conv2d_nhwc.py
new file mode 100644
index 0000000000..d40f813fdb
--- /dev/null
+++ b/python/tvm/topi/adreno/conv2d_nhwc.py
@@ -0,0 +1,339 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
+"""conv2d nhwc schedule on Qualcomm Adreno GPU"""
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from ..utils import get_const_tuple, traverse_inline
+from .utils import (
+    split_to_chunks,
+    pack_input,
+    pack_filter,
+    expand_spatial_dimensions,
+    add_pad,
+    bind_data_copy,
+    get_texture_storage,
+)
+
+
+@autotvm.register_topi_compute("conv2d_nhwc.image2d")
+def conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"):
+    """Compute conv2d with NCHWc layout"""
+    args = {"shared": False, "accumulator": "float16"}
+    return compute_conv2d_NHWC_HWIO(data, kernel, strides, padding, dilation, out_dtype, args=args)
+
+
+@autotvm.register_topi_compute("conv2d_nhwc_acc32.image2d")
+def conv2d_nhwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"):
+    """Compute conv2d with NCHWc layout"""
+    args = {"shared": False, "accumulator": "float32"}
+    return compute_conv2d_NHWC_HWIO(data, kernel, strides, padding, dilation, out_dtype, args=args)
+
+
+@autotvm.register_topi_schedule("conv2d_nhwc.image2d")
+def schedule_conv2d_nhwc(cfg, outs):
+    return schedule_conv2d_nhwc_impl(cfg, outs, tag="cast_from_acc16")
+
+
+@autotvm.register_topi_schedule("conv2d_nhwc_acc32.image2d")
+def schedule_conv2d_nhwc_acc32(cfg, outs):
+    return schedule_conv2d_nhwc_impl(cfg, outs, tag="cast_from_acc32")
+
+
+def schedule_conv2d_nhwc_impl(cfg, outs, tag):
+    """Create the schedule for conv2d_nhwc"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == tag:
+            schedule_conv2d_NHWC(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def compute_conv2d_NHWC_HWIO(Input, Filter, stride, padding, dilation, out_dtype, args):
+    """
+    Convolution operator in NHWC layout.
+    Algo:
+      1. Convert into blocked format if we have 4d original tensor.
+         In case of AutoTVM we override the convert by just tensors since such conversion
+         will be absent for real blocked convolution, no sense to include into tuning
+      2. Expand spatial dimensions to have width and height be dividable by factor 4
+         This leads to slightly bigger amount of compute but allow utilize GPU much better
+      3. Add paddings. This happens even if we do not need pad originaly. This is useful
+         due to work arounding of the gaps of texture annotation between Primary Functions
+         and limited support of textures in schedules. Later on this pad will be executed
+         separately and will produce texture
+      4. 5d Convolution compute with accumulating into out_dtype
+      5. Cast to the origin output data type
+      6. For case of 4d convolution: convert of output from 5d to 4d
+    """
+
+    if out_dtype is None:
+        out_dtype = Input.dtype
+    assert isinstance(stride, int) or len(stride) == 2
+    assert isinstance(dilation, int) or len(dilation) == 2
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    convert_from4d = False
+    if len(Input.shape) == 4:
+        batch, in_height, in_width, in_channels = Input.shape
+        kernel_h, kernel_w, in_filter_channels, out_channles = Filter.shape
+
+        in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(in_channels, 4)
+        out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channles, 4)
+
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            dshape = (batch, in_height, in_width, in_channel_chunks, in_channel_block)
+            Input = tvm.te.placeholder(dshape, Input.dtype, name="data_placeholder")
+            kshape = (kernel_h, kernel_w, in_filter_channels, out_channel_chunks, out_channel_block)
+            Filter = tvm.te.placeholder(kshape, Filter.dtype, name="kernel_placeholder")
+        else:
+            convert_from4d = True
+            Input = pack_input(
+                Input,
+                "NHWC",
+                batch,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                in_height,
+                in_width,
+            )
+            Filter = pack_filter(
+                Filter,
+                "HWIO",
+                out_channel_chunks,
+                out_channel_block,
+                out_channel_tail,
+                in_filter_channels,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                kernel_h,
+                kernel_w,
+            )
+
+    else:
+        batch, in_height, in_width, in_channel_chunks, in_channel_block = Input.shape
+        kernel_h, kernel_w, in_filter_channels, out_channel_chunks, out_channel_block = Filter.shape
+
+    out_height_orig, out_height, out_width_orig, out_width = expand_spatial_dimensions(
+        in_height, in_width, kernel_h, kernel_w, dilation_h, dilation_w, padding, stride_h, stride_w
+    )
+
+    temp = add_pad(
+        Input,
+        "NHWC",
+        out_height_orig,
+        out_width_orig,
+        kernel_h,
+        kernel_w,
+        dilation_h,
+        dilation_w,
+        padding,
+        stride_h,
+        stride_w,
+    )
+
+    rcc = te.reduce_axis((0, in_channel_chunks), name="rcc")
+    rcb = te.reduce_axis((0, in_channel_block), name="rcb")
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
+    conv = te.compute(
+        (batch, out_height, out_width, out_channel_chunks, out_channel_block),
+        lambda nn, yy, xx, fc, fb: te.sum(
+            (
+                temp[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, rcc, rcb]
+                * Filter[ry, rx, rcc * in_channel_block + rcb, fc, fb]
+            ).astype(args["accumulator"]),
+            axis=[ry, rx, rcc, rcb],
+        ),
+        tag="conv2d_nhwc",
+    )
+
+    if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning:
+        dummy_cast = te.compute(
+            (batch, out_height_orig, out_width_orig, out_channel_chunks, out_channel_block),
+            lambda n, y, x, fc, fb: conv[n, y, x, fc, fb].astype(out_dtype),
+            tag="dummy_cast",
+        )
+        return te.compute(
+            (batch, out_height_orig, out_width_orig, out_channles),
+            lambda n, y, x, c: dummy_cast[n, y, x, c // out_channel_block, c % out_channel_block],
+            tag="cast_from_acc" + args["accumulator"][-2:],
+        )
+    else:
+        return te.compute(
+            (batch, out_height_orig, out_width_orig, out_channel_chunks, out_channel_block),
+            lambda n, y, x, ffc, ffb: conv[n, y, x, ffc, ffb].astype(out_dtype),
+            tag="cast_from_acc" + args["accumulator"][-2:],
+        )
+
+
+def schedule_conv2d_NHWC(cfg, s, output):
+    """
+    schedule optimized for batch size = 1
+
+    Algo:
+    1. Split output axis to three parts: global work size, vthread, local worksize.
+       The limitations for tuning includes heuristics from some tuned networks to limit
+       search space and not pay much time for useles configurations.
+    2. In case of 4d convolution schedule copying of the input (and filter) into
+      5d tensors
+    4. pad should be scheduled separately to create independent opencl kernel. If pad is
+       inlined into convolution, this gives 1.5x performance drop
+    5. We are using cache_read to produce texture and guarantee the best performance
+       on the next stage.
+    6. For 5d convolution we schedule the latest op with binding 5d axis and vectorize
+       for textures
+       For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion
+       of data type
+    7. In case of 4d conv we need to schedule postops as well
+    """
+    latest = s.outputs[0].output(0)
+    if len(latest.op.axis) == 4:
+        latest_blocked = dummy = output.op.input_tensors[0]
+        conv = dummy.op.input_tensors[0]
+    else:
+        conv = output.op.input_tensors[0]
+        latest_blocked = latest
+
+    ##### space definition begin #####
+    n, y, x, fc, fb = s[conv].op.axis
+    ry, rx, rcc, rcb = s[conv].op.reduce_axis
+
+    if conv.shape[3] % 2 == 0:
+        min_threads_div = 2
+    else:
+        min_threads_div = 1
+
+    cfg.define_split(
+        "tile_fc",
+        fc,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8
+        and entity.size[2] >= min_threads_div
+        and entity.size[2] < 256,
+    )
+    cfg.define_split(
+        "tile_y",
+        y,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16,
+    )
+    cfg.define_split(
+        "tile_x",
+        x,
+        num_outputs=3,
+        filter=lambda entity: entity.size[1] <= 8 and entity.size[2] <= 16,
+    )
+
+    cfg.define_split("tile_rcc", rcc, num_outputs=2)
+    cfg.define_split("tile_ry", ry, num_outputs=2)
+    cfg.define_split("tile_rx", rx, num_outputs=2)
+    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+    cfg.define_knob("unroll_explicit", [0, 1])
+
+    pad_data, kernel = s[conv].op.input_tensors
+    if (
+        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag
+    ):  # len(latest.op.axis) == 4:
+        # manage scheduling of datacopy
+        pad_data, kernel = s[conv].op.input_tensors
+        pack_data = pad_data.op.input_tensors[0]
+        bind_data_copy(s[pack_data])
+        bind_data_copy(s[kernel])
+
+    pad_data, kernel = s[conv].op.input_tensors
+
+    s[pad_data].compute_inline()
+
+    s[conv].set_scope("local")
+    if latest_blocked == latest and output != latest:
+        s[output].compute_inline()
+
+    # create cache stage
+    AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
+    bind_data_copy(s[AT])
+    WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
+    bind_data_copy(s[WT])
+
+    # tile and bind spatial axes
+    n, y, x, fc, fb = s[latest_blocked].op.axis
+
+    kernel_scope, n = s[latest_blocked].split(n, nparts=1)
+
+    bf, vf, tf = cfg["tile_fc"].apply(s, latest_blocked, fc)
+    by, vy, ty = cfg["tile_y"].apply(s, latest_blocked, y)
+    bx, vx, tx = cfg["tile_x"].apply(s, latest_blocked, x)
+
+    by = s[latest_blocked].fuse(n, by)
+    s[latest_blocked].bind(bf, te.thread_axis("blockIdx.z"))
+    s[latest_blocked].bind(by, te.thread_axis("blockIdx.y"))
+    s[latest_blocked].bind(bx, te.thread_axis("blockIdx.x"))
+    s[latest_blocked].bind(vf, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vy, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vx, te.thread_axis("vthread"))
+    s[latest_blocked].bind(tf, te.thread_axis("threadIdx.z"))
+    s[latest_blocked].bind(ty, te.thread_axis("threadIdx.y"))
+    s[latest_blocked].bind(tx, te.thread_axis("threadIdx.x"))
+    s[latest_blocked].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fb)
+    s[latest_blocked].vectorize(fb)
+
+    s[conv].compute_at(s[latest_blocked], tx)
+
+    # tile reduction axes
+    n, y, x, fc, fb = s[conv].op.axis
+
+    ry, rx, rcc, rcb = s[conv].op.reduce_axis
+    rco, rci = cfg["tile_rcc"].apply(s, conv, rcc)
+    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
+
+    s[conv].reorder(rco, ryo, rxo, rci, ryi, rxi, rcb, n, fc, y, x, fb)
+    s[conv].vectorize(fb)
+    s[conv].unroll(rcb)
+
+    # unroll
+    s[latest_blocked].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[latest_blocked].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
+
+    if latest_blocked != latest:
+        s[latest].compute_root()
+        bind_data_copy(s[latest], 1)
+        if latest != output:
+            s[output].compute_inline()
+
+    N, OH, OW, OCC, OCB = get_const_tuple(latest_blocked.shape)
+    KH, KW, IC, _, _ = get_const_tuple(kernel.shape)
+    ICKHKW = IC * KH * KW
+
+    if isinstance(N, int):
+        cfg.add_flop(2 * N * OH * OW * OCC * OCB * ICKHKW)
diff --git a/python/tvm/topi/adreno/depthwise_conv2d_nchw.py b/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
new file mode 100644
index 0000000000..298bd11e00
--- /dev/null
+++ b/python/tvm/topi/adreno/depthwise_conv2d_nchw.py
@@ -0,0 +1,316 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
+"""depthwise_conv2d_nchw(c) schedule on Qualcomm Adreno GPU"""
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from ..utils import get_const_tuple, traverse_inline
+from .utils import (
+    split_to_chunks,
+    pack_input,
+    pack_filter,
+    expand_spatial_dimensions,
+    add_pad,
+    bind_data_copy,
+)
+
+
+@autotvm.register_topi_compute("depthwise_conv2d_nchwc.image2d")
+def depthwise_conv2d_nchwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"):
+    """Compute depthwise_conv2d with NCHWc layout"""
+    args = {"shared": False, "accumulator": "float16"}
+    return compute_depthwise_conv2d_NCHWc_KCRSk(
+        data, kernel, strides, padding, dilation, out_dtype, args=args
+    )
+
+
+@autotvm.register_topi_compute("depthwise_conv2d_nchwc_acc32.image2d")
+def depthwise_conv2d_nchwc_acc32(
+    cfg, data, kernel, strides, padding, dilation, out_dtype="float16"
+):
+    """Compute depthwise_conv2d with NCHWc layout"""
+    args = {"shared": False, "accumulator": "float32"}
+    return compute_depthwise_conv2d_NCHWc_KCRSk(
+        data, kernel, strides, padding, dilation, out_dtype, args=args
+    )
+
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nchwc.image2d")
+def schedule_depthwise_conv2d_nchwc(cfg, outs):
+    return schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc16")
+
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nchwc_acc32.image2d")
+def schedule_depthwise_conv2d_nchwc_acc32(cfg, outs):
+    return schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag="cast_from_acc32")
+
+
+def schedule_depthwise_conv2d_nchwc_impl(cfg, outs, tag):
+    """Create the schedule for depthwise conv2d_nchw4c_ohwi4o"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == tag:
+            schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def compute_depthwise_conv2d_NCHWc_KCRSk(Input, Filter, stride, padding, dilation, out_dtype, args):
+    """
+    Depthwise convolution operator in NCHWc layout.
+    Algo:
+      1. Convert into blocked format if we have 4d original tensor.
+         In case of AutoTVM we override the convert by just tensors since such conversion
+         will be absent for real blocked convolution, no sense to include into tuning
+      2. Expand spatial dimensions to have width and height be dividable by factor 4
+         This leads to slightly bigger amount of compute but allow utilize GPU much better
+      3. Add paddings. This happens even if we do not need pad originaly. This is useful
+         due to work arounding of the gaps of texture annotation between Primary Functions
+         and limited support of textures in schedules. Later on this pad will be executed
+         separately and will produce texture
+      4. 5d Convolution compute with accumulating into out_dtype
+      5. Cast to the origin output data type
+      6. For case of 4d convolution: convert of output from 5d to 4d
+    """
+    if out_dtype is None:
+        out_dtype = Input.dtype
+    assert isinstance(stride, int) or len(stride) == 2
+    assert isinstance(dilation, int) or len(dilation) == 2
+
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    convert_from4d = False
+    if len(Input.shape) == 4:
+        batch, in_channels, in_height, in_width = Input.shape
+        out_channles, in_filter_channels, kernel_h, kernel_w = Filter.shape
+
+        in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(in_channels, 4)
+        out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channles, 4)
+
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            dshape = (batch, in_channel_chunks, in_height, in_width, in_channel_block)
+            Input = tvm.te.placeholder(dshape, Input.dtype, name="data_placeholder")
+            kshape = (out_channel_chunks, in_filter_channels, kernel_h, kernel_w, out_channel_block)
+            Filter = tvm.te.placeholder(kshape, Filter.dtype, name="kernel_placeholder")
+        else:
+            convert_from4d = True
+            Input = pack_input(
+                Input,
+                "NCHW",
+                batch,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                in_height,
+                in_width,
+            )
+            Filter = pack_filter(
+                Filter,
+                "OIHW",
+                out_channel_chunks,
+                out_channel_block,
+                out_channel_tail,
+                in_filter_channels,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                kernel_h,
+                kernel_w,
+            )
+
+    else:
+        batch, in_channel_chunks, in_height, in_width, in_channel_block = Input.shape
+        out_channel_chunks, in_filter_channels, kernel_h, kernel_w, out_channel_block = Filter.shape
+
+    out_height_orig, out_height, out_width_orig, out_width = expand_spatial_dimensions(
+        in_height, in_width, kernel_h, kernel_w, dilation_h, dilation_w, padding, stride_h, stride_w
+    )
+
+    temp = add_pad(
+        Input,
+        "NCHW",
+        out_height_orig,
+        out_width_orig,
+        kernel_h,
+        kernel_w,
+        dilation_h,
+        dilation_w,
+        padding,
+        stride_h,
+        stride_w,
+    )
+
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
+    conv = te.compute(
+        (batch, out_channel_chunks, out_height, out_width, out_channel_block),
+        lambda nn, ffc, yy, xx, ffb: te.sum(
+            (
+                temp[
+                    nn,
+                    ffc // in_filter_channels,
+                    yy * stride_h + ry * dilation_h,
+                    xx * stride_w + rx * dilation_w,
+                    ffb,
+                ]
+                * Filter[ffc // in_filter_channels, ffc % in_filter_channels, ry, rx, ffb]
+            ).astype(args["accumulator"]),
+            axis=[ry, rx],
+        ),
+        tag="depthwise_conv2d_nchwc_kcrsk",
+    )
+
+    if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning:
+        dummy_cast = te.compute(
+            (batch, out_channel_chunks, out_height_orig, out_width_orig, out_channel_block),
+            lambda n, fc, y, x, fb: conv[n, fc, y, x, fb].astype(out_dtype),
+            tag="dummy_cast",
+        )
+        return te.compute(
+            (batch, out_channles, out_height_orig, out_width_orig),
+            lambda n, c, y, x: dummy_cast[n, c // out_channel_block, y, x, c % out_channel_block],
+            tag="cast_from_acc" + args["accumulator"][-2:],
+        )
+    else:
+        return te.compute(
+            (batch, out_channel_chunks, out_height_orig, out_width_orig, out_channel_block),
+            lambda n, ffc, y, x, ffb: conv[n, ffc, y, x, ffb].astype(out_dtype),
+            tag="cast_from_acc" + args["accumulator"][-2:],
+        )
+
+
+def schedule_depthwise_conv2d_NCHWc_KCRSk(cfg, s, output):
+    """
+    schedule optimized for batch size = 1
+
+    Algo:
+    1. Split output axis to three parts: global work size, vthread, local worksize.
+       The limitations for tuning includes heuristics from some tuned networks to limit
+       search space and not pay much time for useles configurations.
+    2. For depthwise convolution it's better to inline pad into the conv2d compute, the
+       divergence in opencl kernel will not so significant as for regular conv2d.
+    3. For 5d convolution we schedule the latest op with binding 5d axis and vectorize
+       for textures
+       For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion
+       of data type
+    4. In case of 4d conv we need to schedule postops as well
+    """
+    latest = s.outputs[0].output(0)
+    if len(latest.op.axis) == 4:
+        latest_blocked = dummy = output.op.input_tensors[0]
+        conv = dummy.op.input_tensors[0]
+    else:
+        conv = output.op.input_tensors[0]
+        latest_blocked = latest
+
+    ##### space definition begin #####
+    n, fc, y, x, fb = s[conv].op.axis
+    ry, rx = s[conv].op.reduce_axis
+    cfg.define_split("tile_fc", fc, num_outputs=3)
+    cfg.define_split("tile_y", y, num_outputs=3)
+    cfg.define_split("tile_x", x, num_outputs=3)
+    cfg.define_split("tile_ry", ry, num_outputs=2)
+    cfg.define_split("tile_rx", rx, num_outputs=2)
+    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+    cfg.define_knob("unroll_explicit", [0, 1])
+    ##### space definition end #####
+
+    pad_data, kernel = s[conv].op.input_tensors
+    if (
+        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag
+    ):  # len(latest.op.axis) == 4:
+        # manage scheduling of datacopy
+        pad_data, kernel = s[conv].op.input_tensors
+        pack_data = pad_data.op.input_tensors[0]
+        bind_data_copy(s[pack_data])
+        bind_data_copy(s[kernel])
+
+    pad_data, kernel = s[conv].op.input_tensors
+
+    s[pad_data].compute_inline()
+
+    s[conv].set_scope("local")
+    if latest_blocked == latest and output != latest:
+        s[output].compute_inline()
+
+    # create cache stage
+    AT = s.cache_read(pad_data, "global.texture", [conv])
+    WT = s.cache_read(kernel, "global.texture-weight", [conv])
+    bind_data_copy(s[AT])
+    bind_data_copy(s[WT])
+
+    # tile and bind spatial axes
+    n, fc, y, x, fb = s[latest_blocked].op.axis
+    kernel_scope, n = s[latest_blocked].split(n, nparts=1)
+
+    bf, vf, tf = cfg["tile_fc"].apply(s, latest_blocked, fc)
+    by, vy, ty = cfg["tile_y"].apply(s, latest_blocked, y)
+    bx, vx, tx = cfg["tile_x"].apply(s, latest_blocked, x)
+
+    bf = s[latest_blocked].fuse(n, bf)
+    s[latest_blocked].bind(bf, te.thread_axis("blockIdx.z"))
+    s[latest_blocked].bind(by, te.thread_axis("blockIdx.y"))
+    s[latest_blocked].bind(bx, te.thread_axis("blockIdx.x"))
+    s[latest_blocked].bind(vf, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vy, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vx, te.thread_axis("vthread"))
+    s[latest_blocked].bind(tf, te.thread_axis("threadIdx.z"))
+    s[latest_blocked].bind(ty, te.thread_axis("threadIdx.y"))
+    s[latest_blocked].bind(tx, te.thread_axis("threadIdx.x"))
+    s[latest_blocked].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fb)
+    s[latest_blocked].vectorize(fb)
+
+    s[conv].compute_at(s[latest_blocked], tx)
+
+    # tile reduction axes
+    n, fc, y, x, fb = s[conv].op.axis
+
+    ry, rx = s[conv].op.reduce_axis
+    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
+
+    s[conv].reorder(ryo, rxo, ryi, rxi, n, fc, y, x, fb)
+    s[conv].vectorize(fb)
+
+    # unroll
+    s[latest_blocked].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[latest_blocked].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
+    if latest_blocked != latest:
+        s[latest].compute_root()
+        bind_data_copy(s[latest], 1)
+        if latest != output:
+            s[output].compute_inline()
+
+    N, OCC, OH, OW, OCB = get_const_tuple(latest_blocked.shape)
+    _, _, KH, KW, ICB = get_const_tuple(kernel.shape)
+    KHKW = KH * KW
+
+    if isinstance(N, int):
+        cfg.add_flop(2 * N * OH * OW * OCC * OCB * KHKW * ICB)
diff --git a/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py b/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
new file mode 100644
index 0000000000..b8a978d3c2
--- /dev/null
+++ b/python/tvm/topi/adreno/depthwise_conv2d_nhwc.py
@@ -0,0 +1,311 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
+"""depthwise_conv2d_nhwc(c) schedule on Qualcomm Adreno GPU"""
+import tvm
+from tvm import te
+from tvm import autotvm
+
+from ..utils import get_const_tuple, traverse_inline
+from .utils import (
+    split_to_chunks,
+    pack_input,
+    pack_filter,
+    expand_spatial_dimensions,
+    add_pad,
+    bind_data_copy,
+    get_texture_storage,
+)
+
+
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc.image2d")
+def depthwise_conv2d_nhwc(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"):
+    """Compute depthwise_conv2d with NHWC layout"""
+    args = {"shared": False, "accumulator": "float16"}
+    return compute_depthwise_conv2d_NHWC_HWOI(
+        data, kernel, strides, padding, dilation, out_dtype, args=args
+    )
+
+
+@autotvm.register_topi_compute("depthwise_conv2d_nhwc_acc32.image2d")
+def depthwise_conv2d_nhwc_acc32(cfg, data, kernel, strides, padding, dilation, out_dtype="float16"):
+    """Compute depthwise_conv2d with NHWC layout"""
+    args = {"shared": False, "accumulator": "float32"}
+    return compute_depthwise_conv2d_NHWC_HWOI(
+        data, kernel, strides, padding, dilation, out_dtype, args=args
+    )
+
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc.image2d")
+def schedule_depthwise_conv2d_nhwc(cfg, outs):
+    return schedule_depthwise_conv2d_nhwc_impl(cfg, outs, tag="cast_from_acc16")
+
+
+@autotvm.register_topi_schedule("depthwise_conv2d_nhwc_acc32.image2d")
+def schedule_depthwise_conv2d_nhwc_acc32(cfg, outs):
+    return schedule_depthwise_conv2d_nhwc_impl(cfg, outs, tag="cast_from_acc32")
+
+
+def schedule_depthwise_conv2d_nhwc_impl(cfg, outs, tag):
+    """Create the schedule for depthwise conv2d_nchw4c_ohwi4o"""
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _callback(op):
+        if op.tag == tag:
+            schedule_depthwise_conv2d_NHWC_HWOI(cfg, s, op.output(0))
+
+    traverse_inline(s, outs[0].op, _callback)
+    return s
+
+
+def compute_depthwise_conv2d_NHWC_HWOI(Input, Filter, stride, padding, dilation, out_dtype, args):
+    """
+    Depthwise convolution operator in NCHWc layout.
+    Algo:
+      1. Convert into blocked format if we have 4d original tensor.
+         In case of AutoTVM we override the convert by just tensors since such conversion
+         will be absent for real blocked convolution, no sense to include into tuning
+      2. Expand spatial dimensions to have width and height be dividable by factor 4
+         This leads to slightly bigger amount of compute but allow utilize GPU much better
+      3. Add paddings. This happens even if we do not need pad originaly. This is useful
+         due to work arounding of the gaps of texture annotation between Primary Functions
+         and limited support of textures in schedules. Later on this pad will be executed
+         separately and will produce texture
+      4. 5d Convolution compute with accumulating into out_dtype
+      5. Cast to the origin output data type
+      6. For case of 4d convolution: convert of output from 5d to 4d
+    """
+    if out_dtype is None:
+        out_dtype = Input.dtype
+    assert isinstance(stride, int) or len(stride) == 2
+    assert isinstance(dilation, int) or len(dilation) == 2
+
+    if isinstance(stride, int):
+        stride_h = stride_w = stride
+    else:
+        stride_h, stride_w = stride
+
+    if isinstance(dilation, int):
+        dilation_h = dilation_w = dilation
+    else:
+        dilation_h, dilation_w = dilation
+
+    convert_from4d = False
+    if len(Input.shape) == 4:
+        batch, in_height, in_width, in_channels = Input.shape
+        kernel_h, kernel_w, out_channles, in_filter_channels = Filter.shape
+
+        in_channel_chunks, in_channel_block, in_channel_tail = split_to_chunks(in_channels, 4)
+        out_channel_chunks, out_channel_block, out_channel_tail = split_to_chunks(out_channles, 4)
+
+        if autotvm.GLOBAL_SCOPE.in_tuning:
+            dshape = (batch, in_height, in_width, in_channel_chunks, in_channel_block)
+            Input = tvm.te.placeholder(dshape, Input.dtype, name="data_placeholder")
+            kshape = (kernel_h, kernel_w, out_channel_block, in_filter_channels, out_channel_chunks)
+            Filter = tvm.te.placeholder(kshape, Filter.dtype, name="kernel_placeholder")
+        else:
+            convert_from4d = True
+            Input = pack_input(
+                Input,
+                "NHWC",
+                batch,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                in_height,
+                in_width,
+            )
+            Filter = pack_filter(
+                Filter,
+                "HWOI",
+                out_channel_chunks,
+                out_channel_block,
+                out_channel_tail,
+                in_filter_channels,
+                in_channel_chunks,
+                in_channel_block,
+                in_channel_tail,
+                kernel_h,
+                kernel_w,
+            )
+
+    else:
+        batch, in_height, in_width, in_channel_chunks, in_channel_block = Input.shape
+        kernel_h, kernel_w, out_channel_chunks, in_filter_channels, out_channel_block = Filter.shape
+
+    out_height_orig, out_height, out_width_orig, out_width = expand_spatial_dimensions(
+        in_height, in_width, kernel_h, kernel_w, dilation_h, dilation_w, padding, stride_h, stride_w
+    )
+
+    temp = add_pad(
+        Input,
+        "NHWC",
+        out_height_orig,
+        out_width_orig,
+        kernel_h,
+        kernel_w,
+        dilation_h,
+        dilation_w,
+        padding,
+        stride_h,
+        stride_w,
+    )
+
+    ry = te.reduce_axis((0, kernel_h), name="ry")
+    rx = te.reduce_axis((0, kernel_w), name="rx")
+    conv = te.compute(
+        (batch, out_height, out_width, out_channel_chunks, out_channel_block),
+        lambda nn, yy, xx, ffc, ffb: te.sum(
+            (
+                temp[nn, yy * stride_h + ry * dilation_h, xx * stride_w + rx * dilation_w, ffc, ffb]
+                * Filter[ry, rx, ffc, 0, ffb]
+            ).astype(args["accumulator"]),
+            axis=[ry, rx],
+        ),
+        tag="depthwise_conv2d_nhwc",
+    )
+
+    if convert_from4d and not autotvm.GLOBAL_SCOPE.in_tuning:
+        dummy_cast = te.compute(
+            (batch, out_height_orig, out_width_orig, out_channel_chunks, out_channel_block),
+            lambda n, y, x, fc, fb: conv[n, y, x, fc, fb].astype(out_dtype),
+            tag="dummy_cast",
+        )
+        return te.compute(
+            (batch, out_height_orig, out_width_orig, out_channles),
+            lambda n, y, x, c: dummy_cast[n, y, x, c // out_channel_block, c % out_channel_block],
+            tag="cast_from_acc" + args["accumulator"][-2:],
+        )
+    else:
+        return te.compute(
+            (batch, out_height_orig, out_width_orig, out_channel_chunks, out_channel_block),
+            lambda n, y, x, ffc, ffb: conv[n, y, x, ffc, ffb].astype(out_dtype),
+            tag="cast_from_acc" + args["accumulator"][-2:],
+        )
+
+
+def schedule_depthwise_conv2d_NHWC_HWOI(cfg, s, output):
+    """
+    schedule optimized for batch size = 1
+
+    Algo:
+    1. Split output axis to three parts: global work size, vthread, local worksize.
+       The limitations for tuning includes heuristics from some tuned networks to limit
+       search space and not pay much time for useles configurations.
+    2. In case of 4d convolution schedule copying of the input (and filter) into
+      5d tensors
+    3. For depthwise convolution it's better to inline pad into the conv2d compute, the
+       divergence in opencl kernel will not so significant as for regular conv2d.
+    4. For 5d convolution we schedule the latest op with binding 5d axis and vectorize
+       for textures
+       For 4d tensor we are doing the same for the latest blocked stage, i.e. conversion
+       of data type
+    5. In case of 4d conv we need to schedule postops as well
+    """
+    latest = s.outputs[0].output(0)
+    if len(latest.op.axis) == 4:
+        latest_blocked = dummy = output.op.input_tensors[0]
+        conv = dummy.op.input_tensors[0]
+    else:
+        conv = output.op.input_tensors[0]
+        latest_blocked = latest
+
+    ##### space definition begin #####
+    n, y, x, fc, fb = s[conv].op.axis
+    ry, rx = s[conv].op.reduce_axis
+    cfg.define_split("tile_fc", fc, num_outputs=3)
+    cfg.define_split("tile_y", y, num_outputs=3)
+    cfg.define_split("tile_x", x, num_outputs=3)
+    cfg.define_split("tile_ry", ry, num_outputs=2)
+    cfg.define_split("tile_rx", rx, num_outputs=2)
+    cfg.define_knob("auto_unroll_max_step", [0, 512, 1500])
+    cfg.define_knob("unroll_explicit", [0, 1])
+    ##### space definition end #####
+
+    pad_data, kernel = s[conv].op.input_tensors
+    if (
+        isinstance(kernel.op, tvm.te.ComputeOp) and "filter_pack" in kernel.op.tag
+    ):  # len(latest.op.axis) == 4:
+        # manage scheduling of datacopy
+        pad_data, kernel = s[conv].op.input_tensors
+        pack_data = pad_data.op.input_tensors[0]
+        bind_data_copy(s[pack_data])
+        bind_data_copy(s[kernel])
+
+    pad_data, kernel = s[conv].op.input_tensors
+
+    s[pad_data].compute_inline()
+
+    s[conv].set_scope("local")
+    if latest_blocked == latest and output != latest:
+        s[output].compute_inline()
+
+    # create cache stage
+    AT = s.cache_read(pad_data, get_texture_storage(pad_data.shape), [conv])
+    WT = s.cache_read(kernel, get_texture_storage(kernel.shape), [conv])
+    bind_data_copy(s[AT])
+    bind_data_copy(s[WT])
+
+    # tile and bind spatial axes
+    n, y, x, fc, fb = s[latest_blocked].op.axis
+    kernel_scope, n = s[latest_blocked].split(n, nparts=1)
+
+    bf, vf, tf = cfg["tile_fc"].apply(s, latest_blocked, fc)
+    by, vy, ty = cfg["tile_y"].apply(s, latest_blocked, y)
+    bx, vx, tx = cfg["tile_x"].apply(s, latest_blocked, x)
+
+    by = s[latest_blocked].fuse(n, by)
+    s[latest_blocked].bind(bf, te.thread_axis("blockIdx.z"))
+    s[latest_blocked].bind(by, te.thread_axis("blockIdx.y"))
+    s[latest_blocked].bind(bx, te.thread_axis("blockIdx.x"))
+    s[latest_blocked].bind(vf, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vy, te.thread_axis("vthread"))
+    s[latest_blocked].bind(vx, te.thread_axis("vthread"))
+    s[latest_blocked].bind(tf, te.thread_axis("threadIdx.z"))
+    s[latest_blocked].bind(ty, te.thread_axis("threadIdx.y"))
+    s[latest_blocked].bind(tx, te.thread_axis("threadIdx.x"))
+    s[latest_blocked].reorder(bf, by, bx, vf, vy, vx, tf, ty, tx, fb)
+    s[latest_blocked].vectorize(fb)
+
+    s[conv].compute_at(s[latest_blocked], tx)
+
+    # tile reduction axes
+    n, y, x, fc, fb = s[conv].op.axis
+
+    ry, rx = s[conv].op.reduce_axis
+    ryo, ryi = cfg["tile_ry"].apply(s, conv, ry)
+    rxo, rxi = cfg["tile_rx"].apply(s, conv, rx)
+
+    s[conv].reorder(ryo, rxo, ryi, rxi, n, fc, y, x, fb)
+    s[conv].vectorize(fb)
+
+    # unroll
+    s[latest_blocked].pragma(kernel_scope, "auto_unroll_max_step", cfg["auto_unroll_max_step"].val)
+    s[latest_blocked].pragma(kernel_scope, "unroll_explicit", cfg["unroll_explicit"].val)
+    if latest_blocked != latest:
+        s[latest].compute_root()
+        bind_data_copy(s[latest], 1)
+        if latest != output:
+            s[output].compute_inline()
+
+    N, OH, OW, OCC, OCB = get_const_tuple(latest_blocked.shape)
+    KH, KW, _, _, _ = get_const_tuple(kernel.shape)
+    KHKW = KH * KW
+
+    if isinstance(N, int):
+        cfg.add_flop(2 * N * OH * OW * OCC * OCB * KHKW)
diff --git a/python/tvm/topi/adreno/pooling.py b/python/tvm/topi/adreno/pooling.py
new file mode 100644
index 0000000000..49f103c04a
--- /dev/null
+++ b/python/tvm/topi/adreno/pooling.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
+"""pooling schedules for Qualcomm Adreno GPU"""
+import tvm
+from tvm import te
+from .. import tag
+
+
+def schedule_pool(outs, layout):
+    """Schedule for various pooling operators.
+
+    Parameters
+    ----------
+    outs: Array of Tensor
+        The computation graph description of pool
+        in the format of an array of tensors.
+
+    layout: str
+        Data layout.
+
+    Returns
+    -------
+    s: Schedule
+        The computation schedule for pool.
+    """
+    outs = [outs] if isinstance(outs, te.tensor.Tensor) else outs
+    s = te.create_schedule([x.op for x in outs])
+
+    def _schedule(PaddedInput, Pool):
+        if isinstance(PaddedInput.op, tvm.te.ComputeOp):
+            s[PaddedInput].compute_inline()
+        num_thread = tvm.target.Target.current(allow_none=False).max_num_threads
+        num_thread = int(num_thread * 2)
+        if Pool.op in s.outputs:
+            Out = Pool
+            OL = s.cache_write(Pool, "local")
+        else:
+            Out = outs[0].op.output(0)
+            s[Pool].set_scope("local")
+        fused = s[Out].fuse(*s[Out].op.axis[:-1])
+        bx, tx = s[Out].split(fused, factor=num_thread)
+        s[Out].bind(bx, te.thread_axis("blockIdx.x"))
+        s[Out].bind(tx, te.thread_axis("threadIdx.x"))
+        s[Out].vectorize(s[Out].op.axis[-1])
+        if Pool.op in s.outputs:
+            s[OL].compute_at(s[Out], tx)
+            s[OL].vectorize(s[OL].op.axis[-1])
+        else:
+            s[Pool].compute_at(s[Out], tx)
+            s[Pool].vectorize(s[Pool].op.axis[-1])
+
+    scheduled_ops = []
+
+    def traverse(OP):
+        """Internal traverse function"""
+        # inline all one-to-one-mapping operators except the last stage (output)
+        if tag.is_broadcast(OP.tag):
+            if OP not in s.outputs:
+                s[OP].compute_inline()
+            for tensor in OP.input_tensors:
+                if isinstance(tensor.op, te.tensor.ComputeOp) and tensor.op not in scheduled_ops:
+                    traverse(tensor.op)
+        # schedule pool
+        elif OP.tag.startswith("pool"):
+            PaddedInput = OP.input_tensors[0]
+            Pool = OP.output(0)
+            _schedule(PaddedInput, Pool)
+        else:
+            raise RuntimeError("Unsupported operator: %s" % OP.tag)
+
+        scheduled_ops.append(OP)
+
+    traverse(outs[0].op)
+    return s
diff --git a/python/tvm/topi/adreno/utils.py b/python/tvm/topi/adreno/utils.py
new file mode 100644
index 0000000000..727741c11f
--- /dev/null
+++ b/python/tvm/topi/adreno/utils.py
@@ -0,0 +1,549 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name,unused-variable,unused-argument,no-else-return
+"""util functions to be reused in different compute/schedule on Qualcomm Adreno GPU"""
+
+import tvm
+import numpy
+from tvm import te
+from tvm.topi.utils import simplify
+from tvm.topi import nn
+from ..utils import get_const_tuple
+
+
+def get_div(value, start):
+    """Returns the maximum divider for `value` starting from `start` value"""
+    div = 1
+    for d in range(start, 0, -1):
+        if (value % d) == 0:
+            div = d
+            break
+    return div
+
+
+def split_to_chunks(extent, block):
+    """
+    Splits the trip count value to chunks and block, returns the remainder as well
+    the chunks and blocks covers or overlaps the origin value
+
+    If extent can be divisible by block:
+        extent = chunks * block
+    else
+        extent = (chunks - 1) * block + tail
+
+    Parameters
+    ----------
+    extent: int
+        tripcount for original compute
+
+    block: int
+        size of the block
+
+    Returns
+    ----------
+    out: tuple of the (chunks, block, tail)
+         chunks = ceildiv(extent, block)
+         tail = number of origin elements in the latest chunk
+    """
+    tail = extent % block
+    chunks = extent // block
+    if tail == 0:
+        tail = block
+    else:
+        chunks += 1
+    return chunks, block, tail
+
+
+def pack_input(Input, layout, batch, chunks, block, original_tail, in_height, in_width):
+    """
+    Adds compute stages for packing of the data in runtime. Extends channel dimensions
+    to be dividable by factor 4
+
+    This function should be substituted by Schedule.transform_layout() in the future: see
+    https://github.com/apache/tvm-rfcs/blob/main/rfcs/0039-buffer-physical-layout.md
+
+    Parameters
+    ----------
+    Input: tvm.te.Tensor
+        Input tensor to be repacked in runtime
+
+    layout: string
+        Layout of origin 4d tensor
+        NCHW or NHWC are acceptable
+
+    batch: int
+        Batch size
+
+    chunks: int
+        Number of channel chunks been in the final tensor
+
+    block: int
+        size of the channel block
+
+    original_tail: int
+        Tail in the latest chunk diffing original number of channels vs blocked one
+        If original_tail != block:
+          original_channels = chunks * block - original_tail
+        else
+          original_channels = chunks * block
+
+    in_height: int
+        Height of the feature map
+
+    in_width: int
+        Width of the feature map
+    """
+
+    pad_value = tvm.tir.const(0, Input.dtype)
+
+    def _reorder_data_nchw(*indices):
+        condition = []
+        condition.append(indices[1] == chunks - 1)
+        condition.append(indices[4] >= original_tail)
+        condition = tvm.tir.all(*condition)
+        return tvm.tir.if_then_else(
+            condition,
+            pad_value,
+            Input[indices[0], indices[1] * block + indices[4], indices[2], indices[3]],
+        )
+
+    def _reorder_data_nhwc(*indices):
+        condition = []
+        condition.append(indices[3] == chunks - 1)
+        condition.append(indices[4] >= original_tail)
+        condition = tvm.tir.all(*condition)
+        return tvm.tir.if_then_else(
+            condition,
+            pad_value,
+            Input[indices[0], indices[1], indices[2], indices[3] * block + indices[4]],
+        )
+
+    # compute:
+    if layout == "NCHW":
+        reordered_data = te.compute(
+            [batch, chunks, in_height, in_width, block],
+            _reorder_data_nchw,
+            name="input_pack",
+            tag="input_pack",
+        )
+    elif layout == "NHWC":
+        reordered_data = te.compute(
+            [batch, in_height, in_width, chunks, block],
+            _reorder_data_nhwc,
+            name="input_pack",
+            tag="input_pack",
+        )
+    else:
+        assert False, "Adreno util function pack_input does not accept unknown layout"
+    return reordered_data
+
+
+def pack_filter(
+    Filter,
+    layout,
+    out_chunks,
+    out_block,
+    out_original_tail,
+    in_filter_channels,
+    in_chunks,
+    in_block,
+    in_original_tail,
+    kernel_h,
+    kernel_w,
+):
+    """
+    Adds compute stages for packing of the filter in runtime. Extends channels dimensions
+    to be dividable by factor 4
+
+    This function should be substituted by Schedule.transform_layout() in the future: see
+    https://github.com/apache/tvm-rfcs/blob/main/rfcs/0039-buffer-physical-layout.md
+
+    Parameters
+    ----------
+    Filter: tvm.te.Tensor
+        Filter tensor to be repacked in runtime
+
+    layout: string
+        Layout of origin 4d tensor
+        NCHW or NHWC are acceptable
+
+    out_chunks: int
+        Number of chunks for filters
+
+    out_block: int
+        Size of the block for output channels
+
+    out_original_tail: int
+        Original size of the latest chunk of output filters
+
+    in_filter_channels: int
+        Number of filter channels. might be different vs input channels in the
+        data due to groups/depthwise nature
+
+    in_chunks: int
+        Number of input data channel chunks
+
+    in_block: int
+        Size of the block for input data channels
+
+    in_original_tail
+        Original size of the latest chunk for input data channels
+
+    kernel_h: int
+        Height of the conv2d kernel
+
+    kernel_w: int
+        Width of the conv2d kernel
+    """
+    pad_value = tvm.tir.const(0, Filter.dtype)
+
+    def _reorder_weights_depthwise_oihw(*indices):
+        conditionA = []
+        conditionA.append(indices[0] == out_chunks - 1)
+        conditionA.append(indices[4] >= out_original_tail)
+        conditionAT = tvm.tir.all(*conditionA)
+
+        return tvm.tir.if_then_else(
+            conditionAT,
+            pad_value,
+            Filter[indices[0] * out_block + indices[4], indices[1], indices[2], indices[3]],
+        )
+
+    def _reorder_weights_depthwise_hwoi(*indices):
+        conditionA = []
+        conditionA.append(indices[2] == out_chunks - 1)
+        conditionA.append(indices[4] >= out_original_tail)
+        conditionAT = tvm.tir.all(*conditionA)
+
+        return tvm.tir.if_then_else(
+            conditionAT,
+            pad_value,
+            Filter[indices[0], indices[1], indices[2] * out_block + indices[4], indices[3]],
+        )
+
+    def _reorder_weights_oihw(*indices):
+        conditionA = []
+        conditionA.append(indices[0] == out_chunks - 1)
+        conditionA.append(indices[4] >= out_original_tail)
+        conditionAT = tvm.tir.all(*conditionA)
+
+        conditionO = []
+        conditionO.append(conditionAT)
+        conditionO.append(indices[1] >= in_chunks * in_block + in_original_tail)
+        conditionOT = tvm.tir.any(*conditionO)
+        return tvm.tir.if_then_else(
+            conditionOT,
+            pad_value,
+            Filter[indices[0] * out_block + indices[4], indices[1], indices[2], indices[3]],
+        )
+
+    def _reorder_weights_hwio(*indices):
+        conditionA = []
+        conditionA.append(indices[3] == out_chunks - 1)
+        conditionA.append(indices[4] >= out_original_tail)
+        conditionAT = tvm.tir.all(*conditionA)
+
+        conditionO = []
+        conditionO.append(conditionAT)
+        conditionO.append(indices[2] >= in_chunks * in_block + in_original_tail)
+        conditionOT = tvm.tir.any(*conditionO)
+        return tvm.tir.if_then_else(
+            conditionOT,
+            pad_value,
+            Filter[indices[0], indices[1], indices[2], indices[3] * out_block + indices[4]],
+        )
+
+    if in_filter_channels == 1:
+        if layout == "OIHW":
+            reordered_filter = te.compute(
+                [out_chunks, in_filter_channels, kernel_h, kernel_w, out_block],
+                _reorder_weights_depthwise_oihw,
+                name="filter_pack",
+                tag="filter_pack",
+            )
+        elif layout == "HWOI":
+            reordered_filter = te.compute(
+                [kernel_h, kernel_w, out_chunks, in_filter_channels, out_block],
+                _reorder_weights_depthwise_hwoi,
+                name="filter_pack",
+                tag="filter_pack",
+            )
+        else:
+            assert False, "Adreno util function def pack_filter does not accept unknown layout"
+    else:
+        if layout == "OIHW":
+            reordered_filter = te.compute(
+                [out_chunks, in_filter_channels, kernel_h, kernel_w, out_block],
+                _reorder_weights_oihw,
+                name="filter_pack",
+                tag="filter_pack",
+            )
+        elif layout == "HWIO":
+            reordered_filter = te.compute(
+                [kernel_h, kernel_w, in_filter_channels, out_chunks, out_block],
+                _reorder_weights_hwio,
+                name="filter_pack",
+                tag="filter_pack",
+            )
+        else:
+            assert False, "Adreno util function def pack_filter does not accept unknown layout"
+    return reordered_filter
+
+
+def expand_spatial_dimensions(
+    in_height, in_width, kernel_h, kernel_w, dilation_h, dilation_w, padding, stride_h, stride_w
+):
+    """
+    Expands spatial dimensions to be dividable by factor 4. This will allow us to do extrimely
+    better parallel computation on GPU. The drawback of this solution - it will be number of
+    useless computations. By fact the speed-up of parallelism significantly overcomes the slowdown
+    of extra compute and eventuially this is useful approach, at least for GPU
+
+    Parameters
+    ----------
+    in_height: int
+        Height of the feature map
+
+    in_width: int
+        Width of the feature map
+
+    kernel_h: int
+        Height of the conv2d kernel
+
+    kernel_w: int
+        Width of the conv2d kernel
+
+    dilation_h: int
+        Vertical dilation of the conv2d kernel
+
+    dilation_w: int
+        Horizontal dilation of the conv2d kernel
+
+    padding: tuple or list
+        Conv2d paddings
+
+    stride_h: int
+        Vertical stride  of the conv2d kernel
+
+    stride_w: int
+        Horizontal stride  of the conv2d kernel
+    """
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+
+    pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
+
+    out_height_orig = out_height = simplify(
+        (in_height - dilated_kernel_h + pad_top + pad_down) // stride_h + 1
+    )
+    out_width_orig = out_width = simplify(
+        (in_width - dilated_kernel_w + pad_left + pad_right) // stride_w + 1
+    )
+
+    # can output shape be divded by 2 or even 4?
+    # if it cannot be divided, need to extend for further help with split
+    # theortically there should be addition padding for inputs, but it will be optimized by
+    # cache_read InferBound. We must proceed pad here exactly to produce tensor which is
+    # required for calculation of original out size, not more! In other case intermediate
+    # tensor might be allcoated with less sizes while compute will try to fill the expanded
+    # one - data discrepancy as a result
+    # And in case of textures it is not a problem if we provide texture of less size because
+    # 1. It is not important which values would be for extra calc - these calculations are
+    #    required only for better utilizatin of GPU fit to working groups
+    # 2. When we request pixel out opf bound, texture will handle this correctly. As mentioned
+    #    above, the value itself is not important
+    if out_height % 2 != 0:
+        out_height += 1
+    if out_width % 2 != 0:
+        out_width += 1
+
+    if out_height % 4 != 0:
+        out_height += 2
+    if out_width % 4 != 0:
+        out_width += 2
+    return out_height_orig, out_height, out_width_orig, out_width
+
+
+def add_pad(
+    data,
+    layout,
+    out_height,
+    out_width,
+    kernel_h,
+    kernel_w,
+    dilation_h,
+    dilation_w,
+    padding,
+    stride_h,
+    stride_w,
+):
+    """Computes required padding values by the parameters of conv2d and adds
+        compute for extending of original tensor
+
+    Parameters
+    ----------
+    data: tvm.te.Tensor
+        5d tensor, the layout of spatial dimensions are defined as separate argument
+
+    layout: string
+        Layout of origin 4d tensor
+
+    out_height: int
+        Height of the output feature map
+
+    out_width: int
+        Width of the output feature map
+
+    kernel_h: int
+        Height of the conv2d kernel
+
+    kernel_w: int
+        Width of the conv2d kernel
+
+    dilation_h: int
+        Height dilation value from conv2d attributes
+
+    dilation_w: int
+        Width dilation value from conv2d attributes
+
+    padding: list / tuple of n ints
+        Padding values from conv2d attributes
+
+    stride_h: int
+        Height stride value from conv2d attributes
+
+    stride_w: int
+        Width stride value from conv2d attributes
+
+    Returns
+    -------
+    Output : tvm.te.Tensor
+        n-D, the same layout as Input.
+    """
+    dilated_kernel_h = (kernel_h - 1) * dilation_h + 1
+    dilated_kernel_w = (kernel_w - 1) * dilation_w + 1
+    pad_top, pad_left, pad_down, pad_right = nn.get_pad_tuple(
+        padding, (dilated_kernel_h, dilated_kernel_w)
+    )
+
+    # compute graph
+    if layout == "NCHW":
+        y_axis = 2
+        x_axis = 3
+        if len(data.shape) == 4:
+            _, _, in_height, in_width = data.shape
+        else:
+            _, _, in_height, in_width, _ = data.shape
+    elif layout == "NHWC":
+        y_axis = 1
+        x_axis = 2
+        if len(data.shape) == 4:
+            _, in_height, in_width, _ = data.shape
+        else:
+            _, in_height, in_width, _, _ = data.shape
+    else:
+        assert False, "not supported layout in adreno util add_pad"
+    pad_before = [0, 0, 0, 0, 0]
+    pad_after = [0, 0, 0, 0, 0]
+    pad_before[y_axis] = pad_top
+    pad_before[x_axis] = pad_left
+    pad_after[y_axis] = pad_down
+    pad_after[x_axis] = pad_right
+
+    # calculation of real used input size:
+    input_latest_w = (out_width - 1) * stride_w + (kernel_w - 1) * dilation_w + 1
+    input_latest_h = (out_height - 1) * stride_h + (kernel_h - 1) * dilation_h + 1
+    if input_latest_w < in_width + pad_before[x_axis] + pad_after[x_axis]:
+        pad_after[x_axis] -= in_width + pad_before[x_axis] + pad_after[x_axis] - input_latest_w
+    if input_latest_h < in_height + pad_before[y_axis] + pad_after[y_axis]:
+        pad_after[y_axis] -= in_height + pad_before[y_axis] + pad_after[y_axis] - input_latest_h
+    return nn.pad(data, pad_before, pad_after, name="pad_temp")
+
+
+def bind_data_copy(stage, axis_to_vectorize=None):
+    """
+    Schedules the eltwise stages like copying of data or postops
+
+    Parameters
+    ----------
+    stage: tvm.te.Tensor
+
+    axis_to_vectorize:
+        Causes to split certain axis, moves inner part to the end of schedule
+        and enable vectorization by this axis
+        If parameter is not pointed, the schedule will be vectorized if the most inner
+        dim is eq to 4 (size of the vector in texture)
+    """
+    shape = get_const_tuple(stage.op.output(0).shape)
+    if axis_to_vectorize and len(shape) == 4 and shape[axis_to_vectorize] % 4 == 0:
+        ax0, ax1, ax2, ax3 = stage.op.axis
+        if axis_to_vectorize == 1:
+            oax1, iax1 = stage.split(ax1, factor=4)
+            stage.reorder(ax0, oax1, ax2, ax3, iax1)
+            stage.vectorize(iax1)
+            fused = stage.fuse(ax0, oax1, ax2, ax3)
+        elif axis_to_vectorize == 3:
+            oax3, iax3 = stage.split(ax3, factor=4)
+            stage.reorder(ax0, ax1, ax2, oax3, iax3)
+            stage.vectorize(iax3)
+            fused = stage.fuse(ax0, ax1, ax2, oax3)
+
+        ftc = numpy.prod(shape) / 4
+        div = get_div(ftc, 128)
+        block, thread = stage.split(fused, factor=div)
+
+        stage.bind(block, te.thread_axis("blockIdx.z"))
+        stage.bind(thread, te.thread_axis("threadIdx.z"))
+    else:
+        axes = stage.op.axis
+        fused = stage.fuse(*axes[:-1])
+        if shape[-1] <= 32:
+            ftc = numpy.prod(shape[:-1])
+            div = get_div(ftc, 64)
+            block, thread = stage.split(fused, factor=div)
+            stage.bind(block, te.thread_axis("blockIdx.x"))
+            stage.bind(thread, te.thread_axis("threadIdx.x"))
+            if shape[-1] == 4:
+                stage.vectorize(axes[-1])
+        else:
+            stage.bind(fused, te.thread_axis("blockIdx.x"))
+            stage.bind(*axes[-1:], te.thread_axis("threadIdx.x"))
+
+
+def get_texture_storage(shape):
+    """
+    Returns the texture layout acceptable for the shape
+
+    Parameters
+    ----------
+    shape: array
+        Shape of the tensor to be packed to texture
+    """
+    # certain limitation of the Qualcomm devices. Subject to be determined for certain device
+    # individually, but until we have access to remote device during compilation, we have to
+    # define it uniformly for all target devices
+    # limit = 16384
+    limit = tvm.target.Target.current().attrs["texture_spatial_limit"]
+
+    if shape[0] * shape[1] * shape[2] < limit and shape[3] < limit:
+        return "global.texture"
+    elif shape[0] * shape[1] < limit and shape[2] * shape[3] < limit:
+        return "global.texture-nhwc"
+    else:
+        return "global.texture-weight"
diff --git a/src/runtime/opencl/opencl_common.h b/src/runtime/opencl/opencl_common.h
index c2905b4327..6877240c08 100644
--- a/src/runtime/opencl/opencl_common.h
+++ b/src/runtime/opencl/opencl_common.h
@@ -342,7 +342,12 @@ class OpenCLThreadEntry {
   static OpenCLThreadEntry* ThreadLocal();
 };
 
-/*! \brief OpenCL runtime buffer structure with tracked memory layout */
+/*! \brief OpenCL runtime buffer structure with tracked memory layout
+    TODO(tvm-team): Uncouple use of storage scope and data layout by using the transform_layout
+    schedule primitive to express the desired texture layout. This will require supporting Nd
+    indices in BufferLoad and BufferStore in CodegenOpenCL, and ensuring Nd allocations for
+    texture are correctly routed to the AllocateTexture packed function in the OpenCL DeviceAPI.
+*/
 struct BufferDescriptor {
   enum class MemoryLayout {
     /*! \brief One dimensional buffer in row-major layout*/
@@ -355,6 +360,10 @@ struct BufferDescriptor {
      *         e.g. image2d[height=O, width=IHW]
      */
     kImage2DWeight,
+    /*! \brief Two dimensional texture w/ height = axis[1]
+     *         e.g. image2d[height=NH, width=WC]
+     */
+    kImage2DNHWC,
   };
   BufferDescriptor() = default;
   explicit BufferDescriptor(Optional<String> scope) : layout(MemoryLayoutFromScope(scope)) {}
diff --git a/src/runtime/opencl/opencl_device_api.cc b/src/runtime/opencl/opencl_device_api.cc
index 80b95a6ebf..478ec181e8 100644
--- a/src/runtime/opencl/opencl_device_api.cc
+++ b/src/runtime/opencl/opencl_device_api.cc
@@ -72,6 +72,8 @@ cl::BufferDescriptor::MemoryLayout cl::BufferDescriptor::MemoryLayoutFromScope(
     return cl::BufferDescriptor::MemoryLayout::kImage2DActivation;
   } else if (mem_scope.value() == "global.texture-weight") {
     return cl::BufferDescriptor::MemoryLayout::kImage2DWeight;
+  } else if (mem_scope.value() == "global.texture-nhwc") {
+    return cl::BufferDescriptor::MemoryLayout::kImage2DNHWC;
   }
   LOG(FATAL) << "No memory layout defined for memory of scope: " << mem_scope.value();
   return cl::BufferDescriptor::MemoryLayout::kBuffer1D;
@@ -85,6 +87,8 @@ String cl::BufferDescriptor::ScopeFromMemoryLayout(cl::BufferDescriptor::MemoryL
       return "global.texture";
     case cl::BufferDescriptor::MemoryLayout::kImage2DWeight:
       return "global.texture-weight";
+    case cl::BufferDescriptor::MemoryLayout::kImage2DNHWC:
+      return "global.texture-nhwc";
   }
   LOG(FATAL) << "No scope corresponding to the provided memory layout: "
              << static_cast<int>(layout);
@@ -285,6 +289,7 @@ void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHand
         break;
       case cl::BufferDescriptor::MemoryLayout::kImage2DActivation:
       case cl::BufferDescriptor::MemoryLayout::kImage2DWeight:
+      case cl::BufferDescriptor::MemoryLayout::kImage2DNHWC:
         auto image_info = GetImageInfo(from_desc, from);
         // TODO(csullivan): Support calculating row_pitch correctly in the case of reuse.
         // Note that when utilizing texture pools for memory reuse, the allocated image
@@ -306,6 +311,7 @@ void OpenCLWorkspace::CopyDataFromTo(DLTensor* from, DLTensor* to, TVMStreamHand
         break;
       case cl::BufferDescriptor::MemoryLayout::kImage2DActivation:
       case cl::BufferDescriptor::MemoryLayout::kImage2DWeight:
+      case cl::BufferDescriptor::MemoryLayout::kImage2DNHWC:
         auto image_info = GetImageInfo(to_desc, to);
         OPENCL_CALL(clEnqueueWriteImage(
             this->GetQueue(to->device), to_desc->buffer, CL_FALSE, image_info.origin,
diff --git a/src/runtime/texture.h b/src/runtime/texture.h
index 83725c00b8..5f43c8cee8 100644
--- a/src/runtime/texture.h
+++ b/src/runtime/texture.h
@@ -57,6 +57,12 @@ inline size_t DefaultTextureLayoutSeparator(size_t shape_rank,
     separator = shape_rank - 2;
   } else if (convention == "global.texture-weight") {
     separator = 1;
+  } else if (convention == "global.texture-nhwc") {
+    if (shape_rank == 3) {
+      separator = 1;
+    } else {
+      separator = 2;
+    }
   } else {
     LOG(FATAL) << "Encountered unknown texture lowering convention: " << convention;
   }
diff --git a/src/runtime/thread_storage_scope.h b/src/runtime/thread_storage_scope.h
index 4122f9d079..bc9e2faa80 100644
--- a/src/runtime/thread_storage_scope.h
+++ b/src/runtime/thread_storage_scope.h
@@ -60,6 +60,8 @@ enum class StorageRank {
   kWMMAMatrixB = 5,
   /*! \brief wmma scope memory of accumulator */
   kWMMAAccumulator = 6,
+  /*! \brief global scope texture memory */
+  kTexture = 7,
 };
 
 /*!
@@ -109,6 +111,8 @@ struct StorageScope {
         return "wmma.matrix_b" + tag;
       case StorageRank::kWMMAAccumulator:
         return "wmma.accumulator" + tag;
+      case StorageRank::kTexture:
+        return "texture" + tag;
       default:
         LOG(FATAL) << "unknown storage scope";
         return "";
@@ -144,6 +148,9 @@ struct StorageScope {
     } else if (s.compare(0, 16, "wmma.accumulator") == 0) {
       r.rank = StorageRank::kWMMAAccumulator;
       r.tag = s.substr(16, std::string::npos);
+    } else if (s.compare(0, 7, "texture") == 0) {
+      r.rank = StorageRank::kTexture;
+      r.tag = s.substr(7, std::string::npos);
     } else {
       LOG(FATAL) << "unknown storage scope " << s;
     }
diff --git a/src/target/source/codegen_source_base.cc b/src/target/source/codegen_source_base.cc
index 5acb42071b..2353d2e6ba 100644
--- a/src/target/source/codegen_source_base.cc
+++ b/src/target/source/codegen_source_base.cc
@@ -22,6 +22,8 @@
  */
 #include "codegen_source_base.h"
 
+#include <algorithm>
+
 namespace tvm {
 namespace codegen {
 
@@ -73,6 +75,9 @@ std::string CodeGenSourceBase::AllocVarID(const tir::VarNode* v) {
   ICHECK(!var_idmap_.count(v)) << "Need input to be in SSA form dup " << v->name_hint;
   std::string key = v->name_hint;
   std::string vid = GetUniqueName(key);
+  std::replace(vid.begin(), vid.end(), ':', '_');
+  std::replace(vid.begin(), vid.end(), '-', '_');
+  std::replace(vid.begin(), vid.end(), '.', '_');
   var_idmap_[v] = vid;
   return vid;
 }
diff --git a/src/target/target_kind.cc b/src/target/target_kind.cc
index 43bcfef105..1148013706 100644
--- a/src/target/target_kind.cc
+++ b/src/target/target_kind.cc
@@ -324,6 +324,7 @@ TVM_REGISTER_TARGET_KIND("opencl", kDLOpenCL)
     .add_attr_option<Bool>("system-lib")
     .add_attr_option<Integer>("max_num_threads", Integer(256))
     .add_attr_option<Integer>("thread_warp_size", Integer(1))
+    .add_attr_option<Integer>("texture_spatial_limit", Integer(16384))
     .set_default_keys({"opencl", "gpu"});
 
 // The metal has some limitations on the number of input parameters. This is why attribute
diff --git a/tests/python/relay/test_conv2d_nchw_texture.py b/tests/python/relay/test_conv2d_nchw_texture.py
new file mode 100644
index 0000000000..d36da51c8f
--- /dev/null
+++ b/tests/python/relay/test_conv2d_nchw_texture.py
@@ -0,0 +1,394 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay import testing
+from utils.adreno_utils import gpu_preprocess, build_run_compare
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 32, 42, 42)
+    filter_shape = (96, 32, 3, 3)
+    bias_shape = (1, 96, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=96,
+        kernel_size=(3, 3),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 32, 40, 40)
+    filter_shape = (96, 32, 2, 2)
+    bias_shape = (1, 96, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=96,
+        kernel_size=(2, 2),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_inceptionv3_35_35_strides():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 48, 35, 35)
+    filter_shape = (64, 48, 5, 5)
+    bias_shape = (1, 64, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[2, 2, 2, 2],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=64,
+        kernel_size=(5, 5),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_resnet50_v2_nchw_3c():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 3, 224, 224)
+    filter_shape = (64, 3, 7, 7)
+    bias_shape = (1, 64, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[3, 3, 3, 3],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=64,
+        kernel_size=(7, 7),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    # mod, params = relay.testing.init.create_workload(func)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_inceptionv3_nchw_3c():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 3, 299, 299)
+    filter_shape = (64, 3, 3, 3)
+    bias_shape = (1, 64, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=64,
+        kernel_size=(3, 3),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_1x1_16c16spatial():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 16, 256, 256)
+    filter_shape = (32, 16, 4, 4)
+    bias_shape = (1, 32, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(4, 4),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_4x4_16c16pad():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 32, 256, 256)
+    filter_shape = (32, 32, 4, 4)
+    bias_shape = (1, 32, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[3, 3, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(4, 4),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_4x4x4_16c16pad():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 32, 256, 256)
+    filter_shape = (4, 32, 4, 4)
+    bias_shape = (1, 4, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[3, 3, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=4,
+        kernel_size=(4, 4),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_yolov3_v2_nchw_3c():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 1024, 13, 13)
+    filter_shape = (255, 1024, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[0, 0, 0, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=255,
+        kernel_size=(1, 1),
+    )
+
+    mod = relay.Function([A, B], conv)
+    # mod, params = relay.testing.init.create_workload(func)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    initializer("weight", filter_data)
+    params = {
+        "weight": tvm.nd.array(filter_data),
+    }
+
+    build_run_compare(mod, params, {"data": input_shape}, dtype, target)
diff --git a/tests/python/relay/test_conv2d_nhwc_texture.py b/tests/python/relay/test_conv2d_nhwc_texture.py
new file mode 100644
index 0000000000..a02b7cabbe
--- /dev/null
+++ b/tests/python/relay/test_conv2d_nhwc_texture.py
@@ -0,0 +1,556 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay import testing
+from utils.adreno_utils import gpu_preprocess, build_run_compare
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 257, 257, 32)
+    filter_shape = (1, 1, 32, 16)
+    bias_shape = (filter_shape[-1],)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype=dtype,
+        channels=filter_shape[-1],
+        kernel_size=(1, 1),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_deeplabv3_1_257_257_32x1_1_32_16_with_padding():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 257, 257, 32)
+    filter_shape = (1, 1, 32, 16)
+    bias_shape = (filter_shape[-1],)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        padding=[3, 3, 3, 3],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=filter_shape[-1],
+        kernel_size=(1, 1),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    # mod, params = relay.testing.init.create_workload(func)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_4_35_35_32x3_3_144_16():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (4, 35, 35, 32)
+    filter_shape = (3, 3, 32, 16)
+    bias_shape = (filter_shape[-1],)
+    kernel_size = (filter_shape[0], filter_shape[1])
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype=dtype,
+        channels=filter_shape[-1],
+        kernel_size=kernel_size,
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_deeplabv3_1_513_513_3x3_3_3_32():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 513, 513, 3)
+    filter_shape = (3, 3, 3, 32)
+    bias_shape = (filter_shape[-1],)
+    kernel_size = (filter_shape[0], filter_shape[1])
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        out_dtype=dtype,
+        channels=filter_shape[-1],
+        kernel_size=kernel_size,
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.ones(filter_shape).astype(dtype)
+    bias_data = np.ones(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 42, 42, 32)
+    filter_shape = (3, 3, 32, 96)
+    bias_shape = (1, 1, 1, 96)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=96,
+        kernel_size=(3, 3),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_inceptionv3_64x35x35_96x64x3x3_nopad_pass():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 40, 40, 32)
+    filter_shape = (2, 2, 32, 96)
+    bias_shape = (1, 1, 1, 96)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=96,
+        kernel_size=(2, 2),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_inceptionv3_35_35_strides():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 35, 35, 48)
+    filter_shape = (5, 5, 48, 64)
+    bias_shape = (1, 1, 1, 64)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        padding=[2, 2, 2, 2],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=64,
+        kernel_size=(5, 5),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_resnet50_v2_nhwc_3c():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 224, 224, 3)
+    filter_shape = (7, 7, 3, 64)
+    bias_shape = (1, 1, 1, 64)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        padding=[3, 3, 3, 3],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=64,
+        kernel_size=(7, 7),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    # mod, params = relay.testing.init.create_workload(func)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_inceptionv3_nhwc_3c():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 299, 299, 3)
+    filter_shape = (3, 3, 3, 64)
+    bias_shape = (1, 1, 1, 64)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=64,
+        kernel_size=(3, 3),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_1x1_16c16spatial():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 256, 256, 16)
+    filter_shape = (4, 4, 16, 32)
+    bias_shape = (1, 1, 1, 32)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        padding=[0, 0, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(4, 4),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_4x4_16c16pad():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 256, 256, 32)
+    filter_shape = (4, 4, 32, 32)
+    bias_shape = (1, 1, 1, 32)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        padding=[3, 3, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=32,
+        kernel_size=(4, 4),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_4x4x4_16c16pad():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 256, 256, 32)
+    filter_shape = (4, 4, 32, 4)
+    bias_shape = (1, 1, 1, 4)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        padding=[3, 3, 0, 0],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=4,
+        kernel_size=(4, 4),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_conv2d_yolov3_v2_nhwc_3c():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 13, 13, 1024)
+    filter_shape = (1, 1, 1024, 255)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWIO",
+        padding=[0, 0, 0, 0],
+        strides=[1, 1],
+        out_dtype=dtype,
+        channels=255,
+        kernel_size=(1, 1),
+    )
+
+    mod = relay.Function([A, B], conv)
+    # mod, params = relay.testing.init.create_workload(func)
+    np.random.seed(0)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    initializer("weight", filter_data)
+    params = {
+        "weight": tvm.nd.array(filter_data),
+    }
+
+    build_run_compare(mod, params, {"data": input_shape}, dtype, target)
diff --git a/tests/python/relay/test_depthwise_conv2d_nchw_texture.py b/tests/python/relay/test_depthwise_conv2d_nchw_texture.py
new file mode 100644
index 0000000000..71cf62c5d8
--- /dev/null
+++ b/tests/python/relay/test_depthwise_conv2d_nchw_texture.py
@@ -0,0 +1,194 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay import testing
+from utils.adreno_utils import gpu_preprocess, build_run_compare
+
+
+@tvm.testing.requires_opencl
+def test_depthwise_conv2d_bias_nchwc():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 64, 112, 112)
+    filter_shape = (64, 1, 3, 3)
+    bias_shape = (1, 64, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[1, 1, 1, 1],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=64,
+        groups=64,
+        kernel_size=(3, 3),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    # mod, params = relay.testing.init.create_workload(func)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+
+
+@tvm.testing.requires_opencl
+def test_depthwise_conv2d_nchwc():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 64, 112, 112)
+    filter_shape = (64, 1, 3, 3)
+    bias_shape = (1, 64, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[1, 1, 1, 1],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=64,
+        groups=64,
+        kernel_size=(3, 3),
+    )
+
+    mod = relay.Function([A, B], conv)
+    # mod, params = relay.testing.init.create_workload(func)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target, gpu_preprocess)
+
+
+@tvm.testing.requires_opencl
+def test_depthwise_conv2d_bias_nchw():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 64, 112, 112)
+    filter_shape = (64, 1, 3, 3)
+    bias_shape = (1, 64, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[1, 1, 1, 1],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=64,
+        groups=64,
+        kernel_size=(3, 3),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    # mod, params = relay.testing.init.create_workload(func)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_depthwise_conv2d_repack_bias_nchw():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 63, 112, 112)
+    filter_shape = (63, 1, 3, 3)
+    bias_shape = (1, 63, 1, 1)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    # C = relay.nn.relu(A)
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NCHW",
+        kernel_layout="OIHW",
+        padding=[1, 1, 1, 1],
+        strides=[2, 2],
+        out_dtype=dtype,
+        channels=63,
+        groups=63,
+        kernel_size=(3, 3),
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    # mod, params = relay.testing.init.create_workload(func)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
diff --git a/tests/python/relay/test_depthwise_conv2d_nhwc_texture.py b/tests/python/relay/test_depthwise_conv2d_nhwc_texture.py
new file mode 100644
index 0000000000..16d26c77ca
--- /dev/null
+++ b/tests/python/relay/test_depthwise_conv2d_nhwc_texture.py
@@ -0,0 +1,233 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import os
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay import testing
+from utils.adreno_utils import gpu_preprocess, build_run_compare
+
+
+@tvm.testing.requires_opencl
+def test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 129, 129, 144)
+    filter_shape = (3, 3, 144, 1)
+    kernel_size = (filter_shape[0], filter_shape[1])
+    bias_shape = (filter_shape[2],)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+        out_dtype=dtype,
+        groups=filter_shape[2],
+        channels=filter_shape[2],
+        kernel_size=kernel_size,
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    mod = relay.Function([A, B, bias], conv)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_depthwise_conv2d_deeplabv3_4_35_35_576x3_3_576_1():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (4, 35, 35, 576)
+    filter_shape = (3, 3, 576, 1)
+    kernel_size = (filter_shape[0], filter_shape[1])
+    bias_shape = (filter_shape[2],)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+        out_dtype=dtype,
+        groups=filter_shape[2],
+        channels=filter_shape[2],
+        kernel_size=kernel_size,
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    mod = relay.Function([A, B, bias], conv)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_depthwise_conv2d_deeplabv3_1_129_129_144x3_3_144_1_with_padding():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 129, 129, 144)
+    filter_shape = (3, 3, 144, 1)
+    kernel_size = (filter_shape[0], filter_shape[1])
+    bias_shape = (filter_shape[2],)
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+        padding=[3, 3, 3, 3],
+        strides=[2, 2],
+        out_dtype=dtype,
+        groups=filter_shape[2],
+        channels=filter_shape[2],
+        kernel_size=kernel_size,
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    # mod, params = relay.testing.init.create_workload(func)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.zeros(filter_shape).astype(dtype)
+    bias_data = np.zeros(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_depthwise_conv2d_1_513_513_7x3_3_7_1():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 513, 513, 7)
+    filter_shape = (3, 3, 7, 1)
+    bias_shape = (filter_shape[2],)
+    kernel_size = (filter_shape[0], filter_shape[1])
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+        out_dtype=dtype,
+        channels=filter_shape[2],
+        groups=filter_shape[2],
+        kernel_size=kernel_size,
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.ones(filter_shape).astype(dtype)
+    bias_data = np.ones(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
+
+
+@tvm.testing.requires_opencl
+def test_depthwise_conv2d_1_513_513_3x3_3_3_1():
+    target = "opencl --device=adreno"
+    dtype = "float16"
+
+    input_shape = (1, 513, 513, 3)
+    filter_shape = (3, 3, 3, 1)
+    bias_shape = (filter_shape[2],)
+    kernel_size = (filter_shape[0], filter_shape[1])
+    A = relay.var("data", shape=input_shape, dtype=dtype)
+    B = relay.var("weight", shape=filter_shape, dtype=dtype)
+    bias = relay.var("bias", shape=bias_shape, dtype=dtype)
+
+    conv = relay.nn.conv2d(
+        A,
+        B,
+        data_layout="NHWC",
+        kernel_layout="HWOI",
+        out_dtype=dtype,
+        channels=filter_shape[2],
+        groups=filter_shape[2],
+        kernel_size=kernel_size,
+    )
+    D = relay.op.add(conv, bias)
+    D = relay.op.nn.relu(D)
+
+    mod = relay.Function([A, B, bias], D)
+    np.random.seed(1)
+    initializer = relay.testing.init.Xavier()
+    filter_data = np.ones(filter_shape).astype(dtype)
+    bias_data = np.ones(bias_shape).astype(dtype)
+    initializer("weight", filter_data)
+    initializer("bias", bias_data)
+    params1 = {
+        "weight": tvm.nd.array(filter_data),
+        "bias": tvm.nd.array(bias_data),
+    }
+
+    build_run_compare(mod, params1, {"data": input_shape}, dtype, target)
diff --git a/tests/python/relay/utils/adreno_utils.py b/tests/python/relay/utils/adreno_utils.py
new file mode 100644
index 0000000000..11abce3bfa
--- /dev/null
+++ b/tests/python/relay/utils/adreno_utils.py
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Utils for adreno compute/schedules"""
+
+import os
+import tvm
+import numpy as np
+from tvm import relay
+from tvm.relay import testing
+from tvm.relay.transform import recast
+from tvm.contrib import graph_runtime
+
+
+def get_cpu_reference(mod, params1, input_shape, inputs):
+    mod_fp32 = recast(mod, "float32", "float32", ops=["nn.conv2d", "add", "nn.relu"])
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(mod_fp32, "llvm", params=params1)
+    ctx = tvm.cpu()
+    m = graph_runtime.create(graph, lib, ctx)
+    if isinstance(input_shape, dict):
+        for key in input_shape:
+            m.set_input(key, inputs[-1])
+    else:
+        m.set_input("data", inputs[-1])
+    m.set_input(**params)
+    m.run()
+    return [
+        m.get_output(0).asnumpy(),
+    ]
+
+
+# build module run with opencl and cpu, compare results
+def build_run_compare(
+    tvm_mod, params1, input_shape, dtype="float32", target="llvm", gpu_preprocess=None
+):
+
+    if "TVM_TRACKER_HOST" in os.environ and "TVM_TRACKER_PORT" in os.environ:
+        rpc_tracker_host = os.environ["TVM_TRACKER_HOST"]
+        rpc_tracker_port = os.environ["TVM_TRACKER_PORT"]
+        run_on_host = 0
+        target_host = "llvm -mtriple=arm64-linux-android"
+        rpc_tracker_port = int(rpc_tracker_port)
+    else:
+        run_on_host = 1
+        target_host = "llvm"
+
+    if gpu_preprocess:
+        tvm_mod_nchwc = gpu_preprocess(tvm_mod)
+    else:
+        tvm_mod_nchwc = tvm_mod
+
+    with relay.build_config(opt_level=3):
+        graph, lib, params = relay.build(
+            tvm_mod_nchwc, target_host=target_host, target=target, params=params1
+        )
+    if run_on_host:
+        ctx = tvm.opencl()
+        m = graph_runtime.create(graph, lib, ctx)
+    else:
+        from tvm import rpc
+        from tvm.contrib import utils, ndk
+
+        rpc_key = "android"
+        tracker = rpc.connect_tracker(rpc_tracker_host, rpc_tracker_port)
+        remote = tracker.request(rpc_key, priority=0, session_timeout=600)
+        temp = utils.tempdir()
+        dso_binary = "dev_lib_cl.so"
+        dso_binary_path = temp.relpath(dso_binary)
+        ctx = remote.cl(0)
+        lib.export_library(dso_binary_path, ndk.create_shared)
+        remote.upload(dso_binary_path)
+        rlib = remote.load_module(dso_binary)
+        m = graph_runtime.create(graph, rlib, ctx)
+    m.set_input(**params)
+    inputs = []
+    if isinstance(input_shape, dict):
+        for key in input_shape:
+            inputs.append(np.random.normal(size=input_shape[key]).astype(dtype))
+            m.set_input(key, inputs[-1])
+    else:
+        inputs.append(np.random.normal(size=input_shape).astype(dtype))
+        m.set_input("data", inputs[-1])
+    m.run()
+
+    ref_outputs = get_cpu_reference(tvm_mod, params1, input_shape, inputs)
+    for i, ref_output in enumerate(ref_outputs):
+        tvm_output = m.get_output(i)
+        output = tvm_output.asnumpy()
+        # for index, x in np.ndenumerate(ref_output):
+        #     if abs(output[index] - x) > 0.01:
+        #         print(index, output[index], x)
+
+        np.testing.assert_allclose(output, ref_output, rtol=1e-1, atol=1e-1)
+
+
+def gpu_preprocess(tvm_mod):
+    layout_config = relay.transform.LayoutConfig()
+    desired_layouts = {"nn.conv2d": ["NCHW4c", "OIHW4o"]}
+    with layout_config:
+        seq = tvm.transform.Sequential([relay.transform.ConvertLayout(desired_layouts)])
+        with tvm.transform.PassContext(opt_level=3):
+            mod = tvm.IRModule.from_expr(tvm_mod)
+            tvm_mod_nchwc = seq(mod)
+            return tvm_mod_nchwc