This is an automated email from the ASF dual-hosted git repository.
dickjc123 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-mxnet.git
The following commit(s) were added to refs/heads/master by this push:
new 5ab7f64 [2.0] [BACKPORT] of [1.x][FEATURE] CUDA graphs support
(#19142) (#20324)
5ab7f64 is described below
commit 5ab7f64c2e99890502aa456798b74523e76f3c17
Author: Dick Carter <[email protected]>
AuthorDate: Fri Mar 18 16:55:07 2022 -0700
[2.0] [BACKPORT] of [1.x][FEATURE] CUDA graphs support (#19142) (#20324)
* [1.x][FEATURE] CUDA graphs support (#19142)
* Initial cherry-pick
* Store NodeAttrs in OpExecutor
* Do not allow stateful operations in CUDA graphs and provide mechanism
for marking ops as safe
* Guard against using ops with synchronization
* Cleaning
* Properly guard graphs
* Limit graphs to CUDA 10.2+
* Fix the compilation when graphs are not available
* Guarding the libcuda.so usage behind RTC compilation flag
* Document the env variables
* Add test
* Fix the test
* Use with_environment
* Fix compile and test_cuda_graphs
* Fix lint
* Mark more ops as not CUDA Graphs compatible
* Mark some linalg ops as not CUDA Graphs compatible
* Marked 2 ops CUDA Graphs incompatible due to cpu->gpu copy
* Mark cuDNN Dropout as fully CUDA Graphs compatible. Reenable tests.
* clang-tidy fixes
* More clang-tidy fixes
* Avoid CUDA_CALL(e): improper macro expansion
* Add compile guard to Dropout's FIsCUDAGraphsCompatible def
* Temporarily add '-s' to pytest serial tests
* Fix DropoutOp.dropout_passthrough_ handling for CUDA Graphs
* Adapt test_gluon_gpu.py::test_cuda_graphs for gluon2.0
* Create CUDA Graph 'dot' files if MXNET_CUDA_GRAPHS_DBG_FILE=<file_prefix>
* Fix clang-tidy
* Fix more clang-tidy
* Skip test_np_standard_binary_funcs test of 0-dim array broadcast
* Improve test_rnn_layers_fp{16,32} invocation
* Run test_rnn_layers_fp32 only when cuDNN is present
* Fix potential out-of-bounds write in count_sketch.cu
* Add temp output to debug centos crash
* Mark InstanceNorm and LeakyRELU as not CUDA Graphs compatible
* Ops calling FStatefulCompute* are not CUDA Graphs compatible by default
* Fix clang-tidy
* Revert "Add temp output to debug centos crash"
This reverts commit e013a85ea599fa761cb98762f11feab6e7d74049.
* Quiet 'unused variable' compilation warning
* Trigger CI
* Check of FCreateOpState removed given new check for FStatefulCompute*
* Revert "Temporarily add '-s' to pytest serial tests"
This reverts commit 5a2f847558a7f55790f1ad1fb5ee930b4ad1a3a9.
Co-authored-by: Przemyslaw Tredak <[email protected]>
---
docs/static_site/src/pages/api/faq/env_var.md | 16 +
include/mxnet/op_attr_types.h | 13 +
src/imperative/attach_op_execs_pass.cc | 66 +--
src/imperative/cuda_graphs.h | 593 +++++++++++++++++++++++++
src/imperative/exec_pass.h | 19 +
src/imperative/imperative_utils.h | 21 +-
src/operator/contrib/adamw.cu | 8 +
src/operator/contrib/index_array.cu | 5 +-
src/operator/contrib/multi_lamb.cu | 4 +
src/operator/instance_norm.cu | 7 +-
src/operator/leaky_relu.cu | 7 +-
src/operator/nn/dropout-inl.h | 1 -
src/operator/nn/dropout.cu | 21 +-
src/operator/numpy/linalg/np_eig.cu | 10 +-
src/operator/numpy/linalg/np_eigvals.cu | 10 +-
src/operator/numpy/linalg/np_norm_backward.cu | 6 +
src/operator/numpy/linalg/np_norm_forward.cu | 9 +-
src/operator/numpy/np_boolean_mask_assign.cu | 4 +
src/operator/numpy/np_constraint_check.cu | 2 +
src/operator/numpy/np_matrix_op.cu | 37 +-
src/operator/numpy/np_nonzero_op.cu | 2 +
src/operator/numpy/np_pad_op.cu | 12 +-
src/operator/numpy/np_percentile_op.cu | 5 +-
src/operator/numpy/random/np_bernoulli_op.cu | 5 +-
src/operator/numpy/random/np_exponential_op.cu | 2 +
src/operator/numpy/random/np_gamma_op.cu | 5 +-
src/operator/numpy/random/np_multinomial_op.cu | 2 +
src/operator/numpy/random/np_normal_op.cu | 10 +-
src/operator/numpy/random/np_pareto_op.cu | 5 +-
src/operator/numpy/random/np_power_op.cu | 5 +-
src/operator/numpy/random/np_rayleigh_op.cu | 5 +-
src/operator/numpy/random/np_weibull_op.cu | 5 +-
src/operator/tensor/elemwise_unary_op_basic.cu | 5 +-
src/operator/tensor/indexing_op.cu | 5 +-
src/operator/tensor/la_op.cu | 30 +-
src/operator/tensor/matrix_op.cu | 24 +-
tests/python/gpu/test_gluon_gpu.py | 111 ++++-
37 files changed, 1033 insertions(+), 64 deletions(-)
diff --git a/docs/static_site/src/pages/api/faq/env_var.md
b/docs/static_site/src/pages/api/faq/env_var.md
index dad481c..8e12b48 100644
--- a/docs/static_site/src/pages/api/faq/env_var.md
+++ b/docs/static_site/src/pages/api/faq/env_var.md
@@ -170,6 +170,22 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0
* MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN_BWD
- Values: Int ```(default=<value of MXNET_EXEC_BULK_EXEC_MAX_NODE_TRAIN>)```
- The maximum number of nodes in the subgraph executed in bulk during
training (not inference) in the backward pass.
+* MXNET_ENABLE_CUDA_GRAPHS
+ - Values: 0(false) or 1(true) ```(default=0)```
+ - If set to `1`, MXNet will utilize CUDA graphs when executing models on the
GPU when possible.
+ - For CUDA graphs execution, one needs to use either symbolic model or Gluon
model hybridized with options `static_alloc` and `static_shape` set to True.
+* MXNET_CUDA_GRAPHS_VERBOSE
+ - Values: 0(false) or 1(true) ```(default=0)```
+ - If set to `1`, CUDA graphs executor will provide information about the
graph being captured and executed.
+* MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES
+ - Values: Int ```(default=0)```
+ - The maximum number of log messages generated by CUDA graphs executor.
+* MXNET_CUDA_GRAPHS_DBG_FILE
+ - Values: String ```(default='', to indicate no debug dot files should be
created)```
+ - The file prefix for '.dot' files for each graph created. Full path is
<prefix>-devN-{trn,inf}.<graphId>.dot .
+* MXNET_CUDA_GRAPHS_DBG_FILE_FLAGS
+ - Values: Int ```(default=<most verbose setting- includes all info>)```
+ - A bitmask to enable various types of info in the debug '.dot' files. See
cudaGraphDebugDotFlags in the CUDA runtime API doc for details.
## Control the Data Communication
diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h
index 2fec176..c936d3e 100644
--- a/include/mxnet/op_attr_types.h
+++ b/include/mxnet/op_attr_types.h
@@ -357,6 +357,19 @@ using FNeedCalibrateInput =
std::function<std::vector<int>(const NodeAttrs& attr
*/
using FNeedCalibrateOutput = std::function<std::vector<int>(const NodeAttrs&
attrs)>;
+#if MXNET_USE_CUDA
+
+/*!
+ * \brief Register a function to determine if
+ * the operator implementation is compatible
+ * with CUDA graphs. This requires the execution
+ * to stay the same as long as the shape and type
+ * of input stays the same.
+ */
+using FIsCUDAGraphsCompatible = std::function<bool(const NodeAttrs& attrs,
const bool is_train)>;
+
+#endif
+
} // namespace mxnet
#endif // MXNET_OP_ATTR_TYPES_H_
diff --git a/src/imperative/attach_op_execs_pass.cc
b/src/imperative/attach_op_execs_pass.cc
index 4a8c51d..732391f 100644
--- a/src/imperative/attach_op_execs_pass.cc
+++ b/src/imperative/attach_op_execs_pass.cc
@@ -47,8 +47,10 @@ namespace exec {
// FComputeExecutor and FStatefulComputeExecutor inherit from this class
class StorageFallbackOpExecutor : public OpExecutor {
public:
- explicit StorageFallbackOpExecutor(std::vector<uint32_t> mutate_idx)
- : mutate_idx_(std::move(mutate_idx)) {}
+ explicit StorageFallbackOpExecutor(const NodeAttrs& attrs,
+ DispatchMode dispatch_mode,
+ std::vector<uint32_t> mutate_idx)
+ : OpExecutor(attrs, dispatch_mode), mutate_idx_(std::move(mutate_idx)) {}
void Setup() override {
init_ = false;
@@ -146,11 +148,13 @@ class StatefulComputeExecutor : public
StorageFallbackOpExecutor {
return state_;
}
- explicit StatefulComputeExecutor(OpStatePtr state,
+ explicit StatefulComputeExecutor(const NodeAttrs& attrs,
+ DispatchMode dispatch_mode,
+ OpStatePtr state,
FStatefulCompute fcompute,
ExecType exec_type,
const std::vector<uint32_t>& mutate_idx)
- : StorageFallbackOpExecutor(mutate_idx),
+ : StorageFallbackOpExecutor(attrs, dispatch_mode, mutate_idx),
state_(std::move(state)),
fcompute_(std::move(fcompute)),
exec_type_(exec_type) {}
@@ -168,7 +172,7 @@ class StatefulComputeExExecutor : public OpExecutor {
op_ctx.run_ctx = rctx;
INVALIDATE_OUTPUTS(out_array, req);
std::vector<NDArray>* pInArray = &in_array;
- CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback,
attrs_);
+ CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback, attrs);
fcompute_(state_, op_ctx, *pInArray, req, out_array);
}
@@ -186,17 +190,17 @@ class StatefulComputeExExecutor : public OpExecutor {
return state_;
}
- explicit StatefulComputeExExecutor(NodeAttrs attrs,
+ explicit StatefulComputeExExecutor(const NodeAttrs& attrs,
+ DispatchMode dispatch_mode,
OpStatePtr state,
FStatefulComputeEx fcompute,
ExecType exec_type)
- : attrs_(std::move(attrs)),
+ : OpExecutor(attrs, dispatch_mode),
state_(std::move(state)),
fcompute_(std::move(fcompute)),
exec_type_(exec_type) {}
private:
- NodeAttrs attrs_;
OpStatePtr state_;
FStatefulComputeEx fcompute_;
ExecType exec_type_;
@@ -210,7 +214,7 @@ class FComputeExecutor : public StorageFallbackOpExecutor {
op_ctx.run_ctx = rctx;
INVALIDATE_OUTPUTS(out_array, req);
PreFCompute(is_gpu);
- fcompute_(attrs_, op_ctx, in_data_, req, out_data_);
+ fcompute_(attrs, op_ctx, in_data_, req, out_data_);
PostFCompute(is_gpu);
}
@@ -218,17 +222,16 @@ class FComputeExecutor : public StorageFallbackOpExecutor
{
return exec_type_;
}
- explicit FComputeExecutor(NodeAttrs attrs,
+ explicit FComputeExecutor(const NodeAttrs& attrs,
+ DispatchMode dispatch_mode,
FCompute fcompute,
ExecType exec_type,
const std::vector<uint32_t>& mutate_idx)
- : StorageFallbackOpExecutor(mutate_idx),
- attrs_(std::move(attrs)),
+ : StorageFallbackOpExecutor(attrs, dispatch_mode, mutate_idx),
fcompute_(std::move(fcompute)),
exec_type_(exec_type) {}
private:
- NodeAttrs attrs_;
FCompute fcompute_;
ExecType exec_type_;
};
@@ -240,8 +243,8 @@ class FComputeExExecutor : public OpExecutor {
op_ctx.run_ctx = rctx;
INVALIDATE_OUTPUTS(out_array, req);
std::vector<NDArray>* pInArray = &in_array;
- CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback,
attrs_);
- fcompute_(attrs_, op_ctx, *pInArray, req, out_array);
+ CREATE_DEFAULT_INPUTS_DNNL(in_array, pInArray = &in_array_fallback, attrs);
+ fcompute_(attrs, op_ctx, *pInArray, req, out_array);
}
void Setup() override {}
@@ -250,11 +253,13 @@ class FComputeExExecutor : public OpExecutor {
return exec_type_;
}
- explicit FComputeExExecutor(NodeAttrs attrs, FComputeEx fcompute, ExecType
exec_type)
- : attrs_(std::move(attrs)), fcompute_(std::move(fcompute)),
exec_type_(exec_type) {}
+ explicit FComputeExExecutor(const NodeAttrs& attrs,
+ DispatchMode dispatch_mode,
+ FComputeEx fcompute,
+ ExecType exec_type)
+ : OpExecutor(attrs, dispatch_mode), fcompute_(std::move(fcompute)),
exec_type_(exec_type) {}
private:
- NodeAttrs attrs_;
FComputeEx fcompute_;
ExecType exec_type_;
};
@@ -309,14 +314,15 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret,
OpStateVector* p_state,
// FStatefulComputeEx is dispatched only when dispatch_mode is
DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] ==
DispatchMode::kFComputeEx) {
ret[i] = std::make_shared<StatefulComputeExExecutor>(
- inode.source->attrs, state, fcompute_ex, exec_type);
+ inode.source->attrs, dispatch_modes[i], state, fcompute_ex,
exec_type);
} else {
FStatefulCompute fcompute =
common::GetFCompute<FStatefulCompute>(op, "FStatefulCompute",
vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be
registered "
<< "for stateful operator " << op->name;
- ret[i] = std::make_shared<StatefulComputeExecutor>(state, fcompute,
exec_type, mutate_index);
+ ret[i] = std::make_shared<StatefulComputeExecutor>(
+ inode.source->attrs, dispatch_modes[i], state, fcompute, exec_type,
mutate_index);
}
} else if (is_layer_backward.get(op, false)) {
CHECK_GE(inode.control_deps.size(), 1);
@@ -327,25 +333,33 @@ void CreateOpExecs(const Graph& g, OpExecVector* p_ret,
OpStateVector* p_state,
common::GetFCompute<FStatefulComputeEx>(op, "FStatefulComputeEx",
vctx[i]);
// FStatefulComputeEx is dispatched only when dispatch_mode is
DispatchMode::kFComputeEx
if (fcompute_ex != nullptr && dispatch_modes[i] ==
DispatchMode::kFComputeEx) {
- ret[i] = std::make_shared<StatefulComputeExExecutor>(
- inode.source->attrs, ret[fwd_id].get()->state(), fcompute_ex,
exec_type);
+ ret[i] = std::make_shared<StatefulComputeExExecutor>(inode.source->attrs,
+ dispatch_modes[i],
+
ret[fwd_id].get()->state(),
+ fcompute_ex,
+ exec_type);
} else {
FStatefulCompute fcompute =
common::GetFCompute<FStatefulCompute>(op, "FStatefulCompute",
vctx[i]);
CHECK(fcompute != nullptr)
<< "One of FStatefulCompute and FStatefulComputeEx must be
registered "
<< "for stateful operator " << op->name;
- ret[i] = std::make_shared<StatefulComputeExecutor>(
- ret[fwd_id].get()->state(), fcompute, exec_type, mutate_index);
+ ret[i] = std::make_shared<StatefulComputeExecutor>(inode.source->attrs,
+ dispatch_modes[i],
+
ret[fwd_id].get()->state(),
+ fcompute,
+ exec_type,
+ mutate_index);
}
} else {
FCompute fcompute = common::GetFCompute<FCompute>(op, "FCompute",
vctx[i]);
FComputeEx fcomp_ex = common::GetFCompute<FComputeEx>(op, "FComputeEx",
vctx[i]);
if (fcomp_ex != nullptr && dispatch_modes[i] == DispatchMode::kFComputeEx)
{
- ret[i] = std::make_shared<FComputeExExecutor>(inode.source->attrs,
fcomp_ex, exec_type);
+ ret[i] = std::make_shared<FComputeExExecutor>(
+ inode.source->attrs, dispatch_modes[i], fcomp_ex, exec_type);
} else if (fcompute != nullptr) {
ret[i] = std::make_shared<FComputeExecutor>(
- inode.source->attrs, fcompute, exec_type, mutate_index);
+ inode.source->attrs, dispatch_modes[i], fcompute, exec_type,
mutate_index);
} else {
LOG(INFO) << "Neither FCompute nor FComputeEx registered " << op->name;
}
diff --git a/src/imperative/cuda_graphs.h b/src/imperative/cuda_graphs.h
new file mode 100644
index 0000000..c9e16d8
--- /dev/null
+++ b/src/imperative/cuda_graphs.h
@@ -0,0 +1,593 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_IMPERATIVE_CUDA_GRAPHS_H_
+#define MXNET_IMPERATIVE_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+#include <sstream>
+
+#include "./exec_pass.h"
+#include "../common/cuda/utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10020)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+ std::stringstream ss;
+ if (dims.z != 1)
+ ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+ else if (dims.y != 1)
+ ss << "(" << dims.x << "," << dims.y << ")";
+ else
+ ss << "(" << dims.x << ")";
+ return ss.str();
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+ size_t numNodes;
+ CUDA_CALL(cudaGraphGetNodes(cuda_graph,
static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+ if (numNodes == 0)
+ return std::vector<cudaGraphNode_t>();
+ std::vector<cudaGraphNode_t> graphNodes(numNodes);
+ CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+ return graphNodes;
+}
+
+// Create a description of a CUDA Graph node
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+ std::stringstream ss;
+
+ // The following introspection calls are made through the driver API in
order to bypass
+ // problems that would arise if multiple statically-linked copies of the
runtime exist.
+
+ CUgraphNode cu_node = node;
+ CUgraphNodeType t;
+ CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+ switch (t) {
+ case CU_GRAPH_NODE_TYPE_KERNEL: {
+ CUDA_KERNEL_NODE_PARAMS kparams;
+ auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+ if (err == CUDA_SUCCESS) {
+ ss << "GPUKernel@" << kparams.func;
+ dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+ dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+ ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+ << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+ ss << "(...";
+ if (kparams.sharedMemBytes != 0)
+ ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+ ss << ")";
+ } else {
+ ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+ }
+ } break;
+ case CU_GRAPH_NODE_TYPE_MEMCPY: {
+ cudaMemcpy3DParms mparams = {};
+ CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+ // If memcpy is seen, return without setting up runnable executor
+ switch (mparams.kind) {
+ case cudaMemcpyHostToHost:
+ ss << "Host->Host ";
+ break;
+ case cudaMemcpyHostToDevice:
+ ss << "Host->Device ";
+ break;
+ case cudaMemcpyDeviceToHost:
+ ss << "Device->Host ";
+ break;
+ case cudaMemcpyDeviceToDevice:
+ ss << "Device->Device ";
+ break;
+ default:
+ break;
+ }
+ ss << "Memcpy";
+ } break;
+ case CU_GRAPH_NODE_TYPE_MEMSET: {
+ cudaMemsetParams mparams = {};
+ CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+ if (mparams.height == 1 && mparams.elementSize == 1) {
+ ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" <<
mparams.value
+ << ", count=" << mparams.width << ")";
+ } else {
+ if (mparams.elementSize == 1)
+ ss << "cudaMemset2D";
+ else
+ ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+ ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+ << ", value=" << mparams.value << ", width=" << mparams.width
+ << ", height=" << mparams.height << ")";
+ }
+ } break;
+ case CU_GRAPH_NODE_TYPE_HOST:
+ ss << "Host (executable) node";
+ break;
+ case CU_GRAPH_NODE_TYPE_GRAPH:
+ ss << "Node which executes an embedded graph";
+ break;
+ case CU_GRAPH_NODE_TYPE_EMPTY:
+ ss << "Empty (no-op) node";
+ break;
+ default:
+ ss << "Unknown/Invalid node type " << t;
+ }
+ return ss.str();
+}
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+ void operator()(cudaGraph_t graph) {
+ if (graph != nullptr)
+ CUDA_CALL(cudaGraphDestroy(graph));
+ }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+ void operator()(cudaGraphExec_t graph_exec) {
+ if (graph_exec != nullptr)
+ CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+ }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a
'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of
ops.
+class CudaGraphsSubSegExec {
+ public:
+ CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor>>&
exec_list,
+ const RunContext& rctx,
+ bool is_gpu,
+ bool verbose,
+ int from_op_idx,
+ int num_ops,
+ bool ops_are_cuda_graph_compatible = true)
+ : from_op_idx_(from_op_idx),
+ num_ops_(num_ops),
+ graph_(nullptr),
+ graph_exec_(nullptr),
+ graph_exec_id_(0) {
+ if (ops_are_cuda_graph_compatible) {
+ MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+ MakeGraphExec(exec_list, rctx);
+ }
+ }
+
+ void Update(const std::vector<std::shared_ptr<exec::OpExecutor>>& exec_list,
+ const RunContext& rctx,
+ bool is_gpu,
+ bool verbose) {
+ // Current executor should be Runnable with the same parameters
+ CHECK(IsRunnable());
+ MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+ cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+ cudaGraphNode_t error_node;
+ cudaError_t err =
+ cudaGraphExecUpdate(graph_exec_.get(), graph_.get(), &error_node,
&update_result);
+ switch (err) {
+ case cudaErrorGraphExecUpdateFailure:
+ MakeGraphExec(exec_list, rctx);
+ break;
+ case cudaSuccess:
+ CHECK_EQ(update_result, cudaGraphExecUpdateSuccess);
+ break;
+ default:
+ // Respond normally to unusual cudaGraphExecUpdate() ret vals
+ CUDA_CALL(err);
+ }
+ }
+
+ void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor>>&
exec_list,
+ const RunContext& rctx,
+ bool is_gpu) {
+ if (IsRunnable()) {
+ auto s = rctx.get_stream<gpu>();
+ const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+ CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+ } else {
+ // No CUDA Graph could be made for this portion of the OpSegment. Run
conventionally.
+ for (int i = 0; i != num_ops_; ++i)
+ exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+ }
+ }
+
+ bool IsRunnable() {
+ return graph_exec_ != nullptr;
+ }
+
+ int NumGraphNodes() {
+ size_t numNodes;
+ CUDA_CALL(cudaGraphGetNodes(graph_.get(),
static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+ return numNodes;
+ }
+
+ private:
+ void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor>>&
exec_list,
+ const RunContext& rctx,
+ bool is_gpu,
+ bool verbose,
+ int from_op_idx,
+ int num_ops) {
+ auto s = rctx.get_stream<gpu>();
+ const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+ // Create CUDA Graph
+ // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU
Copy workers
+ // to sync their streams without disturbing this capture.
+ CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+ // Run those oprs in the sub segment while capturing- no actual GPU work
is launched.
+ for (int i = 0; i != num_ops; ++i)
+ exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+ cudaGraph_t cuda_graph = nullptr;
+ CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+ graph_.reset(cuda_graph, CudaGraphDeleter());
+
+ if (verbose) {
+ std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+ size_t num_nodes = graph_nodes.size();
+ LOG(INFO) << " Graph has " << num_nodes << " nodes:";
+ for (size_t i = 0; i != num_nodes; ++i) {
+ LOG(INFO) << " node " << i << " = " <<
CudaGraphNodeToString(graph_nodes[i]);
+ }
+ }
+ }
+
+ void MakeGraphExec(const std::vector<std::shared_ptr<exec::OpExecutor>>&
exec_list,
+ const RunContext& rctx) {
+ // Note that this routine is not invoked when a graph executor is merely
updated.
+ cudaGraphExec_t cuda_graph_exec;
+ cudaGraphNode_t error_node;
+ char log_buffer[1000];
+
+ CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
&error_node, log_buffer, 1000));
+ graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+ // At this point we have a CUDA Graph executor
+ static int num_graph_creations = 0;
+ graph_exec_id_ = num_graph_creations++;
+
+ static size_t max_log_entries =
dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+ if (graph_exec_id_ < max_log_entries) {
+ LOG(INFO) << "Created CUDA graph " << graph_exec_id_;
+ if (num_graph_creations == max_log_entries)
+ LOG(INFO) << "Further CUDA graph creation log messages are
suppressed.";
+ }
+ // Create a .dot file for graph visualization if requested
+ static std::string dotfile_base =
dmlc::GetEnv("MXNET_CUDA_GRAPHS_DBG_FILE", std::string());
+ if (dotfile_base.size() > 0) {
+#if CUDA_VERSION >= 11030
+ static int dotfile_flags =
dmlc::GetEnv("MXNET_CUDA_GRAPHS_DBG_FILE_FLAGS",
+
static_cast<int>(cudaGraphDebugDotFlagsVerbose));
+ std::ostringstream filename;
+ const bool is_train = exec_list.size() > 0 &&
exec_list[0]->op_ctx.is_train;
+ int dev_id = rctx.ctx.dev_id;
+ filename << dotfile_base << "-"
+ << "dev" << dev_id << "-" << (is_train ? "trn" : "inf") << "-"
<< graph_exec_id_
+ << ".dot";
+ CUDA_CALL(cudaGraphDebugDotPrint(graph_.get(), filename.str().c_str(),
dotfile_flags));
+#else
+ [[maybe_unused]] static bool dot_file_unsupported = []() { // NOLINT
+ LOG(INFO) << "MXNET_CUDA_GRAPHS_DBG_FILE setting ignored- requires
CUDA version >= 11.3";
+ return true;
+ }();
+#endif // CUDA_VERSION >= 11030
+ }
+ }
+
+ int from_op_idx_;
+ int num_ops_;
+ using cudaGraphStruct_t = typename
std::remove_pointer<cudaGraph_t>::type;
+ using cudaGraphExecStruct_t = typename
std::remove_pointer<cudaGraphExec_t>::type;
+ std::shared_ptr<cudaGraphStruct_t> graph_;
+ std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+ size_t graph_exec_id_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+ std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+ bool has_been_run_conventionally = false;
+ std::vector<void*> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU
Worker) and
+// the state of the is_train flag of the OpContext. If the tempspace_dptrs
change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we
discard the CUDA graph.
+struct CudaGraphCacheKey {
+ cudaStream_t cu_s;
+ bool is_train;
+ // overload '<' so CudaGraphCacheKey can be used as a std::map key
+ bool operator<(const CudaGraphCacheKey& other) const {
+ return cu_s < other.cu_s || (cu_s == other.cu_s && is_train <
other.is_train);
+ }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+ CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor>>&
exec_list,
+ bool is_gpu,
+ const char* opr_names)
+ : verbose_(false), is_enabled_(false) {
+ opr_names_ = opr_names ? std::string(opr_names) : std::string();
+ if (is_gpu) {
+ is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+ verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+ SetTempSpaces(exec_list);
+ }
+ }
+
+ void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor>>& exec_list,
+ const RunContext& rctx,
+ bool is_gpu) {
+ // If this a CPU op or CUDA Graphs use isn't possible, run normally and
return
+ if (!is_gpu || !is_enabled_) {
+ // Run all opr in the sub-graph
+ exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+ return;
+ }
+
+ // Also if we're in a warm-up period where tempspace pointers are likely
+ // to change, run normally and return
+ auto s = rctx.get_stream<gpu>();
+ const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+ // All the ops in the bulked segment will have the same setting of
is_train as the first op
+ const bool is_train = exec_list.size() > 0 &&
exec_list[0]->op_ctx.is_train;
+ const CudaGraphCacheKey key = {cu_s, is_train};
+ // Look-up the CUDA Graph info for this combo of stream and is_train
setting
+ // This may create a default-initialized new entry.
+ auto& cuda_graph_info = cache_[key];
+ if (!cuda_graph_info.has_been_run_conventionally) {
+ // Run all opr in the sub-graph
+ exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+ cuda_graph_info.has_been_run_conventionally = true;
+ return;
+ }
+
+ // At this point we will launch one or more CUDA Graphs through CUDA
Graphs 'executors'
+ // (there might be more than one executor if some ops in the segment
are not capturable)
+ auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+ // Executors exist, but the tempspace pts have changed, so update them
in-place via 'recapture'.
+ if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+ cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+ // Update all runnable executors. Non-runnable executors launch their
ops conventionally.
+ for (auto& subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+ if (subseg_exec.IsRunnable())
+ subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+ }
+ } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+ // No executors exist yet, so create them.
+ if (verbose_)
+ LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+ // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+ for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+ int num_good_ops = 0;
+ for (size_t last_op_idx = first_op_idx; last_op_idx !=
exec_list.size(); ++last_op_idx) {
+ if (OpOK(exec_list[last_op_idx]))
+ num_good_ops++;
+ else
+ break;
+ }
+ if (num_good_ops > 0) {
+ CreateSubExecOverRegion(exec_list,
+ rctx,
+ is_gpu,
+ first_op_idx,
+ first_op_idx + num_good_ops,
+ &cuda_graph_info.cuda_graph_subseg_execs);
+ first_op_idx += num_good_ops;
+ }
+ if (first_op_idx != exec_list.size()) {
+ // We had to have hit an op that was not OK.
+ if (verbose_) {
+ LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << ","
<< first_op_idx << "]"
+ << " of op segment " << opr_names_;
+ }
+ CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false,
first_op_idx, 1, false);
+ cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+ first_op_idx++;
+ }
+ }
+ // During graph capture, the ops may be asking for the tempworkspace.
This should
+ // not alter the base pointers, since this op seg has been executed
before on this
+ // stream (i.e. on this gpu worker). Safest to double-check this though.
+ auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+ if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+ LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during
CUDA graph use.";
+ cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+ }
+ // Now execute the CUDA Graph that we either just created or looked-up in
the cache.
+ if (verbose_) {
+ int runnable_execs = 0;
+ int bypassed_ops = 0;
+ for (auto& subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+ if (subseg_exec.IsRunnable()) {
+ LOG(INFO) << "Launching captured graph with " <<
subseg_exec.NumGraphNodes() << " nodes.";
+ runnable_execs++;
+ } else {
+ bypassed_ops++;
+ }
+ }
+ if (bypassed_ops > 0)
+ LOG(INFO) << " (bypassing " << bypassed_ops << " un-capturable
ops)";
+ }
+ for (auto& subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+ subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+ }
+
+ private:
+ // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx). If
such a graph
+ // is not runnable, e.g. if it includes memcpys from unpinned cpu memory,
then make a
+ // number of smaller graphs that avoid those ops with the memcpys.
+ void CreateSubExecOverRegion(const
std::vector<std::shared_ptr<exec::OpExecutor>>& exec_list,
+ const RunContext& rctx,
+ bool is_gpu,
+ size_t from_op_idx,
+ size_t upto_op_idx,
+ std::vector<CudaGraphsSubSegExec>*
cuda_graph_subseg_execs) {
+ // Optimistically try to create a CUDA Graph of the entire op segment
region
+
+ int num_ops = upto_op_idx - from_op_idx;
+ CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_,
from_op_idx, num_ops);
+ if (full_opseg.IsRunnable()) {
+ cuda_graph_subseg_execs->push_back(full_opseg);
+ } else {
+ if (verbose_)
+ LOG(INFO) << " Graph was not runnable- creating op sub-segments...";
+ // Enter fall-back approach to making many sub-execs
+ for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx;) {
+ int num_good_ops = 0;
+ for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx;
++last_op_idx) {
+ CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false,
last_op_idx, 1);
+ if (single_opseg.IsRunnable())
+ num_good_ops++;
+ // Is it time to create a subseg exec from accumulated good ops?
+ if (num_good_ops > 0 && (last_op_idx == upto_op_idx - 1 ||
!single_opseg.IsRunnable())) {
+ if (verbose_)
+ LOG(INFO) << "Capturing CUDA graph of op sub segment[" <<
first_op_idx << ":"
+ << (first_op_idx + num_good_ops - 1) << "]"
+ << " of op segment " << opr_names_;
+ CudaGraphsSubSegExec good_opseg(
+ exec_list, rctx, is_gpu, verbose_, first_op_idx, num_good_ops);
+ CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA
Graphs creation";
+ cuda_graph_subseg_execs->push_back(good_opseg);
+ first_op_idx += num_good_ops;
+ }
+ // If the last single op was not runnable, use the exec to handle
that op conventionally
+ if (!single_opseg.IsRunnable()) {
+ if (verbose_) {
+ LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << ","
<< last_op_idx << "]"
+ << " of op segment " << opr_names_;
+ // Generate throw-away exec in order to produce a diagnostic
listing of graph nodes
+ CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_,
last_op_idx, 1);
+ }
+ cuda_graph_subseg_execs->push_back(single_opseg);
+ first_op_idx++;
+ break;
+ }
+ }
+ }
+ }
+ }
+
+ // Is the Op OK to make part of a CUDA Graph?
+ bool OpOK(const std::shared_ptr<exec::OpExecutor>& exec) {
+ static auto& fgraphcompatible =
Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+ static auto& fcompute_ex = Op::GetAttr<FComputeEx>("FComputeEx<gpu>");
+ static auto& fstatefulcompute =
Op::GetAttr<FStatefulCompute>("FStatefulCompute<gpu>");
+ static auto& fstatefulcompute_ex =
Op::GetAttr<FStatefulComputeEx>("FStatefulComputeEx<gpu>");
+ const auto& attrs = exec->attrs;
+ if (attrs.op != nullptr) {
+ const auto f = fgraphcompatible.get(attrs.op, nullptr);
+ if (f != nullptr) {
+ return f(attrs, exec->op_ctx.is_train);
+ }
+ if (fstatefulcompute.get(attrs.op, nullptr) != nullptr ||
+ fstatefulcompute_ex.get(attrs.op, nullptr) != nullptr) {
+ if (verbose_) {
+ LOG(INFO) << "Omitting stateful operator " << attrs.op->name << "
from CUDA graph.";
+ }
+ return false;
+ }
+ if ((fcompute_ex.get(attrs.op, nullptr) != nullptr &&
+ exec->dispatch_mode == DispatchMode::kFComputeEx) ||
+ exec->dispatch_mode == DispatchMode::kFComputeFallback) {
+ if (verbose_) {
+ LOG(INFO) << "Omitting operator " << attrs.op->name
+ << " from CUDA graph due to dispatch mode "
+ << static_cast<int>(exec->dispatch_mode);
+ }
+ return false;
+ }
+ }
+ for (auto& resource : exec->op_ctx.requested) {
+ if (!(resource.req.type == ResourceRequest::kTempSpace)) {
+ if (verbose_) {
+ LOG(INFO) << "Omitting operator " << attrs.op->name
+ << " from CUDA graph due to using the resource type "
+ << static_cast<int>(resource.req.type);
+ }
+ return false;
+ }
+ }
+ return true;
+ }
+
+ // Determine Tempspaces used by ops. Other resource uses disable CUDA
Graphs.
+ void SetTempSpaces(const std::vector<std::shared_ptr<exec::OpExecutor>>&
exec_list) {
+ // Gather info about the ops use of TempSpace.
+ if (is_enabled_) {
+ std::set<Resource*> tempspaces_set;
+ for (auto& exec : exec_list) {
+ for (auto& resource : exec->op_ctx.requested) {
+ if (resource.req.type == ResourceRequest::kTempSpace) {
+ tempspaces_set.insert(&resource);
+ }
+ }
+ }
+ tempspaces_.assign(tempspaces_set.begin(), tempspaces_set.end());
+ }
+ }
+
+ // Return the addresses of the gpu TempSpace areas
+ std::vector<void*> GetGPUTempspacePtrs(mshadow::Stream<gpu>* s) {
+ std::vector<void*> ret;
+ for (const auto& resource : tempspaces_) {
+ // Ask for minimal allocation to get base pointer without increasing the
size
+ auto* base_ptr = resource->get_space_typed<gpu, 1,
char>(mshadow::Shape1(1), s).dptr_;
+ ret.push_back(static_cast<void*>(base_ptr));
+ }
+ return ret;
+ }
+
+ CudaGraphCache cache_;
+ std::vector<Resource*> tempspaces_;
+ std::string opr_names_;
+ bool verbose_;
+ bool is_enabled_;
+};
+
+} // namespace cuda_graphs
+} // namespace mxnet
+
+#endif // CUDA_GRAPHS_AVAILABLE
+
+#endif // MXNET_IMPERATIVE_CUDA_GRAPHS_H_
diff --git a/src/imperative/exec_pass.h b/src/imperative/exec_pass.h
index 7667d97..02fa967 100644
--- a/src/imperative/exec_pass.h
+++ b/src/imperative/exec_pass.h
@@ -30,6 +30,7 @@
#include <mxnet/graph_attr_types.h>
#include <nnvm/graph.h>
#include <nnvm/graph_attr_types.h>
+#include <utility>
#include <vector>
#include <memory>
#include <string>
@@ -84,6 +85,13 @@ class OpExecutor {
std::vector<OpReqType> req;
/*! \brief runtime op context, contains allocated resources */
OpContext op_ctx;
+ /*! \brief attributes of the node */
+ NodeAttrs attrs;
+ /*! \brief dispatch mode of the executor */
+ DispatchMode dispatch_mode;
+
+ explicit OpExecutor(NodeAttrs attrs, DispatchMode dispatch_mode)
+ : attrs(std::move(attrs)), dispatch_mode(dispatch_mode) {}
/*! \brief virtual destructor */
virtual ~OpExecutor() {}
/*!
@@ -98,6 +106,17 @@ class OpExecutor {
* \param rctx The runtime context passed in by environment.
*/
virtual void Run(RunContext rctx, bool is_gpu) = 0;
+ /*!
+ * \brief run the operators of a vector of execs, given runtime context on
device.
+ * This function call does not synchronize the stream.
+ * \param rctx The runtime context passed in by environment.
+ */
+ static void RunAll(const std::vector<std::shared_ptr<OpExecutor>>& execs,
+ RunContext rctx,
+ bool is_gpu) {
+ for (auto& exec : execs)
+ exec->Run(rctx, is_gpu);
+ }
/*! \return the execution type */
virtual ExecType exec_type() const = 0;
/*! \return return engine variable for operator states */
diff --git a/src/imperative/imperative_utils.h
b/src/imperative/imperative_utils.h
index ce1a60f..7f90528 100644
--- a/src/imperative/imperative_utils.h
+++ b/src/imperative/imperative_utils.h
@@ -27,6 +27,8 @@
#include <utility>
#include <vector>
+#include "./exec_pass.h"
+#include "./cuda_graphs.h"
#include "../c_api/c_api_common.h"
#include "../common/exec_utils.h"
#include "../common/utils.h"
@@ -1248,6 +1250,21 @@ inline Engine::OprHandle CreateEngineOp(
bool is_gpu = default_ctx.dev_mask() == gpu::kDevMask;
bool is_async = execs.size() > 1 ? false : execs[0]->exec_type() ==
ExecType::kAsync;
+#if CUDA_GRAPHS_AVAILABLE
+ // Provide initialized `cuda_graphs_exec`, which when captured
+ // by exec_fun, acts like a static variable inside the mutable closure.
+ cuda_graphs::CudaGraphsExec cuda_graphs_exec(execs, is_gpu, opr_names);
+ auto exec_fun = [cuda_graphs_exec, execs, is_async, is_gpu](
+ RunContext ctx,
+ Engine::CallbackOnStart on_start,
+ Engine::CallbackOnComplete on_complete) mutable {
+ on_start();
+ if (is_async) {
+ execs[0]->op_ctx.async_on_complete = on_complete;
+ }
+ // Run all opr in the sub-graph with CUDA graphs executor if possible
+ cuda_graphs_exec.RunAll(execs, ctx, is_gpu);
+#else
auto exec_fun = [execs, is_async, is_gpu](RunContext ctx,
Engine::CallbackOnStart on_start,
Engine::CallbackOnComplete
on_complete) {
@@ -1255,8 +1272,8 @@ inline Engine::OprHandle CreateEngineOp(
if (is_async) {
execs[0]->op_ctx.async_on_complete = on_complete;
}
- for (const auto& exec : execs)
- exec->Run(ctx, is_gpu);
+ exec::OpExecutor::RunAll(execs, ctx, is_gpu);
+#endif
// call on complete only if it is async op
if (!is_async) {
if (is_gpu) {
diff --git a/src/operator/contrib/adamw.cu b/src/operator/contrib/adamw.cu
index 8023788..0b24790 100644
--- a/src/operator/contrib/adamw.cu
+++ b/src/operator/contrib/adamw.cu
@@ -45,15 +45,23 @@ void GetScaleFloat<gpu>(mshadow::Stream<gpu>* s, const
TBlob& scale_blob, float*
})}
NNVM_REGISTER_OP(_adamw_update)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", adamw::MPUpdate<gpu,
AdamWUpdate<gpu>>);
NNVM_REGISTER_OP(_mp_adamw_update)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", adamw::MPUpdate<gpu,
MPAdamWUpdate<gpu>>);
NNVM_REGISTER_OP(_multi_adamw_update)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", adamw::multiMPUpdate<gpu, false>);
NNVM_REGISTER_OP(_multi_mp_adamw_update)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", adamw::multiMPUpdate<gpu, true>);
} // namespace adamw
diff --git a/src/operator/contrib/index_array.cu
b/src/operator/contrib/index_array.cu
index 482cbf6..3702fed 100644
--- a/src/operator/contrib/index_array.cu
+++ b/src/operator/contrib/index_array.cu
@@ -82,7 +82,10 @@ void IndexArrayForwardGPU(const nnvm::NodeAttrs& attrs,
}
}
-NNVM_REGISTER_OP(_contrib_index_array).set_attr<FCompute>("FCompute<gpu>",
IndexArrayForwardGPU);
+NNVM_REGISTER_OP(_contrib_index_array)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
+ .set_attr<FCompute>("FCompute<gpu>", IndexArrayForwardGPU);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/contrib/multi_lamb.cu
b/src/operator/contrib/multi_lamb.cu
index 118ec63..c6bedfc 100644
--- a/src/operator/contrib/multi_lamb.cu
+++ b/src/operator/contrib/multi_lamb.cu
@@ -268,9 +268,13 @@ void CallKernel2(Stream<gpu>* s,
}
NNVM_REGISTER_OP(_multi_lamb_update)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
.set_attr<FCompute>("FCompute<gpu>", MultiLAMBUpdate<gpu, false>);
NNVM_REGISTER_OP(_multi_mp_lamb_update)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
.set_attr<FCompute>("FCompute<gpu>", MultiLAMBUpdate<gpu, true>);
} // namespace op
diff --git a/src/operator/instance_norm.cu b/src/operator/instance_norm.cu
index ca45dbb..ce11fbf 100644
--- a/src/operator/instance_norm.cu
+++ b/src/operator/instance_norm.cu
@@ -28,9 +28,14 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(InstanceNorm).set_attr<FCompute>("FCompute<gpu>",
InstanceNormForward<gpu>);
+NNVM_REGISTER_OP(InstanceNorm)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
+ .set_attr<FCompute>("FCompute<gpu>", InstanceNormForward<gpu>);
NNVM_REGISTER_OP(_backward_instance_norm)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
.set_attr<FCompute>("FCompute<gpu>", InstanceNormBackward<gpu>);
} // namespace op
diff --git a/src/operator/leaky_relu.cu b/src/operator/leaky_relu.cu
index d461949..82ec59b 100644
--- a/src/operator/leaky_relu.cu
+++ b/src/operator/leaky_relu.cu
@@ -28,9 +28,14 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(LeakyReLU).set_attr<FCompute>("FCompute<gpu>",
LeakyReLUCompute<gpu>);
+NNVM_REGISTER_OP(LeakyReLU)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
+ .set_attr<FCompute>("FCompute<gpu>", LeakyReLUCompute<gpu>);
NNVM_REGISTER_OP(_backward_LeakyReLU)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
.set_attr<FCompute>("FCompute<gpu>", LeakyReLUGradCompute<gpu>);
} // namespace op
diff --git a/src/operator/nn/dropout-inl.h b/src/operator/nn/dropout-inl.h
index 18f94cf..0baa8e4 100644
--- a/src/operator/nn/dropout-inl.h
+++ b/src/operator/nn/dropout-inl.h
@@ -437,7 +437,6 @@ class DropoutOp {
using namespace mshadow::expr;
Stream<xpu>* s = ctx.get_stream<xpu>();
if (!this->dropout_passthrough_) {
- this->dropout_passthrough_ = true;
const TBlob& gdata = in_grad[dropout::kData];
const TBlob& grad = out_grad[dropout::kOut];
const TBlob& mask = out_data[dropout::kMask];
diff --git a/src/operator/nn/dropout.cu b/src/operator/nn/dropout.cu
index d6c97f5..414b82e 100644
--- a/src/operator/nn/dropout.cu
+++ b/src/operator/nn/dropout.cu
@@ -28,7 +28,26 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(Dropout).set_attr<FStatefulCompute>("FStatefulCompute<gpu>",
DropoutCompute<gpu>);
+NNVM_REGISTER_OP(Dropout)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool
is_train) {
+ // Dropout is a passthrough during
inference for all impls
+ if (!is_train)
+ return true;
+#if MXNET_USE_CUDNN_DROPOUT
+ // cuDNN impl is compatible during
training as well
+ const DropoutParam& param =
+
nnvm::get<DropoutParam>(attrs.parsed);
+ real_t pkeep = 1.0f - param.p;
+ bool cudnn_off =
+ param.cudnn_off &&
param.cudnn_off.value();
+ bool cudnn_available = pkeep > 0 &&
!cudnn_off;
+ return cudnn_available;
+#else
+ return false;
+#endif // MXNET_USE_CUDNN_DROPOUT
+ })
+ .set_attr<FStatefulCompute>("FStatefulCompute<gpu>", DropoutCompute<gpu>);
NNVM_REGISTER_OP(_backward_Dropout)
.set_attr<FStatefulCompute>("FStatefulCompute<gpu>",
DropoutGradCompute<gpu>);
diff --git a/src/operator/numpy/linalg/np_eig.cu
b/src/operator/numpy/linalg/np_eig.cu
index 1f89106..a217b6d 100644
--- a/src/operator/numpy/linalg/np_eig.cu
+++ b/src/operator/numpy/linalg/np_eig.cu
@@ -28,11 +28,17 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_npi_eig).set_attr<FCompute>("FCompute<gpu>",
EigOpForward<gpu>);
+NNVM_REGISTER_OP(_npi_eig)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", EigOpForward<gpu>);
#if MXNET_USE_CUSOLVER == 1
-NNVM_REGISTER_OP(_npi_eigh).set_attr<FCompute>("FCompute<gpu>",
EighOpForward<gpu>);
+NNVM_REGISTER_OP(_npi_eigh)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", EighOpForward<gpu>);
#endif
diff --git a/src/operator/numpy/linalg/np_eigvals.cu
b/src/operator/numpy/linalg/np_eigvals.cu
index dc03805..be00d8c 100644
--- a/src/operator/numpy/linalg/np_eigvals.cu
+++ b/src/operator/numpy/linalg/np_eigvals.cu
@@ -28,11 +28,17 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_npi_eigvals).set_attr<FCompute>("FCompute<gpu>",
EigvalsOpForward<gpu>);
+NNVM_REGISTER_OP(_npi_eigvals)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", EigvalsOpForward<gpu>);
#if MXNET_USE_CUSOLVER == 1
-NNVM_REGISTER_OP(_npi_eigvalsh).set_attr<FCompute>("FCompute<gpu>",
EigvalshOpForward<gpu>);
+NNVM_REGISTER_OP(_npi_eigvalsh)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", EigvalshOpForward<gpu>);
#endif
diff --git a/src/operator/numpy/linalg/np_norm_backward.cu
b/src/operator/numpy/linalg/np_norm_backward.cu
index 24d8783..23a021d 100644
--- a/src/operator/numpy/linalg/np_norm_backward.cu
+++ b/src/operator/numpy/linalg/np_norm_backward.cu
@@ -26,6 +26,12 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(_backward_npi_norm)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool) {
+ const NumpyNormParam& param =
+
nnvm::get<NumpyNormParam>(attrs.parsed);
+ return param.axis.value().ndim() == 2;
+ })
.set_attr<FCompute>("FCompute<gpu>", NumpyNormComputeBackward<gpu>);
} // namespace op
diff --git a/src/operator/numpy/linalg/np_norm_forward.cu
b/src/operator/numpy/linalg/np_norm_forward.cu
index 8926763..7399727 100644
--- a/src/operator/numpy/linalg/np_norm_forward.cu
+++ b/src/operator/numpy/linalg/np_norm_forward.cu
@@ -25,7 +25,14 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_npi_norm).set_attr<FCompute>("FCompute<gpu>",
NumpyNormComputeForward<gpu>);
+NNVM_REGISTER_OP(_npi_norm)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool) {
+ const NumpyNormParam& param =
+
nnvm::get<NumpyNormParam>(attrs.parsed);
+ return param.axis.value().ndim() == 2;
+ })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyNormComputeForward<gpu>);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/np_boolean_mask_assign.cu
b/src/operator/numpy/np_boolean_mask_assign.cu
index 10f8612..216e8ff 100644
--- a/src/operator/numpy/np_boolean_mask_assign.cu
+++ b/src/operator/numpy/np_boolean_mask_assign.cu
@@ -273,9 +273,13 @@ void NumpyBooleanAssignForwardGPU(const nnvm::NodeAttrs&
attrs,
}
NNVM_REGISTER_OP(_npi_boolean_mask_assign_scalar)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", NumpyBooleanAssignForwardGPU);
NNVM_REGISTER_OP(_npi_boolean_mask_assign_tensor)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", NumpyBooleanAssignForwardGPU);
} // namespace op
diff --git a/src/operator/numpy/np_constraint_check.cu
b/src/operator/numpy/np_constraint_check.cu
index 04a0a36..26a5f01 100644
--- a/src/operator/numpy/np_constraint_check.cu
+++ b/src/operator/numpy/np_constraint_check.cu
@@ -38,6 +38,8 @@ void GetReduceOutput<gpu>(mshadow::Stream<gpu>* s, const
TBlob& output_blob, boo
}
NNVM_REGISTER_OP(_npx_constraint_check)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", ConstraintCheckForward<gpu>);
} // namespace op
diff --git a/src/operator/numpy/np_matrix_op.cu
b/src/operator/numpy/np_matrix_op.cu
index f207814..2785898 100644
--- a/src/operator/numpy/np_matrix_op.cu
+++ b/src/operator/numpy/np_matrix_op.cu
@@ -52,9 +52,15 @@ NNVM_REGISTER_OP(_npi_column_stack)
NNVM_REGISTER_OP(_backward_np_column_stack)
.set_attr<FCompute>("FCompute<gpu>", NumpyColumnStackBackward<gpu>);
-NNVM_REGISTER_OP(_npi_tril_indices).set_attr<FCompute>("FCompute<gpu>",
TrilindicesOpForward<gpu>);
+NNVM_REGISTER_OP(_npi_tril_indices)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
+ .set_attr<FCompute>("FCompute<gpu>", TrilindicesOpForward<gpu>);
-NNVM_REGISTER_OP(_npi_roll).set_attr<FCompute>("FCompute<gpu>",
NumpyRollCompute<gpu>);
+NNVM_REGISTER_OP(_npi_roll)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyRollCompute<gpu>);
template <>
void NumpyFlipForwardImpl<gpu>(const OpContext& ctx,
@@ -92,9 +98,15 @@ void NumpyFlipForwardImpl<gpu>(const OpContext& ctx,
});
}
-NNVM_REGISTER_OP(_npi_flip).set_attr<FCompute>("FCompute<gpu>",
NumpyFlipForward<gpu>);
+NNVM_REGISTER_OP(_npi_flip)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);
-NNVM_REGISTER_OP(_backward_npi_flip).set_attr<FCompute>("FCompute<gpu>",
NumpyFlipForward<gpu>);
+NNVM_REGISTER_OP(_backward_npi_flip)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyFlipForward<gpu>);
NNVM_REGISTER_OP(_npi_moveaxis).set_attr<FCompute>("FCompute<gpu>",
NumpyMoveaxisCompute<gpu>);
@@ -103,7 +115,22 @@
NNVM_REGISTER_OP(_npi_rollaxis).set_attr<FCompute>("FCompute<gpu>", NumpyRollaxi
NNVM_REGISTER_OP(_npi_rollaxis_backward)
.set_attr<FCompute>("FCompute<gpu>", NumpyRollaxisBackward<gpu>);
-NNVM_REGISTER_OP(_npi_rot90).set_attr<FCompute>("FCompute<gpu>",
NumpyRot90Compute<gpu>);
+NNVM_REGISTER_OP(_npi_rot90)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool) {
+ const auto& param =
+
nnvm::get<NumpyRot90Param>(attrs.parsed);
+ // Should track code in
NumpyRot90Compute()
+ int real_k(param.k);
+ real_k = real_k % 4;
+ if (real_k < 0) {
+ real_k += 4;
+ }
+ // Avoid NumpyRot90ComputeFlipIml(),
+ // which uses mshadow::Copy()
+ return real_k != 2;
+ })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyRot90Compute<gpu>);
NNVM_REGISTER_OP(_npi_hsplit).set_attr<FCompute>("FCompute<gpu>",
HSplitOpForward<gpu>);
diff --git a/src/operator/numpy/np_nonzero_op.cu
b/src/operator/numpy/np_nonzero_op.cu
index 1499030..597331e 100644
--- a/src/operator/numpy/np_nonzero_op.cu
+++ b/src/operator/numpy/np_nonzero_op.cu
@@ -115,6 +115,8 @@ NNVM_REGISTER_OP(_npx_nonzero)
[](const NodeAttrs& attrs) {
return
std::vector<ResourceRequest>{ResourceRequest::kTempSpace};
})
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs& attrs, const bool)
{ return false; })
.set_attr<FComputeEx>("FComputeEx<gpu>", NonzeroForwardGPU);
} // namespace op
diff --git a/src/operator/numpy/np_pad_op.cu b/src/operator/numpy/np_pad_op.cu
index 01a7035..1b9f4f4 100644
--- a/src/operator/numpy/np_pad_op.cu
+++ b/src/operator/numpy/np_pad_op.cu
@@ -28,9 +28,17 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_npi_pad).set_attr<FCompute>("FCompute<gpu>",
NumpyPadOpForward<gpu>);
+NNVM_REGISTER_OP(_npi_pad)
+ // Incompatible due to Copy(xpu_tensor, cpu_tensor) in NumpyPadOpForward
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyPadOpForward<gpu>);
-NNVM_REGISTER_OP(_backward_npi_pad).set_attr<FCompute>("FCompute<gpu>",
NumpyPadOpBackward<gpu>);
+NNVM_REGISTER_OP(_backward_npi_pad)
+ // Incompatible due to Copy(xpu_tensor, cpu_tensor) in NumpyPadOpBackward
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyPadOpBackward<gpu>);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/np_percentile_op.cu
b/src/operator/numpy/np_percentile_op.cu
index 13d076d..2dcc829 100644
--- a/src/operator/numpy/np_percentile_op.cu
+++ b/src/operator/numpy/np_percentile_op.cu
@@ -52,7 +52,10 @@ bool CheckInvalidInput(mshadow::Stream<gpu>* s,
return is_valid == 0;
}
-NNVM_REGISTER_OP(_npi_percentile).set_attr<FCompute>("FCompute<gpu>",
NumpyPercentileForward<gpu>);
+NNVM_REGISTER_OP(_npi_percentile)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyPercentileForward<gpu>);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/random/np_bernoulli_op.cu
b/src/operator/numpy/random/np_bernoulli_op.cu
index 8cdceb5..eee89c1 100644
--- a/src/operator/numpy/random/np_bernoulli_op.cu
+++ b/src/operator/numpy/random/np_bernoulli_op.cu
@@ -27,7 +27,10 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_npi_bernoulli).set_attr<FCompute>("FCompute<gpu>",
NumpyBernoulliForward<gpu>);
+NNVM_REGISTER_OP(_npi_bernoulli)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyBernoulliForward<gpu>);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/random/np_exponential_op.cu
b/src/operator/numpy/random/np_exponential_op.cu
index 60809fb..8ad7386 100644
--- a/src/operator/numpy/random/np_exponential_op.cu
+++ b/src/operator/numpy/random/np_exponential_op.cu
@@ -28,6 +28,8 @@ namespace mxnet {
namespace op {
NNVM_REGISTER_OP(_npi_exponential)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", NumpyExponentialForward<gpu>);
NNVM_REGISTER_OP(_backward_broadcast_exponential)
diff --git a/src/operator/numpy/random/np_gamma_op.cu
b/src/operator/numpy/random/np_gamma_op.cu
index 7e3cabc..0191fd5 100644
--- a/src/operator/numpy/random/np_gamma_op.cu
+++ b/src/operator/numpy/random/np_gamma_op.cu
@@ -28,7 +28,10 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_npi_gamma).set_attr<FCompute>("FCompute<gpu>",
NumpyGammaForward<gpu, double>);
+NNVM_REGISTER_OP(_npi_gamma)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyGammaForward<gpu, double>);
NNVM_REGISTER_OP(_backward_gamma_sample).set_attr<FCompute>("FCompute<gpu>",
NumpyGammaGrad<gpu>);
diff --git a/src/operator/numpy/random/np_multinomial_op.cu
b/src/operator/numpy/random/np_multinomial_op.cu
index 083b410..575ad08 100644
--- a/src/operator/numpy/random/np_multinomial_op.cu
+++ b/src/operator/numpy/random/np_multinomial_op.cu
@@ -41,6 +41,8 @@ void CheckPvalGPU(const OpContext& ctx, DType* input, int
prob_length) {
}
NNVM_REGISTER_OP(_npi_multinomial)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", NumpyMultinomialForward<gpu>);
} // namespace op
diff --git a/src/operator/numpy/random/np_normal_op.cu
b/src/operator/numpy/random/np_normal_op.cu
index db87461..525a0e1 100644
--- a/src/operator/numpy/random/np_normal_op.cu
+++ b/src/operator/numpy/random/np_normal_op.cu
@@ -27,12 +27,18 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_npi_normal).set_attr<FCompute>("FCompute<gpu>",
NumpyNormalForward<gpu>);
+NNVM_REGISTER_OP(_npi_normal)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyNormalForward<gpu>);
NNVM_REGISTER_OP(_backward_broadcast_normal)
.set_attr<FCompute>("FCompute<gpu>", NormalReparamBackward<gpu>);
-NNVM_REGISTER_OP(_npi_normal_n).set_attr<FCompute>("FCompute<gpu>",
NumpyNormalForward<gpu>);
+NNVM_REGISTER_OP(_npi_normal_n)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyNormalForward<gpu>);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/random/np_pareto_op.cu
b/src/operator/numpy/random/np_pareto_op.cu
index 7618d28..82fcd1f 100644
--- a/src/operator/numpy/random/np_pareto_op.cu
+++ b/src/operator/numpy/random/np_pareto_op.cu
@@ -27,7 +27,10 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_npi_pareto).set_attr<FCompute>("FCompute<gpu>",
NumpyParetoForward<gpu>);
+NNVM_REGISTER_OP(_npi_pareto)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyParetoForward<gpu>);
NNVM_REGISTER_OP(_backward_broadcast_pareto)
.set_attr<FCompute>("FCompute<gpu>", ParetoReparamBackward<gpu>);
diff --git a/src/operator/numpy/random/np_power_op.cu
b/src/operator/numpy/random/np_power_op.cu
index 2904420..f7a6686 100644
--- a/src/operator/numpy/random/np_power_op.cu
+++ b/src/operator/numpy/random/np_power_op.cu
@@ -27,7 +27,10 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_npi_powerd).set_attr<FCompute>("FCompute<gpu>",
NumpyPowerForward<gpu>);
+NNVM_REGISTER_OP(_npi_powerd)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyPowerForward<gpu>);
} // namespace op
} // namespace mxnet
diff --git a/src/operator/numpy/random/np_rayleigh_op.cu
b/src/operator/numpy/random/np_rayleigh_op.cu
index 586f174..f67a2fe 100644
--- a/src/operator/numpy/random/np_rayleigh_op.cu
+++ b/src/operator/numpy/random/np_rayleigh_op.cu
@@ -27,7 +27,10 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_npi_rayleigh).set_attr<FCompute>("FCompute<gpu>",
NumpyRayleighForward<gpu>);
+NNVM_REGISTER_OP(_npi_rayleigh)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyRayleighForward<gpu>);
NNVM_REGISTER_OP(_backward_broadcast_rayleigh)
.set_attr<FCompute>("FCompute<gpu>", RayleighReparamBackward<gpu>);
diff --git a/src/operator/numpy/random/np_weibull_op.cu
b/src/operator/numpy/random/np_weibull_op.cu
index 658be16..4495bab 100644
--- a/src/operator/numpy/random/np_weibull_op.cu
+++ b/src/operator/numpy/random/np_weibull_op.cu
@@ -27,7 +27,10 @@
namespace mxnet {
namespace op {
-NNVM_REGISTER_OP(_npi_weibull).set_attr<FCompute>("FCompute<gpu>",
NumpyWeibullForward<gpu>);
+NNVM_REGISTER_OP(_npi_weibull)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", NumpyWeibullForward<gpu>);
NNVM_REGISTER_OP(_backward_broadcast_weibull)
.set_attr<FCompute>("FCompute<gpu>", WeibullReparamBackward<gpu>);
diff --git a/src/operator/tensor/elemwise_unary_op_basic.cu
b/src/operator/tensor/elemwise_unary_op_basic.cu
index 7fdc047..5099301 100644
--- a/src/operator/tensor/elemwise_unary_op_basic.cu
+++ b/src/operator/tensor/elemwise_unary_op_basic.cu
@@ -115,7 +115,10 @@ void ShapeComputeGPU(const nnvm::NodeAttrs& attrs,
mshadow::Stream<gpu>::GetStream(s));
}
-NNVM_REGISTER_OP(shape_array).set_attr<FCompute>("FCompute<gpu>",
ShapeComputeGPU);
+NNVM_REGISTER_OP(shape_array)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", ShapeComputeGPU);
void SizeComputeGPU(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
diff --git a/src/operator/tensor/indexing_op.cu
b/src/operator/tensor/indexing_op.cu
index 9050430..992054f 100644
--- a/src/operator/tensor/indexing_op.cu
+++ b/src/operator/tensor/indexing_op.cu
@@ -957,7 +957,10 @@
NNVM_REGISTER_OP(batch_take).set_attr<FCompute>("FCompute<gpu>", BatchTakeOpForw
NNVM_REGISTER_OP(one_hot).set_attr<FCompute>("FCompute<gpu>",
OneHotOpForward<gpu>);
-NNVM_REGISTER_OP(gather_nd).set_attr<FCompute>("FCompute<gpu>",
GatherNDForwardGPU);
+NNVM_REGISTER_OP(gather_nd)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", GatherNDForwardGPU);
NNVM_REGISTER_OP(scatter_nd).set_attr<FCompute>("FCompute<gpu>",
ScatterNDForward<gpu>);
diff --git a/src/operator/tensor/la_op.cu b/src/operator/tensor/la_op.cu
index 1f16e2d..a32143a 100644
--- a/src/operator/tensor/la_op.cu
+++ b/src/operator/tensor/la_op.cu
@@ -88,6 +88,8 @@ NNVM_REGISTER_OP(_backward_linalg_maketrian)
.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 1, 1, 1,
copytrian>);
NNVM_REGISTER_OP(_linalg_potri)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 1, 1, potri>);
NNVM_REGISTER_OP(_backward_linalg_potri)
@@ -99,32 +101,56 @@ NNVM_REGISTER_OP(_linalg_inverse)
NNVM_REGISTER_OP(_backward_linalg_inverse)
.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 2, 1,
inverse_backward>);
-NNVM_REGISTER_OP(_linalg_det).set_attr<FCompute>("FCompute<gpu>",
LaOpDetForward<gpu, 1, det>);
+NNVM_REGISTER_OP(_linalg_det)
+ // Incompatibility comes from allocs made in linalg_batch_getrf(), called
by det::op()
+ // see https://github.com/apache/incubator-mxnet/issues/19353
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", LaOpDetForward<gpu, 1, det>);
NNVM_REGISTER_OP(_backward_linalg_det)
+ // Incompatibility comes from allocs made in linalg_batch_getri(),
+ // called by linalg_batch_det_backward_helper, called by det_backward::op()
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", LaOpDetBackward<gpu, 1,
det_backward>);
NNVM_REGISTER_OP(_linalg_slogdet)
+ // Incompatibility comes from allocs made in linalg_batch_getrf(),
+ // called by slogdet::op().
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", LaOpDetForward<gpu, 2, slogdet>);
NNVM_REGISTER_OP(_backward_linalg_slogdet)
+ // Incompatibility comes from allocs made in linalg_batch_getri(),
+ // called by linalg_batch_det_backward_helper, called by
slogdet_backward::op()
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", LaOpDetBackward<gpu, 2,
slogdet_backward>);
#if MXNET_USE_CUSOLVER == 1
NNVM_REGISTER_OP(_linalg_potrf)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 1, 1, potrf>);
NNVM_REGISTER_OP(_backward_linalg_potrf)
.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 2, 1,
potrf_backward>);
NNVM_REGISTER_OP(_linalg_gelqf)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
.set_attr<FCompute>("FCompute<gpu>", LaOpForward<gpu, 2, 2, 1, 2, gelqf>);
NNVM_REGISTER_OP(_backward_linalg_gelqf)
.set_attr<FCompute>("FCompute<gpu>", LaOpBackward<gpu, 2, 2, 4, 1,
gelqf_backward>);
-NNVM_REGISTER_OP(_linalg_syevd).set_attr<FCompute>("FCompute<gpu>",
LaOpForwSyevd<gpu, syevd>);
+NNVM_REGISTER_OP(_linalg_syevd)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", LaOpForwSyevd<gpu, syevd>);
NNVM_REGISTER_OP(_backward_linalg_syevd)
.set_attr<FCompute>("FCompute<gpu>", LaOpBackwSyevd<gpu, syevd_backward>);
diff --git a/src/operator/tensor/matrix_op.cu b/src/operator/tensor/matrix_op.cu
index b5bd1c9..00007bd 100644
--- a/src/operator/tensor/matrix_op.cu
+++ b/src/operator/tensor/matrix_op.cu
@@ -412,9 +412,15 @@ NNVM_REGISTER_OP(tile).set_attr<FCompute>("FCompute<gpu>",
TileOpForward<gpu>);
NNVM_REGISTER_OP(_backward_tile).set_attr<FCompute>("FCompute<gpu>",
TileOpBackward<gpu>);
-NNVM_REGISTER_OP(reverse).set_attr<FCompute>("FCompute<gpu>",
ReverseOpForward<gpu>);
+NNVM_REGISTER_OP(reverse)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", ReverseOpForward<gpu>);
-NNVM_REGISTER_OP(_backward_reverse).set_attr<FCompute>("FCompute<gpu>",
ReverseOpForward<gpu>);
+NNVM_REGISTER_OP(_backward_reverse)
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", ReverseOpForward<gpu>);
NNVM_REGISTER_OP(stack).set_attr<FCompute>("FCompute<gpu>",
StackOpForward<gpu>);
@@ -429,9 +435,17 @@
NNVM_REGISTER_OP(depth_to_space).set_attr<FCompute>("FCompute<gpu>", DepthToSpac
NNVM_REGISTER_OP(space_to_depth).set_attr<FCompute>("FCompute<gpu>",
SpaceToDepthOpForward<gpu>);
-NNVM_REGISTER_OP(_split_v2).set_attr<FCompute>("FCompute<gpu>",
SplitOpForwardGPU);
-
-NNVM_REGISTER_OP(_split_v2_backward).set_attr<FCompute>("FCompute<gpu>",
SplitOpBackward<gpu>);
+NNVM_REGISTER_OP(_split_v2)
+ // Incompatible due to Copy(xpu_tensor, cpu_tensor) in SplitOpForwardImpl
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", SplitOpForwardGPU);
+
+NNVM_REGISTER_OP(_split_v2_backward)
+ // Incompatible due to Copy(xpu_tensor, cpu_tensor) in SplitOpBackwardImpl
+ .set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+ [](const NodeAttrs&, const bool) {
return false; })
+ .set_attr<FCompute>("FCompute<gpu>", SplitOpBackward<gpu>);
} // namespace op
} // namespace mxnet
diff --git a/tests/python/gpu/test_gluon_gpu.py
b/tests/python/gpu/test_gluon_gpu.py
index 492055f..20a7f26 100644
--- a/tests/python/gpu/test_gluon_gpu.py
+++ b/tests/python/gpu/test_gluon_gpu.py
@@ -18,6 +18,7 @@
import sys
import os
import time
+import random
import mxnet as mx
import multiprocessing as mp
from mxnet.test_utils import check_consistency, set_default_device,
assert_almost_equal, rand_ndarray, environment
@@ -28,7 +29,7 @@ import pytest
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
sys.path.insert(0, os.path.join(curr_path, '../unittest'))
-from common import assert_raises_cudnn_not_satisfied, run_in_spawned_process
+from common import assert_raises_cudnn_not_satisfied, run_in_spawned_process,
random_seed
from test_gluon import *
from test_loss import *
from test_numpy_loss import *
@@ -595,3 +596,111 @@ def test_cudnn_dropout_reproducibility():
assert_almost_equal(a.grad, b.grad)
[email protected]_np
+def test_cuda_graphs():
+ class GraphTester(gluon.HybridBlock):
+ def __init__(self, function_to_test, **kwargs):
+ super(GraphTester, self).__init__(**kwargs)
+ self.f = function_to_test()
+
+ def forward(self, *args):
+ # We need to isolate the operation to be fully inside the graph
+ # in order for graphs usage to be possible
+ copied_args = [mx.np.copy(a) for a in args]
+ outputs = self.f(*copied_args)
+ if isinstance(outputs, (list, tuple)):
+ return [mx.np.copy(o) for o in outputs]
+ else:
+ return mx.np.copy(outputs)
+
+ class TestDesc:
+ def __init__(self, name, f, num_inputs=1, input_dim=4):
+ self.name = name
+ self.f = f
+ self.num_inputs = num_inputs
+ self.input_dim = input_dim
+
+ def generate_inputs(self):
+ shape = tuple(_np.random.randint(4, 11, size=self.input_dim))
+ ret = [mx.np.random.uniform(size=shape) for _ in
range(self.num_inputs)]
+ for r in ret:
+ r.attach_grad()
+ return ret
+
+ tested_ops = [
+ TestDesc('add', lambda: (lambda x, y: x + y), num_inputs = 2),
+ TestDesc('add_scalar', lambda: (lambda x: x + 0.5)),
+ TestDesc('Conv', lambda: mx.gluon.nn.Conv2D(channels=32,
kernel_size=(1,1))),
+ TestDesc('ConvTranspose', lambda:
mx.gluon.nn.Conv2DTranspose(channels=32, kernel_size=(1,1))),
+ TestDesc('Dense', lambda: mx.gluon.nn.Dense(units=128)),
+ TestDesc('Activation', lambda: mx.gluon.nn.Activation('tanh')),
+ TestDesc('Dropout', lambda: mx.gluon.nn.Dropout(0.5)),
+ TestDesc('Flatten', lambda: mx.gluon.nn.Flatten()),
+ TestDesc('MaxPool', lambda: mx.gluon.nn.MaxPool2D()),
+ TestDesc('AvgPool', lambda: mx.gluon.nn.AvgPool2D()),
+ TestDesc('GlobalMaxPool', lambda: mx.gluon.nn.GlobalMaxPool2D()),
+ TestDesc('GlobalAvgPool', lambda: mx.gluon.nn.GlobalAvgPool2D()),
+ TestDesc('ReflectionPad2D', lambda: mx.gluon.nn.ReflectionPad2D()),
+ TestDesc('BatchNorm', lambda: mx.gluon.nn.BatchNorm()),
+ TestDesc('InstanceNorm', lambda: mx.gluon.nn.InstanceNorm()),
+ TestDesc('LayerNorm', lambda: mx.gluon.nn.LayerNorm()),
+ TestDesc('LeakyReLU', lambda: mx.gluon.nn.LeakyReLU(0.1)),
+ TestDesc('PReLU', lambda: mx.gluon.nn.PReLU()),
+ TestDesc('ELU', lambda: mx.gluon.nn.ELU()),
+ TestDesc('SELU', lambda: mx.gluon.nn.SELU()),
+ TestDesc('Swish', lambda: mx.gluon.nn.Swish()),
+ ]
+
+ N = 10
+
+ with environment({'MXNET_ENABLE_CUDA_GRAPHS': '1',
+ 'MXNET_USE_FUSION': '0'}):
+ device = mx.gpu(0)
+ for test_desc in tested_ops:
+ print("Testing ", test_desc.name)
+ inputs = test_desc.generate_inputs()
+ inputsg = [i.copy() for i in inputs]
+ for i in inputsg:
+ i.attach_grad()
+ seed = random.randint(0, 10000)
+ net = GraphTester(test_desc.f)
+ netg = GraphTester(test_desc.f)
+
+ # initialize parameters
+ net.initialize(device=device)
+ netg.initialize(device=device)
+
+ net(*inputs)
+
+ for p1, p2 in zip(net.collect_params().values(),
netg.collect_params().values()):
+ p2.set_data(p1.data())
+
+ netg.hybridize(static_alloc=True, static_shape=True)
+
+ print("Testing inference mode")
+ with random_seed(seed):
+ for _ in range(N):
+ assert_almost_equal(net(*inputs), netg(*inputsg))
+
+ mx.npx.waitall()
+ print("Testing training mode")
+ for _ in range(N):
+ with random_seed(seed):
+ with mx.autograd.record():
+ out = net(*inputs)
+ out.backward()
+
+ with random_seed(seed):
+ with mx.autograd.record():
+ outg = netg(*inputsg)
+ outg.backward()
+
+ assert_almost_equal(out, outg)
+ for i, ig in zip(inputs, inputsg):
+ assert_almost_equal(i.grad, ig.grad)
+
+ for p1, p2 in zip(net.collect_params().values(),
netg.collect_params().values()):
+ assert_almost_equal(p1.data(), p2.data())
+ if p1.grad_req != 'null':
+ assert_almost_equal(p1.grad(), p2.grad())
+ mx.npx.waitall()