You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2020/12/28 17:06:32 UTC

[arrow] branch master updated: ARROW-10712: [Rust] [DataFusion] Add tests to TPC-H benchmarks

This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new b07489b  ARROW-10712: [Rust] [DataFusion] Add tests to TPC-H benchmarks
b07489b is described below

commit b07489bf7086bb3d857f3c7fd9dc1d634d1b3c5b
Author: Mike Seddon <se...@gmail.com>
AuthorDate: Mon Dec 28 10:05:28 2020 -0700

    ARROW-10712: [Rust] [DataFusion] Add tests to TPC-H benchmarks
    
    This PR adds the ability to load and parse the expected query answers included in the https://github.com/databricks/tpch-dbgen/tree/master/answers repository - which users have to clone anyway to generate the TPC-H data.
    
    Currently DataFusion does not support Decimal types which all the numeric values in TPC-H are so there are the expected precision errors in the current results. These tests are still useful as they show some interesting results already such as non-deterministic query 5 results.
    
    @andygrove
    
    Closes #9015 from seddonm1/test-tpch
    
    Authored-by: Mike Seddon <se...@gmail.com>
    Signed-off-by: Andy Grove <an...@gmail.com>
---
 rust/benchmarks/README.md       |   3 +-
 rust/benchmarks/src/bin/tpch.rs | 439 ++++++++++++++++++++++++++++++++++++++--
 rust/datafusion/src/prelude.rs  |   4 +-
 3 files changed, 424 insertions(+), 22 deletions(-)

diff --git a/rust/benchmarks/README.md b/rust/benchmarks/README.md
index 9bff3e2..2ae035b 100644
--- a/rust/benchmarks/README.md
+++ b/rust/benchmarks/README.md
@@ -37,6 +37,7 @@ clone the repository and build the source code.
 git clone git@github.com:databricks/tpch-dbgen.git
 cd tpch-dbgen
 make
+export TPCH_DATA=$(pwd)
 ```
 
 Data can now be generated with the following command. Note that `-s 1` means use Scale Factor 1 or ~1 GB of
@@ -63,7 +64,7 @@ This utility does not yet provide support for changing the number of partitions
 option is to use the following Docker image to perform the conversion from `tbl` files to CSV or Parquet.
 
 ```bash
-docker run -it ballistacompute/spark-benchmarks:0.4.0-SNAPSHOT 
+docker run -it ballistacompute/spark-benchmarks:0.4.0-SNAPSHOT
   -h, --help   Show help message
 
 Subcommand: convert-tpch
