You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mxnet.apache.org by GitBox <gi...@apache.org> on 2020/09/14 20:43:17 UTC

[GitHub] [incubator-mxnet] ptrendx opened a new pull request #19142: [1.x][FEATURE] CUDA graphs support

ptrendx opened a new pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142


   ## Description ##
   CUDA graphs is a feature of CUDA 10, which enables lowering CPU overhead by bundling multiple kernel launches together.
   
   The main limitation of CUDA graphs is that they do require the graph to be static - no parameters to the kernels can change, otherwise the graph needs to be recaptured. That is why this feature is currently only enabled for the symbolic models and Gluon models hybridized with `hybridize(static_alloc=True, static_shape=True)`.
   
   The feature is not enabled by default and requires environment variable `MXNET_ENABLE_CUDA_GRAPHS` to be set. In order to not capture the operations, the execution of which may change during the course of the training job, stateful operators and operators relying on resources other than workspace are not included in the graph.
   
   Since the feature lowers the CPU overhead of the execution, the impact is most visible in the inference or scale-out workloads with small batch size. Let us consider following script comparing fp16, batch size 1 inference of RN50v2 model from GluonCV with and without graphs:
   
   ```python
   
   import mxnet as mx
   import gluoncv as gcv
   from gluoncv.model_zoo import get_model
   import time
   import os
   
   net = get_model('ResNet50_V2')
   net2 = get_model('ResNet50_V2')
   
   net.initialize(ctx=mx.gpu())
   net2.initialize(ctx=mx.gpu())
   
   net.cast('float16')
   net2.cast('float16')
   
   net.hybridize(static_alloc=True, static_shape=True)
   net2.hybridize(static_alloc=True, static_shape=True)
   
   img = mx.random.uniform(shape=(1, 3, 224, 224), ctx=mx.gpu(), dtype='float16')
   
   os.environ["MXNET_ENABLE_CUDA_GRAPHS"] = "0"
   
   for _ in range(10):
       o = net(img)
   
   mx.nd.waitall()
   
   s = time.time()
   for _ in range(1000):
       o = net(img)
   
   mx.nd.waitall()
   e = time.time()
   print("No graphs: ", e - s)
   mx.nd.waitall()
   
   os.environ["MXNET_ENABLE_CUDA_GRAPHS"] = "1"
   
   for _ in range(10):
       o = net2(img)
   
   mx.nd.waitall()
   
   s = time.time()
   for _ in range(1000):
       o = net2(img)
   
   mx.nd.waitall()
   e = time.time()
   print("With graphs: ", e - s)
   ```
   
   The result obtained on V100 16GB:
   
   ```bash
   [20:33:46] ../src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
   No graphs:  2.8153152465820312
   With graphs:  2.3230245113372803
   ```
   
   so over 17% increase in performance. The same script but with 128 batch size gives much smaller improvement:
   
   ```bash
   [20:38:34] ../src/operator/nn/./cudnn/./cudnn_algoreg-inl.h:97: Running performance tests to find the best convolution algorithm, this can take a while... (set the environment variable MXNET_CUDNN_AUTOTUNE_DEFAULT to 0 to disable)
   No graphs:  47.31952977180481
   With graphs:  47.152849197387695
   ```
   
   so 0.3%.
   
   ## Checklist ##
   ### Essentials ###
   - [x] PR's title starts with a category (e.g. [BUGFIX], [MODEL], [TUTORIAL], [FEATURE], [DOC], etc)
   - [ ] Changes are complete (i.e. I finished coding on this PR)
   - [ ] All changes have test coverage
   - [ ] Code is well-documented
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488245153



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       I guess look here: https://github.com/apache/incubator-mxnet/pull/19142/files#diff-789523bf443903e74acfa010a5d6b572R33-R37 - this is for dropout, which uses random resource for training (and thus is excluded), but is just a passthrough for inference (and so we want to include it there).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] samskalicky commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
samskalicky commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488288258



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       Thats what the default value is for:
   ```
   const auto f = fgraphcompatible.get(attrs.op, nullptr);
   ```
   you can just check and see if its `null` instead of calling it to return false. Shouldnt the default should be to exclude an op instead of include?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] samskalicky commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
