You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by le...@apache.org on 2022/09/14 09:48:09 UTC

[tvm] branch main updated: [OpenCLML] More ops and network coverage (#12762)

This is an automated email from the ASF dual-hosted git repository.

leandron pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2aa0d1fbfc [OpenCLML] More ops and network coverage (#12762)
2aa0d1fbfc is described below

commit 2aa0d1fbfcf4a31e343cc6852fdc4abd660c850a
Author: Siva <qu...@quicinc.com>
AuthorDate: Wed Sep 14 15:18:03 2022 +0530

    [OpenCLML] More ops and network coverage (#12762)
    
    Added operators pooling (avg, max), binary operators (add, subtract, multiply, min, max) and concat.
    Clip operator with min=0 and max=6 is remapped to relu6 to take advantage of CLML acceleration
    without sub graphing this to fallback path.
    
    Added new test cases for above listed operators and also end-to-end network test cases for Resnet50
    & InceptionV3.
    
    CLML support FP16 arithmetic mode which gives significant performance boost over FP32. This PR
    enhances FP16 usage based on Operator datatype in relay graph.
    
    Co-authored-by: Krishna Raju quic_kvegiraj@quicinc.com
    Co-authored-by: Shwetank Singh quic_shwesing@quicinc.com
---
 python/tvm/relay/op/contrib/clml.py              |  35 ++-
 src/relay/backend/contrib/clml/codegen.cc        |  37 +++
 src/runtime/contrib/clml/clml_runtime.cc         | 315 +++++++++++++++++++----
 tests/python/contrib/test_clml/infrastructure.py |  28 +-
 tests/python/contrib/test_clml/test_network.py   | 139 +++++++---
 tests/python/contrib/test_clml/test_ops.py       |  83 +++++-
 6 files changed, 529 insertions(+), 108 deletions(-)

diff --git a/python/tvm/relay/op/contrib/clml.py b/python/tvm/relay/op/contrib/clml.py
index cacd10de28..d253544d45 100644
--- a/python/tvm/relay/op/contrib/clml.py
+++ b/python/tvm/relay/op/contrib/clml.py
@@ -23,7 +23,7 @@ from tvm._ffi import register_func
 from tvm.relay import transform
 from tvm.relay.build_module import bind_params_by_name
 
-from ...dataflow_pattern import wildcard, is_op, is_constant, is_tuple_get_item
+from ...dataflow_pattern import wildcard, is_op, is_constant, is_tuple_get_item, is_tuple
 from .register import register_pattern_table
 from ..strategy.generic import is_depthwise_conv2d
 
@@ -135,6 +135,7 @@ def clml_pattern_table():
         """Create a convolution pattern."""
         pattern = is_op("nn.conv2d")(wildcard(), is_constant())
         pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
+        pattern = pattern.optional(lambda x: is_op("add")(x, is_constant()))
         pattern = pattern.optional(
             lambda x: is_op("nn.batch_norm")(
                 x, is_constant(), is_constant(), is_constant(), is_constant()
@@ -142,6 +143,7 @@ def clml_pattern_table():
         )
         pattern = pattern.optional(is_tuple_get_item)
         pattern = pattern.optional(is_op("nn.relu"))
+        pattern = pattern.optional(is_op("clip"))
         return pattern
 
     def batch_norm_pattern():
@@ -152,10 +154,24 @@ def clml_pattern_table():
         pattern = is_tuple_get_item(pattern)
         return pattern
 
+    def concat_pattern():
+        """Create a concat pattern.
+
+        Returns
+        -------
+        pattern : dataflow_pattern.AltPattern
+            Denotes the concat pattern.
+        """
+        pattern = is_tuple(None)
+        pattern = is_op("concatenate")(pattern)
+
+        return pattern
+
     def dense_pattern():
         """Create a dense pattern."""
         pattern = is_op("nn.dense")(wildcard(), is_constant())
         pattern = pattern.optional(lambda x: is_op("add")(x, is_constant()))
+        pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, is_constant()))
         return pattern
 
     def pad_pattern():
@@ -172,6 +188,13 @@ def clml_pattern_table():
             call = call.args[0]
             if isinstance(call, tvm.relay.expr.TupleGetItem):
                 call = call.tuple_value
+        elif call.op.name == "clip":
+            if call.attrs["a_min"] != 0.0 or call.attrs["a_max"] != 6.0:
+                return False
+            call = call.args[0]
+            if isinstance(call, tvm.relay.expr.TupleGetItem):
+                call = call.tuple_value
+
         while call.op.name != "nn.conv2d":
             call = call.args[0]
         attrs, args = call.attrs, call.args
@@ -194,6 +217,7 @@ def clml_pattern_table():
         ("clml.conv2d", conv_pattern(), check_conv),
         ("clml.dense", dense_pattern()),
         ("clml.pad", pad_pattern()),
+        ("clml.concat", concat_pattern()),
         ("clml.batch_norm", batch_norm_pattern()),
     ]
 
@@ -207,11 +231,18 @@ def _register_external_op_helper(op_name, supported=True):
 
 
 _register_external_op_helper("clip")
-_register_external_op_helper("relu")
+_register_external_op_helper("nn.relu")
 _register_external_op_helper("nn.global_avg_pool2d")
 _register_external_op_helper("nn.global_max_pool2d")
+_register_external_op_helper("nn.avg_pool2d")
+_register_external_op_helper("nn.max_pool2d")
 _register_external_op_helper("nn.softmax")
 _register_external_op_helper("reshape")
+_register_external_op_helper("add")
+_register_external_op_helper("subtract")
+_register_external_op_helper("multiply")
+_register_external_op_helper("minimum")
+_register_external_op_helper("maximum")
 
 
 class OpAttrContext(object):
diff --git a/src/relay/backend/contrib/clml/codegen.cc b/src/relay/backend/contrib/clml/codegen.cc
index fa082a423d..b89f05e178 100644
--- a/src/relay/backend/contrib/clml/codegen.cc
+++ b/src/relay/backend/contrib/clml/codegen.cc
@@ -91,6 +91,8 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
       json_node = CreateDenseJSONNode(cn);
     } else if (name == "clml.pad") {
       json_node = CreatePadJSONNode(cn);
+    } else if (name == "clml.concat") {
+      json_node = CreateConcatJSONNode(cn);
     } else {
       LOG(FATAL) << "Unrecognized CLML  pattern: " << name;
     }
