You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@tvm.apache.org by tq...@apache.org on 2021/01/21 15:57:00 UTC

[tvm] branch main updated: [µTVM] Add TVMPlatformGenerateRandom, a non-cryptographic random number generator. (#7266)

This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 8524b28  [µTVM] Add TVMPlatformGenerateRandom, a non-cryptographic random number generator. (#7266)
8524b28 is described below

commit 8524b28078928caf5c8ca82442ad0eab81dce838
Author: Andrew Reusch <ar...@octoml.ai>
AuthorDate: Thu Jan 21 07:56:42 2021 -0800

    [µTVM] Add TVMPlatformGenerateRandom, a non-cryptographic random number generator. (#7266)
    
    * [uTVM] Add TVMPlatformGenerateRandom, and use with Session nonce.
    
     * This change is preparation to support autotuning in microTVM. It
       also cleans up a loose end in the microTVM RPC server
       implementation.
     * Randomness is needed in two places of the CRT:
        1. to initialize the Session nonce, which provides a more robust
           way to detect reboots and ensure that messages are not confused
           across them.
        2. to fill input tensors when timing AutoTVM operators (once
           AutoTVM support lands in the next PR).
    
     * This change adds TVMPlatformGenerateRandom, a platform function for
       generating non-cryptographic random data, to service those needs.
---
 include/tvm/runtime/crt/platform.h            | 19 +++++++++++++++++
 include/tvm/runtime/crt/rpc_common/session.h  | 10 +++++----
 src/runtime/crt/common/crt_runtime_api.c      |  5 +++++
 src/runtime/crt/host/main.cc                  | 15 ++++++++++++++
 src/runtime/crt/utvm_rpc_common/session.cc    |  5 ++++-
 src/runtime/crt/utvm_rpc_server/rpc_server.cc | 12 ++++++++---
 src/runtime/micro/micro_session.cc            | 30 ++++++++++++++++++++++++---
 tests/crt/session_test.cc                     | 14 +++++++------
 tests/micro/qemu/zephyr-runtime/prj.conf      |  4 ++++
 tests/micro/qemu/zephyr-runtime/src/main.c    | 21 +++++++++++++++++++
 10 files changed, 118 insertions(+), 17 deletions(-)

diff --git a/include/tvm/runtime/crt/platform.h b/include/tvm/runtime/crt/platform.h
index 8e03839..d1226e3 100644
--- a/include/tvm/runtime/crt/platform.h
+++ b/include/tvm/runtime/crt/platform.h
@@ -97,6 +97,25 @@ tvm_crt_error_t TVMPlatformTimerStart();
  */
 tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds);
 
+/*! \brief Fill a buffer with random data.
+ *
+ * Cryptographically-secure random data is NOT required. This function is intended for use
+ * cases such as filling autotuning input tensors and choosing the nonce used for microTVM RPC.
+ *
+ * This function does not need to be implemented for inference tasks. It is used only by
+ * AutoTVM and the RPC server. When not implemented, an internal weak-linked stub is provided.
+ *
+ * Please take care that across successive resets, this function returns different sequences of
+ * values. If e.g. the random number generator is seeded with the same value, it may make it
+ * difficult for a host to detect device resets during autotuning or host-driven inference.
+ *
+ * \param buffer Pointer to the 0th byte to write with random data. `num_bytes` of random data
+ * should be written here.
+ * \param num_bytes Number of bytes to write.
+ * \return kTvmErrorNoError if successful; a descriptive error code otherwise.
+ */
+tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes);
+
 #ifdef __cplusplus
 }  // extern "C"
 #endif
