You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by uw...@apache.org on 2018/09/12 14:33:56 UTC

[arrow] branch master updated: ARROW-2936: [Python] Implement Table.cast for casting from one schema to another (if possible)

This is an automated email from the ASF dual-hosted git repository.

uwe pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new e912675  ARROW-2936: [Python] Implement Table.cast for casting from one schema to another (if possible)
e912675 is described below

commit e91267555cc72c3cf0e5a472d3c89eefed69d6f7
Author: Krisztián Szűcs <sz...@gmail.com>
AuthorDate: Wed Sep 12 16:33:41 2018 +0200

    ARROW-2936: [Python] Implement Table.cast for casting from one schema to another (if possible)
    
    Also contains a fix for float truncation.
    
    Author: Krisztián Szűcs <sz...@gmail.com>
    
    Closes #2530 from kszucs/ARROW-2936 and squashes the following commits:
    
    1d3b7ec0 <Krisztián Szűcs> unsafe cast assertion; py2 compatible tests
    ca44e219 <Krisztián Szűcs> apidoc
    772a666f <Krisztián Szűcs> flake8
    90fc3183 <Krisztián Szűcs> Table.cast implementation; fix float truncation casting rule
---
 cpp/src/arrow/compute/compute-test.cc | 44 +++++++++++++++----
 cpp/src/arrow/compute/kernels/cast.cc | 16 +++----
 python/pyarrow/table.pxi              | 37 ++++++++++++++--
 python/pyarrow/tests/test_array.py    | 32 ++++++++++++++
 python/pyarrow/tests/test_table.py    | 82 +++++++++++++++++++++++++++++++++++
 5 files changed, 189 insertions(+), 22 deletions(-)

diff --git a/cpp/src/arrow/compute/compute-test.cc b/cpp/src/arrow/compute/compute-test.cc
index a1dfdef..233f8a6 100644
--- a/cpp/src/arrow/compute/compute-test.cc
+++ b/cpp/src/arrow/compute/compute-test.cc
@@ -286,20 +286,21 @@ TEST_F(TestCast, ToIntDowncastUnsafe) {
 }
 
 TEST_F(TestCast, FloatingPointToInt) {
+  // which means allow_float_truncate == false
   auto options = CastOptions::Safe();
 
   vector<bool> is_valid = {true, false, true, true, true};
   vector<bool> all_valid = {true, true, true, true, true};
 
-  // float32 point to integer
-  vector<float> v1 = {1.5, 0, 0.5, -1.5, 5.5};
+  // float32 to int32 no truncation
+  vector<float> v1 = {1.0, 0, 0.0, -1.0, 5.0};
   vector<int32_t> e1 = {1, 0, 0, -1, 5};
   CheckCase<FloatType, float, Int32Type, int32_t>(float32(), v1, is_valid, int32(), e1,
                                                   options);
   CheckCase<FloatType, float, Int32Type, int32_t>(float32(), v1, all_valid, int32(), e1,
                                                   options);
 
-  // float64 point to integer
+  // float64 to int32 no truncation
   vector<double> v2 = {1.0, 0, 0.0, -1.0, 5.0};
   vector<int32_t> e2 = {1, 0, 0, -1, 5};
   CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v2, is_valid, int32(), e2,
@@ -307,15 +308,40 @@ TEST_F(TestCast, FloatingPointToInt) {
   CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v2, all_valid, int32(), e2,
                                                     options);
 
-  vector<double> v3 = {1.5, 0, 0.5, -1.5, 5.5};
-  vector<int32_t> e3 = {1, 0, 0, -1, 5};
-  CheckFails<DoubleType>(float64(), v3, is_valid, int32(), options);
-  CheckFails<DoubleType>(float64(), v3, all_valid, int32(), options);
+  // float64 to int64 no truncation
+  vector<double> v3 = {1.0, 0, 0.0, -1.0, 5.0};
+  vector<int64_t> e3 = {1, 0, 0, -1, 5};
+  CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v3, is_valid, int64(), e3,
+                                                    options);
+  CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v3, all_valid, int64(), e3,
+                                                    options);
+
+  // float64 to int32 truncate
+  vector<double> v4 = {1.5, 0, 0.5, -1.5, 5.5};
+  vector<int32_t> e4 = {1, 0, 0, -1, 5};
+
+  options.allow_float_truncate = false;
+  CheckFails<DoubleType>(float64(), v4, is_valid, int32(), options);
+  CheckFails<DoubleType>(float64(), v4, all_valid, int32(), options);
+
+  options.allow_float_truncate = true;
+  CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v4, is_valid, int32(), e4,
+                                                    options);
+  CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v4, all_valid, int32(), e4,
+                                                    options);
+
+  // float64 to int64 truncate
+  vector<double> v5 = {1.5, 0, 0.5, -1.5, 5.5};
+  vector<int64_t> e5 = {1, 0, 0, -1, 5};
+
+  options.allow_float_truncate = false;
+  CheckFails<DoubleType>(float64(), v5, is_valid, int64(), options);
+  CheckFails<DoubleType>(float64(), v5, all_valid, int64(), options);
 
   options.allow_float_truncate = true;
