You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ko...@apache.org on 2023/01/06 05:16:40 UTC

[arrow] branch master updated: ARROW-18086: [Ruby] Add support for HalfFloat (#15204)

This is an automated email from the ASF dual-hosted git repository.

kou pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new d2481a610f ARROW-18086: [Ruby] Add support for HalfFloat (#15204)
d2481a610f is described below

commit d2481a610f7653e1b965366461dd6be0c22c1fda
Author: Sutou Kouhei <ko...@clear-code.com>
AuthorDate: Fri Jan 6 14:16:28 2023 +0900

    ARROW-18086: [Ruby] Add support for HalfFloat (#15204)
    
    Authored-by: Sutou Kouhei <ko...@clear-code.com>
    Signed-off-by: Sutou Kouhei <ko...@clear-code.com>
---
 ruby/red-arrow/ext/arrow/converters.hpp            |  47 +++++---
 ruby/red-arrow/ext/arrow/raw-records.cpp           |   3 +-
 ruby/red-arrow/ext/arrow/values.cpp                |   3 +-
 .../lib/arrow/half-float-array-builder.rb          |  32 +++++
 ruby/red-arrow/lib/arrow/half-float-array.rb       |  24 ++++
 ruby/red-arrow/lib/arrow/half-float.rb             | 118 +++++++++++++++++++
 ruby/red-arrow/lib/arrow/loader.rb                 |   4 +
 .../test/raw-records/test-basic-arrays.rb          |  10 ++
 ruby/red-arrow/test/test-half-float-array.rb       |  43 +++++++
 ruby/red-arrow/test/test-half-float.rb             | 130 +++++++++++++++++++++
 ruby/red-arrow/test/values/test-basic-arrays.rb    |  10 ++
 11 files changed, 406 insertions(+), 18 deletions(-)

diff --git a/ruby/red-arrow/ext/arrow/converters.hpp b/ruby/red-arrow/ext/arrow/converters.hpp
index 5a500574de..28955432a7 100644
--- a/ruby/red-arrow/ext/arrow/converters.hpp
+++ b/ruby/red-arrow/ext/arrow/converters.hpp
@@ -106,10 +106,34 @@ namespace red_arrow {
       return ULL2NUM(array.Value(i));
     }
 
-    // TODO
-    // inline VALUE convert(const arrow::HalfFloatArray& array,
-    //                      const int64_t i) {
-    // }
+    inline VALUE convert(const arrow::HalfFloatArray& array,
+                         const int64_t i) {
+      const auto value = array.Value(i);
+      // | sign (1 bit) | exponent (5 bit) | fraction (10 bit) |
+      constexpr auto exponent_n_bits = 5;
+      static const auto exponent_mask =
+        static_cast<uint32_t>(std::pow(2.0, exponent_n_bits) - 1);
+      constexpr auto exponent_bias = 15;
+      constexpr auto fraction_n_bits = 10;
+      static const auto fraction_mask =
+        static_cast<uint32_t>(std::pow(2.0, fraction_n_bits)) - 1;
+      static const auto fraction_denominator = std::pow(2.0, fraction_n_bits);
+      const auto sign = value >> (exponent_n_bits + fraction_n_bits);
+      const auto exponent = (value >> fraction_n_bits) & exponent_mask;
+      const auto fraction = value & fraction_mask;
+      if (exponent == exponent_mask) {
+        if (sign == 0) {
+          return DBL2NUM(HUGE_VAL);
+        } else {
+          return DBL2NUM(-HUGE_VAL);
+        }
+      } else {
+        const auto implicit_fraction = (exponent == 0) ? 0 : 1;
+        return DBL2NUM(((sign == 0) ? 1 : -1) *
+                       std::pow(2.0, exponent - exponent_bias) *
+                       (implicit_fraction + fraction / fraction_denominator));
+      }
+    }
 
     inline VALUE convert(const arrow::FloatArray& array,
                          const int64_t i) {
@@ -320,8 +344,7 @@ namespace red_arrow {
     VISIT(UInt16)
     VISIT(UInt32)
     VISIT(UInt64)
-    // TODO
-    // VISIT(HalfFloat)
+    VISIT(HalfFloat)
     VISIT(Float)
     VISIT(Double)
     VISIT(Binary)
@@ -427,8 +450,7 @@ namespace red_arrow {
     VISIT(UInt16)
     VISIT(UInt32)
     VISIT(UInt64)
-    // TODO
-    // VISIT(HalfFloat)
+    VISIT(HalfFloat)
     VISIT(Float)
     VISIT(Double)
     VISIT(Binary)
@@ -530,8 +552,7 @@ namespace red_arrow {
     VISIT(UInt16)
     VISIT(UInt32)
     VISIT(UInt64)
-    // TODO
-    // VISIT(HalfFloat)
+    VISIT(HalfFloat)
     VISIT(Float)
     VISIT(Double)
     VISIT(Binary)
@@ -634,8 +655,7 @@ namespace red_arrow {
     VISIT(UInt16)
     VISIT(UInt32)
     VISIT(UInt64)
-    // TODO
-    // VISIT(HalfFloat)
+    VISIT(HalfFloat)
     VISIT(Float)
     VISIT(Double)
     VISIT(Binary)
@@ -761,8 +781,7 @@ namespace red_arrow {
     VISIT(UInt16)
     VISIT(UInt32)
     VISIT(UInt64)
-    // TODO
-    // VISIT(HalfFloat)
+    VISIT(HalfFloat)
     VISIT(Float)
     VISIT(Double)
     VISIT(Binary)
diff --git a/ruby/red-arrow/ext/arrow/raw-records.cpp b/ruby/red-arrow/ext/arrow/raw-records.cpp
index e34ea2d3c8..e0326f9d2f 100644
--- a/ruby/red-arrow/ext/arrow/raw-records.cpp
+++ b/ruby/red-arrow/ext/arrow/raw-records.cpp
@@ -84,8 +84,7 @@ namespace red_arrow {
       VISIT(UInt16)
       VISIT(UInt32)
       VISIT(UInt64)
-      // TODO
-      // VISIT(HalfFloat)
+      VISIT(HalfFloat)
       VISIT(Float)
       VISIT(Double)
       VISIT(Binary)
diff --git a/ruby/red-arrow/ext/arrow/values.cpp b/ruby/red-arrow/ext/arrow/values.cpp
index 0fcb46e1bb..e412ce2273 100644
--- a/ruby/red-arrow/ext/arrow/values.cpp
+++ b/ruby/red-arrow/ext/arrow/values.cpp
@@ -65,8 +65,7 @@ namespace red_arrow {
       VISIT(UInt16)
       VISIT(UInt32)
       VISIT(UInt64)
-      // TODO
-      // VISIT(HalfFloat)
+      VISIT(HalfFloat)
       VISIT(Float)
       VISIT(Double)
       VISIT(Binary)
diff --git a/ruby/red-arrow/lib/arrow/half-float-array-builder.rb b/ruby/red-arrow/lib/arrow/half-float-array-builder.rb
new file mode 100644
index 0000000000..2b171e57a9
--- /dev/null
+++ b/ruby/red-arrow/lib/arrow/half-float-array-builder.rb
@@ -0,0 +1,32 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+module Arrow
+  class HalfFloatArrayBuilder
+    private
+    def convert_to_arrow_value(value)
+      case value
+      when Float
+        HalfFloat.new(value).to_uint16
+      when HalfFloat
+        value.to_uint16
+      else
+        value
+      end
+    end
+  end
+end
diff --git a/ruby/red-arrow/lib/arrow/half-float-array.rb b/ruby/red-arrow/lib/arrow/half-float-array.rb
new file mode 100644
index 0000000000..94b8ebd51a
--- /dev/null
+++ b/ruby/red-arrow/lib/arrow/half-float-array.rb
@@ -0,0 +1,24 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+module Arrow
+  class HalfFloatArray
+    def get_value(i)
+      HalfFloat.new(get_raw_value(i)).to_f
+    end
+  end
+end
diff --git a/ruby/red-arrow/lib/arrow/half-float.rb b/ruby/red-arrow/lib/arrow/half-float.rb
new file mode 100644
index 0000000000..e6fe976a29
--- /dev/null
+++ b/ruby/red-arrow/lib/arrow/half-float.rb
@@ -0,0 +1,118 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+module Arrow
+  class HalfFloat
+    MAX = 65504
+    MIN = -65504
+    EXPONENT_N_BITS = 5
+    EXPONENT_MASK = (2 ** EXPONENT_N_BITS) - 1
+    EXPONENT_BIAS = 15
+    FRACTION_N_BITS = 10
+    FRACTION_MASK = (2 ** FRACTION_N_BITS) - 1
+    FRACTION_DENOMINATOR = 2.0 ** FRACTION_N_BITS
+
+    attr_reader :sign
+    attr_reader :exponent
+    attr_reader :fraction
+    def initialize(*args)
+      n_args = args.size
+      case n_args
+      when 1
+        if args[0].is_a?(Float)
+          @sign, @exponent, @fraction = deconstruct_float(args[0])
+        else
+          @sign, @exponent, @fraction = deconstruct_uint16(args[0])
+        end
+      when 3
+        @sign, @exponent, @fraction = *args
+      else
+        message = "wrong number of arguments (given #{n_args}, expected 1 or 3)"
+        raise ArgumentError, message
+      end
+    end
+
+    def to_f
+      if @exponent == EXPONENT_MASK
+        if @sign.zero?
+          Float::INFINITY
+        else
+          -Float::INFINITY
+        end
+      else
+        if @exponent.zero?
+          implicit_fraction = 0
+        else
+          implicit_fraction = 1
+        end
+        ((-1) ** @sign) *
+          (2 ** (@exponent - EXPONENT_BIAS)) *
+          (implicit_fraction + @fraction / FRACTION_DENOMINATOR)
+      end
+    end
+
+    def to_uint16
+      (@sign << (EXPONENT_N_BITS + FRACTION_N_BITS)) ^
+        (@exponent << FRACTION_N_BITS) ^
+        @fraction
+    end
+
+    def pack
+      [to_uint16].pack("S")
+    end
+
+    private
+    def deconstruct_float(float)
+      if float > MAX
+        float = Float::INFINITY
+      elsif float < MIN
+        float = -Float::INFINITY
+      end
+      is_infinite = float.infinite?
+      if is_infinite
+        sign = (is_infinite == 1) ? 0 : 1
+        exponent = EXPONENT_MASK
+        fraction = 0
+      elsif float.zero?
+        sign = 0
+        exponent = 0
+        fraction = 0
+      else
+        sign = (float.positive? ? 0 : 1)
+        float_abs = float.abs
+        1.upto(EXPONENT_MASK) do |e|
+          next_exponent_value = 2 ** (e + 1 - EXPONENT_BIAS)
+          next if float_abs > next_exponent_value
+          exponent = e
+          exponent_value = 2 ** (e - EXPONENT_BIAS)
+          fraction =
+            ((float_abs / exponent_value - 1) * FRACTION_DENOMINATOR).round
+          break
+        end
+      end
+      [sign, exponent, fraction]
+    end
+
+    def deconstruct_uint16(uint16)
+      # | sign (1 bit) | exponent (5 bit) | fraction (10 bit) |
+      sign = (uint16 >> (EXPONENT_N_BITS + FRACTION_N_BITS))
+      exponent = ((uint16 >> FRACTION_N_BITS) & EXPONENT_MASK)
+      fraction = (uint16 & FRACTION_MASK)
+      [sign, exponent, fraction]
+    end
+  end
+end
diff --git a/ruby/red-arrow/lib/arrow/loader.rb b/ruby/red-arrow/lib/arrow/loader.rb
index 58b11e567f..9c8300628a 100644
--- a/ruby/red-arrow/lib/arrow/loader.rb
+++ b/ruby/red-arrow/lib/arrow/loader.rb
@@ -81,6 +81,9 @@ module Arrow
       require "arrow/fixed-size-binary-array-builder"
       require "arrow/function"
       require "arrow/group"
+      require "arrow/half-float"
+      require "arrow/half-float-array"
+      require "arrow/half-float-array-builder"
       require "arrow/list-array-builder"
       require "arrow/list-data-type"
       require "arrow/map-array"
@@ -196,6 +199,7 @@ module Arrow
            "Arrow::Date64Array",
            "Arrow::Decimal128Array",
            "Arrow::Decimal256Array",
+           "Arrow::HalfFloatArray",
            "Arrow::Time32Array",
            "Arrow::Time64Array",
            "Arrow::TimestampArray"
diff --git a/ruby/red-arrow/test/raw-records/test-basic-arrays.rb b/ruby/red-arrow/test/raw-records/test-basic-arrays.rb
index 0180cb92b4..15cdee6820 100644
--- a/ruby/red-arrow/test/raw-records/test-basic-arrays.rb
+++ b/ruby/red-arrow/test/raw-records/test-basic-arrays.rb
@@ -117,6 +117,16 @@ module RawRecordsBasicArraysTests
     assert_equal(records, target.raw_records)
   end
 
+  def test_half_float
+    records = [
+      [-1.5],
+      [nil],
+      [1.5],
+    ]
+    target = build({column: :half_float}, records)
+    assert_equal(records, target.raw_records)
+  end
+
   def test_float
     records = [
       [-1.0],
diff --git a/ruby/red-arrow/test/test-half-float-array.rb b/ruby/red-arrow/test/test-half-float-array.rb
new file mode 100644
index 0000000000..a13dcea2f9
--- /dev/null
+++ b/ruby/red-arrow/test/test-half-float-array.rb
@@ -0,0 +1,43 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+class HalfFloatArrayTest < Test::Unit::TestCase
+  sub_test_case(".new") do
+    test("Float") do
+      array = Arrow::HalfFloatArray.new([1.5])
+      assert_equal([1.5], array.to_a)
+    end
+
+    test("Integer") do
+      one_half = Arrow::HalfFloat.new(1.5)
+      array = Arrow::HalfFloatArray.new([one_half.to_uint16])
+      assert_equal([one_half.to_f], array.to_a)
+    end
+
+    test("HalfFloat") do
+      one_half = Arrow::HalfFloat.new(1.5)
+      array = Arrow::HalfFloatArray.new([one_half])
+      assert_equal([one_half.to_f], array.to_a)
+    end
+  end
+
+  test("#[]") do
+    one_half = Arrow::HalfFloat.new(1.5)
+    array = Arrow::HalfFloatArray.new([one_half.to_uint16])
+    assert_equal(one_half.to_f, array[0])
+  end
+end
diff --git a/ruby/red-arrow/test/test-half-float.rb b/ruby/red-arrow/test/test-half-float.rb
new file mode 100644
index 0000000000..1b551a0333
--- /dev/null
+++ b/ruby/red-arrow/test/test-half-float.rb
@@ -0,0 +1,130 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+class HalfFloatTest < Test::Unit::TestCase
+  sub_test_case(".new") do
+    test("Array") do
+      positive_infinity = Arrow::HalfFloat.new(0b1, 0b11111, 0b0000000000)
+      assert_equal([0b1, 0b11111, 0b0000000000],
+                   [
+                     positive_infinity.sign,
+                     positive_infinity.exponent,
+                     positive_infinity.fraction,
+                   ])
+    end
+
+    test("Integer - 0") do
+      zero = Arrow::HalfFloat.new(0)
+      assert_equal([0b0, 0b00000, 0b0000000000],
+                   [
+                     zero.sign,
+                     zero.exponent,
+                     zero.fraction,
+                   ])
+    end
+
+    test("Integer - +infinity") do
+      positive_infinity = Arrow::HalfFloat.new(0x7c00)
+      assert_equal([0b0, 0b11111, 0b0000000000],
+                   [
+                     positive_infinity.sign,
+                     positive_infinity.exponent,
+                     positive_infinity.fraction,
+                   ])
+    end
+
+    test("Integer - -infinity") do
+      negative_infinity = Arrow::HalfFloat.new(0xfc00)
+      assert_equal([0b1, 0b11111, 0b0000000000],
+                   [
+                     negative_infinity.sign,
+                     negative_infinity.exponent,
+                     negative_infinity.fraction,
+                   ])
+    end
+
+    test("Integer - 1/3") do
+      one_thirds = Arrow::HalfFloat.new(0x3555)
+      assert_equal([0b0, 0b01101, 0b0101010101],
+                   [
+                     one_thirds.sign,
+                     one_thirds.exponent,
+                     one_thirds.fraction,
+                   ])
+    end
+
+    test("Float - 0") do
+      zero = Arrow::HalfFloat.new(0.0)
+      assert_equal([0b0, 0b00000, 0b0000000000],
+                   [
+                     zero.sign,
+                     zero.exponent,
+                     zero.fraction,
+                   ])
+    end
+
+    test("Float - too large") do
+      positive_infinity = Arrow::HalfFloat.new(65504.1)
+      assert_equal([0b0, 0b11111, 0b0000000000],
+                   [
+                     positive_infinity.sign,
+                     positive_infinity.exponent,
+                     positive_infinity.fraction,
+                   ])
+    end
+
+    test("Float - +infinity") do
+      positive_infinity = Arrow::HalfFloat.new(Float::INFINITY)
+      assert_equal([0b0, 0b11111, 0b0000000000],
+                   [
+                     positive_infinity.sign,
+                     positive_infinity.exponent,
+                     positive_infinity.fraction,
+                   ])
+    end
+
+    test("Float - too small") do
+      negative_infinity = Arrow::HalfFloat.new(-65504.1)
+      assert_equal([0b1, 0b11111, 0b0000000000],
+                   [
+                     negative_infinity.sign,
+                     negative_infinity.exponent,
+                     negative_infinity.fraction,
+                   ])
+    end
+
+    test("Float - -infinity") do
+      negative_infinity = Arrow::HalfFloat.new(-Float::INFINITY)
+      assert_equal([0b1, 0b11111, 0b0000000000],
+                   [
+                     negative_infinity.sign,
+                     negative_infinity.exponent,
+                     negative_infinity.fraction,
+                   ])
+    end
+
+    test("Float - 1/3") do
+      one_thirds = Arrow::HalfFloat.new((2 ** -2) * (1 + 341 / 1024.0))
+      assert_equal([0b0, 0b01101, 0b0101010101],
+                   [
+                     one_thirds.sign,
+                     one_thirds.exponent,
+                     one_thirds.fraction,
+                   ])
+    end
+  end
+end
diff --git a/ruby/red-arrow/test/values/test-basic-arrays.rb b/ruby/red-arrow/test/values/test-basic-arrays.rb
index 237385fa7b..ae469d1bf0 100644
--- a/ruby/red-arrow/test/values/test-basic-arrays.rb
+++ b/ruby/red-arrow/test/values/test-basic-arrays.rb
@@ -107,6 +107,16 @@ module ValuesBasicArraysTests
     assert_equal(values, target.values)
   end
 
+  def test_half_float
+    values = [
+      -1.5,
+      nil,
+      1.5,
+    ]
+    target = build(Arrow::HalfFloatArray.new(values))
+    assert_equal(values, target.values)
+  end
+
   def test_float
     values = [
       -1.0,