You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ag...@apache.org on 2021/05/04 12:25:04 UTC
[arrow-datafusion] branch master updated: Add datafusion-python
(#69)
This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 46bde0b Add datafusion-python (#69)
46bde0b is described below
commit 46bde0bd148aacf1677a575cb9ddbc154b6c4fb3
Author: Jorge Leitao <jo...@gmail.com>
AuthorDate: Tue May 4 14:24:57 2021 +0200
Add datafusion-python (#69)
* Added Python project.
* Update python/Cargo.toml
Co-authored-by: Andy Grove <an...@users.noreply.github.com>
* Update python/Cargo.toml
Co-authored-by: Uwe L. Korn <xh...@users.noreply.github.com>
* Added license and black formatting.
* License
* Fixing build.
* TesTestt
* Bumped to latest DataFusion.
* Bumped nightly.
* Bumped pyarrow in tests.
* Added some tests back.
Co-authored-by: Andy Grove <an...@users.noreply.github.com>
Co-authored-by: Uwe L. Korn <xh...@users.noreply.github.com>
---
.github/workflows/python_build.yml | 89 ++++++++++
.github/workflows/python_test.yaml | 58 +++++++
Cargo.toml | 4 +-
dev/release/rat_exclude_files.txt | 1 +
Cargo.toml => python/.cargo/config | 16 +-
Cargo.toml => python/.dockerignore | 13 +-
Cargo.toml => python/.gitignore | 14 +-
python/Cargo.toml | 57 +++++++
python/README.md | 146 ++++++++++++++++
Cargo.toml => python/pyproject.toml | 14 +-
python/rust-toolchain | 1 +
python/src/context.rs | 115 +++++++++++++
python/src/dataframe.rs | 161 ++++++++++++++++++
python/src/errors.rs | 61 +++++++
python/src/expression.rs | 162 ++++++++++++++++++
python/src/functions.rs | 165 ++++++++++++++++++
python/src/lib.rs | 44 +++++
python/src/scalar.rs | 36 ++++
python/src/to_py.rs | 77 +++++++++
python/src/to_rust.rs | 111 +++++++++++++
python/src/types.rs | 76 +++++++++
python/src/udaf.rs | 147 +++++++++++++++++
python/src/udf.rs | 62 +++++++
Cargo.toml => python/tests/__init__.py | 12 --
python/tests/generic.py | 75 +++++++++
python/tests/test_df.py | 115 +++++++++++++
python/tests/test_sql.py | 294 +++++++++++++++++++++++++++++++++
python/tests/test_udaf.py | 91 ++++++++++
28 files changed, 2160 insertions(+), 57 deletions(-)
diff --git a/.github/workflows/python_build.yml b/.github/workflows/python_build.yml
new file mode 100644
index 0000000..c86bb81
--- /dev/null
+++ b/.github/workflows/python_build.yml
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+name: Build
+on:
+ push:
+ tags:
+ - v*
+
+jobs:
+ build-python-mac-win:
+ name: Mac/Win
+ runs-on: ${{ matrix.os }}
+ strategy:
+ fail-fast: false
+ matrix:
+ python-version: [3.6, 3.7, 3.8]
+ os: [macos-latest, windows-latest]
+ steps:
+ - uses: actions/checkout@v2
+
+ - uses: actions/setup-python@v1
+ with:
+ python-version: ${{ matrix.python-version }}
+
+ - uses: actions-rs/toolchain@v1
+ with:
+ toolchain: nightly-2021-01-06
+
+ - name: Install dependencies
+ run: |
+ python -m pip install --upgrade pip
+ pip install maturin
+
+ - name: Build Python package
+ run: cd python && maturin build --release --no-sdist --strip --interpreter python${{matrix.python_version}}
+
+ - name: List wheels
+ if: matrix.os == 'windows-latest'
+ run: dir python/target\wheels\
+
+ - name: List wheels
+ if: matrix.os != 'windows-latest'
+ run: find ./python/target/wheels/
+
+ - name: Archive wheels
+ uses: actions/upload-artifact@v2
+ with:
+ name: dist
+ path: python/target/wheels/*
+
+ build-manylinux:
+ name: Manylinux
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - name: Build wheels
+ run: docker run --rm -v $(pwd):/io konstin2/maturin build --release --manylinux
+ - name: Archive wheels
+ uses: actions/upload-artifact@v2
+ with:
+ name: dist
+ path: python/target/wheels/*
+
+ release:
+ name: Publish in PyPI
+ needs: [build-manylinux, build-python-mac-win]
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/download-artifact@v2
+ - name: Publish to PyPI
+ uses: pypa/gh-action-pypi-publish@master
+ with:
+ user: __token__
+ password: ${{ secrets.pypi_password }}
diff --git a/.github/workflows/python_test.yaml b/.github/workflows/python_test.yaml
new file mode 100644
index 0000000..3b2111b
--- /dev/null
+++ b/.github/workflows/python_test.yaml
@@ -0,0 +1,58 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+name: Python test
+on: [push, pull_request]
+
+jobs:
+ test:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ - name: Setup Rust toolchain
+ run: |
+ rustup toolchain install nightly-2021-01-06
+ rustup default nightly-2021-01-06
+ rustup component add rustfmt
+ - name: Cache Cargo
+ uses: actions/cache@v2
+ with:
+ path: /home/runner/.cargo
+ key: cargo-maturin-cache-
+ - name: Cache Rust dependencies
+ uses: actions/cache@v2
+ with:
+ path: /home/runner/target
+ key: target-maturin-cache-
+ - uses: actions/setup-python@v2
+ with:
+ python-version: '3.7'
+ - name: Install Python dependencies
+ run: python -m pip install --upgrade pip setuptools wheel
+ - name: Run tests
+ run: |
+ cd python/
+ export CARGO_HOME="/home/runner/.cargo"
+ export CARGO_TARGET_DIR="/home/runner/target"
+
+ python -m venv venv
+ source venv/bin/activate
+
+ pip install maturin==0.10.4 toml==0.10.1 pyarrow==4.0.0
+ maturin develop
+
+ python -m unittest discover tests
diff --git a/Cargo.toml b/Cargo.toml
index fa36a0c..9795cb6 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -25,4 +25,6 @@ members = [
"ballista/rust/core",
"ballista/rust/executor",
"ballista/rust/scheduler",
-]
\ No newline at end of file
+]
+
+exclude = ["python"]
diff --git a/dev/release/rat_exclude_files.txt b/dev/release/rat_exclude_files.txt
index b94c0ea..6126699 100644
--- a/dev/release/rat_exclude_files.txt
+++ b/dev/release/rat_exclude_files.txt
@@ -104,3 +104,4 @@ rust-toolchain
benchmarks/queries/q*.sql
ballista/rust/scheduler/testdata/*
ballista/ui/scheduler/yarn.lock
+python/rust-toolchain
diff --git a/Cargo.toml b/python/.cargo/config
similarity index 77%
copy from Cargo.toml
copy to python/.cargo/config
index fa36a0c..0b24f30 100644
--- a/Cargo.toml
+++ b/python/.cargo/config
@@ -15,14 +15,8 @@
# specific language governing permissions and limitations
# under the License.
-[workspace]
-members = [
- "datafusion",
- "datafusion-cli",
- "datafusion-examples",
- "benchmarks",
- "ballista/rust/client",
- "ballista/rust/core",
- "ballista/rust/executor",
- "ballista/rust/scheduler",
-]
\ No newline at end of file
+[target.x86_64-apple-darwin]
+rustflags = [
+ "-C", "link-arg=-undefined",
+ "-C", "link-arg=dynamic_lookup",
+]
diff --git a/Cargo.toml b/python/.dockerignore
similarity index 77%
copy from Cargo.toml
copy to python/.dockerignore
index fa36a0c..08c131c 100644
--- a/Cargo.toml
+++ b/python/.dockerignore
@@ -15,14 +15,5 @@
# specific language governing permissions and limitations
# under the License.
-[workspace]
-members = [
- "datafusion",
- "datafusion-cli",
- "datafusion-examples",
- "benchmarks",
- "ballista/rust/client",
- "ballista/rust/core",
- "ballista/rust/executor",
- "ballista/rust/scheduler",
-]
\ No newline at end of file
+target
+venv
diff --git a/Cargo.toml b/python/.gitignore
similarity index 77%
copy from Cargo.toml
copy to python/.gitignore
index fa36a0c..48fe4db 100644
--- a/Cargo.toml
+++ b/python/.gitignore
@@ -15,14 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-[workspace]
-members = [
- "datafusion",
- "datafusion-cli",
- "datafusion-examples",
- "benchmarks",
- "ballista/rust/client",
- "ballista/rust/core",
- "ballista/rust/executor",
- "ballista/rust/scheduler",
-]
\ No newline at end of file
+/target
+Cargo.lock
+venv
diff --git a/python/Cargo.toml b/python/Cargo.toml
new file mode 100644
index 0000000..0707205
--- /dev/null
+++ b/python/Cargo.toml
@@ -0,0 +1,57 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+[package]
+name = "datafusion"
+version = "0.2.1"
+homepage = "https://github.com/apache/arrow"
+repository = "https://github.com/apache/arrow"
+authors = ["Apache Arrow <de...@arrow.apache.org>"]
+description = "Build and run queries against data"
+readme = "README.md"
+license = "Apache-2.0"
+edition = "2018"
+
+[dependencies]
+tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
+rand = "0.7"
+pyo3 = { version = "0.12.1", features = ["extension-module"] }
+datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "2423ff0d" }
+
+[lib]
+name = "datafusion"
+crate-type = ["cdylib"]
+
+[package.metadata.maturin]
+requires-dist = ["pyarrow>=1"]
+
+classifier = [
+ "Development Status :: 2 - Pre-Alpha",
+ "Intended Audience :: Developers",
+ "License :: OSI Approved :: Apache Software License",
+ "License :: OSI Approved",
+ "Operating System :: MacOS",
+ "Operating System :: Microsoft :: Windows",
+ "Operating System :: POSIX :: Linux",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.6",
+ "Programming Language :: Python :: 3.7",
+ "Programming Language :: Python :: 3.8",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python",
+ "Programming Language :: Rust",
+]
diff --git a/python/README.md b/python/README.md
new file mode 100644
index 0000000..1859fca
--- /dev/null
+++ b/python/README.md
@@ -0,0 +1,146 @@
+<!---
+ Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+-->
+
+## DataFusion in Python
+
+This is a Python library that binds to [Apache Arrow](https://arrow.apache.org/) in-memory query engine [DataFusion](https://github.com/apache/arrow/tree/master/rust/datafusion).
+
+Like pyspark, it allows you to build a plan through SQL or a DataFrame API against in-memory data, parquet or CSV files, run it in a multi-threaded environment, and obtain the result back in Python.
+
+It also allows you to use UDFs and UDAFs for complex operations.
+
+The major advantage of this library over other execution engines is that this library achieves zero-copy between Python and its execution engine: there is no cost in using UDFs, UDAFs, and collecting the results to Python apart from having to lock the GIL when running those operations.
+
+Its query engine, DataFusion, is written in [Rust](https://www.rust-lang.org/), which makes strong assumptions about thread safety and lack of memory leaks.
+
+Technically, zero-copy is achieved via the [c data interface](https://arrow.apache.org/docs/format/CDataInterface.html).
+
+## How to use it
+
+Simple usage:
+
+```python
+import datafusion
+import pyarrow
+
+# an alias
+f = datafusion.functions
+
+# create a context
+ctx = datafusion.ExecutionContext()
+
+# create a RecordBatch and a new DataFrame from it
+batch = pyarrow.RecordBatch.from_arrays(
+ [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
+ names=["a", "b"],
+)
+df = ctx.create_dataframe([[batch]])
+
+# create a new statement
+df = df.select(
+ f.col("a") + f.col("b"),
+ f.col("a") - f.col("b"),
+)
+
+# execute and collect the first (and only) batch
+result = df.collect()[0]
+
+assert result.column(0) == pyarrow.array([5, 7, 9])
+assert result.column(1) == pyarrow.array([-3, -3, -3])
+```
+
+### UDFs
+
+```python
+def is_null(array: pyarrow.Array) -> pyarrow.Array:
+ return array.is_null()
+
+udf = f.udf(is_null, [pyarrow.int64()], pyarrow.bool_())
+
+df = df.select(udf(f.col("a")))
+```
+
+### UDAF
+
+```python
+import pyarrow
+import pyarrow.compute
+
+
+class Accumulator:
+ """
+ Interface of a user-defined accumulation.
+ """
+ def __init__(self):
+ self._sum = pyarrow.scalar(0.0)
+
+ def to_scalars(self) -> [pyarrow.Scalar]:
+ return [self._sum]
+
+ def update(self, values: pyarrow.Array) -> None:
+ # not nice since pyarrow scalars can't be summed yet. This breaks on `None`
+ self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(values).as_py())
+
+ def merge(self, states: pyarrow.Array) -> None:
+ # not nice since pyarrow scalars can't be summed yet. This breaks on `None`
+ self._sum = pyarrow.scalar(self._sum.as_py() + pyarrow.compute.sum(states).as_py())
+
+ def evaluate(self) -> pyarrow.Scalar:
+ return self._sum
+
+
+df = ...
+
+udaf = f.udaf(Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()])
+
+df = df.aggregate(
+ [],
+ [udaf(f.col("a"))]
+)
+```
+
+## How to install
+
+```bash
+pip install datafusion
+```
+
+## How to develop
+
+This assumes that you have rust and cargo installed. We use the workflow recommended by [pyo3](https://github.com/PyO3/pyo3) and [maturin](https://github.com/PyO3/maturin).
+
+Bootstrap:
+
+```bash
+# fetch this repo
+git clone git@github.com:apache/arrow-datafusion.git
+
+cd arrow-datafusion/python
+
+# prepare development environment (used to build wheel / install in development)
+python3 -m venv venv
+pip install maturin==0.10.4 toml==0.10.1 pyarrow==1.0.0
+```
+
+Whenever rust code changes (your changes or via git pull):
+
+```bash
+venv/bin/maturin develop
+venv/bin/python -m unittest discover tests
+```
diff --git a/Cargo.toml b/python/pyproject.toml
similarity index 77%
copy from Cargo.toml
copy to python/pyproject.toml
index fa36a0c..2748069 100644
--- a/Cargo.toml
+++ b/python/pyproject.toml
@@ -15,14 +15,6 @@
# specific language governing permissions and limitations
# under the License.
-[workspace]
-members = [
- "datafusion",
- "datafusion-cli",
- "datafusion-examples",
- "benchmarks",
- "ballista/rust/client",
- "ballista/rust/core",
- "ballista/rust/executor",
- "ballista/rust/scheduler",
-]
\ No newline at end of file
+[build-system]
+requires = ["maturin"]
+build-backend = "maturin"
diff --git a/python/rust-toolchain b/python/rust-toolchain
new file mode 100644
index 0000000..9d0cf79
--- /dev/null
+++ b/python/rust-toolchain
@@ -0,0 +1 @@
+nightly-2021-01-06
diff --git a/python/src/context.rs b/python/src/context.rs
new file mode 100644
index 0000000..14ef0f7
--- /dev/null
+++ b/python/src/context.rs
@@ -0,0 +1,115 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{collections::HashSet, sync::Arc};
+
+use rand::distributions::Alphanumeric;
+use rand::Rng;
+
+use pyo3::prelude::*;
+
+use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::datasource::MemTable;
+use datafusion::execution::context::ExecutionContext as _ExecutionContext;
+
+use crate::dataframe;
+use crate::errors;
+use crate::functions;
+use crate::to_rust;
+use crate::types::PyDataType;
+
+/// `ExecutionContext` is able to plan and execute DataFusion plans.
+/// It has a powerful optimizer, a physical planner for local execution, and a
+/// multi-threaded execution engine to perform the execution.
+#[pyclass(unsendable)]
+pub(crate) struct ExecutionContext {
+ ctx: _ExecutionContext,
+}
+
+#[pymethods]
+impl ExecutionContext {
+ #[new]
+ fn new() -> Self {
+ ExecutionContext {
+ ctx: _ExecutionContext::new(),
+ }
+ }
+
+ /// Returns a DataFrame whose plan corresponds to the SQL statement.
+ fn sql(&mut self, query: &str) -> PyResult<dataframe::DataFrame> {
+ let df = self
+ .ctx
+ .sql(query)
+ .map_err(|e| -> errors::DataFusionError { e.into() })?;
+ Ok(dataframe::DataFrame::new(
+ self.ctx.state.clone(),
+ df.to_logical_plan(),
+ ))
+ }
+
+ fn create_dataframe(
+ &mut self,
+ partitions: Vec<Vec<PyObject>>,
+ py: Python,
+ ) -> PyResult<dataframe::DataFrame> {
+ let partitions: Vec<Vec<RecordBatch>> = partitions
+ .iter()
+ .map(|batches| {
+ batches
+ .iter()
+ .map(|batch| to_rust::to_rust_batch(batch.as_ref(py)))
+ .collect()
+ })
+ .collect::<PyResult<_>>()?;
+
+ let table =
+ errors::wrap(MemTable::try_new(partitions[0][0].schema(), partitions))?;
+
+ // generate a random (unique) name for this table
+ let name = rand::thread_rng()
+ .sample_iter(&Alphanumeric)
+ .take(10)
+ .collect::<String>();
+
+ errors::wrap(self.ctx.register_table(&*name, Arc::new(table)))?;
+ Ok(dataframe::DataFrame::new(
+ self.ctx.state.clone(),
+ errors::wrap(self.ctx.table(&*name))?.to_logical_plan(),
+ ))
+ }
+
+ fn register_parquet(&mut self, name: &str, path: &str) -> PyResult<()> {
+ errors::wrap(self.ctx.register_parquet(name, path))?;
+ Ok(())
+ }
+
+ fn register_udf(
+ &mut self,
+ name: &str,
+ func: PyObject,
+ args_types: Vec<PyDataType>,
+ return_type: PyDataType,
+ ) {
+ let function = functions::create_udf(func, args_types, return_type, name);
+
+ self.ctx.register_udf(function.function);
+ }
+
+ fn tables(&self) -> HashSet<String> {
+ self.ctx.tables().unwrap()
+ }
+}
diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs
new file mode 100644
index 0000000..f90a7cf
--- /dev/null
+++ b/python/src/dataframe.rs
@@ -0,0 +1,161 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::{Arc, Mutex};
+
+use logical_plan::LogicalPlan;
+use pyo3::{prelude::*, types::PyTuple};
+use tokio::runtime::Runtime;
+
+use datafusion::execution::context::ExecutionContext as _ExecutionContext;
+use datafusion::logical_plan::{JoinType, LogicalPlanBuilder};
+use datafusion::physical_plan::collect;
+use datafusion::{execution::context::ExecutionContextState, logical_plan};
+
+use crate::{errors, to_py};
+use crate::{errors::DataFusionError, expression};
+
+/// A DataFrame is a representation of a logical plan and an API to compose statements.
+/// Use it to build a plan and `.collect()` to execute the plan and collect the result.
+/// The actual execution of a plan runs natively on Rust and Arrow on a multi-threaded environment.
+#[pyclass]
+pub(crate) struct DataFrame {
+ ctx_state: Arc<Mutex<ExecutionContextState>>,
+ plan: LogicalPlan,
+}
+
+impl DataFrame {
+ /// creates a new DataFrame
+ pub fn new(ctx_state: Arc<Mutex<ExecutionContextState>>, plan: LogicalPlan) -> Self {
+ Self { ctx_state, plan }
+ }
+}
+
+#[pymethods]
+impl DataFrame {
+ /// Select `expressions` from the existing DataFrame.
+ #[args(args = "*")]
+ fn select(&self, args: &PyTuple) -> PyResult<Self> {
+ let expressions = expression::from_tuple(args)?;
+ let builder = LogicalPlanBuilder::from(&self.plan);
+ let builder =
+ errors::wrap(builder.project(expressions.into_iter().map(|e| e.expr)))?;
+ let plan = errors::wrap(builder.build())?;
+
+ Ok(DataFrame {
+ ctx_state: self.ctx_state.clone(),
+ plan,
+ })
+ }
+
+ /// Filter according to the `predicate` expression
+ fn filter(&self, predicate: expression::Expression) -> PyResult<Self> {
+ let builder = LogicalPlanBuilder::from(&self.plan);
+ let builder = errors::wrap(builder.filter(predicate.expr))?;
+ let plan = errors::wrap(builder.build())?;
+
+ Ok(DataFrame {
+ ctx_state: self.ctx_state.clone(),
+ plan,
+ })
+ }
+
+ /// Aggregates using expressions
+ fn aggregate(
+ &self,
+ group_by: Vec<expression::Expression>,
+ aggs: Vec<expression::Expression>,
+ ) -> PyResult<Self> {
+ let builder = LogicalPlanBuilder::from(&self.plan);
+ let builder = errors::wrap(builder.aggregate(
+ group_by.into_iter().map(|e| e.expr),
+ aggs.into_iter().map(|e| e.expr),
+ ))?;
+ let plan = errors::wrap(builder.build())?;
+
+ Ok(DataFrame {
+ ctx_state: self.ctx_state.clone(),
+ plan,
+ })
+ }
+
+ /// Limits the plan to return at most `count` rows
+ fn limit(&self, count: usize) -> PyResult<Self> {
+ let builder = LogicalPlanBuilder::from(&self.plan);
+ let builder = errors::wrap(builder.limit(count))?;
+ let plan = errors::wrap(builder.build())?;
+
+ Ok(DataFrame {
+ ctx_state: self.ctx_state.clone(),
+ plan,
+ })
+ }
+
+ /// Executes the plan, returning a list of `RecordBatch`es.
+ /// Unless some order is specified in the plan, there is no guarantee of the order of the result
+ fn collect(&self, py: Python) -> PyResult<PyObject> {
+ let ctx = _ExecutionContext::from(self.ctx_state.clone());
+ let plan = ctx
+ .optimize(&self.plan)
+ .map_err(|e| -> errors::DataFusionError { e.into() })?;
+ let plan = ctx
+ .create_physical_plan(&plan)
+ .map_err(|e| -> errors::DataFusionError { e.into() })?;
+
+ let rt = Runtime::new().unwrap();
+ let batches = py.allow_threads(|| {
+ rt.block_on(async {
+ collect(plan)
+ .await
+ .map_err(|e| -> errors::DataFusionError { e.into() })
+ })
+ })?;
+ to_py::to_py(&batches)
+ }
+
+ /// Returns the join of two DataFrames `on`.
+ fn join(&self, right: &DataFrame, on: Vec<&str>, how: &str) -> PyResult<Self> {
+ let builder = LogicalPlanBuilder::from(&self.plan);
+
+ let join_type = match how {
+ "inner" => JoinType::Inner,
+ "left" => JoinType::Left,
+ "right" => JoinType::Right,
+ how => {
+ return Err(DataFusionError::Common(format!(
+ "The join type {} does not exist or is not implemented",
+ how
+ ))
+ .into())
+ }
+ };
+
+ let builder = errors::wrap(builder.join(
+ &right.plan,
+ join_type,
+ on.as_slice(),
+ on.as_slice(),
+ ))?;
+
+ let plan = errors::wrap(builder.build())?;
+
+ Ok(DataFrame {
+ ctx_state: self.ctx_state.clone(),
+ plan,
+ })
+ }
+}
diff --git a/python/src/errors.rs b/python/src/errors.rs
new file mode 100644
index 0000000..fbe9803
--- /dev/null
+++ b/python/src/errors.rs
@@ -0,0 +1,61 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use core::fmt;
+
+use datafusion::arrow::error::ArrowError;
+use datafusion::error::DataFusionError as InnerDataFusionError;
+use pyo3::{exceptions, PyErr};
+
+#[derive(Debug)]
+pub enum DataFusionError {
+ ExecutionError(InnerDataFusionError),
+ ArrowError(ArrowError),
+ Common(String),
+}
+
+impl fmt::Display for DataFusionError {
+ fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
+ match self {
+ DataFusionError::ExecutionError(e) => write!(f, "DataFusion error: {:?}", e),
+ DataFusionError::ArrowError(e) => write!(f, "Arrow error: {:?}", e),
+ DataFusionError::Common(e) => write!(f, "{}", e),
+ }
+ }
+}
+
+impl From<DataFusionError> for PyErr {
+ fn from(err: DataFusionError) -> PyErr {
+ exceptions::PyException::new_err(err.to_string())
+ }
+}
+
+impl From<InnerDataFusionError> for DataFusionError {
+ fn from(err: InnerDataFusionError) -> DataFusionError {
+ DataFusionError::ExecutionError(err)
+ }
+}
+
+impl From<ArrowError> for DataFusionError {
+ fn from(err: ArrowError) -> DataFusionError {
+ DataFusionError::ArrowError(err)
+ }
+}
+
+pub(crate) fn wrap<T>(a: Result<T, InnerDataFusionError>) -> Result<T, DataFusionError> {
+ Ok(a?)
+}
diff --git a/python/src/expression.rs b/python/src/expression.rs
new file mode 100644
index 0000000..78ca6d7
--- /dev/null
+++ b/python/src/expression.rs
@@ -0,0 +1,162 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use pyo3::{
+ basic::CompareOp, prelude::*, types::PyTuple, PyNumberProtocol, PyObjectProtocol,
+};
+
+use datafusion::logical_plan::Expr as _Expr;
+use datafusion::physical_plan::udaf::AggregateUDF as _AggregateUDF;
+use datafusion::physical_plan::udf::ScalarUDF as _ScalarUDF;
+
+/// An expression that can be used on a DataFrame
+#[pyclass]
+#[derive(Debug, Clone)]
+pub(crate) struct Expression {
+ pub(crate) expr: _Expr,
+}
+
+/// converts a tuple of expressions into a vector of Expressions
+pub(crate) fn from_tuple(value: &PyTuple) -> PyResult<Vec<Expression>> {
+ value
+ .iter()
+ .map(|e| e.extract::<Expression>())
+ .collect::<PyResult<_>>()
+}
+
+#[pyproto]
+impl PyNumberProtocol for Expression {
+ fn __add__(lhs: Expression, rhs: Expression) -> PyResult<Expression> {
+ Ok(Expression {
+ expr: lhs.expr + rhs.expr,
+ })
+ }
+
+ fn __sub__(lhs: Expression, rhs: Expression) -> PyResult<Expression> {
+ Ok(Expression {
+ expr: lhs.expr - rhs.expr,
+ })
+ }
+
+ fn __truediv__(lhs: Expression, rhs: Expression) -> PyResult<Expression> {
+ Ok(Expression {
+ expr: lhs.expr / rhs.expr,
+ })
+ }
+
+ fn __mul__(lhs: Expression, rhs: Expression) -> PyResult<Expression> {
+ Ok(Expression {
+ expr: lhs.expr * rhs.expr,
+ })
+ }
+
+ fn __and__(lhs: Expression, rhs: Expression) -> PyResult<Expression> {
+ Ok(Expression {
+ expr: lhs.expr.and(rhs.expr),
+ })
+ }
+
+ fn __or__(lhs: Expression, rhs: Expression) -> PyResult<Expression> {
+ Ok(Expression {
+ expr: lhs.expr.or(rhs.expr),
+ })
+ }
+
+ fn __invert__(&self) -> PyResult<Expression> {
+ Ok(Expression {
+ expr: self.expr.clone().not(),
+ })
+ }
+}
+
+#[pyproto]
+impl PyObjectProtocol for Expression {
+ fn __richcmp__(&self, other: Expression, op: CompareOp) -> Expression {
+ match op {
+ CompareOp::Lt => Expression {
+ expr: self.expr.clone().lt(other.expr),
+ },
+ CompareOp::Le => Expression {
+ expr: self.expr.clone().lt_eq(other.expr),
+ },
+ CompareOp::Eq => Expression {
+ expr: self.expr.clone().eq(other.expr),
+ },
+ CompareOp::Ne => Expression {
+ expr: self.expr.clone().not_eq(other.expr),
+ },
+ CompareOp::Gt => Expression {
+ expr: self.expr.clone().gt(other.expr),
+ },
+ CompareOp::Ge => Expression {
+ expr: self.expr.clone().gt_eq(other.expr),
+ },
+ }
+ }
+}
+
+#[pymethods]
+impl Expression {
+ /// assign a name to the expression
+ pub fn alias(&self, name: &str) -> PyResult<Expression> {
+ Ok(Expression {
+ expr: self.expr.clone().alias(name),
+ })
+ }
+}
+
+/// Represents a ScalarUDF
+#[pyclass]
+#[derive(Debug, Clone)]
+pub struct ScalarUDF {
+ pub(crate) function: _ScalarUDF,
+}
+
+#[pymethods]
+impl ScalarUDF {
+ /// creates a new expression with the call of the udf
+ #[call]
+ #[args(args = "*")]
+ fn __call__(&self, args: &PyTuple) -> PyResult<Expression> {
+ let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect();
+
+ Ok(Expression {
+ expr: self.function.call(args),
+ })
+ }
+}
+
+/// Represents a AggregateUDF
+#[pyclass]
+#[derive(Debug, Clone)]
+pub struct AggregateUDF {
+ pub(crate) function: _AggregateUDF,
+}
+
+#[pymethods]
+impl AggregateUDF {
+ /// creates a new expression with the call of the udf
+ #[call]
+ #[args(args = "*")]
+ fn __call__(&self, args: &PyTuple) -> PyResult<Expression> {
+ let args = from_tuple(args)?.iter().map(|e| e.expr.clone()).collect();
+
+ Ok(Expression {
+ expr: self.function.call(args),
+ })
+ }
+}
diff --git a/python/src/functions.rs b/python/src/functions.rs
new file mode 100644
index 0000000..68000cb
--- /dev/null
+++ b/python/src/functions.rs
@@ -0,0 +1,165 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use datafusion::arrow::datatypes::DataType;
+use pyo3::{prelude::*, wrap_pyfunction};
+
+use datafusion::logical_plan;
+
+use crate::udaf;
+use crate::udf;
+use crate::{expression, types::PyDataType};
+
+/// Expression representing a column on the existing plan.
+#[pyfunction]
+#[text_signature = "(name)"]
+fn col(name: &str) -> expression::Expression {
+ expression::Expression {
+ expr: logical_plan::col(name),
+ }
+}
+
+/// Expression representing a constant value
+#[pyfunction]
+#[text_signature = "(value)"]
+fn lit(value: i32) -> expression::Expression {
+ expression::Expression {
+ expr: logical_plan::lit(value),
+ }
+}
+
+#[pyfunction]
+fn sum(value: expression::Expression) -> expression::Expression {
+ expression::Expression {
+ expr: logical_plan::sum(value.expr),
+ }
+}
+
+#[pyfunction]
+fn avg(value: expression::Expression) -> expression::Expression {
+ expression::Expression {
+ expr: logical_plan::avg(value.expr),
+ }
+}
+
+#[pyfunction]
+fn min(value: expression::Expression) -> expression::Expression {
+ expression::Expression {
+ expr: logical_plan::min(value.expr),
+ }
+}
+
+#[pyfunction]
+fn max(value: expression::Expression) -> expression::Expression {
+ expression::Expression {
+ expr: logical_plan::max(value.expr),
+ }
+}
+
+#[pyfunction]
+fn count(value: expression::Expression) -> expression::Expression {
+ expression::Expression {
+ expr: logical_plan::count(value.expr),
+ }
+}
+
+/*
+#[pyfunction]
+fn concat(value: Vec<expression::Expression>) -> expression::Expression {
+ expression::Expression {
+ expr: logical_plan::concat(value.into_iter().map(|e| e.expr)),
+ }
+}
+ */
+
+pub(crate) fn create_udf(
+ fun: PyObject,
+ input_types: Vec<PyDataType>,
+ return_type: PyDataType,
+ name: &str,
+) -> expression::ScalarUDF {
+ let input_types: Vec<DataType> =
+ input_types.iter().map(|d| d.data_type.clone()).collect();
+ let return_type = Arc::new(return_type.data_type);
+
+ expression::ScalarUDF {
+ function: logical_plan::create_udf(
+ name,
+ input_types,
+ return_type,
+ udf::array_udf(fun),
+ ),
+ }
+}
+
+/// Creates a new udf.
+#[pyfunction]
+fn udf(
+ fun: PyObject,
+ input_types: Vec<PyDataType>,
+ return_type: PyDataType,
+ py: Python,
+) -> PyResult<expression::ScalarUDF> {
+ let name = fun.getattr(py, "__qualname__")?.extract::<String>(py)?;
+
+ Ok(create_udf(fun, input_types, return_type, &name))
+}
+
+/// Creates a new udf.
+#[pyfunction]
+fn udaf(
+ accumulator: PyObject,
+ input_type: PyDataType,
+ return_type: PyDataType,
+ state_type: Vec<PyDataType>,
+ py: Python,
+) -> PyResult<expression::AggregateUDF> {
+ let name = accumulator
+ .getattr(py, "__qualname__")?
+ .extract::<String>(py)?;
+
+ let input_type = input_type.data_type;
+ let return_type = Arc::new(return_type.data_type);
+ let state_type = Arc::new(state_type.into_iter().map(|t| t.data_type).collect());
+
+ Ok(expression::AggregateUDF {
+ function: logical_plan::create_udaf(
+ &name,
+ input_type,
+ return_type,
+ udaf::array_udaf(accumulator),
+ state_type,
+ ),
+ })
+}
+
+pub fn init(module: &PyModule) -> PyResult<()> {
+ module.add_function(wrap_pyfunction!(col, module)?)?;
+ module.add_function(wrap_pyfunction!(lit, module)?)?;
+ // see https://github.com/apache/arrow-datafusion/issues/226
+ //module.add_function(wrap_pyfunction!(concat, module)?)?;
+ module.add_function(wrap_pyfunction!(udf, module)?)?;
+ module.add_function(wrap_pyfunction!(sum, module)?)?;
+ module.add_function(wrap_pyfunction!(count, module)?)?;
+ module.add_function(wrap_pyfunction!(min, module)?)?;
+ module.add_function(wrap_pyfunction!(max, module)?)?;
+ module.add_function(wrap_pyfunction!(avg, module)?)?;
+ module.add_function(wrap_pyfunction!(udaf, module)?)?;
+ Ok(())
+}
diff --git a/python/src/lib.rs b/python/src/lib.rs
new file mode 100644
index 0000000..aecfe99
--- /dev/null
+++ b/python/src/lib.rs
@@ -0,0 +1,44 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use pyo3::prelude::*;
+
+mod context;
+mod dataframe;
+mod errors;
+mod expression;
+mod functions;
+mod scalar;
+mod to_py;
+mod to_rust;
+mod types;
+mod udaf;
+mod udf;
+
+/// DataFusion.
+#[pymodule]
+fn datafusion(py: Python, m: &PyModule) -> PyResult<()> {
+ m.add_class::<context::ExecutionContext>()?;
+ m.add_class::<dataframe::DataFrame>()?;
+ m.add_class::<expression::Expression>()?;
+
+ let functions = PyModule::new(py, "functions")?;
+ functions::init(functions)?;
+ m.add_submodule(functions)?;
+
+ Ok(())
+}
diff --git a/python/src/scalar.rs b/python/src/scalar.rs
new file mode 100644
index 0000000..0c562a9
--- /dev/null
+++ b/python/src/scalar.rs
@@ -0,0 +1,36 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use pyo3::prelude::*;
+
+use datafusion::scalar::ScalarValue as _Scalar;
+
+use crate::to_rust::to_rust_scalar;
+
+/// An expression that can be used on a DataFrame
+#[derive(Debug, Clone)]
+pub(crate) struct Scalar {
+ pub(crate) scalar: _Scalar,
+}
+
+impl<'source> FromPyObject<'source> for Scalar {
+ fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ Ok(Self {
+ scalar: to_rust_scalar(ob)?,
+ })
+ }
+}
diff --git a/python/src/to_py.rs b/python/src/to_py.rs
new file mode 100644
index 0000000..deeb971
--- /dev/null
+++ b/python/src/to_py.rs
@@ -0,0 +1,77 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use pyo3::prelude::*;
+use pyo3::{libc::uintptr_t, PyErr};
+
+use std::convert::From;
+
+use datafusion::arrow::array::ArrayRef;
+use datafusion::arrow::record_batch::RecordBatch;
+
+use crate::errors;
+
+pub fn to_py_array(array: &ArrayRef, py: Python) -> PyResult<PyObject> {
+ let (array_pointer, schema_pointer) =
+ array.to_raw().map_err(errors::DataFusionError::from)?;
+
+ let pa = py.import("pyarrow")?;
+
+ let array = pa.getattr("Array")?.call_method1(
+ "_import_from_c",
+ (array_pointer as uintptr_t, schema_pointer as uintptr_t),
+ )?;
+ Ok(array.to_object(py))
+}
+
+fn to_py_batch<'a>(
+ batch: &RecordBatch,
+ py: Python,
+ pyarrow: &'a PyModule,
+) -> Result<PyObject, PyErr> {
+ let mut py_arrays = vec![];
+ let mut py_names = vec![];
+
+ let schema = batch.schema();
+ for (array, field) in batch.columns().iter().zip(schema.fields().iter()) {
+ let array = to_py_array(array, py)?;
+
+ py_arrays.push(array);
+ py_names.push(field.name());
+ }
+
+ let record = pyarrow
+ .getattr("RecordBatch")?
+ .call_method1("from_arrays", (py_arrays, py_names))?;
+
+ Ok(PyObject::from(record))
+}
+
+/// Converts a &[RecordBatch] into a Vec<RecordBatch> represented in PyArrow
+pub fn to_py(batches: &[RecordBatch]) -> PyResult<PyObject> {
+ let gil = pyo3::Python::acquire_gil();
+ let py = gil.python();
+ let pyarrow = PyModule::import(py, "pyarrow")?;
+ let builtins = PyModule::import(py, "builtins")?;
+
+ let mut py_batches = vec![];
+ for batch in batches {
+ py_batches.push(to_py_batch(batch, py, pyarrow)?);
+ }
+ let result = builtins.call1("list", (py_batches,))?;
+ Ok(PyObject::from(result))
+}
diff --git a/python/src/to_rust.rs b/python/src/to_rust.rs
new file mode 100644
index 0000000..d8f2307
--- /dev/null
+++ b/python/src/to_rust.rs
@@ -0,0 +1,111 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use datafusion::arrow::{
+ array::{make_array_from_raw, ArrayRef},
+ datatypes::Field,
+ datatypes::Schema,
+ ffi,
+ record_batch::RecordBatch,
+};
+use datafusion::scalar::ScalarValue;
+use pyo3::{libc::uintptr_t, prelude::*};
+
+use crate::{errors, types::PyDataType};
+
+/// converts a pyarrow Array into a Rust Array
+pub fn to_rust(ob: &PyAny) -> PyResult<ArrayRef> {
+ // prepare a pointer to receive the Array struct
+ let (array_pointer, schema_pointer) =
+ ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() });
+
+ // make the conversion through PyArrow's private API
+ // this changes the pointer's memory and is thus unsafe. In particular, `_export_to_c` can go out of bounds
+ ob.call_method1(
+ "_export_to_c",
+ (array_pointer as uintptr_t, schema_pointer as uintptr_t),
+ )?;
+
+ let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) }
+ .map_err(errors::DataFusionError::from)?;
+ Ok(array)
+}
+
+pub fn to_rust_batch(batch: &PyAny) -> PyResult<RecordBatch> {
+ let schema = batch.getattr("schema")?;
+ let names = schema.getattr("names")?.extract::<Vec<String>>()?;
+
+ let fields = names
+ .iter()
+ .enumerate()
+ .map(|(i, name)| {
+ let field = schema.call_method1("field", (i,))?;
+ let nullable = field.getattr("nullable")?.extract::<bool>()?;
+ let py_data_type = field.getattr("type")?;
+ let data_type = py_data_type.extract::<PyDataType>()?.data_type;
+ Ok(Field::new(name, data_type, nullable))
+ })
+ .collect::<PyResult<_>>()?;
+
+ let schema = Arc::new(Schema::new(fields));
+
+ let arrays = (0..names.len())
+ .map(|i| {
+ let array = batch.call_method1("column", (i,))?;
+ to_rust(array)
+ })
+ .collect::<PyResult<_>>()?;
+
+ let batch =
+ RecordBatch::try_new(schema, arrays).map_err(errors::DataFusionError::from)?;
+ Ok(batch)
+}
+
+/// converts a pyarrow Scalar into a Rust Scalar
+pub fn to_rust_scalar(ob: &PyAny) -> PyResult<ScalarValue> {
+ let t = ob
+ .getattr("__class__")?
+ .getattr("__name__")?
+ .extract::<&str>()?;
+
+ let p = ob.call_method0("as_py")?;
+
+ Ok(match t {
+ "Int8Scalar" => ScalarValue::Int8(Some(p.extract::<i8>()?)),
+ "Int16Scalar" => ScalarValue::Int16(Some(p.extract::<i16>()?)),
+ "Int32Scalar" => ScalarValue::Int32(Some(p.extract::<i32>()?)),
+ "Int64Scalar" => ScalarValue::Int64(Some(p.extract::<i64>()?)),
+ "UInt8Scalar" => ScalarValue::UInt8(Some(p.extract::<u8>()?)),
+ "UInt16Scalar" => ScalarValue::UInt16(Some(p.extract::<u16>()?)),
+ "UInt32Scalar" => ScalarValue::UInt32(Some(p.extract::<u32>()?)),
+ "UInt64Scalar" => ScalarValue::UInt64(Some(p.extract::<u64>()?)),
+ "FloatScalar" => ScalarValue::Float32(Some(p.extract::<f32>()?)),
+ "DoubleScalar" => ScalarValue::Float64(Some(p.extract::<f64>()?)),
+ "BooleanScalar" => ScalarValue::Boolean(Some(p.extract::<bool>()?)),
+ "StringScalar" => ScalarValue::Utf8(Some(p.extract::<String>()?)),
+ "LargeStringScalar" => ScalarValue::LargeUtf8(Some(p.extract::<String>()?)),
+ other => {
+ return Err(errors::DataFusionError::Common(format!(
+ "Type \"{}\"not yet implemented",
+ other
+ ))
+ .into())
+ }
+ })
+}
diff --git a/python/src/types.rs b/python/src/types.rs
new file mode 100644
index 0000000..ffa822e
--- /dev/null
+++ b/python/src/types.rs
@@ -0,0 +1,76 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion::arrow::datatypes::DataType;
+use pyo3::{FromPyObject, PyAny, PyResult};
+
+use crate::errors;
+
+/// utility struct to convert PyObj to native DataType
+#[derive(Debug, Clone)]
+pub struct PyDataType {
+ pub data_type: DataType,
+}
+
+impl<'source> FromPyObject<'source> for PyDataType {
+ fn extract(ob: &'source PyAny) -> PyResult<Self> {
+ let id = ob.getattr("id")?.extract::<i32>()?;
+ let data_type = data_type_id(&id)?;
+ Ok(PyDataType { data_type })
+ }
+}
+
+fn data_type_id(id: &i32) -> Result<DataType, errors::DataFusionError> {
+ // see https://github.com/apache/arrow/blob/3694794bdfd0677b95b8c95681e392512f1c9237/python/pyarrow/includes/libarrow.pxd
+ // this is not ideal as it does not generalize for non-basic types
+ // Find a way to get a unique name from the pyarrow.DataType
+ Ok(match id {
+ 1 => DataType::Boolean,
+ 2 => DataType::UInt8,
+ 3 => DataType::Int8,
+ 4 => DataType::UInt16,
+ 5 => DataType::Int16,
+ 6 => DataType::UInt32,
+ 7 => DataType::Int32,
+ 8 => DataType::UInt64,
+ 9 => DataType::Int64,
+
+ 10 => DataType::Float16,
+ 11 => DataType::Float32,
+ 12 => DataType::Float64,
+
+ //13 => DataType::Decimal,
+
+ // 14 => DataType::Date32(),
+ // 15 => DataType::Date64(),
+ // 16 => DataType::Timestamp(),
+ // 17 => DataType::Time32(),
+ // 18 => DataType::Time64(),
+ // 19 => DataType::Duration()
+ 20 => DataType::Binary,
+ 21 => DataType::Utf8,
+ 22 => DataType::LargeBinary,
+ 23 => DataType::LargeUtf8,
+
+ other => {
+ return Err(errors::DataFusionError::Common(format!(
+ "The type {} is not valid",
+ other
+ )))
+ }
+ })
+}
diff --git a/python/src/udaf.rs b/python/src/udaf.rs
new file mode 100644
index 0000000..3ce223d
--- /dev/null
+++ b/python/src/udaf.rs
@@ -0,0 +1,147 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use pyo3::{prelude::*, types::PyTuple};
+
+use datafusion::arrow::array::ArrayRef;
+
+use datafusion::error::Result;
+use datafusion::{
+ error::DataFusionError as InnerDataFusionError, physical_plan::Accumulator,
+ scalar::ScalarValue,
+};
+
+use crate::scalar::Scalar;
+use crate::to_py::to_py_array;
+use crate::to_rust::to_rust_scalar;
+
+#[derive(Debug)]
+struct PyAccumulator {
+ accum: PyObject,
+}
+
+impl PyAccumulator {
+ fn new(accum: PyObject) -> Self {
+ Self { accum }
+ }
+}
+
+impl Accumulator for PyAccumulator {
+ fn state(&self) -> Result<Vec<datafusion::scalar::ScalarValue>> {
+ let gil = pyo3::Python::acquire_gil();
+ let py = gil.python();
+
+ let state = self
+ .accum
+ .as_ref(py)
+ .call_method0("to_scalars")
+ .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?
+ .extract::<Vec<Scalar>>()
+ .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?;
+
+ Ok(state.into_iter().map(|v| v.scalar).collect::<Vec<_>>())
+ }
+
+ fn update(&mut self, _values: &[ScalarValue]) -> Result<()> {
+ // no need to implement as datafusion does not use it
+ todo!()
+ }
+
+ fn merge(&mut self, _states: &[ScalarValue]) -> Result<()> {
+ // no need to implement as datafusion does not use it
+ todo!()
+ }
+
+ fn evaluate(&self) -> Result<datafusion::scalar::ScalarValue> {
+ // get GIL
+ let gil = pyo3::Python::acquire_gil();
+ let py = gil.python();
+
+ let value = self
+ .accum
+ .as_ref(py)
+ .call_method0("evaluate")
+ .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?;
+
+ to_rust_scalar(value)
+ .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))
+ }
+
+ fn update_batch(&mut self, values: &[ArrayRef]) -> Result<()> {
+ // get GIL
+ let gil = pyo3::Python::acquire_gil();
+ let py = gil.python();
+
+ // 1. cast args to Pyarrow array
+ // 2. call function
+
+ // 1.
+ let py_args = values
+ .iter()
+ .map(|arg| {
+ // remove unwrap
+ to_py_array(arg, py).unwrap()
+ })
+ .collect::<Vec<_>>();
+ let py_args = PyTuple::new(py, py_args);
+
+ // update accumulator
+ self.accum
+ .as_ref(py)
+ .call_method1("update", py_args)
+ .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?;
+
+ Ok(())
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> Result<()> {
+ // get GIL
+ let gil = pyo3::Python::acquire_gil();
+ let py = gil.python();
+
+ // 1. cast states to Pyarrow array
+ // 2. merge
+ let state = &states[0];
+
+ let state = to_py_array(state, py)
+ .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?;
+
+ // 2.
+ self.accum
+ .as_ref(py)
+ .call_method1("merge", (state,))
+ .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?;
+
+ Ok(())
+ }
+}
+
+pub fn array_udaf(
+ accumulator: PyObject,
+) -> Arc<dyn Fn() -> Result<Box<dyn Accumulator>> + Send + Sync> {
+ Arc::new(move || -> Result<Box<dyn Accumulator>> {
+ let gil = pyo3::Python::acquire_gil();
+ let py = gil.python();
+
+ let accumulator = accumulator
+ .call0(py)
+ .map_err(|e| InnerDataFusionError::Execution(format!("{}", e)))?;
+ Ok(Box::new(PyAccumulator::new(accumulator)))
+ })
+}
diff --git a/python/src/udf.rs b/python/src/udf.rs
new file mode 100644
index 0000000..7fee710
--- /dev/null
+++ b/python/src/udf.rs
@@ -0,0 +1,62 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use pyo3::{prelude::*, types::PyTuple};
+
+use datafusion::{arrow::array, physical_plan::functions::make_scalar_function};
+
+use datafusion::error::DataFusionError;
+use datafusion::physical_plan::functions::ScalarFunctionImplementation;
+
+use crate::to_py::to_py_array;
+use crate::to_rust::to_rust;
+
+/// creates a DataFusion's UDF implementation from a python function that expects pyarrow arrays
+/// This is more efficient as it performs a zero-copy of the contents.
+pub fn array_udf(func: PyObject) -> ScalarFunctionImplementation {
+ make_scalar_function(
+ move |args: &[array::ArrayRef]| -> Result<array::ArrayRef, DataFusionError> {
+ // get GIL
+ let gil = pyo3::Python::acquire_gil();
+ let py = gil.python();
+
+ // 1. cast args to Pyarrow arrays
+ // 2. call function
+ // 3. cast to arrow::array::Array
+
+ // 1.
+ let py_args = args
+ .iter()
+ .map(|arg| {
+ // remove unwrap
+ to_py_array(arg, py).unwrap()
+ })
+ .collect::<Vec<_>>();
+ let py_args = PyTuple::new(py, py_args);
+
+ // 2.
+ let value = func.as_ref(py).call(py_args, None);
+ let value = match value {
+ Ok(n) => Ok(n),
+ Err(error) => Err(DataFusionError::Execution(format!("{:?}", error))),
+ }?;
+
+ let array = to_rust(value).unwrap();
+ Ok(array)
+ },
+ )
+}
diff --git a/Cargo.toml b/python/tests/__init__.py
similarity index 77%
copy from Cargo.toml
copy to python/tests/__init__.py
index fa36a0c..13a8339 100644
--- a/Cargo.toml
+++ b/python/tests/__init__.py
@@ -14,15 +14,3 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-[workspace]
-members = [
- "datafusion",
- "datafusion-cli",
- "datafusion-examples",
- "benchmarks",
- "ballista/rust/client",
- "ballista/rust/core",
- "ballista/rust/executor",
- "ballista/rust/scheduler",
-]
\ No newline at end of file
diff --git a/python/tests/generic.py b/python/tests/generic.py
new file mode 100644
index 0000000..7362f0b
--- /dev/null
+++ b/python/tests/generic.py
@@ -0,0 +1,75 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+import tempfile
+import datetime
+import os.path
+import shutil
+
+import numpy
+import pyarrow
+import datafusion
+
+# used to write parquet files
+import pyarrow.parquet
+
+
+def data():
+ data = numpy.concatenate(
+ [numpy.random.normal(0, 0.01, size=50), numpy.random.normal(50, 0.01, size=50)]
+ )
+ return pyarrow.array(data)
+
+
+def data_with_nans():
+ data = numpy.random.normal(0, 0.01, size=50)
+ mask = numpy.random.randint(0, 2, size=50)
+ data[mask == 0] = numpy.NaN
+ return data
+
+
+def data_datetime(f):
+ data = [
+ datetime.datetime.now(),
+ datetime.datetime.now() - datetime.timedelta(days=1),
+ datetime.datetime.now() + datetime.timedelta(days=1),
+ ]
+ return pyarrow.array(
+ data, type=pyarrow.timestamp(f), mask=numpy.array([False, True, False])
+ )
+
+
+def data_timedelta(f):
+ data = [
+ datetime.timedelta(days=100),
+ datetime.timedelta(days=1),
+ datetime.timedelta(seconds=1),
+ ]
+ return pyarrow.array(
+ data, type=pyarrow.duration(f), mask=numpy.array([False, True, False])
+ )
+
+
+def data_binary_other():
+ return numpy.array([1, 0, 0], dtype="u4")
+
+
+def write_parquet(path, data):
+ table = pyarrow.Table.from_arrays([data], names=["a"])
+ pyarrow.parquet.write_table(table, path)
+ return path
diff --git a/python/tests/test_df.py b/python/tests/test_df.py
new file mode 100644
index 0000000..520d4e6
--- /dev/null
+++ b/python/tests/test_df.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+import pyarrow
+import datafusion
+f = datafusion.functions
+
+
+class TestCase(unittest.TestCase):
+
+ def _prepare(self):
+ ctx = datafusion.ExecutionContext()
+
+ # create a RecordBatch and a new DataFrame from it
+ batch = pyarrow.RecordBatch.from_arrays(
+ [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
+ names=["a", "b"],
+ )
+ return ctx.create_dataframe([[batch]])
+
+ def test_select(self):
+ df = self._prepare()
+
+ df = df.select(
+ f.col("a") + f.col("b"),
+ f.col("a") - f.col("b"),
+ )
+
+ # execute and collect the first (and only) batch
+ result = df.collect()[0]
+
+ self.assertEqual(result.column(0), pyarrow.array([5, 7, 9]))
+ self.assertEqual(result.column(1), pyarrow.array([-3, -3, -3]))
+
+ def test_filter(self):
+ df = self._prepare()
+
+ df = df \
+ .select(
+ f.col("a") + f.col("b"),
+ f.col("a") - f.col("b"),
+ ) \
+ .filter(f.col("a") > f.lit(2))
+
+ # execute and collect the first (and only) batch
+ result = df.collect()[0]
+
+ self.assertEqual(result.column(0), pyarrow.array([9]))
+ self.assertEqual(result.column(1), pyarrow.array([-3]))
+
+ def test_limit(self):
+ df = self._prepare()
+
+ df = df.limit(1)
+
+ # execute and collect the first (and only) batch
+ result = df.collect()[0]
+
+ self.assertEqual(len(result.column(0)), 1)
+ self.assertEqual(len(result.column(1)), 1)
+
+ def test_udf(self):
+ df = self._prepare()
+
+ # is_null is a pyarrow function over arrays
+ udf = f.udf(lambda x: x.is_null(), [pyarrow.int64()], pyarrow.bool_())
+
+ df = df.select(udf(f.col("a")))
+
+ self.assertEqual(df.collect()[0].column(0), pyarrow.array([False, False, False]))
+
+ def test_join(self):
+ ctx = datafusion.ExecutionContext()
+
+ batch = pyarrow.RecordBatch.from_arrays(
+ [pyarrow.array([1, 2, 3]), pyarrow.array([4, 5, 6])],
+ names=["a", "b"],
+ )
+ df = ctx.create_dataframe([[batch]])
+
+ batch = pyarrow.RecordBatch.from_arrays(
+ [pyarrow.array([1, 2]), pyarrow.array([8, 10])],
+ names=["a", "c"],
+ )
+ df1 = ctx.create_dataframe([[batch]])
+
+ df = df.join(df1, on="a", how="inner")
+
+ # execute and collect the first (and only) batch
+ batch = df.collect()[0]
+
+ if batch.column(0) == pyarrow.array([1, 2]):
+ self.assertEqual(batch.column(0), pyarrow.array([1, 2]))
+ self.assertEqual(batch.column(1), pyarrow.array([8, 10]))
+ self.assertEqual(batch.column(2), pyarrow.array([4, 5]))
+ else:
+ self.assertEqual(batch.column(0), pyarrow.array([2, 1]))
+ self.assertEqual(batch.column(1), pyarrow.array([10, 8]))
+ self.assertEqual(batch.column(2), pyarrow.array([5, 4]))
diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py
new file mode 100644
index 0000000..e9047ea
--- /dev/null
+++ b/python/tests/test_sql.py
@@ -0,0 +1,294 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+import tempfile
+import datetime
+import os.path
+import shutil
+
+import numpy
+import pyarrow
+import datafusion
+
+# used to write parquet files
+import pyarrow.parquet
+
+from tests.generic import *
+
+
+class TestCase(unittest.TestCase):
+ def setUp(self):
+ # Create a temporary directory
+ self.test_dir = tempfile.mkdtemp()
+ numpy.random.seed(1)
+
+ def tearDown(self):
+ # Remove the directory after the test
+ shutil.rmtree(self.test_dir)
+
+ def test_no_table(self):
+ with self.assertRaises(Exception):
+ datafusion.Context().sql("SELECT a FROM b").collect()
+
+ def test_register(self):
+ ctx = datafusion.ExecutionContext()
+
+ path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data())
+
+ ctx.register_parquet("t", path)
+
+ self.assertEqual(ctx.tables(), {"t"})
+
+ def test_execute(self):
+ data = [1, 1, 2, 2, 3, 11, 12]
+
+ ctx = datafusion.ExecutionContext()
+
+ # single column, "a"
+ path = write_parquet(
+ os.path.join(self.test_dir, "a.parquet"), pyarrow.array(data)
+ )
+ ctx.register_parquet("t", path)
+
+ self.assertEqual(ctx.tables(), {"t"})
+
+ # count
+ result = ctx.sql("SELECT COUNT(a) FROM t").collect()
+
+ expected = pyarrow.array([7], pyarrow.uint64())
+ expected = [pyarrow.RecordBatch.from_arrays([expected], ["COUNT(a)"])]
+ self.assertEqual(expected, result)
+
+ # where
+ expected = pyarrow.array([2], pyarrow.uint64())
+ expected = [pyarrow.RecordBatch.from_arrays([expected], ["COUNT(a)"])]
+ self.assertEqual(
+ expected, ctx.sql("SELECT COUNT(a) FROM t WHERE a > 10").collect()
+ )
+
+ # group by
+ result = ctx.sql(
+ "SELECT CAST(a as int), COUNT(a) FROM t GROUP BY CAST(a as int)"
+ ).collect()
+
+ result_keys = result[0].to_pydict()["CAST(a AS Int32)"]
+ result_values = result[0].to_pydict()["COUNT(a)"]
+ result_keys, result_values = (
+ list(t) for t in zip(*sorted(zip(result_keys, result_values)))
+ )
+
+ self.assertEqual(result_keys, [1, 2, 3, 11, 12])
+ self.assertEqual(result_values, [2, 2, 1, 1, 1])
+
+ # order by
+ result = ctx.sql(
+ "SELECT a, CAST(a AS int) FROM t ORDER BY a DESC LIMIT 2"
+ ).collect()
+ expected_a = pyarrow.array([50.0219, 50.0152], pyarrow.float64())
+ expected_cast = pyarrow.array([50, 50], pyarrow.int32())
+ expected = [
+ pyarrow.RecordBatch.from_arrays(
+ [expected_a, expected_cast], ["a", "CAST(a AS Int32)"]
+ )
+ ]
+ numpy.testing.assert_equal(expected[0].column(1), expected[0].column(1))
+
+ def test_cast(self):
+ """
+ Verify that we can cast
+ """
+ ctx = datafusion.ExecutionContext()
+
+ path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data())
+ ctx.register_parquet("t", path)
+
+ valid_types = [
+ "smallint",
+ "int",
+ "bigint",
+ "float(32)",
+ "float(64)",
+ "float",
+ ]
+
+ select = ", ".join(
+ [f"CAST(9 AS {t}) AS A{i}" for i, t in enumerate(valid_types)]
+ )
+
+ # can execute, which implies that we can cast
+ ctx.sql(f"SELECT {select} FROM t").collect()
+
+ def _test_udf(self, udf, args, return_type, array, expected):
+ ctx = datafusion.ExecutionContext()
+
+ # write to disk
+ path = write_parquet(os.path.join(self.test_dir, "a.parquet"), array)
+ ctx.register_parquet("t", path)
+
+ ctx.register_udf("udf", udf, args, return_type)
+
+ batches = ctx.sql("SELECT udf(a) AS tt FROM t").collect()
+
+ result = batches[0].column(0)
+
+ self.assertEqual(expected, result)
+
+ def test_udf_identity(self):
+ self._test_udf(
+ lambda x: x,
+ [pyarrow.float64()],
+ pyarrow.float64(),
+ pyarrow.array([-1.2, None, 1.2]),
+ pyarrow.array([-1.2, None, 1.2]),
+ )
+
+ def test_udf(self):
+ self._test_udf(
+ lambda x: x.is_null(),
+ [pyarrow.float64()],
+ pyarrow.bool_(),
+ pyarrow.array([-1.2, None, 1.2]),
+ pyarrow.array([False, True, False]),
+ )
+
+
+class TestIO(unittest.TestCase):
+ def setUp(self):
+ # Create a temporary directory
+ self.test_dir = tempfile.mkdtemp()
+
+ def tearDown(self):
+ # Remove the directory after the test
+ shutil.rmtree(self.test_dir)
+
+ def _test_data(self, data):
+ ctx = datafusion.ExecutionContext()
+
+ # write to disk
+ path = write_parquet(os.path.join(self.test_dir, "a.parquet"), data)
+ ctx.register_parquet("t", path)
+
+ batches = ctx.sql("SELECT a AS tt FROM t").collect()
+
+ result = batches[0].column(0)
+
+ numpy.testing.assert_equal(data, result)
+
+ def test_nans(self):
+ self._test_data(data_with_nans())
+
+ def test_utf8(self):
+ array = pyarrow.array(
+ ["a", "b", "c"], pyarrow.utf8(), numpy.array([False, True, False])
+ )
+ self._test_data(array)
+
+ def test_large_utf8(self):
+ array = pyarrow.array(
+ ["a", "b", "c"], pyarrow.large_utf8(), numpy.array([False, True, False])
+ )
+ self._test_data(array)
+
+ # Error from Arrow
+ @unittest.expectedFailure
+ def test_datetime_s(self):
+ self._test_data(data_datetime("s"))
+
+ # C data interface missing
+ @unittest.expectedFailure
+ def test_datetime_ms(self):
+ self._test_data(data_datetime("ms"))
+
+ # C data interface missing
+ @unittest.expectedFailure
+ def test_datetime_us(self):
+ self._test_data(data_datetime("us"))
+
+ # Not writtable to parquet
+ @unittest.expectedFailure
+ def test_datetime_ns(self):
+ self._test_data(data_datetime("ns"))
+
+ # Not writtable to parquet
+ @unittest.expectedFailure
+ def test_timedelta_s(self):
+ self._test_data(data_timedelta("s"))
+
+ # Not writtable to parquet
+ @unittest.expectedFailure
+ def test_timedelta_ms(self):
+ self._test_data(data_timedelta("ms"))
+
+ # Not writtable to parquet
+ @unittest.expectedFailure
+ def test_timedelta_us(self):
+ self._test_data(data_timedelta("us"))
+
+ # Not writtable to parquet
+ @unittest.expectedFailure
+ def test_timedelta_ns(self):
+ self._test_data(data_timedelta("ns"))
+
+ def test_date32(self):
+ array = pyarrow.array(
+ [
+ datetime.date(2000, 1, 1),
+ datetime.date(1980, 1, 1),
+ datetime.date(2030, 1, 1),
+ ],
+ pyarrow.date32(),
+ numpy.array([False, True, False]),
+ )
+ self._test_data(array)
+
+ def test_binary_variable(self):
+ array = pyarrow.array(
+ [b"1", b"2", b"3"], pyarrow.binary(), numpy.array([False, True, False])
+ )
+ self._test_data(array)
+
+ # C data interface missing
+ @unittest.expectedFailure
+ def test_binary_fixed(self):
+ array = pyarrow.array(
+ [b"1111", b"2222", b"3333"],
+ pyarrow.binary(4),
+ numpy.array([False, True, False]),
+ )
+ self._test_data(array)
+
+ def test_large_binary(self):
+ array = pyarrow.array(
+ [b"1111", b"2222", b"3333"],
+ pyarrow.large_binary(),
+ numpy.array([False, True, False]),
+ )
+ self._test_data(array)
+
+ def test_binary_other(self):
+ self._test_data(data_binary_other())
+
+ def test_bool(self):
+ array = pyarrow.array(
+ [False, True, True], None, numpy.array([False, True, False])
+ )
+ self._test_data(array)
+
+ def test_u32(self):
+ array = pyarrow.array([0, 1, 2], None, numpy.array([False, True, False]))
+ self._test_data(array)
diff --git a/python/tests/test_udaf.py b/python/tests/test_udaf.py
new file mode 100644
index 0000000..ffd235e
--- /dev/null
+++ b/python/tests/test_udaf.py
@@ -0,0 +1,91 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import unittest
+
+import pyarrow
+import pyarrow.compute
+import datafusion
+
+f = datafusion.functions
+
+
+class Accumulator:
+ """
+ Interface of a user-defined accumulation.
+ """
+
+ def __init__(self):
+ self._sum = pyarrow.scalar(0.0)
+
+ def to_scalars(self) -> [pyarrow.Scalar]:
+ return [self._sum]
+
+ def update(self, values: pyarrow.Array) -> None:
+ # not nice since pyarrow scalars can't be summed yet. This breaks on `None`
+ self._sum = pyarrow.scalar(
+ self._sum.as_py() + pyarrow.compute.sum(values).as_py()
+ )
+
+ def merge(self, states: pyarrow.Array) -> None:
+ # not nice since pyarrow scalars can't be summed yet. This breaks on `None`
+ self._sum = pyarrow.scalar(
+ self._sum.as_py() + pyarrow.compute.sum(states).as_py()
+ )
+
+ def evaluate(self) -> pyarrow.Scalar:
+ return self._sum
+
+
+class TestCase(unittest.TestCase):
+ def _prepare(self):
+ ctx = datafusion.ExecutionContext()
+
+ # create a RecordBatch and a new DataFrame from it
+ batch = pyarrow.RecordBatch.from_arrays(
+ [pyarrow.array([1, 2, 3]), pyarrow.array([4, 4, 6])],
+ names=["a", "b"],
+ )
+ return ctx.create_dataframe([[batch]])
+
+ def test_aggregate(self):
+ df = self._prepare()
+
+ udaf = f.udaf(
+ Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()]
+ )
+
+ df = df.aggregate([], [udaf(f.col("a"))])
+
+ # execute and collect the first (and only) batch
+ result = df.collect()[0]
+
+ self.assertEqual(result.column(0), pyarrow.array([1.0 + 2.0 + 3.0]))
+
+ def test_group_by(self):
+ df = self._prepare()
+
+ udaf = f.udaf(
+ Accumulator, pyarrow.float64(), pyarrow.float64(), [pyarrow.float64()]
+ )
+
+ df = df.aggregate([f.col("b")], [udaf(f.col("a"))])
+
+ # execute and collect the first (and only) batch
+ result = df.collect()[0]
+
+ self.assertEqual(result.column(1), pyarrow.array([1.0 + 2.0, 3.0]))