This is an automated email from the ASF dual-hosted git repository.

srk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 3a57a40c1b [RUNTIME][CLML] Fix for CLML ops and enable more test case 
(#15896)
3a57a40c1b is described below

commit 3a57a40c1ba40e1c330346905f8db72775fc9992
Author: krishnaraj36 <quic_kvegi...@quicinc.com>
AuthorDate: Wed Dec 20 13:50:00 2023 +0530

    [RUNTIME][CLML] Fix for CLML ops and enable more test case (#15896)
    
    * [RUNTIME][CLML] Fix for few clml ops
    
    Fixed the dense operator and enhance clml network testcase
    
    * [RUNTIME][CLML] Fix for dense layer and float16
    
    Fixed the dense layer issue in network level and improved
    converage of dense layer with clml
    Fixed float16 crash error.
    
    * Update comment for dense pattern
    
    * fix in clml test cases
    
    * Enable more test cases and few fixes
    
    * Fix the import error
    
    * Fix the import error
    
    * Fix in batchnorm testcase
    
    * Restructure clml test case and enable vm executor
    
    * Fix the import error in clml test network
    
    * Fix the test failure for vm tests
    
    * Update clml.py
---
 python/tvm/relay/op/contrib/clml.py              | 118 ++-
 src/relay/backend/contrib/clml/codegen.cc        |   2 +-
 src/runtime/contrib/clml/clml_runtime.cc         | 521 ++++++++-----
 tests/python/contrib/test_clml/conftest.py       |  21 +-
 tests/python/contrib/test_clml/infrastructure.py | 242 +++---
 tests/python/contrib/test_clml/test_network.py   | 249 +++---
 tests/python/contrib/test_clml/test_ops.py       | 942 +++++++++++++++++------
 tests/scripts/task_python_adreno.sh              |   1 +
 8 files changed, 1332 insertions(+), 764 deletions(-)

diff --git a/python/tvm/relay/op/contrib/clml.py 
b/python/tvm/relay/op/contrib/clml.py
index f194dd114b..14dd35a3cb 100644
--- a/python/tvm/relay/op/contrib/clml.py
+++ b/python/tvm/relay/op/contrib/clml.py
@@ -18,6 +18,7 @@
 """CLML Library supported operators."""
 import json
 from string import Template
+import numpy as np
 import tvm
 
 from tvm import relay
@@ -27,7 +28,7 @@ from tvm.relay import transform
 from tvm.relay.build_module import bind_params_by_name
 from tvm.relay import function as _function
 from tvm.relay.expr_functor import ExprMutator
-from tvm.relay.expr import Call, TupleGetItem
+from tvm.relay.expr import Call, TupleGetItem, Var, Constant
 
 from ...dataflow_pattern import wildcard, is_op, is_constant, 
is_tuple_get_item, is_tuple
 from .register import register_pattern_table
@@ -81,34 +82,61 @@ class RemoveDropoutPass:
         return RemoveDropout().visit(func)
 
 
-class BroadcastInputs(ExprMutator):
+class OptimizeBatchnorm(ExprMutator):
     """
-    Binary operators need broadcasting for CLML.
+    Fuse Conv+Batchnorm and constant folder to generate Conv+Add.
     """
 
-    def visit_call(self, call):
-        if call.op.name in ["add", "subtract", "multiply", "divide", 
"maximum", "minimum"]:
-            new_fn = self.visit(call.op)
-            call_shape = call.checked_type.shape
-            lhs = call.args[0]
-            rhs = call.args[1]
-            lhs_shape = lhs.checked_type.shape
-            rhs_shape = rhs.checked_type.shape
-            if list(call_shape) != list(lhs_shape):
-                lhs = relay.broadcast_to(self.visit(lhs), call_shape)
-            if list(call_shape) != list(rhs_shape):
-                rhs = relay.broadcast_to(self.visit(rhs), call_shape)
-            args = [lhs, rhs]
-            return Call(new_fn, args, call.attrs)
-        return super().visit_call(call)
+    def visit_call(self, call) -> relay.expr.Expr:
+        new_args = []
+        for arg in call.args:
+            if (
+                not isinstance(arg, (Var, Constant))
+                and isinstance(arg, tvm.relay.TupleGetItem)
+                and arg.tuple_value.op.name == "nn.batch_norm"
+                and (not isinstance(arg.tuple_value.args[0], (Var, Constant)))
+                and arg.tuple_value.args[0].op.name == "nn.conv2d"
+            ):
+                ep = arg.tuple_value.attrs["epsilon"]
+                wt = arg.tuple_value.args[1].data.numpy()
+                bs = arg.tuple_value.args[2].data.numpy()
+                mn = arg.tuple_value.args[3].data.numpy()
+                vr = arg.tuple_value.args[4].data.numpy() + ep
+                dino = np.sqrt(vr)
+                wt = wt / dino
+                bs = bs - mn * wt
+                conv_op = arg.tuple_value.args[0]
+                conv_args = list(conv_op.args)
+                wt_conv = conv_args[1].data.numpy()
+                if conv_op.attrs["kernel_layout"] == "OIHW":
+                    wt = wt.reshape(wt.shape[0], 1, 1, 1)
+                elif conv_op.attrs["kernel_layout"] == "IOHW":
+                    wt = wt.reshape(1, wt.shape[0], 1, 1)
+                else:
+                    raise ValueError("Unsupported Conv2d kernel layout")
+                wt_conv = wt_conv * wt
+                conv_args[1] = relay.const(tvm.nd.array(wt_conv))
+                bs_args = relay.const(tvm.nd.array(bs.reshape(-1, bs.shape[0], 
1, 1)))
+                conv_out = Call(
+                    arg.tuple_value.args[0].op, conv_args, 
arg.tuple_value.args[0].attrs
+                )
+                mod = tvm.relay.add(conv_out, bs_args)
+                new_args.append(mod)
+            else:
+                new_args.append(arg)
+
+        call = Call(call.op, new_args, call.attrs)
+        args = [self.visit(arg) for arg in call.args]
+
+        return Call(call.op, args, call.attrs)
 
 
 @transform.function_pass(opt_level=0)
-class BinaryOpBroadcaster:
+class OptimizeBatchnormPass:
     def transform_function(
         self, func: relay.function.Function, mod: tvm.IRModule, _: 
tvm.transform.PassContext
     ) -> relay.function.Function:
-        return BroadcastInputs().visit(func)
+        return OptimizeBatchnorm().visit(func)
 
 
 def partition_for_clml(mod, params=None, **opts):
@@ -134,8 +162,8 @@ def partition_for_clml(mod, params=None, **opts):
         [
             transform.InferType(),
             RemoveDropoutPass(),
-            BinaryOpBroadcaster(),
             transform.FoldConstant(),
+            OptimizeBatchnormPass(),
             transform.MergeComposite(clml_pattern_table()),
             transform.AnnotateTarget("clml", False),
             transform.MergeCompilerRegions(),
@@ -289,8 +317,15 @@ def clml_pattern_table():
 
         return pattern
 
-    def dense_pattern():
-        """Create a dense pattern."""
+    def dense1d_pattern():
+        """Create a dense pattern for 1d vector to matrix multiple."""
+        pattern = is_op("nn.dense")(wildcard(), is_constant())
+        pattern = pattern.optional(lambda x: is_op("nn.bias_add")(x, 
is_constant()))
+        pattern = pattern.optional(lambda x: is_op("add")(x, is_constant()))
+        return pattern
+
+    def dense2d_pattern():
+        """Create a dense pattern for 2d matrix to matrix multiple."""
         pattern = is_op("nn.dense")(wildcard(), is_constant())
         return pattern
 
@@ -377,6 +412,9 @@ def clml_pattern_table():
         if len(call.args[1].checked_type.shape) == 0:
             return False
 
+        if tuple(call.args[0].checked_type.shape) != 
tuple(call.args[1].checked_type.shape):
+            return False
+
         for arg in call.args:
             # Avoid any operators with dtype Int64
             if arg.checked_type.dtype == "int64":
@@ -436,11 +474,33 @@ def clml_pattern_table():
             return False
         return True
 
+    def check_dense1d_op(extract):
+        call = extract
+        # Only support single Matmul
+        if call.args[0].checked_type.shape[0] > 1:
+            return False
+        if not (call.op.name in ["nn.bias_add", "add"] and 
call.args[0].op.name == "nn.dense"):
+            return False
+        return True
+
+    def check_reshape(extract):
+        call = extract
+        call_shape = call.checked_type.shape
+        # Only support batch dim = 1
+        if call_shape[0] > 1:
+            return False
+        # Checking buffer indexing limit
+        for shape in call_shape:
+            if shape > 32768:
+                return False
+        return True
+
     return [
         ("clml.pad_conv2d", pad_conv_pattern(), check_conv),
         ("clml.conv2d", conv_pattern(), check_conv),
         ("clml.conv2d_transpose", conv_transpose_pattern(), 
check_conv_transpose),
-        ("clml.dense", dense_pattern(), check_default_op),
+        ("clml.dense1d", dense1d_pattern(), check_dense1d_op),
+        ("clml.dense2d", dense2d_pattern(), check_default_op),
         ("clml.pad", pad_pattern(), check_pad_op),
         ("clml.concat", concat_pattern(), check_concat_op),
         ("clml.batch_norm", batch_norm_pattern(), check_default_op),
@@ -451,7 +511,7 @@ def clml_pattern_table():
         ("clml.minimum", is_op("minimum")(wildcard(), wildcard()), 
check_binary_op),
         ("clml.maximum", is_op("maximum")(wildcard(), wildcard()), 
check_binary_op),
         ("clml.softmax", is_op("nn.softmax")(wildcard()), check_softmax_op),
-        # ("clml.reshape", is_op("reshape")(wildcard()), check_default_op),
+        ("clml.reshape", is_op("reshape")(wildcard()), check_reshape),
         ("clml.avg_pool2d", is_op("nn.avg_pool2d")(wildcard()), 
check_default_op),
         ("clml.max_pool2d", is_op("nn.max_pool2d")(wildcard()), 
check_default_op),
         ("clml.global_avg_pool2d", is_op("nn.global_avg_pool2d")(wildcard()), 
check_default_op),
@@ -807,7 +867,7 @@ class CLMLGetSubModuleSrc:
                         elif activation == "relu6":
                             activation = "CL_ACTIVATION_RELU6"
                         else:
-                            RuntimeError("Unknown activation:" + activation)
+                            raise RuntimeError("Unknown activation:" + 
activation)
                     has_bias = bool((node["inputs"] == 3) or (node["inputs"] 
== 7))
                     has_bn = bool((node["inputs"] == 6) or (node["inputs"] == 
7))
                     input_tensor = get_tensor_from_map(node["inputs"][0][0])
@@ -907,8 +967,8 @@ class CLMLGetSubModuleSrc:
                         )
                     )
                 elif node["name"] == "nn.batch_norm":
-                    bn_attrs = tuple(node["attrs"]["batchnorm"][0][0])
-                    axis = bn_attrs[0]
+                    bn_attrs = tuple(node["attrs"]["axis"])
+                    axis = int(bn_attrs[0][0])
                     bn_shape = [1, 1, 1, 1]
                     bn_node = self.nodes[node["inputs"][0][0]]
                     bn_shape[axis] = bn_node["attrs"]["shape"][0][0]
@@ -1094,7 +1154,7 @@ class CLMLGetSubModuleSrc:
                         )
                     )
                 else:
-                    RuntimeError("Unsupported Op:" + node["name"])
+                    raise RuntimeError("Unsupported Op:" + node["name"])
                 self.clml_code.append(
                     self.MapInsert.substitute(nid=node_out_name, 
tensor_desc=node_out_name)
                 )
diff --git a/src/relay/backend/contrib/clml/codegen.cc 
b/src/relay/backend/contrib/clml/codegen.cc
index 069e11dac5..5d6fc0c2cf 100644
--- a/src/relay/backend/contrib/clml/codegen.cc
+++ b/src/relay/backend/contrib/clml/codegen.cc
@@ -87,7 +87,7 @@ class CLMLJSONSerializer : public 
backend::contrib::JSONSerializer {
       json_node = CreateCompositeConvJSONNode(cn);
     } else if (name == "clml.batch_norm") {
       json_node = CreateBatchNormJSONNode(cn);
-    } else if (name == "clml.dense") {
+    } else if (name == "clml.dense1d" || name == "clml.dense2d") {
       json_node = CreateDenseJSONNode(cn);
     } else if (name == "clml.pad") {
       json_node = CreatePadJSONNode(cn);
diff --git a/src/runtime/contrib/clml/clml_runtime.cc 
b/src/runtime/contrib/clml/clml_runtime.cc
index 1146ff7249..aa1e2b82b6 100644
--- a/src/runtime/contrib/clml/clml_runtime.cc
+++ b/src/runtime/contrib/clml/clml_runtime.cc
@@ -512,36 +512,36 @@ class CLMLRuntime : public JSONRuntimeBase {
   /*!
    * \brief Create an CLML tensor from JSON node entry. Lookup storage map 
before creation.
    *
-   * \param tensor The tensor as Node Entry .
+   * \param nid The node index of graph JSON.
    * \param shape shape information of tensor
    * \param layout the tensor layout to be used
    * \param dtype tensor data type
    * \return CLML Tensor descriptor.
    */
   std::shared_ptr<cl_ml_tensor_memory_desc_qcom> MakeCLMLTensorFromJSONEntry(
-      const JSONGraphNodeEntry& tensor, std::vector<size_t> shape, 
cl_ml_tensor_layout_qcom layout,
-      cl_uint dtype) {
-    JSONGraphNode node = nodes_[tensor.id_];
+      size_t nid, std::vector<size_t> shape, cl_ml_tensor_layout_qcom layout, 
cl_uint dtype) {
+    const JSONGraphNode node = nodes_[nid];
 
-    if (this->layer_.storage_map.find(tensor.id_) == 
this->layer_.storage_map.end()) {
+    if (this->layer_.storage_map.find(nid) == this->layer_.storage_map.end()) {
       void* node_data = nullptr;
       if (node.GetOpType() == "const") {
-        node_data = data_entry_[EntryID(tensor)]->data;
+        uint32_t eid = EntryID(nid, 0);
+        node_data = data_entry_[eid]->data;
       }
       auto clml_tensor = MakeCLMLTensorFromJSONNode(node, layout, dtype, 
node_data, shape);
-      this->layer_.storage_map.insert({tensor.id_, std::make_pair(clml_tensor, 
node)});
+      this->layer_.storage_map.insert({nid, std::make_pair(clml_tensor, 
node)});
 
       if ("input" == node.GetOpType()) {
-        this->layer_.inputs.insert({tensor.id_, clml_tensor});
+        this->layer_.inputs.insert({nid, this->layer_.storage_map[nid].first});
         // Input copy placeholder Tensor
         this->layer_.in_placeholder.insert(
-            {tensor.id_, MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_NCHW_QCOM, dtype,
-                                                    node_data, shape)});
+            {nid, MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, 
dtype, node_data,
+                                             shape)});
       }
 
       return clml_tensor;
     } else {
-      return this->layer_.storage_map[tensor.id_].first;
+      return this->layer_.storage_map[nid].first;
     }
   }
 
