comaniac commented on code in PR #11326:
URL: https://github.com/apache/tvm/pull/11326#discussion_r887270920


##########
python/tvm/relay/op/contrib/libxsmm.py:
##########
@@ -0,0 +1,106 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Patterns supported LIBXSMM."""
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm.ir.transform import Sequential
+from tvm.relay import transform
+
+from ...dataflow_pattern import wildcard, is_op, is_constant
+from .register import register_pattern_table
+
+
+def use_libxsmm(m, n, k):

Review Comment:
   This name is misleading. Better to be more specific, like 
`use_dense_libxsmm` or something like that.
   Another option is inlining this function into `check_dense_shape`, and use 
that in `dense` as well. This looks more clean to me:
   
   ```python
   def check_dense_shape(call):
       dense = get_root_call(call, "nn.dense")
       data = dense.args[0].checked_type
       weight = dense.args[1].checked_type
       m = int(data.shape[0])
       n = int(weight.shape[0])
       k = int(data.shape[1])
       return bool(np.cbrt(m * n * k) <= 256)
   
   @tvm.ir.register_op_attr("nn.dense", "target.libxsmm")
   def dense(expr):
       return check_dense_shape(expr)
   
   ...
   ```



##########
cmake/modules/contrib/LIBXSMM.cmake:
##########
@@ -0,0 +1,31 @@
+if(IS_DIRECTORY ${USE_LIBXSMM})
+  find_library(LIBXSMM_LIBRARY NAMES xsmm HINTS ${USE_LIBXSMM}/lib/)
+  if(LIBXSMM_LIBRARY STREQUAL "LIBXSMM_LIBRARY-NOTFOUND")
+    message(WARNING "Cannot find LIBXSMM library at ${USE_LIBXSMM}/lib/.")
+  else()
+    include_directories(SYSTEM ${USE_LIBXSMM}/include)
+    list(APPEND TVM_RUNTIME_LINKER_LIBS ${LIBXSMM_LIBRARY})
+    MESSAGE(STATUS "Use LIBXSMM library " ${LIBXSMM_LIBRARY})
+
+    tvm_file_glob(GLOB LIBXSMM_RELAY_CONTRIB_SRC 
src/relay/backend/contrib/libxsmm/*.cc)
+    list(APPEND COMPILER_SRCS ${LIBXSMM_RELAY_CONTRIB_SRC})
+    tvm_file_glob(GLOB LIBXSMM_RUNTIME_SRC src/runtime/contrib/libxsmm/*.cc)
+    list(APPEND RUNTIME_SRCS ${LIBXSMM_RUNTIME_SRC})
+  endif()
+elseif(USE_LIBXSMM STREQUAL "ON")
+  find_library(LIBXSMM_LIBRARY xsmm)
+  if(LIBXSMM_LIBRARY STREQUAL "LIBXSMM_LIBRARY-NOTFOUND")
+    message(WARNING "Cannot find LIBXSMM library at $(USE_LIBXSMM).")
+  else()
+    list(APPEND TVM_RUNTIME_LINKER_LIBS ${LIBXSMM_LIBRARY})
+    MESSAGE(STATUS "Use LIBXSMM library " ${LIBXSMM_LIBRARY})
+
+    tvm_file_glob(GLOB LIBXSMM_RELAY_CONTRIB_SRC 
src/relay/backend/contrib/libxsmm/*.cc)
+    list(APPEND COMPILER_SRCS ${LIBXSMM_RELAY_CONTRIB_SRC})
+    tvm_file_glob(GLOB LIBXSMM_RUNTIME_SRC src/runtime/contrib/libxsmm/*.cc)
+    list(APPEND RUNTIME_SRCS ${LIBXSMM_RUNTIME_SRC})
+  endif()

Review Comment:
   Seems there has many redundant. Can they be merged?



##########
src/runtime/contrib/libxsmm/libxsmm_json_runtime.cc:
##########
@@ -0,0 +1,191 @@
+#include <libxsmm.h>
+#include <libxsmm_typedefs.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+class LibxsmmJSONRuntime : public json::JSONRuntimeBase {
+ public:
+  LibxsmmJSONRuntime(const std::string& symbol_name, const std::string& 
graph_json,
+                     const Array<String> const_names)
+      : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+
+  const char* type_key() const { return "libxsmm_json"; }
+
+  void Init(const Array<NDArray>& consts) override {
+    for (size_t nid = 0; nid < nodes_.size(); ++nid) {
+      auto& node = nodes_[nid];
+      if (node.GetOpType() == "kernel") {
+        auto op_name = node.GetOpName();
+
+        // Check if has bias or relu fusion.
+        has_bias_ = op_name.find("_bias") != std::string::npos;
+        has_relu_ = op_name.find("_relu") != std::string::npos;
+
+        // Get M, N, K, lda, ldb, ldc.
+        auto data_entry = node.GetInputs()[0];
+        auto weight_entry = node.GetInputs()[1];
+        json::JSONGraphNodeEntry out_entry(nid, 0);
+
+        std::vector<int64_t> input_shape = 
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
+        std::vector<int64_t> weight_shape =
+            nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
+        std::vector<int64_t> out_shape = 
nodes_[out_entry.id_].GetOpShape()[out_entry.index_];
+
+        M = input_shape[0];
+        N = weight_shape[0];
+        K = input_shape[1];
+
+        int lda = N;
+        int ldb = K;
+        int ldc = N;
+
+        // Curently we support fp32 only.
+        libxsmm_datatype dtype = LIBXSMM_DATATYPE_F32;
+
+        // Configure GEMM related parameters
+        libxsmm_bitfield l_flags = LIBXSMM_GEMM_FLAG_NONE | 
LIBXSMM_GEMM_FLAG_BETA_0;
+        libxsmm_bitfield l_prefetch_flags = LIBXSMM_GEMM_PREFETCH_NONE;
+        libxsmm_gemm_shape l_shape =
+            libxsmm_create_gemm_shape(N, M, K, lda, ldb, ldc, dtype, dtype, 
dtype, dtype);
+        libxsmm_blasint stride_a = N * K * sizeof(float);
+        libxsmm_blasint stride_b = K * M * sizeof(float);
+        libxsmm_gemm_batch_reduce_config l_brconfig = 
libxsmm_create_gemm_batch_reduce_config(
+            LIBXSMM_GEMM_BATCH_REDUCE_STRIDE, stride_a, stride_b, 0 
/*br_unrool_hint*/);
+
+        libxsmm_gemm_ext_unary_argops l_argops;
+        libxsmm_gemm_ext_binary_postops l_postops;
+        memset(&l_argops, 0, sizeof(libxsmm_gemm_ext_unary_argops));
+        memset(&l_postops, 0, sizeof(libxsmm_gemm_ext_binary_postops));
+
+        if (has_bias_) {
+          l_postops.d_in_type = dtype;
+          l_postops.d_binary_flags = LIBXSMM_MELTW_FLAG_BINARY_BCAST_COL_IN_0;
+          l_postops.d_binary_type = LIBXSMM_MELTW_TYPE_BINARY_ADD;
+          l_postops.ldd = ldc;
+        }
+
+        if (has_relu_) {
+          l_argops.cp_unary_flags = LIBXSMM_MELTW_FLAG_UNARY_NONE;
+          l_argops.cp_unary_type = LIBXSMM_MELTW_TYPE_UNARY_RELU;
+          l_argops.ldcp = ldc;
+          // relu mask should have the same size as matrix C.
+          relu_mask_.resize(M * N, 0);
+        }
+
+        // Use "libxsmm_gemmfunction" for GEMM kernel, and 
"libxsmm_gemmfunction_ext" for fused GEMM
+        // kernel.
+        if (has_bias_ || has_relu_) {
+          gemm_fusion_kernel_ = libxsmm_dispatch_brgemm_ext_v2(l_shape, 
l_flags, l_prefetch_flags,
+                                                               l_brconfig, 
l_argops, l_postops);
+        } else {
+          gemm_kernel_ = libxsmm_dispatch_brgemm_v2(l_shape, l_flags, 
l_prefetch_flags, l_brconfig);
+        }
+      }
+    }
+  }
+
+  void Run() override {
+    // Get input/output buffers.
+    auto data_eid = EntryID(input_nodes_[0], 0);
+    auto filter_eid = EntryID(input_nodes_[1], 0);
+    auto output_eid = EntryID(outputs_[0]);
+
+    void* data_handle = data_entry_[data_eid]->data;
+    void* filter_handle = data_entry_[filter_eid]->data;
+    void* output_handle = data_entry_[output_eid]->data;
+
+    // Transpose weight matrix since libxsmm only support GEMM rather than 
DENSE.
+    if (!transposed_filter_handle_) {
+      TVMDeviceAllocDataSpace(dev, K * N * sizeof(float), kAllocAlignment, 
type_hint,
+                              &transposed_filter_handle_);
+      for (int k = 0; k < K; ++k) {
+        for (int n = 0; n < N; ++n) {
+          static_cast<float*>(transposed_filter_handle_)[k * N + n] =
+              static_cast<float*>(filter_handle)[n * K + k];
+        }
+      }
+    }
+
+    if (has_bias_ || has_relu_) {
+      // Setup GEMM params.
+      libxsmm_gemm_ext_param gemm_param_ext;
+      gemm_param_ext.a.secondary = NULL;
+      gemm_param_ext.b.secondary = NULL;
+
+      gemm_param_ext.a.primary = transposed_filter_handle_;
+      gemm_param_ext.b.primary = data_handle;
+      gemm_param_ext.c.primary = output_handle;
+      if (has_bias_) {
+        auto bias_eid = EntryID(input_nodes_[2], 0);
+        void* bias_handle = data_entry_[bias_eid]->data;
+
+        gemm_param_ext.d.primary = bias_handle;
+      }
+      if (has_relu_) {
+        gemm_param_ext.c.secondary = relu_mask_.data();
+      }
+      gemm_param_ext.op.tertiary = &blocks_;
+
+      // Run GEMM fusion kernel.
+      gemm_fusion_kernel_(&gemm_param_ext);
+    } else {
+      // Setup GEMM params.
+      libxsmm_gemm_param gemm_param;
+      gemm_param.a.secondary = NULL;
+      gemm_param.b.secondary = NULL;
+
+      gemm_param.a.primary = transposed_filter_handle_;
+      gemm_param.b.primary = data_handle;
+      gemm_param.c.primary = output_handle;
+      gemm_param.op.tertiary = &blocks_;
+
+      // Run GEMM kernel.
+      gemm_kernel_(&gemm_param);
+    }
+  }
+
+  ~LibxsmmJSONRuntime() { TVMDeviceFreeDataSpace(dev, 
transposed_filter_handle_); }
+
+ private:
+  libxsmm_gemmfunction gemm_kernel_;
+  libxsmm_gemmfunction_ext gemm_fusion_kernel_;
+
+  // Transposed weight is saved to avoid redundant transpose in following 
steps.
+  // TODO(wenxizhu): check if current graph executor is in inference mode.
+  void* transposed_filter_handle_{nullptr};
+
+  DLDevice dev{kDLCPU, 0};
+  DLDataType type_hint{2, 32, 1};
+
+  int64_t M;
+  int64_t K;
+  int64_t N;

Review Comment:
   Please follow the naming convention for private class members.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to