masahi commented on code in PR #15462:
URL: https://github.com/apache/tvm/pull/15462#discussion_r1285746375
##########
python/tvm/contrib/cudnn.py:
##########
@@ -461,6 +466,8 @@ def conv_backward_data_find_algo(
convolution type
groups: int
number of groups
+ verbose: bool
+ whether to show the selection trails
Review Comment:
trials
Fix all of them in this file
##########
python/tvm/relax/backend/contrib/cudnn.py:
##########
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Pattern table for cuDNN backend"""
+from tvm.relax import transform
+from tvm.relax.transform import PatternCheckContext
+
+from ..pattern_registry import get_patterns_with_prefix, register_patterns
+from ..patterns import make_conv2d_pattern
+
+
+def _is_supported_dtype(lhs_dtype, rhs_dtype):
+ """Check if dtypes in the given workload are supported by cuBLAS BYOC."""
+ return (lhs_dtype == "float16" and rhs_dtype == "float16") or (
+ lhs_dtype == "float32" and rhs_dtype == "float32"
+ )
+
+
+def _check_conv2d(context: PatternCheckContext) -> bool:
+ # Retrieve the annotated expression from context
+ input_expr = context.annotated_expr["input"]
+ weight_expr = context.annotated_expr["weight"]
+
+ # Check if the data types of input and weights are float32
+ input_dtype = input_expr.struct_info.dtype
+ weight_dtype = weight_expr.struct_info.dtype
+ if not _is_supported_dtype(input_dtype, weight_dtype):
+ return False
Review Comment:
You should check layouts here.
##########
src/runtime/contrib/cudnn/cudnn_json_runtime.cc:
##########
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/cudnn/cudnn_json_runtime.cc
+ * \brief A simple JSON runtime for CUDNN.
+ */
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstddef>
+#include <regex>
+#include <string>
+#include <vector>
+
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+#include "cudnn_utils.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace tvm::runtime;
+using namespace tvm::runtime::json;
+
+class cuDNNJSONRuntime : public JSONRuntimeBase {
+ public:
+ cuDNNJSONRuntime(const std::string& symbol_name, const std::string&
graph_json,
+ const Array<String> const_names)
+ : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+
+ void Init(const Array<NDArray>& consts) override {
+ auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal();
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ stream = static_cast<cudaStream_t>((*func)().operator void*());
+
+ auto attr_in_name = [this](const std::string& op_name, const std::string&
attr_name) {
+ return std::regex_search(op_name, std::regex(attr_name));
+ };
+
+ auto getVecIntAttrFromVecStr = [this](const JSONGraphNode& node, const
std::string& attrStr) {
+ auto stringToInt = [](const std::string& str) { return std::stoi(str); };
+ auto stringVec = node.GetAttr<std::vector<std::string>>(attrStr);
+ std::vector<int> intVec(stringVec.size());
+ std::transform(stringVec.begin(), stringVec.end(), intVec.begin(),
stringToInt);
+ return intVec;
Review Comment:
Code style in this function. We don't use camelCase.
##########
src/runtime/contrib/cudnn/cudnn_json_runtime.cc:
##########
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/runtime/contrib/cudnn/cudnn_json_runtime.cc
+ * \brief A simple JSON runtime for CUDNN.
+ */
+
+#include <tvm/runtime/ndarray.h>
+#include <tvm/runtime/registry.h>
+
+#include <cstddef>
+#include <regex>
+#include <string>
+#include <vector>
+
+#include "../json/json_node.h"
+#include "../json/json_runtime.h"
+#include "cudnn_utils.h"
+
+namespace tvm {
+namespace runtime {
+namespace contrib {
+
+using namespace tvm::runtime;
+using namespace tvm::runtime::json;
+
+class cuDNNJSONRuntime : public JSONRuntimeBase {
+ public:
+ cuDNNJSONRuntime(const std::string& symbol_name, const std::string&
graph_json,
+ const Array<String> const_names)
+ : JSONRuntimeBase(symbol_name, graph_json, const_names) {}
+
+ void Init(const Array<NDArray>& consts) override {
+ auto* entry_ptr = tvm::contrib::CuDNNThreadEntry::ThreadLocal();
+ auto func = tvm::runtime::Registry::Get("runtime.get_cuda_stream");
+ ICHECK(func != nullptr);
+ stream = static_cast<cudaStream_t>((*func)().operator void*());
+
+ auto attr_in_name = [this](const std::string& op_name, const std::string&
attr_name) {
+ return std::regex_search(op_name, std::regex(attr_name));
+ };
+
+ auto getVecIntAttrFromVecStr = [this](const JSONGraphNode& node, const
std::string& attrStr) {
+ auto stringToInt = [](const std::string& str) { return std::stoi(str); };
+ auto stringVec = node.GetAttr<std::vector<std::string>>(attrStr);
+ std::vector<int> intVec(stringVec.size());
+ std::transform(stringVec.begin(), stringVec.end(), intVec.begin(),
stringToInt);
+ return intVec;
+ };
+ // get some config from the graph
+ for (size_t i = 0; i < nodes_.size(); ++i) {
+ const auto& node = nodes_[i];
+ if (node.GetOpType() == "kernel") {
+ op_name = node.GetOpName();
+ std::vector<int> input_dims, kernel_dims, output_dims;
+ auto input_node = nodes_[0];
+ auto input_shapes = input_node.GetOpShape()[0];
+ auto kernel_node = nodes_[1];
+ auto kernel_shapes = kernel_node.GetOpShape()[0];
+ auto output_shapes = node.GetOpShape()[0];
+ for (const auto& _i : input_shapes) {
+ input_dims.emplace_back(static_cast<int>(_i));
+ }
+ for (const auto& _i : kernel_shapes) {
+ kernel_dims.emplace_back(static_cast<int>(_i));
+ }
+ for (const auto& _i : output_shapes) {
+ output_dims.emplace_back(static_cast<int>(_i));
+ }
+ has_bias = attr_in_name(op_name, "bias");
+ groups =
std::stoi(node.GetAttr<std::vector<std::string>>("groups")[0]);
+ padding = getVecIntAttrFromVecStr(node, "padding");
+ strides = getVecIntAttrFromVecStr(node, "strides");
+ dilation = getVecIntAttrFromVecStr(node, "dilation");
+ conv_dtype = node.GetAttr<std::vector<std::string>>("out_dtype")[0];
+ std::string layout =
node.GetAttr<std::vector<std::string>>("out_layout")[0];
+ dims = layout.size() - 2; // remove O and I dims
+
+ if (layout == "NCHW")
+ format = CUDNN_TENSOR_NCHW;
+ else if (layout == "NHWC")
+ format = CUDNN_TENSOR_NHWC;
+ else
+ LOG(FATAL) << "Unsupported layout: " << layout;
+
+ if (attr_in_name(op_name, "relu")) {
+ act = CUDNN_ACTIVATION_RELU;
+ } else if (attr_in_name(op_name, "relu6")) {
+ act = CUDNN_ACTIVATION_CLIPPED_RELU;
+ coef = 6.0;
+ } else if (attr_in_name(op_name, "leaky_relu")) {
+ act = CUDNN_ACTIVATION_RELU;
+ coef = 0.1;
+ }
+ this->handle = entry_ptr->handle;
+ this->kernel_node = node;
+
+ // find best algo
+ TVMRetValue best_algo;
+
+ tvm::contrib::FindAlgo(format, dims, groups, padding.data(),
strides.data(),
+ dilation.data(), input_dims.data(),
kernel_dims.data(),
+ output_dims.data(), conv_dtype, conv_dtype,
false, &best_algo);
+
+ this->algo = best_algo.operator int();
+ }
+ }
+ }
+
+ const char* type_key() const override { return "cudnn_json"; } // May be
overridden
+
+ void Run() override {
+ auto get_inputs = [this](const JSONGraphNode& node, bool has_bias) {
+ const DLTensor* bias = nullptr;
+ if (has_bias) {
+ bias = GetInput(node, 2);
+ }
+ return std::make_tuple(GetInput(node, 0), GetInput(node, 1), bias);
+ };
+
+ auto [a_ptr, b_ptr, bias_ptr] = get_inputs(kernel_node, has_bias);
+ uint32_t output_eid = EntryID(outputs_[0]);
+ auto out_ptr = data_entry_[output_eid];
+
+ if (this->has_bias) {
+ tvm::contrib::CallCudnnConvolutionBiasActivationForward(
+ this->handle, this->stream, this->mode, this->format, this->algo,
this->dims,
+ this->groups, this->act, this->coef, this->padding.data(),
this->strides.data(),
+ this->dilation.data(), a_ptr, b_ptr, out_ptr, bias_ptr,
this->conv_dtype);
+ } else {
+ tvm::contrib::CallCudnnConvolutionForward(
+ this->handle, this->stream, this->mode, this->format, this->algo,
this->dims,
+ this->groups, this->padding.data(), this->strides.data(),
this->dilation.data(), a_ptr,
+ b_ptr, out_ptr, this->conv_dtype);
+ }
+ }
+
+ private:
+ const DLTensor* GetInput(const JSONGraphNode& node, const int idx) {
+ ICHECK_LT(idx, node.GetInputs().size());
+ auto eid = EntryID(node.GetInputs()[idx]);
+ ICHECK(eid < data_entry_.size());
+ return data_entry_[eid];
+ }
+ /*conv op name*/
+ std::string op_name;
+ /*conv mode: CUDNN_CROSS_CORRELATION by default*/
+ int mode = CUDNN_CROSS_CORRELATION;
+ /*algo: by default we select the implicit gemm algo, will be tuned in the
initial pass.*/
+ int algo = CUDNN_CONVOLUTION_FWD_ALGO_IMPLICIT_GEMM;
+ /*if has bias*/
+ bool has_bias = false;
+ /*args for function call*/
+ int act = CUDNN_ACTIVATION_IDENTITY;
+ double coef = 1.0;
+ int format = CUDNN_TENSOR_NHWC;
+ int dims = 2;
+ int groups = 1;
+ std::vector<int> padding;
+ std::vector<int> strides;
+ std::vector<int> dilation;
+ std::string conv_dtype;
Review Comment:
Please clean unnecessary member variables.
##########
src/relax/backend/contrib/cudnn/codegen.cc:
##########
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file src/relax/backend/contrib/cudnn/codegen.cc
+ * \brief Implementation of the CUBLAS JSON serializer.
Review Comment:
cudnn
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]