You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2023/05/25 17:17:57 UTC
[pinot] branch master updated: [Multi-stage] Support null in aggregate and filter (#10799)
This is an automated email from the ASF dual-hosted git repository.
jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git
The following commit(s) were added to refs/heads/master by this push:
new fcaebab69a [Multi-stage] Support null in aggregate and filter (#10799)
fcaebab69a is described below
commit fcaebab69a8a9d129cfae0130ab67c88815630bf
Author: Xiaotian (Jackie) Jiang <17...@users.noreply.github.com>
AuthorDate: Thu May 25 10:17:48 2023 -0700
[Multi-stage] Support null in aggregate and filter (#10799)
---
.../pinot/query/planner/logical/RexExpression.java | 18 +++--
.../partitioning/FieldSelectionKeySelector.java | 5 +-
.../query/runtime/operator/AggregateOperator.java | 91 +++++++++++++++-------
.../runtime/operator/WindowAggregateOperator.java | 10 +--
.../runtime/operator/operands/FilterOperand.java | 45 ++++++++---
.../operator/operands/TransformOperand.java | 77 ++----------------
.../runtime/operator/utils/AggregationUtils.java | 84 ++++++++++++++------
.../operator/utils/FunctionInvokeUtils.java | 6 +-
.../runtime/operator/AggregateOperatorTest.java | 9 ++-
.../operator/WindowAggregateOperatorTest.java | 4 +-
.../src/test/resources/queries/NullHandling.json | 51 ++++++++++++
.../test/resources/queries/WindowFunctions.json | 16 ++--
12 files changed, 255 insertions(+), 161 deletions(-)
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
index 9b879ab779..ab78924548 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
@@ -82,19 +82,23 @@ public interface RexExpression {
operands);
}
- static Object toRexValue(FieldSpec.DataType dataType, Comparable value) {
+ @Nullable
+ static Object toRexValue(FieldSpec.DataType dataType, @Nullable Comparable<?> value) {
+ if (value == null) {
+ return null;
+ }
switch (dataType) {
case INT:
- return value == null ? 0 : ((BigDecimal) value).intValue();
+ return ((BigDecimal) value).intValue();
case LONG:
- return value == null ? 0L : ((BigDecimal) value).longValue();
+ return ((BigDecimal) value).longValue();
case FLOAT:
- return value == null ? 0f : ((BigDecimal) value).floatValue();
- case BIG_DECIMAL:
+ return ((BigDecimal) value).floatValue();
case DOUBLE:
- return value == null ? 0d : ((BigDecimal) value).doubleValue();
+ case BIG_DECIMAL:
+ return ((BigDecimal) value).doubleValue();
case STRING:
- return value == null ? "" : ((NlsString) value).getValue();
+ return ((NlsString) value).getValue();
default:
return value;
}
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java
index 235c5bd491..b23b34433b 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java
@@ -85,7 +85,10 @@ public class FieldSelectionKeySelector implements KeySelector<Object[], Object[]
// TODO: consider better hashing algorithms than hashCode sum, such as XOR'ing
int hashCode = 0;
for (int columnIndex : _columnIndices) {
- hashCode += input[columnIndex].hashCode();
+ Object value = input[columnIndex];
+ if (value != null) {
+ hashCode += value.hashCode();
+ }
}
// return a positive number because this is used directly to modulo-index
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 8835980ace..d445a6a2ec 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -40,12 +40,9 @@ import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
import org.apache.pinot.query.runtime.operator.utils.AggregationUtils;
import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
import org.apache.pinot.segment.local.customobject.PinotFourthMoment;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
/**
- *
* AggregateOperator is used to aggregate values over a set of group by keys.
* Output data will be in the format of [group by key, aggregate result1, ... aggregate resultN]
* Currently, we only support SUM/COUNT/MIN/MAX aggregation.
@@ -60,7 +57,6 @@ import org.slf4j.LoggerFactory;
*/
public class AggregateOperator extends MultiStageOperator {
private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR";
- private static final Logger LOGGER = LoggerFactory.getLogger(AggregateOperator.class);
private final MultiStageOperator _inputOperator;
@@ -177,7 +173,7 @@ public class AggregateOperator extends MultiStageOperator {
private TransferableBlock constructEmptyAggResultBlock() {
Object[] row = new Object[_aggCalls.size()];
for (int i = 0; i < _accumulators.length; i++) {
- row[i] = _accumulators[i].getMerger().initialize(null, _accumulators[i].getDataType());
+ row[i] = _accumulators[i].getMerger().init(null, _accumulators[i].getDataType());
}
return new TransferableBlock(Collections.singletonList(row), _resultSchema, DataBlock.Type.ROW);
}
@@ -220,43 +216,76 @@ public class AggregateOperator extends MultiStageOperator {
private static class MergeFourthMomentNumeric implements AggregationUtils.Merger {
+ @Nullable
@Override
- public Object merge(Object left, Object right) {
- ((PinotFourthMoment) left).increment(((Number) right).doubleValue());
- return left;
+ public PinotFourthMoment init(@Nullable Object value, DataSchema.ColumnDataType dataType) {
+ if (value == null) {
+ return null;
+ }
+ PinotFourthMoment moment = new PinotFourthMoment();
+ moment.increment(((Number) value).doubleValue());
+ return moment;
}
+ @Nullable
@Override
- public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
- PinotFourthMoment moment = new PinotFourthMoment();
- moment.increment(((Number) other).doubleValue());
+ public PinotFourthMoment merge(@Nullable Object agg, @Nullable Object value) {
+ PinotFourthMoment moment = (PinotFourthMoment) agg;
+ if (value == null) {
+ return moment;
+ }
+ if (moment == null) {
+ moment = new PinotFourthMoment();
+ }
+ moment.increment(((Number) value).doubleValue());
return moment;
}
}
private static class MergeFourthMomentObject implements AggregationUtils.Merger {
+ @Nullable
@Override
- public Object merge(Object left, Object right) {
- PinotFourthMoment agg = (PinotFourthMoment) left;
- agg.combine((PinotFourthMoment) right);
- return agg;
+ public PinotFourthMoment merge(@Nullable Object agg, @Nullable Object value) {
+ PinotFourthMoment moment1 = (PinotFourthMoment) agg;
+ PinotFourthMoment moment2 = (PinotFourthMoment) value;
+ if (moment1 == null) {
+ return moment2;
+ }
+ if (moment2 == null) {
+ return moment1;
+ }
+ moment1.combine(moment2);
+ return moment1;
}
}
+ // TODO: this casts everything to `Set<?>` instead of using the primitive version (e.g. IntSet)
private static class MergeCountDistinctScalars implements AggregationUtils.Merger {
- @SuppressWarnings("unchecked")
+
+ @Nullable
@Override
- public Object merge(Object agg, Object value) {
- // TODO: this casts everything to `Set<?>` instead of using the primitive version (e.g. IntSet)
- ((Set<Object>) agg).add(value);
- return agg;
+ public Set<Object> init(@Nullable Object value, DataSchema.ColumnDataType dataType) {
+ if (value == null) {
+ return null;
+ }
+ Set<Object> set = new ObjectOpenHashSet<>();
+ set.add(value);
+ return set;
}
+ @SuppressWarnings("unchecked")
+ @Nullable
@Override
- public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
- ObjectOpenHashSet<Object> set = new ObjectOpenHashSet<>();
- set.add(other);
+ public Set<Object> merge(@Nullable Object agg, @Nullable Object value) {
+ Set<Object> set = (Set<Object>) agg;
+ if (value == null) {
+ return set;
+ }
+ if (set == null) {
+ set = new ObjectOpenHashSet<>();
+ }
+ set.add(value);
return set;
}
}
@@ -264,11 +293,19 @@ public class AggregateOperator extends MultiStageOperator {
private static class MergeCountDistinctSets implements AggregationUtils.Merger {
@SuppressWarnings("unchecked")
+ @Nullable
@Override
- public Object merge(Object agg, Object value) {
- // TODO: this casts everything to `Set<?>` instead of using the primitive version (e.g. IntSet)
- ((Set<Object>) agg).addAll((Set<Object>) value);
- return agg;
+ public Set<Object> merge(@Nullable Object agg, @Nullable Object value) {
+ Set<Object> set1 = (Set<Object>) agg;
+ Set<Object> set2 = (Set<Object>) value;
+ if (set1 == null) {
+ return set2;
+ }
+ if (set2 == null) {
+ return set1;
+ }
+ set1.addAll(set2);
+ return set1;
}
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
index 6b29cdc82a..9791a72d42 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
@@ -402,13 +402,13 @@ public class WindowAggregateOperator extends MultiStageOperator {
private static class MergeRowNumber implements AggregationUtils.Merger {
@Override
- public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
+ public Long init(@Nullable Object value, DataSchema.ColumnDataType dataType) {
return 1L;
}
@Override
- public Object merge(Object left, Object right) {
- return ((Number) left).longValue() + 1L;
+ public Long merge(Object agg, @Nullable Object value) {
+ return (long) agg + 1;
}
}
@@ -440,7 +440,7 @@ public class WindowAggregateOperator extends MultiStageOperator {
Object previousRowOutputValue) {
Object value = _inputRef == -1 ? _literal : row[_inputRef];
if (previousPartitionKey == null || !currentPartitionKey.equals(previousPartitionKey)) {
- return _merger.initialize(currentPartitionKey, _dataType);
+ return _merger.init(currentPartitionKey, _dataType);
} else {
return _merger.merge(previousRowOutputValue, value);
}
@@ -466,7 +466,7 @@ public class WindowAggregateOperator extends MultiStageOperator {
_orderByResults.putIfAbsent(key, new OrderKeyResult());
if (currentRes == null) {
- _orderByResults.get(key).addOrderByResult(orderKey, _merger.initialize(value, _dataType));
+ _orderByResults.get(key).addOrderByResult(orderKey, _merger.init(value, _dataType));
} else {
Object mergedResult;
if (orderKey.equals(previousOrderKeyIfPresent)) {
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
index fbd8f95bd2..ab69ce67b2 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
@@ -18,12 +18,13 @@
*/
package org.apache.pinot.query.runtime.operator.operands;
-
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.List;
+import java.util.function.IntPredicate;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
import org.apache.pinot.spi.utils.BooleanUtils;
@@ -100,15 +101,16 @@ public abstract class FilterOperand extends TransformOperand {
}
}
- public static abstract class Predicate extends FilterOperand {
- protected final TransformOperand _lhs;
- protected final TransformOperand _rhs;
- protected final boolean _requireCasting;
- protected final DataSchema.ColumnDataType _commonCastType;
+ public static class Predicate extends FilterOperand {
+ private final TransformOperand _lhs;
+ private final TransformOperand _rhs;
+ private final IntPredicate _comparisonResultPredicate;
+ private final boolean _requireCasting;
+ private final DataSchema.ColumnDataType _commonCastType;
/**
* Predicate constructor also resolve data type,
- * since we don't have a exhausted list of filter function signatures. we rely on type casting.
+ * since we don't have an exhausted list of filter function signatures. we rely on type casting.
*
* <ul>
* <li>if both RHS and LHS has null data type, exception occurs.</li>
@@ -116,22 +118,22 @@ public abstract class FilterOperand extends TransformOperand {
* <li>if either side supertype of the other, we use the super type.</li>
* <li>if we can't resolve a common data type, exception occurs.</li>
* </ul>
- *
- *
*/
- public Predicate(List<RexExpression> functionOperands, DataSchema inputDataSchema) {
+ public Predicate(List<RexExpression> functionOperands, DataSchema inputDataSchema,
+ IntPredicate comparisonResultPredicate) {
Preconditions.checkState(functionOperands.size() == 2,
"Expected 2 function ops for Predicate but got:" + functionOperands.size());
_lhs = TransformOperand.toTransformOperand(functionOperands.get(0), inputDataSchema);
_rhs = TransformOperand.toTransformOperand(functionOperands.get(1), inputDataSchema);
+ _comparisonResultPredicate = comparisonResultPredicate;
// TODO: Correctly throw exception instead of returning null.
// Currently exception thrown during constructor is not piped back to query dispatcher, thus in order to
// avoid silent failure, we deliberately set to null here, make the exception thrown during data processing.
// TODO: right now all the numeric columns are still doing conversion b/c even if the inputDataSchema asked for
// one of the number type, it might not contain the exact type in the payload.
- if (_lhs._resultType == null || _lhs._resultType == DataSchema.ColumnDataType.OBJECT
- || _rhs._resultType == null || _rhs._resultType == DataSchema.ColumnDataType.OBJECT) {
+ if (_lhs._resultType == null || _lhs._resultType == DataSchema.ColumnDataType.OBJECT || _rhs._resultType == null
+ || _rhs._resultType == DataSchema.ColumnDataType.OBJECT) {
_requireCasting = false;
_commonCastType = null;
} else if (_lhs._resultType.isSuperTypeOf(_rhs._resultType)) {
@@ -145,5 +147,24 @@ public abstract class FilterOperand extends TransformOperand {
_commonCastType = null;
}
}
+
+ @SuppressWarnings({"rawtypes", "unchecked"})
+ @Override
+ public Boolean apply(Object[] row) {
+ Comparable v1 = (Comparable) _lhs.apply(row);
+ if (v1 == null) {
+ return false;
+ }
+ Comparable v2 = (Comparable) _rhs.apply(row);
+ if (v2 == null) {
+ return false;
+ }
+ if (_requireCasting) {
+ v1 = (Comparable) FunctionInvokeUtils.convert(v1, _commonCastType);
+ v2 = (Comparable) FunctionInvokeUtils.convert(v2, _commonCastType);
+ assert v1 != null && v2 != null;
+ }
+ return _comparisonResultPredicate.test(v1.compareTo(v2));
+ }
}
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java
index 7c34c8e46c..88eb4f37de 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java
@@ -18,12 +18,11 @@
*/
package org.apache.pinot.query.runtime.operator.operands;
-
import com.google.common.base.Preconditions;
import java.util.List;
+import javax.annotation.Nullable;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.query.planner.logical.RexExpression;
-import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
import org.apache.pinot.query.runtime.operator.utils.OperatorUtils;
@@ -43,7 +42,6 @@ public abstract class TransformOperand {
}
}
- @SuppressWarnings({"ConstantConditions", "rawtypes", "unchecked"})
private static TransformOperand toTransformOperand(RexExpression.FunctionCall functionCall,
DataSchema inputDataSchema) {
final List<RexExpression> functionOperands = functionCall.getFunctionOperands();
@@ -65,77 +63,17 @@ public abstract class TransformOperand {
"BOOL / IS_TRUE takes one argument, passed in argument size:" + operandSize);
return new FilterOperand.True(functionOperands.get(0), inputDataSchema);
case "equals":
- return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
- @Override
- public Boolean apply(Object[] row) {
- if (_requireCasting) {
- return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
- FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) == 0;
- } else {
- return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) == 0;
- }
- }
- };
+ return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v == 0);
case "notEquals":
- return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
- @Override
- public Boolean apply(Object[] row) {
- if (_requireCasting) {
- return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
- FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) != 0;
- } else {
- return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) != 0;
- }
- }
- };
+ return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v != 0);
case "greaterThan":
- return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
- @Override
- public Boolean apply(Object[] row) {
- if (_requireCasting) {
- return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
- FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) > 0;
- } else {
- return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) > 0;
- }
- }
- };
+ return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v > 0);
case "greaterThanOrEqual":
- return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
- @Override
- public Boolean apply(Object[] row) {
- if (_requireCasting) {
- return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
- FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) >= 0;
- } else {
- return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) >= 0;
- }
- }
- };
+ return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v >= 0);
case "lessThan":
- return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
- @Override
- public Boolean apply(Object[] row) {
- if (_requireCasting) {
- return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
- FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) < 0;
- } else {
- return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) < 0;
- }
- }
- };
+ return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v < 0);
case "lessThanOrEqual":
- return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
- @Override
- public Boolean apply(Object[] row) {
- if (_requireCasting) {
- return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
- FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) <= 0;
- } else {
- return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) <= 0;
- }
- }
- };
+ return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v <= 0);
default:
return new FunctionOperand(functionCall, inputDataSchema);
}
@@ -149,5 +87,6 @@ public abstract class TransformOperand {
return _resultType;
}
+ @Nullable
public abstract Object apply(Object[] row);
}
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
index 81bd7dea0c..e3e466d7d2 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
@@ -24,6 +24,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
+import javax.annotation.Nullable;
import org.apache.pinot.common.utils.DataSchema;
import org.apache.pinot.core.data.table.Key;
import org.apache.pinot.query.planner.logical.RexExpression;
@@ -38,7 +39,6 @@ import org.apache.pinot.spi.data.FieldSpec;
* <p>Accumulation is used by {@code WindowAggregateOperator} and {@code AggregateOperator}.
*/
public class AggregationUtils {
-
private AggregationUtils() {
}
@@ -54,54 +54,92 @@ public class AggregationUtils {
return new Key(new Object[0]);
}
- private static Object mergeSum(Object left, Object right) {
- return ((Number) left).doubleValue() + ((Number) right).doubleValue();
+ // TODO: Use the correct type for SUM/MIN/MAX instead of always using double
+
+ @Nullable
+ private static Object mergeSum(@Nullable Object agg, @Nullable Object value) {
+ if (agg == null) {
+ return value;
+ }
+ if (value == null) {
+ return agg;
+ }
+ return ((Number) agg).doubleValue() + ((Number) value).doubleValue();
}
- private static Object mergeMin(Object left, Object right) {
- return Math.min(((Number) left).doubleValue(), ((Number) right).doubleValue());
+ @Nullable
+ private static Object mergeMin(@Nullable Object agg, @Nullable Object value) {
+ if (agg == null) {
+ return value;
+ }
+ if (value == null) {
+ return agg;
+ }
+ return Math.min(((Number) agg).doubleValue(), ((Number) value).doubleValue());
}
- private static Object mergeMax(Object left, Object right) {
- return Math.max(((Number) left).doubleValue(), ((Number) right).doubleValue());
+ @Nullable
+ private static Object mergeMax(@Nullable Object agg, @Nullable Object value) {
+ if (agg == null) {
+ return value;
+ }
+ if (value == null) {
+ return agg;
+ }
+ return Math.max(((Number) agg).doubleValue(), ((Number) value).doubleValue());
}
- private static Boolean mergeBoolAnd(Object left, Object right) {
- return ((Boolean) left) && ((Boolean) right);
+ @Nullable
+ private static Boolean mergeBoolAnd(@Nullable Object agg, @Nullable Object value) {
+ if (agg == null) {
+ return (Boolean) value;
+ }
+ if (value == null) {
+ return (Boolean) agg;
+ }
+ return ((Boolean) agg) & ((Boolean) value);
}
- private static Boolean mergeBoolOr(Object left, Object right) {
- return ((Boolean) left) || ((Boolean) right);
+ @Nullable
+ private static Boolean mergeBoolOr(@Nullable Object agg, @Nullable Object value) {
+ if (agg == null) {
+ return (Boolean) value;
+ }
+ if (value == null) {
+ return (Boolean) agg;
+ }
+ return ((Boolean) agg) | ((Boolean) value);
}
private static class MergeCounts implements AggregationUtils.Merger {
@Override
- public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
- return other == null ? 0 : 1;
+ public Long init(@Nullable Object value, DataSchema.ColumnDataType dataType) {
+ return value == null ? 0L : 1L;
}
@Override
- public Object merge(Object left, Object right) {
- return ((Number) left).doubleValue() + (right == null ? 0 : 1);
+ public Long merge(Object agg, @Nullable Object value) {
+ return value == null ? (long) agg : (long) agg + 1;
}
}
public interface Merger {
+
/**
- * Initializes the merger based on the first input
+ * Initializes the merger based on the column data type and first value.
*/
- default Object initialize(Object other, DataSchema.ColumnDataType dataType) {
- // TODO: Initialize as a double so that if only one row is returned it matches the type when many rows are
- // returned
- return other == null ? dataType.getNullPlaceholder() : other;
+ @Nullable
+ default Object init(@Nullable Object value, DataSchema.ColumnDataType dataType) {
+ return value;
}
/**
- * Merges the existing aggregate (the result of {@link #initialize(Object, DataSchema.ColumnDataType)}) with
+ * Merges the existing aggregate (the result of {@link #init(Object, DataSchema.ColumnDataType)}) with
* the new value coming in (which may be an aggregate in and of itself).
*/
- Object merge(Object agg, Object value);
+ @Nullable
+ Object merge(@Nullable Object agg, @Nullable Object value);
}
/**
@@ -169,7 +207,7 @@ public class AggregationUtils {
Object value = _inputRef == -1 ? _literal : row[_inputRef];
if (currentRes == null) {
- _results.put(key, _merger.initialize(value, _dataType));
+ _results.put(key, _merger.init(value, _dataType));
} else {
Object mergedResult = _merger.merge(currentRes, value);
_results.put(key, mergedResult);
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java
index de26cdfab4..851748d2b2 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java
@@ -18,13 +18,12 @@
*/
package org.apache.pinot.query.runtime.operator.utils;
+import javax.annotation.Nullable;
import org.apache.pinot.common.utils.DataSchema;
public class FunctionInvokeUtils {
-
private FunctionInvokeUtils() {
- // do not instantiate.
}
/**
@@ -35,7 +34,8 @@ public class FunctionInvokeUtils {
* @param columnDataType desired column data type
* @return converted entry
*/
- public static Object convert(Object inputObj, DataSchema.ColumnDataType columnDataType) {
+ @Nullable
+ public static Object convert(@Nullable Object inputObj, DataSchema.ColumnDataType columnDataType) {
if (columnDataType.isNumber() && columnDataType != DataSchema.ColumnDataType.BIG_DECIMAL) {
return inputObj == null ? null : columnDataType.convert(inputObj);
} else {
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
index fce25e7e3f..bda5086a40 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
@@ -21,6 +21,7 @@ package org.apache.pinot.query.runtime.operator;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import java.util.Arrays;
+import java.util.Collections;
import java.util.List;
import org.apache.calcite.sql.SqlKind;
import org.apache.pinot.common.utils.DataSchema;
@@ -201,7 +202,7 @@ public class AggregateOperatorTest {
AggregationUtils.Merger merger = Mockito.mock(AggregationUtils.Merger.class);
Mockito.when(merger.merge(Mockito.any(), Mockito.any())).thenReturn(12d);
- Mockito.when(merger.initialize(Mockito.any(), Mockito.any())).thenReturn(1d);
+ Mockito.when(merger.init(Mockito.any(), Mockito.any())).thenReturn(1d);
DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
AggregateOperator operator =
new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, calls, group, inSchema,
@@ -213,7 +214,7 @@ public class AggregateOperatorTest {
// Then:
// should call merger twice, one from second row in first block and two from the first row
// in second block
- Mockito.verify(merger, Mockito.times(1)).initialize(Mockito.any(), Mockito.any());
+ Mockito.verify(merger, Mockito.times(1)).init(Mockito.any(), Mockito.any());
Mockito.verify(merger, Mockito.times(2)).merge(Mockito.any(), Mockito.any());
Assert.assertEquals(resultBlock.getContainer().get(0), new Object[]{1, 12d},
"Expected two columns (group by key, agg value)");
@@ -226,8 +227,8 @@ public class AggregateOperatorTest {
RexExpression.FunctionCall agg = getSum(new RexExpression.InputRef(0));
DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
AggregateOperator sum0GroupBy1 = new AggregateOperator(OperatorTestUtil.getDefaultContext(), upstreamOperator,
- OperatorTestUtil.getDataSchema(OperatorTestUtil.OP_1), Arrays.asList(agg),
- Arrays.asList(new RexExpression.InputRef(1)), inSchema);
+ OperatorTestUtil.getDataSchema(OperatorTestUtil.OP_1), Collections.singletonList(agg),
+ Collections.singletonList(new RexExpression.InputRef(1)), inSchema);
TransferableBlock result = sum0GroupBy1.getNextBlock();
while (result.isNoOpBlock()) {
result = sum0GroupBy1.getNextBlock();
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
index 817c7239d6..fc54b8941e 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
@@ -319,7 +319,7 @@ public class WindowAggregateOperatorTest {
AggregationUtils.Merger merger = Mockito.mock(AggregationUtils.Merger.class);
Mockito.when(merger.merge(Mockito.any(), Mockito.any())).thenReturn(12d);
- Mockito.when(merger.initialize(Mockito.any(), Mockito.any())).thenReturn(1d);
+ Mockito.when(merger.init(Mockito.any(), Mockito.any())).thenReturn(1d);
DataSchema outSchema =
new DataSchema(new String[]{"group", "arg", "sum"}, new DataSchema.ColumnDataType[]{INT, INT, DOUBLE});
WindowAggregateOperator operator =
@@ -334,7 +334,7 @@ public class WindowAggregateOperatorTest {
// Then:
// should call merger twice, one from second row in first block and two from the first row
// in second block
- Mockito.verify(merger, Mockito.times(1)).initialize(Mockito.any(), Mockito.any());
+ Mockito.verify(merger, Mockito.times(1)).init(Mockito.any(), Mockito.any());
Mockito.verify(merger, Mockito.times(2)).merge(Mockito.any(), Mockito.any());
Assert.assertEquals(resultBlock.getContainer().get(0), new Object[]{1, 1, 12d},
"Expected three columns (original two columns, agg literal value)");
diff --git a/pinot-query-runtime/src/test/resources/queries/NullHandling.json b/pinot-query-runtime/src/test/resources/queries/NullHandling.json
new file mode 100644
index 0000000000..f51701317c
--- /dev/null
+++ b/pinot-query-runtime/src/test/resources/queries/NullHandling.json
@@ -0,0 +1,51 @@
+{
+ "null_on_intermediate": {
+ "tables": {
+ "tbl1" : {
+ "schema": [
+ {"name": "strCol1", "type": "STRING"},
+ {"name": "intCol1", "type": "INT"},
+ {"name": "strCol2", "type": "STRING"}
+ ],
+ "inputs": [
+ ["foo", 1, "foo"],
+ ["bar", 2, "alice"]
+ ]
+ },
+ "tbl2" : {
+ "schema": [
+ {"name": "strCol1", "type": "STRING"},
+ {"name": "strCol2", "type": "STRING"},
+ {"name": "intCol1", "type": "INT"},
+ {"name": "doubleCol1", "type": "DOUBLE"}
+ ],
+ "inputs": [
+ ["foo", "bob", 3, 3.1416],
+ ["alice", "alice", 4, 2.7183]
+ ]
+ }
+ },
+ "queries": [
+ {
+ "description": "LEFT JOIN and FILTER",
+ "sql": "SELECT {tbl1}.strCol2, {tbl2}.doubleCol1 IS NULL OR {tbl1}.intCol1 > 3 AS boolFlag FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1"
+ },
+ {
+ "description": "LEFT JOIN and TRANSFORM",
+ "sql": "SELECT {tbl1}.strCol2, {tbl1}.intCol1 * {tbl2}.doubleCol1 + {tbl2}.intCol1 FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1"
+ },
+ {
+ "description": "LEFT JOIN and AGGREGATE",
+ "sql": "SELECT COUNT({tbl2}.intCol1), MIN({tbl2}.intCol1), MAX({tbl2}.doubleCol1), SUM({tbl2}.doubleCol1) FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1"
+ },
+ {
+ "description": "LEFT JOIN and GROUP BY",
+ "sql": "SELECT {tbl1}.strCol2, {tbl2}.intCol1, COUNT(*) FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1 GROUP BY {tbl1}.strCol2, {tbl2}.intCol1"
+ },
+ {
+ "description": "LEFT JOIN and GROUP BY with AGGREGATE",
+ "sql": "SELECT {tbl1}.strCol2, COUNT({tbl2}.intCol1), MIN({tbl2}.intCol1), MAX({tbl2}.doubleCol1), SUM({tbl2}.doubleCol1) FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1 GROUP BY {tbl1}.strCol2"
+ }
+ ]
+ }
+}
\ No newline at end of file
diff --git a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json
index 5caa982bec..afc989fded 100644
--- a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json
+++ b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json
@@ -572,7 +572,7 @@
"description": "Single empty OVER() with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column",
"sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER() as count FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)",
"outputs": [
- [0]
+ [null]
]
},
{
@@ -580,7 +580,7 @@
"sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(ORDER BY string_col) as count FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)",
"keepOutputRowOrder": true,
"outputs": [
- [0]
+ [null]
]
},
{
@@ -1335,7 +1335,7 @@
"description": "Multiple empty OVER()s with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column",
"sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER() as count, MIN(double_col) OVER() as min FROM {tbl} WHERE string_col = 'a' AND bool_col != false AND int_col > 200)",
"outputs": [
- [0]
+ [null]
]
},
{
@@ -1343,7 +1343,7 @@
"sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(ORDER BY string_col) as count, MIN(double_col) OVER(ORDER BY string_col) as min FROM {tbl} WHERE string_col = 'a' AND bool_col != false AND int_col > 200)",
"keepOutputRowOrder": true,
"outputs": [
- [0]
+ [null]
]
},
{
@@ -2477,7 +2477,7 @@
"description": "Single OVER(PARTITION BY) with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column",
"sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(PARTITION BY string_col) as count FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)",
"outputs": [
- [0]
+ [null]
]
},
{
@@ -2486,7 +2486,7 @@
"comments": "Cannot enforce a global ordering as partitions aren't ordered, just keys within a partition are",
"keepOutputRowOrder": false,
"outputs": [
- [0]
+ [null]
]
},
{
@@ -3445,7 +3445,7 @@
"description": "Multiple OVER(PARTITION BY)s with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column",
"sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(PARTITION BY string_col) as count, AVG(int_col) OVER(PARTITION BY string_col) as avg FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)",
"outputs": [
- [0]
+ [null]
]
},
{
@@ -3454,7 +3454,7 @@
"comments": "Cannot enforce a global ordering as partitions aren't ordered, just keys within a partition are",
"keepOutputRowOrder": false,
"outputs": [
- [0]
+ [null]
]
},
{
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org