You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2023/01/03 07:03:56 UTC

[tvm] branch main updated: [CLML] Version compatibility and various test cases (#13670)

This is an automated email from the ASF dual-hosted git repository.

masahi pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new b6851f344e [CLML] Version compatibility and various test cases  (#13670)
b6851f344e is described below

commit b6851f344e9ac3ee57e293626fba3decd333f6ab
Author: Siva <qu...@quicinc.com>
AuthorDate: Tue Jan 3 12:33:51 2023 +0530

    [CLML] Version compatibility and various test cases  (#13670)
    
    * [CLML][TEST] Codegen test cases for ops
    
    Codegen verification test cases for all the ops (convolution, concat, pad, pool ..etc.)
    that are supported by clml BYOC path.
    
    Fix depthwise conv2d issue with layout
    
    * * lint errors
    
    * * version compatilibility changes.
    
    * * review comments
    
    * * Make the adreno container compatible w/ and w/o CLML SDK availability
    
    Co-authored-by: Siva Rama Krishna Reddy B <si...@qti.qualcomm.com>
---
 cmake/modules/contrib/CLML.cmake                 |  16 +-
 python/tvm/relay/op/contrib/clml.py              |  58 ++--
 src/relay/backend/contrib/clml/codegen.cc        |   2 +-
 src/runtime/contrib/clml/clml_runtime.cc         |  38 ++-
 tests/python/contrib/test_clml/infrastructure.py |  58 +++-
 tests/python/contrib/test_clml/test_network.py   |  15 +-
 tests/python/contrib/test_clml/test_ops.py       | 377 +++++++++++++++++++++--
 tests/scripts/task_build_adreno_bins.sh          |   6 +-
 tests/scripts/task_config_build_adreno.sh        |   4 +-
 9 files changed, 482 insertions(+), 92 deletions(-)

diff --git a/cmake/modules/contrib/CLML.cmake b/cmake/modules/contrib/CLML.cmake
index e86a7e1ae0..811b8f8d58 100644
--- a/cmake/modules/contrib/CLML.cmake
+++ b/cmake/modules/contrib/CLML.cmake
@@ -22,7 +22,21 @@ if(USE_CLML)
     if(NOT USE_CLML_GRAPH_EXECUTOR)
         list(APPEND COMPILER_SRCS ${CLML_RUNTIME_MODULE})
     endif()
-    message(STATUS "Build with CLML support...")
+    message(STATUS "Build with CLML support : " ${USE_CLML})
+    if (NOT USE_CLML STREQUAL "ON")
+        set(CLML_VERSION_HEADER "${USE_CLML}/CL/cl_qcom_ml_ops.h")
+        if(EXISTS ${CLML_VERSION_HEADER})
+            file(READ ${CLML_VERSION_HEADER} ver)
+            string(REGEX MATCH "CL_QCOM_ML_OPS_H_MAJOR_VERSION ([0-9]*)" _ ${ver})
+            set(CLML_VERSION_MAJOR ${CMAKE_MATCH_1})
+        else()
+            set(CLML_VERSION_MAJOR "2")
+        endif()
+    else()
+        set(CLML_VERSION_MAJOR "2")
+    endif()
+    add_definitions(-DTVM_CLML_VERSION=${CLML_VERSION_MAJOR})
+    message(STATUS "CLML SDK Version :" ${CLML_VERSION_MAJOR})
 endif()
 
 if(USE_CLML_GRAPH_EXECUTOR)
diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py
index 6453b8a06c..02e4f62bed 100644
--- a/python/tvm/relay/op/contrib/clml.py
+++ b/python/tvm/relay/op/contrib/clml.py
@@ -28,6 +28,12 @@ from .register import register_pattern_table
 from ..strategy.generic import is_depthwise_conv2d
 
 
+def clml_sdk_version():
+    """Utility function to get clml version version"""
+
+    return tvm.support.libinfo().get("TVM_CLML_VERSION", 2)
+
+
 def is_clml_runtime_enabled():
     """Check if the CLML graph runtime is present.
 
@@ -92,38 +98,35 @@ def preprocess_module(mod):
     preprocessed_mod : The processed module.
     """
 
-    def convert_layout_conv2d(conv2d_function):
-        def convert_conv(attrs, inputs, tinfos, desired_layouts):
-            new_attrs = dict(attrs)
-            data_info = tinfos[0]
-            weight_info = tinfos[1]
-            desired_data_layout, desired_kernel_layout = map(str, desired_layouts)
-            new_attrs["data_layout"] = desired_data_layout
-            new_attrs["kernel_layout"] = desired_kernel_layout
-
-            if is_depthwise_conv2d(
-                data_info.shape,
-                attrs["data_layout"],
-                weight_info.shape,
-                attrs["kernel_layout"],
-                attrs["groups"],
-            ):
-                dkl = desired_kernel_layout
-                new_attrs["kernel_layout"] = dkl[1] + dkl[0] + dkl[2] + dkl[3]
-            return conv2d_function(*inputs, **new_attrs)
-
-        return convert_conv
-
-    with OpAttrContext(
-        "nn.conv2d", "FTVMConvertOpLayout", convert_layout_conv2d(tvm.relay.nn.conv2d)
-    ):
+    def alter_conv(attrs, inputs, tinfos, out_type):
+        new_attrs = dict(attrs)
+        data_info = tinfos[0]
+        weight_info = tinfos[1]
+        (desired_data_layout, desired_kernel_layout) = ("NCHW", "OIHW")
+        new_attrs["data_layout"] = desired_data_layout
+        new_attrs["kernel_layout"] = desired_kernel_layout
+
+        if is_depthwise_conv2d(
+            data_info.shape,
+            attrs["data_layout"],
+            weight_info.shape,
+            attrs["kernel_layout"],
+            attrs["groups"],
+        ):
+            dkl = desired_kernel_layout
+            new_attrs["kernel_layout"] = dkl[1] + dkl[0] + dkl[2] + dkl[3]
+        return relay.nn.conv2d(*inputs, **new_attrs)
+
+    with OpAttrContext("nn.conv2d", "FTVMAlterOpLayout", alter_conv):
         seq = tvm.transform.Sequential(
             [
                 transform.ConvertLayout({"nn.conv2d": ["NCHW", "OIHW"]}),
+                transform.AlterOpLayout(),
                 transform.FoldConstant(),
             ]
         )
-        preprocessed_mod = seq(mod)
+        with tvm.transform.PassContext(opt_level=3):
+            preprocessed_mod = seq(mod)
     return preprocessed_mod
 
 
@@ -275,6 +278,9 @@ def clml_pattern_table():
         ("clml.add", is_op("add")(wildcard(), wildcard()), check_binary_op),
         ("clml.subtract", is_op("subtract")(wildcard(), wildcard()), check_binary_op),
         ("clml.multiply", is_op("multiply")(wildcard(), wildcard()), check_binary_op),
+        ("clml.divide", is_op("divide")(wildcard(), wildcard()), check_binary_op),
+        ("clml.minimum", is_op("minimum")(wildcard(), wildcard()), check_binary_op),
+        ("clml.maximum", is_op("maximum")(wildcard(), wildcard()), check_binary_op),
         ("clml.softmax", is_op("nn.softmax")(wildcard()), check_softmax_op),
         ("clml.reshape", is_op("reshape")(wildcard()), check_default_op),
         ("clml.avg_pool2d", is_op("nn.avg_pool2d")(wildcard()), check_default_op),
diff --git a/src/relay/backend/contrib/clml/codegen.cc b/src/relay/backend/contrib/clml/codegen.cc
index 167c48e1ba..d8ca791ad8 100644
--- a/src/relay/backend/contrib/clml/codegen.cc
+++ b/src/relay/backend/contrib/clml/codegen.cc
@@ -328,7 +328,7 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
     const auto* dense = fn->body.as<CallNode>();
     const CallNode* bias = nullptr;
 
-    if (backend::IsOp(dense, "add")) {
+    if (backend::IsOp(dense, "add") || backend::IsOp(dense, "nn.bias_add")) {
       bias = dense;
       dense = dense->args[0].as<CallNode>();
     }
diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc
index a667caaafc..6396fce485 100644
--- a/src/runtime/contrib/clml/clml_runtime.cc
+++ b/src/runtime/contrib/clml/clml_runtime.cc
@@ -153,13 +153,25 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(result == CL_SUCCESS) << "clQueryMLInterfaceVersionsQCOM:" << result;
 
     for (cl_uint i = 0; i < numVersions; ++i) {
+#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 2
       if (majorVersions[i] == 2) {
-        LOG(WARNING) << "CLML Version Selected:" << majorVersions[i] << " : " << majorVersions[i];
         h_ClmlIntf = clGetMLInterfaceV2QCOM(0);
-        ICHECK(h_ClmlIntf != NULL) << "clGetMLInterfaceV2QCOM:" << result;
+        LOG(WARNING) << "CLML Target version:" << majorVersions[i];
         break;
       }
+#endif
+#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 3
+      if (majorVersions[i] == 3) {
+        h_ClmlIntf = clGetMLInterfaceV3QCOM(0);
+        LOG(WARNING) << "CLML Target version:" << majorVersions[i];
+        break;
+      }
+#endif
     }
+    ICHECK(h_ClmlIntf != NULL)
+        << "clGetMLInterfaceVxQCOM:" << result
+        << " Perhaps there is mispatch between CLML SDK version to target supported version:"
+        << majorVersions[numVersions - 1];
     char* tune_flag;
     if ((tune_flag = getenv("CLML_IS_TUNNING_RUN")))
       this->is_tuning_run = std::stoi(tune_flag);
@@ -400,7 +412,7 @@ class CLMLRuntime : public JSONRuntimeBase {
           this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
           this->layer_.func_outs.push_back(out);
         } else if ("add" == op_name || "subtract" == op_name || "multiply" == op_name ||
-                   "minimum" == op_name || "maximum" == op_name) {
+                   "minimum" == op_name || "maximum" == op_name || "divide" == op_name) {
           auto out = CreateBinaryLayer(&layer_, node);
           this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
           this->layer_.func_outs.push_back(out);
@@ -523,7 +535,7 @@ class CLMLRuntime : public JSONRuntimeBase {
   }
 
   cl_ml_tensor_qcom DeviceMakeCLMLTensor(
-      void* pClmlIntf, cl_context context, tensor_dims_t dims,
+      cl_context context, tensor_dims_t dims,
       cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
       cl_channel_type dtype = CL_FLOAT) {
     cl_ml_tensor_qcom tensor;
@@ -531,8 +543,7 @@ class CLMLRuntime : public JSONRuntimeBase {
 
     cl_ml_tensor_desc_qcom desc = {
         dtype, layout, dims.n, dims.c, dims.h, dims.w, 0, CL_TENSOR_DIMENSIONS_4D_QCOM, { 0 }};
-    CLMLInterfaceV2QCOM* clmlIntf = reinterpret_cast<CLMLInterfaceV2QCOM*>(pClmlIntf);
-    result = clmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &tensor);
+    result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &tensor);
     ICHECK(tensor && result == CL_SUCCESS) << "clCreateMLTensorQCOM:" << result;
     (void)result;
     return tensor;
@@ -544,9 +555,8 @@ class CLMLRuntime : public JSONRuntimeBase {
     cl_int result = CL_OUT_OF_HOST_MEMORY;
     cl_mem buffer = NULL;
 
-    CLMLInterfaceV2QCOM* clmlIntf = reinterpret_cast<CLMLInterfaceV2QCOM*>(pClmlIntf);
     result =
-        clmlIntf->clGetMLTensorMemorySizeQCOM(workspace->context, pTensorMemDesc->tensor, &size);
+        h_ClmlIntf->clGetMLTensorMemorySizeQCOM(workspace->context, pTensorMemDesc->tensor, &size);
     ICHECK(result == CL_SUCCESS) << "clGetMLTensorMemorySizeQCOM:" << result;
 
     buffer = clCreateBuffer(workspace->context, CL_MEM_READ_WRITE, size, NULL, &result);
@@ -612,8 +622,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
 
     auto tensor_dsc = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
-    tensor_dsc->tensor =
-        DeviceMakeCLMLTensor(h_ClmlIntf, workspace->context, dims, layout, cl_dtype);
+    tensor_dsc->tensor = DeviceMakeCLMLTensor(workspace->context, dims, layout, cl_dtype);
     return tensor_dsc;
   }
 
@@ -901,7 +910,6 @@ class CLMLRuntime : public JSONRuntimeBase {
     auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
                                              cl_dtype);
     auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
-    auto in_dims = get_tensor_dims(nodes_[node.GetInputs()[0].id_]);
 
     std::vector<std::string> windows = node.GetAttr<std::vector<std::string>>("pool_size");
     std::vector<std::string> strides = node.GetAttr<std::vector<std::string>>("strides");
@@ -1103,7 +1111,6 @@ class CLMLRuntime : public JSONRuntimeBase {
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
     cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
     int inputSize = input_.size();
-    int axis = std::stoi(node.GetAttr<std::vector<std::string>>("axis")[0]);
     auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     cl_ml_tensor_qcom* concatInputs = new cl_ml_tensor_qcom[inputSize];
     for (int i = 0; i < inputSize; i++) {
@@ -1236,6 +1243,8 @@ class CLMLRuntime : public JSONRuntimeBase {
       binary_op = CL_TENSOR_OP_SUB_QCOM;
     else if (op_name == "multiply")
       binary_op = CL_TENSOR_OP_MUL_QCOM;
+    else if (op_name == "divide")
+      binary_op = CL_TENSOR_OP_DIV_QCOM;
     else if (op_name == "minimum")
       binary_op = CL_TENSOR_OP_MIN_QCOM;
     else if (op_name == "maximum")
@@ -1260,7 +1269,12 @@ class CLMLRuntime : public JSONRuntimeBase {
 
   CachedLayer layer_;
   // CLML Context
+#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 2
   CLMLInterfaceV2QCOM* h_ClmlIntf = NULL;
+#endif
+#if CL_QCOM_ML_OPS_H_MAJOR_VERSION == 3
+  CLMLInterfaceV3QCOM* h_ClmlIntf = NULL;
+#endif
   cl::OpenCLWorkspace* workspace = NULL;
   cl::OpenCLThreadEntry* tentry = NULL;
   cl_ml_tuningcache_qcom tuning_cache = NULL;
diff --git a/tests/python/contrib/test_clml/infrastructure.py b/tests/python/contrib/test_clml/infrastructure.py
index 89c22255d7..be2bbc7f8a 100644
--- a/tests/python/contrib/test_clml/infrastructure.py
+++ b/tests/python/contrib/test_clml/infrastructure.py
@@ -39,9 +39,9 @@ class Device:
     Configuration for CLML tests.
 
     Check tests/python/contrib/clml/ for the presence of an test_config.json file.
-    This file can be used to override the default configuration here which will attempt to run the Arm
-    Compute Library runtime tests locally if the runtime is available. Changing the configuration
-    will allow these runtime tests to be offloaded to a remote Arm device via a tracker for example.
+    This file can be used to override the default configuration here which will attempt to run the
+    Open CLML runtime tests locally if the runtime is available. Changing the configuration
+    will allow these runtime tests to be offloaded to a remote Snapdragon device via a tracker for example.
 
     Notes
     -----
@@ -101,6 +101,25 @@ class Device:
         return device
 
 
+def get_cpu_op_count(mod):
+    """Traverse graph counting ops offloaded to TVM."""
+
+    class Counter(tvm.relay.ExprVisitor):
+        def __init__(self):
+            super().__init__()
+            self.count = 0
+
+        def visit_call(self, call):
+            if isinstance(call.op, tvm.ir.Op):
+                self.count += 1
+
+            super().visit_call(call)
+
+    c = Counter()
+    c.visit(mod["main"])
+    return c.count
+
+
 def skip_codegen_test():
     """Skip test if it requires the CLML codegen and it's not present."""
     if not tvm.get_global_func("relay.ext.clml", True):
@@ -130,7 +149,6 @@ def build_and_run(
 
     try:
         libm = build_module(mod, device.target, device.target_host, params, enable_clml, tune_log)
-
         clml_modules = extract_clml_modules(libm)
         for mod in clml_modules:
             source = mod.get_source("json")
@@ -155,9 +173,9 @@ def build_and_run(
     for _ in range(no_runs):
         gen_module.run()
         out.append([gen_module.get_output(i) for i in range(outputs)])
-    time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=1)
-    cost = time_f().mean
-    print("%g secs/iteration\n" % cost)
+    # time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=1)
+    # cost = time_f().mean
+    # print("%g secs/iteration\n" % cost)
     return out
 
 
@@ -181,16 +199,34 @@ def extract_clml_modules(module):
 
 
 def verify_codegen(
-    module,
+    mod,
     known_good_codegen,
+    device,
+    params,
     num_clml_modules=1,
     tvm_ops=0,
-    target="llvm -mtriple=aarch64-linux-gnu",
 ):
     """Check clml codegen against a known good output."""
-    module = build_module(module, target, tvm_ops=tvm_ops, clml_partitions=num_clml_modules)
-    clml_modules = extract_clml_modules(module)
+    if isinstance(mod, tvm.relay.expr.Call):
+        mod = tvm.IRModule.from_expr(mod)
+    with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
+        mod = clml.partition_for_clml(mod, params)
+        tvm_op_count = get_cpu_op_count(mod)
+        assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format(
+            tvm_op_count, tvm_ops
+        )
+        partition_count = 0
+        for global_var in mod.get_global_vars():
+            if "clml" in global_var.name_hint:
+                partition_count += 1
+
+        assert (
+            num_clml_modules == partition_count
+        ), "Got {} Open CLML partitions, expected {}".format(partition_count, num_clml_modules)
+    relay.backend.te_compiler.get().clear()
 
+    module = relay.build(mod, target=device.target, target_host=device.target_host, params=params)
+    clml_modules = extract_clml_modules(module)
     assert len(clml_modules) == num_clml_modules, (
         f"The number of CLML modules produced ({len(clml_modules)}) does not "
         f"match the expected value ({num_clml_modules})."
diff --git a/tests/python/contrib/test_clml/test_network.py b/tests/python/contrib/test_clml/test_network.py
index 8d740d6dce..177359d9b1 100644
--- a/tests/python/contrib/test_clml/test_network.py
+++ b/tests/python/contrib/test_clml/test_network.py
@@ -91,13 +91,8 @@ def test_mobilenet(device, dtype):
         mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5
     )
 
-    # test
-    print("OpenCL:", outputs[0].asnumpy().shape)
-    print("CLML:", outputs[1].asnumpy().shape)
-
     opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
     clml_sort = np.argsort(outputs[0].asnumpy()).flatten()
-
     tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, atol=1e-5)
 
 
@@ -134,7 +129,6 @@ def test_inception_v3(device, dtype):
 
     opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
     clml_sort = np.argsort(outputs[0].asnumpy()).flatten()
-
     tvm.testing.assert_allclose(opencl_sort[:5], clml_sort[:5], rtol=1e-5, atol=1e-5)
 
 
@@ -176,11 +170,10 @@ def test_resnet50v2(device, dtype):
         mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5
     )
 
-    # test
-    print("OpenCL:", outputs[0].asnumpy().shape)
-    print("CLML:", outputs[1].asnumpy().shape)
-
     opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
     clml_sort = np.argsort(outputs[0].asnumpy()).flatten()
-
     tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, atol=1e-5)
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py
index da09715fbe..c4ec260324 100644
--- a/tests/python/contrib/test_clml/test_ops.py
+++ b/tests/python/contrib/test_clml/test_ops.py
@@ -14,15 +14,23 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-"""CLML integration conv2d tests."""
+"""CLML integration operator tests."""
 
 import tvm
 import numpy as np
 from tvm import relay
+from tvm.relay.op.contrib import clml
 from tvm.relay import testing
 from tvm.ir import IRModule
 from tvm.contrib import utils
-from test_clml.infrastructure import build_and_run, Device, skip_codegen_test
+from test_clml.infrastructure import (
+    build_and_run,
+    Device,
+    skip_codegen_test,
+    verify_codegen,
+    build_module,
+    get_cpu_op_count,
+)
 import pytest
 
 
@@ -54,11 +62,8 @@ def _get_conv_model(
         shape = (shape[0], shape[1], shape[2] + padding[0] * 2, shape[3] + padding[1] * 2)
     is_depthwise = shape[1] == channels == groups
 
-    weight_format = "OIHW" if is_depthwise else "OIHW"
-    if weight_format == "IOHW":
-        weight_shape = (shape[1] // groups, channels, kernel_h, kernel_w)
-    else:
-        weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w)
+    weight_format = "OIHW"
+    weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w)
 
     w = tvm.nd.array(np.random.uniform(-1, 1, weight_shape).astype(dtype))
     weights = relay.const(w, dtype)
@@ -77,7 +82,7 @@ def _get_conv_model(
     )
     params = {"w": w}
     if has_bias:
-        bias_shape = weight_shape[2] if is_depthwise else weight_shape[0]
+        bias_shape = (weight_shape[0],)
         b = tvm.nd.array(np.random.uniform(-1, 1, bias_shape).astype(dtype))
         biasc = relay.const(b, dtype)
         out = relay.nn.bias_add(out, biasc, axis=1)
@@ -86,31 +91,121 @@ def _get_conv_model(
     if has_activation:
         out = relay.nn.relu(out)
 
-    print("Out:", out)
-
     return out, params
 
 
+def _get_conv_expected_codegen(
+    shape,
+    kernel_h,
+    kernel_w,
+    padding,
+    strides,
+    dilation,
+    groups,
+    dtype,
+    channels,
+    has_bias=False,
+    has_activation=False,
+):
+    if len(padding) == 2:
+        padding = (padding[0], padding[1], padding[0], padding[1])
+    output_height = ((shape[2] - kernel_h + padding[0] + padding[2]) / strides[0]) + 1
+    output_width = ((shape[3] - kernel_w + padding[1] + padding[3]) / strides[1]) + 1
+    output_shape = (1, channels, int(output_height), int(output_width))
+    out_dtype = dtype
+    is_depthwise = shape[1] == channels == groups
+
+    weight_format = "IOHW" if is_depthwise else "OIHW"
+    if weight_format == "OIHW":
+        weight_shape = (channels, shape[1] // groups, kernel_h, kernel_w)
+    else:
+        weight_shape = (shape[1] // groups, channels, kernel_h, kernel_w)
+
+    if is_depthwise:
+        name = "nn.depthwise_conv2d"
+    else:
+        name = "nn.conv2d"
+
+    node = {
+        "op": "kernel",
+        "name": name,
+        "inputs": [],
+        "attrs": {
+            "groups": [[str(groups)]],
+            "num_outputs": "1",
+            "data_layout": [["NCHW"]],
+            "kernel_layout": [[weight_format]],
+            "channels": [[str(channels)]],
+            "dilation": [[str(dilation[0]), str(dilation[1])]],
+            "out_layout": [[""]],
+            "out_dtype": [[out_dtype]],
+            "kernel_size": [[str(kernel_h), str(kernel_w)]],
+            "shape": [[list(output_shape)]],
+            "dtype": [[dtype]],
+            "padding": [[str(p) for p in padding]],
+            "strides": [[str(s) for s in strides]],
+        },
+    }
+
+    if has_activation:
+        node["attrs"]["activation_type"] = [["relu"]]
+
+    inputs = [
+        {"op": "input", "name": "", "attrs": {"shape": [[list(shape)]], "dtype": [[str(dtype)]]}},
+        {
+            "op": "const",
+            "name": "",
+            "attrs": {"shape": [[list(weight_shape)]], "dtype": [[str(dtype)]]},
+        },
+    ]
+
+    if has_bias:
+        bias_dtype = dtype
+        inputs.append(
+            {
+                "op": "const",
+                "name": "",
+                "attrs": {
+                    "shape": [[[1, weight_shape[1] if is_depthwise else weight_shape[0], 1, 1]]],
+                    "dtype": [[bias_dtype]],
+                },
+            }
+        )
+
+    input_idx = 0
+    for _ in range(len(inputs)):
+        node["inputs"].append([input_idx, 0, 0])
+        input_idx += 1
+    node["attrs"]["num_inputs"] = str(len(inputs))
+    inputs.append(node)
+    return inputs
+
+
 @pytest.mark.parametrize("dtype", ["float32"])
 @tvm.testing.requires_openclml
 def test_conv2d(device, dtype):
     trials = [
         # Normal convolution
-        [3, 3, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False)],
-        [2, 1, (2, 2), (1, 1), (1, 1), 7, (15, 16, 12), (True, False, True)],
-        [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False)],
-        [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, True)],
-        # Normal convolution
-        [2, 2, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False)],
-        [2, 1, (2, 2), (1, 1), (1, 1), 7, (16, 12, 15), (False, False, True)],
-        [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False)],
-        [3, 3, (1, 1), (1, 1), (1, 1), 16, (16, 12, 15), (False, False, False)],
-        [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False)],
-        [1, 3, (1, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, True)],
-        [2, 2, (2, 2), (1, 1), (1, 1), 4, (20, 20, 20), (False, True, False)],
-        [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False)],
-        [3, 3, (2, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, False)],
-        [3, 3, (1, 1), (2, 2), (1, 1), 16, (14, 10, 10), (False, True, True)],
+        [3, 3, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False), False],
+        [2, 1, (2, 2), (1, 1), (1, 1), 7, (15, 16, 12), (True, False, True), False],
+        [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False), False],
+        [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, True), False],
+        [2, 2, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False), False],
+        [2, 1, (2, 2), (1, 1), (1, 1), 7, (16, 12, 15), (False, False, True), False],
+        [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False), False],
+        [3, 3, (1, 1), (1, 1), (1, 1), 16, (16, 12, 15), (False, False, False), False],
+        [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False), False],
+        [1, 3, (1, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, True), False],
+        [2, 2, (2, 2), (1, 1), (1, 1), 4, (20, 20, 20), (False, True, False), False],
+        [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False), False],
+        [3, 3, (2, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, False), False],
+        [3, 3, (1, 1), (2, 2), (1, 1), 16, (14, 10, 10), (False, True, True), False],
+        # Depth-wise convolution
+        [3, 3, (1, 1), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, True), True],
+        [5, 5, (2, 2), (1, 1), (1, 1), 20, (20, 20, 20), (False, True, False), True],
+        [3, 3, (2, 2), (2, 2), (1, 1), 14, (14, 10, 10), (False, False, False), True],
+        [5, 5, (0, 0), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, False), True],
+        [3, 3, (1, 1), (2, 2), (1, 1), 14, (14, 10, 10), (False, True, True), True],
     ]
 
     for (
@@ -122,9 +217,13 @@ def test_conv2d(device, dtype):
         out_channels,
         shape,
         composite,
+        is_depthwise,
     ) in trials:
         shape = (1, *shape)