diff --git a/include/tvm/runtime/crt/rpc_common/session.h b/include/tvm/runtime/crt/rpc_common/session.h
index 9e6a9f3..eee1de6 100644
--- a/include/tvm/runtime/crt/rpc_common/session.h
+++ b/include/tvm/runtime/crt/rpc_common/session.h
@@ -78,9 +78,9 @@ class Session {
   /*! \brief An invalid nonce value that typically indicates an unknown nonce. */
   static constexpr const uint8_t kInvalidNonce = 0;
 
-  Session(uint8_t initial_session_nonce, Framer* framer, FrameBuffer* receive_buffer,
-          MessageReceivedFunc message_received_func, void* message_received_func_context)
-      : local_nonce_{initial_session_nonce},
+  Session(Framer* framer, FrameBuffer* receive_buffer, MessageReceivedFunc message_received_func,
+          void* message_received_func_context)
+      : local_nonce_{kInvalidNonce},
         session_id_{0},
         state_{State::kReset},
         receiver_{this},
@@ -99,9 +99,11 @@ class Session {
 
   /*!
    * \brief Send a session terminate message, usually done at startup to interrupt a hanging remote.
+   * \param initial_session_nonce Initial nonce that should be used on the first session start
+   *      message. Callers should ensure this is different across device resets.
    * \return kTvmErrorNoError on success, or an error code otherwise.
    */
-  tvm_crt_error_t Initialize();
+  tvm_crt_error_t Initialize(uint8_t initial_session_nonce);
 
   /*!
    * \brief Terminate any previously-established session.
diff --git a/src/runtime/crt/common/crt_runtime_api.c b/src/runtime/crt/common/crt_runtime_api.c
index 960f844..bc47f99 100644
--- a/src/runtime/crt/common/crt_runtime_api.c
+++ b/src/runtime/crt/common/crt_runtime_api.c
@@ -509,3 +509,8 @@ release_and_return : {
 }
   return err;
 }
+
+// Default implementation, overridden by the platform runtime.
+__attribute__((weak)) tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
+  return kTvmErrorFunctionCallNotImplemented;
+}
diff --git a/src/runtime/crt/host/main.cc b/src/runtime/crt/host/main.cc
index 7db17f5..bf36dea 100644
--- a/src/runtime/crt/host/main.cc
+++ b/src/runtime/crt/host/main.cc
@@ -22,6 +22,7 @@
  * \brief main entry point for host subprocess-based CRT
  */
 #include <inttypes.h>
+#include <time.h>
 #include <tvm/runtime/c_runtime_api.h>
 #include <tvm/runtime/crt/logging.h>
 #include <tvm/runtime/crt/memory.h>
@@ -93,6 +94,20 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
   g_utvm_timer_running = 0;
   return kTvmErrorNoError;
 }
+
+static_assert(RAND_MAX >= (1 << 8), "RAND_MAX is smaller than acceptable");
+unsigned int random_seed = 0;
+tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
+  if (random_seed == 0) {
+    random_seed = (unsigned int)time(NULL);
+  }
+  for (size_t i = 0; i < num_bytes; ++i) {
+    int random = rand_r(&random_seed);
+    buffer[i] = (uint8_t)random;
+  }
+
+  return kTvmErrorNoError;
+}
 }
 
 uint8_t memory[512 * 1024];
diff --git a/src/runtime/crt/utvm_rpc_common/session.cc b/src/runtime/crt/utvm_rpc_common/session.cc
index 5930863..e1e338e 100644
--- a/src/runtime/crt/utvm_rpc_common/session.cc
+++ b/src/runtime/crt/utvm_rpc_common/session.cc
@@ -95,7 +95,10 @@ tvm_crt_error_t Session::StartSession() {
   return to_return;
 }
 
