This is an automated email from the ASF dual-hosted git repository.

srk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a5e883e846 [RUNTIME][CLML] Fix for Softmax op for 4D tensors (#16328)
a5e883e846 is described below

commit a5e883e8465e11221d3f22d6ef2f61a1bfa5d1f2
Author: krishnaraj36 <quic_kvegi...@quicinc.com>
AuthorDate: Thu Jan 18 12:38:57 2024 +0530

    [RUNTIME][CLML] Fix for Softmax op for 4D tensors (#16328)
    
    Fixed the softmax layer for 4D tensors to support for NCHW and NHWC
    layout types.
    Enabled relevant test cases for softmax layer
---
 python/tvm/relay/op/contrib/clml.py        |  3 +-
 src/runtime/contrib/clml/clml_runtime.cc   | 62 ++++++++++++++++-----
 tests/python/contrib/test_clml/test_ops.py | 86 ++++++++++++++++--------------
 3 files changed, 98 insertions(+), 53 deletions(-)

diff --git a/python/tvm/relay/op/contrib/clml.py 
b/python/tvm/relay/op/contrib/clml.py
index 14dd35a3cb..53b022c347 100644
--- a/python/tvm/relay/op/contrib/clml.py
+++ b/python/tvm/relay/op/contrib/clml.py
@@ -437,7 +437,8 @@ def clml_pattern_table():
 
     def check_softmax_op(extract):
         call = extract
-        if len(call.args[0].checked_type.shape) > 2:
+        # supports 2D and 4D tensors
+        if len(call.args[0].checked_type.shape) not in [2, 4]:
             return False
         return True
 
diff --git a/src/runtime/contrib/clml/clml_runtime.cc 
b/src/runtime/contrib/clml/clml_runtime.cc
index aa1e2b82b6..8e69cb8bd1 100644
--- a/src/runtime/contrib/clml/clml_runtime.cc
+++ b/src/runtime/contrib/clml/clml_runtime.cc
@@ -511,6 +511,7 @@ class CLMLRuntime : public JSONRuntimeBase {
 
   /*!
    * \brief Create an CLML tensor from JSON node entry. Lookup storage map 
before creation.
+   * Update input placeholder for NHWC layout
    *
    * \param nid The node index of graph JSON.
    * \param shape shape information of tensor
@@ -528,15 +529,22 @@ class CLMLRuntime : public JSONRuntimeBase {
         uint32_t eid = EntryID(nid, 0);
         node_data = data_entry_[eid]->data;
       }
+
       auto clml_tensor = MakeCLMLTensorFromJSONNode(node, layout, dtype, 
node_data, shape);
+
       this->layer_.storage_map.insert({nid, std::make_pair(clml_tensor, 
node)});
 
       if ("input" == node.GetOpType()) {
         this->layer_.inputs.insert({nid, this->layer_.storage_map[nid].first});
         // Input copy placeholder Tensor
-        this->layer_.in_placeholder.insert(
-            {nid, MakeCLMLTensorFromJSONNode(node, CL_TENSOR_LAYOUT_NCHW_QCOM, 
dtype, node_data,
-                                             shape)});
+        if (layout == CL_TENSOR_LAYOUT_OPTIMAL_QCOM) {
+          this->layer_.in_placeholder.insert(
+              {nid, MakeCLMLTensorFromJSONNode(node, 
CL_TENSOR_LAYOUT_NCHW_QCOM, dtype, node_data,
+                                               shape)});
+        } else {
+          this->layer_.in_placeholder.insert(
+              {nid, MakeCLMLTensorFromJSONNode(node, layout, dtype, node_data, 
shape)});
+        }
       }
 
       return clml_tensor;
@@ -559,6 +567,7 @@ class CLMLRuntime : public JSONRuntimeBase {
       const auto& node = nodes_[nid];
       if ("nn.dense" == node.GetOpName()) CreateDenseLayerTensor(&layer_, 
node, nid);
       if ("nn.batch_matmul" == node.GetOpName()) 
CreateBatchMatmulLayerTensor(&layer_, node, nid);
+      if ("nn.softmax" == node.GetOpName()) CreateSoftmaxLayerTensor(&layer_, 
node, nid);
     }
 
     for (nid = 0; nid < nodes_.size(); ++nid) {
@@ -1092,6 +1101,37 @@ class CLMLRuntime : public JSONRuntimeBase {
     return;
   }
 
+  /*!
+   * \brief Create a Softmax layer Tensors with supported layout.
+   * \param layer The CLML layer to build. Containing inputs, outputs and the 
CLML function.
+   * \param node The JSON representation of the operator.
+   * \param nid The node index of JSON graph node, which points to this 
operator.
+   */
+
+  void CreateSoftmaxLayerTensor(CachedLayer* layer, const JSONGraphNode& node, 
size_t nid) {
+    cl_ml_tensor_layout_qcom layout;
+    cl_int result = 0;
+    cl_ml_op_qcom op = nullptr;
+    DLDataType tvm_dtype = node.GetOpDataType()[0];
+    cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
+    auto out_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]);
+    int axis = std::stoi(node.GetAttr<std::vector<std::string>>("axis")[0]);
+    // enabling  NHWC layout && NCHW layout for 4D,  basis the axis value
+    if (out_dims.h >= 1 && out_dims.w >= 1) {
+      if (axis == 3 || axis == -1) {
+        layout = CL_TENSOR_LAYOUT_NHWC_QCOM;
+      } else {
+        layout = CL_TENSOR_LAYOUT_NCHW_QCOM;
+      }
+    } else {  // default layout for 2D
+      layout = CL_TENSOR_LAYOUT_OPTIMAL_QCOM;
+    }
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, layout, cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {}, 
layout, cl_dtype);
+
+    return;
+  }
+
   /*!
    * \brief Create a SoftMax layer.
    *
@@ -1100,24 +1140,20 @@ class CLMLRuntime : public JSONRuntimeBase {
    * \param nid The node index of JSON graph node, which points to this 
operator.
    */
   void CreateSoftMaxLayer(CachedLayer* layer, const JSONGraphNode& node, 
size_t nid) {
+    cl_ml_tensor_layout_qcom layout;
+    cl_softmax_mode_qcom mode = CL_SOFTMAX_MODE_SPATIAL_QCOM;
     cl_int result = 0;
     cl_ml_op_qcom op = nullptr;
     DLDataType tvm_dtype = node.GetOpDataType()[0];
     cl_channel_type cl_dtype = MakeCLDataType(tvm_dtype);
     cl_arithmetic_mode_qcom cl_arithmetic_mode = MakeCLArithMode(cl_dtype, 
cl_dtype);
-    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {},
-                                             CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
-    auto out_dims = GetTensorDims(nodes_[node.GetInputs()[0].id_]);
-    auto output = MakeCLMLTensorFromJSONEntry(nid, {out_dims.n, out_dims.c, 1, 
1},
-                                              CL_TENSOR_LAYOUT_OPTIMAL_QCOM, 
cl_dtype);
-
-    cl_ml_op_softmax_desc_qcom softmax_desc = 
{CL_SOFTMAX_ALGORITHM_ACCURATE_QCOM,
-                                               CL_SOFTMAX_MODE_INSTANCE_QCOM, 
cl_arithmetic_mode};
-
+    auto output = MakeCLMLTensorFromJSONEntry(nid, {}, layout, cl_dtype);
+    auto input = MakeCLMLTensorFromJSONEntry(node.GetInputs()[0].id_, {}, 
layout, cl_dtype);
+    cl_ml_op_softmax_desc_qcom softmax_desc = 
{CL_SOFTMAX_ALGORITHM_ACCURATE_QCOM, mode,
+                                               cl_arithmetic_mode};
     result = CLML_INTF->clCreateMLOpSoftmaxQCOM(CLML_CTX, nullptr, 
&softmax_desc, input->tensor,
                                                 output->tensor, &op, 
layer_.tuning_cache);
     ICHECK(op && result == CL_SUCCESS) << "SoftMax Error:" << result;
-
     layer->function.push_back(op);
     return;
   }