@@ -553,76 +553,62 @@ class CLMLRuntime : public JSONRuntimeBase {
    */
   void BuildEngine() {
     size_t nid;
+    // Create tensors for the operators which has distinct layout format
+    // other than CL_TENSOR_LAYOUT_OPTIMAL_QCOM.
+    for (nid = 0; nid < nodes_.size(); ++nid) {
+      const auto& node = nodes_[nid];
+      if ("nn.dense" == node.GetOpName()) CreateDenseLayerTensor(&layer_, 
node, nid);
+      if ("nn.batch_matmul" == node.GetOpName()) 
CreateBatchMatmulLayerTensor(&layer_, node, nid);
+    }
+
     for (nid = 0; nid < nodes_.size(); ++nid) {
       const auto& node = nodes_[nid];
-      DLDataType tvm_dtype = node.GetOpDataType()[0];
-      cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
       if (node.GetOpType() == "input") {
         // Layers may request for different layout. Differ the input 
allocation.
       } else if (node.GetOpType() == "kernel") {
         auto op_name = node.GetOpName();
-        if ("nn.conv2d" == op_name) {
-          auto out = CreateConvolution2DLayer(&layer_, node, 
CL_CONVOLUTION_MODE_CONVOLUTION_QCOM);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.depthwise_conv2d" == op_name) {
-          auto out = CreateConvolution2DLayer(&layer_, node, 
CL_CONVOLUTION_MODE_DEPTHWISE_QCOM);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.conv2d_transpose" == op_name) {
-          auto out = CreateConvolution2DLayer(&layer_, node, 
CL_CONVOLUTION_MODE_TRANSPOSE_QCOM);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.relu6" == op_name) {
-          auto out = CreateReLULayer(&layer_, node, CL_ACTIVATION_RELU6);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.relu" == op_name) {
-          auto out = CreateReLULayer(&layer_, node, CL_ACTIVATION_RELU);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.batch_norm" == op_name) {
-          auto out = CreateBatchNormLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.max_pool2d" == op_name || "nn.avg_pool2d" == op_name ||
-                   "nn.l2_pool2d" == op_name) {
-          auto out = CreatePoolingLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.global_max_pool2d" == op_name || "nn.global_avg_pool2d" 
== op_name) {
-          auto out = CreateGlobalPoolingLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("reshape" == op_name) {
-          auto out = CreateReshapeLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("concatenate" == op_name) {
-          auto out = CreateConcatLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.dense" == op_name) {
-          auto out = CreateDenseLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.softmax" == op_name) {
-          auto out = CreateSoftMaxLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.pad" == op_name) {
-          auto out = CreatePadLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.batch_flatten" == op_name) {
-          auto out = CreateBatchFlattenLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("clip" == op_name) {
-          auto out = CreateClipLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("add" == op_name || "subtract" == op_name || "multiply" == 
op_name ||
-                   "minimum" == op_name || "maximum" == op_name || "divide" == 
op_name) {
-          auto out = CreateBinaryLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.depth_to_space" == op_name) {
-          auto out = CreateDepthToSpaceLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.upsampling" == op_name) {
-          auto out = CreateResizeLayer(&layer_, node);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else if ("nn.batch_matmul" == op_name) {
-          auto out = CreateBatchMatmulLayer(&layer_, node, nid);
-          this->layer_.storage_map.insert({nid, std::make_pair(out, node)});
-        } else {
+        if ("nn.conv2d" == op_name)
+          CreateConvolution2DLayer(&layer_, node, 
CL_CONVOLUTION_MODE_CONVOLUTION_QCOM, nid);
+        else if ("nn.depthwise_conv2d" == op_name)
+          CreateConvolution2DLayer(&layer_, node, 
CL_CONVOLUTION_MODE_DEPTHWISE_QCOM, nid);
+        else if ("nn.conv2d_transpose" == op_name)
+          CreateConvolution2DLayer(&layer_, node, 
CL_CONVOLUTION_MODE_TRANSPOSE_QCOM, nid);
+        else if ("nn.relu6" == op_name)
+          CreateReLULayer(&layer_, node, nid, CL_ACTIVATION_RELU6);
+        else if ("nn.relu" == op_name)
+          CreateReLULayer(&layer_, node, nid, CL_ACTIVATION_RELU);
+        else if ("nn.batch_norm" == op_name)
+          CreateBatchNormLayer(&layer_, node, nid);
+        else if ("nn.max_pool2d" == op_name || "nn.avg_pool2d" == op_name ||
+                 "nn.l2_pool2d" == op_name)
+          CreatePoolingLayer(&layer_, node, nid);
+        else if ("nn.global_max_pool2d" == op_name || "nn.global_avg_pool2d" 
== op_name)
+          CreateGlobalPoolingLayer(&layer_, node, nid);
+        else if ("reshape" == op_name)
+          CreateReshapeLayer(&layer_, node, nid);
+        else if ("concatenate" == op_name)
+          CreateConcatLayer(&layer_, node, nid);
+        else if ("nn.dense" == op_name)
+          CreateDenseLayer(&layer_, node, nid);
+        else if ("nn.softmax" == op_name)
+          CreateSoftMaxLayer(&layer_, node, nid);
+        else if ("nn.pad" == op_name)
+          CreatePadLayer(&layer_, node, nid);
+        else if ("nn.batch_flatten" == op_name)
+          CreateBatchFlattenLayer(&layer_, node, nid);
+        else if ("clip" == op_name)
+          CreateClipLayer(&layer_, node, nid);
+        else if ("add" == op_name || "subtract" == op_name || "multiply" == 
op_name ||
+                 "minimum" == op_name || "maximum" == op_name || "divide" == 
op_name)
+          CreateBinaryLayer(&layer_, node, nid);
+        else if ("nn.depth_to_space" == op_name)
+          CreateDepthToSpaceLayer(&layer_, node, nid);
+        else if ("nn.upsampling" == op_name)
+          CreateResizeLayer(&layer_, node, nid);
+        else if ("nn.batch_matmul" == op_name)
+          CreateBatchMatmulLayer(&layer_, node, nid);
+        else
           LOG(FATAL) << "Unsupported op: " << op_name;
-        }
         this->layer_.layer_names.push_back(op_name);
       } else if (node.GetOpType() != "const") {
         LOG(WARNING) << "Build Engine: Unknown Node:" << node.GetOpType();
@@ -778,16 +764,20 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML function.
    * \param node The JSON representation of the operator.
+   * \param mode The conv2d mode type - CL_CONVOLUTION_MODE_CONVOLUTION_QCOM
+   *                                    or CL_CONVOLUTION_MODE_DEPTHWISE_QCOM
+   *                                    or CL_CONVOLUTION_MODE_TRANSPOSE_QCOM.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> CreateConvolution2DLayer(
-      CachedLayer* layer, const JSONGraphNode& node, cl_convolution_mode_qcom 
mode) {
+  void CreateConvolution2DLayer(CachedLayer* layer, const JSONGraphNode& node,
+                                cl_convolution_mode_qcom mode, size_t nid) {
     std::vector<std::string> padding = 
node.GetAttr<std::vector<std::string>>("padding");
     std::vector<std::string> strides = 
node.GetAttr<std::vector<std::string>>("strides");
     std::vector<std::string> dilation = 
node.GetAttr<std::vector<std::string>>("dilation");
     std::vector<cl_uint> clml_padding = GetVectorValues(padding);
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
     if (!node.HasAttr("padding")) {
       clml_padding.resize(4);
       std::fill(clml_padding.begin(), clml_padding.end(), 0);
@@ -835,14 +825,15 @@ class CLMLRuntime : public JSONRuntimeBase {
     has_bn = (num_inputs == 6) || (num_inputs == 7);
     // Input
     auto input =
-        MakeCLMLTensorFromJSONEntry(inputs[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+        MakeCLMLTensorFromJSONEntry(inputs[0].id_, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     // Weight
     auto weight =
-        MakeCLMLTensorFromJSONEntry(inputs[1], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+        MakeCLMLTensorFromJSONEntry(inputs[1].id_, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     // Bias
     auto bias = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
     if (has_bias) {
-      bias = MakeCLMLTensorFromJSONEntry(inputs[2], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+      bias =
+          MakeCLMLTensorFromJSONEntry(inputs[2].id_, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     } else {
       cl_ml_tensor_desc_qcom desc = {};
       desc.num_dimensions = CL_TENSOR_UNUSED_QCOM;
@@ -851,7 +842,7 @@ class CLMLRuntime : public JSONRuntimeBase {
       bias->tensor = layer_.unusedTensor;
     }
     // Output
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     cl_ml_op_convolution_desc_qcom conv_desc{mode,
                                              groups,
                                              4,
@@ -886,13 +877,13 @@ class CLMLRuntime : public JSONRuntimeBase {
       auto bn_var = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
       auto bn_scale = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
       auto bn_bias = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
-      bn_scale = MakeCLMLTensorFromJSONEntry(inputs[bn_index], bn_shape,
+      bn_scale = MakeCLMLTensorFromJSONEntry(inputs[bn_index].id_, bn_shape,
                                              CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
-      bn_bias = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 1], bn_shape,
+      bn_bias = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 1].id_, bn_shape,
                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
-      bn_mean = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 2], bn_shape,
+      bn_mean = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 2].id_, bn_shape,
                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
-      bn_var = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 3], bn_shape,
+      bn_var = MakeCLMLTensorFromJSONEntry(inputs[bn_index + 3].id_, bn_shape,
                                            CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
 
       cl_ml_op_batchnorm_desc_qcom bn_desc = {CL_BATCHNORM_MODE_SPATIAL_QCOM, 
cl_arithmetic_mode};
@@ -912,7 +903,7 @@ class CLMLRuntime : public JSONRuntimeBase {
       }
       layer->function.push_back(op);
     }
-    return output;
+    return;
   }
 
   /*!
@@ -920,18 +911,18 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML output.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> CreateReLULayer(
-      CachedLayer* layer, const JSONGraphNode& node,
-      cl_activation_function_qcom clml_act_type = CL_ACTIVATION_RELU) {
+  void CreateReLULayer(CachedLayer* layer, const JSONGraphNode& node, size_t 
nid,
+                       cl_activation_function_qcom clml_act_type = 
CL_ACTIVATION_RELU) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
-                                             cl_dtype);
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
     cl_ml_op_activation_desc_qcom act_desc = {clml_act_type, 
CL_PROPAGATE_NAN_QCOM,
                                               cl_arithmetic_mode};
@@ -947,7 +938,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(op && result == CL_SUCCESS) << "Activation Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
@@ -956,16 +947,16 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML function.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> 
CreateBatchNormLayer(CachedLayer* layer,
-                                                                      const 
JSONGraphNode& node) {
+  void CreateBatchNormLayer(CachedLayer* layer, const JSONGraphNode& node, 
size_t nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
-                                             cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
     int axis = std::stoi(node.GetAttr<std::vector<std::string>>("axis")[0]);
     float epsilon = 
std::stof(node.GetAttr<std::vector<std::string>>("epsilon")[0]);
 
@@ -981,16 +972,16 @@ class CLMLRuntime : public JSONRuntimeBase {
     auto bn_var = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
     auto bn_scale = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
     auto bn_bias = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
-    bn_scale = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], bn_shape,
+    bn_scale = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1].id_, bn_shape,
                                            CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
-    bn_bias = MakeCLMLTensorFromJSONEntry(node.GetInputs()[2], bn_shape,
+    bn_bias = MakeCLMLTensorFromJSONEntry(node.GetInputs()[2].id_, bn_shape,
                                           CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
-    bn_mean = MakeCLMLTensorFromJSONEntry(node.GetInputs()[3], bn_shape,
+    bn_mean = MakeCLMLTensorFromJSONEntry(node.GetInputs()[3].id_, bn_shape,
                                           CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
-    bn_var = MakeCLMLTensorFromJSONEntry(node.GetInputs()[4], bn_shape,
+    bn_var = MakeCLMLTensorFromJSONEntry(node.GetInputs()[4].id_, bn_shape,
                                          CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
 
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
     cl_ml_op_batchnorm_desc_qcom bn_desc = {CL_BATCHNORM_MODE_SPATIAL_QCOM, 
cl_arithmetic_mode};
 
@@ -1000,7 +991,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(op && result == CL_SUCCESS) << "Batchnorm Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
@@ -1010,17 +1001,17 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML function.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> 
CreatePoolingLayer(CachedLayer* layer,
-                                                                    const 
JSONGraphNode& node) {
+  void CreatePoolingLayer(CachedLayer* layer, const JSONGraphNode& node, 
size_t nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
-                                             cl_dtype);
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
     std::vector<std::string> windows = 
node.GetAttr<std::vector<std::string>>("pool_size");
     std::vector<std::string> strides = 
node.GetAttr<std::vector<std::string>>("strides");
@@ -1053,7 +1044,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(op && result == CL_SUCCESS) << "Pooling Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
@@ -1063,17 +1054,17 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML function.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> CreateGlobalPoolingLayer(
-      CachedLayer* layer, const JSONGraphNode& node) {
+  void CreateGlobalPoolingLayer(CachedLayer* layer, const JSONGraphNode& node, 
size_t nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
-                                             cl_dtype);
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     auto in_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]);
     cl_ml_op_pooling_desc_qcom pool_desc = {
         node.GetOpName() == "nn.global_max_pool2d" ? CL_POOLING_MODE_MAX_QCOM
@@ -1098,7 +1089,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(op && result == CL_SUCCESS) << "Pooling Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
@@ -1106,19 +1097,19 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML output.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> 
CreateSoftMaxLayer(CachedLayer* layer,
-                                                                    const 
JSONGraphNode& node) {
+  void CreateSoftMaxLayer(CachedLayer* layer, const JSONGraphNode& node, 
size_t nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
-                                             cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
     auto out_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]);
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype, nullptr,
-                                             {out_dims.n, out_dims.c, 1, 1});
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {out_dims.n, out_dims.c, 1, 
1},
+                                              CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
 
     cl_ml_op_softmax_desc_qcom softmax_desc = 
{CL_SOFTMAX_ALGORITHM_ACCURATE_QCOM,
                                                CL_SOFTMAX_MODE_INSTANCE_QCOM, 
cl_arithmetic_mode};
@@ -1128,7 +1119,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(op && result == CL_SUCCESS) << "SoftMax Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
@@ -1136,17 +1127,17 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML output.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> CreatePadLayer(CachedLayer* 
layer,
-                                                                const 
JSONGraphNode& node) {
+  void CreatePadLayer(CachedLayer* layer, const JSONGraphNode& node, size_t 
nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
-                                             cl_dtype);
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
     std::string pad_mode = 
node.GetAttr<std::vector<std::string>>("pad_mode")[0];
     std::vector<std::string> padding = 
node.GetAttr<std::vector<std::string>>("pad_width");
@@ -1173,7 +1164,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(op && result == CL_SUCCESS) << "Pad Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
@@ -1181,23 +1172,23 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML output.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> CreateBatchFlattenLayer(
-      CachedLayer* layer, const JSONGraphNode& node) {
+  void CreateBatchFlattenLayer(CachedLayer* layer, const JSONGraphNode& node, 
size_t nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
-                                             cl_dtype);
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
     result = CLML_INTF->clCreateMLOpReshapeQCOM(CLML_CTX, nullptr, 
input->tensor, output->tensor,
                                                 &op, layer_.tuning_cache);
     ICHECK(op && result == CL_SUCCESS) << "Reshape Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
@@ -1205,23 +1196,23 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML output.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> 
CreateReshapeLayer(CachedLayer* layer,
-                                                                    const 
JSONGraphNode& node) {
+  void CreateReshapeLayer(CachedLayer* layer, const JSONGraphNode& node, 
size_t nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
-                                             cl_dtype);
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
 
     result = CLML_INTF->clCreateMLOpReshapeQCOM(CLML_CTX, nullptr, 
input->tensor, output->tensor,
                                                 &op, layer_.tuning_cache);
     ICHECK(op && result == CL_SUCCESS) << "Reshape Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
@@ -1230,21 +1221,21 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML function.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> 
CreateConcatLayer(CachedLayer* layer,
-                                                                   const 
JSONGraphNode& node) {
+  void CreateConcatLayer(CachedLayer* layer, const JSONGraphNode& node, size_t 
nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     std::vector<JSONGraphNodeEntry> input_ = node.GetInputs();
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
     int inputSize = input_.size();
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     cl_uint axis = 
std::stoi(node.GetAttr<std::vector<std::string>>("axis")[0]);
     cl_ml_tensor_qcom* concatInputs = new cl_ml_tensor_qcom[inputSize];
     for (int i = 0; i < inputSize; i++) {
-      auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[i], {},
+      auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[i].id_, {},
                                                CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
       concatInputs[i] = input->tensor;
     }
@@ -1257,7 +1248,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     layer->function.push_back(op);
 
     delete[] concatInputs;
-    return output;
+    return;
   }
 
   /*!
@@ -1266,40 +1257,112 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML function.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> CreateDenseLayer(CachedLayer* 
layer,
-                                                                  const 
JSONGraphNode& node) {
+  void CreateDenseLayer(CachedLayer* layer, const JSONGraphNode& node, size_t 
nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    size_t num_inputs = node.GetInputs().size();
+    bool has_bias = (num_inputs == 3);
     auto in_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]);
-    auto input =
-        MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype);
+    cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_NCHW_QCOM;
+    bool is_vec_matmul = false;
+    if (in_dims.n == 1 && has_bias) {
+      layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM;
+      is_vec_matmul = true;
+    }
+
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {}, 
layout, cl_dtype);
     auto wt_dims = GetTensorDims(nodes_[node.GetInputs()[1].id_]);
-    auto weight = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], {1, 1, 
wt_dims.n, wt_dims.c},
-                                              CL_TENSOR_LAYOUT_NCHW_QCOM, 
cl_dtype);
-    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, 
cl_dtype);
-    cl_gemm_transform_qcom b_transform = CL_GEMM_TRANSFORM_NONE_QCOM;
-    if (in_dims.c == wt_dims.c) {
-      b_transform = CL_GEMM_TRANSFORM_TRANSPOSE_QCOM;
+    auto weight = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1].id_, {1, 1, 
wt_dims.n, wt_dims.c},
+                                              layout, cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, layout, cl_dtype);
+
+    auto bias = std::make_shared<cl_ml_tensor_memory_desc_qcom>();
+    if (has_bias) {
+      bias = MakeCLMLTensorFromJSONEntry(node.GetInputs()[2].id_, {}, layout, 
cl_dtype);
+    } else {
+      cl_ml_tensor_desc_qcom desc = {};
+      desc.num_dimensions = CL_TENSOR_UNUSED_QCOM;
+      bias->tensor = layer_.unusedTensor;
     }
-    cl_ml_op_gemm_desc_qcom gemmDesc = {in_dims.n,                    // m
-                                        wt_dims.n,                    // n
-                                        wt_dims.c,                    // k
-                                        CL_GEMM_TRANSFORM_NONE_QCOM,  // A 
transform
-                                        b_transform,                  // B 
transform
-                                        {{1.0}, CL_FLOAT},            // alpha
-                                        {{0.0}, CL_FLOAT},            // beta
-                                        cl_arithmetic_mode};
 
-    result = CLML_INTF->clCreateMLOpGemmQCOM(CLML_CTX, 0, &gemmDesc, 
input->tensor, weight->tensor,
-                                             output->tensor, &op, 
layer_.tuning_cache);
-    ICHECK(op && result == CL_SUCCESS) << "Dense Error:" << result;
+    if (is_vec_matmul) {
+      cl_fc_weight_transform_qcom w_transform = 
CL_FC_WEIGHT_TRANSFORM_NONE_QCOM;
+      if (in_dims.c == wt_dims.c) w_transform = 
CL_FC_WEIGHT_TRANSFORM_TRANSPOSE_QCOM;
 
-    layer->function.push_back(op);
-    return output;
+      cl_ml_op_fully_connected_desc_qcom fc_desc{1,  // refer clml_ops.txt for 
struct
+                                                 w_transform, 
cl_arithmetic_mode};
+
+      result = CLML_INTF->clCreateMLOpFullyConnectedQCOM(CLML_CTX, nullptr, 
&fc_desc, input->tensor,
+                                                         weight->tensor, 
bias->tensor,
+                                                         output->tensor, &op, 
layer_.tuning_cache);
+      ICHECK(op && result == CL_SUCCESS) << "FC layer Error:" << result;
+      layer->function.push_back(op);
+    } else {
+      cl_gemm_transform_qcom b_transform = CL_GEMM_TRANSFORM_NONE_QCOM;
+      if (in_dims.c == wt_dims.c) b_transform = 
CL_GEMM_TRANSFORM_TRANSPOSE_QCOM;
+
+      cl_ml_op_gemm_desc_qcom gemmDesc = {in_dims.n,                    // m
+                                          wt_dims.n,                    // n
+                                          wt_dims.c,                    // k
+                                          CL_GEMM_TRANSFORM_NONE_QCOM,  // A 
transform
+                                          b_transform,                  // B 
transform
+                                          {{1.0}, CL_FLOAT},            // 
alpha
+                                          {{0.0}, CL_FLOAT},            // beta
+                                          cl_arithmetic_mode};
+
+      result =
+          CLML_INTF->clCreateMLOpGemmQCOM(CLML_CTX, 0, &gemmDesc, 
input->tensor, weight->tensor,
+                                          output->tensor, &op, 
layer_.tuning_cache);
+      ICHECK(op && result == CL_SUCCESS) << "Gemm layer Error:" << result;
+      layer->function.push_back(op);
+      if (has_bias) {
+        cl_ml_op_binary_desc_qcom binaryDesc = {CL_TENSOR_OP_ADD_QCOM,
+                                                {{1.0}, CL_FLOAT},  // alpha
+                                                {{1.0}, CL_FLOAT},  // beta
+                                                {{1.0}, CL_FLOAT},  // gamma
+                                                cl_arithmetic_mode};
+        result = CLML_INTF->clCreateMLOpBinaryQCOM(CLML_CTX, 0, &binaryDesc, 
bias->tensor,
+                                                   layer_.unusedTensor, 
output->tensor, &op,
+                                                   layer_.tuning_cache);
+        layer->function.push_back(op);
+      }
+    }
+
+    return;
+  }
+
+  /*!
+   * \brief Create a dense layer Tensors with supported layout.
+   *
+   *
+   * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML function.
+   * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
+   */
+  void CreateDenseLayerTensor(CachedLayer* layer, const JSONGraphNode& node, 
size_t nid) {
+    cl_int result = 0;
+    cl_ml_op_qcom op = nullptr;
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    auto in_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]);
+    size_t num_inputs = node.GetInputs().size();
+    bool has_bias = (num_inputs == 3);
+    cl_ml_tensor_layout_qcom layout = CL_TENSOR_LAYOUT_NCHW_QCOM;
+    if (in_dims.n == 1 && has_bias) {
+      layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM;
+    }
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {}, 
layout, cl_dtype);
+    auto wt_dims = GetTensorDims(nodes_[node.GetInputs()[1].id_]);
+    auto weight = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1].id_, {1, 1, 
wt_dims.n, wt_dims.c},
+                                              layout, cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, layout, cl_dtype);
+
+    return;
   }
 
   /*!
@@ -1308,20 +1371,19 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML function.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> 
CreateBatchMatmulLayer(CachedLayer* layer,
-                                                                        const 
JSONGraphNode& node,
-                                                                        int 
nid) {
+  void CreateBatchMatmulLayer(CachedLayer* layer, const JSONGraphNode& node, 
size_t nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
     auto in_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {in_dims.c, 
in_dims.h},
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, 
{in_dims.c, in_dims.h},
                                              CL_TENSOR_LAYOUT_NCHW_QCOM, 
cl_dtype);
     auto wt_dims = GetTensorDims(nodes_[node.GetInputs()[1].id_]);
-    auto weight = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], {1, 1, 
wt_dims.c, wt_dims.h},
+    auto weight = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1].id_, {1, 1, 
wt_dims.c, wt_dims.h},
                                               CL_TENSOR_LAYOUT_NCHW_QCOM, 
cl_dtype);
 
     std::vector<int64_t> out_shape = node.GetOpShape()[0];
@@ -1330,8 +1392,8 @@ class CLMLRuntime : public JSONRuntimeBase {
     clml_out_shape.push_back(out_shape[2]);
     clml_out_shape.push_back(1);
     clml_out_shape.push_back(1);
-    auto output = MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, 
cl_dtype, nullptr,
-                                             clml_out_shape);
+    auto output =
+        MakeCLMLTensorFromJSONEntry(nid, clml_out_shape, 
CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype);
     layer->out_shapes.insert({nid, clml_out_shape});
 
     cl_bool b_transpose = 
std::stoi(node.GetAttr<std::vector<std::string>>("transpose_b")[0]);
@@ -1353,7 +1415,40 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(op && result == CL_SUCCESS) << "BatchMatmul Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
+  }
+
+  /*!
+   * \brief Create a Batch matmul layer(batch_size=1 supported) Tensors with 
supported layout.
+   *
+   *
+   * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML function.
+   * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
+   */
+  void CreateBatchMatmulLayerTensor(CachedLayer* layer, const JSONGraphNode& 
node, size_t nid) {
+    cl_int result = 0;
+    cl_ml_op_qcom op = nullptr;
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    auto in_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, 
{in_dims.c, in_dims.h},
+                                             CL_TENSOR_LAYOUT_NCHW_QCOM, 
cl_dtype);
+    auto wt_dims = GetTensorDims(nodes_[node.GetInputs()[1].id_]);
+    auto weight = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1].id_, {1, 1, 
wt_dims.c, wt_dims.h},
+                                              CL_TENSOR_LAYOUT_NCHW_QCOM, 
cl_dtype);
+
+    std::vector<int64_t> out_shape = node.GetOpShape()[0];
+    std::vector<size_t> clml_out_shape;
+    clml_out_shape.push_back(out_shape[1]);
+    clml_out_shape.push_back(out_shape[2]);
+    clml_out_shape.push_back(1);
+    clml_out_shape.push_back(1);
+    auto output =
+        MakeCLMLTensorFromJSONEntry(nid, clml_out_shape, 
CL_TENSOR_LAYOUT_NCHW_QCOM, cl_dtype);
+    layer->out_shapes.insert({nid, clml_out_shape});
+    return;
   }
 
   /*!
@@ -1361,17 +1456,17 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML output.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> CreateClipLayer(CachedLayer* 
layer,
-                                                                 const 
JSONGraphNode& node) {
+  void CreateClipLayer(CachedLayer* layer, const JSONGraphNode& node, size_t 
nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
-                                             cl_dtype);
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     cl_float a_max = 
std::stof(node.GetAttr<std::vector<std::string>>("a_max")[0]);
     cl_float a_min = 
std::stof(node.GetAttr<std::vector<std::string>>("a_min")[0]);
 
@@ -1383,7 +1478,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(op && result == CL_SUCCESS) << "Clip Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
@@ -1391,19 +1486,19 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML output.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> 
CreateBinaryLayer(CachedLayer* layer,
-                                                                   const 
JSONGraphNode& node) {
+  void CreateBinaryLayer(CachedLayer* layer, const JSONGraphNode& node, size_t 
nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
-    auto input_a = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {},
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    auto input_a = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
                                                CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
-    auto input_b = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1], {},
+    auto input_b = MakeCLMLTensorFromJSONEntry(node.GetInputs()[1].id_, {},
                                                CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     std::string op_name = node.GetOpName();
     cl_binary_op_qcom binary_op = CL_TENSOR_OP_ADD_QCOM;
     if (op_name == "subtract")
@@ -1425,7 +1520,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(op && result == CL_SUCCESS) << op_name << " Node Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
@@ -1433,17 +1528,17 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML output.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> CreateDepthToSpaceLayer(
-      CachedLayer* layer, const JSONGraphNode& node) {
+  void CreateDepthToSpaceLayer(CachedLayer* layer, const JSONGraphNode& node, 
size_t nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
-                                             cl_dtype);
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     cl_uint block_size = 
std::stoi(node.GetAttr<std::vector<std::string>>("block_size")[0]);
 
     cl_ml_op_depthtospace_desc_qcom dtos_desc = {block_size, 
cl_arithmetic_mode};
@@ -1452,7 +1547,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(op && result == CL_SUCCESS) << "DepthToSpace Layer Error:" << 
result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
@@ -1460,17 +1555,17 @@ class CLMLRuntime : public JSONRuntimeBase {
    *
    * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML output.
    * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
    */
-  std::shared_ptr<cl_ml_tensor_memory_desc_qcom> 
CreateResizeLayer(CachedLayer* layer,
-                                                                   const 
JSONGraphNode& node) {
+  void CreateResizeLayer(CachedLayer* layer, const JSONGraphNode& node, size_t 
nid) {
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
-    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0], {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM,
-                                             cl_dtype);
-    auto output = MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
+    cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
+                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, 
CL_TENSOR_LAYOUT_OPTIMAL_QCOM, cl_dtype);
     cl_bool align_corners = 
std::stoi(node.GetAttr<std::vector<std::string>>("align_corners")[0]);
 
     cl_ml_op_resize_bilinear_desc_qcom resize_desc = {align_corners, false, 
cl_arithmetic_mode};
@@ -1479,7 +1574,7 @@ class CLMLRuntime : public JSONRuntimeBase {
     ICHECK(op && result == CL_SUCCESS) << "Resize Layer Error:" << result;
 
     layer->function.push_back(op);
-    return output;
+    return;
   }
 
   /*!
diff --git a/tests/python/contrib/test_clml/conftest.py 
b/tests/python/contrib/test_clml/conftest.py
index a51fc8edf1..6b9c91ec10 100644
--- a/tests/python/contrib/test_clml/conftest.py
+++ b/tests/python/contrib/test_clml/conftest.py
@@ -15,12 +15,25 @@
 # specific language governing permissions and limitations
 # under the License.
 
-import sys
+import os
 import tvm
+from tvm import rpc
 import pytest
-from test_clml.infrastructure import Device
 
 
 @pytest.fixture(scope="session")
-def device():
-    return Device()
+def remote():
+    if (
+        "TVM_TRACKER_HOST" in os.environ
+        and "TVM_TRACKER_PORT" in os.environ
+        and "RPC_DEVICE_KEY" in os.environ
+    ):
+
+        rpc_tracker_host = os.environ["TVM_TRACKER_HOST"]
+        rpc_tracker_port = int(os.environ["TVM_TRACKER_PORT"])
+        rpc_device_key = os.environ["RPC_DEVICE_KEY"]
+        tracker = rpc.connect_tracker(rpc_tracker_host, rpc_tracker_port)
+        remote = tracker.request(rpc_device_key, priority=0, 
session_timeout=600)
+        return remote
+    else:
+        return None
diff --git a/tests/python/contrib/test_clml/infrastructure.py 
b/tests/python/contrib/test_clml/infrastructure.py
index f0a513cc17..b8ce236cdd 100644
--- a/tests/python/contrib/test_clml/infrastructure.py
+++ b/tests/python/contrib/test_clml/infrastructure.py
@@ -33,72 +33,23 @@ from tvm import autotvm
 from tvm.autotvm.measure import request_remote
 from tvm.relay.expr_functor import ExprMutator, Call
 
+"""Utils for adreno compute/schedules"""
 
-class Device:
-    """
-    Configuration for CLML tests.
-
-    Check tests/python/contrib/clml/ for the presence of an test_config.json 
file.
-    This file can be used to override the default configuration here which 
will attempt to run the
-    Open CLML runtime tests locally if the runtime is available. Changing the 
configuration
-    will allow these runtime tests to be offloaded to a remote Snapdragon 
device via a tracker for example.
-
-    Notes
-    -----
-        The test configuration will be loaded once when the class is created. 
If the configuration
-        changes between tests, any changes will not be picked up.
-
-    Parameters
-    ----------
-    device : RPCSession
-        Allows tests to connect to and use remote device.
-
-    Attributes
-    ----------
-    connection_type : str
-        Details the type of RPC connection to use. Options:
-        local - Use the local device,
-        tracker - Connect to a tracker to request a remote device,
-        remote - Connect to a remote device directly.
-    host : str
-        Specify IP address or hostname of remote target.
-    port : int
-        Specify port number of remote target.
-    target : str
-        The compilation target.
-    device_key : str
-        The device key of the remote target. Use when connecting to a remote 
device via a tracker.
-    cross_compile : str
-        Specify path to cross compiler to use when connecting a remote device 
from a non-arm platform.
-    """
-
-    connection_type = "tracker"
-    host = os.getenv("TVM_TRACKER_HOST", "localhost")
-    port = int(os.getenv("TVM_TRACKER_PORT", 9090))
-    target = "opencl"
-    target_host = "llvm -mtriple=aarch64-linux-gnu"
-    device_key = os.getenv("RPC_DEVICE_KEY", "android")
-    cross_compile = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++")
-
-    def __init__(self):
-        """Keep remote device for lifetime of object."""
-        self.device = self._get_remote()
-
-    @classmethod
-    def _get_remote(cls):
-        """Get a remote (or local) device to use for testing."""
-        if cls.connection_type == "tracker":
-            device = request_remote(cls.device_key, cls.host, cls.port, 
timeout=1000)
-        elif cls.connection_type == "remote":
-            device = rpc.connect(cls.host, cls.port)
-        elif cls.connection_type == "local":
-            device = rpc.LocalSession()
-        else:
-            raise ValueError(
-                "connection_type in test_config.json should be one of: " 
"local, tracker, remote."
-            )
+import os
+import tvm
+import numpy as np
+from tvm import relay
+from tvm import autotvm
+from tvm import rpc
+from tvm.contrib import utils, ndk
+from tvm.relay import testing
+from tvm.relay.transform import recast
+from tvm.contrib import graph_runtime
+from tvm.runtime.vm import VirtualMachine
+import json
 
-        return device
+
+NDK_CROSS_COMPILER = os.getenv("TVM_NDK_CC", "aarch64-linux-android-g++")
 
 
 def get_cpu_op_count(mod):
@@ -139,78 +90,102 @@ def get_non_cpu_op_count(mod):
     return c.count
 
 
-def skip_codegen_test():
-    """Skip test if it requires the CLML codegen and it's not present."""
-    if not tvm.get_global_func("relay.ext.clml", True):
-        print("Skip because CLML codegen is not available.")
-        return True
-
+# build module run with opencl or clml target with graph executor
+def build_and_run(
+    remote,
+    mod,
+    params1,
+    inputs,
+    target="llvm",
+    enable_clml=False,
+    stat_file=None,
+):
+    if remote is None:
+        target_host = "llvm"
+    else:
+        target_host = "llvm -mtriple=arm64-linux-android"
 
-def build_module(mod, target, target_host, params=None, enable_clml=True, 
tune_log=""):
-    """Build module with option to build for CLML."""
     if isinstance(mod, tvm.relay.expr.Call):
         mod = tvm.IRModule.from_expr(mod)
 
-    with autotvm.apply_history_best(tune_log):
-        with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["AlterOpLayout"]):
+    with autotvm.apply_history_best(stat_file):
+        with tvm.transform.PassContext(opt_level=3):
             if enable_clml:
-                mod = clml.preprocess_module(mod)
-                mod = clml.partition_for_clml(mod, params)
-            relay.backend.te_compiler.get().clear()
-            return relay.build(mod, target=target, target_host=target_host, 
params=params)
-
+                mod = clml.partition_for_clml(mod, params1)
+            graph, lib, params = relay.build(
+                mod, target_host=target_host, target=target, params=params1
+            )
 
-def build_and_run(
-    mod, inputs, outputs, params, device, enable_clml=True, no_runs=1, 
config=None, tune_log=""
+    if remote is None:
+        ctx = tvm.opencl()
+        m = graph_runtime.create(graph, lib, ctx)
+    else:
+        temp = utils.tempdir()
+        dso_binary = "dev_lib_cl.so"
+        dso_binary_path = temp.relpath(dso_binary)
+        ctx = remote.cl(0)
+        lib.export_library(dso_binary_path, fcompile=ndk.create_shared)
+        remote.upload(dso_binary_path)
+        rlib = remote.load_module(dso_binary)
+        m = graph_runtime.create(graph, rlib, ctx)
+    m.set_input(**params)
+    m.set_input(**inputs)
+    m.run()
+    return m.get_output(0)
+
+
+# build module run with opencl or clml target with vm executor
+def build_and_run_vm(
+    remote,
+    mod,
+    params1,
+    inputs,
+    target="llvm",
+    enable_clml=False,
+    stat_file=None,
 ):
-    """Build and run the relay module."""
-    if config is None:
-        config = {}
-
-    try:
-        libm = build_module(mod, device.target, device.target_host, params, 
enable_clml, tune_log)
-        clml_modules = extract_clml_modules(libm)
-        for mod in clml_modules:
-            source = mod.get_source("json")
-            codegen = json.loads(source)["nodes"]
-            # remove input and const names as these cannot be predetermined
-            for node in range(len(codegen)):
-                if codegen[node]["op"] == "input" or codegen[node]["op"] == 
"const":
-                    codegen[node]["name"] = ""
-            codegen_str = json.dumps(codegen, sort_keys=True, indent=2)
-
-    except Exception as e:
-        err_msg = "The module could not be built.\n"
-        if config:
-            err_msg += f"The test failed with the following parameters: 
{config}\n"
-        err_msg += str(e)
-        raise Exception(err_msg)
-
-    lib = update_lib(libm, device.device, device.cross_compile)
-    gen_module = 
graph_executor.GraphModule(lib["default"](device.device.cl(0)))
-    gen_module.set_input(**inputs)
-    out = []
-    for _ in range(no_runs):
-        gen_module.run()
-        out.append([gen_module.get_output(i) for i in range(outputs)])
-    # time_f = gen_module.module.time_evaluator("run", device.device.cl(0), 
number=1)
-    # cost = time_f().mean
-    # print("%g secs/iteration\n" % cost)
-    return out
+    if remote is None:
+        target_host = "llvm"
+    else:
+        target_host = "llvm -mtriple=arm64-linux-android"
+
+    target_host = tvm.target.Target(target_host)
+    target = tvm.target.Target(target, target_host)
+    if isinstance(mod, relay.Function):
+        module = tvm.IRModule({})
+        module["main"] = mod
+        mod = module
+    elif isinstance(mod, tvm.relay.expr.Call):
+        mod = tvm.IRModule.from_expr(mod)
 
+    with autotvm.apply_history_best(stat_file):
+        with tvm.transform.PassContext(opt_level=3):
+            if enable_clml:
+                mod = clml.partition_for_clml(mod, params1)
+            vmc = relay.vm.compile(mod, target=target, params=params1)
 
-def update_lib(lib, device, cross_compile):
-    """Export the library to the remote/local device."""
-    lib_name = "mod.so"
-    temp = utils.tempdir()
-    lib_path = temp.relpath(lib_name)
-    if cross_compile:
-        lib.export_library(lib_path, cc=cross_compile)
+    if remote is None:
+        dev = tvm.opencl()
+        vm = VirtualMachine(vmc, dev, "naive")
     else:
-        lib.export_library(lib_path)
-    device.upload(lib_path)
-    lib = device.load_module(lib_name)
-    return lib
+        temp = utils.tempdir()
+        dso_binary = "dev_lib_cl.so"
+        dso_binary_path = temp.relpath(dso_binary)
+        dev = remote.cl(0)
+        vmc.mod.export_library(dso_binary_path, cc=NDK_CROSS_COMPILER)
+        remote.upload(dso_binary_path)
+        rlib = remote.load_module(dso_binary)
+        vm = VirtualMachine(rlib, dev, "naive")
+    inputs_data = {}
+    for key in inputs.keys():
+        inputs_data[key] = tvm.nd.array(inputs[key], dev)
+    for k, v in params1.items():
+        inputs_data[k] = tvm.nd.array(v, dev)
+    vm.set_input("main", **inputs_data)
+    vm.invoke_stateful("main")
+    out = vm.get_outputs()[0]
+
+    return out
 
 
 def extract_clml_modules(module):
@@ -219,18 +194,23 @@ def extract_clml_modules(module):
 
 
 def verify_codegen(
+    remote,
     mod,
-    known_good_codegen,
-    device,
     params,
+    known_good_codegen,
+    target="llvm",
     num_clml_modules=1,
     tvm_ops=0,
 ):
+    if remote is None:
+        target_host = "llvm"
+    else:
+        target_host = "llvm -mtriple=arm64-linux-android"
+
     """Check clml codegen against a known good output."""
     if isinstance(mod, tvm.relay.expr.Call):
         mod = tvm.IRModule.from_expr(mod)
-    with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["AlterOpLayout"]):
-        mod = clml.preprocess_module(mod)
+    with tvm.transform.PassContext(opt_level=3):
         mod = clml.partition_for_clml(mod, params)
         tvm_op_count = get_cpu_op_count(mod)
         assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected 
{}".format(
@@ -246,7 +226,7 @@ def verify_codegen(
         ), "Got {} Open CLML partitions, expected {}".format(partition_count, 
num_clml_modules)
     relay.backend.te_compiler.get().clear()
 
-    module = relay.build(mod, target=device.target, 
target_host=device.target_host, params=params)
+    module = relay.build(mod, target=target, target_host=target_host, 
params=params)
     clml_modules = extract_clml_modules(module)
     assert len(clml_modules) == num_clml_modules, (
         f"The number of CLML modules produced ({len(clml_modules)}) does not "
diff --git a/tests/python/contrib/test_clml/test_network.py 
b/tests/python/contrib/test_clml/test_network.py
index 177359d9b1..ec51510920 100644
--- a/tests/python/contrib/test_clml/test_network.py
+++ b/tests/python/contrib/test_clml/test_network.py
@@ -21,158 +21,137 @@ import numpy as np
 from tvm import relay
 from tvm.relay import testing
 from tvm.contrib import utils
-from test_clml.infrastructure import build_and_run, Device
+from test_clml.infrastructure import build_and_run, build_and_run_vm
 import pytest
 
 
-def _build_and_run_network(mod, params, inputs, data, device, atol, rtol, 
tvm_log=""):
+def _build_and_run_network(remote, mod, params, input_data, target, 
executor_type, tvm_log=""):
     """Helper function to build and run a network."""
 
     outputs = []
     for clml in [True, False]:
-        outputs.append(
-            build_and_run(mod, data, 1, params, device, enable_clml=clml, 
tune_log=tvm_log)[0][0]
-        )
+        if executor_type == "ge":
+            outputs.append(
+                build_and_run(
+                    remote,
+                    mod,
+                    params,
+                    input_data,
+                    target,
+                    enable_clml=clml,
+                    stat_file=tvm_log,
+                )
+            )
+        else:
+            outputs.append(
+                build_and_run_vm(
+                    remote,
+                    mod,
+                    params,
+                    input_data,
+                    target,
+                    enable_clml=clml,
+                    stat_file=tvm_log,
+                )
+            )
     return outputs
 
 
-def _get_keras_model(keras_model, inputs_dict, data):
-    """Convert Keras graph to relay."""
-    inputs = {}
-    for name, (shape, _) in inputs_dict.items():
-        inputs[keras_model.input_names[0]] = shape
-
-    from tensorflow.keras.layers import Input
-    from tensorflow.keras.models import Model
-
-    def get_bottom_top_model(model, layer_name):
-        layer = model.get_layer(layer_name)
-        bottom_input = model.layers[0].input
-        bottom_output = layer.output
-        bottom_model = Model(bottom_input, bottom_output)
-        return bottom_model
-
-    keras_model = get_bottom_top_model(keras_model, "predictions")
-    ref_output = keras_model.predict(data["input_1"].transpose(0, 2, 3, 1))
-
-    mod, params = relay.frontend.from_keras(keras_model, inputs, layout="NCHW")
-    return mod, params, ref_output
-
-
-@pytest.mark.parametrize("dtype", ["float16"])
-@tvm.testing.requires_openclml
-def test_mobilenet(device, dtype):
-    def get_model():
-        from tensorflow.keras.applications import MobileNet
-        import tensorflow as tf
-
-        tf.keras.backend.clear_session()
-
-        mobilenet = MobileNet(
-            include_top=True, weights=None, input_shape=(224, 224, 3), 
classes=1000
+def get_network(name, batch_size, dtype="float32"):
+    """Get the symbol definition and random weight of a network
+
+    Parameters
+    ----------
+    name: str
+        The name of the network, can be 'resnet-18', 'resnet-50', 'vgg-16', 
'inception_v3', 'mobilenet', ...
+    batch_size: int
+        batch size
+    dtype: str
+        Data type
+
+    Returns
+    -------
+    net: tvm.IRModule
+        The relay function of network definition
+    params: dict
+        The random parameters for benchmark
+    input_shape: tuple
+        The shape of input tensor
+    output_shape: tuple
+        The shape of output tensor
+    """
+    input_shape = (batch_size, 3, 224, 224)
+    output_shape = (batch_size, 1000)
+
+    if name == "mobilenet":
+        net, params = testing.mobilenet.get_workload(batch_size=batch_size, 
dtype=dtype)
+    elif name == "inception_v3":
+        input_shape = (batch_size, 3, 299, 299)
+        net, params = testing.inception_v3.get_workload(batch_size=batch_size, 
dtype=dtype)
+    elif "resnet" in name:
+        n_layer = int(name.split("-")[1])
+        net, params = testing.resnet.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
         )
-        inputs = {mobilenet.input_names[0]: ((1, 3, 224, 224), "float32")}
-
-        data = {}
-        np.random.seed(0)
-
-        for name, (shape, dtype) in inputs.items():
-            if dtype == "uint8":
-                low, high = 0, 1
-            else:
-                low, high = -1, 1
-            data[name] = np.random.uniform(low, high, shape).astype(dtype)
-
-        mod, params, ref_outputs = _get_keras_model(mobilenet, inputs, data)
-        return mod, params, inputs, data, ref_outputs
-
-    mod, params, inputs, input_data, ref_outputs = get_model()
-    outputs = _build_and_run_network(
-        mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5
-    )
-
-    opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
-    clml_sort = np.argsort(outputs[0].asnumpy()).flatten()
-    tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, 
atol=1e-5)
-
-
-@pytest.mark.parametrize("dtype", ["float16"])
-@tvm.testing.requires_openclml
-def test_inception_v3(device, dtype):
-    def get_model():
-        from tensorflow.keras.applications import InceptionV3
-        import tensorflow as tf
-
-        tf.keras.backend.clear_session()
-
-        inceptionV3 = InceptionV3(
-            include_top=True, weights=None, input_shape=(299, 299, 3), 
classes=1000
+    elif "vgg" in name:
+        n_layer = int(name.split("-")[1])
+        net, params = testing.vgg.get_workload(
+            num_layers=n_layer, batch_size=batch_size, dtype=dtype
         )
-        inputs = {inceptionV3.input_names[0]: ((1, 3, 299, 299), "float16")}
-
-        data = {}
-        np.random.seed(0)
-        for name, (shape, dtype) in inputs.items():
-            if dtype == "uint8":
-                low, high = 0, 1
-            else:
-                low, high = -2, 1
-            data[name] = np.random.uniform(low, high, shape).astype(dtype)
-
-        mod, params, ref_outputs = _get_keras_model(inceptionV3, inputs, data)
-        return mod, params, inputs, data, ref_outputs
-
-    mod, params, inputs, input_data, ref_outputs = get_model()
-    outputs = _build_and_run_network(
-        mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5
-    )
-
-    opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
-    clml_sort = np.argsort(outputs[0].asnumpy()).flatten()
-    tvm.testing.assert_allclose(opencl_sort[:5], clml_sort[:5], rtol=1e-5, 
atol=1e-5)
-
-
-@pytest.mark.parametrize("dtype", ["float16"])
+    elif "densenet" in name:
+        n_layer = int(name.split("-")[1])
+        net, params = testing.densenet.get_workload(
+            densenet_size=n_layer, batch_size=batch_size, dtype=dtype
+        )
+    elif "squeezenet" in name:
+        version = name.split("_v")[1]
+        net, params = testing.squeezenet.get_workload(
+            batch_size=batch_size, version=version, dtype=dtype
+        )
+    else:
+        raise ValueError("Unsupported network: " + name)
+
+    initializer = relay.testing.init.Xavier()
+    for param_name in list(params.keys()):
+        filter_data = 
np.zeros(params[param_name].shape).astype(params[param_name].dtype)
+        if len(filter_data.shape) > 1:
+            initializer("weight", filter_data)
+        else:
+            initializer("bias", filter_data)
+        params[param_name] = tvm.nd.array(filter_data)
+
+    return net, params, {"data": (input_shape, dtype)}, output_shape
+
+
+executor_type = tvm.testing.parameter("ge", "vm")
+
+
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@pytest.mark.parametrize(
+    "name",
+    [
+        "resnet-18",
+        "resnet-34",
+        "resnet-50",
+        "inception_v3",
+        "mobilenet",
+    ],
+)
 @tvm.testing.requires_openclml
-def test_resnet50v2(device, dtype):
-    def get_model():
-        from tensorflow.keras.applications import ResNet50V2
-        import tensorflow as tf
-
-        tf.keras.backend.clear_session()
-
-        model = ResNet50V2(include_top=True, weights=None, input_shape=(224, 
224, 3), classes=1000)
-        inputs_dict = {model.input_names[0]: ((1, 3, 224, 224), "float32")}
-
-        data = {}
-        np.random.seed(0)
-
-        for name, (shape, dtype) in inputs_dict.items():
-            if dtype == "uint8":
-                low, high = 0, 1
-            else:
-                low, high = -1, 1
-            data[name] = np.random.uniform(low, high, shape).astype(dtype)
-
-        """Convert Keras graph to relay."""
-        inputs = {}
-        for name, (shape, _) in inputs_dict.items():
-            inputs[model.input_names[0]] = shape
-
-        ref_outputs = model.predict(data["input_1"].transpose(0, 2, 3, 1))
-
-        mod, params = relay.frontend.from_keras(model, inputs, layout="NCHW")
-
-        return mod, params, inputs, data, ref_outputs
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_network(remote, name, dtype, target, executor_type):
+    print("Network evaluating .. " + name + " " + dtype)
+    np.random.seed(0)
+    mod, params, inputs, _ = get_network(name, 1, dtype=dtype)
+    input_data = {}
 
-    mod, params, inputs, input_data, ref_outputs = get_model()
-    outputs = _build_and_run_network(
-        mod, params, inputs, input_data, device=device, atol=1e-5, rtol=1e-5
-    )
+    for name, (shape, dtype) in inputs.items():
+        input_data[name] = np.random.uniform(-1.0, 1.0, shape).astype(dtype)
 
+    outputs = _build_and_run_network(remote, mod, params, input_data, target, 
executor_type)
     opencl_sort = np.argsort(outputs[1].asnumpy()).flatten()
     clml_sort = np.argsort(outputs[0].asnumpy()).flatten()
-    tvm.testing.assert_allclose(opencl_sort[:10], clml_sort[:10], rtol=1e-5, 
atol=1e-5)
+    tvm.testing.assert_allclose(opencl_sort[-5:], clml_sort[-5:], rtol=0, 
atol=0)
 
 
 if __name__ == "__main__":
diff --git a/tests/python/contrib/test_clml/test_ops.py 
b/tests/python/contrib/test_clml/test_ops.py
index e59a73a485..58365bf429 100644
--- a/tests/python/contrib/test_clml/test_ops.py
+++ b/tests/python/contrib/test_clml/test_ops.py
@@ -25,15 +25,47 @@ from tvm.ir import IRModule
 from tvm.contrib import utils
 from test_clml.infrastructure import (
     build_and_run,
-    Device,
-    skip_codegen_test,
+    build_and_run_vm,
     verify_codegen,
-    build_module,
-    get_cpu_op_count,
 )
 import pytest
 
 
+executor_type = tvm.testing.parameter("ge", "vm")
+
+
+def _build_and_run_network(remote, mod, params, input_data, target, 
executor_type, tvm_log=""):
+    """Helper function to build and run a network."""
+
+    outputs = []
+    for clml in [True, False]:
+        if executor_type == "ge":
+            outputs.append(
+                build_and_run(
+                    remote,
+                    mod,
+                    params,
+                    input_data,
+                    target,
+                    enable_clml=clml,
+                    stat_file=tvm_log,
+                )
+            )
+        else:
+            outputs.append(
+                build_and_run_vm(
+                    remote,
+                    mod,
+                    params,
+                    input_data,
+                    target,
+                    enable_clml=clml,
+                    stat_file=tvm_log,
+                )
+            )
+    return outputs
+
+
 def _get_conv_model(
     shape,
     kernel_h,
@@ -181,34 +213,36 @@ def _get_conv_expected_codegen(
     return inputs
 
 
-@pytest.mark.parametrize("dtype", ["float32"])
-@tvm.testing.requires_openclml
-def test_conv2d(device, dtype):
-    trials = [
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@pytest.mark.parametrize(
+    "trials",
+    [
         # Normal convolution
-        [3, 3, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False), 
False],
-        [2, 1, (2, 2), (1, 1), (1, 1), 7, (15, 16, 12), (True, False, True), 
False],
-        [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False), 
False],
-        [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, True), 
False],
-        [2, 2, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, False, False), 
False],
-        [2, 1, (2, 2), (1, 1), (1, 1), 7, (16, 12, 15), (False, False, True), 
False],
-        [3, 3, (2, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False), 
False],
-        [3, 3, (1, 1), (1, 1), (1, 1), 16, (16, 12, 15), (False, False, 
False), False],
-        [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False), 
False],
-        [1, 3, (1, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, True), 
False],
-        [2, 2, (2, 2), (1, 1), (1, 1), 4, (20, 20, 20), (False, True, False), 
False],
+        [3, 3, (1, 1), (1, 1), (1, 1), 4, (14, 10, 10), (False, True, False), 
False],
+        [2, 2, (1, 1), (1, 1), (1, 1), 4, (16, 10, 10), (False, False, False), 
False],
         [5, 5, (1, 1), (2, 2), (1, 1), 4, (14, 10, 10), (False, False, False), 
False],
-        [3, 3, (2, 1), (1, 1), (1, 1), 7, (20, 20, 20), (False, False, False), 
False],
-        [3, 3, (1, 1), (2, 2), (1, 1), 16, (14, 10, 10), (False, True, True), 
False],
+        [5, 5, (1, 1), (2, 2), (1, 1), 4, (16, 10, 10), (False, False, False), 
False],
+        [5, 5, (1, 1), (1, 1), (1, 1), 4, (6, 256, 256), (True, True, True), 
False],
+        [3, 3, (0, 0), (1, 1), (1, 1), 4, (4, 512, 512), (False, True, False), 
False],
+        [3, 3, (1, 1), (1, 1), (1, 1), 8, (6, 512, 512), (False, True, False), 
False],
+        [1, 3, (0, 0), (1, 1), (1, 1), 16, (16, 20, 20), (False, False, True), 
False],
+        [3, 1, (0, 0), (1, 1), (1, 1), 64, (64, 20, 20), (False, False, True), 
False],
+        # [3, 3, (1, 1), (1, 1), (1, 1), 128, (128, 16, 16), (False, True, 
False), False],
+        # [3, 3, (1, 1), (2, 2), (1, 1), 256, (128, 16, 16), (False, True, 
True), False],
         # Depth-wise convolution
-        [3, 3, (1, 1), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, True), 
True],
-        [5, 5, (2, 2), (1, 1), (1, 1), 20, (20, 20, 20), (False, True, False), 
True],
-        [3, 3, (2, 2), (2, 2), (1, 1), 14, (14, 10, 10), (False, False, 
False), True],
-        [5, 5, (0, 0), (1, 1), (1, 1), 20, (20, 20, 20), (False, False, 
False), True],
-        [3, 3, (1, 1), (2, 2), (1, 1), 14, (14, 10, 10), (False, True, True), 
True],
-    ]
+        [3, 3, (1, 1), (1, 1), (1, 1), 11, (11, 20, 20), (False, False, True), 
True],
+        [5, 5, (2, 2), (1, 1), (1, 1), 32, (32, 20, 20), (False, True, False), 
True],
+        [3, 3, (2, 2), (2, 2), (1, 1), 128, (128, 8, 8), (False, False, 
False), True],
+        [5, 5, (0, 0), (1, 1), (1, 1), 64, (64, 32, 32), (False, False, 
False), True],
+        [3, 3, (1, 1), (2, 2), (1, 1), 16, (16, 256, 256), (False, True, 
True), True],
+    ],
+)
+@tvm.testing.requires_openclml
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_conv2d(remote, dtype, target, trials, executor_type):
+    np.random.seed(0)
 
-    for (
+    (
         kernel_h,
         kernel_w,
         pad,
@@ -218,43 +252,43 @@ def test_conv2d(device, dtype):
         shape,
         composite,
         is_depthwise,
-    ) in trials:
-        shape = (1, *shape)
-        if is_depthwise:
-            groups = shape[1]
-        else:
-            groups = 1
-        outputs = []
-        inputs = {
-            "a": tvm.nd.array(np.random.uniform(-1, 1, shape).astype(dtype)),
-        }
+    ) = trials
 
-        func, params = _get_conv_model(
-            shape,
-            kernel_h,
-            kernel_w,
-            pad,
-            stride,
-            dilation,
-            groups,
-            dtype,
-            out_channels,
-            inputs,
-            has_pad=composite[0],
-            has_bias=composite[1],
-            has_activation=composite[2],
-        )
-        opencl_out = build_and_run(func, inputs, 1, params, device, 
enable_clml=False)[0]
-        clml_out = build_and_run(func, inputs, 1, params, device, 
enable_clml=True)[0]
+    shape = (1, *shape)
+    if is_depthwise:
+        groups = shape[1]
+    else:
+        groups = 1
+    outputs = []
+    inputs = {
+        "a": tvm.nd.array(np.random.uniform(-1, 1, shape).astype(dtype)),
+    }
 
-        tvm.testing.assert_allclose(
-            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-5, 
atol=1e-5
-        )
-        args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, 
dtype, out_channels)
-        exp_codegen = _get_conv_expected_codegen(
-            *args, has_bias=composite[1], has_activation=composite[2]
-        )
-        verify_codegen(func, exp_codegen, device, params)
+    func, params = _get_conv_model(
+        shape,
+        kernel_h,
+        kernel_w,
+        pad,
+        stride,
+        dilation,
+        groups,
+        dtype,
+        out_channels,
+        inputs,
+        has_pad=composite[0],
+        has_bias=composite[1],
+        has_activation=composite[2],
+    )
+    outputs = _build_and_run_network(remote, func, params, inputs, target, 
executor_type)
+    out_rtol = 1e-1 if dtype == "float16" else 1e-5
+    tvm.testing.assert_allclose(
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+    )
+    args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, 
out_channels)
+    exp_codegen = _get_conv_expected_codegen(
+        *args, has_bias=composite[1], has_activation=composite[2]
+    )
+    verify_codegen(remote, func, params, exp_codegen, target)
 
 
 def _get_conv2d_transpose_expected_codegen(
@@ -301,69 +335,75 @@ def _get_conv2d_transpose_expected_codegen(
     return exp_codegen
 
 
-@pytest.mark.parametrize("dtype", ["float32"])
-@tvm.testing.requires_openclml
-def test_conv2d_transpose(device, dtype):
-    trials = [
-        [(1, 256, 100, 100), (256, 64, 4, 4), 64, (4, 4), (2, 2), (1, 1, 1, 
1)],
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@pytest.mark.parametrize(
+    "trials",
+    [
         [(1, 64, 200, 200), (64, 64, 4, 4), 64, (4, 4), (2, 2), (1, 1, 1, 1)],
         [(1, 64, 400, 400), (64, 16, 4, 4), 16, (4, 4), (2, 2), (1, 1, 1, 1)],
-    ]
-    for (dshape, kshape, channels, kernel_size, strides, padding) in trials:
-        x = relay.var("input", shape=dshape, dtype=dtype)
-        input_arr = tvm.nd.array(np.random.uniform(-1, 1, 
dshape).astype(dtype))
-        w = relay.var("wt", shape=kshape, dtype=dtype)
-        weight_arr = tvm.nd.array(np.random.uniform(-1, 1, 
kshape).astype(dtype))
-        inputs = {
-            "input": input_arr,
-        }
-        params = {
-            "wt": weight_arr,
-        }
-        y = relay.nn.conv2d_transpose(
-            x,
-            w,
-            channels=channels,
-            kernel_size=kernel_size,
-            strides=strides,
-            padding=padding,
-            kernel_layout="IOHW",
-            data_layout="NCHW",
-        )
-        func = relay.Function([x, w], y)
-        mod = IRModule.from_expr(func)
-
-        opencl_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=False)[0]
-        clml_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=True)[0]
-        tvm.testing.assert_allclose(
-            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, 
atol=1e-3
-        )
-
-        args = (
-            dshape,
-            kshape,
-            channels,
-            kernel_size,
-            strides,
-            padding,
-            (1, 1),
-            dtype,
-            opencl_out[0].shape,
-        )
-        exp_codegen = _get_conv2d_transpose_expected_codegen(*args)
-        verify_codegen(mod, exp_codegen, device, params)
+        [(1, 16, 32, 32), (16, 16, 3, 3), 16, (3, 3), (1, 1), (1, 1, 1, 1)],
+        # [(1, 256, 100, 100), (256, 64, 4, 4), 64, (4, 4), (2, 2), (1, 1, 1, 
1)],
+    ],
+)
+@tvm.testing.requires_openclml
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_conv2d_transpose(remote, dtype, target, trials, executor_type):
+    np.random.seed(0)
+    (dshape, kshape, channels, kernel_size, strides, padding) = trials
+    x = relay.var("input", shape=dshape, dtype=dtype)
+    input_arr = tvm.nd.array(np.random.uniform(-1, 1, dshape).astype(dtype))
+    w = relay.var("wt", shape=kshape, dtype=dtype)
+    weight_arr = tvm.nd.array(np.random.uniform(-1, 1, kshape).astype(dtype))
+    inputs = {
+        "input": input_arr,
+    }
+    params = {
+        "wt": weight_arr,
+    }
+    y = relay.nn.conv2d_transpose(
+        x,
+        w,
+        channels=channels,
+        kernel_size=kernel_size,
+        strides=strides,
+        padding=padding,
+        kernel_layout="IOHW",
+        data_layout="NCHW",
+    )
+    func = relay.Function([x, w], y)
+    mod = IRModule.from_expr(func)
+    outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+    out_rtol = 1e-1 if dtype == "float16" else 1e-5
+    tvm.testing.assert_allclose(
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+    )
+    args = (
+        dshape,
+        kshape,
+        channels,
+        kernel_size,
+        strides,
+        padding,
+        (1, 1),
+        dtype,
+        outputs[0].shape,
+    )
+    exp_codegen = _get_conv2d_transpose_expected_codegen(*args)
+    verify_codegen(remote, mod, params, exp_codegen, target)
 
 
-@pytest.mark.parametrize("dtype", ["float16"])
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@pytest.mark.parametrize("trials", [[1, 64, 8, 8], [1, 16, 64, 64]])
 @tvm.testing.requires_openclml
-def test_batchnorm(device, dtype):
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_batchnorm(remote, dtype, target, trials, executor_type):
     if clml.clml_sdk_version() < 3:
         print("Skip due to unsupported CLML version:", clml.clml_sdk_version())
         return
-    in_shape = (1, 8, 64, 64)
-    channels = 8
+    in_shape = trials
+    channels = in_shape[1]
 
-    np.random.seed(8)
+    np.random.seed(0)
 
     input_arr = tvm.nd.array(np.random.uniform(-1, 1, in_shape).astype(dtype))
     inp = relay.var("a", shape=in_shape, dtype=dtype)
@@ -381,24 +421,58 @@ def test_batchnorm(device, dtype):
 
     func = relay.nn.batch_norm(inp, gamma, beta, mean, variance, axis=1, 
epsilon=0.0003)[0]
     mod = IRModule.from_expr(func)
-
     inputs = {
         "a": input_arr,
     }
-
-    opencl_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=False)[0]
-    clml_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=True)[0]
-
+    outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+    out_rtol = 1e-3 if dtype == "float16" else 1e-5
     tvm.testing.assert_allclose(
-        clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
     )
+    exp_codegen = [
+        {
+            "attrs": {"dtype": [[dtype]], "shape": 
[[list(inputs["a"].shape)]]},
+            "name": "",
+            "op": "input",
+        },
+        {"attrs": {"dtype": [[dtype]], "shape": [[[channels]]]}, "name": "", 
"op": "const"},
+        {"attrs": {"dtype": [[dtype]], "shape": [[[channels]]]}, "name": "", 
"op": "const"},
+        {"attrs": {"dtype": [[dtype]], "shape": [[[channels]]]}, "name": "", 
"op": "const"},
+        {"attrs": {"dtype": [[dtype]], "shape": [[[channels]]]}, "name": "", 
"op": "const"},
+        {
+            "attrs": {
+                "axis": [["1"]],
+                "center": [["1"]],
+                "dtype": [[dtype]],
+                "epsilon": [["0.00029999999999999997"]],
+                "num_inputs": "5",
+                "num_outputs": "1",
+                "scale": [["1"]],
+                "shape": [[list(outputs[0].shape)]],
+            },
+            "inputs": [[0, 0, 0], [1, 0, 0], [2, 0, 0], [3, 0, 0], [4, 0, 0]],
+            "name": "nn.batch_norm",
+            "op": "kernel",
+        },
+    ]
+    verify_codegen(remote, mod, params, exp_codegen, target)
 
 
-@pytest.mark.parametrize("dtype", ["float16"])
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@pytest.mark.parametrize(
+    "trials",
+    [
+        [(1, 64, 64, 40), (1, 64, 64, 40)],
+        [(1, 1280, 32, 32), (1, 640, 32, 32)],
+        [(1, 64), (1, 32)],
+    ],
+)
 @tvm.testing.requires_openclml
-def test_concat(device, dtype):
-    in_shape_1 = (1, 16, 16, 16)
-    in_shape_2 = (1, 16, 16, 16)
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_concat(remote, dtype, target, trials, executor_type):
+    np.random.seed(0)
+    in_shape_1 = trials[0]
+    in_shape_2 = trials[1]
     a = relay.var("input_1", shape=in_shape_1, dtype=dtype)
     b = relay.var("input_2", shape=in_shape_2, dtype=dtype)
     low, high = -1, 1
@@ -409,14 +483,13 @@ def test_concat(device, dtype):
 
     params = {}
     func = relay.concatenate((a, b), axis=1)
-    mod = IRModule.from_expr(func)
-
-    opencl_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=False)[0]
-    clml_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=True)[0]
 
+    outputs = _build_and_run_network(remote, func, params, inputs, target, 
executor_type)
+    out_rtol = 1e-2 if dtype == "float16" else 1e-5
     tvm.testing.assert_allclose(
-        clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, atol=1e-3
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
     )
+
     exp_codegen = [
         {
             "attrs": {
@@ -440,14 +513,14 @@ def test_concat(device, dtype):
                 "dtype": [[dtype]],
                 "num_inputs": "2",
                 "num_outputs": "1",
-                "shape": [[list(clml_out[0].shape)]],
+                "shape": [[list(outputs[0].shape)]],
             },
             "inputs": [[0, 0, 0], [1, 0, 0]],
             "name": "concatenate",
             "op": "kernel",
         },
     ]
-    verify_codegen(func, exp_codegen, device, params)
+    verify_codegen(remote, func, params, exp_codegen, target)
 
 
 def _get_pool_expected_codegen(input_shape, pool_size, stride, padding, 
pool_type, dtype):
@@ -488,10 +561,10 @@ def _get_pool_expected_codegen(input_shape, pool_size, 
stride, padding, pool_typ
     return exp_codegen
 
 
-@pytest.mark.parametrize("dtype", ["float16"])
-@tvm.testing.requires_openclml
-def test_pool(device, dtype):
-    trials = [
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@pytest.mark.parametrize(
+    "trials",
+    [
         # input size         pool_size stride  paading
         [(1, 64, 147, 147), (3, 3), (2, 2), (0, 0, 0, 0), "max"],
         [(1, 192, 71, 71), (3, 3), (2, 2), (0, 0, 0, 0), "max"],
@@ -503,42 +576,59 @@ def test_pool(device, dtype):
         [(1, 288, 35, 35), (3, 3), (1, 1), (0, 0, 1, 1), "avg"],
         [(1, 768, 17, 17), (3, 3), (1, 1), (0, 0, 1, 1), "avg"],
         [(1, 1280, 8, 8), (3, 3), (1, 1), (0, 0, 1, 1), "avg"],
-    ]
+    ],
+)
+@tvm.testing.requires_openclml
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_pool(remote, dtype, target, trials, executor_type):
+    np.random.seed(0)
     params = {}
-    for (
+    (
         input_shape,
         pool_size,
         stride,
         padding,
         pooling_type,
-    ) in trials:
-        a = relay.var("input_1", shape=input_shape, dtype=dtype)
-        input_arr = tvm.nd.array(np.random.uniform(-1, 1, 
input_shape).astype(dtype))
-        inputs = {
-            "input_1": input_arr,
-        }
-
-        if pooling_type == "max":
-            func = relay.nn.max_pool2d(a, pool_size=pool_size, strides=stride, 
padding=padding)
-        else:
-            func = relay.nn.avg_pool2d(a, pool_size=pool_size, strides=stride, 
padding=padding)
-        mod = IRModule.from_expr(func)
-
-        opencl_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=False)[0]
-        clml_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=True)[0]
-        tvm.testing.assert_allclose(
-            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, 
atol=1e-3
-        )
+    ) = trials
+    a = relay.var("input_1", shape=input_shape, dtype=dtype)
+    input_arr = tvm.nd.array(np.random.uniform(-1, 1, 
input_shape).astype(dtype))
+    inputs = {
+        "input_1": input_arr,
+    }
+    if pooling_type == "max":
+        func = relay.nn.max_pool2d(a, pool_size=pool_size, strides=stride, 
padding=padding)
+    else:
+        func = relay.nn.avg_pool2d(a, pool_size=pool_size, strides=stride, 
padding=padding)
 
-        args = (input_shape, pool_size, stride, padding, pooling_type, dtype)
-        exp_codegen = _get_pool_expected_codegen(*args)
-        verify_codegen(func, exp_codegen, device, params)
+    outputs = _build_and_run_network(remote, func, params, inputs, target, 
executor_type)
+    out_rtol = 1e-2 if dtype == "float16" else 1e-5
+    tvm.testing.assert_allclose(
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+    )
+    args = (input_shape, pool_size, stride, padding, pooling_type, dtype)
+    exp_codegen = _get_pool_expected_codegen(*args)
+    verify_codegen(remote, func, params, exp_codegen, target)
 
 
-@pytest.mark.parametrize("dtype", ["float32"])
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@pytest.mark.parametrize(
+    "trials",
+    [
+        [(5, 16), (32, 16), False],
+        [(320, 64), (320, 64), False],
+        [(256, 256), (256, 256), False],
+        [(512, 512), (512, 512), False],
+        [(1, 256), (100, 256), False],
+        [(1, 16), (32, 16), True],
+        [(1, 512), (512, 512), True],
+        [(1, 5), (4, 5), True],
+    ],
+)
 @tvm.testing.requires_openclml
-def test_dense(device, dtype):
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_dense(remote, dtype, target, trials, executor_type):
     def _get_model(x_shape, k_shape, has_bias=False):
+        np.random.seed(0)
         x = relay.var("x", shape=(x_shape), dtype=dtype)
         kernel = relay.var("kernel", shape=(k_shape), dtype=dtype)
         out = relay.nn.dense(x, kernel, units=k_shape[0])
@@ -562,22 +652,8 @@ def test_dense(device, dtype):
                 "op": "const",
             },
         ]
-
-        dense_node = {
-            "attrs": {
-                "num_inputs": "2",
-                "num_outputs": "1",
-                "dtype": [[dtype]],
-                "out_dtype": [[""]],
-                "shape": [[[x_shape[0], k_shape[0]]]],
-                "units": [[str(k_shape[0])]],
-            },
-            "inputs": [[0, 0, 0], [1, 0, 0]],
-            "name": "nn.dense",
-            "op": "kernel",
-        }
-        exp_codegen.append(dense_node)
-
+        input_nodes = [[0, 0, 0], [1, 0, 0]]
+        num_inputs = 2
         if has_bias:
             bias = relay.var("bias", shape=(k_shape[0],), dtype=dtype)
             out = relay.nn.bias_add(out, bias)
@@ -590,43 +666,48 @@ def test_dense(device, dtype):
                 "op": "const",
             }
             exp_codegen.append(bias_data_node)
-            bias_node = {
-                "attrs": {
-                    "num_inputs": "2",
-                    "num_outputs": "1",
-                    "dtype": [[dtype]],
-                    "shape": [[[x_shape[0], k_shape[0]]]],
-                },
-                "inputs": [[2, 0, 0], [3, 0, 0]],
-                "name": "add",
-                "op": "kernel",
-            }
-            exp_codegen.append(bias_node)
-
+            input_nodes.append([2, 0, 0])
+            num_inputs += 1
             params["bias"] = tvm.nd.array(np.random.uniform(-1, 1, 
(k_shape[0],)).astype(dtype))
 
+        dense_node = {
+            "attrs": {
+                "num_inputs": str(num_inputs),
+                "num_outputs": "1",
+                "dtype": [[dtype]],
+                "out_dtype": [[""]],
+                "shape": [[[x_shape[0], k_shape[0]]]],
+                "units": [[str(k_shape[0])]],
+            },
+            "inputs": input_nodes,
+            "name": "nn.dense",
+            "op": "kernel",
+        }
+        exp_codegen.append(dense_node)
+
         return out, params, inputs, exp_codegen
 
     def _verify(out, params, inputs, exp_codegen):
         mod = IRModule.from_expr(out)
-        opencl_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=False)[0]
-        clml_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=True)[0]
+        outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+        out_rtol = 1e-1 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-2, 
atol=1e-2
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
         )
-        verify_codegen(out, exp_codegen, device, params)
+        verify_codegen(remote, mod, params, exp_codegen, target)
 
-    _verify(*(_get_model((5, 16), (32, 16), False)))
-    _verify(*(_get_model((1, 16), (32, 16), True)))
+    _verify(*(_get_model(trials[0], trials[1], trials[2])))
 
 
-@pytest.mark.parametrize("dtype", ["float32"])
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
 @tvm.testing.requires_openclml
-def test_binary_ops(device, dtype):
-    def _get_model(a_shape, b_shape, op):
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_binary_ops(remote, dtype, target, executor_type):
+    def _get_model(a_shape, b_shape, op_func):
+        np.random.seed(0)
         a = relay.var("a", shape=(a_shape), dtype=dtype)
         b = relay.var("b", shape=(b_shape), dtype=dtype)
-        out = op(a, b)
+        out = op_func(a, b)
         inputs = {
             "a": tvm.nd.array(np.random.uniform(-1, 1, a_shape).astype(dtype)),
             "b": tvm.nd.array(np.random.uniform(-1, 1, b_shape).astype(dtype)),
@@ -636,32 +717,56 @@ def test_binary_ops(device, dtype):
 
     def _verify(out, params, inputs):
         mod = IRModule.from_expr(out)
-        opencl_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=False)[0]
-        clml_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=True)[0]
+        outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+        out_rtol = 1e-2 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, 
atol=1e-3
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
         )
-
-        # Check to make sure these ops are offloaded to CLML instead of TVM.
-        with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["AlterOpLayout"]):
-            mod = clml.partition_for_clml(mod, params)
-            tvm_op_count = get_cpu_op_count(mod)
-            assert tvm_op_count == 0, "Got {} TVM Native Compute partitions, 
expected 0".format(
-                tvm_op_count
-            )
+        exp_codegen = [
+            {
+                "attrs": {
+                    "dtype": [[dtype]],
+                    "shape": [[list(inputs["a"].shape)]],
+                },
+                "name": "",
+                "op": "input",
+            },
+            {
+                "attrs": {
+                    "dtype": [[dtype]],
+                    "shape": [[list(inputs["b"].shape)]],
+                },
+                "name": "",
+                "op": "input",
+            },
+            {
+                "attrs": {
+                    "dtype": [[dtype]],
+                    "num_inputs": "2",
+                    "num_outputs": "1",
+                    "shape": [[list(outputs[0].shape)]],
+                },
+                "inputs": [[0, 0, 0], [1, 0, 0]],
+                "name": str(out.op.name),
+                "op": "kernel",
+            },
+        ]
+        verify_codegen(remote, mod, params, exp_codegen, target)
 
     _verify(*(_get_model((1, 16), (1, 16), relay.add)))
-    _verify(*(_get_model((1, 16), (1, 16), relay.subtract)))
-    _verify(*(_get_model((1, 16), (1, 16), relay.multiply)))
-    _verify(*(_get_model((1, 16), (1, 16), relay.divide)))
+    _verify(*(_get_model((1, 18), (1, 18), relay.subtract)))
+    _verify(*(_get_model((1, 256), (1, 256), relay.multiply)))
+    _verify(*(_get_model((1, 10), (1, 10), relay.divide)))
     _verify(*(_get_model((1, 16), (1, 16), relay.minimum)))
-    _verify(*(_get_model((1, 16), (1, 16), relay.maximum)))
+    _verify(*(_get_model((1, 512), (1, 512), relay.maximum)))
 
 
-@pytest.mark.parametrize("dtype", ["float32"])
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
 @tvm.testing.requires_openclml
-def test_unary_ops(device, dtype):
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_unary_ops(remote, dtype, target, executor_type):
     def _get_model(a_shape, op):
+        np.random.seed(0)
         a = relay.var("a", shape=(a_shape), dtype=dtype)
         out = op(a)
         inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, 
a_shape).astype(dtype))}
@@ -670,28 +775,45 @@ def test_unary_ops(device, dtype):
 
     def _verify(out, params, inputs):
         mod = IRModule.from_expr(out)
-        opencl_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=False)[0]
-        clml_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=True)[0]
+        outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+        out_rtol = 1e-2 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, 
atol=1e-3
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
         )
 
-        # Check to make sure these ops are offloaded to CLML instead of TVM.
-        with tvm.transform.PassContext(opt_level=3, 
disabled_pass=["AlterOpLayout"]):
-            mod = clml.partition_for_clml(mod, params)
-            tvm_op_count = get_cpu_op_count(mod)
-            assert tvm_op_count == 0, "Got {} TVM Native Compute partitions, 
expected 0".format(
-                tvm_op_count
-            )
+        exp_codegen = [
+            {
+                "attrs": {
+                    "dtype": [[dtype]],
+                    "shape": [[list(inputs["a"].shape)]],
+                },
+                "name": "",
+                "op": "input",
+            },
+            {
+                "attrs": {
+                    "dtype": [[dtype]],
+                    "num_inputs": "1",
+                    "num_outputs": "1",
+                    "shape": [[list(outputs[0].shape)]],
+                },
+                "inputs": [[0, 0, 0]],
+                "name": "nn.relu",
+                "op": "kernel",
+            },
+        ]
+        verify_codegen(remote, mod, params, exp_codegen, target)
 
-    _verify(*(_get_model((1, 16), relay.nn.softmax)))
     _verify(*(_get_model((1, 16), relay.nn.relu)))
+    _verify(*(_get_model((1, 256), relay.nn.relu)))
 
 
 @pytest.mark.parametrize("dtype", ["float32", "float16"])
 @tvm.testing.requires_openclml
-def test_depth_to_space(device, dtype):
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_depth_to_space(remote, dtype, target, executor_type):
     def _get_model(a_shape, block_size):
+        np.random.seed(0)
         a = relay.var("a", shape=(a_shape), dtype=dtype)
         out = relay.nn.depth_to_space(a, block_size)
         inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, 
a_shape).astype(dtype))}
@@ -700,10 +822,10 @@ def test_depth_to_space(device, dtype):
 
     def _verify(out, params, inputs):
         mod = IRModule.from_expr(out)
-        opencl_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=False)[0]
-        clml_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=True)[0]
+        outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+        out_rtol = 1e-2 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, 
atol=1e-3
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
         )
 
         # Check to make sure these ops are offloaded to CLML instead of TVM.
@@ -724,23 +846,26 @@ def test_depth_to_space(device, dtype):
                     "dtype": [[dtype]],
                     "num_inputs": "1",
                     "num_outputs": "1",
-                    "shape": [[list(clml_out[0].shape)]],
+                    "shape": [[list(outputs[0].shape)]],
                 },
                 "inputs": [[0, 0, 0]],
                 "name": "nn.depth_to_space",
                 "op": "kernel",
             },
         ]
-        verify_codegen(out, exp_codegen, device, params)
+        verify_codegen(remote, mod, params, exp_codegen, target)
 
     _verify(*(_get_model((1, 64, 8, 8), 4)))
     _verify(*(_get_model((1, 64, 8, 8), 8)))
+    _verify(*(_get_model((1, 512, 8, 8), 8)))
 
 
 @pytest.mark.parametrize("dtype", ["float32", "float16"])
 @tvm.testing.requires_openclml
-def test_resize_bilinear(device, dtype):
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_resize_bilinear(remote, dtype, target, executor_type):
     def _get_model(a_shape, scale, align_corners):
+        np.random.seed(0)
         a = relay.var("a", shape=(a_shape), dtype=dtype)
         out = relay.nn.upsampling(
             a, scale_h=scale[0], scale_w=scale[1], method="bilinear", 
align_corners=align_corners
@@ -751,10 +876,10 @@ def test_resize_bilinear(device, dtype):
 
     def _verify(out, params, inputs):
         mod = IRModule.from_expr(out)
-        opencl_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=False)[0]
-        clml_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=True)[0]
+        outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+        out_rtol = 1e-2 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, 
atol=1e-3
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
         )
 
         # Check to make sure these ops are offloaded to CLML instead of TVM.
@@ -777,23 +902,35 @@ def test_resize_bilinear(device, dtype):
                     "dtype": [[dtype]],
                     "num_inputs": "1",
                     "num_outputs": "1",
-                    "shape": [[list(clml_out[0].shape)]],
+                    "shape": [[list(outputs[0].shape)]],
                 },
                 "inputs": [[0, 0, 0]],
                 "name": "nn.upsampling",
                 "op": "kernel",
             },
         ]
-        verify_codegen(out, exp_codegen, device, params)
+        verify_codegen(remote, mod, params, exp_codegen, target)
 
     _verify(*(_get_model((1, 16, 8, 8), (2, 2), False)))
     _verify(*(_get_model((1, 16, 7, 7), (2, 2), True)))
+    _verify(*(_get_model((1, 64, 8, 8), (2, 2), True)))
 
 
-@pytest.mark.parametrize("dtype", ["float32"])
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@pytest.mark.parametrize(
+    "trials",
+    [
+        [(1, 512, 32), (1, 512, 32), False, True],
+        [(1, 128, 32), (1, 128, 32), False, True],
+        [(1, 128, 128), (1, 32, 128), False, True],
+        [(1, 64, 40), (1, 64, 40), False, True],
+    ],
+)
 @tvm.testing.requires_openclml
-def test_batch_matmul(device, dtype):
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_batch_matmul(remote, dtype, target, executor_type, trials):
     def _get_model(a_shape, b_shape, a_transpose, b_transpose):
+        np.random.seed(0)
         a = relay.var("a", shape=(a_shape), dtype=dtype)
         b = relay.var("b", shape=(b_shape), dtype=dtype)
         out = relay.nn.batch_matmul(a, b, transpose_a=a_transpose, 
transpose_b=b_transpose)
@@ -806,10 +943,10 @@ def test_batch_matmul(device, dtype):
 
     def _verify(out, params, inputs):
         mod = IRModule.from_expr(out)
-        opencl_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=False)[0]
-        clml_out = build_and_run(mod, inputs, 1, params, device, 
enable_clml=True)[0]
+        outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+        out_rtol = 1e-1 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            clml_out[0].asnumpy(), opencl_out[0].asnumpy(), rtol=1e-3, 
atol=1e-3
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
         )
 
         # Check to make sure these ops are offloaded to CLML instead of TVM.
@@ -838,17 +975,320 @@ def test_batch_matmul(device, dtype):
                     "dtype": [[dtype]],
                     "num_inputs": "2",
                     "num_outputs": "1",
-                    "shape": [[list(clml_out[0].shape)]],
+                    "shape": [[list(outputs[0].shape)]],
                 },
                 "inputs": [[0, 0, 0], [1, 0, 0]],
                 "name": "nn.batch_matmul",
                 "op": "kernel",
             },
         ]
-        verify_codegen(out, exp_codegen, device, params)
+        verify_codegen(remote, mod, params, exp_codegen, target)
+
+    _verify(*(_get_model(trials[0], trials[1], trials[2], trials[3])))
+
+
+def _get_softmax_exp_codegen(inputs, dtype, output_shape, axis):
+
+    exp_codegen = [
+        {
+            "attrs": {
+                "dtype": [[dtype]],
+                "shape": [[list(inputs["a"].shape)]],
+            },
+            "name": "",
+            "op": "input",
+        },
+        {
+            "attrs": {
+                "axis": [[str(axis)]],
+                "dtype": [[dtype]],
+                "num_inputs": "1",
+                "num_outputs": "1",
+                "shape": [[list(output_shape)]],
+            },
+            "inputs": [[0, 0, 0]],
+            "name": "nn.softmax",
+            "op": "kernel",
+        },
+    ]
+    return exp_codegen
+
+
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@tvm.testing.requires_openclml
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_softmax(remote, dtype, target, executor_type):
+    def _get_model(a_shape, axis):
+        np.random.seed(0)
+        a = relay.var("a", shape=(a_shape), dtype=dtype)
+        inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, 
a_shape).astype(dtype))}
+        out = relay.nn.softmax(a, axis)
+        params = {}
+        return out, params, inputs, axis
+
+    def _verify(out, params, inputs, axis):
+        mod = IRModule.from_expr(out)
+        outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+        out_rtol = 1e-1 if dtype == "float16" else 1e-5
+        tvm.testing.assert_allclose(
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+        )
+        args = (inputs, dtype, outputs[0].shape, axis)
+        exp_codegen = _get_softmax_exp_codegen(*args)
+        verify_codegen(remote, mod, params, exp_codegen, target)
+
+    _verify(*(_get_model((1, 5), 1)))
+    _verify(*(_get_model((1, 1000), 1)))
+    _verify(*(_get_model((1, 3), 1)))
+
+
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@pytest.mark.parametrize(
+    "trials",
+    [
+        [(1, 1, 2, 2), 2, 1],
+        [(1, 16, 2, 2), 4, 4],
+        [(1, 8, 4, 4), 3, 2],
+    ],
+)
+@tvm.testing.requires_openclml
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_upsampling(remote, dtype, target, executor_type, trials):
+    def _verify(in_shape, scale_h, scale_w):
+        np.random.seed(0)
+        a = relay.var("a", shape=in_shape, dtype=dtype)
+        inputs = {
+            "a": tvm.nd.array(np.random.uniform(-1, 1, 
in_shape).astype(dtype)),
+        }
+        params = {}
+        func = relay.nn.upsampling(
+            a, scale_h, scale_w, layout="NCHW", method="bilinear", 
align_corners=False
+        )
+        mod = IRModule.from_expr(func)
+        outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+        out_rtol = 1e-2 if dtype == "float16" else 1e-5
+        tvm.testing.assert_allclose(
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+        )
+        exp_codegen = [
+            {
+                "attrs": {"dtype": [[dtype]], "shape": 
[[list(inputs["a"].shape)]]},
+                "name": "",
+                "op": "input",
+            },
+            {
+                "attrs": {
+                    "align_corners": [["0"]],
+                    "dtype": [[dtype]],
+                    "layout": [["NCHW"]],
+                    "method": [["bilinear"]],
+                    "num_inputs": "1",
+                    "num_outputs": "1",
+                    "scale_h": [[str(scale_h)]],
+                    "scale_w": [[str(scale_w)]],
+                    "shape": [[list(outputs[0].shape)]],
+                },
+                "inputs": [[0, 0, 0]],
+                "name": "nn.upsampling",
+                "op": "kernel",
+            },
+        ]
+        verify_codegen(remote, mod, params, exp_codegen, target)
+
+    _verify(trials[0], trials[1], trials[2])
+
+
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@pytest.mark.parametrize(
+    "trials",
+    [
+        [(1, 40, 64, 64), (1, 40, 4096)],
+        [(1, 77, 768), (1, 1, -1, 768)],
+        [(1, 80, 32, 32), (1, 80, 1024)],
+        [(1, 2, 3, 4), (1, 0, -1)],
+    ],
+)
+@tvm.testing.requires_openclml
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_reshape(remote, dtype, target, executor_type, trials):
+    def _verify(shape, newshape):
+        np.random.seed(0)
+        x = relay.var("x", shape=(shape), dtype=dtype)
+        # Defined the test case with unary operator
+        # Single reshape op is failing in native OpenCL with vm executor type
+        # Empty TVM mod in VM doesn't pick appropriate cross compiler
+        out = relay.nn.relu(x)
+        out = relay.reshape(out, newshape)
+
+        inputs = {"x": tvm.nd.array(np.random.uniform(-1, 1, 
shape).astype(dtype))}
+        params = {}
+        mod = IRModule.from_expr(out)
+        outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+        out_rtol = 1e-3 if dtype == "float16" else 1e-5
+        tvm.testing.assert_allclose(
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+        )
+        exp_codegen = [
+            {
+                "attrs": {"dtype": [[dtype]], "shape": 
[[list(inputs["x"].shape)]]},
+                "name": "",
+                "op": "input",
+            },
+            {
+                "attrs": {
+                    "dtype": [[dtype]],
+                    "num_inputs": "1",
+                    "num_outputs": "1",
+                    "shape": [[list(inputs["x"].shape)]],
+                },
+                "inputs": [[0, 0, 0]],
+                "name": "nn.relu",
+                "op": "kernel",
+            },
+            {
+                "attrs": {
+                    "allowzero": [["0"]],
+                    "dtype": [[dtype]],
+                    "newshape": [[str(ele) for ele in list(newshape)]],
+                    "num_inputs": "1",
+                    "num_outputs": "1",
+                    "shape": [[list(outputs[0].shape)]],
+                },
+                "inputs": [[1, 0, 0]],
+                "name": "reshape",
+                "op": "kernel",
+            },
+        ]
+        verify_codegen(remote, mod, params, exp_codegen, target)
+
+    _verify(trials[0], trials[1])
+
+
+def _get_pool_global_expected_codegen(input_shape, pool_type, dtype, 
out_shape):
+
+    exp_codegen = [
+        {
+            "attrs": {
+                "dtype": [[str(dtype)]],
+                "shape": [[list(input_shape)]],
+            },
+            "name": "",
+            "op": "input",
+        },
+        {
+            "attrs": {
+                "dtype": [[str(dtype)]],
+                "layout": [["NCHW"]],
+                "num_inputs": "1",
+                "num_outputs": "1",
+                "out_layout": [[""]],
+                "shape": [[list(out_shape)]],
+            },
+            "inputs": [[0, 0, 0]],
+            "name": "nn.global_avg_pool2d" if pool_type == "avg" else 
"nn.global_max_pool2d",
+            "op": "kernel",
+        },
+    ]
+    return exp_codegen
+
+
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@pytest.mark.parametrize(
+    "trials",
+    [
+        [(1, 3, 32, 32), "avg"],
+        [(1, 64, 147, 147), "max"],
+        [(1, 192, 71, 71), "max"],
+        [(1, 288, 35, 35), "max"],
+        [(1, 768, 17, 17), "max"],
+        [(1, 2048, 17, 17), "max"],
+        [(1, 192, 35, 35), "avg"],
+        [(1, 256, 35, 35), "avg"],
+        [(1, 288, 35, 35), "avg"],
+        [(1, 768, 17, 17), "avg"],
+        [(1, 1280, 8, 8), "avg"],
+    ],
+)
+@tvm.testing.requires_openclml
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_pool_global(remote, dtype, target, executor_type, trials):
+    params = {}
+    (input_shape, pooling_type) = trials
+    np.random.seed(0)
+    a = relay.var("a", shape=input_shape, dtype=dtype)
+    inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, 
input_shape).astype(dtype))}
+    if pooling_type == "max":
+        func = relay.nn.global_max_pool2d(a)
+    else:
+        func = relay.nn.global_avg_pool2d(a)
+    mod = IRModule.from_expr(func)
+    outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+    out_rtol = 1e-3 if dtype == "float16" else 1e-5
+    tvm.testing.assert_allclose(
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+    )
+    args = (input_shape, pooling_type, dtype, outputs[0].shape)
+    exp_codegen = _get_pool_global_expected_codegen(*args)
+    verify_codegen(remote, mod, params, exp_codegen, target)
+
+
+@pytest.mark.parametrize("dtype", ["float32", "float16"])
+@tvm.testing.requires_openclml
+@tvm.testing.parametrize_targets("opencl -device=adreno")
+def test_batch_flatten(remote, dtype, target, executor_type):
+    def _get_model(a_shape):
+        a = relay.var("a", shape=(a_shape), dtype=dtype)
+        # Defined the test case with unary operator
+        # Single batch_flatten op is failing in native OpenCL
+        # Empty TVM mod in VM doesn't pick appropriate cross compiler
+        out = relay.nn.relu(a)
+        out = relay.nn.batch_flatten(out)
+        inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, 
a_shape).astype(dtype))}
+        params = {}
+        return out, params, inputs
+
+    def _verify(out, params, inputs):
+        mod = IRModule.from_expr(out)
+        outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
+        out_rtol = 1e-3 if dtype == "float16" else 1e-5
+        tvm.testing.assert_allclose(
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+        )
+        exp_codegen = [
+            {
+                "attrs": {"dtype": [[dtype]], "shape": 
[[list(inputs["a"].shape)]]},
+                "name": "",
+                "op": "input",
+            },
+            {
+                "attrs": {
+                    "dtype": [[dtype]],
+                    "num_inputs": "1",
+                    "num_outputs": "1",
+                    "shape": [[list(inputs["a"].shape)]],
+                },
+                "inputs": [[0, 0, 0]],
+                "name": "nn.relu",
+                "op": "kernel",
+            },
+            {
+                "attrs": {
+                    "dtype": [[dtype]],
+                    "num_inputs": "1",
+                    "num_outputs": "1",
+                    "shape": [[list(outputs[0].shape)]],
+                },
+                "inputs": [[1, 0, 0]],
+                "name": "nn.batch_flatten",
+                "op": "kernel",
+            },
+        ]
+        verify_codegen(remote, mod, params, exp_codegen, target)
 
-    _verify(*(_get_model((1, 128, 32), (1, 128, 32), False, True)))
-    _verify(*(_get_model((1, 128, 128), (1, 32, 128), False, True)))
+    _verify(*(_get_model((1, 3, 2))))
+    _verify(*(_get_model((1, 4, 3, 2))))
+    _verify(*(_get_model((1, 64, 8, 8))))
+    _verify(*(_get_model((1, 128, 4, 4))))
 
 
 if __name__ == "__main__":
diff --git a/tests/scripts/task_python_adreno.sh 
b/tests/scripts/task_python_adreno.sh
index 634d9adbd6..18e0feb815 100755
--- a/tests/scripts/task_python_adreno.sh
+++ b/tests/scripts/task_python_adreno.sh
@@ -31,6 +31,7 @@ export TVM_TRACKER_PORT=$(((RANDOM % 100) + 9100))
 export RPC_DEVICE_KEY="android"
 export RPC_TARGET="adreno"
 export 
TVM_NDK_CC="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang"
+export 
CXX="${ANDROID_NDK_HOME}/toolchains/llvm/prebuilt/linux-x86_64/bin/aarch64-linux-android28-clang"
 
 env PYTHONPATH=python python3 -m tvm.exec.rpc_tracker --host 
"${TVM_TRACKER_HOST}" --port "${TVM_TRACKER_PORT}" &
 TRACKER_PID=$!

Reply via email to