You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/03/11 20:33:57 UTC

[GitHub] [tvm] comaniac commented on a change in pull request #7616: [Runtime] Extend Graph Runtime To Support Cuda Graph Launch

comaniac commented on a change in pull request #7616:
URL: https://github.com/apache/tvm/pull/7616#discussion_r592694212



##########
File path: tests/python/unittest/test_runtime_graph_cugraph.py
##########
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+import os
+import re
+import sys
+import time
+
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import te
+import numpy as np
+
+from tvm.contrib import utils, graph_runtime
+from tvm.contrib.cu_graph import cugraph_runtime
+
+
+bx = te.thread_axis("blockIdx.x")
+tx = te.thread_axis("threadIdx.x")
+
+
+@tvm.testing.requires_cuda
+def test_graph_simple():

Review comment:
       You may need to check if CuGraph is enabled and skip this test if not.

##########
File path: CMakeLists.txt
##########
@@ -321,6 +322,16 @@ if(USE_GRAPH_RUNTIME)
     set_source_files_properties(${RUNTIME_GRAPH_SRCS}
       PROPERTIES COMPILE_DEFINITIONS "TVM_GRAPH_RUNTIME_DEBUG")
   endif(USE_GRAPH_RUNTIME_DEBUG)
+
+  if(USE_CUDA)
+    if(USE_GRAPH_RUNTIME_CUGRAPH)

Review comment:
       This makes USE_GRAPH_RUNTIME_CUGRAPH silent when CUDA is OFF and may confuse users. We should have
   ```
   if(USE_GRAPH_RUNTIME_CUGRAPH)
     if(NOT USE_CUDA)
       // error out saying please config with USE_CUDA=ON.
   ```

##########
File path: python/tvm/contrib/cu_graph/cugraph_runtime.py
##########
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Graph runtime test cuGraph"""
+import tvm._ffi
+
+from tvm._ffi.base import string_types
+from tvm.contrib import graph_runtime
+
+
+def create(graph_json_str, libmod, ctx):
+    assert isinstance(graph_json_str, string_types)
+    try:
+        ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
+        if num_rpc_ctx == len(ctx):
+            pass
+        else:
+            fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cugraph.create")
+    except ValueError:
+        raise ValueError(
+            "Please set '(USE_GRAPH_RUNTIME_CUGRAPH ON)' in "
+            "config.cmake and rebuild TVM to enable cu_graph test mode"

Review comment:
       Why test mode?

##########
File path: src/runtime/graph/cugraph/graph_runtime_cugraph.cc
##########
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph_runtime_cugraph.cc
+ */
+
+#include <tvm/runtime/registry.h>
+
+#include "../../cuda/cuda_common.h"
+#include "../graph_runtime.h"
+
+namespace tvm {
+namespace runtime {
+
+class GraphRuntimeCuGraph : public GraphRuntime {
+ public:
+  int StartCapture() {

Review comment:
       docstring (ditto to the rest).

##########
File path: tests/python/unittest/test_runtime_graph_cugraph.py
##########
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+import os
+import re
+import sys
+import time
+
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import te
+import numpy as np
+
+from tvm.contrib import utils, graph_runtime
+from tvm.contrib.cu_graph import cugraph_runtime
+
+
+bx = te.thread_axis("blockIdx.x")
+tx = te.thread_axis("threadIdx.x")
+
+
+@tvm.testing.requires_cuda
+def test_graph_simple():
+    n = 32
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
+    s = te.create_schedule(B.op)
+    xo, xi = s[B].split(B.op.axis[0], factor=8)
+    s[B].bind(xo, bx)
+    s[B].bind(xi, tx)
+
+    node0 = {"op": "null", "name": "x", "inputs": []}
+    node1 = {
+        "op": "tvm_op",
+        "name": "add",
+        "inputs": [[0, 0, 0]],
+        "attrs": {"func_name": "myadd", "flatten_data": "1", "num_inputs": "1", "num_outputs": "1"},
+    }
+    nodes = [node0, node1]
+    arg_nodes = [0]
+    node_row_ptr = [0, 1, 2]
+    outputs = [[1, 0, 0]]
+    shape = (n,)
+    attrs = {
+        "shape": ["list_shape", [shape, shape]],
+        "dltype": ["list_str", ["float32", "float32"]],
+        "storage_id": ["list_int", [0, 1]],
+    }
+    graph = {
+        "nodes": nodes,
+        "arg_nodes": arg_nodes,
+        "node_row_ptr": node_row_ptr,
+        "heads": outputs,
+        "attrs": attrs,
+    }
+    graph = json.dumps(graph)
+
+    def check_verify():
+        mlib = tvm.build(s, [A, B], "cuda", name="myadd")
+        ctx = tvm.gpu(0)
+        try:
+            mod = cugraph_runtime.create(graph, mlib, ctx)

Review comment:
       As I mentioned before, can we use the unified GraphModule interface?

##########
File path: CMakeLists.txt
##########
@@ -36,6 +36,7 @@ tvm_option(USE_LLVM "Build with LLVM, can be set to specific llvm-config path" O
 tvm_option(USE_STACKVM_RUNTIME "Include stackvm into the runtime" OFF)
 tvm_option(USE_GRAPH_RUNTIME "Build with tiny graph runtime" ON)
 tvm_option(USE_GRAPH_RUNTIME_DEBUG "Build with tiny graph runtime debug mode" OFF)
+tvm_option(USE_GRAPH_RUNTIME_CUGRAPH "Build with tiny graph runtime cuGraph launch mode" OFF)

Review comment:
       ```suggestion
   tvm_option(USE_GRAPH_RUNTIME_CUGRAPH "Build with tiny graph runtime with cuGraph for GPUs" OFF)
   ```

##########
File path: python/tvm/contrib/cu_graph/cugraph_runtime.py
##########
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Graph runtime test cuGraph"""
+import tvm._ffi
+
+from tvm._ffi.base import string_types
+from tvm.contrib import graph_runtime
+
+
+def create(graph_json_str, libmod, ctx):
+    assert isinstance(graph_json_str, string_types)
+    try:
+        ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
+        if num_rpc_ctx == len(ctx):
+            pass
+        else:
+            fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cugraph.create")

