You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ju...@apache.org on 2016/12/05 22:57:01 UTC
arrow git commit: ARROW-401: Floating point vectors should do an approximate comparison…
Repository: arrow
Updated Branches:
refs/heads/master 0ac01a5bf -> 599d516a7
ARROW-401: Floating point vectors should do an approximate comparison\u2026
\u2026 in integration tests
Author: Julien Le Dem <ju...@dremio.com>
Closes #223 from julienledem/arrow_401 and squashes the following commits:
a9ee84d [Julien Le Dem] review feedback
da64ca0 [Julien Le Dem] ARROW-401: Floating point vectors should do an approximate comparison in integration tests
Project: http://git-wip-us.apache.org/repos/asf/arrow/repo
Commit: http://git-wip-us.apache.org/repos/asf/arrow/commit/599d516a
Tree: http://git-wip-us.apache.org/repos/asf/arrow/tree/599d516a
Diff: http://git-wip-us.apache.org/repos/asf/arrow/diff/599d516a
Branch: refs/heads/master
Commit: 599d516a7306de4d1f9d7e0ddc888f13026efd49
Parents: 0ac01a5
Author: Julien Le Dem <ju...@dremio.com>
Authored: Mon Dec 5 14:56:56 2016 -0800
Committer: Julien Le Dem <ju...@dremio.com>
Committed: Mon Dec 5 14:56:56 2016 -0800
----------------------------------------------------------------------
.../org/apache/arrow/tools/Integration.java | 51 +++++++++++-
.../org/apache/arrow/tools/TestIntegration.java | 84 +++++++++++++++++++-
2 files changed, 130 insertions(+), 5 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/arrow/blob/599d516a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java
----------------------------------------------------------------------
diff --git a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java
index 85af30d..fd835a6 100644
--- a/java/tools/src/main/java/org/apache/arrow/tools/Integration.java
+++ b/java/tools/src/main/java/org/apache/arrow/tools/Integration.java
@@ -39,6 +39,8 @@ import org.apache.arrow.vector.file.ArrowWriter;
import org.apache.arrow.vector.file.json.JsonFileReader;
import org.apache.arrow.vector.file.json.JsonFileWriter;
import org.apache.arrow.vector.schema.ArrowRecordBatch;
+import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.ArrowType.FloatingPoint;
import org.apache.arrow.vector.types.pojo.Field;
import org.apache.arrow.vector.types.pojo.Schema;
import org.apache.commons.cli.CommandLine;
@@ -247,7 +249,7 @@ public class Integration {
for (int j = 0; j < valueCount; j++) {
Object arrow = arrowVector.getAccessor().getObject(j);
Object json = jsonVector.getAccessor().getObject(j);
- if (!Objects.equal(arrow, json)) {
+ if (!equals(field.getType(), arrow, json)) {
throw new IllegalArgumentException(
"Different values in column:\n" + field + " at index " + j + ": " + arrow + " != " + json);
}
@@ -255,6 +257,53 @@ public class Integration {
}
}
+ private static boolean equals(ArrowType type, final Object arrow, final Object json) {
+ if (type instanceof ArrowType.FloatingPoint) {
+ FloatingPoint fpType = (FloatingPoint) type;
+ switch (fpType.getPrecision()) {
+ case DOUBLE:
+ return equalEnough((Double)arrow, (Double)json);
+ case SINGLE:
+ return equalEnough((Float)arrow, (Float)json);
+ case HALF:
+ default:
+ throw new UnsupportedOperationException("unsupported precision: " + fpType);
+ }
+ }
+ return Objects.equal(arrow, json);
+ }
+
+ static boolean equalEnough(Float f1, Float f2) {
+ if (f1 == null || f2 == null) {
+ return f1 == null && f2 == null;
+ }
+ if (f1.isNaN()) {
+ return f2.isNaN();
+ }
+ if (f1.isInfinite()) {
+ return f2.isInfinite() && Math.signum(f1) == Math.signum(f2);
+ }
+ float average = Math.abs((f1 + f2) / 2);
+ float differenceScaled = Math.abs(f1 - f2) / (average == 0.0f ? 1f : average);
+ return differenceScaled < 1.0E-6f;
+ }
+
+ static boolean equalEnough(Double f1, Double f2) {
+ if (f1 == null || f2 == null) {
+ return f1 == null && f2 == null;
+ }
+ if (f1.isNaN()) {
+ return f2.isNaN();
+ }
+ if (f1.isInfinite()) {
+ return f2.isInfinite() && Math.signum(f1) == Math.signum(f2);
+ }
+ double average = Math.abs((f1 + f2) / 2);
+ double differenceScaled = Math.abs(f1 - f2) / (average == 0.0d ? 1d : average);
+ return differenceScaled < 1.0E-12d;
+ }
+
+
private static void compareSchemas(Schema jsonSchema, Schema arrowSchema) {
if (!arrowSchema.equals(jsonSchema)) {
throw new IllegalArgumentException("Different schemas:\n" + arrowSchema + "\n" + jsonSchema);
http://git-wip-us.apache.org/repos/asf/arrow/blob/599d516a/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java
----------------------------------------------------------------------
diff --git a/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java b/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java
index 464144b..ee6196b 100644
--- a/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java
+++ b/java/tools/src/test/java/org/apache/arrow/tools/TestIntegration.java
@@ -22,6 +22,10 @@ import static org.apache.arrow.tools.ArrowFileTestFixtures.validateOutput;
import static org.apache.arrow.tools.ArrowFileTestFixtures.write;
import static org.apache.arrow.tools.ArrowFileTestFixtures.writeData;
import static org.apache.arrow.tools.ArrowFileTestFixtures.writeInput;
+import static org.apache.arrow.tools.Integration.equalEnough;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.BufferedReader;
@@ -39,9 +43,9 @@ import org.apache.arrow.vector.complex.impl.ComplexWriterImpl;
import org.apache.arrow.vector.complex.writer.BaseWriter.ComplexWriter;
import org.apache.arrow.vector.complex.writer.BaseWriter.MapWriter;
import org.apache.arrow.vector.complex.writer.BigIntWriter;
+import org.apache.arrow.vector.complex.writer.Float8Writer;
import org.apache.arrow.vector.complex.writer.IntWriter;
import org.junit.After;
-import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
@@ -121,7 +125,7 @@ public class TestIntegration {
String i, o;
int j = 0;
while ((i = orig.readLine()) != null && (o = rt.readLine()) != null) {
- Assert.assertEquals("line: " + j, i, o);
+ assertEquals("line: " + j, i, o);
++j;
}
}
@@ -142,6 +146,33 @@ public class TestIntegration {
}
+ /**
+ * the test should not be sensitive to small variations in float representation
+ */
+ @Test
+ public void testFloat() throws Exception {
+ File testValidInFile = testFolder.newFile("testValidFloatIn.arrow");
+ File testInvalidInFile = testFolder.newFile("testAlsoValidFloatIn.arrow");
+ File testJSONFile = testFolder.newFile("testValidOut.json");
+ testJSONFile.delete();
+
+ // generate an arrow file
+ writeInputFloat(testValidInFile, allocator, 912.4140000000002, 912.414);
+ // generate a different arrow file
+ writeInputFloat(testInvalidInFile, allocator, 912.414, 912.4140000000002);
+
+ Integration integration = new Integration();
+
+ // convert the "valid" file to json
+ String[] args1 = { "-arrow", testValidInFile.getAbsolutePath(), "-json", testJSONFile.getAbsolutePath(), "-command", Command.ARROW_TO_JSON.name()};
+ integration.run(args1);
+
+ // compare the "invalid" file to the "valid" json
+ String[] args3 = { "-arrow", testInvalidInFile.getAbsolutePath(), "-json", testJSONFile.getAbsolutePath(), "-command", Command.VALIDATE.name()};
+ // this should fail
+ integration.run(args3);
+ }
+
@Test
public void testInvalid() throws Exception {
File testValidInFile = testFolder.newFile("testValidIn.arrow");
@@ -167,12 +198,28 @@ public class TestIntegration {
integration.run(args3);
fail("should have failed");
} catch (IllegalArgumentException e) {
- Assert.assertTrue(e.getMessage(), e.getMessage().contains("Different values in column"));
- Assert.assertTrue(e.getMessage(), e.getMessage().contains("999"));
+ assertTrue(e.getMessage(), e.getMessage().contains("Different values in column"));
+ assertTrue(e.getMessage(), e.getMessage().contains("999"));
}
}
+ static void writeInputFloat(File testInFile, BufferAllocator allocator, double... f) throws FileNotFoundException, IOException {
+ try (
+ BufferAllocator vectorAllocator = allocator.newChildAllocator("original vectors", 0, Integer.MAX_VALUE);
+ MapVector parent = new MapVector("parent", vectorAllocator, null)) {
+ ComplexWriter writer = new ComplexWriterImpl("root", parent);
+ MapWriter rootWriter = writer.rootAsMap();
+ Float8Writer floatWriter = rootWriter.float8("float");
+ for (int i = 0; i < f.length; i++) {
+ floatWriter.setPosition(i);
+ floatWriter.writeFloat8(f[i]);
+ }
+ writer.setValueCount(f.length);
+ write(parent.getChild("root"), testInFile);
+ }
+ }
+
static void writeInput2(File testInFile, BufferAllocator allocator) throws FileNotFoundException, IOException {
int count = ArrowFileTestFixtures.COUNT;
try (
@@ -192,4 +239,33 @@ public class TestIntegration {
}
}
+ @Test
+ public void testFloatComp() {
+ assertTrue(equalEnough(912.4140000000002F, 912.414F));
+ assertTrue(equalEnough(912.4140000000002D, 912.414D));
+ assertTrue(equalEnough(912.414F, 912.4140000000002F));
+ assertTrue(equalEnough(912.414D, 912.4140000000002D));
+ assertFalse(equalEnough(912.414D, 912.4140001D));
+ assertFalse(equalEnough(null, 912.414D));
+ assertTrue(equalEnough((Float)null, null));
+ assertTrue(equalEnough((Double)null, null));
+ assertFalse(equalEnough(912.414D, null));
+ assertFalse(equalEnough(Double.MAX_VALUE, Double.MIN_VALUE));
+ assertFalse(equalEnough(Double.MIN_VALUE, Double.MAX_VALUE));
+ assertTrue(equalEnough(Double.MAX_VALUE, Double.MAX_VALUE));
+ assertTrue(equalEnough(Double.MIN_VALUE, Double.MIN_VALUE));
+ assertTrue(equalEnough(Double.NEGATIVE_INFINITY, Double.NEGATIVE_INFINITY));
+ assertFalse(equalEnough(Double.NEGATIVE_INFINITY, Double.POSITIVE_INFINITY));
+ assertTrue(equalEnough(Double.NaN, Double.NaN));
+ assertFalse(equalEnough(1.0, Double.NaN));
+ assertFalse(equalEnough(Float.MAX_VALUE, Float.MIN_VALUE));
+ assertFalse(equalEnough(Float.MIN_VALUE, Float.MAX_VALUE));
+ assertTrue(equalEnough(Float.MAX_VALUE, Float.MAX_VALUE));
+ assertTrue(equalEnough(Float.MIN_VALUE, Float.MIN_VALUE));
+ assertTrue(equalEnough(Float.NEGATIVE_INFINITY, Float.NEGATIVE_INFINITY));
+ assertFalse(equalEnough(Float.NEGATIVE_INFINITY, Float.POSITIVE_INFINITY));
+ assertTrue(equalEnough(Float.NaN, Float.NaN));
+ assertFalse(equalEnough(1.0F, Float.NaN));
+ }
+
}