You are viewing a plain text version of this content. The canonical link for it is here.
Posted to github@arrow.apache.org by GitBox <gi...@apache.org> on 2022/10/26 15:16:21 UTC

[GitHub] [arrow] pitrou commented on a diff in pull request #14266: ARROW-17867: [C++][FlightRPC] Expose bulk parameter binding in Flight SQL

pitrou commented on code in PR #14266:
URL: https://github.com/apache/arrow/pull/14266#discussion_r1005675232


##########
cpp/src/arrow/flight/sql/client.h:
##########
@@ -392,17 +392,18 @@ class ARROW_FLIGHT_SQL_EXPORT PreparedStatement {
 
   /// \brief Retrieve the parameter schema from the query.
   /// \return The parameter schema from the query.
-  std::shared_ptr<Schema> parameter_schema() const;
+  [[nodiscard]] const std::shared_ptr<Schema>& parameter_schema() const;

Review Comment:
   Is the `nodiscard` meant to protect against some incorrect usage here?



##########
cpp/src/arrow/flight/sql/client.cc:
##########
@@ -574,26 +584,21 @@ arrow::Result<int64_t> PreparedStatement::ExecuteUpdate(
   command.set_prepared_statement_handle(handle_);
   ARROW_ASSIGN_OR_RAISE(FlightDescriptor descriptor,
                         GetFlightDescriptorForCommand(command));
-  std::unique_ptr<FlightStreamWriter> writer;
-  std::unique_ptr<FlightMetadataReader> reader;
-
-  if (parameter_binding_ && parameter_binding_->num_rows() > 0) {
-    ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, parameter_binding_->schema(),
-                                       &writer, &reader));
-    ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*parameter_binding_));
+  std::shared_ptr<Buffer> metadata;
+  if (parameter_binding_) {
+    ARROW_ASSIGN_OR_RAISE(metadata, BindParameters(client_->impl_.get(), options,
+                                                   descriptor, parameter_binding_.get()));
   } else {
     const std::shared_ptr<Schema> schema = arrow::schema({});
-    ARROW_RETURN_NOT_OK(client_->DoPut(options, descriptor, schema, &writer, &reader));
-    const ArrayVector columns;
-    const auto& record_batch = arrow::RecordBatch::Make(schema, 0, columns);
-    ARROW_RETURN_NOT_OK(writer->WriteRecordBatch(*record_batch));
+    auto record_batch = arrow::RecordBatch::Make(schema, 0, ArrayVector{});
+    ARROW_ASSIGN_OR_RAISE(auto params,
+                          RecordBatchReader::Make({std::move(record_batch)}));

Review Comment:
   Is it part of the spec that we have to send a single 0-row batch here?



##########
cpp/src/arrow/flight/sql/example/sqlite_server.cc:
##########
@@ -346,18 +331,17 @@ class SQLiteFlightSqlServer::Impl {
 
   arrow::Result<std::unique_ptr<FlightDataStream>> DoGetCatalogs(
       const ServerCallContext& context) {
-    // As SQLite doesn't support catalogs, this will return an empty record batch.
-
+    // https://www.sqlite.org/cli.html
+    // > The ".databases" command shows a list of all databases open
+    // > in the current connection. There will always be at least
+    // > 2. The first one is "main", the original database opened.

Review Comment:
   This comment is a bit confusing, because we return a single row ("main") here.



##########
cpp/src/arrow/flight/sql/example/sqlite_server.cc:
##########
@@ -77,50 +84,25 @@ std::string PrepareQueryForGetTables(const GetTables& command) {
   return table_query.str();
 }
 
-Status SetParametersOnSQLiteStatement(sqlite3_stmt* stmt, FlightMessageReader* reader) {
+template <typename Callback>
+Status SetParametersOnSQLiteStatement(SqliteStatement* statement,
+                                      FlightMessageReader* reader, Callback callback) {
+  sqlite3_stmt* stmt = statement->GetSqlite3Stmt();
   while (true) {
     ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next());
-    std::shared_ptr<RecordBatch>& record_batch = chunk.data;
-    if (record_batch == nullptr) break;
+    if (chunk.data == nullptr) break;
 
-    const int64_t num_rows = record_batch->num_rows();
-    const int& num_columns = record_batch->num_columns();
+    const int64_t num_rows = chunk.data->num_rows();
+    if (num_rows == 0) continue;
 
+    ARROW_RETURN_NOT_OK(statement->SetParameters({std::move(chunk.data)}));
     for (int i = 0; i < num_rows; ++i) {
-      for (int c = 0; c < num_columns; ++c) {
-        const std::shared_ptr<Array>& column = record_batch->column(c);
-        ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, column->GetScalar(i));
-
-        auto& holder = static_cast<DenseUnionScalar&>(*scalar).value;
-
-        switch (holder->type->id()) {
-          case Type::INT64: {
-            int64_t value = static_cast<Int64Scalar&>(*holder).value;
-            sqlite3_bind_int64(stmt, c + 1, value);
-            break;
-          }
-          case Type::FLOAT: {
-            double value = static_cast<FloatScalar&>(*holder).value;
-            sqlite3_bind_double(stmt, c + 1, value);
-            break;
-          }
-          case Type::STRING: {
-            std::shared_ptr<Buffer> buffer = static_cast<StringScalar&>(*holder).value;
-            sqlite3_bind_text(stmt, c + 1, reinterpret_cast<const char*>(buffer->data()),
-                              static_cast<int>(buffer->size()), SQLITE_TRANSIENT);
-            break;
-          }
-          case Type::BINARY: {
-            std::shared_ptr<Buffer> buffer = static_cast<BinaryScalar&>(*holder).value;
-            sqlite3_bind_blob(stmt, c + 1, buffer->data(),
-                              static_cast<int>(buffer->size()), SQLITE_TRANSIENT);
-            break;
-          }
-          default:
-            return Status::Invalid("Received unsupported data type: ",
-                                   holder->type->ToString());
-        }
+      if (sqlite3_clear_bindings(stmt) != SQLITE_OK) {
+        return Status::Invalid("Failed to reset bindings on row ", i, ": ",
+                               sqlite3_errmsg(statement->db()));
       }
+      ARROW_RETURN_NOT_OK(statement->Bind(/*batch_index=*/0, i));
+      ARROW_RETURN_NOT_OK(callback());

Review Comment:
   Hmm, why call the callback for each row?



##########
cpp/src/arrow/flight/sql/example/sqlite_statement.cc:
##########
@@ -160,10 +166,85 @@ arrow::Result<int> SqliteStatement::Reset() {
 sqlite3_stmt* SqliteStatement::GetSqlite3Stmt() const { return stmt_; }
 
 arrow::Result<int64_t> SqliteStatement::ExecuteUpdate() {
-  ARROW_RETURN_NOT_OK(Step());
+  while (true) {
+    ARROW_ASSIGN_OR_RAISE(int rc, Step());
+    if (rc == SQLITE_DONE) break;
+  }
   return sqlite3_changes(db_);
 }
 
+Status SqliteStatement::SetParameters(
+    std::vector<std::shared_ptr<arrow::RecordBatch>> parameters) {
+  const int num_params = sqlite3_bind_parameter_count(stmt_);
+  for (const auto& batch : parameters) {
+    if (batch->num_columns() != num_params) {
+      return Status::Invalid("Expected ", num_params, " parameters, but got ",
+                             batch->num_columns());
+    }
+  }
+  parameters_ = std::move(parameters);
+  auto end = std::remove_if(
+      parameters_.begin(), parameters_.end(),
+      [](const std::shared_ptr<RecordBatch>& batch) { return batch->num_rows() == 0; });
+  parameters_.erase(end, parameters_.end());
+  return Status::OK();
+}
+
+Status SqliteStatement::Bind(size_t batch_index, int64_t row_index) {
+  if (batch_index >= parameters_.size()) {
+    return Status::Invalid("Cannot bind to batch ", batch_index);

Review Comment:
   `IndexError` perhaps?



##########
cpp/src/arrow/flight/sql/example/sqlite_statement.cc:
##########
@@ -160,10 +166,85 @@ arrow::Result<int> SqliteStatement::Reset() {
 sqlite3_stmt* SqliteStatement::GetSqlite3Stmt() const { return stmt_; }
 
 arrow::Result<int64_t> SqliteStatement::ExecuteUpdate() {
-  ARROW_RETURN_NOT_OK(Step());
+  while (true) {
+    ARROW_ASSIGN_OR_RAISE(int rc, Step());
+    if (rc == SQLITE_DONE) break;
+  }
   return sqlite3_changes(db_);
 }
 
+Status SqliteStatement::SetParameters(
+    std::vector<std::shared_ptr<arrow::RecordBatch>> parameters) {
+  const int num_params = sqlite3_bind_parameter_count(stmt_);
+  for (const auto& batch : parameters) {
+    if (batch->num_columns() != num_params) {
+      return Status::Invalid("Expected ", num_params, " parameters, but got ",
+                             batch->num_columns());
+    }
+  }
+  parameters_ = std::move(parameters);
+  auto end = std::remove_if(
+      parameters_.begin(), parameters_.end(),
+      [](const std::shared_ptr<RecordBatch>& batch) { return batch->num_rows() == 0; });
+  parameters_.erase(end, parameters_.end());
+  return Status::OK();
+}
+
+Status SqliteStatement::Bind(size_t batch_index, int64_t row_index) {
+  if (batch_index >= parameters_.size()) {
+    return Status::Invalid("Cannot bind to batch ", batch_index);
+  }
+  const RecordBatch& batch = *parameters_[batch_index];
+  if (row_index < 0 || row_index >= batch.num_rows()) {
+    return Status::Invalid("Cannot bind to row ", row_index, " in batch ", batch_index);

Review Comment:
   Same here.



##########
cpp/src/arrow/flight/sql/server_test.cc:
##########
@@ -502,51 +489,53 @@ TEST_F(TestFlightSqlServer, TestCommandPreparedStatementQueryWithParameterBindin
       auto prepared_statement,
       sql_client->Prepare({}, "SELECT * FROM intTable WHERE keyName LIKE ?"));
 
-  auto parameter_schema = prepared_statement->parameter_schema();
-
+  const std::shared_ptr<Schema>& parameter_schema =
+      prepared_statement->parameter_schema();
   const std::shared_ptr<Schema>& expected_parameter_schema =
       arrow::schema({arrow::field("parameter_1", example::GetUnknownColumnDataType())});
+  ASSERT_NO_FATAL_FAILURE(AssertSchemaEqual(expected_parameter_schema, parameter_schema));
 
-  AssertSchemaEqual(expected_parameter_schema, parameter_schema);
-
-  std::shared_ptr<Array> type_ids = ArrayFromJSON(int8(), R"([0])");
-  std::shared_ptr<Array> offsets = ArrayFromJSON(int32(), R"([0])");
-  std::shared_ptr<Array> string_array = ArrayFromJSON(utf8(), R"(["%one"])");
-  std::shared_ptr<Array> bytes_array = ArrayFromJSON(binary(), R"([])");
-  std::shared_ptr<Array> bigint_array = ArrayFromJSON(int64(), R"([])");
-  std::shared_ptr<Array> double_array = ArrayFromJSON(float64(), R"([])");
-
-  ASSERT_OK_AND_ASSIGN(
-      auto parameter_1_array,
-      DenseUnionArray::Make(*type_ids, *offsets,
-                            {string_array, bytes_array, bigint_array, double_array},
-                            {"string", "bytes", "bigint", "double"}, {0, 1, 2, 3}));
-
-  const std::shared_ptr<RecordBatch>& record_batch =
-      RecordBatch::Make(parameter_schema, 1, {parameter_1_array});
-
-  ASSERT_OK(prepared_statement->SetParameters(record_batch));
+  auto record_batch = RecordBatchFromJSON(parameter_schema, R"([ [[0, "%one"]] ])");
+  ASSERT_OK(prepared_statement->SetParameters(std::move(record_batch)));
 
   ASSERT_OK_AND_ASSIGN(auto flight_info, prepared_statement->Execute());
-
   ASSERT_OK_AND_ASSIGN(auto stream,
                        sql_client->DoGet({}, flight_info->endpoints()[0].ticket));
-
   ASSERT_OK_AND_ASSIGN(auto table, stream->ToTable());
 
   const std::shared_ptr<Schema>& expected_schema =
       arrow::schema({arrow::field("id", int64()), arrow::field("keyName", utf8()),
                      arrow::field("value", int64()), arrow::field("foreignId", int64())});
 
-  const auto id_array = ArrayFromJSON(int64(), R"([1, 3])");
-  const auto keyname_array = ArrayFromJSON(utf8(), R"(["one", "negative one"])");
-  const auto value_array = ArrayFromJSON(int64(), R"([1, -1])");
-  const auto foreignId_array = ArrayFromJSON(int64(), R"([1, 1])");
-
-  const std::shared_ptr<Table>& expected_table = Table::Make(
-      expected_schema, {id_array, keyname_array, value_array, foreignId_array});
-
-  AssertTablesEqual(*expected_table, *table);
+  auto expected_table = TableFromJSON(expected_schema, {R"([
+      [1, "one", 1, 1],
+      [3, "negative one", -1, 1]
+  ])"});
+  ASSERT_NO_FATAL_FAILURE(AssertTablesEqual(*expected_table, *table, /*verbose=*/true));
+
+  // Set multiple parameters at once
+  record_batch =
+      RecordBatchFromJSON(parameter_schema, R"([ [[0, "%one"]], [[0, "%zero"]] ])");
+  ASSERT_OK(prepared_statement->SetParameters(std::move(record_batch)));
+  ASSERT_OK_AND_ASSIGN(flight_info, prepared_statement->Execute());
+  ASSERT_OK_AND_ASSIGN(stream, sql_client->DoGet({}, flight_info->endpoints()[0].ticket));
+  ASSERT_OK_AND_ASSIGN(table, stream->ToTable());
+  expected_table = TableFromJSON(expected_schema, {R"([
+      [1, "one", 1, 1],
+      [3, "negative one", -1, 1],
+      [2, "zero", 0, 1]

Review Comment:
   For the record, are there any tests emitting/receiving nulls?



##########
cpp/src/arrow/flight/sql/example/sqlite_server.cc:
##########
@@ -77,50 +84,25 @@ std::string PrepareQueryForGetTables(const GetTables& command) {
   return table_query.str();
 }
 
-Status SetParametersOnSQLiteStatement(sqlite3_stmt* stmt, FlightMessageReader* reader) {
+template <typename Callback>
+Status SetParametersOnSQLiteStatement(SqliteStatement* statement,
+                                      FlightMessageReader* reader, Callback callback) {
+  sqlite3_stmt* stmt = statement->GetSqlite3Stmt();
   while (true) {
     ARROW_ASSIGN_OR_RAISE(FlightStreamChunk chunk, reader->Next());
-    std::shared_ptr<RecordBatch>& record_batch = chunk.data;
-    if (record_batch == nullptr) break;
+    if (chunk.data == nullptr) break;
 
-    const int64_t num_rows = record_batch->num_rows();
-    const int& num_columns = record_batch->num_columns();
+    const int64_t num_rows = chunk.data->num_rows();
+    if (num_rows == 0) continue;
 
+    ARROW_RETURN_NOT_OK(statement->SetParameters({std::move(chunk.data)}));
     for (int i = 0; i < num_rows; ++i) {
-      for (int c = 0; c < num_columns; ++c) {
-        const std::shared_ptr<Array>& column = record_batch->column(c);
-        ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, column->GetScalar(i));
-
-        auto& holder = static_cast<DenseUnionScalar&>(*scalar).value;
-
-        switch (holder->type->id()) {
-          case Type::INT64: {
-            int64_t value = static_cast<Int64Scalar&>(*holder).value;
-            sqlite3_bind_int64(stmt, c + 1, value);
-            break;
-          }
-          case Type::FLOAT: {
-            double value = static_cast<FloatScalar&>(*holder).value;
-            sqlite3_bind_double(stmt, c + 1, value);
-            break;
-          }
-          case Type::STRING: {
-            std::shared_ptr<Buffer> buffer = static_cast<StringScalar&>(*holder).value;
-            sqlite3_bind_text(stmt, c + 1, reinterpret_cast<const char*>(buffer->data()),
-                              static_cast<int>(buffer->size()), SQLITE_TRANSIENT);
-            break;
-          }
-          case Type::BINARY: {
-            std::shared_ptr<Buffer> buffer = static_cast<BinaryScalar&>(*holder).value;
-            sqlite3_bind_blob(stmt, c + 1, buffer->data(),
-                              static_cast<int>(buffer->size()), SQLITE_TRANSIENT);
-            break;
-          }
-          default:
-            return Status::Invalid("Received unsupported data type: ",
-                                   holder->type->ToString());
-        }
+      if (sqlite3_clear_bindings(stmt) != SQLITE_OK) {
+        return Status::Invalid("Failed to reset bindings on row ", i, ": ",
+                               sqlite3_errmsg(statement->db()));
       }
+      ARROW_RETURN_NOT_OK(statement->Bind(/*batch_index=*/0, i));

Review Comment:
   Is `/*batch_index=*/0` expected here?



##########
cpp/src/arrow/flight/sql/example/sqlite_statement.cc:
##########
@@ -160,10 +166,85 @@ arrow::Result<int> SqliteStatement::Reset() {
 sqlite3_stmt* SqliteStatement::GetSqlite3Stmt() const { return stmt_; }
 
 arrow::Result<int64_t> SqliteStatement::ExecuteUpdate() {
-  ARROW_RETURN_NOT_OK(Step());
+  while (true) {
+    ARROW_ASSIGN_OR_RAISE(int rc, Step());
+    if (rc == SQLITE_DONE) break;
+  }
   return sqlite3_changes(db_);
 }
 
+Status SqliteStatement::SetParameters(
+    std::vector<std::shared_ptr<arrow::RecordBatch>> parameters) {
+  const int num_params = sqlite3_bind_parameter_count(stmt_);
+  for (const auto& batch : parameters) {
+    if (batch->num_columns() != num_params) {
+      return Status::Invalid("Expected ", num_params, " parameters, but got ",
+                             batch->num_columns());
+    }
+  }
+  parameters_ = std::move(parameters);
+  auto end = std::remove_if(
+      parameters_.begin(), parameters_.end(),
+      [](const std::shared_ptr<RecordBatch>& batch) { return batch->num_rows() == 0; });
+  parameters_.erase(end, parameters_.end());
+  return Status::OK();
+}
+
+Status SqliteStatement::Bind(size_t batch_index, int64_t row_index) {
+  if (batch_index >= parameters_.size()) {
+    return Status::Invalid("Cannot bind to batch ", batch_index);
+  }
+  const RecordBatch& batch = *parameters_[batch_index];
+  if (row_index < 0 || row_index >= batch.num_rows()) {
+    return Status::Invalid("Cannot bind to row ", row_index, " in batch ", batch_index);
+  }
+
+  if (sqlite3_clear_bindings(stmt_) != SQLITE_OK) {
+    return Status::Invalid("Failed to reset bindings: ", sqlite3_errmsg(db_));
+  }
+  for (int c = 0; c < batch.num_columns(); ++c) {
+    const std::shared_ptr<Array>& column = batch.column(c);
+    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, column->GetScalar(row_index));

Review Comment:
   Instead of using `GetScalar` and then switch on type, wouldn't it be more efficient to first switch on type and then access the concrete array directly?



##########
cpp/src/arrow/flight/sql/example/sqlite_statement.h:
##########
@@ -62,15 +62,25 @@ class SqliteStatement {
 
   /// \brief Returns the underlying sqlite3_stmt.
   /// \return A sqlite statement.
-  sqlite3_stmt* GetSqlite3Stmt() const;
+  [[nodiscard]] sqlite3_stmt* GetSqlite3Stmt() const;

Review Comment:
   Not sure why `nodiscard` is added on these?



##########
cpp/src/arrow/flight/sql/example/sqlite_statement.cc:
##########
@@ -160,10 +166,85 @@ arrow::Result<int> SqliteStatement::Reset() {
 sqlite3_stmt* SqliteStatement::GetSqlite3Stmt() const { return stmt_; }
 
 arrow::Result<int64_t> SqliteStatement::ExecuteUpdate() {
-  ARROW_RETURN_NOT_OK(Step());
+  while (true) {
+    ARROW_ASSIGN_OR_RAISE(int rc, Step());
+    if (rc == SQLITE_DONE) break;
+  }
   return sqlite3_changes(db_);
 }
 
+Status SqliteStatement::SetParameters(
+    std::vector<std::shared_ptr<arrow::RecordBatch>> parameters) {
+  const int num_params = sqlite3_bind_parameter_count(stmt_);
+  for (const auto& batch : parameters) {
+    if (batch->num_columns() != num_params) {
+      return Status::Invalid("Expected ", num_params, " parameters, but got ",
+                             batch->num_columns());
+    }
+  }
+  parameters_ = std::move(parameters);
+  auto end = std::remove_if(
+      parameters_.begin(), parameters_.end(),
+      [](const std::shared_ptr<RecordBatch>& batch) { return batch->num_rows() == 0; });
+  parameters_.erase(end, parameters_.end());
+  return Status::OK();
+}
+
+Status SqliteStatement::Bind(size_t batch_index, int64_t row_index) {
+  if (batch_index >= parameters_.size()) {
+    return Status::Invalid("Cannot bind to batch ", batch_index);
+  }
+  const RecordBatch& batch = *parameters_[batch_index];
+  if (row_index < 0 || row_index >= batch.num_rows()) {
+    return Status::Invalid("Cannot bind to row ", row_index, " in batch ", batch_index);
+  }
+
+  if (sqlite3_clear_bindings(stmt_) != SQLITE_OK) {
+    return Status::Invalid("Failed to reset bindings: ", sqlite3_errmsg(db_));
+  }
+  for (int c = 0; c < batch.num_columns(); ++c) {
+    const std::shared_ptr<Array>& column = batch.column(c);
+    ARROW_ASSIGN_OR_RAISE(std::shared_ptr<Scalar> scalar, column->GetScalar(row_index));
+    if (scalar->type->id() == Type::DENSE_UNION) {
+      scalar = checked_cast<DenseUnionScalar&>(*scalar).value;
+    }
+
+    int rc = 0;
+    if (!scalar->is_valid) {
+      rc = sqlite3_bind_null(stmt_, c + 1);
+      continue;
+    } else {
+      switch (scalar->type->id()) {
+        case Type::INT64: {
+          int64_t value = checked_cast<const Int64Scalar&>(*scalar).value;
+          rc = sqlite3_bind_int64(stmt_, c + 1, value);
+          break;
+        }
+        case Type::FLOAT: {
+          float value = checked_cast<const FloatScalar&>(*scalar).value;
+          rc = sqlite3_bind_double(stmt_, c + 1, value);
+          break;
+        }
+        case Type::STRING: {
+          const std::shared_ptr<Buffer>& buffer =
+              checked_cast<const StringScalar&>(*scalar).value;
+          rc = sqlite3_bind_text(stmt_, c + 1,
+                                 reinterpret_cast<const char*>(buffer->data()),
+                                 static_cast<int>(buffer->size()), SQLITE_TRANSIENT);
+          break;
+        }
+        default:
+          return Status::Invalid("Received unsupported data type: ", *scalar->type);

Review Comment:
   `TypeError`?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscribe@arrow.apache.org

For queries about this service, please contact Infrastructure at:
users@infra.apache.org