You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/07/19 15:47:29 UTC

[arrow-datafusion] branch master updated: fix arrow type id mapping (#742)

This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new a4f6cdd  fix arrow type id mapping (#742)
a4f6cdd is described below

commit a4f6cdd64997617d3d32fac537615d19fa7cbe36
Author: Jiayu Liu <Ji...@users.noreply.github.com>
AuthorDate: Mon Jul 19 23:47:23 2021 +0800

    fix arrow type id mapping (#742)
---
 python/src/dataframe.rs               |  7 +----
 python/src/to_rust.rs                 |  1 +
 python/src/types.rs                   | 19 +++---------
 python/tests/test_pa_types.py         | 51 ++++++++++++++++++++++++++++++++
 python/tests/test_string_functions.py | 55 +++++++++++++++++++++++++++++++++++
 5 files changed, 112 insertions(+), 21 deletions(-)

diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs
index 89c85f9..4a50262 100644
--- a/python/src/dataframe.rs
+++ b/python/src/dataframe.rs
@@ -159,12 +159,7 @@ impl DataFrame {
             }
         };
 
-        let builder = errors::wrap(builder.join(
-            &right.plan,
-            join_type,
-            on.clone(),
-            on,
-        ))?;
+        let builder = errors::wrap(builder.join(&right.plan, join_type, on.clone(), on))?;
 
         let plan = errors::wrap(builder.build())?;
 
diff --git a/python/src/to_rust.rs b/python/src/to_rust.rs
index 2e3f7f0..e7957ec 100644
--- a/python/src/to_rust.rs
+++ b/python/src/to_rust.rs
@@ -48,6 +48,7 @@ pub fn to_rust(ob: &PyAny) -> PyResult<ArrayRef> {
     Ok(array)
 }
 
+/// converts a pyarrow batch into a RecordBatch
 pub fn to_rust_batch(batch: &PyAny) -> PyResult<RecordBatch> {
     let schema = batch.getattr("schema")?;
     let names = schema.getattr("names")?.extract::<Vec<String>>()?;
diff --git a/python/src/types.rs b/python/src/types.rs
index ffa822e..bd6ef0d 100644
--- a/python/src/types.rs
+++ b/python/src/types.rs
@@ -48,24 +48,13 @@ fn data_type_id(id: &i32) -> Result<DataType, errors::DataFusionError> {
         7 => DataType::Int32,
         8 => DataType::UInt64,
         9 => DataType::Int64,
-
         10 => DataType::Float16,
         11 => DataType::Float32,
         12 => DataType::Float64,
-
-        //13 => DataType::Decimal,
-
-        // 14 => DataType::Date32(),
-        // 15 => DataType::Date64(),
-        // 16 => DataType::Timestamp(),
-        // 17 => DataType::Time32(),
-        // 18 => DataType::Time64(),
-        // 19 => DataType::Duration()
-        20 => DataType::Binary,
-        21 => DataType::Utf8,
-        22 => DataType::LargeBinary,
-        23 => DataType::LargeUtf8,
-
+        13 => DataType::Utf8,
+        14 => DataType::Binary,
+        34 => DataType::LargeUtf8,
+        35 => DataType::LargeBinary,
         other => {
             return Err(errors::DataFusionError::Common(format!(
                 "The type {} is not valid",
diff --git a/python/tests/test_pa_types.py b/python/tests/test_pa_types.py
new file mode 100644
index 0000000..069343f
--- /dev/null
+++ b/python/tests/test_pa_types.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pyarrow as pa
+
+
+def test_type_ids():
+    """having this fixed is very important because internally we rely on this id to parse from
+    python"""
+    for idx, arrow_type in [
+        (0, pa.null()),
+        (1, pa.bool_()),
+        (2, pa.uint8()),
+        (3, pa.int8()),
+        (4, pa.uint16()),
+        (5, pa.int16()),
+        (6, pa.uint32()),
+        (7, pa.int32()),
+        (8, pa.uint64()),
+        (9, pa.int64()),
+        (10, pa.float16()),
+        (11, pa.float32()),
+        (12, pa.float64()),
+        (13, pa.string()),
+        (13, pa.utf8()),
+        (14, pa.binary()),
+        (16, pa.date32()),
+        (17, pa.date64()),
+        (18, pa.timestamp("us")),
+        (19, pa.time32("s")),
+        (20, pa.time64("us")),
+        (23, pa.decimal128(8, 1)),
+        (34, pa.large_utf8()),
+        (35, pa.large_binary()),
+    ]:
+
+        assert idx == arrow_type.id
diff --git a/python/tests/test_string_functions.py b/python/tests/test_string_functions.py
new file mode 100644
index 0000000..f8e1557
--- /dev/null
+++ b/python/tests/test_string_functions.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pyarrow as pa
+import pytest
+from datafusion import ExecutionContext
+from datafusion import functions as f
+
+
+@pytest.fixture
+def df():
+    ctx = ExecutionContext()
+
+    # create a RecordBatch and a new DataFrame from it
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])],
+        names=["a", "b"],
+    )
+
+    return ctx.create_dataframe([[batch]])
+
+
+def test_string_functions(df):
+    df = df.select(f.md5(f.col("a")), f.lower(f.col("a")))
+    result = df.collect()
+    assert len(result) == 1
+    result = result[0]
+    assert result.column(0) == pa.array(
+        [
+            "8b1a9953c4611296a827abf8c47804d7",
+            "f5a7924e621e84c9280a9a27e1bcb7f6",
+            "9033e0e305f247c0c3c80d0c7848c8b3",
+        ]
+    )
+    assert result.column(1) == pa.array(
+        [
+            "hello",
+            "world",
+            "!",
+        ]
+    )