You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by ma...@apache.org on 2022/05/06 11:14:28 UTC
[tvm] branch main updated: Revert "Implemented rpc logging (#10967)" (#11227)
This is an automated email from the ASF dual-hosted git repository.
manupa pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new ff7efe767a Revert "Implemented rpc logging (#10967)" (#11227)
ff7efe767a is described below
commit ff7efe767a25ca140dc70ed723c37e6a1cd11c77
Author: Leandro Nunes <le...@arm.com>
AuthorDate: Fri May 6 12:14:20 2022 +0100
Revert "Implemented rpc logging (#10967)" (#11227)
This reverts commit aa3bcd9d3374878c5e958b842f51bfd82f0ebd9e, because it
fails on Windows CI as reported in issue #11220. PR #11223 tries to address
it but is is failing in the regular CI with testing issue on Hexagon.
---
CMakeLists.txt | 1 -
python/tvm/micro/session.py | 1 -
python/tvm/rpc/client.py | 13 +-
src/runtime/crt/microtvm_rpc_server/rpc_server.cc | 2 +
src/runtime/micro/micro_session.cc | 8 -
src/runtime/minrpc/minrpc_interfaces.h | 93 ----
src/runtime/minrpc/minrpc_logger.cc | 291 ----------
src/runtime/minrpc/minrpc_logger.h | 296 ----------
src/runtime/minrpc/minrpc_server.h | 649 +++++++++-------------
src/runtime/minrpc/minrpc_server_logging.h | 166 ------
src/runtime/rpc/rpc_channel_logger.h | 183 ------
src/runtime/rpc/rpc_endpoint.h | 2 -
src/runtime/rpc/rpc_socket_impl.cc | 21 +-
tests/python/unittest/test_runtime_rpc.py | 23 +-
14 files changed, 275 insertions(+), 1474 deletions(-)
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7023caf97e..90cc0f9518 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -318,7 +318,6 @@ list(APPEND COMPILER_SRCS "src/target/datatype/myfloat/myfloat.cc")
tvm_file_glob(GLOB RUNTIME_SRCS
src/runtime/*.cc
src/runtime/vm/*.cc
- src/runtime/minrpc/*.cc
)
if(BUILD_FOR_HEXAGON)
diff --git a/python/tvm/micro/session.py b/python/tvm/micro/session.py
index 4c38476207..4f754d9d44 100644
--- a/python/tvm/micro/session.py
+++ b/python/tvm/micro/session.py
@@ -133,7 +133,6 @@ class Session:
int(timeouts.session_start_timeout_sec * 1e6),
int(timeouts.session_established_timeout_sec * 1e6),
self._cleanup,
- False,
)
)
self.device = self._rpc.cpu(0)
diff --git a/python/tvm/rpc/client.py b/python/tvm/rpc/client.py
index eddc324b33..4e6c902538 100644
--- a/python/tvm/rpc/client.py
+++ b/python/tvm/rpc/client.py
@@ -459,9 +459,7 @@ class TrackerSession(object):
)
-def connect(
- url, port, key="", session_timeout=0, session_constructor_args=None, enable_logging=False
-):
+def connect(url, port, key="", session_timeout=0, session_constructor_args=None):
"""Connect to RPC Server
Parameters
@@ -485,9 +483,6 @@ def connect(
The first element of the list is always a string specifying the name of
the session constructor, the following args are the positional args to that function.
- enable_logging: boolean
- flag to enable/disable logging. Logging is disabled by default.
-
Returns
-------
sess : RPCSession
@@ -508,9 +503,9 @@ def connect(
.. code-block:: python
client_via_proxy = rpc.connect(
- proxy_server_url, proxy_server_port, proxy_server_key, enable_logging
+ proxy_server_url, proxy_server_port, proxy_server_key,
session_constructor_args=[
- "rpc.Connect", internal_url, internal_port, internal_key, internal_logging])
+ "rpc.Connect", internal_url, internal_port, internal_key])
"""
try:
@@ -519,7 +514,7 @@ def connect(
session_constructor_args = session_constructor_args if session_constructor_args else []
if not isinstance(session_constructor_args, (list, tuple)):
raise TypeError("Expect the session constructor to be a list or tuple")
- sess = _ffi_api.Connect(url, port, key, enable_logging, *session_constructor_args)
+ sess = _ffi_api.Connect(url, port, key, *session_constructor_args)
except NameError:
raise RuntimeError("Please compile with USE_RPC=1")
return RPCSession(sess)
diff --git a/src/runtime/crt/microtvm_rpc_server/rpc_server.cc b/src/runtime/crt/microtvm_rpc_server/rpc_server.cc
index b7bae243ec..ac10c82b58 100644
--- a/src/runtime/crt/microtvm_rpc_server/rpc_server.cc
+++ b/src/runtime/crt/microtvm_rpc_server/rpc_server.cc
@@ -193,6 +193,8 @@ class MicroRPCServer {
} // namespace runtime
} // namespace tvm
+void* operator new[](size_t count, void* ptr) noexcept { return ptr; }
+
extern "C" {
static microtvm_rpc_server_t g_rpc_server = nullptr;
diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc
index 6911c2021a..9e6664ff59 100644
--- a/src/runtime/micro/micro_session.cc
+++ b/src/runtime/micro/micro_session.cc
@@ -38,7 +38,6 @@
#include "../../support/str_escape.h"
#include "../rpc/rpc_channel.h"
-#include "../rpc/rpc_channel_logger.h"
#include "../rpc/rpc_endpoint.h"
#include "../rpc/rpc_session.h"
#include "crt_config.h"
@@ -405,13 +404,6 @@ TVM_REGISTER_GLOBAL("micro._rpc_connect").set_body([](TVMArgs args, TVMRetValue*
throw std::runtime_error(ss.str());
}
std::unique_ptr<RPCChannel> channel(micro_channel);
- bool enable_logging = false;
- if (args.num_args > 7) {
- enable_logging = args[7];
- }
- if (enable_logging) {
- channel.reset(new RPCChannelLogging(std::move(channel)));
- }
auto ep = RPCEndpoint::Create(std::move(channel), args[0], "", args[6]);
auto sess = CreateClientSession(ep);
*rv = CreateRPCSessionModule(sess);
diff --git a/src/runtime/minrpc/minrpc_interfaces.h b/src/runtime/minrpc/minrpc_interfaces.h
deleted file mode 100644
index a45dee9f2c..0000000000
--- a/src/runtime/minrpc/minrpc_interfaces.h
+++ /dev/null
@@ -1,93 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_
-#define TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_
-
-#include <tvm/runtime/c_runtime_api.h>
-
-#include "rpc_reference.h"
-
-namespace tvm {
-namespace runtime {
-
-/*!
- * \brief Return interface used in ExecInterface to generate and send the responses.
- */
-class MinRPCReturnInterface {
- public:
- virtual ~MinRPCReturnInterface() {}
- /*! * \brief sends a response to the client with kTVMNullptr in payload. */
- virtual void ReturnVoid() = 0;
-
- /*! * \brief sends a response to the client with one kTVMOpaqueHandle in payload. */
- virtual void ReturnHandle(void* handle) = 0;
-
- /*! * \brief sends an exception response to the client with a kTVMStr in payload. */
- virtual void ReturnException(const char* msg) = 0;
-
- /*! * \brief sends a packed argument sequnce to the client. */
- virtual void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) = 0;
-
- /*! * \brief sends a copy of the requested remote data to the client. */
- virtual void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) = 0;
-
- /*! * \brief sends an exception response to the client with the last TVM erros as the message. */
- virtual void ReturnLastTVMError() = 0;
-
- /*! * \brief internal error. */
- virtual void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) = 0;
-};
-
-/*!
- * \brief Execute interface used in MinRPCServer to process different received commands
- */
-class MinRPCExecInterface {
- public:
- virtual ~MinRPCExecInterface() {}
-
- /*! * \brief Execute an Initilize server command. */
- virtual void InitServer(int num_args) = 0;
-
- /*! * \brief calls a function specified by the call_handle. */
- virtual void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes,
- int num_args) = 0;
-
- /*! * \brief Execute a copy from remote command by sending the data described in arr to the client
- */
- virtual void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) = 0;
-
- /*! * \brief Execute a copy to remote command by receiving the data described in arr from the
- * client */
- virtual int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) = 0;
-
- /*! * \brief calls a system function specified by the code. */
- virtual void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) = 0;
-
- /*! * \brief internal error. */
- virtual void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) = 0;
-
- /*! * \brief return the ReturnInterface pointer that is used to generate and send the responses.
- */
- virtual MinRPCReturnInterface* GetReturnInterface() = 0;
-};
-
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_MINRPC_MINRPC_INTERFACES_H_
diff --git a/src/runtime/minrpc/minrpc_logger.cc b/src/runtime/minrpc/minrpc_logger.cc
deleted file mode 100644
index 4f3b7e764c..0000000000
--- a/src/runtime/minrpc/minrpc_logger.cc
+++ /dev/null
@@ -1,291 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#include "minrpc_logger.h"
-
-#include <string.h>
-#include <time.h>
-#include <tvm/runtime/c_runtime_api.h>
-#include <tvm/runtime/logging.h>
-
-#include <functional>
-#include <iostream>
-#include <sstream>
-#include <unordered_map>
-
-#include "minrpc_interfaces.h"
-#include "rpc_reference.h"
-
-namespace tvm {
-namespace runtime {
-
-void Logger::LogTVMValue(int tcode, TVMValue value) {
- switch (tcode) {
- case kDLInt: {
- LogValue<int64_t>("(int64)", value.v_int64);
- break;
- }
- case kDLUInt: {
- LogValue<uint64_t>("(uint64)", value.v_int64);
- break;
- }
- case kDLFloat: {
- LogValue<float>("(float)", value.v_float64);
- break;
- }
- case kTVMDataType: {
- LogDLData("DLDataType(code,bits,lane)", &value.v_type);
- break;
- }
- case kDLDevice: {
- LogDLDevice("DLDevice(type,id)", &value.v_device);
- break;
- }
- case kTVMPackedFuncHandle: {
- LogValue<void*>("(PackedFuncHandle)", value.v_handle);
- break;
- }
- case kTVMModuleHandle: {
- LogValue<void*>("(ModuleHandle)", value.v_handle);
- break;
- }
- case kTVMOpaqueHandle: {
- LogValue<void*>("(OpaqueHandle)", value.v_handle);
- break;
- }
- case kTVMDLTensorHandle: {
- LogValue<void*>("(TensorHandle)", value.v_handle);
- break;
- }
- case kTVMNDArrayHandle: {
- LogValue<void*>("kTVMNDArrayHandle", value.v_handle);
- break;
- }
- case kTVMNullptr: {
- Log("Nullptr");
- break;
- }
- case kTVMStr: {
- Log("\"");
- Log(value.v_str);
- Log("\"");
- break;
- }
- case kTVMBytes: {
- TVMByteArray* bytes = static_cast<TVMByteArray*>(value.v_handle);
- int len = bytes->size;
- LogValue<int64_t>("(Bytes) [size]: ", len);
- if (PRINT_BYTES) {
- Log(", [Values]:");
- Log(" { ");
- if (len > 0) {
- LogValue<uint64_t>("", (uint8_t)bytes->data[0]);
- }
- for (int j = 1; j < len; j++) LogValue<uint64_t>(" - ", (uint8_t)bytes->data[j]);
- Log(" } ");
- }
- break;
- }
- default: {
- Log("ERROR-kUnknownTypeCode)");
- break;
- }
- }
- Log("; ");
-}
-
-void Logger::OutputLog() {
- LOG(INFO) << os_.str();
- os_.str(std::string());
-}
-
-void MinRPCReturnsWithLog::ReturnVoid() {
- next_->ReturnVoid();
- logger_->Log("-> ReturnVoid");
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ReturnHandle(void* handle) {
- next_->ReturnHandle(handle);
- if (code_ == RPCCode::kGetGlobalFunc) {
- RegisterHandleName(handle);
- }
- logger_->LogValue<void*>("-> ReturnHandle: ", handle);
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ReturnException(const char* msg) {
- next_->ReturnException(msg);
- logger_->Log("-> Exception: ");
- logger_->Log(msg);
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes,
- int num_args) {
- next_->ReturnPackedSeq(arg_values, type_codes, num_args);
- ProcessValues(arg_values, type_codes, num_args);
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) {
- next_->ReturnCopyFromRemote(data_ptr, num_bytes);
- logger_->LogValue<uint64_t>("-> CopyFromRemote: ", num_bytes);
- logger_->LogValue<void*>(", ", static_cast<void*>(data_ptr));
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ReturnLastTVMError() {
- const char* err = TVMGetLastError();
- ReturnException(err);
-}
-
-void MinRPCReturnsWithLog::ThrowError(RPCServerStatus code, RPCCode info) {
- next_->ThrowError(code, info);
- logger_->Log("-> ERROR: ");
- logger_->Log(RPCServerStatusToString(code));
- logger_->OutputLog();
-}
-
-void MinRPCReturnsWithLog::ProcessValues(const TVMValue* values, const int* tcodes, int num_args) {
- if (tcodes != nullptr) {
- logger_->Log("-> [");
- for (int i = 0; i < num_args; ++i) {
- logger_->LogTVMValue(tcodes[i], values[i]);
-
- if (tcodes[i] == kTVMOpaqueHandle) {
- RegisterHandleName(values[i].v_handle);
- }
- }
- logger_->Log("]");
- }
-}
-
-void MinRPCReturnsWithLog::ResetHandleName(RPCCode code) {
- code_ = code;
- handle_name_.clear();
-}
-
-void MinRPCReturnsWithLog::UpdateHandleName(const char* name) {
- if (handle_name_.length() != 0) {
- handle_name_.append("::");
- }
- handle_name_.append(name);
-}
-
-void MinRPCReturnsWithLog::GetHandleName(void* handle) {
- if (handle_descriptions_.find(handle) != handle_descriptions_.end()) {
- handle_name_.append(handle_descriptions_[handle]);
- logger_->LogHandleName(handle_name_);
- }
-}
-
-void MinRPCReturnsWithLog::ReleaseHandleName(void* handle) {
- if (handle_descriptions_.find(handle) != handle_descriptions_.end()) {
- logger_->LogHandleName(handle_descriptions_[handle]);
- handle_descriptions_.erase(handle);
- }
-}
-
-void MinRPCReturnsWithLog::RegisterHandleName(void* handle) {
- handle_descriptions_[handle] = handle_name_;
-}
-
-void MinRPCExecuteWithLog::InitServer(int num_args) {
- SetRPCCode(RPCCode::kInitServer);
- logger_->Log("Init Server");
- next_->InitServer(num_args);
-}
-
-void MinRPCExecuteWithLog::NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes,
- int num_args) {
- SetRPCCode(RPCCode::kCallFunc);
- logger_->LogValue<void*>("call_handle: ", reinterpret_cast<void*>(call_handle));
- ret_handler_->GetHandleName(reinterpret_cast<void*>(call_handle));
- if (num_args > 0) {
- logger_->Log(", ");
- }
- ProcessValues(values, tcodes, num_args);
- next_->NormalCallFunc(call_handle, values, tcodes, num_args);
-}
-
-void MinRPCExecuteWithLog::CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* temp_data) {
- SetRPCCode(RPCCode::kCopyFromRemote);
- logger_->LogValue<void*>("data_handle: ", static_cast<void*>(arr->data));
- logger_->LogDLDevice(", DLDevice(type,id):", &(arr->device));
- logger_->LogValue<int64_t>(", ndim: ", arr->ndim);
- logger_->LogDLData(", DLDataType(code,bits,lane): ", &(arr->dtype));
- logger_->LogValue<uint64_t>(", num_bytes:", num_bytes);
- next_->CopyFromRemote(arr, num_bytes, temp_data);
-}
-
-int MinRPCExecuteWithLog::CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) {
- SetRPCCode(RPCCode::kCopyToRemote);
- logger_->LogValue<void*>("data_handle: ", static_cast<void*>(arr->data));
- logger_->LogDLDevice(", DLDevice(type,id):", &(arr->device));
- logger_->LogValue<int64_t>(", ndim: ", arr->ndim);
- logger_->LogDLData(", DLDataType(code,bits,lane): ", &(arr->dtype));
- logger_->LogValue<uint64_t>(", byte_offset: ", arr->byte_offset);
- return next_->CopyToRemote(arr, num_bytes, data_ptr);
-}
-
-void MinRPCExecuteWithLog::SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) {
- SetRPCCode(code);
- if ((code) == RPCCode::kFreeHandle) {
- if ((num_args == 2) && (tcodes[0] == kTVMOpaqueHandle) && (tcodes[1] == kDLInt)) {
- logger_->LogValue<void*>("handle: ", static_cast<void*>(values[0].v_handle));
- if (values[1].v_int64 == kTVMModuleHandle || values[1].v_int64 == kTVMPackedFuncHandle) {
- ret_handler_->ReleaseHandleName(static_cast<void*>(values[0].v_handle));
- }
- }
- } else {
- ProcessValues(values, tcodes, num_args);
- }
- next_->SysCallFunc(code, values, tcodes, num_args);
-}
-
-void MinRPCExecuteWithLog::ThrowError(RPCServerStatus code, RPCCode info) {
- logger_->Log("-> Error\n");
- next_->ThrowError(code, info);
-}
-
-void MinRPCExecuteWithLog::ProcessValues(TVMValue* values, int* tcodes, int num_args) {
- if (tcodes != nullptr) {
- logger_->Log("[");
- for (int i = 0; i < num_args; ++i) {
- logger_->LogTVMValue(tcodes[i], values[i]);
-
- if (tcodes[i] == kTVMStr) {
- if (strlen(values[i].v_str) > 0) {
- ret_handler_->UpdateHandleName(values[i].v_str);
- }
- }
- }
- logger_->Log("]");
- }
-}
-
-void MinRPCExecuteWithLog::SetRPCCode(RPCCode code) {
- logger_->Log(RPCCodeToString(code));
- logger_->Log(", ");
- ret_handler_->ResetHandleName(code);
-}
-
-} // namespace runtime
-} // namespace tvm
diff --git a/src/runtime/minrpc/minrpc_logger.h b/src/runtime/minrpc/minrpc_logger.h
deleted file mode 100644
index 13d44c3cba..0000000000
--- a/src/runtime/minrpc/minrpc_logger.h
+++ /dev/null
@@ -1,296 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_
-#define TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_
-
-#include <tvm/runtime/c_runtime_api.h>
-
-#include <functional>
-#include <sstream>
-#include <string>
-#include <unordered_map>
-
-#include "minrpc_interfaces.h"
-#include "rpc_reference.h"
-
-namespace tvm {
-namespace runtime {
-
-#define PRINT_BYTES false
-
-/*!
- * \brief Generates a user readeable log on the console
- */
-class Logger {
- public:
- Logger() {}
-
- /*!
- * \brief this function logs a string
- *
- * \param s the string to be logged.
- */
- void Log(const char* s) { os_ << s; }
- void Log(std::string s) { os_ << s; }
-
- /*!
- * \brief this function logs a numerical value
- *
- * \param desc adds any necessary description before the value.
- * \param val is the value to be logged.
- */
- template <typename T>
- void LogValue(const char* desc, T val) {
- os_ << desc << val;
- }
-
- /*!
- * \brief this function logs the properties of a DLDevice
- *
- * \param desc adds any necessary description before the DLDevice.
- * \param dev is the pointer to the DLDevice to be logged.
- */
- void LogDLDevice(const char* desc, DLDevice* dev) {
- os_ << desc << "(" << dev->device_type << "," << dev->device_id << ")";
- }
-
- /*!
- * \brief this function logs the properties of a DLDataType
- *
- * \param desc adds any necessary description before the DLDataType.
- * \param data is the pointer to the DLDataType to be logged.
- */
- void LogDLData(const char* desc, DLDataType* data) {
- os_ << desc << "(" << (uint16_t)data->code << "," << (uint16_t)data->bits << "," << data->lanes
- << ")";
- }
-
- /*!
- * \brief this function logs a handle name.
- *
- * \param name is the name to be logged.
- */
- void LogHandleName(std::string name) {
- if (name.length() > 0) {
- os_ << " <" << name.c_str() << ">";
- }
- }
-
- /*!
- * \brief this function logs a TVMValue based on its type.
- *
- * \param tcode the type_code of the value stored in TVMValue.
- * \param value is the TVMValue to be logged.
- */
- void LogTVMValue(int tcode, TVMValue value);
-
- /*!
- * \brief this function output the log to the console.
- */
- void OutputLog();
-
- private:
- std::stringstream os_;
-};
-
-/*!
- * \brief A wrapper for a MinRPCReturns object, that also logs the responses.
- *
- * \param next underlying MinRPCReturns that generates the responses.
- */
-class MinRPCReturnsWithLog : public MinRPCReturnInterface {
- public:
- /*!
- * \brief Constructor.
- * \param io The IO handler.
- */
- MinRPCReturnsWithLog(MinRPCReturnInterface* next, Logger* logger)
- : next_(next), logger_(logger) {}
-
- ~MinRPCReturnsWithLog() {}
-
- void ReturnVoid();
-
- void ReturnHandle(void* handle);
-
- void ReturnException(const char* msg);
-
- void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args);
-
- void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes);
-
- void ReturnLastTVMError();
-
- void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone);
-
- /*!
- * \brief this function logs a list of TVMValues, and registers handle_name when needed.
- *
- * \param values is the list of TVMValues.
- * \param tcodes is the list type_code of the TVMValues.
- * \param num_args is the number of items in the list.
- */
- void ProcessValues(const TVMValue* values, const int* tcodes, int num_args);
-
- /*!
- * \brief this function is called when a new command is executed.
- * It clears the handle_name_ and records the command code.
- *
- * \param code the RPC command code.
- */
- void ResetHandleName(RPCCode code);
-
- /*!
- * \brief appends name to the handle_name_.
- *
- * \param name handle name.
- */
- void UpdateHandleName(const char* name);
-
- /*!
- * \brief get the stored handle description.
- *
- * \param handle the handle to get the description for.
- */
- void GetHandleName(void* handle);
-
- /*!
- * \brief remove the handle description from handle_descriptions_.
- *
- * \param handle the handle to remove the description for.
- */
- void ReleaseHandleName(void* handle);
-
- private:
- /*!
- * \brief add the handle description to handle_descriptions_.
- *
- * \param handle the handle to add the description for.
- */
- void RegisterHandleName(void* handle);
-
- MinRPCReturnInterface* next_;
- std::string handle_name_;
- std::unordered_map<void*, std::string> handle_descriptions_;
- RPCCode code_;
- Logger* logger_;
-};
-
-/*!
- * \brief A wrapper for a MinRPCExecute object, that also logs the responses.
- *
- * \param next: underlying MinRPCExecute that processes the packets.
- */
-class MinRPCExecuteWithLog : public MinRPCExecInterface {
- public:
- MinRPCExecuteWithLog(MinRPCExecInterface* next, Logger* logger) : next_(next), logger_(logger) {
- ret_handler_ = reinterpret_cast<MinRPCReturnsWithLog*>(next_->GetReturnInterface());
- }
-
- ~MinRPCExecuteWithLog() {}
-
- void InitServer(int num_args);
-
- void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int num_args);
-
- void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* temp_data);
-
- int CopyToRemote(DLTensor* arr, uint64_t _num_bytes, uint8_t* _data_ptr);
-
- void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args);
-
- void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone);
-
- MinRPCReturnInterface* GetReturnInterface() { return next_->GetReturnInterface(); }
-
- private:
- /*!
- * \brief this function logs a list of TVMValues, and updates handle_name when needed.
- *
- * \param values is the list of TVMValues.
- * \param tcodes is the list type_code of the TVMValues.
- * \param num_args is the number of items in the list.
- */
- void ProcessValues(TVMValue* values, int* tcodes, int num_args);
-
- /*!
- * \brief this function is called when a new command is executed.
- *
- * \param code the RPC command code.
- */
- void SetRPCCode(RPCCode code);
-
- MinRPCExecInterface* next_;
- MinRPCReturnsWithLog* ret_handler_;
- Logger* logger_;
-};
-
-/*!
- * \brief A No-operation MinRPCReturns used within the MinRPCSniffer
- *
- * \tparam TIOHandler* IO provider to provide io handling.
- */
-template <typename TIOHandler>
-class MinRPCReturnsNoOp : public MinRPCReturnInterface {
- public:
- /*!
- * \brief Constructor.
- * \param io The IO handler.
- */
- explicit MinRPCReturnsNoOp(TIOHandler* io) : io_(io) {}
- ~MinRPCReturnsNoOp() {}
- void ReturnVoid() {}
- void ReturnHandle(void* handle) {}
- void ReturnException(const char* msg) {}
- void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) {}
- void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) {}
- void ReturnLastTVMError() {}
- void ThrowError(RPCServerStatus code, RPCCode info) {}
-
- private:
- TIOHandler* io_;
-};
-
-/*!
- * \brief A No-operation MinRPCExecute used within the MinRPCSniffer
- *
- * \tparam ReturnInterface* ReturnInterface pointer to generate and send the responses.
-
- */
-class MinRPCExecuteNoOp : public MinRPCExecInterface {
- public:
- explicit MinRPCExecuteNoOp(MinRPCReturnInterface* ret_handler) : ret_handler_(ret_handler) {}
- ~MinRPCExecuteNoOp() {}
- void InitServer(int _num_args) {}
- void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int num_args) {}
- void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* temp_data) {}
- int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) { return 1; }
- void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) {}
- void ThrowError(RPCServerStatus code, RPCCode info) {}
- MinRPCReturnInterface* GetReturnInterface() { return ret_handler_; }
-
- private:
- MinRPCReturnInterface* ret_handler_;
-};
-
-} // namespace runtime
-} // namespace tvm
-
-#endif // TVM_RUNTIME_MINRPC_MINRPC_LOGGER_H_"
diff --git a/src/runtime/minrpc/minrpc_server.h b/src/runtime/minrpc/minrpc_server.h
index 4684aa0e16..92cb2e819f 100644
--- a/src/runtime/minrpc/minrpc_server.h
+++ b/src/runtime/minrpc/minrpc_server.h
@@ -28,25 +28,27 @@
#ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_
#define TVM_RUNTIME_MINRPC_MINRPC_SERVER_H_
-#ifndef DMLC_LITTLE_ENDIAN
#define DMLC_LITTLE_ENDIAN 1
-#endif
-
#include <string.h>
#include <tvm/runtime/c_runtime_api.h>
-#include <memory>
-#include <utility>
-
#include "../../support/generic_arena.h"
-#include "minrpc_interfaces.h"
#include "rpc_reference.h"
+/*! \brief Whether or not to enable glog style DLOG */
+#ifndef TVM_MINRPC_ENABLE_LOGGING
+#define TVM_MINRPC_ENABLE_LOGGING 0
+#endif
+
#ifndef MINRPC_CHECK
#define MINRPC_CHECK(cond) \
if (!(cond)) this->ThrowError(RPCServerStatus::kCheckError);
#endif
+#if TVM_MINRPC_ENABLE_LOGGING
+#include <tvm/runtime/logging.h>
+#endif
+
namespace tvm {
namespace runtime {
@@ -56,133 +58,95 @@ class PageAllocator;
}
/*!
- * \brief Responses to a minimum RPC command.
+ * \brief A minimum RPC server that only depends on the tvm C runtime..
+ *
+ * All the dependencies are provided by the io arguments.
*
* \tparam TIOHandler IO provider to provide io handling.
+ * An IOHandler needs to provide the following functions:
+ * - PosixWrite, PosixRead, Close: posix style, read, write, close API.
+ * - MessageStart(num_bytes), MessageDone(): framing APIs.
+ * - Exit: exit with status code.
*/
-template <typename TIOHandler>
-class MinRPCReturns : public MinRPCReturnInterface {
+template <typename TIOHandler, template <typename> class Allocator = detail::PageAllocator>
+class MinRPCServer {
public:
+ using PageAllocator = Allocator<TIOHandler>;
+
/*!
* \brief Constructor.
* \param io The IO handler.
*/
- explicit MinRPCReturns(TIOHandler* io) : io_(io) {}
-
- void ReturnVoid() {
- int32_t num_args = 1;
- int32_t tcode = kTVMNullptr;
- RPCCode code = RPCCode::kReturn;
-
- uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode);
-
- io_->MessageStart(packet_nbytes);
- Write(packet_nbytes);
- Write(code);
- Write(num_args);
- Write(tcode);
- io_->MessageDone();
- }
+ explicit MinRPCServer(TIOHandler* io) : io_(io), arena_(PageAllocator(io)) {}
- void ReturnHandle(void* handle) {
- int32_t num_args = 1;
- int32_t tcode = kTVMOpaqueHandle;
- RPCCode code = RPCCode::kReturn;
- uint64_t encode_handle = reinterpret_cast<uint64_t>(handle);
- uint64_t packet_nbytes =
- sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle);
-
- io_->MessageStart(packet_nbytes);
- Write(packet_nbytes);
- Write(code);
- Write(num_args);
- Write(tcode);
- Write(encode_handle);
- io_->MessageDone();
- }
-
- void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); }
-
- void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) {
- RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this);
- }
-
- void ReturnCopyFromRemote(uint8_t* data_ptr, uint64_t num_bytes) {
- RPCCode code = RPCCode::kCopyAck;
- uint64_t packet_nbytes = sizeof(code) + num_bytes;
-
- io_->MessageStart(packet_nbytes);
- Write(packet_nbytes);
- Write(code);
- WriteArray(data_ptr, num_bytes);
- io_->MessageDone();
- }
-
- void ReturnLastTVMError() {
- const char* err = TVMGetLastError();
- ReturnException(err);
- }
-
- void MessageStart(uint64_t packet_nbytes) { io_->MessageStart(packet_nbytes); }
-
- void MessageDone() { io_->MessageDone(); }
+ /*! \brief Process a single request.
+ *
+ * \return true when the server should continue processing requests. false when it should be
+ * shutdown.
+ */
+ bool ProcessOnePacket() {
+ RPCCode code;
+ uint64_t packet_len;
- void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
- io_->Exit(static_cast<int>(code));
- }
+ arena_.RecycleAll();
+ allow_clean_shutdown_ = true;
- template <typename T>
- void Write(const T& data) {
- static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
- "need to be trival");
- return WriteRawBytes(&data, sizeof(T));
- }
+ this->Read(&packet_len);
+ if (packet_len == 0) return true;
+ this->Read(&code);
- template <typename T>
- void WriteArray(T* data, size_t count) {
- static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
- "need to be trival");
- return WriteRawBytes(data, sizeof(T) * count);
- }
+ allow_clean_shutdown_ = false;
- private:
- void WriteRawBytes(const void* data, size_t size) {
- const uint8_t* buf = static_cast<const uint8_t*>(data);
- size_t ndone = 0;
- while (ndone < size) {
- ssize_t ret = io_->PosixWrite(buf, size - ndone);
- if (ret <= 0) {
- this->ThrowError(RPCServerStatus::kWriteError);
+ if (code >= RPCCode::kSyscallCodeStart) {
+ this->HandleSyscallFunc(code);
+ } else {
+ switch (code) {
+ case RPCCode::kCallFunc: {
+ HandleNormalCallFunc();
+ break;
+ }
+ case RPCCode::kInitServer: {
+ HandleInitServer();
+ break;
+ }
+ case RPCCode::kCopyFromRemote: {
+ HandleCopyFromRemote();
+ break;
+ }
+ case RPCCode::kCopyToRemote: {
+ HandleCopyToRemote();
+ break;
+ }
+ case RPCCode::kShutdown: {
+ this->Shutdown();
+ return false;
+ }
+ default: {
+ this->ThrowError(RPCServerStatus::kUnknownRPCCode);
+ break;
+ }
}
- buf += ret;
- ndone += ret;
}
- }
- TIOHandler* io_;
-};
-
-/*!
- * \brief Executing a minimum RPC command.
- *
- * \tparam TIOHandler IO provider to provide io handling.
- * \tparam MinRPCReturnInterface* handles response generatation and transmission.
- */
-template <typename TIOHandler>
-class MinRPCExecute : public MinRPCExecInterface {
- public:
- MinRPCExecute(TIOHandler* io, MinRPCReturnInterface* ret_handler)
- : io_(io), ret_handler_(ret_handler) {}
+ return true;
+ }
- void InitServer(int num_args) {
- MINRPC_CHECK(num_args == 0);
- ret_handler_->ReturnVoid();
+ void Shutdown() {
+ arena_.FreeAll();
+ io_->Close();
}
- void NormalCallFunc(uint64_t call_handle, TVMValue* values, int* tcodes, int num_args) {
+ void HandleNormalCallFunc() {
+ uint64_t call_handle;
+ TVMValue* values;
+ int* tcodes;
+ int num_args;
TVMValue ret_value[3];
int ret_tcode[3];
+ this->Read(&call_handle);
+ RecvPackedSeq(&values, &tcodes, &num_args);
+
int call_ecode = TVMFuncCall(reinterpret_cast<void*>(call_handle), values, tcodes, num_args,
&(ret_value[1]), &(ret_tcode[1]));
@@ -195,27 +159,46 @@ class MinRPCExecute : public MinRPCExecInterface {
ret_tcode[1] = kTVMDLTensorHandle;
ret_value[2].v_handle = ret_value[1].v_handle;
ret_tcode[2] = kTVMOpaqueHandle;
- ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 3);
+ this->ReturnPackedSeq(ret_value, ret_tcode, 3);
} else if (rv_tcode == kTVMBytes) {
ret_tcode[1] = kTVMBytes;
- ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
+ this->ReturnPackedSeq(ret_value, ret_tcode, 2);
TVMByteArrayFree(reinterpret_cast<TVMByteArray*>(ret_value[1].v_handle)); // NOLINT(*)
} else if (rv_tcode == kTVMPackedFuncHandle || rv_tcode == kTVMModuleHandle) {
ret_tcode[1] = kTVMOpaqueHandle;
- ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
+ this->ReturnPackedSeq(ret_value, ret_tcode, 2);
} else {
- ret_handler_->ReturnPackedSeq(ret_value, ret_tcode, 2);
+ this->ReturnPackedSeq(ret_value, ret_tcode, 2);
}
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
- void CopyFromRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) {
+ void HandleCopyFromRemote() {
+ DLTensor* arr = this->ArenaAlloc<DLTensor>(1);
+ uint64_t data_handle;
+ this->Read(&data_handle);
+ arr->data = reinterpret_cast<void*>(data_handle);
+ this->Read(&(arr->device));
+ this->Read(&(arr->ndim));
+ this->Read(&(arr->dtype));
+ arr->shape = this->ArenaAlloc<int64_t>(arr->ndim);
+ this->ReadArray(arr->shape, arr->ndim);
+ arr->strides = nullptr;
+ this->Read(&(arr->byte_offset));
+
+ uint64_t num_bytes;
+ this->Read(&num_bytes);
+
+ uint8_t* data_ptr;
int call_ecode = 0;
- if (arr->device.device_type != kDLCPU) {
+ if (arr->device.device_type == kDLCPU) {
+ data_ptr = reinterpret_cast<uint8_t*>(data_handle) + arr->byte_offset;
+ } else {
+ data_ptr = this->ArenaAlloc<uint8_t>(num_bytes);
DLTensor temp;
- temp.data = static_cast<void*>(data_ptr);
+ temp.data = reinterpret_cast<void*>(data_ptr);
temp.device = DLDevice{kDLCPU, 0};
temp.ndim = arr->ndim;
temp.dtype = arr->dtype;
@@ -230,21 +213,43 @@ class MinRPCExecute : public MinRPCExecInterface {
}
if (call_ecode == 0) {
- ret_handler_->ReturnCopyFromRemote(data_ptr, num_bytes);
+ RPCCode code = RPCCode::kCopyAck;
+ uint64_t packet_nbytes = sizeof(code) + num_bytes;
+
+ io_->MessageStart(packet_nbytes);
+ this->Write(packet_nbytes);
+ this->Write(code);
+ this->WriteArray(data_ptr, num_bytes);
+ io_->MessageDone();
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
- int CopyToRemote(DLTensor* arr, uint64_t num_bytes, uint8_t* data_ptr) {
- int call_ecode = 0;
-
- int ret = ReadArray(data_ptr, num_bytes);
- if (ret <= 0) return ret;
+ void HandleCopyToRemote() {
+ DLTensor* arr = this->ArenaAlloc<DLTensor>(1);
+ uint64_t data_handle;
+ this->Read(&data_handle);
+ arr->data = reinterpret_cast<void*>(data_handle);
+ this->Read(&(arr->device));
+ this->Read(&(arr->ndim));
+ this->Read(&(arr->dtype));
+ arr->shape = this->ArenaAlloc<int64_t>(arr->ndim);
+ this->ReadArray(arr->shape, arr->ndim);
+ arr->strides = nullptr;
+ this->Read(&(arr->byte_offset));
+ uint64_t num_bytes;
+ this->Read(&num_bytes);
- if (arr->device.device_type != kDLCPU) {
+ int call_ecode = 0;
+ if (arr->device.device_type == kDLCPU) {
+ uint8_t* dptr = reinterpret_cast<uint8_t*>(data_handle) + arr->byte_offset;
+ this->ReadArray(dptr, num_bytes);
+ } else {
+ uint8_t* temp_data = this->ArenaAlloc<uint8_t>(num_bytes);
+ this->ReadArray(temp_data, num_bytes);
DLTensor temp;
- temp.data = data_ptr;
+ temp.data = temp_data;
temp.device = DLDevice{kDLCPU, 0};
temp.ndim = arr->ndim;
temp.dtype = arr->dtype;
@@ -259,71 +264,87 @@ class MinRPCExecute : public MinRPCExecInterface {
}
if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
+ this->ReturnVoid();
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
-
- return 1;
}
- void SysCallFunc(RPCCode code, TVMValue* values, int* tcodes, int num_args) {
+ void HandleSyscallFunc(RPCCode code) {
+ TVMValue* values;
+ int* tcodes;
+ int num_args;
+ RecvPackedSeq(&values, &tcodes, &num_args);
switch (code) {
case RPCCode::kFreeHandle: {
- SyscallFreeHandle(values, tcodes, num_args);
+ this->SyscallFreeHandle(values, tcodes, num_args);
break;
}
case RPCCode::kGetGlobalFunc: {
- SyscallGetGlobalFunc(values, tcodes, num_args);
+ this->SyscallGetGlobalFunc(values, tcodes, num_args);
break;
}
case RPCCode::kDevSetDevice: {
- ret_handler_->ReturnException("SetDevice not supported");
+ this->ReturnException("SetDevice not supported");
break;
}
case RPCCode::kDevGetAttr: {
- ret_handler_->ReturnException("GetAttr not supported");
+ this->ReturnException("GetAttr not supported");
break;
}
case RPCCode::kDevAllocData: {
- SyscallDevAllocData(values, tcodes, num_args);
+ this->SyscallDevAllocData(values, tcodes, num_args);
break;
}
case RPCCode::kDevAllocDataWithScope: {
- SyscallDevAllocDataWithScope(values, tcodes, num_args);
+ this->SyscallDevAllocDataWithScope(values, tcodes, num_args);
break;
}
case RPCCode::kDevFreeData: {
- SyscallDevFreeData(values, tcodes, num_args);
+ this->SyscallDevFreeData(values, tcodes, num_args);
break;
}
case RPCCode::kDevCreateStream: {
- SyscallDevCreateStream(values, tcodes, num_args);
+ this->SyscallDevCreateStream(values, tcodes, num_args);
break;
}
case RPCCode::kDevFreeStream: {
- SyscallDevFreeStream(values, tcodes, num_args);
+ this->SyscallDevFreeStream(values, tcodes, num_args);
break;
}
case RPCCode::kDevStreamSync: {
- SyscallDevStreamSync(values, tcodes, num_args);
+ this->SyscallDevStreamSync(values, tcodes, num_args);
break;
}
case RPCCode::kDevSetStream: {
- SyscallDevSetStream(values, tcodes, num_args);
+ this->SyscallDevSetStream(values, tcodes, num_args);
break;
}
case RPCCode::kCopyAmongRemote: {
- SyscallCopyAmongRemote(values, tcodes, num_args);
+ this->SyscallCopyAmongRemote(values, tcodes, num_args);
break;
}
default: {
- ret_handler_->ReturnException("Syscall not recognized");
+ this->ReturnException("Syscall not recognized");
break;
}
}
}
+ void HandleInitServer() {
+ uint64_t len;
+ this->Read(&len);
+ char* proto_ver = this->ArenaAlloc<char>(len + 1);
+ this->ReadArray(proto_ver, len);
+
+ TVMValue* values;
+ int* tcodes;
+ int num_args;
+ RecvPackedSeq(&values, &tcodes, &num_args);
+ MINRPC_CHECK(num_args == 0);
+ this->ReturnVoid();
+ }
+
void SyscallFreeHandle(TVMValue* values, int* tcodes, int num_args) {
MINRPC_CHECK(num_args == 2);
MINRPC_CHECK(tcodes[0] == kTVMOpaqueHandle);
@@ -343,22 +364,23 @@ class MinRPCExecute : public MinRPCExecInterface {
}
if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
+ this->ReturnVoid();
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
void SyscallGetGlobalFunc(TVMValue* values, int* tcodes, int num_args) {
MINRPC_CHECK(num_args == 1);
MINRPC_CHECK(tcodes[0] == kTVMStr);
+
void* handle;
int call_ecode = TVMFuncGetGlobal(values[0].v_str, &handle);
if (call_ecode == 0) {
- ret_handler_->ReturnHandle(handle);
+ this->ReturnHandle(handle);
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
@@ -379,9 +401,9 @@ class MinRPCExecute : public MinRPCExecInterface {
reinterpret_cast<DLTensor*>(to), stream);
if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
+ this->ReturnVoid();
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
@@ -401,9 +423,9 @@ class MinRPCExecute : public MinRPCExecInterface {
int call_ecode = TVMDeviceAllocDataSpace(dev, nbytes, alignment, type_hint, &handle);
if (call_ecode == 0) {
- ret_handler_->ReturnHandle(handle);
+ this->ReturnHandle(handle);
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
@@ -412,15 +434,15 @@ class MinRPCExecute : public MinRPCExecInterface {
MINRPC_CHECK(tcodes[0] == kTVMDLTensorHandle);
MINRPC_CHECK(tcodes[1] == kTVMNullptr || tcodes[1] == kTVMStr);
- DLTensor* arr = static_cast<DLTensor*>(values[0].v_handle);
+ DLTensor* arr = reinterpret_cast<DLTensor*>(values[0].v_handle);
const char* mem_scope = (tcodes[1] == kTVMNullptr ? nullptr : values[1].v_str);
void* handle;
int call_ecode = TVMDeviceAllocDataSpaceWithScope(arr->device, arr->ndim, arr->shape,
arr->dtype, mem_scope, &handle);
if (call_ecode == 0) {
- ret_handler_->ReturnHandle(handle);
+ this->ReturnHandle(handle);
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
@@ -435,9 +457,9 @@ class MinRPCExecute : public MinRPCExecInterface {
int call_ecode = TVMDeviceFreeDataSpace(dev, handle);
if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
+ this->ReturnVoid();
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
@@ -451,9 +473,9 @@ class MinRPCExecute : public MinRPCExecInterface {
int call_ecode = TVMStreamCreate(dev.device_type, dev.device_id, &handle);
if (call_ecode == 0) {
- ret_handler_->ReturnHandle(handle);
+ this->ReturnHandle(handle);
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
@@ -468,9 +490,9 @@ class MinRPCExecute : public MinRPCExecInterface {
int call_ecode = TVMStreamFree(dev.device_type, dev.device_id, handle);
if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
+ this->ReturnVoid();
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
@@ -485,9 +507,9 @@ class MinRPCExecute : public MinRPCExecInterface {
int call_ecode = TVMSynchronize(dev.device_type, dev.device_id, handle);
if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
+ this->ReturnVoid();
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
@@ -502,265 +524,103 @@ class MinRPCExecute : public MinRPCExecInterface {
int call_ecode = TVMSetStream(dev.device_type, dev.device_id, handle);
if (call_ecode == 0) {
- ret_handler_->ReturnVoid();
+ this->ReturnVoid();
} else {
- ret_handler_->ReturnLastTVMError();
+ this->ReturnLastTVMError();
}
}
void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
- ret_handler_->ThrowError(code, info);
+ io_->Exit(static_cast<int>(code));
}
- MinRPCReturnInterface* GetReturnInterface() { return ret_handler_; }
-
- private:
template <typename T>
- int ReadArray(T* data, size_t count) {
- static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
- "need to be trival");
- return ReadRawBytes(data, sizeof(T) * count);
- }
-
- int ReadRawBytes(void* data, size_t size) {
- uint8_t* buf = static_cast<uint8_t*>(data);
- size_t ndone = 0;
- while (ndone < size) {
- ssize_t ret = io_->PosixRead(buf, size - ndone);
- if (ret <= 0) return ret;
- ndone += ret;
- buf += ret;
- }
- return 1;
- }
-
- TIOHandler* io_;
- MinRPCReturnInterface* ret_handler_;
-};
-
-/*!
- * \brief A minimum RPC server that only depends on the tvm C runtime..
- *
- * All the dependencies are provided by the io arguments.
- *
- * \tparam TIOHandler IO provider to provide io handling.
- * An IOHandler needs to provide the following functions:
- * - PosixWrite, PosixRead, Close: posix style, read, write, close API.
- * - MessageStart(num_bytes), MessageDone(): framing APIs.
- * - Exit: exit with status code.
- */
-template <typename TIOHandler, template <typename> class Allocator = detail::PageAllocator>
-class MinRPCServer {
- public:
- using PageAllocator = Allocator<TIOHandler>;
-
- /*!
- * \brief Constructor.
- * \param io The IO handler.
- */
- MinRPCServer(TIOHandler* io, std::unique_ptr<MinRPCExecInterface>&& exec_handler)
- : io_(io), arena_(PageAllocator(io_)), exec_handler_(std::move(exec_handler)) {}
-
- explicit MinRPCServer(TIOHandler* io)
- : io_(io),
- arena_(PageAllocator(io)),
- ret_handler_(new MinRPCReturns<TIOHandler>(io_)),
- exec_handler_(std::unique_ptr<MinRPCExecInterface>(
- new MinRPCExecute<TIOHandler>(io_, ret_handler_))) {}
-
- ~MinRPCServer() {
- if (ret_handler_ != nullptr) {
- delete ret_handler_;
- }
+ T* ArenaAlloc(int count) {
+ static_assert(std::is_pod<T>::value, "need to be trival");
+ return arena_.template allocate_<T>(count);
}
- /*! \brief Process a single request.
- *
- * \return true when the server should continue processing requests. false when it should be
- * shutdown.
- */
- bool ProcessOnePacket() {
- RPCCode code;
- uint64_t packet_len;
-
- arena_.RecycleAll();
- allow_clean_shutdown_ = true;
-
- Read(&packet_len);
- if (packet_len == 0) return true;
- Read(&code);
- allow_clean_shutdown_ = false;
-
- if (code >= RPCCode::kSyscallCodeStart) {
- HandleSyscallFunc(code);
- } else {
- switch (code) {
- case RPCCode::kCallFunc: {
- HandleNormalCallFunc();
- break;
- }
- case RPCCode::kInitServer: {
- HandleInitServer();
- break;
- }
- case RPCCode::kCopyFromRemote: {
- HandleCopyFromRemote();
- break;
- }
- case RPCCode::kCopyToRemote: {
- HandleCopyToRemote();
- break;
- }
- case RPCCode::kShutdown: {
- Shutdown();
- return false;
- }
- default: {
- this->ThrowError(RPCServerStatus::kUnknownRPCCode);
- break;
- }
- }
- }
-
- return true;
+ template <typename T>
+ void Read(T* data) {
+ static_assert(std::is_pod<T>::value, "need to be trival");
+ this->ReadRawBytes(data, sizeof(T));
}
- void HandleInitServer() {
- uint64_t len;
- Read(&len);
- char* proto_ver = ArenaAlloc<char>(len + 1);
- ReadArray(proto_ver, len);
- TVMValue* values;
- int* tcodes;
- int num_args;
- RecvPackedSeq(&values, &tcodes, &num_args);
- exec_handler_->InitServer(num_args);
+ template <typename T>
+ void ReadArray(T* data, size_t count) {
+ static_assert(std::is_pod<T>::value, "need to be trival");
+ return this->ReadRawBytes(data, sizeof(T) * count);
}
- void Shutdown() {
- arena_.FreeAll();
- io_->Close();
+ template <typename T>
+ void Write(const T& data) {
+ static_assert(std::is_pod<T>::value, "need to be trival");
+ return this->WriteRawBytes(&data, sizeof(T));
}
- void HandleNormalCallFunc() {
- uint64_t call_handle;
- TVMValue* values;
- int* tcodes;
- int num_args;
-
- Read(&call_handle);
- RecvPackedSeq(&values, &tcodes, &num_args);
- exec_handler_->NormalCallFunc(call_handle, values, tcodes, num_args);
+ template <typename T>
+ void WriteArray(T* data, size_t count) {
+ static_assert(std::is_pod<T>::value, "need to be trival");
+ return this->WriteRawBytes(data, sizeof(T) * count);
}
- void HandleCopyFromRemote() {
- DLTensor* arr = ArenaAlloc<DLTensor>(1);
- uint64_t data_handle;
- Read(&data_handle);
- arr->data = reinterpret_cast<void*>(data_handle);
- Read(&(arr->device));
- Read(&(arr->ndim));
- Read(&(arr->dtype));
- arr->shape = ArenaAlloc<int64_t>(arr->ndim);
- ReadArray(arr->shape, arr->ndim);
- arr->strides = nullptr;
- Read(&(arr->byte_offset));
-
- uint64_t num_bytes;
- Read(&num_bytes);
+ void MessageStart(uint64_t packet_nbytes) { io_->MessageStart(packet_nbytes); }
- uint8_t* data_ptr;
- if (arr->device.device_type == kDLCPU) {
- data_ptr = reinterpret_cast<uint8_t*>(data_handle) + arr->byte_offset;
- } else {
- data_ptr = ArenaAlloc<uint8_t>(num_bytes);
- }
+ void MessageDone() { io_->MessageDone(); }
- exec_handler_->CopyFromRemote(arr, num_bytes, data_ptr);
+ private:
+ void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) {
+ RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this);
}
- void HandleCopyToRemote() {
- DLTensor* arr = ArenaAlloc<DLTensor>(1);
- uint64_t data_handle;
- Read(&data_handle);
- arr->data = reinterpret_cast<void*>(data_handle);
- Read(&(arr->device));
- Read(&(arr->ndim));
- Read(&(arr->dtype));
- arr->shape = ArenaAlloc<int64_t>(arr->ndim);
- ReadArray(arr->shape, arr->ndim);
- arr->strides = nullptr;
- Read(&(arr->byte_offset));
- uint64_t num_bytes;
- Read(&num_bytes);
- int ret;
- if (arr->device.device_type == kDLCPU) {
- uint8_t* dptr = reinterpret_cast<uint8_t*>(data_handle) + arr->byte_offset;
- ret = exec_handler_->CopyToRemote(arr, num_bytes, dptr);
- } else {
- uint8_t* temp_data = ArenaAlloc<uint8_t>(num_bytes);
- ret = exec_handler_->CopyToRemote(arr, num_bytes, temp_data);
- }
- if (ret == 0) {
- if (allow_clean_shutdown_) {
- Shutdown();
- io_->Exit(0);
- } else {
- this->ThrowError(RPCServerStatus::kReadError);
- }
- }
- if (ret == -1) {
- this->ThrowError(RPCServerStatus::kReadError);
- }
- }
+ void ReturnVoid() {
+ int32_t num_args = 1;
+ int32_t tcode = kTVMNullptr;
+ RPCCode code = RPCCode::kReturn;
- void HandleSyscallFunc(RPCCode code) {
- TVMValue* values;
- int* tcodes;
- int num_args;
- RecvPackedSeq(&values, &tcodes, &num_args);
+ uint64_t packet_nbytes = sizeof(code) + sizeof(num_args) + sizeof(tcode);
- exec_handler_->SysCallFunc(code, values, tcodes, num_args);
+ io_->MessageStart(packet_nbytes);
+ this->Write(packet_nbytes);
+ this->Write(code);
+ this->Write(num_args);
+ this->Write(tcode);
+ io_->MessageDone();
}
- void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
- io_->Exit(static_cast<int>(code));
- }
+ void ReturnHandle(void* handle) {
+ int32_t num_args = 1;
+ int32_t tcode = kTVMOpaqueHandle;
+ RPCCode code = RPCCode::kReturn;
+ uint64_t encode_handle = reinterpret_cast<uint64_t>(handle);
+ uint64_t packet_nbytes =
+ sizeof(code) + sizeof(num_args) + sizeof(tcode) + sizeof(encode_handle);
- template <typename T>
- T* ArenaAlloc(int count) {
- static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
- "need to be trival");
- return arena_.template allocate_<T>(count);
+ io_->MessageStart(packet_nbytes);
+ this->Write(packet_nbytes);
+ this->Write(code);
+ this->Write(num_args);
+ this->Write(tcode);
+ this->Write(encode_handle);
+ io_->MessageDone();
}
- template <typename T>
- void Read(T* data) {
- static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
- "need to be trival");
- ReadRawBytes(data, sizeof(T));
- }
+ void ReturnException(const char* msg) { RPCReference::ReturnException(msg, this); }
- template <typename T>
- void ReadArray(T* data, size_t count) {
- static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
- "need to be trival");
- return ReadRawBytes(data, sizeof(T) * count);
+ void ReturnPackedSeq(const TVMValue* arg_values, const int* type_codes, int num_args) {
+ RPCReference::ReturnPackedSeq(arg_values, type_codes, num_args, this);
}
- private:
- void RecvPackedSeq(TVMValue** out_values, int** out_tcodes, int* out_num_args) {
- RPCReference::RecvPackedSeq(out_values, out_tcodes, out_num_args, this);
- }
+ void ReturnLastTVMError() { this->ReturnException(TVMGetLastError()); }
void ReadRawBytes(void* data, size_t size) {
- uint8_t* buf = static_cast<uint8_t*>(data);
+ uint8_t* buf = reinterpret_cast<uint8_t*>(data);
size_t ndone = 0;
while (ndone < size) {
ssize_t ret = io_->PosixRead(buf, size - ndone);
if (ret == 0) {
if (allow_clean_shutdown_) {
- Shutdown();
+ this->Shutdown();
io_->Exit(0);
} else {
this->ThrowError(RPCServerStatus::kReadError);
@@ -774,15 +634,26 @@ class MinRPCServer {
}
}
+ void WriteRawBytes(const void* data, size_t size) {
+ const uint8_t* buf = reinterpret_cast<const uint8_t*>(data);
+ size_t ndone = 0;
+ while (ndone < size) {
+ ssize_t ret = io_->PosixWrite(buf, size - ndone);
+ if (ret == 0 || ret == -1) {
+ this->ThrowError(RPCServerStatus::kWriteError);
+ }
+ buf += ret;
+ ndone += ret;
+ }
+ }
+
/*! \brief IO handler. */
TIOHandler* io_;
/*! \brief internal arena. */
support::GenericArena<PageAllocator> arena_;
- MinRPCReturns<TIOHandler>* ret_handler_ = nullptr;
- std::unique_ptr<MinRPCExecInterface> exec_handler_;
/*! \brief Whether we are in a state that allows clean shutdown. */
bool allow_clean_shutdown_{true};
- static_assert(DMLC_LITTLE_ENDIAN == 1, "MinRPC only works on little endian.");
+ static_assert(DMLC_LITTLE_ENDIAN, "MinRPC only works on little endian.");
};
namespace detail {
diff --git a/src/runtime/minrpc/minrpc_server_logging.h b/src/runtime/minrpc/minrpc_server_logging.h
deleted file mode 100644
index deca2156ce..0000000000
--- a/src/runtime/minrpc/minrpc_server_logging.h
+++ /dev/null
@@ -1,166 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-#ifndef TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_
-#define TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_
-
-#include <memory>
-#include <utility>
-
-#include "minrpc_logger.h"
-#include "minrpc_server.h"
-
-namespace tvm {
-namespace runtime {
-
-/*!
- * \brief A minimum RPC server that logs the received commands.
- *
- * \tparam TIOHandler IO provider to provide io handling.
- */
-template <typename TIOHandler>
-class MinRPCServerWithLog {
- public:
- explicit MinRPCServerWithLog(TIOHandler* io)
- : ret_handler_(io),
- ret_handler_wlog_(&ret_handler_, &logger_),
- exec_handler_(io, &ret_handler_wlog_),
- exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)),
- next_(io, std::move(exec_handler_ptr_)) {}
-
- bool ProcessOnePacket() { return next_.ProcessOnePacket(); }
-
- private:
- Logger logger_;
- MinRPCReturns<TIOHandler> ret_handler_;
- MinRPCExecute<TIOHandler> exec_handler_;
- MinRPCReturnsWithLog ret_handler_wlog_;
- std::unique_ptr<MinRPCExecuteWithLog> exec_handler_ptr_;
- MinRPCServer<TIOHandler> next_;
-};
-
-/*!
- * \brief A minimum RPC server that only logs the outgoing commands and received responses.
- * (Does not process the packets or respond to them.)
- *
- * \tparam TIOHandler IO provider to provide io handling.
- */
-template <typename TIOHandler, template <typename> class Allocator = detail::PageAllocator>
-class MinRPCSniffer {
- public:
- using PageAllocator = Allocator<TIOHandler>;
- explicit MinRPCSniffer(TIOHandler* io)
- : io_(io),
- arena_(PageAllocator(io_)),
- ret_handler_(io_),
- ret_handler_wlog_(&ret_handler_, &logger_),
- exec_handler_(&ret_handler_wlog_),
- exec_handler_ptr_(new MinRPCExecuteWithLog(&exec_handler_, &logger_)),
- next_(io_, std::move(exec_handler_ptr_)) {}
-
- bool ProcessOnePacket() { return next_.ProcessOnePacket(); }
-
- void ProcessOneResponse() {
- RPCCode code;
- uint64_t packet_len = 0;
-
- if (!Read(&packet_len)) return;
- if (packet_len == 0) {
- OutputLog();
- return;
- }
- if (!Read(&code)) return;
- switch (code) {
- case RPCCode::kReturn: {
- int32_t num_args;
- int* type_codes;
- TVMValue* values;
- RPCReference::RecvPackedSeq(&values, &type_codes, &num_args, this);
- ret_handler_wlog_.ReturnPackedSeq(values, type_codes, num_args);
- break;
- }
- case RPCCode::kException: {
- ret_handler_wlog_.ReturnException("");
- break;
- }
- default: {
- OutputLog();
- break;
- }
- }
- }
-
- void OutputLog() { logger_.OutputLog(); }
-
- void ThrowError(RPCServerStatus code, RPCCode info = RPCCode::kNone) {
- logger_.Log("-> ");
- logger_.Log(RPCServerStatusToString(code));
- OutputLog();
- }
-
- template <typename T>
- T* ArenaAlloc(int count) {
- static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
- "need to be trival");
- return arena_.template allocate_<T>(count);
- }
-
- template <typename T>
- bool Read(T* data) {
- static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
- "need to be trival");
- return ReadRawBytes(data, sizeof(T));
- }
-
- template <typename T>
- bool ReadArray(T* data, size_t count) {
- static_assert(std::is_trivial<T>::value && std::is_standard_layout<T>::value,
- "need to be trival");
- return ReadRawBytes(data, sizeof(T) * count);
- }
-
- private:
- bool ReadRawBytes(void* data, size_t size) {
- uint8_t* buf = reinterpret_cast<uint8_t*>(data);
- size_t ndone = 0;
- while (ndone < size) {
- ssize_t ret = io_->PosixRead(buf, size - ndone);
- if (ret <= 0) {
- this->ThrowError(RPCServerStatus::kReadError);
- return false;
- }
- ndone += ret;
- buf += ret;
- }
- return true;
- }
-
- Logger logger_;
- TIOHandler* io_;
- support::GenericArena<PageAllocator> arena_;
- MinRPCReturnsNoOp<TIOHandler> ret_handler_;
- MinRPCReturnsWithLog ret_handler_wlog_;
- MinRPCExecuteNoOp exec_handler_;
- std::unique_ptr<MinRPCExecuteWithLog> exec_handler_ptr_;
- MinRPCServer<TIOHandler> next_;
-};
-
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_MINRPC_MINRPC_SERVER_LOGGING_H_
diff --git a/src/runtime/rpc/rpc_channel_logger.h b/src/runtime/rpc/rpc_channel_logger.h
deleted file mode 100644
index 53144956eb..0000000000
--- a/src/runtime/rpc/rpc_channel_logger.h
+++ /dev/null
@@ -1,183 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements. See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership. The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied. See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file rpc_channel_logger.h
- * \brief A wrapper for RPCChannel with a NanoRPCListener for logging the commands.
- */
-#ifndef TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_
-#define TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_
-
-#include <memory>
-#include <utility>
-
-#include "../minrpc/minrpc_server_logging.h"
-#include "rpc_channel.h"
-
-#define RX_BUFFER_SIZE 65536
-
-namespace tvm {
-namespace runtime {
-
-class Buffer {
- public:
- Buffer(uint8_t* data, size_t data_size_bytes)
- : data_{data}, capacity_{data_size_bytes}, num_valid_bytes_{0}, read_cursor_{0} {}
-
- size_t Write(const uint8_t* data, size_t data_size_bytes) {
- size_t num_bytes_available = capacity_ - num_valid_bytes_;
- size_t num_bytes_to_copy = data_size_bytes;
- if (num_bytes_available < num_bytes_to_copy) {
- num_bytes_to_copy = num_bytes_available;
- }
-
- memcpy(&data_[num_valid_bytes_], data, num_bytes_to_copy);
- num_valid_bytes_ += num_bytes_to_copy;
- return num_bytes_to_copy;
- }
-
- size_t Read(uint8_t* data, size_t data_size_bytes) {
- size_t num_bytes_to_copy = data_size_bytes;
- size_t num_bytes_available = num_valid_bytes_ - read_cursor_;
- if (num_bytes_available < num_bytes_to_copy) {
- num_bytes_to_copy = num_bytes_available;
- }
-
- memcpy(data, &data_[read_cursor_], num_bytes_to_copy);
- read_cursor_ += num_bytes_to_copy;
- return num_bytes_to_copy;
- }
-
- void Clear() {
- num_valid_bytes_ = 0;
- read_cursor_ = 0;
- }
-
- size_t Size() const { return num_valid_bytes_; }
-
- private:
- /*! \brief pointer to data buffer. */
- uint8_t* data_;
-
- /*! \brief The total number of bytes available in data_.*/
- size_t capacity_;
-
- /*! \brief number of valid bytes in the buffer. */
- size_t num_valid_bytes_;
-
- /*! \brief Read cursor position. */
- size_t read_cursor_;
-};
-
-/*!
- * \brief A simple IO handler for MinRPCSniffer.
- *
- * \tparam Buffer* buffer to store received data.
- */
-class SnifferIOHandler {
- public:
- explicit SnifferIOHandler(Buffer* receive_buffer) : receive_buffer_(receive_buffer) {}
-
- void MessageStart(size_t message_size_bytes) {}
-
- ssize_t PosixWrite(const uint8_t* buf, size_t buf_size_bytes) { return 0; }
-
- void MessageDone() {}
-
- ssize_t PosixRead(uint8_t* buf, size_t buf_size_bytes) {
- return receive_buffer_->Read(buf, buf_size_bytes);
- }
-
- void Close() {}
-
- void Exit(int code) {}
-
- private:
- Buffer* receive_buffer_;
-};
-
-/*!
- * \brief A simple rpc session that logs the received commands.
- */
-class NanoRPCListener {
- public:
- NanoRPCListener()
- : receive_buffer_(receive_storage_, receive_storage_size_bytes_),
- io_(&receive_buffer_),
- rpc_server_(&io_) {}
-
- void Listen(const uint8_t* data, size_t size) { receive_buffer_.Write(data, size); }
-
- void ProcessTxPacket() {
- rpc_server_.ProcessOnePacket();
- ClearBuffer();
- }
-
- void ProcessRxPacket() {
- rpc_server_.ProcessOneResponse();
- ClearBuffer();
- }
-
- private:
- void ClearBuffer() { receive_buffer_.Clear(); }
-
- private:
- size_t receive_storage_size_bytes_ = RX_BUFFER_SIZE;
- uint8_t receive_storage_[RX_BUFFER_SIZE];
- Buffer receive_buffer_;
- SnifferIOHandler io_;
- MinRPCSniffer<SnifferIOHandler> rpc_server_;
-
- void HandleCompleteMessage() { rpc_server_.ProcessOnePacket(); }
-
- static void HandleCompleteMessageCb(void* context) {
- static_cast<NanoRPCListener*>(context)->HandleCompleteMessage();
- }
-};
-
-/*!
- * \brief A wrapper for RPCChannel, that also logs the commands sent.
- *
- * \tparam std::unique_ptr<RPCChannel>&& underlying RPCChannel unique_ptr.
- */
-class RPCChannelLogging : public RPCChannel {
- public:
- explicit RPCChannelLogging(std::unique_ptr<RPCChannel>&& next) { next_ = std::move(next); }
-
- size_t Send(const void* data, size_t size) {
- listener_.ProcessRxPacket();
- listener_.Listen((const uint8_t*)data, size);
- listener_.ProcessTxPacket();
- return next_->Send(data, size);
- }
-
- size_t Recv(void* data, size_t size) {
- size_t ret = next_->Recv(data, size);
- listener_.Listen((const uint8_t*)data, size);
- return ret;
- }
-
- private:
- std::unique_ptr<RPCChannel> next_;
- NanoRPCListener listener_;
-};
-
-} // namespace runtime
-} // namespace tvm
-#endif // TVM_RUNTIME_RPC_RPC_CHANNEL_LOGGER_H_
diff --git a/src/runtime/rpc/rpc_endpoint.h b/src/runtime/rpc/rpc_endpoint.h
index d8e2dece73..ed19a3f59e 100644
--- a/src/runtime/rpc/rpc_endpoint.h
+++ b/src/runtime/rpc/rpc_endpoint.h
@@ -34,7 +34,6 @@
#include "../../support/ring_buffer.h"
#include "../minrpc/rpc_reference.h"
#include "rpc_channel.h"
-#include "rpc_channel_logger.h"
#include "rpc_session.h"
namespace tvm {
@@ -181,7 +180,6 @@ class RPCEndpoint {
void Shutdown();
// Internal channel.
std::unique_ptr<RPCChannel> channel_;
-
// Internal mutex
std::mutex mutex_;
// Internal ring buffer.
diff --git a/src/runtime/rpc/rpc_socket_impl.cc b/src/runtime/rpc/rpc_socket_impl.cc
index bc274ff888..1456fc7191 100644
--- a/src/runtime/rpc/rpc_socket_impl.cc
+++ b/src/runtime/rpc/rpc_socket_impl.cc
@@ -65,7 +65,7 @@ class SockChannel final : public RPCChannel {
};
std::shared_ptr<RPCEndpoint> RPCConnect(std::string url, int port, std::string key,
- bool enable_logging, TVMArgs init_seq) {
+ TVMArgs init_seq) {
support::TCPSocket sock;
support::SockAddr addr(url.c_str(), port);
sock.Create(addr.ss_family());
@@ -96,20 +96,14 @@ std::shared_ptr<RPCEndpoint> RPCConnect(std::string url, int port, std::string k
remote_key.resize(keylen);
ICHECK_EQ(sock.RecvAll(&remote_key[0], keylen), keylen);
}
-
- std::unique_ptr<RPCChannel> channel{new SockChannel(sock)};
- if (enable_logging) {
- channel.reset(new RPCChannelLogging(std::move(channel)));
- }
- auto endpt = RPCEndpoint::Create(std::move(channel), key, remote_key);
-
+ auto endpt =
+ RPCEndpoint::Create(std::unique_ptr<SockChannel>(new SockChannel(sock)), key, remote_key);
endpt->InitRemoteSession(init_seq);
return endpt;
}
-Module RPCClientConnect(std::string url, int port, std::string key, bool enable_logging,
- TVMArgs init_seq) {
- auto endpt = RPCConnect(url, port, "client:" + key, enable_logging, init_seq);
+Module RPCClientConnect(std::string url, int port, std::string key, TVMArgs init_seq) {
+ auto endpt = RPCConnect(url, port, "client:" + key, init_seq);
return CreateRPCSessionModule(CreateClientSession(endpt));
}
@@ -130,9 +124,8 @@ TVM_REGISTER_GLOBAL("rpc.Connect").set_body([](TVMArgs args, TVMRetValue* rv) {
std::string url = args[0];
int port = args[1];
std::string key = args[2];
- bool enable_logging = args[3];
- *rv = RPCClientConnect(url, port, key, enable_logging,
- TVMArgs(args.values + 4, args.type_codes + 4, args.size() - 4));
+ *rv = RPCClientConnect(url, port, key,
+ TVMArgs(args.values + 3, args.type_codes + 3, args.size() - 3));
});
TVM_REGISTER_GLOBAL("rpc.ServerLoop").set_body([](TVMArgs args, TVMRetValue* rv) {
diff --git a/tests/python/unittest/test_runtime_rpc.py b/tests/python/unittest/test_runtime_rpc.py
index 63be742fdb..f0ddcb60a1 100644
--- a/tests/python/unittest/test_runtime_rpc.py
+++ b/tests/python/unittest/test_runtime_rpc.py
@@ -109,25 +109,6 @@ def test_rpc_simple():
check_remote()
-@tvm.testing.requires_rpc
-def test_rpc_simple_wlog():
- server = rpc.Server(key="x1")
- client = rpc.connect("127.0.0.1", server.port, key="x1", enable_logging=True)
-
- def check_remote():
- f1 = client.get_function("rpc.test.addone")
- assert f1(10) == 11
- f3 = client.get_function("rpc.test.except")
-
- with pytest.raises(tvm._ffi.base.TVMError):
- f3("abc")
-
- f2 = client.get_function("rpc.test.strcat")
- assert f2("abc", 11) == "abc:11"
-
- check_remote()
-
-
@tvm.testing.requires_rpc
def test_rpc_runtime_string():
server = rpc.Server(key="x1")
@@ -250,7 +231,7 @@ def test_rpc_remote_module():
"127.0.0.1",
server0.port,
key="x0",
- session_constructor_args=["rpc.Connect", "127.0.0.1", server1.port, "x1", False],
+ session_constructor_args=["rpc.Connect", "127.0.0.1", server1.port, "x1"],
)
def check_remote(remote):
@@ -385,7 +366,7 @@ def test_rpc_session_constructor_args():
"127.0.0.1",
server0.port,
key="x0",
- session_constructor_args=["rpc.Connect", "127.0.0.1", server1.port, "x1", False],
+ session_constructor_args=["rpc.Connect", "127.0.0.1", server1.port, "x1"],
)
fecho = client.get_function("testing.echo")