You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by me...@apache.org on 2022/04/26 00:18:25 UTC

[tvm] branch main updated: [Hexagon] Add test for registered schedules (#11016)

This is an automated email from the ASF dual-hosted git repository.

mehrdadh pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new dca94ec9d1 [Hexagon] Add test for registered schedules (#11016)
dca94ec9d1 is described below

commit dca94ec9d1a2ea553d0f7c2ee09e9487b73d4d35
Author: Mehrdad Hessar <mh...@octoml.ai>
AuthorDate: Mon Apr 25 17:18:19 2022 -0700

    [Hexagon] Add test for registered schedules (#11016)
    
    * add hexagon schedule tests
    
    * moved tests to sub-directories
---
 python/tvm/contrib/hexagon/session.py              |   2 +-
 .../contrib/test_hexagon/conv2d/__init__.py}       |  24 +-
 .../{ => conv2d}/test_conv2d_blocked.md            |   0
 .../{ => conv2d}/test_conv2d_blocked.py            |   2 +-
 .../{ => conv2d}/test_conv2d_conv2d.md             |   0
 .../{ => conv2d}/test_conv2d_conv2d.py             |   2 +-
 .../test_hexagon/test_2d_physical_buffers.py       |   0
 .../contrib/test_hexagon/topi/__init__.py}         |  22 +-
 .../contrib/test_hexagon/topi/test_batch_matmul.py | 141 ++++
 .../{ => topi}/test_cache_read_write.py            |   5 +-
 .../contrib/test_hexagon/topi/test_conv2d_nchw.py  | 246 +++++++
 .../contrib/test_hexagon/topi/test_conv2d_nhwc.py  | 126 ++++
 .../python/contrib/test_hexagon/topi/test_dense.py | 112 ++++
 .../contrib/test_hexagon/topi/test_pooling.py      | 740 +++++++++++++++++++++
 .../contrib/test_hexagon/topi/test_reduce.py       | 165 +++++
 .../contrib/test_hexagon/topi/test_softmax.py      | 101 +++
 tests/scripts/task_build_hexagon_api.sh            |  12 +-
 17 files changed, 1648 insertions(+), 52 deletions(-)

diff --git a/python/tvm/contrib/hexagon/session.py b/python/tvm/contrib/hexagon/session.py
index 7d2eecbc2c..a69a33e270 100644
--- a/python/tvm/contrib/hexagon/session.py
+++ b/python/tvm/contrib/hexagon/session.py
@@ -57,7 +57,7 @@ class Session:
         remote_kw: dict,
         session_name: str = "hexagon-rpc",
         remote_stack_size_bytes: int = 256 * 1024,  # Min size for main thread in QuRT/sim
-        rpc_receive_buffer_size_bytes: int = 2 * 1024 * 1024,
+        rpc_receive_buffer_size_bytes: int = 5 * 1024 * 1024,  # Size for passing hexagon tests
     ):
         self._launcher = launcher
         self._session_name: str = session_name
diff --git a/tests/scripts/task_python_hexagon_simulator.sh b/tests/python/contrib/test_hexagon/conv2d/__init__.py
old mode 100755
new mode 100644
similarity index 54%
rename from tests/scripts/task_python_hexagon_simulator.sh
rename to tests/python/contrib/test_hexagon/conv2d/__init__.py
index c8ae847e3e..1c727042a9
--- a/tests/scripts/task_python_hexagon_simulator.sh
+++ b/tests/python/contrib/test_hexagon/conv2d/__init__.py
@@ -1,4 +1,3 @@
-#!/usr/bin/env bash
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -16,25 +15,4 @@
 # specific language governing permissions and limitations
 # under the License.
 
-set -e
-set -u
-
-source tests/scripts/setup-pytest-env.sh
-
-make cython3
-
-export TVM_TRACKER_PORT=9190
-export TVM_TRACKER_HOST=0.0.0.0
-env PYTHONPATH=python python3 -m tvm.exec.rpc_tracker --host "${TVM_TRACKER_HOST}" --port "${TVM_TRACKER_PORT}" &
-TRACKER_PID=$!
-sleep 5   # Wait for tracker to bind
-
-# Temporary workaround for symbol visibility
-export HEXAGON_SHARED_LINK_FLAGS="-Lbuild/hexagon_api_output -lhexagon_rpc_sim"
-
-# HEXAGON_TOOLCHAIN is already set
-export HEXAGON_SDK_ROOT=${HEXAGON_SDK_PATH}
-export ANDROID_SERIAL_NUMBER=simulator
-run_pytest ctypes python-contrib-hexagon-simulator tests/python/contrib/test_hexagon
-
-kill ${TRACKER_PID}
+""" Testing infrastructure for Hexagon/TOPI/Conv2d """
diff --git a/tests/python/contrib/test_hexagon/test_conv2d_blocked.md b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.md
similarity index 100%
rename from tests/python/contrib/test_hexagon/test_conv2d_blocked.md
rename to tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.md
diff --git a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.py
similarity index 99%
rename from tests/python/contrib/test_hexagon/test_conv2d_blocked.py
rename to tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.py
index 9c8f759414..6762db85e6 100644
--- a/tests/python/contrib/test_hexagon/test_conv2d_blocked.py
+++ b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_blocked.py
@@ -23,7 +23,7 @@ from tvm import te
 from tvm import topi
 from tvm.topi import testing
 