diff --git a/tests/python/contrib/test_clml/test_ops.py 
b/tests/python/contrib/test_clml/test_ops.py
index 58365bf429..3d89994126 100644
--- a/tests/python/contrib/test_clml/test_ops.py
+++ b/tests/python/contrib/test_clml/test_ops.py
@@ -280,9 +280,9 @@ def test_conv2d(remote, dtype, target, trials, 
executor_type):
         has_activation=composite[2],
     )
     outputs = _build_and_run_network(remote, func, params, inputs, target, 
executor_type)
-    out_rtol = 1e-1 if dtype == "float16" else 1e-5
+    out_tol = 1e-1 if dtype == "float16" else 1e-5
     tvm.testing.assert_allclose(
-        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
     )
     args = (shape, kernel_h, kernel_w, pad, stride, dilation, groups, dtype, 
out_channels)
     exp_codegen = _get_conv_expected_codegen(
@@ -373,9 +373,9 @@ def test_conv2d_transpose(remote, dtype, target, trials, 
executor_type):
     func = relay.Function([x, w], y)
     mod = IRModule.from_expr(func)
     outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-    out_rtol = 1e-1 if dtype == "float16" else 1e-5
+    out_tol = 1e-1 if dtype == "float16" else 1e-5
     tvm.testing.assert_allclose(
-        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
     )
     args = (
         dshape,
@@ -425,9 +425,9 @@ def test_batchnorm(remote, dtype, target, trials, 
executor_type):
         "a": input_arr,
     }
     outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-    out_rtol = 1e-3 if dtype == "float16" else 1e-5
+    out_tol = 1e-3 if dtype == "float16" else 1e-5
     tvm.testing.assert_allclose(
-        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
     )
     exp_codegen = [
         {
@@ -485,9 +485,9 @@ def test_concat(remote, dtype, target, trials, 
executor_type):
     func = relay.concatenate((a, b), axis=1)
 
     outputs = _build_and_run_network(remote, func, params, inputs, target, 
executor_type)
-    out_rtol = 1e-2 if dtype == "float16" else 1e-5
+    out_tol = 1e-2 if dtype == "float16" else 1e-5
     tvm.testing.assert_allclose(
-        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
     )
 
     exp_codegen = [
@@ -601,9 +601,9 @@ def test_pool(remote, dtype, target, trials, executor_type):
         func = relay.nn.avg_pool2d(a, pool_size=pool_size, strides=stride, 
padding=padding)
 
     outputs = _build_and_run_network(remote, func, params, inputs, target, 
executor_type)
-    out_rtol = 1e-2 if dtype == "float16" else 1e-5
+    out_tol = 1e-2 if dtype == "float16" else 1e-5
     tvm.testing.assert_allclose(
-        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
     )
     args = (input_shape, pool_size, stride, padding, pooling_type, dtype)
     exp_codegen = _get_pool_expected_codegen(*args)
@@ -690,9 +690,9 @@ def test_dense(remote, dtype, target, trials, 
executor_type):
     def _verify(out, params, inputs, exp_codegen):
         mod = IRModule.from_expr(out)
         outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-        out_rtol = 1e-1 if dtype == "float16" else 1e-5
+        out_tol = 1e-1 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, 
atol=out_tol
         )
         verify_codegen(remote, mod, params, exp_codegen, target)
 
@@ -718,9 +718,9 @@ def test_binary_ops(remote, dtype, target, executor_type):
     def _verify(out, params, inputs):
         mod = IRModule.from_expr(out)
         outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-        out_rtol = 1e-2 if dtype == "float16" else 1e-5
+        out_tol = 1e-2 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, 
atol=out_tol
         )
         exp_codegen = [
             {
@@ -776,9 +776,9 @@ def test_unary_ops(remote, dtype, target, executor_type):
     def _verify(out, params, inputs):
         mod = IRModule.from_expr(out)
         outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-        out_rtol = 1e-2 if dtype == "float16" else 1e-5
+        out_tol = 1e-2 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, 
atol=out_tol
         )
 
         exp_codegen = [
@@ -823,12 +823,11 @@ def test_depth_to_space(remote, dtype, target, 
executor_type):
     def _verify(out, params, inputs):
         mod = IRModule.from_expr(out)
         outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-        out_rtol = 1e-2 if dtype == "float16" else 1e-5
+        out_tol = 1e-2 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, 
atol=out_tol
         )
 
-        # Check to make sure these ops are offloaded to CLML instead of TVM.
         exp_codegen = [
             {
                 "attrs": {
@@ -877,12 +876,11 @@ def test_resize_bilinear(remote, dtype, target, 
executor_type):
     def _verify(out, params, inputs):
         mod = IRModule.from_expr(out)
         outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-        out_rtol = 1e-2 if dtype == "float16" else 1e-5
+        out_tol = 1e-2 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, 
atol=out_tol
         )
 
-        # Check to make sure these ops are offloaded to CLML instead of TVM.
         exp_codegen = [
             {
                 "attrs": {
@@ -944,12 +942,11 @@ def test_batch_matmul(remote, dtype, target, 
executor_type, trials):
     def _verify(out, params, inputs):
         mod = IRModule.from_expr(out)
         outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-        out_rtol = 1e-1 if dtype == "float16" else 1e-5
+        out_tol = 1e-1 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, 
atol=out_tol
         )
 
-        # Check to make sure these ops are offloaded to CLML instead of TVM.
         exp_codegen = [
             {
                 "attrs": {
@@ -1026,20 +1023,30 @@ def test_softmax(remote, dtype, target, executor_type):
         params = {}
         return out, params, inputs, axis
 
-    def _verify(out, params, inputs, axis):
+    def _verify(out, params, inputs, axis, out_tol):
         mod = IRModule.from_expr(out)
         outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-        out_rtol = 1e-1 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+            outputs[0].asnumpy(), outputs[1].numpy(), rtol=out_tol, 
atol=out_tol
         )
         args = (inputs, dtype, outputs[0].shape, axis)
         exp_codegen = _get_softmax_exp_codegen(*args)
         verify_codegen(remote, mod, params, exp_codegen, target)
 
-    _verify(*(_get_model((1, 5), 1)))
-    _verify(*(_get_model((1, 1000), 1)))
-    _verify(*(_get_model((1, 3), 1)))
+    # 2D Tensor  TEST CASES
+    _verify(*(_get_model((1, 5), 1)), 1e-3)
+    _verify(*(_get_model((1, 16), 1)), 1e-3)
+    _verify(*(_get_model((1, 1000), -1)), 1e-3)
+
+    # 4D Tensor  TEST CASES  layout = NCHW
+    _verify(*(_get_model((1, 100, 64, 100), 1)), 1e-3)
+    _verify(*(_get_model((1, 64, 64, 64), 1)), 1e-3)
+    _verify(*(_get_model((1, 5, 3, 4), 1)), 1e-3)
+
+    # 4D Tensor  TEST CASES  layout = NHWC
+    _verify(*(_get_model((1, 64, 100, 100), 3)), 1e-1)
+    _verify(*(_get_model((1, 100, 100, 100), 3)), 1e-1)
+    _verify(*(_get_model((1, 64, 5, 32), -1)), 1e-1)
 
 
 @pytest.mark.parametrize("dtype", ["float32", "float16"])
@@ -1066,9 +1073,9 @@ def test_upsampling(remote, dtype, target, executor_type, 
trials):
         )
         mod = IRModule.from_expr(func)
         outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-        out_rtol = 1e-2 if dtype == "float16" else 1e-5
+        out_tol = 1e-2 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, 
atol=out_tol
         )
         exp_codegen = [
             {
@@ -1124,9 +1131,9 @@ def test_reshape(remote, dtype, target, executor_type, 
trials):
         params = {}
         mod = IRModule.from_expr(out)
         outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-        out_rtol = 1e-3 if dtype == "float16" else 1e-5
+        out_tol = 1e-3 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, 
atol=out_tol
         )
         exp_codegen = [
             {
@@ -1223,9 +1230,9 @@ def test_pool_global(remote, dtype, target, 
executor_type, trials):
         func = relay.nn.global_avg_pool2d(a)
     mod = IRModule.from_expr(func)
     outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-    out_rtol = 1e-3 if dtype == "float16" else 1e-5
+    out_tol = 1e-3 if dtype == "float16" else 1e-5
     tvm.testing.assert_allclose(
-        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+        outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, atol=out_tol
     )
     args = (input_shape, pooling_type, dtype, outputs[0].shape)
     exp_codegen = _get_pool_global_expected_codegen(*args)
@@ -1241,6 +1248,7 @@ def test_batch_flatten(remote, dtype, target, 
executor_type):
         # Defined the test case with unary operator
         # Single batch_flatten op is failing in native OpenCL
         # Empty TVM mod in VM doesn't pick appropriate cross compiler
+        np.random.seed(0)
         out = relay.nn.relu(a)
         out = relay.nn.batch_flatten(out)
         inputs = {"a": tvm.nd.array(np.random.uniform(-1, 1, 
a_shape).astype(dtype))}
@@ -1250,9 +1258,9 @@ def test_batch_flatten(remote, dtype, target, 
executor_type):
     def _verify(out, params, inputs):
         mod = IRModule.from_expr(out)
         outputs = _build_and_run_network(remote, mod, params, inputs, target, 
executor_type)
-        out_rtol = 1e-3 if dtype == "float16" else 1e-5
+        out_tol = 1e-3 if dtype == "float16" else 1e-5
         tvm.testing.assert_allclose(
-            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_rtol, 
atol=out_rtol
+            outputs[0].asnumpy(), outputs[1].asnumpy(), rtol=out_tol, 
atol=out_tol
         )
         exp_codegen = [
             {

Reply via email to