You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by ab...@apache.org on 2021/12/15 05:15:10 UTC

[druid] branch master updated: Fix incorrect type conversion in DruidLogicalValueRule (#11923)

This is an automated email from the ASF dual-hosted git repository.

abhishek pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/druid.git


The following commit(s) were added to refs/heads/master by this push:
     new 16642fb  Fix incorrect type conversion in DruidLogicalValueRule (#11923)
16642fb is described below

commit 16642fb2780b11ade62c8d46d3dd7f90e424add4
Author: Laksh Singla <30...@users.noreply.github.com>
AuthorDate: Wed Dec 15 10:44:35 2021 +0530

    Fix incorrect type conversion in DruidLogicalValueRule (#11923)
    
    DruidLogicalValuesRule while transforming to DruidRel can return incorrect values, if during the creation of the literal it was created from a float value. The BigDecimal representation stores 123.0, and it seems that using RexLiteral's method while conversion returns the inflated value (which is 1230). I am unsure if this is intentional from Calcite's perspective, and the actual change should be done somewhere else.
    
    Extract the values of INT/LONG from the RexLiteral in the DruidLogicalValuesRule, via BigDecimal.longValue() method.
---
 .../sql/calcite/rule/DruidLogicalValuesRule.java   | 18 ++++++++++--
 .../druid/sql/calcite/CalciteSelectQueryTest.java  | 32 +++++++++++++++++++++
 .../calcite/rule/DruidLogicalValuesRuleTest.java   | 33 ++++++++++++++++++----
 3 files changed, 75 insertions(+), 8 deletions(-)

diff --git a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidLogicalValuesRule.java b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidLogicalValuesRule.java
index df8c516..845c460 100644
--- a/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidLogicalValuesRule.java
+++ b/sql/src/main/java/org/apache/druid/sql/calcite/rule/DruidLogicalValuesRule.java
@@ -34,6 +34,7 @@ import org.apache.druid.sql.calcite.rel.DruidQueryRel;
 import org.apache.druid.sql.calcite.table.DruidTable;
 import org.apache.druid.sql.calcite.table.RowSignatures;
 
+import javax.annotation.Nullable;
 import java.util.List;
 import java.util.stream.Collectors;
 
@@ -93,24 +94,35 @@ public class DruidLogicalValuesRule extends RelOptRule
    *
    * @throws IllegalArgumentException for unsupported types
    */
+  @Nullable
   @VisibleForTesting
   static Object getValueFromLiteral(RexLiteral literal, PlannerContext plannerContext)
   {
     switch (literal.getType().getSqlTypeName()) {
       case CHAR:
       case VARCHAR:
+        // RexLiteral.stringValue(literal) was causing some issue during tests
         return literal.getValueAs(String.class);
       case FLOAT:
-        return literal.getValueAs(Float.class);
+        if (literal.isNull()) {
+          return null;
+        }
+        return ((Number) RexLiteral.value(literal)).floatValue();
       case DOUBLE:
       case REAL:
       case DECIMAL:
-        return literal.getValueAs(Double.class);
+        if (literal.isNull()) {
+          return null;
+        }
+        return ((Number) RexLiteral.value(literal)).doubleValue();
       case TINYINT:
       case SMALLINT:
       case INTEGER:
       case BIGINT:
-        return literal.getValueAs(Long.class);
+        if (literal.isNull()) {
+          return null;
+        }
+        return ((Number) RexLiteral.value(literal)).longValue();
       case BOOLEAN:
         return literal.isAlwaysTrue() ? 1L : 0L;
       case TIMESTAMP:
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java
index a4aaab8..86af750 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/CalciteSelectQueryTest.java
@@ -36,6 +36,7 @@ import org.apache.druid.query.extraction.SubstringDimExtractionFn;
 import org.apache.druid.query.groupby.GroupByQuery;
 import org.apache.druid.query.ordering.StringComparators;
 import org.apache.druid.query.scan.ScanQuery;
+import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
 import org.apache.druid.query.topn.DimensionTopNMetricSpec;
 import org.apache.druid.query.topn.InvertedTopNMetricSpec;
 import org.apache.druid.query.topn.TopNQueryBuilder;
@@ -166,6 +167,37 @@ public class CalciteSelectQueryTest extends BaseCalciteQueryTest
     );
   }
 
+  // Test that the integers are getting correctly casted after being passed through a function when not selecting from
+  // a table
+  @Test
+  public void testDruidLogicalValuesRule() throws Exception
+  {
+    testQuery(
+        "SELECT FLOOR(123), CEIL(123), CAST(123.0 AS INTEGER)",
+        ImmutableList.of(
+            newScanQueryBuilder()
+                .dataSource(InlineDataSource.fromIterable(
+                    ImmutableList.of(new Object[]{123L, 123L, 123L}),
+                    RowSignature.builder()
+                                .add("EXPR$0", ColumnType.LONG)
+                                .add("EXPR$1", ColumnType.LONG)
+                                .add("EXPR$2", ColumnType.LONG)
+                                .build()
+                ))
+                .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Intervals.ETERNITY)))
+                .columns(ImmutableList.of("EXPR$0", "EXPR$1", "EXPR$2"))
+                .build()
+        ),
+        ImmutableList.of(
+            new Object[]{
+                123,
+                123,
+                123
+            }
+        )
+    );
+  }
+
   @Test
   public void testSelectConstantExpressionFromTable() throws Exception
   {
diff --git a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidLogicalValuesRuleTest.java b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidLogicalValuesRuleTest.java
index e47f156..de7a005 100644
--- a/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidLogicalValuesRuleTest.java
+++ b/sql/src/test/java/org/apache/druid/sql/calcite/rule/DruidLogicalValuesRuleTest.java
@@ -21,6 +21,7 @@ package org.apache.druid.sql.calcite.rule;
 
 import com.google.common.collect.ImmutableList;
 import org.apache.calcite.rel.type.RelDataType;
+import org.apache.calcite.rel.type.RelDataTypeFactory;
 import org.apache.calcite.rex.RexBuilder;
 import org.apache.calcite.rex.RexLiteral;
 import org.apache.calcite.sql.type.SqlTypeFactoryImpl;
@@ -46,6 +47,9 @@ import org.junit.runners.Parameterized.Parameters;
 import org.mockito.ArgumentMatchers;
 import org.mockito.Mockito;
 
+import java.lang.reflect.Field;
+import java.math.BigDecimal;
+
 @RunWith(Enclosed.class)
 public class DruidLogicalValuesRuleTest
 {
@@ -70,11 +74,11 @@ public class DruidLogicalValuesRuleTest
       );
     }
 
