You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ju...@apache.org on 2016/12/05 22:57:01 UTC

arrow git commit: ARROW-401: Floating point vectors should do an approximate comparison…

Repository: arrow
Updated Branches:
  refs/heads/master 0ac01a5bf -> 599d516a7


ARROW-401: Floating point vectors should do an approximate comparison\u2026

\u2026 in integration tests

Author: Julien Le Dem <ju...@dremio.com>

Closes #223 from julienledem/arrow_401 and squashes the following commits:

a9ee84d [Julien Le Dem] review feedback
da64ca0 [Julien Le Dem] ARROW-401: Floating point vectors should do an approximate comparison in integration tests


Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/599d516a
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/599d516a
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/599d516a

Branch: refs/heads/master
Commit: 599d516a7306de4d1f9d7e0ddc888f13026efd49
Parents: 0ac01a5
Author: Julien Le Dem <ju...@dremio.com>
Authored: Mon Dec 5 14:56:56 2016 -0800
Committer: Julien Le Dem <ju...@dremio.com>
Committed: Mon Dec 5 14:56:56 2016 -0800

----------------------------------------------------------------------
 .../org/apache/arrow/tools/Integration.java     | 51 +++++++++++-
 .../org/apache/arrow/tools/TestIntegration.java | 84 +++++++++++++++++++-
 2 files changed, 130 insertions(+), 5 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/arrow/blob/599d516a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java
----------------------------------------------------------------------
diff --git a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java
index 85af30d..fd835a6 100644
--- a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java
+++ b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java
@@ -39,6 +39,8 @@ import org.apache.arrow.vector.file.ArrowWriter;
 import org.apache.arrow.vector.file.json.JsonFileReader;
 import org.apache.arrow.vector.file.json.JsonFileWriter;
 import org.apache.arrow.vector.schema.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.ArrowType.FloatingPoint;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.arrow.vector.types.pojo.Schema;
 import org.apache.commons.cli.CommandLine;
