You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by li...@apache.org on 2023/04/15 00:31:18 UTC
[arrow-adbc] branch main updated: feat(java/driver/flight-sql): allow passing BufferAllocator (#564)
This is an automated email from the ASF dual-hosted git repository.
lidavidm pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/arrow-adbc.git
The following commit(s) were added to refs/heads/main by this push:
new 1d3831e feat(java/driver/flight-sql): allow passing BufferAllocator (#564)
1d3831e is described below
commit 1d3831efd68008b7087638a2bf0066c87bbdc0ba
Author: David Li <li...@gmail.com>
AuthorDate: Sat Apr 15 09:31:12 2023 +0900
feat(java/driver/flight-sql): allow passing BufferAllocator (#564)
Requires https://github.com/apache/arrow/pull/34776.
Fixes #534.
---
.github/workflows/java.yml | 6 +
.../flightsql/FlightSqlConnectionMetadataTest.java | 6 +
.../adbc/driver/flightsql/FlightSqlQuirks.java | 4 +-
.../arrow/adbc/driver/flightsql/package-info.java | 29 +++++
.../adbc/driver/flightsql/FlightSqlConnection.java | 3 +-
.../adbc/driver/flightsql/FlightSqlDriver.java | 17 ++-
.../adbc/driver/flightsql/FlightSqlStatement.java | 7 ++
.../testsuite/AbstractConnectionMetadataTest.java | 4 +-
.../testsuite/AbstractPartitionDescriptorTest.java | 11 +-
.../driver/testsuite/AbstractStatementTest.java | 135 +++++++++++++++------
.../adbc/driver/testsuite/ArrowAssertions.java | 118 ++++++++++++++++--
.../arrow/adbc/driver/testsuite/SqlTestUtil.java | 5 +-
12 files changed, 281 insertions(+), 64 deletions(-)
diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml
index bc18594..e4ec0c9 100644
--- a/.github/workflows/java.yml
+++ b/.github/workflows/java.yml
@@ -53,7 +53,13 @@ jobs:
cache: "maven"
distribution: "temurin"
java-version: ${{ matrix.java }}
+ - name: Start SQLite server
+ shell: bash -l {0}
+ run: |
+ docker-compose up -d golang-sqlite-flightsql
- name: Build/Test
+ env:
+ ADBC_SQLITE_FLIGHTSQL_URI: "grpc+tcp://localhost:8080"
run: |
cd java
mvn install
diff --git a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionMetadataTest.java b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionMetadataTest.java
index 605da7f..7441317 100644
--- a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionMetadataTest.java
+++ b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnectionMetadataTest.java
@@ -57,6 +57,12 @@ public class FlightSqlConnectionMetadataTest extends AbstractConnectionMetadataT
super.getTableSchema();
}
+ @Override
+ @Disabled("Not yet implemented")
+ public void getTableSchemaDoesNotExist() throws Exception {
+ super.getTableSchemaDoesNotExist();
+ }
+
@Override
@Disabled("Not yet implemented")
public void getTableTypes() throws Exception {
diff --git a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java
index 2f1840d..477781d 100644
--- a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java
+++ b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlQuirks.java
@@ -32,7 +32,7 @@ import org.apache.arrow.memory.RootAllocator;
import org.junit.jupiter.api.Assumptions;
public class FlightSqlQuirks extends SqlValidationQuirks {
- static final String FLIGHT_SQL_LOCATION_ENV_VAR = "ADBC_FLIGHT_SQL_LOCATION";
+ static final String FLIGHT_SQL_LOCATION_ENV_VAR = "ADBC_SQLITE_FLIGHTSQL_URI";
static String getFlightLocation() {
final String location = System.getenv(FLIGHT_SQL_LOCATION_ENV_VAR);
@@ -48,7 +48,7 @@ public class FlightSqlQuirks extends SqlValidationQuirks {
final Map<String, Object> parameters = new HashMap<>();
parameters.put(AdbcDriver.PARAM_URL, url);
- return FlightSqlDriver.INSTANCE.open(parameters);
+ return new FlightSqlDriver(allocator).open(parameters);
}
@Override
diff --git a/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/package-info.java b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/package-info.java
new file mode 100644
index 0000000..c507c80
--- /dev/null
+++ b/java/driver/flight-sql-validation/src/test/java/org/apache/arrow/adbc/driver/flightsql/package-info.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Integration tests of the Flight SQL driver against the Golang SQLite Flight SQL server.
+ *
+ * <p>To run the server:
+ *
+ * <pre>
+ * go run github.com/apache/arrow/go/v12/arrow/flight/flightsql/example/cmd/sqlite_flightsql_server@latest -port 54000
+ * </pre>
+ *
+ * Then to run the tests, set the environment variable <code>ADBC_SQLITE_FLIGHTSQL_URI</code>.
+ */
+package org.apache.arrow.adbc.driver.flightsql;
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java
index 0a05077..7d6e7b1 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlConnection.java
@@ -36,6 +36,7 @@ import org.apache.arrow.flight.Ticket;
import org.apache.arrow.flight.impl.Flight;
import org.apache.arrow.flight.sql.FlightSqlClient;
import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
@@ -133,7 +134,7 @@ public class FlightSqlConnection implements AdbcConnection {
@Override
public void close() throws Exception {
- client.close();
+ AutoCloseables.close(client, allocator);
}
@Override
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java
index 045d728..324fdea 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlDriver.java
@@ -18,6 +18,7 @@ package org.apache.arrow.adbc.driver.flightsql;
import java.net.URISyntaxException;
import java.util.Map;
+import java.util.Objects;
import org.apache.arrow.adbc.core.AdbcDatabase;
import org.apache.arrow.adbc.core.AdbcDriver;
import org.apache.arrow.adbc.core.AdbcException;
@@ -29,14 +30,22 @@ import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.Preconditions;
/** An ADBC driver wrapping Arrow Flight SQL. */
-public enum FlightSqlDriver implements AdbcDriver {
- INSTANCE;
+public class FlightSqlDriver implements AdbcDriver {
+ public static final FlightSqlDriver INSTANCE = new FlightSqlDriver();
+
+ static {
+ AdbcDriverManager.getInstance()
+ .registerDriver("org.apache.arrow.adbc.driver.flightsql", INSTANCE);
+ }
private final BufferAllocator allocator;
FlightSqlDriver() {
- allocator = new RootAllocator();
- AdbcDriverManager.getInstance().registerDriver("org.apache.arrow.adbc.driver.flightsql", this);
+ this(new RootAllocator());
+ }
+
+ FlightSqlDriver(BufferAllocator allocator) {
+ this.allocator = Objects.requireNonNull(allocator);
}
@Override
diff --git a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java
index 70d4dfe..1fd8b91 100644
--- a/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java
+++ b/java/driver/flight-sql/src/main/java/org/apache/arrow/adbc/driver/flightsql/FlightSqlStatement.java
@@ -22,6 +22,7 @@ import com.google.protobuf.ByteString;
import java.util.ArrayList;
import java.util.List;
import java.util.Objects;
+import java.util.concurrent.ExecutionException;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatement;
import org.apache.arrow.adbc.core.AdbcStatusCode;
@@ -169,6 +170,12 @@ public class FlightSqlStatement implements AdbcStatement {
statement.close();
}
} catch (FlightRuntimeException e) {
+ // XXX: FlightSqlClient.executeUpdate does some extra wrapping that we need to undo
+ if (e.getCause() instanceof ExecutionException
+ && e.getCause().getCause() instanceof FlightRuntimeException) {
+ throw FlightSqlDriverUtil.fromFlightException(
+ (FlightRuntimeException) e.getCause().getCause());
+ }
throw FlightSqlDriverUtil.fromFlightException(e);
}
return new UpdateResult(bindRoot.getRowCount());
diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java
index e2b377b..925737b 100644
--- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java
+++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractConnectionMetadataTest.java
@@ -315,7 +315,9 @@ public abstract class AbstractConnectionMetadataTest {
assertThrows(
AdbcException.class,
() -> connection.getTableSchema(/*catalog*/ null, /*dbSchema*/ null, "DOESNOTEXIST"));
- assertThat(thrown.getStatus()).isEqualTo(AdbcStatusCode.NOT_FOUND);
+ assertThat(thrown.getStatus())
+ .describedAs(thrown.toString())
+ .isEqualTo(AdbcStatusCode.NOT_FOUND);
}
@Test
diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractPartitionDescriptorTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractPartitionDescriptorTest.java
index a0dee52..c4c423a 100644
--- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractPartitionDescriptorTest.java
+++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractPartitionDescriptorTest.java
@@ -30,11 +30,11 @@ import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
-import org.apache.arrow.vector.IntVector;
+import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ipc.ArrowReader;
-import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.junit.jupiter.api.AfterEach;
@@ -63,9 +63,9 @@ public abstract class AbstractPartitionDescriptorTest {
schema =
new Schema(
Arrays.asList(
+ Field.nullable(quirks.caseFoldColumnName("ints"), Types.MinorType.BIGINT.getType()),
Field.nullable(
- quirks.caseFoldColumnName("ints"), new ArrowType.Int(32, /*signed=*/ true)),
- Field.nullable(quirks.caseFoldColumnName("strs"), new ArrowType.Utf8())));
+ quirks.caseFoldColumnName("strs"), Types.MinorType.VARCHAR.getType())));
quirks.cleanupTable(tableName);
}
@@ -78,7 +78,7 @@ public abstract class AbstractPartitionDescriptorTest {
@Test
public void serializeDeserializeQuery() throws Exception {
try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
- final IntVector ints = (IntVector) root.getVector(0);
+ final BigIntVector ints = (BigIntVector) root.getVector(0);
final VarCharVector strs = (VarCharVector) root.getVector(1);
ints.allocateNew(4);
@@ -111,7 +111,6 @@ public abstract class AbstractPartitionDescriptorTest {
connection2.readPartition(
partitionResult.getPartitionDescriptors().get(0).getDescriptor())) {
assertThat(reader.loadNextBatch()).isTrue();
- assertThat(reader.getVectorSchemaRoot().getSchema()).isEqualTo(root.getSchema());
assertRoot(reader.getVectorSchemaRoot()).isEqualTo(root);
}
}
diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java
index b264bf4..e7a1a57 100644
--- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java
+++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/AbstractStatementTest.java
@@ -17,6 +17,7 @@
package org.apache.arrow.adbc.driver.testsuite;
+import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertField;
import static org.apache.arrow.adbc.driver.testsuite.ArrowAssertions.assertRoot;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.jupiter.api.Assertions.assertThrows;
@@ -34,9 +35,11 @@ import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.Preconditions;
+import org.apache.arrow.vector.BigIntVector;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
@@ -54,7 +57,9 @@ public abstract class AbstractStatementTest {
protected BufferAllocator allocator;
protected SqlTestUtil util;
protected String tableName;
- protected Schema schema;
+ // Implementations vary on the integer type
+ protected Schema schema32;
+ protected Schema schema64;
@BeforeEach
public void beforeEach() throws Exception {
@@ -64,12 +69,18 @@ public abstract class AbstractStatementTest {
connection = database.connect();
util = new SqlTestUtil(quirks);
tableName = quirks.caseFoldTableName("bulktable");
- schema =
+ schema32 =
new Schema(
Arrays.asList(
+ Field.nullable(quirks.caseFoldColumnName("ints"), Types.MinorType.INT.getType()),
Field.nullable(
- quirks.caseFoldColumnName("ints"), new ArrowType.Int(32, /*signed=*/ true)),
- Field.nullable(quirks.caseFoldColumnName("strs"), new ArrowType.Utf8())));
+ quirks.caseFoldColumnName("strs"), Types.MinorType.VARCHAR.getType())));
+ schema64 =
+ new Schema(
+ Arrays.asList(
+ Field.nullable(quirks.caseFoldColumnName("ints"), Types.MinorType.BIGINT.getType()),
+ Field.nullable(
+ quirks.caseFoldColumnName("strs"), Types.MinorType.VARCHAR.getType())));
quirks.cleanupTable(tableName);
}
@@ -81,53 +92,96 @@ public abstract class AbstractStatementTest {
@Test
public void bulkIngestAppend() throws Exception {
- try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
- final IntVector ints = (IntVector) root.getVector(0);
- final VarCharVector strs = (VarCharVector) root.getVector(1);
+ // Implementations vary on the integer type here.
+ try (final VectorSchemaRoot root32 = VectorSchemaRoot.create(schema32, allocator);
+ final VectorSchemaRoot root64 = VectorSchemaRoot.create(schema64, allocator)) {
+ {
+ final IntVector ints = (IntVector) root32.getVector(0);
+ final VarCharVector strs = (VarCharVector) root32.getVector(1);
+
+ ints.allocateNew(4);
+ ints.setSafe(0, 0);
+ ints.setSafe(1, 1);
+ ints.setSafe(2, 2);
+ ints.setNull(3);
+ strs.allocateNew(4);
+ strs.setNull(0);
+ strs.setSafe(1, "foo".getBytes(StandardCharsets.UTF_8));
+ strs.setSafe(2, "".getBytes(StandardCharsets.UTF_8));
+ strs.setSafe(3, "asdf".getBytes(StandardCharsets.UTF_8));
+ root32.setRowCount(4);
+ }
+ {
+ final BigIntVector ints = (BigIntVector) root64.getVector(0);
+ final VarCharVector strs = (VarCharVector) root64.getVector(1);
- ints.allocateNew(4);
- ints.setSafe(0, 0);
- ints.setSafe(1, 1);
- ints.setSafe(2, 2);
- ints.setNull(3);
- strs.allocateNew(4);
- strs.setNull(0);
- strs.setSafe(1, "foo".getBytes(StandardCharsets.UTF_8));
- strs.setSafe(2, "".getBytes(StandardCharsets.UTF_8));
- strs.setSafe(3, "asdf".getBytes(StandardCharsets.UTF_8));
- root.setRowCount(4);
+ ints.allocateNew(4);
+ ints.setSafe(0, 0);
+ ints.setSafe(1, 1);
+ ints.setSafe(2, 2);
+ ints.setNull(3);
+ strs.allocateNew(4);
+ strs.setNull(0);
+ strs.setSafe(1, "foo".getBytes(StandardCharsets.UTF_8));
+ strs.setSafe(2, "".getBytes(StandardCharsets.UTF_8));
+ strs.setSafe(3, "asdf".getBytes(StandardCharsets.UTF_8));
+ root64.setRowCount(4);
+ }
try (final AdbcStatement stmt = connection.bulkIngest(tableName, BulkIngestMode.CREATE)) {
- stmt.bind(root);
+ stmt.bind(root32);
stmt.executeUpdate();
}
try (final AdbcStatement stmt = connection.createStatement()) {
stmt.setSqlQuery("SELECT * FROM " + tableName);
try (AdbcStatement.QueryResult queryResult = stmt.executeQuery()) {
assertThat(queryResult.getReader().loadNextBatch()).isTrue();
- assertRoot(queryResult.getReader().getVectorSchemaRoot()).isEqualTo(root);
+ assertThat(queryResult.getReader().getVectorSchemaRoot())
+ .satisfiesAnyOf(
+ data -> assertRoot(data).isEqualTo(root32),
+ data -> assertRoot(data).isEqualTo(root64));
}
}
// Append
try (final AdbcStatement stmt = connection.bulkIngest(tableName, BulkIngestMode.APPEND)) {
- stmt.bind(root);
+ stmt.bind(root32);
stmt.executeUpdate();
}
try (final AdbcStatement stmt = connection.createStatement()) {
stmt.setSqlQuery("SELECT * FROM " + tableName);
try (AdbcStatement.QueryResult queryResult = stmt.executeQuery()) {
assertThat(queryResult.getReader().loadNextBatch()).isTrue();
- root.setRowCount(8);
- ints.setSafe(4, 0);
- ints.setSafe(5, 1);
- ints.setSafe(6, 2);
- ints.setNull(7);
- strs.setNull(4);
- strs.setSafe(5, "foo".getBytes(StandardCharsets.UTF_8));
- strs.setSafe(6, "".getBytes(StandardCharsets.UTF_8));
- strs.setSafe(7, "asdf".getBytes(StandardCharsets.UTF_8));
- assertRoot(queryResult.getReader().getVectorSchemaRoot()).isEqualTo(root);
+ {
+ root32.setRowCount(8);
+ final IntVector ints = (IntVector) root32.getVector(0);
+ final VarCharVector strs = (VarCharVector) root32.getVector(1);
+ ints.setSafe(4, 0);
+ ints.setSafe(5, 1);
+ ints.setSafe(6, 2);
+ ints.setNull(7);
+ strs.setNull(4);
+ strs.setSafe(5, "foo".getBytes(StandardCharsets.UTF_8));
+ strs.setSafe(6, "".getBytes(StandardCharsets.UTF_8));
+ strs.setSafe(7, "asdf".getBytes(StandardCharsets.UTF_8));
+ }
+ {
+ root64.setRowCount(8);
+ final BigIntVector ints = (BigIntVector) root64.getVector(0);
+ final VarCharVector strs = (VarCharVector) root64.getVector(1);
+ ints.setSafe(4, 0);
+ ints.setSafe(5, 1);
+ ints.setSafe(6, 2);
+ ints.setNull(7);
+ strs.setNull(4);
+ strs.setSafe(5, "foo".getBytes(StandardCharsets.UTF_8));
+ strs.setSafe(6, "".getBytes(StandardCharsets.UTF_8));
+ strs.setSafe(7, "asdf".getBytes(StandardCharsets.UTF_8));
+ }
+ assertThat(queryResult.getReader().getVectorSchemaRoot())
+ .satisfiesAnyOf(
+ data -> assertRoot(data).isEqualTo(root32),
+ data -> assertRoot(data).isEqualTo(root64));
}
}
}
@@ -139,7 +193,7 @@ public abstract class AbstractStatementTest {
new Schema(
Collections.singletonList(
Field.nullable(quirks.caseFoldColumnName("ints"), new ArrowType.Utf8())));
- try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema32, allocator)) {
root.setRowCount(1);
try (final AdbcStatement stmt = connection.bulkIngest(tableName, BulkIngestMode.CREATE)) {
stmt.bind(root);
@@ -157,7 +211,7 @@ public abstract class AbstractStatementTest {
@Test
public void bulkIngestAppendNotFound() throws Exception {
- try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema32, allocator)) {
root.setRowCount(1);
try (final AdbcStatement stmt = connection.bulkIngest(tableName, BulkIngestMode.APPEND)) {
stmt.bind(root);
@@ -169,14 +223,14 @@ public abstract class AbstractStatementTest {
@Test
public void bulkIngestCreateConflict() throws Exception {
- try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema32, allocator)) {
root.setRowCount(1);
try (final AdbcStatement stmt = connection.bulkIngest(tableName, BulkIngestMode.CREATE)) {
stmt.bind(root);
stmt.executeUpdate();
}
}
- try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
+ try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema32, allocator)) {
try (final AdbcStatement stmt = connection.bulkIngest(tableName, BulkIngestMode.CREATE)) {
stmt.bind(root);
final AdbcException e = assertThrows(AdbcException.class, stmt::executeUpdate);
@@ -192,8 +246,12 @@ public abstract class AbstractStatementTest {
stmt.setSqlQuery("SELECT * FROM " + tableName);
stmt.prepare();
try (AdbcStatement.QueryResult queryResult = stmt.executeQuery()) {
- assertThat(queryResult.getReader().getVectorSchemaRoot().getSchema())
- .isEqualTo(expectedSchema);
+ // Implementations vary on the integer type here.
+ Schema actualSchema = queryResult.getReader().getVectorSchemaRoot().getSchema();
+ assertThat(actualSchema.getFields().size()).isEqualTo(2);
+ assertThat(actualSchema.getFields().get(0).getType())
+ .isIn(Types.MinorType.INT.getType(), Types.MinorType.BIGINT.getType());
+ assertField(actualSchema.getFields().get(1)).isEqualTo(expectedSchema.getFields().get(1));
assertThat(queryResult.getReader().loadNextBatch()).isTrue();
assertThat(queryResult.getReader().getVectorSchemaRoot().getRowCount()).isEqualTo(4);
while (queryResult.getReader().loadNextBatch()) {
@@ -271,7 +329,8 @@ public abstract class AbstractStatementTest {
stmt.setSqlQuery(String.format("SELECT * FROM %s WHERE INTS = ?", tableName));
stmt.prepare();
final Schema paramsSchema = stmt.getParameterSchema();
- assertThat(paramsSchema.getFields().size()).isEqualTo(1);
+ // Golang SQLite Flight SQL server doesn't support this
+ assertThat(paramsSchema).isNotNull();
}
}
}
diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/ArrowAssertions.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/ArrowAssertions.java
index e357824..14e2146 100644
--- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/ArrowAssertions.java
+++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/ArrowAssertions.java
@@ -17,11 +17,16 @@
package org.apache.arrow.adbc.driver.testsuite;
+import java.util.List;
+import java.util.Objects;
import org.apache.arrow.adbc.core.AdbcException;
import org.apache.arrow.adbc.core.AdbcStatusCode;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.compare.TypeEqualsVisitor;
import org.apache.arrow.vector.compare.VectorEqualsVisitor;
+import org.apache.arrow.vector.types.pojo.Field;
+import org.apache.arrow.vector.types.pojo.Schema;
import org.assertj.core.api.AbstractAssert;
/** AssertJ assertions for Arrow. */
@@ -36,6 +41,14 @@ public final class ArrowAssertions {
return new VectorSchemaRootAssert(actual);
}
+ public static SchemaAssert assertSchema(Schema actual) {
+ return new SchemaAssert(actual);
+ }
+
+ public static FieldAssert assertField(Field actual) {
+ return new FieldAssert(actual);
+ }
+
public static final class AdbcExceptionAssert
extends AbstractAssert<AdbcExceptionAssert, AdbcException> {
AdbcExceptionAssert(AdbcException e) {
@@ -70,24 +83,23 @@ public final class ArrowAssertions {
expected.getClass().getName());
}
final VectorSchemaRoot expectedRoot = (VectorSchemaRoot) expected;
- if (!actual.getSchema().equals(expectedRoot.getSchema())) {
- throw failureWithActualExpected(
- actual,
- expected,
- "Expected Schema:\n%sActual Schema:\n%s",
- expectedRoot.getSchema(),
- actual.getSchema());
- }
+ assertSchema(actual.getSchema()).isEqualTo(expectedRoot.getSchema());
for (int i = 0; i < expectedRoot.getSchema().getFields().size(); i++) {
final FieldVector expectedVector = expectedRoot.getVector(i);
final FieldVector actualVector = actual.getVector(i);
- if (!VectorEqualsVisitor.vectorEquals(expectedVector, actualVector)) {
+ if (!VectorEqualsVisitor.vectorEquals(
+ expectedVector,
+ actualVector,
+ (v1, v2) ->
+ new TypeEqualsVisitor(v2, /*checkName*/ false, /*checkMetadata*/ false)
+ .equals(v1))) {
throw failureWithActualExpected(
actual,
expected,
- "Vector %s does not match.\nExpected vector: %s\nActual vector : %s",
+ "Vector %s does not match %s.\nExpected vector: %s\nActual vector : %s",
expectedVector.getField(),
+ actualVector.getField(),
expectedVector,
actualVector);
}
@@ -95,4 +107,90 @@ public final class ArrowAssertions {
return this;
}
}
+
+ public static class SchemaAssert extends AbstractAssert<SchemaAssert, Schema> {
+ SchemaAssert(Schema schema) {
+ super(schema, SchemaAssert.class);
+ }
+
+ @Override
+ public SchemaAssert isEqualTo(Object expected) {
+ if (!(expected instanceof Schema)) {
+ throw failure(
+ "Expected object is not a Schema, but rather a %s", expected.getClass().getName());
+ }
+ final Schema expectedSchema = (Schema) expected;
+ if (!schemasEqualIgnoringMetadata(expectedSchema, actual)) {
+ throw failureWithActualExpected(
+ actual, expected, "Expected Schema:\n%s\nActual Schema:\n%s", expectedSchema, actual);
+ }
+ return this;
+ }
+
+ private boolean schemasEqualIgnoringMetadata(Schema expected, Schema actual) {
+ if (expected.getFields().size() != actual.getFields().size()) {
+ return false;
+ }
+ for (int i = 0; i < expected.getFields().size(); i++) {
+ assertField(actual.getFields().get(i)).isEqualTo(expected.getFields().get(i));
+ }
+ return true;
+ }
+ }
+
+ public static class FieldAssert extends AbstractAssert<FieldAssert, Field> {
+ FieldAssert(Field field) {
+ super(field, FieldAssert.class);
+ }
+
+ @Override
+ public FieldAssert isEqualTo(Object expected) {
+ if (!(expected instanceof Field)) {
+ throw failure(
+ "Expected object is not a Field, but rather a %s", expected.getClass().getName());
+ }
+ final Field expectedField = (Field) expected;
+ if (!fieldsEqualIgnoringMetadata(expectedField, actual)) {
+ throw failureWithActualExpected(
+ actual, expected, "Expected Field:\n%s\nActual Field:\n%s", expectedField, actual);
+ }
+ return this;
+ }
+
+ private boolean fieldsEqualIgnoringMetadata(Field expectedField, Field actualField) {
+ if (!expectedField.getName().equals(actualField.getName())) {
+ return false;
+ }
+
+ if (!expectedField.getType().equals(actualField.getType())) {
+ return false;
+ }
+
+ if (expectedField.getFieldType().isNullable() != actualField.getFieldType().isNullable()) {
+ return false;
+ }
+
+ if (!Objects.equals(
+ expectedField.getFieldType().getDictionary(),
+ actualField.getFieldType().getDictionary())) {
+ return false;
+ }
+
+ return fieldsEqualIgnoringMetadata(expectedField.getChildren(), actualField.getChildren());
+ }
+
+ private boolean fieldsEqualIgnoringMetadata(List<Field> expected, List<Field> actual) {
+ if (expected.size() != actual.size()) {
+ return false;
+ }
+
+ for (int i = 0; i < expected.size(); i++) {
+ if (!fieldsEqualIgnoringMetadata(expected.get(i), actual.get(i))) {
+ return false;
+ }
+ }
+
+ return true;
+ }
+ }
}
diff --git a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlTestUtil.java b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlTestUtil.java
index 623bfab..814d0a8 100644
--- a/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlTestUtil.java
+++ b/java/driver/validation/src/main/java/org/apache/arrow/adbc/driver/testsuite/SqlTestUtil.java
@@ -27,6 +27,7 @@ import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.vector.IntVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
+import org.apache.arrow.vector.types.Types;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
@@ -45,9 +46,9 @@ public final class SqlTestUtil {
final Schema schema =
new Schema(
Arrays.asList(
+ Field.nullable(quirks.caseFoldColumnName("INTS"), Types.MinorType.INT.getType()),
Field.nullable(
- quirks.caseFoldColumnName("INTS"), new ArrowType.Int(32, /*signed=*/ true)),
- Field.nullable(quirks.caseFoldColumnName("STRS"), new ArrowType.Utf8())));
+ quirks.caseFoldColumnName("STRS"), Types.MinorType.VARCHAR.getType())));
try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
final IntVector ints = (IntVector) root.getVector(0);
final VarCharVector strs = (VarCharVector) root.getVector(1);