You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by GitBox <gi...@apache.org> on 2021/04/14 21:33:40 UTC

[GitHub] [tvm] areusch commented on a change in pull request #7838: [RPC] microtvm: fix rpc large transfer size for microtvm targets

areusch commented on a change in pull request #7838:
URL: https://github.com/apache/tvm/pull/7838#discussion_r612565248



##########
File path: src/runtime/rpc/rpc_endpoint.cc
##########
@@ -801,14 +801,14 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes)
   std::lock_guard<std::mutex> lock(mutex_);
   RPCCode code = RPCCode::kCopyToRemote;
 
-  uint64_t num_data_bytes = static_cast<uint64_t>(GetDataSize(*to));
-  ICHECK_EQ(nbytes, num_data_bytes);
+  uint64_t tensor_max_size_bytes = static_cast<uint64_t>(GetDataSize(*to));
+  ICHECK_LE(to->byte_offset + nbytes, tensor_max_size_bytes) << "Overflow in tensor size.";

Review comment:
       also print the details (byte_offset, nbytes, tensor_max_size_bytes)

##########
File path: src/runtime/rpc/rpc_endpoint.cc
##########
@@ -801,14 +800,14 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes)
   std::lock_guard<std::mutex> lock(mutex_);
   RPCCode code = RPCCode::kCopyToRemote;
 
-  uint64_t num_data_bytes = static_cast<uint64_t>(GetDataSize(*to));
-  ICHECK_EQ(nbytes, num_data_bytes);
+  uint64_t tensor_max_size_bytes = static_cast<uint64_t>(GetDataSize(*to));

Review comment:
       maybe also update the variable names here, if we go the route of returning MAX_PACKET_SIZE_BYTES to here.

##########
File path: src/runtime/crt/host/main.cc
##########
@@ -110,7 +110,7 @@ tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
 }
 }
 
-uint8_t memory[512 * 1024];
+uint8_t memory[2048 * 1024];

Review comment:
       needed?

##########
File path: src/runtime/rpc/rpc_endpoint.cc
##########
@@ -968,7 +967,10 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
   /*!
    * \brief param endpoint The client endpoint of the session.
    */
-  explicit RPCClientSession(std::shared_ptr<RPCEndpoint> endpoint) : endpoint_(endpoint) {}
+  explicit RPCClientSession(std::shared_ptr<RPCEndpoint> endpoint) : endpoint_(endpoint) {
+    // update max transfer size if not set already.
+    SetRPCMaxTransferSize();

Review comment:
       could we do this lazily in CopyToRemote rather than immediately on establishing the session?

##########
File path: src/runtime/rpc/rpc_endpoint.cc
##########
@@ -801,14 +801,14 @@ void RPCEndpoint::CopyToRemote(void* from_bytes, DLTensor* to, uint64_t nbytes)
   std::lock_guard<std::mutex> lock(mutex_);
   RPCCode code = RPCCode::kCopyToRemote;
 
-  uint64_t num_data_bytes = static_cast<uint64_t>(GetDataSize(*to));
-  ICHECK_EQ(nbytes, num_data_bytes);
+  uint64_t tensor_max_size_bytes = static_cast<uint64_t>(GetDataSize(*to));
+  ICHECK_LE(to->byte_offset + nbytes, tensor_max_size_bytes) << "Overflow in tensor size.";
 
-  uint64_t to_data = reinterpret_cast<uint64_t>(to->data);
+  uint64_t to_data = reinterpret_cast<uint64_t>(static_cast<char*>(to->data) + to->byte_offset);

Review comment:
       prefer uint8_t over char, since it's more explicit

##########
File path: src/runtime/rpc/rpc_endpoint.cc
##########
@@ -981,7 +981,20 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
   }
 
   void CopyToRemote(void* local_from_bytes, DLTensor* remote_to, uint64_t nbytes) final {
-    endpoint_->CopyToRemote(local_from_bytes, remote_to, nbytes);
+    uint64_t block_size = 2048;

Review comment:
       make a constant out of this

##########
File path: src/runtime/crt/host/crt_config.h
##########
@@ -46,11 +46,14 @@
 #define TVM_CRT_GLOBAL_FUNC_REGISTRY_SIZE_BYTES 256
 
 /*! Maximum packet size, in bytes, including the length header. */
-#define TVM_CRT_MAX_PACKET_SIZE_BYTES 64000
+#define TVM_CRT_MAX_PACKET_SIZE_BYTES 8 * 1024
 
 /*! \brief Maximum length of a PackedFunc function name. */
 #define TVM_CRT_MAX_FUNCTION_NAME_LENGTH_BYTES 30
 
+/*! Size of the global function for max RPC transfer, in bytes. */
+#define TVM_CRT_RPC_MAX_TRANSFER_SIZE_BYTES 2048

Review comment:
       I think the RPC function should return an RPC-level thing rather than the limit on one RPC call in particular. So I'd vote to have the function return TVM_CRT_MAX_PACKET_SIZE_BYTES, then in CopyToRemote, compute the max size of tensor that fits in TVM_CRT_MAX_PACKET_SIZE_BYTES. I think that approach should avoid needing to add this constant here.

##########
File path: src/runtime/rpc/rpc_endpoint.h
##########
@@ -48,6 +48,9 @@ const int kRPCSuccess = kRPCMagic + 0;
 // cannot found matched key in server
 const int kRPCMismatch = kRPCMagic + 2;
 
+// When tvm.rpc.server.GetTransferMaxSize global function is not registered.
+const int kRPCMaxTransferSizeDefault = 128000;

Review comment:
       include bytes in the name

##########
File path: src/runtime/crt/common/crt_runtime_api.c
##########
@@ -298,8 +298,14 @@ static tvm_crt_error_t FindFunctionOrSetAPIError(tvm_module_index_t module_index
 }
 
 int TVMFuncGetGlobal(const char* name, TVMFunctionHandle* out) {
-  return FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name,
-                                   out);
+  tvm_crt_error_t to_return =
+      FindFunctionOrSetAPIError(kGlobalFuncModuleIndex, &global_func_registry.registry, name, out);
+  // For compatibility with C++

Review comment:
       maybe just be a little clearer: "for compatibility with the C++ runtime equivalent, in src/runtime/registry.cc"

##########
File path: src/runtime/rpc/rpc_endpoint.cc
##########
@@ -1042,7 +1057,23 @@ class RPCClientSession : public RPCSession, public DeviceAPI {
   bool IsLocalSession() const final { return false; }
 
  private:
+  void RPCMaxTransferRemoteReturnValue(TVMArgs args) {
+    // Use args[1] as return value, args[0] is tcode
+    rpc_chunk_max_size_bytes_ = (int64_t)args[1];
+  }
+
+  void SetRPCMaxTransferSize() {

Review comment:
       might consider renaming to GetRPCMaxTransferSize and/or making this the getter function which handles lazily fetching from server as needed.




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

For queries about this service, please contact Infrastructure at:
users@infra.apache.org