-    private final Object val;
+    private final Comparable<?> val;
     private final SqlTypeName sqlTypeName;
     private final Class<?> javaType;
 
-    public GetValueFromLiteralSimpleTypesTest(Object val, SqlTypeName sqlTypeName, Class<?> javaType)
+    public GetValueFromLiteralSimpleTypesTest(Comparable<?> val, SqlTypeName sqlTypeName, Class<?> javaType)
     {
       this.val = val;
       this.sqlTypeName = sqlTypeName;
@@ -89,14 +93,21 @@ public class DruidLogicalValuesRuleTest
       Assert.assertSame(javaType, fromLiteral.getClass());
       Assert.assertEquals(val, fromLiteral);
       Mockito.verify(literal, Mockito.times(1)).getType();
-      Mockito.verify(literal, Mockito.times(1)).getValueAs(ArgumentMatchers.any());
     }
 
-    private static RexLiteral makeLiteral(Object val, SqlTypeName typeName, Class<?> javaType)
+    private static RexLiteral makeLiteral(Comparable<?> val, SqlTypeName typeName, Class<?> javaType)
     {
       RelDataType dataType = Mockito.mock(RelDataType.class);
       Mockito.when(dataType.getSqlTypeName()).thenReturn(typeName);
       RexLiteral literal = Mockito.mock(RexLiteral.class);
+      try {
+        Field field = literal.getClass().getSuperclass().getDeclaredField("value");
+        field.setAccessible(true);
+        field.set(literal, val);
+      }
+      catch (Exception e) {
+        Assert.fail("Unable to mock the literal for test.\nException: " + e);
+      }
       Mockito.when(literal.getType()).thenReturn(dataType);
       Mockito.when(literal.getValueAs(ArgumentMatchers.any())).thenReturn(javaType.cast(val));
       return literal;
@@ -107,7 +118,8 @@ public class DruidLogicalValuesRuleTest
   {
     private static final PlannerContext DEFAULT_CONTEXT = Mockito.mock(PlannerContext.class);
     private static final DateTimeZone TIME_ZONE = DateTimes.inferTzFromString("Asia/Seoul");
-    private static final RexBuilder REX_BUILDER = new RexBuilder(new SqlTypeFactoryImpl(DruidTypeSystem.INSTANCE));
+    private static final RelDataTypeFactory TYPE_FACTORY = new SqlTypeFactoryImpl(DruidTypeSystem.INSTANCE);
+    private static final RexBuilder REX_BUILDER = new RexBuilder(TYPE_FACTORY);
 
     @Rule
     public ExpectedException expectedException = ExpectedException.none();
@@ -187,5 +199,16 @@ public class DruidLogicalValuesRuleTest
       expectedException.expectMessage("TIME_WITH_LOCAL_TIME_ZONE type is not supported");
       DruidLogicalValuesRule.getValueFromLiteral(literal, DEFAULT_CONTEXT);
     }
+
+    @Test
+    public void testGetCastedValuesFromFloatToNumeric()
+    {
+      RexLiteral literal = REX_BUILDER.makeExactLiteral(
+          new BigDecimal("123.0"),
+          TYPE_FACTORY.createSqlType(SqlTypeName.INTEGER)
+      );
+      Object value = DruidLogicalValuesRule.getValueFromLiteral(literal, DEFAULT_CONTEXT);
+      Assert.assertEquals(value, 123L);
+    }
   }
 }

---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org