-from .infrastructure import (
+from ..infrastructure import (
     build_and_run,
     conv2d_compute,
     conv2d_verify,
diff --git a/tests/python/contrib/test_hexagon/test_conv2d_conv2d.md b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.md
similarity index 100%
rename from tests/python/contrib/test_hexagon/test_conv2d_conv2d.md
rename to tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.md
diff --git a/tests/python/contrib/test_hexagon/test_conv2d_conv2d.py b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.py
similarity index 99%
rename from tests/python/contrib/test_hexagon/test_conv2d_conv2d.py
rename to tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.py
index d0d381f0aa..437bdb750b 100644
--- a/tests/python/contrib/test_hexagon/test_conv2d_conv2d.py
+++ b/tests/python/contrib/test_hexagon/conv2d/test_conv2d_conv2d.py
@@ -23,7 +23,7 @@ from tvm import te
 from tvm import topi
 from tvm.topi import testing
 
-from .infrastructure import (
+from ..infrastructure import (
     build_and_run,
     conv2d_compute,
     conv2d_verify,
diff --git a/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py b/tests/python/contrib/test_hexagon/test_2d_physical_buffers.py
old mode 100755
new mode 100644
diff --git a/tests/scripts/task_build_hexagon_api.sh b/tests/python/contrib/test_hexagon/topi/__init__.py
old mode 100755
new mode 100644
similarity index 58%
copy from tests/scripts/task_build_hexagon_api.sh
copy to tests/python/contrib/test_hexagon/topi/__init__.py
index 89b7545f4d..fb6657b09e
--- a/tests/scripts/task_build_hexagon_api.sh
+++ b/tests/python/contrib/test_hexagon/topi/__init__.py
@@ -1,4 +1,3 @@
-#!/bin/bash
 # Licensed to the Apache Software Foundation (ASF) under one
 # or more contributor license agreements.  See the NOTICE file
 # distributed with this work for additional information
@@ -16,23 +15,4 @@
 # specific language governing permissions and limitations
 # under the License.
 
-set -e
-set -u
-
-cd apps/hexagon_api
-rm -rf build
-mkdir -p build
-cd build
-
-output_binary_directory=$(realpath ${PWD}/../../../build/hexagon_api_output)
-rm -rf ${output_binary_directory}
-
-cmake -DANDROID_ABI=arm64-v8a \
-    -DANDROID_PLATFORM=android-28 \
-    -DUSE_ANDROID_TOOLCHAIN="${ANDROID_NDK_HOME}/build/cmake/android.toolchain.cmake" \
-    -DUSE_HEXAGON_ARCH=v68 \
-    -DUSE_HEXAGON_SDK="${HEXAGON_SDK_PATH}" \
-    -DUSE_HEXAGON_TOOLCHAIN="${HEXAGON_TOOLCHAIN}" \
-    -DUSE_OUTPUT_BINARY_DIR="${output_binary_directory}" ..
-
-make -j$(nproc)
+""" Testing infrastructure for Hexagon/TOPI """
diff --git a/tests/python/contrib/test_hexagon/topi/test_batch_matmul.py b/tests/python/contrib/test_hexagon/topi/test_batch_matmul.py
new file mode 100644
index 0000000000..d73ab46424
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_batch_matmul.py
@@ -0,0 +1,141 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for matmul"""
+import numpy as np
+import pytest
+import sys
+
+import tvm
+from tvm import topi
+from tvm import te
+import tvm.topi.testing
+from tvm.topi.utils import get_const_tuple
+
+from ..conftest import requires_hexagon_toolchain
+
+dtype = tvm.testing.parameter(
+    "float32",
+    "float16",
+)
+
+
+class TestMatMulFloat:
+    x_batch, y_batch, M, N, K = tvm.testing.parameters(
+        (1, 1, 16, 16, 32),
+        (5, 5, 16, 16, 32),
+        (5, 5, 16, 20, 32),
+        (30, 30, 16, 20, 32),
+        # Test batch broadcasting.
+        (1, 5, 16, 16, 32),
+        (5, 1, 16, 16, 32),
+    )
+
+    # TODO(mehrdadh): add dynamic testing
+    @requires_hexagon_toolchain
+    def test_batch_matmul(self, hexagon_session, x_batch, y_batch, M, N, K, dtype):
+        if dtype == "float16":
+            pytest.xfail("float16 is not supported.")
+
+        x = te.placeholder((x_batch, M, K), name="x")
+        y = te.placeholder((y_batch, N, K), name="y")
+
+        def get_ref_data():
+            a_np = np.random.uniform(size=(x_batch, M, K)).astype(dtype)
+            b_np = np.random.uniform(size=(y_batch, N, K)).astype(dtype)
+            c_np = tvm.topi.testing.batch_matmul(a_np, b_np)
+            return (a_np, b_np, c_np)
+
+        # get the test data
+        a_np, b_np, c_np = get_ref_data()
+
+        target_hexagon = tvm.target.hexagon("v68")
+        with tvm.target.Target(target_hexagon):
+            fcompute = topi.nn.batch_matmul
+            fschedule = topi.hexagon.schedule_batch_matmul
+            out = fcompute(x, y)
+            s = fschedule([out])
+            out_shape = out.shape
+
+        func = tvm.build(
+            s,
+            [x, y, out],
+            tvm.target.Target(target_hexagon, host=target_hexagon),
+            name="batch_matmul",
+        )
+        mod = hexagon_session.load_module(func)
+
+        dev = hexagon_session.device
+        a = tvm.nd.array(a_np, dev)
+        b = tvm.nd.array(b_np, dev)
+        c = tvm.nd.array(np.zeros(get_const_tuple(out_shape), dtype=dtype), dev)
+        mod["batch_matmul"](a, b, c)
+
+        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
+
+
+class TestMatMulInt8:
+    x_batch, y_batch, M, N, K = tvm.testing.parameters(
+        (1, 1, 2, 3, 1),
+        (1, 1, 16, 24, 32),
+        (5, 5, 24, 16, 32),
+        (30, 30, 16, 20, 32),
+        (1, 5, 16, 16, 32),
+        (5, 1, 16, 16, 32),
+    )
+
+    @requires_hexagon_toolchain
+    def test_batch_matmul_int8(self, hexagon_session, x_batch, y_batch, M, N, K):
+        dtype = "int8"
+        out_dtype = "int8"
+        assert x_batch == y_batch or x_batch == 1 or y_batch == 1
+        x = te.placeholder((x_batch, M, K), name="x", dtype=dtype)
+        y = te.placeholder((y_batch, N, K), name="y", dtype=dtype)
+
+        def get_ref_data():
+            a_np = np.random.randint(low=-128, high=127, size=(x_batch, M, K)).astype(dtype)
+            b_np = np.random.randint(low=-128, high=127, size=(y_batch, N, K)).astype(dtype)
+            c_np = tvm.topi.testing.batch_matmul(a_np, b_np, out_dtype=out_dtype)
+            return (a_np, b_np, c_np)
+
+        # get the test data
+        a_np, b_np, c_np = get_ref_data()
+
+        target_hexagon = tvm.target.hexagon("v68")
+        with tvm.target.Target(target_hexagon):
+            fcompute = topi.nn.batch_matmul
+            fschedule = topi.hexagon.schedule_batch_matmul
+            out = fcompute(x, y)
+            s = fschedule([out])
+
+        func = tvm.build(
+            s,
+            [x, y, out],
+            tvm.target.Target(target_hexagon, host=target_hexagon),
+            name="batch_matmul_int8",
+        )
+        mod = hexagon_session.load_module(func)
+
+        dev = hexagon_session.device
+        a = tvm.nd.array(a_np, dev)
+        b = tvm.nd.array(b_np, dev)
+        c = tvm.nd.array(np.zeros(get_const_tuple(out.shape), dtype=out_dtype), dev)
+        mod["batch_matmul_int8"](a, b, c)
+        tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))
diff --git a/tests/python/contrib/test_hexagon/test_cache_read_write.py b/tests/python/contrib/test_hexagon/topi/test_cache_read_write.py
similarity index 97%
rename from tests/python/contrib/test_hexagon/test_cache_read_write.py
rename to tests/python/contrib/test_hexagon/topi/test_cache_read_write.py
index 8f94531871..46e78f6683 100644
--- a/tests/python/contrib/test_hexagon/test_cache_read_write.py
+++ b/tests/python/contrib/test_hexagon/topi/test_cache_read_write.py
@@ -20,11 +20,8 @@ import numpy as np
 
 import tvm.testing
 from tvm import te
-from tvm.contrib import utils
-from tvm.contrib.hexagon.build import HexagonLauncher
-import tvm.contrib.hexagon as hexagon
 
-from .conftest import requires_hexagon_toolchain
+from ..conftest import requires_hexagon_toolchain
 
 
 def intrin_mem_copy(shape, dtype, dst_scope, src_scope):
diff --git a/tests/python/contrib/test_hexagon/topi/test_conv2d_nchw.py b/tests/python/contrib/test_hexagon/topi/test_conv2d_nchw.py
new file mode 100644
index 0000000000..12417e80af
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_conv2d_nchw.py
@@ -0,0 +1,246 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for convolution."""
+import numpy as np
+import pytest
+import sys
+
+import tvm
+from tvm import topi
+from tvm import te
+import tvm.topi.testing
+from tvm.topi.utils import get_const_tuple
+from tvm.topi.nn.utils import get_pad_tuple
+
+from ..conftest import requires_hexagon_toolchain
+
+
+dtype = tvm.testing.parameter("float32")
+random_seed = tvm.testing.parameter(0)
+
+
+@tvm.testing.fixture
+def input_shape(batch, in_channel, in_size):
+    return (batch, in_channel, in_size, in_size)
+
+
+@tvm.testing.fixture
+def weight_shape(num_filter, in_channel, kernel):
+    return (num_filter, in_channel, kernel, kernel)
+
+
+@tvm.testing.fixture
+def bias_shape(num_filter):
+    return (num_filter, 1, 1)
+
+
+@tvm.testing.fixture(cache_return_value=True)
+def ref_data(
+    random_seed,
+    input_shape,
+    weight_shape,
+    bias_shape,
+    dtype,
+    stride,
+    padding,
+    dilation,
+    add_bias,
+    apply_relu,
+):
+    np.random.seed(random_seed)
+
+    # scipy.signal.convolve2d does not support float16 data types, and
+    # the python fallback is too slow for general use.  Computing
+    # ref_data in float32 will have fewer rounding errors than the TVM
+    # float16 compute, but those vary based on schedule anyways.
+    conv_dtype = "float32" if dtype == "float16" else dtype
+
+    a_np = np.random.uniform(size=input_shape).astype(dtype)
+    w_np = np.random.uniform(size=weight_shape).astype(dtype)
+    b_np = np.random.uniform(size=bias_shape).astype(dtype)
+    dw_np = tvm.topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
+    c_np = tvm.topi.testing.conv2d_nchw_python(
+        a_np.astype(conv_dtype), dw_np.astype(conv_dtype), stride, padding
+    ).astype(dtype)
+
+    if add_bias:
+        c_np = c_np + b_np
+    if apply_relu:
+        c_np = np.maximum(c_np, 0)
+    return a_np, w_np, b_np, c_np
+
+
+class BaseConv2DTests:
+    add_bias = tvm.testing.parameter(False)
+    apply_relu = tvm.testing.parameter(False)
+    dilation = tvm.testing.parameter(1)
+    batch = tvm.testing.parameter(1)
+
+    @requires_hexagon_toolchain
+    def test_conv2d_nchw(
+        self,
+        hexagon_session,
+        batch,
+        in_channel,
+        in_size,
+        num_filter,
+        kernel,
+        stride,
+        padding,
+        dtype,
+        ref_data,
+        dilation,
+        add_bias,
+        apply_relu,
+    ):
+        target_hexagon = tvm.target.hexagon("v68")
+
+        pad_top, pad_left, pad_bottom, pad_right = get_pad_tuple(padding, (kernel, kernel))
+        padding_sum = pad_top + pad_left + pad_bottom + pad_right
+
+        a_np, w_np, b_np, c_np = ref_data
+
+        A = te.placeholder(a_np.shape, name="A", dtype=dtype)
+        W = te.placeholder(w_np.shape, name="W", dtype=dtype)
+        bias = te.placeholder(b_np.shape, name="bias", dtype=dtype)
+
+        if "int" in dtype:
+            tol = {"atol": 0, "rtol": 0}
+        elif dtype == "float32":
+            tol = {"rtol": 1e-4, "atol": 2e-4}
+        elif dtype == "float16":
+            # A summation in float16 with a single accumulator very
+            # quickly runs into large rounding errors.  At some point,
+            # this tolerance should be schedule-dependent for to avoid
+            # false negatives.
+            num_values_summed = in_channel * kernel * kernel
+            gap_size = np.nextafter(c_np.max(), np.inf, dtype=c_np.dtype) - c_np.max()
+            tol = {"rtol": 1e-3, "atol": num_values_summed * gap_size / 2}
+
+        with tvm.target.Target(target_hexagon):
+            fcompute = topi.nn.conv2d_nchw
+            fschedule = topi.hexagon.schedule_conv2d_nchw
+            C = fcompute(A, W, (stride, stride), padding, (dilation, dilation), dtype)
+            if add_bias:
+                C = topi.add(C, bias)
+            if apply_relu:
+                C = topi.nn.relu(C)
+            s = fschedule([C])
+
+        func_name = "conv2d_{}_{}_{}_{}_{}_{}_{}_{}_{}".format(
+            dtype,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding_sum,
+            dilation,
+        )
+        func = tvm.build(
+            s,
+            [A, W, bias, C],
+            tvm.target.Target(target_hexagon, host=target_hexagon),
+            name=func_name,
+        )
+        mod = hexagon_session.load_module(func)
+
+        dev = hexagon_session.device
+        a = tvm.nd.array(a_np, dev)
+        w = tvm.nd.array(w_np, dev)
+        b = tvm.nd.array(b_np, dev)
+
+        c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
+        mod[func_name](a, w, b, c)
+        tvm.testing.assert_allclose(c.numpy(), c_np, **tol)
+
+
+class TestBatchSize(BaseConv2DTests):
+    in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters(
+        (32, 28, 32, 3, 1, 1),
+    )
+    batch = tvm.testing.parameter(1, 4, 9)
+
+
+class TestBiasRelu(BaseConv2DTests):
+    apply_relu = tvm.testing.parameter(True, False, ids=["relu", "no_relu"])
+    add_bias = tvm.testing.parameter(True, False, ids=["bias", "no_bias"])
+    in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters(
+        (64, 56, 64, 3, 1, 1),
+        (64, 8, 64, 3, 1, (1, 2, 2, 1)),
+        (64, 8, 64, 5, 2, (1, 3)),
+        (64, 8, 64, 3, 1, "VALID"),
+        (32, 8, 32, 24, 1, "SAME"),
+    )
+
+
+class TestResNet18Workloads(BaseConv2DTests):
+    in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters(
+        (3, 224, 64, 7, 2, 3),
+        (64, 56, 64, 3, 1, 1),
+        (64, 56, 64, 1, 1, 0),
+        (64, 56, 32, 3, 2, 1),
+        (64, 56, 32, 1, 2, 0),
+        (64, 28, 32, 3, 1, 1),
+    )
+
+
+class TestMobilenet(BaseConv2DTests):
+    batch, in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters(
+        (1, 32, 112, 32, 3, 1, 1),
+    )
+
+
+class TestWeirdWorkloads(BaseConv2DTests):
+    batch, in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters(
+        (2, 2, 2, 2, 2, 2, 2),
+        (3, 3, 3, 3, 3, 3, 3),
+        (4, 4, 4, 4, 4, 4, 4),
+        (5, 5, 5, 5, 5, 5, 5),
+        (6, 6, 6, 6, 6, 6, 6),
+        (1, 1, 1, 1, 1, 1, 1),
+        (2, 13, 71, 59, 3, 1, 1),
+    )
+
+
+class TestAsymmetricPadding(BaseConv2DTests):
+    dilation = tvm.testing.parameter(1, 2)
+    in_channel, in_size, num_filter, kernel, stride, padding = tvm.testing.parameters(
+        (3, 35, 64, 7, 2, (0, 0, 1, 1)),
+        (64, 8, 128, 3, 1, (3, 3, 2, 2)),
+        (64, 8, 64, 1, 1, (1, 2, 2, 1)),
+        (64, 17, 48, 1, 1, (1, 2)),
+        (64, 8, 64, 3, 1, (3, 1)),
+        (128, 8, 96, 3, 1, (0, 2)),
+        (64, 35, 64, 3, 1, (1, 2)),
+        (64, 8, 64, 1, 1, "VALID"),
+        (388, 8, 64, 3, 1, "VALID"),
+        (64, 10, 48, 3, 1, "VALID"),
+        (64, 19, 64, 1, 1, "SAME"),
+        (64, 5, 32, 2, 1, "SAME"),
+        (32, 8, 32, 3, 1, "SAME"),
+        (64, 8, 64, 3, 1, (1, 2, 2, 1)),
+        (64, 8, 64, 5, 2, (1, 3)),
+        (64, 8, 64, 3, 1, "VALID"),
+        (32, 8, 32, 24, 1, "SAME"),
+        (32, 35, 64, 7, 2, (0, 0, 2, 2)),
+    )
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))
diff --git a/tests/python/contrib/test_hexagon/topi/test_conv2d_nhwc.py b/tests/python/contrib/test_hexagon/topi/test_conv2d_nhwc.py
new file mode 100644
index 0000000000..60b0b7ea6d
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_conv2d_nhwc.py
@@ -0,0 +1,126 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for convolution."""
+import numpy as np
+import pytest
+import sys
+
+import tvm
+from tvm import topi
+from tvm import te
+import tvm.topi.testing
+from tvm.topi.utils import get_const_tuple
+from tvm.topi.nn.utils import get_pad_tuple
+
+from ..conftest import requires_hexagon_toolchain
+
+dtype = tvm.testing.parameter("float32")
+
+
+@tvm.testing.fixture(cache_return_value=True)
+def ref_data(dtype, batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation):
+    in_height = in_width = in_size
+    a_shape = (batch, in_height, in_width, in_channel)
+    w_shape = (kernel, kernel, in_channel, num_filter)
+
+    a_np = np.random.uniform(size=a_shape).astype(dtype)
+    w_np = np.random.uniform(size=w_shape).astype(dtype)
+    dw_np = tvm.topi.testing.dilate_python(w_np, (dilation, dilation, 1, 1))
+    b_np = tvm.topi.testing.conv2d_nhwc_python(a_np, dw_np, stride, padding)
+    return a_np, w_np, b_np
+
+
+class BaseConv2DTests:
+    @requires_hexagon_toolchain
+    def test_conv2d_nhwc(
+        self,
+        hexagon_session,
+        ref_data,
+        batch,
+        in_channel,
+        in_size,
+        num_filter,
+        kernel,
+        dtype,
+        stride,
+        padding,
+        dilation,
+    ):
+        target_hexagon = tvm.target.hexagon("v68")
+
+        a_np, w_np, b_np = ref_data
+
+        A = te.placeholder(a_np.shape, name="A", dtype=dtype)
+        W = te.placeholder(w_np.shape, name="W", dtype=dtype)
+
+        with tvm.target.Target(target_hexagon):
+            fcompute = topi.nn.conv2d_nhwc
+            fschedule = topi.hexagon.schedule_conv2d_nhwc
+            B = fcompute(A, W, stride, padding, dilation, dtype)
+            s = fschedule([B])
+
+        func_name = "conv2d_{}_{}_{}_{}_{}_{}_{}_{}_{}".format(
+            dtype,
+            batch,
+            in_channel,
+            in_size,
+            num_filter,
+            kernel,
+            stride,
+            padding,
+            dilation,
+        )
+        func = tvm.build(
+            s, [A, W, B], tvm.target.Target(target_hexagon, host=target_hexagon), name=func_name
+        )
+        mod = hexagon_session.load_module(func)
+
+        dev = hexagon_session.device
+        a = tvm.nd.array(a_np, dev)
+        w = tvm.nd.array(w_np, dev)
+        b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
+
+        mod[func_name](a, w, b)
+        tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+
+
+class TestConv2dNHWC(BaseConv2DTests):
+    (
+        batch,
+        in_channel,
+        in_size,
+        num_filter,
+        kernel,
+        stride,
+        padding,
+        dilation,
+    ) = tvm.testing.parameters(
+        (1, 64, 32, 64, 3, 1, "SAME", 1),
+        (4, 32, 16, 32, 5, 2, "SAME", 1),
+        (1, 64, 32, 64, 3, 1, "VALID", 1),
+        (4, 32, 16, 32, 5, 2, "VALID", 1),
+        (1, 32, 16, 64, 3, 2, (0, 0, 1, 1), 1),
+        (1, 32, 16, 64, 3, 2, (1, 1, 2, 2), 1),
+        (1, 32, 16, 32, 5, 2, (3, 3, 2, 2), 1),
+        (1, 32, 16, 64, 3, 2, (0, 1, 2, 3), 1),
+        (1, 64, 32, 64, 3, 1, "SAME", 2),
+        (1, 64, 32, 64, 3, 1, (1, 1, 2, 2), 2),
+    )
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))
diff --git a/tests/python/contrib/test_hexagon/topi/test_dense.py b/tests/python/contrib/test_hexagon/topi/test_dense.py
new file mode 100644
index 0000000000..59a1573a6b
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_dense.py
@@ -0,0 +1,112 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for dense"""
+import numpy as np
+import pytest
+import sys
+
+import tvm
+from tvm import topi
+from tvm import te
+import tvm.topi.testing
+from tvm.topi.utils import get_const_tuple
+
+from ..conftest import requires_hexagon_toolchain
+
+random_seed = tvm.testing.parameter(0)
+
+use_bias = tvm.testing.parameter(True, False)
+
+# batch_size more than 8 would break
+batch_size = tvm.testing.parameter(1, 2, 8)
+
+in_dim, out_dim = tvm.testing.parameters((1024, 1000))
+
+in_dtype, out_dtype = tvm.testing.parameters(
+    ("float32", "float32"),
+    ("float16", "float32"),
+    ("int8", "int32"),
+)
+
+
+@tvm.testing.fixture(cache_return_value=True)
+def dense_ref_data(random_seed, batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype):
+    np.random.seed(random_seed)
+
+    if "float" in in_dtype:
+        a_np = np.random.uniform(size=(batch_size, in_dim)).astype(in_dtype)
+        b_np = np.random.uniform(size=(out_dim, in_dim)).astype(in_dtype)
+        c_np = np.random.uniform(size=(out_dim,)).astype(out_dtype)
+    elif in_dtype == "int8":
+        a_np = np.random.randint(low=-128, high=127, size=(batch_size, in_dim)).astype(in_dtype)
+        b_np = np.random.randint(low=-128, high=127, size=(out_dim, in_dim)).astype(in_dtype)
+        c_np = np.random.randint(low=-128, high=127, size=(out_dim,)).astype(out_dtype)
+    else:
+        raise ValueError("No method to generate test data for data type '{}'".format(in_dtype))
+
+    matmul = np.dot(a_np.astype(out_dtype), b_np.T.astype(out_dtype))
+
+    if use_bias:
+        matmul += c_np
+
+    d_np = np.maximum(matmul, 0)
+    return (a_np, b_np, c_np, d_np)
+
+
+@requires_hexagon_toolchain
+def test_dense(
+    hexagon_session, batch_size, in_dim, out_dim, use_bias, in_dtype, out_dtype, dense_ref_data
+):
+    if in_dtype == "float16":
+        pytest.xfail("float16 is not supported.")
+
+    if "int" in in_dtype:
+        tol = {"atol": 0, "rtol": 0}
+    elif in_dtype == "float32":
+        tol = {"rtol": 1e-5, "atol": 1e-5}
+
+    A = te.placeholder((batch_size, in_dim), name="A", dtype=in_dtype)
+    B = te.placeholder((out_dim, in_dim), name="B", dtype=in_dtype)
+    C = te.placeholder((out_dim,), name="C", dtype=out_dtype)
+
+    a_np, b_np, c_np, d_np = dense_ref_data
+
+    fcompute = topi.nn.dense
+    fschedule = topi.hexagon.schedule_dense
+
+    target_hexagon = tvm.target.hexagon("v68")
+    with tvm.target.Target(target_hexagon):
+        D = fcompute(A, B, C if use_bias else None, out_dtype)
+        D = topi.nn.relu(D)
+        s = fschedule([D])
+
+    func = tvm.build(
+        s, [A, B, C, D], tvm.target.Target(target_hexagon, host=target_hexagon), name="dense"
+    )
+    mod = hexagon_session.load_module(func)
+
+    dev = hexagon_session.device
+    a = tvm.nd.array(a_np, dev)
+    b = tvm.nd.array(b_np, dev)
+    c = tvm.nd.array(c_np, dev)
+    d = tvm.nd.array(np.zeros(get_const_tuple(D.shape), dtype=out_dtype), dev)
+    mod["dense"](a, b, c, d)
+    tvm.testing.assert_allclose(d.numpy(), d_np, **tol)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))
diff --git a/tests/python/contrib/test_hexagon/topi/test_pooling.py b/tests/python/contrib/test_hexagon/topi/test_pooling.py
new file mode 100644
index 0000000000..f05611f2f5
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_pooling.py
@@ -0,0 +1,740 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for pooling"""
+import numpy as np
+import pytest
+import sys
+
+import tvm
+from tvm import topi
+from tvm import te
+import tvm.topi.testing
+from tvm.topi.utils import get_const_tuple
+
+from ..conftest import requires_hexagon_toolchain
+
+
+class TestAdaptivePool:
+    dshape, out_size, pool_type, layout = tvm.testing.parameters(
+        ((1, 3, 112, 112), (1, 1), "max", "NCHW"),
+        ((1, 3, 112, 112), (1, 1), "avg", "NCHW"),
+        ((1, 14, 56, 78), (34, 13), "max", "NCHW"),
+        ((1, 5, 46, 97), (4, 96), "avg", "NCHW"),
+        ((1, 112, 112, 3), (1, 1), "max", "NHWC"),
+        ((1, 5, 46, 97), (4, 96), "avg", "NHWC"),
+        ((1, 16, 32, 32, 32), (1, 1, 1), "max", "NCDHW"),
+        ((1, 16, 32, 32, 32), (1, 1, 1), "avg", "NCDHW"),
+        ((1, 16, 32, 32, 32), (2, 2, 2), "avg", "NCDHW"),
+        (
+            (1, 16, 64, 32, 32),
+            (7, 8, 9),
+            "avg",
+            "NCDHW",
+        ),
+        (
+            (1, 16, 64, 32, 32),
+            (8, 16, 16),
+            "avg",
+            "NCDHW",
+        ),
+        ((1, 16, 32, 32, 32), (1, 1, 1), "avg", "NDHWC"),
+        ((1, 16, 32, 32, 32), (2, 2, 2), "max", "NDHWC"),
+        ((1, 16, 32, 32, 32), (2, 4, 4), "max", "NDHWC"),
+    )
+
+    @requires_hexagon_toolchain
+    def test_adaptive_pool(self, hexagon_session, dshape, out_size, pool_type, layout):
+        dtype = "float32"
+        np_data = np.random.uniform(low=0, high=255, size=dshape).astype(dtype)
+        np_out = tvm.topi.testing.adaptive_pool(np_data, out_size, pool_type, layout)
+        oshape = np_out.shape
+
+        data = te.placeholder(dshape, name="data", dtype=dtype)
+        if len(out_size) == 2:
+            out = topi.nn.adaptive_pool(data, out_size, pool_type, layout)
+        else:
+            assert len(out_size) == 3
+            out = topi.nn.adaptive_pool3d(data, out_size, pool_type, layout)
+
+        target_hexagon = tvm.target.hexagon("v68")
+        with tvm.target.Target(target_hexagon):
+            fschedule = topi.hexagon.schedule_adaptive_pool
+            s = fschedule(out)
+
+        func = tvm.build(
+            s,
+            [data, out],
+            tvm.target.Target(target_hexagon, host=target_hexagon),
+            name="adaptive-pool",
+        )
+        mod = hexagon_session.load_module(func)
+
+        dev = hexagon_session.device
+        a = tvm.nd.array(np_data, dev)
+        b = tvm.nd.array(np.zeros(get_const_tuple(oshape), dtype=out.dtype), dev)
+        mod["adaptive-pool"](a, b)
+
+        tvm.testing.assert_allclose(b.numpy(), np_out, rtol=4e-5, atol=1e-6)
+
+
+def verify_poolnd(
+    hexagon_session,
+    n,
+    input_shape,
+    kernel,
+    stride,
+    dilation,
+    padding,
+    pool_type,
+    ceil_mode,
+    count_include_pad=True,
+    layout="NCW",
+):
+    A = te.placeholder(input_shape, name="A")
+
+    if n == 1:
+        B = topi.nn.pool1d(
+            A,
+            kernel=kernel,
+            stride=stride,
+            dilation=dilation,
+            padding=padding,
+            pool_type=pool_type,
+            ceil_mode=ceil_mode,
+            layout=layout,
+            count_include_pad=count_include_pad,
+        )
+    elif n == 2:
+        B = topi.nn.pool2d(
+            A,
+            kernel=kernel,
+            stride=stride,
+            dilation=dilation,
+            padding=padding,
+            pool_type=pool_type,
+            ceil_mode=ceil_mode,
+            layout=layout,
+            count_include_pad=count_include_pad,
+        )
+    elif n == 3:
+        B = topi.nn.pool3d(
+            A,
+            kernel=kernel,
+            stride=stride,
+            dilation=dilation,
+            padding=padding,
+            pool_type=pool_type,
+            ceil_mode=ceil_mode,
+            layout=layout,
+            count_include_pad=count_include_pad,
+        )
+    else:
+        raise ValueError(f"PoolND only supports n=1, 2, 3 got n={n}")
+
+    B = topi.nn.relu(B)
+    dtype = A.dtype
+    output_shape = [int(i) for i in B.shape]
+
+    input_np = np.random.uniform(low=0.001, size=input_shape).astype(dtype)
+
+    padding_before = padding[:n]
+    padding_after = padding[n:]
+    ref_np = tvm.topi.testing.poolnd_python(
+        input_np,
+        kernel,
+        stride,
+        dilation,
+        padding_before,
+        padding_after,
+        pool_type,
+        count_include_pad,
+        ceil_mode,
+        layout=layout,
+    )
+
+    np.testing.assert_equal(tuple(output_shape), tuple(ref_np.shape))
+
+    target_hexagon = tvm.target.hexagon("v68")
+    with tvm.target.Target(target_hexagon):
+        fschedule = topi.hexagon.schedule_pool
+        s = fschedule(B, layout)
+
+    func = tvm.build(s, [A, B], tvm.target.Target(target_hexagon, host=target_hexagon), name="pool")
+    mod = hexagon_session.load_module(func)
+
+    dev = hexagon_session.device
+    a = tvm.nd.array(input_np, dev)
+    b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=dtype), dev)
+    mod["pool"](a, b)
+
+    tvm.testing.assert_allclose(b.numpy(), ref_np, rtol=1e-5)
+
+
+class TestPool1D:
+    (
+        input_shape,
+        kernel,
+        stride,
+        dilation,
+        padding,
+        pool_type,
+        ceil_mode,
+        count_include_pad,
+        layout,
+    ) = tvm.testing.parameters(
+        ([1, 16, 32], [2], [2], [1], [0, 0], "avg", False, True, "NCW"),
+        ([1, 16, 31], [3], [3], [1], [1, 2], "avg", False, True, "NCW"),
+        ([1, 16, 32], [2], [2], [1], [1, 2], "avg", False, False, "NCW"),
+        ([1, 16, 31], [4], [4], [1], [3, 3], "avg", False, False, "NCW"),
+        ([1, 16, 31], [4], [4], [1], [0, 0], "avg", False, False, "NCW"),
+        ([1, 16, 32], [2], [2], [1], [0, 0], "max", False, True, "NCW"),
+        ([1, 16, 31], [3], [3], [1], [2, 1], "max", False, True, "NCW"),
+        ([1, 16, 31], [3], [3], [1], [2, 1], "max", True, True, "NCW"),
+        ([1, 16, 31], [3], [3], [1], [2, 5], "avg", False, True, "NCW"),
+        ([1, 16, 32], [2], [2], [1], [0, 3], "avg", False, False, "NCW"),
+        ([1, 16, 31], [3], [3], [1], [1, 4], "max", False, True, "NCW"),
+        ([1, 16, 31], [3], [3], [1], [3, 0], "max", True, True, "NCW"),
+        # Test non-1 dilations
+        ([1, 16, 31], [3], [3], [2], [2, 5], "avg", False, True, "NCW"),
+        ([1, 16, 32], [2], [2], [3], [0, 3], "avg", False, False, "NCW"),
+        ([1, 16, 31], [3], [3], [2], [1, 4], "max", False, True, "NCW"),
+        ([1, 16, 31], [3], [3], [3], [3, 0], "max", True, True, "NCW"),
+        # Test Channel last
+        ([1, 32, 16], [2], [2], [1], [0, 0], "avg", False, True, "NWC"),
+        ([1, 31, 16], [3], [3], [1], [1, 2], "avg", False, True, "NWC"),
+        ([1, 32, 16], [2], [2], [1], [1, 2], "avg", False, False, "NWC"),
+        ([1, 31, 16], [4], [4], [1], [3, 3], "avg", False, False, "NWC"),
+        ([1, 31, 16], [4], [4], [1], [0, 0], "avg", False, False, "NWC"),
+        ([1, 32, 16], [2], [2], [1], [0, 0], "max", False, True, "NWC"),
+        ([1, 31, 16], [3], [3], [1], [2, 1], "max", False, True, "NWC"),
+        ([1, 31, 16], [3], [3], [1], [2, 1], "max", True, True, "NWC"),
+        ([1, 31, 16], [3], [3], [1], [2, 5], "avg", False, True, "NWC"),
+        ([1, 31, 16], [2], [2], [1], [0, 3], "avg", False, False, "NWC"),
+        ([1, 31, 16], [3], [3], [1], [1, 4], "max", False, True, "NWC"),
+        ([1, 31, 16], [3], [3], [1], [3, 0], "max", True, True, "NWC"),
+        ([1, 31, 16], [3], [3], [2], [2, 5], "avg", False, True, "NWC"),
+        ([1, 32, 16], [2], [2], [3], [0, 3], "avg", False, False, "NWC"),
+        ([1, 31, 16], [3], [3], [2], [1, 4], "max", False, True, "NWC"),
+        ([1, 31, 16], [3], [3], [3], [3, 0], "max", True, True, "NWC"),
+    )
+
+    @requires_hexagon_toolchain
+    def test_pool1d(
+        self,
+        hexagon_session,
+        input_shape,
+        kernel,
+        stride,
+        dilation,
+        padding,
+        pool_type,
+        ceil_mode,
+        count_include_pad,
+        layout,
+    ):
+        verify_poolnd(
+            hexagon_session,
+            1,
+            input_shape,
+            kernel,
+            stride,
+            dilation,
+            padding,
+            pool_type,
+            ceil_mode,
+            count_include_pad,
+            layout,
+        )
+
+
+class TestPool2D:
+    (
+        input_shape,
+        kernel,
+        stride,
+        dilation,
+        padding,
+        pool_type,
+        ceil_mode,
+        count_include_pad,
+        layout,
+    ) = tvm.testing.parameters(
+        ([1, 16, 32, 32], [2, 2], [2, 2], [1, 1], [0, 0, 0, 0], "avg", False, True, "NCHW"),
+        ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [1, 2, 1, 2], "avg", False, True, "NCHW"),
+        ([1, 16, 32, 32], [2, 2], [2, 2], [1, 1], [1, 2, 1, 2], "avg", False, False, "NCHW"),
+        ([1, 16, 31, 31], [4, 4], [4, 4], [1, 1], [3, 3, 3, 3], "avg", False, False, "NCHW"),
+        ([1, 16, 31, 31], [4, 4], [4, 4], [1, 1], [0, 0, 0, 0], "avg", False, False, "NCHW"),
+        ([1, 16, 32, 32], [2, 3], [2, 2], [1, 1], [0, 0, 0, 0], "max", False, True, "NCHW"),
+        ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", False, True, "NCHW"),
+        ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", True, True, "NCHW"),
+        ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [2, 1, 0, 3], "avg", False, True, "NCHW"),
+        ([1, 16, 32, 32], [2, 3], [2, 2], [1, 1], [0, 3, 2, 1], "avg", False, False, "NCHW"),
+        ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [1, 0, 3, 2], "max", False, True, "NCHW"),
+        ([1, 16, 31, 31], [3, 3], [3, 3], [1, 1], [3, 2, 1, 0], "max", True, True, "NCHW"),
+        # Test non-1 dilations
+        ([1, 16, 31, 31], [3, 3], [3, 3], [2, 1], [2, 1, 0, 3], "avg", False, True, "NCHW"),
+        ([1, 16, 32, 32], [2, 3], [2, 2], [2, 3], [0, 3, 2, 1], "avg", False, False, "NCHW"),
+        ([1, 16, 31, 31], [3, 3], [3, 3], [3, 3], [1, 0, 3, 2], "max", False, True, "NCHW"),
+        ([1, 16, 31, 31], [3, 3], [3, 3], [2, 2], [3, 2, 1, 0], "max", True, True, "NCHW"),
+        # Test channel last
+        ([1, 32, 32, 16], [2, 2], [2, 2], [1, 1], [0, 0, 0, 0], "avg", False, True, "NHWC"),
+        ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [1, 2, 1, 2], "avg", False, True, "NHWC"),
+        ([1, 32, 32, 16], [2, 2], [2, 2], [1, 1], [1, 2, 1, 2], "avg", False, False, "NHWC"),
+        ([1, 31, 31, 16], [4, 4], [4, 4], [1, 1], [3, 3, 3, 3], "avg", False, False, "NHWC"),
+        ([1, 31, 31, 16], [4, 4], [4, 4], [1, 1], [0, 0, 0, 0], "avg", False, False, "NHWC"),
+        ([1, 32, 32, 16], [2, 3], [2, 2], [1, 1], [0, 0, 0, 0], "max", False, True, "NHWC"),
+        ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", False, True, "NHWC"),
+        ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 2, 1], "max", True, True, "NHWC"),
+        ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [2, 1, 0, 3], "avg", False, True, "NHWC"),
+        ([1, 32, 32, 16], [2, 3], [2, 2], [1, 1], [0, 3, 2, 1], "avg", False, False, "NHWC"),
+        ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [1, 0, 3, 2], "max", False, True, "NHWC"),
+        ([1, 31, 31, 16], [3, 3], [3, 3], [1, 1], [3, 2, 1, 0], "max", True, True, "NHWC"),
+        ([1, 31, 31, 16], [3, 3], [3, 3], [2, 1], [2, 1, 0, 3], "avg", False, True, "NHWC"),
+        ([1, 32, 32, 16], [2, 3], [2, 2], [2, 3], [0, 3, 2, 1], "avg", False, False, "NHWC"),
+        ([1, 31, 31, 16], [3, 3], [3, 3], [3, 3], [1, 0, 3, 2], "max", False, True, "NHWC"),
+        ([1, 31, 31, 16], [3, 3], [3, 3], [2, 2], [3, 2, 1, 0], "max", True, True, "NHWC"),
+    )
+
+    @requires_hexagon_toolchain
+    def test_pool2d(
+        self,
+        hexagon_session,
+        input_shape,
+        kernel,
+        stride,
+        dilation,
+        padding,
+        pool_type,
+        ceil_mode,
+        count_include_pad,
+        layout,
+    ):
+        verify_poolnd(
+            hexagon_session,
+            2,
+            input_shape,
+            kernel,
+            stride,
+            dilation,
+            padding,
+            pool_type,
+            ceil_mode,
+            count_include_pad,
+            layout,
+        )
+
+
+class TestPool3D:
+    (
+        input_shape,
+        kernel,
+        stride,
+        dilation,
+        padding,
+        pool_type,
+        ceil_mode,
+        count_include_pad,
+        layout,
+    ) = tvm.testing.parameters(
+        (
+            [1, 16, 32, 32, 32],
+            [2, 2, 2],
+            [2, 2, 2],
+            [1, 1, 1],
+            [0, 0, 0, 0, 0, 0],
+            "avg",
+            False,
+            True,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [1, 1, 2, 2, 2, 1],
+            "avg",
+            False,
+            True,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 32, 32, 32],
+            [2, 2, 2],
+            [2, 2, 2],
+            [1, 1, 1],
+            [1, 1, 2, 2, 2, 1],
+            "avg",
+            False,
+            False,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [4, 4, 4],
+            [4, 4, 4],
+            [1, 1, 1],
+            [3, 3, 3, 3, 3, 3],
+            "avg",
+            False,
+            False,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [4, 4, 4],
+            [4, 4, 4],
+            [1, 1, 1],
+            [0, 0, 0, 0, 0, 0],
+            "avg",
+            False,
+            False,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 32, 32, 32],
+            [2, 2, 2],
+            [2, 2, 2],
+            [1, 1, 1],
+            [0, 0, 0, 0, 0, 0],
+            "max",
+            False,
+            True,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [2, 2, 1, 1, 1, 2],
+            "max",
+            False,
+            True,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [2, 2, 1, 1, 1, 2],
+            "max",
+            True,
+            True,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [2, 1, 0, 5, 4, 3],
+            "avg",
+            False,
+            True,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 32, 32, 32],
+            [2, 2, 2],
+            [2, 2, 2],
+            [1, 1, 1],
+            [0, 5, 4, 3, 2, 1],
+            "avg",
+            False,
+            False,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [1, 0, 5, 4, 3, 2],
+            "max",
+            False,
+            True,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [3, 2, 1, 0, 5, 4],
+            "max",
+            True,
+            True,
+            "NCDHW",
+        ),
+        # Test non-1 dilation
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [3, 3, 3],
+            [2, 1, 0, 5, 4, 3],
+            "avg",
+            False,
+            True,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 32, 32, 32],
+            [2, 2, 2],
+            [2, 2, 2],
+            [2, 2, 2],
+            [0, 5, 4, 3, 2, 1],
+            "avg",
+            False,
+            False,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [2, 1, 3],
+            [1, 0, 5, 4, 3, 2],
+            "max",
+            False,
+            True,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [2, 2, 3],
+            [3, 2, 1, 0, 5, 4],
+            "max",
+            True,
+            True,
+            "NCDHW",
+        ),
+        # Test channel last layouts
+        (
+            [1, 32, 32, 32, 16],
+            [2, 2, 2],
+            [2, 2, 2],
+            [1, 1, 1],
+            [0, 0, 0, 0, 0, 0],
+            "avg",
+            False,
+            True,
+            "NDHWC",
+        ),
+        (
+            [1, 31, 31, 31, 16],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [1, 1, 2, 2, 2, 1],
+            "avg",
+            False,
+            True,
+            "NDHWC",
+        ),
+        (
+            [1, 32, 32, 32, 16],
+            [2, 2, 2],
+            [2, 2, 2],
+            [1, 1, 1],
+            [1, 1, 2, 2, 2, 1],
+            "avg",
+            False,
+            False,
+            "NDHWC",
+        ),
+        (
+            [1, 31, 31, 31, 16],
+            [4, 4, 4],
+            [4, 4, 4],
+            [1, 1, 1],
+            [3, 3, 3, 3, 3, 3],
+            "avg",
+            False,
+            False,
+            "NDHWC",
+        ),
+        (
+            [1, 31, 31, 31, 16],
+            [4, 4, 4],
+            [4, 4, 4],
+            [1, 1, 1],
+            [0, 0, 0, 0, 0, 0],
+            "avg",
+            False,
+            False,
+            "NDHWC",
+        ),
+        (
+            [1, 32, 32, 32, 16],
+            [2, 2, 2],
+            [2, 2, 2],
+            [1, 1, 1],
+            [0, 0, 0, 0, 0, 0],
+            "max",
+            False,
+            True,
+            "NDHWC",
+        ),
+        (
+            [1, 31, 31, 31, 16],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [2, 2, 1, 1, 1, 2],
+            "max",
+            False,
+            True,
+            "NDHWC",
+        ),
+        (
+            [1, 31, 31, 31, 16],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [2, 2, 1, 1, 1, 2],
+            "max",
+            True,
+            True,
+            "NDHWC",
+        ),
+        (
+            [1, 31, 31, 31, 16],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [2, 1, 0, 5, 4, 3],
+            "avg",
+            False,
+            True,
+            "NDHWC",
+        ),
+        (
+            [1, 32, 32, 32, 16],
+            [2, 2, 2],
+            [2, 2, 2],
+            [1, 1, 1],
+            [0, 5, 4, 3, 2, 1],
+            "avg",
+            False,
+            False,
+            "NDHWC",
+        ),
+        (
+            [1, 31, 31, 31, 16],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [1, 0, 5, 4, 3, 2],
+            "max",
+            False,
+            True,
+            "NDHWC",
+        ),
+        (
+            [1, 31, 31, 31, 16],
+            [3, 3, 3],
+            [3, 3, 3],
+            [1, 1, 1],
+            [3, 2, 1, 0, 5, 4],
+            "max",
+            True,
+            True,
+            "NDHWC",
+        ),
+        # Test non-1 dilation
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [3, 3, 3],
+            [2, 1, 0, 5, 4, 3],
+            "avg",
+            False,
+            True,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 32, 32, 32],
+            [2, 2, 2],
+            [2, 2, 2],
+            [2, 2, 2],
+            [0, 5, 4, 3, 2, 1],
+            "avg",
+            False,
+            False,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [2, 1, 3],
+            [1, 0, 5, 4, 3, 2],
+            "max",
+            False,
+            True,
+            "NCDHW",
+        ),
+        (
+            [1, 16, 31, 31, 31],
+            [3, 3, 3],
+            [3, 3, 3],
+            [2, 2, 3],
+            [3, 2, 1, 0, 5, 4],
+            "max",
+            True,
+            True,
+            "NCDHW",
+        ),
+    )
+
+    @requires_hexagon_toolchain
+    def test_pool3d(
+        self,
+        hexagon_session,
+        input_shape,
+        kernel,
+        stride,
+        dilation,
+        padding,
+        pool_type,
+        ceil_mode,
+        count_include_pad,
+        layout,
+    ):
+        verify_poolnd(
+            hexagon_session,
+            3,
+            input_shape,
+            kernel,
+            stride,
+            dilation,
+            padding,
+            pool_type,
+            ceil_mode,
+            count_include_pad,
+            layout,
+        )
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))
diff --git a/tests/python/contrib/test_hexagon/topi/test_reduce.py b/tests/python/contrib/test_hexagon/topi/test_reduce.py
new file mode 100644
index 0000000000..7978e3854f
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_reduce.py
@@ -0,0 +1,165 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for reduce"""
+import numpy as np
+import pytest
+import sys
+
+import tvm
+from tvm import topi
+from tvm import te
+import tvm.topi.testing
+
+from ..conftest import requires_hexagon_toolchain
+
+
+in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters(
+    ((32,), 0, False, "argmax", "float32"),
+    ((32, 24, 32, 24), (1, 2, 3), True, "sum", "float32"),
+    ((2, 3), None, True, "all", "bool"),
+    ((32, 24 * 32 * 24), (1,), False, "max", "float32"),
+    ((32, 128, 24), None, True, "sum", "float32"),
+    ((32, 128, 24), None, True, "all", "bool"),
+    ((32, 24, 32, 24), (0, 2), False, "min", "float32"),
+    ((32, 128), 1, True, "argmax", "float32"),
+    ((32, 24, 32, 24), 2, False, "argmin", "float32"),
+    ((31, 21, 15), None, True, "argmax", "float32"),
+    ((31, 21, 15), None, False, "sum", "float32"),
+    ((2, 3), None, True, "any", "bool"),
+    ((32, 128, 24), None, True, "any", "bool"),
+    ((1, 4, 7), 1, True, "any", "bool"),
+    ((32, 24, 32, 24), 2, False, "any", "bool"),
+)
+
+
+def _my_npy_argmax(arr, axis, keepdims):
+    if not keepdims:
+        return arr.argmax(axis=axis)
+    else:
+        if axis is None:
+            out_shape = [1 for _ in arr.shape]
+        else:
+            out_shape = list(arr.shape)
+            out_shape[axis] = 1
+
+        return arr.argmax(axis=axis).reshape(out_shape)
+
+
+def _my_npy_argmin(arr, axis, keepdims):
+    if not keepdims:
+        return arr.argmin(axis=axis)
+    else:
+        if axis is None:
+            out_shape = [1 for _ in arr.shape]
+        else:
+            out_shape = list(arr.shape)
+            out_shape[axis] = 1
+        return arr.argmin(axis=axis).reshape(out_shape)
+
+
+@tvm.testing.fixture(cache_return_value=True)
+def ref_data(in_shape, axis, keepdims, reduce_type, dtype):
+    # Test
+    if dtype == "bool":
+        in_npy_map = in_npy = np.random.choice([True, False], size=in_shape)
+    else:
+        in_npy = np.random.uniform(-1, 1, size=in_shape).astype(dtype)
+        in_npy_map = np.sqrt(np.exp(in_npy)).astype(dtype)
+
+    if reduce_type == "sum":
+        out_npy = in_npy_map.sum(axis=axis, keepdims=keepdims)
+    elif reduce_type == "all" and dtype == "bool":
+        out_npy = in_npy_map.all(axis=axis, keepdims=keepdims)
+    elif reduce_type == "any" and dtype == "bool":
+        out_npy = in_npy_map.any(axis=axis, keepdims=keepdims)
+    elif reduce_type == "max":
+        out_npy = in_npy_map.max(axis=axis, keepdims=keepdims)
+    elif reduce_type == "min":
+        out_npy = in_npy_map.min(axis=axis, keepdims=keepdims)
+    elif reduce_type == "argmax":
+        out_npy = _my_npy_argmax(in_npy_map, axis=axis, keepdims=keepdims)
+    elif reduce_type == "argmin":
+        out_npy = _my_npy_argmin(in_npy_map, axis=axis, keepdims=keepdims)
+    else:
+        raise NotImplementedError
+
+    return in_npy, in_npy_map, out_npy
+
+
+@requires_hexagon_toolchain
+def test_reduce_map(hexagon_session, ref_data, in_shape, axis, keepdims, reduce_type, dtype):
+    in_npy, in_npy_map, out_npy = ref_data
+
+    # Build the logic and compile the function
+    A = te.placeholder(shape=in_shape, name="A", dtype=dtype)
+    A1 = topi.sqrt(topi.exp(A))
+    out_dtype = dtype
+    if reduce_type == "sum":
+        B = topi.sum(A1, axis=axis, keepdims=keepdims)
+    elif reduce_type == "all":
+        B = topi.all(A, axis=axis, keepdims=keepdims)
+    elif reduce_type == "any":
+        B = topi.any(A, axis=axis, keepdims=keepdims)
+    elif reduce_type == "max":
+        B = topi.max(A1, axis=axis, keepdims=keepdims)
+    elif reduce_type == "min":
+        B = topi.min(A1, axis=axis, keepdims=keepdims)
+    elif reduce_type == "argmax":
+        B = topi.argmax(A1, axis=axis, keepdims=keepdims)
+        out_dtype = "int32"
+    elif reduce_type == "argmin":
+        B = topi.argmin(A1, axis=axis, keepdims=keepdims)
+        out_dtype = "int32"
+    else:
+        raise NotImplementedError
+
+    target_hexagon = tvm.target.hexagon("v68")
+    with tvm.target.Target(target_hexagon):
+        fschedule = topi.hexagon.schedule_reduce
+        s = fschedule(B)
+
+    func = tvm.build(
+        s, [A, B], tvm.target.Target(target_hexagon, host=target_hexagon), name=reduce_type
+    )
+    mod = hexagon_session.load_module(func)
+
+    dev = hexagon_session.device
+    data_tvm = tvm.nd.array(in_npy, device=dev)
+    out_tvm = tvm.nd.empty(shape=out_npy.shape, device=dev, dtype=out_dtype)
+
+    mod[reduce_type](data_tvm, out_tvm)
+
+    if reduce_type == "argmax" or reduce_type == "argmin":
+        out_tvm_indices = out_tvm.numpy()
+        if keepdims:
+            out_tvm_indices = np.take(out_tvm_indices, indices=0, axis=axis)
+        if axis is None:
+            out_tvm_val = in_npy_map.ravel()[out_tvm_indices]
+        else:
+            other_indices = tuple(np.indices(in_shape[0:axis] + in_shape[(axis + 1) :]))
+            sel_indices = other_indices[0:axis] + (out_tvm_indices,) + other_indices[axis:]
+            out_tvm_val = in_npy_map[sel_indices]
+        if reduce_type == "argmax":
+            tvm.testing.assert_allclose(out_tvm_val, in_npy_map.max(axis=axis), 1e-3, 1e-3)
+        elif reduce_type == "argmin":
+            tvm.testing.assert_allclose(out_tvm_val, in_npy_map.min(axis=axis), 1e-3, 1e-3)
+    else:
+        tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))
diff --git a/tests/python/contrib/test_hexagon/topi/test_softmax.py b/tests/python/contrib/test_hexagon/topi/test_softmax.py
new file mode 100644
index 0000000000..4825d1e524
--- /dev/null
+++ b/tests/python/contrib/test_hexagon/topi/test_softmax.py
@@ -0,0 +1,101 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Test code for softmax"""
+import numpy as np
+import pytest
+import sys
+
+import tvm
+from tvm import topi
+from tvm import te
+import tvm.topi.testing
+from tvm.topi.utils import get_const_tuple
+
+from ..conftest import requires_hexagon_toolchain
+
+dtype = tvm.testing.parameter(
+    "float16",
+    "float32",
+)
+
+# TODO(mehrdadh): add log_softmax to config
+configs = {
+    "softmax": {
+        "topi": topi.nn.softmax,
+        "ref": tvm.topi.testing.softmax_python,
+        "dimensions": [2, 4],
+    },
+}
+
+# TODO(mehrdadh): larger size like (1, 16, 256, 256) would fail due to TVM_HEXAGON_RPC_BUFF_SIZE_BYTES
+shapes = [(32, 10), (3, 4), (1, 16, 32, 32)]
+softmax_operation, shape = tvm.testing.parameters(
+    *[
+        (name, shape)
+        for name, config in configs.items()
+        for shape in shapes
+        if len(shape) in config["dimensions"]
+    ]
+)
+
+
+@requires_hexagon_toolchain
+def test_softmax(hexagon_session, shape, dtype, softmax_operation):
+    if dtype == "float16":
+        pytest.xfail("float16 is not supported.")
+    A = te.placeholder(shape, dtype=dtype, name="A")
+
+    topi_op = configs[softmax_operation]["topi"]
+    B = topi_op(A, axis=1)
+
+    def get_ref_data(shape):
+        ref_func = tvm.topi.testing.softmax_python
+        a_np = np.random.uniform(size=shape).astype(dtype)
+
+        if len(shape) == 2:
+            b_np = ref_func(a_np)
+        elif len(shape) == 4:
+            _, c, h, w = a_np.shape
+            a_np_2d = a_np.transpose(0, 2, 3, 1).reshape(h * w, c)
+            b_np_2d = tvm.topi.testing.softmax_python(a_np_2d)
+            b_np = b_np_2d.reshape(1, h, w, c).transpose(0, 3, 1, 2)
+
+        return a_np, b_np
+
+    # get the test data
+    a_np, b_np = get_ref_data(shape)
+
+    target_hexagon = tvm.target.hexagon("v68")
+    with tvm.target.Target(target_hexagon):
+        fschedule = topi.hexagon.schedule_softmax
+        s = fschedule(B)
+
+    func = tvm.build(
+        s, [A, B], tvm.target.Target(target_hexagon, host=target_hexagon), name="softmax"
+    )
+    mod = hexagon_session.load_module(func)
+
+    dev = hexagon_session.device
+    a = tvm.nd.array(a_np, dev)
+    b = tvm.nd.array(np.zeros(get_const_tuple(B.shape), dtype=B.dtype), dev)
+    mod["softmax"](a, b)
+
+    tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+
+
+if __name__ == "__main__":
+    sys.exit(pytest.main(sys.argv))
diff --git a/tests/scripts/task_build_hexagon_api.sh b/tests/scripts/task_build_hexagon_api.sh
index 89b7545f4d..ae4d421268 100755
--- a/tests/scripts/task_build_hexagon_api.sh
+++ b/tests/scripts/task_build_hexagon_api.sh
@@ -19,8 +19,18 @@
 set -e
 set -u
 
+use_cache=false
+if [ $# -ge 1 ] && [[ "$1" == "--use-cache" ]]; then
+    use_cache=true
+    shift 1
+fi
+
 cd apps/hexagon_api
-rm -rf build
+
+if [ "$use_cache" = false ]; then
+    rm -rf build
+fi
+
 mkdir -p build
 cd build