DickJC123 commented on a change in pull request #20635: URL: https://github.com/apache/incubator-mxnet/pull/20635#discussion_r745831518
########## File path: src/operator/cudnn_ops.cc ########## @@ -0,0 +1,765 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * Copyright (c) 2021 by Contributors + * \file cudnn_ops.cc + * \brief cuDNN v8 ops + */ + +#include "cudnn_ops.h" + +#include <mxnet/base.h> +#if MXNET_USE_CUDNN == 1 + +#include <dmlc/parameter.h> + +#include <algorithm> +#include <cstdlib> +#include <iomanip> +#include <iterator> +#include <limits> +#include <numeric> +#include <sstream> +#include <string> +#include <utility> + +namespace mxnet { +namespace op { + +using cudnn_cxx::Descriptor; +using cudnn_cxx::GetAttr; +using cudnn_cxx::GetSomeAttrs; +using cudnn_cxx::IsCompatible; +using cudnn_cxx::MakeAvgSampler; +using cudnn_cxx::MakeFinalized; +using cudnn_cxx::PackedStrides; +using cudnn_cxx::PlanStr; + +namespace cudnn { + +cudnnDataType_t CudnnType(mshadow::TypeFlag dtype) { + static std::unordered_map<mshadow::TypeFlag, cudnnDataType_t> type_map { + {mshadow::kFloat32, CUDNN_DATA_FLOAT}, {mshadow::kFloat64, CUDNN_DATA_DOUBLE}, + {mshadow::kFloat16, CUDNN_DATA_HALF}, {mshadow::kUint8, CUDNN_DATA_UINT8}, + {mshadow::kInt8, CUDNN_DATA_INT8}, {mshadow::kInt32, CUDNN_DATA_INT32}, +#if CUDNN_VERSION >= 8100 + {mshadow::kInt64, CUDNN_DATA_INT64}, +#endif // CUDNN_VERSION >= 8100 + }; + auto it = type_map.find(dtype); + CHECK(it != type_map.end()) << "Unsupported type: " << dtype; + return it->second; +} + +std::vector<size_t> LayoutInfo::Order() const { + std::vector<size_t> ret(n_space_dims + 2); + std::iota(ret.begin(), ret.end(), 0); + if (channel_last) + std::rotate(ret.begin() + 1, ret.begin() + 2, ret.end()); + return ret; +} + +size_t LayoutInfo::ChannelIdx() const { + return channel_last ? 1 + n_space_dims : 1; +} + +std::vector<int64_t> LayoutInfo::Strides(const std::vector<int64_t>& dims) const { + return PackedStrides(Order(), dims); +} + +LayoutInfo GetLayoutInfo(mshadow::LayoutFlag layout) { + static std::unordered_map<mshadow::LayoutFlag, LayoutInfo> layout_map{ + {mshadow::kNCW, {1, false}}, + {mshadow::kNWC, {1, true}}, + {mshadow::kNCHW, {2, false}}, + {mshadow::kNHWC, {2, true}}, + {mshadow::kNCDHW, {3, false}}, + {mshadow::kNDHWC, {3, true}}, + }; + auto it = layout_map.find(layout); + CHECK(it != layout_map.end()) << "Unsupported layout: " << layout; + return it->second; +} + +TShape ExpandChannelDims(mshadow::LayoutFlag layout, int c) { + auto li = GetLayoutInfo(layout); + std::vector<int> dims(li.n_space_dims + 2, 1); + dims[li.ChannelIdx()] = c; + return TShape(dims.begin(), dims.end()); +} + +std::vector<size_t> ReverseOrder(const std::vector<size_t>& o) { + std::vector<size_t> ret(o.size()); + for (size_t i = 0; i < ret.size(); ++i) + ret[o[i]] = i; + return ret; +} + +std::vector<cudnnBackendNumericalNote_t> RequireNumerics() { + std::vector<cudnnBackendNumericalNote_t> ret; + return ret; +} + +std::vector<cudnnBackendNumericalNote_t> ExcludeNumerics() { + std::vector<cudnnBackendNumericalNote_t> ret; + if (!dmlc::GetEnv("MXNET_CUDA_ALLOW_TENSOR_CORE", true)) + ret.push_back(CUDNN_NUMERICAL_NOTE_TENSOR_CORE); + if (!dmlc::GetEnv("MXNET_CUDA_TENSOR_OP_MATH_ALLOW_CONVERSION", false)) + ret.push_back(CUDNN_NUMERICAL_NOTE_DOWN_CONVERT_INPUTS); + if (!dmlc::GetEnv("MXNET_CUDNN_ALLOW_REDUCED_PRECISION_REDUCTION", true)) + ret.push_back(CUDNN_NUMERICAL_NOTE_REDUCED_PRECISION_REDUCTION); + if (!dmlc::GetEnv("MXNET_CUDNN_ALLOW_FFT", true)) + ret.push_back(CUDNN_NUMERICAL_NOTE_FFT); + if (dmlc::GetEnv("MXNET_ENFORCE_DETERMINISM", false)) + ret.push_back(CUDNN_NUMERICAL_NOTE_NONDETERMINISTIC); + if (!dmlc::GetEnv("MXNET_CUDNN_ALLOW_WINOGRAD", true)) + ret.push_back(CUDNN_NUMERICAL_NOTE_WINOGRAD); + return ret; +} + +Descriptor MakeTensorDesc(int64_t uid, + cudnnDataType_t dtype, + const std::vector<int64_t>& dims, + const std::vector<int64_t>& strides, + bool is_virtual) { + int64_t alignment = 16; // TODO(vcherepanov): ? + return MakeFinalized(CUDNN_BACKEND_TENSOR_DESCRIPTOR, + CUDNN_ATTR_TENSOR_UNIQUE_ID, + uid, + CUDNN_ATTR_TENSOR_DATA_TYPE, + dtype, + CUDNN_ATTR_TENSOR_BYTE_ALIGNMENT, + alignment, + CUDNN_ATTR_TENSOR_DIMENSIONS, + dims, + CUDNN_ATTR_TENSOR_STRIDES, + strides, + CUDNN_ATTR_TENSOR_IS_VIRTUAL, + is_virtual); +} + +Descriptor MakeTensorDesc(int64_t uid, + const TBlob& blob, + const LayoutInfo& li, + bool expand_1d, + bool is_virtual) { + std::vector<int64_t> dims(blob.shape_.ndim()); + CHECK_EQ(dims.size(), li.n_space_dims + 2); + auto rev_order = ReverseOrder(li.Order()); + for (size_t i = 0; i < dims.size(); ++i) + dims[i] = blob.shape_[rev_order[i]]; + auto strides = li.Strides(dims); + if (li.n_space_dims == 1 && expand_1d) { + dims.insert(dims.begin() + 2, 1); + std::vector<size_t> order(dims.size()); + std::iota(order.begin(), order.end(), 0); + if (li.channel_last) + std::rotate(order.begin() + 1, order.begin() + 2, order.end()); + strides = PackedStrides(order, dims); + } + return MakeTensorDesc( + uid, CudnnType(static_cast<mshadow::TypeFlag>(blob.type_flag_)), dims, strides, is_virtual); +} + +Descriptor MakeCTensorDescExpandDims(int64_t uid, + const TBlob& b, + const LayoutInfo& li, + bool is_virtual) { + std::vector<int64_t> dims(li.n_space_dims + 2, 1); + dims[1] = b.shape_[0]; + auto dtype = CudnnType(static_cast<mshadow::TypeFlag>(b.type_flag_)); + return MakeTensorDesc(uid, dtype, dims, li.Strides(dims), is_virtual); +} + +Descriptor MakeConvDesc(const ConvParam& param, mshadow::TypeFlag dtype) { + int64_t sdims = param.kernel.ndim(); + std::vector<int64_t> stride(param.stride.begin(), param.stride.end()); + std::vector<int64_t> dilate(param.dilate.begin(), param.dilate.end()); + std::vector<int64_t> pad(param.pad.begin(), param.pad.end()); + + auto comp_type = CudnnType(dtype); + if (comp_type == CUDNN_DATA_HALF) + comp_type = CUDNN_DATA_FLOAT; + + if (sdims == 1) { + // TODO(vcherepanov): remove this once cuDNN properly supports 1D convolutions. + // For now, making spacial dims 2D: 1 x W. + ++sdims; + stride.insert(stride.begin(), 1); + dilate.insert(dilate.begin(), 1); + pad.insert(pad.begin(), 0); + } + return MakeFinalized(CUDNN_BACKEND_CONVOLUTION_DESCRIPTOR, + CUDNN_ATTR_CONVOLUTION_SPATIAL_DIMS, + sdims, + CUDNN_ATTR_CONVOLUTION_COMP_TYPE, + comp_type, + CUDNN_ATTR_CONVOLUTION_CONV_MODE, + CUDNN_CROSS_CORRELATION, + CUDNN_ATTR_CONVOLUTION_FILTER_STRIDES, + stride, + CUDNN_ATTR_CONVOLUTION_DILATIONS, + dilate, + CUDNN_ATTR_CONVOLUTION_PRE_PADDINGS, + pad, + CUDNN_ATTR_CONVOLUTION_POST_PADDINGS, + pad); +} + +Descriptor MakeConvFwdOp(const Descriptor& conv, + const Descriptor& x, + const Descriptor& w, + const Descriptor& y, + bool add_to) { + auto ret = Make(CUDNN_BACKEND_OPERATION_CONVOLUTION_FORWARD_DESCRIPTOR, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_CONV_DESC, + conv, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_X, + x, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_W, + w, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_Y, + y); + if (GetAttr<cudnnDataType_t>(x, CUDNN_ATTR_TENSOR_DATA_TYPE) == CUDNN_DATA_DOUBLE) { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA, + 1.0, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA, + add_to ? 1.0 : 0.0); + } else { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_ALPHA, + 1.0f, + CUDNN_ATTR_OPERATION_CONVOLUTION_FORWARD_BETA, + add_to ? 1.0f : 0.0f); + } + CUDNN_CALL(cudnnBackendFinalize(ret.get())); + return ret; +} + +Descriptor MakeConvDgradOp(const Descriptor& conv, + const Descriptor& w, + const Descriptor& dy, + const Descriptor& dx, + bool add_to) { + auto ret = Make(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_DATA_DESCRIPTOR, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_CONV_DESC, + conv, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_W, + w, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DY, + dy, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_DX, + dx); + if (GetAttr<cudnnDataType_t>(w, CUDNN_ATTR_TENSOR_DATA_TYPE) == CUDNN_DATA_DOUBLE) { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA, + 1.0, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA, + add_to ? 1.0 : 0.0); + } else { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_ALPHA, + 1.0f, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_DATA_BETA, + add_to ? 1.0f : 0.0f); + } + CUDNN_CALL(cudnnBackendFinalize(ret.get())); + return ret; +} + +Descriptor MakeConvWgradOp(const Descriptor& conv, + const Descriptor& x, + const Descriptor& dy, + const Descriptor& dw, + bool add_to) { + auto ret = Make(CUDNN_BACKEND_OPERATION_CONVOLUTION_BACKWARD_FILTER_DESCRIPTOR, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_CONV_DESC, + conv, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_X, + x, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DY, + dy, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_DW, + dw); + if (GetAttr<cudnnDataType_t>(x, CUDNN_ATTR_TENSOR_DATA_TYPE) == CUDNN_DATA_DOUBLE) { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA, + 1.0, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA, + add_to ? 1.0 : 0.0); + } else { + SetAttrs(ret, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_ALPHA, + 1.0f, + CUDNN_ATTR_OPERATION_CONVOLUTION_BWD_FILTER_BETA, + add_to ? 1.0f : 0.0f); + } + CUDNN_CALL(cudnnBackendFinalize(ret.get())); + return ret; +} + +Descriptor MakeOpGraph(cudnnHandle_t handle, const std::vector<Descriptor>& ops) { + return MakeFinalized(CUDNN_BACKEND_OPERATIONGRAPH_DESCRIPTOR, + CUDNN_ATTR_OPERATIONGRAPH_HANDLE, + handle, + CUDNN_ATTR_OPERATIONGRAPH_OPS, + ops); +} + +ConvParam::ConvParam(const ConvolutionParam& p, bool add_to) + : kernel(p.kernel), + stride(p.stride), + dilate(p.dilate), + pad(p.pad), + num_filter(p.num_filter), + num_group(p.num_group), + workspace(p.workspace), + cudnn_tune(p.cudnn_tune), + layout(p.layout), + add_to(add_to) {} + +ConvParam::ConvParam(const DeconvolutionParam& p, bool add_to) + : kernel(p.kernel), + stride(p.stride), + dilate(p.dilate), + pad(p.pad), + num_filter(p.num_filter), + num_group(p.num_group), + workspace(p.workspace), + cudnn_tune(p.cudnn_tune), + layout(p.layout), + add_to(add_to) {} + +void TuneWarnOnce() { + thread_local bool done = false; + if (!done) { + LOG(INFO) << "Auto-tuning cuDNN op, set MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable"; + done = true; + } +} + +std::vector<Descriptor> MakeFallbackPlans( + const std::vector<int64_t>& ixs, + cudnnHandle_t handle, + const Descriptor& op_graph, + size_t workspace_limit, + size_t* max_workspace, + const std::unordered_set<int64_t>& excl_engines, + const std::vector<cudnnBackendNumericalNote_t>& req_numeric, + const std::vector<cudnnBackendNumericalNote_t>& excl_numeric +#if CUDNN_VERSION >= 8200 + , + const std::vector<cudnnBackendBehaviorNote_t>& req_behavior, + const std::vector<cudnnBackendBehaviorNote_t>& excl_behavior +#endif // CUDNN_VERSION >= 8200 +) { + std::vector<Descriptor> plans; + if (max_workspace) + *max_workspace = 0; + for (auto ix : ixs) { + if (excl_engines.count(ix)) + continue; + auto engine = Make(CUDNN_BACKEND_ENGINE_DESCRIPTOR, + CUDNN_ATTR_ENGINE_OPERATION_GRAPH, + op_graph, + CUDNN_ATTR_ENGINE_GLOBAL_INDEX, + ix); + auto err = cudnnBackendFinalize(engine.get()); + if (err == CUDNN_STATUS_NOT_SUPPORTED || err == CUDNN_STATUS_ARCH_MISMATCH) + continue; + if (err != CUDNN_STATUS_SUCCESS) { + LOG(WARNING) << "Unexpected cuDNN status: " << err << ": " << cudnnGetErrorString(err); + continue; + } + auto cfg = + MakeFinalized(CUDNN_BACKEND_ENGINECFG_DESCRIPTOR, CUDNN_ATTR_ENGINECFG_ENGINE, engine); + auto plan = Make(CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR, + CUDNN_ATTR_EXECUTION_PLAN_HANDLE, + handle, + CUDNN_ATTR_EXECUTION_PLAN_ENGINE_CONFIG, + cfg); + err = cudnnBackendFinalize(plan.get()); + if (err == CUDNN_STATUS_NOT_SUPPORTED || err == CUDNN_STATUS_ARCH_MISMATCH) + continue; + if (err != CUDNN_STATUS_SUCCESS) { + LOG(WARNING) << "Unexpected cuDNN status: " << err << ": " << cudnnGetErrorString(err); + continue; + } + auto workspace = GetAttr<int64_t>(plan, CUDNN_ATTR_EXECUTION_PLAN_WORKSPACE_SIZE); + if (workspace > workspace_limit) + continue; + auto numerical = GetSomeAttrs<cudnnBackendNumericalNote_t>( + CUDNN_NUMERICAL_NOTE_TYPE_COUNT, engine, CUDNN_ATTR_ENGINE_NUMERICAL_NOTE); + if (!IsCompatible(numerical, req_numeric, excl_numeric)) + continue; +#if CUDNN_VERSION >= 8200 + auto behavior = GetSomeAttrs<cudnnBackendBehaviorNote_t>( + CUDNN_BEHAVIOR_NOTE_TYPE_COUNT, engine, CUDNN_ATTR_ENGINE_BEHAVIOR_NOTE); + if (!IsCompatible(behavior, req_behavior, excl_behavior)) + continue; +#endif // CUDNN_VERSION >= 8200 + plans.push_back(std::move(plan)); + if (max_workspace) + *max_workspace = std::max(*max_workspace, static_cast<size_t>(workspace)); + } + return plans; +} + +cudnnBackendHeurMode_t HeurMode() { +#if CUDNN_VERSION >= 8100 + auto minor = cudnnGetVersion() / 100 % 10; + int default_mode = minor < 2 ? CUDNN_HEUR_MODE_INSTANT : CUDNN_HEUR_MODE_B; +#else + int default_mode = CUDNN_HEUR_MODE_INSTANT; +#endif // CUDNN_VERSION >= 8100 + return static_cast<cudnnBackendHeurMode_t>(dmlc::GetEnv("MXNET_CUDNN_HEUR_MODE", default_mode)); Review comment: The conclusion here is that it's simpler to define this as integers that align with the current definition. If future versions of cudnn*.h invalidate this, we can choose to insert a remapping function at that time if we feel it's important to maintain backward compatibility. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
