comaniac commented on code in PR #11326:
URL: https://github.com/apache/tvm/pull/11326#discussion_r893779753
##########
cmake/modules/contrib/LIBXSMM.cmake:
##########
@@ -0,0 +1,31 @@
+if(IS_DIRECTORY ${USE_LIBXSMM})
+ find_library(LIBXSMM_LIBRARY NAMES xsmm HINTS ${USE_LIBXSMM}/lib/)
+ if(LIBXSMM_LIBRARY STREQUAL "LIBXSMM_LIBRARY-NOTFOUND")
+ message(WARNING "Cannot find LIBXSMM library at ${USE_LIBXSMM}/lib/.")
+ else()
+ include_directories(SYSTEM ${USE_LIBXSMM}/include)
+ list(APPEND TVM_RUNTIME_LINKER_LIBS ${LIBXSMM_LIBRARY})
+ MESSAGE(STATUS "Use LIBXSMM library " ${LIBXSMM_LIBRARY})
+
+ tvm_file_glob(GLOB LIBXSMM_RELAY_CONTRIB_SRC
src/relay/backend/contrib/libxsmm/*.cc)
+ list(APPEND COMPILER_SRCS ${LIBXSMM_RELAY_CONTRIB_SRC})
+ tvm_file_glob(GLOB LIBXSMM_RUNTIME_SRC src/runtime/contrib/libxsmm/*.cc)
+ list(APPEND RUNTIME_SRCS ${LIBXSMM_RUNTIME_SRC})
+ endif()
+elseif(USE_LIBXSMM STREQUAL "ON")
+ find_library(LIBXSMM_LIBRARY xsmm)
+ if(LIBXSMM_LIBRARY STREQUAL "LIBXSMM_LIBRARY-NOTFOUND")
+ message(WARNING "Cannot find LIBXSMM library at $(USE_LIBXSMM).")
+ else()
+ list(APPEND TVM_RUNTIME_LINKER_LIBS ${LIBXSMM_LIBRARY})
+ MESSAGE(STATUS "Use LIBXSMM library " ${LIBXSMM_LIBRARY})
+
+ tvm_file_glob(GLOB LIBXSMM_RELAY_CONTRIB_SRC
src/relay/backend/contrib/libxsmm/*.cc)
+ list(APPEND COMPILER_SRCS ${LIBXSMM_RELAY_CONTRIB_SRC})
+ tvm_file_glob(GLOB LIBXSMM_RUNTIME_SRC src/runtime/contrib/libxsmm/*.cc)
+ list(APPEND RUNTIME_SRCS ${LIBXSMM_RUNTIME_SRC})
+ endif()
Review Comment:
Yeah I was thinking to merge these redundant lines, but it's not critical.
##########
src/runtime/contrib/libxsmm/libxsmm_json_runtime.cc:
##########
@@ -0,0 +1,192 @@
+#include <libxsmm.h>
+#include <libxsmm_typedefs.h>
+#include <tvm/runtime/device_api.h>
+#include <tvm/runtime/registry.h>
+
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+class LibxsmmJSONRuntime : public json::JSONRuntimeBase {
+ public:
+ LibxsmmJSONRuntime(const std::string& symbol_name, const std::string&
graph_json,
+ const Array<String> const_names)
+ : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+
+ const char* type_key() const { return "libxsmm_json"; }
+
+ void Init(const Array<NDArray>& consts) override {
+
+ SetupConstants(consts);
+ for (size_t nid = 0; nid < nodes_.size(); ++nid) {
+ auto& node = nodes_[nid];
+ if (node.GetOpType() == "kernel") {
+ auto op_name = node.GetOpName();
+
+ // Check if has bias or relu fusion.
+ has_bias_ = op_name.find("_bias") != std::string::npos;
+ has_relu_ = op_name.find("_relu") != std::string::npos;
+
+ // Get M, N, K, lda, ldb, ldc.
+ auto data_entry = node.GetInputs()[0];
+ auto weight_entry = node.GetInputs()[1];
+ json::JSONGraphNodeEntry out_entry(nid, 0);
+
+ std::vector<int64_t> input_shape =
nodes_[data_entry.id_].GetOpShape()[data_entry.index_];
+ std::vector<int64_t> weight_shape =
+ nodes_[weight_entry.id_].GetOpShape()[weight_entry.index_];
+ std::vector<int64_t> out_shape =
nodes_[out_entry.id_].GetOpShape()[out_entry.index_];
+
+ M_ = input_shape[0];
+ N_ = weight_shape[0];
+ K_ = input_shape[1];
+
+ int lda = N_;
+ int ldb = K_;
+ int ldc = N_;
+
+ // Curently we support fp32 only.
+ libxsmm_datatype dtype = LIBXSMM_DATATYPE_F32;
+
+ // Configure GEMM related parameters
+ libxsmm_bitfield l_flags = LIBXSMM_GEMM_FLAG_NONE |
LIBXSMM_GEMM_FLAG_BETA_0;
+ libxsmm_bitfield l_prefetch_flags = LIBXSMM_GEMM_PREFETCH_NONE;
+ libxsmm_gemm_shape l_shape =
+ libxsmm_create_gemm_shape(N_, M_, K_, lda, ldb, ldc, dtype, dtype,
dtype, dtype);
+ libxsmm_blasint stride_a = N_ * K_ * sizeof(float);
+ libxsmm_blasint stride_b = K_ * M_ * sizeof(float);
+ libxsmm_gemm_batch_reduce_config l_brconfig =
libxsmm_create_gemm_batch_reduce_config(
+ LIBXSMM_GEMM_BATCH_REDUCE_STRIDE, stride_a, stride_b, 0
/*br_unrool_hint*/);
+
+ libxsmm_gemm_ext_unary_argops l_argops;
+ libxsmm_gemm_ext_binary_postops l_postops;
+ memset(&l_argops, 0, sizeof(libxsmm_gemm_ext_unary_argops));
+ memset(&l_postops, 0, sizeof(libxsmm_gemm_ext_binary_postops));
+
+ if (has_bias_) {
+ l_postops.d_in_type = dtype;
+ l_postops.d_binary_flags = LIBXSMM_MELTW_FLAG_BINARY_BCAST_COL_IN_0;
+ l_postops.d_binary_type = LIBXSMM_MELTW_TYPE_BINARY_ADD;
+ l_postops.ldd = ldc;
+ }
+
+ if (has_relu_) {
+ l_argops.cp_unary_flags = LIBXSMM_MELTW_FLAG_UNARY_NONE;
+ l_argops.cp_unary_type = LIBXSMM_MELTW_TYPE_UNARY_RELU;
+ l_argops.ldcp = ldc;
+ // relu mask should have the same size as matrix C.
+ relu_mask_.resize(M_ * N_, 0);
+ }
+
+ // Use "libxsmm_gemmfunction" for GEMM kernel, and
"libxsmm_gemmfunction_ext" for fused GEMM
+ // kernel.
+ if (has_bias_ || has_relu_) {
+ gemm_fusion_kernel_ = libxsmm_dispatch_brgemm_ext_v2(l_shape,
l_flags, l_prefetch_flags,
+ l_brconfig,
l_argops, l_postops);
+ } else {
+ gemm_kernel_ = libxsmm_dispatch_brgemm_v2(l_shape, l_flags,
l_prefetch_flags, l_brconfig);
+ }
+ }
+ }
+ }
+
+ void Run() override {
+ // Get input/output buffers.
+ auto data_eid = EntryID(input_nodes_[0], 0);
+ auto filter_eid = EntryID(input_nodes_[1], 0);
+ auto output_eid = EntryID(outputs_[0]);
+
+ void* data_handle = data_entry_[data_eid]->data;
+ void* filter_handle = data_entry_[filter_eid]->data;
+ void* output_handle = data_entry_[output_eid]->data;
+
+ // Transpose weight matrix since libxsmm only support GEMM rather than
DENSE.
+ if (!transposed_filter_handle_) {
+ TVMDeviceAllocDataSpace(dev_, K_ * N_ * sizeof(float), kAllocAlignment,
type_hint_,
+ &transposed_filter_handle_);
+ for (int k = 0; k < K_; ++k) {
+ for (int n = 0; n < N_; ++n) {
+ static_cast<float*>(transposed_filter_handle_)[k * N_ + n] =
+ static_cast<float*>(filter_handle)[n * K_ + k];
+ }
+ }
+ }
Review Comment:
This should be done at the compile time because weights are constants.
Also, if you attempt to manage constant weights in your codegen, you should
have a constant updater to remove the one in metadata module; otherwise you'll
actually have two weights stored in the binary. You can search
"constant_updater" to see the use cases.
##########
python/tvm/relay/op/contrib/libxsmm.py:
##########
@@ -0,0 +1,100 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name
+"""Patterns supported LIBXSMM."""
+import numpy as np
+
+import tvm
+from tvm import relay
+from tvm.ir.transform import Sequential
+from tvm.relay import transform
+from tvm.relay.build_module import bind_params_by_name
+
+from ...dataflow_pattern import wildcard, is_op, is_constant
+from .register import register_pattern_table
+
+
+def get_root_call(call, root_op_name):
+ if not isinstance(call, relay.Call):
+ return None
+ if str(call.op) == root_op_name:
+ return call
+ return get_root_call(call.args[0], root_op_name)
+
+
+def check_dense_shape(call):
+ dense = get_root_call(call, "nn.dense")
+ data = dense.args[0].checked_type
+ weight = dense.args[1].checked_type
+ m = int(data.shape[0])
+ n = int(weight.shape[0])
+ k = int(data.shape[1])
+
+ # Conditions to enable libxsmm BYOC.
+ # Note: currently we enable libxsmm when cube_root(m * n * k ) <= 256
since it has significant performance improvement.
+ return bool(np.cbrt(m * n * k) <= 256)
+
+
[email protected]_op_attr("nn.dense", "target.libxsmm")
+def dense(expr):
+ return check_dense_shape(expr)
+
+
+def make_dense_pattern(with_bias=False, eltwise=None):
+ data = wildcard()
+ weight = is_constant()
Review Comment:
Better to add a comment here saying you only support constant weights.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]