You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ri...@apache.org on 2022/04/15 23:53:15 UTC

[pinot] branch master updated: add scalar function for cast so it can be calculated at compile time (#8535)

This is an automated email from the ASF dual-hosted git repository.

richardstartin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 87fb007f7d add scalar function for cast so it can be calculated at compile time (#8535)
87fb007f7d is described below

commit 87fb007f7ddbeec5412b526f94ecd3bbd8fe2e08
Author: Richard Startin <ri...@startree.ai>
AuthorDate: Sat Apr 16 00:53:10 2022 +0100

    add scalar function for cast so it can be calculated at compile time (#8535)
---
 config/checkstyle.xml                              |   1 +
 .../scalar/DataTypeConversionFunctions.java        |  34 ++++
 .../scalar/DataTypeConversionFunctionsTest.java    |  63 ++++++++
 .../pinot/sql/parsers/CalciteSqlCompilerTest.java  |  30 +---
 .../function/CastTransformFunctionTest.java        |  23 ++-
 .../org/apache/pinot/queries/CastQueriesTest.java  | 171 +++++++++++++++++++++
 6 files changed, 295 insertions(+), 27 deletions(-)

diff --git a/config/checkstyle.xml b/config/checkstyle.xml
index 8c7242ca8b..86e163f50c 100644
--- a/config/checkstyle.xml
+++ b/config/checkstyle.xml
@@ -136,6 +136,7 @@
                 org.apache.pinot.controller.recommender.rules.io.params.RecommenderConstants.RulesToExecute.*,
                 org.apache.pinot.controller.recommender.rules.utils.PredicateParseResult.*,
                 org.apache.pinot.client.utils.Constants.*,
+                org.apache.pinot.common.utils.PinotDataType.*,
                 org.apache.pinot.segment.local.startree.StarTreeBuilderUtils.*,
                 org.apache.pinot.segment.local.startree.v2.store.StarTreeIndexMapUtils.*,
                 org.apache.pinot.segment.local.utils.GeometryType.*,
diff --git a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
index 47c9ce91f5..789cb35cb7 100644
--- a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
+++ b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
@@ -18,12 +18,19 @@
  */
 package org.apache.pinot.common.function.scalar;
 
+import com.google.common.base.Preconditions;
 import java.math.BigDecimal;
 import java.util.Base64;
+import org.apache.pinot.common.utils.PinotDataType;
 import org.apache.pinot.spi.annotations.ScalarFunction;
 import org.apache.pinot.spi.utils.BigDecimalUtils;
 import org.apache.pinot.spi.utils.BytesUtils;
 
+import static org.apache.pinot.common.utils.PinotDataType.DOUBLE;
+import static org.apache.pinot.common.utils.PinotDataType.INTEGER;
+import static org.apache.pinot.common.utils.PinotDataType.LONG;
+import static org.apache.pinot.common.utils.PinotDataType.STRING;
+
 
 /**
  * Contains function to convert a datatype to another datatype.
@@ -32,6 +39,33 @@ public class DataTypeConversionFunctions {
   private DataTypeConversionFunctions() {
   }
 
+  @ScalarFunction
+  public static Object cast(Object value, String targetTypeLiteral) {
+    try {
+      Class<?> clazz = value.getClass();
+      Preconditions.checkArgument(!clazz.isArray() | clazz == byte[].class, "%s must not be an array type", clazz);
+      PinotDataType sourceType = PinotDataType.getSingleValueType(clazz);
+      String transformed = targetTypeLiteral.toUpperCase();
+      PinotDataType targetDataType;
+      if ("INT".equals(transformed)) {
+        targetDataType = INTEGER;
+      } else if ("VARCHAR".equals(transformed)) {
+        targetDataType = STRING;
+      } else {
+        targetDataType = PinotDataType.valueOf(transformed);
+      }
+      if (sourceType == STRING && (targetDataType == INTEGER || targetDataType == LONG)) {
+        if (String.valueOf(value).contains(".")) {
+          // convert integers via double to avoid parse errors
+          return targetDataType.convert(DOUBLE.convert(value, sourceType), DOUBLE);
+        }
+      }
+      return targetDataType.convert(value, sourceType);
+    } catch (IllegalArgumentException e) {
+      throw new IllegalArgumentException("Unknown data type: " + targetTypeLiteral);
+    }
+  }
+
   /**
    * Converts big decimal string representation to bytes.
    * Only scale of upto 2 bytes is supported by the function
diff --git a/pinot-common/src/test/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctionsTest.java b/pinot-common/src/test/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctionsTest.java
new file mode 100644
index 0000000000..981e39e2ad
--- /dev/null
+++ b/pinot-common/src/test/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctionsTest.java
@@ -0,0 +1,63 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.common.function.scalar;
+
+import org.testng.annotations.DataProvider;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+
+
+public class DataTypeConversionFunctionsTest {
+
+  @DataProvider(name = "testCases")
+  public static Object[][] testCases() {
+    return new Object[][]{
+        {"a", "string", "a"},
+        {"10", "int", 10},
+        {"10", "long", 10L},
+        {"10", "float", 10F},
+        {"10", "double", 10D},
+        {"10.0", "int", 10},
+        {"10.0", "long", 10L},
+        {"10.0", "float", 10F},
+        {"10.0", "double", 10D},
+        {10, "string", "10"},
+        {10L, "string", "10"},
+        {10F, "string", "10.0"},
+        {10D, "string", "10.0"},
+        {"a", "string", "a"},
+        {10, "int", 10},
+        {10L, "long", 10L},
+        {10F, "float", 10F},
+        {10D, "double", 10D},
+        {10L, "int", 10},
+        {10, "long", 10L},
+        {10D, "float", 10F},
+        {10F, "double", 10D},
+        {"abc1", "bytes", new byte[]{(byte) 0xab, (byte) 0xc1}},
+        {new byte[]{(byte) 0xab, (byte) 0xc1}, "string", "abc1"}
+    };
+  }
+
+  @Test(dataProvider = "testCases")
+  public void test(Object value, String type, Object expected) {
+    assertEquals(DataTypeConversionFunctions.cast(value, type), expected);
+  }
+}
diff --git a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
index 83cde06edc..90f0d2604b 100644
--- a/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
+++ b/pinot-common/src/test/java/org/apache/pinot/sql/parsers/CalciteSqlCompilerTest.java
@@ -1611,41 +1611,19 @@ public class CalciteSqlCompilerTest {
   public void testCastTransformation() {
     PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery("select CAST(25.65 AS int) from myTable");
     Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
-    Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast");
-    Assert.assertEquals(
-        pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getDoubleValue(), 25.65);
-    Assert.assertEquals(
-        pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(),
-        "INTEGER");
+    Assert.assertEquals(pinotQuery.getSelectList().get(0).getLiteral().getLongValue(), 25);
 
     pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST('20170825' AS LONG) from myTable");
     Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
-    Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast");
-    Assert.assertEquals(
-        pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getStringValue(),
-        "20170825");
-    Assert.assertEquals(
-        pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(), "LONG");
+    Assert.assertEquals(pinotQuery.getSelectList().get(0).getLiteral().getLongValue(), 20170825);
 
     pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST(20170825.0 AS Float) from myTable");
     Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
-    Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast");
-    Assert.assertEquals(
-        pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getDoubleValue(),
-        20170825.0);
-    Assert.assertEquals(
-        pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(),
-        "FLOAT");
+    Assert.assertEquals((float) pinotQuery.getSelectList().get(0).getLiteral().getDoubleValue(), 20170825.0F);
 
     pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST(20170825.0 AS dOuble) from myTable");
     Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
-    Assert.assertEquals(pinotQuery.getSelectList().get(0).getFunctionCall().getOperator(), "cast");
-    Assert.assertEquals(
-        pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(0).getLiteral().getDoubleValue(),
-        20170825.0);
-    Assert.assertEquals(
-        pinotQuery.getSelectList().get(0).getFunctionCall().getOperands().get(1).getLiteral().getStringValue(),
-        "DOUBLE");
+    Assert.assertEquals((float) pinotQuery.getSelectList().get(0).getLiteral().getDoubleValue(), 20170825.0F);
 
     pinotQuery = CalciteSqlParser.compileToPinotQuery("SELECT CAST(column1 AS STRING) from myTable");
     Assert.assertEquals(pinotQuery.getSelectListSize(), 1);
diff --git a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CastTransformFunctionTest.java b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CastTransformFunctionTest.java
index 6f9c985896..19556a3d9e 100644
--- a/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CastTransformFunctionTest.java
+++ b/pinot-core/src/test/java/org/apache/pinot/core/operator/transform/function/CastTransformFunctionTest.java
@@ -23,6 +23,9 @@ import org.apache.pinot.common.request.context.RequestContextUtils;
 import org.testng.Assert;
 import org.testng.annotations.Test;
 
+import static org.apache.pinot.common.function.scalar.DataTypeConversionFunctions.cast;
+import static org.testng.Assert.assertEquals;
+
 
 public class CastTransformFunctionTest extends BaseTransformFunctionTest {
 
@@ -32,22 +35,28 @@ public class CastTransformFunctionTest extends BaseTransformFunctionTest {
         RequestContextUtils.getExpressionFromSQL(String.format("CAST(%s AS string)", INT_SV_COLUMN));
     TransformFunction transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
     Assert.assertTrue(transformFunction instanceof CastTransformFunction);
-    Assert.assertEquals(transformFunction.getName(), CastTransformFunction.FUNCTION_NAME);
+    assertEquals(transformFunction.getName(), CastTransformFunction.FUNCTION_NAME);
     String[] expectedValues = new String[NUM_ROWS];
+    String[] scalarStringValues = new String[NUM_ROWS];
     for (int i = 0; i < NUM_ROWS; i++) {
       expectedValues[i] = Integer.toString(_intSVValues[i]);
+      scalarStringValues[i] = (String) cast(_intSVValues[i], "string");
     }
     testTransformFunction(transformFunction, expectedValues);
+    assertEquals(expectedValues, scalarStringValues);
 
     expression =
         RequestContextUtils.getExpressionFromSQL(String.format("CAST(CAST(%s as INT) as FLOAT)", FLOAT_SV_COLUMN));
     transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
     Assert.assertTrue(transformFunction instanceof CastTransformFunction);
     float[] expectedFloatValues = new float[NUM_ROWS];
+    float[] scalarFloatValues = new float[NUM_ROWS];
     for (int i = 0; i < NUM_ROWS; i++) {
       expectedFloatValues[i] = (int) _floatSVValues[i];
+      scalarFloatValues[i] = (float) cast(cast(_floatSVValues[i], "int"), "float");
     }
     testTransformFunction(transformFunction, expectedFloatValues);
+    assertEquals(expectedFloatValues, scalarFloatValues);
 
     expression = RequestContextUtils.getExpressionFromSQL(
         String.format("CAST(ADD(CAST(%s AS LONG), %s) AS STRING)", DOUBLE_SV_COLUMN, LONG_SV_COLUMN));
@@ -55,18 +64,26 @@ public class CastTransformFunctionTest extends BaseTransformFunctionTest {
     Assert.assertTrue(transformFunction instanceof CastTransformFunction);
     for (int i = 0; i < NUM_ROWS; i++) {
       expectedValues[i] = Double.toString((double) (long) _doubleSVValues[i] + (double) _longSVValues[i]);
+      scalarStringValues[i] = (String) cast(
+          (double) (long) cast(_doubleSVValues[i], "long") + (double) _longSVValues[i], "string");
     }
     testTransformFunction(transformFunction, expectedValues);
+    assertEquals(expectedValues, scalarStringValues);
 
     expression = RequestContextUtils.getExpressionFromSQL(
         String.format("caSt(cAst(casT(%s as inT) + %s aS sTring) As DouBle)", FLOAT_SV_COLUMN, INT_SV_COLUMN));
     transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
     Assert.assertTrue(transformFunction instanceof CastTransformFunction);
     double[] expectedDoubleValues = new double[NUM_ROWS];
+    double[] scalarDoubleValues = new double[NUM_ROWS];
     for (int i = 0; i < NUM_ROWS; i++) {
       expectedDoubleValues[i] = (double) (int) _floatSVValues[i] + (double) _intSVValues[i];
+      scalarDoubleValues[i] =
+          (double) cast(cast((double) (int) cast(_floatSVValues[i], "int") + (double) _intSVValues[i], "string"),
+              "double");
     }
     testTransformFunction(transformFunction, expectedDoubleValues);
+    assertEquals(expectedDoubleValues, scalarDoubleValues);
 
     expression = RequestContextUtils.getExpressionFromSQL(String
         .format("CAST(CAST(%s AS INT) - CAST(%s AS FLOAT) / CAST(%s AS DOUBLE) AS LONG)", DOUBLE_SV_COLUMN,
@@ -74,10 +91,14 @@ public class CastTransformFunctionTest extends BaseTransformFunctionTest {
     transformFunction = TransformFunctionFactory.get(expression, _dataSourceMap);
     Assert.assertTrue(transformFunction instanceof CastTransformFunction);
     long[] expectedLongValues = new long[NUM_ROWS];
+    long[] longScalarValues = new long[NUM_ROWS];
     for (int i = 0; i < NUM_ROWS; i++) {
       expectedLongValues[i] =
           (long) ((double) (int) _doubleSVValues[i] - (double) (float) _longSVValues[i] / (double) _intSVValues[i]);
+      longScalarValues[i] = (long) cast((double) (int) cast(_doubleSVValues[i], "int")
+          - (double) (float) cast(_longSVValues[i], "float") / (double) cast(_intSVValues[i], "double"), "long");
     }
     testTransformFunction(transformFunction, expectedLongValues);
+    assertEquals(expectedLongValues, longScalarValues);
   }
 }
diff --git a/pinot-core/src/test/java/org/apache/pinot/queries/CastQueriesTest.java b/pinot-core/src/test/java/org/apache/pinot/queries/CastQueriesTest.java
new file mode 100644
index 0000000000..47f6b0d4c2
--- /dev/null
+++ b/pinot-core/src/test/java/org/apache/pinot/queries/CastQueriesTest.java
@@ -0,0 +1,171 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.pinot.queries;
+
+import java.io.File;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Iterator;
+import java.util.List;
+import org.apache.commons.io.FileUtils;
+import org.apache.pinot.core.common.Operator;
+import org.apache.pinot.core.operator.query.AggregationGroupByOperator;
+import org.apache.pinot.core.operator.query.AggregationOperator;
+import org.apache.pinot.core.operator.query.SelectionOnlyOperator;
+import org.apache.pinot.core.query.aggregation.groupby.AggregationGroupByResult;
+import org.apache.pinot.core.query.aggregation.groupby.GroupKeyGenerator;
+import org.apache.pinot.segment.local.indexsegment.immutable.ImmutableSegmentLoader;
+import org.apache.pinot.segment.local.segment.creator.impl.SegmentIndexCreationDriverImpl;
+import org.apache.pinot.segment.local.segment.readers.GenericRowRecordReader;
+import org.apache.pinot.segment.spi.ImmutableSegment;
+import org.apache.pinot.segment.spi.IndexSegment;
+import org.apache.pinot.segment.spi.creator.SegmentGeneratorConfig;
+import org.apache.pinot.spi.config.table.TableConfig;
+import org.apache.pinot.spi.config.table.TableType;
+import org.apache.pinot.spi.data.FieldSpec;
+import org.apache.pinot.spi.data.Schema;
+import org.apache.pinot.spi.data.readers.GenericRow;
+import org.apache.pinot.spi.utils.ReadMode;
+import org.apache.pinot.spi.utils.builder.TableConfigBuilder;
+import org.testng.annotations.BeforeClass;
+import org.testng.annotations.Test;
+
+import static org.testng.Assert.assertEquals;
+import static org.testng.Assert.assertNotNull;
+import static org.testng.Assert.assertTrue;
+
+
+public class CastQueriesTest extends BaseQueriesTest {
+
+  private static final File INDEX_DIR = new File(FileUtils.getTempDirectory(), "CastQueriesTest");
+  private static final String RAW_TABLE_NAME = "testTable";
+  private static final String SEGMENT_NAME = "testSegment";
+
+  private static final int NUM_RECORDS = 1000;
+  private static final int BUCKET_SIZE = 8;
+  private static final String CLASSIFICATION_COLUMN = "class";
+  private static final String X_COL = "x";
+  private static final String Y_COL = "y";
+
+  private static final Schema SCHEMA = new Schema.SchemaBuilder()
+      .addSingleValueDimension(X_COL, FieldSpec.DataType.DOUBLE)
+      .addSingleValueDimension(Y_COL, FieldSpec.DataType.DOUBLE)
+      .addSingleValueDimension(CLASSIFICATION_COLUMN, FieldSpec.DataType.STRING)
+      .build();
+
+  private static final TableConfig TABLE_CONFIG =
+      new TableConfigBuilder(TableType.OFFLINE).setTableName(RAW_TABLE_NAME).build();
+
+  private IndexSegment _indexSegment;
+  private List<IndexSegment> _indexSegments;
+
+  @Override
+  protected String getFilter() {
+    return "";
+  }
+
+  @Override
+  protected IndexSegment getIndexSegment() {
+    return _indexSegment;
+  }
+
+  @Override
+  protected List<IndexSegment> getIndexSegments() {
+    return _indexSegments;
+  }
+
+  @BeforeClass
+  public void setUp()
+      throws Exception {
+    FileUtils.deleteQuietly(INDEX_DIR);
+
+    List<GenericRow> records = new ArrayList<>(NUM_RECORDS);
+    for (int i = 0; i < NUM_RECORDS; i++) {
+      GenericRow record = new GenericRow();
+      record.putValue(X_COL, 0.5);
+      record.putValue(Y_COL, 0.25);
+      record.putValue(CLASSIFICATION_COLUMN, "" + (i % BUCKET_SIZE));
+      records.add(record);
+    }
+
+    SegmentGeneratorConfig segmentGeneratorConfig = new SegmentGeneratorConfig(TABLE_CONFIG, SCHEMA);
+    segmentGeneratorConfig.setTableName(RAW_TABLE_NAME);
+    segmentGeneratorConfig.setSegmentName(SEGMENT_NAME);
+    segmentGeneratorConfig.setOutDir(INDEX_DIR.getPath());
+
+    SegmentIndexCreationDriverImpl driver = new SegmentIndexCreationDriverImpl();
+    driver.init(segmentGeneratorConfig, new GenericRowRecordReader(records));
+    driver.build();
+
+    ImmutableSegment immutableSegment = ImmutableSegmentLoader.load(new File(INDEX_DIR, SEGMENT_NAME), ReadMode.mmap);
+    _indexSegment = immutableSegment;
+    _indexSegments = Arrays.asList(immutableSegment, immutableSegment);
+  }
+
+  @Test
+  public void testCastSum() {
+    String query = "select cast(sum(" + X_COL + ") as int), "
+        + "cast(sum(" + Y_COL + ") as int) "
+        + "from " + RAW_TABLE_NAME;
+    Operator<?> operator = getOperatorForSqlQuery(query);
+    assertTrue(operator instanceof AggregationOperator);
+    List<Object> aggregationResult = ((AggregationOperator) operator).nextBlock().getAggregationResult();
+    assertNotNull(aggregationResult);
+    assertEquals(aggregationResult.size(), 2);
+    assertEquals(((Number) aggregationResult.get(0)).intValue(), NUM_RECORDS / 2);
+    assertEquals(((Number) aggregationResult.get(1)).intValue(), NUM_RECORDS / 4);
+  }
+
+  @Test
+  public void testCastSumGroupBy() {
+    String query = "select cast(sum(" + X_COL + ") as int), "
+        + "cast(sum(" + Y_COL + ") as int) "
+        + "from " + RAW_TABLE_NAME + " "
+        + "group by " + CLASSIFICATION_COLUMN;
+    Operator<?> operator = getOperatorForSqlQuery(query);
+    assertTrue(operator instanceof AggregationGroupByOperator);
+    AggregationGroupByResult result = ((AggregationGroupByOperator) operator).nextBlock().getAggregationGroupByResult();
+    assertNotNull(result);
+    Iterator<GroupKeyGenerator.GroupKey> it = result.getGroupKeyIterator();
+    while (it.hasNext()) {
+      GroupKeyGenerator.GroupKey groupKey = it.next();
+      Object aggregate = result.getResultForGroupId(0, groupKey._groupId);
+      assertEquals(((Number) aggregate).intValue(), NUM_RECORDS / (2 * BUCKET_SIZE));
+      aggregate = result.getResultForGroupId(1, groupKey._groupId);
+      assertEquals(((Number) aggregate).intValue(), NUM_RECORDS / (4 * BUCKET_SIZE));
+    }
+  }
+
+  @Test
+  public void testCastFilterAndProject() {
+    String query = "select cast(" + CLASSIFICATION_COLUMN + " as int)"
+        + " from " + RAW_TABLE_NAME
+        + " where " + CLASSIFICATION_COLUMN + " = cast(0 as string) limit " + NUM_RECORDS;
+    Operator<?> operator = getOperatorForSqlQuery(query);
+    assertTrue(operator instanceof SelectionOnlyOperator);
+    Collection<Object[]> result = ((SelectionOnlyOperator) operator).nextBlock().getSelectionResult();
+    assertNotNull(result);
+    assertEquals(result.size(), NUM_RECORDS / BUCKET_SIZE);
+    for (Object[] row : result) {
+      assertEquals(row.length, 1);
+      assertEquals(row[0], 0);
+    }
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org