@@ -247,7 +249,7 @@ public class Integration {
       for (int j = 0; j < valueCount; j++) {
         Object arrow = arrowVector.getAccessor().getObject(j);
         Object json = jsonVector.getAccessor().getObject(j);
-        if (!Objects.equal(arrow, json)) {
+        if (!equals(field.getType(), arrow, json)) {
           throw new IllegalArgumentException(
               "Different values in column:\n" + field + " at index " + j + ": " + arrow + " != " + json);
         }
@@ -255,6 +257,53 @@ public class Integration {
     }
   }
 
+  private static boolean equals(ArrowType type, final Object arrow, final Object json) {
+    if (type instanceof ArrowType.FloatingPoint) {
+      FloatingPoint fpType = (FloatingPoint) type;
+      switch (fpType.getPrecision()) {
+      case DOUBLE:
+        return equalEnough((Double)arrow, (Double)json);
+      case SINGLE:
+        return equalEnough((Float)arrow, (Float)json);
+      case HALF:
+      default:
+        throw new UnsupportedOperationException("unsupported precision: " + fpType);
+      }
+    }
+    return Objects.equal(arrow, json);
+  }
+
+  static boolean equalEnough(Float f1, Float f2) {
+    if (f1 == null || f2 == null) {
+      return f1 == null && f2 == null;
+    }
+    if (f1.isNaN()) {
+      return f2.isNaN();
+    }
+    if (f1.isInfinite()) {
+      return f2.isInfinite() && Math.signum(f1) == Math.signum(f2);
+    }
+    float average = Math.abs((f1 + f2) / 2);
+    float differenceScaled = Math.abs(f1 - f2) / (average == 0.0f ? 1f : average);
+    return differenceScaled < 1.0E-6f;
+  }
+
+  static boolean equalEnough(Double f1, Double f2) {
+    if (f1 == null || f2 == null) {
+      return f1 == null && f2 == null;
+    }
+    if (f1.isNaN()) {
+      return f2.isNaN();
+    }
+    if (f1.isInfinite()) {
+      return f2.isInfinite() && Math.signum(f1) == Math.signum(f2);
+    }
+    double average = Math.abs((f1 + f2) / 2);
+    double differenceScaled = Math.abs(f1 - f2) / (average == 0.0d ? 1d : average);
+    return differenceScaled < 1.0E-12d;
+  }
+
+
   private static void compareSchemas(Schema jsonSchema, Schema arrowSchema) {
     if (!arrowSchema.equals(jsonSchema)) {
       throw new IllegalArgumentException("Different schemas:\n" + arrowSchema + "\n" + jsonSchema);

http://git-wip-us.apache.org/repos/asf/arrow/blob/599d516a/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java
----------------------------------------------------------------------
diff --git a/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java b/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java
index 464144b..ee6196b 100644
--- a/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java
+++ b/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java
@@ -22,6 +22,10 @@ import static org.apache.arrow.tools.ArrowFileTestFixtures.validateOutput;
 import static org.apache.arrow.tools.ArrowFileTestFixtures.write;
 import static org.apache.arrow.tools.ArrowFileTestFixtures.writeData;
 import static org.apache.arrow.tools.ArrowFileTestFixtures.writeInput;
+import static org.apache.arrow.tools.Integration.equalEnough;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
 import static org.junit.Assert.fail;
 
 import java.io.BufferedReader;
@@ -39,9 +43,9 @@ import org.apache.arrow.vector.complex.impl.ComplexWriterImpl;
 import org.apache.arrow.vector.complex.writer.BaseWriter.ComplexWriter;
 import org.apache.arrow.vector.complex.writer.BaseWriter.MapWriter;
 import org.apache.arrow.vector.complex.writer.BigIntWriter;
+import org.apache.arrow.vector.complex.writer.Float8Writer;
 import org.apache.arrow.vector.complex.writer.IntWriter;
 import org.junit.After;
-import org.junit.Assert;
 import org.junit.Before;
 import org.junit.Rule;
 import org.junit.Test;
@@ -121,7 +125,7 @@ public class TestIntegration {
     String i, o;
     int j = 0;
     while ((i = orig.readLine()) != null && (o = rt.readLine()) != null) {
-      Assert.assertEquals("line: " + j, i, o);
+      assertEquals("line: " + j, i, o);
       ++j;
     }
   }
@@ -142,6 +146,33 @@ public class TestIntegration {
   }
 
 
+  /**
+   * the test should not be sensitive to small variations in float representation
+   */
+  @Test
+  public void testFloat() throws Exception {
+    File testValidInFile = testFolder.newFile("testValidFloatIn.arrow");
+    File testInvalidInFile = testFolder.newFile("testAlsoValidFloatIn.arrow");
+    File testJSONFile = testFolder.newFile("testValidOut.json");
+    testJSONFile.delete();
+
+    // generate an arrow file
+    writeInputFloat(testValidInFile, allocator, 912.4140000000002, 912.414);
+    // generate a different arrow file
+    writeInputFloat(testInvalidInFile, allocator, 912.414, 912.4140000000002);
+
+    Integration integration = new Integration();
+
+    // convert the "valid" file to json
+    String[] args1 = { "-arrow", testValidInFile.getAbsolutePath(), "-json",  testJSONFile.getAbsolutePath(), "-command", Command.ARROW_TO_JSON.name()};
+    integration.run(args1);
+
+    // compare the "invalid" file to the "valid" json
+    String[] args3 = { "-arrow", testInvalidInFile.getAbsolutePath(), "-json",  testJSONFile.getAbsolutePath(), "-command", Command.VALIDATE.name()};
+    // this should fail
+    integration.run(args3);
+  }
+
   @Test
   public void testInvalid() throws Exception {
     File testValidInFile = testFolder.newFile("testValidIn.arrow");
@@ -167,12 +198,28 @@ public class TestIntegration {
       integration.run(args3);
       fail("should have failed");
     } catch (IllegalArgumentException e) {
-      Assert.assertTrue(e.getMessage(), e.getMessage().contains("Different values in column"));
-      Assert.assertTrue(e.getMessage(), e.getMessage().contains("999"));
+      assertTrue(e.getMessage(), e.getMessage().contains("Different values in column"));
+      assertTrue(e.getMessage(), e.getMessage().contains("999"));
     }
 
   }
 
+  static void writeInputFloat(File testInFile, BufferAllocator allocator, double... f) throws FileNotFoundException, IOException {
+    try (
+        BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE);
+        MapVector parent = new MapVector("parent", vectorAllocator, null)) {
+      ComplexWriter writer = new ComplexWriterImpl("root", parent);
+      MapWriter rootWriter = writer.rootAsMap();
+      Float8Writer floatWriter = rootWriter.float8("float");
+      for (int i = 0; i < f.length; i++) {
+        floatWriter.setPosition(i);
+        floatWriter.writeFloat8(f[i]);
+      }
+      writer.setValueCount(f.length);
+      write(parent.getChild("root"), testInFile);
+    }
+  }
+
   static void writeInput2(File testInFile, BufferAllocator allocator) throws FileNotFoundException, IOException {
     int count = ArrowFileTestFixtures.COUNT;
     try (
@@ -192,4 +239,33 @@ public class TestIntegration {
     }
   }
 
+  @Test
+  public void testFloatComp() {
+    assertTrue(equalEnough(912.4140000000002F, 912.414F));
+    assertTrue(equalEnough(912.4140000000002D, 912.414D));
+    assertTrue(equalEnough(912.414F, 912.4140000000002F));
+    assertTrue(equalEnough(912.414D, 912.4140000000002D));
+    assertFalse(equalEnough(912.414D, 912.4140001D));
+    assertFalse(equalEnough(null, 912.414D));
+    assertTrue(equalEnough((Float)null, null));
+    assertTrue(equalEnough((Double)null, null));
+    assertFalse(equalEnough(912.414D, null));
+    assertFalse(equalEnough(Double.MAX_VALUE, Double.MIN_VALUE));
+    assertFalse(equalEnough(Double.MIN_VALUE, Double.MAX_VALUE));
+    assertTrue(equalEnough(Double.MAX_VALUE, Double.MAX_VALUE));
+    assertTrue(equalEnough(Double.MIN_VALUE, Double.MIN_VALUE));
+    assertTrue(equalEnough(Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY));
+    assertFalse(equalEnough(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY));
+    assertTrue(equalEnough(Double.NaN, Double.NaN));
+    assertFalse(equalEnough(1.0, Double.NaN));
+    assertFalse(equalEnough(Float.MAX_VALUE, Float.MIN_VALUE));
+    assertFalse(equalEnough(Float.MIN_VALUE, Float.MAX_VALUE));
+    assertTrue(equalEnough(Float.MAX_VALUE, Float.MAX_VALUE));
+    assertTrue(equalEnough(Float.MIN_VALUE, Float.MIN_VALUE));
+    assertTrue(equalEnough(Float.NEGATIVE_INFINITY, Float.NEGATIVE_INFINITY));
+    assertFalse(equalEnough(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY));
+    assertTrue(equalEnough(Float.NaN, Float.NaN));
+    assertFalse(equalEnough(1.0F, Float.NaN));
+  }
+
 }