You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/06/04 15:11:39 UTC
[arrow-rs] branch master updated: Raise TypeError on PyArrow import (#4316)
This is an automated email from the ASF dual-hosted git repository.
tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git
The following commit(s) were added to refs/heads/master by this push:
new d8d5fca51 Raise TypeError on PyArrow import (#4316)
d8d5fca51 is described below
commit d8d5fca516a8f947d30b4e0d854710ee691a96b7
Author: Will Jones <wi...@gmail.com>
AuthorDate: Sun Jun 4 08:11:33 2023 -0700
Raise TypeError on PyArrow import (#4316)
* type error on PyArrow import
* fix error message
---
.../tests/test_sql.py | 22 ++++++++++++++++
arrow/src/pyarrow.rs | 30 +++++++++++++++++++++-
2 files changed, 51 insertions(+), 1 deletion(-)
diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py
index f631f67cb..a7c6b34a4 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -408,3 +408,25 @@ def test_record_batch_reader():
assert b.schema == schema
got_batches = list(b)
assert got_batches == batches
+
+def test_reject_other_classes():
+ # Arbitrary type that is not a PyArrow type
+ not_pyarrow = ["hello"]
+
+ with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Array, got builtins.list"):
+ rust.round_trip_array(not_pyarrow)
+
+ with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Schema, got builtins.list"):
+ rust.round_trip_schema(not_pyarrow)
+
+ with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Field, got builtins.list"):
+ rust.round_trip_field(not_pyarrow)
+
+ with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.DataType, got builtins.list"):
+ rust.round_trip_type(not_pyarrow)
+
+ with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatch, got builtins.list"):
+ rust.round_trip_record_batch(not_pyarrow)
+
+ with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatchReader, got builtins.list"):
+ rust.round_trip_record_batch_reader(not_pyarrow)
diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs
index ba8d606f2..98e27ab30 100644
--- a/arrow/src/pyarrow.rs
+++ b/arrow/src/pyarrow.rs
@@ -24,7 +24,7 @@ use std::convert::{From, TryFrom};
use std::ptr::{addr_of, addr_of_mut};
use std::sync::Arc;
-use pyo3::exceptions::PyValueError;
+use pyo3::exceptions::{PyTypeError, PyValueError};
use pyo3::ffi::Py_uintptr_t;
use pyo3::import_exception;
use pyo3::prelude::*;
@@ -67,8 +67,27 @@ impl<T: ToPyArrow> IntoPyArrow for T {
}
}
+fn validate_class(expected: &str, value: &PyAny) -> PyResult<()> {
+ let pyarrow = PyModule::import(value.py(), "pyarrow")?;
+ let class = pyarrow.getattr(expected)?;
+ if !value.is_instance(class)? {
+ let expected_module = class.getattr("__module__")?.extract::<&str>()?;
+ let expected_name = class.getattr("__name__")?.extract::<&str>()?;
+ let found_class = value.get_type();
+ let found_module = found_class.getattr("__module__")?.extract::<&str>()?;
+ let found_name = found_class.getattr("__name__")?.extract::<&str>()?;
+ return Err(PyTypeError::new_err(format!(
+ "Expected instance of {}.{}, got {}.{}",
+ expected_module, expected_name, found_module, found_name
+ )));
+ }
+ Ok(())
+}
+
impl FromPyArrow for DataType {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ validate_class("DataType", value)?;
+
let c_schema = FFI_ArrowSchema::empty();
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
@@ -91,6 +110,8 @@ impl ToPyArrow for DataType {
impl FromPyArrow for Field {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ validate_class("Field", value)?;
+
let c_schema = FFI_ArrowSchema::empty();
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
@@ -113,6 +134,8 @@ impl ToPyArrow for Field {
impl FromPyArrow for Schema {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ validate_class("Schema", value)?;
+
let c_schema = FFI_ArrowSchema::empty();
let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
@@ -135,6 +158,8 @@ impl ToPyArrow for Schema {
impl FromPyArrow for ArrayData {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ validate_class("Array", value)?;
+
// prepare a pointer to receive the Array struct
let mut array = FFI_ArrowArray::empty();
let mut schema = FFI_ArrowSchema::empty();
@@ -194,6 +219,7 @@ impl<T: ToPyArrow> ToPyArrow for Vec<T> {
impl FromPyArrow for RecordBatch {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ validate_class("RecordBatch", value)?;
// TODO(kszucs): implement the FFI conversions in arrow-rs for RecordBatches
let schema = value.getattr("schema")?;
let schema = Arc::new(Schema::from_pyarrow(schema)?);
@@ -235,6 +261,8 @@ impl ToPyArrow for RecordBatch {
impl FromPyArrow for ArrowArrayStreamReader {
fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+ validate_class("RecordBatchReader", value)?;
+
// prepare a pointer to receive the stream struct
let mut stream = FFI_ArrowArrayStream::empty();
let stream_ptr = &mut stream as *mut FFI_ArrowArrayStream;