-  CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v3, is_valid, int32(), e3,
+  CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v5, is_valid, int64(), e5,
                                                     options);
-  CheckCase<DoubleType, double, Int32Type, int32_t>(float64(), v3, all_valid, int32(), e3,
+  CheckCase<DoubleType, double, Int64Type, int64_t>(float64(), v5, all_valid, int64(), e5,
                                                     options);
 }
 
diff --git a/cpp/src/arrow/compute/kernels/cast.cc b/cpp/src/arrow/compute/kernels/cast.cc
index 2a0479d..369ebb9 100644
--- a/cpp/src/arrow/compute/kernels/cast.cc
+++ b/cpp/src/arrow/compute/kernels/cast.cc
@@ -194,20 +194,16 @@ struct is_integer_downcast<
 };
 
 template <typename O, typename I, typename Enable = void>
-struct is_float_downcast {
+struct is_float_truncate {
   static constexpr bool value = false;
 };
 
 template <typename O, typename I>
-struct is_float_downcast<
+struct is_float_truncate<
     O, I,
-    typename std::enable_if<std::is_base_of<Number, O>::value &&
+    typename std::enable_if<std::is_base_of<Integer, O>::value &&
                             std::is_base_of<FloatingPoint, I>::value>::type> {
-  using O_T = typename O::c_type;
-  using I_T = typename I::c_type;
-
-  // Smaller output size
-  static constexpr bool value = !std::is_same<O, I>::value && (sizeof(O_T) < sizeof(I_T));
+  static constexpr bool value = true;
 };
 
 template <typename O, typename I>
@@ -270,7 +266,7 @@ struct CastFunctor<O, I,
 };
 
 template <typename O, typename I>
