You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2022/11/15 08:51:09 UTC

[arrow] 13/27: ARROW-18294: [Java] Fix Flight SQL JDBC PreparedStatement#executeUpdate (#14616)

This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch maint-10.0.x
in repository https://gitbox.apache.org/repos/asf/arrow.git

commit 89f1b91a86583c3bcd7122a058fcb81815157599
Author: David Li <li...@gmail.com>
AuthorDate: Wed Nov 9 16:07:19 2022 -0500

    ARROW-18294: [Java] Fix Flight SQL JDBC PreparedStatement#executeUpdate (#14616)
    
    We need to implement a separate code path for executing a prepared statement that returns an  update count.
    
    Authored-by: David Li <li...@gmail.com>
    Signed-off-by: David Li <li...@gmail.com>
---
 .../arrow/driver/jdbc/ArrowFlightConnection.java   |  2 +-
 .../arrow/driver/jdbc/ArrowFlightMetaImpl.java     | 78 +++++++++++++++++++---
 .../jdbc/ArrowFlightPreparedStatementTest.java     | 15 ++++-
 3 files changed, 83 insertions(+), 12 deletions(-)

diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java
index d2b6e89e3f..79bc04d27f 100644
--- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java
+++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightConnection.java
@@ -139,7 +139,7 @@ public final class ArrowFlightConnection extends AvaticaConnection {
    *
    * @return the handler.
    */
-  ArrowFlightSqlClientHandler getClientHandler() throws SQLException {
+  ArrowFlightSqlClientHandler getClientHandler() {
     return clientHandler;
   }
 
diff --git a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java
index cc7addc3a7..f825e7d13c 100644
--- a/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java
+++ b/java/flight/flight-sql-jdbc-driver/src/main/java/org/apache/arrow/driver/jdbc/ArrowFlightMetaImpl.java
@@ -42,7 +42,7 @@ import org.apache.calcite.avatica.remote.TypedValue;
  * Metadata handler for Arrow Flight.
  */
 public class ArrowFlightMetaImpl extends MetaImpl {
-  private final Map<StatementHandle, PreparedStatement> statementHandlePreparedStatementMap;
+  private final Map<StatementHandleKey, PreparedStatement> statementHandlePreparedStatementMap;
 
   /**
    * Constructs a {@link MetaImpl} object specific for Arrow Flight.
@@ -67,7 +67,8 @@ public class ArrowFlightMetaImpl extends MetaImpl {
 
   @Override
   public void closeStatement(final StatementHandle statementHandle) {
-    PreparedStatement preparedStatement = statementHandlePreparedStatementMap.remove(statementHandle);
+    PreparedStatement preparedStatement =
+        statementHandlePreparedStatementMap.remove(new StatementHandleKey(statementHandle));
     // Testing if the prepared statement was created because the statement can be not created until this moment
     if (preparedStatement != null) {
       preparedStatement.close();
@@ -82,12 +83,25 @@ public class ArrowFlightMetaImpl extends MetaImpl {
   @Override
   public ExecuteResult execute(final StatementHandle statementHandle,
                                final List<TypedValue> typedValues, final long maxRowCount) {
-    // TODO Why is maxRowCount ignored?
-    Preconditions.checkNotNull(statementHandle.signature, "Signature not found.");
-    return new ExecuteResult(
-        Collections.singletonList(MetaResultSet.create(
-            statementHandle.connectionId, statementHandle.id,
-            true, statementHandle.signature, null)));
+    Preconditions.checkArgument(connection.id.equals(statementHandle.connectionId),
+        "Connection IDs are not consistent");
+    if (statementHandle.signature == null) {
+      // Update query
+      final StatementHandleKey key = new StatementHandleKey(statementHandle);
+      PreparedStatement preparedStatement = statementHandlePreparedStatementMap.get(key);
+      if (preparedStatement == null) {
+        throw new IllegalStateException("Prepared statement not found: " + statementHandle);
+      }
+      long updatedCount = preparedStatement.executeUpdate();
+      return new ExecuteResult(Collections.singletonList(MetaResultSet.count(statementHandle.connectionId,
+          statementHandle.id, updatedCount)));
+    } else {
+      // TODO Why is maxRowCount ignored?
+      return new ExecuteResult(
+          Collections.singletonList(MetaResultSet.create(
+              statementHandle.connectionId, statementHandle.id,
+              true, statementHandle.signature, null)));
+    }
   }
 
   @Override
@@ -121,6 +135,9 @@ public class ArrowFlightMetaImpl extends MetaImpl {
                                  final String query, final long maxRowCount) {
     final StatementHandle handle = super.createStatement(connectionHandle);
     handle.signature = newSignature(query);
+    final PreparedStatement preparedStatement =
+        ((ArrowFlightConnection) connection).getClientHandler().prepare(query);
+    statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
     return handle;
   }
 
@@ -143,7 +160,7 @@ public class ArrowFlightMetaImpl extends MetaImpl {
       final PreparedStatement preparedStatement =
           ((ArrowFlightConnection) connection).getClientHandler().prepare(query);
       final StatementType statementType = preparedStatement.getType();
-      statementHandlePreparedStatementMap.put(handle, preparedStatement);
+      statementHandlePreparedStatementMap.put(new StatementHandleKey(handle), preparedStatement);
       final Signature signature = newSignature(query);
       final long updateCount =
           statementType.equals(StatementType.UPDATE) ? preparedStatement.executeUpdate() : -1;
@@ -195,6 +212,47 @@ public class ArrowFlightMetaImpl extends MetaImpl {
   }
 
   PreparedStatement getPreparedStatement(StatementHandle statementHandle) {
-    return statementHandlePreparedStatementMap.get(statementHandle);
+    return statementHandlePreparedStatementMap.get(new StatementHandleKey(statementHandle));
+  }
+
+  // Helper used to look up prepared statement instances later. Avatica doesn't give us the signature in
+  // an UPDATE code path so we can't directly use StatementHandle as a map key.
+  private static final class StatementHandleKey {
+    public final String connectionId;
+    public final int id;
+
+    StatementHandleKey(String connectionId, int id) {
+      this.connectionId = connectionId;
+      this.id = id;
+    }
+
+    StatementHandleKey(StatementHandle statementHandle) {
+      this.connectionId = statementHandle.connectionId;
+      this.id = statementHandle.id;
+    }
+
+    @Override
+    public boolean equals(Object o) {
+      if (this == o) {
+        return true;
+      }
+      if (o == null || getClass() != o.getClass()) {
+        return false;
+      }
+
+      StatementHandleKey that = (StatementHandleKey) o;
+
+      if (id != that.id) {
+        return false;
+      }
+      return connectionId.equals(that.connectionId);
+    }
+
+    @Override
+    public int hashCode() {
+      int result = connectionId.hashCode();
+      result = 31 * result + id;
+      return result;
+    }
   }
 }
diff --git a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
index 51c491be28..8af529296f 100644
--- a/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
+++ b/java/flight/flight-sql-jdbc-driver/src/test/java/org/apache/arrow/driver/jdbc/ArrowFlightPreparedStatementTest.java
@@ -18,6 +18,7 @@
 package org.apache.arrow.driver.jdbc;
 
 import static org.hamcrest.CoreMatchers.equalTo;
+import static org.junit.jupiter.api.Assertions.assertEquals;
 
 import java.sql.Connection;
 import java.sql.PreparedStatement;
@@ -25,6 +26,7 @@ import java.sql.ResultSet;
 import java.sql.SQLException;
 
 import org.apache.arrow.driver.jdbc.utils.CoreMockedSqlProducers;
+import org.apache.arrow.driver.jdbc.utils.MockFlightSqlProducer;
 import org.junit.AfterClass;
 import org.junit.BeforeClass;
 import org.junit.ClassRule;
@@ -34,9 +36,10 @@ import org.junit.rules.ErrorCollector;
 
 public class ArrowFlightPreparedStatementTest {
 
+  public static final MockFlightSqlProducer PRODUCER = CoreMockedSqlProducers.getLegacyProducer();
   @ClassRule
   public static final FlightServerTestRule FLIGHT_SERVER_TEST_RULE = FlightServerTestRule
-      .createStandardTestRule(CoreMockedSqlProducers.getLegacyProducer());
+      .createStandardTestRule(PRODUCER);
 
   private static Connection connection;
 
@@ -75,4 +78,14 @@ public class ArrowFlightPreparedStatementTest {
       collector.checkThat(6, equalTo(psmt.getMetaData().getColumnCount()));
     }
   }
+
+  @Test
+  public void testUpdateQuery() throws SQLException {
+    String query = "Fake update";
+    PRODUCER.addUpdateQuery(query, /*updatedRows*/42);
+    try (final PreparedStatement stmt = connection.prepareStatement(query)) {
+      int updated = stmt.executeUpdate();
+      assertEquals(42, updated);
+    }
+  }
 }