-        groups = 1
+        if is_depthwise:
+            groups = shape[1]
+        else:
+            groups = 1
         outputs = []
         inputs = {
             "a": tvm.nd.array(np.random.uniform(-1, 1, shape).astype(dtype)),
@@ -151,11 +250,19 @@ def test_conv2d(device, dtype):
         tvm.testing.assert_allclose(
             clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-5, atol=1e-5
         )
+        args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, out_channels)
+        exp_codegen = _get_conv_expected_codegen(
+            *args, has_bias=composite[1], has_activation=composite[2]
+        )
+        verify_codegen(func, exp_codegen, device, params)
 
 
 @pytest.mark.parametrize("dtype", ["float16"])
 @tvm.testing.requires_openclml
-def _test_batchnorm(device, dtype):
+def test_batchnorm(device, dtype):
+    if tvm.support.libinfo().get("TVM_CLML_VERSION", 2) < 3:
+        print("Skip due to unsupported CLML version")
+        return
     in_shape = (1, 8, 64, 64)
     channels = 8
 
@@ -211,11 +318,80 @@ def test_concat(device, dtype):
     tvm.testing.assert_allclose(
         clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3
     )
+    exp_codegen = [
+        {
+            "attrs": {
+                "dtype": [[dtype]],
+                "shape": [[list(in_shape_1)]],
+            },
+            "name": "",
+            "op": "input",
+        },
+        {
+            "attrs": {
+                "dtype": [[dtype]],
+                "shape": [[list(in_shape_2)]],
+            },
+            "name": "",
+            "op": "input",
+        },
+        {
+            "attrs": {
+                "axis": [["1"]],
+                "dtype": [[dtype]],
+                "num_inputs": "2",
+                "num_outputs": "1",
+                "shape": [[list(clml_out[0].shape)]],
+            },
+            "inputs": [[0, 0, 0], [1, 0, 0]],
+            "name": "concatenate",
+            "op": "kernel",
+        },
+    ]
+    verify_codegen(func, exp_codegen, device, params)
+
+
+def _get_pool_expected_codegen(input_shape, pool_size, stride, padding, pool_type, dtype):
+    import math
+
+    pool_height = math.floor(((input_shape[2] + padding[2] - pool_size[0]) / stride[0]) + 1)
+    pool_width = math.floor(((input_shape[3] + padding[3] - pool_size[1]) / stride[1]) + 1)
+    output_shape = [input_shape[0], input_shape[1], pool_height, pool_width]
+    attrs = {
+        "ceil_mode": [["0"]],
+        "dilation": [["1", "1"]],
+        "layout": [["NCHW"]],
+        "num_inputs": "1",
+        "num_outputs": "1",
+        "out_layout": [[""]],
+        "padding": [[str(p) for p in padding]],
+        "pool_size": [[str(p) for p in pool_size]],
+        "shape": [[list(output_shape)]],
+        "dtype": [[dtype]],
+        "strides": [[str(s) for s in stride]],
+    }
+    if sum(padding):
+        attrs["count_include_pad"] = [["0"]]
+
+    exp_codegen = [
+        {
+            "op": "input",
+            "name": "",
+            "attrs": {"shape": [[list(input_shape)]], "dtype": [[str(dtype)]]},
+        },
+        {
+            "op": "kernel",
+            "name": "nn.avg_pool2d" if pool_type == "avg" else "nn.max_pool2d",
+            "inputs": [[0, 0, 0]],
+            "attrs": attrs,
+        },
+    ]
+    return exp_codegen
 
 
 @pytest.mark.parametrize("dtype", ["float16"])
 @tvm.testing.requires_openclml
