You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2022/02/15 19:00:11 UTC
[arrow-datafusion] branch arrow2 updated: Arrow2 02092022 (#1795)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch arrow2
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/arrow2 by this push:
new 14bf39d Arrow2 02092022 (#1795)
14bf39d is described below
commit 14bf39d4dbd28368b267aec728fba345764f759c
Author: Guillaume Balaine <ig...@gmail.com>
AuthorDate: Tue Feb 15 20:00:04 2022 +0100
Arrow2 02092022 (#1795)
* feat: add join type for logical plan display (#1674)
* (minor) Reduce memory manager and disk manager logs from `info!` to `debug!` (#1689)
* Move `information_schema` tests out of execution/context.rs to `sql_integration` tests (#1684)
* Move tests from context.rs to information_schema.rs
* Fix up tests to compile
* Move timestamp related tests out of context.rs and into sql integration test (#1696)
* Move some tests out of context.rs and into sql
* Move support test out of context.rs and into sql tests
* Fixup tests and make them compile
* Add `MemTrackingMetrics` to ease memory tracking for non-limited memory consumers (#1691)
* Memory manager no longer track consumers, update aggregatedMetricsSet
* Easy memory tracking with metrics
* use tracking metrics in SPMS
* tests
* fix
* doc
* Update datafusion/src/physical_plan/sorts/sort.rs
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* make tracker AtomicUsize
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* Implement TableProvider for DataFrameImpl (#1699)
* Add TableProvider impl for DataFrameImpl
* Add physical plan in
* Clean up plan construction and names construction
* Remove duplicate comments
* Remove unused parameter
* Add test
* Remove duplicate limit comment
* Use cloned instead of individual clone
* Reduce the amount of code to get a schema
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* Add comments to test
* Fix plan comparison
* Compare only the results of execution
* Remove println
* Refer to df_impl instead of table in test
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* Fix the register_table test to use the correct result set for comparison
* Consolidate group/agg exprs
* Format
* Remove outdated comment
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* refine test in repartition.rs & coalesce_batches.rs (#1707)
* Fuzz test for spillable sort (#1706)
* Lazy TempDir creation in DiskManager (#1695)
* Incorporate dyn scalar kernels (#1685)
* Rebase
* impl ToNumeric for ScalarValue
* Update macro to be based on
* Add floats
* Cleanup
* Newline
* add annotation for select_to_plan (#1714)
* Support `create_physical_expr` and `ExecutionContextState` or `DefaultPhysicalPlanner` for faster speed (#1700)
* Change physical_expr creation API
* Refactor API usage to avoid creating ExecutionContextState
* Fixup ballista
* clippy!
* Fix can not load parquet table form spark in datafusion-cli. (#1665)
* fix can not load parquet table form spark
* add Invalid file in log.
* fix fmt
* add upper bound for pub fn (#1713)
Signed-off-by: remzi <13...@gmail.com>
* Create SchemaAdapter trait to map table schema to file schemas (#1709)
* Create SchemaAdapter trait to map table schema to file schemas
* Linting fix
* Remove commented code
* approx_quantile() aggregation function (#1539)
* feat: implement TDigest for approx quantile
Adds a [TDigest] implementation providing approximate quantile
estimations of large inputs using a small amount of (bounded) memory.
A TDigest is most accurate near either "end" of the quantile range (that
is, 0.1, 0.9, 0.95, etc) due to the use of a scalaing function that
increases resolution at the tails. The paper claims single digit part
per million errors for q ≤ 0.001 or q ≥ 0.999 using 100 centroids, and
in practice I have found accuracy to be more than acceptable for an
apprixmate function across the entire quantile range.
The implementation is a modified copy of
https://github.com/MnO2/t-digest, itself a Rust port of [Facebook's C++
implementation]. Both Facebook's implementation, and Mn02's Rust port
are Apache 2.0 licensed.
[TDigest]: https://arxiv.org/abs/1902.04023
[Facebook's C++ implementation]: https://github.com/facebook/folly/blob/main/folly/stats/TDigest.h
* feat: approx_quantile aggregation
Adds the ApproxQuantile physical expression, plumbing & test cases.
The function signature is:
approx_quantile(column, quantile)
Where column can be any numeric type (that can be cast to a float64) and
quantile is a float64 literal between 0 and 1.
* feat: approx_quantile dataframe function
Adds the approx_quantile() dataframe function, and exports it in the
prelude.
* refactor: bastilla approx_quantile support
Adds bastilla wire encoding for approx_quantile.
Adding support for this required modifying the AggregateExprNode proto
message to support propigating multiple LogicalExprNode aggregate
arguments - all the existing aggregations take a single argument, so
this wasn't needed before.
This commit adds "repeated" to the expr field, which I believe is
backwards compatible as described here:
https://developers.google.com/protocol-buffers/docs/proto3#updating
Specifically, adding "repeated" to an existing message field:
"For ... message fields, optional is compatible with repeated"
No existing tests needed fixing, and a new roundtrip test is included
that covers the change to allow multiple expr.
* refactor: use input type as return type
Casts the calculated quantile value to the same type as the input data.
* fixup! refactor: bastilla approx_quantile support
* refactor: rebase onto main
* refactor: validate quantile value
Ensures the quantile values is between 0 and 1, emitting a plan error if
not.
* refactor: rename to approx_percentile_cont
* refactor: clippy lints
* suppport bitwise and as an example (#1653)
* suppport bitwise and as an example
* Use $OP in macro rather than `&`
* fix: change signature to &dyn Array
* fmt
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* fix: substr - correct behaivour with negative start pos (#1660)
* minor: fix cargo run --release error (#1723)
* Convert boolean case expressions to boolean logic (#1719)
* Convert boolean case expressions to boolean logic
* Review feedback
* substitute `parking_lot::Mutex` for `std::sync::Mutex` (#1720)
* Substitute parking_lot::Mutex for std::sync::Mutex
* enable parking_lot feature in tokio
* Add Expression Simplification API (#1717)
* Add Expression Simplification API
* fmt
* Add tests and CI for optional pyarrow module (#1711)
* Implement other side of conversion
* Add test workflow
* Add (failing) tests
* Get unit tests passing
* Use python -m pip
* Debug LD_LIBRARY_PATH
* Set LIBRARY_PATH
* Update help with better info
* Update parking_lot requirement from 0.11 to 0.12 (#1735)
Updates the requirements on [parking_lot](https://github.com/Amanieu/parking_lot) to permit the latest version.
- [Release notes](https://github.com/Amanieu/parking_lot/releases)
- [Changelog](https://github.com/Amanieu/parking_lot/blob/master/CHANGELOG.md)
- [Commits](https://github.com/Amanieu/parking_lot/compare/0.11.0...0.12.0)
---
updated-dependencies:
- dependency-name: parking_lot
dependency-type: direct:production
...
Signed-off-by: dependabot[bot] <su...@github.com>
Co-authored-by: dependabot[bot] <49...@users.noreply.github.com>
* Prevent repartitioning of certain operator's direct children (#1731) (#1732)
* Prevent repartitioning of certain operator's direct children (#1731)
* Update ballista tests
* Don't repartition children of RepartitionExec
* Revert partition restriction on Repartition and Projection
* Review feedback
* Lint
* API to get Expr's type and nullability without a `DFSchema` (#1726)
* API to get Expr type and nullability without a `DFSchema`
* Add test
* publically export
* Improve docs
* Fix typos in crate documentation (#1739)
* add `cargo check --release` to ci (#1737)
* remote test
* Update .github/workflows/rust.yml
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
* Move optimize test out of context.rs (#1742)
* Move optimize test out of context.rs
* Update
* use clap 3 style args parsing for datafusion cli (#1749)
* use clap 3 style args parsing for datafusion cli
* upgrade cli version
* Add partitioned_csv setup code to sql_integration test (#1743)
* use ordered-float 2.10 (#1756)
Signed-off-by: Andy Grove <ag...@apache.org>
* #1768 Support TimeUnit::Second in hasher (#1769)
* Support TimeUnit::Second in hasher
* fix linter
* format (#1745)
* Create built-in scalar functions programmatically (#1734)
* create build-in scalar functions programatically
Signed-off-by: remzi <13...@gmail.com>
* solve conflict
Signed-off-by: remzi <13...@gmail.com>
* fix spelling mistake
Signed-off-by: remzi <13...@gmail.com>
* rename to call_fn
Signed-off-by: remzi <13...@gmail.com>
* [split/1] split datafusion-common module (#1751)
* split datafusion-common module
* pyarrow
* Update datafusion-common/README.md
Co-authored-by: Andy Grove <ag...@apache.org>
* Update datafusion/Cargo.toml
* include publishing
Co-authored-by: Andy Grove <ag...@apache.org>
* fix: Case insensitive unquoted identifiers (#1747)
* move dfschema and column (#1758)
* add datafusion-expr module (#1759)
* move column, dfschema, etc. to common module (#1760)
* include window frames and operator into datafusion-expr (#1761)
* move signature, type signature, and volatility to split module (#1763)
* [split/10] split up expr for rewriting, visiting, and simplification traits (#1774)
* split up expr for rewriting, visiting, and simplification
* add docs
* move built-in scalar functions (#1764)
* split expr type and null info to be expr-schemable (#1784)
* rewrite predicates before pushing to union inputs (#1781)
* move accumulator and columnar value (#1765)
* move accumulator and columnar value (#1762)
* fix bad data type in test_try_cast_decimal_to_decimal
* added projections for avro columns
Co-authored-by: xudong.w <wx...@gmail.com>
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
Co-authored-by: Yijie Shen <he...@gmail.com>
Co-authored-by: Phillip Cloud <41...@users.noreply.github.com>
Co-authored-by: Matthew Turner <ma...@outlook.com>
Co-authored-by: Yang <37...@users.noreply.github.com>
Co-authored-by: Remzi Yang <59...@users.noreply.github.com>
Co-authored-by: Dan Harris <13...@users.noreply.github.com>
Co-authored-by: Dom <do...@itsallbroken.com>
Co-authored-by: Kun Liu <li...@apache.org>
Co-authored-by: Dmitry Patsura <ta...@dmtry.me>
Co-authored-by: Raphael Taylor-Davies <17...@users.noreply.github.com>
Co-authored-by: Will Jones <wi...@gmail.com>
Co-authored-by: dependabot[bot] <49...@users.noreply.github.com>
Co-authored-by: r.4ntix <r....@gmail.com>
Co-authored-by: Jiayu Liu <Ji...@users.noreply.github.com>
Co-authored-by: Andy Grove <ag...@apache.org>
Co-authored-by: Rich <jy...@users.noreply.github.com>
Co-authored-by: Marko Mikulicic <mm...@gmail.com>
Co-authored-by: Eduard Karacharov <13...@users.noreply.github.com>
---
.github/workflows/rust.yml | 59 +-
Cargo.toml | 6 +-
README.md | 354 +---
ballista/rust/client/Cargo.toml | 2 +-
ballista/rust/core/Cargo.toml | 2 +-
ballista/rust/executor/Cargo.toml | 2 +-
ballista/rust/scheduler/Cargo.toml | 2 +-
datafusion-cli/Cargo.toml | 3 +-
datafusion-cli/src/command.rs | 11 +-
datafusion-cli/src/exec.rs | 9 +-
datafusion-cli/src/functions.rs | 4 +-
datafusion-cli/src/lib.rs | 1 -
datafusion-cli/src/main.rs | 162 +-
datafusion-cli/src/print_format.rs | 72 +-
.../rust/client => datafusion-common}/Cargo.toml | 36 +-
datafusion-common/README.md | 24 +
datafusion-common/src/column.rs | 150 ++
.../src}/dfschema.rs | 36 +-
{datafusion => datafusion-common}/src/error.rs | 12 +-
.../src/field_util.rs | 0
{datafusion-cli => datafusion-common}/src/lib.rs | 22 +-
datafusion-common/src/pyarrow.rs | 247 +++
.../src/record_batch.rs | 17 +
{datafusion => datafusion-common}/src/scalar.rs | 4 +-
.../src/scalar_tmp.rs | 0
{datafusion-cli => datafusion-expr}/Cargo.toml | 30 +-
datafusion-expr/README.md | 24 +
datafusion-expr/src/accumulator.rs | 44 +
datafusion-expr/src/aggregate_function.rs | 93 +
datafusion-expr/src/built_in_function.rs | 330 ++++
datafusion-expr/src/columnar_value.rs | 63 +
datafusion-expr/src/expr.rs | 698 +++++++
.../src/lib.rs => datafusion-expr/src/expr_fn.rs | 24 +-
datafusion-expr/src/function.rs | 46 +
datafusion-expr/src/lib.rs | 49 +
datafusion-expr/src/literal.rs | 138 ++
.../src/operator.rs | 36 +-
datafusion-expr/src/signature.rs | 116 ++
.../physical_plan => datafusion-expr/src}/udaf.rs | 95 +-
.../physical_plan => datafusion-expr/src}/udf.rs | 39 +-
.../src/window_frame.rs | 34 +-
datafusion-expr/src/window_function.rs | 204 ++
datafusion/Cargo.toml | 11 +-
datafusion/benches/sort_limit_query_sql.rs | 3 +
datafusion/fuzz-utils/Cargo.toml | 2 +-
datafusion/fuzz-utils/src/lib.rs | 11 +-
datafusion/src/avro_to_arrow/arrow_array_reader.rs | 2 +
datafusion/src/avro_to_arrow/reader.rs | 23 +-
datafusion/src/avro_to_arrow/schema.rs | 1 -
datafusion/src/dataframe.rs | 3 +-
datafusion/src/datasource/file_format/parquet.rs | 6 +-
datafusion/src/datasource/listing/helpers.rs | 2 +-
datafusion/src/datasource/memory.rs | 1 -
datafusion/src/error.rs | 171 +-
datafusion/src/execution/context.rs | 561 +++---
datafusion/src/execution/dataframe_impl.rs | 4 +-
datafusion/src/field_util.rs | 474 +----
datafusion/src/lib.rs | 4 +-
datafusion/src/logical_plan/builder.rs | 49 +-
datafusion/src/logical_plan/dfschema.rs | 667 +------
datafusion/src/logical_plan/expr.rs | 1966 +-------------------
datafusion/src/logical_plan/expr_rewriter.rs | 592 ++++++
datafusion/src/logical_plan/expr_schema.rs | 232 +++
datafusion/src/logical_plan/expr_simplier.rs | 97 +
datafusion/src/logical_plan/expr_visitor.rs | 176 ++
datafusion/src/logical_plan/mod.rs | 26 +-
datafusion/src/logical_plan/operators.rs | 123 +-
datafusion/src/logical_plan/window_frames.rs | 363 +---
.../src/optimizer/common_subexpr_eliminate.rs | 4 +-
datafusion/src/optimizer/filter_push_down.rs | 54 +-
datafusion/src/optimizer/simplify_expressions.rs | 62 +-
.../src/optimizer/single_distinct_to_groupby.rs | 1 +
datafusion/src/optimizer/utils.rs | 6 +-
datafusion/src/physical_optimizer/repartition.rs | 212 ++-
datafusion/src/physical_plan/aggregates.rs | 87 +-
.../src/physical_plan/expressions/try_cast.rs | 2 +-
.../src/physical_plan/file_format/parquet.rs | 52 +-
datafusion/src/physical_plan/functions.rs | 434 +----
datafusion/src/physical_plan/hash_utils.rs | 10 +
datafusion/src/physical_plan/limit.rs | 5 +
datafusion/src/physical_plan/mod.rs | 70 +-
datafusion/src/physical_plan/udaf.rs | 83 +-
datafusion/src/physical_plan/udf.rs | 85 +-
datafusion/src/physical_plan/union.rs | 4 +
datafusion/src/physical_plan/window_functions.rs | 186 +-
datafusion/src/pyarrow.rs | 96 -
datafusion/src/record_batch.rs | 452 +----
datafusion/src/scalar.rs | 1904 +------------------
datafusion/src/sql/planner.rs | 21 +-
datafusion/src/sql/utils.rs | 10 +
datafusion/tests/order_spill_fuzz.rs | 6 +-
datafusion/tests/parquet_pruning.rs | 75 +-
datafusion/tests/simplification.rs | 2 +
datafusion/tests/sql/explain.rs | 60 +
datafusion/tests/sql/mod.rs | 17 +
datafusion/tests/sql/partitioned_csv.rs | 95 +
datafusion/tests/sql/projection.rs | 192 ++
datafusion/tests/sql/select.rs | 58 +
docs/source/index.rst | 1 +
docs/source/specification/quarterly_roadmap.md | 72 +
docs/source/user-guide/sql/index.rst | 1 +
.../source/user-guide/sql/sql_status.md | 192 --
102 files changed, 4930 insertions(+), 8551 deletions(-)
diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml
index 8a7f673..a9ff52e 100644
--- a/.github/workflows/rust.yml
+++ b/.github/workflows/rust.yml
@@ -58,12 +58,18 @@ jobs:
rustup toolchain install ${{ matrix.rust }}
rustup default ${{ matrix.rust }}
rustup component add rustfmt
- - name: Build Workspace
+ - name: Build workspace in debug mode
run: |
cargo build
env:
CARGO_HOME: "/github/home/.cargo"
- CARGO_TARGET_DIR: "/github/home/target"
+ CARGO_TARGET_DIR: "/github/home/target/debug"
+ - name: Build workspace in release mode
+ run: |
+ cargo check --release
+ env:
+ CARGO_HOME: "/github/home/.cargo"
+ CARGO_TARGET_DIR: "/github/home/target/release"
- name: Check DataFusion Build without default features
run: |
cargo check --no-default-features -p datafusion
@@ -230,6 +236,55 @@ jobs:
# do not produce debug symbols to keep memory usage down
RUSTFLAGS: "-C debuginfo=0"
+ test-datafusion-pyarrow:
+ needs: [linux-build-lib]
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ arch: [amd64]
+ rust: [stable]
+ container:
+ image: ${{ matrix.arch }}/rust
+ env:
+ # Disable full debug symbol generation to speed up CI build and keep memory down
+ # "1" means line tables only, which is useful for panic tracebacks.
+ RUSTFLAGS: "-C debuginfo=1"
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ submodules: true
+ - name: Cache Cargo
+ uses: actions/cache@v2
+ with:
+ path: /github/home/.cargo
+ # this key equals the ones on `linux-build-lib` for re-use
+ key: cargo-cache-
+ - name: Cache Rust dependencies
+ uses: actions/cache@v2
+ with:
+ path: /github/home/target
+ # this key equals the ones on `linux-build-lib` for re-use
+ key: ${{ runner.os }}-${{ matrix.arch }}-target-cache-${{ matrix.rust }}
+ - uses: actions/setup-python@v2
+ with:
+ python-version: "3.8"
+ - name: Install PyArrow
+ run: |
+ echo "LIBRARY_PATH=$LD_LIBRARY_PATH" >> $GITHUB_ENV
+ python -m pip install pyarrow
+ - name: Setup Rust toolchain
+ run: |
+ rustup toolchain install ${{ matrix.rust }}
+ rustup default ${{ matrix.rust }}
+ rustup component add rustfmt
+ - name: Run tests
+ run: |
+ cd datafusion
+ cargo test --features=pyarrow
+ env:
+ CARGO_HOME: "/github/home/.cargo"
+ CARGO_TARGET_DIR: "/github/home/target"
+
lint:
name: Lint
runs-on: ubuntu-latest
diff --git a/Cargo.toml b/Cargo.toml
index 5af182e..a988927 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -18,6 +18,8 @@
[workspace]
members = [
"datafusion",
+ "datafusion-common",
+ "datafusion-expr",
"datafusion-cli",
"datafusion-examples",
"benchmarks",
@@ -33,5 +35,5 @@ lto = true
codegen-units = 1
[patch.crates-io]
-#arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", branch = "main" }
-#parquet2 = { git = "https://github.com/jorgecarleitao/parquet2.git", branch = "main" }
+arrow2 = { git = "https://github.com/jorgecarleitao/arrow2.git", branch = "main" }
+parquet2 = { git = "https://github.com/jorgecarleitao/parquet2.git", branch = "main" }
diff --git a/README.md b/README.md
index 25fc16c..dc350f6 100644
--- a/README.md
+++ b/README.md
@@ -73,361 +73,23 @@ Here are some of the projects known to use DataFusion:
## Example Usage
-Run a SQL query against data stored in a CSV:
+Please see [example usage](https://arrow.apache.org/datafusion/user-guide/example-usage.html) to find how to use DataFusion.
-```rust
-use datafusion::prelude::*;
-use datafusion::arrow::record_batch::RecordBatch;
-
-#[tokio::main]
-async fn main() -> datafusion::error::Result<()> {
- // register the table
- let mut ctx = ExecutionContext::new();
- ctx.register_csv("example", "tests/example.csv", CsvReadOptions::new()).await?;
-
- // create a plan to run a SQL query
- let df = ctx.sql("SELECT a, MIN(b) FROM example GROUP BY a LIMIT 100").await?;
-
- // execute and print results
- df.show().await?;
- Ok(())
-}
-```
-
-Use the DataFrame API to process data stored in a CSV:
-
-```rust
-use datafusion::prelude::*;
-use datafusion::arrow::record_batch::RecordBatch;
-
-#[tokio::main]
-async fn main() -> datafusion::error::Result<()> {
- // create the dataframe
- let mut ctx = ExecutionContext::new();
- let df = ctx.read_csv("tests/example.csv", CsvReadOptions::new()).await?;
-
- let df = df.filter(col("a").lt_eq(col("b")))?
- .aggregate(vec![col("a")], vec![min(col("b"))])?;
-
- // execute and print results
- df.show_limit(100).await?;
- Ok(())
-}
-```
-
-Both of these examples will produce
-
-```text
-+---+--------+
-| a | MIN(b) |
-+---+--------+
-| 1 | 2 |
-+---+--------+
-```
-
-## Using DataFusion as a library
-
-DataFusion is [published on crates.io](https://crates.io/crates/datafusion), and is [well documented on docs.rs](https://docs.rs/datafusion/).
-
-To get started, add the following to your `Cargo.toml` file:
-
-```toml
-[dependencies]
-datafusion = "6.0.0"
-```
-
-## Using DataFusion as a binary
-
-DataFusion also includes a simple command-line interactive SQL utility. See the [CLI reference](https://arrow.apache.org/datafusion/cli/index.html) for more information.
-
-# Roadmap
-
-A quarterly roadmap will be published to give the DataFusion community visibility into the priorities of the projects contributors. This roadmap is not binding.
-
-## 2022 Q1
-
-### DataFusion Core
-
-- Publish official Arrow2 branch
-- Implementation of memory manager (i.e. to enable spilling to disk as needed)
-
-### Benchmarking
-
-- Inclusion in Db-Benchmark with all quries covered
-- All TPCH queries covered
-
-### Performance Improvements
-
-- Predicate evaluation
-- Improve multi-column comparisons (that can't be vectorized at the moment)
-- Null constant support
-
-### New Features
-
-- Read JSON as table
-- Simplify DDL with Datafusion-Cli
-- Add Decimal128 data type and the attendant features such as Arrow Kernel and UDF support
-- Add new experimental e-graph based optimizer
-
-### Ballista
-
-- Begin work on design documents and plan / priorities for development
-
-### Extensions ([datafusion-contrib](https://github.com/datafusion-contrib]))
-
-- Stable S3 support
-- Begin design discussions and prototyping of a stream provider
-
-## Beyond 2022 Q1
-
-There is no clear timeline for the below, but community members have expressed interest in working on these topics.
-
-### DataFusion Core
-
-- Custom SQL support
-- Split DataFusion into multiple crates
-- Push based query execution and code generation
-
-### Ballista
-
-- Evolve architecture so that it can be deployed in a multi-tenant cloud native environment
-- Ensure Ballista is scalable, elastic, and stable for production usage
-- Develop distributed ML capabilities
-
-# Status
-
-## General
-
-- [x] SQL Parser
-- [x] SQL Query Planner
-- [x] Query Optimizer
-- [x] Constant folding
-- [x] Join Reordering
-- [x] Limit Pushdown
-- [x] Projection push down
-- [x] Predicate push down
-- [x] Type coercion
-- [x] Parallel query execution
-
-## SQL Support
-
-- [x] Projection
-- [x] Filter (WHERE)
-- [x] Filter post-aggregate (HAVING)
-- [x] Limit
-- [x] Aggregate
-- [x] Common math functions
-- [x] cast
-- [x] try_cast
-- [x] [`VALUES` lists](https://www.postgresql.org/docs/current/queries-values.html)
-- Postgres compatible String functions
- - [x] ascii
- - [x] bit_length
- - [x] btrim
- - [x] char_length
- - [x] character_length
- - [x] chr
- - [x] concat
- - [x] concat_ws
- - [x] initcap
- - [x] left
- - [x] length
- - [x] lpad
- - [x] ltrim
- - [x] octet_length
- - [x] regexp_replace
- - [x] repeat
- - [x] replace
- - [x] reverse
- - [x] right
- - [x] rpad
- - [x] rtrim
- - [x] split_part
- - [x] starts_with
- - [x] strpos
- - [x] substr
- - [x] to_hex
- - [x] translate
- - [x] trim
-- Miscellaneous/Boolean functions
- - [x] nullif
-- Approximation functions
- - [x] approx_distinct
-- Common date/time functions
- - [ ] Basic date functions
- - [ ] Basic time functions
- - [x] Basic timestamp functions
- - [x] [to_timestamp](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp)
- - [x] [to_timestamp_millis](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_millis)
- - [x] [to_timestamp_micros](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_micros)
- - [x] [to_timestamp_seconds](docs/user-guide/book/sql/datafusion-functions.html#to_timestamp_seconds)
-- nested functions
- - [x] Array of columns
-- [x] Schema Queries
- - [x] SHOW TABLES
- - [x] SHOW COLUMNS
- - [x] information_schema.{tables, columns}
- - [ ] information_schema other views
-- [x] Sorting
-- [ ] Nested types
-- [ ] Lists
-- [x] Subqueries
-- [x] Common table expressions
-- [x] Set Operations
- - [x] UNION ALL
- - [x] UNION
- - [x] INTERSECT
- - [x] INTERSECT ALL
- - [x] EXCEPT
- - [x] EXCEPT ALL
-- [x] Joins
- - [x] INNER JOIN
- - [x] LEFT JOIN
- - [x] RIGHT JOIN
- - [x] FULL JOIN
- - [x] CROSS JOIN
-- [ ] Window
- - [x] Empty window
- - [x] Common window functions
- - [x] Window with PARTITION BY clause
- - [x] Window with ORDER BY clause
- - [ ] Window with FILTER clause
- - [ ] [Window with custom WINDOW FRAME](https://github.com/apache/arrow-datafusion/issues/361)
- - [ ] UDF and UDAF for window functions
-
-## Data Sources
-
-- [x] CSV
-- [x] Parquet primitive types
-- [ ] Parquet nested types
-
-## Extensibility
-
-DataFusion is designed to be extensible at all points. To that end, you can provide your own custom:
-
-- [x] User Defined Functions (UDFs)
-- [x] User Defined Aggregate Functions (UDAFs)
-- [x] User Defined Table Source (`TableProvider`) for tables
-- [x] User Defined `Optimizer` passes (plan rewrites)
-- [x] User Defined `LogicalPlan` nodes
-- [x] User Defined `ExecutionPlan` nodes
-
-## Rust Version Compatbility
-
-This crate is tested with the latest stable version of Rust. We do not currently test against other, older versions of the Rust compiler.
-
-# Supported SQL
-
-This library currently supports many SQL constructs, including
-
-- `CREATE EXTERNAL TABLE X STORED AS PARQUET LOCATION '...';` to register a table's locations
-- `SELECT ... FROM ...` together with any expression
-- `ALIAS` to name an expression
-- `CAST` to change types, including e.g. `Timestamp(Nanosecond, None)`
-- Many mathematical unary and binary expressions such as `+`, `/`, `sqrt`, `tan`, `>=`.
-- `WHERE` to filter
-- `GROUP BY` together with one of the following aggregations: `MIN`, `MAX`, `COUNT`, `SUM`, `AVG`, `CORR`, `VAR`, `COVAR`, `STDDEV` (sample and population)
-- `ORDER BY` together with an expression and optional `ASC` or `DESC` and also optional `NULLS FIRST` or `NULLS LAST`
-
-## Supported Functions
-
-DataFusion strives to implement a subset of the [PostgreSQL SQL dialect](https://www.postgresql.org/docs/current/functions.html) where possible. We explicitly choose a single dialect to maximize interoperability with other tools and allow reuse of the PostgreSQL documents and tutorials as much as possible.
-
-Currently, only a subset of the PostgreSQL dialect is implemented, and we will document any deviations.
-
-## Schema Metadata / Information Schema Support
-
-DataFusion supports the showing metadata about the tables available. This information can be accessed using the views of the ISO SQL `information_schema` schema or the DataFusion specific `SHOW TABLES` and `SHOW COLUMNS` commands.
-
-More information can be found in the [Postgres docs](https://www.postgresql.org/docs/13/infoschema-schema.html)).
-
-To show tables available for use in DataFusion, use the `SHOW TABLES` command or the `information_schema.tables` view:
-
-```sql
-> show tables;
-+---------------+--------------------+------------+------------+
-| table_catalog | table_schema | table_name | table_type |
-+---------------+--------------------+------------+------------+
-| datafusion | public | t | BASE TABLE |
-| datafusion | information_schema | tables | VIEW |
-+---------------+--------------------+------------+------------+
-
-> select * from information_schema.tables;
-
-+---------------+--------------------+------------+--------------+
-| table_catalog | table_schema | table_name | table_type |
-+---------------+--------------------+------------+--------------+
-| datafusion | public | t | BASE TABLE |
-| datafusion | information_schema | TABLES | SYSTEM TABLE |
-+---------------+--------------------+------------+--------------+
-```
-
-To show the schema of a table in DataFusion, use the `SHOW COLUMNS` command or the or `information_schema.columns` view:
-
-```sql
-> show columns from t;
-+---------------+--------------+------------+-------------+-----------+-------------+
-| table_catalog | table_schema | table_name | column_name | data_type | is_nullable |
-+---------------+--------------+------------+-------------+-----------+-------------+
-| datafusion | public | t | a | Int32 | NO |
-| datafusion | public | t | b | Utf8 | NO |
-| datafusion | public | t | c | Float32 | NO |
-+---------------+--------------+------------+-------------+-----------+-------------+
-
-> select table_name, column_name, ordinal_position, is_nullable, data_type from information_schema.columns;
-+------------+-------------+------------------+-------------+-----------+
-| table_name | column_name | ordinal_position | is_nullable | data_type |
-+------------+-------------+------------------+-------------+-----------+
-| t | a | 0 | NO | Int32 |
-| t | b | 1 | NO | Utf8 |
-| t | c | 2 | NO | Float32 |
-+------------+-------------+------------------+-------------+-----------+
-```
-
-## Supported Data Types
-
-DataFusion uses Arrow, and thus the Arrow type system, for query
-execution. The SQL types from
-[sqlparser-rs](https://github.com/ballista-compute/sqlparser-rs/blob/main/src/ast/data_type.rs#L57)
-are mapped to Arrow types according to the following table
-
-| SQL Data Type | Arrow DataType |
-| ------------- | --------------------------------- |
-| `CHAR` | `Utf8` |
-| `VARCHAR` | `Utf8` |
-| `UUID` | _Not yet supported_ |
-| `CLOB` | _Not yet supported_ |
-| `BINARY` | _Not yet supported_ |
-| `VARBINARY` | _Not yet supported_ |
-| `DECIMAL` | `Float64` |
-| `FLOAT` | `Float32` |
-| `SMALLINT` | `Int16` |
-| `INT` | `Int32` |
-| `BIGINT` | `Int64` |
-| `REAL` | `Float32` |
-| `DOUBLE` | `Float64` |
-| `BOOLEAN` | `Boolean` |
-| `DATE` | `Date32` |
-| `TIME` | `Time64(TimeUnit::Millisecond)` |
-| `TIMESTAMP` | `Timestamp(TimeUnit::Nanosecond)` |
-| `INTERVAL` | _Not yet supported_ |
-| `REGCLASS` | _Not yet supported_ |
-| `TEXT` | _Not yet supported_ |
-| `BYTEA` | _Not yet supported_ |
-| `CUSTOM` | _Not yet supported_ |
-| `ARRAY` | _Not yet supported_ |
-
-# Roadmap
+## Roadmap
Please see [Roadmap](docs/source/specification/roadmap.md) for information of where the project is headed.
-# Architecture Overview
+## Architecture Overview
There is no formal document describing DataFusion's architecture yet, but the following presentations offer a good overview of its different components and how they interact together.
- (March 2021): The DataFusion architecture is described in _Query Engine Design and the Rust-Based DataFusion in Apache Arrow_: [recording](https://www.youtube.com/watch?v=K6eCAVEk4kU) (DataFusion content starts [~ 15 minutes in](https://www.youtube.com/watch?v=K6eCAVEk4kU&t=875s)) and [slides](https://www.slideshare.net/influxdata/influxdb-iox-tech-talks-query-engine-design-and-the-rustbased-datafusion-in-apache-arrow-244161934)
- (February 2021): How DataFusion is used within the Ballista Project is described in \*Ballista: Distributed Compute with Rust and Apache Arrow: [recording](https://www.youtube.com/watch?v=ZZHQaOap9pQ)
-# Developer's guide
+## User's guide
+
+Please see [User Guide](https://arrow.apache.org/datafusion/) for more information about DataFusion.
+
+## Developer's guide
Please see [Developers Guide](DEVELOPERS.md) for information about developing DataFusion.
diff --git a/ballista/rust/client/Cargo.toml b/ballista/rust/client/Cargo.toml
index 4ec1abe..dff5d1a 100644
--- a/ballista/rust/client/Cargo.toml
+++ b/ballista/rust/client/Cargo.toml
@@ -35,7 +35,7 @@ log = "0.4"
tokio = "1.0"
tempfile = "3"
sqlparser = "0.13"
-parking_lot = "0.11"
+parking_lot = "0.12"
datafusion = { path = "../../../datafusion", version = "6.0.0" }
diff --git a/ballista/rust/core/Cargo.toml b/ballista/rust/core/Cargo.toml
index cdbbbf0..dfc6a6c 100644
--- a/ballista/rust/core/Cargo.toml
+++ b/ballista/rust/core/Cargo.toml
@@ -50,7 +50,7 @@ arrow = { package = "arrow2", version="0.9", features = ["io_ipc", "io_flight"]
datafusion = { path = "../../../datafusion", version = "6.0.0" }
-parking_lot = "0.11"
+parking_lot = "0.12"
[dev-dependencies]
tempfile = "3"
diff --git a/ballista/rust/executor/Cargo.toml b/ballista/rust/executor/Cargo.toml
index 310affd..ad48962 100644
--- a/ballista/rust/executor/Cargo.toml
+++ b/ballista/rust/executor/Cargo.toml
@@ -46,7 +46,7 @@ tokio-stream = { version = "0.1", features = ["net"] }
tonic = "0.6"
uuid = { version = "0.8", features = ["v4"] }
hyper = "0.14.4"
-parking_lot = "0.11"
+parking_lot = "0.12"
[dev-dependencies]
diff --git a/ballista/rust/scheduler/Cargo.toml b/ballista/rust/scheduler/Cargo.toml
index fdeb7e7..8acb13b 100644
--- a/ballista/rust/scheduler/Cargo.toml
+++ b/ballista/rust/scheduler/Cargo.toml
@@ -53,7 +53,7 @@ tokio-stream = { version = "0.1", features = ["net"], optional = true }
tonic = "0.6"
tower = { version = "0.4" }
warp = "0.3"
-parking_lot = "0.11"
+parking_lot = "0.12"
[dev-dependencies]
ballista-core = { path = "../core", version = "0.6.0" }
diff --git a/datafusion-cli/Cargo.toml b/datafusion-cli/Cargo.toml
index 09df15b..26ccdaf 100644
--- a/datafusion-cli/Cargo.toml
+++ b/datafusion-cli/Cargo.toml
@@ -17,7 +17,8 @@
[package]
name = "datafusion-cli"
-version = "5.1.0"
+description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model. It supports executing SQL queries against CSV and Parquet files as well as querying directly against in-memory data."
+version = "6.0.0"
authors = ["Apache Arrow <de...@arrow.apache.org>"]
edition = "2021"
keywords = [ "arrow", "datafusion", "ballista", "query", "sql" ]
diff --git a/datafusion-cli/src/command.rs b/datafusion-cli/src/command.rs
index fa37059..f6bedc2 100644
--- a/datafusion-cli/src/command.rs
+++ b/datafusion-cli/src/command.rs
@@ -20,7 +20,8 @@
use crate::context::Context;
use crate::functions::{display_all_functions, Function};
use crate::print_format::PrintFormat;
-use crate::print_options::{self, PrintOptions};
+use crate::print_options::PrintOptions;
+use clap::ArgEnum;
use datafusion::arrow::array::{ArrayRef, Utf8Array};
use datafusion::arrow::datatypes::{DataType, Field, Schema};
use datafusion::error::{DataFusionError, Result};
@@ -209,10 +210,14 @@ impl OutputFormat {
Self::ChangeFormat(format) => {
if let Ok(format) = format.parse::<PrintFormat>() {
print_options.format = format;
- println!("Output format is {}.", print_options.format);
+ println!("Output format is {:?}.", print_options.format);
Ok(())
} else {
- Err(DataFusionError::Execution(format!("{} is not a valid format type [possible values: csv, tsv, table, json, ndjson]", format)))
+ Err(DataFusionError::Execution(format!(
+ "{:?} is not a valid format type [possible values: {:?}]",
+ format,
+ PrintFormat::value_variants()
+ )))
}
}
}
diff --git a/datafusion-cli/src/exec.rs b/datafusion-cli/src/exec.rs
index acc340d..17b329b 100644
--- a/datafusion-cli/src/exec.rs
+++ b/datafusion-cli/src/exec.rs
@@ -21,19 +21,14 @@ use crate::{
command::{Command, OutputFormat},
context::Context,
helper::CliHelper,
- print_format::{all_print_formats, PrintFormat},
print_options::PrintOptions,
};
-use datafusion::error::{DataFusionError, Result};
-use datafusion::record_batch::RecordBatch;
-use rustyline::config::Config;
+use datafusion::error::Result;
use rustyline::error::ReadlineError;
use rustyline::Editor;
use std::fs::File;
use std::io::prelude::*;
use std::io::BufReader;
-use std::str::FromStr;
-use std::sync::Arc;
use std::time::Instant;
/// run and execute SQL statements and commands from a file, against a context with the given print options
@@ -108,7 +103,7 @@ pub async fn exec_from_repl(ctx: &mut Context, print_options: &mut PrintOptions)
);
}
} else {
- println!("Output format is {}.", print_options.format);
+ println!("Output format is {:?}.", print_options.format);
}
}
_ => {
diff --git a/datafusion-cli/src/functions.rs b/datafusion-cli/src/functions.rs
index 7839d4f..224f990 100644
--- a/datafusion-cli/src/functions.rs
+++ b/datafusion-cli/src/functions.rs
@@ -20,10 +20,8 @@ use arrow::array::{ArrayRef, Utf8Array};
use arrow::chunk::Chunk;
use arrow::datatypes::{DataType, Field, Schema};
use datafusion::arrow::io::print;
-use datafusion::error::{DataFusionError, Result};
+use datafusion::error::Result;
use datafusion::field_util::SchemaExt;
-use datafusion::physical_plan::ColumnarValue::Array;
-use datafusion::record_batch::RecordBatch;
use std::fmt;
use std::str::FromStr;
use std::sync::Arc;
diff --git a/datafusion-cli/src/lib.rs b/datafusion-cli/src/lib.rs
index b2bcdd3..b75be33 100644
--- a/datafusion-cli/src/lib.rs
+++ b/datafusion-cli/src/lib.rs
@@ -16,7 +16,6 @@
// under the License.
#![doc = include_str!("../README.md")]
-#![allow(unused_imports)]
pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION");
pub mod command;
diff --git a/datafusion-cli/src/main.rs b/datafusion-cli/src/main.rs
index 4cb9e9d..788bb27 100644
--- a/datafusion-cli/src/main.rs
+++ b/datafusion-cli/src/main.rs
@@ -15,14 +15,11 @@
// specific language governing permissions and limitations
// under the License.
-use clap::{crate_version, App, Arg};
+use clap::Parser;
use datafusion::error::Result;
use datafusion::execution::context::ExecutionConfig;
use datafusion_cli::{
- context::Context,
- exec,
- print_format::{all_print_formats, PrintFormat},
- print_options::PrintOptions,
+ context::Context, exec, print_format::PrintFormat, print_options::PrintOptions,
DATAFUSION_CLI_VERSION,
};
use std::env;
@@ -30,117 +27,84 @@ use std::fs::File;
use std::io::BufReader;
use std::path::Path;
+#[derive(Debug, Parser, PartialEq)]
+#[clap(author, version, about, long_about= None)]
+struct Args {
+ #[clap(
+ short = 'p',
+ long,
+ help = "Path to your data, default to current directory",
+ validator(is_valid_data_dir)
+ )]
+ data_path: Option<String>,
+
+ #[clap(
+ short = 'c',
+ long,
+ help = "The batch size of each query, or use DataFusion default",
+ validator(is_valid_batch_size)
+ )]
+ batch_size: Option<usize>,
+
+ #[clap(
+ short,
+ long,
+ multiple_values = true,
+ help = "Execute commands from file(s), then exit",
+ validator(is_valid_file)
+ )]
+ file: Vec<String>,
+
+ #[clap(long, arg_enum, default_value_t = PrintFormat::Table)]
+ format: PrintFormat,
+
+ #[clap(long, help = "Ballista scheduler host")]
+ host: Option<String>,
+
+ #[clap(long, help = "Ballista scheduler port")]
+ port: Option<u16>,
+
+ #[clap(
+ short,
+ long,
+ help = "Reduce printing other than the results and work quietly"
+ )]
+ quiet: bool,
+}
+
#[tokio::main]
pub async fn main() -> Result<()> {
- let matches = App::new("DataFusion")
- .version(crate_version!())
- .about(
- "DataFusion is an in-memory query engine that uses Apache Arrow \
- as the memory model. It supports executing SQL queries against CSV and \
- Parquet files as well as querying directly against in-memory data.",
- )
- .arg(
- Arg::new("data-path")
- .help("Path to your data, default to current directory")
- .short('p')
- .long("data-path")
- .validator(is_valid_data_dir)
- .takes_value(true),
- )
- .arg(
- Arg::new("batch-size")
- .help("The batch size of each query, or use DataFusion default")
- .short('c')
- .long("batch-size")
- .validator(is_valid_batch_size)
- .takes_value(true),
- )
- .arg(
- Arg::new("file")
- .help("Execute commands from file(s), then exit")
- .short('f')
- .long("file")
- .multiple_occurrences(true)
- .validator(is_valid_file)
- .takes_value(true),
- )
- .arg(
- Arg::new("format")
- .help("Output format")
- .long("format")
- .default_value("table")
- .possible_values(
- &all_print_formats()
- .iter()
- .map(|format| format.to_string())
- .collect::<Vec<_>>()
- .iter()
- .map(|i| i.as_str())
- .collect::<Vec<_>>(),
- )
- .takes_value(true),
- )
- .arg(
- Arg::new("host")
- .help("Ballista scheduler host")
- .long("host")
- .takes_value(true),
- )
- .arg(
- Arg::new("port")
- .help("Ballista scheduler port")
- .long("port")
- .takes_value(true),
- )
- .arg(
- Arg::new("quiet")
- .help("Reduce printing other than the results and work quietly")
- .short('q')
- .long("quiet")
- .takes_value(false),
- )
- .get_matches();
-
- let quiet = matches.is_present("quiet");
-
- if !quiet {
- println!("DataFusion CLI v{}\n", DATAFUSION_CLI_VERSION);
- }
+ let args = Args::parse();
- let host = matches.value_of("host");
- let port = matches
- .value_of("port")
- .and_then(|port| port.parse::<u16>().ok());
+ if !args.quiet {
+ println!("DataFusion CLI v{}", DATAFUSION_CLI_VERSION);
+ }
- if let Some(path) = matches.value_of("data-path") {
+ if let Some(ref path) = args.data_path {
let p = Path::new(path);
env::set_current_dir(&p).unwrap();
};
let mut execution_config = ExecutionConfig::new().with_information_schema(true);
- if let Some(batch_size) = matches
- .value_of("batch-size")
- .and_then(|size| size.parse::<usize>().ok())
- {
+ if let Some(batch_size) = args.batch_size {
execution_config = execution_config.with_batch_size(batch_size);
};
- let mut ctx: Context = match (host, port) {
- (Some(h), Some(p)) => Context::new_remote(h, p)?,
+ let mut ctx: Context = match (args.host, args.port) {
+ (Some(ref h), Some(p)) => Context::new_remote(h, p)?,
_ => Context::new_local(&execution_config),
};
- let format = matches
- .value_of("format")
- .expect("No format is specified")
- .parse::<PrintFormat>()
- .expect("Invalid format");
-
- let mut print_options = PrintOptions { format, quiet };
+ let mut print_options = PrintOptions {
+ format: args.format,
+ quiet: args.quiet,
+ };
- if let Some(file_paths) = matches.values_of("file") {
- let files = file_paths
+ let files = args.file;
+ if !files.is_empty() {
+ let files = files
+ .into_iter()
.map(|file_path| File::open(file_path).unwrap())
.collect::<Vec<_>>();
for file in files {
diff --git a/datafusion-cli/src/print_format.rs b/datafusion-cli/src/print_format.rs
index fa8bf23..5a176e0 100644
--- a/datafusion-cli/src/print_format.rs
+++ b/datafusion-cli/src/print_format.rs
@@ -17,15 +17,14 @@
//! Print format variants
use arrow::io::json::write::{JsonArray, JsonFormat, LineDelimited};
-use datafusion::arrow::io::{csv::write, print};
+use datafusion::arrow::io::csv::write;
use datafusion::error::{DataFusionError, Result};
use datafusion::field_util::SchemaExt;
use datafusion::record_batch::RecordBatch;
-use std::fmt;
use std::str::FromStr;
/// Allow records to be printed in different formats
-#[derive(Debug, PartialEq, Eq, Clone)]
+#[derive(Debug, PartialEq, Eq, clap::ArgEnum, Clone)]
pub enum PrintFormat {
Csv,
Tsv,
@@ -34,40 +33,11 @@ pub enum PrintFormat {
NdJson,
}
-/// returns all print formats
-pub fn all_print_formats() -> Vec<PrintFormat> {
- vec![
- PrintFormat::Csv,
- PrintFormat::Tsv,
- PrintFormat::Table,
- PrintFormat::Json,
- PrintFormat::NdJson,
- ]
-}
-
impl FromStr for PrintFormat {
- type Err = ();
- fn from_str(s: &str) -> std::result::Result<Self, ()> {
- match s.to_lowercase().as_str() {
- "csv" => Ok(Self::Csv),
- "tsv" => Ok(Self::Tsv),
- "table" => Ok(Self::Table),
- "json" => Ok(Self::Json),
- "ndjson" => Ok(Self::NdJson),
- _ => Err(()),
- }
- }
-}
+ type Err = String;
-impl fmt::Display for PrintFormat {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- match *self {
- Self::Csv => write!(f, "csv"),
- Self::Tsv => write!(f, "tsv"),
- Self::Table => write!(f, "table"),
- Self::Json => write!(f, "json"),
- Self::NdJson => write!(f, "ndjson"),
- }
+ fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
+ clap::ArgEnum::from_str(s, true)
}
}
@@ -147,38 +117,6 @@ mod tests {
use std::sync::Arc;
#[test]
- fn test_from_str() {
- let format = "csv".parse::<PrintFormat>().unwrap();
- assert_eq!(PrintFormat::Csv, format);
-
- let format = "tsv".parse::<PrintFormat>().unwrap();
- assert_eq!(PrintFormat::Tsv, format);
-
- let format = "json".parse::<PrintFormat>().unwrap();
- assert_eq!(PrintFormat::Json, format);
-
- let format = "ndjson".parse::<PrintFormat>().unwrap();
- assert_eq!(PrintFormat::NdJson, format);
-
- let format = "table".parse::<PrintFormat>().unwrap();
- assert_eq!(PrintFormat::Table, format);
- }
-
- #[test]
- fn test_to_str() {
- assert_eq!("csv", PrintFormat::Csv.to_string());
- assert_eq!("table", PrintFormat::Table.to_string());
- assert_eq!("tsv", PrintFormat::Tsv.to_string());
- assert_eq!("json", PrintFormat::Json.to_string());
- assert_eq!("ndjson", PrintFormat::NdJson.to_string());
- }
-
- #[test]
- fn test_from_str_failure() {
- assert!("pretty".parse::<PrintFormat>().is_err());
- }
-
- #[test]
fn test_print_batches_with_sep() {
let batches = vec![];
assert_eq!("", print_batches_with_sep(&batches, b',').unwrap());
diff --git a/ballista/rust/client/Cargo.toml b/datafusion-common/Cargo.toml
similarity index 67%
copy from ballista/rust/client/Cargo.toml
copy to datafusion-common/Cargo.toml
index 4ec1abe..08f228f 100644
--- a/ballista/rust/client/Cargo.toml
+++ b/datafusion-common/Cargo.toml
@@ -16,29 +16,29 @@
# under the License.
[package]
-name = "ballista"
-description = "Ballista Distributed Compute"
-license = "Apache-2.0"
-version = "0.6.0"
+name = "datafusion-common"
+description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model"
+version = "6.0.0"
homepage = "https://github.com/apache/arrow-datafusion"
repository = "https://github.com/apache/arrow-datafusion"
+readme = "README.md"
authors = ["Apache Arrow <de...@arrow.apache.org>"]
+license = "Apache-2.0"
+keywords = [ "arrow", "query", "sql" ]
edition = "2021"
rust-version = "1.58"
-[dependencies]
-ballista-core = { path = "../core", version = "0.6.0" }
-ballista-executor = { path = "../executor", version = "0.6.0", optional = true }
-ballista-scheduler = { path = "../scheduler", version = "0.6.0", optional = true }
-futures = "0.3"
-log = "0.4"
-tokio = "1.0"
-tempfile = "3"
-sqlparser = "0.13"
-parking_lot = "0.11"
-
-datafusion = { path = "../../../datafusion", version = "6.0.0" }
+[lib]
+name = "datafusion_common"
+path = "src/lib.rs"
[features]
-default = []
-standalone = ["ballista-executor", "ballista-scheduler"]
+pyarrow = ["pyo3"]
+
+[dependencies]
+arrow = { package = "arrow2", version = "0.9", default-features = false }
+parquet = { package = "parquet2", version = "0.10", default_features = false, features = ["stream"] }
+
+pyo3 = { version = "0.15", optional = true }
+sqlparser = "0.13"
+ordered-float = "2.10"
diff --git a/datafusion-common/README.md b/datafusion-common/README.md
new file mode 100644
index 0000000..8c44d78
--- /dev/null
+++ b/datafusion-common/README.md
@@ -0,0 +1,24 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# DataFusion Common
+
+This is an internal module for the most fundamental types of [DataFusion][df].
+
+[df]: https://crates.io/crates/datafusion
diff --git a/datafusion-common/src/column.rs b/datafusion-common/src/column.rs
new file mode 100644
index 0000000..02faa24
--- /dev/null
+++ b/datafusion-common/src/column.rs
@@ -0,0 +1,150 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Column
+
+use crate::{DFSchema, DataFusionError, Result};
+use std::collections::HashSet;
+use std::convert::Infallible;
+use std::fmt;
+use std::str::FromStr;
+use std::sync::Arc;
+
+/// A named reference to a qualified field in a schema.
+#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
+pub struct Column {
+ /// relation/table name.
+ pub relation: Option<String>,
+ /// field/column name.
+ pub name: String,
+}
+
+impl Column {
+ /// Create Column from unqualified name.
+ pub fn from_name(name: impl Into<String>) -> Self {
+ Self {
+ relation: None,
+ name: name.into(),
+ }
+ }
+
+ /// Deserialize a fully qualified name string into a column
+ pub fn from_qualified_name(flat_name: &str) -> Self {
+ use sqlparser::tokenizer::Token;
+
+ let dialect = sqlparser::dialect::GenericDialect {};
+ let mut tokenizer = sqlparser::tokenizer::Tokenizer::new(&dialect, flat_name);
+ if let Ok(tokens) = tokenizer.tokenize() {
+ if let [Token::Word(relation), Token::Period, Token::Word(name)] =
+ tokens.as_slice()
+ {
+ return Column {
+ relation: Some(relation.value.clone()),
+ name: name.value.clone(),
+ };
+ }
+ }
+ // any expression that's not in the form of `foo.bar` will be treated as unqualified column
+ // name
+ Column {
+ relation: None,
+ name: String::from(flat_name),
+ }
+ }
+
+ /// Serialize column into a flat name string
+ pub fn flat_name(&self) -> String {
+ match &self.relation {
+ Some(r) => format!("{}.{}", r, self.name),
+ None => self.name.clone(),
+ }
+ }
+
+ // Internal implementation of normalize
+ pub fn normalize_with_schemas(
+ self,
+ schemas: &[&Arc<DFSchema>],
+ using_columns: &[HashSet<Column>],
+ ) -> Result<Self> {
+ if self.relation.is_some() {
+ return Ok(self);
+ }
+
+ for schema in schemas {
+ let fields = schema.fields_with_unqualified_name(&self.name);
+ match fields.len() {
+ 0 => continue,
+ 1 => {
+ return Ok(fields[0].qualified_column());
+ }
+ _ => {
+ // More than 1 fields in this schema have their names set to self.name.
+ //
+ // This should only happen when a JOIN query with USING constraint references
+ // join columns using unqualified column name. For example:
+ //
+ // ```sql
+ // SELECT id FROM t1 JOIN t2 USING(id)
+ // ```
+ //
+ // In this case, both `t1.id` and `t2.id` will match unqualified column `id`.
+ // We will use the relation from the first matched field to normalize self.
+
+ // Compare matched fields with one USING JOIN clause at a time
+ for using_col in using_columns {
+ let all_matched = fields
+ .iter()
+ .all(|f| using_col.contains(&f.qualified_column()));
+ // All matched fields belong to the same using column set, in orther words
+ // the same join clause. We simply pick the qualifer from the first match.
+ if all_matched {
+ return Ok(fields[0].qualified_column());
+ }
+ }
+ }
+ }
+ }
+
+ Err(DataFusionError::Plan(format!(
+ "Column {} not found in provided schemas",
+ self
+ )))
+ }
+}
+
+impl From<&str> for Column {
+ fn from(c: &str) -> Self {
+ Self::from_qualified_name(c)
+ }
+}
+
+impl FromStr for Column {
+ type Err = Infallible;
+
+ fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
+ Ok(s.into())
+ }
+}
+
+impl fmt::Display for Column {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match &self.relation {
+ Some(r) => write!(f, "#{}.{}", r, self.name),
+ None => write!(f, "#{}", self.name),
+ }
+ }
+}
diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion-common/src/dfschema.rs
similarity index 95%
copy from datafusion/src/logical_plan/dfschema.rs
copy to datafusion-common/src/dfschema.rs
index b89b239..55c5c4c 100644
--- a/datafusion/src/logical_plan/dfschema.rs
+++ b/datafusion-common/src/dfschema.rs
@@ -23,7 +23,7 @@ use std::convert::TryFrom;
use std::sync::Arc;
use crate::error::{DataFusionError, Result};
-use crate::logical_plan::Column;
+use crate::Column;
use crate::field_util::{FieldExt, SchemaExt};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
@@ -399,6 +399,40 @@ impl Display for DFSchema {
}
}
+/// Provides schema information needed by [Expr] methods such as
+/// [Expr::nullable] and [Expr::data_type].
+///
+/// Note that this trait is implemented for &[DFSchema] which is
+/// widely used in the DataFusion codebase.
+pub trait ExprSchema {
+ /// Is this column reference nullable?
+ fn nullable(&self, col: &Column) -> Result<bool>;
+
+ /// What is the datatype of this column?
+ fn data_type(&self, col: &Column) -> Result<&DataType>;
+}
+
+// Implement `ExprSchema` for `Arc<DFSchema>`
+impl<P: AsRef<DFSchema>> ExprSchema for P {
+ fn nullable(&self, col: &Column) -> Result<bool> {
+ self.as_ref().nullable(col)
+ }
+
+ fn data_type(&self, col: &Column) -> Result<&DataType> {
+ self.as_ref().data_type(col)
+ }
+}
+
+impl ExprSchema for DFSchema {
+ fn nullable(&self, col: &Column) -> Result<bool> {
+ Ok(self.field_from_column(col)?.is_nullable())
+ }
+
+ fn data_type(&self, col: &Column) -> Result<&DataType> {
+ Ok(self.field_from_column(col)?.data_type())
+ }
+}
+
/// DFField wraps an Arrow field and adds an optional qualifier
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DFField {
diff --git a/datafusion/src/error.rs b/datafusion-common/src/error.rs
similarity index 95%
copy from datafusion/src/error.rs
copy to datafusion-common/src/error.rs
index fbad9a9..33c4768 100644
--- a/datafusion/src/error.rs
+++ b/datafusion-common/src/error.rs
@@ -34,7 +34,6 @@ pub type GenericError = Box<dyn error::Error + Send + Sync>;
/// DataFusion error
#[derive(Debug)]
-#[allow(missing_docs)]
pub enum DataFusionError {
/// Error returned by arrow.
ArrowError(ArrowError),
@@ -83,8 +82,8 @@ impl From<DataFusionError> for ArrowError {
fn from(e: DataFusionError) -> Self {
match e {
DataFusionError::ArrowError(e) => e,
- DataFusionError::External(e) => ArrowError::External("".to_string(), e),
- other => ArrowError::External("".to_string(), Box::new(other)),
+ DataFusionError::External(e) => ArrowError::External(String::new(), e),
+ other => ArrowError::External(String::new(), Box::new(other)),
}
}
}
@@ -160,10 +159,7 @@ mod test {
#[test]
fn datafusion_error_to_arrow() {
let res = return_datafusion_error().unwrap_err();
- assert_eq!(
- res.to_string(),
- "Arrow error: Invalid argument error: Schema error: bar"
- );
+ assert_eq!(res.to_string(), "Arrow error: Schema error: bar");
}
/// Model what happens when implementing SendableRecrordBatchStream:
@@ -181,7 +177,7 @@ mod test {
fn return_datafusion_error() -> crate::error::Result<()> {
// Expect the '?' to work
let _bar = Err(ArrowError::InvalidArgumentError(
- "Schema error: bar".to_string(),
+ "bad schema bar".to_string(),
))?;
Ok(())
}
diff --git a/datafusion/src/field_util.rs b/datafusion-common/src/field_util.rs
similarity index 100%
copy from datafusion/src/field_util.rs
copy to datafusion-common/src/field_util.rs
diff --git a/datafusion-cli/src/lib.rs b/datafusion-common/src/lib.rs
similarity index 68%
copy from datafusion-cli/src/lib.rs
copy to datafusion-common/src/lib.rs
index b2bcdd3..cb06b46 100644
--- a/datafusion-cli/src/lib.rs
+++ b/datafusion-common/src/lib.rs
@@ -15,14 +15,16 @@
// specific language governing permissions and limitations
// under the License.
-#![doc = include_str!("../README.md")]
-#![allow(unused_imports)]
-pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION");
+mod column;
+mod dfschema;
+mod error;
+pub mod field_util;
+#[cfg(feature = "pyarrow")]
+mod pyarrow;
+pub mod record_batch;
+mod scalar;
-pub mod command;
-pub mod context;
-pub mod exec;
-pub mod functions;
-pub mod helper;
-pub mod print_format;
-pub mod print_options;
+pub use column::Column;
+pub use dfschema::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema};
+pub use error::{DataFusionError, Result};
+pub use scalar::{ScalarValue, MAX_PRECISION_FOR_DECIMAL128, MAX_SCALE_FOR_DECIMAL128};
diff --git a/datafusion-common/src/pyarrow.rs b/datafusion-common/src/pyarrow.rs
new file mode 100644
index 0000000..405e568
--- /dev/null
+++ b/datafusion-common/src/pyarrow.rs
@@ -0,0 +1,247 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! PyArrow
+
+use crate::{DataFusionError, ScalarValue};
+use arrow::array::ArrayData;
+use arrow::pyarrow::PyArrowConvert;
+use pyo3::exceptions::PyException;
+use pyo3::prelude::PyErr;
+use pyo3::types::PyList;
+use pyo3::{FromPyObject, IntoPy, PyAny, PyObject, PyResult, Python};
+
+impl From<DataFusionError> for PyErr {
+ fn from(err: DataFusionError) -> PyErr {
+ PyException::new_err(err.to_string())
+ }
+}
+
+impl PyArrowConvert for ScalarValue {
+ fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ let py = value.py();
+ let typ = value.getattr("type")?;
+ let val = value.call_method0("as_py")?;
+
+ // construct pyarrow array from the python value and pyarrow type
+ let factory = py.import("pyarrow")?.getattr("array")?;
+ let args = PyList::new(py, &[val]);
+ let array = factory.call1((args, typ))?;
+
+ // convert the pyarrow array to rust array using C data interface
+ let array = array.extract::<ArrayData>()?;
+ let scalar = ScalarValue::try_from_array(&array.into(), 0)?;
+
+ Ok(scalar)
+ }
+
+ fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
+ let array = self.to_array();
+ // convert to pyarrow array using C data interface
+ let pyarray = array.data_ref().clone().into_py(py);
+ let pyscalar = pyarray.call_method1(py, "__getitem__", (0,))?;
+
+ Ok(pyscalar)
+ }
+}
+
+impl<'source> FromPyObject<'source> for ScalarValue {
+ fn extract(value: &'source PyAny) -> PyResult<Self> {
+ Self::from_pyarrow(value)
+ }
+}
+
+impl<'a> IntoPy<PyObject> for ScalarValue {
+ fn into_py(self, py: Python) -> PyObject {
+ self.to_pyarrow(py).unwrap()
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use pyo3::prepare_freethreaded_python;
+ use pyo3::py_run;
+ use pyo3::types::PyDict;
+
+ fn init_python() {
+ prepare_freethreaded_python();
+ Python::with_gil(|py| {
+ if let Err(err) = py.run("import pyarrow", None, None) {
+ let locals = PyDict::new(py);
+ py.run(
+ "import sys; executable = sys.executable; python_path = sys.path",
+ None,
+ Some(locals),
+ )
+ .expect("Couldn't get python info");
+ let executable: String =
+ locals.get_item("executable").unwrap().extract().unwrap();
+ let python_path: Vec<&str> =
+ locals.get_item("python_path").unwrap().extract().unwrap();
+
+ Err(err).expect(
+ format!(
+ "pyarrow not found\nExecutable: {}\nPython path: {:?}\n\
+ HINT: try `pip install pyarrow`\n\
+ NOTE: On Mac OS, you must compile against a Framework Python \
+ (default in python.org installers and brew, but not pyenv)\n\
+ NOTE: On Mac OS, PYO3 might point to incorrect Python library \
+ path when using virtual environments. Try \
+ `export PYTHONPATH=$(python -c \"import sys; print(sys.path[-1])\")`\n",
+ executable, python_path
+ )
+ .as_ref(),
+ )
+ }
+ })
+ }
+
+ #[test]
+ fn test_roundtrip() {
+ init_python();
+
+ let example_scalars = vec![
+ ScalarValue::Boolean(Some(true)),
+ ScalarValue::Int32(Some(23)),
+ ScalarValue::Float64(Some(12.34)),
+ ScalarValue::Utf8(Some("Hello!".to_string())),
+ ScalarValue::Date32(Some(1234)),
+ ];
+
+ Python::with_gil(|py| {
+ for scalar in example_scalars.iter() {
+ let result =
+ ScalarValue::from_pyarrow(scalar.to_pyarrow(py).unwrap().as_ref(py))
+ .unwrap();
+ assert_eq!(scalar, &result);
+ }
+ });
+ }
+
+ #[test]
+ fn test_py_scalar() {
+ init_python();
+
+ Python::with_gil(|py| {
+ let scalar_float = ScalarValue::Float64(Some(12.34));
+ let py_float = scalar_float.into_py(py).call_method0(py, "as_py").unwrap();
+ py_run!(py, py_float, "assert py_float == 12.34");
+
+ let scalar_string = ScalarValue::Utf8(Some("Hello!".to_string()));
+ let py_string = scalar_string.into_py(py).call_method0(py, "as_py").unwrap();
+ py_run!(py, py_string, "assert py_string == 'Hello!'");
+ });
+ }
+}
+
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::Array;
+use arrow::error::ArrowError;
+use pyo3::exceptions::{PyException, PyNotImplementedError};
+use pyo3::ffi::Py_uintptr_t;
+use pyo3::prelude::*;
+use pyo3::types::PyList;
+use std::sync::Arc;
+
+use crate::error::DataFusionError;
+use crate::scalar::ScalarValue;
+
+impl From<DataFusionError> for PyErr {
+ fn from(err: DataFusionError) -> PyErr {
+ PyException::new_err(err.to_string())
+ }
+}
+
+impl From<PyO3ArrowError> for PyErr {
+ fn from(err: PyO3ArrowError) -> PyErr {
+ PyException::new_err(format!("{:?}", err))
+ }
+}
+
+#[derive(Debug)]
+enum PyO3ArrowError {
+ ArrowError(ArrowError),
+}
+
+fn to_rust_array(ob: PyObject, py: Python) -> PyResult<Arc<dyn Array>> {
+ // prepare a pointer to receive the Array struct
+ let array = Box::new(arrow::ffi::Ffi_ArrowArray::empty());
+ let schema = Box::new(arrow::ffi::Ffi_ArrowSchema::empty());
+
+ let array_ptr = &*array as *const arrow::ffi::Ffi_ArrowArray;
+ let schema_ptr = &*schema as *const arrow::ffi::Ffi_ArrowSchema;
+
+ // make the conversion through PyArrow's private API
+ // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds
+ ob.call_method1(
+ py,
+ "_export_to_c",
+ (array_ptr as Py_uintptr_t, schema_ptr as Py_uintptr_t),
+ )?;
+
+ let field = unsafe {
+ arrow::ffi::import_field_from_c(schema.as_ref())
+ .map_err(PyO3ArrowError::ArrowError)?
+ };
+ let array = unsafe {
+ arrow::ffi::import_array_from_c(array, &field)
+ .map_err(PyO3ArrowError::ArrowError)?
+ };
+
+ Ok(array.into())
+}
+impl<'source> FromPyObject<'source> for ScalarValue {
+ fn extract(value: &'source PyAny) -> PyResult<Self> {
+ let py = value.py();
+ let typ = value.getattr("type")?;
+ let val = value.call_method0("as_py")?;
+
+ // construct pyarrow array from the python value and pyarrow type
+ let factory = py.import("pyarrow")?.getattr("array")?;
+ let args = PyList::new(py, &[val]);
+ let array = factory.call1((args, typ))?;
+
+ // convert the pyarrow array to rust array using C data interface]
+ let array = to_rust_array(array.to_object(py), py)?;
+ let scalar = ScalarValue::try_from_array(&array, 0)?;
+
+ Ok(scalar)
+ }
+}
+
+impl<'a> IntoPy<PyObject> for ScalarValue {
+ fn into_py(self, _py: Python) -> PyObject {
+ Err(PyNotImplementedError::new_err("Not implemented")).unwrap()
+ }
+}
diff --git a/datafusion/src/record_batch.rs b/datafusion-common/src/record_batch.rs
similarity index 94%
copy from datafusion/src/record_batch.rs
copy to datafusion-common/src/record_batch.rs
index 8fba09e..4d45687 100644
--- a/datafusion/src/record_batch.rs
+++ b/datafusion-common/src/record_batch.rs
@@ -1,3 +1,20 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
//! Contains [`RecordBatch`].
use std::sync::Arc;
diff --git a/datafusion/src/scalar.rs b/datafusion-common/src/scalar.rs
similarity index 99%
copy from datafusion/src/scalar.rs
copy to datafusion-common/src/scalar.rs
index 847a9dd..c87ea73 100644
--- a/datafusion/src/scalar.rs
+++ b/datafusion-common/src/scalar.rs
@@ -45,8 +45,8 @@ type MutableLargeStringArray = MutableUtf8Array<i64>;
// TODO may need to be moved to arrow-rs
/// The max precision and scale for decimal128
-pub(crate) const MAX_PRECISION_FOR_DECIMAL128: usize = 38;
-pub(crate) const MAX_SCALE_FOR_DECIMAL128: usize = 38;
+pub const MAX_PRECISION_FOR_DECIMAL128: usize = 38;
+pub const MAX_SCALE_FOR_DECIMAL128: usize = 38;
/// Represents a dynamically typed, nullable single value.
/// This is the single-valued counter-part of arrow’s `Array`.
diff --git a/datafusion/src/scalar.rs b/datafusion-common/src/scalar_tmp.rs
similarity index 100%
copy from datafusion/src/scalar.rs
copy to datafusion-common/src/scalar_tmp.rs
diff --git a/datafusion-cli/Cargo.toml b/datafusion-expr/Cargo.toml
similarity index 68%
copy from datafusion-cli/Cargo.toml
copy to datafusion-expr/Cargo.toml
index 09df15b..7da7c26 100644
--- a/datafusion-cli/Cargo.toml
+++ b/datafusion-expr/Cargo.toml
@@ -16,20 +16,26 @@
# under the License.
[package]
-name = "datafusion-cli"
-version = "5.1.0"
-authors = ["Apache Arrow <de...@arrow.apache.org>"]
-edition = "2021"
-keywords = [ "arrow", "datafusion", "ballista", "query", "sql" ]
-license = "Apache-2.0"
+name = "datafusion-expr"
+description = "DataFusion is an in-memory query engine that uses Apache Arrow as the memory model"
+version = "6.0.0"
homepage = "https://github.com/apache/arrow-datafusion"
repository = "https://github.com/apache/arrow-datafusion"
+readme = "../README.md"
+authors = ["Apache Arrow <de...@arrow.apache.org>"]
+license = "Apache-2.0"
+keywords = [ "arrow", "query", "sql" ]
+edition = "2021"
rust-version = "1.58"
+[lib]
+name = "datafusion_expr"
+path = "src/lib.rs"
+
+[features]
+
[dependencies]
-clap = { version = "3", features = ["derive", "cargo"] }
-rustyline = "9.0"
-tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync", "parking_lot"] }
-datafusion = { path = "../datafusion", version = "6.0.0" }
-arrow = { package = "arrow2", version="0.9", features = ["io_print"] }
-ballista = { path = "../ballista/rust/client", version = "0.6.0" }
+datafusion-common = { path = "../datafusion-common", version = "6.0.0" }
+arrow = { package = "arrow2", version = "0.9", default-features = false }
+sqlparser = "0.13"
+ahash = { version = "0.7", default-features = false }
diff --git a/datafusion-expr/README.md b/datafusion-expr/README.md
new file mode 100644
index 0000000..25ac79c
--- /dev/null
+++ b/datafusion-expr/README.md
@@ -0,0 +1,24 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+# DataFusion Expr
+
+This is an internal module for fundamental expression types of [DataFusion][df].
+
+[df]: https://crates.io/crates/datafusion
diff --git a/datafusion-expr/src/accumulator.rs b/datafusion-expr/src/accumulator.rs
new file mode 100644
index 0000000..599bd36
--- /dev/null
+++ b/datafusion-expr/src/accumulator.rs
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::ArrayRef;
+use datafusion_common::{Result, ScalarValue};
+use std::fmt::Debug;
+
+/// An accumulator represents a stateful object that lives throughout the evaluation of multiple rows and
+/// generically accumulates values.
+///
+/// An accumulator knows how to:
+/// * update its state from inputs via `update_batch`
+/// * convert its internal state to a vector of scalar values
+/// * update its state from multiple accumulators' states via `merge_batch`
+/// * compute the final value from its internal state via `evaluate`
+pub trait Accumulator: Send + Sync + Debug {
+ /// Returns the state of the accumulator at the end of the accumulation.
+ // in the case of an average on which we track `sum` and `n`, this function should return a vector
+ // of two values, sum and n.
+ fn state(&self) -> Result<Vec<ScalarValue>>;
+
+ /// updates the accumulator's state from a vector of arrays.
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()>;
+
+ /// updates the accumulator's state from a vector of states.
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()>;
+
+ /// returns its value based on its current state.
+ fn evaluate(&self) -> Result<ScalarValue>;
+}
diff --git a/datafusion-expr/src/aggregate_function.rs b/datafusion-expr/src/aggregate_function.rs
new file mode 100644
index 0000000..8f12e88
--- /dev/null
+++ b/datafusion-expr/src/aggregate_function.rs
@@ -0,0 +1,93 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion_common::{DataFusionError, Result};
+use std::{fmt, str::FromStr};
+
+/// Enum of all built-in aggregate functions
+#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Hash)]
+pub enum AggregateFunction {
+ /// count
+ Count,
+ /// sum
+ Sum,
+ /// min
+ Min,
+ /// max
+ Max,
+ /// avg
+ Avg,
+ /// Approximate aggregate function
+ ApproxDistinct,
+ /// array_agg
+ ArrayAgg,
+ /// Variance (Sample)
+ Variance,
+ /// Variance (Population)
+ VariancePop,
+ /// Standard Deviation (Sample)
+ Stddev,
+ /// Standard Deviation (Population)
+ StddevPop,
+ /// Covariance (Sample)
+ Covariance,
+ /// Covariance (Population)
+ CovariancePop,
+ /// Correlation
+ Correlation,
+ /// Approximate continuous percentile function
+ ApproxPercentileCont,
+}
+
+impl fmt::Display for AggregateFunction {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ // uppercase of the debug.
+ write!(f, "{}", format!("{:?}", self).to_uppercase())
+ }
+}
+
+impl FromStr for AggregateFunction {
+ type Err = DataFusionError;
+ fn from_str(name: &str) -> Result<AggregateFunction> {
+ Ok(match name {
+ "min" => AggregateFunction::Min,
+ "max" => AggregateFunction::Max,
+ "count" => AggregateFunction::Count,
+ "avg" => AggregateFunction::Avg,
+ "sum" => AggregateFunction::Sum,
+ "approx_distinct" => AggregateFunction::ApproxDistinct,
+ "array_agg" => AggregateFunction::ArrayAgg,
+ "var" => AggregateFunction::Variance,
+ "var_samp" => AggregateFunction::Variance,
+ "var_pop" => AggregateFunction::VariancePop,
+ "stddev" => AggregateFunction::Stddev,
+ "stddev_samp" => AggregateFunction::Stddev,
+ "stddev_pop" => AggregateFunction::StddevPop,
+ "covar" => AggregateFunction::Covariance,
+ "covar_samp" => AggregateFunction::Covariance,
+ "covar_pop" => AggregateFunction::CovariancePop,
+ "corr" => AggregateFunction::Correlation,
+ "approx_percentile_cont" => AggregateFunction::ApproxPercentileCont,
+ _ => {
+ return Err(DataFusionError::Plan(format!(
+ "There is no built-in function named {}",
+ name
+ )));
+ }
+ })
+ }
+}
diff --git a/datafusion-expr/src/built_in_function.rs b/datafusion-expr/src/built_in_function.rs
new file mode 100644
index 0000000..0d5ee97
--- /dev/null
+++ b/datafusion-expr/src/built_in_function.rs
@@ -0,0 +1,330 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Built-in functions
+
+use crate::Volatility;
+use datafusion_common::{DataFusionError, Result};
+use std::fmt;
+use std::str::FromStr;
+
+/// Enum of all built-in scalar functions
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub enum BuiltinScalarFunction {
+ // math functions
+ /// abs
+ Abs,
+ /// acos
+ Acos,
+ /// asin
+ Asin,
+ /// atan
+ Atan,
+ /// ceil
+ Ceil,
+ /// cos
+ Cos,
+ /// Digest
+ Digest,
+ /// exp
+ Exp,
+ /// floor
+ Floor,
+ /// ln, Natural logarithm
+ Ln,
+ /// log, same as log10
+ Log,
+ /// log10
+ Log10,
+ /// log2
+ Log2,
+ /// round
+ Round,
+ /// signum
+ Signum,
+ /// sin
+ Sin,
+ /// sqrt
+ Sqrt,
+ /// tan
+ Tan,
+ /// trunc
+ Trunc,
+
+ // string functions
+ /// construct an array from columns
+ Array,
+ /// ascii
+ Ascii,
+ /// bit_length
+ BitLength,
+ /// btrim
+ Btrim,
+ /// character_length
+ CharacterLength,
+ /// chr
+ Chr,
+ /// concat
+ Concat,
+ /// concat_ws
+ ConcatWithSeparator,
+ /// date_part
+ DatePart,
+ /// date_trunc
+ DateTrunc,
+ /// initcap
+ InitCap,
+ /// left
+ Left,
+ /// lpad
+ Lpad,
+ /// lower
+ Lower,
+ /// ltrim
+ Ltrim,
+ /// md5
+ MD5,
+ /// nullif
+ NullIf,
+ /// octet_length
+ OctetLength,
+ /// random
+ Random,
+ /// regexp_replace
+ RegexpReplace,
+ /// repeat
+ Repeat,
+ /// replace
+ Replace,
+ /// reverse
+ Reverse,
+ /// right
+ Right,
+ /// rpad
+ Rpad,
+ /// rtrim
+ Rtrim,
+ /// sha224
+ SHA224,
+ /// sha256
+ SHA256,
+ /// sha384
+ SHA384,
+ /// Sha512
+ SHA512,
+ /// split_part
+ SplitPart,
+ /// starts_with
+ StartsWith,
+ /// strpos
+ Strpos,
+ /// substr
+ Substr,
+ /// to_hex
+ ToHex,
+ /// to_timestamp
+ ToTimestamp,
+ /// to_timestamp_millis
+ ToTimestampMillis,
+ /// to_timestamp_micros
+ ToTimestampMicros,
+ /// to_timestamp_seconds
+ ToTimestampSeconds,
+ ///now
+ Now,
+ /// translate
+ Translate,
+ /// trim
+ Trim,
+ /// upper
+ Upper,
+ /// regexp_match
+ RegexpMatch,
+}
+
+impl BuiltinScalarFunction {
+ /// an allowlist of functions to take zero arguments, so that they will get special treatment
+ /// while executing.
+ pub fn supports_zero_argument(&self) -> bool {
+ matches!(
+ self,
+ BuiltinScalarFunction::Random | BuiltinScalarFunction::Now
+ )
+ }
+ /// Returns the [Volatility] of the builtin function.
+ pub fn volatility(&self) -> Volatility {
+ match self {
+ //Immutable scalar builtins
+ BuiltinScalarFunction::Abs => Volatility::Immutable,
+ BuiltinScalarFunction::Acos => Volatility::Immutable,
+ BuiltinScalarFunction::Asin => Volatility::Immutable,
+ BuiltinScalarFunction::Atan => Volatility::Immutable,
+ BuiltinScalarFunction::Ceil => Volatility::Immutable,
+ BuiltinScalarFunction::Cos => Volatility::Immutable,
+ BuiltinScalarFunction::Exp => Volatility::Immutable,
+ BuiltinScalarFunction::Floor => Volatility::Immutable,
+ BuiltinScalarFunction::Ln => Volatility::Immutable,
+ BuiltinScalarFunction::Log => Volatility::Immutable,
+ BuiltinScalarFunction::Log10 => Volatility::Immutable,
+ BuiltinScalarFunction::Log2 => Volatility::Immutable,
+ BuiltinScalarFunction::Round => Volatility::Immutable,
+ BuiltinScalarFunction::Signum => Volatility::Immutable,
+ BuiltinScalarFunction::Sin => Volatility::Immutable,
+ BuiltinScalarFunction::Sqrt => Volatility::Immutable,
+ BuiltinScalarFunction::Tan => Volatility::Immutable,
+ BuiltinScalarFunction::Trunc => Volatility::Immutable,
+ BuiltinScalarFunction::Array => Volatility::Immutable,
+ BuiltinScalarFunction::Ascii => Volatility::Immutable,
+ BuiltinScalarFunction::BitLength => Volatility::Immutable,
+ BuiltinScalarFunction::Btrim => Volatility::Immutable,
+ BuiltinScalarFunction::CharacterLength => Volatility::Immutable,
+ BuiltinScalarFunction::Chr => Volatility::Immutable,
+ BuiltinScalarFunction::Concat => Volatility::Immutable,
+ BuiltinScalarFunction::ConcatWithSeparator => Volatility::Immutable,
+ BuiltinScalarFunction::DatePart => Volatility::Immutable,
+ BuiltinScalarFunction::DateTrunc => Volatility::Immutable,
+ BuiltinScalarFunction::InitCap => Volatility::Immutable,
+ BuiltinScalarFunction::Left => Volatility::Immutable,
+ BuiltinScalarFunction::Lpad => Volatility::Immutable,
+ BuiltinScalarFunction::Lower => Volatility::Immutable,
+ BuiltinScalarFunction::Ltrim => Volatility::Immutable,
+ BuiltinScalarFunction::MD5 => Volatility::Immutable,
+ BuiltinScalarFunction::NullIf => Volatility::Immutable,
+ BuiltinScalarFunction::OctetLength => Volatility::Immutable,
+ BuiltinScalarFunction::RegexpReplace => Volatility::Immutable,
+ BuiltinScalarFunction::Repeat => Volatility::Immutable,
+ BuiltinScalarFunction::Replace => Volatility::Immutable,
+ BuiltinScalarFunction::Reverse => Volatility::Immutable,
+ BuiltinScalarFunction::Right => Volatility::Immutable,
+ BuiltinScalarFunction::Rpad => Volatility::Immutable,
+ BuiltinScalarFunction::Rtrim => Volatility::Immutable,
+ BuiltinScalarFunction::SHA224 => Volatility::Immutable,
+ BuiltinScalarFunction::SHA256 => Volatility::Immutable,
+ BuiltinScalarFunction::SHA384 => Volatility::Immutable,
+ BuiltinScalarFunction::SHA512 => Volatility::Immutable,
+ BuiltinScalarFunction::Digest => Volatility::Immutable,
+ BuiltinScalarFunction::SplitPart => Volatility::Immutable,
+ BuiltinScalarFunction::StartsWith => Volatility::Immutable,
+ BuiltinScalarFunction::Strpos => Volatility::Immutable,
+ BuiltinScalarFunction::Substr => Volatility::Immutable,
+ BuiltinScalarFunction::ToHex => Volatility::Immutable,
+ BuiltinScalarFunction::ToTimestamp => Volatility::Immutable,
+ BuiltinScalarFunction::ToTimestampMillis => Volatility::Immutable,
+ BuiltinScalarFunction::ToTimestampMicros => Volatility::Immutable,
+ BuiltinScalarFunction::ToTimestampSeconds => Volatility::Immutable,
+ BuiltinScalarFunction::Translate => Volatility::Immutable,
+ BuiltinScalarFunction::Trim => Volatility::Immutable,
+ BuiltinScalarFunction::Upper => Volatility::Immutable,
+ BuiltinScalarFunction::RegexpMatch => Volatility::Immutable,
+
+ //Stable builtin functions
+ BuiltinScalarFunction::Now => Volatility::Stable,
+
+ //Volatile builtin functions
+ BuiltinScalarFunction::Random => Volatility::Volatile,
+ }
+ }
+}
+
+impl fmt::Display for BuiltinScalarFunction {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ // lowercase of the debug.
+ write!(f, "{}", format!("{:?}", self).to_lowercase())
+ }
+}
+
+impl FromStr for BuiltinScalarFunction {
+ type Err = DataFusionError;
+ fn from_str(name: &str) -> Result<BuiltinScalarFunction> {
+ Ok(match name {
+ // math functions
+ "abs" => BuiltinScalarFunction::Abs,
+ "acos" => BuiltinScalarFunction::Acos,
+ "asin" => BuiltinScalarFunction::Asin,
+ "atan" => BuiltinScalarFunction::Atan,
+ "ceil" => BuiltinScalarFunction::Ceil,
+ "cos" => BuiltinScalarFunction::Cos,
+ "exp" => BuiltinScalarFunction::Exp,
+ "floor" => BuiltinScalarFunction::Floor,
+ "ln" => BuiltinScalarFunction::Ln,
+ "log" => BuiltinScalarFunction::Log,
+ "log10" => BuiltinScalarFunction::Log10,
+ "log2" => BuiltinScalarFunction::Log2,
+ "round" => BuiltinScalarFunction::Round,
+ "signum" => BuiltinScalarFunction::Signum,
+ "sin" => BuiltinScalarFunction::Sin,
+ "sqrt" => BuiltinScalarFunction::Sqrt,
+ "tan" => BuiltinScalarFunction::Tan,
+ "trunc" => BuiltinScalarFunction::Trunc,
+
+ // string functions
+ "array" => BuiltinScalarFunction::Array,
+ "ascii" => BuiltinScalarFunction::Ascii,
+ "bit_length" => BuiltinScalarFunction::BitLength,
+ "btrim" => BuiltinScalarFunction::Btrim,
+ "char_length" => BuiltinScalarFunction::CharacterLength,
+ "character_length" => BuiltinScalarFunction::CharacterLength,
+ "concat" => BuiltinScalarFunction::Concat,
+ "concat_ws" => BuiltinScalarFunction::ConcatWithSeparator,
+ "chr" => BuiltinScalarFunction::Chr,
+ "date_part" | "datepart" => BuiltinScalarFunction::DatePart,
+ "date_trunc" | "datetrunc" => BuiltinScalarFunction::DateTrunc,
+ "initcap" => BuiltinScalarFunction::InitCap,
+ "left" => BuiltinScalarFunction::Left,
+ "length" => BuiltinScalarFunction::CharacterLength,
+ "lower" => BuiltinScalarFunction::Lower,
+ "lpad" => BuiltinScalarFunction::Lpad,
+ "ltrim" => BuiltinScalarFunction::Ltrim,
+ "md5" => BuiltinScalarFunction::MD5,
+ "nullif" => BuiltinScalarFunction::NullIf,
+ "octet_length" => BuiltinScalarFunction::OctetLength,
+ "random" => BuiltinScalarFunction::Random,
+ "regexp_replace" => BuiltinScalarFunction::RegexpReplace,
+ "repeat" => BuiltinScalarFunction::Repeat,
+ "replace" => BuiltinScalarFunction::Replace,
+ "reverse" => BuiltinScalarFunction::Reverse,
+ "right" => BuiltinScalarFunction::Right,
+ "rpad" => BuiltinScalarFunction::Rpad,
+ "rtrim" => BuiltinScalarFunction::Rtrim,
+ "sha224" => BuiltinScalarFunction::SHA224,
+ "sha256" => BuiltinScalarFunction::SHA256,
+ "sha384" => BuiltinScalarFunction::SHA384,
+ "sha512" => BuiltinScalarFunction::SHA512,
+ "digest" => BuiltinScalarFunction::Digest,
+ "split_part" => BuiltinScalarFunction::SplitPart,
+ "starts_with" => BuiltinScalarFunction::StartsWith,
+ "strpos" => BuiltinScalarFunction::Strpos,
+ "substr" => BuiltinScalarFunction::Substr,
+ "to_hex" => BuiltinScalarFunction::ToHex,
+ "to_timestamp" => BuiltinScalarFunction::ToTimestamp,
+ "to_timestamp_millis" => BuiltinScalarFunction::ToTimestampMillis,
+ "to_timestamp_micros" => BuiltinScalarFunction::ToTimestampMicros,
+ "to_timestamp_seconds" => BuiltinScalarFunction::ToTimestampSeconds,
+ "now" => BuiltinScalarFunction::Now,
+ "translate" => BuiltinScalarFunction::Translate,
+ "trim" => BuiltinScalarFunction::Trim,
+ "upper" => BuiltinScalarFunction::Upper,
+ "regexp_match" => BuiltinScalarFunction::RegexpMatch,
+ _ => {
+ return Err(DataFusionError::Plan(format!(
+ "There is no built-in function named {}",
+ name
+ )))
+ }
+ })
+ }
+}
diff --git a/datafusion-expr/src/columnar_value.rs b/datafusion-expr/src/columnar_value.rs
new file mode 100644
index 0000000..fb00f0c
--- /dev/null
+++ b/datafusion-expr/src/columnar_value.rs
@@ -0,0 +1,63 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::array::ArrayRef;
+use arrow::array::NullArray;
+use arrow::datatypes::DataType;
+use datafusion_common::record_batch::RecordBatch;
+use datafusion_common::ScalarValue;
+use std::sync::Arc;
+
+/// Represents the result from an expression
+#[derive(Clone)]
+pub enum ColumnarValue {
+ /// Array of values
+ Array(ArrayRef),
+ /// A single value
+ Scalar(ScalarValue),
+}
+
+impl ColumnarValue {
+ pub fn data_type(&self) -> DataType {
+ match self {
+ ColumnarValue::Array(array_value) => array_value.data_type().clone(),
+ ColumnarValue::Scalar(scalar_value) => scalar_value.get_datatype(),
+ }
+ }
+
+ /// Convert a columnar value into an ArrayRef
+ pub fn into_array(self, num_rows: usize) -> ArrayRef {
+ match self {
+ ColumnarValue::Array(array) => array,
+ ColumnarValue::Scalar(scalar) => scalar.to_array_of_size(num_rows),
+ }
+ }
+}
+
+/// null columnar values are implemented as a null array in order to pass batch
+/// num_rows
+pub type NullColumnarValue = ColumnarValue;
+
+impl From<&RecordBatch> for NullColumnarValue {
+ fn from(batch: &RecordBatch) -> Self {
+ let num_rows = batch.num_rows();
+ ColumnarValue::Array(Arc::new(NullArray::new_null(
+ DataType::Struct(batch.schema().fields.to_vec()),
+ num_rows,
+ )))
+ }
+}
diff --git a/datafusion-expr/src/expr.rs b/datafusion-expr/src/expr.rs
new file mode 100644
index 0000000..f26f1df
--- /dev/null
+++ b/datafusion-expr/src/expr.rs
@@ -0,0 +1,698 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::aggregate_function;
+use crate::built_in_function;
+use crate::expr_fn::binary_expr;
+use crate::window_frame;
+use crate::window_function;
+use crate::AggregateUDF;
+use crate::Operator;
+use crate::ScalarUDF;
+use arrow::datatypes::DataType;
+use datafusion_common::Column;
+use datafusion_common::{DFSchema, Result};
+use datafusion_common::{DataFusionError, ScalarValue};
+use std::fmt;
+use std::hash::{BuildHasher, Hash, Hasher};
+use std::ops::Not;
+use std::sync::Arc;
+
+/// `Expr` is a central struct of DataFusion's query API, and
+/// represent logical expressions such as `A + 1`, or `CAST(c1 AS
+/// int)`.
+///
+/// An `Expr` can compute its [DataType](arrow::datatypes::DataType)
+/// and nullability, and has functions for building up complex
+/// expressions.
+///
+/// # Examples
+///
+/// ## Create an expression `c1` referring to column named "c1"
+/// ```
+/// # use datafusion_common::Column;
+/// # use datafusion_expr::{lit, col, Expr};
+/// let expr = col("c1");
+/// assert_eq!(expr, Expr::Column(Column::from_name("c1")));
+/// ```
+///
+/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together
+/// ```
+/// # use datafusion_expr::{lit, col, Operator, Expr};
+/// let expr = col("c1") + col("c2");
+///
+/// assert!(matches!(expr, Expr::BinaryExpr { ..} ));
+/// if let Expr::BinaryExpr { left, right, op } = expr {
+/// assert_eq!(*left, col("c1"));
+/// assert_eq!(*right, col("c2"));
+/// assert_eq!(op, Operator::Plus);
+/// }
+/// ```
+///
+/// ## Create expression `c1 = 42` to compare the value in column "c1" to the literal value `42`
+/// ```
+/// # use datafusion_common::ScalarValue;
+/// # use datafusion_expr::{lit, col, Operator, Expr};
+/// let expr = col("c1").eq(lit(42_i32));
+///
+/// assert!(matches!(expr, Expr::BinaryExpr { .. } ));
+/// if let Expr::BinaryExpr { left, right, op } = expr {
+/// assert_eq!(*left, col("c1"));
+/// let scalar = ScalarValue::Int32(Some(42));
+/// assert_eq!(*right, Expr::Literal(scalar));
+/// assert_eq!(op, Operator::Eq);
+/// }
+/// ```
+#[derive(Clone, PartialEq, Hash)]
+pub enum Expr {
+ /// An expression with a specific name.
+ Alias(Box<Expr>, String),
+ /// A named reference to a qualified filed in a schema.
+ Column(Column),
+ /// A named reference to a variable in a registry.
+ ScalarVariable(Vec<String>),
+ /// A constant value.
+ Literal(ScalarValue),
+ /// A binary expression such as "age > 21"
+ BinaryExpr {
+ /// Left-hand side of the expression
+ left: Box<Expr>,
+ /// The comparison operator
+ op: Operator,
+ /// Right-hand side of the expression
+ right: Box<Expr>,
+ },
+ /// Negation of an expression. The expression's type must be a boolean to make sense.
+ Not(Box<Expr>),
+ /// Whether an expression is not Null. This expression is never null.
+ IsNotNull(Box<Expr>),
+ /// Whether an expression is Null. This expression is never null.
+ IsNull(Box<Expr>),
+ /// arithmetic negation of an expression, the operand must be of a signed numeric data type
+ Negative(Box<Expr>),
+ /// Returns the field of a [`ListArray`] or [`StructArray`] by key
+ GetIndexedField {
+ /// the expression to take the field from
+ expr: Box<Expr>,
+ /// The name of the field to take
+ key: ScalarValue,
+ },
+ /// Whether an expression is between a given range.
+ Between {
+ /// The value to compare
+ expr: Box<Expr>,
+ /// Whether the expression is negated
+ negated: bool,
+ /// The low end of the range
+ low: Box<Expr>,
+ /// The high end of the range
+ high: Box<Expr>,
+ },
+ /// The CASE expression is similar to a series of nested if/else and there are two forms that
+ /// can be used. The first form consists of a series of boolean "when" expressions with
+ /// corresponding "then" expressions, and an optional "else" expression.
+ ///
+ /// CASE WHEN condition THEN result
+ /// [WHEN ...]
+ /// [ELSE result]
+ /// END
+ ///
+ /// The second form uses a base expression and then a series of "when" clauses that match on a
+ /// literal value.
+ ///
+ /// CASE expression
+ /// WHEN value THEN result
+ /// [WHEN ...]
+ /// [ELSE result]
+ /// END
+ Case {
+ /// Optional base expression that can be compared to literal values in the "when" expressions
+ expr: Option<Box<Expr>>,
+ /// One or more when/then expressions
+ when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
+ /// Optional "else" expression
+ else_expr: Option<Box<Expr>>,
+ },
+ /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast.
+ /// This expression is guaranteed to have a fixed type.
+ Cast {
+ /// The expression being cast
+ expr: Box<Expr>,
+ /// The `DataType` the expression will yield
+ data_type: DataType,
+ },
+ /// Casts the expression to a given type and will return a null value if the expression cannot be cast.
+ /// This expression is guaranteed to have a fixed type.
+ TryCast {
+ /// The expression being cast
+ expr: Box<Expr>,
+ /// The `DataType` the expression will yield
+ data_type: DataType,
+ },
+ /// A sort expression, that can be used to sort values.
+ Sort {
+ /// The expression to sort on
+ expr: Box<Expr>,
+ /// The direction of the sort
+ asc: bool,
+ /// Whether to put Nulls before all other data values
+ nulls_first: bool,
+ },
+ /// Represents the call of a built-in scalar function with a set of arguments.
+ ScalarFunction {
+ /// The function
+ fun: built_in_function::BuiltinScalarFunction,
+ /// List of expressions to feed to the functions as arguments
+ args: Vec<Expr>,
+ },
+ /// Represents the call of a user-defined scalar function with arguments.
+ ScalarUDF {
+ /// The function
+ fun: Arc<ScalarUDF>,
+ /// List of expressions to feed to the functions as arguments
+ args: Vec<Expr>,
+ },
+ /// Represents the call of an aggregate built-in function with arguments.
+ AggregateFunction {
+ /// Name of the function
+ fun: aggregate_function::AggregateFunction,
+ /// List of expressions to feed to the functions as arguments
+ args: Vec<Expr>,
+ /// Whether this is a DISTINCT aggregation or not
+ distinct: bool,
+ },
+ /// Represents the call of a window function with arguments.
+ WindowFunction {
+ /// Name of the function
+ fun: window_function::WindowFunction,
+ /// List of expressions to feed to the functions as arguments
+ args: Vec<Expr>,
+ /// List of partition by expressions
+ partition_by: Vec<Expr>,
+ /// List of order by expressions
+ order_by: Vec<Expr>,
+ /// Window frame
+ window_frame: Option<window_frame::WindowFrame>,
+ },
+ /// aggregate function
+ AggregateUDF {
+ /// The function
+ fun: Arc<AggregateUDF>,
+ /// List of expressions to feed to the functions as arguments
+ args: Vec<Expr>,
+ },
+ /// Returns whether the list contains the expr value.
+ InList {
+ /// The expression to compare
+ expr: Box<Expr>,
+ /// A list of values to compare against
+ list: Vec<Expr>,
+ /// Whether the expression is negated
+ negated: bool,
+ },
+ /// Represents a reference to all fields in a schema.
+ Wildcard,
+}
+
+/// Fixed seed for the hashing so that Ords are consistent across runs
+const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);
+
+impl PartialOrd for Expr {
+ fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
+ let mut hasher = SEED.build_hasher();
+ self.hash(&mut hasher);
+ let s = hasher.finish();
+
+ let mut hasher = SEED.build_hasher();
+ other.hash(&mut hasher);
+ let o = hasher.finish();
+
+ Some(s.cmp(&o))
+ }
+}
+
+impl Expr {
+ /// Returns the name of this expression based on [crate::logical_plan::DFSchema].
+ ///
+ /// This represents how a column with this expression is named when no alias is chosen
+ pub fn name(&self, input_schema: &DFSchema) -> Result<String> {
+ create_name(self, input_schema)
+ }
+
+ /// Return `self == other`
+ pub fn eq(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::Eq, other)
+ }
+
+ /// Return `self != other`
+ pub fn not_eq(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::NotEq, other)
+ }
+
+ /// Return `self > other`
+ pub fn gt(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::Gt, other)
+ }
+
+ /// Return `self >= other`
+ pub fn gt_eq(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::GtEq, other)
+ }
+
+ /// Return `self < other`
+ pub fn lt(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::Lt, other)
+ }
+
+ /// Return `self <= other`
+ pub fn lt_eq(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::LtEq, other)
+ }
+
+ /// Return `self && other`
+ pub fn and(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::And, other)
+ }
+
+ /// Return `self || other`
+ pub fn or(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::Or, other)
+ }
+
+ /// Return `!self`
+ #[allow(clippy::should_implement_trait)]
+ pub fn not(self) -> Expr {
+ !self
+ }
+
+ /// Calculate the modulus of two expressions.
+ /// Return `self % other`
+ pub fn modulus(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::Modulo, other)
+ }
+
+ /// Return `self LIKE other`
+ pub fn like(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::Like, other)
+ }
+
+ /// Return `self NOT LIKE other`
+ pub fn not_like(self, other: Expr) -> Expr {
+ binary_expr(self, Operator::NotLike, other)
+ }
+
+ /// Return `self AS name` alias expression
+ pub fn alias(self, name: &str) -> Expr {
+ Expr::Alias(Box::new(self), name.to_owned())
+ }
+
+ /// Return `self IN <list>` if `negated` is false, otherwise
+ /// return `self NOT IN <list>`.a
+ pub fn in_list(self, list: Vec<Expr>, negated: bool) -> Expr {
+ Expr::InList {
+ expr: Box::new(self),
+ list,
+ negated,
+ }
+ }
+
+ /// Return `IsNull(Box(self))
+ #[allow(clippy::wrong_self_convention)]
+ pub fn is_null(self) -> Expr {
+ Expr::IsNull(Box::new(self))
+ }
+
+ /// Return `IsNotNull(Box(self))
+ #[allow(clippy::wrong_self_convention)]
+ pub fn is_not_null(self) -> Expr {
+ Expr::IsNotNull(Box::new(self))
+ }
+
+ /// Create a sort expression from an existing expression.
+ ///
+ /// ```
+ /// # use datafusion_expr::col;
+ /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST
+ /// ```
+ pub fn sort(self, asc: bool, nulls_first: bool) -> Expr {
+ Expr::Sort {
+ expr: Box::new(self),
+ asc,
+ nulls_first,
+ }
+ }
+}
+
+impl Not for Expr {
+ type Output = Self;
+
+ fn not(self) -> Self::Output {
+ Expr::Not(Box::new(self))
+ }
+}
+
+impl std::fmt::Display for Expr {
+ fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
+ match self {
+ Expr::BinaryExpr {
+ ref left,
+ ref right,
+ ref op,
+ } => write!(f, "{} {} {}", left, op, right),
+ Expr::AggregateFunction {
+ /// Name of the function
+ ref fun,
+ /// List of expressions to feed to the functions as arguments
+ ref args,
+ /// Whether this is a DISTINCT aggregation or not
+ ref distinct,
+ } => fmt_function(f, &fun.to_string(), *distinct, args, true),
+ Expr::ScalarFunction {
+ /// Name of the function
+ ref fun,
+ /// List of expressions to feed to the functions as arguments
+ ref args,
+ } => fmt_function(f, &fun.to_string(), false, args, true),
+ _ => write!(f, "{:?}", self),
+ }
+ }
+}
+
+impl fmt::Debug for Expr {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias),
+ Expr::Column(c) => write!(f, "{}", c),
+ Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")),
+ Expr::Literal(v) => write!(f, "{:?}", v),
+ Expr::Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ ..
+ } => {
+ write!(f, "CASE ")?;
+ if let Some(e) = expr {
+ write!(f, "{:?} ", e)?;
+ }
+ for (w, t) in when_then_expr {
+ write!(f, "WHEN {:?} THEN {:?} ", w, t)?;
+ }
+ if let Some(e) = else_expr {
+ write!(f, "ELSE {:?} ", e)?;
+ }
+ write!(f, "END")
+ }
+ Expr::Cast { expr, data_type } => {
+ write!(f, "CAST({:?} AS {:?})", expr, data_type)
+ }
+ Expr::TryCast { expr, data_type } => {
+ write!(f, "TRY_CAST({:?} AS {:?})", expr, data_type)
+ }
+ Expr::Not(expr) => write!(f, "NOT {:?}", expr),
+ Expr::Negative(expr) => write!(f, "(- {:?})", expr),
+ Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr),
+ Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr),
+ Expr::BinaryExpr { left, op, right } => {
+ write!(f, "{:?} {} {:?}", left, op, right)
+ }
+ Expr::Sort {
+ expr,
+ asc,
+ nulls_first,
+ } => {
+ if *asc {
+ write!(f, "{:?} ASC", expr)?;
+ } else {
+ write!(f, "{:?} DESC", expr)?;
+ }
+ if *nulls_first {
+ write!(f, " NULLS FIRST")
+ } else {
+ write!(f, " NULLS LAST")
+ }
+ }
+ Expr::ScalarFunction { fun, args, .. } => {
+ fmt_function(f, &fun.to_string(), false, args, false)
+ }
+ Expr::ScalarUDF { fun, ref args, .. } => {
+ fmt_function(f, &fun.name, false, args, false)
+ }
+ Expr::WindowFunction {
+ fun,
+ args,
+ partition_by,
+ order_by,
+ window_frame,
+ } => {
+ fmt_function(f, &fun.to_string(), false, args, false)?;
+ if !partition_by.is_empty() {
+ write!(f, " PARTITION BY {:?}", partition_by)?;
+ }
+ if !order_by.is_empty() {
+ write!(f, " ORDER BY {:?}", order_by)?;
+ }
+ if let Some(window_frame) = window_frame {
+ write!(
+ f,
+ " {} BETWEEN {} AND {}",
+ window_frame.units,
+ window_frame.start_bound,
+ window_frame.end_bound
+ )?;
+ }
+ Ok(())
+ }
+ Expr::AggregateFunction {
+ fun,
+ distinct,
+ ref args,
+ ..
+ } => fmt_function(f, &fun.to_string(), *distinct, args, true),
+ Expr::AggregateUDF { fun, ref args, .. } => {
+ fmt_function(f, &fun.name, false, args, false)
+ }
+ Expr::Between {
+ expr,
+ negated,
+ low,
+ high,
+ } => {
+ if *negated {
+ write!(f, "{:?} NOT BETWEEN {:?} AND {:?}", expr, low, high)
+ } else {
+ write!(f, "{:?} BETWEEN {:?} AND {:?}", expr, low, high)
+ }
+ }
+ Expr::InList {
+ expr,
+ list,
+ negated,
+ } => {
+ if *negated {
+ write!(f, "{:?} NOT IN ({:?})", expr, list)
+ } else {
+ write!(f, "{:?} IN ({:?})", expr, list)
+ }
+ }
+ Expr::Wildcard => write!(f, "*"),
+ Expr::GetIndexedField { ref expr, key } => {
+ write!(f, "({:?})[{}]", expr, key)
+ }
+ }
+ }
+}
+
+fn fmt_function(
+ f: &mut fmt::Formatter,
+ fun: &str,
+ distinct: bool,
+ args: &[Expr],
+ display: bool,
+) -> fmt::Result {
+ let args: Vec<String> = match display {
+ true => args.iter().map(|arg| format!("{}", arg)).collect(),
+ false => args.iter().map(|arg| format!("{:?}", arg)).collect(),
+ };
+
+ // let args: Vec<String> = args.iter().map(|arg| format!("{:?}", arg)).collect();
+ let distinct_str = match distinct {
+ true => "DISTINCT ",
+ false => "",
+ };
+ write!(f, "{}({}{})", fun, distinct_str, args.join(", "))
+}
+
+fn create_function_name(
+ fun: &str,
+ distinct: bool,
+ args: &[Expr],
+ input_schema: &DFSchema,
+) -> Result<String> {
+ let names: Vec<String> = args
+ .iter()
+ .map(|e| create_name(e, input_schema))
+ .collect::<Result<_>>()?;
+ let distinct_str = match distinct {
+ true => "DISTINCT ",
+ false => "",
+ };
+ Ok(format!("{}({}{})", fun, distinct_str, names.join(",")))
+}
+
+/// Returns a readable name of an expression based on the input schema.
+/// This function recursively transverses the expression for names such as "CAST(a > 2)".
+fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
+ match e {
+ Expr::Alias(_, name) => Ok(name.clone()),
+ Expr::Column(c) => Ok(c.flat_name()),
+ Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")),
+ Expr::Literal(value) => Ok(format!("{:?}", value)),
+ Expr::BinaryExpr { left, op, right } => {
+ let left = create_name(left, input_schema)?;
+ let right = create_name(right, input_schema)?;
+ Ok(format!("{} {} {}", left, op, right))
+ }
+ Expr::Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ } => {
+ let mut name = "CASE ".to_string();
+ if let Some(e) = expr {
+ let e = create_name(e, input_schema)?;
+ name += &format!("{} ", e);
+ }
+ for (w, t) in when_then_expr {
+ let when = create_name(w, input_schema)?;
+ let then = create_name(t, input_schema)?;
+ name += &format!("WHEN {} THEN {} ", when, then);
+ }
+ if let Some(e) = else_expr {
+ let e = create_name(e, input_schema)?;
+ name += &format!("ELSE {} ", e);
+ }
+ name += "END";
+ Ok(name)
+ }
+ Expr::Cast { expr, data_type } => {
+ let expr = create_name(expr, input_schema)?;
+ Ok(format!("CAST({} AS {:?})", expr, data_type))
+ }
+ Expr::TryCast { expr, data_type } => {
+ let expr = create_name(expr, input_schema)?;
+ Ok(format!("TRY_CAST({} AS {:?})", expr, data_type))
+ }
+ Expr::Not(expr) => {
+ let expr = create_name(expr, input_schema)?;
+ Ok(format!("NOT {}", expr))
+ }
+ Expr::Negative(expr) => {
+ let expr = create_name(expr, input_schema)?;
+ Ok(format!("(- {})", expr))
+ }
+ Expr::IsNull(expr) => {
+ let expr = create_name(expr, input_schema)?;
+ Ok(format!("{} IS NULL", expr))
+ }
+ Expr::IsNotNull(expr) => {
+ let expr = create_name(expr, input_schema)?;
+ Ok(format!("{} IS NOT NULL", expr))
+ }
+ Expr::GetIndexedField { expr, key } => {
+ let expr = create_name(expr, input_schema)?;
+ Ok(format!("{}[{}]", expr, key))
+ }
+ Expr::ScalarFunction { fun, args, .. } => {
+ create_function_name(&fun.to_string(), false, args, input_schema)
+ }
+ Expr::ScalarUDF { fun, args, .. } => {
+ create_function_name(&fun.name, false, args, input_schema)
+ }
+ Expr::WindowFunction {
+ fun,
+ args,
+ window_frame,
+ partition_by,
+ order_by,
+ } => {
+ let mut parts: Vec<String> = vec![create_function_name(
+ &fun.to_string(),
+ false,
+ args,
+ input_schema,
+ )?];
+ if !partition_by.is_empty() {
+ parts.push(format!("PARTITION BY {:?}", partition_by));
+ }
+ if !order_by.is_empty() {
+ parts.push(format!("ORDER BY {:?}", order_by));
+ }
+ if let Some(window_frame) = window_frame {
+ parts.push(format!("{}", window_frame));
+ }
+ Ok(parts.join(" "))
+ }
+ Expr::AggregateFunction {
+ fun,
+ distinct,
+ args,
+ ..
+ } => create_function_name(&fun.to_string(), *distinct, args, input_schema),
+ Expr::AggregateUDF { fun, args } => {
+ let mut names = Vec::with_capacity(args.len());
+ for e in args {
+ names.push(create_name(e, input_schema)?);
+ }
+ Ok(format!("{}({})", fun.name, names.join(",")))
+ }
+ Expr::InList {
+ expr,
+ list,
+ negated,
+ } => {
+ let expr = create_name(expr, input_schema)?;
+ let list = list.iter().map(|expr| create_name(expr, input_schema));
+ if *negated {
+ Ok(format!("{} NOT IN ({:?})", expr, list))
+ } else {
+ Ok(format!("{} IN ({:?})", expr, list))
+ }
+ }
+ Expr::Between {
+ expr,
+ negated,
+ low,
+ high,
+ } => {
+ let expr = create_name(expr, input_schema)?;
+ let low = create_name(low, input_schema)?;
+ let high = create_name(high, input_schema)?;
+ if *negated {
+ Ok(format!("{} NOT BETWEEN {} AND {}", expr, low, high))
+ } else {
+ Ok(format!("{} BETWEEN {} AND {}", expr, low, high))
+ }
+ }
+ Expr::Sort { .. } => Err(DataFusionError::Internal(
+ "Create name does not support sort expression".to_string(),
+ )),
+ Expr::Wildcard => Err(DataFusionError::Internal(
+ "Create name does not support wildcard".to_string(),
+ )),
+ }
+}
diff --git a/datafusion-cli/src/lib.rs b/datafusion-expr/src/expr_fn.rs
similarity index 68%
copy from datafusion-cli/src/lib.rs
copy to datafusion-expr/src/expr_fn.rs
index b2bcdd3..469a82d 100644
--- a/datafusion-cli/src/lib.rs
+++ b/datafusion-expr/src/expr_fn.rs
@@ -15,14 +15,18 @@
// specific language governing permissions and limitations
// under the License.
-#![doc = include_str!("../README.md")]
-#![allow(unused_imports)]
-pub const DATAFUSION_CLI_VERSION: &str = env!("CARGO_PKG_VERSION");
+use crate::{Expr, Operator};
-pub mod command;
-pub mod context;
-pub mod exec;
-pub mod functions;
-pub mod helper;
-pub mod print_format;
-pub mod print_options;
+/// Create a column expression based on a qualified or unqualified column name
+pub fn col(ident: &str) -> Expr {
+ Expr::Column(ident.into())
+}
+
+/// return a new expression l <op> r
+pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr {
+ Expr::BinaryExpr {
+ left: Box::new(l),
+ op,
+ right: Box::new(r),
+ }
+}
diff --git a/datafusion-expr/src/function.rs b/datafusion-expr/src/function.rs
new file mode 100644
index 0000000..2bacd6a
--- /dev/null
+++ b/datafusion-expr/src/function.rs
@@ -0,0 +1,46 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::Accumulator;
+use crate::ColumnarValue;
+use arrow::datatypes::DataType;
+use datafusion_common::Result;
+use std::sync::Arc;
+
+/// Scalar function
+///
+/// The Fn param is the wrapped function but be aware that the function will
+/// be passed with the slice / vec of columnar values (either scalar or array)
+/// with the exception of zero param function, where a singular element vec
+/// will be passed. In that case the single element is a null array to indicate
+/// the batch's row count (so that the generative zero-argument function can know
+/// the result array size).
+pub type ScalarFunctionImplementation =
+ Arc<dyn Fn(&[ColumnarValue]) -> Result<ColumnarValue> + Send + Sync>;
+
+/// A function's return type
+pub type ReturnTypeFunction =
+ Arc<dyn Fn(&[DataType]) -> Result<Arc<DataType>> + Send + Sync>;
+
+/// the implementation of an aggregate function
+pub type AccumulatorFunctionImplementation =
+ Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync>;
+
+/// This signature corresponds to which types an aggregator serializes
+/// its state, given its return datatype.
+pub type StateTypeFunction =
+ Arc<dyn Fn(&DataType) -> Result<Arc<Vec<DataType>>> + Send + Sync>;
diff --git a/datafusion-expr/src/lib.rs b/datafusion-expr/src/lib.rs
new file mode 100644
index 0000000..709fa63
--- /dev/null
+++ b/datafusion-expr/src/lib.rs
@@ -0,0 +1,49 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+mod accumulator;
+mod aggregate_function;
+mod built_in_function;
+mod columnar_value;
+pub mod expr;
+pub mod expr_fn;
+mod function;
+mod literal;
+mod operator;
+mod signature;
+mod udaf;
+mod udf;
+mod window_frame;
+mod window_function;
+
+pub use accumulator::Accumulator;
+pub use aggregate_function::AggregateFunction;
+pub use built_in_function::BuiltinScalarFunction;
+pub use columnar_value::{ColumnarValue, NullColumnarValue};
+pub use expr::Expr;
+pub use expr_fn::col;
+pub use function::{
+ AccumulatorFunctionImplementation, ReturnTypeFunction, ScalarFunctionImplementation,
+ StateTypeFunction,
+};
+pub use literal::{lit, lit_timestamp_nano, Literal, TimestampLiteral};
+pub use operator::Operator;
+pub use signature::{Signature, TypeSignature, Volatility};
+pub use udaf::AggregateUDF;
+pub use udf::ScalarUDF;
+pub use window_frame::{WindowFrame, WindowFrameBound, WindowFrameUnits};
+pub use window_function::{BuiltInWindowFunction, WindowFunction};
diff --git a/datafusion-expr/src/literal.rs b/datafusion-expr/src/literal.rs
new file mode 100644
index 0000000..02c75af
--- /dev/null
+++ b/datafusion-expr/src/literal.rs
@@ -0,0 +1,138 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::Expr;
+use datafusion_common::ScalarValue;
+
+/// Create a literal expression
+pub fn lit<T: Literal>(n: T) -> Expr {
+ n.lit()
+}
+
+/// Create a literal timestamp expression
+pub fn lit_timestamp_nano<T: TimestampLiteral>(n: T) -> Expr {
+ n.lit_timestamp_nano()
+}
+
+/// Trait for converting a type to a [`Literal`] literal expression.
+pub trait Literal {
+ /// convert the value to a Literal expression
+ fn lit(&self) -> Expr;
+}
+
+/// Trait for converting a type to a literal timestamp
+pub trait TimestampLiteral {
+ fn lit_timestamp_nano(&self) -> Expr;
+}
+
+impl Literal for &str {
+ fn lit(&self) -> Expr {
+ Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned())))
+ }
+}
+
+impl Literal for String {
+ fn lit(&self) -> Expr {
+ Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned())))
+ }
+}
+
+impl Literal for Vec<u8> {
+ fn lit(&self) -> Expr {
+ Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())))
+ }
+}
+
+impl Literal for &[u8] {
+ fn lit(&self) -> Expr {
+ Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())))
+ }
+}
+
+impl Literal for ScalarValue {
+ fn lit(&self) -> Expr {
+ Expr::Literal(self.clone())
+ }
+}
+
+macro_rules! make_literal {
+ ($TYPE:ty, $SCALAR:ident, $DOC: expr) => {
+ #[doc = $DOC]
+ impl Literal for $TYPE {
+ fn lit(&self) -> Expr {
+ Expr::Literal(ScalarValue::$SCALAR(Some(self.clone())))
+ }
+ }
+ };
+}
+
+macro_rules! make_timestamp_literal {
+ ($TYPE:ty, $SCALAR:ident, $DOC: expr) => {
+ #[doc = $DOC]
+ impl TimestampLiteral for $TYPE {
+ fn lit_timestamp_nano(&self) -> Expr {
+ Expr::Literal(ScalarValue::TimestampNanosecond(
+ Some((self.clone()).into()),
+ None,
+ ))
+ }
+ }
+ };
+}
+
+make_literal!(bool, Boolean, "literal expression containing a bool");
+make_literal!(f32, Float32, "literal expression containing an f32");
+make_literal!(f64, Float64, "literal expression containing an f64");
+make_literal!(i8, Int8, "literal expression containing an i8");
+make_literal!(i16, Int16, "literal expression containing an i16");
+make_literal!(i32, Int32, "literal expression containing an i32");
+make_literal!(i64, Int64, "literal expression containing an i64");
+make_literal!(u8, UInt8, "literal expression containing a u8");
+make_literal!(u16, UInt16, "literal expression containing a u16");
+make_literal!(u32, UInt32, "literal expression containing a u32");
+make_literal!(u64, UInt64, "literal expression containing a u64");
+
+make_timestamp_literal!(i8, Int8, "literal expression containing an i8");
+make_timestamp_literal!(i16, Int16, "literal expression containing an i16");
+make_timestamp_literal!(i32, Int32, "literal expression containing an i32");
+make_timestamp_literal!(i64, Int64, "literal expression containing an i64");
+make_timestamp_literal!(u8, UInt8, "literal expression containing a u8");
+make_timestamp_literal!(u16, UInt16, "literal expression containing a u16");
+make_timestamp_literal!(u32, UInt32, "literal expression containing a u32");
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use crate::expr_fn::col;
+ use datafusion_common::ScalarValue;
+
+ #[test]
+ fn test_lit_timestamp_nano() {
+ let expr = col("time").eq(lit_timestamp_nano(10)); // 10 is an implicit i32
+ let expected =
+ col("time").eq(lit(ScalarValue::TimestampNanosecond(Some(10), None)));
+ assert_eq!(expr, expected);
+
+ let i: i64 = 10;
+ let expr = col("time").eq(lit_timestamp_nano(i));
+ assert_eq!(expr, expected);
+
+ let i: u32 = 10;
+ let expr = col("time").eq(lit_timestamp_nano(i));
+ assert_eq!(expr, expected);
+ }
+}
diff --git a/datafusion/src/logical_plan/operators.rs b/datafusion-expr/src/operator.rs
similarity index 83%
copy from datafusion/src/logical_plan/operators.rs
copy to datafusion-expr/src/operator.rs
index 14ccab0..a1cad76 100644
--- a/datafusion/src/logical_plan/operators.rs
+++ b/datafusion-expr/src/operator.rs
@@ -15,9 +15,10 @@
// specific language governing permissions and limitations
// under the License.
-use std::{fmt, ops};
-
-use super::{binary_expr, Expr};
+use crate::expr_fn::binary_expr;
+use crate::Expr;
+use std::fmt;
+use std::ops;
/// Operators applied to expressions
#[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Hash)]
@@ -137,32 +138,3 @@ impl ops::Rem for Expr {
binary_expr(self, Operator::Modulo, rhs)
}
}
-
-#[cfg(test)]
-mod tests {
- use crate::prelude::lit;
-
- #[test]
- fn test_operators() {
- assert_eq!(
- format!("{:?}", lit(1u32) + lit(2u32)),
- "UInt32(1) + UInt32(2)"
- );
- assert_eq!(
- format!("{:?}", lit(1u32) - lit(2u32)),
- "UInt32(1) - UInt32(2)"
- );
- assert_eq!(
- format!("{:?}", lit(1u32) * lit(2u32)),
- "UInt32(1) * UInt32(2)"
- );
- assert_eq!(
- format!("{:?}", lit(1u32) / lit(2u32)),
- "UInt32(1) / UInt32(2)"
- );
- assert_eq!(
- format!("{:?}", lit(1u32) % lit(2u32)),
- "UInt32(1) % UInt32(2)"
- );
- }
-}
diff --git a/datafusion-expr/src/signature.rs b/datafusion-expr/src/signature.rs
new file mode 100644
index 0000000..5c27f42
--- /dev/null
+++ b/datafusion-expr/src/signature.rs
@@ -0,0 +1,116 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use arrow::datatypes::DataType;
+
+///A function's volatility, which defines the functions eligibility for certain optimizations
+#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)]
+pub enum Volatility {
+ /// Immutable - An immutable function will always return the same output when given the same input. An example of this is [BuiltinScalarFunction::Cos].
+ Immutable,
+ /// Stable - A stable function may return different values given the same input accross different queries but must return the same value for a given input within a query. An example of this is [BuiltinScalarFunction::Now].
+ Stable,
+ /// Volatile - A volatile function may change the return value from evaluation to evaluation. Mutiple invocations of a volatile function may return different results when used in the same query. An example of this is [BuiltinScalarFunction::Random].
+ Volatile,
+}
+
+/// A function's type signature, which defines the function's supported argument types.
+#[derive(Debug, Clone, PartialEq, Hash)]
+pub enum TypeSignature {
+ /// arbitrary number of arguments of an common type out of a list of valid types
+ // A function such as `concat` is `Variadic(vec![DataType::Utf8, DataType::LargeUtf8])`
+ Variadic(Vec<DataType>),
+ /// arbitrary number of arguments of an arbitrary but equal type
+ // A function such as `array` is `VariadicEqual`
+ // The first argument decides the type used for coercion
+ VariadicEqual,
+ /// fixed number of arguments of an arbitrary but equal type out of a list of valid types
+ // A function of one argument of f64 is `Uniform(1, vec![DataType::Float64])`
+ // A function of one argument of f64 or f32 is `Uniform(1, vec![DataType::Float32, DataType::Float64])`
+ Uniform(usize, Vec<DataType>),
+ /// exact number of arguments of an exact type
+ Exact(Vec<DataType>),
+ /// fixed number of arguments of arbitrary types
+ Any(usize),
+ /// One of a list of signatures
+ OneOf(Vec<TypeSignature>),
+}
+
+///The Signature of a function defines its supported input types as well as its volatility.
+#[derive(Debug, Clone, PartialEq, Hash)]
+pub struct Signature {
+ /// type_signature - The types that the function accepts. See [TypeSignature] for more information.
+ pub type_signature: TypeSignature,
+ /// volatility - The volatility of the function. See [Volatility] for more information.
+ pub volatility: Volatility,
+}
+
+impl Signature {
+ /// new - Creates a new Signature from any type signature and the volatility.
+ pub fn new(type_signature: TypeSignature, volatility: Volatility) -> Self {
+ Signature {
+ type_signature,
+ volatility,
+ }
+ }
+ /// variadic - Creates a variadic signature that represents an arbitrary number of arguments all from a type in common_types.
+ pub fn variadic(common_types: Vec<DataType>, volatility: Volatility) -> Self {
+ Self {
+ type_signature: TypeSignature::Variadic(common_types),
+ volatility,
+ }
+ }
+ /// variadic_equal - Creates a variadic signature that represents an arbitrary number of arguments of the same type.
+ pub fn variadic_equal(volatility: Volatility) -> Self {
+ Self {
+ type_signature: TypeSignature::VariadicEqual,
+ volatility,
+ }
+ }
+ /// uniform - Creates a function with a fixed number of arguments of the same type, which must be from valid_types.
+ pub fn uniform(
+ arg_count: usize,
+ valid_types: Vec<DataType>,
+ volatility: Volatility,
+ ) -> Self {
+ Self {
+ type_signature: TypeSignature::Uniform(arg_count, valid_types),
+ volatility,
+ }
+ }
+ /// exact - Creates a signture which must match the types in exact_types in order.
+ pub fn exact(exact_types: Vec<DataType>, volatility: Volatility) -> Self {
+ Signature {
+ type_signature: TypeSignature::Exact(exact_types),
+ volatility,
+ }
+ }
+ /// any - Creates a signature which can a be made of any type but of a specified number
+ pub fn any(arg_count: usize, volatility: Volatility) -> Self {
+ Signature {
+ type_signature: TypeSignature::Any(arg_count),
+ volatility,
+ }
+ }
+ /// one_of Creates a signature which can match any of the [TypeSignature]s which are passed in.
+ pub fn one_of(type_signatures: Vec<TypeSignature>, volatility: Volatility) -> Self {
+ Signature {
+ type_signature: TypeSignature::OneOf(type_signatures),
+ volatility,
+ }
+ }
+}
diff --git a/datafusion/src/physical_plan/udaf.rs b/datafusion-expr/src/udaf.rs
similarity index 53%
copy from datafusion/src/physical_plan/udaf.rs
copy to datafusion-expr/src/udaf.rs
index 0de696d..a39d58b 100644
--- a/datafusion/src/physical_plan/udaf.rs
+++ b/datafusion-expr/src/udaf.rs
@@ -17,26 +17,11 @@
//! This module contains functions and structs supporting user-defined aggregate functions.
-use fmt::{Debug, Formatter};
-use std::any::Any;
-use std::fmt;
-
-use arrow::{
- datatypes::Field,
- datatypes::{DataType, Schema},
-};
-
-use crate::physical_plan::PhysicalExpr;
-use crate::{error::Result, logical_plan::Expr};
-
-use super::{
- aggregates::AccumulatorFunctionImplementation,
- aggregates::StateTypeFunction,
- expressions::format_state_name,
- functions::{ReturnTypeFunction, Signature},
- type_coercion::coerce,
- Accumulator, AggregateExpr,
+use crate::Expr;
+use crate::{
+ AccumulatorFunctionImplementation, ReturnTypeFunction, Signature, StateTypeFunction,
};
+use std::fmt::{self, Debug, Formatter};
use std::sync::Arc;
/// Logical representation of a user-defined aggregate function (UDAF)
@@ -105,75 +90,3 @@ impl AggregateUDF {
}
}
}
-
-/// Creates a physical expression of the UDAF, that includes all necessary type coercion.
-/// This function errors when `args`' can't be coerced to a valid argument type of the UDAF.
-pub fn create_aggregate_expr(
- fun: &AggregateUDF,
- input_phy_exprs: &[Arc<dyn PhysicalExpr>],
- input_schema: &Schema,
- name: impl Into<String>,
-) -> Result<Arc<dyn AggregateExpr>> {
- // coerce
- let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &fun.signature)?;
-
- let coerced_exprs_types = coerced_phy_exprs
- .iter()
- .map(|arg| arg.data_type(input_schema))
- .collect::<Result<Vec<_>>>()?;
-
- Ok(Arc::new(AggregateFunctionExpr {
- fun: fun.clone(),
- args: coerced_phy_exprs.clone(),
- data_type: (fun.return_type)(&coerced_exprs_types)?.as_ref().clone(),
- name: name.into(),
- }))
-}
-
-/// Physical aggregate expression of a UDAF.
-#[derive(Debug)]
-pub struct AggregateFunctionExpr {
- fun: AggregateUDF,
- args: Vec<Arc<dyn PhysicalExpr>>,
- data_type: DataType,
- name: String,
-}
-
-impl AggregateExpr for AggregateFunctionExpr {
- /// Return a reference to Any that can be used for downcasting
- fn as_any(&self) -> &dyn Any {
- self
- }
-
- fn expressions(&self) -> Vec<Arc<dyn PhysicalExpr>> {
- self.args.clone()
- }
-
- fn state_fields(&self) -> Result<Vec<Field>> {
- let fields = (self.fun.state_type)(&self.data_type)?
- .iter()
- .enumerate()
- .map(|(i, data_type)| {
- Field::new(
- &format_state_name(&self.name, &format!("{}", i)),
- data_type.clone(),
- true,
- )
- })
- .collect::<Vec<Field>>();
-
- Ok(fields)
- }
-
- fn field(&self) -> Result<Field> {
- Ok(Field::new(&self.name, self.data_type.clone(), true))
- }
-
- fn create_accumulator(&self) -> Result<Box<dyn Accumulator>> {
- (self.fun.accumulator)()
- }
-
- fn name(&self) -> &str {
- &self.name
- }
-}
diff --git a/datafusion/src/physical_plan/udf.rs b/datafusion-expr/src/udf.rs
similarity index 73%
copy from datafusion/src/physical_plan/udf.rs
copy to datafusion-expr/src/udf.rs
index 7355746..79a17a4 100644
--- a/datafusion/src/physical_plan/udf.rs
+++ b/datafusion-expr/src/udf.rs
@@ -17,20 +17,10 @@
//! UDF support
-use fmt::{Debug, Formatter};
+use crate::{Expr, ReturnTypeFunction, ScalarFunctionImplementation, Signature};
use std::fmt;
-
-use arrow::datatypes::Schema;
-
-use crate::error::Result;
-use crate::{logical_plan::Expr, physical_plan::PhysicalExpr};
-
-use super::{
- functions::{
- ReturnTypeFunction, ScalarFunctionExpr, ScalarFunctionImplementation, Signature,
- },
- type_coercion::coerce,
-};
+use std::fmt::Debug;
+use std::fmt::Formatter;
use std::sync::Arc;
/// Logical representation of a UDF.
@@ -101,26 +91,3 @@ impl ScalarUDF {
}
}
}
-
-/// Create a physical expression of the UDF.
-/// This function errors when `args`' can't be coerced to a valid argument type of the UDF.
-pub fn create_physical_expr(
- fun: &ScalarUDF,
- input_phy_exprs: &[Arc<dyn PhysicalExpr>],
- input_schema: &Schema,
-) -> Result<Arc<dyn PhysicalExpr>> {
- // coerce
- let coerced_phy_exprs = coerce(input_phy_exprs, input_schema, &fun.signature)?;
-
- let coerced_exprs_types = coerced_phy_exprs
- .iter()
- .map(|e| e.data_type(input_schema))
- .collect::<Result<Vec<_>>>()?;
-
- Ok(Arc::new(ScalarFunctionExpr::new(
- &fun.name,
- fun.fun.clone(),
- coerced_phy_exprs,
- (fun.return_type)(&coerced_exprs_types)?.as_ref(),
- )))
-}
diff --git a/datafusion/src/logical_plan/window_frames.rs b/datafusion-expr/src/window_frame.rs
similarity index 96%
copy from datafusion/src/logical_plan/window_frames.rs
copy to datafusion-expr/src/window_frame.rs
index 50e2ee7..ba65a50 100644
--- a/datafusion/src/logical_plan/window_frames.rs
+++ b/datafusion-expr/src/window_frame.rs
@@ -23,7 +23,7 @@
//! - An ending frame boundary,
//! - An EXCLUDE clause.
-use crate::error::{DataFusionError, Result};
+use datafusion_common::{DataFusionError, Result};
use sqlparser::ast;
use std::cmp::Ordering;
use std::convert::{From, TryFrom};
@@ -78,9 +78,9 @@ impl TryFrom<ast::WindowFrame> for WindowFrame {
))
} else if start_bound > end_bound {
Err(DataFusionError::Execution(format!(
- "Invalid window frame: start bound ({}) cannot be larger than end bound ({})",
- start_bound, end_bound
- )))
+ "Invalid window frame: start bound ({}) cannot be larger than end bound ({})",
+ start_bound, end_bound
+ )))
} else {
let units = value.units.into();
if units == WindowFrameUnits::Range {
@@ -173,12 +173,6 @@ impl fmt::Display for WindowFrameBound {
}
}
-impl Hash for WindowFrameBound {
- fn hash<H: Hasher>(&self, state: &mut H) {
- self.get_rank().hash(state)
- }
-}
-
impl PartialEq for WindowFrameBound {
fn eq(&self, other: &Self) -> bool {
self.cmp(other) == Ordering::Equal
@@ -197,6 +191,12 @@ impl Ord for WindowFrameBound {
}
}
+impl Hash for WindowFrameBound {
+ fn hash<H: Hasher>(&self, state: &mut H) {
+ self.get_rank().hash(state)
+ }
+}
+
impl WindowFrameBound {
/// get the rank of this window frame bound.
///
@@ -268,9 +268,10 @@ mod tests {
};
let result = WindowFrame::try_from(window_frame);
assert_eq!(
- result.err().unwrap().to_string(),
- "Execution error: Invalid window frame: start bound cannot be unbounded following".to_owned()
- );
+ result.err().unwrap().to_string(),
+ "Execution error: Invalid window frame: start bound cannot be unbounded following"
+ .to_owned()
+ );
let window_frame = ast::WindowFrame {
units: ast::WindowFrameUnits::Range,
@@ -279,9 +280,10 @@ mod tests {
};
let result = WindowFrame::try_from(window_frame);
assert_eq!(
- result.err().unwrap().to_string(),
- "Execution error: Invalid window frame: end bound cannot be unbounded preceding".to_owned()
- );
+ result.err().unwrap().to_string(),
+ "Execution error: Invalid window frame: end bound cannot be unbounded preceding"
+ .to_owned()
+ );
let window_frame = ast::WindowFrame {
units: ast::WindowFrameUnits::Range,
diff --git a/datafusion-expr/src/window_function.rs b/datafusion-expr/src/window_function.rs
new file mode 100644
index 0000000..59523d6
--- /dev/null
+++ b/datafusion-expr/src/window_function.rs
@@ -0,0 +1,204 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::aggregate_function::AggregateFunction;
+use datafusion_common::{DataFusionError, Result};
+use std::{fmt, str::FromStr};
+
+/// WindowFunction
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub enum WindowFunction {
+ /// window function that leverages an aggregate function
+ AggregateFunction(AggregateFunction),
+ /// window function that leverages a built-in window function
+ BuiltInWindowFunction(BuiltInWindowFunction),
+}
+
+impl FromStr for WindowFunction {
+ type Err = DataFusionError;
+ fn from_str(name: &str) -> Result<WindowFunction> {
+ let name = name.to_lowercase();
+ if let Ok(aggregate) = AggregateFunction::from_str(name.as_str()) {
+ Ok(WindowFunction::AggregateFunction(aggregate))
+ } else if let Ok(built_in_function) =
+ BuiltInWindowFunction::from_str(name.as_str())
+ {
+ Ok(WindowFunction::BuiltInWindowFunction(built_in_function))
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "There is no window function named {}",
+ name
+ )))
+ }
+ }
+}
+
+impl fmt::Display for BuiltInWindowFunction {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ BuiltInWindowFunction::RowNumber => write!(f, "ROW_NUMBER"),
+ BuiltInWindowFunction::Rank => write!(f, "RANK"),
+ BuiltInWindowFunction::DenseRank => write!(f, "DENSE_RANK"),
+ BuiltInWindowFunction::PercentRank => write!(f, "PERCENT_RANK"),
+ BuiltInWindowFunction::CumeDist => write!(f, "CUME_DIST"),
+ BuiltInWindowFunction::Ntile => write!(f, "NTILE"),
+ BuiltInWindowFunction::Lag => write!(f, "LAG"),
+ BuiltInWindowFunction::Lead => write!(f, "LEAD"),
+ BuiltInWindowFunction::FirstValue => write!(f, "FIRST_VALUE"),
+ BuiltInWindowFunction::LastValue => write!(f, "LAST_VALUE"),
+ BuiltInWindowFunction::NthValue => write!(f, "NTH_VALUE"),
+ }
+ }
+}
+
+impl fmt::Display for WindowFunction {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ WindowFunction::AggregateFunction(fun) => fun.fmt(f),
+ WindowFunction::BuiltInWindowFunction(fun) => fun.fmt(f),
+ }
+ }
+}
+
+/// An aggregate function that is part of a built-in window function
+#[derive(Debug, Clone, PartialEq, Eq, Hash)]
+pub enum BuiltInWindowFunction {
+ /// number of the current row within its partition, counting from 1
+ RowNumber,
+ /// rank of the current row with gaps; same as row_number of its first peer
+ Rank,
+ /// ank of the current row without gaps; this function counts peer groups
+ DenseRank,
+ /// relative rank of the current row: (rank - 1) / (total rows - 1)
+ PercentRank,
+ /// relative rank of the current row: (number of rows preceding or peer with current row) / (total rows)
+ CumeDist,
+ /// integer ranging from 1 to the argument value, dividing the partition as equally as possible
+ Ntile,
+ /// returns value evaluated at the row that is offset rows before the current row within the partition;
+ /// if there is no such row, instead return default (which must be of the same type as value).
+ /// Both offset and default are evaluated with respect to the current row.
+ /// If omitted, offset defaults to 1 and default to null
+ Lag,
+ /// returns value evaluated at the row that is offset rows after the current row within the partition;
+ /// if there is no such row, instead return default (which must be of the same type as value).
+ /// Both offset and default are evaluated with respect to the current row.
+ /// If omitted, offset defaults to 1 and default to null
+ Lead,
+ /// returns value evaluated at the row that is the first row of the window frame
+ FirstValue,
+ /// returns value evaluated at the row that is the last row of the window frame
+ LastValue,
+ /// returns value evaluated at the row that is the nth row of the window frame (counting from 1); null if no such row
+ NthValue,
+}
+
+impl FromStr for BuiltInWindowFunction {
+ type Err = DataFusionError;
+ fn from_str(name: &str) -> Result<BuiltInWindowFunction> {
+ Ok(match name.to_uppercase().as_str() {
+ "ROW_NUMBER" => BuiltInWindowFunction::RowNumber,
+ "RANK" => BuiltInWindowFunction::Rank,
+ "DENSE_RANK" => BuiltInWindowFunction::DenseRank,
+ "PERCENT_RANK" => BuiltInWindowFunction::PercentRank,
+ "CUME_DIST" => BuiltInWindowFunction::CumeDist,
+ "NTILE" => BuiltInWindowFunction::Ntile,
+ "LAG" => BuiltInWindowFunction::Lag,
+ "LEAD" => BuiltInWindowFunction::Lead,
+ "FIRST_VALUE" => BuiltInWindowFunction::FirstValue,
+ "LAST_VALUE" => BuiltInWindowFunction::LastValue,
+ "NTH_VALUE" => BuiltInWindowFunction::NthValue,
+ _ => {
+ return Err(DataFusionError::Plan(format!(
+ "There is no built-in window function named {}",
+ name
+ )))
+ }
+ })
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ #[test]
+ fn test_window_function_case_insensitive() -> Result<()> {
+ let names = vec![
+ "row_number",
+ "rank",
+ "dense_rank",
+ "percent_rank",
+ "cume_dist",
+ "ntile",
+ "lag",
+ "lead",
+ "first_value",
+ "last_value",
+ "nth_value",
+ "min",
+ "max",
+ "count",
+ "avg",
+ "sum",
+ ];
+ for name in names {
+ let fun = WindowFunction::from_str(name)?;
+ let fun2 = WindowFunction::from_str(name.to_uppercase().as_str())?;
+ assert_eq!(fun, fun2);
+ assert_eq!(fun.to_string(), name.to_uppercase());
+ }
+ Ok(())
+ }
+
+ #[test]
+ fn test_window_function_from_str() -> Result<()> {
+ assert_eq!(
+ WindowFunction::from_str("max")?,
+ WindowFunction::AggregateFunction(AggregateFunction::Max)
+ );
+ assert_eq!(
+ WindowFunction::from_str("min")?,
+ WindowFunction::AggregateFunction(AggregateFunction::Min)
+ );
+ assert_eq!(
+ WindowFunction::from_str("avg")?,
+ WindowFunction::AggregateFunction(AggregateFunction::Avg)
+ );
+ assert_eq!(
+ WindowFunction::from_str("cume_dist")?,
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::CumeDist)
+ );
+ assert_eq!(
+ WindowFunction::from_str("first_value")?,
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::FirstValue)
+ );
+ assert_eq!(
+ WindowFunction::from_str("LAST_value")?,
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::LastValue)
+ );
+ assert_eq!(
+ WindowFunction::from_str("LAG")?,
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lag)
+ );
+ assert_eq!(
+ WindowFunction::from_str("LEAD")?,
+ WindowFunction::BuiltInWindowFunction(BuiltInWindowFunction::Lead)
+ );
+ Ok(())
+ }
+}
diff --git a/datafusion/Cargo.toml b/datafusion/Cargo.toml
index c37c204..67c09e5 100644
--- a/datafusion/Cargo.toml
+++ b/datafusion/Cargo.toml
@@ -45,17 +45,18 @@ simd = ["arrow/simd"]
crypto_expressions = ["md-5", "sha2", "blake2", "blake3"]
regex_expressions = ["regex"]
unicode_expressions = ["unicode-segmentation"]
-# FIXME: add pyarrow support to arrow2 pyarrow = ["pyo3", "arrow/pyarrow"]
-pyarrow = ["pyo3"]
+pyarrow = ["pyo3", "datafusion-common/pyarrow"]
# Used for testing ONLY: causes all values to hash to the same value (test for collisions)
force_hash_collisions = []
# Used to enable the avro format
avro = ["arrow/io_avro", "arrow/io_avro_async", "arrow/io_avro_compression", "num-traits", "avro-schema"]
[dependencies]
+datafusion-common = { path = "../datafusion-common", version = "6.0.0" }
+datafusion-expr = { path = "../datafusion-expr", version = "6.0.0" }
ahash = { version = "0.7", default-features = false }
hashbrown = { version = "0.11", features = ["raw"] }
-parquet = { package = "parquet2", version = "0.9", default_features = false, features = ["stream"] }
+parquet = { package = "parquet2", version = "0.10", default_features = false, features = ["stream"] }
sqlparser = "0.13"
paste = "^1.0"
num_cpus = "1.13.0"
@@ -70,7 +71,7 @@ md-5 = { version = "^0.10.0", optional = true }
sha2 = { version = "^0.10.1", optional = true }
blake2 = { version = "^0.10.2", optional = true }
blake3 = { version = "1.0", optional = true }
-ordered-float = "2.0"
+ordered-float = "2.10"
unicode-segmentation = { version = "^1.7.1", optional = true }
regex = { version = "^1.4.3", optional = true }
lazy_static = { version = "^1.4.0" }
@@ -79,7 +80,7 @@ rand = "0.8"
num-traits = { version = "0.2", optional = true }
pyo3 = { version = "0.15", optional = true }
tempfile = "3"
-parking_lot = "0.11"
+parking_lot = "0.12"
avro-schema = { version = "0.2", optional = true }
# used to print arrow arrays in a nice columnar format
diff --git a/datafusion/benches/sort_limit_query_sql.rs b/datafusion/benches/sort_limit_query_sql.rs
index 7fe8e7c..2013a2b 100644
--- a/datafusion/benches/sort_limit_query_sql.rs
+++ b/datafusion/benches/sort_limit_query_sql.rs
@@ -25,6 +25,9 @@ use datafusion::datasource::object_store::local::LocalFileSystem;
use parking_lot::Mutex;
use std::sync::Arc;
+extern crate arrow;
+extern crate datafusion;
+
use arrow::datatypes::{DataType, Field, Schema};
use datafusion::datasource::MemTable;
diff --git a/datafusion/fuzz-utils/Cargo.toml b/datafusion/fuzz-utils/Cargo.toml
index cb1e2e9..b064645 100644
--- a/datafusion/fuzz-utils/Cargo.toml
+++ b/datafusion/fuzz-utils/Cargo.toml
@@ -23,7 +23,7 @@ edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
+datafusion-common = { path = "../../datafusion-common" }
arrow = { package = "arrow2", version="0.9", features = ["io_print"] }
-datafusion = { path = ".." }
rand = "0.8"
env_logger = "0.9.0"
diff --git a/datafusion/fuzz-utils/src/lib.rs b/datafusion/fuzz-utils/src/lib.rs
index 81da480..03b6678 100644
--- a/datafusion/fuzz-utils/src/lib.rs
+++ b/datafusion/fuzz-utils/src/lib.rs
@@ -20,14 +20,14 @@ use arrow::array::Int32Array;
use rand::prelude::StdRng;
use rand::Rng;
-use datafusion::record_batch::RecordBatch;
+use datafusion_common::record_batch::RecordBatch;
pub use env_logger;
/// Extracts the i32 values from the set of batches and returns them as a single Vec
pub fn batches_to_vec(batches: &[RecordBatch]) -> Vec<Option<i32>> {
batches
.iter()
- .map(|batch| {
+ .flat_map(|batch| {
assert_eq!(batch.num_columns(), 1);
batch
.column(0)
@@ -37,7 +37,6 @@ pub fn batches_to_vec(batches: &[RecordBatch]) -> Vec<Option<i32>> {
.iter()
.map(|v| v.copied())
})
- .flatten()
.collect()
}
@@ -45,8 +44,7 @@ pub fn batches_to_vec(batches: &[RecordBatch]) -> Vec<Option<i32>> {
pub fn partitions_to_sorted_vec(partitions: &[Vec<RecordBatch>]) -> Vec<Option<i32>> {
let mut values: Vec<_> = partitions
.iter()
- .map(|batches| batches_to_vec(batches).into_iter())
- .flatten()
+ .flat_map(|batches| batches_to_vec(batches).into_iter())
.collect();
values.sort_unstable();
@@ -62,7 +60,7 @@ pub fn add_empty_batches(
batches
.into_iter()
- .map(|batch| {
+ .flat_map(|batch| {
// insert 0, or 1 empty batches before and after the current batch
let empty_batch = RecordBatch::new_empty(schema.clone());
std::iter::repeat(empty_batch.clone())
@@ -70,6 +68,5 @@ pub fn add_empty_batches(
.chain(std::iter::once(batch))
.chain(std::iter::repeat(empty_batch).take(rng.gen_range(0..2)))
})
- .flatten()
.collect()
}
diff --git a/datafusion/src/avro_to_arrow/arrow_array_reader.rs b/datafusion/src/avro_to_arrow/arrow_array_reader.rs
index 0fd50e9..8667c77 100644
--- a/datafusion/src/avro_to_arrow/arrow_array_reader.rs
+++ b/datafusion/src/avro_to_arrow/arrow_array_reader.rs
@@ -38,6 +38,7 @@ impl<'a, R: Read> AvroBatchReader<R> {
avro_schemas: Vec<avro_schema::Schema>,
codec: Option<Compression>,
file_marker: [u8; 16],
+ projection: Option<Vec<bool>>,
) -> Result<Self> {
let reader = AvroReader::new(
read::Decompressor::new(
@@ -46,6 +47,7 @@ impl<'a, R: Read> AvroBatchReader<R> {
),
avro_schemas,
schema.fields.clone(),
+ projection,
);
Ok(Self { reader, schema })
}
diff --git a/datafusion/src/avro_to_arrow/reader.rs b/datafusion/src/avro_to_arrow/reader.rs
index 7cb640e..a7a8e95 100644
--- a/datafusion/src/avro_to_arrow/reader.rs
+++ b/datafusion/src/avro_to_arrow/reader.rs
@@ -108,22 +108,16 @@ impl ReaderBuilder {
// check if schema should be inferred
source.seek(SeekFrom::Start(0))?;
- let (mut avro_schemas, mut schema, codec, file_marker) =
+ let (avro_schemas, schema, codec, file_marker) =
read::read_metadata(&mut source)?;
- if let Some(proj) = self.projection {
- let mut indices: Vec<usize> = schema
+
+ let projection = self.projection.map(|proj| {
+ schema
.fields
.iter()
- .filter(|f| !proj.contains(&f.name))
- .enumerate()
- .map(|(i, _)| i)
- .collect();
- indices.sort_by(|i1, i2| i2.cmp(i1));
- for i in indices {
- avro_schemas.remove(i);
- schema.fields.remove(i);
- }
- }
+ .map(|f| proj.contains(&f.name))
+ .collect::<Vec<bool>>()
+ });
Reader::try_new(
source,
@@ -132,6 +126,7 @@ impl ReaderBuilder {
avro_schemas,
codec,
file_marker,
+ projection,
)
}
}
@@ -155,6 +150,7 @@ impl<'a, R: Read> Reader<R> {
avro_schemas: Vec<avro_schema::Schema>,
codec: Option<Compression>,
file_marker: [u8; 16],
+ projection: Option<Vec<bool>>,
) -> Result<Self> {
Ok(Self {
array_reader: AvroBatchReader::try_new(
@@ -163,6 +159,7 @@ impl<'a, R: Read> Reader<R> {
avro_schemas,
codec,
file_marker,
+ projection,
)?,
schema,
batch_size,
diff --git a/datafusion/src/avro_to_arrow/schema.rs b/datafusion/src/avro_to_arrow/schema.rs
deleted file mode 100644
index 8b13789..0000000
--- a/datafusion/src/avro_to_arrow/schema.rs
+++ /dev/null
@@ -1 +0,0 @@
-
diff --git a/datafusion/src/dataframe.rs b/datafusion/src/dataframe.rs
index 9c4a4e4..dd1ebeb 100644
--- a/datafusion/src/dataframe.rs
+++ b/datafusion/src/dataframe.rs
@@ -19,13 +19,14 @@
use crate::error::Result;
use crate::logical_plan::{
- DFSchema, Expr, FunctionRegistry, JoinType, LogicalPlan, Partitioning,
+ DFSchema, FunctionRegistry, JoinType, LogicalPlan, Partitioning,
};
use crate::record_batch::RecordBatch;
use std::sync::Arc;
use crate::physical_plan::SendableRecordBatchStream;
use async_trait::async_trait;
+use datafusion_expr::Expr;
/// DataFrame represents a logical set of rows with the same named columns.
/// Similar to a [Pandas DataFrame](https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.DataFrame.html) or
diff --git a/datafusion/src/datasource/file_format/parquet.rs b/datafusion/src/datasource/file_format/parquet.rs
index c32f7b2..7d5eb11 100644
--- a/datafusion/src/datasource/file_format/parquet.rs
+++ b/datafusion/src/datasource/file_format/parquet.rs
@@ -25,7 +25,7 @@ use arrow::datatypes::Schema;
use arrow::datatypes::SchemaRef;
use async_trait::async_trait;
-use arrow::io::parquet::read::{get_schema, read_metadata};
+use arrow::io::parquet::read::{infer_schema, read_metadata};
use futures::TryStreamExt;
use parquet::statistics::{
BinaryStatistics as ParquetBinaryStatistics,
@@ -265,7 +265,7 @@ fn summarize_min_max(
pub fn fetch_schema(object_reader: Arc<dyn ObjectReader>) -> Result<Schema> {
let mut reader = object_reader.sync_reader()?;
let meta_data = read_metadata(&mut reader)?;
- let schema = get_schema(&meta_data)?;
+ let schema = infer_schema(&meta_data)?;
Ok(schema)
}
@@ -273,7 +273,7 @@ pub fn fetch_schema(object_reader: Arc<dyn ObjectReader>) -> Result<Schema> {
fn fetch_statistics(object_reader: Arc<dyn ObjectReader>) -> Result<Statistics> {
let mut reader = object_reader.sync_reader()?;
let meta_data = read_metadata(&mut reader)?;
- let schema = get_schema(&meta_data)?;
+ let schema = infer_schema(&meta_data)?;
let num_fields = schema.fields().len();
let fields = schema.fields().to_vec();
diff --git a/datafusion/src/datasource/listing/helpers.rs b/datafusion/src/datasource/listing/helpers.rs
index 0d52966..24d3b35 100644
--- a/datafusion/src/datasource/listing/helpers.rs
+++ b/datafusion/src/datasource/listing/helpers.rs
@@ -34,7 +34,7 @@ use log::debug;
use crate::{
error::Result,
execution::context::ExecutionContext,
- logical_plan::{self, Expr, ExpressionVisitor, Recursion},
+ logical_plan::{self, Expr, ExprVisitable, ExpressionVisitor, Recursion},
physical_plan::functions::Volatility,
scalar::ScalarValue,
};
diff --git a/datafusion/src/datasource/memory.rs b/datafusion/src/datasource/memory.rs
index 4b1e09e..ddd81ff 100644
--- a/datafusion/src/datasource/memory.rs
+++ b/datafusion/src/datasource/memory.rs
@@ -166,7 +166,6 @@ mod tests {
use arrow::array::Int32Array;
use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
- use futures::StreamExt;
use std::collections::BTreeMap;
#[tokio::test]
diff --git a/datafusion/src/error.rs b/datafusion/src/error.rs
index fbad9a9..c2c80b4 100644
--- a/datafusion/src/error.rs
+++ b/datafusion/src/error.rs
@@ -16,173 +16,4 @@
// under the License.
//! DataFusion error types
-
-use std::error;
-use std::fmt::{Display, Formatter};
-use std::io;
-use std::result;
-
-use arrow::error::ArrowError;
-use parquet::error::ParquetError;
-use sqlparser::parser::ParserError;
-
-/// Result type for operations that could result in an [DataFusionError]
-pub type Result<T> = result::Result<T, DataFusionError>;
-
-/// Error type for generic operations that could result in DataFusionError::External
-pub type GenericError = Box<dyn error::Error + Send + Sync>;
-
-/// DataFusion error
-#[derive(Debug)]
-#[allow(missing_docs)]
-pub enum DataFusionError {
- /// Error returned by arrow.
- ArrowError(ArrowError),
- /// Wraps an error from the Parquet crate
- ParquetError(ParquetError),
- /// Error associated to I/O operations and associated traits.
- IoError(io::Error),
- /// Error returned when SQL is syntactically incorrect.
- SQL(ParserError),
- /// Error returned on a branch that we know it is possible
- /// but to which we still have no implementation for.
- /// Often, these errors are tracked in our issue tracker.
- NotImplemented(String),
- /// Error returned as a consequence of an error in DataFusion.
- /// This error should not happen in normal usage of DataFusion.
- // DataFusions has internal invariants that we are unable to ask the compiler to check for us.
- // This error is raised when one of those invariants is not verified during execution.
- Internal(String),
- /// This error happens whenever a plan is not valid. Examples include
- /// impossible casts, schema inference not possible and non-unique column names.
- Plan(String),
- /// Error returned during execution of the query.
- /// Examples include files not found, errors in parsing certain types.
- Execution(String),
- /// This error is thrown when a consumer cannot acquire memory from the Memory Manager
- /// we can just cancel the execution of the partition.
- ResourcesExhausted(String),
- /// Errors originating from outside DataFusion's core codebase.
- /// For example, a custom S3Error from the crate datafusion-objectstore-s3
- External(GenericError),
-}
-
-impl From<io::Error> for DataFusionError {
- fn from(e: io::Error) -> Self {
- DataFusionError::IoError(e)
- }
-}
-
-impl From<ArrowError> for DataFusionError {
- fn from(e: ArrowError) -> Self {
- DataFusionError::ArrowError(e)
- }
-}
-
-impl From<DataFusionError> for ArrowError {
- fn from(e: DataFusionError) -> Self {
- match e {
- DataFusionError::ArrowError(e) => e,
- DataFusionError::External(e) => ArrowError::External("".to_string(), e),
- other => ArrowError::External("".to_string(), Box::new(other)),
- }
- }
-}
-
-impl From<ParquetError> for DataFusionError {
- fn from(e: ParquetError) -> Self {
- DataFusionError::ParquetError(e)
- }
-}
-
-impl From<ParserError> for DataFusionError {
- fn from(e: ParserError) -> Self {
- DataFusionError::SQL(e)
- }
-}
-
-impl From<GenericError> for DataFusionError {
- fn from(err: GenericError) -> Self {
- DataFusionError::External(err)
- }
-}
-
-impl Display for DataFusionError {
- fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
- match *self {
- DataFusionError::ArrowError(ref desc) => write!(f, "Arrow error: {}", desc),
- DataFusionError::ParquetError(ref desc) => {
- write!(f, "Parquet error: {}", desc)
- }
- DataFusionError::IoError(ref desc) => write!(f, "IO error: {}", desc),
- DataFusionError::SQL(ref desc) => {
- write!(f, "SQL error: {:?}", desc)
- }
- DataFusionError::NotImplemented(ref desc) => {
- write!(f, "This feature is not implemented: {}", desc)
- }
- DataFusionError::Internal(ref desc) => {
- write!(f, "Internal error: {}. This was likely caused by a bug in DataFusion's \
- code and we would welcome that you file an bug report in our issue tracker", desc)
- }
- DataFusionError::Plan(ref desc) => {
- write!(f, "Error during planning: {}", desc)
- }
- DataFusionError::Execution(ref desc) => {
- write!(f, "Execution error: {}", desc)
- }
- DataFusionError::ResourcesExhausted(ref desc) => {
- write!(f, "Resources exhausted: {}", desc)
- }
- DataFusionError::External(ref desc) => {
- write!(f, "External error: {}", desc)
- }
- }
- }
-}
-
-impl error::Error for DataFusionError {}
-
-#[cfg(test)]
-mod test {
- use crate::error::DataFusionError;
- use arrow::error::ArrowError;
-
- #[test]
- fn arrow_error_to_datafusion() {
- let res = return_arrow_error().unwrap_err();
- assert_eq!(
- res.to_string(),
- "External error: Error during planning: foo"
- );
- }
-
- #[test]
- fn datafusion_error_to_arrow() {
- let res = return_datafusion_error().unwrap_err();
- assert_eq!(
- res.to_string(),
- "Arrow error: Invalid argument error: Schema error: bar"
- );
- }
-
- /// Model what happens when implementing SendableRecrordBatchStream:
- /// DataFusion code needs to return an ArrowError
- #[allow(clippy::try_err)]
- fn return_arrow_error() -> arrow::error::Result<()> {
- // Expect the '?' to work
- let _foo = Err(DataFusionError::Plan("foo".to_string()))?;
- Ok(())
- }
-
- /// Model what happens when using arrow kernels in DataFusion
- /// code: need to turn an ArrowError into a DataFusionError
- #[allow(clippy::try_err)]
- fn return_datafusion_error() -> crate::error::Result<()> {
- // Expect the '?' to work
- let _bar = Err(ArrowError::InvalidArgumentError(
- "Schema error: bar".to_string(),
- ))?;
- Ok(())
- }
-}
+pub use datafusion_common::{DataFusionError, Result};
diff --git a/datafusion/src/execution/context.rs b/datafusion/src/execution/context.rs
index 9aa2b47..09346a3 100644
--- a/datafusion/src/execution/context.rs
+++ b/datafusion/src/execution/context.rs
@@ -38,6 +38,8 @@ use crate::{
hash_build_probe_order::HashBuildProbeOrder, optimizer::PhysicalOptimizerRule,
},
};
+use arrow::array::ArrayRef;
+use arrow::chunk::Chunk;
use log::debug;
use parking_lot::Mutex;
use std::collections::{HashMap, HashSet};
@@ -50,12 +52,13 @@ use futures::{StreamExt, TryStreamExt};
use tokio::task::{self, JoinHandle};
use crate::record_batch::RecordBatch;
-use arrow::datatypes::SchemaRef;
+use arrow::datatypes::{PhysicalType, SchemaRef};
use arrow::error::{ArrowError, Result as ArrowResult};
use arrow::io::csv;
use arrow::io::parquet;
-use arrow::io::parquet::write::FallibleStreamingIterator;
-use arrow::io::parquet::write::WriteOptions;
+use arrow::io::parquet::write::{
+ row_group_iter, to_parquet_schema, Encoding, WriteOptions,
+};
use crate::catalog::{
catalog::{CatalogProvider, MemoryCatalogProvider},
@@ -791,56 +794,54 @@ impl ExecutionContext {
let stream = plan.execute(i, runtime.clone()).await?;
let handle: JoinHandle<Result<u64>> = task::spawn(async move {
- let parquet_schema = parquet::write::to_parquet_schema(&schema)?;
+ let parquet_schema = to_parquet_schema(&schema)?;
let a = parquet_schema.clone();
- let row_groups = stream.map(|batch: ArrowResult<RecordBatch>| {
- // map each record batch to a row group
- let r = batch.map(|batch| {
- let batch_cols = batch.columns().to_vec();
- // column chunk in row group
- let pages =
- batch_cols
- .into_iter()
- .zip(a.columns().iter().cloned())
- .map(move |(array, descriptor)| {
- parquet::write::array_to_pages(
- array.as_ref(),
- descriptor,
- options,
- parquet::write::Encoding::Plain,
- )
- .map(move |pages| {
- let encoded_pages =
- parquet::write::DynIter::new(
- pages.map(|x| Ok(x?)),
- );
- let compressed_pages =
- parquet::write::Compressor::new(
- encoded_pages,
- options.compression,
- vec![],
- )
- .map_err(ArrowError::from);
- parquet::write::DynStreamingIterator::new(
- compressed_pages,
- )
- })
- });
- parquet::write::DynIter::new(pages)
+ let encodings: Vec<Encoding> = schema
+ .fields()
+ .iter()
+ .map(|field| match field.data_type().to_physical_type() {
+ PhysicalType::Binary
+ | PhysicalType::LargeBinary
+ | PhysicalType::Utf8
+ | PhysicalType::LargeUtf8 => {
+ Encoding::DeltaLengthByteArray
+ }
+ _ => Encoding::Plain,
+ })
+ .collect();
+
+ let mut row_groups =
+ stream.map(|batch: ArrowResult<RecordBatch>| {
+ // map each record batch to a row group
+ batch.map(|batch| {
+ // column chunk in row group
+ let chunk: Chunk<ArrayRef> = batch.into();
+ let len = chunk.len();
+ (
+ row_group_iter(
+ chunk,
+ encodings.clone(),
+ a.columns().to_vec(),
+ options,
+ ),
+ len,
+ )
+ })
});
- async { r }
- });
- Ok(parquet::write::stream::write_stream(
+ let mut writer = parquet::write::FileWriter::try_new(
&mut file,
- row_groups,
schema.as_ref().clone(),
- parquet_schema,
options,
- None,
- )
- .await?)
+ )?;
+ writer.start()?;
+ while let Some(row_group) = row_groups.next().await {
+ let (group, len) = row_group?;
+ writer.write(group, len)?;
+ }
+ let (written, _) = writer.end(None)?;
+ Ok(written)
});
tasks.push(handle);
}
@@ -1204,7 +1205,7 @@ impl ExecutionProps {
var_type: VarType,
provider: Arc<dyn VarProvider + Send + Sync>,
) -> Option<Arc<dyn VarProvider + Send + Sync>> {
- let mut var_providers = self.var_providers.take().unwrap_or_default();
+ let mut var_providers = self.var_providers.take().unwrap_or_else(HashMap::new);
let old_provider = var_providers.insert(var_type, provider);
@@ -1345,11 +1346,9 @@ mod tests {
use super::*;
use crate::execution::context::QueryPlanner;
use crate::field_util::{FieldExt, SchemaExt};
- use crate::logical_plan::plan::Projection;
- use crate::logical_plan::TableScan;
use crate::logical_plan::{binary_expr, lit, Operator};
+ use crate::physical_plan::collect;
use crate::physical_plan::functions::{make_scalar_function, Volatility};
- use crate::physical_plan::{collect, collect_partitioned};
use crate::record_batch::RecordBatch;
use crate::test;
use crate::variable::VarType;
@@ -1367,8 +1366,7 @@ mod tests {
use arrow::compute::arithmetics::basic::add;
use arrow::datatypes::*;
use arrow::io::parquet::write::{
- to_parquet_schema, write_file, Compression, Encoding, RowGroupIterator, Version,
- WriteOptions,
+ Compression, Encoding, FileWriter, RowGroupIterator, Version, WriteOptions,
};
use async_trait::async_trait;
use std::collections::BTreeMap;
@@ -1377,7 +1375,6 @@ mod tests {
use std::thread::{self, JoinHandle};
use std::{io::prelude::*, sync::Mutex};
use tempfile::TempDir;
- use test::*;
#[tokio::test]
async fn shared_memory_and_disk_manager() {
@@ -1413,100 +1410,6 @@ mod tests {
));
}
- #[test]
- fn optimize_explain() {
- let schema = Schema::new(vec![Field::new("id", DataType::Int32, false)]);
-
- let plan = LogicalPlanBuilder::scan_empty(Some("employee"), &schema, None)
- .unwrap()
- .explain(true, false)
- .unwrap()
- .build()
- .unwrap();
-
- if let LogicalPlan::Explain(e) = &plan {
- assert_eq!(e.stringified_plans.len(), 1);
- } else {
- panic!("plan was not an explain: {:?}", plan);
- }
-
- // now optimize the plan and expect to see more plans
- let optimized_plan = ExecutionContext::new().optimize(&plan).unwrap();
- if let LogicalPlan::Explain(e) = &optimized_plan {
- // should have more than one plan
- assert!(
- e.stringified_plans.len() > 1,
- "plans: {:#?}",
- e.stringified_plans
- );
- // should have at least one optimized plan
- let opt = e
- .stringified_plans
- .iter()
- .any(|p| matches!(p.plan_type, PlanType::OptimizedLogicalPlan { .. }));
-
- assert!(opt, "plans: {:#?}", e.stringified_plans);
- } else {
- panic!("plan was not an explain: {:?}", plan);
- }
- }
-
- #[tokio::test]
- async fn parallel_projection() -> Result<()> {
- let partition_count = 4;
- let results = execute("SELECT c1, c2 FROM test", partition_count).await?;
-
- let expected = vec![
- "+----+----+",
- "| c1 | c2 |",
- "+----+----+",
- "| 3 | 1 |",
- "| 3 | 2 |",
- "| 3 | 3 |",
- "| 3 | 4 |",
- "| 3 | 5 |",
- "| 3 | 6 |",
- "| 3 | 7 |",
- "| 3 | 8 |",
- "| 3 | 9 |",
- "| 3 | 10 |",
- "| 2 | 1 |",
- "| 2 | 2 |",
- "| 2 | 3 |",
- "| 2 | 4 |",
- "| 2 | 5 |",
- "| 2 | 6 |",
- "| 2 | 7 |",
- "| 2 | 8 |",
- "| 2 | 9 |",
- "| 2 | 10 |",
- "| 1 | 1 |",
- "| 1 | 2 |",
- "| 1 | 3 |",
- "| 1 | 4 |",
- "| 1 | 5 |",
- "| 1 | 6 |",
- "| 1 | 7 |",
- "| 1 | 8 |",
- "| 1 | 9 |",
- "| 1 | 10 |",
- "| 0 | 1 |",
- "| 0 | 2 |",
- "| 0 | 3 |",
- "| 0 | 4 |",
- "| 0 | 5 |",
- "| 0 | 6 |",
- "| 0 | 7 |",
- "| 0 | 8 |",
- "| 0 | 9 |",
- "| 0 | 10 |",
- "+----+----+",
- ];
- assert_batches_sorted_eq!(expected, &results);
-
- Ok(())
- }
-
#[tokio::test]
async fn create_variable_expr() -> Result<()> {
let tmp_dir = TempDir::new()?;
@@ -1552,184 +1455,6 @@ mod tests {
}
#[tokio::test]
- async fn parallel_query_with_filter() -> Result<()> {
- let tmp_dir = TempDir::new()?;
- let partition_count = 4;
- let ctx = create_ctx(&tmp_dir, partition_count).await?;
-
- let logical_plan =
- ctx.create_logical_plan("SELECT c1, c2 FROM test WHERE c1 > 0 AND c1 < 3")?;
- let logical_plan = ctx.optimize(&logical_plan)?;
-
- let physical_plan = ctx.create_physical_plan(&logical_plan).await?;
-
- let runtime = ctx.state.lock().runtime_env.clone();
- let results = collect_partitioned(physical_plan, runtime).await?;
-
- // note that the order of partitions is not deterministic
- let mut num_rows = 0;
- for partition in &results {
- for batch in partition {
- num_rows += batch.num_rows();
- }
- }
- assert_eq!(20, num_rows);
-
- let results: Vec<RecordBatch> = results.into_iter().flatten().collect();
- let expected = vec![
- "+----+----+",
- "| c1 | c2 |",
- "+----+----+",
- "| 1 | 1 |",
- "| 1 | 10 |",
- "| 1 | 2 |",
- "| 1 | 3 |",
- "| 1 | 4 |",
- "| 1 | 5 |",
- "| 1 | 6 |",
- "| 1 | 7 |",
- "| 1 | 8 |",
- "| 1 | 9 |",
- "| 2 | 1 |",
- "| 2 | 10 |",
- "| 2 | 2 |",
- "| 2 | 3 |",
- "| 2 | 4 |",
- "| 2 | 5 |",
- "| 2 | 6 |",
- "| 2 | 7 |",
- "| 2 | 8 |",
- "| 2 | 9 |",
- "+----+----+",
- ];
- assert_batches_sorted_eq!(expected, &results);
-
- Ok(())
- }
-
- #[tokio::test]
- async fn projection_on_table_scan() -> Result<()> {
- let tmp_dir = TempDir::new()?;
- let partition_count = 4;
- let ctx = create_ctx(&tmp_dir, partition_count).await?;
- let runtime = ctx.state.lock().runtime_env.clone();
-
- let table = ctx.table("test")?;
- let logical_plan = LogicalPlanBuilder::from(table.to_logical_plan())
- .project(vec![col("c2")])?
- .build()?;
-
- let optimized_plan = ctx.optimize(&logical_plan)?;
- match &optimized_plan {
- LogicalPlan::Projection(Projection { input, .. }) => match &**input {
- LogicalPlan::TableScan(TableScan {
- source,
- projected_schema,
- ..
- }) => {
- assert_eq!(source.schema().fields().len(), 3);
- assert_eq!(projected_schema.fields().len(), 1);
- }
- _ => panic!("input to projection should be TableScan"),
- },
- _ => panic!("expect optimized_plan to be projection"),
- }
-
- let expected = "Projection: #test.c2\
- \n TableScan: test projection=Some([1])";
- assert_eq!(format!("{:?}", optimized_plan), expected);
-
- let physical_plan = ctx.create_physical_plan(&optimized_plan).await?;
-
- assert_eq!(1, physical_plan.schema().fields().len());
- assert_eq!("c2", physical_plan.schema().field(0).name());
-
- let batches = collect(physical_plan, runtime).await?;
- assert_eq!(40, batches.iter().map(|x| x.num_rows()).sum::<usize>());
-
- Ok(())
- }
-
- #[tokio::test]
- async fn preserve_nullability_on_projection() -> Result<()> {
- let tmp_dir = TempDir::new()?;
- let ctx = create_ctx(&tmp_dir, 1).await?;
-
- let schema: Schema = ctx.table("test").unwrap().schema().clone().into();
- assert!(!schema.field_with_name("c1")?.is_nullable());
-
- let plan = LogicalPlanBuilder::scan_empty(None, &schema, None)?
- .project(vec![col("c1")])?
- .build()?;
-
- let plan = ctx.optimize(&plan)?;
- let physical_plan = ctx.create_physical_plan(&Arc::new(plan)).await?;
- assert!(!physical_plan.schema().field_with_name("c1")?.is_nullable());
- Ok(())
- }
-
- #[tokio::test]
- async fn projection_on_memory_scan() -> Result<()> {
- let schema = Schema::new(vec![
- Field::new("a", DataType::Int32, false),
- Field::new("b", DataType::Int32, false),
- Field::new("c", DataType::Int32, false),
- ]);
- let schema = SchemaRef::new(schema);
-
- let partitions = vec![vec![RecordBatch::try_new(
- schema.clone(),
- vec![
- Arc::new(Int32Array::from_slice(&[1, 10, 10, 100])),
- Arc::new(Int32Array::from_slice(&[2, 12, 12, 120])),
- Arc::new(Int32Array::from_slice(&[3, 12, 12, 120])),
- ],
- )?]];
-
- let plan = LogicalPlanBuilder::scan_memory(partitions, schema, None)?
- .project(vec![col("b")])?
- .build()?;
- assert_fields_eq(&plan, vec!["b"]);
-
- let ctx = ExecutionContext::new();
- let optimized_plan = ctx.optimize(&plan)?;
- match &optimized_plan {
- LogicalPlan::Projection(Projection { input, .. }) => match &**input {
- LogicalPlan::TableScan(TableScan {
- source,
- projected_schema,
- ..
- }) => {
- assert_eq!(source.schema().fields().len(), 3);
- assert_eq!(projected_schema.fields().len(), 1);
- }
- _ => panic!("input to projection should be InMemoryScan"),
- },
- _ => panic!("expect optimized_plan to be projection"),
- }
-
- let expected = format!(
- "Projection: #{}.b\
- \n TableScan: {} projection=Some([1])",
- UNNAMED_TABLE, UNNAMED_TABLE
- );
- assert_eq!(format!("{:?}", optimized_plan), expected);
-
- let physical_plan = ctx.create_physical_plan(&optimized_plan).await?;
-
- assert_eq!(1, physical_plan.schema().fields().len());
- assert_eq!("b", physical_plan.schema().field(0).name());
-
- let runtime = ctx.state.lock().runtime_env.clone();
- let batches = collect(physical_plan, runtime).await?;
- assert_eq!(1, batches.len());
- assert_eq!(1, batches[0].num_columns());
- assert_eq!(4, batches[0].num_rows());
-
- Ok(())
- }
-
- #[tokio::test]
async fn sort() -> Result<()> {
let results =
execute("SELECT c1, c2 FROM test ORDER BY c1 DESC, c2 ASC", 4).await?;
@@ -3684,7 +3409,6 @@ mod tests {
let ids = Arc::new(Int32Array::from_slice(&[i as i32]));
let names = Arc::new(Utf8Array::<i32>::from_slice(&["test"]));
let schema_ref = schema.as_ref();
- let parquet_schema = to_parquet_schema(schema_ref).unwrap();
let iter = vec![Ok(Chunk::new(vec![ids as ArrayRef, names as ArrayRef]))];
let row_groups = RowGroupIterator::try_new(
iter.into_iter(),
@@ -3693,16 +3417,14 @@ mod tests {
vec![Encoding::Plain, Encoding::Plain],
)
.unwrap();
-
- let _ = write_file(
- &mut file,
- row_groups,
- schema_ref,
- parquet_schema,
- options,
- None,
- )
- .unwrap();
+ let mut writer =
+ FileWriter::try_new(&mut file, schema_ref.clone(), options).unwrap();
+ writer.start().unwrap();
+ for rg in row_groups {
+ let (group, len) = rg.unwrap();
+ writer.write(group, len).unwrap();
+ }
+ writer.end(None).unwrap();
}
}
@@ -3718,6 +3440,173 @@ mod tests {
assert_eq!(result[0].schema().metadata(), result[1].schema().metadata());
}
+ #[tokio::test]
+ async fn normalized_column_identifiers() {
+ // create local execution context
+ let mut ctx = ExecutionContext::new();
+
+ // register csv file with the execution context
+ ctx.register_csv(
+ "case_insensitive_test",
+ "tests/example.csv",
+ CsvReadOptions::new(),
+ )
+ .await
+ .unwrap();
+
+ let sql = "SELECT A, b FROM case_insensitive_test";
+ let result = plan_and_collect(&mut ctx, sql)
+ .await
+ .expect("ran plan correctly");
+ let expected = vec![
+ "+---+---+",
+ "| a | b |",
+ "+---+---+",
+ "| 1 | 2 |",
+ "+---+---+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+
+ let sql = "SELECT t.A, b FROM case_insensitive_test AS t";
+ let result = plan_and_collect(&mut ctx, sql)
+ .await
+ .expect("ran plan correctly");
+ let expected = vec![
+ "+---+---+",
+ "| a | b |",
+ "+---+---+",
+ "| 1 | 2 |",
+ "+---+---+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+
+ // Aliases
+
+ let sql = "SELECT t.A as x, b FROM case_insensitive_test AS t";
+ let result = plan_and_collect(&mut ctx, sql)
+ .await
+ .expect("ran plan correctly");
+ let expected = vec![
+ "+---+---+",
+ "| x | b |",
+ "+---+---+",
+ "| 1 | 2 |",
+ "+---+---+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+
+ let sql = "SELECT t.A AS X, b FROM case_insensitive_test AS t";
+ let result = plan_and_collect(&mut ctx, sql)
+ .await
+ .expect("ran plan correctly");
+ let expected = vec![
+ "+---+---+",
+ "| x | b |",
+ "+---+---+",
+ "| 1 | 2 |",
+ "+---+---+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+
+ let sql = r#"SELECT t.A AS "X", b FROM case_insensitive_test AS t"#;
+ let result = plan_and_collect(&mut ctx, sql)
+ .await
+ .expect("ran plan correctly");
+ let expected = vec![
+ "+---+---+",
+ "| X | b |",
+ "+---+---+",
+ "| 1 | 2 |",
+ "+---+---+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+
+ // Order by
+
+ let sql = "SELECT t.A AS x, b FROM case_insensitive_test AS t ORDER BY x";
+ let result = plan_and_collect(&mut ctx, sql)
+ .await
+ .expect("ran plan correctly");
+ let expected = vec![
+ "+---+---+",
+ "| x | b |",
+ "+---+---+",
+ "| 1 | 2 |",
+ "+---+---+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+
+ let sql = "SELECT t.A AS x, b FROM case_insensitive_test AS t ORDER BY X";
+ let result = plan_and_collect(&mut ctx, sql)
+ .await
+ .expect("ran plan correctly");
+ let expected = vec![
+ "+---+---+",
+ "| x | b |",
+ "+---+---+",
+ "| 1 | 2 |",
+ "+---+---+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+
+ let sql = r#"SELECT t.A AS "X", b FROM case_insensitive_test AS t ORDER BY "X""#;
+ let result = plan_and_collect(&mut ctx, sql)
+ .await
+ .expect("ran plan correctly");
+ let expected = vec![
+ "+---+---+",
+ "| X | b |",
+ "+---+---+",
+ "| 1 | 2 |",
+ "+---+---+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+
+ // Where
+
+ let sql = "SELECT a, b FROM case_insensitive_test where A IS NOT null";
+ let result = plan_and_collect(&mut ctx, sql)
+ .await
+ .expect("ran plan correctly");
+ let expected = vec![
+ "+---+---+",
+ "| a | b |",
+ "+---+---+",
+ "| 1 | 2 |",
+ "+---+---+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+
+ // Group by
+
+ let sql = "SELECT a as x, count(*) as c FROM case_insensitive_test GROUP BY X";
+ let result = plan_and_collect(&mut ctx, sql)
+ .await
+ .expect("ran plan correctly");
+ let expected = vec![
+ "+---+---+",
+ "| x | c |",
+ "+---+---+",
+ "| 1 | 1 |",
+ "+---+---+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+
+ let sql =
+ r#"SELECT a as "X", count(*) as c FROM case_insensitive_test GROUP BY "X""#;
+ let result = plan_and_collect(&mut ctx, sql)
+ .await
+ .expect("ran plan correctly");
+ let expected = vec![
+ "+---+---+",
+ "| X | c |",
+ "+---+---+",
+ "| 1 | 1 |",
+ "+---+---+",
+ ];
+ assert_batches_sorted_eq!(expected, &result);
+ }
+
struct MyPhysicalPlanner {}
#[async_trait]
diff --git a/datafusion/src/execution/dataframe_impl.rs b/datafusion/src/execution/dataframe_impl.rs
index 1ad9595..73252f3 100644
--- a/datafusion/src/execution/dataframe_impl.rs
+++ b/datafusion/src/execution/dataframe_impl.rs
@@ -324,12 +324,12 @@ mod tests {
use super::*;
use crate::execution::options::CsvReadOptions;
- use crate::physical_plan::functions::ScalarFunctionImplementation;
- use crate::physical_plan::functions::Volatility;
use crate::physical_plan::{window_functions, ColumnarValue};
use crate::{assert_batches_sorted_eq, execution::context::ExecutionContext};
use crate::{logical_plan::*, test_util};
use arrow::datatypes::DataType;
+ use datafusion_expr::ScalarFunctionImplementation;
+ use datafusion_expr::Volatility;
#[tokio::test]
async fn select_columns() -> Result<()> {
diff --git a/datafusion/src/field_util.rs b/datafusion/src/field_util.rs
index 2dfccb7..4ad7990 100644
--- a/datafusion/src/field_util.rs
+++ b/datafusion/src/field_util.rs
@@ -15,476 +15,6 @@
// specific language governing permissions and limitations
// under the License.
-//! Utility functions for complex field access
+//! Field utils reimported from datafusion-common
-use arrow::array::{ArrayRef, StructArray};
-use arrow::datatypes::{DataType, Field, Metadata, Schema};
-use arrow::error::ArrowError;
-use std::borrow::Borrow;
-use std::collections::BTreeMap;
-
-use crate::error::{DataFusionError, Result};
-use crate::scalar::ScalarValue;
-
-/// Returns the field access indexed by `key` from a [`DataType::List`] or [`DataType::Struct`]
-/// # Error
-/// Errors if
-/// * the `data_type` is not a Struct or,
-/// * there is no field key is not of the required index type
-pub fn get_indexed_field(data_type: &DataType, key: &ScalarValue) -> Result<Field> {
- match (data_type, key) {
- (DataType::List(lt), ScalarValue::Int64(Some(i))) => {
- if *i < 0 {
- Err(DataFusionError::Plan(format!(
- "List based indexed access requires a positive int, was {0}",
- i
- )))
- } else {
- Ok(Field::new(&i.to_string(), lt.data_type().clone(), false))
- }
- }
- (DataType::Struct(fields), ScalarValue::Utf8(Some(s))) => {
- if s.is_empty() {
- Err(DataFusionError::Plan(
- "Struct based indexed access requires a non empty string".to_string(),
- ))
- } else {
- let field = fields.iter().find(|f| f.name() == s);
- match field {
- None => Err(DataFusionError::Plan(format!(
- "Field {} not found in struct",
- s
- ))),
- Some(f) => Ok(f.clone()),
- }
- }
- }
- (DataType::Struct(_), _) => Err(DataFusionError::Plan(
- "Only utf8 strings are valid as an indexed field in a struct".to_string(),
- )),
- (DataType::List(_), _) => Err(DataFusionError::Plan(
- "Only ints are valid as an indexed field in a list".to_string(),
- )),
- _ => Err(DataFusionError::Plan(
- "The expression to get an indexed field is only valid for `List` types"
- .to_string(),
- )),
- }
-}
-
-/// Imitate arrow-rs StructArray behavior by extending arrow2 StructArray
-pub trait StructArrayExt {
- /// Return field names in this struct array
- fn column_names(&self) -> Vec<&str>;
- /// Return child array whose field name equals to column_name
- fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef>;
- /// Return the number of fields in this struct array
- fn num_columns(&self) -> usize;
- /// Return the column at the position
- fn column(&self, pos: usize) -> ArrayRef;
-}
-
-impl StructArrayExt for StructArray {
- fn column_names(&self) -> Vec<&str> {
- self.fields().iter().map(|f| f.name.as_str()).collect()
- }
-
- fn column_by_name(&self, column_name: &str) -> Option<&ArrayRef> {
- self.fields()
- .iter()
- .position(|c| c.name() == column_name)
- .map(|pos| self.values()[pos].borrow())
- }
-
- fn num_columns(&self) -> usize {
- self.fields().len()
- }
-
- fn column(&self, pos: usize) -> ArrayRef {
- self.values()[pos].clone()
- }
-}
-
-/// Converts a list of field / array pairs to a struct array
-pub fn struct_array_from(pairs: Vec<(Field, ArrayRef)>) -> StructArray {
- let fields: Vec<Field> = pairs.iter().map(|v| v.0.clone()).collect();
- let values = pairs.iter().map(|v| v.1.clone()).collect();
- StructArray::from_data(DataType::Struct(fields), values, None)
-}
-
-/// Imitate arrow-rs Schema behavior by extending arrow2 Schema
-pub trait SchemaExt {
- /// Creates a new [`Schema`] from a sequence of [`Field`] values.
- ///
- /// # Example
- ///
- /// ```
- /// use arrow::datatypes::{Field, DataType, Schema};
- /// use datafusion::field_util::SchemaExt;
- /// let field_a = Field::new("a", DataType::Int64, false);
- /// let field_b = Field::new("b", DataType::Boolean, false);
- ///
- /// let schema = Schema::new(vec![field_a, field_b]);
- /// ```
- fn new(fields: Vec<Field>) -> Self;
-
- /// Creates a new [`Schema`] from a sequence of [`Field`] values and [`arrow::datatypes::Metadata`]
- ///
- /// # Example
- ///
- /// ```
- /// use std::collections::BTreeMap;
- /// use arrow::datatypes::{Field, DataType, Schema};
- /// use datafusion::field_util::SchemaExt;
- ///
- /// let field_a = Field::new("a", DataType::Int64, false);
- /// let field_b = Field::new("b", DataType::Boolean, false);
- ///
- /// let schema_metadata: BTreeMap<String, String> =
- /// vec![("baz".to_string(), "barf".to_string())]
- /// .into_iter()
- /// .collect();
- /// let schema = Schema::new_with_metadata(vec![field_a, field_b], schema_metadata);
- /// ```
- fn new_with_metadata(fields: Vec<Field>, metadata: Metadata) -> Self;
-
- /// Creates an empty [`Schema`].
- fn empty() -> Self;
-
- /// Look up a column by name and return a immutable reference to the column along with
- /// its index.
- fn column_with_name(&self, name: &str) -> Option<(usize, &Field)>;
-
- /// Returns the first [`Field`] named `name`.
- fn field_with_name(&self, name: &str) -> Result<&Field>;
-
- /// Find the index of the column with the given name.
- fn index_of(&self, name: &str) -> Result<usize>;
-
- /// Returns the [`Field`] at position `i`.
- /// # Panics
- /// Panics iff `i` is larger than the number of fields in this [`Schema`].
- fn field(&self, index: usize) -> &Field;
-
- /// Returns all [`Field`]s in this schema.
- fn fields(&self) -> &[Field];
-
- /// Returns an immutable reference to the Map of custom metadata key-value pairs.
- fn metadata(&self) -> &BTreeMap<String, String>;
-
- /// Merge schema into self if it is compatible. Struct fields will be merged recursively.
- ///
- /// Example:
- ///
- /// ```
- /// use arrow::datatypes::*;
- /// use datafusion::field_util::SchemaExt;
- ///
- /// let merged = Schema::try_merge(vec![
- /// Schema::new(vec![
- /// Field::new("c1", DataType::Int64, false),
- /// Field::new("c2", DataType::Utf8, false),
- /// ]),
- /// Schema::new(vec![
- /// Field::new("c1", DataType::Int64, true),
- /// Field::new("c2", DataType::Utf8, false),
- /// Field::new("c3", DataType::Utf8, false),
- /// ]),
- /// ]).unwrap();
- ///
- /// assert_eq!(
- /// merged,
- /// Schema::new(vec![
- /// Field::new("c1", DataType::Int64, true),
- /// Field::new("c2", DataType::Utf8, false),
- /// Field::new("c3", DataType::Utf8, false),
- /// ]),
- /// );
- /// ```
- fn try_merge(schemas: impl IntoIterator<Item = Self>) -> Result<Self>
- where
- Self: Sized;
-
- /// Return the field names
- fn field_names(&self) -> Vec<String>;
-
- /// Returns a new schema with only the specified columns in the new schema
- /// This carries metadata from the parent schema over as well
- fn project(&self, indices: &[usize]) -> Result<Schema>;
-}
-
-impl SchemaExt for Schema {
- fn new(fields: Vec<Field>) -> Self {
- Self::from(fields)
- }
-
- fn new_with_metadata(fields: Vec<Field>, metadata: Metadata) -> Self {
- Self::new(fields).with_metadata(metadata)
- }
-
- fn empty() -> Self {
- Self::from(vec![])
- }
-
- fn column_with_name(&self, name: &str) -> Option<(usize, &Field)> {
- self.fields.iter().enumerate().find(|(_, f)| f.name == name)
- }
-
- fn field_with_name(&self, name: &str) -> Result<&Field> {
- Ok(&self.fields[self.index_of(name)?])
- }
-
- fn index_of(&self, name: &str) -> Result<usize> {
- self.column_with_name(name).map(|(i, _f)| i).ok_or_else(|| {
- DataFusionError::ArrowError(ArrowError::InvalidArgumentError(format!(
- "Unable to get field named \"{}\". Valid fields: {:?}",
- name,
- self.field_names()
- )))
- })
- }
-
- fn field(&self, index: usize) -> &Field {
- &self.fields[index]
- }
-
- #[inline]
- fn fields(&self) -> &[Field] {
- &self.fields
- }
-
- #[inline]
- fn metadata(&self) -> &BTreeMap<String, String> {
- &self.metadata
- }
-
- fn try_merge(schemas: impl IntoIterator<Item = Self>) -> Result<Self> {
- schemas
- .into_iter()
- .try_fold(Self::empty(), |mut merged, schema| {
- let Schema { metadata, fields } = schema;
- for (key, value) in metadata.into_iter() {
- // merge metadata
- if let Some(old_val) = merged.metadata.get(&key) {
- if old_val != &value {
- return Err(DataFusionError::ArrowError(
- ArrowError::InvalidArgumentError(
- "Fail to merge schema due to conflicting metadata."
- .to_string(),
- ),
- ));
- }
- }
- merged.metadata.insert(key, value);
- }
- // merge fields
- for field in fields.into_iter() {
- let mut new_field = true;
- for merged_field in &mut merged.fields {
- if field.name() != merged_field.name() {
- continue;
- }
- new_field = false;
- merged_field.try_merge(&field)?
- }
- // found a new field, add to field list
- if new_field {
- merged.fields.push(field);
- }
- }
- Ok(merged)
- })
- }
-
- fn field_names(&self) -> Vec<String> {
- self.fields.iter().map(|f| f.name.to_string()).collect()
- }
-
- fn project(&self, indices: &[usize]) -> Result<Schema> {
- let new_fields = indices
- .iter()
- .map(|i| {
- self.fields.get(*i).cloned().ok_or_else(|| {
- DataFusionError::ArrowError(ArrowError::InvalidArgumentError(
- format!(
- "project index {} out of bounds, max field {}",
- i,
- self.fields().len()
- ),
- ))
- })
- })
- .collect::<Result<Vec<_>>>()?;
- Ok(Self::new_with_metadata(new_fields, self.metadata.clone()))
- }
-}
-
-/// Imitate arrow-rs Field behavior by extending arrow2 Field
-pub trait FieldExt {
- /// The field name
- fn name(&self) -> &str;
-
- /// Whether the field is nullable
- fn is_nullable(&self) -> bool;
-
- /// Returns the field metadata
- fn metadata(&self) -> &BTreeMap<String, String>;
-
- /// Merge field into self if it is compatible. Struct will be merged recursively.
- /// NOTE: `self` may be updated to unexpected state in case of merge failure.
- ///
- /// Example:
- ///
- /// ```
- /// use arrow2::datatypes::*;
- ///
- /// let mut field = Field::new("c1", DataType::Int64, false);
- /// assert!(field.try_merge(&Field::new("c1", DataType::Int64, true)).is_ok());
- /// assert!(field.is_nullable());
- /// ```
- fn try_merge(&mut self, from: &Field) -> Result<()>;
-
- /// Sets the `Field`'s optional custom metadata.
- /// The metadata is set as `None` for empty map.
- fn set_metadata(&mut self, metadata: Option<BTreeMap<String, String>>);
-}
-
-impl FieldExt for Field {
- #[inline]
- fn name(&self) -> &str {
- &self.name
- }
-
- #[inline]
- fn is_nullable(&self) -> bool {
- self.is_nullable
- }
-
- #[inline]
- fn metadata(&self) -> &BTreeMap<String, String> {
- &self.metadata
- }
-
- fn try_merge(&mut self, from: &Field) -> Result<()> {
- // merge metadata
- for (key, from_value) in from.metadata() {
- if let Some(self_value) = self.metadata.get(key) {
- if self_value != from_value {
- return Err(DataFusionError::ArrowError(ArrowError::InvalidArgumentError(format!(
- "Fail to merge field due to conflicting metadata data value for key {}",
- key
- ))));
- }
- } else {
- self.metadata.insert(key.clone(), from_value.clone());
- }
- }
-
- match &mut self.data_type {
- DataType::Struct(nested_fields) => match &from.data_type {
- DataType::Struct(from_nested_fields) => {
- for from_field in from_nested_fields {
- let mut is_new_field = true;
- for self_field in nested_fields.iter_mut() {
- if self_field.name != from_field.name {
- continue;
- }
- is_new_field = false;
- self_field.try_merge(from_field)?;
- }
- if is_new_field {
- nested_fields.push(from_field.clone());
- }
- }
- }
- _ => {
- return Err(DataFusionError::ArrowError(
- ArrowError::InvalidArgumentError(
- "Fail to merge schema Field due to conflicting datatype"
- .to_string(),
- ),
- ));
- }
- },
- DataType::Union(nested_fields, _, _) => match &from.data_type {
- DataType::Union(from_nested_fields, _, _) => {
- for from_field in from_nested_fields {
- let mut is_new_field = true;
- for self_field in nested_fields.iter_mut() {
- if from_field == self_field {
- is_new_field = false;
- break;
- }
- }
- if is_new_field {
- nested_fields.push(from_field.clone());
- }
- }
- }
- _ => {
- return Err(DataFusionError::ArrowError(
- ArrowError::InvalidArgumentError(
- "Fail to merge schema Field due to conflicting datatype"
- .to_string(),
- ),
- ));
- }
- },
- DataType::Null
- | DataType::Boolean
- | DataType::Int8
- | DataType::Int16
- | DataType::Int32
- | DataType::Int64
- | DataType::UInt8
- | DataType::UInt16
- | DataType::UInt32
- | DataType::UInt64
- | DataType::Float16
- | DataType::Float32
- | DataType::Float64
- | DataType::Timestamp(_, _)
- | DataType::Date32
- | DataType::Date64
- | DataType::Time32(_)
- | DataType::Time64(_)
- | DataType::Duration(_)
- | DataType::Binary
- | DataType::LargeBinary
- | DataType::Interval(_)
- | DataType::LargeList(_)
- | DataType::List(_)
- | DataType::Dictionary(_, _, _)
- | DataType::FixedSizeList(_, _)
- | DataType::FixedSizeBinary(_)
- | DataType::Utf8
- | DataType::LargeUtf8
- | DataType::Extension(_, _, _)
- | DataType::Map(_, _)
- | DataType::Decimal(_, _) => {
- if self.data_type != from.data_type {
- return Err(DataFusionError::ArrowError(
- ArrowError::InvalidArgumentError(
- "Fail to merge schema Field due to conflicting datatype"
- .to_string(),
- ),
- ));
- }
- }
- }
- if from.is_nullable {
- self.is_nullable = from.is_nullable;
- }
-
- Ok(())
- }
-
- #[inline]
- fn set_metadata(&mut self, metadata: Option<BTreeMap<String, String>>) {
- if let Some(v) = metadata {
- if !v.is_empty() {
- self.metadata = v;
- }
- }
- }
-}
+pub use datafusion_common::field_util::*;
diff --git a/datafusion/src/lib.rs b/datafusion/src/lib.rs
index 6b83980..682675e 100644
--- a/datafusion/src/lib.rs
+++ b/datafusion/src/lib.rs
@@ -160,8 +160,8 @@
//! * Sort: [`SortExec`](physical_plan::sort::SortExec)
//! * Coalesce partitions: [`CoalescePartitionsExec`](physical_plan::coalesce_partitions::CoalescePartitionsExec)
//! * Limit: [`LocalLimitExec`](physical_plan::limit::LocalLimitExec) and [`GlobalLimitExec`](physical_plan::limit::GlobalLimitExec)
-//! * Scan a CSV: [`CsvExec`](physical_plan::csv::CsvExec)
-//! * Scan a Parquet: [`ParquetExec`](physical_plan::parquet::ParquetExec)
+//! * Scan a CSV: [`CsvExec`](physical_plan::file_format::CsvExec)
+//! * Scan a Parquet: [`ParquetExec`](physical_plan::file_format::ParquetExec)
//! * Scan from memory: [`MemoryExec`](physical_plan::memory::MemoryExec)
//! * Explain the plan: [`ExplainExec`](physical_plan::explain::ExplainExec)
//!
diff --git a/datafusion/src/logical_plan/builder.rs b/datafusion/src/logical_plan/builder.rs
index 3a64f76..db4573c 100644
--- a/datafusion/src/logical_plan/builder.rs
+++ b/datafusion/src/logical_plan/builder.rs
@@ -26,6 +26,7 @@ use crate::datasource::{
};
use crate::error::{DataFusionError, Result};
use crate::field_util::SchemaExt;
+use crate::logical_plan::expr_schema::ExprSchemable;
use crate::logical_plan::plan::{
Aggregate, Analyze, EmptyRelation, Explain, Filter, Join, Projection, Sort,
TableScan, ToStringifiedPlan, Union, Window,
@@ -594,6 +595,17 @@ impl LogicalPlanBuilder {
self.join_detailed(right, join_type, join_keys, false)
}
+ fn normalize(
+ plan: &LogicalPlan,
+ column: impl Into<Column> + Clone,
+ ) -> Result<Column> {
+ let schemas = plan.all_schemas();
+ let using_columns = plan.using_columns()?;
+ column
+ .into()
+ .normalize_with_schemas(&schemas, &using_columns)
+ }
+
/// Apply a join with on constraint and specified null equality
/// If null_equals_null is true then null == null, else null != null
pub fn join_detailed(
@@ -632,7 +644,10 @@ impl LogicalPlanBuilder {
match (l_is_left, l_is_right, r_is_left, r_is_right) {
(_, Ok(_), Ok(_), _) => (Ok(r), Ok(l)),
(Ok(_), _, _, Ok(_)) => (Ok(l), Ok(r)),
- _ => (l.normalize(&self.plan), r.normalize(right)),
+ _ => (
+ Self::normalize(&self.plan, l),
+ Self::normalize(right, r),
+ ),
}
}
(Some(lr), None) => {
@@ -642,9 +657,12 @@ impl LogicalPlanBuilder {
right.schema().field_with_qualified_name(lr, &l.name);
match (l_is_left, l_is_right) {
- (Ok(_), _) => (Ok(l), r.normalize(right)),
- (_, Ok(_)) => (r.normalize(&self.plan), Ok(l)),
- _ => (l.normalize(&self.plan), r.normalize(right)),
+ (Ok(_), _) => (Ok(l), Self::normalize(right, r)),
+ (_, Ok(_)) => (Self::normalize(&self.plan, r), Ok(l)),
+ _ => (
+ Self::normalize(&self.plan, l),
+ Self::normalize(right, r),
+ ),
}
}
(None, Some(rr)) => {
@@ -654,22 +672,25 @@ impl LogicalPlanBuilder {
right.schema().field_with_qualified_name(rr, &r.name);
match (r_is_left, r_is_right) {
- (Ok(_), _) => (Ok(r), l.normalize(right)),
- (_, Ok(_)) => (l.normalize(&self.plan), Ok(r)),
- _ => (l.normalize(&self.plan), r.normalize(right)),
+ (Ok(_), _) => (Ok(r), Self::normalize(right, l)),
+ (_, Ok(_)) => (Self::normalize(&self.plan, l), Ok(r)),
+ _ => (
+ Self::normalize(&self.plan, l),
+ Self::normalize(right, r),
+ ),
}
}
(None, None) => {
let mut swap = false;
- let left_key =
- l.clone().normalize(&self.plan).or_else(|_| {
+ let left_key = Self::normalize(&self.plan, l.clone())
+ .or_else(|_| {
swap = true;
- l.normalize(right)
+ Self::normalize(right, l)
});
if swap {
- (r.normalize(&self.plan), left_key)
+ (Self::normalize(&self.plan, r), left_key)
} else {
- (left_key, r.normalize(right))
+ (left_key, Self::normalize(right, r))
}
}
}
@@ -704,11 +725,11 @@ impl LogicalPlanBuilder {
let left_keys: Vec<Column> = using_keys
.clone()
.into_iter()
- .map(|c| c.into().normalize(&self.plan))
+ .map(|c| Self::normalize(&self.plan, c))
.collect::<Result<_>>()?;
let right_keys: Vec<Column> = using_keys
.into_iter()
- .map(|c| c.into().normalize(right))
+ .map(|c| Self::normalize(right, c))
.collect::<Result<_>>()?;
let on: Vec<(_, _)> = left_keys.into_iter().zip(right_keys.into_iter()).collect();
diff --git a/datafusion/src/logical_plan/dfschema.rs b/datafusion/src/logical_plan/dfschema.rs
index b89b239..eb62428 100644
--- a/datafusion/src/logical_plan/dfschema.rs
+++ b/datafusion/src/logical_plan/dfschema.rs
@@ -18,669 +18,4 @@
//! DFSchema is an extended schema struct that DataFusion uses to provide support for
//! fields with optional relation names.
-use std::collections::HashSet;
-use std::convert::TryFrom;
-use std::sync::Arc;
-
-use crate::error::{DataFusionError, Result};
-use crate::logical_plan::Column;
-
-use crate::field_util::{FieldExt, SchemaExt};
-use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
-use std::fmt::{Display, Formatter};
-
-/// A reference-counted reference to a `DFSchema`.
-pub type DFSchemaRef = Arc<DFSchema>;
-
-/// DFSchema wraps an Arrow schema and adds relation names
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub struct DFSchema {
- /// Fields
- fields: Vec<DFField>,
-}
-
-impl DFSchema {
- /// Creates an empty `DFSchema`
- pub fn empty() -> Self {
- Self { fields: vec![] }
- }
-
- /// Create a new `DFSchema`
- pub fn new(fields: Vec<DFField>) -> Result<Self> {
- let mut qualified_names = HashSet::new();
- let mut unqualified_names = HashSet::new();
-
- for field in &fields {
- if let Some(qualifier) = field.qualifier() {
- if !qualified_names.insert((qualifier, field.name())) {
- return Err(DataFusionError::Plan(format!(
- "Schema contains duplicate qualified field name '{}'",
- field.qualified_name()
- )));
- }
- } else if !unqualified_names.insert(field.name()) {
- return Err(DataFusionError::Plan(format!(
- "Schema contains duplicate unqualified field name '{}'",
- field.name()
- )));
- }
- }
-
- // check for mix of qualified and unqualified field with same unqualified name
- // note that we need to sort the contents of the HashSet first so that errors are
- // deterministic
- let mut qualified_names = qualified_names
- .iter()
- .map(|(l, r)| (l.as_str(), r.to_owned()))
- .collect::<Vec<(&str, &str)>>();
- qualified_names.sort_by(|a, b| {
- let a = format!("{}.{}", a.0, a.1);
- let b = format!("{}.{}", b.0, b.1);
- a.cmp(&b)
- });
- for (qualifier, name) in &qualified_names {
- if unqualified_names.contains(name) {
- return Err(DataFusionError::Plan(format!(
- "Schema contains qualified field name '{}.{}' \
- and unqualified field name '{}' which would be ambiguous",
- qualifier, name, name
- )));
- }
- }
- Ok(Self { fields })
- }
-
- /// Create a `DFSchema` from an Arrow schema
- pub fn try_from_qualified_schema(qualifier: &str, schema: &Schema) -> Result<Self> {
- Self::new(
- schema
- .fields()
- .iter()
- .map(|f| DFField::from_qualified(qualifier, f.clone()))
- .collect(),
- )
- }
-
- /// Combine two schemas
- pub fn join(&self, schema: &DFSchema) -> Result<Self> {
- let mut fields = self.fields.clone();
- fields.extend_from_slice(schema.fields().as_slice());
- Self::new(fields)
- }
-
- /// Merge a schema into self
- pub fn merge(&mut self, other_schema: &DFSchema) {
- for field in other_schema.fields() {
- // skip duplicate columns
- let duplicated_field = match field.qualifier() {
- Some(q) => self.field_with_name(Some(q.as_str()), field.name()).is_ok(),
- // for unqualifed columns, check as unqualified name
- None => self.field_with_unqualified_name(field.name()).is_ok(),
- };
- if !duplicated_field {
- self.fields.push(field.clone());
- }
- }
- }
-
- /// Get a list of fields
- pub fn fields(&self) -> &Vec<DFField> {
- &self.fields
- }
-
- /// Returns an immutable reference of a specific `Field` instance selected using an
- /// offset within the internal `fields` vector
- pub fn field(&self, i: usize) -> &DFField {
- &self.fields[i]
- }
-
- /// Find the index of the column with the given unqualified name
- pub fn index_of(&self, name: &str) -> Result<usize> {
- for i in 0..self.fields.len() {
- if self.fields[i].name() == name {
- return Ok(i);
- }
- }
- Err(DataFusionError::Plan(format!(
- "No field named '{}'. Valid fields are {}.",
- name,
- self.get_field_names()
- )))
- }
-
- fn index_of_column_by_name(
- &self,
- qualifier: Option<&str>,
- name: &str,
- ) -> Result<usize> {
- let mut matches = self
- .fields
- .iter()
- .enumerate()
- .filter(|(_, field)| match (qualifier, &field.qualifier) {
- // field to lookup is qualified.
- // current field is qualified and not shared between relations, compare both
- // qualifier and name.
- (Some(q), Some(field_q)) => q == field_q && field.name() == name,
- // field to lookup is qualified but current field is unqualified.
- (Some(_), None) => false,
- // field to lookup is unqualified, no need to compare qualifier
- (None, Some(_)) | (None, None) => field.name() == name,
- })
- .map(|(idx, _)| idx);
- match matches.next() {
- None => Err(DataFusionError::Plan(format!(
- "No field named '{}.{}'. Valid fields are {}.",
- qualifier.unwrap_or("<unqualified>"),
- name,
- self.get_field_names()
- ))),
- Some(idx) => match matches.next() {
- None => Ok(idx),
- // found more than one matches
- Some(_) => Err(DataFusionError::Internal(format!(
- "Ambiguous reference to qualified field named '{}.{}'",
- qualifier.unwrap_or("<unqualified>"),
- name
- ))),
- },
- }
- }
-
- /// Find the index of the column with the given qualifier and name
- pub fn index_of_column(&self, col: &Column) -> Result<usize> {
- self.index_of_column_by_name(col.relation.as_deref(), &col.name)
- }
-
- /// Find the field with the given name
- pub fn field_with_name(
- &self,
- qualifier: Option<&str>,
- name: &str,
- ) -> Result<&DFField> {
- if let Some(qualifier) = qualifier {
- self.field_with_qualified_name(qualifier, name)
- } else {
- self.field_with_unqualified_name(name)
- }
- }
-
- /// Find all fields match the given name
- pub fn fields_with_unqualified_name(&self, name: &str) -> Vec<&DFField> {
- self.fields
- .iter()
- .filter(|field| field.name() == name)
- .collect()
- }
-
- /// Find the field with the given name
- pub fn field_with_unqualified_name(&self, name: &str) -> Result<&DFField> {
- let matches = self.fields_with_unqualified_name(name);
- match matches.len() {
- 0 => Err(DataFusionError::Plan(format!(
- "No field with unqualified name '{}'. Valid fields are {}.",
- name,
- self.get_field_names()
- ))),
- 1 => Ok(matches[0]),
- _ => Err(DataFusionError::Plan(format!(
- "Ambiguous reference to field named '{}'",
- name
- ))),
- }
- }
-
- /// Find the field with the given qualified name
- pub fn field_with_qualified_name(
- &self,
- qualifier: &str,
- name: &str,
- ) -> Result<&DFField> {
- let idx = self.index_of_column_by_name(Some(qualifier), name)?;
- Ok(self.field(idx))
- }
-
- /// Find the field with the given qualified column
- pub fn field_from_column(&self, column: &Column) -> Result<&DFField> {
- match &column.relation {
- Some(r) => self.field_with_qualified_name(r, &column.name),
- None => self.field_with_unqualified_name(&column.name),
- }
- }
-
- /// Check to see if unqualified field names matches field names in Arrow schema
- pub fn matches_arrow_schema(&self, arrow_schema: &Schema) -> bool {
- self.fields
- .iter()
- .zip(arrow_schema.fields().iter())
- .all(|(dffield, arrowfield)| dffield.name() == arrowfield.name())
- }
-
- /// Strip all field qualifier in schema
- pub fn strip_qualifiers(self) -> Self {
- DFSchema {
- fields: self
- .fields
- .into_iter()
- .map(|f| f.strip_qualifier())
- .collect(),
- }
- }
-
- /// Replace all field qualifier with new value in schema
- pub fn replace_qualifier(self, qualifier: &str) -> Self {
- DFSchema {
- fields: self
- .fields
- .into_iter()
- .map(|f| {
- DFField::new(
- Some(qualifier),
- f.name(),
- f.data_type().to_owned(),
- f.is_nullable(),
- )
- })
- .collect(),
- }
- }
-
- /// Get comma-seperated list of field names for use in error messages
- fn get_field_names(&self) -> String {
- self.fields
- .iter()
- .map(|f| match f.qualifier() {
- Some(qualifier) => format!("'{}.{}'", qualifier, f.name()),
- None => format!("'{}'", f.name()),
- })
- .collect::<Vec<_>>()
- .join(", ")
- }
-}
-
-impl From<DFSchema> for Schema {
- /// Convert DFSchema into a Schema
- fn from(df_schema: DFSchema) -> Self {
- Schema::new(
- df_schema
- .fields
- .into_iter()
- .map(|f| {
- if f.qualifier().is_some() {
- Field::new(f.name(), f.data_type().to_owned(), f.is_nullable())
- } else {
- f.field
- }
- })
- .collect(),
- )
- }
-}
-
-impl From<&DFSchema> for Schema {
- /// Convert DFSchema reference into a Schema
- fn from(df_schema: &DFSchema) -> Self {
- Schema::new(df_schema.fields.iter().map(|f| f.field.clone()).collect())
- }
-}
-
-/// Create a `DFSchema` from an Arrow schema
-impl TryFrom<Schema> for DFSchema {
- type Error = DataFusionError;
- fn try_from(schema: Schema) -> std::result::Result<Self, Self::Error> {
- Self::new(
- schema
- .fields()
- .iter()
- .map(|f| DFField::from(f.clone()))
- .collect(),
- )
- }
-}
-
-impl From<DFSchema> for SchemaRef {
- fn from(df_schema: DFSchema) -> Self {
- SchemaRef::new(df_schema.into())
- }
-}
-
-/// Convenience trait to convert Schema like things to DFSchema and DFSchemaRef with fewer keystrokes
-pub trait ToDFSchema
-where
- Self: Sized,
-{
- /// Attempt to create a DSSchema
- #[allow(clippy::wrong_self_convention)]
- fn to_dfschema(self) -> Result<DFSchema>;
-
- /// Attempt to create a DSSchemaRef
- #[allow(clippy::wrong_self_convention)]
- fn to_dfschema_ref(self) -> Result<DFSchemaRef> {
- Ok(Arc::new(self.to_dfschema()?))
- }
-}
-
-impl ToDFSchema for Schema {
- #[allow(clippy::wrong_self_convention)]
- fn to_dfschema(self) -> Result<DFSchema> {
- DFSchema::try_from(self)
- }
-}
-
-impl ToDFSchema for SchemaRef {
- #[allow(clippy::wrong_self_convention)]
- fn to_dfschema(self) -> Result<DFSchema> {
- // Attempt to use the Schema directly if there are no other
- // references, otherwise clone
- match Self::try_unwrap(self) {
- Ok(schema) => DFSchema::try_from(schema),
- Err(schemaref) => DFSchema::try_from(schemaref.as_ref().clone()),
- }
- }
-}
-
-impl ToDFSchema for Vec<DFField> {
- fn to_dfschema(self) -> Result<DFSchema> {
- DFSchema::new(self)
- }
-}
-
-impl Display for DFSchema {
- fn fmt(&self, f: &mut Formatter) -> std::fmt::Result {
- write!(
- f,
- "{}",
- self.fields
- .iter()
- .map(|field| field.qualified_name())
- .collect::<Vec<String>>()
- .join(", ")
- )
- }
-}
-
-/// DFField wraps an Arrow field and adds an optional qualifier
-#[derive(Debug, Clone, PartialEq, Eq)]
-pub struct DFField {
- /// Optional qualifier (usually a table or relation name)
- qualifier: Option<String>,
- /// Arrow field definition
- field: Field,
-}
-
-impl DFField {
- /// Creates a new `DFField`
- pub fn new(
- qualifier: Option<&str>,
- name: &str,
- data_type: DataType,
- nullable: bool,
- ) -> Self {
- DFField {
- qualifier: qualifier.map(|s| s.to_owned()),
- field: Field::new(name, data_type, nullable),
- }
- }
-
- /// Create an unqualified field from an existing Arrow field
- pub fn from(field: Field) -> Self {
- Self {
- qualifier: None,
- field,
- }
- }
-
- /// Create a qualified field from an existing Arrow field
- pub fn from_qualified(qualifier: &str, field: Field) -> Self {
- Self {
- qualifier: Some(qualifier.to_owned()),
- field,
- }
- }
-
- /// Returns an immutable reference to the `DFField`'s unqualified name
- pub fn name(&self) -> &str {
- self.field.name()
- }
-
- /// Returns an immutable reference to the `DFField`'s data-type
- pub fn data_type(&self) -> &DataType {
- self.field.data_type()
- }
-
- /// Indicates whether this `DFField` supports null values
- pub fn is_nullable(&self) -> bool {
- self.field.is_nullable()
- }
-
- /// Returns a string to the `DFField`'s qualified name
- pub fn qualified_name(&self) -> String {
- if let Some(qualifier) = &self.qualifier {
- format!("{}.{}", qualifier, self.field.name())
- } else {
- self.field.name().to_owned()
- }
- }
-
- /// Builds a qualified column based on self
- pub fn qualified_column(&self) -> Column {
- Column {
- relation: self.qualifier.clone(),
- name: self.field.name().to_string(),
- }
- }
-
- /// Builds an unqualified column based on self
- pub fn unqualified_column(&self) -> Column {
- Column {
- relation: None,
- name: self.field.name().to_string(),
- }
- }
-
- /// Get the optional qualifier
- pub fn qualifier(&self) -> Option<&String> {
- self.qualifier.as_ref()
- }
-
- /// Get the arrow field
- pub fn field(&self) -> &Field {
- &self.field
- }
-
- /// Return field with qualifier stripped
- pub fn strip_qualifier(mut self) -> Self {
- self.qualifier = None;
- self
- }
-}
-
-#[cfg(test)]
-mod tests {
- use super::*;
- use arrow::datatypes::DataType;
-
- #[test]
- fn from_unqualified_field() {
- let field = Field::new("c0", DataType::Boolean, true);
- let field = DFField::from(field);
- assert_eq!("c0", field.name());
- assert_eq!("c0", field.qualified_name());
- }
-
- #[test]
- fn from_qualified_field() {
- let field = Field::new("c0", DataType::Boolean, true);
- let field = DFField::from_qualified("t1", field);
- assert_eq!("c0", field.name());
- assert_eq!("t1.c0", field.qualified_name());
- }
-
- #[test]
- fn from_unqualified_schema() -> Result<()> {
- let schema = DFSchema::try_from(test_schema_1())?;
- assert_eq!("c0, c1", schema.to_string());
- Ok(())
- }
-
- #[test]
- fn from_qualified_schema() -> Result<()> {
- let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
- assert_eq!("t1.c0, t1.c1", schema.to_string());
- Ok(())
- }
-
- #[test]
- fn from_qualified_schema_into_arrow_schema() -> Result<()> {
- let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
- let arrow_schema: Schema = schema.into();
- let expected =
- "[Field { name: \"c0\", data_type: Boolean, is_nullable: true, metadata: {} }, \
- Field { name: \"c1\", data_type: Boolean, is_nullable: true, metadata: {} }]";
- assert_eq!(expected, format!("{:?}", arrow_schema.fields));
- Ok(())
- }
-
- #[test]
- fn join_qualified() -> Result<()> {
- let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
- let right = DFSchema::try_from_qualified_schema("t2", &test_schema_1())?;
- let join = left.join(&right)?;
- assert_eq!("t1.c0, t1.c1, t2.c0, t2.c1", join.to_string());
- // test valid access
- assert!(join.field_with_qualified_name("t1", "c0").is_ok());
- assert!(join.field_with_qualified_name("t2", "c0").is_ok());
- // test invalid access
- assert!(join.field_with_unqualified_name("c0").is_err());
- assert!(join.field_with_unqualified_name("t1.c0").is_err());
- assert!(join.field_with_unqualified_name("t2.c0").is_err());
- Ok(())
- }
-
- #[test]
- fn join_qualified_duplicate() -> Result<()> {
- let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
- let right = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
- let join = left.join(&right);
- assert!(join.is_err());
- assert_eq!(
- "Error during planning: Schema contains duplicate \
- qualified field name \'t1.c0\'",
- &format!("{}", join.err().unwrap())
- );
- Ok(())
- }
-
- #[test]
- fn join_unqualified_duplicate() -> Result<()> {
- let left = DFSchema::try_from(test_schema_1())?;
- let right = DFSchema::try_from(test_schema_1())?;
- let join = left.join(&right);
- assert!(join.is_err());
- assert_eq!(
- "Error during planning: Schema contains duplicate \
- unqualified field name \'c0\'",
- &format!("{}", join.err().unwrap())
- );
- Ok(())
- }
-
- #[test]
- fn join_mixed() -> Result<()> {
- let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
- let right = DFSchema::try_from(test_schema_2())?;
- let join = left.join(&right)?;
- assert_eq!("t1.c0, t1.c1, c100, c101", join.to_string());
- // test valid access
- assert!(join.field_with_qualified_name("t1", "c0").is_ok());
- assert!(join.field_with_unqualified_name("c0").is_ok());
- assert!(join.field_with_unqualified_name("c100").is_ok());
- assert!(join.field_with_name(None, "c100").is_ok());
- // test invalid access
- assert!(join.field_with_unqualified_name("t1.c0").is_err());
- assert!(join.field_with_unqualified_name("t1.c100").is_err());
- assert!(join.field_with_qualified_name("", "c100").is_err());
- Ok(())
- }
-
- #[test]
- fn join_mixed_duplicate() -> Result<()> {
- let left = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
- let right = DFSchema::try_from(test_schema_1())?;
- let join = left.join(&right);
- assert!(join.is_err());
- assert_eq!(
- "Error during planning: Schema contains qualified \
- field name \'t1.c0\' and unqualified field name \'c0\' which would be ambiguous",
- &format!("{}", join.err().unwrap())
- );
- Ok(())
- }
-
- #[test]
- fn helpful_error_messages() -> Result<()> {
- let schema = DFSchema::try_from_qualified_schema("t1", &test_schema_1())?;
- let expected_help = "Valid fields are \'t1.c0\', \'t1.c1\'.";
- assert!(schema
- .field_with_qualified_name("x", "y")
- .unwrap_err()
- .to_string()
- .contains(expected_help));
- assert!(schema
- .field_with_unqualified_name("y")
- .unwrap_err()
- .to_string()
- .contains(expected_help));
- assert!(schema
- .index_of("y")
- .unwrap_err()
- .to_string()
- .contains(expected_help));
- Ok(())
- }
-
- #[test]
- fn into() {
- // Demonstrate how to convert back and forth between Schema, SchemaRef, DFSchema, and DFSchemaRef
- let arrow_schema = Schema::new(vec![Field::new("c0", DataType::Int64, true)]);
- let arrow_schema_ref = Arc::new(arrow_schema.clone());
-
- let df_schema =
- DFSchema::new(vec![DFField::new(None, "c0", DataType::Int64, true)]).unwrap();
- let df_schema_ref = Arc::new(df_schema.clone());
-
- {
- let arrow_schema = arrow_schema.clone();
- let arrow_schema_ref = arrow_schema_ref.clone();
-
- assert_eq!(df_schema, arrow_schema.to_dfschema().unwrap());
- assert_eq!(df_schema, arrow_schema_ref.to_dfschema().unwrap());
- }
-
- {
- let arrow_schema = arrow_schema.clone();
- let arrow_schema_ref = arrow_schema_ref.clone();
-
- assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap());
- assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap());
- }
-
- // Now, consume the refs
- assert_eq!(df_schema_ref, arrow_schema.to_dfschema_ref().unwrap());
- assert_eq!(df_schema_ref, arrow_schema_ref.to_dfschema_ref().unwrap());
- }
-
- fn test_schema_1() -> Schema {
- Schema::new(vec![
- Field::new("c0", DataType::Boolean, true),
- Field::new("c1", DataType::Boolean, true),
- ])
- }
-
- fn test_schema_2() -> Schema {
- Schema::new(vec![
- Field::new("c100", DataType::Boolean, true),
- Field::new("c101", DataType::Boolean, true),
- ])
- }
-}
+pub use datafusion_common::{DFField, DFSchema, DFSchemaRef, ExprSchema, ToDFSchema};
diff --git a/datafusion/src/logical_plan/expr.rs b/datafusion/src/logical_plan/expr.rs
index 2dd9f9e..3826e45 100644
--- a/datafusion/src/logical_plan/expr.rs
+++ b/datafusion/src/logical_plan/expr.rs
@@ -20,1147 +20,25 @@
pub use super::Operator;
-use arrow::{compute::cast::can_cast_types, datatypes::DataType};
+use arrow::datatypes::DataType;
use crate::error::{DataFusionError, Result};
-use crate::execution::context::ExecutionProps;
-use crate::field_util::{get_indexed_field, FieldExt};
-use crate::logical_plan::{
- plan::Aggregate, window_frames, DFField, DFSchema, LogicalPlan,
+use crate::logical_plan::ExprSchemable;
+use crate::logical_plan::{DFField, DFSchema};
+use crate::physical_plan::udaf::AggregateUDF;
+use crate::physical_plan::{aggregates, functions, udf::ScalarUDF};
+pub use datafusion_common::{Column, ExprSchema};
+pub use datafusion_expr::expr_fn::col;
+use datafusion_expr::AccumulatorFunctionImplementation;
+pub use datafusion_expr::Expr;
+use datafusion_expr::StateTypeFunction;
+pub use datafusion_expr::{lit, lit_timestamp_nano, Literal};
+use datafusion_expr::{
+ ReturnTypeFunction, ScalarFunctionImplementation, Signature, Volatility,
};
-use crate::optimizer::simplify_expressions::{ConstEvaluator, Simplifier};
-use crate::physical_plan::functions::Volatility;
-use crate::physical_plan::{
- aggregates, expressions::binary_operator_data_type, functions, udf::ScalarUDF,
- window_functions,
-};
-use crate::{physical_plan::udaf::AggregateUDF, scalar::ScalarValue};
-use aggregates::{AccumulatorFunctionImplementation, StateTypeFunction};
-use functions::{ReturnTypeFunction, ScalarFunctionImplementation, Signature};
-use std::collections::{HashMap, HashSet};
-use std::convert::Infallible;
-use std::fmt;
-use std::hash::{BuildHasher, Hash, Hasher};
-use std::ops::Not;
-use std::str::FromStr;
+use std::collections::HashSet;
use std::sync::Arc;
-/// A named reference to a qualified field in a schema.
-#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
-pub struct Column {
- /// relation/table name.
- pub relation: Option<String>,
- /// field/column name.
- pub name: String,
-}
-
-impl Column {
- /// Create Column from unqualified name.
- pub fn from_name(name: impl Into<String>) -> Self {
- Self {
- relation: None,
- name: name.into(),
- }
- }
-
- /// Deserialize a fully qualified name string into a column
- pub fn from_qualified_name(flat_name: &str) -> Self {
- use sqlparser::tokenizer::Token;
-
- let dialect = sqlparser::dialect::GenericDialect {};
- let mut tokenizer = sqlparser::tokenizer::Tokenizer::new(&dialect, flat_name);
- if let Ok(tokens) = tokenizer.tokenize() {
- if let [Token::Word(relation), Token::Period, Token::Word(name)] =
- tokens.as_slice()
- {
- return Column {
- relation: Some(relation.value.clone()),
- name: name.value.clone(),
- };
- }
- }
- // any expression that's not in the form of `foo.bar` will be treated as unqualified column
- // name
- Column {
- relation: None,
- name: String::from(flat_name),
- }
- }
-
- /// Serialize column into a flat name string
- pub fn flat_name(&self) -> String {
- match &self.relation {
- Some(r) => format!("{}.{}", r, self.name),
- None => self.name.clone(),
- }
- }
-
- /// Normalizes `self` if is unqualified (has no relation name)
- /// with an explicit qualifier from the first matching input
- /// schemas.
- ///
- /// For example, `foo` will be normalized to `t.foo` if there is a
- /// column named `foo` in a relation named `t` found in `schemas`
- pub fn normalize(self, plan: &LogicalPlan) -> Result<Self> {
- let schemas = plan.all_schemas();
- let using_columns = plan.using_columns()?;
- self.normalize_with_schemas(&schemas, &using_columns)
- }
-
- // Internal implementation of normalize
- fn normalize_with_schemas(
- self,
- schemas: &[&Arc<DFSchema>],
- using_columns: &[HashSet<Column>],
- ) -> Result<Self> {
- if self.relation.is_some() {
- return Ok(self);
- }
-
- for schema in schemas {
- let fields = schema.fields_with_unqualified_name(&self.name);
- match fields.len() {
- 0 => continue,
- 1 => {
- return Ok(fields[0].qualified_column());
- }
- _ => {
- // More than 1 fields in this schema have their names set to self.name.
- //
- // This should only happen when a JOIN query with USING constraint references
- // join columns using unqualified column name. For example:
- //
- // ```sql
- // SELECT id FROM t1 JOIN t2 USING(id)
- // ```
- //
- // In this case, both `t1.id` and `t2.id` will match unqualified column `id`.
- // We will use the relation from the first matched field to normalize self.
-
- // Compare matched fields with one USING JOIN clause at a time
- for using_col in using_columns {
- let all_matched = fields
- .iter()
- .all(|f| using_col.contains(&f.qualified_column()));
- // All matched fields belong to the same using column set, in orther words
- // the same join clause. We simply pick the qualifer from the first match.
- if all_matched {
- return Ok(fields[0].qualified_column());
- }
- }
- }
- }
- }
-
- Err(DataFusionError::Plan(format!(
- "Column {} not found in provided schemas",
- self
- )))
- }
-}
-
-impl From<&str> for Column {
- fn from(c: &str) -> Self {
- Self::from_qualified_name(c)
- }
-}
-
-impl FromStr for Column {
- type Err = Infallible;
-
- fn from_str(s: &str) -> std::result::Result<Self, Self::Err> {
- Ok(s.into())
- }
-}
-
-impl fmt::Display for Column {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- match &self.relation {
- Some(r) => write!(f, "#{}.{}", r, self.name),
- None => write!(f, "#{}", self.name),
- }
- }
-}
-
-/// `Expr` is a central struct of DataFusion's query API, and
-/// represent logical expressions such as `A + 1`, or `CAST(c1 AS
-/// int)`.
-///
-/// An `Expr` can compute its [DataType](arrow::datatypes::DataType)
-/// and nullability, and has functions for building up complex
-/// expressions.
-///
-/// # Examples
-///
-/// ## Create an expression `c1` referring to column named "c1"
-/// ```
-/// # use datafusion::logical_plan::*;
-/// let expr = col("c1");
-/// assert_eq!(expr, Expr::Column(Column::from_name("c1")));
-/// ```
-///
-/// ## Create the expression `c1 + c2` to add columns "c1" and "c2" together
-/// ```
-/// # use datafusion::logical_plan::*;
-/// let expr = col("c1") + col("c2");
-///
-/// assert!(matches!(expr, Expr::BinaryExpr { ..} ));
-/// if let Expr::BinaryExpr { left, right, op } = expr {
-/// assert_eq!(*left, col("c1"));
-/// assert_eq!(*right, col("c2"));
-/// assert_eq!(op, Operator::Plus);
-/// }
-/// ```
-///
-/// ## Create expression `c1 = 42` to compare the value in coumn "c1" to the literal value `42`
-/// ```
-/// # use datafusion::logical_plan::*;
-/// # use datafusion::scalar::*;
-/// let expr = col("c1").eq(lit(42));
-///
-/// assert!(matches!(expr, Expr::BinaryExpr { ..} ));
-/// if let Expr::BinaryExpr { left, right, op } = expr {
-/// assert_eq!(*left, col("c1"));
-/// let scalar = ScalarValue::Int32(Some(42));
-/// assert_eq!(*right, Expr::Literal(scalar));
-/// assert_eq!(op, Operator::Eq);
-/// }
-/// ```
-#[derive(Clone, PartialEq, Hash)]
-pub enum Expr {
- /// An expression with a specific name.
- Alias(Box<Expr>, String),
- /// A named reference to a qualified filed in a schema.
- Column(Column),
- /// A named reference to a variable in a registry.
- ScalarVariable(Vec<String>),
- /// A constant value.
- Literal(ScalarValue),
- /// A binary expression such as "age > 21"
- BinaryExpr {
- /// Left-hand side of the expression
- left: Box<Expr>,
- /// The comparison operator
- op: Operator,
- /// Right-hand side of the expression
- right: Box<Expr>,
- },
- /// Negation of an expression. The expression's type must be a boolean to make sense.
- Not(Box<Expr>),
- /// Whether an expression is not Null. This expression is never null.
- IsNotNull(Box<Expr>),
- /// Whether an expression is Null. This expression is never null.
- IsNull(Box<Expr>),
- /// arithmetic negation of an expression, the operand must be of a signed numeric data type
- Negative(Box<Expr>),
- /// Returns the field of a [`ListArray`] or [`StructArray`] by key
- GetIndexedField {
- /// the expression to take the field from
- expr: Box<Expr>,
- /// The name of the field to take
- key: ScalarValue,
- },
- /// Whether an expression is between a given range.
- Between {
- /// The value to compare
- expr: Box<Expr>,
- /// Whether the expression is negated
- negated: bool,
- /// The low end of the range
- low: Box<Expr>,
- /// The high end of the range
- high: Box<Expr>,
- },
- /// The CASE expression is similar to a series of nested if/else and there are two forms that
- /// can be used. The first form consists of a series of boolean "when" expressions with
- /// corresponding "then" expressions, and an optional "else" expression.
- ///
- /// CASE WHEN condition THEN result
- /// [WHEN ...]
- /// [ELSE result]
- /// END
- ///
- /// The second form uses a base expression and then a series of "when" clauses that match on a
- /// literal value.
- ///
- /// CASE expression
- /// WHEN value THEN result
- /// [WHEN ...]
- /// [ELSE result]
- /// END
- Case {
- /// Optional base expression that can be compared to literal values in the "when" expressions
- expr: Option<Box<Expr>>,
- /// One or more when/then expressions
- when_then_expr: Vec<(Box<Expr>, Box<Expr>)>,
- /// Optional "else" expression
- else_expr: Option<Box<Expr>>,
- },
- /// Casts the expression to a given type and will return a runtime error if the expression cannot be cast.
- /// This expression is guaranteed to have a fixed type.
- Cast {
- /// The expression being cast
- expr: Box<Expr>,
- /// The `DataType` the expression will yield
- data_type: DataType,
- },
- /// Casts the expression to a given type and will return a null value if the expression cannot be cast.
- /// This expression is guaranteed to have a fixed type.
- TryCast {
- /// The expression being cast
- expr: Box<Expr>,
- /// The `DataType` the expression will yield
- data_type: DataType,
- },
- /// A sort expression, that can be used to sort values.
- Sort {
- /// The expression to sort on
- expr: Box<Expr>,
- /// The direction of the sort
- asc: bool,
- /// Whether to put Nulls before all other data values
- nulls_first: bool,
- },
- /// Represents the call of a built-in scalar function with a set of arguments.
- ScalarFunction {
- /// The function
- fun: functions::BuiltinScalarFunction,
- /// List of expressions to feed to the functions as arguments
- args: Vec<Expr>,
- },
- /// Represents the call of a user-defined scalar function with arguments.
- ScalarUDF {
- /// The function
- fun: Arc<ScalarUDF>,
- /// List of expressions to feed to the functions as arguments
- args: Vec<Expr>,
- },
- /// Represents the call of an aggregate built-in function with arguments.
- AggregateFunction {
- /// Name of the function
- fun: aggregates::AggregateFunction,
- /// List of expressions to feed to the functions as arguments
- args: Vec<Expr>,
- /// Whether this is a DISTINCT aggregation or not
- distinct: bool,
- },
- /// Represents the call of a window function with arguments.
- WindowFunction {
- /// Name of the function
- fun: window_functions::WindowFunction,
- /// List of expressions to feed to the functions as arguments
- args: Vec<Expr>,
- /// List of partition by expressions
- partition_by: Vec<Expr>,
- /// List of order by expressions
- order_by: Vec<Expr>,
- /// Window frame
- window_frame: Option<window_frames::WindowFrame>,
- },
- /// aggregate function
- AggregateUDF {
- /// The function
- fun: Arc<AggregateUDF>,
- /// List of expressions to feed to the functions as arguments
- args: Vec<Expr>,
- },
- /// Returns whether the list contains the expr value.
- InList {
- /// The expression to compare
- expr: Box<Expr>,
- /// A list of values to compare against
- list: Vec<Expr>,
- /// Whether the expression is negated
- negated: bool,
- },
- /// Represents a reference to all fields in a schema.
- Wildcard,
-}
-
-/// Fixed seed for the hashing so that Ords are consistent across runs
-const SEED: ahash::RandomState = ahash::RandomState::with_seeds(0, 0, 0, 0);
-
-impl PartialOrd for Expr {
- fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
- let mut hasher = SEED.build_hasher();
- self.hash(&mut hasher);
- let s = hasher.finish();
-
- let mut hasher = SEED.build_hasher();
- other.hash(&mut hasher);
- let o = hasher.finish();
-
- Some(s.cmp(&o))
- }
-}
-
-impl Expr {
- /// Returns the [arrow::datatypes::DataType] of the expression based on [arrow::datatypes::Schema].
- ///
- /// # Errors
- ///
- /// This function errors when it is not possible to compute its [arrow::datatypes::DataType].
- /// This happens when e.g. the expression refers to a column that does not exist in the schema, or when
- /// the expression is incorrectly typed (e.g. `[utf8] + [bool]`).
- pub fn get_type(&self, schema: &DFSchema) -> Result<DataType> {
- match self {
- Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => {
- expr.get_type(schema)
- }
- Expr::Column(c) => Ok(schema.field_from_column(c)?.data_type().clone()),
- Expr::ScalarVariable(_) => Ok(DataType::Utf8),
- Expr::Literal(l) => Ok(l.get_datatype()),
- Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
- Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => {
- Ok(data_type.clone())
- }
- Expr::ScalarUDF { fun, args } => {
- let data_types = args
- .iter()
- .map(|e| e.get_type(schema))
- .collect::<Result<Vec<_>>>()?;
- Ok((fun.return_type)(&data_types)?.as_ref().clone())
- }
- Expr::ScalarFunction { fun, args } => {
- let data_types = args
- .iter()
- .map(|e| e.get_type(schema))
- .collect::<Result<Vec<_>>>()?;
- functions::return_type(fun, &data_types)
- }
- Expr::WindowFunction { fun, args, .. } => {
- let data_types = args
- .iter()
- .map(|e| e.get_type(schema))
- .collect::<Result<Vec<_>>>()?;
- window_functions::return_type(fun, &data_types)
- }
- Expr::AggregateFunction { fun, args, .. } => {
- let data_types = args
- .iter()
- .map(|e| e.get_type(schema))
- .collect::<Result<Vec<_>>>()?;
- aggregates::return_type(fun, &data_types)
- }
- Expr::AggregateUDF { fun, args, .. } => {
- let data_types = args
- .iter()
- .map(|e| e.get_type(schema))
- .collect::<Result<Vec<_>>>()?;
- Ok((fun.return_type)(&data_types)?.as_ref().clone())
- }
- Expr::Not(_)
- | Expr::IsNull(_)
- | Expr::Between { .. }
- | Expr::InList { .. }
- | Expr::IsNotNull(_) => Ok(DataType::Boolean),
- Expr::BinaryExpr {
- ref left,
- ref right,
- ref op,
- } => binary_operator_data_type(
- &left.get_type(schema)?,
- op,
- &right.get_type(schema)?,
- ),
- Expr::Wildcard => Err(DataFusionError::Internal(
- "Wildcard expressions are not valid in a logical query plan".to_owned(),
- )),
- Expr::GetIndexedField { ref expr, key } => {
- let data_type = expr.get_type(schema)?;
-
- get_indexed_field(&data_type, key).map(|x| x.data_type().clone())
- }
- }
- }
-
- /// Returns the nullability of the expression based on [arrow::datatypes::Schema].
- ///
- /// # Errors
- ///
- /// This function errors when it is not possible to compute its nullability.
- /// This happens when the expression refers to a column that does not exist in the schema.
- pub fn nullable(&self, input_schema: &DFSchema) -> Result<bool> {
- match self {
- Expr::Alias(expr, _)
- | Expr::Not(expr)
- | Expr::Negative(expr)
- | Expr::Sort { expr, .. }
- | Expr::Between { expr, .. }
- | Expr::InList { expr, .. } => expr.nullable(input_schema),
- Expr::Column(c) => Ok(input_schema.field_from_column(c)?.is_nullable()),
- Expr::Literal(value) => Ok(value.is_null()),
- Expr::Case {
- when_then_expr,
- else_expr,
- ..
- } => {
- // this expression is nullable if any of the input expressions are nullable
- let then_nullable = when_then_expr
- .iter()
- .map(|(_, t)| t.nullable(input_schema))
- .collect::<Result<Vec<_>>>()?;
- if then_nullable.contains(&true) {
- Ok(true)
- } else if let Some(e) = else_expr {
- e.nullable(input_schema)
- } else {
- Ok(false)
- }
- }
- Expr::Cast { expr, .. } => expr.nullable(input_schema),
- Expr::ScalarVariable(_)
- | Expr::TryCast { .. }
- | Expr::ScalarFunction { .. }
- | Expr::ScalarUDF { .. }
- | Expr::WindowFunction { .. }
- | Expr::AggregateFunction { .. }
- | Expr::AggregateUDF { .. } => Ok(true),
- Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false),
- Expr::BinaryExpr {
- ref left,
- ref right,
- ..
- } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
- Expr::Wildcard => Err(DataFusionError::Internal(
- "Wildcard expressions are not valid in a logical query plan".to_owned(),
- )),
- Expr::GetIndexedField { ref expr, key } => {
- let data_type = expr.get_type(input_schema)?;
- get_indexed_field(&data_type, key).map(|x| x.is_nullable())
- }
- }
- }
-
- /// Returns the name of this expression based on [crate::logical_plan::DFSchema].
- ///
- /// This represents how a column with this expression is named when no alias is chosen
- pub fn name(&self, input_schema: &DFSchema) -> Result<String> {
- create_name(self, input_schema)
- }
-
- /// Returns a [arrow::datatypes::Field] compatible with this expression.
- pub fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
- match self {
- Expr::Column(c) => Ok(DFField::new(
- c.relation.as_deref(),
- &c.name,
- self.get_type(input_schema)?,
- self.nullable(input_schema)?,
- )),
- _ => Ok(DFField::new(
- None,
- &self.name(input_schema)?,
- self.get_type(input_schema)?,
- self.nullable(input_schema)?,
- )),
- }
- }
-
- /// Wraps this expression in a cast to a target [arrow::datatypes::DataType].
- ///
- /// # Errors
- ///
- /// This function errors when it is impossible to cast the
- /// expression to the target [arrow::datatypes::DataType].
- pub fn cast_to(self, cast_to_type: &DataType, schema: &DFSchema) -> Result<Expr> {
- // TODO(kszucs): most of the operations do not validate the type correctness
- // like all of the binary expressions below. Perhaps Expr should track the
- // type of the expression?
- let this_type = self.get_type(schema)?;
- if this_type == *cast_to_type {
- Ok(self)
- } else if can_cast_types(&this_type, cast_to_type) {
- Ok(Expr::Cast {
- expr: Box::new(self),
- data_type: cast_to_type.clone(),
- })
- } else {
- Err(DataFusionError::Plan(format!(
- "Cannot automatically convert {:?} to {:?}",
- this_type, cast_to_type
- )))
- }
- }
-
- /// Return `self == other`
- pub fn eq(self, other: Expr) -> Expr {
- binary_expr(self, Operator::Eq, other)
- }
-
- /// Return `self != other`
- pub fn not_eq(self, other: Expr) -> Expr {
- binary_expr(self, Operator::NotEq, other)
- }
-
- /// Return `self > other`
- pub fn gt(self, other: Expr) -> Expr {
- binary_expr(self, Operator::Gt, other)
- }
-
- /// Return `self >= other`
- pub fn gt_eq(self, other: Expr) -> Expr {
- binary_expr(self, Operator::GtEq, other)
- }
-
- /// Return `self < other`
- pub fn lt(self, other: Expr) -> Expr {
- binary_expr(self, Operator::Lt, other)
- }
-
- /// Return `self <= other`
- pub fn lt_eq(self, other: Expr) -> Expr {
- binary_expr(self, Operator::LtEq, other)
- }
-
- /// Return `self && other`
- pub fn and(self, other: Expr) -> Expr {
- binary_expr(self, Operator::And, other)
- }
-
- /// Return `self || other`
- pub fn or(self, other: Expr) -> Expr {
- binary_expr(self, Operator::Or, other)
- }
-
- /// Return `!self`
- #[allow(clippy::should_implement_trait)]
- pub fn not(self) -> Expr {
- !self
- }
-
- /// Calculate the modulus of two expressions.
- /// Return `self % other`
- pub fn modulus(self, other: Expr) -> Expr {
- binary_expr(self, Operator::Modulo, other)
- }
-
- /// Return `self LIKE other`
- pub fn like(self, other: Expr) -> Expr {
- binary_expr(self, Operator::Like, other)
- }
-
- /// Return `self NOT LIKE other`
- pub fn not_like(self, other: Expr) -> Expr {
- binary_expr(self, Operator::NotLike, other)
- }
-
- /// Return `self AS name` alias expression
- pub fn alias(self, name: &str) -> Expr {
- Expr::Alias(Box::new(self), name.to_owned())
- }
-
- /// Return `self IN <list>` if `negated` is false, otherwise
- /// return `self NOT IN <list>`.a
- pub fn in_list(self, list: Vec<Expr>, negated: bool) -> Expr {
- Expr::InList {
- expr: Box::new(self),
- list,
- negated,
- }
- }
-
- /// Return `IsNull(Box(self))
- #[allow(clippy::wrong_self_convention)]
- pub fn is_null(self) -> Expr {
- Expr::IsNull(Box::new(self))
- }
-
- /// Return `IsNotNull(Box(self))
- #[allow(clippy::wrong_self_convention)]
- pub fn is_not_null(self) -> Expr {
- Expr::IsNotNull(Box::new(self))
- }
-
- /// Create a sort expression from an existing expression.
- ///
- /// ```
- /// # use datafusion::logical_plan::col;
- /// let sort_expr = col("foo").sort(true, true); // SORT ASC NULLS_FIRST
- /// ```
- pub fn sort(self, asc: bool, nulls_first: bool) -> Expr {
- Expr::Sort {
- expr: Box::new(self),
- asc,
- nulls_first,
- }
- }
-
- /// Performs a depth first walk of an expression and
- /// its children, calling [`ExpressionVisitor::pre_visit`] and
- /// `visitor.post_visit`.
- ///
- /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to
- /// separate expression algorithms from the structure of the
- /// `Expr` tree and make it easier to add new types of expressions
- /// and algorithms that walk the tree.
- ///
- /// For an expression tree such as
- /// ```text
- /// BinaryExpr (GT)
- /// left: Column("foo")
- /// right: Column("bar")
- /// ```
- ///
- /// The nodes are visited using the following order
- /// ```text
- /// pre_visit(BinaryExpr(GT))
- /// pre_visit(Column("foo"))
- /// pre_visit(Column("bar"))
- /// post_visit(Column("bar"))
- /// post_visit(Column("bar"))
- /// post_visit(BinaryExpr(GT))
- /// ```
- ///
- /// If an Err result is returned, recursion is stopped immediately
- ///
- /// If `Recursion::Stop` is returned on a call to pre_visit, no
- /// children of that expression are visited, nor is post_visit
- /// called on that expression
- ///
- pub fn accept<V: ExpressionVisitor>(&self, visitor: V) -> Result<V> {
- let visitor = match visitor.pre_visit(self)? {
- Recursion::Continue(visitor) => visitor,
- // If the recursion should stop, do not visit children
- Recursion::Stop(visitor) => return Ok(visitor),
- };
-
- // recurse (and cover all expression types)
- let visitor = match self {
- Expr::Alias(expr, _)
- | Expr::Not(expr)
- | Expr::IsNotNull(expr)
- | Expr::IsNull(expr)
- | Expr::Negative(expr)
- | Expr::Cast { expr, .. }
- | Expr::TryCast { expr, .. }
- | Expr::Sort { expr, .. }
- | Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
- Expr::Column(_)
- | Expr::ScalarVariable(_)
- | Expr::Literal(_)
- | Expr::Wildcard => Ok(visitor),
- Expr::BinaryExpr { left, right, .. } => {
- let visitor = left.accept(visitor)?;
- right.accept(visitor)
- }
- Expr::Between {
- expr, low, high, ..
- } => {
- let visitor = expr.accept(visitor)?;
- let visitor = low.accept(visitor)?;
- high.accept(visitor)
- }
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- } => {
- let visitor = if let Some(expr) = expr.as_ref() {
- expr.accept(visitor)
- } else {
- Ok(visitor)
- }?;
- let visitor = when_then_expr.iter().try_fold(
- visitor,
- |visitor, (when, then)| {
- let visitor = when.accept(visitor)?;
- then.accept(visitor)
- },
- )?;
- if let Some(else_expr) = else_expr.as_ref() {
- else_expr.accept(visitor)
- } else {
- Ok(visitor)
- }
- }
- Expr::ScalarFunction { args, .. }
- | Expr::ScalarUDF { args, .. }
- | Expr::AggregateFunction { args, .. }
- | Expr::AggregateUDF { args, .. } => args
- .iter()
- .try_fold(visitor, |visitor, arg| arg.accept(visitor)),
- Expr::WindowFunction {
- args,
- partition_by,
- order_by,
- ..
- } => {
- let visitor = args
- .iter()
- .try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
- let visitor = partition_by
- .iter()
- .try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
- let visitor = order_by
- .iter()
- .try_fold(visitor, |visitor, arg| arg.accept(visitor))?;
- Ok(visitor)
- }
- Expr::InList { expr, list, .. } => {
- let visitor = expr.accept(visitor)?;
- list.iter()
- .try_fold(visitor, |visitor, arg| arg.accept(visitor))
- }
- }?;
-
- visitor.post_visit(self)
- }
-
- /// Performs a depth first walk of an expression and its children
- /// to rewrite an expression, consuming `self` producing a new
- /// [`Expr`].
- ///
- /// Implements a modified version of the [visitor
- /// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to
- /// separate algorithms from the structure of the `Expr` tree and
- /// make it easier to write new, efficient expression
- /// transformation algorithms.
- ///
- /// For an expression tree such as
- /// ```text
- /// BinaryExpr (GT)
- /// left: Column("foo")
- /// right: Column("bar")
- /// ```
- ///
- /// The nodes are visited using the following order
- /// ```text
- /// pre_visit(BinaryExpr(GT))
- /// pre_visit(Column("foo"))
- /// mutatate(Column("foo"))
- /// pre_visit(Column("bar"))
- /// mutate(Column("bar"))
- /// mutate(BinaryExpr(GT))
- /// ```
- ///
- /// If an Err result is returned, recursion is stopped immediately
- ///
- /// If [`false`] is returned on a call to pre_visit, no
- /// children of that expression are visited, nor is mutate
- /// called on that expression
- ///
- pub fn rewrite<R>(self, rewriter: &mut R) -> Result<Self>
- where
- R: ExprRewriter,
- {
- let need_mutate = match rewriter.pre_visit(&self)? {
- RewriteRecursion::Mutate => return rewriter.mutate(self),
- RewriteRecursion::Stop => return Ok(self),
- RewriteRecursion::Continue => true,
- RewriteRecursion::Skip => false,
- };
-
- // recurse into all sub expressions(and cover all expression types)
- let expr = match self {
- Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name),
- Expr::Column(_) => self.clone(),
- Expr::ScalarVariable(names) => Expr::ScalarVariable(names),
- Expr::Literal(value) => Expr::Literal(value),
- Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr {
- left: rewrite_boxed(left, rewriter)?,
- op,
- right: rewrite_boxed(right, rewriter)?,
- },
- Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?),
- Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?),
- Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?),
- Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?),
- Expr::Between {
- expr,
- low,
- high,
- negated,
- } => Expr::Between {
- expr: rewrite_boxed(expr, rewriter)?,
- low: rewrite_boxed(low, rewriter)?,
- high: rewrite_boxed(high, rewriter)?,
- negated,
- },
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- } => {
- let expr = rewrite_option_box(expr, rewriter)?;
- let when_then_expr = when_then_expr
- .into_iter()
- .map(|(when, then)| {
- Ok((
- rewrite_boxed(when, rewriter)?,
- rewrite_boxed(then, rewriter)?,
- ))
- })
- .collect::<Result<Vec<_>>>()?;
-
- let else_expr = rewrite_option_box(else_expr, rewriter)?;
-
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- }
- }
- Expr::Cast { expr, data_type } => Expr::Cast {
- expr: rewrite_boxed(expr, rewriter)?,
- data_type,
- },
- Expr::TryCast { expr, data_type } => Expr::TryCast {
- expr: rewrite_boxed(expr, rewriter)?,
- data_type,
- },
- Expr::Sort {
- expr,
- asc,
- nulls_first,
- } => Expr::Sort {
- expr: rewrite_boxed(expr, rewriter)?,
- asc,
- nulls_first,
- },
- Expr::ScalarFunction { args, fun } => Expr::ScalarFunction {
- args: rewrite_vec(args, rewriter)?,
- fun,
- },
- Expr::ScalarUDF { args, fun } => Expr::ScalarUDF {
- args: rewrite_vec(args, rewriter)?,
- fun,
- },
- Expr::WindowFunction {
- args,
- fun,
- partition_by,
- order_by,
- window_frame,
- } => Expr::WindowFunction {
- args: rewrite_vec(args, rewriter)?,
- fun,
- partition_by: rewrite_vec(partition_by, rewriter)?,
- order_by: rewrite_vec(order_by, rewriter)?,
- window_frame,
- },
- Expr::AggregateFunction {
- args,
- fun,
- distinct,
- } => Expr::AggregateFunction {
- args: rewrite_vec(args, rewriter)?,
- fun,
- distinct,
- },
- Expr::AggregateUDF { args, fun } => Expr::AggregateUDF {
- args: rewrite_vec(args, rewriter)?,
- fun,
- },
- Expr::InList {
- expr,
- list,
- negated,
- } => Expr::InList {
- expr: rewrite_boxed(expr, rewriter)?,
- list: rewrite_vec(list, rewriter)?,
- negated,
- },
- Expr::Wildcard => Expr::Wildcard,
- Expr::GetIndexedField { expr, key } => Expr::GetIndexedField {
- expr: rewrite_boxed(expr, rewriter)?,
- key,
- },
- };
-
- // now rewrite this expression itself
- if need_mutate {
- rewriter.mutate(expr)
- } else {
- Ok(expr)
- }
- }
-
- /// Simplifies this [`Expr`]`s as much as possible, evaluating
- /// constants and applying algebraic simplifications
- ///
- /// # Example:
- /// `b > 2 AND b > 2`
- /// can be written to
- /// `b > 2`
- ///
- /// ```
- /// use datafusion::logical_plan::*;
- /// use datafusion::error::Result;
- /// use datafusion::execution::context::ExecutionProps;
- ///
- /// /// Simple implementation that provides `Simplifier` the information it needs
- /// #[derive(Default)]
- /// struct Info {
- /// execution_props: ExecutionProps,
- /// };
- ///
- /// impl SimplifyInfo for Info {
- /// fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
- /// Ok(false)
- /// }
- /// fn nullable(&self, expr: &Expr) -> Result<bool> {
- /// Ok(true)
- /// }
- /// fn execution_props(&self) -> &ExecutionProps {
- /// &self.execution_props
- /// }
- /// }
- ///
- /// // b < 2
- /// let b_lt_2 = col("b").gt(lit(2));
- ///
- /// // (b < 2) OR (b < 2)
- /// let expr = b_lt_2.clone().or(b_lt_2.clone());
- ///
- /// // (b < 2) OR (b < 2) --> (b < 2)
- /// let expr = expr.simplify(&Info::default()).unwrap();
- /// assert_eq!(expr, b_lt_2);
- /// ```
- pub fn simplify<S: SimplifyInfo>(self, info: &S) -> Result<Self> {
- let mut rewriter = Simplifier::new(info);
- let mut const_evaluator = ConstEvaluator::new(info.execution_props());
-
- // TODO iterate until no changes are made during rewrite
- // (evaluating constants can enable new simplifications and
- // simplifications can enable new constant evaluation)
- // https://github.com/apache/arrow-datafusion/issues/1160
- self.rewrite(&mut const_evaluator)?.rewrite(&mut rewriter)
- }
-}
-
-impl Not for Expr {
- type Output = Self;
-
- fn not(self) -> Self::Output {
- Expr::Not(Box::new(self))
- }
-}
-
-impl std::fmt::Display for Expr {
- fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
- match self {
- Expr::BinaryExpr {
- ref left,
- ref right,
- ref op,
- } => write!(f, "{} {} {}", left, op, right),
- Expr::AggregateFunction {
- /// Name of the function
- ref fun,
- /// List of expressions to feed to the functions as arguments
- ref args,
- /// Whether this is a DISTINCT aggregation or not
- ref distinct,
- } => fmt_function(f, &fun.to_string(), *distinct, args, true),
- Expr::ScalarFunction {
- /// Name of the function
- ref fun,
- /// List of expressions to feed to the functions as arguments
- ref args,
- } => fmt_function(f, &fun.to_string(), false, args, true),
- _ => write!(f, "{:?}", self),
- }
- }
-}
-
-#[allow(clippy::boxed_local)]
-fn rewrite_boxed<R>(boxed_expr: Box<Expr>, rewriter: &mut R) -> Result<Box<Expr>>
-where
- R: ExprRewriter,
-{
- // TODO: It might be possible to avoid an allocation (the
- // Box::new) below by reusing the box.
- let expr: Expr = *boxed_expr;
- let rewritten_expr = expr.rewrite(rewriter)?;
- Ok(Box::new(rewritten_expr))
-}
-
-fn rewrite_option_box<R>(
- option_box: Option<Box<Expr>>,
- rewriter: &mut R,
-) -> Result<Option<Box<Expr>>>
-where
- R: ExprRewriter,
-{
- option_box
- .map(|expr| rewrite_boxed(expr, rewriter))
- .transpose()
-}
-
-/// rewrite a `Vec` of `Expr`s with the rewriter
-fn rewrite_vec<R>(v: Vec<Expr>, rewriter: &mut R) -> Result<Vec<Expr>>
-where
- R: ExprRewriter,
-{
- v.into_iter().map(|expr| expr.rewrite(rewriter)).collect()
-}
-
-/// Controls how the visitor recursion should proceed.
-pub enum Recursion<V: ExpressionVisitor> {
- /// Attempt to visit all the children, recursively, of this expression.
- Continue(V),
- /// Do not visit the children of this expression, though the walk
- /// of parents of this expression will not be affected
- Stop(V),
-}
-
-/// Encode the traversal of an expression tree. When passed to
-/// `Expr::accept`, `ExpressionVisitor::visit` is invoked
-/// recursively on all nodes of an expression tree. See the comments
-/// on `Expr::accept` for details on its use
-pub trait ExpressionVisitor: Sized {
- /// Invoked before any children of `expr` are visisted.
- fn pre_visit(self, expr: &Expr) -> Result<Recursion<Self>>;
-
- /// Invoked after all children of `expr` are visited. Default
- /// implementation does nothing.
- fn post_visit(self, _expr: &Expr) -> Result<Self> {
- Ok(self)
- }
-}
-
-/// Controls how the [ExprRewriter] recursion should proceed.
-pub enum RewriteRecursion {
- /// Continue rewrite / visit this expression.
- Continue,
- /// Call [mutate()] immediately and return.
- Mutate,
- /// Do not rewrite / visit the children of this expression.
- Stop,
- /// Keep recursive but skip mutate on this expression
- Skip,
-}
-
-/// Trait for potentially recursively rewriting an [`Expr`] expression
-/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is
-/// invoked recursively on all nodes of an expression tree. See the
-/// comments on `Expr::rewrite` for details on its use
-pub trait ExprRewriter: Sized {
- /// Invoked before any children of `expr` are rewritten /
- /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)`
- fn pre_visit(&mut self, _expr: &Expr) -> Result<RewriteRecursion> {
- Ok(RewriteRecursion::Continue)
- }
-
- /// Invoked after all children of `expr` have been mutated and
- /// returns a potentially modified expr.
- fn mutate(&mut self, expr: Expr) -> Result<Expr>;
-}
-
-/// The information necessary to apply algebraic simplification to an
-/// [Expr]. See [SimplifyContext] for one implementation
-pub trait SimplifyInfo {
- /// returns true if this Expr has boolean type
- fn is_boolean_type(&self, expr: &Expr) -> Result<bool>;
-
- /// returns true of this expr is nullable (could possibly be NULL)
- fn nullable(&self, expr: &Expr) -> Result<bool>;
-
- /// Returns details needed for partial expression evaluation
- fn execution_props(&self) -> &ExecutionProps;
-}
-
/// Helper struct for building [Expr::Case]
pub struct CaseBuilder {
expr: Option<Box<Expr>>,
@@ -1251,15 +129,6 @@ pub fn when(when: Expr, then: Expr) -> CaseBuilder {
}
}
-/// return a new expression l <op> r
-pub fn binary_expr(l: Expr, op: Operator, r: Expr) -> Expr {
- Expr::BinaryExpr {
- left: Box::new(l),
- op,
- right: Box::new(r),
- }
-}
-
/// return a new expression with a logical AND
pub fn and(left: Expr, right: Expr) -> Expr {
Expr::BinaryExpr {
@@ -1292,11 +161,6 @@ pub fn or(left: Expr, right: Expr) -> Expr {
}
}
-/// Create a column expression based on a qualified or unqualified column name
-pub fn col(ident: &str) -> Expr {
- Expr::Column(ident.into())
-}
-
/// Convert an expression into Column expression if it's already provided as input plan.
///
/// For example, it rewrites:
@@ -1329,183 +193,6 @@ pub fn columnize_expr(e: Expr, input_schema: &DFSchema) -> Expr {
}
}
-/// Recursively replace all Column expressions in a given expression tree with Column expressions
-/// provided by the hash map argument.
-pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
- struct ColumnReplacer<'a> {
- replace_map: &'a HashMap<&'a Column, &'a Column>,
- }
-
- impl<'a> ExprRewriter for ColumnReplacer<'a> {
- fn mutate(&mut self, expr: Expr) -> Result<Expr> {
- if let Expr::Column(c) = &expr {
- match self.replace_map.get(c) {
- Some(new_c) => Ok(Expr::Column((*new_c).to_owned())),
- None => Ok(expr),
- }
- } else {
- Ok(expr)
- }
- }
- }
-
- e.rewrite(&mut ColumnReplacer { replace_map })
-}
-
-/// Recursively call [`Column::normalize`] on all Column expressions
-/// in the `expr` expression tree.
-pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
- normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?)
-}
-
-/// Recursively call [`Column::normalize`] on all Column expressions
-/// in the `expr` expression tree.
-fn normalize_col_with_schemas(
- expr: Expr,
- schemas: &[&Arc<DFSchema>],
- using_columns: &[HashSet<Column>],
-) -> Result<Expr> {
- struct ColumnNormalizer<'a> {
- schemas: &'a [&'a Arc<DFSchema>],
- using_columns: &'a [HashSet<Column>],
- }
-
- impl<'a> ExprRewriter for ColumnNormalizer<'a> {
- fn mutate(&mut self, expr: Expr) -> Result<Expr> {
- if let Expr::Column(c) = expr {
- Ok(Expr::Column(c.normalize_with_schemas(
- self.schemas,
- self.using_columns,
- )?))
- } else {
- Ok(expr)
- }
- }
- }
-
- expr.rewrite(&mut ColumnNormalizer {
- schemas,
- using_columns,
- })
-}
-
-/// Recursively normalize all Column expressions in a list of expression trees
-pub fn normalize_cols(
- exprs: impl IntoIterator<Item = impl Into<Expr>>,
- plan: &LogicalPlan,
-) -> Result<Vec<Expr>> {
- exprs
- .into_iter()
- .map(|e| normalize_col(e.into(), plan))
- .collect()
-}
-
-/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
-/// For example, `max(x)` is written to `col("MAX(x)")`
-pub fn rewrite_sort_cols_by_aggs(
- exprs: impl IntoIterator<Item = impl Into<Expr>>,
- plan: &LogicalPlan,
-) -> Result<Vec<Expr>> {
- exprs
- .into_iter()
- .map(|e| {
- let expr = e.into();
- match expr {
- Expr::Sort {
- expr,
- asc,
- nulls_first,
- } => {
- let sort = Expr::Sort {
- expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?),
- asc,
- nulls_first,
- };
- Ok(sort)
- }
- expr => Ok(expr),
- }
- })
- .collect()
-}
-
-fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
- match plan {
- LogicalPlan::Aggregate(Aggregate {
- input, aggr_expr, ..
- }) => {
- struct Rewriter<'a> {
- plan: &'a LogicalPlan,
- input: &'a LogicalPlan,
- aggr_expr: &'a Vec<Expr>,
- }
-
- impl<'a> ExprRewriter for Rewriter<'a> {
- fn mutate(&mut self, expr: Expr) -> Result<Expr> {
- let normalized_expr = normalize_col(expr.clone(), self.plan);
- if normalized_expr.is_err() {
- // The expr is not based on Aggregate plan output. Skip it.
- return Ok(expr);
- }
- let normalized_expr = normalized_expr.unwrap();
- if let Some(found_agg) =
- self.aggr_expr.iter().find(|a| (**a) == normalized_expr)
- {
- let agg = normalize_col(found_agg.clone(), self.plan)?;
- let col = Expr::Column(
- agg.to_field(self.input.schema())
- .map(|f| f.qualified_column())?,
- );
- Ok(col)
- } else {
- Ok(expr)
- }
- }
- }
-
- expr.rewrite(&mut Rewriter {
- plan,
- input,
- aggr_expr,
- })
- }
- LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]),
- _ => Ok(expr),
- }
-}
-
-/// Recursively 'unnormalize' (remove all qualifiers) from an
-/// expression tree.
-///
-/// For example, if there were expressions like `foo.bar` this would
-/// rewrite it to just `bar`.
-pub fn unnormalize_col(expr: Expr) -> Expr {
- struct RemoveQualifier {}
-
- impl ExprRewriter for RemoveQualifier {
- fn mutate(&mut self, expr: Expr) -> Result<Expr> {
- if let Expr::Column(col) = expr {
- //let Column { relation: _, name } = col;
- Ok(Expr::Column(Column {
- relation: None,
- name: col.name,
- }))
- } else {
- Ok(expr)
- }
- }
- }
-
- expr.rewrite(&mut RemoveQualifier {})
- .expect("Unnormalize is infallable")
-}
-
-/// Recursively un-normalize all Column expressions in a list of expression trees
-#[inline]
-pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
- exprs.into_iter().map(unnormalize_col).collect()
-}
-
/// Recursively un-alias an expressions
#[inline]
pub fn unalias(expr: Expr) -> Expr {
@@ -1578,102 +265,6 @@ pub fn in_list(expr: Expr, list: Vec<Expr>, negated: bool) -> Expr {
}
}
-/// Trait for converting a type to a [`Literal`] literal expression.
-pub trait Literal {
- /// convert the value to a Literal expression
- fn lit(&self) -> Expr;
-}
-
-/// Trait for converting a type to a literal timestamp
-pub trait TimestampLiteral {
- fn lit_timestamp_nano(&self) -> Expr;
-}
-
-impl Literal for &str {
- fn lit(&self) -> Expr {
- Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned())))
- }
-}
-
-impl Literal for String {
- fn lit(&self) -> Expr {
- Expr::Literal(ScalarValue::Utf8(Some((*self).to_owned())))
- }
-}
-
-impl Literal for Vec<u8> {
- fn lit(&self) -> Expr {
- Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())))
- }
-}
-
-impl Literal for &[u8] {
- fn lit(&self) -> Expr {
- Expr::Literal(ScalarValue::Binary(Some((*self).to_owned())))
- }
-}
-
-impl Literal for ScalarValue {
- fn lit(&self) -> Expr {
- Expr::Literal(self.clone())
- }
-}
-
-macro_rules! make_literal {
- ($TYPE:ty, $SCALAR:ident, $DOC: expr) => {
- #[doc = $DOC]
- impl Literal for $TYPE {
- fn lit(&self) -> Expr {
- Expr::Literal(ScalarValue::$SCALAR(Some(self.clone())))
- }
- }
- };
-}
-
-macro_rules! make_timestamp_literal {
- ($TYPE:ty, $SCALAR:ident, $DOC: expr) => {
- #[doc = $DOC]
- impl TimestampLiteral for $TYPE {
- fn lit_timestamp_nano(&self) -> Expr {
- Expr::Literal(ScalarValue::TimestampNanosecond(
- Some((self.clone()).into()),
- None,
- ))
- }
- }
- };
-}
-
-make_literal!(bool, Boolean, "literal expression containing a bool");
-make_literal!(f32, Float32, "literal expression containing an f32");
-make_literal!(f64, Float64, "literal expression containing an f64");
-make_literal!(i8, Int8, "literal expression containing an i8");
-make_literal!(i16, Int16, "literal expression containing an i16");
-make_literal!(i32, Int32, "literal expression containing an i32");
-make_literal!(i64, Int64, "literal expression containing an i64");
-make_literal!(u8, UInt8, "literal expression containing a u8");
-make_literal!(u16, UInt16, "literal expression containing a u16");
-make_literal!(u32, UInt32, "literal expression containing a u32");
-make_literal!(u64, UInt64, "literal expression containing a u64");
-
-make_timestamp_literal!(i8, Int8, "literal expression containing an i8");
-make_timestamp_literal!(i16, Int16, "literal expression containing an i16");
-make_timestamp_literal!(i32, Int32, "literal expression containing an i32");
-make_timestamp_literal!(i64, Int64, "literal expression containing an i64");
-make_timestamp_literal!(u8, UInt8, "literal expression containing a u8");
-make_timestamp_literal!(u16, UInt16, "literal expression containing a u16");
-make_timestamp_literal!(u32, UInt32, "literal expression containing a u32");
-
-/// Create a literal expression
-pub fn lit<T: Literal>(n: T) -> Expr {
- n.lit()
-}
-
-/// Create a literal timestamp expression
-pub fn lit_timestamp_nano<T: TimestampLiteral>(n: T) -> Expr {
- n.lit_timestamp_nano()
-}
-
/// Concatenates the text representations of all the arguments. NULL arguments are ignored.
pub fn concat(args: &[Expr]) -> Expr {
Expr::ScalarFunction {
@@ -1878,311 +469,6 @@ pub fn create_udaf(
)
}
-fn fmt_function(
- f: &mut fmt::Formatter,
- fun: &str,
- distinct: bool,
- args: &[Expr],
- display: bool,
-) -> fmt::Result {
- let args: Vec<String> = match display {
- true => args.iter().map(|arg| format!("{}", arg)).collect(),
- false => args.iter().map(|arg| format!("{:?}", arg)).collect(),
- };
-
- // let args: Vec<String> = args.iter().map(|arg| format!("{:?}", arg)).collect();
- let distinct_str = match distinct {
- true => "DISTINCT ",
- false => "",
- };
- write!(f, "{}({}{})", fun, distinct_str, args.join(", "))
-}
-
-impl fmt::Debug for Expr {
- fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
- match self {
- Expr::Alias(expr, alias) => write!(f, "{:?} AS {}", expr, alias),
- Expr::Column(c) => write!(f, "{}", c),
- Expr::ScalarVariable(var_names) => write!(f, "{}", var_names.join(".")),
- Expr::Literal(v) => write!(f, "{:?}", v),
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- ..
- } => {
- write!(f, "CASE ")?;
- if let Some(e) = expr {
- write!(f, "{:?} ", e)?;
- }
- for (w, t) in when_then_expr {
- write!(f, "WHEN {:?} THEN {:?} ", w, t)?;
- }
- if let Some(e) = else_expr {
- write!(f, "ELSE {:?} ", e)?;
- }
- write!(f, "END")
- }
- Expr::Cast { expr, data_type } => {
- write!(f, "CAST({:?} AS {:?})", expr, data_type)
- }
- Expr::TryCast { expr, data_type } => {
- write!(f, "TRY_CAST({:?} AS {:?})", expr, data_type)
- }
- Expr::Not(expr) => write!(f, "NOT {:?}", expr),
- Expr::Negative(expr) => write!(f, "(- {:?})", expr),
- Expr::IsNull(expr) => write!(f, "{:?} IS NULL", expr),
- Expr::IsNotNull(expr) => write!(f, "{:?} IS NOT NULL", expr),
- Expr::BinaryExpr { left, op, right } => {
- write!(f, "{:?} {} {:?}", left, op, right)
- }
- Expr::Sort {
- expr,
- asc,
- nulls_first,
- } => {
- if *asc {
- write!(f, "{:?} ASC", expr)?;
- } else {
- write!(f, "{:?} DESC", expr)?;
- }
- if *nulls_first {
- write!(f, " NULLS FIRST")
- } else {
- write!(f, " NULLS LAST")
- }
- }
- Expr::ScalarFunction { fun, args, .. } => {
- fmt_function(f, &fun.to_string(), false, args, false)
- }
- Expr::ScalarUDF { fun, ref args, .. } => {
- fmt_function(f, &fun.name, false, args, false)
- }
- Expr::WindowFunction {
- fun,
- args,
- partition_by,
- order_by,
- window_frame,
- } => {
- fmt_function(f, &fun.to_string(), false, args, false)?;
- if !partition_by.is_empty() {
- write!(f, " PARTITION BY {:?}", partition_by)?;
- }
- if !order_by.is_empty() {
- write!(f, " ORDER BY {:?}", order_by)?;
- }
- if let Some(window_frame) = window_frame {
- write!(
- f,
- " {} BETWEEN {} AND {}",
- window_frame.units,
- window_frame.start_bound,
- window_frame.end_bound
- )?;
- }
- Ok(())
- }
- Expr::AggregateFunction {
- fun,
- distinct,
- ref args,
- ..
- } => fmt_function(f, &fun.to_string(), *distinct, args, true),
- Expr::AggregateUDF { fun, ref args, .. } => {
- fmt_function(f, &fun.name, false, args, false)
- }
- Expr::Between {
- expr,
- negated,
- low,
- high,
- } => {
- if *negated {
- write!(f, "{:?} NOT BETWEEN {:?} AND {:?}", expr, low, high)
- } else {
- write!(f, "{:?} BETWEEN {:?} AND {:?}", expr, low, high)
- }
- }
- Expr::InList {
- expr,
- list,
- negated,
- } => {
- if *negated {
- write!(f, "{:?} NOT IN ({:?})", expr, list)
- } else {
- write!(f, "{:?} IN ({:?})", expr, list)
- }
- }
- Expr::Wildcard => write!(f, "*"),
- Expr::GetIndexedField { ref expr, key } => {
- write!(f, "({:?})[{}]", expr, key)
- }
- }
- }
-}
-
-fn create_function_name(
- fun: &str,
- distinct: bool,
- args: &[Expr],
- input_schema: &DFSchema,
-) -> Result<String> {
- let names: Vec<String> = args
- .iter()
- .map(|e| create_name(e, input_schema))
- .collect::<Result<_>>()?;
- let distinct_str = match distinct {
- true => "DISTINCT ",
- false => "",
- };
- Ok(format!("{}({}{})", fun, distinct_str, names.join(",")))
-}
-
-/// Returns a readable name of an expression based on the input schema.
-/// This function recursively transverses the expression for names such as "CAST(a > 2)".
-fn create_name(e: &Expr, input_schema: &DFSchema) -> Result<String> {
- match e {
- Expr::Alias(_, name) => Ok(name.clone()),
- Expr::Column(c) => Ok(c.flat_name()),
- Expr::ScalarVariable(variable_names) => Ok(variable_names.join(".")),
- Expr::Literal(value) => Ok(format!("{:?}", value)),
- Expr::BinaryExpr { left, op, right } => {
- let left = create_name(left, input_schema)?;
- let right = create_name(right, input_schema)?;
- Ok(format!("{} {} {}", left, op, right))
- }
- Expr::Case {
- expr,
- when_then_expr,
- else_expr,
- } => {
- let mut name = "CASE ".to_string();
- if let Some(e) = expr {
- let e = create_name(e, input_schema)?;
- name += &format!("{} ", e);
- }
- for (w, t) in when_then_expr {
- let when = create_name(w, input_schema)?;
- let then = create_name(t, input_schema)?;
- name += &format!("WHEN {} THEN {} ", when, then);
- }
- if let Some(e) = else_expr {
- let e = create_name(e, input_schema)?;
- name += &format!("ELSE {} ", e);
- }
- name += "END";
- Ok(name)
- }
- Expr::Cast { expr, data_type } => {
- let expr = create_name(expr, input_schema)?;
- Ok(format!("CAST({} AS {:?})", expr, data_type))
- }
- Expr::TryCast { expr, data_type } => {
- let expr = create_name(expr, input_schema)?;
- Ok(format!("TRY_CAST({} AS {:?})", expr, data_type))
- }
- Expr::Not(expr) => {
- let expr = create_name(expr, input_schema)?;
- Ok(format!("NOT {}", expr))
- }
- Expr::Negative(expr) => {
- let expr = create_name(expr, input_schema)?;
- Ok(format!("(- {})", expr))
- }
- Expr::IsNull(expr) => {
- let expr = create_name(expr, input_schema)?;
- Ok(format!("{} IS NULL", expr))
- }
- Expr::IsNotNull(expr) => {
- let expr = create_name(expr, input_schema)?;
- Ok(format!("{} IS NOT NULL", expr))
- }
- Expr::GetIndexedField { expr, key } => {
- let expr = create_name(expr, input_schema)?;
- Ok(format!("{}[{}]", expr, key))
- }
- Expr::ScalarFunction { fun, args, .. } => {
- create_function_name(&fun.to_string(), false, args, input_schema)
- }
- Expr::ScalarUDF { fun, args, .. } => {
- create_function_name(&fun.name, false, args, input_schema)
- }
- Expr::WindowFunction {
- fun,
- args,
- window_frame,
- partition_by,
- order_by,
- } => {
- let mut parts: Vec<String> = vec![create_function_name(
- &fun.to_string(),
- false,
- args,
- input_schema,
- )?];
- if !partition_by.is_empty() {
- parts.push(format!("PARTITION BY {:?}", partition_by));
- }
- if !order_by.is_empty() {
- parts.push(format!("ORDER BY {:?}", order_by));
- }
- if let Some(window_frame) = window_frame {
- parts.push(format!("{}", window_frame));
- }
- Ok(parts.join(" "))
- }
- Expr::AggregateFunction {
- fun,
- distinct,
- args,
- ..
- } => create_function_name(&fun.to_string(), *distinct, args, input_schema),
- Expr::AggregateUDF { fun, args } => {
- let mut names = Vec::with_capacity(args.len());
- for e in args {
- names.push(create_name(e, input_schema)?);
- }
- Ok(format!("{}({})", fun.name, names.join(",")))
- }
- Expr::InList {
- expr,
- list,
- negated,
- } => {
- let expr = create_name(expr, input_schema)?;
- let list = list.iter().map(|expr| create_name(expr, input_schema));
- if *negated {
- Ok(format!("{} NOT IN ({:?})", expr, list))
- } else {
- Ok(format!("{} IN ({:?})", expr, list))
- }
- }
- Expr::Between {
- expr,
- negated,
- low,
- high,
- } => {
- let expr = create_name(expr, input_schema)?;
- let low = create_name(low, input_schema)?;
- let high = create_name(high, input_schema)?;
- if *negated {
- Ok(format!("{} NOT BETWEEN {} AND {}", expr, low, high))
- } else {
- Ok(format!("{} BETWEEN {} AND {}", expr, low, high))
- }
- }
- Expr::Sort { .. } => Err(DataFusionError::Internal(
- "Create name does not support sort expression".to_string(),
- )),
- Expr::Wildcard => Err(DataFusionError::Internal(
- "Create name does not support wildcard".to_string(),
- )),
- }
-}
-
/// Create field meta-data from an expression, for use in a result set schema
pub fn exprlist_to_fields<'a>(
expr: impl IntoIterator<Item = &'a Expr>,
@@ -2191,10 +477,25 @@ pub fn exprlist_to_fields<'a>(
expr.into_iter().map(|e| e.to_field(input_schema)).collect()
}
+/// Calls a named built in function
+/// ```
+/// use datafusion::logical_plan::*;
+///
+/// // create the expression sin(x) < 0.2
+/// let expr = call_fn("sin", vec![col("x")]).unwrap().lt(lit(0.2));
+/// ```
+pub fn call_fn(name: impl AsRef<str>, args: Vec<Expr>) -> Result<Expr> {
+ match name.as_ref().parse::<functions::BuiltinScalarFunction>() {
+ Ok(fun) => Ok(Expr::ScalarFunction { fun, args }),
+ Err(e) => Err(e),
+ }
+}
+
#[cfg(test)]
mod tests {
use super::super::{col, lit, when};
use super::*;
+ use datafusion_expr::expr_fn::binary_expr;
#[test]
fn case_when_same_literal_then_types() -> Result<()> {
@@ -2213,40 +514,6 @@ mod tests {
}
#[test]
- fn test_lit_timestamp_nano() {
- let expr = col("time").eq(lit_timestamp_nano(10)); // 10 is an implicit i32
- let expected =
- col("time").eq(lit(ScalarValue::TimestampNanosecond(Some(10), None)));
- assert_eq!(expr, expected);
-
- let i: i64 = 10;
- let expr = col("time").eq(lit_timestamp_nano(i));
- assert_eq!(expr, expected);
-
- let i: u32 = 10;
- let expr = col("time").eq(lit_timestamp_nano(i));
- assert_eq!(expr, expected);
- }
-
- #[test]
- fn rewriter_visit() {
- let mut rewriter = RecordingRewriter::default();
- col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
-
- assert_eq!(
- rewriter.v,
- vec![
- "Previsited #state = Utf8(\"CO\")",
- "Previsited #state",
- "Mutated #state",
- "Previsited Utf8(\"CO\")",
- "Mutated Utf8(\"CO\")",
- "Mutated #state = Utf8(\"CO\")"
- ]
- )
- }
-
- #[test]
fn filter_is_null_and_is_not_null() {
let col_null = col("col1");
let col_not_null = col("col2");
@@ -2257,128 +524,6 @@ mod tests {
);
}
- #[derive(Default)]
- struct RecordingRewriter {
- v: Vec<String>,
- }
- impl ExprRewriter for RecordingRewriter {
- fn mutate(&mut self, expr: Expr) -> Result<Expr> {
- self.v.push(format!("Mutated {:?}", expr));
- Ok(expr)
- }
-
- fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
- self.v.push(format!("Previsited {:?}", expr));
- Ok(RewriteRecursion::Continue)
- }
- }
-
- #[test]
- fn rewriter_rewrite() {
- let mut rewriter = FooBarRewriter {};
-
- // rewrites "foo" --> "bar"
- let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap();
- assert_eq!(rewritten, col("state").eq(lit("bar")));
-
- // doesn't wrewrite
- let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap();
- assert_eq!(rewritten, col("state").eq(lit("baz")));
- }
-
- /// rewrites all "foo" string literals to "bar"
- struct FooBarRewriter {}
- impl ExprRewriter for FooBarRewriter {
- fn mutate(&mut self, expr: Expr) -> Result<Expr> {
- match expr {
- Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => {
- let utf8_val = if utf8_val == "foo" {
- "bar".to_string()
- } else {
- utf8_val
- };
- Ok(lit(utf8_val))
- }
- // otherwise, return the expression unchanged
- expr => Ok(expr),
- }
- }
- }
-
- #[test]
- fn normalize_cols() {
- let expr = col("a") + col("b") + col("c");
-
- // Schemas with some matching and some non matching cols
- let schema_a =
- DFSchema::new(vec![make_field("tableA", "a"), make_field("tableA", "aa")])
- .unwrap();
- let schema_c =
- DFSchema::new(vec![make_field("tableC", "cc"), make_field("tableC", "c")])
- .unwrap();
- let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap();
- // non matching
- let schema_f =
- DFSchema::new(vec![make_field("tableC", "f"), make_field("tableC", "ff")])
- .unwrap();
- let schemas = vec![schema_c, schema_f, schema_b, schema_a]
- .into_iter()
- .map(Arc::new)
- .collect::<Vec<_>>();
- let schemas = schemas.iter().collect::<Vec<_>>();
-
- let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap();
- assert_eq!(
- normalized_expr,
- col("tableA.a") + col("tableB.b") + col("tableC.c")
- );
- }
-
- #[test]
- fn normalize_cols_priority() {
- let expr = col("a") + col("b");
- // Schemas with multiple matches for column a, first takes priority
- let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap();
- let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap();
- let schema_a2 = DFSchema::new(vec![make_field("tableA2", "a")]).unwrap();
- let schemas = vec![schema_a2, schema_b, schema_a]
- .into_iter()
- .map(Arc::new)
- .collect::<Vec<_>>();
- let schemas = schemas.iter().collect::<Vec<_>>();
-
- let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap();
- assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b"));
- }
-
- #[test]
- fn normalize_cols_non_exist() {
- // test normalizing columns when the name doesn't exist
- let expr = col("a") + col("b");
- let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap();
- let schemas = vec![schema_a].into_iter().map(Arc::new).collect::<Vec<_>>();
- let schemas = schemas.iter().collect::<Vec<_>>();
-
- let error = normalize_col_with_schemas(expr, &schemas, &[])
- .unwrap_err()
- .to_string();
- assert_eq!(
- error,
- "Error during planning: Column #b not found in provided schemas"
- );
- }
-
- #[test]
- fn unnormalize_cols() {
- let expr = col("tableA.a") + col("tableB.b");
- let unnormalized_expr = unnormalize_col(expr);
- assert_eq!(unnormalized_expr, col("a") + col("b"));
- }
-
- fn make_field(relation: &str, column: &str) -> DFField {
- DFField::new(Some(relation), column, DataType::Int8, false)
- }
-
#[test]
fn test_not() {
assert_eq!(lit(1).not(), !lit(1));
@@ -2559,4 +704,57 @@ mod tests {
combine_filters(&[filter1.clone(), filter2.clone(), filter3.clone()]);
assert_eq!(result, Some(and(and(filter1, filter2), filter3)));
}
+
+ #[test]
+ fn expr_schema_nullability() {
+ let expr = col("foo").eq(lit(1));
+ assert!(!expr.nullable(&MockExprSchema::new()).unwrap());
+ assert!(expr
+ .nullable(&MockExprSchema::new().with_nullable(true))
+ .unwrap());
+ }
+
+ #[test]
+ fn expr_schema_data_type() {
+ let expr = col("foo");
+ assert_eq!(
+ DataType::Utf8,
+ expr.get_type(&MockExprSchema::new().with_data_type(DataType::Utf8))
+ .unwrap()
+ );
+ }
+
+ struct MockExprSchema {
+ nullable: bool,
+ data_type: DataType,
+ }
+
+ impl MockExprSchema {
+ fn new() -> Self {
+ Self {
+ nullable: false,
+ data_type: DataType::Null,
+ }
+ }
+
+ fn with_nullable(mut self, nullable: bool) -> Self {
+ self.nullable = nullable;
+ self
+ }
+
+ fn with_data_type(mut self, data_type: DataType) -> Self {
+ self.data_type = data_type;
+ self
+ }
+ }
+
+ impl ExprSchema for MockExprSchema {
+ fn nullable(&self, _col: &Column) -> Result<bool> {
+ Ok(self.nullable)
+ }
+
+ fn data_type(&self, _col: &Column) -> Result<&DataType> {
+ Ok(&self.data_type)
+ }
+ }
}
diff --git a/datafusion/src/logical_plan/expr_rewriter.rs b/datafusion/src/logical_plan/expr_rewriter.rs
new file mode 100644
index 0000000..5062d5f
--- /dev/null
+++ b/datafusion/src/logical_plan/expr_rewriter.rs
@@ -0,0 +1,592 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Expression rewriter
+
+use super::Expr;
+use crate::logical_plan::plan::Aggregate;
+use crate::logical_plan::DFSchema;
+use crate::logical_plan::ExprSchemable;
+use crate::logical_plan::LogicalPlan;
+use datafusion_common::Column;
+use datafusion_common::Result;
+use std::collections::HashMap;
+use std::collections::HashSet;
+use std::sync::Arc;
+
+/// Controls how the [ExprRewriter] recursion should proceed.
+pub enum RewriteRecursion {
+ /// Continue rewrite / visit this expression.
+ Continue,
+ /// Call [mutate()] immediately and return.
+ Mutate,
+ /// Do not rewrite / visit the children of this expression.
+ Stop,
+ /// Keep recursive but skip mutate on this expression
+ Skip,
+}
+
+/// Trait for potentially recursively rewriting an [`Expr`] expression
+/// tree. When passed to `Expr::rewrite`, `ExpressionVisitor::mutate` is
+/// invoked recursively on all nodes of an expression tree. See the
+/// comments on `Expr::rewrite` for details on its use
+pub trait ExprRewriter<E: ExprRewritable = Expr>: Sized {
+ /// Invoked before any children of `expr` are rewritten /
+ /// visited. Default implementation returns `Ok(RewriteRecursion::Continue)`
+ fn pre_visit(&mut self, _expr: &E) -> Result<RewriteRecursion> {
+ Ok(RewriteRecursion::Continue)
+ }
+
+ /// Invoked after all children of `expr` have been mutated and
+ /// returns a potentially modified expr.
+ fn mutate(&mut self, expr: E) -> Result<E>;
+}
+
+/// a trait for marking types that are rewritable by [ExprRewriter]
+pub trait ExprRewritable: Sized {
+ /// rewrite the expression tree using the given [ExprRewriter]
+ fn rewrite<R: ExprRewriter<Self>>(self, rewriter: &mut R) -> Result<Self>;
+}
+
+impl ExprRewritable for Expr {
+ /// Performs a depth first walk of an expression and its children
+ /// to rewrite an expression, consuming `self` producing a new
+ /// [`Expr`].
+ ///
+ /// Implements a modified version of the [visitor
+ /// pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to
+ /// separate algorithms from the structure of the `Expr` tree and
+ /// make it easier to write new, efficient expression
+ /// transformation algorithms.
+ ///
+ /// For an expression tree such as
+ /// ```text
+ /// BinaryExpr (GT)
+ /// left: Column("foo")
+ /// right: Column("bar")
+ /// ```
+ ///
+ /// The nodes are visited using the following order
+ /// ```text
+ /// pre_visit(BinaryExpr(GT))
+ /// pre_visit(Column("foo"))
+ /// mutatate(Column("foo"))
+ /// pre_visit(Column("bar"))
+ /// mutate(Column("bar"))
+ /// mutate(BinaryExpr(GT))
+ /// ```
+ ///
+ /// If an Err result is returned, recursion is stopped immediately
+ ///
+ /// If [`false`] is returned on a call to pre_visit, no
+ /// children of that expression are visited, nor is mutate
+ /// called on that expression
+ ///
+ fn rewrite<R>(self, rewriter: &mut R) -> Result<Self>
+ where
+ R: ExprRewriter<Self>,
+ {
+ let need_mutate = match rewriter.pre_visit(&self)? {
+ RewriteRecursion::Mutate => return rewriter.mutate(self),
+ RewriteRecursion::Stop => return Ok(self),
+ RewriteRecursion::Continue => true,
+ RewriteRecursion::Skip => false,
+ };
+
+ // recurse into all sub expressions(and cover all expression types)
+ let expr = match self {
+ Expr::Alias(expr, name) => Expr::Alias(rewrite_boxed(expr, rewriter)?, name),
+ Expr::Column(_) => self.clone(),
+ Expr::ScalarVariable(names) => Expr::ScalarVariable(names),
+ Expr::Literal(value) => Expr::Literal(value),
+ Expr::BinaryExpr { left, op, right } => Expr::BinaryExpr {
+ left: rewrite_boxed(left, rewriter)?,
+ op,
+ right: rewrite_boxed(right, rewriter)?,
+ },
+ Expr::Not(expr) => Expr::Not(rewrite_boxed(expr, rewriter)?),
+ Expr::IsNotNull(expr) => Expr::IsNotNull(rewrite_boxed(expr, rewriter)?),
+ Expr::IsNull(expr) => Expr::IsNull(rewrite_boxed(expr, rewriter)?),
+ Expr::Negative(expr) => Expr::Negative(rewrite_boxed(expr, rewriter)?),
+ Expr::Between {
+ expr,
+ low,
+ high,
+ negated,
+ } => Expr::Between {
+ expr: rewrite_boxed(expr, rewriter)?,
+ low: rewrite_boxed(low, rewriter)?,
+ high: rewrite_boxed(high, rewriter)?,
+ negated,
+ },
+ Expr::Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ } => {
+ let expr = rewrite_option_box(expr, rewriter)?;
+ let when_then_expr = when_then_expr
+ .into_iter()
+ .map(|(when, then)| {
+ Ok((
+ rewrite_boxed(when, rewriter)?,
+ rewrite_boxed(then, rewriter)?,
+ ))
+ })
+ .collect::<Result<Vec<_>>>()?;
+
+ let else_expr = rewrite_option_box(else_expr, rewriter)?;
+
+ Expr::Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ }
+ }
+ Expr::Cast { expr, data_type } => Expr::Cast {
+ expr: rewrite_boxed(expr, rewriter)?,
+ data_type,
+ },
+ Expr::TryCast { expr, data_type } => Expr::TryCast {
+ expr: rewrite_boxed(expr, rewriter)?,
+ data_type,
+ },
+ Expr::Sort {
+ expr,
+ asc,
+ nulls_first,
+ } => Expr::Sort {
+ expr: rewrite_boxed(expr, rewriter)?,
+ asc,
+ nulls_first,
+ },
+ Expr::ScalarFunction { args, fun } => Expr::ScalarFunction {
+ args: rewrite_vec(args, rewriter)?,
+ fun,
+ },
+ Expr::ScalarUDF { args, fun } => Expr::ScalarUDF {
+ args: rewrite_vec(args, rewriter)?,
+ fun,
+ },
+ Expr::WindowFunction {
+ args,
+ fun,
+ partition_by,
+ order_by,
+ window_frame,
+ } => Expr::WindowFunction {
+ args: rewrite_vec(args, rewriter)?,
+ fun,
+ partition_by: rewrite_vec(partition_by, rewriter)?,
+ order_by: rewrite_vec(order_by, rewriter)?,
+ window_frame,
+ },
+ Expr::AggregateFunction {
+ args,
+ fun,
+ distinct,
+ } => Expr::AggregateFunction {
+ args: rewrite_vec(args, rewriter)?,
+ fun,
+ distinct,
+ },
+ Expr::AggregateUDF { args, fun } => Expr::AggregateUDF {
+ args: rewrite_vec(args, rewriter)?,
+ fun,
+ },
+ Expr::InList {
+ expr,
+ list,
+ negated,
+ } => Expr::InList {
+ expr: rewrite_boxed(expr, rewriter)?,
+ list: rewrite_vec(list, rewriter)?,
+ negated,
+ },
+ Expr::Wildcard => Expr::Wildcard,
+ Expr::GetIndexedField { expr, key } => Expr::GetIndexedField {
+ expr: rewrite_boxed(expr, rewriter)?,
+ key,
+ },
+ };
+
+ // now rewrite this expression itself
+ if need_mutate {
+ rewriter.mutate(expr)
+ } else {
+ Ok(expr)
+ }
+ }
+}
+
+#[allow(clippy::boxed_local)]
+fn rewrite_boxed<R>(boxed_expr: Box<Expr>, rewriter: &mut R) -> Result<Box<Expr>>
+where
+ R: ExprRewriter,
+{
+ // TODO: It might be possible to avoid an allocation (the
+ // Box::new) below by reusing the box.
+ let expr: Expr = *boxed_expr;
+ let rewritten_expr = expr.rewrite(rewriter)?;
+ Ok(Box::new(rewritten_expr))
+}
+
+fn rewrite_option_box<R>(
+ option_box: Option<Box<Expr>>,
+ rewriter: &mut R,
+) -> Result<Option<Box<Expr>>>
+where
+ R: ExprRewriter,
+{
+ option_box
+ .map(|expr| rewrite_boxed(expr, rewriter))
+ .transpose()
+}
+
+/// rewrite a `Vec` of `Expr`s with the rewriter
+fn rewrite_vec<R>(v: Vec<Expr>, rewriter: &mut R) -> Result<Vec<Expr>>
+where
+ R: ExprRewriter,
+{
+ v.into_iter().map(|expr| expr.rewrite(rewriter)).collect()
+}
+
+/// Rewrite sort on aggregate expressions to sort on the column of aggregate output
+/// For example, `max(x)` is written to `col("MAX(x)")`
+pub fn rewrite_sort_cols_by_aggs(
+ exprs: impl IntoIterator<Item = impl Into<Expr>>,
+ plan: &LogicalPlan,
+) -> Result<Vec<Expr>> {
+ exprs
+ .into_iter()
+ .map(|e| {
+ let expr = e.into();
+ match expr {
+ Expr::Sort {
+ expr,
+ asc,
+ nulls_first,
+ } => {
+ let sort = Expr::Sort {
+ expr: Box::new(rewrite_sort_col_by_aggs(*expr, plan)?),
+ asc,
+ nulls_first,
+ };
+ Ok(sort)
+ }
+ expr => Ok(expr),
+ }
+ })
+ .collect()
+}
+
+fn rewrite_sort_col_by_aggs(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
+ match plan {
+ LogicalPlan::Aggregate(Aggregate {
+ input, aggr_expr, ..
+ }) => {
+ struct Rewriter<'a> {
+ plan: &'a LogicalPlan,
+ input: &'a LogicalPlan,
+ aggr_expr: &'a Vec<Expr>,
+ }
+
+ impl<'a> ExprRewriter for Rewriter<'a> {
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ let normalized_expr = normalize_col(expr.clone(), self.plan);
+ if normalized_expr.is_err() {
+ // The expr is not based on Aggregate plan output. Skip it.
+ return Ok(expr);
+ }
+ let normalized_expr = normalized_expr.unwrap();
+ if let Some(found_agg) =
+ self.aggr_expr.iter().find(|a| (**a) == normalized_expr)
+ {
+ let agg = normalize_col(found_agg.clone(), self.plan)?;
+ let col = Expr::Column(
+ agg.to_field(self.input.schema())
+ .map(|f| f.qualified_column())?,
+ );
+ Ok(col)
+ } else {
+ Ok(expr)
+ }
+ }
+ }
+
+ expr.rewrite(&mut Rewriter {
+ plan,
+ input,
+ aggr_expr,
+ })
+ }
+ LogicalPlan::Projection(_) => rewrite_sort_col_by_aggs(expr, plan.inputs()[0]),
+ _ => Ok(expr),
+ }
+}
+
+/// Recursively call [`Column::normalize`] on all Column expressions
+/// in the `expr` expression tree.
+pub fn normalize_col(expr: Expr, plan: &LogicalPlan) -> Result<Expr> {
+ normalize_col_with_schemas(expr, &plan.all_schemas(), &plan.using_columns()?)
+}
+
+/// Recursively call [`Column::normalize`] on all Column expressions
+/// in the `expr` expression tree.
+fn normalize_col_with_schemas(
+ expr: Expr,
+ schemas: &[&Arc<DFSchema>],
+ using_columns: &[HashSet<Column>],
+) -> Result<Expr> {
+ struct ColumnNormalizer<'a> {
+ schemas: &'a [&'a Arc<DFSchema>],
+ using_columns: &'a [HashSet<Column>],
+ }
+
+ impl<'a> ExprRewriter for ColumnNormalizer<'a> {
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ if let Expr::Column(c) = expr {
+ Ok(Expr::Column(c.normalize_with_schemas(
+ self.schemas,
+ self.using_columns,
+ )?))
+ } else {
+ Ok(expr)
+ }
+ }
+ }
+
+ expr.rewrite(&mut ColumnNormalizer {
+ schemas,
+ using_columns,
+ })
+}
+
+/// Recursively normalize all Column expressions in a list of expression trees
+pub fn normalize_cols(
+ exprs: impl IntoIterator<Item = impl Into<Expr>>,
+ plan: &LogicalPlan,
+) -> Result<Vec<Expr>> {
+ exprs
+ .into_iter()
+ .map(|e| normalize_col(e.into(), plan))
+ .collect()
+}
+
+/// Recursively replace all Column expressions in a given expression tree with Column expressions
+/// provided by the hash map argument.
+pub fn replace_col(e: Expr, replace_map: &HashMap<&Column, &Column>) -> Result<Expr> {
+ struct ColumnReplacer<'a> {
+ replace_map: &'a HashMap<&'a Column, &'a Column>,
+ }
+
+ impl<'a> ExprRewriter for ColumnReplacer<'a> {
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ if let Expr::Column(c) = &expr {
+ match self.replace_map.get(c) {
+ Some(new_c) => Ok(Expr::Column((*new_c).to_owned())),
+ None => Ok(expr),
+ }
+ } else {
+ Ok(expr)
+ }
+ }
+ }
+
+ e.rewrite(&mut ColumnReplacer { replace_map })
+}
+
+/// Recursively 'unnormalize' (remove all qualifiers) from an
+/// expression tree.
+///
+/// For example, if there were expressions like `foo.bar` this would
+/// rewrite it to just `bar`.
+pub fn unnormalize_col(expr: Expr) -> Expr {
+ struct RemoveQualifier {}
+
+ impl ExprRewriter for RemoveQualifier {
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ if let Expr::Column(col) = expr {
+ //let Column { relation: _, name } = col;
+ Ok(Expr::Column(Column {
+ relation: None,
+ name: col.name,
+ }))
+ } else {
+ Ok(expr)
+ }
+ }
+ }
+
+ expr.rewrite(&mut RemoveQualifier {})
+ .expect("Unnormalize is infallable")
+}
+
+/// Recursively un-normalize all Column expressions in a list of expression trees
+#[inline]
+pub fn unnormalize_cols(exprs: impl IntoIterator<Item = Expr>) -> Vec<Expr> {
+ exprs.into_iter().map(unnormalize_col).collect()
+}
+
+#[cfg(test)]
+mod test {
+ use super::*;
+ use crate::logical_plan::DFField;
+ use crate::prelude::{col, lit};
+ use arrow::datatypes::DataType;
+ use datafusion_common::ScalarValue;
+
+ #[derive(Default)]
+ struct RecordingRewriter {
+ v: Vec<String>,
+ }
+ impl ExprRewriter for RecordingRewriter {
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ self.v.push(format!("Mutated {:?}", expr));
+ Ok(expr)
+ }
+
+ fn pre_visit(&mut self, expr: &Expr) -> Result<RewriteRecursion> {
+ self.v.push(format!("Previsited {:?}", expr));
+ Ok(RewriteRecursion::Continue)
+ }
+ }
+
+ #[test]
+ fn rewriter_rewrite() {
+ let mut rewriter = FooBarRewriter {};
+
+ // rewrites "foo" --> "bar"
+ let rewritten = col("state").eq(lit("foo")).rewrite(&mut rewriter).unwrap();
+ assert_eq!(rewritten, col("state").eq(lit("bar")));
+
+ // doesn't wrewrite
+ let rewritten = col("state").eq(lit("baz")).rewrite(&mut rewriter).unwrap();
+ assert_eq!(rewritten, col("state").eq(lit("baz")));
+ }
+
+ /// rewrites all "foo" string literals to "bar"
+ struct FooBarRewriter {}
+ impl ExprRewriter for FooBarRewriter {
+ fn mutate(&mut self, expr: Expr) -> Result<Expr> {
+ match expr {
+ Expr::Literal(ScalarValue::Utf8(Some(utf8_val))) => {
+ let utf8_val = if utf8_val == "foo" {
+ "bar".to_string()
+ } else {
+ utf8_val
+ };
+ Ok(lit(utf8_val))
+ }
+ // otherwise, return the expression unchanged
+ expr => Ok(expr),
+ }
+ }
+ }
+
+ #[test]
+ fn normalize_cols() {
+ let expr = col("a") + col("b") + col("c");
+
+ // Schemas with some matching and some non matching cols
+ let schema_a =
+ DFSchema::new(vec![make_field("tableA", "a"), make_field("tableA", "aa")])
+ .unwrap();
+ let schema_c =
+ DFSchema::new(vec![make_field("tableC", "cc"), make_field("tableC", "c")])
+ .unwrap();
+ let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap();
+ // non matching
+ let schema_f =
+ DFSchema::new(vec![make_field("tableC", "f"), make_field("tableC", "ff")])
+ .unwrap();
+ let schemas = vec![schema_c, schema_f, schema_b, schema_a]
+ .into_iter()
+ .map(Arc::new)
+ .collect::<Vec<_>>();
+ let schemas = schemas.iter().collect::<Vec<_>>();
+
+ let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap();
+ assert_eq!(
+ normalized_expr,
+ col("tableA.a") + col("tableB.b") + col("tableC.c")
+ );
+ }
+
+ #[test]
+ fn normalize_cols_priority() {
+ let expr = col("a") + col("b");
+ // Schemas with multiple matches for column a, first takes priority
+ let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap();
+ let schema_b = DFSchema::new(vec![make_field("tableB", "b")]).unwrap();
+ let schema_a2 = DFSchema::new(vec![make_field("tableA2", "a")]).unwrap();
+ let schemas = vec![schema_a2, schema_b, schema_a]
+ .into_iter()
+ .map(Arc::new)
+ .collect::<Vec<_>>();
+ let schemas = schemas.iter().collect::<Vec<_>>();
+
+ let normalized_expr = normalize_col_with_schemas(expr, &schemas, &[]).unwrap();
+ assert_eq!(normalized_expr, col("tableA2.a") + col("tableB.b"));
+ }
+
+ #[test]
+ fn normalize_cols_non_exist() {
+ // test normalizing columns when the name doesn't exist
+ let expr = col("a") + col("b");
+ let schema_a = DFSchema::new(vec![make_field("tableA", "a")]).unwrap();
+ let schemas = vec![schema_a].into_iter().map(Arc::new).collect::<Vec<_>>();
+ let schemas = schemas.iter().collect::<Vec<_>>();
+
+ let error = normalize_col_with_schemas(expr, &schemas, &[])
+ .unwrap_err()
+ .to_string();
+ assert_eq!(
+ error,
+ "Error during planning: Column #b not found in provided schemas"
+ );
+ }
+
+ #[test]
+ fn unnormalize_cols() {
+ let expr = col("tableA.a") + col("tableB.b");
+ let unnormalized_expr = unnormalize_col(expr);
+ assert_eq!(unnormalized_expr, col("a") + col("b"));
+ }
+
+ fn make_field(relation: &str, column: &str) -> DFField {
+ DFField::new(Some(relation), column, DataType::Int8, false)
+ }
+
+ #[test]
+ fn rewriter_visit() {
+ let mut rewriter = RecordingRewriter::default();
+ col("state").eq(lit("CO")).rewrite(&mut rewriter).unwrap();
+
+ assert_eq!(
+ rewriter.v,
+ vec![
+ "Previsited #state = Utf8(\"CO\")",
+ "Previsited #state",
+ "Mutated #state",
+ "Previsited Utf8(\"CO\")",
+ "Mutated Utf8(\"CO\")",
+ "Mutated #state = Utf8(\"CO\")"
+ ]
+ )
+ }
+}
diff --git a/datafusion/src/logical_plan/expr_schema.rs b/datafusion/src/logical_plan/expr_schema.rs
new file mode 100644
index 0000000..7bad353
--- /dev/null
+++ b/datafusion/src/logical_plan/expr_schema.rs
@@ -0,0 +1,232 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use super::Expr;
+use crate::field_util::get_indexed_field;
+use crate::physical_plan::{
+ aggregates, expressions::binary_operator_data_type, functions, window_functions,
+};
+use arrow::compute::cast::can_cast_types;
+use arrow::datatypes::DataType;
+use datafusion_common::field_util::FieldExt;
+use datafusion_common::{DFField, DFSchema, DataFusionError, ExprSchema, Result};
+
+/// trait to allow expr to typable with respect to a schema
+pub trait ExprSchemable {
+ /// given a schema, return the type of the expr
+ fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType>;
+
+ /// given a schema, return the nullability of the expr
+ fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool>;
+
+ /// convert to a field with respect to a schema
+ fn to_field(&self, input_schema: &DFSchema) -> Result<DFField>;
+
+ /// cast to a type with respect to a schema
+ fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr>;
+}
+
+impl ExprSchemable for Expr {
+ /// Returns the [arrow::datatypes::DataType] of the expression
+ /// based on [ExprSchema]
+ ///
+ /// Note: [DFSchema] implements [ExprSchema].
+ ///
+ /// # Errors
+ ///
+ /// This function errors when it is not possible to compute its
+ /// [arrow::datatypes::DataType]. This happens when e.g. the
+ /// expression refers to a column that does not exist in the
+ /// schema, or when the expression is incorrectly typed
+ /// (e.g. `[utf8] + [bool]`).
+ fn get_type<S: ExprSchema>(&self, schema: &S) -> Result<DataType> {
+ match self {
+ Expr::Alias(expr, _) | Expr::Sort { expr, .. } | Expr::Negative(expr) => {
+ expr.get_type(schema)
+ }
+ Expr::Column(c) => Ok(schema.data_type(c)?.clone()),
+ Expr::ScalarVariable(_) => Ok(DataType::Utf8),
+ Expr::Literal(l) => Ok(l.get_datatype()),
+ Expr::Case { when_then_expr, .. } => when_then_expr[0].1.get_type(schema),
+ Expr::Cast { data_type, .. } | Expr::TryCast { data_type, .. } => {
+ Ok(data_type.clone())
+ }
+ Expr::ScalarUDF { fun, args } => {
+ let data_types = args
+ .iter()
+ .map(|e| e.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+ Ok((fun.return_type)(&data_types)?.as_ref().clone())
+ }
+ Expr::ScalarFunction { fun, args } => {
+ let data_types = args
+ .iter()
+ .map(|e| e.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+ functions::return_type(fun, &data_types)
+ }
+ Expr::WindowFunction { fun, args, .. } => {
+ let data_types = args
+ .iter()
+ .map(|e| e.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+ window_functions::return_type(fun, &data_types)
+ }
+ Expr::AggregateFunction { fun, args, .. } => {
+ let data_types = args
+ .iter()
+ .map(|e| e.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+ aggregates::return_type(fun, &data_types)
+ }
+ Expr::AggregateUDF { fun, args, .. } => {
+ let data_types = args
+ .iter()
+ .map(|e| e.get_type(schema))
+ .collect::<Result<Vec<_>>>()?;
+ Ok((fun.return_type)(&data_types)?.as_ref().clone())
+ }
+ Expr::Not(_)
+ | Expr::IsNull(_)
+ | Expr::Between { .. }
+ | Expr::InList { .. }
+ | Expr::IsNotNull(_) => Ok(DataType::Boolean),
+ Expr::BinaryExpr {
+ ref left,
+ ref right,
+ ref op,
+ } => binary_operator_data_type(
+ &left.get_type(schema)?,
+ op,
+ &right.get_type(schema)?,
+ ),
+ Expr::Wildcard => Err(DataFusionError::Internal(
+ "Wildcard expressions are not valid in a logical query plan".to_owned(),
+ )),
+ Expr::GetIndexedField { ref expr, key } => {
+ let data_type = expr.get_type(schema)?;
+
+ get_indexed_field(&data_type, key).map(|x| x.data_type().clone())
+ }
+ }
+ }
+
+ /// Returns the nullability of the expression based on [ExprSchema].
+ ///
+ /// Note: [DFSchema] implements [ExprSchema].
+ ///
+ /// # Errors
+ ///
+ /// This function errors when it is not possible to compute its
+ /// nullability. This happens when the expression refers to a
+ /// column that does not exist in the schema.
+ fn nullable<S: ExprSchema>(&self, input_schema: &S) -> Result<bool> {
+ match self {
+ Expr::Alias(expr, _)
+ | Expr::Not(expr)
+ | Expr::Negative(expr)
+ | Expr::Sort { expr, .. }
+ | Expr::Between { expr, .. }
+ | Expr::InList { expr, .. } => expr.nullable(input_schema),
+ Expr::Column(c) => input_schema.nullable(c),
+ Expr::Literal(value) => Ok(value.is_null()),
+ Expr::Case {
+ when_then_expr,
+ else_expr,
+ ..
+ } => {
+ // this expression is nullable if any of the input expressions are nullable
+ let then_nullable = when_then_expr
+ .iter()
+ .map(|(_, t)| t.nullable(input_schema))
+ .collect::<Result<Vec<_>>>()?;
+ if then_nullable.contains(&true) {
+ Ok(true)
+ } else if let Some(e) = else_expr {
+ e.nullable(input_schema)
+ } else {
+ Ok(false)
+ }
+ }
+ Expr::Cast { expr, .. } => expr.nullable(input_schema),
+ Expr::ScalarVariable(_)
+ | Expr::TryCast { .. }
+ | Expr::ScalarFunction { .. }
+ | Expr::ScalarUDF { .. }
+ | Expr::WindowFunction { .. }
+ | Expr::AggregateFunction { .. }
+ | Expr::AggregateUDF { .. } => Ok(true),
+ Expr::IsNull(_) | Expr::IsNotNull(_) => Ok(false),
+ Expr::BinaryExpr {
+ ref left,
+ ref right,
+ ..
+ } => Ok(left.nullable(input_schema)? || right.nullable(input_schema)?),
+ Expr::Wildcard => Err(DataFusionError::Internal(
+ "Wildcard expressions are not valid in a logical query plan".to_owned(),
+ )),
+ Expr::GetIndexedField { ref expr, key } => {
+ let data_type = expr.get_type(input_schema)?;
+ get_indexed_field(&data_type, key).map(|x| x.is_nullable())
+ }
+ }
+ }
+
+ /// Returns a [arrow::datatypes::Field] compatible with this expression.
+ fn to_field(&self, input_schema: &DFSchema) -> Result<DFField> {
+ match self {
+ Expr::Column(c) => Ok(DFField::new(
+ c.relation.as_deref(),
+ &c.name,
+ self.get_type(input_schema)?,
+ self.nullable(input_schema)?,
+ )),
+ _ => Ok(DFField::new(
+ None,
+ &self.name(input_schema)?,
+ self.get_type(input_schema)?,
+ self.nullable(input_schema)?,
+ )),
+ }
+ }
+
+ /// Wraps this expression in a cast to a target [arrow::datatypes::DataType].
+ ///
+ /// # Errors
+ ///
+ /// This function errors when it is impossible to cast the
+ /// expression to the target [arrow::datatypes::DataType].
+ fn cast_to<S: ExprSchema>(self, cast_to_type: &DataType, schema: &S) -> Result<Expr> {
+ // TODO(kszucs): most of the operations do not validate the type correctness
+ // like all of the binary expressions below. Perhaps Expr should track the
+ // type of the expression?
+ let this_type = self.get_type(schema)?;
+ if this_type == *cast_to_type {
+ Ok(self)
+ } else if can_cast_types(&this_type, cast_to_type) {
+ Ok(Expr::Cast {
+ expr: Box::new(self),
+ data_type: cast_to_type.clone(),
+ })
+ } else {
+ Err(DataFusionError::Plan(format!(
+ "Cannot automatically convert {:?} to {:?}",
+ this_type, cast_to_type
+ )))
+ }
+ }
+}
diff --git a/datafusion/src/logical_plan/expr_simplier.rs b/datafusion/src/logical_plan/expr_simplier.rs
new file mode 100644
index 0000000..06e5856
--- /dev/null
+++ b/datafusion/src/logical_plan/expr_simplier.rs
@@ -0,0 +1,97 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Expression simplifier
+
+use super::Expr;
+use super::ExprRewritable;
+use crate::execution::context::ExecutionProps;
+use crate::optimizer::simplify_expressions::{ConstEvaluator, Simplifier};
+use datafusion_common::Result;
+
+/// The information necessary to apply algebraic simplification to an
+/// [Expr]. See [SimplifyContext] for one implementation
+pub trait SimplifyInfo {
+ /// returns true if this Expr has boolean type
+ fn is_boolean_type(&self, expr: &Expr) -> Result<bool>;
+
+ /// returns true of this expr is nullable (could possibly be NULL)
+ fn nullable(&self, expr: &Expr) -> Result<bool>;
+
+ /// Returns details needed for partial expression evaluation
+ fn execution_props(&self) -> &ExecutionProps;
+}
+
+/// trait for types that can be simplified
+pub trait ExprSimplifiable: Sized {
+ /// simplify this trait object using the given SimplifyInfo
+ fn simplify<S: SimplifyInfo>(self, info: &S) -> Result<Self>;
+}
+
+impl ExprSimplifiable for Expr {
+ /// Simplifies this [`Expr`]`s as much as possible, evaluating
+ /// constants and applying algebraic simplifications
+ ///
+ /// # Example:
+ /// `b > 2 AND b > 2`
+ /// can be written to
+ /// `b > 2`
+ ///
+ /// ```
+ /// use datafusion::logical_plan::*;
+ /// use datafusion::error::Result;
+ /// use datafusion::execution::context::ExecutionProps;
+ ///
+ /// /// Simple implementation that provides `Simplifier` the information it needs
+ /// #[derive(Default)]
+ /// struct Info {
+ /// execution_props: ExecutionProps,
+ /// };
+ ///
+ /// impl SimplifyInfo for Info {
+ /// fn is_boolean_type(&self, expr: &Expr) -> Result<bool> {
+ /// Ok(false)
+ /// }
+ /// fn nullable(&self, expr: &Expr) -> Result<bool> {
+ /// Ok(true)
+ /// }
+ /// fn execution_props(&self) -> &ExecutionProps {
+ /// &self.execution_props
+ /// }
+ /// }
+ ///
+ /// // b < 2
+ /// let b_lt_2 = col("b").gt(lit(2));
+ ///
+ /// // (b < 2) OR (b < 2)
+ /// let expr = b_lt_2.clone().or(b_lt_2.clone());
+ ///
+ /// // (b < 2) OR (b < 2) --> (b < 2)
+ /// let expr = expr.simplify(&Info::default()).unwrap();
+ /// assert_eq!(expr, b_lt_2);
+ /// ```
+ fn simplify<S: SimplifyInfo>(self, info: &S) -> Result<Self> {
+ let mut rewriter = Simplifier::new(info);
+ let mut const_evaluator = ConstEvaluator::new(info.execution_props());
+
+ // TODO iterate until no changes are made during rewrite
+ // (evaluating constants can enable new simplifications and
+ // simplifications can enable new constant evaluation)
+ // https://github.com/apache/arrow-datafusion/issues/1160
+ self.rewrite(&mut const_evaluator)?.rewrite(&mut rewriter)
+ }
+}
diff --git a/datafusion/src/logical_plan/expr_visitor.rs b/datafusion/src/logical_plan/expr_visitor.rs
new file mode 100644
index 0000000..26084fb
--- /dev/null
+++ b/datafusion/src/logical_plan/expr_visitor.rs
@@ -0,0 +1,176 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Expression visitor
+
+use super::Expr;
+use datafusion_common::Result;
+
+/// Controls how the visitor recursion should proceed.
+pub enum Recursion<V: ExpressionVisitor> {
+ /// Attempt to visit all the children, recursively, of this expression.
+ Continue(V),
+ /// Do not visit the children of this expression, though the walk
+ /// of parents of this expression will not be affected
+ Stop(V),
+}
+
+/// Encode the traversal of an expression tree. When passed to
+/// `Expr::accept`, `ExpressionVisitor::visit` is invoked
+/// recursively on all nodes of an expression tree. See the comments
+/// on `Expr::accept` for details on its use
+pub trait ExpressionVisitor<E: ExprVisitable = Expr>: Sized {
+ /// Invoked before any children of `expr` are visisted.
+ fn pre_visit(self, expr: &E) -> Result<Recursion<Self>>
+ where
+ Self: ExpressionVisitor;
+
+ /// Invoked after all children of `expr` are visited. Default
+ /// implementation does nothing.
+ fn post_visit(self, _expr: &E) -> Result<Self> {
+ Ok(self)
+ }
+}
+
+/// trait for types that can be visited by [`ExpressionVisitor`]
+pub trait ExprVisitable: Sized {
+ /// accept a visitor, calling `visit` on all children of this
+ fn accept<V: ExpressionVisitor<Self>>(&self, visitor: V) -> Result<V>;
+}
+
+impl ExprVisitable for Expr {
+ /// Performs a depth first walk of an expression and
+ /// its children, calling [`ExpressionVisitor::pre_visit`] and
+ /// `visitor.post_visit`.
+ ///
+ /// Implements the [visitor pattern](https://en.wikipedia.org/wiki/Visitor_pattern) to
+ /// separate expression algorithms from the structure of the
+ /// `Expr` tree and make it easier to add new types of expressions
+ /// and algorithms that walk the tree.
+ ///
+ /// For an expression tree such as
+ /// ```text
+ /// BinaryExpr (GT)
+ /// left: Column("foo")
+ /// right: Column("bar")
+ /// ```
+ ///
+ /// The nodes are visited using the following order
+ /// ```text
+ /// pre_visit(BinaryExpr(GT))
+ /// pre_visit(Column("foo"))
+ /// pre_visit(Column("bar"))
+ /// post_visit(Column("bar"))
+ /// post_visit(Column("bar"))
+ /// post_visit(BinaryExpr(GT))
+ /// ```
+ ///
+ /// If an Err result is returned, recursion is stopped immediately
+ ///
+ /// If `Recursion::Stop` is returned on a call to pre_visit, no
+ /// children of that expression are visited, nor is post_visit
+ /// called on that expression
+ ///
+ fn accept<V: ExpressionVisitor>(&self, visitor: V) -> Result<V> {
+ let visitor = match visitor.pre_visit(self)? {
+ Recursion::Continue(visitor) => visitor,
+ // If the recursion should stop, do not visit children
+ Recursion::Stop(visitor) => return Ok(visitor),
+ };
+
+ // recurse (and cover all expression types)
+ let visitor = match self {
+ Expr::Alias(expr, _)
+ | Expr::Not(expr)
+ | Expr::IsNotNull(expr)
+ | Expr::IsNull(expr)
+ | Expr::Negative(expr)
+ | Expr::Cast { expr, .. }
+ | Expr::TryCast { expr, .. }
+ | Expr::Sort { expr, .. }
+ | Expr::GetIndexedField { expr, .. } => expr.accept(visitor),
+ Expr::Column(_)
+ | Expr::ScalarVariable(_)
+ | Expr::Literal(_)
+ | Expr::Wildcard => Ok(visitor),
+ Expr::BinaryExpr { left, right, .. } => {
+ let visitor = left.accept(visitor)?;
+ right.accept(visitor)
+ }
+ Expr::Between {
+ expr, low, high, ..
+ } => {
+ let visitor = expr.accept(visitor)?;
+ let visitor = low.accept(visitor)?;
+ high.accept(visitor)
+ }
+ Expr::Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ } => {
+ let visitor = if let Some(expr) = expr.as_ref() {
+ expr.accept(visitor)
+ } else {
+ Ok(visitor)
+ }?;
+ let visitor = when_then_expr.iter().try_fold(
+ visitor,
+ |visitor, (when, then)| {
+ let visitor = when.accept(visitor)?;
+ then.accept(visitor)
+ },
+ )?;
+ if let Some(else_expr) = else_expr.as_ref() {
+ else_expr.accept(visitor)
+ } else {
+ Ok(visitor)
+ }
+ }
+ Expr::ScalarFunction { args, .. }
+ | Expr::ScalarUDF { args, .. }
+ | Expr::AggregateFunction { args, .. }
+ | Expr::AggregateUDF { args, .. } => args
+ .iter()
+ .try_fold(visitor, |visitor, arg| arg.accept(visitor)),
+ Expr::WindowFunction {
+ args,
... 5971 lines suppressed ...