You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/04/02 15:32:29 UTC

[GitHub] [arrow] lidavidm commented on a change in pull request #12442: ARROW-15706: [C++][FlightRPC] Implement a UCX transport

lidavidm commented on a change in pull request #12442:
URL: https://github.com/apache/arrow/pull/12442#discussion_r841088787



##########
File path: cpp/src/arrow/flight/transport/ucx/ucx_internal.cc
##########
@@ -0,0 +1,1164 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+#include "arrow/flight/transport/ucx/ucx_internal.h"
+
+#include <array>
+#include <limits>
+#include <mutex>
+#include <unordered_map>
+
+#include "arrow/buffer.h"
+#include "arrow/flight/transport/ucx/util_internal.h"
+#include "arrow/flight/types.h"
+#include "arrow/util/base64.h"
+#include "arrow/util/bit_util.h"
+#include "arrow/util/logging.h"
+#include "arrow/util/make_unique.h"
+#include "arrow/util/uri.h"
+
+namespace arrow {
+namespace flight {
+namespace transport {
+namespace ucx {
+
+// Defines to test different implementation strategies
+// Enable the CONTIG path for CPU-only data
+// #define ARROW_FLIGHT_UCX_SEND_CONTIG
+// Enable ucp_mem_map in IOV path
+// #define ARROW_FLIGHT_UCX_SEND_IOV_MAP
+
+constexpr char kHeaderMethod[] = ":method:";
+
+namespace {
+Status SizeToUInt32BytesBe(const int64_t in, uint8_t* out) {
+  if (ARROW_PREDICT_FALSE(in < 0)) {
+    return Status::Invalid("Length cannot be negative");
+  } else if (ARROW_PREDICT_FALSE(
+                 in > static_cast<int64_t>(std::numeric_limits<uint32_t>::max()))) {
+    return Status::Invalid("Length cannot exceed uint32_t");
+  }
+  UInt32ToBytesBe(static_cast<uint32_t>(in), out);
+  return Status::OK();
+}
+ucs_memory_type InferMemoryType(const Buffer& buffer) {
+  if (!buffer.is_cpu()) {
+    return UCS_MEMORY_TYPE_CUDA;
+  }
+  return UCS_MEMORY_TYPE_UNKNOWN;
+}
+void TryMapBuffer(ucp_context_h context, const void* buffer, const size_t size,
+                  ucs_memory_type memory_type, ucp_mem_h* memh_p) {
+  ucp_mem_map_params_t map_param;
+  std::memset(&map_param, 0, sizeof(map_param));
+  map_param.field_mask = UCP_MEM_MAP_PARAM_FIELD_ADDRESS |
+                         UCP_MEM_MAP_PARAM_FIELD_LENGTH |
+                         UCP_MEM_MAP_PARAM_FIELD_MEMORY_TYPE;
+  map_param.address = const_cast<void*>(buffer);
+  map_param.length = size;
+  map_param.memory_type = memory_type;
+  auto ucs_status = ucp_mem_map(context, &map_param, memh_p);
+  if (ucs_status != UCS_OK) {
+    *memh_p = nullptr;
+    ARROW_LOG(WARNING) << "Could not map memory: "
+                       << FromUcsStatus("ucp_mem_map", ucs_status);
+  }
+}
+void TryMapBuffer(ucp_context_h context, const Buffer& buffer, ucp_mem_h* memh_p) {
+  TryMapBuffer(context, reinterpret_cast<void*>(buffer.address()),
+               static_cast<size_t>(buffer.size()), InferMemoryType(buffer), memh_p);
+}
+void TryUnmapBuffer(ucp_context_h context, ucp_mem_h memh_p) {
+  if (memh_p) {
+    auto ucs_status = ucp_mem_unmap(context, memh_p);
+    if (ucs_status != UCS_OK) {
+      ARROW_LOG(WARNING) << "Could not unmap memory: "
+                         << FromUcsStatus("ucp_mem_unmap", ucs_status);
+    }
+  }
+}
+
+/// \brief Wrapper around a UCX zero copy buffer (a host memory DATA
+///   buffer).
+///
+/// Owns a reference to the associated worker to avoid undefined
+/// behavior.
+class UcxDataBuffer : public Buffer {
+ public:
+  explicit UcxDataBuffer(std::shared_ptr<UcpWorker> worker, void* data, size_t size)
+      : Buffer(const_cast<const uint8_t*>(reinterpret_cast<uint8_t*>(data)),
+               static_cast<int64_t>(size)),
+        worker_(std::move(worker)) {}
+
+  ~UcxDataBuffer() {
+    ucp_am_data_release(worker_->get(),
+                        const_cast<void*>(reinterpret_cast<const void*>(data())));
+  }
+
+ private:
+  std::shared_ptr<UcpWorker> worker_;
+};
+};  // namespace
+
+constexpr size_t FrameHeader::kFrameHeaderBytes;
+constexpr uint8_t FrameHeader::kFrameVersion;
+
+Status FrameHeader::Set(FrameType frame_type, uint32_t counter, int64_t body_size) {
+  header[0] = kFrameVersion;
+  header[1] = static_cast<uint8_t>(frame_type);
+  UInt32ToBytesBe(counter, header.data() + 4);
+  RETURN_NOT_OK(SizeToUInt32BytesBe(body_size, header.data() + 8));
+  return Status::OK();
+}
+
+arrow::Result<std::shared_ptr<Frame>> Frame::ParseHeader(const void* header,
+                                                         size_t header_length) {
+  if (header_length < FrameHeader::kFrameHeaderBytes) {
+    return Status::IOError("Header is too short, must be at least ",
+                           FrameHeader::kFrameHeaderBytes, " bytes, got ", header_length);
+  }
+
+  const uint8_t* frame_header = reinterpret_cast<const uint8_t*>(header);
+  if (frame_header[0] != FrameHeader::kFrameVersion) {
+    return Status::IOError("Expected frame version ",
+                           static_cast<int>(FrameHeader::kFrameVersion), " but got ",
+                           static_cast<int>(frame_header[0]));
+  } else if (frame_header[1] > static_cast<uint8_t>(FrameType::kMaxFrameType)) {
+    return Status::IOError("Unknown frame type ", static_cast<int>(frame_header[1]));
+  }
+
+  const FrameType frame_type = static_cast<FrameType>(frame_header[1]);
+  const uint32_t frame_counter = BytesToUInt32Be(frame_header + 4);
+  const uint32_t frame_size = BytesToUInt32Be(frame_header + 8);
+
+  if (frame_type == FrameType::kDisconnect) {
+    return Status::Cancelled("Client initiated disconnect");
+  }
+
+  return std::make_shared<Frame>(frame_type, frame_size, frame_counter, nullptr);
+}
+
+arrow::Result<HeadersFrame> HeadersFrame::Parse(std::unique_ptr<Buffer> buffer) {
+  HeadersFrame result;
+  const uint8_t* payload = buffer->data();
+  const uint8_t* end = payload + buffer->size();
+  if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+    return Status::Invalid("Buffer underflow, expected number of headers");
+  }
+  const uint32_t num_headers = BytesToUInt32Be(payload);
+  payload += 4;
+  for (uint32_t i = 0; i < num_headers; i++) {
+    if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+      return Status::Invalid("Buffer underflow, expected length of key ", i + 1);
+    }
+    const uint32_t key_length = BytesToUInt32Be(payload);
+    payload += 4;
+
+    if (ARROW_PREDICT_FALSE((end - payload) < 4)) {
+      return Status::Invalid("Buffer underflow, expected length of value ", i + 1);
+    }
+    const uint32_t value_length = BytesToUInt32Be(payload);
+    payload += 4;
+
+    if (ARROW_PREDICT_FALSE((end - payload) < key_length)) {
+      return Status::Invalid("Buffer underflow, expected key ", i + 1, " to have length ",
+                             key_length, ", but only ", (end - payload), " bytes remain");
+    }
+    const util::string_view key(reinterpret_cast<const char*>(payload), key_length);
+    payload += key_length;
+
+    if (ARROW_PREDICT_FALSE((end - payload) < value_length)) {
+      return Status::Invalid("Buffer underflow, expected value ", i + 1,
+                             " to have length ", value_length, ", but only ",
+                             (end - payload), " bytes remain");
+    }
+    const util::string_view value(reinterpret_cast<const char*>(payload), value_length);
+    payload += value_length;
+    result.headers_.emplace_back(key, value);
+  }
+
+  result.buffer_ = std::move(buffer);
+  return result;
+}
+arrow::Result<HeadersFrame> HeadersFrame::Make(
+    const std::vector<std::pair<std::string, std::string>>& headers) {
+  int32_t total_length = 4 /* # of headers */;
+  for (const auto& header : headers) {
+    total_length += 4 /* key length */ + 4 /* value length */ +
+                    header.first.size() /* key */ + header.second.size();
+  }
+
+  ARROW_ASSIGN_OR_RAISE(auto buffer, AllocateBuffer(total_length));
+  uint8_t* payload = buffer->mutable_data();
+
+  RETURN_NOT_OK(SizeToUInt32BytesBe(headers.size(), payload));
+  payload += 4;
+  for (const auto& header : headers) {
+    RETURN_NOT_OK(SizeToUInt32BytesBe(header.first.size(), payload));
+    payload += 4;
+    RETURN_NOT_OK(SizeToUInt32BytesBe(header.second.size(), payload));
+    payload += 4;
+    std::memcpy(payload, header.first.data(), header.first.size());
+    payload += header.first.size();
+    std::memcpy(payload, header.second.data(), header.second.size());
+    payload += header.second.size();
+  }
+  return Parse(std::move(buffer));
+}
+arrow::Result<HeadersFrame> HeadersFrame::Make(
+    const Status& status,
+    const std::vector<std::pair<std::string, std::string>>& headers) {
+  auto all_headers = headers;
+  all_headers.emplace_back(kHeaderStatusCode,
+                           std::to_string(static_cast<int32_t>(status.code())));
+  all_headers.emplace_back(kHeaderStatusMessage, status.message());
+  if (status.detail()) {
+    auto fsd = FlightStatusDetail::UnwrapStatus(status);
+    if (fsd) {
+      all_headers.emplace_back(kHeaderStatusDetailCode,
+                               std::to_string(static_cast<int32_t>(fsd->code())));
+      all_headers.emplace_back(kHeaderStatusDetail, fsd->extra_info());
+    } else {
+      all_headers.emplace_back(kHeaderStatusDetail, status.detail()->ToString());
+    }
+  }
+  return Make(all_headers);
+}
+
+arrow::Result<util::string_view> HeadersFrame::Get(const std::string& key) {
+  for (const auto& pair : headers_) {
+    if (pair.first == key) return pair.second;
+  }
+  return Status::KeyError(key);
+}
+
+Status HeadersFrame::GetStatus(Status* out) {
+  util::string_view code_str, message_str;
+  auto status = Get(kHeaderStatusCode).Value(&code_str);
+  if (!status.ok()) {
+    return Status::KeyError("Server did not send status code header ", kHeaderStatusCode);
+  }
+
+  StatusCode status_code = StatusCode::OK;
+  auto code = std::strtol(code_str.data(), nullptr, /*base=*/10);
+  switch (code) {
+    case 0:

Review comment:
       These are the Arrow status codes, actually. However doing this isn't portable between other Arrow implementations. Flight does define some standard codes, I've started untangling it in a separate PR and I will follow up to make gRPC and UCX consistent: https://github.com/apache/arrow/pull/12749




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org