-def test_avgpool(device, dtype):
+def test_pool(device, dtype):
     trials = [
         # input size         pool_size stride  paading
         [(1, 64, 147, 147), (3, 3), (2, 2), (0, 0, 0, 0), "max"],
@@ -251,7 +427,152 @@ def test_avgpool(device, dtype):
 
         opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0]
         clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0]
+        tvm.testing.assert_allclose(
+            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3
+        )
+
+        args = (input_shape, pool_size, stride, padding, pooling_type, dtype)
+        exp_codegen = _get_pool_expected_codegen(*args)
+        verify_codegen(func, exp_codegen, device, params)
+
 
+@pytest.mark.parametrize("dtype", ["float32"])
+@tvm.testing.requires_openclml
+def test_dense(device, dtype):
+    def _get_model(x_shape, k_shape, has_bias=False):
+        x = relay.var("x", shape=(x_shape), dtype=dtype)
+        kernel = relay.var("kernel", shape=(k_shape), dtype=dtype)
+        out = relay.nn.dense(x, kernel, units=k_shape[0])
+        params = {"kernel": tvm.nd.array(np.random.uniform(-1, 1, k_shape).astype(dtype))}
+        inputs = {"x": tvm.nd.array(np.random.uniform(-1, 1, x_shape).astype(dtype))}
+        exp_codegen = [
+            {
+                "attrs": {
+                    "dtype": [[dtype]],
+                    "shape": [[list(x_shape)]],
+                },
+                "name": "",
+                "op": "input",
+            },
+            {
+                "attrs": {
+                    "dtype": [[dtype]],
+                    "shape": [[list(k_shape)]],
+                },
+                "name": "",
+                "op": "const",
+            },
+        ]
+        if has_bias:
+            bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype)
+            out = relay.nn.bias_add(out, bias)
+            bias_node = {
+                "attrs": {
+                    "dtype": [[dtype]],
+                    "shape": [[list((1, k_shape[0]))]],
+                },
+                "name": "",
+                "op": "const",
+            }
+            exp_codegen.append(bias_node)
+            params["bias"] = tvm.nd.array(np.random.uniform(-1, 1, (k_shape[0],)).astype(dtype))
+
+        dense_node = {
+            "attrs": {
+                "num_inputs": "3" if has_bias else "2",
+                "num_outputs": "1",
+                "dtype": [[dtype]],
+                "out_dtype": [[""]],
+                "shape": [[[x_shape[0], k_shape[0]]]],
+                "units": [[str(k_shape[0])]],
+            },
+            "inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0]] if has_bias else [[0, 0, 0], [1, 0, 0]],
+            "name": "nn.dense",
+            "op": "kernel",
+        }
+        exp_codegen.append(dense_node)
+        return out, params, inputs, exp_codegen
+
+    def _verify(out, params, inputs, exp_codegen):
+        mod = IRModule.from_expr(out)
+        opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0]
+        clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0]
         tvm.testing.assert_allclose(
             clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3
         )
