You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by we...@apache.org on 2019/06/22 20:10:29 UTC

[arrow] branch master updated: ARROW-5169: [Python] preserve field nullability of specified schema in Table.from_pandas

This is an automated email from the ASF dual-hosted git repository.

wesm pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new a566bc2  ARROW-5169: [Python] preserve field nullability of specified schema in Table.from_pandas
a566bc2 is described below

commit a566bc21bf7b5113c76b85a5bbe4bbea13411f9b
Author: Joris Van den Bossche <jo...@gmail.com>
AuthorDate: Sat Jun 22 15:10:21 2019 -0500

    ARROW-5169: [Python] preserve field nullability of specified schema in Table.from_pandas
    
    Author: Joris Van den Bossche <jo...@gmail.com>
    
    Closes #4397 from jorisvandenbossche/ARROW-5169-table-from-pandas-schema-nullability and squashes the following commits:
    
    f4b9c71b5 <Joris Van den Bossche> Merge remote-tracking branch 'upstream/master' into ARROW-5169-table-from-pandas-schema-nullability
    4baef2a3c <Joris Van den Bossche> update return comment
    496f731ae <Joris Van den Bossche> Merge remote-tracking branch 'upstream/master' into ARROW-5169-table-from-pandas-schema-nullability
    dbdac8cbc <Joris Van den Bossche> Merge remote-tracking branch 'upstream/master' into ARROW-5169-table-from-pandas-schema-nullability
    d693d4645 <Joris Van den Bossche> fix case of None as column name
    d5b322431 <Joris Van den Bossche> Merge remote-tracking branch 'upstream/master' into ARROW-5169-table-from-pandas-schema-nullability
    fc357a642 <Joris Van den Bossche> ARROW-5169:  preserve field nullability of specified schema in Table.from_pandas
---
 python/pyarrow/pandas_compat.py     | 24 +++++++++++++++++++++---
 python/pyarrow/table.pxi            |  8 ++++----
 python/pyarrow/tests/test_pandas.py | 14 ++++++++++++++
 3 files changed, 39 insertions(+), 7 deletions(-)

diff --git a/python/pyarrow/pandas_compat.py b/python/pyarrow/pandas_compat.py
index ea38d41..50cabb8 100644
--- a/python/pyarrow/pandas_compat.py
+++ b/python/pyarrow/pandas_compat.py
@@ -374,6 +374,7 @@ def _get_columns_to_convert(df, schema, preserve_index, columns):
     # all_names : all of the columns in the resulting table including the data
     # columns and serialized index columns
     # column_names : the names of the data columns
+    # index_column_names : the names of the serialized index columns
     # index_descriptors : descriptions of each index to be used for
     # reconstruction
     # index_levels : the extracted index level values
@@ -381,8 +382,8 @@ def _get_columns_to_convert(df, schema, preserve_index, columns):
     # to be converted to Arrow format
     # columns_types : specified column types to use for coercion / casting
     # during serialization, if a Schema was provided
-    return (all_names, column_names, index_descriptors, index_levels,
-            columns_to_convert, convert_types)
+    return (all_names, column_names, index_column_names, index_descriptors,
+            index_levels, columns_to_convert, convert_types)
 
 
 def _get_range_index_descriptor(level):
@@ -418,6 +419,7 @@ def _resolve_columns_of_interest(df, schema, columns):
 def dataframe_to_types(df, preserve_index, columns=None):
     (all_names,
      column_names,
+     _,
      index_descriptors,
      index_columns,
      columns_to_convert,
@@ -446,6 +448,7 @@ def dataframe_to_arrays(df, schema, preserve_index, nthreads=1, columns=None,
                         safe=True):
     (all_names,
      column_names,
+     index_column_names,
      index_descriptors,
      index_columns,
      columns_to_convert,
@@ -485,10 +488,25 @@ def dataframe_to_arrays(df, schema, preserve_index, nthreads=1, columns=None,
 
     types = [x.type for x in arrays]
 
+    if schema is not None:
+        # add index columns
+        index_types = types[len(column_names):]
+        for name, type_ in zip(index_column_names, index_types):
+            name = name if name is not None else 'None'
+            schema = schema.append(pa.field(name, type_))
+    else:
+        fields = []
+        for name, type_ in zip(all_names, types):
+            name = name if name is not None else 'None'
+            fields.append(pa.field(name, type_))
+        schema = pa.schema(fields)
+
     metadata = construct_metadata(df, column_names, index_columns,
                                   index_descriptors, preserve_index,
                                   types)
-    return all_names, arrays, metadata
+    schema = schema.add_metadata(metadata)
+
+    return arrays, schema
 
 
 def get_datetimetz_type(values, dtype, type_):
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index db26dc2..688050b 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -861,10 +861,10 @@ cdef class RecordBatch(_PandasConvertible):
         pyarrow.RecordBatch
         """
         from pyarrow.pandas_compat import dataframe_to_arrays
-        names, arrays, metadata = dataframe_to_arrays(
+        arrays, schema = dataframe_to_arrays(
             df, schema, preserve_index, nthreads=nthreads, columns=columns
         )
-        return cls.from_arrays(arrays, names, metadata)
+        return cls.from_arrays(arrays, schema)
 
     @staticmethod
     def from_arrays(list arrays, names, metadata=None):
@@ -1142,7 +1142,7 @@ cdef class Table(_PandasConvertible):
         <pyarrow.lib.Table object at 0x7f05d1fb1b40>
         """
         from pyarrow.pandas_compat import dataframe_to_arrays
-        names, arrays, metadata = dataframe_to_arrays(
+        arrays, schema = dataframe_to_arrays(
             df,
             schema=schema,
             preserve_index=preserve_index,
@@ -1150,7 +1150,7 @@ cdef class Table(_PandasConvertible):
             columns=columns,
             safe=safe
         )
-        return cls.from_arrays(arrays, names=names, metadata=metadata)
+        return cls.from_arrays(arrays, schema=schema)
 
     @staticmethod
     def from_arrays(arrays, names=None, schema=None, metadata=None):
diff --git a/python/pyarrow/tests/test_pandas.py b/python/pyarrow/tests/test_pandas.py
index 4af3708..5ea3d19 100644
--- a/python/pyarrow/tests/test_pandas.py
+++ b/python/pyarrow/tests/test_pandas.py
@@ -2607,6 +2607,20 @@ def test_table_from_pandas_columns_and_schema_are_mutually_exclusive():
         pa.Table.from_pandas(df, schema=schema, columns=columns)
 
 
+def test_table_from_pandas_keeps_schema_nullability():
+    # ARROW-5169
+    df = pd.DataFrame({'a': [1, 2, 3, 4]})
+
+    schema = pa.schema([
+        pa.field('a', pa.int64(), nullable=False),
+    ])
+
+    table = pa.Table.from_pandas(df)
+    assert table.schema.field_by_name('a').nullable is True
+    table = pa.Table.from_pandas(df, schema=schema)
+    assert table.schema.field_by_name('a').nullable is False
+
+
 # ----------------------------------------------------------------------
 # RecordBatch, Table