-tvm_crt_error_t Session::Initialize() { return TerminateSession(); }
+tvm_crt_error_t Session::Initialize(uint8_t initial_session_nonce) {
+  local_nonce_ = initial_session_nonce;
+  return TerminateSession();
+}
 
 tvm_crt_error_t Session::TerminateSession() {
   SetSessionId(0, 0);
diff --git a/src/runtime/crt/utvm_rpc_server/rpc_server.cc b/src/runtime/crt/utvm_rpc_server/rpc_server.cc
index 074799c..0b9e96c 100644
--- a/src/runtime/crt/utvm_rpc_server/rpc_server.cc
+++ b/src/runtime/crt/utvm_rpc_server/rpc_server.cc
@@ -112,7 +112,7 @@ class MicroRPCServer {
                  utvm_rpc_channel_write_t write_func, void* write_func_ctx)
       : receive_buffer_{receive_storage, receive_storage_size_bytes},
         framer_{&send_stream_},
-        session_{0xa5, &framer_, &receive_buffer_, &HandleCompleteMessageCb, this},
+        session_{&framer_, &receive_buffer_, &HandleCompleteMessageCb, this},
         io_{&session_, &receive_buffer_},
         unframer_{session_.Receiver()},
         rpc_server_{&io_},
@@ -120,7 +120,13 @@ class MicroRPCServer {
 
   void* operator new(size_t count, void* ptr) { return ptr; }
 
-  void Initialize() { CHECK_EQ(kTvmErrorNoError, session_.Initialize(), "rpc server init"); }
+  void Initialize() {
+    uint8_t initial_session_nonce = Session::kInvalidNonce;
+    tvm_crt_error_t error =
+        TVMPlatformGenerateRandom(&initial_session_nonce, sizeof(initial_session_nonce));
+    CHECK_EQ(kTvmErrorNoError, error, "generating random session id");
+    CHECK_EQ(kTvmErrorNoError, session_.Initialize(initial_session_nonce), "rpc server init");
+  }
 
   /*! \brief Process one message from the receive buffer, if possible.
    *
@@ -242,7 +248,7 @@ void TVMLogf(const char* format, ...) {
   } else {
     tvm::runtime::micro_rpc::SerialWriteStream write_stream;
     tvm::runtime::micro_rpc::Framer framer{&write_stream};
-    tvm::runtime::micro_rpc::Session session{0xa5, &framer, nullptr, nullptr, nullptr};
+    tvm::runtime::micro_rpc::Session session{&framer, nullptr, nullptr, nullptr};
     tvm_crt_error_t err =
         session.SendMessage(tvm::runtime::micro_rpc::MessageType::kLog,
                             reinterpret_cast<uint8_t*>(log_buffer), num_bytes_logged);
diff --git a/src/runtime/micro/micro_session.cc b/src/runtime/micro/micro_session.cc
index ceaa5dd..f26a717 100644
--- a/src/runtime/micro/micro_session.cc
+++ b/src/runtime/micro/micro_session.cc
@@ -105,7 +105,7 @@ class MicroTransportChannel : public RPCChannel {
         write_stream_{fsend, session_start_timeout},
         framer_{&write_stream_},
         receive_buffer_{new uint8_t[TVM_CRT_MAX_PACKET_SIZE_BYTES], TVM_CRT_MAX_PACKET_SIZE_BYTES},
-        session_{0x5c, &framer_, &receive_buffer_, &HandleMessageReceivedCb, this},
+        session_{&framer_, &receive_buffer_, &HandleMessageReceivedCb, this},
         unframer_{session_.Receiver()},
         did_receive_message_{false},
         frecv_{frecv},
@@ -161,13 +161,35 @@ class MicroTransportChannel : public RPCChannel {
     }
   }
 
+  static constexpr const int kNumRandRetries = 10;
+  static std::atomic<unsigned int> random_seed;
+
+  inline uint8_t GenerateRandomNonce() {
+    // NOTE: this is bad concurrent programming but in practice we don't really expect race
+    // conditions here, and even if they occur we don't particularly care whether a competing
+    // process computes a different random seed. This value is just chosen pseudo-randomly to
+    // form an initial distinct session id. Here we just want to protect against bad loads causing
+    // confusion.
+    unsigned int seed = random_seed.load();
+    if (seed == 0) {
+      seed = (unsigned int)time(NULL);
+    }
+    uint8_t initial_nonce = 0;
+    for (int i = 0; i < kNumRandRetries && initial_nonce == 0; ++i) {
+      initial_nonce = rand_r(&seed);
+    }
+    random_seed.store(seed);
+    ICHECK_NE(initial_nonce, 0) << "rand() does not seem to be producing random values";
+    return initial_nonce;
+  }
+
   bool StartSessionInternal() {
     using ::std::chrono::duration_cast;
     using ::std::chrono::microseconds;
     using ::std::chrono::steady_clock;
 
     steady_clock::time_point start_time = steady_clock::now();
-    ICHECK_EQ(kTvmErrorNoError, session_.Initialize());
+    ICHECK_EQ(kTvmErrorNoError, session_.Initialize(GenerateRandomNonce()));
     ICHECK_EQ(kTvmErrorNoError, session_.StartSession());
 
     if (session_start_timeout_ == microseconds::zero() &&
@@ -198,7 +220,7 @@ class MicroTransportChannel : public RPCChannel {
       }
       end_time += session_start_retry_timeout_;
 
-      ICHECK_EQ(kTvmErrorNoError, session_.Initialize());
+      ICHECK_EQ(kTvmErrorNoError, session_.Initialize(GenerateRandomNonce()));
       ICHECK_EQ(kTvmErrorNoError, session_.StartSession());
     }
 
@@ -365,6 +387,8 @@ class MicroTransportChannel : public RPCChannel {
   std::string pending_chunk_;
 };
 
+std::atomic<unsigned int> MicroTransportChannel::random_seed{0};
+
 TVM_REGISTER_GLOBAL("micro._rpc_connect").set_body([](TVMArgs args, TVMRetValue* rv) {
   MicroTransportChannel* micro_channel =
       new MicroTransportChannel(args[1], args[2], ::std::chrono::microseconds(uint64_t(args[3])),
diff --git a/tests/crt/session_test.cc b/tests/crt/session_test.cc
index a1d57fc..60686be 100644
--- a/tests/crt/session_test.cc
+++ b/tests/crt/session_test.cc
@@ -55,8 +55,9 @@ class TestSession {
   TestSession(uint8_t initial_nonce)
       : framer{&framer_write_stream},
         receive_buffer{receive_buffer_array, sizeof(receive_buffer_array)},
-        sess{initial_nonce, &framer, &receive_buffer, TestSessionMessageReceivedThunk, this},
-        unframer{sess.Receiver()} {}
+        sess{&framer, &receive_buffer, TestSessionMessageReceivedThunk, this},
+        unframer{sess.Receiver()},
+        initial_nonce{initial_nonce} {}
 
   void WriteTo(TestSession* other) {
     auto framer_buffer = framer_write_stream.BufferContents();
@@ -84,6 +85,7 @@ class TestSession {
   FrameBuffer receive_buffer;
   Session sess;
   Unframer unframer;
+  uint8_t initial_nonce;
 };
 
 #define EXPECT_FRAMED_PACKET(session, expected)          \
@@ -126,14 +128,14 @@ class SessionTest : public ::testing::Test {
 
 TEST_F(SessionTest, NormalExchange) {
   tvm_crt_error_t err;
-  err = alice_.sess.Initialize();
+  err = alice_.sess.Initialize(alice_.initial_nonce);
   EXPECT_EQ(kTvmErrorNoError, err);
   EXPECT_FRAMED_PACKET(alice_,
                        "\xfe\xff\xfd\x03\0\0\0\0\0\x02"
                        "fw");
   alice_.WriteTo(&bob_);
 
-  err = bob_.sess.Initialize();
+  err = bob_.sess.Initialize(bob_.initial_nonce);
   EXPECT_EQ(kTvmErrorNoError, err);
   EXPECT_FRAMED_PACKET(bob_,
                        "\xfe\xff\xfd\x03\0\0\0\0\0\x02"
@@ -212,14 +214,14 @@ static constexpr const char kBobStartPacket[] = "\xff\xfd\x04\0\0\0f\0\0\x01`\xa
 
 TEST_F(SessionTest, DoubleStart) {
   tvm_crt_error_t err;
-  err = alice_.sess.Initialize();
+  err = alice_.sess.Initialize(alice_.initial_nonce);
   EXPECT_EQ(kTvmErrorNoError, err);
   EXPECT_FRAMED_PACKET(alice_,
                        "\xfe\xff\xfd\x03\0\0\0\0\0\x02"
                        "fw");
   alice_.WriteTo(&bob_);
 
-  err = bob_.sess.Initialize();
+  err = bob_.sess.Initialize(bob_.initial_nonce);
   EXPECT_EQ(kTvmErrorNoError, err);
   EXPECT_FRAMED_PACKET(bob_,
                        "\xfe\xff\xfd\x03\0\0\0\0\0\x02"
diff --git a/tests/micro/qemu/zephyr-runtime/prj.conf b/tests/micro/qemu/zephyr-runtime/prj.conf
index cebb557..7be42b2 100644
--- a/tests/micro/qemu/zephyr-runtime/prj.conf
+++ b/tests/micro/qemu/zephyr-runtime/prj.conf
@@ -29,3 +29,7 @@ CONFIG_FPU=y
 
 # For TVMPlatformAbort().
 CONFIG_REBOOT=y
+
+# For TVMPlatformGenerateRandom(). Remember, these values do not need to be truly random.
+CONFIG_TEST_RANDOM_GENERATOR=y
+CONFIG_TIMER_RANDOM_GENERATOR=y
diff --git a/tests/micro/qemu/zephyr-runtime/src/main.c b/tests/micro/qemu/zephyr-runtime/src/main.c
index 9d10504..e04fc20 100644
--- a/tests/micro/qemu/zephyr-runtime/src/main.c
+++ b/tests/micro/qemu/zephyr-runtime/src/main.c
@@ -26,6 +26,7 @@
 #include <drivers/uart.h>
 #include <kernel.h>
 #include <power/reboot.h>
+#include <random/rand32.h>
 #include <stdio.h>
 #include <sys/printk.h>
 #include <sys/ring_buffer.h>
@@ -161,6 +162,26 @@ tvm_crt_error_t TVMPlatformTimerStop(double* elapsed_time_seconds) {
   return kTvmErrorNoError;
 }
 
+tvm_crt_error_t TVMPlatformGenerateRandom(uint8_t* buffer, size_t num_bytes) {
+  uint32_t random;  // one unit of random data.
+
+  // Fill parts of `buffer` which are as large as `random`.
+  size_t num_full_blocks = num_bytes / sizeof(random);
+  for (int i = 0; i < num_full_blocks; ++i) {
+    random = sys_rand32_get();
+    memcpy(&buffer[i * sizeof(random)], &random, sizeof(random));
+  }
+
+  // Fill any leftover tail which is smaller than `random`.
+  size_t num_tail_bytes = num_bytes % sizeof(random);
+  if (num_tail_bytes > 0) {
+    random = sys_rand32_get();
+    memcpy(&buffer[num_bytes - num_tail_bytes], &random, num_tail_bytes);
+  }
+
+  return kTvmErrorNoError;
+}
+
 #define RING_BUF_SIZE 512
 struct uart_rx_buf_t {
   struct ring_buf buf;