samskalicky commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488823750



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       I think we should be ok since its disabled by default, and only enabled explicitly by setting the env var. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] samskalicky commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
samskalicky commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488238529



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       Do something like this?
   https://github.com/apache/incubator-mxnet/blob/72eff9b66ecc683c3e7f9ad2c0ba69efa8dd423b/src/imperative/imperative_utils.h#L83-L84




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] samskalicky commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
samskalicky commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488255168



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       ya that one is fine, but "_npi_eig" is always false...




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] samskalicky commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
samskalicky commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488288258



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       Thats what the default value is for:
   ```
   const auto f = fgraphcompatible.get(attrs.op, nullptr);
   ```
   you can just check and see if its `null` instead of calling it to return false. 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488239192



##########
File path: src/operator/numpy/linalg/np_eigvals.cu
##########
@@ -29,11 +29,19 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_npi_eigvals)
+.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+    [](const NodeAttrs&, const bool) {
+      return false;
+    })
 .set_attr<FCompute>("FCompute<gpu>", EigvalsOpForward<gpu>);
 
 #if MXNET_USE_CUSOLVER == 1
 
 NNVM_REGISTER_OP(_npi_eigvalsh)
+.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",

Review comment:
       Well, the "everywhere" here is actually way less than if we went the other way around (only a bunch of operators with `FCompute` has synchronization that is not allowed under graphs.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488793102



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       For the first question: yes - nearly all of the FCompute functions that do not use random resource are compatible with graphs.
   For the second question - for FCompute operators to not be compatible you need to do synchronization inside the operator - either via stream synchronize or allocation. You generally do not want to do either one of those as it really hurts performance (and the operators that in this PR I marked incompatible in this PR do just that). If you just launch a kernel (or multiple), which is the case for vast majority of the operators, then you are good and do not even need to think about graphs - it will just work.
   
   I'm still exploring the ways of automatically testing newly added operators in order for the feature to be able to be on by default, but I do not consider this the scope of this PR, as v1.x branch is not really supposed to get many more operators (I will do that in the PR to master). Generally this would involve testing operators with `hybridize(static_alloc=True, static_shape=True)` (which generally should be tested much more as right now testing of this functionality is really limited, even though it is widely used).




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
ptrendx commented on pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#issuecomment-692327098


   Forgot to add in the description that most of the work here was done by @DickJC123.


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] samskalicky commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
samskalicky commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488240768



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       Whats the point of registering a lambda function that just returns false?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] mxnet-bot commented on pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
mxnet-bot commented on pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#issuecomment-692302943


   Hey @ptrendx , Thanks for submitting the PR 
   All tests are already queued to run once. If tests fail, you can trigger one or more tests again with the following commands: 
   - To trigger all jobs: @mxnet-bot run ci [all] 
   - To trigger specific jobs: @mxnet-bot run ci [job1, job2] 
   *** 
   **CI supported jobs**: [clang, centos-cpu, windows-gpu, windows-cpu, unix-cpu, website, miscellaneous, sanity, edge, unix-gpu, centos-gpu]
   *** 
   _Note_: 
    Only following 3 categories can trigger CI :PR Author, MXNet Committer, Jenkins Admin. 
   All CI tests must pass before the PR can be merged. 
   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488290777



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       We went with the default to be include the operator, but instead have the functionality itself be non-default.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488239585



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       Not sure I understand - I do want to have the function call per `Op` here, not just a simple `true`/`false`.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] samskalicky merged pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
samskalicky merged pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142


   


