You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ks...@apache.org on 2021/07/02 10:32:30 UTC
[arrow-rs] branch master updated: Python FFI bridge for Schema,
Field and DataType (#439)
This is an automated email from the ASF dual-hosted git repository.
kszucs pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new 31bc052 Python FFI bridge for Schema, Field and DataType (#439)
31bc052 is described below
commit 31bc052126abc4834edf9b3cd7cb72384f84ba3e
Author: Krisztián Szűcs <sz...@gmail.com>
AuthorDate: Fri Jul 2 12:30:59 2021 +0200
Python FFI bridge for Schema, Field and DataType (#439)
* FFI bridge for Schema, Field and DataType
* Factor out conversion to datatypes/ffi.rs
* Add flags
* Rust tests
* Test datatypes from the python test suite
* Install a pinned nightly pyarrow wheel
* Python tests for Field and Schema
* Cleanup
* Remove comment
* cleanup
* Fix python tests after rebase
* fix clippy
Co-authored-by: Andrew Lamb <an...@nerdnetworks.org>
---
.github/workflows/integration.yml | 53 ++-
.github/workflows/rust.yml | 46 ---
arrow-pyarrow-integration-testing/src/lib.rs | 158 ++++++--
.../tests/test_sql.py | 398 +++++++++++++--------
arrow/Cargo.toml | 1 +
arrow/src/datatypes/ffi.rs | 359 +++++++++++++++++++
arrow/src/datatypes/mod.rs | 2 +
arrow/src/ffi.rs | 327 +++++------------
8 files changed, 870 insertions(+), 474 deletions(-)
diff --git a/.github/workflows/integration.yml b/.github/workflows/integration.yml
index cab6dd3..a713d05 100644
--- a/.github/workflows/integration.yml
+++ b/.github/workflows/integration.yml
@@ -23,7 +23,7 @@ on:
jobs:
- docker:
+ integration:
name: Integration Test
runs-on: ubuntu-latest
steps:
@@ -46,3 +46,54 @@ jobs:
run: pip install -e dev/archery[docker]
- name: Execute Docker Build
run: archery docker run -e ARCHERY_INTEGRATION_WITH_RUST=1 conda-integration
+
+ # test FFI against the C-Data interface exposed by pyarrow
+ pyarrow-integration-test:
+ name: Test Pyarrow C Data Interface
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ rust: [stable]
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ submodules: true
+ - name: Setup Rust toolchain
+ run: |
+ rustup toolchain install ${{ matrix.rust }}
+ rustup default ${{ matrix.rust }}
+ rustup component add rustfmt clippy
+ - name: Cache Cargo
+ uses: actions/cache@v2
+ with:
+ path: /home/runner/.cargo
+ key: cargo-maturin-cache-
+ - name: Cache Rust dependencies
+ uses: actions/cache@v2
+ with:
+ path: /home/runner/target
+ # this key is not equal because maturin uses different compilation flags.
+ key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}-
+ - uses: actions/setup-python@v2
+ with:
+ python-version: '3.7'
+ - name: Upgrade pip and setuptools
+ run: pip install --upgrade pip setuptools wheel
+ - name: Install python dependencies
+ run: pip install maturin==0.8.2 toml==0.10.1 pytest pytz
+ - name: Install nightly pyarrow wheel
+ # this points to a nightly pyarrow build containing neccessary
+ # API for integration testing (https://github.com/apache/arrow/pull/10529)
+ # the hardcoded version is wrong and should be removed either
+ # after https://issues.apache.org/jira/browse/ARROW-13083
+ # gets fixes or pyarrow 5.0 gets released
+ hardcoded version is wrong, bot contains
+ run: pip install --index-url https://pypi.fury.io/arrow-nightlies/ pyarrow==3.1.0.dev1030
+ - name: Run tests
+ env:
+ CARGO_HOME: "/home/runner/.cargo"
+ CARGO_TARGET_DIR: "/home/runner/target"
+ working-directory: arrow-pyarrow-integration-testing
+ run: |
+ maturin develop
+ pytest -v .
diff --git a/.github/workflows/rust.yml b/.github/workflows/rust.yml
index 559c7c8..a041afc 100644
--- a/.github/workflows/rust.yml
+++ b/.github/workflows/rust.yml
@@ -283,52 +283,6 @@ jobs:
continue-on-error: true
run: bash <(curl -s https://codecov.io/bash)
- # test FFI against the C-Data interface exposed by pyarrow
- pyarrow-integration-test:
- name: Test Pyarrow C Data Interface
- runs-on: ubuntu-latest
- strategy:
- matrix:
- rust: [stable]
- steps:
- - uses: actions/checkout@v2
- with:
- submodules: true
- - name: Setup Rust toolchain
- run: |
- rustup toolchain install ${{ matrix.rust }}
- rustup default ${{ matrix.rust }}
- rustup component add rustfmt clippy
- - name: Cache Cargo
- uses: actions/cache@v2
- with:
- path: /home/runner/.cargo
- key: cargo-maturin-cache-
- - name: Cache Rust dependencies
- uses: actions/cache@v2
- with:
- path: /home/runner/target
- # this key is not equal because maturin uses different compilation flags.
- key: ${{ runner.os }}-${{ matrix.arch }}-target-maturin-cache-${{ matrix.rust }}-
- - uses: actions/setup-python@v2
- with:
- python-version: '3.7'
- - name: Install Python dependencies
- run: python -m pip install --upgrade pip setuptools wheel
- - name: Run tests
- run: |
- export CARGO_HOME="/home/runner/.cargo"
- export CARGO_TARGET_DIR="/home/runner/target"
-
- cd arrow-pyarrow-integration-testing
-
- python -m venv venv
- source venv/bin/activate
-
- pip install maturin==0.8.2 toml==0.10.1 pyarrow==1.0.0 pytz
- maturin develop
- python -m unittest discover tests
-
# test the arrow crate builds against wasm32 in stable rust
wasm32-build:
name: Build wasm32 on AMD64 Rust ${{ matrix.rust }}
diff --git a/arrow-pyarrow-integration-testing/src/lib.rs b/arrow-pyarrow-integration-testing/src/lib.rs
index 5b5462d..a601654 100644
--- a/arrow-pyarrow-integration-testing/src/lib.rs
+++ b/arrow-pyarrow-integration-testing/src/lib.rs
@@ -18,6 +18,7 @@
//! This library demonstrates a minimal usage of Rust's C data interface to pass
//! arrays from and to Python.
+use std::convert::TryFrom;
use std::error;
use std::fmt;
use std::sync::Arc;
@@ -28,8 +29,10 @@ use pyo3::{libc::uintptr_t, prelude::*};
use arrow::array::{make_array_from_raw, ArrayRef, Int64Array};
use arrow::compute::kernels;
+use arrow::datatypes::{DataType, Field, Schema};
use arrow::error::ArrowError;
use arrow::ffi;
+use arrow::ffi::FFI_ArrowSchema;
/// an error that bridges ArrowError with a Python error
#[derive(Debug)]
@@ -68,7 +71,107 @@ impl From<PyO3ArrowError> for PyErr {
}
}
-fn to_rust(ob: PyObject, py: Python) -> PyResult<ArrayRef> {
+#[pyclass]
+struct PyDataType {
+ inner: DataType,
+}
+
+#[pyclass]
+struct PyField {
+ inner: Field,
+}
+
+#[pyclass]
+struct PySchema {
+ inner: Schema,
+}
+
+#[pymethods]
+impl PyDataType {
+ #[staticmethod]
+ fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ let c_schema = FFI_ArrowSchema::empty();
+ let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
+ value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?;
+ let dtype = DataType::try_from(&c_schema).map_err(PyO3ArrowError::from)?;
+ Ok(Self { inner: dtype })
+ }
+
+ fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
+ let c_schema =
+ FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?;
+ let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
+ let module = py.import("pyarrow")?;
+ let class = module.getattr("DataType")?;
+ let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?;
+ Ok(dtype.into())
+ }
+}
+
+#[pymethods]
+impl PyField {
+ #[staticmethod]
+ fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ let c_schema = FFI_ArrowSchema::empty();
+ let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
+ value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?;
+ let field = Field::try_from(&c_schema).map_err(PyO3ArrowError::from)?;
+ Ok(Self { inner: field })
+ }
+
+ fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
+ let c_schema =
+ FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?;
+ let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
+ let module = py.import("pyarrow")?;
+ let class = module.getattr("Field")?;
+ let dtype = class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?;
+ Ok(dtype.into())
+ }
+}
+
+#[pymethods]
+impl PySchema {
+ #[staticmethod]
+ fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ let c_schema = FFI_ArrowSchema::empty();
+ let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
+ value.call_method1("_export_to_c", (c_schema_ptr as uintptr_t,))?;
+ let schema = Schema::try_from(&c_schema).map_err(PyO3ArrowError::from)?;
+ Ok(Self { inner: schema })
+ }
+
+ fn to_pyarrow(&self, py: Python) -> PyResult<PyObject> {
+ let c_schema =
+ FFI_ArrowSchema::try_from(&self.inner).map_err(PyO3ArrowError::from)?;
+ let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
+ let module = py.import("pyarrow")?;
+ let class = module.getattr("Schema")?;
+ let schema =
+ class.call_method1("_import_from_c", (c_schema_ptr as uintptr_t,))?;
+ Ok(schema.into())
+ }
+}
+
+impl<'source> FromPyObject<'source> for PyDataType {
+ fn extract(value: &'source PyAny) -> PyResult<Self> {
+ PyDataType::from_pyarrow(value)
+ }
+}
+
+impl<'source> FromPyObject<'source> for PyField {
+ fn extract(value: &'source PyAny) -> PyResult<Self> {
+ PyField::from_pyarrow(value)
+ }
+}
+
+impl<'source> FromPyObject<'source> for PySchema {
+ fn extract(value: &'source PyAny) -> PyResult<Self> {
+ PySchema::from_pyarrow(value)
+ }
+}
+
+fn array_to_rust(ob: PyObject, py: Python) -> PyResult<ArrayRef> {
// prepare a pointer to receive the Array struct
let (array_pointer, schema_pointer) =
ffi::ArrowArray::into_raw(unsafe { ffi::ArrowArray::empty() });
@@ -82,13 +185,12 @@ fn to_rust(ob: PyObject, py: Python) -> PyResult<ArrayRef> {
)?;
let array = unsafe { make_array_from_raw(array_pointer, schema_pointer) }
- .map_err(|e| PyO3ArrowError::from(e))?;
+ .map_err(PyO3ArrowError::from)?;
Ok(array)
}
-fn to_py(array: ArrayRef, py: Python) -> PyResult<PyObject> {
- let (array_pointer, schema_pointer) =
- array.to_raw().map_err(|e| PyO3ArrowError::from(e))?;
+fn array_to_py(array: ArrayRef, py: Python) -> PyResult<PyObject> {
+ let (array_pointer, schema_pointer) = array.to_raw().map_err(PyO3ArrowError::from)?;
let pa = py.import("pyarrow")?;
@@ -103,22 +205,17 @@ fn to_py(array: ArrayRef, py: Python) -> PyResult<PyObject> {
#[pyfunction]
fn double(array: PyObject, py: Python) -> PyResult<PyObject> {
// import
- let array = to_rust(array, py)?;
+ let array = array_to_rust(array, py)?;
// perform some operation
- let array =
- array
- .as_any()
- .downcast_ref::<Int64Array>()
- .ok_or(PyO3ArrowError::ArrowError(ArrowError::ParseError(
- "Expects an int64".to_string(),
- )))?;
- let array =
- kernels::arithmetic::add(&array, &array).map_err(|e| PyO3ArrowError::from(e))?;
+ let array = array.as_any().downcast_ref::<Int64Array>().ok_or_else(|| {
+ PyO3ArrowError::ArrowError(ArrowError::ParseError("Expects an int64".to_string()))
+ })?;
+ let array = kernels::arithmetic::add(&array, &array).map_err(PyO3ArrowError::from)?;
let array = Arc::new(array);
// export
- to_py(array, py)
+ array_to_py(array, py)
}
/// calls a lambda function that receives and returns an array
@@ -130,11 +227,9 @@ fn double_py(lambda: PyObject, py: Python) -> PyResult<bool> {
let expected = Arc::new(Int64Array::from(vec![Some(2), None, Some(6)])) as ArrayRef;
// to py
- let array = to_py(array, py)?;
-
- let array = lambda.call1(py, (array,))?;
-
- let array = to_rust(array, py)?;
+ let pyarray = array_to_py(array, py)?;
+ let pyarray = lambda.call1(py, (pyarray,))?;
+ let array = array_to_rust(pyarray, py)?;
Ok(array == expected)
}
@@ -143,42 +238,45 @@ fn double_py(lambda: PyObject, py: Python) -> PyResult<bool> {
#[pyfunction]
fn substring(array: PyObject, start: i64, py: Python) -> PyResult<PyObject> {
// import
- let array = to_rust(array, py)?;
+ let array = array_to_rust(array, py)?;
// substring
let array = kernels::substring::substring(array.as_ref(), start, &None)
- .map_err(|e| PyO3ArrowError::from(e))?;
+ .map_err(PyO3ArrowError::from)?;
// export
- to_py(array, py)
+ array_to_py(array, py)
}
/// Returns the concatenate
#[pyfunction]
fn concatenate(array: PyObject, py: Python) -> PyResult<PyObject> {
// import
- let array = to_rust(array, py)?;
+ let array = array_to_rust(array, py)?;
// concat
let array = kernels::concat::concat(&[array.as_ref(), array.as_ref()])
- .map_err(|e| PyO3ArrowError::from(e))?;
+ .map_err(PyO3ArrowError::from)?;
// export
- to_py(array, py)
+ array_to_py(array, py)
}
/// Converts to rust and back to python
#[pyfunction]
-fn round_trip(array: PyObject, py: Python) -> PyResult<PyObject> {
+fn round_trip(pyarray: PyObject, py: Python) -> PyResult<PyObject> {
// import
- let array = to_rust(array, py)?;
+ let array = array_to_rust(pyarray, py)?;
// export
- to_py(array, py)
+ array_to_py(array, py)
}
#[pymodule]
fn arrow_pyarrow_integration_testing(_py: Python, m: &PyModule) -> PyResult<()> {
+ m.add_class::<PyDataType>()?;
+ m.add_class::<PyField>()?;
+ m.add_class::<PySchema>()?;
m.add_wrapped(wrap_pyfunction!(double))?;
m.add_wrapped(wrap_pyfunction!(double_py))?;
m.add_wrapped(wrap_pyfunction!(substring))?;
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py
index 5524c54..301eac8 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -16,156 +16,252 @@
# specific language governing permissions and limitations
# under the License.
-import unittest
-from datetime import date, datetime
-from decimal import Decimal
-
-import arrow_pyarrow_integration_testing
-import pyarrow
-from pytz import timezone
-
-
-class TestCase(unittest.TestCase):
- def test_primitive_python(self):
- """
- Python -> Rust -> Python
- """
- old_allocated = pyarrow.total_allocated_bytes()
- a = pyarrow.array([1, 2, 3])
- b = arrow_pyarrow_integration_testing.double(a)
- self.assertEqual(b, pyarrow.array([2, 4, 6]))
- del a
- del b
- # No leak of C++ memory
- self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
- def test_primitive_rust(self):
- """
- Rust -> Python -> Rust
- """
- old_allocated = pyarrow.total_allocated_bytes()
-
- def double(array):
- array = array.to_pylist()
- return pyarrow.array([x * 2 if x is not None else None for x in array])
-
- is_correct = arrow_pyarrow_integration_testing.double_py(double)
- self.assertTrue(is_correct)
- # No leak of C++ memory
- self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
- def test_string_python(self):
- """
- Python -> Rust -> Python
- """
- old_allocated = pyarrow.total_allocated_bytes()
- a = pyarrow.array(["a", None, "ccc"])
- b = arrow_pyarrow_integration_testing.substring(a, 1)
- self.assertEqual(b, pyarrow.array(["", None, "cc"]))
- del a
- del b
- # No leak of C++ memory
- self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
- def test_time32_python(self):
- """
- Python -> Rust -> Python
- """
- old_allocated = pyarrow.total_allocated_bytes()
- a = pyarrow.array([None, 1, 2], pyarrow.time32("s"))
- b = arrow_pyarrow_integration_testing.concatenate(a)
- expected = pyarrow.array([None, 1, 2] + [None, 1, 2], pyarrow.time32("s"))
- self.assertEqual(b, expected)
- del a
- del b
- del expected
- # No leak of C++ memory
- self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
- def test_date32_python(self):
- """
- Python -> Rust -> Python
- """
- old_allocated = pyarrow.total_allocated_bytes()
- py_array = [None, date(1990, 3, 9), date(2021, 6, 20)]
- a = pyarrow.array(py_array, pyarrow.date32())
- b = arrow_pyarrow_integration_testing.concatenate(a)
- expected = pyarrow.array(py_array + py_array, pyarrow.date32())
- self.assertEqual(b, expected)
- del a
- del b
- del expected
- # No leak of C++ memory
- self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
- def test_timestamp_python(self):
- """
- Python -> Rust -> Python
- """
- old_allocated = pyarrow.total_allocated_bytes()
- py_array = [
- None,
- datetime(2021, 1, 1, 1, 1, 1, 1),
- datetime(2020, 3, 9, 1, 1, 1, 1),
+import contextlib
+import datetime
+import decimal
+import string
+
+import pytest
+import pyarrow as pa
+import pytz
+
+from arrow_pyarrow_integration_testing import PyDataType, PyField, PySchema
+import arrow_pyarrow_integration_testing as rust
+
+
+@contextlib.contextmanager
+def no_pyarrow_leak():
+ # No leak of C++ memory
+ old_allocation = pa.total_allocated_bytes()
+ try:
+ yield
+ finally:
+ assert pa.total_allocated_bytes() == old_allocation
+
+
+@pytest.fixture(autouse=True)
+def assert_pyarrow_leak():
+ # automatically applied to all test cases
+ with no_pyarrow_leak():
+ yield
+
+
+_supported_pyarrow_types = [
+ pa.null(),
+ pa.bool_(),
+ pa.int32(),
+ pa.time32("s"),
+ pa.time64("us"),
+ pa.date32(),
+ pa.timestamp("us"),
+ pa.timestamp("us", tz="UTC"),
+ pa.timestamp("us", tz="Europe/Paris"),
+ pa.float16(),
+ pa.float32(),
+ pa.float64(),
+ pa.decimal128(19, 4),
+ pa.string(),
+ pa.binary(),
+ pa.large_string(),
+ pa.large_binary(),
+ pa.list_(pa.int32()),
+ pa.large_list(pa.uint16()),
+ pa.struct(
+ [
+ pa.field("a", pa.int32()),
+ pa.field("b", pa.int8()),
+ pa.field("c", pa.string()),
]
- a = pyarrow.array(py_array, pyarrow.timestamp("us"))
- b = arrow_pyarrow_integration_testing.concatenate(a)
- expected = pyarrow.array(py_array + py_array, pyarrow.timestamp("us"))
- self.assertEqual(b, expected)
- del a
- del b
- del expected
- # No leak of C++ memory
- self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
- def test_timestamp_tz_python(self):
- """
- Python -> Rust -> Python
- """
- old_allocated = pyarrow.total_allocated_bytes()
- py_array = [
- None,
- datetime(2021, 1, 1, 1, 1, 1, 1, tzinfo=timezone("America/New_York")),
- datetime(2020, 3, 9, 1, 1, 1, 1, tzinfo=timezone("America/New_York")),
+ ),
+ pa.struct(
+ [
+ pa.field("a", pa.int32(), nullable=False),
+ pa.field("b", pa.int8(), nullable=False),
+ pa.field("c", pa.string()),
]
- a = pyarrow.array(py_array, pyarrow.timestamp("us", tz="America/New_York"))
- b = arrow_pyarrow_integration_testing.concatenate(a)
- expected = pyarrow.array(
- py_array + py_array, pyarrow.timestamp("us", tz="America/New_York")
- )
- self.assertEqual(b, expected)
- del a
- del b
- del expected
- # No leak of C++ memory
- self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
- def test_decimal_python(self):
- """
- Python -> Rust -> Python
- """
- old_allocated = pyarrow.total_allocated_bytes()
- py_array = [round(Decimal(123.45), 2), round(Decimal(-123.45), 2), None]
- a = pyarrow.array(py_array, pyarrow.decimal128(6, 2))
- b = arrow_pyarrow_integration_testing.round_trip(a)
- self.assertEqual(a, b)
- del a
- del b
- # No leak of C++ memory
- self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
-
- def test_list_array(self):
- """
- Python -> Rust -> Python
- """
- old_allocated = pyarrow.total_allocated_bytes()
- a = pyarrow.array([[], None, [1, 2], [4, 5, 6]], pyarrow.list_(pyarrow.int64()))
- b = arrow_pyarrow_integration_testing.round_trip(a)
-
- b.validate(full=True)
- assert a.to_pylist() == b.to_pylist()
- assert a.type == b.type
- del a
- del b
- # No leak of C++ memory
- self.assertEqual(old_allocated, pyarrow.total_allocated_bytes())
+ ),
+]
+
+_unsupported_pyarrow_types = [
+ pa.decimal256(76, 38),
+ pa.duration("s"),
+ pa.binary(10),
+ pa.list_(pa.int32(), 2),
+ pa.map_(pa.string(), pa.int32()),
+ pa.union(
+ [pa.field("a", pa.binary(10)), pa.field("b", pa.string())],
+ mode=pa.lib.UnionMode_DENSE,
+ ),
+ pa.union(
+ [pa.field("a", pa.binary(10)), pa.field("b", pa.string())],
+ mode=pa.lib.UnionMode_DENSE,
+ type_codes=[4, 8],
+ ),
+ pa.union(
+ [pa.field("a", pa.binary(10)), pa.field("b", pa.string())],
+ mode=pa.lib.UnionMode_SPARSE,
+ ),
+ pa.union(
+ [
+ pa.field("a", pa.binary(10), nullable=False),
+ pa.field("b", pa.string()),
+ ],
+ mode=pa.lib.UnionMode_SPARSE,
+ ),
+]
+
+
+@pytest.mark.parametrize("pyarrow_type", _supported_pyarrow_types, ids=str)
+def test_type_roundtrip(pyarrow_type):
+ ty = PyDataType.from_pyarrow(pyarrow_type)
+ restored = ty.to_pyarrow()
+ assert restored == pyarrow_type
+ assert restored is not pyarrow_type
+
+
+@pytest.mark.parametrize("pyarrow_type", _unsupported_pyarrow_types, ids=str)
+def test_type_roundtrip_raises(pyarrow_type):
+ with pytest.raises(Exception):
+ PyDataType.from_pyarrow(pyarrow_type)
+
+
+def test_dictionary_type_roundtrip():
+ # the dictionary type conversion is incomplete
+ pyarrow_type = pa.dictionary(pa.int32(), pa.string())
+ ty = PyDataType.from_pyarrow(pyarrow_type)
+ assert ty.to_pyarrow() == pa.int32()
+
+
+@pytest.mark.parametrize('pyarrow_type', _supported_pyarrow_types, ids=str)
+def test_field_roundtrip(pyarrow_type):
+ pyarrow_field = pa.field("test", pyarrow_type, nullable=True)
+ field = PyField.from_pyarrow(pyarrow_field)
+ assert field.to_pyarrow() == pyarrow_field
+
+ if pyarrow_type != pa.null():
+ # A null type field may not be non-nullable
+ pyarrow_field = pa.field("test", pyarrow_type, nullable=False)
+ field = PyField.from_pyarrow(pyarrow_field)
+ assert field.to_pyarrow() == pyarrow_field
+
+
+def test_schema_roundtrip():
+ pyarrow_fields = zip(string.ascii_lowercase, _supported_pyarrow_types)
+ pyarrow_schema = pa.schema(pyarrow_fields)
+ schema = PySchema.from_pyarrow(pyarrow_schema)
+ assert schema.to_pyarrow() == pyarrow_schema
+
+
+def test_primitive_python():
+ """
+ Python -> Rust -> Python
+ """
+ a = pa.array([1, 2, 3])
+ b = rust.double(a)
+ assert b == pa.array([2, 4, 6])
+ del a
+ del b
+
+
+def test_primitive_rust():
+ """
+ Rust -> Python -> Rust
+ """
+
+ def double(array):
+ array = array.to_pylist()
+ return pa.array([x * 2 if x is not None else None for x in array])
+
+ is_correct = rust.double_py(double)
+ assert is_correct
+
+
+def test_string_python():
+ """
+ Python -> Rust -> Python
+ """
+ a = pa.array(["a", None, "ccc"])
+ b = rust.substring(a, 1)
+ assert b == pa.array(["", None, "cc"])
+ del a
+ del b
+
+
+def test_time32_python():
+ """
+ Python -> Rust -> Python
+ """
+ a = pa.array([None, 1, 2], pa.time32("s"))
+ b = rust.concatenate(a)
+ expected = pa.array([None, 1, 2] + [None, 1, 2], pa.time32("s"))
+ assert b == expected
+ del a
+ del b
+ del expected
+
+
+def test_list_array():
+ """
+ Python -> Rust -> Python
+ """
+ a = pa.array([[], None, [1, 2], [4, 5, 6]], pa.list_(pa.int64()))
+ b = rust.round_trip(a)
+ b.validate(full=True)
+ assert a.to_pylist() == b.to_pylist()
+ assert a.type == b.type
+ del a
+ del b
+
+
+def test_timestamp_python():
+ """
+ Python -> Rust -> Python
+ """
+ data = [
+ None,
+ datetime.datetime(2021, 1, 1, 1, 1, 1, 1),
+ datetime.datetime(2020, 3, 9, 1, 1, 1, 1),
+ ]
+ a = pa.array(data, pa.timestamp("us"))
+ b = rust.concatenate(a)
+ expected = pa.array(data + data, pa.timestamp("us"))
+ assert b == expected
+ del a
+ del b
+ del expected
+
+
+def test_timestamp_tz_python():
+ """
+ Python -> Rust -> Python
+ """
+ tzinfo = pytz.timezone("America/New_York")
+ pyarrow_type = pa.timestamp("us", tz="America/New_York")
+ data = [
+ None,
+ datetime.datetime(2021, 1, 1, 1, 1, 1, 1, tzinfo=tzinfo),
+ datetime.datetime(2020, 3, 9, 1, 1, 1, 1, tzinfo=tzinfo),
+ ]
+ a = pa.array(data, type=pyarrow_type)
+ b = rust.concatenate(a)
+ expected = pa.array(data * 2, type=pyarrow_type)
+ assert b == expected
+ del a
+ del b
+ del expected
+
+
+def test_decimal_python():
+ """
+ Python -> Rust -> Python
+ """
+ data = [
+ round(decimal.Decimal(123.45), 2),
+ round(decimal.Decimal(-123.45), 2),
+ None
+ ]
+ a = pa.array(data, pa.decimal128(6, 2))
+ b = rust.round_trip(a)
+ assert a == b
+ del a
+ del b
diff --git a/arrow/Cargo.toml b/arrow/Cargo.toml
index 0ed2a45..4a1016a 100644
--- a/arrow/Cargo.toml
+++ b/arrow/Cargo.toml
@@ -52,6 +52,7 @@ hex = "0.4"
prettytable-rs = { version = "0.8.0", optional = true }
lexical-core = "^0.7"
multiversion = "0.6.1"
+bitflags = "1.2.1"
[features]
default = ["csv", "ipc"]
diff --git a/arrow/src/datatypes/ffi.rs b/arrow/src/datatypes/ffi.rs
new file mode 100644
index 0000000..7e98508
--- /dev/null
+++ b/arrow/src/datatypes/ffi.rs
@@ -0,0 +1,359 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::convert::TryFrom;
+
+use crate::{
+ datatypes::{DataType, Field, Schema, TimeUnit},
+ error::{ArrowError, Result},
+ ffi::{FFI_ArrowSchema, Flags},
+};
+
+impl TryFrom<&FFI_ArrowSchema> for DataType {
+ type Error = ArrowError;
+
+ /// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings
+ fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self> {
+ let dtype = match c_schema.format() {
+ "n" => DataType::Null,
+ "b" => DataType::Boolean,
+ "c" => DataType::Int8,
+ "C" => DataType::UInt8,
+ "s" => DataType::Int16,
+ "S" => DataType::UInt16,
+ "i" => DataType::Int32,
+ "I" => DataType::UInt32,
+ "l" => DataType::Int64,
+ "L" => DataType::UInt64,
+ "e" => DataType::Float16,
+ "f" => DataType::Float32,
+ "g" => DataType::Float64,
+ "z" => DataType::Binary,
+ "Z" => DataType::LargeBinary,
+ "u" => DataType::Utf8,
+ "U" => DataType::LargeUtf8,
+ "tdD" => DataType::Date32,
+ "tdm" => DataType::Date64,
+ "tts" => DataType::Time32(TimeUnit::Second),
+ "ttm" => DataType::Time32(TimeUnit::Millisecond),
+ "ttu" => DataType::Time64(TimeUnit::Microsecond),
+ "ttn" => DataType::Time64(TimeUnit::Nanosecond),
+ "+l" => {
+ let c_child = c_schema.child(0);
+ DataType::List(Box::new(Field::try_from(c_child)?))
+ }
+ "+L" => {
+ let c_child = c_schema.child(0);
+ DataType::LargeList(Box::new(Field::try_from(c_child)?))
+ }
+ "+s" => {
+ let fields = c_schema.children().map(Field::try_from);
+ DataType::Struct(fields.collect::<Result<Vec<_>>>()?)
+ }
+ // Parametrized types, requiring string parse
+ other => {
+ match other.splitn(2, ':').collect::<Vec<&str>>().as_slice() {
+ // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth"
+ ["d", extra] => {
+ match extra.splitn(3, ',').collect::<Vec<&str>>().as_slice() {
+ [precision, scale] => {
+ let parsed_precision = precision.parse::<usize>().map_err(|_| {
+ ArrowError::CDataInterface(
+ "The decimal type requires an integer precision".to_string(),
+ )
+ })?;
+ let parsed_scale = scale.parse::<usize>().map_err(|_| {
+ ArrowError::CDataInterface(
+ "The decimal type requires an integer scale".to_string(),
+ )
+ })?;
+ DataType::Decimal(parsed_precision, parsed_scale)
+ },
+ [precision, scale, bits] => {
+ if *bits != "128" {
+ return Err(ArrowError::CDataInterface("Only 128 bit wide decimal is supported in the Rust implementation".to_string()));
+ }
+ let parsed_precision = precision.parse::<usize>().map_err(|_| {
+ ArrowError::CDataInterface(
+ "The decimal type requires an integer precision".to_string(),
+ )
+ })?;
+ let parsed_scale = scale.parse::<usize>().map_err(|_| {
+ ArrowError::CDataInterface(
+ "The decimal type requires an integer scale".to_string(),
+ )
+ })?;
+ DataType::Decimal(parsed_precision, parsed_scale)
+ }
+ _ => {
+ return Err(ArrowError::CDataInterface(format!(
+ "The decimal pattern \"d:{:?}\" is not supported in the Rust implementation",
+ extra
+ )))
+ }
+ }
+ }
+
+ // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp.
+ ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None),
+ ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None),
+ ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None),
+ ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None),
+ ["tss", tz] => {
+ DataType::Timestamp(TimeUnit::Second, Some(tz.to_string()))
+ }
+ ["tsm", tz] => {
+ DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string()))
+ }
+ ["tsu", tz] => {
+ DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string()))
+ }
+ ["tsn", tz] => {
+ DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string()))
+ }
+ _ => {
+ return Err(ArrowError::CDataInterface(format!(
+ "The datatype \"{:?}\" is still not supported in Rust implementation",
+ other
+ )))
+ }
+ }
+ }
+ };
+ Ok(dtype)
+ }
+}
+
+impl TryFrom<&FFI_ArrowSchema> for Field {
+ type Error = ArrowError;
+
+ fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self> {
+ let dtype = DataType::try_from(c_schema)?;
+ let field = Field::new(c_schema.name(), dtype, c_schema.nullable());
+ Ok(field)
+ }
+}
+
+impl TryFrom<&FFI_ArrowSchema> for Schema {
+ type Error = ArrowError;
+
+ fn try_from(c_schema: &FFI_ArrowSchema) -> Result<Self> {
+ // interpret it as a struct type then extract its fields
+ let dtype = DataType::try_from(c_schema)?;
+ if let DataType::Struct(fields) = dtype {
+ Ok(Schema::new(fields))
+ } else {
+ Err(ArrowError::CDataInterface(
+ "Unable to interpret C data struct as a Schema".to_string(),
+ ))
+ }
+ }
+}
+
+impl TryFrom<&DataType> for FFI_ArrowSchema {
+ type Error = ArrowError;
+
+ /// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings
+ fn try_from(dtype: &DataType) -> Result<Self> {
+ let format = match dtype {
+ DataType::Null => "n".to_string(),
+ DataType::Boolean => "b".to_string(),
+ DataType::Int8 => "c".to_string(),
+ DataType::UInt8 => "C".to_string(),
+ DataType::Int16 => "s".to_string(),
+ DataType::UInt16 => "S".to_string(),
+ DataType::Int32 => "i".to_string(),
+ DataType::UInt32 => "I".to_string(),
+ DataType::Int64 => "l".to_string(),
+ DataType::UInt64 => "L".to_string(),
+ DataType::Float16 => "e".to_string(),
+ DataType::Float32 => "f".to_string(),
+ DataType::Float64 => "g".to_string(),
+ DataType::Binary => "z".to_string(),
+ DataType::LargeBinary => "Z".to_string(),
+ DataType::Utf8 => "u".to_string(),
+ DataType::LargeUtf8 => "U".to_string(),
+ DataType::Decimal(precision, scale) => format!("d:{},{}", precision, scale),
+ DataType::Date32 => "tdD".to_string(),
+ DataType::Date64 => "tdm".to_string(),
+ DataType::Time32(TimeUnit::Second) => "tts".to_string(),
+ DataType::Time32(TimeUnit::Millisecond) => "ttm".to_string(),
+ DataType::Time64(TimeUnit::Microsecond) => "ttu".to_string(),
+ DataType::Time64(TimeUnit::Nanosecond) => "ttn".to_string(),
+ DataType::Timestamp(TimeUnit::Second, None) => "tss:".to_string(),
+ DataType::Timestamp(TimeUnit::Millisecond, None) => "tsm:".to_string(),
+ DataType::Timestamp(TimeUnit::Microsecond, None) => "tsu:".to_string(),
+ DataType::Timestamp(TimeUnit::Nanosecond, None) => "tsn:".to_string(),
+ DataType::Timestamp(TimeUnit::Second, Some(tz)) => format!("tss:{}", tz),
+ DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => format!("tsm:{}", tz),
+ DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => format!("tsu:{}", tz),
+ DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => format!("tsn:{}", tz),
+ DataType::List(_) => "+l".to_string(),
+ DataType::LargeList(_) => "+L".to_string(),
+ DataType::Struct(_) => "+s".to_string(),
+ other => {
+ return Err(ArrowError::CDataInterface(format!(
+ "The datatype \"{:?}\" is still not supported in Rust implementation",
+ other
+ )))
+ }
+ };
+ // allocate and hold the children
+ let children = match dtype {
+ DataType::List(child) | DataType::LargeList(child) => {
+ vec![FFI_ArrowSchema::try_from(child.as_ref())?]
+ }
+ DataType::Struct(fields) => fields
+ .iter()
+ .map(FFI_ArrowSchema::try_from)
+ .collect::<Result<Vec<_>>>()?,
+ _ => vec![],
+ };
+ FFI_ArrowSchema::try_new(&format, children)
+ }
+}
+
+impl TryFrom<&Field> for FFI_ArrowSchema {
+ type Error = ArrowError;
+
+ fn try_from(field: &Field) -> Result<Self> {
+ let flags = if field.is_nullable() {
+ Flags::NULLABLE
+ } else {
+ Flags::empty()
+ };
+ FFI_ArrowSchema::try_from(field.data_type())?
+ .with_name(field.name())?
+ .with_flags(flags)
+ }
+}
+
+impl TryFrom<&Schema> for FFI_ArrowSchema {
+ type Error = ArrowError;
+
+ fn try_from(schema: &Schema) -> Result<Self> {
+ let dtype = DataType::Struct(schema.fields().clone());
+ let c_schema = FFI_ArrowSchema::try_from(&dtype)?;
+ Ok(c_schema)
+ }
+}
+
+impl TryFrom<DataType> for FFI_ArrowSchema {
+ type Error = ArrowError;
+
+ fn try_from(dtype: DataType) -> Result<Self> {
+ FFI_ArrowSchema::try_from(&dtype)
+ }
+}
+
+impl TryFrom<Field> for FFI_ArrowSchema {
+ type Error = ArrowError;
+
+ fn try_from(field: Field) -> Result<Self> {
+ FFI_ArrowSchema::try_from(&field)
+ }
+}
+
+impl TryFrom<Schema> for FFI_ArrowSchema {
+ type Error = ArrowError;
+
+ fn try_from(schema: Schema) -> Result<Self> {
+ FFI_ArrowSchema::try_from(&schema)
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::datatypes::{DataType, Field, TimeUnit};
+ use crate::error::Result;
+ use std::convert::TryFrom;
+
+ fn round_trip_type(dtype: DataType) -> Result<()> {
+ let c_schema = FFI_ArrowSchema::try_from(&dtype)?;
+ let restored = DataType::try_from(&c_schema)?;
+ assert_eq!(restored, dtype);
+ Ok(())
+ }
+
+ fn round_trip_field(field: Field) -> Result<()> {
+ let c_schema = FFI_ArrowSchema::try_from(&field)?;
+ let restored = Field::try_from(&c_schema)?;
+ assert_eq!(restored, field);
+ Ok(())
+ }
+
+ fn round_trip_schema(schema: Schema) -> Result<()> {
+ let c_schema = FFI_ArrowSchema::try_from(&schema)?;
+ let restored = Schema::try_from(&c_schema)?;
+ assert_eq!(restored, schema);
+ Ok(())
+ }
+
+ #[test]
+ fn test_type() -> Result<()> {
+ round_trip_type(DataType::Int64)?;
+ round_trip_type(DataType::UInt64)?;
+ round_trip_type(DataType::Float64)?;
+ round_trip_type(DataType::Date64)?;
+ round_trip_type(DataType::Time64(TimeUnit::Nanosecond))?;
+ round_trip_type(DataType::Utf8)?;
+ round_trip_type(DataType::List(Box::new(Field::new(
+ "a",
+ DataType::Int16,
+ false,
+ ))))?;
+ round_trip_type(DataType::Struct(vec![Field::new(
+ "a",
+ DataType::Utf8,
+ true,
+ )]))?;
+ Ok(())
+ }
+
+ #[test]
+ fn test_field() -> Result<()> {
+ let dtype = DataType::Struct(vec![Field::new("a", DataType::Utf8, true)]);
+ round_trip_field(Field::new("test", dtype, true))?;
+ Ok(())
+ }
+
+ #[test]
+ fn test_schema() -> Result<()> {
+ let schema = Schema::new(vec![
+ Field::new("name", DataType::Utf8, false),
+ Field::new("address", DataType::Utf8, false),
+ Field::new("priority", DataType::UInt8, false),
+ ]);
+ round_trip_schema(schema)?;
+
+ // test that we can interpret struct types as schema
+ let dtype = DataType::Struct(vec![
+ Field::new("a", DataType::Utf8, true),
+ Field::new("b", DataType::Int16, false),
+ ]);
+ let c_schema = FFI_ArrowSchema::try_from(&dtype)?;
+ let schema = Schema::try_from(&c_schema)?;
+ assert_eq!(schema.fields().len(), 2);
+
+ // test that we assert the input type
+ let c_schema = FFI_ArrowSchema::try_from(&DataType::Float64)?;
+ let result = Schema::try_from(&c_schema);
+ assert!(result.is_err());
+ Ok(())
+ }
+}
diff --git a/arrow/src/datatypes/mod.rs b/arrow/src/datatypes/mod.rs
index 6a2d0dc..51b33dc 100644
--- a/arrow/src/datatypes/mod.rs
+++ b/arrow/src/datatypes/mod.rs
@@ -36,6 +36,8 @@ mod types;
pub use types::*;
mod datatype;
pub use datatype::*;
+mod ffi;
+pub use ffi::*;
/// A reference-counted reference to a [`Schema`](crate::datatypes::Schema).
pub type SchemaRef = Arc<Schema>;
diff --git a/arrow/src/ffi.rs b/arrow/src/ffi.rs
index b804dd2..e3589ca 100644
--- a/arrow/src/ffi.rs
+++ b/arrow/src/ffi.rs
@@ -77,24 +77,30 @@ To export an array, create an `ArrowArray` using [ArrowArray::try_new].
*/
use std::{
+ convert::TryFrom,
ffi::CStr,
ffi::CString,
iter,
mem::size_of,
+ os::raw::{c_char, c_void},
ptr::{self, NonNull},
sync::Arc,
};
+use bitflags::bitflags;
+
use crate::array::ArrayData;
use crate::buffer::Buffer;
-use crate::datatypes::{DataType, Field, TimeUnit};
+use crate::datatypes::DataType;
use crate::error::{ArrowError, Result};
use crate::util::bit_util;
-#[allow(dead_code)]
-struct SchemaPrivateData {
- field: Field,
- children_ptr: Box<[*mut FFI_ArrowSchema]>,
+bitflags! {
+ pub struct Flags: i64 {
+ const DICTIONARY_ORDERED = 0b00000001;
+ const NULLABLE = 0b00000010;
+ const MAP_KEYS_SORTED = 0b00000100;
+ }
}
/// ABI-compatible struct for `ArrowSchema` from C Data Interface
@@ -103,15 +109,19 @@ struct SchemaPrivateData {
#[repr(C)]
#[derive(Debug)]
pub struct FFI_ArrowSchema {
- format: *const ::std::os::raw::c_char,
- name: *const ::std::os::raw::c_char,
- metadata: *const ::std::os::raw::c_char,
+ format: *const c_char,
+ name: *const c_char,
+ metadata: *const c_char,
flags: i64,
n_children: i64,
children: *mut *mut FFI_ArrowSchema,
dictionary: *mut FFI_ArrowSchema,
- release: ::std::option::Option<unsafe extern "C" fn(arg1: *mut FFI_ArrowSchema)>,
- private_data: *mut ::std::os::raw::c_void,
+ release: Option<unsafe extern "C" fn(arg1: *mut FFI_ArrowSchema)>,
+ private_data: *mut c_void,
+}
+
+struct SchemaPrivateData {
+ children: Box<[*mut FFI_ArrowSchema]>,
}
// callback used to drop [FFI_ArrowSchema] when it is exported.
@@ -122,11 +132,16 @@ unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) {
let schema = &mut *schema;
// take ownership back to release it.
- CString::from_raw(schema.format as *mut std::os::raw::c_char);
- CString::from_raw(schema.name as *mut std::os::raw::c_char);
- let private = Box::from_raw(schema.private_data as *mut SchemaPrivateData);
- for child in private.children_ptr.iter() {
- let _ = Box::from_raw(*child);
+ CString::from_raw(schema.format as *mut c_char);
+ if !schema.name.is_null() {
+ CString::from_raw(schema.name as *mut c_char);
+ }
+ if !schema.private_data.is_null() {
+ let private_data = Box::from_raw(schema.private_data as *mut SchemaPrivateData);
+ for child in private_data.children.iter() {
+ drop(Box::from_raw(*child))
+ }
+ drop(private_data);
}
schema.release = None;
@@ -134,54 +149,39 @@ unsafe extern "C" fn release_schema(schema: *mut FFI_ArrowSchema) {
impl FFI_ArrowSchema {
/// create a new [`Ffi_ArrowSchema`]. This fails if the fields' [`DataType`] is not supported.
- fn try_new(field: Field) -> Result<FFI_ArrowSchema> {
- let format = to_format(field.data_type())?;
- let name = field.name().clone();
-
- // allocate (and hold) the children
- let children_vec = match field.data_type() {
- DataType::List(field) => {
- vec![Box::new(FFI_ArrowSchema::try_new(field.as_ref().clone())?)]
- }
- DataType::LargeList(field) => {
- vec![Box::new(FFI_ArrowSchema::try_new(field.as_ref().clone())?)]
- }
- DataType::Struct(fields) => fields
- .iter()
- .map(|field| Ok(Box::new(FFI_ArrowSchema::try_new(field.clone())?)))
- .collect::<Result<Vec<_>>>()?,
- _ => vec![],
- };
- // note: this cannot be done along with the above because the above is fallible and this op leaks.
- let children_ptr = children_vec
+ pub fn try_new(format: &str, children: Vec<FFI_ArrowSchema>) -> Result<Self> {
+ let mut this = Self::empty();
+
+ let mut children_ptr = children
.into_iter()
+ .map(Box::new)
.map(Box::into_raw)
.collect::<Box<_>>();
- let n_children = children_ptr.len() as i64;
- let flags = field.is_nullable() as i64 * 2;
+ this.format = CString::new(format).unwrap().into_raw();
+ this.release = Some(release_schema);
+ this.n_children = children_ptr.len() as i64;
+ this.children = children_ptr.as_mut_ptr();
- let mut private = Box::new(SchemaPrivateData {
- field,
- children_ptr,
+ let private_data = Box::new(SchemaPrivateData {
+ children: children_ptr,
});
+ this.private_data = Box::into_raw(private_data) as *mut c_void;
- // <https://arrow.apache.org/docs/format/CDataInterface.html#c.ArrowSchema>
- Ok(FFI_ArrowSchema {
- format: CString::new(format).unwrap().into_raw(),
- name: CString::new(name).unwrap().into_raw(),
- metadata: std::ptr::null_mut(),
- flags,
- n_children,
- children: private.children_ptr.as_mut_ptr(),
- dictionary: std::ptr::null_mut(),
- release: Some(release_schema),
- private_data: Box::into_raw(private) as *mut ::std::os::raw::c_void,
- })
+ Ok(this)
}
- /// create an empty [FFI_ArrowSchema]
- fn empty() -> Self {
+ pub fn with_name(mut self, name: &str) -> Result<Self> {
+ self.name = CString::new(name).unwrap().into_raw();
+ Ok(self)
+ }
+
+ pub fn with_flags(mut self, flags: Flags) -> Result<Self> {
+ self.flags = flags.bits();
+ Ok(self)
+ }
+
+ pub fn empty() -> Self {
Self {
format: std::ptr::null_mut(),
name: std::ptr::null_mut(),
@@ -208,15 +208,24 @@ impl FFI_ArrowSchema {
pub fn name(&self) -> &str {
assert!(!self.name.is_null());
// safe because the lifetime of `self.name` equals `self`
- unsafe { CStr::from_ptr(self.name) }.to_str().unwrap()
+ unsafe { CStr::from_ptr(self.name) }
+ .to_str()
+ .expect("The external API has a non-utf8 as name")
+ }
+
+ pub fn flags(&self) -> Option<Flags> {
+ Flags::from_bits(self.flags)
}
pub fn child(&self, index: usize) -> &Self {
assert!(index < self.n_children as usize);
- assert!(!self.name.is_null());
unsafe { self.children.add(index).as_ref().unwrap().as_ref().unwrap() }
}
+ pub fn children(&self) -> impl Iterator<Item = &Self> {
+ (0..self.n_children as usize).map(move |i| self.child(i))
+ }
+
pub fn nullable(&self) -> bool {
(self.flags / 2) & 1 == 1
}
@@ -231,178 +240,6 @@ impl Drop for FFI_ArrowSchema {
}
}
-/// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings
-fn to_field(schema: &FFI_ArrowSchema) -> Result<Field> {
- let data_type = match schema.format() {
- "n" => DataType::Null,
- "b" => DataType::Boolean,
- "c" => DataType::Int8,
- "C" => DataType::UInt8,
- "s" => DataType::Int16,
- "S" => DataType::UInt16,
- "i" => DataType::Int32,
- "I" => DataType::UInt32,
- "l" => DataType::Int64,
- "L" => DataType::UInt64,
- "e" => DataType::Float16,
- "f" => DataType::Float32,
- "g" => DataType::Float64,
- "z" => DataType::Binary,
- "Z" => DataType::LargeBinary,
- "u" => DataType::Utf8,
- "U" => DataType::LargeUtf8,
- "tdD" => DataType::Date32,
- "tdm" => DataType::Date64,
- "tts" => DataType::Time32(TimeUnit::Second),
- "ttm" => DataType::Time32(TimeUnit::Millisecond),
- "ttu" => DataType::Time64(TimeUnit::Microsecond),
- "ttn" => DataType::Time64(TimeUnit::Nanosecond),
- "+l" => {
- let child = schema.child(0);
- DataType::List(Box::new(to_field(child)?))
- }
- "+L" => {
- let child = schema.child(0);
- DataType::LargeList(Box::new(to_field(child)?))
- }
- "+s" => {
- let children = (0..schema.n_children as usize)
- .map(|x| to_field(schema.child(x)))
- .collect::<Result<Vec<_>>>()?;
- DataType::Struct(children)
- }
- // Parametrized types, requiring string parse
- other => {
- match other.splitn(2, ':').collect::<Vec<&str>>().as_slice() {
- // Decimal types in format "d:precision,scale" or "d:precision,scale,bitWidth"
- ["d", extra] => {
- match extra.splitn(3, ',').collect::<Vec<&str>>().as_slice() {
- [precision, scale] => {
- let parsed_precision = precision.parse::<usize>().map_err(|_| {
- ArrowError::CDataInterface(
- "The decimal type requires an integer precision".to_string(),
- )
- })?;
- let parsed_scale = scale.parse::<usize>().map_err(|_| {
- ArrowError::CDataInterface(
- "The decimal type requires an integer scale".to_string(),
- )
- })?;
- DataType::Decimal(parsed_precision, parsed_scale)
- },
- [precision, scale, bits] => {
- if *bits != "128" {
- return Err(ArrowError::CDataInterface("Only 128 bit wide decimal is supported in the Rust implementation".to_string()));
- }
- let parsed_precision = precision.parse::<usize>().map_err(|_| {
- ArrowError::CDataInterface(
- "The decimal type requires an integer precision".to_string(),
- )
- })?;
- let parsed_scale = scale.parse::<usize>().map_err(|_| {
- ArrowError::CDataInterface(
- "The decimal type requires an integer scale".to_string(),
- )
- })?;
- DataType::Decimal(parsed_precision, parsed_scale)
- }
- _ => {
- return Err(ArrowError::CDataInterface(format!(
- "The decimal pattern \"d:{:?}\" is not supported in the Rust implementation",
- extra
- )))
- }
- }
- }
-
- // Timestamps in format "tts:" and "tts:America/New_York" for no timezones and timezones resp.
- ["tss", ""] => DataType::Timestamp(TimeUnit::Second, None),
- ["tsm", ""] => DataType::Timestamp(TimeUnit::Millisecond, None),
- ["tsu", ""] => DataType::Timestamp(TimeUnit::Microsecond, None),
- ["tsn", ""] => DataType::Timestamp(TimeUnit::Nanosecond, None),
- ["tss", tz] => {
- DataType::Timestamp(TimeUnit::Second, Some(tz.to_string()))
- }
- ["tsm", tz] => {
- DataType::Timestamp(TimeUnit::Millisecond, Some(tz.to_string()))
- }
- ["tsu", tz] => {
- DataType::Timestamp(TimeUnit::Microsecond, Some(tz.to_string()))
- }
- ["tsn", tz] => {
- DataType::Timestamp(TimeUnit::Nanosecond, Some(tz.to_string()))
- }
-
- _ => {
- return Err(ArrowError::CDataInterface(format!(
- "The datatype \"{:?}\" is still not supported in Rust implementation",
- other
- )))
- }
- }
- }
- };
- Ok(Field::new(schema.name(), data_type, schema.nullable()))
-}
-
-/// See https://arrow.apache.org/docs/format/CDataInterface.html#data-type-description-format-strings
-fn to_format(data_type: &DataType) -> Result<String> {
- Ok(match data_type {
- DataType::Null => "n",
- DataType::Boolean => "b",
- DataType::Int8 => "c",
- DataType::UInt8 => "C",
- DataType::Int16 => "s",
- DataType::UInt16 => "S",
- DataType::Int32 => "i",
- DataType::UInt32 => "I",
- DataType::Int64 => "l",
- DataType::UInt64 => "L",
- DataType::Float16 => "e",
- DataType::Float32 => "f",
- DataType::Float64 => "g",
- DataType::Binary => "z",
- DataType::LargeBinary => "Z",
- DataType::Utf8 => "u",
- DataType::LargeUtf8 => "U",
- DataType::Decimal(precision, scale) => {
- return Ok(format!("d:{},{}", precision, scale))
- }
- DataType::Date32 => "tdD",
- DataType::Date64 => "tdm",
- DataType::Time32(TimeUnit::Second) => "tts",
- DataType::Time32(TimeUnit::Millisecond) => "ttm",
- DataType::Time64(TimeUnit::Microsecond) => "ttu",
- DataType::Time64(TimeUnit::Nanosecond) => "ttn",
- DataType::Timestamp(TimeUnit::Second, None) => "tss:",
- DataType::Timestamp(TimeUnit::Millisecond, None) => "tsm:",
- DataType::Timestamp(TimeUnit::Microsecond, None) => "tsu:",
- DataType::Timestamp(TimeUnit::Nanosecond, None) => "tsn:",
- DataType::Timestamp(TimeUnit::Second, Some(tz)) => {
- return Ok(format!("tss:{}", tz))
- }
- DataType::Timestamp(TimeUnit::Millisecond, Some(tz)) => {
- return Ok(format!("tsm:{}", tz))
- }
- DataType::Timestamp(TimeUnit::Microsecond, Some(tz)) => {
- return Ok(format!("tsu:{}", tz))
- }
- DataType::Timestamp(TimeUnit::Nanosecond, Some(tz)) => {
- return Ok(format!("tsn:{}", tz))
- }
- DataType::List(_) => "+l",
- DataType::LargeList(_) => "+L",
- DataType::Struct(_) => "+s",
- z => {
- return Err(ArrowError::CDataInterface(format!(
- "The datatype \"{:?}\" is still not supported in Rust implementation",
- z
- )))
- }
- }
- .to_string())
-}
-
// returns the number of bits that buffer `i` (in the C data interface) is expected to have.
// This is set by the Arrow specification
fn bit_width(data_type: &DataType, i: usize) -> Result<usize> {
@@ -482,16 +319,16 @@ pub struct FFI_ArrowArray {
pub(crate) offset: i64,
pub(crate) n_buffers: i64,
pub(crate) n_children: i64,
- pub(crate) buffers: *mut *const ::std::os::raw::c_void,
+ pub(crate) buffers: *mut *const c_void,
children: *mut *mut FFI_ArrowArray,
dictionary: *mut FFI_ArrowArray,
- release: ::std::option::Option<unsafe extern "C" fn(arg1: *mut FFI_ArrowArray)>,
+ release: Option<unsafe extern "C" fn(arg1: *mut FFI_ArrowArray)>,
// When exported, this MUST contain everything that is owned by this array.
- // for example, any buffer pointed to in `buffers` must be here, as well as the `buffers` pointer
- // itself.
- // In other words, everything in [FFI_ArrowArray] must be owned by `private_data` and can assume
- // that they do not outlive `private_data`.
- private_data: *mut ::std::os::raw::c_void,
+ // for example, any buffer pointed to in `buffers` must be here, as well
+ // as the `buffers` pointer itself.
+ // In other words, everything in [FFI_ArrowArray] must be owned by
+ // `private_data` and can assume that they do not outlive `private_data`.
+ private_data: *mut c_void,
}
impl Drop for FFI_ArrowArray {
@@ -511,7 +348,7 @@ unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) {
let array = &mut *array;
// take ownership of `private_data`, therefore dropping it`
- let private = Box::from_raw(array.private_data as *mut PrivateData);
+ let private = Box::from_raw(array.private_data as *mut ArrayPrivateData);
for child in private.children.iter() {
let _ = Box::from_raw(*child);
}
@@ -519,9 +356,9 @@ unsafe extern "C" fn release_array(array: *mut FFI_ArrowArray) {
array.release = None;
}
-struct PrivateData {
+struct ArrayPrivateData {
buffers: Vec<Option<Buffer>>,
- buffers_ptr: Box<[*const std::os::raw::c_void]>,
+ buffers_ptr: Box<[*const c_void]>,
children: Box<[*mut FFI_ArrowArray]>,
}
@@ -542,7 +379,7 @@ impl FFI_ArrowArray {
.iter()
.map(|maybe_buffer| match maybe_buffer {
// note that `raw_data` takes into account the buffer's offset
- Some(b) => b.as_ptr() as *const std::os::raw::c_void,
+ Some(b) => b.as_ptr() as *const c_void,
None => std::ptr::null(),
})
.collect::<Box<[_]>>();
@@ -556,7 +393,7 @@ impl FFI_ArrowArray {
// create the private data owning everything.
// any other data must be added here, e.g. via a struct, to track lifetime.
- let mut private_data = Box::new(PrivateData {
+ let mut private_data = Box::new(ArrayPrivateData {
buffers,
buffers_ptr,
children,
@@ -572,7 +409,7 @@ impl FFI_ArrowArray {
children: private_data.children.as_mut_ptr(),
dictionary: std::ptr::null_mut(),
release: Some(release_array),
- private_data: Box::into_raw(private_data) as *mut ::std::os::raw::c_void,
+ private_data: Box::into_raw(private_data) as *mut c_void,
}
}
@@ -814,7 +651,7 @@ pub struct ArrowArrayChild<'a> {
impl ArrowArrayRef for ArrowArray {
/// the data_type as declared in the schema
fn data_type(&self) -> Result<DataType> {
- to_field(&self.schema).map(|x| x.data_type().clone())
+ DataType::try_from(self.schema.as_ref())
}
fn array(&self) -> &FFI_ArrowArray {
@@ -833,7 +670,7 @@ impl ArrowArrayRef for ArrowArray {
impl<'a> ArrowArrayRef for ArrowArrayChild<'a> {
/// the data_type as declared in the schema
fn data_type(&self) -> Result<DataType> {
- to_field(self.schema).map(|x| x.data_type().clone())
+ DataType::try_from(self.schema)
}
fn array(&self) -> &FFI_ArrowArray {
@@ -855,10 +692,8 @@ impl ArrowArray {
/// See safety of [ArrowArray]
#[allow(clippy::too_many_arguments)]
pub unsafe fn try_new(data: ArrayData) -> Result<Self> {
- let field = Field::new("", data.data_type().clone(), data.null_count() != 0);
let array = Arc::new(FFI_ArrowArray::new(&data));
- let schema = Arc::new(FFI_ArrowSchema::try_new(field)?);
-
+ let schema = Arc::new(FFI_ArrowSchema::try_from(data.data_type())?);
Ok(ArrowArray { array, schema })
}