You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ho...@apache.org on 2021/09/08 04:32:08 UTC

[arrow-datafusion] branch master updated: update datafusion to 5.1.0 for python binding (#967)

This is an automated email from the ASF dual-hosted git repository.

houqp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new bb616bf  update datafusion to 5.1.0 for python binding (#967)
bb616bf is described below

commit bb616bf94bd952da8cc0c18b20cf33c55c30459e
Author: QP Hou <qp...@scribd.com>
AuthorDate: Tue Sep 7 21:32:03 2021 -0700

    update datafusion to 5.1.0 for python binding (#967)
    
    * update datafusion to 5.1.0 for python binding
---
 python/Cargo.toml        |  8 ++++++--
 python/src/dataframe.rs  | 10 +++++++---
 python/tests/test_df.py  |  2 +-
 python/tests/test_sql.py | 10 +++++-----
 4 files changed, 19 insertions(+), 11 deletions(-)

diff --git a/python/Cargo.toml b/python/Cargo.toml
index 8dba538..8c81a53 100644
--- a/python/Cargo.toml
+++ b/python/Cargo.toml
@@ -16,7 +16,7 @@
 # under the License.
 
 [package]
-name = "datafusion"
+name = "datafusion-python"
 version = "0.3.0"
 homepage = "https://github.com/apache/arrow"
 repository = "https://github.com/apache/arrow"
@@ -31,7 +31,11 @@ libc = "0.2"
 tokio = { version = "1.0", features = ["macros", "rt", "rt-multi-thread", "sync"] }
 rand = "0.7"
 pyo3 = { version = "0.14.1", features = ["extension-module", "abi3", "abi3-py36"] }
-datafusion = { git = "https://github.com/apache/arrow-datafusion.git", rev = "4d61196dee8526998aee7e7bb10ea88422e5f9e1" }
+datafusion = { path = "../datafusion", version = "5.1.0" }
+# workaround for a bug introduced in
+# https://github.com/dtolnay/proc-macro2/pull/286
+# TODO: remove this version pin after upstream releases a fix
+proc-macro2 = { version = "=1.0.28" }
 
 [lib]
 name = "datafusion"
diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs
index 8e5657b..0885ae3 100644
--- a/python/src/dataframe.rs
+++ b/python/src/dataframe.rs
@@ -161,9 +161,13 @@ impl DataFrame {
         Ok(pretty::print_batches(&batches).unwrap())
     }
 
-
     /// Returns the join of two DataFrames `on`.
-    fn join(&self, right: &DataFrame, on: Vec<&str>, how: &str) -> PyResult<Self> {
+    fn join(
+        &self,
+        right: &DataFrame,
+        join_keys: (Vec<&str>, Vec<&str>),
+        how: &str,
+    ) -> PyResult<Self> {
         let builder = LogicalPlanBuilder::from(self.plan.clone());
 
         let join_type = match how {
@@ -182,7 +186,7 @@ impl DataFrame {
             }
         };
 
-        let builder = errors::wrap(builder.join(&right.plan, join_type, on.clone(), on))?;
+        let builder = errors::wrap(builder.join(&right.plan, join_type, join_keys))?;
 
         let plan = errors::wrap(builder.build())?;
 
diff --git a/python/tests/test_df.py b/python/tests/test_df.py
index 5b6cbdd..14ab5ff 100644
--- a/python/tests/test_df.py
+++ b/python/tests/test_df.py
@@ -104,7 +104,7 @@ def test_join():
     )
     df1 = ctx.create_dataframe([[batch]])
 
-    df = df.join(df1, on="a", how="inner")
+    df = df.join(df1, join_keys=(["a"], ["a"]), how="inner")
     df = df.sort([f.col("a").sort(ascending=True)])
     table = pa.Table.from_batches(df.collect())
 
diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py
index 669f640..beac578 100644
--- a/python/tests/test_sql.py
+++ b/python/tests/test_sql.py
@@ -69,7 +69,7 @@ def test_register_csv(ctx, tmp_path):
     for table in ["csv", "csv1", "csv2"]:
         result = ctx.sql(f"SELECT COUNT(int) FROM {table}").collect()
         result = pa.Table.from_batches(result)
-        assert result.to_pydict() == {"COUNT(int)": [4]}
+        assert result.to_pydict() == {f"COUNT({table}.int)": [4]}
 
     result = ctx.sql("SELECT * FROM csv3").collect()
     result = pa.Table.from_batches(result)
@@ -88,7 +88,7 @@ def test_register_parquet(ctx, tmp_path):
 
     result = ctx.sql("SELECT COUNT(a) FROM t").collect()
     result = pa.Table.from_batches(result)
-    assert result.to_pydict() == {"COUNT(a)": [100]}
+    assert result.to_pydict() == {"COUNT(t.a)": [100]}
 
 
 def test_execute(ctx, tmp_path):
@@ -123,8 +123,8 @@ def test_execute(ctx, tmp_path):
     result_values = []
     for result in results:
         pydict = result.to_pydict()
-        result_keys.extend(pydict["CAST(a AS Int32)"])
-        result_values.extend(pydict["COUNT(a)"])
+        result_keys.extend(pydict["CAST(t.a AS Int32)"])
+        result_values.extend(pydict["COUNT(t.a)"])
 
     result_keys, result_values = (
         list(t) for t in zip(*sorted(zip(result_keys, result_values)))
@@ -141,7 +141,7 @@ def test_execute(ctx, tmp_path):
     expected_cast = pa.array([50, 50], pa.int32())
     expected = [
         pa.RecordBatch.from_arrays(
-            [expected_a, expected_cast], ["a", "CAST(a AS Int32)"]
+            [expected_a, expected_cast], ["a", "CAST(t.a AS Int32)"]
         )
     ]
     np.testing.assert_equal(expected[0].column(1), expected[0].column(1))