----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] samskalicky commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
samskalicky commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488291244



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       Are the majority of ops supported to be included in cuda graphs? If a new user comes in to write an op, do they need to be aware of how to handle cuda graph support? 




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488290581



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       But then I would need to add that function for almost all operators for it to return true...




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488240388



##########
File path: src/operator/numpy/linalg/np_eigvals.cu
##########
@@ -29,11 +29,19 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_npi_eigvals)
+.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+    [](const NodeAttrs&, const bool) {
+      return false;
+    })
 .set_attr<FCompute>("FCompute<gpu>", EigvalsOpForward<gpu>);
 
 #if MXNET_USE_CUSOLVER == 1
 
 NNVM_REGISTER_OP(_npi_eigvalsh)
+.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",

Review comment:
       And the long term goal here is actually to make those excluded operators compatible too.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] samskalicky commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
samskalicky commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488236848



##########
File path: src/operator/numpy/linalg/np_eigvals.cu
##########
@@ -29,11 +29,19 @@ namespace mxnet {
 namespace op {
 
 NNVM_REGISTER_OP(_npi_eigvals)
+.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",
+    [](const NodeAttrs&, const bool) {
+      return false;
+    })
 .set_attr<FCompute>("FCompute<gpu>", EigvalsOpForward<gpu>);
 
 #if MXNET_USE_CUSOLVER == 1
 
 NNVM_REGISTER_OP(_npi_eigvalsh)
+.set_attr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible",

Review comment:
       istead of setting false everywhere, can we just check if `hasAttr("FIsCUDAGraphsCompatible")` so that by default its false?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] samskalicky commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
samskalicky commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488288258



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       Thats what the default value is for:
   ```
   const auto f = fgraphcompatible.get(attrs.op, nullptr);
   ```
   you can just check and see if its `null` instead of calling it to return false. Shouldnt the default should be to include an op instead of exclude?




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org



[GitHub] [incubator-mxnet] ptrendx commented on a change in pull request #19142: [1.x][FEATURE] CUDA graphs support

Posted by GitBox <gi...@apache.org>.
ptrendx commented on a change in pull request #19142:
URL: https://github.com/apache/incubator-mxnet/pull/19142#discussion_r488257131



