You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by uw...@apache.org on 2017/09/10 06:35:09 UTC

arrow git commit: ARROW-1359: [C++] Add flavor='spark' option to write_parquet that sanitizes schema field names

Repository: arrow
Updated Branches:
  refs/heads/master 947ca871c -> 4a6a6cb47


ARROW-1359: [C++] Add flavor='spark' option to write_parquet that sanitizes schema field names

I also made the default for `use_deprecated_int96_timestamps` None so that we can distinguish between unspecified and explicitly False. In the event that the user passes `flavor='spark'`, this is enabled. Once Spark processes the int96 deprecation in the future, we can remove this part.

Author: Wes McKinney <we...@twosigma.com>

Closes #1076 from wesm/ARROW-1359 and squashes the following commits:

8a60b66 [Wes McKinney] Use composition rather than inheritance
e3fa8ec [Wes McKinney] Add note about spark flavor to Sphinx docs
8159a51 [Wes McKinney] Add flavor='spark' option to write_parquet that sanitizes schema field names, turns on int96 timestamps


Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/4a6a6cb4
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/4a6a6cb4
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/4a6a6cb4

Branch: refs/heads/master
Commit: 4a6a6cb47dcf832213f1fb31f2325ad10a3864bd
Parents: 947ca87
Author: Wes McKinney <we...@twosigma.com>
Authored: Sun Sep 10 08:34:43 2017 +0200
Committer: Uwe L. Korn <uw...@xhochy.com>
Committed: Sun Sep 10 08:34:43 2017 +0200

----------------------------------------------------------------------
 python/doc/source/parquet.rst        |  7 +++
 python/pyarrow/parquet.py            | 96 ++++++++++++++++++++++++++++---
 python/pyarrow/table.pxi             | 31 +++++++---
 python/pyarrow/tests/test_parquet.py | 28 ++++++++-
 python/pyarrow/types.pxi             |  5 +-
 5 files changed, 149 insertions(+), 18 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/arrow/blob/4a6a6cb4/python/doc/source/parquet.rst
----------------------------------------------------------------------
diff --git a/python/doc/source/parquet.rst b/python/doc/source/parquet.rst
index 7626c15..d466ba1 100644
--- a/python/doc/source/parquet.rst
+++ b/python/doc/source/parquet.rst
@@ -217,6 +217,13 @@ such as those produced by Hive:
    dataset = pq.ParquetDataset('dataset_name/')
    table = dataset.read()
 
+Using with Spark
+----------------
+
+Spark places some constraints on the types of Parquet files it will read. The
+option ``flavor='spark'`` will set these options automatically and also
+sanitize field characters unsupported by Spark SQL.
+
 Multithreaded Reads
 -------------------
 

http://git-wip-us.apache.org/repos/asf/arrow/blob/4a6a6cb4/python/pyarrow/parquet.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/parquet.py b/python/pyarrow/parquet.py
index 568aad4..1584b84 100644
--- a/python/pyarrow/parquet.py
+++ b/python/pyarrow/parquet.py
@@ -18,17 +18,17 @@
 import os
 import inspect
 import json
-
+import re
 import six
 
 import numpy as np
 
 from pyarrow.filesystem import FileSystem, LocalFileSystem, S3FSWrapper
 from pyarrow._parquet import (ParquetReader, FileMetaData,  # noqa
-                              RowGroupMetaData, ParquetSchema,
-                              ParquetWriter)
+                              RowGroupMetaData, ParquetSchema)
 import pyarrow._parquet as _parquet  # noqa
 import pyarrow.lib as lib
+import pyarrow as pa
 
 
 # ----------------------------------------------------------------------
@@ -164,6 +164,73 @@ class ParquetFile(object):
         return indices
 
 
