You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2022/10/30 14:17:33 UTC
[arrow-datafusion] branch master updated: Add CI check to verify that benchmark queries return the expected results (#4015)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 665107043 Add CI check to verify that benchmark queries return the expected results (#4015)
665107043 is described below
commit 66510704369f9eb5a7d973a530148aee890c1191
Author: Andy Grove <an...@gmail.com>
AuthorDate: Sun Oct 30 08:17:28 2022 -0600
Add CI check to verify that benchmark queries return the expected results (#4015)
---
.github/workflows/rust.yml | 41 +++
Cargo.toml | 29 +-
benchmarks/Cargo.toml | 2 +-
benchmarks/src/bin/tpch.rs | 770 +++++++++++----------------------------------
benchmarks/src/lib.rs | 18 ++
benchmarks/src/tpch.rs | 512 ++++++++++++++++++++++++++++++
6 files changed, 766 insertions(+), 606 deletions(-)
diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml
index 4c6daf755..ebb1c07a4 100644
--- a/.github/workflows/rust.yml
+++ b/.github/workflows/rust.yml
@@ -117,6 +117,47 @@ jobs:
- name: Verify Working Directory Clean
run: git diff --exit-code
+ # verify that the benchmark queries return the correct results
+ verify-benchmark-results:
+ name: verify benchmark results (amd64)
+ needs: [linux-build-lib]
+ runs-on: ubuntu-latest
+ container:
+ image: amd64/rust
+ env:
+ # Disable full debug symbol generation to speed up CI build and keep memory down
+ # "1" means line tables only, which is useful for panic tracebacks.
+ RUSTFLAGS: "-C debuginfo=1"
+ steps:
+ - uses: actions/checkout@v3
+ with:
+ submodules: true
+ - name: Cache Cargo
+ uses: actions/cache@v3
+ with:
+ path: /github/home/.cargo
+ # this key equals the ones on `linux-build-lib` for re-use
+ key: cargo-cache-
+ - name: Setup Rust toolchain
+ uses: ./.github/actions/setup-builder
+ with:
+ rust-version: stable
+ - name: Generate benchmark data and expected query results
+ run: |
+ mkdir -p benchmarks/data/answers
+ git clone https://github.com/databricks/tpch-dbgen.git
+ cd tpch-dbgen
+ make
+ ./dbgen -f -s 1
+ mv *.tbl ../benchmarks/data
+ mv ./answers/* ../benchmarks/data/answers/
+ - name: Verify that benchmark queries return expected results
+ run: |
+ export TPCH_DATA=`pwd`/benchmarks/data
+ cargo test verify_q --profile release-nonlto --features=ci -- --test-threads=1
+ - name: Verify Working Directory Clean
+ run: git diff --exit-code
+
integration-test:
name: "Compare to postgres"
needs: [linux-build-lib]
diff --git a/Cargo.toml b/Cargo.toml
index ab3f427e4..36a9405b0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -16,21 +16,24 @@
# under the License.
[workspace]
-members = [
- "datafusion/common",
- "datafusion/core",
- "datafusion/expr",
- "datafusion/jit",
- "datafusion/optimizer",
- "datafusion/physical-expr",
- "datafusion/proto",
- "datafusion/row",
- "datafusion/sql",
- "datafusion-examples",
- "benchmarks",
-]
exclude = ["datafusion-cli"]
+members = ["datafusion/common", "datafusion/core", "datafusion/expr", "datafusion/jit", "datafusion/optimizer", "datafusion/physical-expr", "datafusion/proto", "datafusion/row", "datafusion/sql", "datafusion-examples", "benchmarks",
+]
[profile.release]
codegen-units = 1
lto = true
+
+# the release profile takes a long time to build so we can use this profile during development to save time
+# cargo build --profile release-nonlto
+[profile.release-nonlto]
+codegen-units = 16
+debug = false
+debug-assertions = false
+incremental = false
+inherits = "release"
+lto = false
+opt-level = 3
+overflow-checks = false
+panic = 'unwind'
+rpath = false
diff --git a/benchmarks/Cargo.toml b/benchmarks/Cargo.toml
index 7105a6033..8795a8611 100644
--- a/benchmarks/Cargo.toml
+++ b/benchmarks/Cargo.toml
@@ -24,10 +24,10 @@ authors = ["Apache Arrow <de...@arrow.apache.org>"]
homepage = "https://github.com/apache/arrow-datafusion"
repository = "https://github.com/apache/arrow-datafusion"
license = "Apache-2.0"
-publish = false
rust-version = "1.62"
[features]
+ci = []
default = ["mimalloc"]
simd = ["datafusion/simd"]
snmalloc = ["snmalloc-rs"]
diff --git a/benchmarks/src/bin/tpch.rs b/benchmarks/src/bin/tpch.rs
index 02de551f2..b9afe4d6a 100644
--- a/benchmarks/src/bin/tpch.rs
+++ b/benchmarks/src/bin/tpch.rs
@@ -18,7 +18,7 @@
//! Benchmark derived from TPC-H. This is not an official TPC-H benchmark.
use std::{
- fs::{self, File},
+ fs::File,
io::Write,
iter::Iterator,
path::{Path, PathBuf},
@@ -29,15 +29,9 @@ use std::{
use datafusion::datasource::{MemTable, TableProvider};
use datafusion::error::{DataFusionError, Result};
use datafusion::parquet::basic::Compression;
-use datafusion::parquet::file::properties::WriterProperties;
use datafusion::physical_plan::display::DisplayableExecutionPlan;
use datafusion::physical_plan::{collect, displayable};
use datafusion::prelude::*;
-use datafusion::{
- arrow::datatypes::{DataType, Field, Schema},
- datasource::file_format::{csv::CsvFormat, FileFormat},
- DATAFUSION_VERSION,
-};
use datafusion::{
arrow::record_batch::RecordBatch, datasource::file_format::parquet::ParquetFormat,
};
@@ -45,6 +39,11 @@ use datafusion::{
arrow::util::pretty,
datasource::listing::{ListingOptions, ListingTable, ListingTableConfig},
};
+use datafusion::{
+ datasource::file_format::{csv::CsvFormat, FileFormat},
+ DATAFUSION_VERSION,
+};
+use datafusion_benchmarks::tpch::*;
use datafusion::datasource::file_format::csv::DEFAULT_CSV_EXTENSION;
use datafusion::datasource::file_format::parquet::DEFAULT_PARQUET_EXTENSION;
@@ -145,10 +144,6 @@ enum TpchOpt {
Convert(ConvertOpt),
}
-const TABLES: &[&str] = &[
- "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region",
-];
-
#[tokio::main]
async fn main() -> Result<()> {
use BenchmarkSubCommandOpt::*;
@@ -158,7 +153,32 @@ async fn main() -> Result<()> {
TpchOpt::Benchmark(DataFusionBenchmark(opt)) => {
benchmark_datafusion(opt).await.map(|_| ())
}
- TpchOpt::Convert(opt) => convert_tbl(opt).await,
+ TpchOpt::Convert(opt) => {
+ let compression = match opt.compression.as_str() {
+ "none" => Compression::UNCOMPRESSED,
+ "snappy" => Compression::SNAPPY,
+ "brotli" => Compression::BROTLI,
+ "gzip" => Compression::GZIP,
+ "lz4" => Compression::LZ4,
+ "lz0" => Compression::LZO,
+ "zstd" => Compression::ZSTD,
+ other => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Invalid compression format: {}",
+ other
+ )));
+ }
+ };
+ convert_tbl(
+ opt.input_path.to_str().unwrap(),
+ opt.output_path.to_str().unwrap(),
+ &opt.file_format,
+ opt.partitions,
+ opt.batch_size,
+ compression,
+ )
+ .await
+ }
}
}
@@ -173,7 +193,7 @@ async fn benchmark_datafusion(opt: DataFusionBenchmarkOpt) -> Result<Vec<RecordB
let ctx = SessionContext::with_config(config);
// register tables
- for table in TABLES {
+ for table in TPCH_TABLES {
let table_provider = {
let mut session_state = ctx.state.write();
get_table(
@@ -264,38 +284,6 @@ fn write_summary_json(benchmark_run: &mut BenchmarkRun, path: &Path) -> Result<(
Ok(())
}
-/// Get the SQL statements from the specified query file
-fn get_query_sql(query: usize) -> Result<Vec<String>> {
- if query > 0 && query < 23 {
- let possibilities = vec![
- format!("queries/q{}.sql", query),
- format!("benchmarks/queries/q{}.sql", query),
- ];
- let mut errors = vec![];
- for filename in possibilities {
- match fs::read_to_string(&filename) {
- Ok(contents) => {
- return Ok(contents
- .split(';')
- .map(|s| s.trim())
- .filter(|s| !s.is_empty())
- .map(|s| s.to_string())
- .collect());
- }
- Err(e) => errors.push(format!("{}: {}", filename, e)),
- };
- }
- Err(DataFusionError::Plan(format!(
- "invalid query. Could not find query: {:?}",
- errors
- )))
- } else {
- Err(DataFusionError::Plan(
- "invalid query. Expected value between 1 and 22".to_owned(),
- ))
- }
-}
-
async fn execute_query(
ctx: &SessionContext,
sql: &str,
@@ -335,77 +323,6 @@ async fn execute_query(
Ok(result)
}
-async fn convert_tbl(opt: ConvertOpt) -> Result<()> {
- let output_root_path = Path::new(&opt.output_path);
- for table in TABLES {
- let start = Instant::now();
- let schema = get_schema(table);
-
- let input_path = format!("{}/{}.tbl", opt.input_path.to_str().unwrap(), table);
- let options = CsvReadOptions::new()
- .schema(&schema)
- .has_header(false)
- .delimiter(b'|')
- .file_extension(".tbl");
-
- let config = SessionConfig::new().with_batch_size(opt.batch_size);
- let ctx = SessionContext::with_config(config);
-
- // build plan to read the TBL file
- let mut csv = ctx.read_csv(&input_path, options).await?;
-
- // optionally, repartition the file
- if opt.partitions > 1 {
- csv = csv.repartition(Partitioning::RoundRobinBatch(opt.partitions))?
- }
-
- // create the physical plan
- let csv = csv.to_logical_plan()?;
- let csv = ctx.create_physical_plan(&csv).await?;
-
- let output_path = output_root_path.join(table);
- let output_path = output_path.to_str().unwrap().to_owned();
-
- println!(
- "Converting '{}' to {} files in directory '{}'",
- &input_path, &opt.file_format, &output_path
- );
- match opt.file_format.as_str() {
- "csv" => ctx.write_csv(csv, output_path).await?,
- "parquet" => {
- let compression = match opt.compression.as_str() {
- "none" => Compression::UNCOMPRESSED,
- "snappy" => Compression::SNAPPY,
- "brotli" => Compression::BROTLI,
- "gzip" => Compression::GZIP,
- "lz4" => Compression::LZ4,
- "lz0" => Compression::LZO,
- "zstd" => Compression::ZSTD,
- other => {
- return Err(DataFusionError::NotImplemented(format!(
- "Invalid compression format: {}",
- other
- )));
- }
- };
- let props = WriterProperties::builder()
- .set_compression(compression)
- .build();
- ctx.write_parquet(csv, output_path, Some(props)).await?
- }
- other => {
- return Err(DataFusionError::NotImplemented(format!(
- "Invalid output format: {}",
- other
- )));
- }
- }
- println!("Conversion completed in {} ms", start.elapsed().as_millis());
- }
-
- Ok(())
-}
-
async fn get_table(
ctx: &mut SessionState,
path: &str,
@@ -443,7 +360,7 @@ async fn get_table(
unimplemented!("Invalid file format '{}'", other);
}
};
- let schema = Arc::new(get_schema(table));
+ let schema = Arc::new(get_tpch_table_schema(table));
let options = ListingOptions {
format,
@@ -465,101 +382,6 @@ async fn get_table(
Ok(Arc::new(ListingTable::try_new(config)?))
}
-fn get_schema(table: &str) -> Schema {
- // note that the schema intentionally uses signed integers so that any generated Parquet
- // files can also be used to benchmark tools that only support signed integers, such as
- // Apache Spark
-
- match table {
- "part" => Schema::new(vec![
- Field::new("p_partkey", DataType::Int64, false),
- Field::new("p_name", DataType::Utf8, false),
- Field::new("p_mfgr", DataType::Utf8, false),
- Field::new("p_brand", DataType::Utf8, false),
- Field::new("p_type", DataType::Utf8, false),
- Field::new("p_size", DataType::Int32, false),
- Field::new("p_container", DataType::Utf8, false),
- Field::new("p_retailprice", DataType::Decimal128(15, 2), false),
- Field::new("p_comment", DataType::Utf8, false),
- ]),
-
- "supplier" => Schema::new(vec![
- Field::new("s_suppkey", DataType::Int64, false),
- Field::new("s_name", DataType::Utf8, false),
- Field::new("s_address", DataType::Utf8, false),
- Field::new("s_nationkey", DataType::Int64, false),
- Field::new("s_phone", DataType::Utf8, false),
- Field::new("s_acctbal", DataType::Decimal128(15, 2), false),
- Field::new("s_comment", DataType::Utf8, false),
- ]),
-
- "partsupp" => Schema::new(vec![
- Field::new("ps_partkey", DataType::Int64, false),
- Field::new("ps_suppkey", DataType::Int64, false),
- Field::new("ps_availqty", DataType::Int32, false),
- Field::new("ps_supplycost", DataType::Decimal128(15, 2), false),
- Field::new("ps_comment", DataType::Utf8, false),
- ]),
-
- "customer" => Schema::new(vec![
- Field::new("c_custkey", DataType::Int64, false),
- Field::new("c_name", DataType::Utf8, false),
- Field::new("c_address", DataType::Utf8, false),
- Field::new("c_nationkey", DataType::Int64, false),
- Field::new("c_phone", DataType::Utf8, false),
- Field::new("c_acctbal", DataType::Decimal128(15, 2), false),
- Field::new("c_mktsegment", DataType::Utf8, false),
- Field::new("c_comment", DataType::Utf8, false),
- ]),
-
- "orders" => Schema::new(vec![
- Field::new("o_orderkey", DataType::Int64, false),
- Field::new("o_custkey", DataType::Int64, false),
- Field::new("o_orderstatus", DataType::Utf8, false),
- Field::new("o_totalprice", DataType::Decimal128(15, 2), false),
- Field::new("o_orderdate", DataType::Date32, false),
- Field::new("o_orderpriority", DataType::Utf8, false),
- Field::new("o_clerk", DataType::Utf8, false),
- Field::new("o_shippriority", DataType::Int32, false),
- Field::new("o_comment", DataType::Utf8, false),
- ]),
-
- "lineitem" => Schema::new(vec![
- Field::new("l_orderkey", DataType::Int64, false),
- Field::new("l_partkey", DataType::Int64, false),
- Field::new("l_suppkey", DataType::Int64, false),
- Field::new("l_linenumber", DataType::Int32, false),
- Field::new("l_quantity", DataType::Decimal128(15, 2), false),
- Field::new("l_extendedprice", DataType::Decimal128(15, 2), false),
- Field::new("l_discount", DataType::Decimal128(15, 2), false),
- Field::new("l_tax", DataType::Decimal128(15, 2), false),
- Field::new("l_returnflag", DataType::Utf8, false),
- Field::new("l_linestatus", DataType::Utf8, false),
- Field::new("l_shipdate", DataType::Date32, false),
- Field::new("l_commitdate", DataType::Date32, false),
- Field::new("l_receiptdate", DataType::Date32, false),
- Field::new("l_shipinstruct", DataType::Utf8, false),
- Field::new("l_shipmode", DataType::Utf8, false),
- Field::new("l_comment", DataType::Utf8, false),
- ]),
-
- "nation" => Schema::new(vec![
- Field::new("n_nationkey", DataType::Int64, false),
- Field::new("n_name", DataType::Utf8, false),
- Field::new("n_regionkey", DataType::Int64, false),
- Field::new("n_comment", DataType::Utf8, false),
- ]),
-
- "region" => Schema::new(vec![
- Field::new("r_regionkey", DataType::Int64, false),
- Field::new("r_name", DataType::Utf8, false),
- Field::new("r_comment", DataType::Utf8, false),
- ]),
-
- _ => unimplemented!(),
- }
-}
-
#[derive(Debug, Serialize)]
struct BenchmarkRun {
/// Benchmark crate version
@@ -611,43 +433,10 @@ struct QueryResult {
#[cfg(test)]
mod tests {
use super::*;
- use std::env;
+ use datafusion::sql::TableReference;
use std::io::{BufRead, BufReader};
- use std::ops::{Div, Mul};
use std::sync::Arc;
- use datafusion::arrow::array::*;
- use datafusion::arrow::util::display::array_value_to_string;
- use datafusion::logical_expr::expr::Cast;
- use datafusion::logical_expr::Expr;
- use datafusion::logical_expr::Expr::ScalarFunction;
- use datafusion::sql::TableReference;
-
- const QUERY_LIMIT: [Option<usize>; 22] = [
- None,
- Some(100),
- Some(10),
- None,
- None,
- None,
- None,
- None,
- None,
- Some(20),
- None,
- None,
- None,
- None,
- None,
- None,
- None,
- Some(100),
- None,
- None,
- Some(100),
- None,
- ];
-
#[tokio::test]
async fn q1_expected_plan() -> Result<()> {
expected_plan(1).await
@@ -770,9 +559,9 @@ mod tests {
async fn expected_plan(query: usize) -> Result<()> {
let ctx = SessionContext::new();
- for table in TABLES {
+ for table in TPCH_TABLES {
let table = table.to_string();
- let schema = get_schema(&table);
+ let schema = get_tpch_table_schema(&table);
let mem_table = MemTable::try_new(Arc::new(schema), vec![])?;
ctx.register_table(
TableReference::from(table.as_str()),
@@ -829,113 +618,140 @@ mod tests {
Ok(str)
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q1() -> Result<()> {
+ async fn verify_q1() -> Result<()> {
verify_query(1).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q2() -> Result<()> {
+ async fn verify_q2() -> Result<()> {
verify_query(2).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q3() -> Result<()> {
+ async fn verify_q3() -> Result<()> {
verify_query(3).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q4() -> Result<()> {
+ async fn verify_q4() -> Result<()> {
verify_query(4).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q5() -> Result<()> {
+ async fn verify_q5() -> Result<()> {
verify_query(5).await
}
+ #[cfg(feature = "ci")]
+ #[ignore] // https://github.com/apache/arrow-datafusion/issues/4024
#[tokio::test]
- async fn q6() -> Result<()> {
+ async fn verify_q6() -> Result<()> {
verify_query(6).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q7() -> Result<()> {
+ async fn verify_q7() -> Result<()> {
verify_query(7).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q8() -> Result<()> {
+ async fn verify_q8() -> Result<()> {
verify_query(8).await
}
+ #[cfg(feature = "ci")]
+ #[ignore] // TODO produces correct result but has rounding error
#[tokio::test]
- async fn q9() -> Result<()> {
+ async fn verify_q9() -> Result<()> {
verify_query(9).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q10() -> Result<()> {
+ async fn verify_q10() -> Result<()> {
verify_query(10).await
}
+ #[cfg(feature = "ci")]
+ #[ignore] // https://github.com/apache/arrow-datafusion/issues/4023
#[tokio::test]
- async fn q11() -> Result<()> {
+ async fn verify_q11() -> Result<()> {
verify_query(11).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q12() -> Result<()> {
+ async fn verify_q12() -> Result<()> {
verify_query(12).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q13() -> Result<()> {
+ async fn verify_q13() -> Result<()> {
verify_query(13).await
}
+ #[cfg(feature = "ci")]
+ #[ignore] // https://github.com/apache/arrow-datafusion/issues/4025
#[tokio::test]
- async fn q14() -> Result<()> {
+ async fn verify_q14() -> Result<()> {
verify_query(14).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q15() -> Result<()> {
+ async fn verify_q15() -> Result<()> {
verify_query(15).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q16() -> Result<()> {
+ async fn verify_q16() -> Result<()> {
verify_query(16).await
}
+ #[cfg(feature = "ci")]
+ #[ignore] // https://github.com/apache/arrow-datafusion/issues/4026
#[tokio::test]
- async fn q17() -> Result<()> {
+ async fn verify_q17() -> Result<()> {
verify_query(17).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q18() -> Result<()> {
+ async fn verify_q18() -> Result<()> {
verify_query(18).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q19() -> Result<()> {
+ async fn verify_q19() -> Result<()> {
verify_query(19).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q20() -> Result<()> {
+ async fn verify_q20() -> Result<()> {
verify_query(20).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q21() -> Result<()> {
+ async fn verify_q21() -> Result<()> {
verify_query(21).await
}
+ #[cfg(feature = "ci")]
#[tokio::test]
- async fn q22() -> Result<()> {
+ async fn verify_q22() -> Result<()> {
verify_query(22).await
}
@@ -1049,253 +865,6 @@ mod tests {
run_query(22).await
}
- /// Specialised String representation
- fn col_str(column: &ArrayRef, row_index: usize) -> String {
- if column.is_null(row_index) {
- return "NULL".to_string();
- }
-
- array_value_to_string(column, row_index).unwrap()
- }
-
- /// Converts the results into a 2d array of strings, `result[row][column]`
- /// Special cases nulls to NULL for testing
- fn result_vec(results: &[RecordBatch]) -> Vec<Vec<String>> {
- let mut result = vec![];
- for batch in results {
- for row_index in 0..batch.num_rows() {
- let row_vec = batch
- .columns()
- .iter()
- .map(|column| col_str(column, row_index))
- .collect();
- result.push(row_vec);
- }
- }
- result
- }
-
- fn get_answer_schema(n: usize) -> Schema {
- match n {
- 1 => Schema::new(vec![
- Field::new("l_returnflag", DataType::Utf8, true),
- Field::new("l_linestatus", DataType::Utf8, true),
- Field::new("sum_qty", DataType::Decimal128(15, 2), true),
- Field::new("sum_base_price", DataType::Decimal128(15, 2), true),
- Field::new("sum_disc_price", DataType::Decimal128(15, 2), true),
- Field::new("sum_charge", DataType::Decimal128(15, 2), true),
- Field::new("avg_qty", DataType::Decimal128(15, 2), true),
- Field::new("avg_price", DataType::Decimal128(15, 2), true),
- Field::new("avg_disc", DataType::Decimal128(15, 2), true),
- Field::new("count_order", DataType::Int64, true),
- ]),
-
- 2 => Schema::new(vec![
- Field::new("s_acctbal", DataType::Decimal128(15, 2), true),
- Field::new("s_name", DataType::Utf8, true),
- Field::new("n_name", DataType::Utf8, true),
- Field::new("p_partkey", DataType::Int64, true),
- Field::new("p_mfgr", DataType::Utf8, true),
- Field::new("s_address", DataType::Utf8, true),
- Field::new("s_phone", DataType::Utf8, true),
- Field::new("s_comment", DataType::Utf8, true),
- ]),
-
- 3 => Schema::new(vec![
- Field::new("l_orderkey", DataType::Int64, true),
- Field::new("revenue", DataType::Decimal128(15, 2), true),
- Field::new("o_orderdate", DataType::Date32, true),
- Field::new("o_shippriority", DataType::Int32, true),
- ]),
-
- 4 => Schema::new(vec![
- Field::new("o_orderpriority", DataType::Utf8, true),
- Field::new("order_count", DataType::Int64, true),
- ]),
-
- 5 => Schema::new(vec![
- Field::new("n_name", DataType::Utf8, true),
- Field::new("revenue", DataType::Decimal128(15, 2), true),
- ]),
-
- 6 => Schema::new(vec![Field::new(
- "revenue",
- DataType::Decimal128(15, 2),
- true,
- )]),
-
- 7 => Schema::new(vec![
- Field::new("supp_nation", DataType::Utf8, true),
- Field::new("cust_nation", DataType::Utf8, true),
- Field::new("l_year", DataType::Int32, true),
- Field::new("revenue", DataType::Decimal128(15, 2), true),
- ]),
-
- 8 => Schema::new(vec![
- Field::new("o_year", DataType::Int32, true),
- Field::new("mkt_share", DataType::Decimal128(15, 2), true),
- ]),
-
- 9 => Schema::new(vec![
- Field::new("nation", DataType::Utf8, true),
- Field::new("o_year", DataType::Int32, true),
- Field::new("sum_profit", DataType::Decimal128(15, 2), true),
- ]),
-
- 10 => Schema::new(vec![
- Field::new("c_custkey", DataType::Int64, true),
- Field::new("c_name", DataType::Utf8, true),
- Field::new("revenue", DataType::Decimal128(15, 2), true),
- Field::new("c_acctbal", DataType::Decimal128(15, 2), true),
- Field::new("n_name", DataType::Utf8, true),
- Field::new("c_address", DataType::Utf8, true),
- Field::new("c_phone", DataType::Utf8, true),
- Field::new("c_comment", DataType::Utf8, true),
- ]),
-
- 11 => Schema::new(vec![
- Field::new("ps_partkey", DataType::Int64, true),
- Field::new("value", DataType::Decimal128(15, 2), true),
- ]),
-
- 12 => Schema::new(vec![
- Field::new("l_shipmode", DataType::Utf8, true),
- Field::new("high_line_count", DataType::Int64, true),
- Field::new("low_line_count", DataType::Int64, true),
- ]),
-
- 13 => Schema::new(vec![
- Field::new("c_count", DataType::Int64, true),
- Field::new("custdist", DataType::Int64, true),
- ]),
-
- 14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]),
-
- 15 => Schema::new(vec![
- Field::new("s_suppkey", DataType::Int64, true),
- Field::new("s_name", DataType::Utf8, true),
- Field::new("s_address", DataType::Utf8, true),
- Field::new("s_phone", DataType::Utf8, true),
- Field::new("total_revenue", DataType::Decimal128(15, 2), true),
- ]),
-
- 16 => Schema::new(vec![
- Field::new("p_brand", DataType::Utf8, true),
- Field::new("p_type", DataType::Utf8, true),
- Field::new("p_size", DataType::Int32, true),
- Field::new("supplier_cnt", DataType::Int64, true),
- ]),
-
- 17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]),
-
- 18 => Schema::new(vec![
- Field::new("c_name", DataType::Utf8, true),
- Field::new("c_custkey", DataType::Int64, true),
- Field::new("o_orderkey", DataType::Int64, true),
- Field::new("o_orderdate", DataType::Date32, true),
- Field::new("o_totalprice", DataType::Decimal128(15, 2), true),
- Field::new("sum_l_quantity", DataType::Decimal128(15, 2), true),
- ]),
-
- 19 => Schema::new(vec![Field::new(
- "revenue",
- DataType::Decimal128(15, 2),
- true,
- )]),
-
- 20 => Schema::new(vec![
- Field::new("s_name", DataType::Utf8, true),
- Field::new("s_address", DataType::Utf8, true),
- ]),
-
- 21 => Schema::new(vec![
- Field::new("s_name", DataType::Utf8, true),
- Field::new("numwait", DataType::Int64, true),
- ]),
-
- 22 => Schema::new(vec![
- Field::new("cntrycode", DataType::Utf8, true),
- Field::new("numcust", DataType::Int64, true),
- Field::new("totacctbal", DataType::Decimal128(15, 2), true),
- ]),
-
- _ => unimplemented!(),
- }
- }
-
- // convert expected schema to all utf8 so columns can be read as strings to be parsed separately
- // this is due to the fact that the csv parser cannot handle leading/trailing spaces
- fn string_schema(schema: Schema) -> Schema {
- Schema::new(
- schema
- .fields()
- .iter()
- .map(|field| {
- Field::new(
- Field::name(field),
- DataType::Utf8,
- Field::is_nullable(field),
- )
- })
- .collect::<Vec<Field>>(),
- )
- }
-
- async fn transform_actual_result(
- result: Vec<RecordBatch>,
- n: usize,
- ) -> Result<Vec<RecordBatch>> {
- // to compare the recorded answers to the answers we got back from running the query,
- // we need to round the decimal columns and trim the Utf8 columns
- let ctx = SessionContext::new();
- let result_schema = result[0].schema();
- let table = Arc::new(MemTable::try_new(result_schema.clone(), vec![result])?);
- let mut df = ctx.read_table(table)?
- .select(
- result_schema
- .fields
- .iter()
- .map(|field| {
- match Field::data_type(field) {
- DataType::Decimal128(_, _) => {
- // if decimal, then round it to 2 decimal places like the answers
- // round() doesn't support the second argument for decimal places to round to
- // this can be simplified to remove the mul and div when
- // https://github.com/apache/arrow-datafusion/issues/2420 is completed
- // cast it back to an over-sized Decimal with 2 precision when done rounding
- let round = Box::new(ScalarFunction {
- fun: datafusion::logical_expr::BuiltinScalarFunction::Round,
- args: vec![col(Field::name(field)).mul(lit(100))],
- }.div(lit(100)));
- Expr::Alias(
- Box::new(Expr::Cast(Cast::new(
- round,
- DataType::Decimal128(38, 2),
- ))),
- Field::name(field).to_string(),
- )
- }
- DataType::Utf8 => {
- // if string, then trim it like the answers got trimmed
- Expr::Alias(
- Box::new(trim(col(Field::name(field)))),
- Field::name(field).to_string(),
- )
- }
- _ => {
- col(Field::name(field))
- }
- }
- }).collect()
- )?;
- if let Some(x) = QUERY_LIMIT[n - 1] {
- df = df.limit(0, Some(x))?;
- }
-
- let df = df.collect().await?;
- Ok(df)
- }
-
async fn run_query(n: usize) -> Result<()> {
// Tests running query with empty tables, to see whether they run successfully.
@@ -1304,8 +873,8 @@ mod tests {
.with_batch_size(10);
let ctx = SessionContext::with_config(config);
- for &table in TABLES {
- let schema = get_schema(table);
+ for &table in TPCH_TABLES {
+ let schema = get_tpch_table_schema(table);
let batch = RecordBatch::new_empty(Arc::new(schema.to_owned()));
ctx.register_batch(table, batch)?;
@@ -1324,75 +893,95 @@ mod tests {
/// * datatypes returned in columns is correct
/// * the correct number of rows are returned
/// * the content of the rows is correct
+ #[cfg(feature = "ci")]
async fn verify_query(n: usize) -> Result<()> {
- if let Ok(path) = env::var("TPCH_DATA") {
- // load expected answers from tpch-dbgen
- // read csv as all strings, trim and cast to expected type as the csv string
- // to value parser does not handle data with leading/trailing spaces
- let ctx = SessionContext::new();
- let schema = string_schema(get_answer_schema(n));
- let options = CsvReadOptions::new()
- .schema(&schema)
- .delimiter(b'|')
- .file_extension(".out");
- let df = ctx
- .read_csv(&format!("{}/answers/q{}.out", path, n), options)
- .await?;
- let df = df.select(
- get_answer_schema(n)
- .fields()
- .iter()
- .map(|field| {
- match Field::data_type(field) {
- DataType::Decimal128(_, _) => {
- // there's no support for casting from Utf8 to Decimal, so
- // we'll cast from Utf8 to Float64 to Decimal for Decimal types
- let inner_cast = Box::new(Expr::Cast(Cast::new(
- Box::new(trim(col(Field::name(field)))),
- DataType::Float64,
- )));
- Expr::Alias(
- Box::new(Expr::Cast(Cast::new(
- inner_cast,
- Field::data_type(field).to_owned(),
- ))),
- Field::name(field).to_string(),
- )
- }
- _ => Expr::Alias(
+ use datafusion::arrow::datatypes::{DataType, Field};
+ use datafusion::logical_expr::expr::Cast;
+ use datafusion::logical_expr::Expr;
+ use std::env;
+
+ let path = env::var("TPCH_DATA").unwrap_or("benchmarks/data".to_string());
+ if !Path::new(&path).exists() {
+ return Err(DataFusionError::Execution(format!(
+ "Benchmark data not found (set TPCH_DATA env var to override): {}",
+ path
+ )));
+ }
+
+ let answer_file = format!("{}/answers/q{}.out", path, n);
+ if !Path::new(&answer_file).exists() {
+ return Err(DataFusionError::Execution(format!(
+ "Expected results not found: {}",
+ answer_file
+ )));
+ }
+
+ // load expected answers from tpch-dbgen
+ // read csv as all strings, trim and cast to expected type as the csv string
+ // to value parser does not handle data with leading/trailing spaces
+ let ctx = SessionContext::new();
+ let schema = string_schema(get_answer_schema(n));
+ let options = CsvReadOptions::new()
+ .schema(&schema)
+ .delimiter(b'|')
+ .file_extension(".out");
+ let df = ctx.read_csv(&answer_file, options).await?;
+ let df = df.select(
+ get_answer_schema(n)
+ .fields()
+ .iter()
+ .map(|field| {
+ match Field::data_type(field) {
+ DataType::Decimal128(_, _) => {
+ // there's no support for casting from Utf8 to Decimal, so
+ // we'll cast from Utf8 to Float64 to Decimal for Decimal types
+ let inner_cast = Box::new(Expr::Cast(Cast::new(
+ Box::new(trim(col(Field::name(field)))),
+ DataType::Float64,
+ )));
+ Expr::Alias(
Box::new(Expr::Cast(Cast::new(
- Box::new(trim(col(Field::name(field)))),
+ inner_cast,
Field::data_type(field).to_owned(),
))),
Field::name(field).to_string(),
- ),
+ )
}
- })
- .collect::<Vec<Expr>>(),
- )?;
- let expected = df.collect().await?;
-
- // run the query to compute actual results of the query
- let opt = DataFusionBenchmarkOpt {
- query: n,
- debug: false,
- iterations: 1,
- partitions: 2,
- batch_size: 8192,
- path: PathBuf::from(path.to_string()),
- file_format: "tbl".to_string(),
- mem_table: false,
- output_path: None,
- disable_statistics: false,
- };
- let actual = benchmark_datafusion(opt).await?;
+ _ => Expr::Alias(
+ Box::new(Expr::Cast(Cast::new(
+ Box::new(trim(col(Field::name(field)))),
+ Field::data_type(field).to_owned(),
+ ))),
+ Field::name(field).to_string(),
+ ),
+ }
+ })
+ .collect::<Vec<Expr>>(),
+ )?;
+ let expected = df.collect().await?;
+
+ // run the query to compute actual results of the query
+ let opt = DataFusionBenchmarkOpt {
+ query: n,
+ debug: false,
+ iterations: 1,
+ partitions: 2,
+ batch_size: 8192,
+ path: PathBuf::from(path.to_string()),
+ file_format: "tbl".to_string(),
+ mem_table: false,
+ output_path: None,
+ disable_statistics: false,
+ };
+ let actual = benchmark_datafusion(opt).await?;
- let transformed = transform_actual_result(actual, n).await?;
+ let transformed = transform_actual_result(actual, n).await?;
- // assert schema data types match
- let transformed_fields = &transformed[0].schema().fields;
- let expected_fields = &expected[0].schema().fields;
- let schema_matches = transformed_fields
+ // assert schema data types match
+ let transformed_fields = &transformed[0].schema().fields;
+ let expected_fields = &expected[0].schema().fields;
+ let schema_matches =
+ transformed_fields
.iter()
.zip(expected_fields.iter())
.all(|(t, e)| match t.data_type() {
@@ -1401,21 +990,18 @@ mod tests {
}
data_type => data_type == e.data_type(),
});
- assert!(schema_matches);
+ assert!(schema_matches);
- // convert both datasets to Vec<Vec<String>> for simple comparison
- let expected_vec = result_vec(&expected);
- let actual_vec = result_vec(&transformed);
+ // convert both datasets to Vec<Vec<String>> for simple comparison
+ let expected_vec = result_vec(&expected);
+ let actual_vec = result_vec(&transformed);
- // basic result comparison
- assert_eq!(expected_vec.len(), actual_vec.len());
+ // basic result comparison
+ assert_eq!(expected_vec.len(), actual_vec.len());
- // compare each row. this works as all TPC-H queries have deterministically ordered results
- for i in 0..actual_vec.len() {
- assert_eq!(expected_vec[i], actual_vec[i]);
- }
- } else {
- println!("TPCH_DATA environment variable not set, skipping test");
+ // compare each row. this works as all TPC-H queries have deterministically ordered results
+ for i in 0..actual_vec.len() {
+ assert_eq!(expected_vec[i], actual_vec[i]);
}
Ok(())
diff --git a/benchmarks/src/lib.rs b/benchmarks/src/lib.rs
new file mode 100644
index 000000000..af1dd46fd
--- /dev/null
+++ b/benchmarks/src/lib.rs
@@ -0,0 +1,18 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+pub mod tpch;
diff --git a/benchmarks/src/tpch.rs b/benchmarks/src/tpch.rs
new file mode 100644
index 000000000..46c53edf1
--- /dev/null
+++ b/benchmarks/src/tpch.rs
@@ -0,0 +1,512 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::ArrayRef;
+use arrow::record_batch::RecordBatch;
+use std::fs;
+use std::ops::{Div, Mul};
+use std::path::Path;
+use std::sync::Arc;
+use std::time::Instant;
+
+use datafusion::arrow::util::display::array_value_to_string;
+use datafusion::logical_expr::Cast;
+use datafusion::prelude::*;
+use datafusion::{
+ arrow::datatypes::{DataType, Field, Schema},
+ datasource::MemTable,
+ error::{DataFusionError, Result},
+};
+use parquet::basic::Compression;
+use parquet::file::properties::WriterProperties;
+
+pub const TPCH_TABLES: &[&str] = &[
+ "part", "supplier", "partsupp", "customer", "orders", "lineitem", "nation", "region",
+];
+
+/// Get the schema for the benchmarks derived from TPC-H
+pub fn get_tpch_table_schema(table: &str) -> Schema {
+ // note that the schema intentionally uses signed integers so that any generated Parquet
+ // files can also be used to benchmark tools that only support signed integers, such as
+ // Apache Spark
+
+ match table {
+ "part" => Schema::new(vec![
+ Field::new("p_partkey", DataType::Int64, false),
+ Field::new("p_name", DataType::Utf8, false),
+ Field::new("p_mfgr", DataType::Utf8, false),
+ Field::new("p_brand", DataType::Utf8, false),
+ Field::new("p_type", DataType::Utf8, false),
+ Field::new("p_size", DataType::Int32, false),
+ Field::new("p_container", DataType::Utf8, false),
+ Field::new("p_retailprice", DataType::Decimal128(15, 2), false),
+ Field::new("p_comment", DataType::Utf8, false),
+ ]),
+
+ "supplier" => Schema::new(vec![
+ Field::new("s_suppkey", DataType::Int64, false),
+ Field::new("s_name", DataType::Utf8, false),
+ Field::new("s_address", DataType::Utf8, false),
+ Field::new("s_nationkey", DataType::Int64, false),
+ Field::new("s_phone", DataType::Utf8, false),
+ Field::new("s_acctbal", DataType::Decimal128(15, 2), false),
+ Field::new("s_comment", DataType::Utf8, false),
+ ]),
+
+ "partsupp" => Schema::new(vec![
+ Field::new("ps_partkey", DataType::Int64, false),
+ Field::new("ps_suppkey", DataType::Int64, false),
+ Field::new("ps_availqty", DataType::Int32, false),
+ Field::new("ps_supplycost", DataType::Decimal128(15, 2), false),
+ Field::new("ps_comment", DataType::Utf8, false),
+ ]),
+
+ "customer" => Schema::new(vec![
+ Field::new("c_custkey", DataType::Int64, false),
+ Field::new("c_name", DataType::Utf8, false),
+ Field::new("c_address", DataType::Utf8, false),
+ Field::new("c_nationkey", DataType::Int64, false),
+ Field::new("c_phone", DataType::Utf8, false),
+ Field::new("c_acctbal", DataType::Decimal128(15, 2), false),
+ Field::new("c_mktsegment", DataType::Utf8, false),
+ Field::new("c_comment", DataType::Utf8, false),
+ ]),
+
+ "orders" => Schema::new(vec![
+ Field::new("o_orderkey", DataType::Int64, false),
+ Field::new("o_custkey", DataType::Int64, false),
+ Field::new("o_orderstatus", DataType::Utf8, false),
+ Field::new("o_totalprice", DataType::Decimal128(15, 2), false),
+ Field::new("o_orderdate", DataType::Date32, false),
+ Field::new("o_orderpriority", DataType::Utf8, false),
+ Field::new("o_clerk", DataType::Utf8, false),
+ Field::new("o_shippriority", DataType::Int32, false),
+ Field::new("o_comment", DataType::Utf8, false),
+ ]),
+
+ "lineitem" => Schema::new(vec![
+ Field::new("l_orderkey", DataType::Int64, false),
+ Field::new("l_partkey", DataType::Int64, false),
+ Field::new("l_suppkey", DataType::Int64, false),
+ Field::new("l_linenumber", DataType::Int32, false),
+ Field::new("l_quantity", DataType::Decimal128(15, 2), false),
+ Field::new("l_extendedprice", DataType::Decimal128(15, 2), false),
+ Field::new("l_discount", DataType::Decimal128(15, 2), false),
+ Field::new("l_tax", DataType::Decimal128(15, 2), false),
+ Field::new("l_returnflag", DataType::Utf8, false),
+ Field::new("l_linestatus", DataType::Utf8, false),
+ Field::new("l_shipdate", DataType::Date32, false),
+ Field::new("l_commitdate", DataType::Date32, false),
+ Field::new("l_receiptdate", DataType::Date32, false),
+ Field::new("l_shipinstruct", DataType::Utf8, false),
+ Field::new("l_shipmode", DataType::Utf8, false),
+ Field::new("l_comment", DataType::Utf8, false),
+ ]),
+
+ "nation" => Schema::new(vec![
+ Field::new("n_nationkey", DataType::Int64, false),
+ Field::new("n_name", DataType::Utf8, false),
+ Field::new("n_regionkey", DataType::Int64, false),
+ Field::new("n_comment", DataType::Utf8, false),
+ ]),
+
+ "region" => Schema::new(vec![
+ Field::new("r_regionkey", DataType::Int64, false),
+ Field::new("r_name", DataType::Utf8, false),
+ Field::new("r_comment", DataType::Utf8, false),
+ ]),
+
+ _ => unimplemented!(),
+ }
+}
+
+/// Get the expected schema for the results of a query
+pub fn get_answer_schema(n: usize) -> Schema {
+ match n {
+ 1 => Schema::new(vec![
+ Field::new("l_returnflag", DataType::Utf8, true),
+ Field::new("l_linestatus", DataType::Utf8, true),
+ Field::new("sum_qty", DataType::Decimal128(15, 2), true),
+ Field::new("sum_base_price", DataType::Decimal128(15, 2), true),
+ Field::new("sum_disc_price", DataType::Decimal128(15, 2), true),
+ Field::new("sum_charge", DataType::Decimal128(15, 2), true),
+ Field::new("avg_qty", DataType::Decimal128(15, 2), true),
+ Field::new("avg_price", DataType::Decimal128(15, 2), true),
+ Field::new("avg_disc", DataType::Decimal128(15, 2), true),
+ Field::new("count_order", DataType::Int64, true),
+ ]),
+
+ 2 => Schema::new(vec![
+ Field::new("s_acctbal", DataType::Decimal128(15, 2), true),
+ Field::new("s_name", DataType::Utf8, true),
+ Field::new("n_name", DataType::Utf8, true),
+ Field::new("p_partkey", DataType::Int64, true),
+ Field::new("p_mfgr", DataType::Utf8, true),
+ Field::new("s_address", DataType::Utf8, true),
+ Field::new("s_phone", DataType::Utf8, true),
+ Field::new("s_comment", DataType::Utf8, true),
+ ]),
+
+ 3 => Schema::new(vec![
+ Field::new("l_orderkey", DataType::Int64, true),
+ Field::new("revenue", DataType::Decimal128(15, 2), true),
+ Field::new("o_orderdate", DataType::Date32, true),
+ Field::new("o_shippriority", DataType::Int32, true),
+ ]),
+
+ 4 => Schema::new(vec![
+ Field::new("o_orderpriority", DataType::Utf8, true),
+ Field::new("order_count", DataType::Int64, true),
+ ]),
+
+ 5 => Schema::new(vec![
+ Field::new("n_name", DataType::Utf8, true),
+ Field::new("revenue", DataType::Decimal128(15, 2), true),
+ ]),
+
+ 6 => Schema::new(vec![Field::new(
+ "revenue",
+ DataType::Decimal128(15, 2),
+ true,
+ )]),
+
+ 7 => Schema::new(vec![
+ Field::new("supp_nation", DataType::Utf8, true),
+ Field::new("cust_nation", DataType::Utf8, true),
+ Field::new("l_year", DataType::Int32, true),
+ Field::new("revenue", DataType::Decimal128(15, 2), true),
+ ]),
+
+ 8 => Schema::new(vec![
+ Field::new("o_year", DataType::Int32, true),
+ Field::new("mkt_share", DataType::Decimal128(15, 2), true),
+ ]),
+
+ 9 => Schema::new(vec![
+ Field::new("nation", DataType::Utf8, true),
+ Field::new("o_year", DataType::Int32, true),
+ Field::new("sum_profit", DataType::Decimal128(15, 2), true),
+ ]),
+
+ 10 => Schema::new(vec![
+ Field::new("c_custkey", DataType::Int64, true),
+ Field::new("c_name", DataType::Utf8, true),
+ Field::new("revenue", DataType::Decimal128(15, 2), true),
+ Field::new("c_acctbal", DataType::Decimal128(15, 2), true),
+ Field::new("n_name", DataType::Utf8, true),
+ Field::new("c_address", DataType::Utf8, true),
+ Field::new("c_phone", DataType::Utf8, true),
+ Field::new("c_comment", DataType::Utf8, true),
+ ]),
+
+ 11 => Schema::new(vec![
+ Field::new("ps_partkey", DataType::Int64, true),
+ Field::new("value", DataType::Decimal128(15, 2), true),
+ ]),
+
+ 12 => Schema::new(vec![
+ Field::new("l_shipmode", DataType::Utf8, true),
+ Field::new("high_line_count", DataType::Int64, true),
+ Field::new("low_line_count", DataType::Int64, true),
+ ]),
+
+ 13 => Schema::new(vec![
+ Field::new("c_count", DataType::Int64, true),
+ Field::new("custdist", DataType::Int64, true),
+ ]),
+
+ 14 => Schema::new(vec![Field::new(
+ "promo_revenue",
+ DataType::Decimal128(38, 2),
+ true,
+ )]),
+
+ 15 => Schema::new(vec![
+ Field::new("s_suppkey", DataType::Int64, true),
+ Field::new("s_name", DataType::Utf8, true),
+ Field::new("s_address", DataType::Utf8, true),
+ Field::new("s_phone", DataType::Utf8, true),
+ Field::new("total_revenue", DataType::Decimal128(15, 2), true),
+ ]),
+
+ 16 => Schema::new(vec![
+ Field::new("p_brand", DataType::Utf8, true),
+ Field::new("p_type", DataType::Utf8, true),
+ Field::new("p_size", DataType::Int32, true),
+ Field::new("supplier_cnt", DataType::Int64, true),
+ ]),
+
+ 17 => Schema::new(vec![Field::new(
+ "avg_yearly",
+ DataType::Decimal128(38, 2),
+ true,
+ )]),
+
+ 18 => Schema::new(vec![
+ Field::new("c_name", DataType::Utf8, true),
+ Field::new("c_custkey", DataType::Int64, true),
+ Field::new("o_orderkey", DataType::Int64, true),
+ Field::new("o_orderdate", DataType::Date32, true),
+ Field::new("o_totalprice", DataType::Decimal128(15, 2), true),
+ Field::new("sum_l_quantity", DataType::Decimal128(15, 2), true),
+ ]),
+
+ 19 => Schema::new(vec![Field::new(
+ "revenue",
+ DataType::Decimal128(15, 2),
+ true,
+ )]),
+
+ 20 => Schema::new(vec![
+ Field::new("s_name", DataType::Utf8, true),
+ Field::new("s_address", DataType::Utf8, true),
+ ]),
+
+ 21 => Schema::new(vec![
+ Field::new("s_name", DataType::Utf8, true),
+ Field::new("numwait", DataType::Int64, true),
+ ]),
+
+ 22 => Schema::new(vec![
+ Field::new("cntrycode", DataType::Utf8, true),
+ Field::new("numcust", DataType::Int64, true),
+ Field::new("totacctbal", DataType::Decimal128(15, 2), true),
+ ]),
+
+ _ => unimplemented!(),
+ }
+}
+
+/// Get the SQL statements from the specified query file
+pub fn get_query_sql(query: usize) -> Result<Vec<String>> {
+ if query > 0 && query < 23 {
+ let possibilities = vec![
+ format!("queries/q{}.sql", query),
+ format!("benchmarks/queries/q{}.sql", query),
+ ];
+ let mut errors = vec![];
+ for filename in possibilities {
+ match fs::read_to_string(&filename) {
+ Ok(contents) => {
+ return Ok(contents
+ .split(';')
+ .map(|s| s.trim())
+ .filter(|s| !s.is_empty())
+ .map(|s| s.to_string())
+ .collect());
+ }
+ Err(e) => errors.push(format!("{}: {}", filename, e)),
+ };
+ }
+ Err(DataFusionError::Plan(format!(
+ "invalid query. Could not find query: {:?}",
+ errors
+ )))
+ } else {
+ Err(DataFusionError::Plan(
+ "invalid query. Expected value between 1 and 22".to_owned(),
+ ))
+ }
+}
+
+/// Conver tbl (csv) file to parquet
+pub async fn convert_tbl(
+ input_path: &str,
+ output_path: &str,
+ file_format: &str,
+ partitions: usize,
+ batch_size: usize,
+ compression: Compression,
+) -> Result<()> {
+ let output_root_path = Path::new(output_path);
+ for table in TPCH_TABLES {
+ let start = Instant::now();
+ let schema = get_tpch_table_schema(table);
+
+ let input_path = format!("{}/{}.tbl", input_path, table);
+ let options = CsvReadOptions::new()
+ .schema(&schema)
+ .has_header(false)
+ .delimiter(b'|')
+ .file_extension(".tbl");
+
+ let config = SessionConfig::new().with_batch_size(batch_size);
+ let ctx = SessionContext::with_config(config);
+
+ // build plan to read the TBL file
+ let mut csv = ctx.read_csv(&input_path, options).await?;
+
+ // optionally, repartition the file
+ if partitions > 1 {
+ csv = csv.repartition(Partitioning::RoundRobinBatch(partitions))?
+ }
+
+ // create the physical plan
+ let csv = csv.to_logical_plan()?;
+ let csv = ctx.create_physical_plan(&csv).await?;
+
+ let output_path = output_root_path.join(table);
+ let output_path = output_path.to_str().unwrap().to_owned();
+
+ println!(
+ "Converting '{}' to {} files in directory '{}'",
+ &input_path, &file_format, &output_path
+ );
+ match file_format {
+ "csv" => ctx.write_csv(csv, output_path).await?,
+ "parquet" => {
+ let props = WriterProperties::builder()
+ .set_compression(compression)
+ .build();
+ ctx.write_parquet(csv, output_path, Some(props)).await?
+ }
+ other => {
+ return Err(DataFusionError::NotImplemented(format!(
+ "Invalid output format: {}",
+ other
+ )));
+ }
+ }
+ println!("Conversion completed in {} ms", start.elapsed().as_millis());
+ }
+
+ Ok(())
+}
+
+/// Converts the results into a 2d array of strings, `result[row][column]`
+/// Special cases nulls to NULL for testing
+pub fn result_vec(results: &[RecordBatch]) -> Vec<Vec<String>> {
+ let mut result = vec![];
+ for batch in results {
+ for row_index in 0..batch.num_rows() {
+ let row_vec = batch
+ .columns()
+ .iter()
+ .map(|column| col_str(column, row_index))
+ .collect();
+ result.push(row_vec);
+ }
+ }
+ result
+}
+
+/// convert expected schema to all utf8 so columns can be read as strings to be parsed separately
+/// this is due to the fact that the csv parser cannot handle leading/trailing spaces
+pub fn string_schema(schema: Schema) -> Schema {
+ Schema::new(
+ schema
+ .fields()
+ .iter()
+ .map(|field| {
+ Field::new(
+ Field::name(field),
+ DataType::Utf8,
+ Field::is_nullable(field),
+ )
+ })
+ .collect::<Vec<Field>>(),
+ )
+}
+
+/// Specialised String representation
+fn col_str(column: &ArrayRef, row_index: usize) -> String {
+ if column.is_null(row_index) {
+ return "NULL".to_string();
+ }
+
+ array_value_to_string(column, row_index).unwrap()
+}
+
+pub async fn transform_actual_result(
+ result: Vec<RecordBatch>,
+ n: usize,
+) -> Result<Vec<RecordBatch>> {
+ // to compare the recorded answers to the answers we got back from running the query,
+ // we need to round the decimal columns and trim the Utf8 columns
+ let ctx = SessionContext::new();
+ let result_schema = result[0].schema();
+ let table = Arc::new(MemTable::try_new(result_schema.clone(), vec![result])?);
+ let mut df = ctx.read_table(table)?
+ .select(
+ result_schema
+ .fields
+ .iter()
+ .map(|field| {
+ match Field::data_type(field) {
+ DataType::Decimal128(_, _) => {
+ // if decimal, then round it to 2 decimal places like the answers
+ // round() doesn't support the second argument for decimal places to round to
+ // this can be simplified to remove the mul and div when
+ // https://github.com/apache/arrow-datafusion/issues/2420 is completed
+ // cast it back to an over-sized Decimal with 2 precision when done rounding
+ let round = Box::new(Expr::ScalarFunction {
+ fun: datafusion::logical_expr::BuiltinScalarFunction::Round,
+ args: vec![col(Field::name(field)).mul(lit(100))],
+ }.div(lit(100)));
+ Expr::Alias(
+ Box::new(Expr::Cast(Cast::new(
+ round,
+ DataType::Decimal128(38, 2),
+ ))),
+ Field::name(field).to_string(),
+ )
+ }
+ DataType::Utf8 => {
+ // if string, then trim it like the answers got trimmed
+ Expr::Alias(
+ Box::new(trim(col(Field::name(field)))),
+ Field::name(field).to_string(),
+ )
+ }
+ _ => {
+ col(Field::name(field))
+ }
+ }
+ }).collect()
+ )?;
+ if let Some(x) = QUERY_LIMIT[n - 1] {
+ df = df.limit(0, Some(x))?;
+ }
+
+ let df = df.collect().await?;
+ Ok(df)
+}
+
+pub const QUERY_LIMIT: [Option<usize>; 22] = [
+ None,
+ Some(100),
+ Some(10),
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ Some(20),
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ None,
+ Some(100),
+ None,
+ None,
+ Some(100),
+ None,
+];