@@ -148,6 +150,15 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
       } else {
         current_call = current_call->args[0].as<CallNode>();
       }
+    } else if (backend::IsOp(current_call, "clip")) {
+      nodes.activation = current_call;
+      nodes.act_type = "relu6";
+      if (current_call->args[0].as<TupleGetItemNode>()) {
+        auto tuple_item = current_call->args[0].as<TupleGetItemNode>();
+        current_call = tuple_item->tuple.as<CallNode>();
+      } else {
+        current_call = current_call->args[0].as<CallNode>();
+      }
     }
     if (backend::IsOp(current_call, "nn.batch_norm")) {
       nodes.bn = current_call;
@@ -279,6 +290,32 @@ class CLMLJSONSerializer : public backend::contrib::JSONSerializer {
     return json_node;
   }
 
+  /*!
+   * \brief Create a JSON representation of a Concat operator.
+   *
+   * \param cn The call to be represented.
+   * \return A JSON representation of a specific operator.
+   */
+  std::shared_ptr<JSONGraphNode> CreateConcatJSONNode(const CallNode* cn) {
+    const auto* fn = cn->op.as<FunctionNode>();
+    ICHECK(fn);
+    const auto* concat = fn->body.as<CallNode>();
+
+    ICHECK(backend::IsOp(concat, "concatenate"));
+    const auto* concat_op = concat->op.as<OpNode>();
+    ICHECK(concat_op);
+    const std::string name = concat_op->name;
+
+    std::vector<JSONGraphNodeEntry> inputs;
+    for (auto arg : cn->args) {
+      inputs.push_back(VisitExpr(arg)[0]);
+    }
+
+    auto json_node = std::make_shared<JSONGraphNode>(name, "kernel", inputs, 1);
+    SetCallNodeAttribute(json_node, concat);
+    return json_node;
+  }
+
   /*!
    * \brief Create a JSON representation of a Dense operator.
    *
diff --git a/src/runtime/contrib/clml/clml_runtime.cc b/src/runtime/contrib/clml/clml_runtime.cc
index da41442ef9..cdc3b9a7b5 100644
--- a/src/runtime/contrib/clml/clml_runtime.cc
+++ b/src/runtime/contrib/clml/clml_runtime.cc
@@ -335,13 +335,15 @@ class CLMLRuntime : public JSONRuntimeBase {
     size_t nid;
     for (nid = 0; nid < nodes_.size(); ++nid) {
       const auto& node = nodes_[nid];
+      DLDataType tvm_dtype = node.GetOpDataType()[0];
+      cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
       if (node.GetOpType() == "input") {
-        auto clml_input = MakeCLMLTensorFromJSONNode(node);
+        auto clml_input = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
         this->layer_.storage_map.insert({nid, std::make_pair(clml_input, node)});
         this->layer_.inputs.push_back(clml_input);
         // Input copy placeholder Tensor
         this->layer_.in_placeholder.push_back(
-            MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM));
+            MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype));
       } else if (node.GetOpType() == "kernel") {
         auto op_name = node.GetOpName();
         if ("nn.conv2d" == op_name) {
@@ -364,6 +366,11 @@ class CLMLRuntime : public JSONRuntimeBase {
           auto out = CreateBatchNormLayer(&layer_, node);
           this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
           this->layer_.func_outs.push_back(out);
+        } else if ("nn.max_pool2d" == op_name || "nn.avg_pool2d" == op_name ||
+                   "nn.l2_pool2d" == op_name) {
+          auto out = CreatePoolingLayer(&layer_, node);
+          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
+          this->layer_.func_outs.push_back(out);
         } else if ("nn.global_max_pool2d" == op_name || "nn.global_avg_pool2d" == op_name) {
           auto out = CreateGlobalPoolingLayer(&layer_, node);
           this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
@@ -372,6 +379,10 @@ class CLMLRuntime : public JSONRuntimeBase {
           auto out = CreateReshapeLayer(&layer_, node);
           this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
           this->layer_.func_outs.push_back(out);
+        } else if ("concatenate" == op_name) {
+          auto out = CreateConcatLayer(&layer_, node);
+          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
+          this->layer_.func_outs.push_back(out);
         } else if ("nn.dense" == op_name) {
           auto out = CreateDenseLayer(&layer_, node);
           this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
@@ -388,6 +399,11 @@ class CLMLRuntime : public JSONRuntimeBase {
           auto out = CreateClipLayer(&layer_, node);
           this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
           this->layer_.func_outs.push_back(out);
+        } else if ("add" == op_name || "subtract" == op_name || "multiply" == op_name ||
+                   "minimum" == op_name || "maximum" == op_name) {
+          auto out = CreateBinaryLayer(&layer_, node);
+          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
+          this->layer_.func_outs.push_back(out);
         } else {
           LOG(FATAL) << "Unsupported op: " << op_name;
         }
@@ -396,10 +412,14 @@ class CLMLRuntime : public JSONRuntimeBase {
         LOG(WARNING) << "Build Engine: Unknown Node:" << node.GetOpType();
       }
     }
-    if (nid > 0) {
-      this->layer_.outputs.push_back(this->layer_.storage_map[nid - 1].first);
+
+    for (size_t i = 0; i < outputs_.size(); ++i) {
+      nid = outputs_[i].id_;
+      DLDataType tvm_dtype = nodes_[nid].GetOpDataType()[0];
+      cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+      this->layer_.outputs.push_back(this->layer_.storage_map[nid].first);
       this->layer_.out_placeholder.push_back(
-          MakeCLMLTensorFromJSONNode(nodes_[nid - 1], CL_TENSOR_LAYOUT_NCHW_QCOM));
+          MakeCLMLTensorFromJSONNode(nodes_[nid], CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype));
     }
     // ALlocate device memories and initialize the params if any
     cl_int result = 0;
@@ -558,6 +578,20 @@ class CLMLRuntime : public JSONRuntimeBase {
     }
   }
 
+  cl_arithmetic_mode_qcom MakeCLArithMode(const cl_channel_type& data_type,
+                                          const cl_channel_type& acc_type = CL_FLOAT) {
+    if (data_type == CL_FLOAT && acc_type == CL_FLOAT) {
+      return CL_ARITHMETIC_MODE_FP32_QCOM;
+    } else if (data_type == CL_HALF_FLOAT && acc_type == CL_FLOAT) {
+      return CL_ARITHMETIC_MODE_FP16_ACC32_QCOM;
+    } else if (data_type == CL_HALF_FLOAT && acc_type == CL_HALF_FLOAT) {
+      return CL_ARITHMETIC_MODE_FP16_QCOM;
+    } else {
+      LOG(FATAL) << "Datatype " << data_type << " unsupported by CLML runtime";
+      return CL_ARITHMETIC_MODE_FP32_QCOM;
+    }
+  }
+
   std::shared_ptr<cl_ml_tensor_memory_desc_qcom> MakeCLMLTensor(
       const JSONGraphNode& tensor_rep, void* data, std::vector<size_t> c_shape,
       cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_uint dtype = CL_FLOAT) {
@@ -634,6 +668,9 @@ class CLMLRuntime : public JSONRuntimeBase {
     std::vector<std::string> strides = node.GetAttr<std::vector<std::string>>("strides");
     std::vector<std::string> dilation = node.GetAttr<std::vector<std::string>>("dilation");
     std::vector<cl_uint> clml_padding = GetVectorValues(padding);
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
     if (!node.HasAttr("padding")) {
       clml_padding.resize(4);
       std::fill(clml_padding.begin(), clml_padding.end(), 0);
@@ -668,7 +705,7 @@ class CLMLRuntime : public JSONRuntimeBase {
       has_act = true;
     }
     cl_ml_op_activation_desc_qcom act_desc = {clml_act_type, CL_PROPAGATE_NAN_QCOM,
-                                              CL_ARITHMETIC_MODE_FP32_QCOM};
+                                              cl_arithmetic_mode};
 
     // Collect inputs and outputs, handling nn.conv2d.
     std::vector<JSONGraphNodeEntry> inputs = node.GetInputs();
@@ -680,15 +717,15 @@ class CLMLRuntime : public JSONRuntimeBase {
     has_bias = (num_inputs == 3) || (num_inputs == 7);
     has_bn = (num_inputs == 6) || (num_inputs == 7);
     // Input
-    auto input = MakeCLMLTensorFromJSONEntry(inputs[0]);
-
+    auto input =
+        MakeCLMLTensorFromJSONEntry(inputs[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     // Weight
-    auto weight = MakeCLMLTensorFromJSONEntry(inputs[1]);
-
+    auto weight =
+        MakeCLMLTensorFromJSONEntry(inputs[1], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     // Bias
     auto bias = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
     if (has_bias) {
-      bias = MakeCLMLTensorFromJSONEntry(inputs[2]);
+      bias = MakeCLMLTensorFromJSONEntry(inputs[2], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     } else {
       cl_ml_tensor_desc_qcom desc = {};
       desc.num_dimensions = CL_TENSOR_UNUSED_QCOM;
@@ -698,7 +735,7 @@ class CLMLRuntime : public JSONRuntimeBase {
       bias->tensor = layer_.unusedTensor;
     }
     // Output
-    auto output = MakeCLMLTensorFromJSONNode(node);
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     cl_ml_op_convolution_desc_qcom conv_desc{mode,
                                              groups,
                                              4,
@@ -707,7 +744,7 @@ class CLMLRuntime : public JSONRuntimeBase {
                                              {clml_strides[0], clml_strides[1]},
                                              {clml_dilation[0], clml_dilation[1]},
                                              0,
-                                             CL_ARITHMETIC_MODE_FP32_QCOM};
+                                             cl_arithmetic_mode};
 
     cl_ml_op_qcom op = NULL;
     if (!has_bn) {
@@ -734,13 +771,16 @@ class CLMLRuntime : public JSONRuntimeBase {
       auto bn_var = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
       auto bn_scale = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
       auto bn_bias = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
-      bn_scale = MakeCLMLTensorFromJSONEntry(inputs[bn_index], bn_shape);
-      bn_bias = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 1], bn_shape);
-      bn_mean = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 2], bn_shape);
-      bn_var = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 3], bn_shape);
-
-      cl_ml_op_batchnorm_desc_qcom bn_desc = {CL_BATCHNORM_MODE_SPATIAL_QCOM,
-                                              CL_ARITHMETIC_MODE_FP32_QCOM};
+      bn_scale = MakeCLMLTensorFromJSONEntry(inputs[bn_index], bn_shape,
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+      bn_bias = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 1], bn_shape,
+                                            CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+      bn_mean = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 2], bn_shape,
+                                            CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+      bn_var = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 3], bn_shape,
+                                           CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+
+      cl_ml_op_batchnorm_desc_qcom bn_desc = {CL_BATCHNORM_MODE_SPATIAL_QCOM, cl_arithmetic_mode};
       if (!has_act) {
         result = h_ClmlIntf->clCreateMLOpFusedConvolutionBatchNormForwardQCOM(
             workspace->context, 0, &conv_desc, &bn_desc, input->tensor, weight->tensor,
@@ -772,11 +812,15 @@ class CLMLRuntime : public JSONRuntimeBase {
       cl_activation_function_qcom clml_act_type = CL_ACTIVATION_RELU) {
     cl_int result = 0;
     cl_ml_op_qcom op = NULL;
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0]);
-    auto output = MakeCLMLTensorFromJSONNode(node);
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
+                                             cl_dtype);
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
     cl_ml_op_activation_desc_qcom act_desc = {clml_act_type, CL_PROPAGATE_NAN_QCOM,
-                                              CL_ARITHMETIC_MODE_FP32_QCOM};
+                                              cl_arithmetic_mode};
 
     cl_ml_tensor_desc_qcom desc = {};
     desc.num_dimensions = CL_TENSOR_UNUSED_QCOM;
@@ -805,7 +849,11 @@ class CLMLRuntime : public JSONRuntimeBase {
                                                                       const JSONGraphNode& node) {
     cl_int result = 0;
     cl_ml_op_qcom op = NULL;
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0]);
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
+                                             cl_dtype);
     int axis = std::stoi(node.GetAttr<std::vector<std::string>>("axis")[0]);
     auto bn_dims = get_tensor_dims(nodes_[node.GetInputs()[1].id_]);
     std::vector<size_t> bn_shape = {1, 1, 1, 1};
@@ -814,15 +862,18 @@ class CLMLRuntime : public JSONRuntimeBase {
     auto bn_var = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
     auto bn_scale = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
     auto bn_bias = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
-    bn_scale = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], bn_shape);
-    bn_bias = MakeCLMLTensorFromJSONEntry(node.GetInputs()[2], bn_shape);
-    bn_mean = MakeCLMLTensorFromJSONEntry(node.GetInputs()[3], bn_shape);
-    bn_var = MakeCLMLTensorFromJSONEntry(node.GetInputs()[4], bn_shape);
+    bn_scale = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], bn_shape,
+                                           CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    bn_bias = MakeCLMLTensorFromJSONEntry(node.GetInputs()[2], bn_shape,
+                                          CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    bn_mean = MakeCLMLTensorFromJSONEntry(node.GetInputs()[3], bn_shape,
+                                          CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    bn_var = MakeCLMLTensorFromJSONEntry(node.GetInputs()[4], bn_shape,
+                                         CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
-    auto output = MakeCLMLTensorFromJSONNode(node);
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
-    cl_ml_op_batchnorm_desc_qcom bn_desc = {CL_BATCHNORM_MODE_SPATIAL_QCOM,
-                                            CL_ARITHMETIC_MODE_FP32_QCOM};
+    cl_ml_op_batchnorm_desc_qcom bn_desc = {CL_BATCHNORM_MODE_SPATIAL_QCOM, cl_arithmetic_mode};
 
     result = h_ClmlIntf->clCreateMLOpBatchNormForwardQCOM(
         workspace->context, 0, &bn_desc, input->tensor, bn_mean->tensor, bn_var->tensor,
@@ -834,6 +885,61 @@ class CLMLRuntime : public JSONRuntimeBase {
     return output;
   }
 
+  /*!
+   * \brief Create a creating pooling layer.
+   *
+   * \note Currently global_max_pool2d and global_avg_pool2d are supported.
+   *
+   * \param layer The CLML layer to build. Containing inputs, outputs and the CLML function.
+   * \param node The JSON representation of the operator.
+   */
+  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> CreatePoolingLayer(CachedLayer* layer,
+                                                                    const JSONGraphNode& node) {
+    cl_int result = 0;
+    cl_ml_op_qcom op = NULL;
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
+                                             cl_dtype);
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    auto in_dims = get_tensor_dims(nodes_[node.GetInputs()[0].id_]);
+
+    std::vector<std::string> windows = node.GetAttr<std::vector<std::string>>("pool_size");
+    std::vector<std::string> strides = node.GetAttr<std::vector<std::string>>("strides");
+    std::vector<std::string> padding = node.GetAttr<std::vector<std::string>>("padding");
+    std::vector<cl_uint> clml_window = GetVectorValues(windows);
+    std::vector<cl_uint> clml_stride = GetVectorValues(strides);
+    std::vector<cl_uint> clml_padding = GetVectorValues(padding);
+
+    cl_ml_op_pooling_desc_qcom pool_desc = {
+        node.GetOpName() == "nn.max_pool2d" ? CL_POOLING_MODE_MAX_QCOM
+                                            : CL_POOLING_MODE_AVERAGE_EXCLUDE_PADDING_QCOM,
+        4,  // reserved
+        {clml_padding[0], clml_padding[1]},
+        {clml_padding[2], clml_padding[3]},
+        {clml_stride[0], clml_stride[1]},
+        {clml_window[0], clml_window[1]},
+        CL_PROPAGATE_NAN_QCOM,
+        cl_arithmetic_mode,
+    };
+
+    cl_ml_tensor_desc_qcom desc = {};
+    cl_ml_tensor_qcom unusedTensor = NULL;
+    desc.num_dimensions = CL_TENSOR_UNUSED_QCOM;
+    result = h_ClmlIntf->clCreateMLTensorQCOM(workspace->context, NULL, &desc, &unusedTensor);
+    ICHECK(unusedTensor && result == CL_SUCCESS) << ":" << result;
+
+    result =
+        h_ClmlIntf->clCreateMLOpPoolingForwardQCOM(workspace->context, 0, &pool_desc, input->tensor,
+                                                   unusedTensor, output->tensor, &op, tuning_cache);
+    ICHECK(op && result == CL_SUCCESS) << "Pooling Error:" << result;
+
+    layer_.func_ins.push_back(input);
+    layer->function.push_back(op);
+    return output;
+  }
+
   /*!
    * \brief Create a global pooling layer.
    *
@@ -846,8 +952,12 @@ class CLMLRuntime : public JSONRuntimeBase {
       CachedLayer* layer, const JSONGraphNode& node) {
     cl_int result = 0;
     cl_ml_op_qcom op = NULL;
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0]);
-    auto output = MakeCLMLTensorFromJSONNode(node);
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
+                                             cl_dtype);
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     auto in_dims = get_tensor_dims(nodes_[node.GetInputs()[0].id_]);
     cl_ml_op_pooling_desc_qcom pool_desc = {
         node.GetOpName() == "nn.global_max_pool2d" ? CL_POOLING_MODE_MAX_QCOM
@@ -858,7 +968,7 @@ class CLMLRuntime : public JSONRuntimeBase {
         {1, 1},
         {in_dims.w, in_dims.h},
         CL_PROPAGATE_NAN_QCOM,
-        CL_ARITHMETIC_MODE_FP32_QCOM,
+        cl_arithmetic_mode,
     };
 
     cl_ml_tensor_desc_qcom desc = {};
@@ -887,14 +997,17 @@ class CLMLRuntime : public JSONRuntimeBase {
                                                                     const JSONGraphNode& node) {
     cl_int result = 0;
     cl_ml_op_qcom op = NULL;
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0]);
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
+                                             cl_dtype);
     auto out_dims = get_tensor_dims(nodes_[node.GetInputs()[0].id_]);
-    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, CL_FLOAT, nullptr,
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype, nullptr,
                                              {out_dims.n, out_dims.c, 1, 1});
 
     cl_ml_op_softmax_desc_qcom softmax_desc = {CL_SOFTMAX_ALGORITHM_ACCURATE_QCOM,
-                                               CL_SOFTMAX_MODE_INSTANCE_QCOM,
-                                               CL_ARITHMETIC_MODE_FP32_QCOM};
+                                               CL_SOFTMAX_MODE_INSTANCE_QCOM, cl_arithmetic_mode};
 
     result = h_ClmlIntf->clCreateMLOpSoftmaxQCOM(workspace->context, 0, &softmax_desc,
                                                  input->tensor, output->tensor, &op, tuning_cache);
@@ -915,8 +1028,12 @@ class CLMLRuntime : public JSONRuntimeBase {
                                                                 const JSONGraphNode& node) {
     cl_int result = 0;
     cl_ml_op_qcom op = NULL;
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0]);
-    auto output = MakeCLMLTensorFromJSONNode(node);
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
+                                             cl_dtype);
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
     std::string pad_mode = node.GetAttr<std::vector<std::string>>("pad_mode")[0];
     std::vector<std::string> padding = node.GetAttr<std::vector<std::string>>("pad_width");
@@ -936,7 +1053,7 @@ class CLMLRuntime : public JSONRuntimeBase {
         clml_pad_mode,
         {0, 0},
         {clml_padding[0], clml_padding[1], clml_padding[2], clml_padding[3], 0, 0, 0, 0},
-        CL_ARITHMETIC_MODE_FP32_QCOM};
+        cl_arithmetic_mode};
 
     result = h_ClmlIntf->clCreateMLOpPadQCOM(workspace->context, 0, &pad_desc, input->tensor,
                                              output->tensor, &op, tuning_cache);
@@ -957,8 +1074,11 @@ class CLMLRuntime : public JSONRuntimeBase {
                                                                     const JSONGraphNode& node) {
     cl_int result = 0;
     cl_ml_op_qcom op = NULL;
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0]);
-    auto output = MakeCLMLTensorFromJSONNode(node);
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
+                                             cl_dtype);
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
     result = h_ClmlIntf->clCreateMLOpReshapeQCOM(workspace->context, 0, input->tensor,
                                                  output->tensor, &op, tuning_cache);
@@ -969,6 +1089,42 @@ class CLMLRuntime : public JSONRuntimeBase {
     return output;
   }
 
+  /*!
+   * \brief Create a concat layer.
+   *
+   *
+   * \param layer The CLML layer to build. Containing inputs, outputs and the CLML function.
+   * \param node The JSON representation of the operator.
+   */
+  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> CreateConcatLayer(CachedLayer* layer,
+                                                                   const JSONGraphNode& node) {
+    cl_int result = 0;
+    cl_ml_op_qcom op = NULL;
+    std::vector<JSONGraphNodeEntry> input_ = node.GetInputs();
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    int inputSize = input_.size();
+    int axis = std::stoi(node.GetAttr<std::vector<std::string>>("axis")[0]);
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    cl_ml_tensor_qcom* concatInputs = new cl_ml_tensor_qcom[inputSize];
+    for (int i = 0; i < inputSize; i++) {
+      auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[i], {},
+                                               CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+      concatInputs[i] = input->tensor;
+    }
+    cl_ml_op_concat_desc_qcom concatDesc = {1, (cl_uint)inputSize, cl_arithmetic_mode};
+
+    result = h_ClmlIntf->clCreateMLOpConcatQCOM(workspace->context, 0, &concatDesc, concatInputs,
+                                                output->tensor, &op, tuning_cache);
+    ICHECK(op && result == CL_SUCCESS) << "Concat Error:" << result;
+
+    layer->function.push_back(op);
+
+    delete[] concatInputs;
+    return output;
+  }
+
   /*!
    * \brief Create a dense layer.
    *
@@ -980,21 +1136,27 @@ class CLMLRuntime : public JSONRuntimeBase {
                                                                   const JSONGraphNode& node) {
     cl_int result = 0;
     cl_ml_op_qcom op = NULL;
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0]);
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    auto inp_dims = get_tensor_dims(nodes_[node.GetInputs()[0].id_]);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {1, inp_dims.c, 1, 1},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     auto wt_dims = get_tensor_dims(nodes_[node.GetInputs()[1].id_]);
     bool has_bias = node.GetInputs().size() == 3 ? true : false;
-
-    auto weight = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], {1, 1, wt_dims.n, wt_dims.c});
+    auto weight = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], {1, 1, wt_dims.n, wt_dims.c},
+                                              CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     auto bias = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
     if (has_bias) {
       auto bias_dims = get_tensor_dims(nodes_[node.GetInputs()[2].id_]);
-      bias = MakeCLMLTensorFromJSONEntry(node.GetInputs()[2], {1, bias_dims.c, 1, 1});
+      bias = MakeCLMLTensorFromJSONEntry(node.GetInputs()[2], {1, bias_dims.c, 1, 1},
+                                         CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     }
 
     cl_ml_op_fully_connected_desc_qcom fc_desc = {1, CL_FC_WEIGHT_TRANSFORM_TRANSPOSE_QCOM,
-                                                  CL_ARITHMETIC_MODE_FP32_QCOM};
+                                                  cl_arithmetic_mode};
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
-    auto output = MakeCLMLTensorFromJSONNode(node);
     if (has_bias) {
       result = h_ClmlIntf->clCreateMLOpFullyConnectedQCOM(
           workspace->context, 0, &fc_desc, input->tensor, weight->tensor, bias->tensor,
@@ -1021,15 +1183,17 @@ class CLMLRuntime : public JSONRuntimeBase {
                                                                  const JSONGraphNode& node) {
     cl_int result = 0;
     cl_ml_op_qcom op = NULL;
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0]);
-    auto output = MakeCLMLTensorFromJSONNode(node);
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
+                                             cl_dtype);
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     cl_float a_max = std::stof(node.GetAttr<std::vector<std::string>>("a_max")[0]);
     cl_float a_min = std::stof(node.GetAttr<std::vector<std::string>>("a_min")[0]);
 
-    cl_ml_op_clip_desc_qcom clip_desc = {CL_CLIP_BY_VALUE_QCOM,
-                                         {{a_max}, CL_FLOAT},
-                                         {{a_min}, CL_FLOAT},
-                                         CL_ARITHMETIC_MODE_FP32_QCOM};
+    cl_ml_op_clip_desc_qcom clip_desc = {
+        CL_CLIP_BY_VALUE_QCOM, {{a_max}, CL_FLOAT}, {{a_min}, CL_FLOAT}, cl_arithmetic_mode};
 
     result = h_ClmlIntf->clCreateMLOpClipQCOM(workspace->context, 0, &clip_desc, input->tensor,
                                               output->tensor, &op, tuning_cache);
@@ -1040,6 +1204,47 @@ class CLMLRuntime : public JSONRuntimeBase {
     return output;
   }
 
+  /*!
+   * \brief Create a Binary layer.
+   *
+   * \param layer The CLML layer to build. Containing inputs, outputs and the CLML output.
+   * \param node The JSON representation of the operator.
+   */
+  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> CreateBinaryLayer(CachedLayer* layer,
+                                                                   const JSONGraphNode& node) {
+    cl_int result = 0;
+    cl_ml_op_qcom op = NULL;
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    auto input_a = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {},
+                                               CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    auto input_b = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], {},
+                                               CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    std::string op_name = node.GetOpName();
+    cl_binary_op_qcom binary_op = CL_TENSOR_OP_ADD_QCOM;
+    if (op_name == "subtract")
+      binary_op = CL_TENSOR_OP_SUB_QCOM;
+    else if (op_name == "multiply")
+      binary_op = CL_TENSOR_OP_MUL_QCOM;
+    else if (op_name == "minimum")
+      binary_op = CL_TENSOR_OP_MIN_QCOM;
+    else if (op_name == "maximum")
+      binary_op = CL_TENSOR_OP_MAX_QCOM;
+    cl_ml_op_binary_desc_qcom add_desc = {
+        binary_op, {{1.0}, CL_FLOAT}, {{1.0}, CL_FLOAT}, {{0.0}, CL_FLOAT}, cl_arithmetic_mode};
+
+    result = h_ClmlIntf->clCreateMLOpBinaryQCOM(workspace->context, 0, &add_desc, input_a->tensor,
+                                                input_b->tensor, output->tensor, &op, tuning_cache);
+    ICHECK(op && result == CL_SUCCESS) << op_name << " Node Error:" << result;
+
+    layer_.func_ins.push_back(input_a);
+    layer_.func_ins.push_back(input_b);
+    layer->function.push_back(op);
+    return output;
+  }
+
   /*!
    * \brief The network layers represented by acl functions.
    * \note Currently only supports a single layer.
diff --git a/tests/python/contrib/test_clml/infrastructure.py b/tests/python/contrib/test_clml/infrastructure.py
index 0cf76079e8..08b11525ec 100644
--- a/tests/python/contrib/test_clml/infrastructure.py
+++ b/tests/python/contrib/test_clml/infrastructure.py
@@ -29,6 +29,7 @@ from tvm import rpc
 from tvm.contrib import graph_executor
 from tvm.relay.op.contrib import clml
 from tvm.contrib import utils
+from tvm import autotvm
 from tvm.autotvm.measure import request_remote
 from tvm.relay.expr_functor import ExprMutator, Call
 
@@ -144,35 +145,28 @@ def skip_codegen_test():
         return True
 
 
-def build_module(mod, target, target_host, params=None, enable_clml=True):
+def build_module(mod, target, target_host, params=None, enable_clml=True, tune_log=""):
     """Build module with option to build for CLML."""
     if isinstance(mod, tvm.relay.expr.Call):
         mod = tvm.IRModule.from_expr(mod)
 
-    with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
-        if enable_clml:
-            mod = clml.partition_for_clml(mod, params)
-        relay.backend.te_compiler.get().clear()
-        # print("Build  Mod:", mod)
-        return relay.build(mod, target=target, target_host=target_host, params=params)
+    with autotvm.apply_history_best(tune_log):
+        with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
+            if enable_clml:
+                mod = clml.partition_for_clml(mod, params)
+            relay.backend.te_compiler.get().clear()
+            return relay.build(mod, target=target, target_host=target_host, params=params)
 
 
 def build_and_run(
-    mod,
-    inputs,
-    outputs,
-    params,
-    device,
-    enable_clml=True,
-    no_runs=1,
-    config=None,
+    mod, inputs, outputs, params, device, enable_clml=True, no_runs=1, config=None, tune_log=""
 ):
     """Build and run the relay module."""
     if config is None:
         config = {}
 
     try:
-        libm = build_module(mod, device.target, device.target_host, params, enable_clml)
+        libm = build_module(mod, device.target, device.target_host, params, enable_clml, tune_log)
 
         clml_modules = extract_clml_modules(libm)
         for mod in clml_modules:
@@ -198,7 +192,7 @@ def build_and_run(
     for _ in range(no_runs):
         gen_module.run()
         out.append([gen_module.get_output(i) for i in range(outputs)])
-    time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=50)
+    time_f = gen_module.module.time_evaluator("run", device.device.cl(0), number=1)
     cost = time_f().mean
     print("%g secs/iteration\n" % cost)
     return out
diff --git a/tests/python/contrib/test_clml/test_network.py b/tests/python/contrib/test_clml/test_network.py
index 405f5782ff..95f3a45baf 100644
--- a/tests/python/contrib/test_clml/test_network.py
+++ b/tests/python/contrib/test_clml/test_network.py
@@ -25,20 +25,13 @@ import tvm
 from test_clml.infrastructure import skip_runtime_test, build_and_run, Device
 
 
-def _build_and_run_network(mod, params, inputs, data, device, atol, rtol):
+def _build_and_run_network(mod, params, inputs, data, device, atol, rtol, tvm_log=""):
     """Helper function to build and run a network."""
 
     outputs = []
     for clml in [True, False]:
         outputs.append(
-            build_and_run(
-                mod,
-                data,
-                1,
-                params,
-                device,
-                enable_clml=clml,
-            )[0]
+            build_and_run(mod, data, 1, params, device, enable_clml=clml, tune_log=tvm_log)[0][0]
         )
     return outputs
 
@@ -55,11 +48,7 @@ def _get_keras_model(keras_model, inputs_dict, data):
     def get_bottom_top_model(model, layer_name):
         layer = model.get_layer(layer_name)
         bottom_input = model.layers[0].input
-        bottom_output = bottom_input
-        for layer in model.layers:
-            bottom_output = layer(bottom_output)
-            if layer.name == layer_name:
-                break
+        bottom_output = layer.output
         bottom_model = Model(bottom_input, bottom_output)
         return bottom_model
 
@@ -81,6 +70,9 @@ def test_mobilenet():
 
     def get_model():
         from tensorflow.keras.applications import MobileNet
+        import tensorflow as tf
+
+        tf.keras.backend.clear_session()
 
         mobilenet = MobileNet(
             include_top=True, weights=None, input_shape=(224, 224, 3), classes=1000
@@ -106,32 +98,113 @@ def test_mobilenet():
     )
 
     # test
-    print("OpenCL:", outputs[0][0].asnumpy().shape)
-    print("CLML:", outputs[1][0].asnumpy().shape)
+    print("OpenCL:", outputs[0].asnumpy().shape)
+    print("CLML:", outputs[1].asnumpy().shape)
 
-    opencl_sort = np.argsort(outputs[1][0].asnumpy()).flatten()
-    clml_sort = np.argsort(outputs[0][0].asnumpy()).flatten()
+    opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
+    clml_sort = np.argsort(outputs[0].asnumpy()).flatten()
 
     tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, atol=1e-5)
 
 
-"""
-    tvm.testing.assert_allclose(
-         ref_outputs, outputs[1][0].asnumpy(), rtol=1e-5, atol=1e-5)
-    print("OpenCL to Keras looks good")
-    tvm.testing.assert_allclose(
-         outputs[0][0].asnumpy(), outputs[1][0].asnumpy(), rtol=1e-5, atol=1e-5)
-    print("OpenCL to CLML looks good")
-    exit(0)
+def test_inception_v3():
+    Device.load("test_config.json")
+
+    if skip_runtime_test():
+        return
+
+    device = Device()
+    dtype = "float16"
+
+    def get_model():
+        from tensorflow.keras.applications import InceptionV3
+        import tensorflow as tf
+
+        tf.keras.backend.clear_session()
+
+        inceptionV3 = InceptionV3(
+            include_top=True, weights=None, input_shape=(299, 299, 3), classes=1000
+        )
+        inputs = {inceptionV3.input_names[0]: ((1, 3, 299, 299), "float16")}
+
+        data = {}
+        np.random.seed(0)
+        for name, (shape, dtype) in inputs.items():
+            if dtype == "uint8":
+                low, high = 0, 1
+            else:
+                low, high = -2, 1
+            data[name] = np.random.uniform(low, high, shape).astype(dtype)
+
+        mod, params, ref_outputs = _get_keras_model(inceptionV3, inputs, data)
+        return mod, params, inputs, data, ref_outputs
+
+    mod, params, inputs, input_data, ref_outputs = get_model()
+    outputs = _build_and_run_network(
+        mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5
+    )
+
+    opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
+    clml_sort = np.argsort(outputs[0].asnumpy()).flatten()
+
+    tvm.testing.assert_allclose(opencl_sort[:5], clml_sort[:5], rtol=1e-5, atol=1e-5)
+
+
+def test_resnet50v2():
+    Device.load("test_config.json")
+
+    if skip_runtime_test():
+        return
+
+    device = Device()
+    dtype = "float16"
+
+    def get_model():
+        from tensorflow.keras.applications import ResNet50V2
+        import tensorflow as tf
+
+        tf.keras.backend.clear_session()
 
-    tvm.testing.assert_allclose(
-         ref_outputs.transpose(0, 3, 1, 2), outputs[1][0].asnumpy(), rtol=1e-5, atol=1e-5)
-    print("OpenCL to Keras looks good")
-    tvm.testing.assert_allclose(
-         outputs[0][0].asnumpy(), outputs[1][0].asnumpy(), rtol=1e-5, atol=1e-5)
-    print("OpenCL to CLML looks good")
-"""
+        model = ResNet50V2(include_top=True, weights=None, input_shape=(224, 224, 3), classes=1000)
+        inputs_dict = {model.input_names[0]: ((1, 3, 224, 224), "float32")}
+
+        data = {}
+        np.random.seed(0)
+
+        for name, (shape, dtype) in inputs_dict.items():
+            if dtype == "uint8":
+                low, high = 0, 1
+            else:
+                low, high = -1, 1
+            data[name] = np.random.uniform(low, high, shape).astype(dtype)
+
+        """Convert Keras graph to relay."""
+        inputs = {}
+        for name, (shape, _) in inputs_dict.items():
+            inputs[model.input_names[0]] = shape
+
+        ref_outputs = model.predict(data["input_1"].transpose(0, 2, 3, 1))
+
+        mod, params = relay.frontend.from_keras(model, inputs, layout="NCHW")
+
+        return mod, params, inputs, data, ref_outputs
+
+    mod, params, inputs, input_data, ref_outputs = get_model()
+    outputs = _build_and_run_network(
+        mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5
+    )
+
+    # test
+    print("OpenCL:", outputs[0].asnumpy().shape)
+    print("CLML:", outputs[1].asnumpy().shape)
+
+    opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
+    clml_sort = np.argsort(outputs[0].asnumpy()).flatten()
+
+    tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, atol=1e-5)
 
 
 if __name__ == "__main__":
     test_mobilenet()
+    test_resnet50v2()
+    test_inception_v3()
diff --git a/tests/python/contrib/test_clml/test_ops.py b/tests/python/contrib/test_clml/test_ops.py
index 13f49d1527..d14a5ec6e9 100644
--- a/tests/python/contrib/test_clml/test_ops.py
+++ b/tests/python/contrib/test_clml/test_ops.py
@@ -211,6 +211,87 @@ def test_batchnorm():
     )
 
 
+def test_concat():
+    Device.load("test_config.json")
+
+    if skip_runtime_test():
+        return
+
+    device = Device()
+    dtype = "float16"
+    in_shape_1 = (1, 16, 16, 16)
+    in_shape_2 = (1, 16, 16, 16)
+    a = relay.var("input_1", shape=in_shape_1, dtype=dtype)
+    b = relay.var("input_2", shape=in_shape_2, dtype=dtype)
+    low, high = -1, 1
+    inputs = {
+        "input_1": tvm.nd.array(np.random.uniform(-1, 1, in_shape_1).astype(dtype)),
+        "input_2": tvm.nd.array(np.random.uniform(-1, 1, in_shape_2).astype(dtype)),
+    }
+
+    params = {}
+    func = relay.concatenate((a, b), axis=1)
+    mod = IRModule.from_expr(func)
+
+    opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0]
+    clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0]
+
+    tvm.testing.assert_allclose(
+        clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3
+    )
+
+
+def test_avgpool():
+    Device.load("test_config.json")
+
+    if skip_runtime_test():
+        return
+
+    device = Device()
+    dtype = "float16"
+    trials = [
+        # input size         pool_size stride  paading
+        [(1, 64, 147, 147), (3, 3), (2, 2), (0, 0, 0, 0), "max"],
+        [(1, 192, 71, 71), (3, 3), (2, 2), (0, 0, 0, 0), "max"],
+        [(1, 288, 35, 35), (3, 3), (2, 2), (0, 0, 0, 0), "max"],
+        [(1, 768, 17, 17), (3, 3), (2, 2), (0, 0, 0, 0), "max"],
+        [(1, 2048, 17, 17), (3, 3), (2, 2), (0, 0, 0, 0), "max"],
+        [(1, 192, 35, 35), (3, 3), (1, 1), (0, 0, 1, 1), "avg"],
+        [(1, 256, 35, 35), (3, 3), (1, 1), (0, 0, 1, 1), "avg"],
+        [(1, 288, 35, 35), (3, 3), (1, 1), (0, 0, 1, 1), "avg"],
+        [(1, 768, 17, 17), (3, 3), (1, 1), (0, 0, 1, 1), "avg"],
+        [(1, 1280, 8, 8), (3, 3), (1, 1), (0, 0, 1, 1), "avg"],
+    ]
+    params = {}
+    for (
+        input_shape,
+        pool_size,
+        stride,
+        padding,
+        pooling_type,
+    ) in trials:
+        a = relay.var("input_1", shape=input_shape, dtype=dtype)
+        input_arr = tvm.nd.array(np.random.uniform(-1, 1, input_shape).astype(dtype))
+        inputs = {
+            "input_1": input_arr,
+        }
+
+        if pooling_type == "max":
+            func = relay.nn.max_pool2d(a, pool_size=pool_size, strides=stride, padding=padding)
+        else:
+            func = relay.nn.avg_pool2d(a, pool_size=pool_size, strides=stride, padding=padding)
+        mod = IRModule.from_expr(func)
+
+        opencl_out = build_and_run(mod, inputs, 1, params, device, enable_clml=False)[0]
+        clml_out = build_and_run(mod, inputs, 1, params, device, enable_clml=True)[0]
+
+        tvm.testing.assert_allclose(
+            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3
+        )
+
+
 if __name__ == "__main__":
     test_conv2d()
-    test_batchnorm()
+    # test_batchnorm()
+    test_avgpool()
+    test_concat()