-struct CastFunctor<O, I, typename std::enable_if<is_float_downcast<O, I>::value>::type> {
+struct CastFunctor<O, I, typename std::enable_if<is_float_truncate<O, I>::value>::type> {
   void operator()(FunctionContext* ctx, const CastOptions& options,
                   const ArrayData& input, ArrayData* output) {
     using in_type = typename I::c_type;
@@ -316,7 +312,7 @@ struct CastFunctor<O, I, typename std::enable_if<is_float_downcast<O, I>::value>
 template <typename O, typename I>
 struct CastFunctor<O, I,
                    typename std::enable_if<is_numeric_cast<O, I>::value &&
-                                           !is_float_downcast<O, I>::value &&
+                                           !is_float_truncate<O, I>::value &&
                                            !is_integer_downcast<O, I>::value>::type> {
   void operator()(FunctionContext* ctx, const CastOptions& options,
                   const ArrayData& input, ArrayData* output) {
diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi
index bbf40e0..62f6803 100644
--- a/python/pyarrow/table.pxi
+++ b/python/pyarrow/table.pxi
@@ -638,8 +638,8 @@ cdef _schema_from_arrays(arrays, names, dict metadata,
             raise ValueError('Must pass names when constructing '
                              'from Array objects')
         if len(names) != K:
-            raise ValueError("Length of names ({}) does not match "
-                             "length of arrays ({})".format(len(names), K))
+            raise ValueError('Length of names ({}) does not match '
+                             'length of arrays ({})'.format(len(names), K))
         for i in range(K):
             val = arrays[i]
             if isinstance(val, (Array, ChunkedArray)):
@@ -760,7 +760,7 @@ cdef class RecordBatch:
 
     def column(self, i):
         """
-        Select single column from record batcha
+        Select single column from record batch
 
         Returns
         -------
@@ -1078,6 +1078,37 @@ cdef class Table:
 
         return result
 
+    def cast(self, Schema target_schema, bint safe=True):
+        """
+        Cast table values to another schema
+
+        Parameters
+        ----------
+        target_schema : Schema
+            Schema to cast to, the names and order of fields must match
+        safe : boolean, default True
+            Check for overflows or other unsafe conversions
+
+        Returns
+        -------
+        casted : Table
+        """
+        cdef:
+            Column column, casted
+            Field field
+            list newcols = []
+
+        if self.schema.names != target_schema.names:
+            raise ValueError("Target schema's field names are not matching "
+                             "the table's field names: {!r}, {!r}"
+                             .format(self.schema.names, target_schema.names))
+
+        for column, field in zip(self.itercolumns(), target_schema):
+            casted = column.cast(field.type, safe=safe)
+            newcols.append(casted)
+
+        return Table.from_arrays(newcols, schema=target_schema)
+
     @classmethod
     def from_pandas(cls, df, Schema schema=None, bint preserve_index=True,
                     nthreads=None, columns=None, bint safe=True):
diff --git a/python/pyarrow/tests/test_array.py b/python/pyarrow/tests/test_array.py
index d4b582e..0002dce 100644
--- a/python/pyarrow/tests/test_array.py
+++ b/python/pyarrow/tests/test_array.py
@@ -549,6 +549,38 @@ def test_cast_integers_unsafe():
         _check_cast_case(case, safe=False)
 
 
+def test_floating_point_truncate_safe():
+    safe_cases = [
+        (np.array([1.0, 2.0, 3.0], dtype='float32'), 'float32',
+         np.array([1, 2, 3], dtype='i4'), pa.int32()),
+        (np.array([1.0, 2.0, 3.0], dtype='float64'), 'float64',
+         np.array([1, 2, 3], dtype='i4'), pa.int32()),
+        (np.array([-10.0, 20.0, -30.0], dtype='float64'), 'float64',
+         np.array([-10, 20, -30], dtype='i4'), pa.int32()),
+    ]
+    for case in safe_cases:
+        _check_cast_case(case, safe=True)
+
+
+def test_floating_point_truncate_unsafe():
+    unsafe_cases = [
+        (np.array([1.1, 2.2, 3.3], dtype='float32'), 'float32',
+         np.array([1, 2, 3], dtype='i4'), pa.int32()),
+        (np.array([1.1, 2.2, 3.3], dtype='float64'), 'float64',
+         np.array([1, 2, 3], dtype='i4'), pa.int32()),
+        (np.array([-10.1, 20.2, -30.3], dtype='float64'), 'float64',
+         np.array([-10, 20, -30], dtype='i4'), pa.int32()),
+    ]
+    for case in unsafe_cases:
+        # test safe casting raises
+        with pytest.raises(pa.ArrowInvalid,
+                           match='Floating point value truncated'):
+            _check_cast_case(case, safe=True)
+
+        # test unsafe casting truncates
+        _check_cast_case(case, safe=False)
+
+
 def test_cast_timestamp_unit():
     # ARROW-1680
     val = datetime.datetime.now()
diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py
index 14609ad..f45e918 100644
--- a/python/pyarrow/tests/test_table.py
+++ b/python/pyarrow/tests/test_table.py
@@ -749,3 +749,85 @@ def test_table_negative_indexing():
 
     with pytest.raises(IndexError):
         table[4]
+
+
+def test_table_cast_to_incompatible_schema():
+    data = [
+        pa.array(range(5)),
+        pa.array([-10, -5, 0, 5, 10]),
+    ]
+    table = pa.Table.from_arrays(data, names=tuple('ab'))
+
+    target_schema1 = pa.schema([
+        pa.field('A', pa.int32()),
+        pa.field('b', pa.int16()),
+    ])
+    target_schema2 = pa.schema([
+        pa.field('a', pa.int32()),
+    ])
+    message = ("Target schema's field names are not matching the table's "
+               "field names:.*")
+    with pytest.raises(ValueError, match=message):
+        table.cast(target_schema1)
+    with pytest.raises(ValueError, match=message):
+        table.cast(target_schema2)
+
+
+def test_table_safe_casting():
+    data = [
+        pa.array(range(5), type=pa.int64()),
+        pa.array([-10, -5, 0, 5, 10], type=pa.int32()),
+        pa.array([1.0, 2.0, 3.0], type=pa.float64()),
+        pa.array(['ab', 'bc', 'cd'], type=pa.string())
+    ]
+    table = pa.Table.from_arrays(data, names=tuple('abcd'))
+
+    expected_data = [
+        pa.array(range(5), type=pa.int32()),
+        pa.array([-10, -5, 0, 5, 10], type=pa.int16()),
+        pa.array([1, 2, 3], type=pa.int64()),
+        pa.array(['ab', 'bc', 'cd'], type=pa.string())
+    ]
+    expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd'))
+
+    target_schema = pa.schema([
+        pa.field('a', pa.int32()),
+        pa.field('b', pa.int16()),
+        pa.field('c', pa.int64()),
+        pa.field('d', pa.string())
+    ])
+    casted_table = table.cast(target_schema)
+
+    assert casted_table.equals(expected_table)
+
+
+def test_table_unsafe_casting():
+    data = [
+        pa.array(range(5), type=pa.int64()),
+        pa.array([-10, -5, 0, 5, 10], type=pa.int32()),
+        pa.array([1.1, 2.2, 3.3], type=pa.float64()),
+        pa.array(['ab', 'bc', 'cd'], type=pa.string())
+    ]
+    table = pa.Table.from_arrays(data, names=tuple('abcd'))
+
+    expected_data = [
+        pa.array(range(5), type=pa.int32()),
+        pa.array([-10, -5, 0, 5, 10], type=pa.int16()),
+        pa.array([1, 2, 3], type=pa.int64()),
+        pa.array(['ab', 'bc', 'cd'], type=pa.string())
+    ]
+    expected_table = pa.Table.from_arrays(expected_data, names=tuple('abcd'))
+
+    target_schema = pa.schema([
+        pa.field('a', pa.int32()),
+        pa.field('b', pa.int16()),
+        pa.field('c', pa.int64()),
+        pa.field('d', pa.string())
+    ])
+
+    with pytest.raises(pa.ArrowInvalid,
+                       match='Floating point value truncated'):
+        table.cast(target_schema)
+
+    casted_table = table.cast(target_schema, safe=False)
+    assert casted_table.equals(expected_table)