diff --git a/rust/benchmarks/src/bin/tpch.rs b/rust/benchmarks/src/bin/tpch.rs
index 769668c..eb789ba 100644
--- a/rust/benchmarks/src/bin/tpch.rs
+++ b/rust/benchmarks/src/bin/tpch.rs
@@ -108,15 +108,14 @@ const TABLES: &[&str] = &[
 
 #[tokio::main]
 async fn main() -> Result<()> {
+    env_logger::init();
     match TpchOpt::from_args() {
-        TpchOpt::Benchmark(opt) => benchmark(opt).await,
+        TpchOpt::Benchmark(opt) => benchmark(opt).await.map(|_| ()),
         TpchOpt::Convert(opt) => convert_tbl(opt).await,
     }
 }
 
-async fn benchmark(opt: BenchmarkOpt) -> Result<()> {
-    env_logger::init();
-
+async fn benchmark(opt: BenchmarkOpt) -> Result<Vec<arrow::record_batch::RecordBatch>> {
     println!("Running benchmarks with the following options: {:?}", opt);
     let config = ExecutionConfig::new()
         .with_concurrency(opt.concurrency)
@@ -146,10 +145,11 @@ async fn benchmark(opt: BenchmarkOpt) -> Result<()> {
 
     let mut millis = vec![];
     // run benchmark
+    let mut result: Vec<arrow::record_batch::RecordBatch> = Vec::with_capacity(1);
     for i in 0..opt.iterations {
         let start = Instant::now();
         let plan = create_logical_plan(&mut ctx, opt.query)?;
-        execute_query(&mut ctx, &plan, opt.debug).await?;
+        result = execute_query(&mut ctx, &plan, opt.debug).await?;
         let elapsed = start.elapsed().as_secs_f64() * 1000.0;
         millis.push(elapsed as f64);
         println!("Query {} iteration {} took {:.1} ms", opt.query, i, elapsed);
@@ -158,7 +158,7 @@ async fn benchmark(opt: BenchmarkOpt) -> Result<()> {
     let avg = millis.iter().sum::<f64>() / millis.len() as f64;
     println!("Query {} avg time: {:.2} ms", opt.query, avg);
 
-    Ok(())
+    Ok(result)
 }
 
 fn create_logical_plan(ctx: &mut ExecutionContext, query: usize) -> Result<LogicalPlan> {
@@ -994,7 +994,7 @@ async fn execute_query(
     ctx: &mut ExecutionContext,
     plan: &LogicalPlan,
     debug: bool,
-) -> Result<()> {
+) -> Result<Vec<arrow::record_batch::RecordBatch>> {
     if debug {
         println!("Logical plan:\n{:?}", plan);
     }
@@ -1007,12 +1007,11 @@ async fn execute_query(
     if debug {
         pretty::print_batches(&result)?;
     }
-    Ok(())
+    Ok(result)
 }
 
 async fn convert_tbl(opt: ConvertOpt) -> Result<()> {
     let output_root_path = Path::new(&opt.output_path);
-
     for table in TABLES {
         let start = Instant::now();
         let schema = get_schema(table);
@@ -1088,13 +1087,14 @@ fn get_table(
     table_format: &str,
 ) -> Result<Box<dyn TableProvider + Send + Sync>> {
     match table_format {
-        // dbgen creates .tbl ('|' delimited) files
+        // dbgen creates .tbl ('|' delimited) files without header
         "tbl" => {
             let path = format!("{}/{}.tbl", path, table);
             let schema = get_schema(table);
             let options = CsvReadOptions::new()
                 .schema(&schema)
                 .delimiter(b'|')
+                .has_header(false)
                 .file_extension(".tbl");
 
             Ok(Box::new(CsvFile::try_new(&path, options)?))
@@ -1130,7 +1130,7 @@ fn get_schema(table: &str) -> Schema {
             Field::new("p_type", DataType::Utf8, false),
             Field::new("p_size", DataType::Int32, false),
             Field::new("p_container", DataType::Utf8, false),
-            Field::new("p_retailprice", DataType::Float64, false), // decimal
+            Field::new("p_retailprice", DataType::Float64, false),
             Field::new("p_comment", DataType::Utf8, false),
         ]),
 
@@ -1140,7 +1140,7 @@ fn get_schema(table: &str) -> Schema {
             Field::new("s_address", DataType::Utf8, false),
             Field::new("s_nationkey", DataType::Int32, false),
             Field::new("s_phone", DataType::Utf8, false),
-            Field::new("s_acctbal", DataType::Float64, false), // decimal
+            Field::new("s_acctbal", DataType::Float64, false),
             Field::new("s_comment", DataType::Utf8, false),
         ]),
 
@@ -1148,7 +1148,7 @@ fn get_schema(table: &str) -> Schema {
             Field::new("ps_partkey", DataType::Int32, false),
             Field::new("ps_suppkey", DataType::Int32, false),
             Field::new("ps_availqty", DataType::Int32, false),
-            Field::new("ps_supplycost", DataType::Float64, false), // decimal
+            Field::new("ps_supplycost", DataType::Float64, false),
             Field::new("ps_comment", DataType::Utf8, false),
         ]),
 
@@ -1158,7 +1158,7 @@ fn get_schema(table: &str) -> Schema {
             Field::new("c_address", DataType::Utf8, false),
             Field::new("c_nationkey", DataType::Int32, false),
             Field::new("c_phone", DataType::Utf8, false),
-            Field::new("c_acctbal", DataType::Float64, false), // decimal
+            Field::new("c_acctbal", DataType::Float64, false),
             Field::new("c_mktsegment", DataType::Utf8, false),
             Field::new("c_comment", DataType::Utf8, false),
         ]),
@@ -1167,7 +1167,7 @@ fn get_schema(table: &str) -> Schema {
             Field::new("o_orderkey", DataType::Int32, false),
             Field::new("o_custkey", DataType::Int32, false),
             Field::new("o_orderstatus", DataType::Utf8, false),
-            Field::new("o_totalprice", DataType::Float64, false), // decimal
+            Field::new("o_totalprice", DataType::Float64, false),
             Field::new("o_orderdate", DataType::Date32(DateUnit::Day), false),
             Field::new("o_orderpriority", DataType::Utf8, false),
             Field::new("o_clerk", DataType::Utf8, false),
@@ -1180,10 +1180,10 @@ fn get_schema(table: &str) -> Schema {
             Field::new("l_partkey", DataType::Int32, false),
             Field::new("l_suppkey", DataType::Int32, false),
             Field::new("l_linenumber", DataType::Int32, false),
-            Field::new("l_quantity", DataType::Float64, false), // decimal
-            Field::new("l_extendedprice", DataType::Float64, false), // decimal
-            Field::new("l_discount", DataType::Float64, false), // decimal
-            Field::new("l_tax", DataType::Float64, false),      // decimal
+            Field::new("l_quantity", DataType::Float64, false),
+            Field::new("l_extendedprice", DataType::Float64, false),
+            Field::new("l_discount", DataType::Float64, false),
+            Field::new("l_tax", DataType::Float64, false),
             Field::new("l_returnflag", DataType::Utf8, false),
             Field::new("l_linestatus", DataType::Utf8, false),
             Field::new("l_shipdate", DataType::Date32(DateUnit::Day), false),
@@ -1210,3 +1210,404 @@ fn get_schema(table: &str) -> Schema {
         _ => unimplemented!(),
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use std::env;
+    use std::sync::Arc;
+
+    use arrow::array::*;
+    use arrow::record_batch::RecordBatch;
+    use arrow::util::display::array_value_to_string;
+
+    use datafusion::logical_plan::Expr;
+    use datafusion::logical_plan::Expr::Cast;
+
+    #[tokio::test]
+    async fn q1() -> Result<()> {
+        verify_query(1).await
+    }
+
+    #[tokio::test]
+    async fn q2() -> Result<()> {
+        verify_query(2).await
+    }
+
+    #[tokio::test]
+    async fn q3() -> Result<()> {
+        verify_query(3).await
+    }
+
+    #[tokio::test]
+    async fn q4() -> Result<()> {
+        verify_query(4).await
+    }
+
+    #[tokio::test]
+    async fn q5() -> Result<()> {
+        verify_query(5).await
+    }
+
+    #[tokio::test]
+    async fn q6() -> Result<()> {
+        verify_query(6).await
+    }
+
+    #[tokio::test]
+    async fn q7() -> Result<()> {
+        verify_query(7).await
+    }
+
+    #[tokio::test]
+    async fn q8() -> Result<()> {
+        verify_query(8).await
+    }
+
+    #[tokio::test]
+    async fn q9() -> Result<()> {
+        verify_query(9).await
+    }
+
+    #[tokio::test]
+    async fn q10() -> Result<()> {
+        verify_query(10).await
+    }
+
+    #[tokio::test]
+    async fn q11() -> Result<()> {
+        verify_query(11).await
+    }
+
+    #[tokio::test]
+    async fn q12() -> Result<()> {
+        verify_query(12).await
+    }
+
+    #[tokio::test]
+    async fn q13() -> Result<()> {
+        verify_query(13).await
+    }
+
+    #[tokio::test]
+    async fn q14() -> Result<()> {
+        verify_query(14).await
+    }
+
+    #[tokio::test]
+    async fn q15() -> Result<()> {
+        verify_query(15).await
+    }
+
+    #[tokio::test]
+    async fn q16() -> Result<()> {
+        verify_query(16).await
+    }
+
+    #[tokio::test]
+    async fn q17() -> Result<()> {
+        verify_query(17).await
+    }
+
+    #[tokio::test]
+    async fn q18() -> Result<()> {
+        verify_query(18).await
+    }
+
+    #[tokio::test]
+    async fn q19() -> Result<()> {
+        verify_query(19).await
+    }
+
+    #[tokio::test]
+    async fn q20() -> Result<()> {
+        verify_query(20).await
+    }
+
+    #[tokio::test]
+    async fn q21() -> Result<()> {
+        verify_query(21).await
+    }
+
+    #[tokio::test]
+    async fn q22() -> Result<()> {
+        verify_query(22).await
+    }
+
+    /// Specialised String representation
+    fn col_str(column: &ArrayRef, row_index: usize) -> String {
+        if column.is_null(row_index) {
+            return "NULL".to_string();
+        }
+
+        // Special case ListArray as there is no pretty print support for it yet
+        if let DataType::FixedSizeList(_, n) = column.data_type() {
+            let array = column
+                .as_any()
+                .downcast_ref::<FixedSizeListArray>()
+                .unwrap()
+                .value(row_index);
+
+            let mut r = Vec::with_capacity(*n as usize);
+            for i in 0..*n {
+                r.push(col_str(&array, i as usize));
+            }
+            return format!("[{}]", r.join(","));
+        }
+
+        array_value_to_string(column, row_index).unwrap()
+    }
+
+    /// Converts the results into a 2d array of strings, `result[row][column]`
+    /// Special cases nulls to NULL for testing
+    fn result_vec(results: &[RecordBatch]) -> Vec<Vec<String>> {
+        let mut result = vec![];
+        for batch in results {
+            for row_index in 0..batch.num_rows() {
+                let row_vec = batch
+                    .columns()
+                    .iter()
+                    .map(|column| col_str(column, row_index))
+                    .collect();
+                result.push(row_vec);
+            }
+        }
+        result
+    }
+
+    fn get_answer_schema(n: usize) -> Schema {
+        match n {
+            1 => Schema::new(vec![
+                Field::new("l_returnflag", DataType::Utf8, true),
+                Field::new("l_linestatus", DataType::Utf8, true),
+                Field::new("sum_qty", DataType::Float64, true),
+                Field::new("sum_base_price", DataType::Float64, true),
+                Field::new("sum_disc_price", DataType::Float64, true),
+                Field::new("sum_charge", DataType::Float64, true),
+                Field::new("avg_qty", DataType::Float64, true),
+                Field::new("avg_price", DataType::Float64, true),
+                Field::new("avg_disc", DataType::Float64, true),
+                Field::new("count_order", DataType::UInt64, true),
+            ]),
+
+            2 => Schema::new(vec![
+                Field::new("s_acctbal", DataType::Float64, true),
+                Field::new("s_name", DataType::Utf8, true),
+                Field::new("n_name", DataType::Utf8, true),
+                Field::new("p_partkey", DataType::Int32, true),
+                Field::new("p_mfgr", DataType::Utf8, true),
+                Field::new("s_address", DataType::Utf8, true),
+                Field::new("s_phone", DataType::Utf8, true),
+                Field::new("s_comment", DataType::Utf8, true),
+            ]),
+
+            3 => Schema::new(vec![
+                Field::new("l_orderkey", DataType::Int32, true),
+                Field::new("revenue", DataType::Float64, true),
+                Field::new("o_orderdat", DataType::Date32(DateUnit::Day), true),
+                Field::new("o_shippriority", DataType::Int32, true),
+            ]),
+
+            4 => Schema::new(vec![
+                Field::new("o_orderpriority", DataType::Utf8, true),
+                Field::new("order_count", DataType::Int32, true),
+            ]),
+
+            5 => Schema::new(vec![
+                Field::new("n_name", DataType::Utf8, true),
+                Field::new("revenue", DataType::Float64, true),
+            ]),
+
+            6 => Schema::new(vec![Field::new("revenue", DataType::Float64, true)]),
+
+            7 => Schema::new(vec![
+                Field::new("supp_nation", DataType::Utf8, true),
+                Field::new("cust_nation", DataType::Utf8, true),
+                Field::new("l_year", DataType::Int32, true),
+                Field::new("revenue", DataType::Float64, true),
+            ]),
+
+            8 => Schema::new(vec![
+                Field::new("o_year", DataType::Int32, true),
+                Field::new("mkt_share", DataType::Float64, true),
+            ]),
+
+            9 => Schema::new(vec![
+                Field::new("nation", DataType::Utf8, true),
+                Field::new("o_year", DataType::Int32, true),
+                Field::new("sum_profit", DataType::Float64, true),
+            ]),
+
+            10 => Schema::new(vec![
+                Field::new("c_custkey", DataType::Int32, true),
+                Field::new("c_name", DataType::Utf8, true),
+                Field::new("revenue", DataType::Float64, true),
+                Field::new("c_acctbal", DataType::Float64, true),
+                Field::new("n_name", DataType::Utf8, true),
+                Field::new("c_address", DataType::Utf8, true),
+                Field::new("c_phone", DataType::Utf8, true),
+                Field::new("c_comment", DataType::Utf8, true),
+            ]),
+
+            11 => Schema::new(vec![
+                Field::new("ps_partkey", DataType::Int32, true),
+                Field::new("value", DataType::Float64, true),
+            ]),
+
+            12 => Schema::new(vec![
+                Field::new("l_shipmode", DataType::Utf8, true),
+                Field::new("high_line_count", DataType::Int64, true),
+                Field::new("low_line_count", DataType::Int64, true),
+            ]),
+
+            13 => Schema::new(vec![
+                Field::new("c_count", DataType::Int64, true),
+                Field::new("custdist", DataType::Int64, true),
+            ]),
+
+            14 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]),
+
+            15 => Schema::new(vec![Field::new("promo_revenue", DataType::Float64, true)]),
+
+            16 => Schema::new(vec![
+                Field::new("p_brand", DataType::Utf8, true),
+                Field::new("p_type", DataType::Utf8, true),
+                Field::new("c_phone", DataType::Int32, true),
+                Field::new("c_comment", DataType::Int32, true),
+            ]),
+
+            17 => Schema::new(vec![Field::new("avg_yearly", DataType::Float64, true)]),
+
+            18 => Schema::new(vec![
+                Field::new("c_name", DataType::Utf8, true),
+                Field::new("c_custkey", DataType::Int32, true),
+                Field::new("o_orderkey", DataType::Int32, true),
+                Field::new("o_orderdat", DataType::Date32(DateUnit::Day), true),
+                Field::new("o_totalprice", DataType::Float64, true),
+                Field::new("sum_l_quantity", DataType::Float64, true),
+            ]),
+
+            19 => Schema::new(vec![Field::new("revenue", DataType::Float64, true)]),
+
+            20 => Schema::new(vec![
+                Field::new("s_name", DataType::Utf8, true),
+                Field::new("s_address", DataType::Utf8, true),
+            ]),
+
+            21 => Schema::new(vec![
+                Field::new("s_name", DataType::Utf8, true),
+                Field::new("numwait", DataType::Int32, true),
+            ]),
+
+            22 => Schema::new(vec![
+                Field::new("cntrycode", DataType::Int32, true),
+                Field::new("numcust", DataType::Int32, true),
+                Field::new("totacctbal", DataType::Float64, true),
+            ]),
+
+            _ => unimplemented!(),
+        }
+    }
+
+    // convert expected schema to all utf8 so columns can be read as strings to be parsed separately
+    // this is due to the fact that the csv parser cannot handle leading/trailing spaces
+    fn string_schema(schema: Schema) -> Schema {
+        Schema::new(
+            schema
+                .fields()
+                .iter()
+                .map(|field| {
+                    Field::new(
+                        Field::name(&field),
+                        DataType::Utf8,
+                        Field::is_nullable(&field),
+                    )
+                })
+                .collect::<Vec<Field>>(),
+        )
+    }
+
+    // convert the schema to the same but with all columns set to nullable=true.
+    // this allows direct schema comparison ignoring nullable.
+    fn nullable_schema(schema: Arc<Schema>) -> Schema {
+        Schema::new(
+            schema
+                .fields()
+                .iter()
+                .map(|field| {
+                    Field::new(
+                        Field::name(&field),
+                        Field::data_type(&field).to_owned(),
+                        true,
+                    )
+                })
+                .collect::<Vec<Field>>(),
+        )
+    }
+
+    async fn verify_query(n: usize) -> Result<()> {
+        if let Ok(path) = env::var("TPCH_DATA") {
+            // load expected answers from tpch-dbgen
+            // read csv as all strings, trim and cast to expected type as the csv string
+            // to value parser does not handle data with leading/trailing spaces
+            let mut ctx = ExecutionContext::new();
+            let schema = string_schema(get_answer_schema(n));
+            let options = CsvReadOptions::new()
+                .schema(&schema)
+                .delimiter(b'|')
+                .file_extension(".out");
+            let df = ctx.read_csv(&format!("{}/answers/q{}.out", path, n), options)?;
+            let df = df.select(
+                get_answer_schema(n)
+                    .fields()
+                    .iter()
+                    .map(|field| {
+                        Expr::Alias(
+                            Box::new(Cast {
+                                expr: Box::new(trim(col(Field::name(&field)))),
+                                data_type: Field::data_type(&field).to_owned(),
+                            }),
+                            Field::name(&field).to_string(),
+                        )
+                    })
+                    .collect::<Vec<Expr>>(),
+            )?;
+            let expected = df.collect().await?;
+
+            // run the query to compute actual results of the query
+            let opt = BenchmarkOpt {
+                query: n,
+                debug: false,
+                iterations: 1,
+                concurrency: 2,
+                batch_size: 4096,
+                path: PathBuf::from(path.to_string()),
+                file_format: "tbl".to_string(),
+                mem_table: false,
+            };
+            let actual = benchmark(opt).await?;
+
+            // assert schema equality without comparing nullable values
+            assert_eq!(
+                nullable_schema(expected[0].schema()),
+                nullable_schema(actual[0].schema())
+            );
+
+            // convert both datasets to Vec<Vec<String>> for simple comparison
+            let expected_vec = result_vec(&expected);
+            let actual_vec = result_vec(&actual);
+
+            // basic result comparison
+            assert_eq!(expected_vec.len(), actual_vec.len());
+
+            // compare each row. this works as all TPC-H queries have determinisically ordered results
+            for i in 0..actual_vec.len() {
+                assert_eq!(expected_vec[i], actual_vec[i]);
+            }
+        } else {
+            println!("TPCH_DATA environment variable not set, skipping test");
+        }
+
+        Ok(())
+    }
+}
diff --git a/rust/datafusion/src/prelude.rs b/rust/datafusion/src/prelude.rs
index 309b75b..c8a4804 100644
--- a/rust/datafusion/src/prelude.rs
+++ b/rust/datafusion/src/prelude.rs
@@ -28,7 +28,7 @@
 pub use crate::dataframe::DataFrame;
 pub use crate::execution::context::{ExecutionConfig, ExecutionContext};
 pub use crate::logical_plan::{
-    array, avg, col, concat, count, create_udf, length, lit, max, min, sum, JoinType,
-    Partitioning,
+    array, avg, col, concat, count, create_udf, length, lit, lower, max, min, sum, trim,
+    upper, JoinType, Partitioning,
 };
 pub use crate::physical_plan::csv::CsvReadOptions;