You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by al...@apache.org on 2021/07/19 15:47:29 UTC
[arrow-datafusion] branch master updated: fix arrow type id mapping
(#742)
This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new a4f6cdd fix arrow type id mapping (#742)
a4f6cdd is described below
commit a4f6cdd64997617d3d32fac537615d19fa7cbe36
Author: Jiayu Liu <Ji...@users.noreply.github.com>
AuthorDate: Mon Jul 19 23:47:23 2021 +0800
fix arrow type id mapping (#742)
---
python/src/dataframe.rs | 7 +----
python/src/to_rust.rs | 1 +
python/src/types.rs | 19 +++---------
python/tests/test_pa_types.py | 51 ++++++++++++++++++++++++++++++++
python/tests/test_string_functions.py | 55 +++++++++++++++++++++++++++++++++++
5 files changed, 112 insertions(+), 21 deletions(-)
diff --git a/python/src/dataframe.rs b/python/src/dataframe.rs
index 89c85f9..4a50262 100644
--- a/python/src/dataframe.rs
+++ b/python/src/dataframe.rs
@@ -159,12 +159,7 @@ impl DataFrame {
}
};
- let builder = errors::wrap(builder.join(
- &right.plan,
- join_type,
- on.clone(),
- on,
- ))?;
+ let builder = errors::wrap(builder.join(&right.plan, join_type, on.clone(), on))?;
let plan = errors::wrap(builder.build())?;
diff --git a/python/src/to_rust.rs b/python/src/to_rust.rs
index 2e3f7f0..e7957ec 100644
--- a/python/src/to_rust.rs
+++ b/python/src/to_rust.rs
@@ -48,6 +48,7 @@ pub fn to_rust(ob: &PyAny) -> PyResult<ArrayRef> {
Ok(array)
}
+/// converts a pyarrow batch into a RecordBatch
pub fn to_rust_batch(batch: &PyAny) -> PyResult<RecordBatch> {
let schema = batch.getattr("schema")?;
let names = schema.getattr("names")?.extract::<Vec<String>>()?;
diff --git a/python/src/types.rs b/python/src/types.rs
index ffa822e..bd6ef0d 100644
--- a/python/src/types.rs
+++ b/python/src/types.rs
@@ -48,24 +48,13 @@ fn data_type_id(id: &i32) -> Result<DataType, errors::DataFusionError> {
7 => DataType::Int32,
8 => DataType::UInt64,
9 => DataType::Int64,
-
10 => DataType::Float16,
11 => DataType::Float32,
12 => DataType::Float64,
-
- //13 => DataType::Decimal,
-
- // 14 => DataType::Date32(),
- // 15 => DataType::Date64(),
- // 16 => DataType::Timestamp(),
- // 17 => DataType::Time32(),
- // 18 => DataType::Time64(),
- // 19 => DataType::Duration()
- 20 => DataType::Binary,
- 21 => DataType::Utf8,
- 22 => DataType::LargeBinary,
- 23 => DataType::LargeUtf8,
-
+ 13 => DataType::Utf8,
+ 14 => DataType::Binary,
+ 34 => DataType::LargeUtf8,
+ 35 => DataType::LargeBinary,
other => {
return Err(errors::DataFusionError::Common(format!(
"The type {} is not valid",
diff --git a/python/tests/test_pa_types.py b/python/tests/test_pa_types.py
new file mode 100644
index 0000000..069343f
--- /dev/null
+++ b/python/tests/test_pa_types.py
@@ -0,0 +1,51 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pyarrow as pa
+
+
+def test_type_ids():
+ """having this fixed is very important because internally we rely on this id to parse from
+ python"""
+ for idx, arrow_type in [
+ (0, pa.null()),
+ (1, pa.bool_()),
+ (2, pa.uint8()),
+ (3, pa.int8()),
+ (4, pa.uint16()),
+ (5, pa.int16()),
+ (6, pa.uint32()),
+ (7, pa.int32()),
+ (8, pa.uint64()),
+ (9, pa.int64()),
+ (10, pa.float16()),
+ (11, pa.float32()),
+ (12, pa.float64()),
+ (13, pa.string()),
+ (13, pa.utf8()),
+ (14, pa.binary()),
+ (16, pa.date32()),
+ (17, pa.date64()),
+ (18, pa.timestamp("us")),
+ (19, pa.time32("s")),
+ (20, pa.time64("us")),
+ (23, pa.decimal128(8, 1)),
+ (34, pa.large_utf8()),
+ (35, pa.large_binary()),
+ ]:
+
+ assert idx == arrow_type.id
diff --git a/python/tests/test_string_functions.py b/python/tests/test_string_functions.py
new file mode 100644
index 0000000..f8e1557
--- /dev/null
+++ b/python/tests/test_string_functions.py
@@ -0,0 +1,55 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import pyarrow as pa
+import pytest
+from datafusion import ExecutionContext
+from datafusion import functions as f
+
+
+@pytest.fixture
+def df():
+ ctx = ExecutionContext()
+
+ # create a RecordBatch and a new DataFrame from it
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array(["Hello", "World", "!"]), pa.array([4, 5, 6])],
+ names=["a", "b"],
+ )
+
+ return ctx.create_dataframe([[batch]])
+
+
+def test_string_functions(df):
+ df = df.select(f.md5(f.col("a")), f.lower(f.col("a")))
+ result = df.collect()
+ assert len(result) == 1
+ result = result[0]
+ assert result.column(0) == pa.array(
+ [
+ "8b1a9953c4611296a827abf8c47804d7",
+ "f5a7924e621e84c9280a9a27e1bcb7f6",
+ "9033e0e305f247c0c3c80d0c7848c8b3",
+ ]
+ )
+ assert result.column(1) == pa.array(
+ [
+ "hello",
+ "world",
+ "!",
+ ]
+ )