You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by zh...@apache.org on 2020/04/15 08:49:25 UTC
[incubator-tvm] branch master updated: Windows Support for cpp_rpc
(#4857)
This is an automated email from the ASF dual-hosted git repository.
zhaowu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-tvm.git
The following commit(s) were added to refs/heads/master by this push:
new afcf939 Windows Support for cpp_rpc (#4857)
afcf939 is described below
commit afcf9397b60ae7ccf46601cf29828992ca9d5f57
Author: jmorrill <je...@gmail.com>
AuthorDate: Wed Apr 15 01:49:15 2020 -0700
Windows Support for cpp_rpc (#4857)
* Windows Support for cpp_rpc
* Add missing patches that fix crashes under Windows
* On Windows, use python to untar vs wsl
* remove some CMakeLists.txt stuff
* more minor CMakeLists.txt changes
* Remove items from CMakeLists.txt
* Minor CMakeLists.txt changes
* More minor CMakeLists.txt changes
* Even more minor CMakeLists.txt changes
* Modify readme
---
CMakeLists.txt | 8 +
apps/cpp_rpc/CMakeLists.txt | 27 ++++
apps/cpp_rpc/README.md | 10 +-
apps/cpp_rpc/main.cc | 95 ++++++++----
apps/cpp_rpc/rpc_env.cc | 305 ++++++++++++++++++++-----------------
apps/cpp_rpc/rpc_env.h | 6 +-
apps/cpp_rpc/rpc_server.cc | 250 +++++++++++++++---------------
apps/cpp_rpc/rpc_server.h | 23 ++-
apps/cpp_rpc/win32_process.cc | 273 +++++++++++++++++++++++++++++++++
apps/cpp_rpc/win32_process.h | 43 ++++++
src/runtime/rpc/rpc_socket_impl.cc | 11 +-
src/support/ring_buffer.h | 2 +-
12 files changed, 751 insertions(+), 302 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index cf334ff..8a559b8 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -66,9 +66,14 @@ tvm_option(USE_NNPACK "Build with nnpack support" OFF)
tvm_option(USE_RANDOM "Build with random support" OFF)
tvm_option(USE_MICRO_STANDALONE_RUNTIME "Build with micro.standalone_runtime support" OFF)
tvm_option(USE_ANTLR "Build with ANTLR for Relay parsing" OFF)
+tvm_option(USE_CPP_RPC "Build CPP RPC" OFF)
tvm_option(USE_TFLITE "Build with tflite support" OFF)
tvm_option(USE_TENSORFLOW_PATH "TensorFlow root path when use TFLite" none)
+if(USE_CPP_RPC AND UNIX)
+ message(FATAL_ERROR "USE_CPP_RPC is only supported with WIN32. Use the Makefile for non-Windows.")
+endif()
+
# include directories
include_directories(${CMAKE_INCLUDE_PATH})
include_directories("include")
@@ -309,6 +314,9 @@ add_library(tvm SHARED ${COMPILER_SRCS} ${RUNTIME_SRCS})
add_library(tvm_topi SHARED ${TOPI_SRCS})
add_library(tvm_runtime SHARED ${RUNTIME_SRCS})
+if(USE_CPP_RPC)
+ add_subdirectory("apps/cpp_rpc")
+endif()
if(USE_RELAY_DEBUG)
message(STATUS "Building Relay in debug mode...")
diff --git a/apps/cpp_rpc/CMakeLists.txt b/apps/cpp_rpc/CMakeLists.txt
new file mode 100644
index 0000000..9888738
--- /dev/null
+++ b/apps/cpp_rpc/CMakeLists.txt
@@ -0,0 +1,27 @@
+set(TVM_RPC_SOURCES
+ main.cc
+ rpc_env.cc
+ rpc_server.cc
+)
+
+if(WIN32)
+ list(APPEND TVM_RPC_SOURCES win32_process.cc)
+endif()
+
+# Set output to same directory as the other TVM libs
+set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR})
+add_executable(tvm_rpc ${TVM_RPC_SOURCES})
+set_property(TARGET tvm_rpc PROPERTY INTERPROCEDURAL_OPTIMIZATION_RELEASE TRUE)
+
+if(WIN32)
+ target_compile_definitions(tvm_rpc PUBLIC -DNOMINMAX)
+endif()
+
+target_include_directories(
+ tvm_rpc
+ PUBLIC "../../include"
+ PUBLIC DLPACK_PATH
+ PUBLIC DMLC_PATH
+)
+
+target_link_libraries(tvm_rpc tvm_runtime)
\ No newline at end of file
diff --git a/apps/cpp_rpc/README.md b/apps/cpp_rpc/README.md
index 4baecaf..c826dae 100644
--- a/apps/cpp_rpc/README.md
+++ b/apps/cpp_rpc/README.md
@@ -18,7 +18,7 @@
# TVM RPC Server
This folder contains a simple recipe to make RPC server in c++.
-## Usage
+## Usage (Non-Windows)
- Build tvm runtime
- Make the rpc executable [Makefile](Makefile).
`make CXX=/path/to/cross compiler g++/ TVM_RUNTIME_DIR=/path/to/tvm runtime library directory/ OS=Linux`
@@ -35,6 +35,12 @@ This folder contains a simple recipe to make RPC server in c++.
```
- Use `./tvm_rpc server` to start the RPC server
+## Usage (Windows)
+- Build tvm with the argument -DUSE_CPP_RPC
+- Install [LLVM pre-build binaries](https://releases.llvm.org/download.html), making sure to select the option to add it to the PATH.
+- Verify Python 3.6 or newer is installed and in the PATH.
+- Use `<tmv_output_dir>\tvm_rpc.exe` to start the RPC server
+
## How it works
- The tvm runtime dll is linked along with this executable and when the RPC server starts it will load the tvm runtime library.
@@ -53,4 +59,4 @@ Command line usage
```
## Note
-Currently support is only there for Linux / Android environment and proxy mode doesn't be supported currently.
\ No newline at end of file
+Currently support is only there for Linux / Android / Windows environment and proxy mode doesn't be supported currently.
\ No newline at end of file
diff --git a/apps/cpp_rpc/main.cc b/apps/cpp_rpc/main.cc
index ae66bd2..5168da3 100644
--- a/apps/cpp_rpc/main.cc
+++ b/apps/cpp_rpc/main.cc
@@ -21,10 +21,12 @@
* \file rpc_server.cc
* \brief RPC Server for TVM.
*/
-#include <stdlib.h>
-#include <signal.h>
-#include <stdio.h>
+#include <cstdlib>
+#include <csignal>
+#include <cstdio>
+#if defined(__linux__) || defined(__ANDROID__)
#include <unistd.h>
+#endif
#include <dmlc/logging.h>
#include <iostream>
#include <cstring>
@@ -35,11 +37,15 @@
#include "../../src/support/socket.h"
#include "rpc_server.h"
+#if defined(_WIN32)
+#include "win32_process.h"
+#endif
+
using namespace std;
using namespace tvm::runtime;
using namespace tvm::support;
-static const string kUSAGE = \
+static const string kUsage = \
"Command line usage\n" \
" server - Start the server\n" \
"--host - The hostname of the server, Default=0.0.0.0\n" \
@@ -73,13 +79,16 @@ struct RpcServerArgs {
string key;
string custom_addr;
bool silent = false;
+#if defined(WIN32)
+ std::string mmap_path;
+#endif
};
/*!
* \brief PrintArgs print the contents of RpcServerArgs
* \param args RpcServerArgs structure
*/
-void PrintArgs(struct RpcServerArgs args) {
+void PrintArgs(const RpcServerArgs& args) {
LOG(INFO) << "host = " << args.host;
LOG(INFO) << "port = " << args.port;
LOG(INFO) << "port_end = " << args.port_end;
@@ -89,6 +98,7 @@ void PrintArgs(struct RpcServerArgs args) {
LOG(INFO) << "silent = " << ((args.silent) ? ("True"): ("False"));
}
+#if defined(__linux__) || defined(__ANDROID__)
/*!
* \brief CtrlCHandler, exits if Ctrl+C is pressed
* \param s signal
@@ -109,7 +119,7 @@ void HandleCtrlC() {
sigIntHandler.sa_flags = 0;
sigaction(SIGINT, &sigIntHandler, nullptr);
}
-
+#endif
/*!
* \brief GetCmdOption Parse and find the command option.
* \param argc arg counter
@@ -129,7 +139,7 @@ string GetCmdOption(int argc, char* argv[], string option, bool key = false) {
}
// We assume "=" is the end of option.
CHECK_EQ(*option.rbegin(), '=');
- cmd = arg.substr(arg.find("=") + 1);
+ cmd = arg.substr(arg.find('=') + 1);
return cmd;
}
}
@@ -156,41 +166,41 @@ bool ValidateTracker(string &tracker) {
* \brief ParseCmdArgs parses the command line arguments.
* \param argc arg counter
* \param argv arg values
- * \param args, the output structure which holds the parsed values
+ * \param args the output structure which holds the parsed values
*/
void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
- string silent = GetCmdOption(argc, argv, "--silent", true);
+ const string silent = GetCmdOption(argc, argv, "--silent", true);
if (!silent.empty()) {
args.silent = true;
// Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2");
}
- string host = GetCmdOption(argc, argv, "--host=");
+ const string host = GetCmdOption(argc, argv, "--host=");
if (!host.empty()) {
if (!ValidateIP(host)) {
LOG(WARNING) << "Wrong host address format.";
- LOG(INFO) << kUSAGE;
+ LOG(INFO) << kUsage;
exit(1);
}
args.host = host;
}
- string port = GetCmdOption(argc, argv, "--port=");
+ const string port = GetCmdOption(argc, argv, "--port=");
if (!port.empty()) {
if (!IsNumber(port) || stoi(port) > 65535) {
LOG(WARNING) << "Wrong port number.";
- LOG(INFO) << kUSAGE;
+ LOG(INFO) << kUsage;
exit(1);
}
args.port = stoi(port);
}
- string port_end = GetCmdOption(argc, argv, "--port_end=");
+ const string port_end = GetCmdOption(argc, argv, "--port_end=");
if (!port_end.empty()) {
if (!IsNumber(port_end) || stoi(port_end) > 65535) {
LOG(WARNING) << "Wrong port_end number.";
- LOG(INFO) << kUSAGE;
+ LOG(INFO) << kUsage;
exit(1);
}
args.port_end = stoi(port_end);
@@ -200,26 +210,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
if (!tracker.empty()) {
if (!ValidateTracker(tracker)) {
LOG(WARNING) << "Wrong tracker address format.";
- LOG(INFO) << kUSAGE;
+ LOG(INFO) << kUsage;
exit(1);
}
args.tracker = tracker;
}
- string key = GetCmdOption(argc, argv, "--key=");
+ const string key = GetCmdOption(argc, argv, "--key=");
if (!key.empty()) {
args.key = key;
}
- string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
+ const string custom_addr = GetCmdOption(argc, argv, "--custom_addr=");
if (!custom_addr.empty()) {
if (!ValidateIP(custom_addr)) {
LOG(WARNING) << "Wrong custom address format.";
- LOG(INFO) << kUSAGE;
+ LOG(INFO) << kUsage;
exit(1);
}
args.custom_addr = custom_addr;
}
+#if defined(WIN32)
+ const string mmap_path = GetCmdOption(argc, argv, "--child_proc=");
+ if(!mmap_path.empty()) {
+ args.mmap_path = mmap_path;
+ dmlc::InitLogging("--minloglevel=0");
+ }
+#endif
+
}
/*!
@@ -229,17 +247,34 @@ void ParseCmdArgs(int argc, char * argv[], struct RpcServerArgs &args) {
* \return result of operation.
*/
int RpcServer(int argc, char * argv[]) {
- struct RpcServerArgs args;
+ RpcServerArgs args;
/* parse the command line args */
ParseCmdArgs(argc, argv, args);
PrintArgs(args);
- // Ctrl+C handler
LOG(INFO) << "Starting CPP Server, Press Ctrl+C to stop.";
+#if defined(__linux__) || defined(__ANDROID__)
+ // Ctrl+C handler
HandleCtrlC();
- tvm::runtime::RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
- args.key, args.custom_addr, args.silent);
+#endif
+
+#if defined(WIN32)
+ if(!args.mmap_path.empty()) {
+ int ret = 0;
+
+ try {
+ ChildProcSocketHandler(args.mmap_path);
+ } catch (const std::exception&) {
+ ret = -1;
+ }
+
+ return ret;
+ }
+#endif
+
+ RPCServerCreate(args.host, args.port, args.port_end, args.tracker,
+ args.key, args.custom_addr, args.silent);
return 0;
}
@@ -251,15 +286,21 @@ int RpcServer(int argc, char * argv[]) {
*/
int main(int argc, char * argv[]) {
if (argc <= 1) {
- LOG(INFO) << kUSAGE;
+ LOG(INFO) << kUsage;
return 0;
}
+ // Runs WSAStartup on Win32, no-op on POSIX
+ Socket::Startup();
+#if defined(_WIN32)
+ SetEnvironmentVariableA("CUDA_CACHE_DISABLE", "1");
+#endif
+
if (0 == strcmp(argv[1], "server")) {
- RpcServer(argc, argv);
- } else {
- LOG(INFO) << kUSAGE;
+ return RpcServer(argc, argv);
}
+ LOG(INFO) << kUsage;
+
return 0;
}
diff --git a/apps/cpp_rpc/rpc_env.cc b/apps/cpp_rpc/rpc_env.cc
index 844a7af..b5dc51b 100644
--- a/apps/cpp_rpc/rpc_env.cc
+++ b/apps/cpp_rpc/rpc_env.cc
@@ -20,77 +20,86 @@
* \file rpc_env.cc
* \brief Server environment of the RPC.
*/
+#include <cerrno>
#include <tvm/runtime/registry.h>
-#include <errno.h>
-#ifndef _MSC_VER
-#include <sys/stat.h>
+#ifndef _WIN32
#include <dirent.h>
+#include <sys/stat.h>
#include <unistd.h>
#else
#include <Windows.h>
+#include <direct.h>
+namespace {
+ int mkdir(const char* path, int /* ignored */) { return _mkdir(path); }
+}
#endif
+#include <cstring>
#include <fstream>
-#include <vector>
#include <iostream>
#include <string>
-#include <cstring>
+#include <vector>
+#include <string>
-#include "rpc_env.h"
#include "../../src/support/util.h"
#include "../../src/runtime/file_util.h"
+#include "rpc_env.h"
+
+namespace {
+ std::string GenerateUntarCommand(const std::string& tar_file, const std::string& output_dir) {
+ std::string untar_cmd;
+ untar_cmd.reserve(512);
+#if defined(__linux__) || defined(__ANDROID__)
+ untar_cmd += "tar -C ";
+ untar_cmd += output_dir;
+ untar_cmd += " -zxf ";
+ untar_cmd += tar_file;
+#elif defined(_WIN32)
+ untar_cmd += "python -m tarfile -e ";
+ untar_cmd += tar_file;
+ untar_cmd += " ";
+ untar_cmd += output_dir;
+#endif
+ return untar_cmd;
+ }
+
+}// Anonymous namespace
namespace tvm {
namespace runtime {
-
RPCEnv::RPCEnv() {
- #if defined(__linux__) || defined(__ANDROID__)
- base_ = "./rpc";
- mkdir(&base_[0], 0777);
-
- TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath")
- .set_body([](TVMArgs args, TVMRetValue* rv) {
- static RPCEnv env;
- *rv = env.GetPath(args[0]);
- });
+ base_ = "./rpc";
+ mkdir(base_.c_str(), 0777);
+ TVM_REGISTER_GLOBAL("tvm.rpc.server.workpath").set_body([](TVMArgs args, TVMRetValue* rv) {
+ static RPCEnv env;
+ *rv = env.GetPath(args[0]);
+ });
- TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module")
- .set_body([](TVMArgs args, TVMRetValue *rv) {
- static RPCEnv env;
- std::string file_name = env.GetPath(args[0]);
- *rv = Load(&file_name, "");
- LOG(INFO) << "Load module from " << file_name << " ...";
- });
- #else
- LOG(FATAL) << "Only support RPC in linux environment";
- #endif
+ TVM_REGISTER_GLOBAL("tvm.rpc.server.load_module").set_body([](TVMArgs args, TVMRetValue* rv) {
+ static RPCEnv env;
+ std::string file_name = env.GetPath(args[0]);
+ *rv = Load(&file_name, "");
+ LOG(INFO) << "Load module from " << file_name << " ...";
+ });
}
/*!
- * \brief GetPath To get the workpath from packed function
- * \param name The file name
+ * \brief GetPath To get the work path from packed function
+ * \param file_name The file name
* \return The full path of file.
*/
-std::string RPCEnv::GetPath(std::string file_name) {
+std::string RPCEnv::GetPath(const std::string& file_name) const {
// we assume file_name has "/" means file_name is the exact path
// and does not create /.rpc/
- if (file_name.find("/") != std::string::npos) {
- return file_name;
- } else {
- return base_ + "/" + file_name;
- }
+ return file_name.find('/') != std::string::npos ? file_name : base_ + "/" + file_name;
}
/*!
* \brief Remove The RPC Environment cleanup function
*/
-void RPCEnv::CleanUp() {
- #if defined(__linux__) || defined(__ANDROID__)
- CleanDir(&base_[0]);
- int ret = rmdir(&base_[0]);
- if (ret != 0) {
- LOG(WARNING) << "Remove directory " << base_ << " failed";
- }
- #else
- LOG(FATAL) << "Only support RPC in linux environment";
- #endif
+void RPCEnv::CleanUp() const {
+ CleanDir(base_);
+ const int ret = rmdir(base_.c_str());
+ if (ret != 0) {
+ LOG(WARNING) << "Remove directory " << base_ << " failed";
+ }
}
/*!
@@ -98,53 +107,54 @@ void RPCEnv::CleanUp() {
* \param dirname The root directory name
* \return vector Files in directory.
*/
-std::vector<std::string> ListDir(const std::string &dirname) {
+std::vector<std::string> ListDir(const std::string& dirname) {
std::vector<std::string> vec;
- #ifndef _MSC_VER
- DIR *dp = opendir(dirname.c_str());
- if (dp == nullptr) {
- int errsv = errno;
- LOG(FATAL) << "ListDir " << dirname <<" error: " << strerror(errsv);
- }
- dirent *d;
- while ((d = readdir(dp)) != nullptr) {
- std::string filename = d->d_name;
- if (filename != "." && filename != "..") {
- std::string f = dirname;
- if (f[f.length() - 1] != '/') {
- f += '/';
- }
- f += d->d_name;
- vec.push_back(f);
+#ifndef _WIN32
+ DIR* dp = opendir(dirname.c_str());
+ if (dp == nullptr) {
+ int errsv = errno;
+ LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv);
+ }
+ dirent* d;
+ while ((d = readdir(dp)) != nullptr) {
+ std::string filename = d->d_name;
+ if (filename != "." && filename != "..") {
+ std::string f = dirname;
+ if (f[f.length() - 1] != '/') {
+ f += '/';
}
+ f += d->d_name;
+ vec.push_back(f);
}
- closedir(dp);
- #else
- WIN32_FIND_DATA fd;
- std::string pattern = dirname + "/*";
- HANDLE handle = FindFirstFile(pattern.c_str(), &fd);
- if (handle == INVALID_HANDLE_VALUE) {
- int errsv = GetLastError();
- LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv);
- }
- do {
- if (fd.cFileName != "." && fd.cFileName != "..") {
- std::string f = dirname;
- char clast = f[f.length() - 1];
- if (f == ".") {
- f = fd.cFileName;
- } else if (clast != '/' && clast != '\\') {
- f += '/';
- f += fd.cFileName;
- }
- vec.push_back(f);
+ }
+ closedir(dp);
+#elif defined(_WIN32)
+ WIN32_FIND_DATAA fd;
+ const std::string pattern = dirname + "/*";
+ HANDLE handle = FindFirstFileA(pattern.c_str(), &fd);
+ if (handle == INVALID_HANDLE_VALUE) {
+ const int errsv = GetLastError();
+ LOG(FATAL) << "ListDir " << dirname << " error: " << strerror(errsv);
+ }
+ do {
+ std::string filename = fd.cFileName;
+ if (filename != "." && filename != "..") {
+ std::string f = dirname;
+ if (f[f.length() - 1] != '/') {
+ f += '/';
}
- } while (FindNextFile(handle, &fd));
- FindClose(handle);
- #endif
+ f += filename;
+ vec.push_back(f);
+ }
+ } while (FindNextFileA(handle, &fd));
+ FindClose(handle);
+#else
+ LOG(FATAL) << "Operating system not supported";
+#endif
return vec;
}
+#if defined(__linux__) || defined(__ANDROID__)
/*!
* \brief LinuxShared Creates a linux shared library
* \param output The output file name
@@ -152,9 +162,9 @@ std::vector<std::string> ListDir(const std::string &dirname) {
* \param options The compiler options
* \param cc The compiler
*/
-void LinuxShared(const std::string output,
+void LinuxShared(const std::string output,
const std::vector<std::string> &files,
- std::string options = "",
+ std::string options = "",
std::string cc = "g++") {
std::string cmd = cc;
cmd += " -shared -fPIC ";
@@ -169,18 +179,48 @@ void LinuxShared(const std::string output,
LOG(FATAL) << err_msg;
}
}
+#endif
+
+#ifdef _WIN32
+/*!
+ * \brief WindowsShared Creates a Windows shared library
+ * \param output The output file name
+ * \param files The files for building
+ * \param options The compiler options
+ * \param cc The compiler
+ */
+void WindowsShared(const std::string& output,
+ const std::vector<std::string>& files,
+ const std::string& options = "",
+ const std::string& cc = "clang") {
+ std::string cmd = cc;
+ cmd += " -O2 -flto=full -fuse-ld=lld-link -Wl,/EXPORT:__tvm_main__ -shared ";
+ cmd += " -o " + output;
+ for (const auto& file : files) {
+ cmd += " " + file;
+ }
+ cmd += " " + options;
+ std::string err_msg;
+ const auto executed_status = support::Execute(cmd, &err_msg);
+ if (executed_status) {
+ LOG(FATAL) << err_msg;
+ }
+}
+#endif
/*!
* \brief CreateShared Creates a shared library
* \param output The output file name
* \param files The files for building
*/
-void CreateShared(const std::string output, const std::vector<std::string> &files) {
- #if defined(__linux__) || defined(__ANDROID__)
- LinuxShared(output, files);
- #else
- LOG(FATAL) << "Do not support creating shared library";
- #endif
+void CreateShared(const std::string& output, const std::vector<std::string>& files) {
+#if defined(__linux__) || defined(__ANDROID__)
+ LinuxShared(output, files);
+#elif defined(_WIN32)
+ WindowsShared(output, files);
+#else
+ LOG(FATAL) << "Operating system not supported";
+#endif
}
/*!
@@ -193,61 +233,52 @@ void CreateShared(const std::string output, const std::vector<std::string> &file
* \param fmt The format of file
* \return Module The loaded module
*/
-Module Load(std::string *fileIn, const std::string fmt) {
- std::string file = *fileIn;
- if (support::EndsWith(file, ".so")) {
- return Module::LoadFromFile(file, fmt);
+Module Load(std::string *fileIn, const std::string& fmt) {
+ const std::string& file = *fileIn;
+ if (support::EndsWith(file, ".so") || support::EndsWith(file, ".dll")) {
+ return Module::LoadFromFile(file, fmt);
}
- #if defined(__linux__) || defined(__ANDROID__)
- std::string file_name = file + ".so";
- if (support::EndsWith(file, ".o")) {
- std::vector<std::string> files;
- files.push_back(file);
- CreateShared(file_name, files);
- } else if (support::EndsWith(file, ".tar")) {
- std::string tmp_dir = "./rpc/tmp/";
- mkdir(&tmp_dir[0], 0777);
- std::string cmd = "tar -C " + tmp_dir + " -zxf " + file;
- std::string err_msg;
- int executed_status = support::Execute(cmd, &err_msg);
- if (executed_status) {
- LOG(FATAL) << err_msg;
- }
- CreateShared(file_name, ListDir(tmp_dir));
- CleanDir(tmp_dir);
- rmdir(&tmp_dir[0]);
- } else {
- file_name = file;
+ std::string file_name = file + ".so";
+ if (support::EndsWith(file, ".o")) {
+ std::vector<std::string> files;
+ files.push_back(file);
+ CreateShared(file_name, files);
+ } else if (support::EndsWith(file, ".tar")) {
+ const std::string tmp_dir = "./rpc/tmp/";
+ mkdir(tmp_dir.c_str(), 0777);
+
+ const std::string cmd = GenerateUntarCommand(file, tmp_dir);
+
+ std::string err_msg;
+ const int executed_status = support::Execute(cmd, &err_msg);
+ if (executed_status) {
+ LOG(FATAL) << err_msg;
}
- *fileIn = file_name;
- return Module::LoadFromFile(file_name, fmt);
- #else
- LOG(FATAL) << "Do not support creating shared library";
- #endif
+ CreateShared(file_name, ListDir(tmp_dir));
+ CleanDir(tmp_dir);
+ (void)rmdir(tmp_dir.c_str());
+ } else {
+ file_name = file;
+ }
+ *fileIn = file_name;
+ return Module::LoadFromFile(file_name, fmt);
}
/*!
* \brief CleanDir Removes the files from the directory
* \param dirname The name of the directory
*/
-void CleanDir(const std::string &dirname) {
- #if defined(__linux__) || defined(__ANDROID__)
- DIR *dp = opendir(dirname.c_str());
- dirent *d;
- while ((d = readdir(dp)) != nullptr) {
- std::string filename = d->d_name;
- if (filename != "." && filename != "..") {
- filename = dirname + "/" + d->d_name;
- int ret = std::remove(&filename[0]);
- if (ret != 0) {
- LOG(WARNING) << "Remove file " << filename << " failed";
- }
- }
+void CleanDir(const std::string& dirname) {
+ auto files = ListDir(dirname);
+ for (const auto& filename : files) {
+ std::string file_path = dirname + "/";
+ file_path += filename;
+ const int ret = std::remove(filename.c_str());
+ if (ret != 0) {
+ LOG(WARNING) << "Remove file " << filename << " failed";
}
- #else
- LOG(FATAL) << "Only support RPC in linux environment";
- #endif
+ }
}
} // namespace runtime
diff --git a/apps/cpp_rpc/rpc_env.h b/apps/cpp_rpc/rpc_env.h
index 82409ba..d046f6e 100644
--- a/apps/cpp_rpc/rpc_env.h
+++ b/apps/cpp_rpc/rpc_env.h
@@ -40,7 +40,7 @@ namespace runtime {
* \param file The format of file
* \return Module The loaded module
*/
-Module Load(std::string *path, const std::string fmt = "");
+Module Load(std::string *path, const std::string& fmt = "");
/*!
* \brief CleanDir Removes the files from the directory
@@ -62,11 +62,11 @@ struct RPCEnv {
* \param name The file name
* \return The full path of file.
*/
- std::string GetPath(std::string file_name);
+ std::string GetPath(const std::string& file_name) const;
/*!
* \brief The RPC Environment cleanup function
*/
- void CleanUp();
+ void CleanUp() const;
private:
/*!
diff --git a/apps/cpp_rpc/rpc_server.cc b/apps/cpp_rpc/rpc_server.cc
index 1a29421..ea4ab00 100644
--- a/apps/cpp_rpc/rpc_server.cc
+++ b/apps/cpp_rpc/rpc_server.cc
@@ -22,24 +22,27 @@
* \brief RPC Server implementation.
*/
#include <tvm/runtime/registry.h>
-
#if defined(__linux__) || defined(__ANDROID__)
#include <sys/select.h>
#include <sys/wait.h>
#endif
-#include <set>
-#include <iostream>
-#include <future>
-#include <thread>
#include <chrono>
+#include <future>
+#include <iostream>
+#include <set>
#include <string>
-#include "rpc_server.h"
-#include "rpc_env.h"
-#include "rpc_tracker_client.h"
+#include "../../src/support/socket.h"
#include "../../src/runtime/rpc/rpc_session.h"
#include "../../src/runtime/rpc/rpc_socket_impl.h"
-#include "../../src/support/socket.h"
+#include "rpc_env.h"
+#include "rpc_server.h"
+#include "rpc_tracker_client.h"
+#if defined(_WIN32)
+#include "win32_process.h"
+#endif
+
+using namespace std::chrono;
namespace tvm {
namespace runtime {
@@ -49,7 +52,7 @@ namespace runtime {
* \param status status value
*/
#if defined(__linux__) || defined(__ANDROID__)
-static pid_t waitPidEintr(int *status) {
+static pid_t waitPidEintr(int* status) {
pid_t pid = 0;
while ((pid = waitpid(-1, status, 0)) == -1) {
if (errno == EINTR) {
@@ -76,34 +79,32 @@ class RPCServer {
public:
/*!
* \brief Constructor.
- */
- RPCServer(const std::string &host,
- int port,
- int port_end,
- const std::string &tracker_addr,
- const std::string &key,
- const std::string &custom_addr) {
- // Init the values
- host_ = host;
- port_ = port;
- port_end_ = port_end;
- tracker_addr_ = tracker_addr;
- key_ = key;
- custom_addr_ = custom_addr;
+ */
+ RPCServer(std::string host, int port, int port_end, std::string tracker_addr,
+ std::string key, std::string custom_addr) :
+ host_(std::move(host)), port_(port), my_port_(0), port_end_(port_end),
+ tracker_addr_(std::move(tracker_addr)), key_(std::move(key)),
+ custom_addr_(std::move(custom_addr))
+ {
+
}
/*!
* \brief Destructor.
- */
+ */
~RPCServer() {
- // Free the resources
- tracker_sock_.Close();
- listen_sock_.Close();
+ try {
+ // Free the resources
+ tracker_sock_.Close();
+ listen_sock_.Close();
+ } catch(...) {
+
+ }
}
/*!
* \brief Start Creates the RPC listen process and execution.
- */
+ */
void Start() {
listen_sock_.Create();
my_port_ = listen_sock_.TryBindHost(host_, port_, port_end_);
@@ -130,102 +131,98 @@ class RPCServer {
tracker.TryConnect();
// step 2: wait for in-coming connections
AcceptConnection(&tracker, &conn, &addr, &opts);
- }
- catch (const char* msg) {
+ } catch (const char* msg) {
LOG(WARNING) << "Socket exception: " << msg;
// close tracker resource
tracker.Close();
continue;
- }
- catch (std::exception& e) {
- // Other errors
+ } catch (const std::exception& e) {
+ // close tracker resource
+ tracker.Close();
LOG(WARNING) << "Exception standard: " << e.what();
continue;
}
int timeout = GetTimeOutFromOpts(opts);
- #if defined(__linux__) || defined(__ANDROID__)
- // step 3: serving
- if (timeout != 0) {
- const pid_t timer_pid = fork();
- if (timer_pid == 0) {
- // Timer process
- sleep(timeout);
- exit(0);
- }
+#if defined(__linux__) || defined(__ANDROID__)
+ // step 3: serving
+ if (timeout != 0) {
+ const pid_t timer_pid = fork();
+ if (timer_pid == 0) {
+ // Timer process
+ sleep(timeout);
+ exit(0);
+ }
- const pid_t worker_pid = fork();
- if (worker_pid == 0) {
- // Worker process
- ServerLoopProc(conn, addr);
- exit(0);
- }
+ const pid_t worker_pid = fork();
+ if (worker_pid == 0) {
+ // Worker process
+ ServerLoopProc(conn, addr);
+ exit(0);
+ }
- int status = 0;
- const pid_t finished_first = waitPidEintr(&status);
- if (finished_first == timer_pid) {
- kill(worker_pid, SIGKILL);
- } else if (finished_first == worker_pid) {
- kill(timer_pid, SIGKILL);
- } else {
- LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue.";
- }
+ int status = 0;
+ const pid_t finished_first = waitPidEintr(&status);
+ if (finished_first == timer_pid) {
+ kill(worker_pid, SIGKILL);
+ } else if (finished_first == worker_pid) {
+ kill(timer_pid, SIGKILL);
+ } else {
+ LOG(INFO) << "Child pid=" << finished_first << " unexpected, but still continue.";
+ }
- int status_second = 0;
- waitPidEintr(&status_second);
+ int status_second = 0;
+ waitPidEintr(&status_second);
- // Logging.
- if (finished_first == timer_pid) {
- LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout
- << "), Process status = " << status_second;
- } else if (finished_first == worker_pid) {
- LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second;
- }
- } else {
- auto pid = fork();
- if (pid == 0) {
- ServerLoopProc(conn, addr);
- exit(0);
- }
- // Wait for the result
- int status = 0;
- wait(&status);
- LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status;
+ // Logging.
+ if (finished_first == timer_pid) {
+ LOG(INFO) << "Child pid=" << worker_pid << " killed (timeout = " << timeout
+ << "), Process status = " << status_second;
+ } else if (finished_first == worker_pid) {
+ LOG(INFO) << "Child pid=" << timer_pid << " killed, Process status = " << status_second;
}
- #else
- // step 3: serving
- std::future<void> proc(std::async(std::launch::async,
- &RPCServer::ServerLoopProc, this, conn, addr));
- // wait until server process finish or timeout
- if (timeout != 0) {
- // Autoterminate after timeout
- proc.wait_for(std::chrono::seconds(timeout));
- } else {
- // Wait for the result
- proc.get();
+ } else {
+ auto pid = fork();
+ if (pid == 0) {
+ ServerLoopProc(conn, addr);
+ exit(0);
}
- #endif
+ // Wait for the result
+ int status = 0;
+ wait(&status);
+ LOG(INFO) << "Child pid=" << pid << " exited, Process status =" << status;
+ }
+#elif defined(WIN32)
+ auto start_time = high_resolution_clock::now();
+ try {
+ SpawnRPCChild(conn.sockfd, seconds(timeout));
+ } catch (const std::exception&) {
+
+ }
+ auto dur = high_resolution_clock::now() - start_time;
+
+ LOG(INFO) << "Serve Time " << duration_cast<milliseconds>(dur).count() << "ms";
+#endif
// close from our side.
LOG(INFO) << "Socket Connection Closed";
conn.Close();
}
}
-
/*!
* \brief AcceptConnection Accepts the RPC Server connection.
* \param tracker Tracker details.
- * \param conn New connection information.
+ * \param conn_sock New connection information.
* \param addr New connection address information.
* \param opts Parsed options for socket
* \param ping_period Timeout for select call waiting
*/
- void AcceptConnection(TrackerClient* tracker,
+ void AcceptConnection(TrackerClient* tracker,
support::TCPSocket* conn_sock,
- support::SockAddr* addr,
- std::string* opts,
+ support::SockAddr* addr,
+ std::string* opts,
int ping_period = 2) {
- std::set <std::string> old_keyset;
+ std::set<std::string> old_keyset;
std::string matchkey;
// Report resource to tracker and get key
@@ -236,7 +233,7 @@ class RPCServer {
support::TCPSocket conn = listen_sock_.Accept(addr);
int code = kRPCMagic;
- CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code));
+ CHECK_EQ(conn.RecvAll(&code, sizeof(code)), sizeof(code));
if (code != kRPCMagic) {
conn.Close();
LOG(FATAL) << "Client connected is not TVM RPC server";
@@ -265,15 +262,15 @@ class RPCServer {
std::string arg0;
ssin >> arg0;
if (arg0 != expect_header) {
- code = kRPCMismatch;
- CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code));
- conn.Close();
- LOG(WARNING) << "Mismatch key from" << addr->AsString();
- continue;
+ code = kRPCMismatch;
+ CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code));
+ conn.Close();
+ LOG(WARNING) << "Mismatch key from" << addr->AsString();
+ continue;
} else {
code = kRPCSuccess;
CHECK_EQ(conn.SendAll(&code, sizeof(code)), sizeof(code));
- keylen = server_key.length();
+ keylen = int(server_key.length());
CHECK_EQ(conn.SendAll(&keylen, sizeof(keylen)), sizeof(keylen));
CHECK_EQ(conn.SendAll(server_key.c_str(), keylen), keylen);
LOG(INFO) << "Connection success " << addr->AsString();
@@ -289,25 +286,23 @@ class RPCServer {
* \param sock The socket information
* \param addr The socket address information
*/
- void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) {
- // Server loop
- auto env = RPCEnv();
- RPCServerLoop(sock.sockfd);
- LOG(INFO) << "Finish serving " << addr.AsString();
- env.CleanUp();
+ static void ServerLoopProc(support::TCPSocket sock, support::SockAddr addr) {
+ // Server loop
+ const auto env = RPCEnv();
+ RPCServerLoop(int(sock.sockfd));
+ LOG(INFO) << "Finish serving " << addr.AsString();
+ env.CleanUp();
}
/*!
* \brief GetTimeOutFromOpts Parse and get the timeout option.
* \param opts The option string
- * \param timeout value after parsing.
*/
- int GetTimeOutFromOpts(std::string opts) {
- std::string cmd;
- std::string option = "-timeout=";
+ int GetTimeOutFromOpts(const std::string& opts) const {
+ const std::string option = "-timeout=";
if (opts.find(option) == 0) {
- cmd = opts.substr(opts.find_last_of(option) + 1);
+ const std::string cmd = opts.substr(opts.find_last_of(option) + 1);
CHECK(support::IsNumber(cmd)) << "Timeout is not valid";
return std::stoi(cmd);
}
@@ -325,29 +320,40 @@ class RPCServer {
support::TCPSocket tracker_sock_;
};
+#if defined(WIN32)
+/*!
+* \brief ServerLoopFromChild The Server loop process.
+* \param socket The socket information
+*/
+void ServerLoopFromChild(SOCKET socket) {
+ // Server loop
+ tvm::support::TCPSocket sock(socket);
+ const auto env = RPCEnv();
+ RPCServerLoop(int(sock.sockfd));
+
+ sock.Close();
+ env.CleanUp();
+}
+#endif
+
/*!
* \brief RPCServerCreate Creates the RPC Server.
* \param host The hostname of the server, Default=0.0.0.0
* \param port The port of the RPC, Default=9090
* \param port_end The end search port of the RPC, Default=9199
- * \param tracker The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
+ * \param tracker_addr The address of RPC tracker in host:port format e.g. 10.77.1.234:9190 Default=""
* \param key The key used to identify the device type in tracker. Default=""
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \param silent Whether run in silent mode. Default=True
*/
-void RPCServerCreate(std::string host,
- int port,
- int port_end,
- std::string tracker_addr,
- std::string key,
- std::string custom_addr,
- bool silent) {
+void RPCServerCreate(std::string host, int port, int port_end, std::string tracker_addr,
+ std::string key, std::string custom_addr, bool silent) {
if (silent) {
// Only errors and fatal is logged
dmlc::InitLogging("--minloglevel=2");
}
// Start the rpc server
- RPCServer rpc(host, port, port_end, tracker_addr, key, custom_addr);
+ RPCServer rpc(std::move(host), port, port_end, std::move(tracker_addr), std::move(key), std::move(custom_addr));
rpc.Start();
}
diff --git a/apps/cpp_rpc/rpc_server.h b/apps/cpp_rpc/rpc_server.h
index 205182e..db7c89d 100644
--- a/apps/cpp_rpc/rpc_server.h
+++ b/apps/cpp_rpc/rpc_server.h
@@ -30,6 +30,15 @@
namespace tvm {
namespace runtime {
+#if defined(WIN32)
+/*!
+ * \brief ServerLoopFromChild The Server loop process.
+ * \param sock The socket information
+ * \param addr The socket address information
+ */
+void ServerLoopFromChild(SOCKET socket);
+#endif
+
/*!
* \brief RPCServerCreate Creates the RPC Server.
* \param host The hostname of the server, Default=0.0.0.0
@@ -40,13 +49,13 @@ namespace runtime {
* \param custom_addr Custom IP Address to Report to RPC Tracker. Default=""
* \param silent Whether run in silent mode. Default=True
*/
-TVM_DLL void RPCServerCreate(std::string host = "",
- int port = 9090,
- int port_end = 9099,
- std::string tracker_addr = "",
- std::string key = "",
- std::string custom_addr = "",
- bool silent = true);
+void RPCServerCreate(std::string host = "",
+ int port = 9090,
+ int port_end = 9099,
+ std::string tracker_addr = "",
+ std::string key = "",
+ std::string custom_addr = "",
+ bool silent = true);
} // namespace runtime
} // namespace tvm
#endif // TVM_APPS_CPP_RPC_SERVER_H_
diff --git a/apps/cpp_rpc/win32_process.cc b/apps/cpp_rpc/win32_process.cc
new file mode 100644
index 0000000..c6c72d7
--- /dev/null
+++ b/apps/cpp_rpc/win32_process.cc
@@ -0,0 +1,273 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+#ifndef WIN32_LEAN_AND_MEAN
+#define WIN32_LEAN_AND_MEAN
+#endif
+#include <winsock2.h>
+#include <ws2tcpip.h>
+#include <cstdio>
+#include <memory>
+#include <conio.h>
+#include <string>
+#include <stdexcept>
+#include <dmlc/logging.h>
+#include "win32_process.h"
+#include "rpc_server.h"
+
+using namespace std::chrono;
+using namespace tvm::runtime;
+
+namespace {
+// The prefix path for the memory mapped file used to store IPC information
+const std::string kMemoryMapPrefix = "/MAPPED_FILE/TVM_RPC";
+// Used to construct unique names for named resources in the parent process
+const std::string kParent = "parent";
+// Used to construct unique names for named resources in the child process
+const std::string kChild = "child";
+// The timeout of the WIN32 events, in the parent and the child
+const milliseconds kEventTimeout(2000);
+
+// Used to create unique WIN32 mmap paths and event names
+int child_counter_ = 0;
+
+/*!
+ * \brief HandleDeleter Deleter for UniqueHandle smart pointer
+ * \param handle The WIN32 HANDLE to manage
+ */
+struct HandleDeleter {
+ void operator()(HANDLE handle) const {
+ if (handle != INVALID_HANDLE_VALUE && handle != nullptr) {
+ CloseHandle(handle);
+ }
+ }
+};
+
+/*!
+ * \brief UniqueHandle Smart pointer to manage a WIN32 HANDLE
+ */
+using UniqueHandle = std::unique_ptr<void, HandleDeleter>;
+
+/*!
+ * \brief MakeUniqueHandle Helper method to construct a UniqueHandle
+ * \param handle The WIN32 HANDLE to manage
+ */
+UniqueHandle MakeUniqueHandle(HANDLE handle) {
+ if (handle == INVALID_HANDLE_VALUE || handle == nullptr) {
+ return nullptr;
+ }
+
+ return UniqueHandle(handle);
+}
+
+/*!
+ * \brief GetSocket Gets the socket info from the parent process and duplicates the socket
+ * \param mmap_path The path to the memory mapped info set by the parent
+ */
+SOCKET GetSocket(const std::string& mmap_path) {
+ WSAPROTOCOL_INFO protocol_info;
+
+ const std::string parent_event_name = mmap_path + kParent;
+ const std::string child_event_name = mmap_path + kChild;
+
+ // Open the events
+ UniqueHandle parent_file_mapping_event;
+ if ((parent_file_mapping_event = MakeUniqueHandle(OpenEventA(SYNCHRONIZE, false, parent_event_name.c_str()))) == nullptr) {
+ LOG(FATAL) << "OpenEvent() failed: " << GetLastError();
+ }
+
+ UniqueHandle child_file_mapping_event;
+ if ((child_file_mapping_event = MakeUniqueHandle(OpenEventA(EVENT_MODIFY_STATE, false, child_event_name.c_str()))) == nullptr) {
+ LOG(FATAL) << "OpenEvent() failed: " << GetLastError();
+ }
+
+ // Wait for the parent to set the event, notifying WSAPROTOCOL_INFO is ready to be read
+ if (WaitForSingleObject(parent_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) {
+ LOG(FATAL) << "WaitForSingleObject() failed: " << GetLastError();
+ }
+
+ const UniqueHandle file_map = MakeUniqueHandle(OpenFileMappingA(FILE_MAP_READ | FILE_MAP_WRITE,
+ false,
+ mmap_path.c_str()));
+ if (!file_map) {
+ LOG(INFO) << "CreateFileMapping() failed: " << GetLastError();
+ }
+
+ void* map_view = MapViewOfFile(file_map.get(),
+ FILE_MAP_READ | FILE_MAP_WRITE,
+ 0, 0, 0);
+
+ SOCKET sock_duplicated = INVALID_SOCKET;
+
+ if (map_view != nullptr) {
+ memcpy(&protocol_info, map_view, sizeof(WSAPROTOCOL_INFO));
+ UnmapViewOfFile(map_view);
+
+ // Creates the duplicate socket, that was created in the parent
+ sock_duplicated = WSASocket(FROM_PROTOCOL_INFO,
+ FROM_PROTOCOL_INFO,
+ FROM_PROTOCOL_INFO,
+ &protocol_info,
+ 0,
+ 0);
+
+ // Let the parent know we are finished dupicating the socket
+ SetEvent(child_file_mapping_event.get());
+ } else {
+ LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError();
+ }
+
+ return sock_duplicated;
+}
+}// Anonymous namespace
+
+namespace tvm {
+namespace runtime {
+/*!
+ * \brief SpawnRPCChild Spawns a child process with a given timeout to run
+ * \param fd The client socket to duplicate in the child
+ * \param timeout The time in seconds to wait for the child to complete before termination
+ */
+void SpawnRPCChild(SOCKET fd, seconds timeout) {
+ STARTUPINFOA startup_info;
+
+ memset(&startup_info, 0, sizeof(startup_info));
+ startup_info.cb = sizeof(startup_info);
+
+ std::string file_map_path = kMemoryMapPrefix + std::to_string(child_counter_++);
+
+ const std::string parent_event_name = file_map_path + kParent;
+ const std::string child_event_name = file_map_path + kChild;
+
+ // Create an event to let the child know the socket info was set to the mmap file
+ UniqueHandle parent_file_mapping_event;
+ if ((parent_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, parent_event_name.c_str()))) == nullptr) {
+ LOG(FATAL) << "CreateEvent for parent file mapping failed";
+ }
+
+ UniqueHandle child_file_mapping_event;
+ // An event to let the parent know the socket info was read from the mmap file
+ if ((child_file_mapping_event = MakeUniqueHandle(CreateEventA(nullptr, true, false, child_event_name.c_str()))) == nullptr) {
+ LOG(FATAL) << "CreateEvent for child file mapping failed";
+ }
+
+ char current_executable[MAX_PATH];
+
+ // Get the full path of the current executable
+ GetModuleFileNameA(nullptr, current_executable, MAX_PATH);
+
+ std::string child_command_line = current_executable;
+ child_command_line += " server --child_proc=";
+ child_command_line += file_map_path;
+
+ // CreateProcessA requires a non const char*, so we copy our std::string
+ std::unique_ptr<char[]> command_line_ptr(new char[child_command_line.size() + 1]);
+ strcpy(command_line_ptr.get(), child_command_line.c_str());
+
+ PROCESS_INFORMATION child_process_info;
+ if (CreateProcessA(nullptr,
+ command_line_ptr.get(),
+ nullptr,
+ nullptr,
+ false,
+ CREATE_NO_WINDOW,
+ nullptr,
+ nullptr,
+ &startup_info,
+ &child_process_info)) {
+ // Child process and thread handles must be closed, so wrapped in RAII
+ auto child_process_handle = MakeUniqueHandle(child_process_info.hProcess);
+ auto child_process_thread_handle = MakeUniqueHandle(child_process_info.hThread);
+
+ WSAPROTOCOL_INFO protocol_info;
+ // Get info needed to duplicate the socket
+ if (WSADuplicateSocket(fd,
+ child_process_info.dwProcessId,
+ &protocol_info) == SOCKET_ERROR) {
+ LOG(FATAL) << "WSADuplicateSocket(): failed. Error =" << WSAGetLastError();
+ }
+
+ // Create a mmap file to store the info needed for duplicating the SOCKET in the child proc
+ UniqueHandle file_map = MakeUniqueHandle(CreateFileMappingA(INVALID_HANDLE_VALUE,
+ nullptr,
+ PAGE_READWRITE,
+ 0,
+ sizeof(WSAPROTOCOL_INFO),
+ file_map_path.c_str()));
+ if (!file_map) {
+ LOG(INFO) << "CreateFileMapping() failed: " << GetLastError();
+ }
+
+ if (GetLastError() == ERROR_ALREADY_EXISTS) {
+ LOG(FATAL) << "CreateFileMapping(): mapping file already exists";
+ } else {
+ void* map_view = MapViewOfFile(file_map.get(), FILE_MAP_READ | FILE_MAP_WRITE, 0, 0, 0);
+
+ if (map_view != nullptr) {
+ memcpy(map_view, &protocol_info, sizeof(WSAPROTOCOL_INFO));
+ UnmapViewOfFile(map_view);
+
+ // Let child proc know the mmap file is ready to be read
+ SetEvent(parent_file_mapping_event.get());
+
+ // Wait for the child to finish reading mmap file
+ if (WaitForSingleObject(child_file_mapping_event.get(), uint32_t(kEventTimeout.count())) != WAIT_OBJECT_0) {
+ TerminateProcess(child_process_handle.get(), 0);
+ LOG(FATAL) << "WaitForSingleObject for child file mapping timed out. Terminating child process.";
+ }
+ } else {
+ TerminateProcess(child_process_handle.get(), 0);
+ LOG(FATAL) << "MapViewOfFile() failed: " << GetLastError();
+ }
+ }
+
+ const DWORD process_timeout = timeout.count()
+ ? uint32_t(duration_cast<milliseconds>(timeout).count())
+ : INFINITE;
+
+ // Wait for child process to exit, or hit configured timeout
+ if (WaitForSingleObject(child_process_handle.get(), process_timeout) != WAIT_OBJECT_0) {
+ LOG(INFO) << "Child process timeout. Terminating.";
+ TerminateProcess(child_process_handle.get(), 0);
+ }
+ } else {
+ LOG(INFO) << "Create child process failed: " << GetLastError();
+ }
+}
+/*!
+ * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket
+ * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent
+ */
+void ChildProcSocketHandler(const std::string& mmap_path) {
+ SOCKET socket;
+
+ // Set high thread priority to avoid the thread scheduler from
+ // interfering with any measurements in the RPC server.
+ SetThreadPriority(GetCurrentThread(), THREAD_PRIORITY_TIME_CRITICAL);
+
+ if ((socket = GetSocket(mmap_path)) != INVALID_SOCKET) {
+ tvm::runtime::ServerLoopFromChild(socket);
+ }
+ else {
+ LOG(FATAL) << "GetSocket() failed";
+ }
+
+}
+} // namespace runtime
+} // namespace tvm
\ No newline at end of file
diff --git a/apps/cpp_rpc/win32_process.h b/apps/cpp_rpc/win32_process.h
new file mode 100644
index 0000000..7d1a276
--- /dev/null
+++ b/apps/cpp_rpc/win32_process.h
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+ /*!
+ * \file win32_process.h
+ * \brief Win32 process code to mimic a POSIX fork()
+ */
+#ifndef TVM_APPS_CPP_RPC_WIN32_PROCESS_H_
+#define TVM_APPS_CPP_RPC_WIN32_PROCESS_H_
+#include <chrono>
+#include <string>
+namespace tvm {
+namespace runtime {
+/*!
+ * \brief SpawnRPCChild Spawns a child process with a given timeout to run
+ * \param fd The client socket to duplicate in the child
+ * \param timeout The time in seconds to wait for the child to complete before termination
+ */
+void SpawnRPCChild(SOCKET fd, std::chrono::seconds timeout);
+/*!
+ * \brief ChildProcSocketHandler Ran from the child process and runs server to handle the client socket
+ * \param mmap_path The memory mapped file path that will contain the information to duplicate the client socket from the parent
+ */
+void ChildProcSocketHandler(const std::string& mmap_path);
+} // namespace runtime
+} // namespace tvm
+#endif // TVM_APPS_CPP_RPC_WIN32_PROCESS_H_
\ No newline at end of file
diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc
index 6b4e341..642fbb8 100644
--- a/src/runtime/rpc/rpc_socket_impl.cc
+++ b/src/runtime/rpc/rpc_socket_impl.cc
@@ -34,8 +34,12 @@ class SockChannel final : public RPCChannel {
explicit SockChannel(support::TCPSocket sock)
: sock_(sock) {}
~SockChannel() {
- if (!sock_.BadSocket()) {
- sock_.Close();
+ try {
+ // BadSocket can throw
+ if (!sock_.BadSocket()) {
+ sock_.Close();
+ }
+ } catch (...) {
}
}
size_t Send(const void* data, size_t size) final {
@@ -100,7 +104,8 @@ Module RPCClientConnect(std::string url, int port, std::string key) {
return CreateRPCModule(RPCConnect(url, port, "client:" + key));
}
-void RPCServerLoop(int sockfd) {
+// TVM_DLL needed for MSVC
+TVM_DLL void RPCServerLoop(int sockfd) {
support::TCPSocket sock(
static_cast<support::TCPSocket::SockType>(sockfd));
RPCSession::Create(
diff --git a/src/support/ring_buffer.h b/src/support/ring_buffer.h
index 7700a96..e6e3b04 100644
--- a/src/support/ring_buffer.h
+++ b/src/support/ring_buffer.h
@@ -63,7 +63,7 @@ class RingBuffer {
size_t ncopy = head_ptr_ + bytes_available_ - old_size;
memcpy(&ring_[0] + old_size, &ring_[0], ncopy);
}
- } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity) {
+ } else if (ring_.size() > n * 8 && ring_.size() > kInitCapacity && bytes_available_ > 0) {
// shrink too large temporary buffer to avoid out of memory on some embedded devices
size_t old_bytes = bytes_available_;