+_SPARK_DISALLOWED_CHARS = re.compile('[ ,;{}()\n\t=]')
+
+
+def _sanitized_spark_field_name(name):
+    return _SPARK_DISALLOWED_CHARS.sub('_', name)
+
+
+def _sanitize_schema(schema, flavor):
+    if 'spark' in flavor:
+        sanitized_fields = []
+
+        schema_changed = False
+
+        for field in schema:
+            name = field.name
+            sanitized_name = _sanitized_spark_field_name(name)
+
+            if sanitized_name != name:
+                schema_changed = True
+                sanitized_field = pa.field(sanitized_name, field.type,
+                                           field.nullable, field.metadata)
+                sanitized_fields.append(sanitized_field)
+            else:
+                sanitized_fields.append(field)
+        return pa.schema(sanitized_fields), schema_changed
+    else:
+        return schema, False
+
+
+def _sanitize_table(table, new_schema, flavor):
+    # TODO: This will not handle prohibited characters in nested field names
+    if 'spark' in flavor:
+        column_data = [table[i].data for i in range(table.num_columns)]
+        return pa.Table.from_arrays(column_data, schema=new_schema)
+    else:
+        return table
+
+
+class ParquetWriter(object):
+    """
+
+    Parameters
+    ----------
+    where
+    schema
+    flavor : {'spark', ...}
+        Set options for compatibility with a particular reader
+    """
+    def __init__(self, where, schema, flavor=None, **options):
+        self.flavor = flavor
+        if flavor is not None:
+            schema, self.schema_changed = _sanitize_schema(schema, flavor)
+        else:
+            self.schema_changed = False
+
+        self.schema = schema
+        self.writer = _parquet.ParquetWriter(where, schema, **options)
+
+    def write_table(self, table, row_group_size=None):
+        if self.schema_changed:
+            table = _sanitize_table(table, self.schema, self.flavor)
+        self.writer.write_table(table, row_group_size=row_group_size)
+
+    def close(self):
+        self.writer.close()
+
+
 def _get_pandas_index_columns(keyvalues):
     return (json.loads(keyvalues[b'pandas'].decode('utf8'))
             ['index_columns'])
@@ -787,8 +854,9 @@ def read_pandas(source, columns=None, nthreads=1, metadata=None):
 
 def write_table(table, where, row_group_size=None, version='1.0',
                 use_dictionary=True, compression='snappy',
-                use_deprecated_int96_timestamps=False,
-                coerce_timestamps=None, **kwargs):
+                use_deprecated_int96_timestamps=None,
+                coerce_timestamps=None,
+                flavor=None, **kwargs):
     """
     Write a Table to Parquet format
 
@@ -804,15 +872,26 @@ def write_table(table, where, row_group_size=None, version='1.0',
     use_dictionary : bool or list
         Specify if we should use dictionary encoding in general or only for
         some columns.
-    use_deprecated_int96_timestamps : boolean, default False
-        Write nanosecond resolution timestamps to INT96 Parquet format
+    use_deprecated_int96_timestamps : boolean, default None
+        Write nanosecond resolution timestamps to INT96 Parquet
+        format. Defaults to False unless enabled by flavor argument
     coerce_timestamps : string, default None
         Cast timestamps a particular resolution.
         Valid values: {None, 'ms', 'us'}
     compression : str or dict
         Specify the compression codec, either on a general basis or per-column.
+    flavor : {'spark'}, default None
+        Sanitize schema or set other compatibility options for compatibility
     """
     row_group_size = kwargs.get('chunk_size', row_group_size)
