You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by uw...@apache.org on 2017/09/10 06:35:09 UTC
arrow git commit: ARROW-1359: [C++] Add flavor='spark' option to
write_parquet that sanitizes schema field names
Repository: arrow
Updated Branches:
refs/heads/master 947ca871c -> 4a6a6cb47
ARROW-1359: [C++] Add flavor='spark' option to write_parquet that sanitizes schema field names
I also made the default for `use_deprecated_int96_timestamps` None so that we can distinguish between unspecified and explicitly False. In the event that the user passes `flavor='spark'`, this is enabled. Once Spark processes the int96 deprecation in the future, we can remove this part.
Author: Wes McKinney <we...@twosigma.com>
Closes #1076 from wesm/ARROW-1359 and squashes the following commits:
8a60b66 [Wes McKinney] Use composition rather than inheritance
e3fa8ec [Wes McKinney] Add note about spark flavor to Sphinx docs
8159a51 [Wes McKinney] Add flavor='spark' option to write_parquet that sanitizes schema field names, turns on int96 timestamps
Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/4a6a6cb4
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/4a6a6cb4
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/4a6a6cb4
Branch: refs/heads/master
Commit: 4a6a6cb47dcf832213f1fb31f2325ad10a3864bd
Parents: 947ca87
Author: Wes McKinney <we...@twosigma.com>
Authored: Sun Sep 10 08:34:43 2017 +0200
Committer: Uwe L. Korn <uw...@xhochy.com>
Committed: Sun Sep 10 08:34:43 2017 +0200
----------------------------------------------------------------------
python/doc/source/parquet.rst | 7 +++
python/pyarrow/parquet.py | 96 ++++++++++++++++++++++++++++---
python/pyarrow/table.pxi | 31 +++++++---
python/pyarrow/tests/test_parquet.py | 28 ++++++++-
python/pyarrow/types.pxi | 5 +-
5 files changed, 149 insertions(+), 18 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/arrow/blob/4a6a6cb4/python/doc/source/parquet.rst
----------------------------------------------------------------------
diff --git a/python/doc/source/parquet.rst b/python/doc/source/parquet.rst
index 7626c15..d466ba1 100644
--- a/python/doc/source/parquet.rst
+++ b/python/doc/source/parquet.rst
@@ -217,6 +217,13 @@ such as those produced by Hive:
dataset = pq.ParquetDataset('dataset_name/')
table = dataset.read()
+Using with Spark
+----------------
+
+Spark places some constraints on the types of Parquet files it will read. The
+option ``flavor='spark'`` will set these options automatically and also
+sanitize field characters unsupported by Spark SQL.
+
Multithreaded Reads
-------------------
http://git-wip-us.apache.org/repos/asf/arrow/blob/4a6a6cb4/python/pyarrow/parquet.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/parquet.py b/python/pyarrow/parquet.py
index 568aad4..1584b84 100644
--- a/python/pyarrow/parquet.py
+++ b/python/pyarrow/parquet.py
@@ -18,17 +18,17 @@
import os
import inspect
import json
-
+import re
import six
import numpy as np
from pyarrow.filesystem import FileSystem, LocalFileSystem, S3FSWrapper
from pyarrow._parquet import (ParquetReader, FileMetaData, # noqa
- RowGroupMetaData, ParquetSchema,
- ParquetWriter)
+ RowGroupMetaData, ParquetSchema)
import pyarrow._parquet as _parquet # noqa
import pyarrow.lib as lib
+import pyarrow as pa
# ----------------------------------------------------------------------
@@ -164,6 +164,73 @@ class ParquetFile(object):
return indices
+_SPARK_DISALLOWED_CHARS = re.compile('[ ,;{}()\n\t=]')
+
+
+def _sanitized_spark_field_name(name):
+ return _SPARK_DISALLOWED_CHARS.sub('_', name)
+
+
+def _sanitize_schema(schema, flavor):
+ if 'spark' in flavor:
+ sanitized_fields = []
+
+ schema_changed = False
+
+ for field in schema:
+ name = field.name
+ sanitized_name = _sanitized_spark_field_name(name)
+
+ if sanitized_name != name:
+ schema_changed = True
+ sanitized_field = pa.field(sanitized_name, field.type,
+ field.nullable, field.metadata)
+ sanitized_fields.append(sanitized_field)
+ else:
+ sanitized_fields.append(field)
+ return pa.schema(sanitized_fields), schema_changed
+ else:
+ return schema, False
+
+
+def _sanitize_table(table, new_schema, flavor):
+ # TODO: This will not handle prohibited characters in nested field names
+ if 'spark' in flavor:
+ column_data = [table[i].data for i in range(table.num_columns)]
+ return pa.Table.from_arrays(column_data, schema=new_schema)
+ else:
+ return table
+
+
+class ParquetWriter(object):
+ """
+
+ Parameters
+ ----------
+ where
+ schema
+ flavor : {'spark', ...}
+ Set options for compatibility with a particular reader
+ """
+ def __init__(self, where, schema, flavor=None, **options):
+ self.flavor = flavor
+ if flavor is not None:
+ schema, self.schema_changed = _sanitize_schema(schema, flavor)
+ else:
+ self.schema_changed = False
+
+ self.schema = schema
+ self.writer = _parquet.ParquetWriter(where, schema, **options)
+
+ def write_table(self, table, row_group_size=None):
+ if self.schema_changed:
+ table = _sanitize_table(table, self.schema, self.flavor)
+ self.writer.write_table(table, row_group_size=row_group_size)
+
+ def close(self):
+ self.writer.close()
+
+
def _get_pandas_index_columns(keyvalues):
return (json.loads(keyvalues[b'pandas'].decode('utf8'))
['index_columns'])
@@ -787,8 +854,9 @@ def read_pandas(source, columns=None, nthreads=1, metadata=None):
def write_table(table, where, row_group_size=None, version='1.0',
use_dictionary=True, compression='snappy',
- use_deprecated_int96_timestamps=False,
- coerce_timestamps=None, **kwargs):
+ use_deprecated_int96_timestamps=None,
+ coerce_timestamps=None,
+ flavor=None, **kwargs):
"""
Write a Table to Parquet format
@@ -804,15 +872,26 @@ def write_table(table, where, row_group_size=None, version='1.0',
use_dictionary : bool or list
Specify if we should use dictionary encoding in general or only for
some columns.
- use_deprecated_int96_timestamps : boolean, default False
- Write nanosecond resolution timestamps to INT96 Parquet format
+ use_deprecated_int96_timestamps : boolean, default None
+ Write nanosecond resolution timestamps to INT96 Parquet
+ format. Defaults to False unless enabled by flavor argument
coerce_timestamps : string, default None
Cast timestamps a particular resolution.
Valid values: {None, 'ms', 'us'}
compression : str or dict
Specify the compression codec, either on a general basis or per-column.
+ flavor : {'spark'}, default None
+ Sanitize schema or set other compatibility options for compatibility
"""
row_group_size = kwargs.get('chunk_size', row_group_size)
+
+ if use_deprecated_int96_timestamps is None:
+ # Use int96 timestamps for Spark
+ if flavor is not None and 'spark' in flavor:
+ use_deprecated_int96_timestamps = True
+ else:
+ use_deprecated_int96_timestamps = False
+
options = dict(
use_dictionary=use_dictionary,
compression=compression,
@@ -822,7 +901,8 @@ def write_table(table, where, row_group_size=None, version='1.0',
writer = None
try:
- writer = ParquetWriter(where, table.schema, **options)
+ writer = ParquetWriter(where, table.schema, flavor=flavor,
+ **options)
writer.write_table(table, row_group_size=row_group_size)
except:
if writer is not None:
http://git-wip-us.apache.org/repos/asf/arrow/blob/4a6a6cb4/python/pyarrow/table.pxi
----------------------------------------------------------------------
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index fc6099f..68eb5cb 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -758,7 +758,7 @@ cdef class Table:
return cls.from_arrays(arrays, names=names, metadata=metadata)
@staticmethod
- def from_arrays(arrays, names=None, dict metadata=None):
+ def from_arrays(arrays, names=None, schema=None, dict metadata=None):
"""
Construct a Table from Arrow arrays or columns
@@ -777,11 +777,22 @@ cdef class Table:
"""
cdef:
vector[shared_ptr[CColumn]] columns
- shared_ptr[CSchema] schema
+ Schema cy_schema
+ shared_ptr[CSchema] c_schema
shared_ptr[CTable] table
int i, K = <int> len(arrays)
- _schema_from_arrays(arrays, names, metadata, &schema)
+ if schema is None:
+ _schema_from_arrays(arrays, names, metadata, &c_schema)
+ elif schema is not None:
+ if names is not None:
+ raise ValueError('Cannot pass schema and arrays')
+ cy_schema = schema
+
+ if len(schema) != len(arrays):
+ raise ValueError('Schema and number of arrays unequal')
+
+ c_schema = cy_schema.sp_schema
columns.reserve(K)
@@ -789,23 +800,29 @@ cdef class Table:
if isinstance(arrays[i], Array):
columns.push_back(
make_shared[CColumn](
- schema.get().field(i),
+ c_schema.get().field(i),
(<Array> arrays[i]).sp_array
)
)
elif isinstance(arrays[i], ChunkedArray):
columns.push_back(
make_shared[CColumn](
- schema.get().field(i),
+ c_schema.get().field(i),
(<ChunkedArray> arrays[i]).sp_chunked_array
)
)
elif isinstance(arrays[i], Column):
- columns.push_back((<Column> arrays[i]).sp_column)
+ # Make sure schema field and column are consistent
+ columns.push_back(
+ make_shared[CColumn](
+ c_schema.get().field(i),
+ (<Column> arrays[i]).sp_column.get().data()
+ )
+ )
else:
raise ValueError(type(arrays[i]))
- table.reset(new CTable(schema, columns))
+ table.reset(new CTable(c_schema, columns))
return pyarrow_wrap_table(table)
@staticmethod
http://git-wip-us.apache.org/repos/asf/arrow/blob/4a6a6cb4/python/pyarrow/tests/test_parquet.py
----------------------------------------------------------------------
diff --git a/python/pyarrow/tests/test_parquet.py b/python/pyarrow/tests/test_parquet.py
index 5dfe0a5..9b5a4bc 100644
--- a/python/pyarrow/tests/test_parquet.py
+++ b/python/pyarrow/tests/test_parquet.py
@@ -17,6 +17,7 @@
from os.path import join as pjoin
import datetime
+import gc
import io
import os
import json
@@ -562,6 +563,10 @@ def test_date_time_types():
_check_roundtrip(table, expected=expected, version='2.0',
use_deprecated_int96_timestamps=True)
+ # Check that setting flavor to 'spark' uses int96 timestamps
+ _check_roundtrip(table, expected=expected, version='2.0',
+ flavor='spark')
+
# Unsupported stuff
def _assert_unsupported(array):
table = pa.Table.from_arrays([array], ['unsupported'])
@@ -577,6 +582,18 @@ def test_date_time_types():
@parquet
+def test_sanitized_spark_field_names():
+ a0 = pa.array([0, 1, 2, 3, 4])
+ name = 'prohib; ,\t{}'
+ table = pa.Table.from_arrays([a0], [name])
+
+ result = _roundtrip_table(table, flavor='spark')
+
+ expected_name = 'prohib______'
+ assert result.schema[0].name == expected_name
+
+
+@parquet
def test_fixed_size_binary():
t0 = pa.binary(10)
data = [b'fooooooooo', None, b'barooooooo', b'quxooooooo']
@@ -587,15 +604,19 @@ def test_fixed_size_binary():
_check_roundtrip(table)
-def _check_roundtrip(table, expected=None, **params):
+def _roundtrip_table(table, **params):
buf = io.BytesIO()
_write_table(table, buf, **params)
buf.seek(0)
+ return _read_table(buf)
+
+
+def _check_roundtrip(table, expected=None, **params):
if expected is None:
expected = table
- result = _read_table(buf)
+ result = _roundtrip_table(table, **params)
assert result.equals(expected)
@@ -1181,6 +1202,9 @@ def test_write_error_deletes_incomplete_file(tmpdir):
except pa.ArrowException:
pass
+ # Ensure that object has been destructed; this causes test failures on
+ # Windows
+ gc.collect()
assert not os.path.exists(filename)
http://git-wip-us.apache.org/repos/asf/arrow/blob/4a6a6cb4/python/pyarrow/types.pxi
----------------------------------------------------------------------
diff --git a/python/pyarrow/types.pxi b/python/pyarrow/types.pxi
index 3eaee6c..56670f6 100644
--- a/python/pyarrow/types.pxi
+++ b/python/pyarrow/types.pxi
@@ -299,7 +299,6 @@ cdef class Schema:
return self.schema.num_fields()
def __getitem__(self, int i):
-
cdef:
Field result = Field()
int num_fields = self.schema.num_fields()
@@ -318,6 +317,10 @@ cdef class Schema:
return result
+ def __iter__(self):
+ for i in range(len(self)):
+ yield self[i]
+
def _check_null(self):
if self.schema == NULL:
raise ReferenceError(