Review comment:
       ```suggestion
           if num_rpc_ctx != len(ctx):
               fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cugraph.create")
   ```

##########
File path: python/tvm/contrib/cu_graph/cugraph_runtime.py
##########
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Graph runtime test cuGraph"""
+import tvm._ffi
+
+from tvm._ffi.base import string_types
+from tvm.contrib import graph_runtime
+
+
+def create(graph_json_str, libmod, ctx):
+    assert isinstance(graph_json_str, string_types)
+    try:
+        ctx, num_rpc_ctx, device_type_id = graph_runtime.get_device_ctx(libmod, ctx)
+        if num_rpc_ctx == len(ctx):
+            pass
+        else:
+            fcreate = tvm._ffi.get_global_func("tvm.graph_runtime_cugraph.create")
+    except ValueError:
+        raise ValueError(
+            "Please set '(USE_GRAPH_RUNTIME_CUGRAPH ON)' in "
+            "config.cmake and rebuild TVM to enable cu_graph test mode"
+        )
+
+    func_obj = fcreate(graph_json_str, libmod, *device_type_id)
+    return GraphModuleCuGraph(func_obj, ctx, graph_json_str)
+
+
+class GraphModuleCuGraph(graph_runtime.GraphModule):
+    def __init__(self, module, ctx, graph_json_str):
+
+        self._start_capture = module["start_capture"]
+        self._end_capture = module["end_capture"]
+        self._run_cuda_graph = module["run_cuda_graph"]
+
+        graph_runtime.GraphModule.__init__(self, module)
+
+    def capture_cuda_graph(self):
+        self._run()  # call cuModuleLoadData before cudaStream API
+
+        print("====== Start Stream Capture ======")

Review comment:
       Remove unnecessary prints (ditto to the rest).

##########
File path: tests/python/unittest/test_runtime_graph_cugraph.py
##########
@@ -0,0 +1,92 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+import json
+import os
+import re
+import sys
+import time
+
+import pytest
+
+import tvm
+import tvm.testing
+from tvm import te
+import numpy as np
+
+from tvm.contrib import utils, graph_runtime
+from tvm.contrib.cu_graph import cugraph_runtime
+
+
+bx = te.thread_axis("blockIdx.x")
+tx = te.thread_axis("threadIdx.x")
+
+
+@tvm.testing.requires_cuda
+def test_graph_simple():
+    n = 32
+    A = te.placeholder((n,), name="A")
+    B = te.compute(A.shape, lambda *i: A(*i) + 1.0, name="B")
+    s = te.create_schedule(B.op)
+    xo, xi = s[B].split(B.op.axis[0], factor=8)
+    s[B].bind(xo, bx)
+    s[B].bind(xi, tx)
+
+    node0 = {"op": "null", "name": "x", "inputs": []}
+    node1 = {
+        "op": "tvm_op",
+        "name": "add",
+        "inputs": [[0, 0, 0]],
+        "attrs": {"func_name": "myadd", "flatten_data": "1", "num_inputs": "1", "num_outputs": "1"},
+    }
+    nodes = [node0, node1]
+    arg_nodes = [0]
+    node_row_ptr = [0, 1, 2]
+    outputs = [[1, 0, 0]]
+    shape = (n,)
+    attrs = {
+        "shape": ["list_shape", [shape, shape]],
+        "dltype": ["list_str", ["float32", "float32"]],
+        "storage_id": ["list_int", [0, 1]],
+    }
+    graph = {
+        "nodes": nodes,
+        "arg_nodes": arg_nodes,
+        "node_row_ptr": node_row_ptr,
+        "heads": outputs,
+        "attrs": attrs,
+    }
+    graph = json.dumps(graph)
+
+    def check_verify():
+        mlib = tvm.build(s, [A, B], "cuda", name="myadd")
+        ctx = tvm.gpu(0)
+        try:
+            mod = cugraph_runtime.create(graph, mlib, ctx)
+        except ValueError:
+            return
+        mod.capture_cuda_graph()
+        a = np.random.uniform(size=(n,)).astype(A.dtype)
+        mod.set_input(x=a)
+        mod.run_cuda_graph()

Review comment:
       We should think about the user inferface a bit more. For example, can we use the unified `run` API as graph_runtime, and automatically call `capture_cude_graph` before the first run?

##########
File path: src/runtime/graph/cugraph/graph_runtime_cugraph.cc
##########
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph_runtime_cugraph.cc
+ */
+
+#include <tvm/runtime/registry.h>
+
+#include "../../cuda/cuda_common.h"
+#include "../graph_runtime.h"
+
+namespace tvm {
+namespace runtime {
+
+class GraphRuntimeCuGraph : public GraphRuntime {
+ public:
+  int StartCapture() {
+    const TVMContext& ctx = data_entry_[entry_id(0, 0)]->ctx;
+
+    TVMStreamCreate(ctx.device_type, ctx.device_id, &capture_stream_);
+    TVMSetStream(ctx.device_type, ctx.device_id, capture_stream_);
+
+    CUDA_CALL(cudaStreamBeginCapture(static_cast<cudaStream_t>(capture_stream_),
+                                     cudaStreamCaptureModeGlobal));
+    return 0;

Review comment:
       If 0 is the only return value, we could simply return void.

##########
File path: src/runtime/graph/cugraph/graph_runtime_cugraph.cc
##########
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+/*!
+ * \file graph_runtime_cugraph.cc
+ */
+
+#include <tvm/runtime/registry.h>
+
+#include "../../cuda/cuda_common.h"
+#include "../graph_runtime.h"
+
+namespace tvm {
+namespace runtime {
+
+class GraphRuntimeCuGraph : public GraphRuntime {
+ public:
+  int StartCapture() {
+    const TVMContext& ctx = data_entry_[entry_id(0, 0)]->ctx;
+
+    TVMStreamCreate(ctx.device_type, ctx.device_id, &capture_stream_);
+    TVMSetStream(ctx.device_type, ctx.device_id, capture_stream_);
+
+    CUDA_CALL(cudaStreamBeginCapture(static_cast<cudaStream_t>(capture_stream_),
+                                     cudaStreamCaptureModeGlobal));
+    return 0;
+  }
+
+  int RunCudaGraph() {
+    cudaStream_t cuStream = static_cast<cudaStream_t>(capture_stream_);
+    CUDA_CALL(cudaGraphLaunch(cu_graph_exec_, cuStream));
+    CUDA_CALL(cudaStreamSynchronize(cuStream));
+    return 0;
+  }
+
+  int EndCapture() {
+    cudaGraph_t graph;
+    CUDA_CALL(cudaStreamEndCapture(static_cast<cudaStream_t>(capture_stream_), &graph));
+
+    cudaGraphNode_t* nodes = NULL;
+    size_t numNodes = 0;
+    CUDA_CALL(cudaGraphGetNodes(graph, nodes, &numNodes));
+    LOG(INFO) << "Num of nodes in the cuda graph created using stream capture API = " << numNodes;
+
+    CUDA_CALL(cudaGraphInstantiate(&cu_graph_exec_, graph, NULL, NULL, 0));
+    return 0;
+  }
+
+  /*!
+   * \brief GetFunction Get the function based on input.
+   * \param name The function which needs to be invoked.
+   * \param sptr_to_self Packed function pointer.
+   */
+  PackedFunc GetFunction(const std::string& name, const ObjectPtr<Object>& sptr_to_self);
+
+ private:
+  TVMStreamHandle capture_stream_;

Review comment:
       docstring.




----------------------------------------------------------------
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org