+
+    if use_deprecated_int96_timestamps is None:
+        # Use int96 timestamps for Spark
+        if flavor is not None and 'spark' in flavor:
+            use_deprecated_int96_timestamps = True
+        else:
+            use_deprecated_int96_timestamps = False
+
     options = dict(
         use_dictionary=use_dictionary,
         compression=compression,
@@ -822,7 +901,8 @@ def write_table(table, where, row_group_size=None, version='1.0',
 
     writer = None
     try:
-        writer = ParquetWriter(where, table.schema, **options)
+        writer = ParquetWriter(where, table.schema, flavor=flavor,
+                               **options)
         writer.write_table(table, row_group_size=row_group_size)
     except:
         if writer is not None:

http://git-wip-us.apache.org/repos/asf/arrow/blob/4a6a6cb4/python/pyarrow/table.pxi
----------------------------------------------------------------------
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index fc6099f..68eb5cb 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -758,7 +758,7 @@ cdef class Table:
         return cls.from_arrays(arrays, names=names, metadata=metadata)
 
     @staticmethod
-    def from_arrays(arrays, names=None, dict metadata=None):
+    def from_arrays(arrays, names=None, schema=None, dict metadata=None):
         """
         Construct a Table from Arrow arrays or columns
 
@@ -777,11 +777,22 @@ cdef class Table:
         """
         cdef:
             vector[shared_ptr[CColumn]] columns
-            shared_ptr[CSchema] schema
+            Schema cy_schema
+            shared_ptr[CSchema] c_schema
             shared_ptr[CTable] table
             int i, K = <int> len(arrays)
 
-        _schema_from_arrays(arrays, names, metadata, &schema)
+        if schema is None:
+            _schema_from_arrays(arrays, names, metadata, &c_schema)
+        elif schema is not None:
+            if names is not None:
+                raise ValueError('Cannot pass schema and arrays')
+            cy_schema = schema
+
+            if len(schema) != len(arrays):
+                raise ValueError('Schema and number of arrays unequal')
+
+            c_schema = cy_schema.sp_schema
 
         columns.reserve(K)
 
@@ -789,23 +800,29 @@ cdef class Table:
             if isinstance(arrays[i], Array):
                 columns.push_back(
                     make_shared[CColumn](
-                        schema.get().field(i),
+                        c_schema.get().field(i),
                         (<Array> arrays[i]).sp_array
                     )
                 )
             elif isinstance(arrays[i], ChunkedArray):
                 columns.push_back(
                     make_shared[CColumn](
-                        schema.get().field(i),
+                        c_schema.get().field(i),
                         (<ChunkedArray> arrays[i]).sp_chunked_array
                     )
                 )
             elif isinstance(arrays[i], Column):
-                columns.push_back((<Column> arrays[i]).sp_column)
+                # Make sure schema field and column are consistent
+                columns.push_back(
+                    make_shared[CColumn](
+                        c_schema.get().field(i),
+                        (<Column> arrays[i]).sp_column.get().data()
+                    )
+                )
             else:
                 raise ValueError(type(arrays[i]))
 
-        table.reset(new CTable(schema, columns))
+        table.reset(new CTable(c_schema, columns))
         return pyarrow_wrap_table(table)
 
     @staticmethod

http://git-wip-us.apache.org/repos/asf/arrow/blob/4a6a6cb4/python/pyarrow/tests/test_parquet.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/tests/test_parquet.py b/python/pyarrow/tests/test_parquet.py
index 5dfe0a5..9b5a4bc 100644
--- a/python/pyarrow/tests/test_parquet.py
+++ b/python/pyarrow/tests/test_parquet.py
@@ -17,6 +17,7 @@
 
 from os.path import join as pjoin
 import datetime
+import gc
 import io
 import os
 import json
@@ -562,6 +563,10 @@ def test_date_time_types():
     _check_roundtrip(table, expected=expected, version='2.0',
                      use_deprecated_int96_timestamps=True)
 
+    # Check that setting flavor to 'spark' uses int96 timestamps
+    _check_roundtrip(table, expected=expected, version='2.0',
+                     flavor='spark')
+
     # Unsupported stuff
     def _assert_unsupported(array):
         table = pa.Table.from_arrays([array], ['unsupported'])
@@ -577,6 +582,18 @@ def test_date_time_types():
 
 
 @parquet
+def test_sanitized_spark_field_names():
+    a0 = pa.array([0, 1, 2, 3, 4])
+    name = 'prohib; ,\t{}'
+    table = pa.Table.from_arrays([a0], [name])
+
+    result = _roundtrip_table(table, flavor='spark')
+
+    expected_name = 'prohib______'
+    assert result.schema[0].name == expected_name
+
+
+@parquet
 def test_fixed_size_binary():
     t0 = pa.binary(10)
     data = [b'fooooooooo', None, b'barooooooo', b'quxooooooo']
@@ -587,15 +604,19 @@ def test_fixed_size_binary():
     _check_roundtrip(table)
 
 
-def _check_roundtrip(table, expected=None, **params):
+def _roundtrip_table(table, **params):
     buf = io.BytesIO()
     _write_table(table, buf, **params)
     buf.seek(0)
 
+    return _read_table(buf)
+
+
+def _check_roundtrip(table, expected=None, **params):
     if expected is None:
         expected = table
 
-    result = _read_table(buf)
+    result = _roundtrip_table(table, **params)
     assert result.equals(expected)
 
 
@@ -1181,6 +1202,9 @@ def test_write_error_deletes_incomplete_file(tmpdir):
     except pa.ArrowException:
         pass
 
+    # Ensure that object has been destructed; this causes test failures on
+    # Windows
+    gc.collect()
     assert not os.path.exists(filename)
 
 

http://git-wip-us.apache.org/repos/asf/arrow/blob/4a6a6cb4/python/pyarrow/types.pxi
----------------------------------------------------------------------
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 3eaee6c..56670f6 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -299,7 +299,6 @@ cdef class Schema:
         return self.schema.num_fields()
 
     def __getitem__(self, int i):
-
         cdef:
             Field result = Field()
             int num_fields = self.schema.num_fields()
@@ -318,6 +317,10 @@ cdef class Schema:
 
         return result
 
+    def __iter__(self):
+        for i in range(len(self)):
+            yield self[i]
+
     def _check_null(self):
         if self.schema == NULL:
             raise ReferenceError(