You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@arrow.apache.org by ra...@apache.org on 2019/04/30 14:09:09 UTC

[arrow] branch master updated: ARROW-5243: [Java][Gandiva] Add decimal compare tests

This is an automated email from the ASF dual-hosted git repository.

ravindra pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow.git


The following commit(s) were added to refs/heads/master by this push:
     new 782a63d  ARROW-5243: [Java][Gandiva] Add decimal compare tests
782a63d is described below

commit 782a63d0e71d122dadbad97091d96f1fe3c0e06d
Author: Pindikura Ravindra <ra...@dremio.com>
AuthorDate: Tue Apr 30 19:38:30 2019 +0530

    ARROW-5243: [Java][Gandiva] Add decimal compare tests
    
    Author: Pindikura Ravindra <ra...@dremio.com>
    
    Closes #4227 from pravindra/arrow-5243 and squashes the following commits:
    
    4f0bd606 <Pindikura Ravindra> ARROW-5243:  Add decimal compare tests
---
 cpp/src/gandiva/decimal_xlarge.cc                  |  2 +-
 .../gandiva/evaluator/ProjectorDecimalTest.java    | 95 ++++++++++++++++++++++
 2 files changed, 96 insertions(+), 1 deletion(-)

diff --git a/cpp/src/gandiva/decimal_xlarge.cc b/cpp/src/gandiva/decimal_xlarge.cc
index 60917ed..392c14c 100644
--- a/cpp/src/gandiva/decimal_xlarge.cc
+++ b/cpp/src/gandiva/decimal_xlarge.cc
@@ -92,7 +92,7 @@ void ExportedDecimalFunctions::AddMappings(Engine* engine) const {
           types->i32_type()};  // int32_t y_scale
 
   engine->AddGlobalMappingForFunc("gdv_xlarge_compare", types->i32_type() /*return_type*/,
-                                  args, reinterpret_cast<void*>(gdv_xlarge_mod));
+                                  args, reinterpret_cast<void*>(gdv_xlarge_compare));
 }
 
 }  // namespace gandiva
diff --git a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
index e4a7cc3..5dc36c0 100644
--- a/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
+++ b/java/gandiva/src/test/java/org/apache/arrow/gandiva/evaluator/ProjectorDecimalTest.java
@@ -18,22 +18,28 @@
 package org.apache.arrow.gandiva.evaluator;
 
 
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertFalse;
 import static org.junit.Assert.assertTrue;
 
 import java.math.BigDecimal;
 import java.util.ArrayList;
+import java.util.Arrays;
 import java.util.List;
 
 import org.apache.arrow.gandiva.exceptions.GandivaException;
 import org.apache.arrow.gandiva.expression.ExpressionTree;
 import org.apache.arrow.gandiva.expression.TreeBuilder;
 import org.apache.arrow.gandiva.expression.TreeNode;
+import org.apache.arrow.vector.BitVector;
 import org.apache.arrow.vector.DecimalVector;
 import org.apache.arrow.vector.ValueVector;
 import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
 import org.apache.arrow.vector.ipc.message.ArrowRecordBatch;
 import org.apache.arrow.vector.types.pojo.ArrowType;
+import org.apache.arrow.vector.types.pojo.ArrowType.Bool;
+import org.apache.arrow.vector.types.pojo.ArrowType.Decimal;
 import org.apache.arrow.vector.types.pojo.Field;
 import org.apache.arrow.vector.types.pojo.Schema;
 import org.junit.Test;
@@ -213,4 +219,93 @@ public class ProjectorDecimalTest extends org.apache.arrow.gandiva.evaluator.Bas
     releaseValueVectors(output);
     eval.close();
   }
+
+  @Test
+  public void testCompare() throws GandivaException {
+    Decimal aType = new Decimal(38, 3);
+    Decimal bType = new Decimal(38, 2);
+    Field a = Field.nullable("a", aType);
+    Field b = Field.nullable("b", bType);
+    List<Field> args = Lists.newArrayList(a, b);
+
+    List<ExpressionTree> exprs = new ArrayList<>(
+        Arrays.asList(
+            TreeBuilder.makeExpression("equal", args, Field.nullable("eq", boolType)),
+            TreeBuilder.makeExpression("not_equal", args, Field.nullable("ne", boolType)),
+            TreeBuilder.makeExpression("less_than", args, Field.nullable("lt", boolType)),
+            TreeBuilder.makeExpression("less_than_or_equal_to", args, Field.nullable("le", boolType)),
+            TreeBuilder.makeExpression("greater_than", args, Field.nullable("gt", boolType)),
+            TreeBuilder.makeExpression("greater_than_or_equal_to", args, Field.nullable("ge", boolType))
+        )
+    );
+
+    Schema schema = new Schema(args);
+    Projector eval = Projector.make(schema, exprs);
+
+    List<ValueVector> output = null;
+    ArrowRecordBatch batch = null;
+    try {
+      int numRows = 4;
+      String[] aValues = new String[]{"7.620", "2.380", "3.860", "-18.160"};
+      String[] bValues = new String[]{"7.62", "3.50", "1.90", "-1.45"};
+
+      DecimalVector valuesa = decimalVector(aValues, aType.getPrecision(), aType.getScale());
+      DecimalVector valuesb = decimalVector(bValues, bType.getPrecision(), bType.getScale());
+      batch =
+          new ArrowRecordBatch(
+              numRows,
+              Lists.newArrayList(new ArrowFieldNode(numRows, 0), new ArrowFieldNode(numRows, 0)),
+              Lists.newArrayList(valuesa.getValidityBuffer(), valuesa.getDataBuffer(),
+                  valuesb.getValidityBuffer(), valuesb.getDataBuffer()));
+
+      // expected results.
+      boolean[][] expected = {
+          {true, false, false, false}, // eq
+          {false, true, true, true}, // ne
+          {false, true, false, true}, // lt
+          {true, true, false, true}, // le
+          {false, false, true, false}, // gt
+          {true, false, true, false}, // ge
+      };
+
+      // Allocate output vectors.
+      output = new ArrayList<>(
+          Arrays.asList(
+              new BitVector("eq", allocator),
+              new BitVector("ne", allocator),
+              new BitVector("lt", allocator),
+              new BitVector("le", allocator),
+              new BitVector("gt", allocator),
+              new BitVector("ge", allocator)
+          )
+      );
+      for (ValueVector v : output) {
+        v.allocateNew();
+      }
+
+      // evaluate expressions.
+      eval.evaluate(batch, output);
+
+      // compare the outputs.
+      for (int idx = 0; idx < output.size(); ++idx) {
+        boolean[] expectedArray = expected[idx];
+        BitVector resultVector = (BitVector) output.get(idx);
+
+        for (int i = 0; i < numRows; i++) {
+          assertFalse(resultVector.isNull(i));
+          assertEquals("mismatch in result for expr at idx " + idx + " for row " + i,
+              expectedArray[i], resultVector.getObject(i).booleanValue());
+        }
+      }
+    } finally {
+      // free buffers
+      if (batch != null) {
+        releaseRecordBatch(batch);
+      }
+      if (output != null) {
+        releaseValueVectors(output);
+      }
+      eval.close();
+    }
+  }
 }