You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ne...@apache.org on 2021/01/09 05:38:03 UTC
[arrow] branch master updated: ARROW-8853: [Rust] [Integration
Testing] Enable Flight tests
This is an automated email from the ASF dual-hosted git repository.
nevime pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new f202e70 ARROW-8853: [Rust] [Integration Testing] Enable Flight tests
f202e70 is described below
commit f202e70d5fa7214535b664a5eb3ebd7d0e89d4df
Author: Carol (Nichols || Goulding) <ca...@gmail.com>
AuthorDate: Sat Jan 9 07:37:00 2021 +0200
ARROW-8853: [Rust] [Integration Testing] Enable Flight tests
This PR has a few refactorings and then the main commit contains a new Flight integration test client and server 🎉
The middleware scenario tests are currently skipped because they will fail until `tonic` can be updated to a version containing [a fix having to do with trailers](https://github.com/hyperium/tonic/pull/510); this is tracked in [ARROW-10961](https://issues.apache.org/jira/browse/ARROW-10961).
Some Rust <-> Java integration tests will fail until [this PR is merged](https://github.com/apache/arrow/pull/8963); I'm happy to rebase once that goes in, but I wanted to get code review started on this. Thank you!!
Closes #9049 from carols10cents/rust-flight-integration
Lead-authored-by: Carol (Nichols || Goulding) <ca...@gmail.com>
Co-authored-by: Jake Goulding <ja...@gmail.com>
Signed-off-by: Neville Dipale <ne...@gmail.com>
---
dev/archery/archery/integration/runner.py | 4 +-
dev/archery/archery/integration/tester_rust.py | 86 +--
rust/arrow-flight/src/utils.rs | 107 ++--
rust/arrow/src/ipc/reader.rs | 2 +-
rust/arrow/src/ipc/writer.rs | 9 +-
rust/datafusion/examples/flight_client.rs | 9 +-
rust/datafusion/examples/flight_server.rs | 7 +-
rust/integration-testing/Cargo.toml | 24 +-
.../src/bin/arrow-json-integration-test.rs | 570 +-------------------
.../src/bin/flight-test-integration-client.rs | 62 +++
.../src/bin/flight-test-integration-server.rs | 55 ++
.../src/{lib.rs => flight_client_scenarios.rs} | 4 +-
.../flight_client_scenarios/auth_basic_proto.rs | 109 ++++
.../flight_client_scenarios/integration_test.rs | 271 ++++++++++
.../src/flight_client_scenarios/middleware.rs | 82 +++
.../src/flight_server_scenarios.rs | 49 ++
.../flight_server_scenarios/auth_basic_proto.rs | 226 ++++++++
.../flight_server_scenarios/integration_test.rs | 385 ++++++++++++++
.../src/flight_server_scenarios/middleware.rs | 150 ++++++
rust/integration-testing/src/lib.rs | 583 +++++++++++++++++++++
20 files changed, 2116 insertions(+), 678 deletions(-)
diff --git a/dev/archery/archery/integration/runner.py b/dev/archery/archery/integration/runner.py
index c1d7a69..520f9c4 100644
--- a/dev/archery/archery/integration/runner.py
+++ b/dev/archery/archery/integration/runner.py
@@ -347,7 +347,9 @@ def run_all_tests(with_cpp=True, with_java=True, with_js=True,
description="Authenticate using the BasicAuth protobuf."),
Scenario(
"middleware",
- description="Ensure headers are propagated via middleware."),
+ description="Ensure headers are propagated via middleware.",
+ skip={"Rust"} # TODO(ARROW-10961): tonic upgrade needed
+ ),
]
runner = IntegrationRunner(json_files, flight_scenarios, testers, **kwargs)
diff --git a/dev/archery/archery/integration/tester_rust.py b/dev/archery/archery/integration/tester_rust.py
index 23c2d37..bca80eb 100644
--- a/dev/archery/archery/integration/tester_rust.py
+++ b/dev/archery/archery/integration/tester_rust.py
@@ -15,7 +15,9 @@
# specific language governing permissions and limitations
# under the License.
+import contextlib
import os
+import subprocess
from .tester import Tester
from .util import run_cmd, ARROW_ROOT_DEFAULT, log
@@ -24,8 +26,8 @@ from .util import run_cmd, ARROW_ROOT_DEFAULT, log
class RustTester(Tester):
PRODUCER = True
CONSUMER = True
- # FLIGHT_SERVER = True
- # FLIGHT_CLIENT = True
+ FLIGHT_SERVER = True
+ FLIGHT_CLIENT = True
EXE_PATH = os.path.join(ARROW_ROOT_DEFAULT, 'rust/target/debug')
@@ -34,11 +36,11 @@ class RustTester(Tester):
STREAM_TO_FILE = os.path.join(EXE_PATH, 'arrow-stream-to-file')
FILE_TO_STREAM = os.path.join(EXE_PATH, 'arrow-file-to-stream')
- # FLIGHT_SERVER_CMD = [
- # os.path.join(EXE_PATH, 'flight-test-integration-server')]
- # FLIGHT_CLIENT_CMD = [
- # os.path.join(EXE_PATH, 'flight-test-integration-client'),
- # "-host", "localhost"]
+ FLIGHT_SERVER_CMD = [
+ os.path.join(EXE_PATH, 'flight-test-integration-server')]
+ FLIGHT_CLIENT_CMD = [
+ os.path.join(EXE_PATH, 'flight-test-integration-client'),
+ "--host", "localhost"]
name = 'Rust'
@@ -72,34 +74,42 @@ class RustTester(Tester):
cmd = [self.FILE_TO_STREAM, file_path, '>', stream_path]
self.run_shell_command(cmd)
- # @contextlib.contextmanager
- # def flight_server(self):
- # cmd = self.FLIGHT_SERVER_CMD + ['-port=0']
- # if self.debug:
- # log(' '.join(cmd))
- # server = subprocess.Popen(cmd,
- # stdout=subprocess.PIPE,
- # stderr=subprocess.PIPE)
- # try:
- # output = server.stdout.readline().decode()
- # if not output.startswith("Server listening on localhost:"):
- # server.kill()
- # out, err = server.communicate()
- # raise RuntimeError(
- # "Flight-C++ server did not start properly, "
- # "stdout:\n{}\n\nstderr:\n{}\n"
- # .format(output + out.decode(), err.decode()))
- # port = int(output.split(":")[1])
- # yield port
- # finally:
- # server.kill()
- # server.wait(5)
-
- # def flight_request(self, port, json_path):
- # cmd = self.FLIGHT_CLIENT_CMD + [
- # '-port=' + str(port),
- # '-path=' + json_path,
- # ]
- # if self.debug:
- # log(' '.join(cmd))
- # run_cmd(cmd)
+ @contextlib.contextmanager
+ def flight_server(self, scenario_name=None):
+ cmd = self.FLIGHT_SERVER_CMD + ['--port=0']
+ if scenario_name:
+ cmd = cmd + ["--scenario", scenario_name]
+ if self.debug:
+ log(' '.join(cmd))
+ server = subprocess.Popen(cmd,
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE)
+ try:
+ output = server.stdout.readline().decode()
+ if not output.startswith("Server listening on localhost:"):
+ server.kill()
+ out, err = server.communicate()
+ raise RuntimeError(
+ "Flight-Rust server did not start properly, "
+ "stdout:\n{}\n\nstderr:\n{}\n"
+ .format(output + out.decode(), err.decode()))
+ port = int(output.split(":")[1])
+ yield port
+ finally:
+ server.kill()
+ server.wait(5)
+
+ def flight_request(self, port, json_path=None, scenario_name=None):
+ cmd = self.FLIGHT_CLIENT_CMD + [
+ '--port=' + str(port),
+ ]
+ if json_path:
+ cmd.extend(('--path', json_path))
+ elif scenario_name:
+ cmd.extend(('--scenario', scenario_name))
+ else:
+ raise TypeError("Must provide one of json_path or scenario_name")
+
+ if self.debug:
+ log(' '.join(cmd))
+ run_cmd(cmd)
diff --git a/rust/arrow-flight/src/utils.rs b/rust/arrow-flight/src/utils.rs
index c2e01fb..659668c 100644
--- a/rust/arrow-flight/src/utils.rs
+++ b/rust/arrow-flight/src/utils.rs
@@ -21,17 +21,18 @@ use std::convert::TryFrom;
use crate::{FlightData, SchemaResult};
+use arrow::array::ArrayRef;
use arrow::datatypes::{Schema, SchemaRef};
use arrow::error::{ArrowError, Result};
-use arrow::ipc::{convert, reader, writer, writer::IpcWriteOptions};
+use arrow::ipc::{convert, reader, writer, writer::EncodedData, writer::IpcWriteOptions};
use arrow::record_batch::RecordBatch;
/// Convert a `RecordBatch` to a vector of `FlightData` representing the bytes of the dictionaries
-/// and values
+/// and a `FlightData` representing the bytes of the batch's values
pub fn flight_data_from_arrow_batch(
batch: &RecordBatch,
options: &IpcWriteOptions,
-) -> Vec<FlightData> {
+) -> (Vec<FlightData>, FlightData) {
let data_gen = writer::IpcDataGenerator::default();
let mut dictionary_tracker = writer::DictionaryTracker::new(false);
@@ -39,16 +40,20 @@ pub fn flight_data_from_arrow_batch(
.encoded_batch(batch, &mut dictionary_tracker, &options)
.expect("DictionaryTracker configured above to not error on replacement");
- encoded_dictionaries
- .into_iter()
- .chain(std::iter::once(encoded_batch))
- .map(|data| FlightData {
- flight_descriptor: None,
- app_metadata: vec![],
+ let flight_dictionaries = encoded_dictionaries.into_iter().map(Into::into).collect();
+ let flight_batch = encoded_batch.into();
+
+ (flight_dictionaries, flight_batch)
+}
+
+impl From<EncodedData> for FlightData {
+ fn from(data: EncodedData) -> Self {
+ FlightData {
data_header: data.ipc_message,
data_body: data.arrow_data,
- })
- .collect()
+ ..Default::default()
+ }
+ }
}
/// Convert a `Schema` to `SchemaResult` by converting to an IPC message
@@ -56,11 +61,8 @@ pub fn flight_schema_from_arrow_schema(
schema: &Schema,
options: &IpcWriteOptions,
) -> SchemaResult {
- let data_gen = writer::IpcDataGenerator::default();
- let schema_bytes = data_gen.schema_to_bytes(schema, &options);
-
SchemaResult {
- schema: schema_bytes.ipc_message,
+ schema: flight_schema_as_flatbuffer(schema, options),
}
}
@@ -69,16 +71,41 @@ pub fn flight_data_from_arrow_schema(
schema: &Schema,
options: &IpcWriteOptions,
) -> FlightData {
- let data_gen = writer::IpcDataGenerator::default();
- let schema = data_gen.schema_to_bytes(schema, &options);
+ let data_header = flight_schema_as_flatbuffer(schema, options);
FlightData {
- flight_descriptor: None,
- app_metadata: vec![],
- data_header: schema.ipc_message,
- data_body: vec![],
+ data_header,
+ ..Default::default()
}
}
+/// Convert a `Schema` to bytes in the format expected in `FlightInfo.schema`
+pub fn ipc_message_from_arrow_schema(
+ arrow_schema: &Schema,
+ options: &IpcWriteOptions,
+) -> Result<Vec<u8>> {
+ let encoded_data = flight_schema_as_encoded_data(arrow_schema, options);
+
+ let mut schema = vec![];
+ arrow::ipc::writer::write_message(&mut schema, encoded_data, options)?;
+ Ok(schema)
+}
+
+fn flight_schema_as_flatbuffer(
+ arrow_schema: &Schema,
+ options: &IpcWriteOptions,
+) -> Vec<u8> {
+ let encoded_data = flight_schema_as_encoded_data(arrow_schema, options);
+ encoded_data.ipc_message
+}
+
+fn flight_schema_as_encoded_data(
+ arrow_schema: &Schema,
+ options: &IpcWriteOptions,
+) -> EncodedData {
+ let data_gen = writer::IpcDataGenerator::default();
+ data_gen.schema_to_bytes(arrow_schema, options)
+}
+
/// Try convert `FlightData` into an Arrow Schema
///
/// Returns an error if the `FlightData` header is not a valid IPC schema
@@ -113,21 +140,12 @@ impl TryFrom<&SchemaResult> for Schema {
pub fn flight_data_to_arrow_batch(
data: &FlightData,
schema: SchemaRef,
-) -> Option<Result<RecordBatch>> {
+ dictionaries_by_field: &[Option<ArrayRef>],
+) -> Result<RecordBatch> {
// check that the data_header is a record batch message
- let res = arrow::ipc::root_as_message(&data.data_header[..]);
-
- // Catch error.
- if let Err(err) = res {
- return Some(Err(ArrowError::ParseError(format!(
- "Unable to get root as message: {:?}",
- err
- ))));
- }
-
- let message = res.unwrap();
-
- let dictionaries_by_field = Vec::new();
+ let message = arrow::ipc::root_as_message(&data.data_header[..]).map_err(|err| {
+ ArrowError::ParseError(format!("Unable to get root as message: {:?}", err))
+ })?;
message
.header_as_record_batch()
@@ -136,17 +154,14 @@ pub fn flight_data_to_arrow_batch(
"Unable to convert flight data header to a record batch".to_string(),
)
})
- .map_or_else(
- |err| Some(Err(err)),
- |batch| {
- Some(reader::read_record_batch(
- &data.data_body,
- batch,
- schema,
- &dictionaries_by_field,
- ))
- },
- )
+ .map(|batch| {
+ reader::read_record_batch(
+ &data.data_body,
+ batch,
+ schema,
+ &dictionaries_by_field,
+ )
+ })?
}
// TODO: add more explicit conversion that exposes flight descriptor and metadata options
diff --git a/rust/arrow/src/ipc/reader.rs b/rust/arrow/src/ipc/reader.rs
index d3f2829..65dcfa6 100644
--- a/rust/arrow/src/ipc/reader.rs
+++ b/rust/arrow/src/ipc/reader.rs
@@ -445,7 +445,7 @@ pub fn read_record_batch(
/// Read the dictionary from the buffer and provided metadata,
/// updating the `dictionaries_by_field` with the resulting dictionary
-fn read_dictionary(
+pub fn read_dictionary(
buf: &[u8],
batch: ipc::DictionaryBatch,
schema: &Schema,
diff --git a/rust/arrow/src/ipc/writer.rs b/rust/arrow/src/ipc/writer.rs
index fdec26c..688829a 100644
--- a/rust/arrow/src/ipc/writer.rs
+++ b/rust/arrow/src/ipc/writer.rs
@@ -554,10 +554,9 @@ pub struct EncodedData {
/// Arrow buffers to be written, should be an empty vec for schema messages
pub arrow_data: Vec<u8>,
}
-
/// Write a message's IPC data and buffers, returning metadata and buffer data lengths written
-fn write_message<W: Write>(
- mut writer: &mut BufWriter<W>,
+pub fn write_message<W: Write>(
+ mut writer: W,
encoded: EncodedData,
write_options: &IpcWriteOptions,
) -> Result<(usize, usize)> {
@@ -602,7 +601,7 @@ fn write_message<W: Write>(
Ok((aligned_size, body_len))
}
-fn write_body_buffers<W: Write>(writer: &mut BufWriter<W>, data: &[u8]) -> Result<usize> {
+fn write_body_buffers<W: Write>(mut writer: W, data: &[u8]) -> Result<usize> {
let len = data.len() as u32;
let pad_len = pad_to_8(len) as u32;
let total_len = len + pad_len;
@@ -620,7 +619,7 @@ fn write_body_buffers<W: Write>(writer: &mut BufWriter<W>, data: &[u8]) -> Resul
/// Write a record batch to the writer, writing the message size before the message
/// if the record batch is being written to a stream
fn write_continuation<W: Write>(
- writer: &mut BufWriter<W>,
+ mut writer: W,
write_options: &IpcWriteOptions,
total_len: i32,
) -> Result<usize> {
diff --git a/rust/datafusion/examples/flight_client.rs b/rust/datafusion/examples/flight_client.rs
index 13fd394..2c2954d 100644
--- a/rust/datafusion/examples/flight_client.rs
+++ b/rust/datafusion/examples/flight_client.rs
@@ -62,10 +62,13 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
// all the remaining stream messages should be dictionary and record batches
let mut results = vec![];
+ let dictionaries_by_field = vec![None; schema.fields().len()];
while let Some(flight_data) = stream.message().await? {
- // the unwrap is infallible and thus safe
- let record_batch =
- flight_data_to_arrow_batch(&flight_data, schema.clone()).unwrap()?;
+ let record_batch = flight_data_to_arrow_batch(
+ &flight_data,
+ schema.clone(),
+ &dictionaries_by_field,
+ )?;
results.push(record_batch);
}
diff --git a/rust/datafusion/examples/flight_server.rs b/rust/datafusion/examples/flight_server.rs
index 47c476a..75b470d 100644
--- a/rust/datafusion/examples/flight_server.rs
+++ b/rust/datafusion/examples/flight_server.rs
@@ -125,11 +125,14 @@ impl FlightService for FlightServiceImpl {
let mut batches: Vec<Result<FlightData, Status>> = results
.iter()
.flat_map(|batch| {
- let flight_data =
+ let (flight_dictionaries, flight_batch) =
arrow_flight::utils::flight_data_from_arrow_batch(
batch, &options,
);
- flight_data.into_iter().map(Ok)
+ flight_dictionaries
+ .into_iter()
+ .chain(std::iter::once(flight_batch))
+ .map(Ok)
})
.collect();
diff --git a/rust/integration-testing/Cargo.toml b/rust/integration-testing/Cargo.toml
index 1c26870..e4f798d 100644
--- a/rust/integration-testing/Cargo.toml
+++ b/rust/integration-testing/Cargo.toml
@@ -25,22 +25,20 @@ authors = ["Apache Arrow <de...@arrow.apache.org>"]
license = "Apache-2.0"
edition = "2018"
+[features]
+logging = ["tracing-subscriber"]
+
[dependencies]
arrow = { path = "../arrow" }
+arrow-flight = { path = "../arrow-flight" }
+async-trait = "0.1.41"
clap = "2.33"
+futures = "0.3"
+hex = "0.4"
+prost = "0.6"
serde = { version = "1.0", features = ["rc"] }
serde_derive = "1.0"
serde_json = { version = "1.0", features = ["preserve_order"] }
-hex = "0.4"
-
-[[bin]]
-name = "arrow-file-to-stream"
-path = "src/bin/arrow-file-to-stream.rs"
-
-[[bin]]
-name = "arrow-stream-to-file"
-path = "src/bin/arrow-stream-to-file.rs"
-
-[[bin]]
-name = "arrow-json-integration-test"
-path = "src/bin/arrow-json-integration-test.rs"
+tokio = { version = "0.2", features = ["macros", "rt-core", "rt-threaded"] }
+tonic = "0.3"
+tracing-subscriber = { version = "0.2.15", optional = true }
diff --git a/rust/integration-testing/src/bin/arrow-json-integration-test.rs b/rust/integration-testing/src/bin/arrow-json-integration-test.rs
index 8d17c1a..cd89a8e 100644
--- a/rust/integration-testing/src/bin/arrow-json-integration-test.rs
+++ b/rust/integration-testing/src/bin/arrow-json-integration-test.rs
@@ -15,27 +15,15 @@
// specific language governing permissions and limitations
// under the License.
-use std::collections::HashMap;
use std::fs::File;
-use std::io::BufReader;
-use std::sync::Arc;
use clap::{App, Arg};
-use hex::decode;
-use serde_json::Value;
-use arrow::array::*;
-use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, Schema};
use arrow::error::{ArrowError, Result};
use arrow::ipc::reader::FileReader;
use arrow::ipc::writer::FileWriter;
-use arrow::record_batch::RecordBatch;
-use arrow::{
- buffer::Buffer,
- buffer::MutableBuffer,
- datatypes::ToByteSlice,
- util::{bit_util, integration_util::*},
-};
+use arrow::util::integration_util::*;
+use arrow_integration_testing::read_json_file;
fn main() -> Result<()> {
let matches = App::new("rust arrow-json-integration-test")
@@ -93,520 +81,6 @@ fn json_to_arrow(json_name: &str, arrow_name: &str, verbose: bool) -> Result<()>
Ok(())
}
-fn record_batch_from_json(
- schema: &Schema,
- json_batch: ArrowJsonBatch,
- json_dictionaries: Option<&HashMap<i64, ArrowJsonDictionaryBatch>>,
-) -> Result<RecordBatch> {
- let mut columns = vec![];
-
- for (field, json_col) in schema.fields().iter().zip(json_batch.columns) {
- let col = array_from_json(field, json_col, json_dictionaries)?;
- columns.push(col);
- }
-
- RecordBatch::try_new(Arc::new(schema.clone()), columns)
-}
-
-/// Construct an Arrow array from a partially typed JSON column
-fn array_from_json(
- field: &Field,
- json_col: ArrowJsonColumn,
- dictionaries: Option<&HashMap<i64, ArrowJsonDictionaryBatch>>,
-) -> Result<ArrayRef> {
- match field.data_type() {
- DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))),
- DataType::Boolean => {
- let mut b = BooleanBuilder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(value.as_bool().unwrap()),
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::Int8 => {
- let mut b = Int8Builder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(value.as_i64().ok_or_else(|| {
- ArrowError::JsonError(format!(
- "Unable to get {:?} as int64",
- value
- ))
- })? as i8),
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::Int16 => {
- let mut b = Int16Builder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(value.as_i64().unwrap() as i16),
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::Int32
- | DataType::Date32(DateUnit::Day)
- | DataType::Time32(_)
- | DataType::Interval(IntervalUnit::YearMonth) => {
- let mut b = Int32Builder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(value.as_i64().unwrap() as i32),
- _ => b.append_null(),
- }?;
- }
- let array = Arc::new(b.finish()) as ArrayRef;
- arrow::compute::cast(&array, field.data_type())
- }
- DataType::Int64
- | DataType::Date64(DateUnit::Millisecond)
- | DataType::Time64(_)
- | DataType::Timestamp(_, _)
- | DataType::Duration(_)
- | DataType::Interval(IntervalUnit::DayTime) => {
- let mut b = Int64Builder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(match value {
- Value::Number(n) => n.as_i64().unwrap(),
- Value::String(s) => {
- s.parse().expect("Unable to parse string as i64")
- }
- _ => panic!("Unable to parse {:?} as number", value),
- }),
- _ => b.append_null(),
- }?;
- }
- let array = Arc::new(b.finish()) as ArrayRef;
- arrow::compute::cast(&array, field.data_type())
- }
- DataType::UInt8 => {
- let mut b = UInt8Builder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(value.as_u64().unwrap() as u8),
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::UInt16 => {
- let mut b = UInt16Builder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(value.as_u64().unwrap() as u16),
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::UInt32 => {
- let mut b = UInt32Builder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(value.as_u64().unwrap() as u32),
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::UInt64 => {
- let mut b = UInt64Builder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(
- value
- .as_str()
- .unwrap()
- .parse()
- .expect("Unable to parse string as u64"),
- ),
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::Float32 => {
- let mut b = Float32Builder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(value.as_f64().unwrap() as f32),
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::Float64 => {
- let mut b = Float64Builder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(value.as_f64().unwrap()),
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::Binary => {
- let mut b = BinaryBuilder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => {
- let v = decode(value.as_str().unwrap()).unwrap();
- b.append_value(&v)
- }
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::LargeBinary => {
- let mut b = LargeBinaryBuilder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => {
- let v = decode(value.as_str().unwrap()).unwrap();
- b.append_value(&v)
- }
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::Utf8 => {
- let mut b = StringBuilder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(value.as_str().unwrap()),
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::LargeUtf8 => {
- let mut b = LargeStringBuilder::new(json_col.count);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => b.append_value(value.as_str().unwrap()),
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::FixedSizeBinary(len) => {
- let mut b = FixedSizeBinaryBuilder::new(json_col.count, *len);
- for (is_valid, value) in json_col
- .validity
- .as_ref()
- .unwrap()
- .iter()
- .zip(json_col.data.unwrap())
- {
- match is_valid {
- 1 => {
- let v = hex::decode(value.as_str().unwrap()).unwrap();
- b.append_value(&v)
- }
- _ => b.append_null(),
- }?;
- }
- Ok(Arc::new(b.finish()))
- }
- DataType::List(child_field) => {
- let null_buf = create_null_buf(&json_col);
- let children = json_col.children.clone().unwrap();
- let child_array = array_from_json(
- &child_field,
- children.get(0).unwrap().clone(),
- dictionaries,
- )?;
- let offsets: Vec<i32> = json_col
- .offset
- .unwrap()
- .iter()
- .map(|v| v.as_i64().unwrap() as i32)
- .collect();
- let list_data = ArrayData::builder(field.data_type().clone())
- .len(json_col.count)
- .offset(0)
- .add_buffer(Buffer::from(&offsets.to_byte_slice()))
- .add_child_data(child_array.data())
- .null_bit_buffer(null_buf)
- .build();
- Ok(Arc::new(ListArray::from(list_data)))
- }
- DataType::LargeList(child_field) => {
- let null_buf = create_null_buf(&json_col);
- let children = json_col.children.clone().unwrap();
- let child_array = array_from_json(
- &child_field,
- children.get(0).unwrap().clone(),
- dictionaries,
- )?;
- let offsets: Vec<i64> = json_col
- .offset
- .unwrap()
- .iter()
- .map(|v| match v {
- Value::Number(n) => n.as_i64().unwrap(),
- Value::String(s) => s.parse::<i64>().unwrap(),
- _ => panic!("64-bit offset must be either string or number"),
- })
- .collect();
- let list_data = ArrayData::builder(field.data_type().clone())
- .len(json_col.count)
- .offset(0)
- .add_buffer(Buffer::from(&offsets.to_byte_slice()))
- .add_child_data(child_array.data())
- .null_bit_buffer(null_buf)
- .build();
- Ok(Arc::new(LargeListArray::from(list_data)))
- }
- DataType::FixedSizeList(child_field, _) => {
- let children = json_col.children.clone().unwrap();
- let child_array = array_from_json(
- &child_field,
- children.get(0).unwrap().clone(),
- dictionaries,
- )?;
- let null_buf = create_null_buf(&json_col);
- let list_data = ArrayData::builder(field.data_type().clone())
- .len(json_col.count)
- .add_child_data(child_array.data())
- .null_bit_buffer(null_buf)
- .build();
- Ok(Arc::new(FixedSizeListArray::from(list_data)))
- }
- DataType::Struct(fields) => {
- // construct struct with null data
- let null_buf = create_null_buf(&json_col);
- let mut array_data = ArrayData::builder(field.data_type().clone())
- .len(json_col.count)
- .null_bit_buffer(null_buf);
-
- for (field, col) in fields.iter().zip(json_col.children.unwrap()) {
- let array = array_from_json(field, col, dictionaries)?;
- array_data = array_data.add_child_data(array.data());
- }
-
- let array = StructArray::from(array_data.build());
- Ok(Arc::new(array))
- }
- DataType::Dictionary(key_type, value_type) => {
- let dict_id = field.dict_id().ok_or_else(|| {
- ArrowError::JsonError(format!(
- "Unable to find dict_id for field {:?}",
- field
- ))
- })?;
- // find dictionary
- let dictionary = dictionaries
- .ok_or_else(|| {
- ArrowError::JsonError(format!(
- "Unable to find any dictionaries for field {:?}",
- field
- ))
- })?
- .get(&dict_id);
- match dictionary {
- Some(dictionary) => dictionary_array_from_json(
- field, json_col, key_type, value_type, dictionary,
- ),
- None => Err(ArrowError::JsonError(format!(
- "Unable to find dictionary for field {:?}",
- field
- ))),
- }
- }
- t => Err(ArrowError::JsonError(format!(
- "data type {:?} not supported",
- t
- ))),
- }
-}
-
-fn dictionary_array_from_json(
- field: &Field,
- json_col: ArrowJsonColumn,
- dict_key: &DataType,
- dict_value: &DataType,
- dictionary: &ArrowJsonDictionaryBatch,
-) -> Result<ArrayRef> {
- match dict_key {
- DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64 => {
- let null_buf = create_null_buf(&json_col);
-
- // build the key data into a buffer, then construct values separately
- let key_field = Field::new_dict(
- "key",
- dict_key.clone(),
- field.is_nullable(),
- field
- .dict_id()
- .expect("Dictionary fields must have a dict_id value"),
- field
- .dict_is_ordered()
- .expect("Dictionary fields must have a dict_is_ordered value"),
- );
- let keys = array_from_json(&key_field, json_col, None)?;
- // note: not enough info on nullability of dictionary
- let value_field = Field::new("value", dict_value.clone(), true);
- println!("dictionary value type: {:?}", dict_value);
- let values =
- array_from_json(&value_field, dictionary.data.columns[0].clone(), None)?;
-
- // convert key and value to dictionary data
- let dict_data = ArrayData::builder(field.data_type().clone())
- .len(keys.len())
- .add_buffer(keys.data().buffers()[0].clone())
- .null_bit_buffer(null_buf)
- .add_child_data(values.data())
- .build();
-
- let array = match dict_key {
- DataType::Int8 => {
- Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef
- }
- DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)),
- DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)),
- DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)),
- DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)),
- DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)),
- DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)),
- DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)),
- _ => unreachable!(),
- };
- Ok(array)
- }
- _ => Err(ArrowError::JsonError(format!(
- "Dictionary key type {:?} not supported",
- dict_key
- ))),
- }
-}
-
-/// A helper to create a null buffer from a Vec<bool>
-fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer {
- let num_bytes = bit_util::ceil(json_col.count, 8);
- let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false);
- json_col
- .validity
- .clone()
- .unwrap()
- .iter()
- .enumerate()
- .for_each(|(i, v)| {
- let null_slice = null_buf.as_slice_mut();
- if *v != 0 {
- bit_util::set_bit(null_slice, i);
- }
- });
- null_buf.into()
-}
-
fn arrow_to_json(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> {
if verbose {
eprintln!("Converting {} to {}", arrow_name, json_name);
@@ -702,43 +176,3 @@ fn validate(arrow_name: &str, json_name: &str, verbose: bool) -> Result<()> {
Ok(())
}
-
-struct ArrowFile {
- schema: Schema,
- // we can evolve this into a concrete Arrow type
- // this is temporarily not being read from
- _dictionaries: HashMap<i64, ArrowJsonDictionaryBatch>,
- batches: Vec<RecordBatch>,
-}
-
-fn read_json_file(json_name: &str) -> Result<ArrowFile> {
- let json_file = File::open(json_name)?;
- let reader = BufReader::new(json_file);
- let arrow_json: Value = serde_json::from_reader(reader).unwrap();
- let schema = Schema::from(&arrow_json["schema"])?;
- // read dictionaries
- let mut dictionaries = HashMap::new();
- if let Some(dicts) = arrow_json.get("dictionaries") {
- for d in dicts
- .as_array()
- .expect("Unable to get dictionaries as array")
- {
- let json_dict: ArrowJsonDictionaryBatch = serde_json::from_value(d.clone())
- .expect("Unable to get dictionary from JSON");
- // TODO: convert to a concrete Arrow type
- dictionaries.insert(json_dict.id, json_dict);
- }
- }
-
- let mut batches = vec![];
- for b in arrow_json["batches"].as_array().unwrap() {
- let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap();
- let batch = record_batch_from_json(&schema, json_batch, Some(&dictionaries))?;
- batches.push(batch);
- }
- Ok(ArrowFile {
- schema,
- _dictionaries: dictionaries,
- batches,
- })
-}
diff --git a/rust/integration-testing/src/bin/flight-test-integration-client.rs b/rust/integration-testing/src/bin/flight-test-integration-client.rs
new file mode 100644
index 0000000..1901553
--- /dev/null
+++ b/rust/integration-testing/src/bin/flight-test-integration-client.rs
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_integration_testing::flight_client_scenarios;
+
+use clap::{App, Arg};
+
+type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
+type Result<T = (), E = Error> = std::result::Result<T, E>;
+
+#[tokio::main]
+async fn main() -> Result {
+ #[cfg(feature = "logging")]
+ tracing_subscriber::fmt::init();
+
+ let matches = App::new("rust flight-test-integration-client")
+ .arg(Arg::with_name("host").long("host").takes_value(true))
+ .arg(Arg::with_name("port").long("port").takes_value(true))
+ .arg(Arg::with_name("path").long("path").takes_value(true))
+ .arg(
+ Arg::with_name("scenario")
+ .long("scenario")
+ .takes_value(true),
+ )
+ .get_matches();
+
+ let host = matches.value_of("host").expect("Host is required");
+ let port = matches.value_of("port").expect("Port is required");
+
+ match matches.value_of("scenario") {
+ Some("middleware") => {
+ flight_client_scenarios::middleware::run_scenario(host, port).await?
+ }
+ Some("auth:basic_proto") => {
+ flight_client_scenarios::auth_basic_proto::run_scenario(host, port).await?
+ }
+ Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name),
+ None => {
+ let path = matches
+ .value_of("path")
+ .expect("Path is required if scenario is not specified");
+ flight_client_scenarios::integration_test::run_scenario(host, port, path)
+ .await?;
+ }
+ }
+
+ Ok(())
+}
diff --git a/rust/integration-testing/src/bin/flight-test-integration-server.rs b/rust/integration-testing/src/bin/flight-test-integration-server.rs
new file mode 100644
index 0000000..b1b2807
--- /dev/null
+++ b/rust/integration-testing/src/bin/flight-test-integration-server.rs
@@ -0,0 +1,55 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use clap::{App, Arg};
+
+use arrow_integration_testing::flight_server_scenarios;
+
+type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
+type Result<T = (), E = Error> = std::result::Result<T, E>;
+
+#[tokio::main]
+async fn main() -> Result {
+ #[cfg(feature = "logging")]
+ tracing_subscriber::fmt::init();
+
+ let matches = App::new("rust flight-test-integration-server")
+ .about("Integration testing server for Flight.")
+ .arg(Arg::with_name("port").long("port").takes_value(true))
+ .arg(
+ Arg::with_name("scenario")
+ .long("scenario")
+ .takes_value(true),
+ )
+ .get_matches();
+
+ let port = matches.value_of("port").unwrap_or("0");
+
+ match matches.value_of("scenario") {
+ Some("middleware") => {
+ flight_server_scenarios::middleware::scenario_setup(port).await?
+ }
+ Some("auth:basic_proto") => {
+ flight_server_scenarios::auth_basic_proto::scenario_setup(port).await?
+ }
+ Some(scenario_name) => unimplemented!("Scenario not found: {}", scenario_name),
+ None => {
+ flight_server_scenarios::integration_test::scenario_setup(port).await?;
+ }
+ }
+ Ok(())
+}
diff --git a/rust/integration-testing/src/lib.rs b/rust/integration-testing/src/flight_client_scenarios.rs
similarity index 91%
copy from rust/integration-testing/src/lib.rs
copy to rust/integration-testing/src/flight_client_scenarios.rs
index 596017a..66cced5 100644
--- a/rust/integration-testing/src/lib.rs
+++ b/rust/integration-testing/src/flight_client_scenarios.rs
@@ -15,4 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-//! Common code used in the integration test binaries
+pub mod auth_basic_proto;
+pub mod integration_test;
+pub mod middleware;
diff --git a/rust/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs b/rust/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs
new file mode 100644
index 0000000..5e8cd46
--- /dev/null
+++ b/rust/integration-testing/src/flight_client_scenarios/auth_basic_proto.rs
@@ -0,0 +1,109 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{AUTH_PASSWORD, AUTH_USERNAME};
+
+use arrow_flight::{
+ flight_service_client::FlightServiceClient, BasicAuth, HandshakeRequest,
+};
+use futures::{stream, StreamExt};
+use prost::Message;
+use tonic::{metadata::MetadataValue, Request, Status};
+
+type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
+type Result<T = (), E = Error> = std::result::Result<T, E>;
+
+type Client = FlightServiceClient<tonic::transport::Channel>;
+
+pub async fn run_scenario(host: &str, port: &str) -> Result {
+ let url = format!("http://{}:{}", host, port);
+ let mut client = FlightServiceClient::connect(url).await?;
+
+ let action = arrow_flight::Action::default();
+
+ let resp = client.do_action(Request::new(action.clone())).await;
+ // This client is unauthenticated and should fail.
+ match resp {
+ Err(e) => {
+ if e.code() != tonic::Code::Unauthenticated {
+ return Err(Box::new(Status::internal(format!(
+ "Expected UNAUTHENTICATED but got {:?}",
+ e
+ ))));
+ }
+ }
+ Ok(other) => {
+ return Err(Box::new(Status::internal(format!(
+ "Expected UNAUTHENTICATED but got {:?}",
+ other
+ ))));
+ }
+ }
+
+ let token = authenticate(&mut client, AUTH_USERNAME, AUTH_PASSWORD)
+ .await
+ .expect("must respond successfully from handshake");
+
+ let mut request = Request::new(action);
+ let metadata = request.metadata_mut();
+ metadata.insert_bin(
+ "auth-token-bin",
+ MetadataValue::from_bytes(token.as_bytes()),
+ );
+
+ let resp = client.do_action(request).await?;
+ let mut resp = resp.into_inner();
+
+ let r = resp
+ .next()
+ .await
+ .expect("No response received")
+ .expect("Invalid response received");
+
+ let body = String::from_utf8(r.body).unwrap();
+ assert_eq!(body, AUTH_USERNAME);
+
+ Ok(())
+}
+
+async fn authenticate(
+ client: &mut Client,
+ username: &str,
+ password: &str,
+) -> Result<String> {
+ let auth = BasicAuth {
+ username: username.into(),
+ password: password.into(),
+ };
+ let mut payload = vec![];
+ auth.encode(&mut payload)?;
+
+ let req = stream::once(async {
+ HandshakeRequest {
+ payload,
+ ..HandshakeRequest::default()
+ }
+ });
+
+ let rx = client.handshake(Request::new(req)).await?;
+ let mut rx = rx.into_inner();
+
+ let r = rx.next().await.expect("must respond from handshake")?;
+ assert!(rx.next().await.is_none(), "must not respond a second time");
+
+ Ok(String::from_utf8(r.payload).unwrap())
+}
diff --git a/rust/integration-testing/src/flight_client_scenarios/integration_test.rs b/rust/integration-testing/src/flight_client_scenarios/integration_test.rs
new file mode 100644
index 0000000..ff61b5c
--- /dev/null
+++ b/rust/integration-testing/src/flight_client_scenarios/integration_test.rs
@@ -0,0 +1,271 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{read_json_file, ArrowFile};
+
+use arrow::{
+ array::ArrayRef,
+ datatypes::SchemaRef,
+ ipc::{self, reader, writer},
+ record_batch::RecordBatch,
+};
+use arrow_flight::{
+ flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient,
+ utils::flight_data_to_arrow_batch, FlightData, FlightDescriptor, Location, Ticket,
+};
+use futures::{channel::mpsc, sink::SinkExt, stream, StreamExt};
+use tonic::{Request, Streaming};
+
+use std::sync::Arc;
+
+type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
+type Result<T = (), E = Error> = std::result::Result<T, E>;
+
+type Client = FlightServiceClient<tonic::transport::Channel>;
+
+pub async fn run_scenario(host: &str, port: &str, path: &str) -> Result {
+ let url = format!("http://{}:{}", host, port);
+
+ let client = FlightServiceClient::connect(url).await?;
+
+ let ArrowFile {
+ schema, batches, ..
+ } = read_json_file(path)?;
+
+ let schema = Arc::new(schema);
+
+ let mut descriptor = FlightDescriptor::default();
+ descriptor.set_type(DescriptorType::Path);
+ descriptor.path = vec![path.to_string()];
+
+ upload_data(
+ client.clone(),
+ schema.clone(),
+ descriptor.clone(),
+ batches.clone(),
+ )
+ .await?;
+ verify_data(client, descriptor, schema, &batches).await?;
+
+ Ok(())
+}
+
+async fn upload_data(
+ mut client: Client,
+ schema: SchemaRef,
+ descriptor: FlightDescriptor,
+ original_data: Vec<RecordBatch>,
+) -> Result {
+ let (mut upload_tx, upload_rx) = mpsc::channel(10);
+
+ let options = arrow::ipc::writer::IpcWriteOptions::default();
+ let mut schema_flight_data =
+ arrow_flight::utils::flight_data_from_arrow_schema(&schema, &options);
+ schema_flight_data.flight_descriptor = Some(descriptor.clone());
+ upload_tx.send(schema_flight_data).await?;
+
+ let mut original_data_iter = original_data.iter().enumerate();
+
+ if let Some((counter, first_batch)) = original_data_iter.next() {
+ let metadata = counter.to_string().into_bytes();
+ // Preload the first batch into the channel before starting the request
+ send_batch(&mut upload_tx, &metadata, first_batch, &options).await?;
+
+ let outer = client.do_put(Request::new(upload_rx)).await?;
+ let mut inner = outer.into_inner();
+
+ let r = inner
+ .next()
+ .await
+ .expect("No response received")
+ .expect("Invalid response received");
+ assert_eq!(metadata, r.app_metadata);
+
+ // Stream the rest of the batches
+ for (counter, batch) in original_data_iter {
+ let metadata = counter.to_string().into_bytes();
+ send_batch(&mut upload_tx, &metadata, batch, &options).await?;
+
+ let r = inner
+ .next()
+ .await
+ .expect("No response received")
+ .expect("Invalid response received");
+ assert_eq!(metadata, r.app_metadata);
+ }
+ drop(upload_tx);
+ assert!(
+ inner.next().await.is_none(),
+ "Should not receive more results"
+ );
+ } else {
+ drop(upload_tx);
+ client.do_put(Request::new(upload_rx)).await?;
+ }
+
+ Ok(())
+}
+
+async fn send_batch(
+ upload_tx: &mut mpsc::Sender<FlightData>,
+ metadata: &[u8],
+ batch: &RecordBatch,
+ options: &writer::IpcWriteOptions,
+) -> Result {
+ let (dictionary_flight_data, mut batch_flight_data) =
+ arrow_flight::utils::flight_data_from_arrow_batch(batch, &options);
+
+ upload_tx
+ .send_all(&mut stream::iter(dictionary_flight_data).map(Ok))
+ .await?;
+
+ // Only the record batch's FlightData gets app_metadata
+ batch_flight_data.app_metadata = metadata.to_vec();
+ upload_tx.send(batch_flight_data).await?;
+ Ok(())
+}
+
+async fn verify_data(
+ mut client: Client,
+ descriptor: FlightDescriptor,
+ expected_schema: SchemaRef,
+ expected_data: &[RecordBatch],
+) -> Result {
+ let resp = client.get_flight_info(Request::new(descriptor)).await?;
+ let info = resp.into_inner();
+
+ assert!(
+ !info.endpoint.is_empty(),
+ "No endpoints returned from Flight server",
+ );
+ for endpoint in info.endpoint {
+ let ticket = endpoint
+ .ticket
+ .expect("No ticket returned from Flight server");
+
+ assert!(
+ !endpoint.location.is_empty(),
+ "No locations returned from Flight server",
+ );
+ for location in endpoint.location {
+ consume_flight_location(
+ location,
+ ticket.clone(),
+ &expected_data,
+ expected_schema.clone(),
+ )
+ .await?;
+ }
+ }
+
+ Ok(())
+}
+
+async fn consume_flight_location(
+ location: Location,
+ ticket: Ticket,
+ expected_data: &[RecordBatch],
+ schema: SchemaRef,
+) -> Result {
+ let mut location = location;
+ // The other Flight implementations use the `grpc+tcp` scheme, but the Rust http libs
+ // don't recognize this as valid.
+ location.uri = location.uri.replace("grpc+tcp://", "grpc://");
+
+ let mut client = FlightServiceClient::connect(location.uri).await?;
+ let resp = client.do_get(ticket).await?;
+ let mut resp = resp.into_inner();
+
+ // We already have the schema from the FlightInfo, but the server sends it again as the
+ // first FlightData. Ignore this one.
+ let _schema_again = resp.next().await.unwrap();
+
+ let mut dictionaries_by_field = vec![None; schema.fields().len()];
+
+ for (counter, expected_batch) in expected_data.iter().enumerate() {
+ let data = receive_batch_flight_data(
+ &mut resp,
+ schema.clone(),
+ &mut dictionaries_by_field,
+ )
+ .await
+ .unwrap_or_else(|| {
+ panic!(
+ "Got fewer batches than expected, received so far: {} expected: {}",
+ counter,
+ expected_data.len(),
+ )
+ });
+
+ let metadata = counter.to_string().into_bytes();
+ assert_eq!(metadata, data.app_metadata);
+
+ let actual_batch =
+ flight_data_to_arrow_batch(&data, schema.clone(), &dictionaries_by_field)
+ .expect("Unable to convert flight data to Arrow batch");
+
+ assert_eq!(expected_batch.schema(), actual_batch.schema());
+ assert_eq!(expected_batch.num_columns(), actual_batch.num_columns());
+ assert_eq!(expected_batch.num_rows(), actual_batch.num_rows());
+ let schema = expected_batch.schema();
+ for i in 0..expected_batch.num_columns() {
+ let field = schema.field(i);
+ let field_name = field.name();
+
+ let expected_data = expected_batch.column(i).data();
+ let actual_data = actual_batch.column(i).data();
+
+ assert_eq!(expected_data, actual_data, "Data for field {}", field_name);
+ }
+ }
+
+ assert!(
+ resp.next().await.is_none(),
+ "Got more batches than the expected: {}",
+ expected_data.len(),
+ );
+
+ Ok(())
+}
+
+async fn receive_batch_flight_data(
+ resp: &mut Streaming<FlightData>,
+ schema: SchemaRef,
+ dictionaries_by_field: &mut [Option<ArrayRef>],
+) -> Option<FlightData> {
+ let mut data = resp.next().await?.ok()?;
+ let mut message = arrow::ipc::root_as_message(&data.data_header[..])
+ .expect("Error parsing first message");
+
+ while message.header_type() == ipc::MessageHeader::DictionaryBatch {
+ reader::read_dictionary(
+ &data.data_body,
+ message
+ .header_as_dictionary_batch()
+ .expect("Error parsing dictionary"),
+ &schema,
+ dictionaries_by_field,
+ )
+ .expect("Error reading dictionary");
+
+ data = resp.next().await?.ok()?;
+ message = arrow::ipc::root_as_message(&data.data_header[..])
+ .expect("Error parsing message");
+ }
+
+ Some(data)
+}
diff --git a/rust/integration-testing/src/flight_client_scenarios/middleware.rs b/rust/integration-testing/src/flight_client_scenarios/middleware.rs
new file mode 100644
index 0000000..607eab1
--- /dev/null
+++ b/rust/integration-testing/src/flight_client_scenarios/middleware.rs
@@ -0,0 +1,82 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow_flight::{
+ flight_descriptor::DescriptorType, flight_service_client::FlightServiceClient,
+ FlightDescriptor,
+};
+use tonic::{Request, Status};
+
+type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
+type Result<T = (), E = Error> = std::result::Result<T, E>;
+
+pub async fn run_scenario(host: &str, port: &str) -> Result {
+ let url = format!("http://{}:{}", host, port);
+ let conn = tonic::transport::Endpoint::new(url)?.connect().await?;
+ let mut client = FlightServiceClient::with_interceptor(conn, middleware_interceptor);
+
+ let mut descriptor = FlightDescriptor::default();
+ descriptor.set_type(DescriptorType::Cmd);
+ descriptor.cmd = b"".to_vec();
+
+ // This call is expected to fail.
+ match client
+ .get_flight_info(Request::new(descriptor.clone()))
+ .await
+ {
+ Ok(_) => return Err(Box::new(Status::internal("Expected call to fail"))),
+ Err(e) => {
+ let headers = e.metadata();
+ let middleware_header = headers.get("x-middleware");
+ let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or("");
+
+ if value != "expected value" {
+ let msg = format!(
+ "On failing call: Expected to receive header 'x-middleware: expected value', \
+ but instead got: '{}'",
+ value
+ );
+ return Err(Box::new(Status::internal(msg)));
+ }
+ }
+ }
+
+ // This call should succeed
+ descriptor.cmd = b"success".to_vec();
+ let resp = client.get_flight_info(Request::new(descriptor)).await?;
+
+ let headers = resp.metadata();
+ let middleware_header = headers.get("x-middleware");
+ let value = middleware_header.map(|v| v.to_str().unwrap()).unwrap_or("");
+
+ if value != "expected value" {
+ let msg = format!(
+ "On success call: Expected to receive header 'x-middleware: expected value', \
+ but instead got: '{}'",
+ value
+ );
+ return Err(Box::new(Status::internal(msg)));
+ }
+
+ Ok(())
+}
+
+fn middleware_interceptor(mut req: Request<()>) -> Result<Request<()>, Status> {
+ let metadata = req.metadata_mut();
+ metadata.insert("x-middleware", "expected value".parse().unwrap());
+ Ok(req)
+}
diff --git a/rust/integration-testing/src/flight_server_scenarios.rs b/rust/integration-testing/src/flight_server_scenarios.rs
new file mode 100644
index 0000000..3d99e53
--- /dev/null
+++ b/rust/integration-testing/src/flight_server_scenarios.rs
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::net::SocketAddr;
+
+use arrow_flight::{FlightEndpoint, Location, Ticket};
+use tokio::net::TcpListener;
+
+pub mod auth_basic_proto;
+pub mod integration_test;
+pub mod middleware;
+
+type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
+type Result<T = (), E = Error> = std::result::Result<T, E>;
+
+pub async fn listen_on(port: &str) -> Result<(TcpListener, SocketAddr)> {
+ let addr: SocketAddr = format!("0.0.0.0:{}", port).parse()?;
+
+ let listener = TcpListener::bind(addr).await?;
+ let addr = listener.local_addr()?;
+ println!("Server listening on localhost:{}", addr.port());
+
+ Ok((listener, addr))
+}
+
+pub fn endpoint(ticket: &str, location_uri: impl Into<String>) -> FlightEndpoint {
+ FlightEndpoint {
+ ticket: Some(Ticket {
+ ticket: ticket.as_bytes().to_vec(),
+ }),
+ location: vec![Location {
+ uri: location_uri.into(),
+ }],
+ }
+}
diff --git a/rust/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs b/rust/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs
new file mode 100644
index 0000000..355209f
--- /dev/null
+++ b/rust/integration-testing/src/flight_server_scenarios/auth_basic_proto.rs
@@ -0,0 +1,226 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::pin::Pin;
+use std::sync::Arc;
+
+use arrow_flight::{
+ flight_service_server::FlightService, flight_service_server::FlightServiceServer,
+ Action, ActionType, BasicAuth, Criteria, Empty, FlightData, FlightDescriptor,
+ FlightInfo, HandshakeRequest, HandshakeResponse, PutResult, SchemaResult, Ticket,
+};
+use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt};
+use tokio::sync::Mutex;
+use tonic::{
+ metadata::MetadataMap, transport::Server, Request, Response, Status, Streaming,
+};
+
+type TonicStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'static>>;
+
+type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
+type Result<T = (), E = Error> = std::result::Result<T, E>;
+
+use prost::Message;
+
+use crate::{AUTH_PASSWORD, AUTH_USERNAME};
+
+pub async fn scenario_setup(port: &str) -> Result {
+ let (mut listener, _) = super::listen_on(port).await?;
+
+ let service = AuthBasicProtoScenarioImpl {
+ username: AUTH_USERNAME.into(),
+ password: AUTH_PASSWORD.into(),
+ peer_identity: Arc::new(Mutex::new(None)),
+ };
+ let svc = FlightServiceServer::new(service);
+
+ Server::builder()
+ .add_service(svc)
+ .serve_with_incoming(listener.incoming())
+ .await?;
+ Ok(())
+}
+
+#[derive(Clone)]
+pub struct AuthBasicProtoScenarioImpl {
+ username: Arc<str>,
+ password: Arc<str>,
+ peer_identity: Arc<Mutex<Option<String>>>,
+}
+
+impl AuthBasicProtoScenarioImpl {
+ async fn check_auth(
+ &self,
+ metadata: &MetadataMap,
+ ) -> Result<GrpcServerCallContext, Status> {
+ let token = metadata
+ .get_bin("auth-token-bin")
+ .and_then(|v| v.to_bytes().ok())
+ .and_then(|b| String::from_utf8(b.to_vec()).ok());
+ self.is_valid(token).await
+ }
+
+ async fn is_valid(
+ &self,
+ token: Option<String>,
+ ) -> Result<GrpcServerCallContext, Status> {
+ match token {
+ Some(t) if t == *self.username => Ok(GrpcServerCallContext {
+ peer_identity: self.username.to_string(),
+ }),
+ _ => Err(Status::unauthenticated("Invalid token")),
+ }
+ }
+}
+
+struct GrpcServerCallContext {
+ peer_identity: String,
+}
+
+impl GrpcServerCallContext {
+ pub fn peer_identity(&self) -> &str {
+ &self.peer_identity
+ }
+}
+
+#[tonic::async_trait]
+impl FlightService for AuthBasicProtoScenarioImpl {
+ type HandshakeStream = TonicStream<Result<HandshakeResponse, Status>>;
+ type ListFlightsStream = TonicStream<Result<FlightInfo, Status>>;
+ type DoGetStream = TonicStream<Result<FlightData, Status>>;
+ type DoPutStream = TonicStream<Result<PutResult, Status>>;
+ type DoActionStream = TonicStream<Result<arrow_flight::Result, Status>>;
+ type ListActionsStream = TonicStream<Result<ActionType, Status>>;
+ type DoExchangeStream = TonicStream<Result<FlightData, Status>>;
+
+ async fn get_schema(
+ &self,
+ request: Request<FlightDescriptor>,
+ ) -> Result<Response<SchemaResult>, Status> {
+ self.check_auth(request.metadata()).await?;
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn do_get(
+ &self,
+ request: Request<Ticket>,
+ ) -> Result<Response<Self::DoGetStream>, Status> {
+ self.check_auth(request.metadata()).await?;
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn handshake(
+ &self,
+ request: Request<Streaming<HandshakeRequest>>,
+ ) -> Result<Response<Self::HandshakeStream>, Status> {
+ let (tx, rx) = mpsc::channel(10);
+
+ tokio::spawn({
+ let username = self.username.clone();
+ let password = self.password.clone();
+
+ async move {
+ let requests = request.into_inner();
+
+ requests
+ .for_each(move |req| {
+ let mut tx = tx.clone();
+ let req = req.expect("Error reading handshake request");
+ let HandshakeRequest { payload, .. } = req;
+
+ let auth = BasicAuth::decode(&*payload)
+ .expect("Error parsing handshake request");
+
+ let resp = if *auth.username == *username
+ && *auth.password == *password
+ {
+ Ok(HandshakeResponse {
+ payload: username.as_bytes().to_vec(),
+ ..HandshakeResponse::default()
+ })
+ } else {
+ Err(Status::unauthenticated(format!(
+ "Don't know user {}",
+ auth.username
+ )))
+ };
+
+ async move {
+ tx.send(resp)
+ .await
+ .expect("Error sending handshake response");
+ }
+ })
+ .await;
+ }
+ });
+
+ Ok(Response::new(Box::pin(rx)))
+ }
+
+ async fn list_flights(
+ &self,
+ request: Request<Criteria>,
+ ) -> Result<Response<Self::ListFlightsStream>, Status> {
+ self.check_auth(request.metadata()).await?;
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn get_flight_info(
+ &self,
+ request: Request<FlightDescriptor>,
+ ) -> Result<Response<FlightInfo>, Status> {
+ self.check_auth(request.metadata()).await?;
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn do_put(
+ &self,
+ request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoPutStream>, Status> {
+ self.check_auth(request.metadata()).await?;
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn do_action(
+ &self,
+ request: Request<Action>,
+ ) -> Result<Response<Self::DoActionStream>, Status> {
+ let flight_context = self.check_auth(request.metadata()).await?;
+ // Respond with the authenticated username.
+ let buf = flight_context.peer_identity().as_bytes().to_vec();
+ let result = arrow_flight::Result { body: buf };
+ let output = futures::stream::once(async { Ok(result) });
+ Ok(Response::new(Box::pin(output) as Self::DoActionStream))
+ }
+
+ async fn list_actions(
+ &self,
+ request: Request<Empty>,
+ ) -> Result<Response<Self::ListActionsStream>, Status> {
+ self.check_auth(request.metadata()).await?;
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn do_exchange(
+ &self,
+ request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoExchangeStream>, Status> {
+ self.check_auth(request.metadata()).await?;
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+}
diff --git a/rust/integration-testing/src/flight_server_scenarios/integration_test.rs b/rust/integration-testing/src/flight_server_scenarios/integration_test.rs
new file mode 100644
index 0000000..a555b2e
--- /dev/null
+++ b/rust/integration-testing/src/flight_server_scenarios/integration_test.rs
@@ -0,0 +1,385 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::collections::HashMap;
+use std::convert::TryFrom;
+use std::pin::Pin;
+use std::sync::Arc;
+
+use arrow::{
+ array::ArrayRef,
+ datatypes::Schema,
+ datatypes::SchemaRef,
+ ipc::{self, reader},
+ record_batch::RecordBatch,
+};
+use arrow_flight::{
+ flight_descriptor::DescriptorType, flight_service_server::FlightService,
+ flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty,
+ FlightData, FlightDescriptor, FlightEndpoint, FlightInfo, HandshakeRequest,
+ HandshakeResponse, PutResult, SchemaResult, Ticket,
+};
+use futures::{channel::mpsc, sink::SinkExt, Stream, StreamExt};
+use tokio::sync::Mutex;
+use tonic::{transport::Server, Request, Response, Status, Streaming};
+
+type TonicStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'static>>;
+
+type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
+type Result<T = (), E = Error> = std::result::Result<T, E>;
+
+pub async fn scenario_setup(port: &str) -> Result {
+ let (mut listener, addr) = super::listen_on(port).await?;
+
+ let service = FlightServiceImpl {
+ server_location: format!("grpc+tcp://{}", addr),
+ ..Default::default()
+ };
+ let svc = FlightServiceServer::new(service);
+
+ Server::builder()
+ .add_service(svc)
+ .serve_with_incoming(listener.incoming())
+ .await?;
+
+ Ok(())
+}
+
+#[derive(Debug, Clone)]
+struct IntegrationDataset {
+ schema: Schema,
+ chunks: Vec<RecordBatch>,
+}
+
+#[derive(Clone, Default)]
+pub struct FlightServiceImpl {
+ server_location: String,
+ uploaded_chunks: Arc<Mutex<HashMap<String, IntegrationDataset>>>,
+}
+
+impl FlightServiceImpl {
+ fn endpoint_from_path(&self, path: &str) -> FlightEndpoint {
+ super::endpoint(path, &self.server_location)
+ }
+}
+
+#[tonic::async_trait]
+impl FlightService for FlightServiceImpl {
+ type HandshakeStream = TonicStream<Result<HandshakeResponse, Status>>;
+ type ListFlightsStream = TonicStream<Result<FlightInfo, Status>>;
+ type DoGetStream = TonicStream<Result<FlightData, Status>>;
+ type DoPutStream = TonicStream<Result<PutResult, Status>>;
+ type DoActionStream = TonicStream<Result<arrow_flight::Result, Status>>;
+ type ListActionsStream = TonicStream<Result<ActionType, Status>>;
+ type DoExchangeStream = TonicStream<Result<FlightData, Status>>;
+
+ async fn get_schema(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<SchemaResult>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn do_get(
+ &self,
+ request: Request<Ticket>,
+ ) -> Result<Response<Self::DoGetStream>, Status> {
+ let ticket = request.into_inner();
+
+ let key = String::from_utf8(ticket.ticket.to_vec())
+ .map_err(|e| Status::invalid_argument(format!("Invalid ticket: {:?}", e)))?;
+
+ let uploaded_chunks = self.uploaded_chunks.lock().await;
+
+ let flight = uploaded_chunks.get(&key).ok_or_else(|| {
+ Status::not_found(format!("Could not find flight. {}", key))
+ })?;
+
+ let options = arrow::ipc::writer::IpcWriteOptions::default();
+
+ let schema = std::iter::once({
+ Ok(arrow_flight::utils::flight_data_from_arrow_schema(
+ &flight.schema,
+ &options,
+ ))
+ });
+
+ let batches = flight
+ .chunks
+ .iter()
+ .enumerate()
+ .flat_map(|(counter, batch)| {
+ let (dictionary_flight_data, mut batch_flight_data) =
+ arrow_flight::utils::flight_data_from_arrow_batch(batch, &options);
+
+ // Only the record batch's FlightData gets app_metadata
+ let metadata = counter.to_string().into_bytes();
+ batch_flight_data.app_metadata = metadata;
+
+ dictionary_flight_data
+ .into_iter()
+ .chain(std::iter::once(batch_flight_data))
+ .map(Ok)
+ });
+
+ let output = futures::stream::iter(schema.chain(batches).collect::<Vec<_>>());
+
+ Ok(Response::new(Box::pin(output) as Self::DoGetStream))
+ }
+
+ async fn handshake(
+ &self,
+ _request: Request<Streaming<HandshakeRequest>>,
+ ) -> Result<Response<Self::HandshakeStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn list_flights(
+ &self,
+ _request: Request<Criteria>,
+ ) -> Result<Response<Self::ListFlightsStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn get_flight_info(
+ &self,
+ request: Request<FlightDescriptor>,
+ ) -> Result<Response<FlightInfo>, Status> {
+ let descriptor = request.into_inner();
+
+ match descriptor.r#type {
+ t if t == DescriptorType::Path as i32 => {
+ let path = &descriptor.path;
+ if path.is_empty() {
+ return Err(Status::invalid_argument("Invalid path"));
+ }
+
+ let uploaded_chunks = self.uploaded_chunks.lock().await;
+ let flight = uploaded_chunks.get(&path[0]).ok_or_else(|| {
+ Status::not_found(format!("Could not find flight. {}", path[0]))
+ })?;
+
+ let endpoint = self.endpoint_from_path(&path[0]);
+
+ let total_records: usize =
+ flight.chunks.iter().map(|chunk| chunk.num_rows()).sum();
+
+ let options = arrow::ipc::writer::IpcWriteOptions::default();
+ let schema = arrow_flight::utils::ipc_message_from_arrow_schema(
+ &flight.schema,
+ &options,
+ )
+ .expect(
+ "Could not generate schema bytes from schema stored by a DoPut; \
+ this should be impossible",
+ );
+
+ let info = FlightInfo {
+ schema,
+ flight_descriptor: Some(descriptor.clone()),
+ endpoint: vec![endpoint],
+ total_records: total_records as i64,
+ total_bytes: -1,
+ };
+
+ Ok(Response::new(info))
+ }
+ other => Err(Status::unimplemented(format!("Request type: {}", other))),
+ }
+ }
+
+ async fn do_put(
+ &self,
+ request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoPutStream>, Status> {
+ let mut input_stream = request.into_inner();
+ let flight_data = input_stream
+ .message()
+ .await?
+ .ok_or_else(|| Status::invalid_argument("Must send some FlightData"))?;
+
+ let descriptor = flight_data
+ .flight_descriptor
+ .clone()
+ .ok_or_else(|| Status::invalid_argument("Must have a descriptor"))?;
+
+ if descriptor.r#type != DescriptorType::Path as i32 || descriptor.path.is_empty()
+ {
+ return Err(Status::invalid_argument("Must specify a path"));
+ }
+
+ let key = descriptor.path[0].clone();
+
+ let schema = Schema::try_from(&flight_data)
+ .map_err(|e| Status::invalid_argument(format!("Invalid schema: {:?}", e)))?;
+ let schema_ref = Arc::new(schema.clone());
+
+ let (response_tx, response_rx) = mpsc::channel(10);
+
+ let uploaded_chunks = self.uploaded_chunks.clone();
+
+ tokio::spawn(async {
+ let mut error_tx = response_tx.clone();
+ if let Err(e) = save_uploaded_chunks(
+ uploaded_chunks,
+ schema_ref,
+ input_stream,
+ response_tx,
+ schema,
+ key,
+ )
+ .await
+ {
+ error_tx.send(Err(e)).await.expect("Error sending error")
+ }
+ });
+
+ Ok(Response::new(Box::pin(response_rx) as Self::DoPutStream))
+ }
+
+ async fn do_action(
+ &self,
+ _request: Request<Action>,
+ ) -> Result<Response<Self::DoActionStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn list_actions(
+ &self,
+ _request: Request<Empty>,
+ ) -> Result<Response<Self::ListActionsStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn do_exchange(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoExchangeStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+}
+
+async fn send_app_metadata(
+ tx: &mut mpsc::Sender<Result<PutResult, Status>>,
+ app_metadata: &[u8],
+) -> Result<(), Status> {
+ tx.send(Ok(PutResult {
+ app_metadata: app_metadata.to_vec(),
+ }))
+ .await
+ .map_err(|e| Status::internal(format!("Could not send PutResult: {:?}", e)))
+}
+
+async fn record_batch_from_message(
+ message: ipc::Message<'_>,
+ data_body: &[u8],
+ schema_ref: SchemaRef,
+ dictionaries_by_field: &[Option<ArrayRef>],
+) -> Result<RecordBatch, Status> {
+ let ipc_batch = message.header_as_record_batch().ok_or_else(|| {
+ Status::internal("Could not parse message header as record batch")
+ })?;
+
+ let arrow_batch_result = reader::read_record_batch(
+ data_body,
+ ipc_batch,
+ schema_ref,
+ &dictionaries_by_field,
+ );
+
+ arrow_batch_result.map_err(|e| {
+ Status::internal(format!("Could not convert to RecordBatch: {:?}", e))
+ })
+}
+
+async fn dictionary_from_message(
+ message: ipc::Message<'_>,
+ data_body: &[u8],
+ schema_ref: SchemaRef,
+ dictionaries_by_field: &mut [Option<ArrayRef>],
+) -> Result<(), Status> {
+ let ipc_batch = message.header_as_dictionary_batch().ok_or_else(|| {
+ Status::internal("Could not parse message header as dictionary batch")
+ })?;
+
+ let dictionary_batch_result =
+ reader::read_dictionary(data_body, ipc_batch, &schema_ref, dictionaries_by_field);
+ dictionary_batch_result.map_err(|e| {
+ Status::internal(format!("Could not convert to Dictionary: {:?}", e))
+ })
+}
+
+async fn save_uploaded_chunks(
+ uploaded_chunks: Arc<Mutex<HashMap<String, IntegrationDataset>>>,
+ schema_ref: Arc<Schema>,
+ mut input_stream: Streaming<FlightData>,
+ mut response_tx: mpsc::Sender<Result<PutResult, Status>>,
+ schema: Schema,
+ key: String,
+) -> Result<(), Status> {
+ let mut chunks = vec![];
+ let mut uploaded_chunks = uploaded_chunks.lock().await;
+
+ let mut dictionaries_by_field = vec![None; schema_ref.fields().len()];
+
+ while let Some(Ok(data)) = input_stream.next().await {
+ let message = arrow::ipc::root_as_message(&data.data_header[..])
+ .map_err(|e| Status::internal(format!("Could not parse message: {:?}", e)))?;
+
+ match message.header_type() {
+ ipc::MessageHeader::Schema => {
+ return Err(Status::internal(
+ "Not expecting a schema when messages are read",
+ ))
+ }
+ ipc::MessageHeader::RecordBatch => {
+ send_app_metadata(&mut response_tx, &data.app_metadata).await?;
+
+ let batch = record_batch_from_message(
+ message,
+ &data.data_body,
+ schema_ref.clone(),
+ &dictionaries_by_field,
+ )
+ .await?;
+
+ chunks.push(batch);
+ }
+ ipc::MessageHeader::DictionaryBatch => {
+ dictionary_from_message(
+ message,
+ &data.data_body,
+ schema_ref.clone(),
+ &mut dictionaries_by_field,
+ )
+ .await?;
+ }
+ t => {
+ return Err(Status::internal(format!(
+ "Reading types other than record batches not yet supported, \
+ unable to read {:?}",
+ t
+ )));
+ }
+ }
+ }
+
+ let dataset = IntegrationDataset { schema, chunks };
+ uploaded_chunks.insert(key, dataset);
+
+ Ok(())
+}
diff --git a/rust/integration-testing/src/flight_server_scenarios/middleware.rs b/rust/integration-testing/src/flight_server_scenarios/middleware.rs
new file mode 100644
index 0000000..12421bc
--- /dev/null
+++ b/rust/integration-testing/src/flight_server_scenarios/middleware.rs
@@ -0,0 +1,150 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::pin::Pin;
+
+use arrow_flight::{
+ flight_descriptor::DescriptorType, flight_service_server::FlightService,
+ flight_service_server::FlightServiceServer, Action, ActionType, Criteria, Empty,
+ FlightData, FlightDescriptor, FlightInfo, HandshakeRequest, HandshakeResponse,
+ PutResult, SchemaResult, Ticket,
+};
+use futures::Stream;
+use tonic::{transport::Server, Request, Response, Status, Streaming};
+
+type TonicStream<T> = Pin<Box<dyn Stream<Item = T> + Send + Sync + 'static>>;
+
+type Error = Box<dyn std::error::Error + Send + Sync + 'static>;
+type Result<T = (), E = Error> = std::result::Result<T, E>;
+
+pub async fn scenario_setup(port: &str) -> Result {
+ let (mut listener, _) = super::listen_on(port).await?;
+
+ let service = MiddlewareScenarioImpl {};
+ let svc = FlightServiceServer::new(service);
+
+ Server::builder()
+ .add_service(svc)
+ .serve_with_incoming(listener.incoming())
+ .await?;
+ Ok(())
+}
+
+#[derive(Clone, Default)]
+pub struct MiddlewareScenarioImpl {}
+
+#[tonic::async_trait]
+impl FlightService for MiddlewareScenarioImpl {
+ type HandshakeStream = TonicStream<Result<HandshakeResponse, Status>>;
+ type ListFlightsStream = TonicStream<Result<FlightInfo, Status>>;
+ type DoGetStream = TonicStream<Result<FlightData, Status>>;
+ type DoPutStream = TonicStream<Result<PutResult, Status>>;
+ type DoActionStream = TonicStream<Result<arrow_flight::Result, Status>>;
+ type ListActionsStream = TonicStream<Result<ActionType, Status>>;
+ type DoExchangeStream = TonicStream<Result<FlightData, Status>>;
+
+ async fn get_schema(
+ &self,
+ _request: Request<FlightDescriptor>,
+ ) -> Result<Response<SchemaResult>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn do_get(
+ &self,
+ _request: Request<Ticket>,
+ ) -> Result<Response<Self::DoGetStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn handshake(
+ &self,
+ _request: Request<Streaming<HandshakeRequest>>,
+ ) -> Result<Response<Self::HandshakeStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn list_flights(
+ &self,
+ _request: Request<Criteria>,
+ ) -> Result<Response<Self::ListFlightsStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn get_flight_info(
+ &self,
+ request: Request<FlightDescriptor>,
+ ) -> Result<Response<FlightInfo>, Status> {
+ let middleware_header = request.metadata().get("x-middleware").cloned();
+
+ let descriptor = request.into_inner();
+
+ if descriptor.r#type == DescriptorType::Cmd as i32 && descriptor.cmd == b"success"
+ {
+ // Return a fake location - the test doesn't read it
+ let endpoint = super::endpoint("foo", "grpc+tcp://localhost:10010");
+
+ let info = FlightInfo {
+ flight_descriptor: Some(descriptor),
+ endpoint: vec![endpoint],
+ ..Default::default()
+ };
+
+ let mut response = Response::new(info);
+ if let Some(value) = middleware_header {
+ response.metadata_mut().insert("x-middleware", value);
+ }
+
+ return Ok(response);
+ }
+
+ let mut status = Status::unknown("Unknown");
+ if let Some(value) = middleware_header {
+ status.metadata_mut().insert("x-middleware", value);
+ }
+
+ Err(status)
+ }
+
+ async fn do_put(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoPutStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn do_action(
+ &self,
+ _request: Request<Action>,
+ ) -> Result<Response<Self::DoActionStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn list_actions(
+ &self,
+ _request: Request<Empty>,
+ ) -> Result<Response<Self::ListActionsStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+
+ async fn do_exchange(
+ &self,
+ _request: Request<Streaming<FlightData>>,
+ ) -> Result<Response<Self::DoExchangeStream>, Status> {
+ Err(Status::unimplemented("Not yet implemented"))
+ }
+}
diff --git a/rust/integration-testing/src/lib.rs b/rust/integration-testing/src/lib.rs
index 596017a..1a96f44 100644
--- a/rust/integration-testing/src/lib.rs
+++ b/rust/integration-testing/src/lib.rs
@@ -16,3 +16,586 @@
// under the License.
//! Common code used in the integration test binaries
+
+use hex::decode;
+use serde_json::Value;
+
+use arrow::util::integration_util::ArrowJsonBatch;
+
+use arrow::array::*;
+use arrow::datatypes::{DataType, DateUnit, Field, IntervalUnit, Schema};
+use arrow::error::{ArrowError, Result};
+use arrow::record_batch::RecordBatch;
+use arrow::{
+ buffer::Buffer,
+ buffer::MutableBuffer,
+ datatypes::ToByteSlice,
+ util::{bit_util, integration_util::*},
+};
+
+use std::collections::HashMap;
+use std::fs::File;
+use std::io::BufReader;
+use std::sync::Arc;
+
+/// The expected username for the basic auth integration test.
+pub const AUTH_USERNAME: &str = "arrow";
+/// The expected password for the basic auth integration test.
+pub const AUTH_PASSWORD: &str = "flight";
+
+pub mod flight_client_scenarios;
+pub mod flight_server_scenarios;
+
+pub struct ArrowFile {
+ pub schema: Schema,
+ // we can evolve this into a concrete Arrow type
+ // this is temporarily not being read from
+ pub _dictionaries: HashMap<i64, ArrowJsonDictionaryBatch>,
+ pub batches: Vec<RecordBatch>,
+}
+
+pub fn read_json_file(json_name: &str) -> Result<ArrowFile> {
+ let json_file = File::open(json_name)?;
+ let reader = BufReader::new(json_file);
+ let arrow_json: Value = serde_json::from_reader(reader).unwrap();
+ let schema = Schema::from(&arrow_json["schema"])?;
+ // read dictionaries
+ let mut dictionaries = HashMap::new();
+ if let Some(dicts) = arrow_json.get("dictionaries") {
+ for d in dicts
+ .as_array()
+ .expect("Unable to get dictionaries as array")
+ {
+ let json_dict: ArrowJsonDictionaryBatch = serde_json::from_value(d.clone())
+ .expect("Unable to get dictionary from JSON");
+ // TODO: convert to a concrete Arrow type
+ dictionaries.insert(json_dict.id, json_dict);
+ }
+ }
+
+ let mut batches = vec![];
+ for b in arrow_json["batches"].as_array().unwrap() {
+ let json_batch: ArrowJsonBatch = serde_json::from_value(b.clone()).unwrap();
+ let batch = record_batch_from_json(&schema, json_batch, Some(&dictionaries))?;
+ batches.push(batch);
+ }
+ Ok(ArrowFile {
+ schema,
+ _dictionaries: dictionaries,
+ batches,
+ })
+}
+
+fn record_batch_from_json(
+ schema: &Schema,
+ json_batch: ArrowJsonBatch,
+ json_dictionaries: Option<&HashMap<i64, ArrowJsonDictionaryBatch>>,
+) -> Result<RecordBatch> {
+ let mut columns = vec![];
+
+ for (field, json_col) in schema.fields().iter().zip(json_batch.columns) {
+ let col = array_from_json(field, json_col, json_dictionaries)?;
+ columns.push(col);
+ }
+
+ RecordBatch::try_new(Arc::new(schema.clone()), columns)
+}
+
+/// Construct an Arrow array from a partially typed JSON column
+fn array_from_json(
+ field: &Field,
+ json_col: ArrowJsonColumn,
+ dictionaries: Option<&HashMap<i64, ArrowJsonDictionaryBatch>>,
+) -> Result<ArrayRef> {
+ match field.data_type() {
+ DataType::Null => Ok(Arc::new(NullArray::new(json_col.count))),
+ DataType::Boolean => {
+ let mut b = BooleanBuilder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(value.as_bool().unwrap()),
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::Int8 => {
+ let mut b = Int8Builder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(value.as_i64().ok_or_else(|| {
+ ArrowError::JsonError(format!(
+ "Unable to get {:?} as int64",
+ value
+ ))
+ })? as i8),
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::Int16 => {
+ let mut b = Int16Builder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(value.as_i64().unwrap() as i16),
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::Int32
+ | DataType::Date32(DateUnit::Day)
+ | DataType::Time32(_)
+ | DataType::Interval(IntervalUnit::YearMonth) => {
+ let mut b = Int32Builder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(value.as_i64().unwrap() as i32),
+ _ => b.append_null(),
+ }?;
+ }
+ let array = Arc::new(b.finish()) as ArrayRef;
+ arrow::compute::cast(&array, field.data_type())
+ }
+ DataType::Int64
+ | DataType::Date64(DateUnit::Millisecond)
+ | DataType::Time64(_)
+ | DataType::Timestamp(_, _)
+ | DataType::Duration(_)
+ | DataType::Interval(IntervalUnit::DayTime) => {
+ let mut b = Int64Builder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(match value {
+ Value::Number(n) => n.as_i64().unwrap(),
+ Value::String(s) => {
+ s.parse().expect("Unable to parse string as i64")
+ }
+ _ => panic!("Unable to parse {:?} as number", value),
+ }),
+ _ => b.append_null(),
+ }?;
+ }
+ let array = Arc::new(b.finish()) as ArrayRef;
+ arrow::compute::cast(&array, field.data_type())
+ }
+ DataType::UInt8 => {
+ let mut b = UInt8Builder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(value.as_u64().unwrap() as u8),
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::UInt16 => {
+ let mut b = UInt16Builder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(value.as_u64().unwrap() as u16),
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::UInt32 => {
+ let mut b = UInt32Builder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(value.as_u64().unwrap() as u32),
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::UInt64 => {
+ let mut b = UInt64Builder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(
+ value
+ .as_str()
+ .unwrap()
+ .parse()
+ .expect("Unable to parse string as u64"),
+ ),
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::Float32 => {
+ let mut b = Float32Builder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(value.as_f64().unwrap() as f32),
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::Float64 => {
+ let mut b = Float64Builder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(value.as_f64().unwrap()),
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::Binary => {
+ let mut b = BinaryBuilder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => {
+ let v = decode(value.as_str().unwrap()).unwrap();
+ b.append_value(&v)
+ }
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::LargeBinary => {
+ let mut b = LargeBinaryBuilder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => {
+ let v = decode(value.as_str().unwrap()).unwrap();
+ b.append_value(&v)
+ }
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::Utf8 => {
+ let mut b = StringBuilder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(value.as_str().unwrap()),
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::LargeUtf8 => {
+ let mut b = LargeStringBuilder::new(json_col.count);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => b.append_value(value.as_str().unwrap()),
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::FixedSizeBinary(len) => {
+ let mut b = FixedSizeBinaryBuilder::new(json_col.count, *len);
+ for (is_valid, value) in json_col
+ .validity
+ .as_ref()
+ .unwrap()
+ .iter()
+ .zip(json_col.data.unwrap())
+ {
+ match is_valid {
+ 1 => {
+ let v = hex::decode(value.as_str().unwrap()).unwrap();
+ b.append_value(&v)
+ }
+ _ => b.append_null(),
+ }?;
+ }
+ Ok(Arc::new(b.finish()))
+ }
+ DataType::List(child_field) => {
+ let null_buf = create_null_buf(&json_col);
+ let children = json_col.children.clone().unwrap();
+ let child_array = array_from_json(
+ &child_field,
+ children.get(0).unwrap().clone(),
+ dictionaries,
+ )?;
+ let offsets: Vec<i32> = json_col
+ .offset
+ .unwrap()
+ .iter()
+ .map(|v| v.as_i64().unwrap() as i32)
+ .collect();
+ let list_data = ArrayData::builder(field.data_type().clone())
+ .len(json_col.count)
+ .offset(0)
+ .add_buffer(Buffer::from(&offsets.to_byte_slice()))
+ .add_child_data(child_array.data())
+ .null_bit_buffer(null_buf)
+ .build();
+ Ok(Arc::new(ListArray::from(list_data)))
+ }
+ DataType::LargeList(child_field) => {
+ let null_buf = create_null_buf(&json_col);
+ let children = json_col.children.clone().unwrap();
+ let child_array = array_from_json(
+ &child_field,
+ children.get(0).unwrap().clone(),
+ dictionaries,
+ )?;
+ let offsets: Vec<i64> = json_col
+ .offset
+ .unwrap()
+ .iter()
+ .map(|v| match v {
+ Value::Number(n) => n.as_i64().unwrap(),
+ Value::String(s) => s.parse::<i64>().unwrap(),
+ _ => panic!("64-bit offset must be either string or number"),
+ })
+ .collect();
+ let list_data = ArrayData::builder(field.data_type().clone())
+ .len(json_col.count)
+ .offset(0)
+ .add_buffer(Buffer::from(&offsets.to_byte_slice()))
+ .add_child_data(child_array.data())
+ .null_bit_buffer(null_buf)
+ .build();
+ Ok(Arc::new(LargeListArray::from(list_data)))
+ }
+ DataType::FixedSizeList(child_field, _) => {
+ let children = json_col.children.clone().unwrap();
+ let child_array = array_from_json(
+ &child_field,
+ children.get(0).unwrap().clone(),
+ dictionaries,
+ )?;
+ let null_buf = create_null_buf(&json_col);
+ let list_data = ArrayData::builder(field.data_type().clone())
+ .len(json_col.count)
+ .add_child_data(child_array.data())
+ .null_bit_buffer(null_buf)
+ .build();
+ Ok(Arc::new(FixedSizeListArray::from(list_data)))
+ }
+ DataType::Struct(fields) => {
+ // construct struct with null data
+ let null_buf = create_null_buf(&json_col);
+ let mut array_data = ArrayData::builder(field.data_type().clone())
+ .len(json_col.count)
+ .null_bit_buffer(null_buf);
+
+ for (field, col) in fields.iter().zip(json_col.children.unwrap()) {
+ let array = array_from_json(field, col, dictionaries)?;
+ array_data = array_data.add_child_data(array.data());
+ }
+
+ let array = StructArray::from(array_data.build());
+ Ok(Arc::new(array))
+ }
+ DataType::Dictionary(key_type, value_type) => {
+ let dict_id = field.dict_id().ok_or_else(|| {
+ ArrowError::JsonError(format!(
+ "Unable to find dict_id for field {:?}",
+ field
+ ))
+ })?;
+ // find dictionary
+ let dictionary = dictionaries
+ .ok_or_else(|| {
+ ArrowError::JsonError(format!(
+ "Unable to find any dictionaries for field {:?}",
+ field
+ ))
+ })?
+ .get(&dict_id);
+ match dictionary {
+ Some(dictionary) => dictionary_array_from_json(
+ field, json_col, key_type, value_type, dictionary,
+ ),
+ None => Err(ArrowError::JsonError(format!(
+ "Unable to find dictionary for field {:?}",
+ field
+ ))),
+ }
+ }
+ t => Err(ArrowError::JsonError(format!(
+ "data type {:?} not supported",
+ t
+ ))),
+ }
+}
+
+fn dictionary_array_from_json(
+ field: &Field,
+ json_col: ArrowJsonColumn,
+ dict_key: &DataType,
+ dict_value: &DataType,
+ dictionary: &ArrowJsonDictionaryBatch,
+) -> Result<ArrayRef> {
+ match dict_key {
+ DataType::Int8
+ | DataType::Int16
+ | DataType::Int32
+ | DataType::Int64
+ | DataType::UInt8
+ | DataType::UInt16
+ | DataType::UInt32
+ | DataType::UInt64 => {
+ let null_buf = create_null_buf(&json_col);
+
+ // build the key data into a buffer, then construct values separately
+ let key_field = Field::new_dict(
+ "key",
+ dict_key.clone(),
+ field.is_nullable(),
+ field
+ .dict_id()
+ .expect("Dictionary fields must have a dict_id value"),
+ field
+ .dict_is_ordered()
+ .expect("Dictionary fields must have a dict_is_ordered value"),
+ );
+ let keys = array_from_json(&key_field, json_col, None)?;
+ // note: not enough info on nullability of dictionary
+ let value_field = Field::new("value", dict_value.clone(), true);
+ println!("dictionary value type: {:?}", dict_value);
+ let values =
+ array_from_json(&value_field, dictionary.data.columns[0].clone(), None)?;
+
+ // convert key and value to dictionary data
+ let dict_data = ArrayData::builder(field.data_type().clone())
+ .len(keys.len())
+ .add_buffer(keys.data().buffers()[0].clone())
+ .null_bit_buffer(null_buf)
+ .add_child_data(values.data())
+ .build();
+
+ let array = match dict_key {
+ DataType::Int8 => {
+ Arc::new(Int8DictionaryArray::from(dict_data)) as ArrayRef
+ }
+ DataType::Int16 => Arc::new(Int16DictionaryArray::from(dict_data)),
+ DataType::Int32 => Arc::new(Int32DictionaryArray::from(dict_data)),
+ DataType::Int64 => Arc::new(Int64DictionaryArray::from(dict_data)),
+ DataType::UInt8 => Arc::new(UInt8DictionaryArray::from(dict_data)),
+ DataType::UInt16 => Arc::new(UInt16DictionaryArray::from(dict_data)),
+ DataType::UInt32 => Arc::new(UInt32DictionaryArray::from(dict_data)),
+ DataType::UInt64 => Arc::new(UInt64DictionaryArray::from(dict_data)),
+ _ => unreachable!(),
+ };
+ Ok(array)
+ }
+ _ => Err(ArrowError::JsonError(format!(
+ "Dictionary key type {:?} not supported",
+ dict_key
+ ))),
+ }
+}
+
+/// A helper to create a null buffer from a Vec<bool>
+fn create_null_buf(json_col: &ArrowJsonColumn) -> Buffer {
+ let num_bytes = bit_util::ceil(json_col.count, 8);
+ let mut null_buf = MutableBuffer::new(num_bytes).with_bitset(num_bytes, false);
+ json_col
+ .validity
+ .clone()
+ .unwrap()
+ .iter()
+ .enumerate()
+ .for_each(|(i, v)| {
+ let null_slice = null_buf.as_slice_mut();
+ if *v != 0 {
+ bit_util::set_bit(null_slice, i);
+ }
+ });
+ null_buf.into()
+}