+        verify_codegen(out, exp_codegen, device, params)
+
+    _verify(*(_get_model((1, 16), (32, 16))))
+    _verify(*(_get_model((1, 16), (32, 16), True)))
+
+
+@pytest.mark.parametrize("dtype", ["float32"])
+@tvm.testing.requires_openclml
+def test_binary_ops(device, dtype):
+    def _get_model(a_shape, b_shape, op):
+        a = relay.var("a", shape=(a_shape), dtype=dtype)
+        b = relay.var("b", shape=(b_shape), dtype=dtype)
+        out = op(a, b)
+        inputs = {
+            "a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype)),
+            "b": tvm.nd.array(np.random.uniform(-1, 1, b_shape).astype(dtype)),
+        }
+        params = {}
+        return out, params, inputs
+
+    def _verify(out, params, inputs):
+        mod = IRModule.from_expr(out)
+        opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0]
+        clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0]
+        tvm.testing.assert_allclose(
+            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3
+        )
+
+        # Check to make sure these ops are offloaded to CLML instead of TVM.
+        with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            mod = clml.partition_for_clml(mod, params)
+            tvm_op_count = get_cpu_op_count(mod)
+            assert tvm_op_count == 0, "Got {} TVM Native Compute partitions, expected 0".format(
+                tvm_op_count
+            )
+
+    _verify(*(_get_model((1, 16), (1, 16), relay.add)))
+    _verify(*(_get_model((1, 16), (1, 16), relay.subtract)))
+    _verify(*(_get_model((1, 16), (1, 16), relay.multiply)))
+    _verify(*(_get_model((1, 16), (1, 16), relay.divide)))
+    _verify(*(_get_model((1, 16), (1, 16), relay.minimum)))
+    _verify(*(_get_model((1, 16), (1, 16), relay.maximum)))
+
+
+@pytest.mark.parametrize("dtype", ["float32"])
+@tvm.testing.requires_openclml
+def test_unary_ops(device, dtype):
+    def _get_model(a_shape, op):
+        a = relay.var("a", shape=(a_shape), dtype=dtype)
+        out = op(a)
+        inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype))}
+        params = {}
+        return out, params, inputs
+
+    def _verify(out, params, inputs):
+        mod = IRModule.from_expr(out)
+        opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0]
+        clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0]
+        tvm.testing.assert_allclose(
+            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3
+        )
+
+        # Check to make sure these ops are offloaded to CLML instead of TVM.
+        with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            mod = clml.partition_for_clml(mod, params)
+            tvm_op_count = get_cpu_op_count(mod)
+            assert tvm_op_count == 0, "Got {} TVM Native Compute partitions, expected 0".format(
+                tvm_op_count
+            )
+
+    _verify(*(_get_model((1, 16), relay.nn.softmax)))
+    _verify(*(_get_model((1, 16), relay.nn.relu)))
+
+
+if __name__ == "__main__":
+    tvm.testing.main()
diff --git a/tests/scripts/task_build_adreno_bins.sh b/tests/scripts/task_build_adreno_bins.sh
index 6b43d7cbc4..187ca7f815 100755
--- a/tests/scripts/task_build_adreno_bins.sh
+++ b/tests/scripts/task_build_adreno_bins.sh
@@ -29,8 +29,12 @@ cd ${output_directory}
 cp ../cmake/config.cmake .
 
 echo set\(USE_MICRO OFF\) >> config.cmake