##########
File path: src/executor/cuda_graphs.h
##########
@@ -0,0 +1,546 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * Copyright (c) 2020 by Contributors
+ * \file cuda_graphs.h
+ * \brief Wrappers for use of CUDA Graphs API
+ */
+#ifndef MXNET_EXECUTOR_CUDA_GRAPHS_H_
+#define MXNET_EXECUTOR_CUDA_GRAPHS_H_
+
+#include <mxnet/base.h>
+#include <vector>
+#include <string>
+#include <map>
+#include <set>
+
+#include "./exec_pass.h"
+#include "../common/cuda_utils.h"
+
+#if MXNET_USE_CUDA
+#define CUDA_GRAPHS_AVAILABLE (CUDA_VERSION >= 10010)
+#else
+#define CUDA_GRAPHS_AVAILABLE (0)
+#endif
+
+#if CUDA_GRAPHS_AVAILABLE
+
+namespace mxnet {
+namespace cuda_graphs {
+
+inline std::string CudaDim3ToString(const dim3& dims) {
+  std::stringstream ss;
+  if (dims.z != 1)
+    ss << "(" << dims.x << "," << dims.y << "," << dims.z << ")";
+  else if (dims.y != 1)
+    ss << "(" << dims.x << "," << dims.y << ")";
+  else
+    ss << "(" << dims.x << ")";
+  return ss.str();
+}
+
+// Get the type of a CUDA Graph node (e.g. kernel launch, memcpy, etc.)
+inline CUgraphNodeType CudaGraphNodeType(const cudaGraphNode_t node) {
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  return t;
+}
+
+// Return the list of CUDA Graph nodes from a graph
+inline std::vector<cudaGraphNode_t> GetCudaGraphNodes(cudaGraph_t cuda_graph) {
+  size_t numNodes;
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, static_cast<cudaGraphNode_t*>(nullptr), &numNodes));
+  if (numNodes == 0)
+    return std::vector<cudaGraphNode_t>();
+  std::vector<cudaGraphNode_t> graphNodes(numNodes);
+  CUDA_CALL(cudaGraphGetNodes(cuda_graph, graphNodes.data(), &numNodes));
+  return graphNodes;
+}
+
+inline std::string CudaGraphNodeToString(const cudaGraphNode_t node) {
+  std::stringstream ss;
+
+  // The following introspection calls are made through the driver API in order to bypass
+  // problems that would arise if multiple statically-linked copies of the runtime exist.
+
+  CUgraphNode cu_node = node;
+  CUgraphNodeType t;
+  CUDA_DRIVER_CALL(cuGraphNodeGetType(cu_node, &t));
+  switch (t) {
+    case CU_GRAPH_NODE_TYPE_KERNEL:
+      {
+        CUDA_KERNEL_NODE_PARAMS kparams;
+        auto err = cuGraphKernelNodeGetParams(cu_node, &kparams);
+        if (err == CUDA_SUCCESS) {
+          ss << "GPUKernel@" << kparams.func;
+          dim3 gridDim(kparams.gridDimX, kparams.gridDimY, kparams.gridDimZ);
+          dim3 blockDim(kparams.blockDimX, kparams.blockDimY, kparams.blockDimZ);
+          ss << "<<<gridDim=" << CudaDim3ToString(gridDim)
+             << ", blkDim=" << CudaDim3ToString(blockDim) << ">>>";
+          ss << "(...";
+          if (kparams.sharedMemBytes != 0)
+            ss << ", dynSharedMemBytes=" << kparams.sharedMemBytes;
+          ss << ")";
+        } else {
+          ss << "GPU Kernel: cuGraphKernelNodeGetParams() fails with " << err;
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMCPY:
+      {
+        cudaMemcpy3DParms mparams = {};
+        CUDA_CALL(cudaGraphMemcpyNodeGetParams(node, &mparams));
+        // If memcpy is seen, return without setting up runnable executor
+        switch (mparams.kind) {
+          case cudaMemcpyHostToHost: ss << "Host->Host "; break;
+          case cudaMemcpyHostToDevice: ss << "Host->Device "; break;
+          case cudaMemcpyDeviceToHost: ss << "Device->Host "; break;
+          case cudaMemcpyDeviceToDevice: ss << "Device->Device "; break;
+          default: break;
+        }
+        ss << "Memcpy";
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_MEMSET:
+      {
+        cudaMemsetParams mparams = {};
+        CUDA_CALL(cudaGraphMemsetNodeGetParams(node, &mparams));
+        if (mparams.height == 1 && mparams.elementSize == 1) {
+          ss << "cudaMemset(devPtr=" << mparams.dst << ", value=" << mparams.value
+             << ", count=" << mparams.width << ")";
+        } else {
+          if (mparams.elementSize == 1)
+            ss << "cudaMemset2D";
+          else
+            ss << "MemSet<elemBytes=" << mparams.elementSize << ">";
+          ss << "(devPtr=" << mparams.dst << ", pitch=" << mparams.pitch
+             << ", value=" << mparams.value << ", width=" << mparams.width
+             << ", height=" << mparams.height << ")";
+        }
+      }
+      break;
+    case CU_GRAPH_NODE_TYPE_HOST: ss << "Host (executable) node"; break;
+    case CU_GRAPH_NODE_TYPE_GRAPH: ss << "Node which executes an embedded graph"; break;
+    case CU_GRAPH_NODE_TYPE_EMPTY: ss << "Empty (no-op) node"; break;
+    default: ss << "Unknown/Invalid node type " << t;
+  }
+  return ss.str();
+}
+
+
+// CUDA Graphs are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphDeleter {
+ public:
+  void operator() (cudaGraph_t graph) {
+    if (graph != nullptr)
+      CUDA_CALL(cudaGraphDestroy(graph));
+  }
+};
+
+// CUDA Graphs Executors are managed in RAII fashion by smart pointers below.
+// Function objects (preferred for readability) provide the deleter function.
+class CudaGraphExecDeleter {
+ public:
+  void operator() (cudaGraphExec_t graph_exec) {
+    if (graph_exec != nullptr)
+      CUDA_CALL(cudaGraphExecDestroy(graph_exec));
+  }
+};
+
+// A CUDA Graphs executor for a portion of an Operator Segment (i.e. a 'SubSegment'),
+// characterized by a starting index in the OpExecutor list and a number of ops.
+class CudaGraphsSubSegExec {
+ public:
+  CudaGraphsSubSegExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops,
+                       bool ops_are_cuda_graph_compatible = true) :
+  from_op_idx_(from_op_idx),
+  num_ops_(num_ops),
+  graph_(nullptr),
+  graph_exec_(nullptr) {
+    if (ops_are_cuda_graph_compatible) {
+      MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx, num_ops);
+      MakeGraphExec();
+    }
+  }
+
+  void Update(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu,
+              bool verbose) {
+    // Current executor should be Runnable with the same parameters
+    CHECK(IsRunnable());
+    MakeGraph(exec_list, rctx, is_gpu, verbose, from_op_idx_, num_ops_);
+
+    cudaGraphExecUpdateResult update_result = cudaGraphExecUpdateError;
+    cudaGraphNode_t error_node;
+    CUDA_CALL(cudaGraphExecUpdate(graph_exec_.get(), graph_.get(),
+                                  &error_node, &update_result));
+    // If update fails make a new executor, discarding old one.
+    if (update_result != cudaGraphExecUpdateSuccess)
+      MakeGraphExec();
+  }
+
+  void RunSubSeg(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 const RunContext &rctx,
+                 bool is_gpu) {
+    if (IsRunnable()) {
+      auto s = rctx.get_stream<gpu>();
+      const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+      CUDA_CALL(cudaGraphLaunch(graph_exec_.get(), cu_s));
+    } else {
+      // No CUDA Graph could be made for this portion of the OpSegment.  Run conventionally.
+      for (int i = 0; i != num_ops_; ++i)
+        exec_list[from_op_idx_ + i]->Run(rctx, is_gpu);
+    }
+  }
+
+  bool IsRunnable() { return graph_exec_ != nullptr; }
+
+ private:
+  void MakeGraph(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                       const RunContext &rctx,
+                       bool is_gpu,
+                       bool verbose,
+                       int from_op_idx,
+                       int num_ops) {
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // Create CUDA Graph
+    // Use of cudaStreamCaptureModeThreadLocal allows other threads like GPU Copy workers
+    // to sync their streams without disturbing this capture.
+    CUDA_CALL(cudaStreamBeginCapture(cu_s, cudaStreamCaptureModeThreadLocal));
+    // Run those oprs in the sub segment while capturing- no actual GPU work is launched.
+    for (int i = 0; i != num_ops; ++i)
+      exec_list[from_op_idx + i]->Run(rctx, is_gpu);
+    cudaGraph_t cuda_graph = nullptr;
+    CUDA_CALL(cudaStreamEndCapture(cu_s, &cuda_graph));
+    graph_.reset(cuda_graph, CudaGraphDeleter());
+
+    if (verbose) {
+      std::vector<cudaGraphNode_t> graph_nodes = GetCudaGraphNodes(cuda_graph);
+      size_t num_nodes = graph_nodes.size();
+      LOG(INFO) << "  Graph has " << num_nodes << " nodes:";
+      for (size_t i = 0; i != num_nodes; ++i) {
+        LOG(INFO) << "    node " << i << " = "
+                  << CudaGraphNodeToString(graph_nodes[i]);
+      }
+    }
+  }
+
+  void MakeGraphExec() {
+      cudaGraphExec_t cuda_graph_exec;
+      cudaGraphNode_t error_node;
+      char log_buffer[1000];
+
+      CUDA_CALL(cudaGraphInstantiate(&cuda_graph_exec, graph_.get(),
+                                     &error_node, log_buffer, 1000));
+      graph_exec_.reset(cuda_graph_exec, CudaGraphExecDeleter());
+
+      // At this point we have a CUDA Graph executor
+      static int num_graph_creations_logged = 0;
+      static int max_log_entries = dmlc::GetEnv("MXNET_CUDA_GRAPHS_MAX_LOG_ENTRIES", 0);
+      if (num_graph_creations_logged < max_log_entries) {
+        num_graph_creations_logged++;
+        LOG(INFO) << "Created CUDA graph " << num_graph_creations_logged;
+        if (num_graph_creations_logged == max_log_entries)
+          LOG(INFO) << "Further CUDA graph creation log messages are suppressed.";
+      }
+  }
+
+  int from_op_idx_;
+  int num_ops_;
+  using cudaGraphStruct_t = typename std::remove_pointer<cudaGraph_t>::type;
+  using cudaGraphExecStruct_t = typename std::remove_pointer<cudaGraphExec_t>::type;
+  std::shared_ptr<cudaGraphStruct_t> graph_;
+  std::shared_ptr<cudaGraphExecStruct_t> graph_exec_;
+};
+
+// The CudaGraph executor and associated Tempspace ptrs for which it is valid.
+struct CudaGraphInfo {
+  std::vector<CudaGraphsSubSegExec> cuda_graph_subseg_execs;
+  bool has_been_run_conventionally = false;
+  std::vector<void *> tempspace_dptrs;
+};
+// A CUDA graph is maintained for every combination of cudaStream_t (i.e. GPU Worker) and
+// the state of the is_train flag of the OpContext.  If the tempspace_dptrs change, we
+// don't expect to ever see the old tempspace_dptrs config again, so we discard the CUDA graph.
+struct CudaGraphCacheKey {
+  cudaStream_t cu_s;
+  bool is_train;
+  // overload '<' so CudaGraphCacheKey can be used as a std::map key
+  bool operator<(const CudaGraphCacheKey &other) const {
+    return cu_s < other.cu_s || (cu_s == other.cu_s && is_train < other.is_train);
+  }
+};
+using CudaGraphCache = std::map<CudaGraphCacheKey, CudaGraphInfo>;
+
+class CudaGraphsExec {
+ public:
+  CudaGraphsExec(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                 bool is_gpu,
+                 const char *opr_names) :
+  verbose_(false), is_enabled_(false) {
+    opr_names_ = opr_names ? std::string(opr_names) : std::string();
+    if (is_gpu) {
+      is_enabled_ = dmlc::GetEnv("MXNET_ENABLE_CUDA_GRAPHS", false);
+      verbose_ = dmlc::GetEnv("MXNET_CUDA_GRAPHS_VERBOSE", false);
+      SetTempSpaces(exec_list);
+    }
+  }
+
+  void RunAll(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+              const RunContext &rctx,
+              bool is_gpu) {
+    // If this a CPU op or CUDA Graphs use isn't possible, run normally and return
+    if (!is_gpu || !is_enabled_) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      return;
+    }
+
+    // Also if we're in a warm-up period where tempspace pointers are likely
+    // to change, run normally and return
+    auto s = rctx.get_stream<gpu>();
+    const cudaStream_t cu_s = mshadow::Stream<gpu>::GetStream(s);
+    // All the ops in the bulked segment will have the same setting of is_train as the first op
+    const bool is_train = exec_list.size() > 0 && exec_list[0]->op_ctx.is_train;
+    const CudaGraphCacheKey key = {cu_s, is_train};
+    // Look-up the CUDA Graph info for this combo of stream and is_train setting
+    // This may create a default-initialized new entry.
+    auto &cuda_graph_info = cache_[key];
+    if (!cuda_graph_info.has_been_run_conventionally) {
+      // Run all opr in the sub-graph
+      exec::OpExecutor::RunAll(exec_list, rctx, is_gpu);
+      cuda_graph_info.has_been_run_conventionally = true;
+      return;
+    }
+
+    // At this point we will launch one or more CUDA Graphs through CUDA Graphs 'executors'
+    //     (there might be more than one executor if some ops in the segment are not capturable)
+    auto before_exec_tempspace_ptrs = GetGPUTempspacePtrs(s);
+
+    // Executors exist, but the tempspace pts have changed, so update them in-place via 'recapture'.
+    if (cuda_graph_info.cuda_graph_subseg_execs.size() > 0 &&
+        cuda_graph_info.tempspace_dptrs != before_exec_tempspace_ptrs) {
+      // Update all runnable executors.  Non-runnable executors launch their ops conventionally.
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          subseg_exec.Update(exec_list, rctx, is_gpu, verbose_);
+      }
+    } else if (cuda_graph_info.cuda_graph_subseg_execs.size() == 0) {
+      // No executors exist yet, so create them.
+      if (verbose_)
+        LOG(INFO) << "Capturing CUDA graph of op segment " << opr_names_;
+      // Make one or more CUDA Graphs, avoiding ops that are not compatible.
+      for (size_t first_op_idx = 0; first_op_idx != exec_list.size();) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != exec_list.size(); ++last_op_idx) {
+          if (OpOK(exec_list[last_op_idx]))
+            num_good_ops++;
+          else
+            break;
+        }
+        if (num_good_ops > 0) {
+          CreateSubExecOverRegion(exec_list, rctx, is_gpu,
+                                  first_op_idx,
+                                  first_op_idx + num_good_ops,
+                                  &cuda_graph_info.cuda_graph_subseg_execs);
+          first_op_idx += num_good_ops;
+        }
+        if (first_op_idx != exec_list.size()) {
+          // We had to have hit an op that was not OK.
+          if (verbose_) {
+            LOG(INFO) << "Bypassing notOK op segment[" << first_op_idx << "," << first_op_idx << "]"
+                      << " of op segment "  << opr_names_;
+          }
+          CudaGraphsSubSegExec notOK_opseg(exec_list, rctx, is_gpu, false, first_op_idx, 1, false);
+          cuda_graph_info.cuda_graph_subseg_execs.push_back(notOK_opseg);
+          first_op_idx++;
+        }
+      }
+      // During graph capture, the ops may be asking for the tempworkspace.  This should
+      // not alter the base pointers, since this op seg has been executed before on this
+      // stream (i.e. on this gpu worker).  Safest to double-check this though.
+      auto after_capture_tempspace_ptrs = GetGPUTempspacePtrs(s);
+      if (before_exec_tempspace_ptrs != after_capture_tempspace_ptrs)
+        LOG(FATAL) << "Internal error: saw change in TempSpace ptrs during CUDA graph use.";
+      cuda_graph_info.tempspace_dptrs = before_exec_tempspace_ptrs;
+    }
+    // Now execute the CUDA Graph that we either just created or looked-up in the cache.
+    if (verbose_) {
+      int runnable_execs = 0;
+      int bypassed_ops = 0;
+      for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs) {
+        if (subseg_exec.IsRunnable())
+          runnable_execs++;
+        else
+          bypassed_ops++;
+      }
+      LOG(INFO) << "Launching " << runnable_execs
+                << " captured CUDA graph(s) for op segment " << opr_names_;
+      if (bypassed_ops > 0)
+        LOG(INFO) << "    (bypassing " << bypassed_ops << " un-capturable ops)";
+    }
+    for (auto &subseg_exec : cuda_graph_info.cuda_graph_subseg_execs)
+      subseg_exec.RunSubSeg(exec_list, rctx, is_gpu);
+  }
+
+ private:
+  // Make a CUDA Graph of the region of ops [from_op_idx, upto_op_idx).  If such a graph
+  // is not runnable, e.g. if it includes memcpys from unpinned cpu memory, then make a
+  // number of smaller graphs that avoid those ops with the memcpys.
+  void CreateSubExecOverRegion(const std::vector<std::shared_ptr<exec::OpExecutor> > &exec_list,
+                               const RunContext &rctx,
+                               bool is_gpu,
+                               size_t from_op_idx,
+                               size_t upto_op_idx,
+                               std::vector<CudaGraphsSubSegExec> *cuda_graph_subseg_execs) {
+    // Optimistically try to create a CUDA Graph of the entire op segment region
+
+    int num_ops = upto_op_idx - from_op_idx;
+    CudaGraphsSubSegExec full_opseg(exec_list, rctx, is_gpu, verbose_, from_op_idx, num_ops);
+    if (full_opseg.IsRunnable()) {
+      cuda_graph_subseg_execs->push_back(full_opseg);
+    } else {
+      if (verbose_)
+        LOG(INFO) << "  Graph was not runnable- creating op sub-segments...";
+      // Enter fall-back approach to making many sub-execs
+      for (size_t first_op_idx = from_op_idx; first_op_idx != upto_op_idx; ) {
+        int num_good_ops = 0;
+        for (size_t last_op_idx = first_op_idx; last_op_idx != upto_op_idx; ++last_op_idx) {
+          CudaGraphsSubSegExec single_opseg(exec_list, rctx, is_gpu, false, last_op_idx, 1);
+          if (single_opseg.IsRunnable())
+            num_good_ops++;
+          // Is it time to create a subseg exec from accumulated good ops?
+          if (num_good_ops > 0 &&
+              (last_op_idx == upto_op_idx - 1 || !single_opseg.IsRunnable())) {
+            if (verbose_)
+              LOG(INFO) << "Capturing CUDA graph of op sub segment["
+                        << first_op_idx << ":" << (first_op_idx + num_good_ops - 1) << "]"
+                        << " of op segment "  << opr_names_;
+            CudaGraphsSubSegExec good_opseg(exec_list, rctx, is_gpu, verbose_,
+                                            first_op_idx, num_good_ops);
+            CHECK(good_opseg.IsRunnable()) << "Unexpected issue with CUDA Graphs creation";
+            cuda_graph_subseg_execs->push_back(good_opseg);
+            first_op_idx += num_good_ops;
+          }
+          // If the last single op was not runnable, use the exec to handle that op conventionally
+          if (!single_opseg.IsRunnable()) {
+            if (verbose_) {
+              LOG(INFO) << "Bypassing op sub segment[" << last_op_idx << "," << last_op_idx << "]"
+                        << " of op segment "  << opr_names_;
+              // Generate throw-away exec in order to produce a diagnostic listing of graph nodes
+              CudaGraphsSubSegExec dummy(exec_list, rctx, is_gpu, verbose_, last_op_idx, 1);
+            }
+            cuda_graph_subseg_execs->push_back(single_opseg);
+            first_op_idx++;
+            break;
+          }
+        }
+      }
+    }
+  }
+
+  // Is the Op OK to make part of a CUDA Graph?
+  bool OpOK(const std::shared_ptr<exec::OpExecutor> &exec) {
+    static auto& fstateful = Op::GetAttr<FCreateOpState>("FCreateOpState");
+    static auto& fgraphcompatible = Op::GetAttr<FIsCUDAGraphsCompatible>("FIsCUDAGraphsCompatible");
+    const auto& attrs = exec->attrs;
+    if (attrs.op != nullptr) {
+      const auto f = fgraphcompatible.get(attrs.op, nullptr);

Review comment:
       sure, but the signature needs to be the same for everything, so if I have a lambda for dropout then I need a lambda for the other ones as well...




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org