You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by uw...@apache.org on 2018/09/12 14:33:56 UTC
[arrow] branch master updated: ARROW-2936: [Python] Implement
Table.cast for casting from one schema to another (if possible)
This is an automated email from the ASF dual-hosted git repository.
uwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git
The following commit(s) were added to refs/heads/master by this push:
new e912675 ARROW-2936: [Python] Implement Table.cast for casting from one schema to another (if possible)
e912675 is described below
commit e91267555cc72c3cf0e5a472d3c89eefed69d6f7
Author: Krisztián Szűcs <sz...@gmail.com>
AuthorDate: Wed Sep 12 16:33:41 2018 +0200
ARROW-2936: [Python] Implement Table.cast for casting from one schema to another (if possible)
Also contains a fix for float truncation.
Author: Krisztián Szűcs <sz...@gmail.com>
Closes #2530 from kszucs/ARROW-2936 and squashes the following commits:
1d3b7ec0 <Krisztián Szűcs> unsafe cast assertion; py2 compatible tests
ca44e219 <Krisztián Szűcs> apidoc
772a666f <Krisztián Szűcs> flake8
90fc3183 <Krisztián Szűcs> Table.cast implementation; fix float truncation casting rule
---
cpp/src/arrow/compute/compute-test.cc | 44 +++++++++++++++----
cpp/src/arrow/compute/kernels/cast.cc | 16 +++----
python/pyarrow/table.pxi | 37 ++++++++++++++--
python/pyarrow/tests/test_array.py | 32 ++++++++++++++
python/pyarrow/tests/test_table.py | 82 +++++++++++++++++++++++++++++++++++
5 files changed, 189 insertions(+), 22 deletions(-)
diff --git a/cpp/src/arrow/compute/compute-test.cc b/cpp/src/arrow/compute/compute-test.cc
index a1dfdef..233f8a6 100644
--- a/cpp/src/arrow/compute/compute-test.cc
+++ b/cpp/src/arrow/compute/compute-test.cc
@@ -286,20 +286,21 @@ TEST_F(TestCast, ToIntDowncastUnsafe) {
}
TEST_F(TestCast, FloatingPointToInt) {
+ // which means allow_float_truncate == false
auto options = CastOptions::Safe();
vector<bool> is_valid = {true, false, true, true, true};
vector<bool> all_valid = {true, true, true, true, true};
- // float32 point to integer
- vector<float> v1 = {1.5, 0, 0.5, -1.5, 5.5};
+ // float32 to int32 no truncation
+ vector<float> v1 = {1.0, 0, 0.0, -1.0, 5.0};
vector<int32_t> e1 = {1, 0, 0, -1, 5};
CheckCase<FloatType, float, Int32Type, int32_t>(float32(), v1, is_valid, int32(), e1,
options);
CheckCase<FloatType, float, Int32Type, int32_t>(float32(), v1, all_valid, int32(), e1,
options);
- // float64 point to integer
+ // float64 to int32 no truncation
vector<double> v2 = {1.0, 0, 0.0, -1.0, 5.0};
vector<int32_t> e2 = {1, 0, 0, -1, 5};
CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v2, is_valid, int32(), e2,
@@ -307,15 +308,40 @@ TEST_F(TestCast, FloatingPointToInt) {
CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v2, all_valid, int32(), e2,
options);
- vector<double> v3 = {1.5, 0, 0.5, -1.5, 5.5};
- vector<int32_t> e3 = {1, 0, 0, -1, 5};
- CheckFails<DoubleType>(float64(), v3, is_valid, int32(), options);
- CheckFails<DoubleType>(float64(), v3, all_valid, int32(), options);
+ // float64 to int64 no truncation
+ vector<double> v3 = {1.0, 0, 0.0, -1.0, 5.0};
+ vector<int64_t> e3 = {1, 0, 0, -1, 5};
+ CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v3, is_valid, int64(), e3,
+ options);
+ CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v3, all_valid, int64(), e3,
+ options);
+
+ // float64 to int32 truncate
+ vector<double> v4 = {1.5, 0, 0.5, -1.5, 5.5};
+ vector<int32_t> e4 = {1, 0, 0, -1, 5};
+
+ options.allow_float_truncate = false;
+ CheckFails<DoubleType>(float64(), v4, is_valid, int32(), options);
+ CheckFails<DoubleType>(float64(), v4, all_valid, int32(), options);
+
+ options.allow_float_truncate = true;
+ CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v4, is_valid, int32(), e4,
+ options);
+ CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v4, all_valid, int32(), e4,
+ options);
+
+ // float64 to int64 truncate
+ vector<double> v5 = {1.5, 0, 0.5, -1.5, 5.5};
+ vector<int64_t> e5 = {1, 0, 0, -1, 5};
+
+ options.allow_float_truncate = false;
+ CheckFails<DoubleType>(float64(), v5, is_valid, int64(), options);
+ CheckFails<DoubleType>(float64(), v5, all_valid, int64(), options);
options.allow_float_truncate = true;
- CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v3, is_valid, int32(), e3,
+ CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v5, is_valid, int64(), e5,
options);
- CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v3, all_valid, int32(), e3,
+ CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v5, all_valid, int64(), e5,
options);
}
diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc
index 2a0479d..369ebb9 100644
--- a/cpp/src/arrow/compute/kernels/cast.cc
+++ b/cpp/src/arrow/compute/kernels/cast.cc
@@ -194,20 +194,16 @@ struct is_integer_downcast<
};
template <typename O, typename I, typename Enable = void>
-struct is_float_downcast {
+struct is_float_truncate {
static constexpr bool value = false;
};
template <typename O, typename I>
-struct is_float_downcast<
+struct is_float_truncate<
O, I,
- typename std::enable_if<std::is_base_of<Number, O>::value &&
+ typename std::enable_if<std::is_base_of<Integer, O>::value &&
std::is_base_of<FloatingPoint, I>::value>::type> {
- using O_T = typename O::c_type;
- using I_T = typename I::c_type;
-
- // Smaller output size
- static constexpr bool value = !std::is_same<O, I>::value && (sizeof(O_T) < sizeof(I_T));
+ static constexpr bool value = true;
};
template <typename O, typename I>
@@ -270,7 +266,7 @@ struct CastFunctor<O, I,
};
template <typename O, typename I>
-struct CastFunctor<O, I, typename std::enable_if<is_float_downcast<O, I>::value>::type> {
+struct CastFunctor<O, I, typename std::enable_if<is_float_truncate<O, I>::value>::type> {
void operator()(FunctionContext* ctx, const CastOptions& options,
const ArrayData& input, ArrayData* output) {
using in_type = typename I::c_type;
@@ -316,7 +312,7 @@ struct CastFunctor<O, I, typename std::enable_if<is_float_downcast<O, I>::value>
template <typename O, typename I>
struct CastFunctor<O, I,
typename std::enable_if<is_numeric_cast<O, I>::value &&
- !is_float_downcast<O, I>::value &&
+ !is_float_truncate<O, I>::value &&
!is_integer_downcast<O, I>::value>::type> {
void operator()(FunctionContext* ctx, const CastOptions& options,
const ArrayData& input, ArrayData* output) {
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index bbf40e0..62f6803 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -638,8 +638,8 @@ cdef _schema_from_arrays(arrays, names, dict metadata,
raise ValueError('Must pass names when constructing '
'from Array objects')
if len(names) != K:
- raise ValueError("Length of names ({}) does not match "
- "length of arrays ({})".format(len(names), K))
+ raise ValueError('Length of names ({}) does not match '
+ 'length of arrays ({})'.format(len(names), K))
for i in range(K):
val = arrays[i]
if isinstance(val, (Array, ChunkedArray)):
@@ -760,7 +760,7 @@ cdef class RecordBatch:
def column(self, i):
"""
- Select single column from record batcha
+ Select single column from record batch
Returns
-------
@@ -1078,6 +1078,37 @@ cdef class Table:
return result
+ def cast(self, Schema target_schema, bint safe=True):
+ """
+ Cast table values to another schema
+
+ Parameters
+ ----------
+ target_schema : Schema
+ Schema to cast to, the names and order of fields must match
+ safe : boolean, default True
+ Check for overflows or other unsafe conversions
+
+ Returns
+ -------
+ casted : Table
+ """
+ cdef:
+ Column column, casted
+ Field field
+ list newcols = []
+
+ if self.schema.names != target_schema.names:
+ raise ValueError("Target schema's field names are not matching "
+ "the table's field names: {!r}, {!r}"
+ .format(self.schema.names, target_schema.names))
+
+ for column, field in zip(self.itercolumns(), target_schema):
+ casted = column.cast(field.type, safe=safe)
+ newcols.append(casted)
+
+ return Table.from_arrays(newcols, schema=target_schema)
+
@classmethod
def from_pandas(cls, df, Schema schema=None, bint preserve_index=True,
nthreads=None, columns=None, bint safe=True):
diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py
index d4b582e..0002dce 100644
--- a/python/pyarrow/tests/test_array.py
+++ b/python/pyarrow/tests/test_array.py
@@ -549,6 +549,38 @@ def test_cast_integers_unsafe():
_check_cast_case(case, safe=False)
+def test_floating_point_truncate_safe():
+ safe_cases = [
+ (np.array([1.0, 2.0, 3.0], dtype='float32'), 'float32',
+ np.array([1, 2, 3], dtype='i4'), pa.int32()),
+ (np.array([1.0, 2.0, 3.0], dtype='float64'), 'float64',
+ np.array([1, 2, 3], dtype='i4'), pa.int32()),
+ (np.array([-10.0, 20.0, -30.0], dtype='float64'), 'float64',
+ np.array([-10, 20, -30], dtype='i4'), pa.int32()),
+ ]
+ for case in safe_cases:
+ _check_cast_case(case, safe=True)
+
+
+def test_floating_point_truncate_unsafe():
+ unsafe_cases = [
+ (np.array([1.1, 2.2, 3.3], dtype='float32'), 'float32',
+ np.array([1, 2, 3], dtype='i4'), pa.int32()),
+ (np.array([1.1, 2.2, 3.3], dtype='float64'), 'float64',
+ np.array([1, 2, 3], dtype='i4'), pa.int32()),
+ (np.array([-10.1, 20.2, -30.3], dtype='float64'), 'float64',
+ np.array([-10, 20, -30], dtype='i4'), pa.int32()),
+ ]
+ for case in unsafe_cases:
+ # test safe casting raises
+ with pytest.raises(pa.ArrowInvalid,
+ match='Floating point value truncated'):
+ _check_cast_case(case, safe=True)
+
+ # test unsafe casting truncates
+ _check_cast_case(case, safe=False)
+
+
def test_cast_timestamp_unit():
# ARROW-1680
val = datetime.datetime.now()
diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py
index 14609ad..f45e918 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -749,3 +749,85 @@ def test_table_negative_indexing():
with pytest.raises(IndexError):
table[4]
+
+
+def test_table_cast_to_incompatible_schema():
+ data = [
+ pa.array(range(5)),
+ pa.array([-10, -5, 0, 5, 10]),
+ ]
+ table = pa.Table.from_arrays(data, names=tuple('ab'))
+
+ target_schema1 = pa.schema([
+ pa.field('A', pa.int32()),
+ pa.field('b', pa.int16()),
+ ])
+ target_schema2 = pa.schema([
+ pa.field('a', pa.int32()),
+ ])
+ message = ("Target schema's field names are not matching the table's "
+ "field names:.*")
+ with pytest.raises(ValueError, match=message):
+ table.cast(target_schema1)
+ with pytest.raises(ValueError, match=message):
+ table.cast(target_schema2)
+
+
+def test_table_safe_casting():
+ data = [
+ pa.array(range(5), type=pa.int64()),
+ pa.array([-10, -5, 0, 5, 10], type=pa.int32()),
+ pa.array([1.0, 2.0, 3.0], type=pa.float64()),
+ pa.array(['ab', 'bc', 'cd'], type=pa.string())
+ ]
+ table = pa.Table.from_arrays(data, names=tuple('abcd'))
+
+ expected_data = [
+ pa.array(range(5), type=pa.int32()),
+ pa.array([-10, -5, 0, 5, 10], type=pa.int16()),
+ pa.array([1, 2, 3], type=pa.int64()),
+ pa.array(['ab', 'bc', 'cd'], type=pa.string())
+ ]
+ expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd'))
+
+ target_schema = pa.schema([
+ pa.field('a', pa.int32()),
+ pa.field('b', pa.int16()),
+ pa.field('c', pa.int64()),
+ pa.field('d', pa.string())
+ ])
+ casted_table = table.cast(target_schema)
+
+ assert casted_table.equals(expected_table)
+
+
+def test_table_unsafe_casting():
+ data = [
+ pa.array(range(5), type=pa.int64()),
+ pa.array([-10, -5, 0, 5, 10], type=pa.int32()),
+ pa.array([1.1, 2.2, 3.3], type=pa.float64()),
+ pa.array(['ab', 'bc', 'cd'], type=pa.string())
+ ]
+ table = pa.Table.from_arrays(data, names=tuple('abcd'))
+
+ expected_data = [
+ pa.array(range(5), type=pa.int32()),
+ pa.array([-10, -5, 0, 5, 10], type=pa.int16()),
+ pa.array([1, 2, 3], type=pa.int64()),
+ pa.array(['ab', 'bc', 'cd'], type=pa.string())
+ ]
+ expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd'))
+
+ target_schema = pa.schema([
+ pa.field('a', pa.int32()),
+ pa.field('b', pa.int16()),
+ pa.field('c', pa.int64()),
+ pa.field('d', pa.string())
+ ])
+
+ with pytest.raises(pa.ArrowInvalid,
+ match='Floating point value truncated'):
+ table.cast(target_schema)
+
+ casted_table = table.cast(target_schema, safe=False)
+ assert casted_table.equals(expected_table)