-echo set\(USE_CLML ON\) >> config.cmake
+if [ -f "${ADRENO_OPENCL}/CL/cl_qcom_ml_ops.h" ] ; then
+echo set\(USE_CLML "${ADRENO_OPENCL}"\) >> config.cmake
 echo set\(USE_CLML_GRAPH_EXECUTOR "${ADRENO_OPENCL}"\) >> config.cmake
+else
+echo set\(USE_OPENCL ON\) >> config.cmake
+fi
 echo set\(USE_RPC ON\) >> config.cmake
 echo set\(USE_CPP_RPC ON\) >> config.cmake
 echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
diff --git a/tests/scripts/task_config_build_adreno.sh b/tests/scripts/task_config_build_adreno.sh
index d45c5e8b7d..d378b5f842 100755
--- a/tests/scripts/task_config_build_adreno.sh
+++ b/tests/scripts/task_config_build_adreno.sh
@@ -24,7 +24,9 @@ cd "$BUILD_DIR"
 cp ../cmake/config.cmake .
 
 echo set\(USE_OPENCL ON\) >> config.cmake
-echo set\(USE_CLML ON\) >> config.cmake
+if [ -f "${ADRENO_OPENCL}/CL/cl_qcom_ml_ops.h" ] ; then
+echo set\(USE_CLML ${ADRENO_OPENCL}\) >> config.cmake
+fi
 echo set\(USE_RPC ON\) >> config.cmake
 echo set\(USE_GRAPH_EXECUTOR ON\) >> config.cmake
 echo set\(USE_LIBBACKTRACE AUTO\) >> config.cmake