You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by tu...@apache.org on 2023/06/04 15:11:39 UTC

[arrow-rs] branch master updated: Raise TypeError on PyArrow import (#4316)

This is an automated email from the ASF dual-hosted git repository.

tustvold pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-rs.git


The following commit(s) were added to refs/heads/master by this push:
     new d8d5fca51 Raise TypeError on PyArrow import (#4316)
d8d5fca51 is described below

commit d8d5fca516a8f947d30b4e0d854710ee691a96b7
Author: Will Jones <wi...@gmail.com>
AuthorDate: Sun Jun 4 08:11:33 2023 -0700

    Raise TypeError on PyArrow import (#4316)
    
    * type error on PyArrow import
    
    * fix error message
---
 .../tests/test_sql.py                              | 22 ++++++++++++++++
 arrow/src/pyarrow.rs                               | 30 +++++++++++++++++++++-
 2 files changed, 51 insertions(+), 1 deletion(-)

diff --git a/arrow-pyarrow-integration-testing/tests/test_sql.py b/arrow-pyarrow-integration-testing/tests/test_sql.py
index f631f67cb..a7c6b34a4 100644
--- a/arrow-pyarrow-integration-testing/tests/test_sql.py
+++ b/arrow-pyarrow-integration-testing/tests/test_sql.py
@@ -408,3 +408,25 @@ def test_record_batch_reader():
     assert b.schema == schema
     got_batches = list(b)
     assert got_batches == batches
+
+def test_reject_other_classes():
+    # Arbitrary type that is not a PyArrow type
+    not_pyarrow = ["hello"]
+
+    with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Array, got builtins.list"):
+        rust.round_trip_array(not_pyarrow)
+    
+    with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Schema, got builtins.list"):
+        rust.round_trip_schema(not_pyarrow)
+    
+    with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.Field, got builtins.list"):
+        rust.round_trip_field(not_pyarrow)
+    
+    with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.DataType, got builtins.list"):
+        rust.round_trip_type(not_pyarrow)
+
+    with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatch, got builtins.list"):
+        rust.round_trip_record_batch(not_pyarrow)
+    
+    with pytest.raises(TypeError, match="Expected instance of pyarrow.lib.RecordBatchReader, got builtins.list"):
+        rust.round_trip_record_batch_reader(not_pyarrow)
diff --git a/arrow/src/pyarrow.rs b/arrow/src/pyarrow.rs
index ba8d606f2..98e27ab30 100644
--- a/arrow/src/pyarrow.rs
+++ b/arrow/src/pyarrow.rs
@@ -24,7 +24,7 @@ use std::convert::{From, TryFrom};
 use std::ptr::{addr_of, addr_of_mut};
 use std::sync::Arc;
 
-use pyo3::exceptions::PyValueError;
+use pyo3::exceptions::{PyTypeError, PyValueError};
 use pyo3::ffi::Py_uintptr_t;
 use pyo3::import_exception;
 use pyo3::prelude::*;
@@ -67,8 +67,27 @@ impl<T: ToPyArrow> IntoPyArrow for T {
     }
 }
 
+fn validate_class(expected: &str, value: &PyAny) -> PyResult<()> {
+    let pyarrow = PyModule::import(value.py(), "pyarrow")?;
+    let class = pyarrow.getattr(expected)?;
+    if !value.is_instance(class)? {
+        let expected_module = class.getattr("__module__")?.extract::<&str>()?;
+        let expected_name = class.getattr("__name__")?.extract::<&str>()?;
+        let found_class = value.get_type();
+        let found_module = found_class.getattr("__module__")?.extract::<&str>()?;
+        let found_name = found_class.getattr("__name__")?.extract::<&str>()?;
+        return Err(PyTypeError::new_err(format!(
+            "Expected instance of {}.{}, got {}.{}",
+            expected_module, expected_name, found_module, found_name
+        )));
+    }
+    Ok(())
+}
+
 impl FromPyArrow for DataType {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        validate_class("DataType", value)?;
+
         let c_schema = FFI_ArrowSchema::empty();
         let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
         value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
@@ -91,6 +110,8 @@ impl ToPyArrow for DataType {
 
 impl FromPyArrow for Field {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        validate_class("Field", value)?;
+
         let c_schema = FFI_ArrowSchema::empty();
         let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
         value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
@@ -113,6 +134,8 @@ impl ToPyArrow for Field {
 
 impl FromPyArrow for Schema {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        validate_class("Schema", value)?;
+
         let c_schema = FFI_ArrowSchema::empty();
         let c_schema_ptr = &c_schema as *const FFI_ArrowSchema;
         value.call_method1("_export_to_c", (c_schema_ptr as Py_uintptr_t,))?;
@@ -135,6 +158,8 @@ impl ToPyArrow for Schema {
 
 impl FromPyArrow for ArrayData {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        validate_class("Array", value)?;
+
         // prepare a pointer to receive the Array struct
         let mut array = FFI_ArrowArray::empty();
         let mut schema = FFI_ArrowSchema::empty();
@@ -194,6 +219,7 @@ impl<T: ToPyArrow> ToPyArrow for Vec<T> {
 
 impl FromPyArrow for RecordBatch {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        validate_class("RecordBatch", value)?;
         // TODO(kszucs): implement the FFI conversions in arrow-rs for RecordBatches
         let schema = value.getattr("schema")?;
         let schema = Arc::new(Schema::from_pyarrow(schema)?);
@@ -235,6 +261,8 @@ impl ToPyArrow for RecordBatch {
 
 impl FromPyArrow for ArrowArrayStreamReader {
     fn from_pyarrow(value: &PyAny) -> PyResult<Self> {
+        validate_class("RecordBatchReader", value)?;
+
         // prepare a pointer to receive the stream struct
         let mut stream = FFI_ArrowArrayStream::empty();
         let stream_ptr = &mut stream as *mut FFI_ArrowArrayStream;