You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@pinot.apache.org by ja...@apache.org on 2023/05/25 17:17:57 UTC

[pinot] branch master updated: [Multi-stage] Support null in aggregate and filter (#10799)

This is an automated email from the ASF dual-hosted git repository.

jackie pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new fcaebab69a [Multi-stage] Support null in aggregate and filter (#10799)
fcaebab69a is described below

commit fcaebab69a8a9d129cfae0130ab67c88815630bf
Author: Xiaotian (Jackie) Jiang <17...@users.noreply.github.com>
AuthorDate: Thu May 25 10:17:48 2023 -0700

    [Multi-stage] Support null in aggregate and filter (#10799)
---
 .../pinot/query/planner/logical/RexExpression.java | 18 +++--
 .../partitioning/FieldSelectionKeySelector.java    |  5 +-
 .../query/runtime/operator/AggregateOperator.java  | 91 +++++++++++++++-------
 .../runtime/operator/WindowAggregateOperator.java  | 10 +--
 .../runtime/operator/operands/FilterOperand.java   | 45 ++++++++---
 .../operator/operands/TransformOperand.java        | 77 ++----------------
 .../runtime/operator/utils/AggregationUtils.java   | 84 ++++++++++++++------
 .../operator/utils/FunctionInvokeUtils.java        |  6 +-
 .../runtime/operator/AggregateOperatorTest.java    |  9 ++-
 .../operator/WindowAggregateOperatorTest.java      |  4 +-
 .../src/test/resources/queries/NullHandling.json   | 51 ++++++++++++
 .../test/resources/queries/WindowFunctions.json    | 16 ++--
 12 files changed, 255 insertions(+), 161 deletions(-)

diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
index 9b879ab779..ab78924548 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java
@@ -82,19 +82,23 @@ public interface RexExpression {
         operands);
   }
 
-  static Object toRexValue(FieldSpec.DataType dataType, Comparable value) {
+  @Nullable
+  static Object toRexValue(FieldSpec.DataType dataType, @Nullable Comparable<?> value) {
+    if (value == null) {
+      return null;
+    }
     switch (dataType) {
       case INT:
-        return value == null ? 0 : ((BigDecimal) value).intValue();
+        return ((BigDecimal) value).intValue();
       case LONG:
-        return value == null ? 0L : ((BigDecimal) value).longValue();
+        return ((BigDecimal) value).longValue();
       case FLOAT:
-        return value == null ? 0f : ((BigDecimal) value).floatValue();
-      case BIG_DECIMAL:
+        return ((BigDecimal) value).floatValue();
       case DOUBLE:
-        return value == null ? 0d : ((BigDecimal) value).doubleValue();
+      case BIG_DECIMAL:
+        return ((BigDecimal) value).doubleValue();
       case STRING:
-        return value == null ? "" : ((NlsString) value).getValue();
+        return ((NlsString) value).getValue();
       default:
         return value;
     }
diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java
index 235c5bd491..b23b34433b 100644
--- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java
+++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/partitioning/FieldSelectionKeySelector.java
@@ -85,7 +85,10 @@ public class FieldSelectionKeySelector implements KeySelector<Object[], Object[]
     // TODO: consider better hashing algorithms than hashCode sum, such as XOR'ing
     int hashCode = 0;
     for (int columnIndex : _columnIndices) {
-      hashCode += input[columnIndex].hashCode();
+      Object value = input[columnIndex];
+      if (value != null) {
+        hashCode += value.hashCode();
+      }
     }
 
     // return a positive number because this is used directly to modulo-index
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
index 8835980ace..d445a6a2ec 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/AggregateOperator.java
@@ -40,12 +40,9 @@ import org.apache.pinot.query.runtime.blocks.TransferableBlockUtils;
 import org.apache.pinot.query.runtime.operator.utils.AggregationUtils;
 import org.apache.pinot.query.runtime.plan.OpChainExecutionContext;
 import org.apache.pinot.segment.local.customobject.PinotFourthMoment;
-import org.slf4j.Logger;
-import org.slf4j.LoggerFactory;
 
 
 /**
- *
  * AggregateOperator is used to aggregate values over a set of group by keys.
  * Output data will be in the format of [group by key, aggregate result1, ... aggregate resultN]
  * Currently, we only support SUM/COUNT/MIN/MAX aggregation.
@@ -60,7 +57,6 @@ import org.slf4j.LoggerFactory;
  */
 public class AggregateOperator extends MultiStageOperator {
   private static final String EXPLAIN_NAME = "AGGREGATE_OPERATOR";
-  private static final Logger LOGGER = LoggerFactory.getLogger(AggregateOperator.class);
 
   private final MultiStageOperator _inputOperator;
 
@@ -177,7 +173,7 @@ public class AggregateOperator extends MultiStageOperator {
   private TransferableBlock constructEmptyAggResultBlock() {
     Object[] row = new Object[_aggCalls.size()];
     for (int i = 0; i < _accumulators.length; i++) {
-      row[i] = _accumulators[i].getMerger().initialize(null, _accumulators[i].getDataType());
+      row[i] = _accumulators[i].getMerger().init(null, _accumulators[i].getDataType());
     }
     return new TransferableBlock(Collections.singletonList(row), _resultSchema, DataBlock.Type.ROW);
   }
@@ -220,43 +216,76 @@ public class AggregateOperator extends MultiStageOperator {
 
   private static class MergeFourthMomentNumeric implements AggregationUtils.Merger {
 
+    @Nullable
     @Override
-    public Object merge(Object left, Object right) {
-      ((PinotFourthMoment) left).increment(((Number) right).doubleValue());
-      return left;
+    public PinotFourthMoment init(@Nullable Object value, DataSchema.ColumnDataType dataType) {
+      if (value == null) {
+        return null;
+      }
+      PinotFourthMoment moment = new PinotFourthMoment();
+      moment.increment(((Number) value).doubleValue());
+      return moment;
     }
 
+    @Nullable
     @Override
-    public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
-      PinotFourthMoment moment = new PinotFourthMoment();
-      moment.increment(((Number) other).doubleValue());
+    public PinotFourthMoment merge(@Nullable Object agg, @Nullable Object value) {
+      PinotFourthMoment moment = (PinotFourthMoment) agg;
+      if (value == null) {
+        return moment;
+      }
+      if (moment == null) {
+        moment = new PinotFourthMoment();
+      }
+      moment.increment(((Number) value).doubleValue());
       return moment;
     }
   }
 
   private static class MergeFourthMomentObject implements AggregationUtils.Merger {
 
+    @Nullable
     @Override
-    public Object merge(Object left, Object right) {
-      PinotFourthMoment agg = (PinotFourthMoment) left;
-      agg.combine((PinotFourthMoment) right);
-      return agg;
+    public PinotFourthMoment merge(@Nullable Object agg, @Nullable Object value) {
+      PinotFourthMoment moment1 = (PinotFourthMoment) agg;
+      PinotFourthMoment moment2 = (PinotFourthMoment) value;
+      if (moment1 == null) {
+        return moment2;
+      }
+      if (moment2 == null) {
+        return moment1;
+      }
+      moment1.combine(moment2);
+      return moment1;
     }
   }
 
+  // TODO: this casts everything to `Set<?>` instead of using the primitive version (e.g. IntSet)
   private static class MergeCountDistinctScalars implements AggregationUtils.Merger {
-    @SuppressWarnings("unchecked")
+
+    @Nullable
     @Override
-    public Object merge(Object agg, Object value) {
-      // TODO: this casts everything to `Set<?>` instead of using the primitive version (e.g. IntSet)
-      ((Set<Object>) agg).add(value);
-      return agg;
+    public Set<Object> init(@Nullable Object value, DataSchema.ColumnDataType dataType) {
+      if (value == null) {
+        return null;
+      }
+      Set<Object> set = new ObjectOpenHashSet<>();
+      set.add(value);
+      return set;
     }
 
+    @SuppressWarnings("unchecked")
+    @Nullable
     @Override
-    public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
-      ObjectOpenHashSet<Object> set = new ObjectOpenHashSet<>();
-      set.add(other);
+    public Set<Object> merge(@Nullable Object agg, @Nullable Object value) {
+      Set<Object> set = (Set<Object>) agg;
+      if (value == null) {
+        return set;
+      }
+      if (set == null) {
+        set = new ObjectOpenHashSet<>();
+      }
+      set.add(value);
       return set;
     }
   }
@@ -264,11 +293,19 @@ public class AggregateOperator extends MultiStageOperator {
   private static class MergeCountDistinctSets implements AggregationUtils.Merger {
 
     @SuppressWarnings("unchecked")
+    @Nullable
     @Override
-    public Object merge(Object agg, Object value) {
-      // TODO: this casts everything to `Set<?>` instead of using the primitive version (e.g. IntSet)
-      ((Set<Object>) agg).addAll((Set<Object>) value);
-      return agg;
+    public Set<Object> merge(@Nullable Object agg, @Nullable Object value) {
+      Set<Object> set1 = (Set<Object>) agg;
+      Set<Object> set2 = (Set<Object>) value;
+      if (set1 == null) {
+        return set2;
+      }
+      if (set2 == null) {
+        return set1;
+      }
+      set1.addAll(set2);
+      return set1;
     }
   }
 
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
index 6b29cdc82a..9791a72d42 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperator.java
@@ -402,13 +402,13 @@ public class WindowAggregateOperator extends MultiStageOperator {
   private static class MergeRowNumber implements AggregationUtils.Merger {
 
     @Override
-    public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
+    public Long init(@Nullable Object value, DataSchema.ColumnDataType dataType) {
       return 1L;
     }
 
     @Override
-    public Object merge(Object left, Object right) {
-      return ((Number) left).longValue() + 1L;
+    public Long merge(Object agg, @Nullable Object value) {
+      return (long) agg + 1;
     }
   }
 
@@ -440,7 +440,7 @@ public class WindowAggregateOperator extends MultiStageOperator {
         Object previousRowOutputValue) {
       Object value = _inputRef == -1 ? _literal : row[_inputRef];
       if (previousPartitionKey == null || !currentPartitionKey.equals(previousPartitionKey)) {
-        return _merger.initialize(currentPartitionKey, _dataType);
+        return _merger.init(currentPartitionKey, _dataType);
       } else {
         return _merger.merge(previousRowOutputValue, value);
       }
@@ -466,7 +466,7 @@ public class WindowAggregateOperator extends MultiStageOperator {
 
       _orderByResults.putIfAbsent(key, new OrderKeyResult());
       if (currentRes == null) {
-        _orderByResults.get(key).addOrderByResult(orderKey, _merger.initialize(value, _dataType));
+        _orderByResults.get(key).addOrderByResult(orderKey, _merger.init(value, _dataType));
       } else {
         Object mergedResult;
         if (orderKey.equals(previousOrderKeyIfPresent)) {
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
index fbd8f95bd2..ab69ce67b2 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/FilterOperand.java
@@ -18,12 +18,13 @@
  */
 package org.apache.pinot.query.runtime.operator.operands;
 
-
 import com.google.common.base.Preconditions;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.function.IntPredicate;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.logical.RexExpression;
+import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
 import org.apache.pinot.spi.utils.BooleanUtils;
 
 
@@ -100,15 +101,16 @@ public abstract class FilterOperand extends TransformOperand {
     }
   }
 
-  public static abstract class Predicate extends FilterOperand {
-    protected final TransformOperand _lhs;
-    protected final TransformOperand _rhs;
-    protected final boolean _requireCasting;
-    protected final DataSchema.ColumnDataType _commonCastType;
+  public static class Predicate extends FilterOperand {
+    private final TransformOperand _lhs;
+    private final TransformOperand _rhs;
+    private final IntPredicate _comparisonResultPredicate;
+    private final boolean _requireCasting;
+    private final DataSchema.ColumnDataType _commonCastType;
 
     /**
      * Predicate constructor also resolve data type,
-     * since we don't have a exhausted list of filter function signatures. we rely on type casting.
+     * since we don't have an exhausted list of filter function signatures. we rely on type casting.
      *
      * <ul>
      *   <li>if both RHS and LHS has null data type, exception occurs.</li>
@@ -116,22 +118,22 @@ public abstract class FilterOperand extends TransformOperand {
      *   <li>if either side supertype of the other, we use the super type.</li>
      *   <li>if we can't resolve a common data type, exception occurs.</li>
      * </ul>
-     *
-     *
      */
-    public Predicate(List<RexExpression> functionOperands, DataSchema inputDataSchema) {
+    public Predicate(List<RexExpression> functionOperands, DataSchema inputDataSchema,
+        IntPredicate comparisonResultPredicate) {
       Preconditions.checkState(functionOperands.size() == 2,
           "Expected 2 function ops for Predicate but got:" + functionOperands.size());
       _lhs = TransformOperand.toTransformOperand(functionOperands.get(0), inputDataSchema);
       _rhs = TransformOperand.toTransformOperand(functionOperands.get(1), inputDataSchema);
+      _comparisonResultPredicate = comparisonResultPredicate;
 
       // TODO: Correctly throw exception instead of returning null.
       // Currently exception thrown during constructor is not piped back to query dispatcher, thus in order to
       // avoid silent failure, we deliberately set to null here, make the exception thrown during data processing.
       // TODO: right now all the numeric columns are still doing conversion b/c even if the inputDataSchema asked for
       // one of the number type, it might not contain the exact type in the payload.
-      if (_lhs._resultType == null || _lhs._resultType == DataSchema.ColumnDataType.OBJECT
-          || _rhs._resultType == null || _rhs._resultType == DataSchema.ColumnDataType.OBJECT) {
+      if (_lhs._resultType == null || _lhs._resultType == DataSchema.ColumnDataType.OBJECT || _rhs._resultType == null
+          || _rhs._resultType == DataSchema.ColumnDataType.OBJECT) {
         _requireCasting = false;
         _commonCastType = null;
       } else if (_lhs._resultType.isSuperTypeOf(_rhs._resultType)) {
@@ -145,5 +147,24 @@ public abstract class FilterOperand extends TransformOperand {
         _commonCastType = null;
       }
     }
+
+    @SuppressWarnings({"rawtypes", "unchecked"})
+    @Override
+    public Boolean apply(Object[] row) {
+      Comparable v1 = (Comparable) _lhs.apply(row);
+      if (v1 == null) {
+        return false;
+      }
+      Comparable v2 = (Comparable) _rhs.apply(row);
+      if (v2 == null) {
+        return false;
+      }
+      if (_requireCasting) {
+        v1 = (Comparable) FunctionInvokeUtils.convert(v1, _commonCastType);
+        v2 = (Comparable) FunctionInvokeUtils.convert(v2, _commonCastType);
+        assert v1 != null && v2 != null;
+      }
+      return _comparisonResultPredicate.test(v1.compareTo(v2));
+    }
   }
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java
index 7c34c8e46c..88eb4f37de 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/operands/TransformOperand.java
@@ -18,12 +18,11 @@
  */
 package org.apache.pinot.query.runtime.operator.operands;
 
-
 import com.google.common.base.Preconditions;
 import java.util.List;
+import javax.annotation.Nullable;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.query.planner.logical.RexExpression;
-import org.apache.pinot.query.runtime.operator.utils.FunctionInvokeUtils;
 import org.apache.pinot.query.runtime.operator.utils.OperatorUtils;
 
 
@@ -43,7 +42,6 @@ public abstract class TransformOperand {
     }
   }
 
-  @SuppressWarnings({"ConstantConditions", "rawtypes", "unchecked"})
   private static TransformOperand toTransformOperand(RexExpression.FunctionCall functionCall,
       DataSchema inputDataSchema) {
     final List<RexExpression> functionOperands = functionCall.getFunctionOperands();
@@ -65,77 +63,17 @@ public abstract class TransformOperand {
             "BOOL / IS_TRUE takes one argument, passed in argument size:" + operandSize);
         return new FilterOperand.True(functionOperands.get(0), inputDataSchema);
       case "equals":
-        return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
-          @Override
-          public Boolean apply(Object[] row) {
-            if (_requireCasting) {
-              return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
-                  FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) == 0;
-            } else {
-              return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) == 0;
-            }
-          }
-        };
+        return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v == 0);
       case "notEquals":
-        return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
-          @Override
-          public Boolean apply(Object[] row) {
-            if (_requireCasting) {
-              return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
-                  FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) != 0;
-            } else {
-              return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) != 0;
-            }
-          }
-        };
+        return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v != 0);
       case "greaterThan":
-        return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
-          @Override
-          public Boolean apply(Object[] row) {
-            if (_requireCasting) {
-              return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
-                  FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) > 0;
-            } else {
-              return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) > 0;
-            }
-          }
-        };
+        return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v > 0);
       case "greaterThanOrEqual":
-        return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
-          @Override
-          public Boolean apply(Object[] row) {
-            if (_requireCasting) {
-              return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
-                  FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) >= 0;
-            } else {
-              return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) >= 0;
-            }
-          }
-        };
+        return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v >= 0);
       case "lessThan":
-        return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
-          @Override
-          public Boolean apply(Object[] row) {
-            if (_requireCasting) {
-              return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
-                  FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) < 0;
-            } else {
-              return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) < 0;
-            }
-          }
-        };
+        return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v < 0);
       case "lessThanOrEqual":
-        return new FilterOperand.Predicate(functionOperands, inputDataSchema) {
-          @Override
-          public Boolean apply(Object[] row) {
-            if (_requireCasting) {
-              return ((Comparable) FunctionInvokeUtils.convert(_lhs.apply(row), _commonCastType)).compareTo(
-                  FunctionInvokeUtils.convert(_rhs.apply(row), _commonCastType)) <= 0;
-            } else {
-              return ((Comparable) _lhs.apply(row)).compareTo(_rhs.apply(row)) <= 0;
-            }
-          }
-        };
+        return new FilterOperand.Predicate(functionOperands, inputDataSchema, v -> v <= 0);
       default:
         return new FunctionOperand(functionCall, inputDataSchema);
     }
@@ -149,5 +87,6 @@ public abstract class TransformOperand {
     return _resultType;
   }
 
+  @Nullable
   public abstract Object apply(Object[] row);
 }
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
index 81bd7dea0c..e3e466d7d2 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/AggregationUtils.java
@@ -24,6 +24,7 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 import java.util.function.Function;
+import javax.annotation.Nullable;
 import org.apache.pinot.common.utils.DataSchema;
 import org.apache.pinot.core.data.table.Key;
 import org.apache.pinot.query.planner.logical.RexExpression;
@@ -38,7 +39,6 @@ import org.apache.pinot.spi.data.FieldSpec;
  * <p>Accumulation is used by {@code WindowAggregateOperator} and {@code AggregateOperator}.
  */
 public class AggregationUtils {
-
   private AggregationUtils() {
   }
 
@@ -54,54 +54,92 @@ public class AggregationUtils {
     return new Key(new Object[0]);
   }
 
-  private static Object mergeSum(Object left, Object right) {
-    return ((Number) left).doubleValue() + ((Number) right).doubleValue();
+  // TODO: Use the correct type for SUM/MIN/MAX instead of always using double
+
+  @Nullable
+  private static Object mergeSum(@Nullable Object agg, @Nullable Object value) {
+    if (agg == null) {
+      return value;
+    }
+    if (value == null) {
+      return agg;
+    }
+    return ((Number) agg).doubleValue() + ((Number) value).doubleValue();
   }
 
-  private static Object mergeMin(Object left, Object right) {
-    return Math.min(((Number) left).doubleValue(), ((Number) right).doubleValue());
+  @Nullable
+  private static Object mergeMin(@Nullable Object agg, @Nullable Object value) {
+    if (agg == null) {
+      return value;
+    }
+    if (value == null) {
+      return agg;
+    }
+    return Math.min(((Number) agg).doubleValue(), ((Number) value).doubleValue());
   }
 
-  private static Object mergeMax(Object left, Object right) {
-    return Math.max(((Number) left).doubleValue(), ((Number) right).doubleValue());
+  @Nullable
+  private static Object mergeMax(@Nullable Object agg, @Nullable Object value) {
+    if (agg == null) {
+      return value;
+    }
+    if (value == null) {
+      return agg;
+    }
+    return Math.max(((Number) agg).doubleValue(), ((Number) value).doubleValue());
   }
 
-  private static Boolean mergeBoolAnd(Object left, Object right) {
-    return ((Boolean) left) && ((Boolean) right);
+  @Nullable
+  private static Boolean mergeBoolAnd(@Nullable Object agg, @Nullable Object value) {
+    if (agg == null) {
+      return (Boolean) value;
+    }
+    if (value == null) {
+      return (Boolean) agg;
+    }
+    return ((Boolean) agg) & ((Boolean) value);
   }
 
-  private static Boolean mergeBoolOr(Object left, Object right) {
-    return ((Boolean) left) || ((Boolean) right);
+  @Nullable
+  private static Boolean mergeBoolOr(@Nullable Object agg, @Nullable Object value) {
+    if (agg == null) {
+      return (Boolean) value;
+    }
+    if (value == null) {
+      return (Boolean) agg;
+    }
+    return ((Boolean) agg) | ((Boolean) value);
   }
 
   private static class MergeCounts implements AggregationUtils.Merger {
 
     @Override
-    public Object initialize(Object other, DataSchema.ColumnDataType dataType) {
-      return other == null ? 0 : 1;
+    public Long init(@Nullable Object value, DataSchema.ColumnDataType dataType) {
+      return value == null ? 0L : 1L;
     }
 
     @Override
-    public Object merge(Object left, Object right) {
-      return ((Number) left).doubleValue() + (right == null ? 0 : 1);
+    public Long merge(Object agg, @Nullable Object value) {
+      return value == null ? (long) agg : (long) agg + 1;
     }
   }
 
   public interface Merger {
+
     /**
-     * Initializes the merger based on the first input
+     * Initializes the merger based on the column data type and first value.
      */
-    default Object initialize(Object other, DataSchema.ColumnDataType dataType) {
-      // TODO: Initialize as a double so that if only one row is returned it matches the type when many rows are
-      //       returned
-      return other == null ? dataType.getNullPlaceholder() : other;
+    @Nullable
+    default Object init(@Nullable Object value, DataSchema.ColumnDataType dataType) {
+      return value;
     }
 
     /**
-     * Merges the existing aggregate (the result of {@link #initialize(Object, DataSchema.ColumnDataType)}) with
+     * Merges the existing aggregate (the result of {@link #init(Object, DataSchema.ColumnDataType)}) with
      * the new value coming in (which may be an aggregate in and of itself).
      */
-    Object merge(Object agg, Object value);
+    @Nullable
+    Object merge(@Nullable Object agg, @Nullable Object value);
   }
 
   /**
@@ -169,7 +207,7 @@ public class AggregationUtils {
       Object value = _inputRef == -1 ? _literal : row[_inputRef];
 
       if (currentRes == null) {
-        _results.put(key, _merger.initialize(value, _dataType));
+        _results.put(key, _merger.init(value, _dataType));
       } else {
         Object mergedResult = _merger.merge(currentRes, value);
         _results.put(key, mergedResult);
diff --git a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java
index de26cdfab4..851748d2b2 100644
--- a/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java
+++ b/pinot-query-runtime/src/main/java/org/apache/pinot/query/runtime/operator/utils/FunctionInvokeUtils.java
@@ -18,13 +18,12 @@
  */
 package org.apache.pinot.query.runtime.operator.utils;
 
+import javax.annotation.Nullable;
 import org.apache.pinot.common.utils.DataSchema;
 
 
 public class FunctionInvokeUtils {
-
   private FunctionInvokeUtils() {
-    // do not instantiate.
   }
 
   /**
@@ -35,7 +34,8 @@ public class FunctionInvokeUtils {
    * @param columnDataType desired column data type
    * @return converted entry
    */
-  public static Object convert(Object inputObj, DataSchema.ColumnDataType columnDataType) {
+  @Nullable
+  public static Object convert(@Nullable Object inputObj, DataSchema.ColumnDataType columnDataType) {
     if (columnDataType.isNumber() && columnDataType != DataSchema.ColumnDataType.BIG_DECIMAL) {
       return inputObj == null ? null : columnDataType.convert(inputObj);
     } else {
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
index fce25e7e3f..bda5086a40 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/AggregateOperatorTest.java
@@ -21,6 +21,7 @@ package org.apache.pinot.query.runtime.operator;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import java.util.Arrays;
+import java.util.Collections;
 import java.util.List;
 import org.apache.calcite.sql.SqlKind;
 import org.apache.pinot.common.utils.DataSchema;
@@ -201,7 +202,7 @@ public class AggregateOperatorTest {
 
     AggregationUtils.Merger merger = Mockito.mock(AggregationUtils.Merger.class);
     Mockito.when(merger.merge(Mockito.any(), Mockito.any())).thenReturn(12d);
-    Mockito.when(merger.initialize(Mockito.any(), Mockito.any())).thenReturn(1d);
+    Mockito.when(merger.init(Mockito.any(), Mockito.any())).thenReturn(1d);
     DataSchema outSchema = new DataSchema(new String[]{"sum"}, new ColumnDataType[]{DOUBLE});
     AggregateOperator operator =
         new AggregateOperator(OperatorTestUtil.getDefaultContext(), _input, outSchema, calls, group, inSchema,
@@ -213,7 +214,7 @@ public class AggregateOperatorTest {
     // Then:
     // should call merger twice, one from second row in first block and two from the first row
     // in second block
-    Mockito.verify(merger, Mockito.times(1)).initialize(Mockito.any(), Mockito.any());
+    Mockito.verify(merger, Mockito.times(1)).init(Mockito.any(), Mockito.any());
     Mockito.verify(merger, Mockito.times(2)).merge(Mockito.any(), Mockito.any());
     Assert.assertEquals(resultBlock.getContainer().get(0), new Object[]{1, 12d},
         "Expected two columns (group by key, agg value)");
@@ -226,8 +227,8 @@ public class AggregateOperatorTest {
     RexExpression.FunctionCall agg = getSum(new RexExpression.InputRef(0));
     DataSchema inSchema = new DataSchema(new String[]{"group", "arg"}, new ColumnDataType[]{INT, INT});
     AggregateOperator sum0GroupBy1 = new AggregateOperator(OperatorTestUtil.getDefaultContext(), upstreamOperator,
-        OperatorTestUtil.getDataSchema(OperatorTestUtil.OP_1), Arrays.asList(agg),
-        Arrays.asList(new RexExpression.InputRef(1)), inSchema);
+        OperatorTestUtil.getDataSchema(OperatorTestUtil.OP_1), Collections.singletonList(agg),
+        Collections.singletonList(new RexExpression.InputRef(1)), inSchema);
     TransferableBlock result = sum0GroupBy1.getNextBlock();
     while (result.isNoOpBlock()) {
       result = sum0GroupBy1.getNextBlock();
diff --git a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
index 817c7239d6..fc54b8941e 100644
--- a/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
+++ b/pinot-query-runtime/src/test/java/org/apache/pinot/query/runtime/operator/WindowAggregateOperatorTest.java
@@ -319,7 +319,7 @@ public class WindowAggregateOperatorTest {
 
     AggregationUtils.Merger merger = Mockito.mock(AggregationUtils.Merger.class);
     Mockito.when(merger.merge(Mockito.any(), Mockito.any())).thenReturn(12d);
-    Mockito.when(merger.initialize(Mockito.any(), Mockito.any())).thenReturn(1d);
+    Mockito.when(merger.init(Mockito.any(), Mockito.any())).thenReturn(1d);
     DataSchema outSchema =
         new DataSchema(new String[]{"group", "arg", "sum"}, new DataSchema.ColumnDataType[]{INT, INT, DOUBLE});
     WindowAggregateOperator operator =
@@ -334,7 +334,7 @@ public class WindowAggregateOperatorTest {
     // Then:
     // should call merger twice, one from second row in first block and two from the first row
     // in second block
-    Mockito.verify(merger, Mockito.times(1)).initialize(Mockito.any(), Mockito.any());
+    Mockito.verify(merger, Mockito.times(1)).init(Mockito.any(), Mockito.any());
     Mockito.verify(merger, Mockito.times(2)).merge(Mockito.any(), Mockito.any());
     Assert.assertEquals(resultBlock.getContainer().get(0), new Object[]{1, 1, 12d},
         "Expected three columns (original two columns, agg literal value)");
diff --git a/pinot-query-runtime/src/test/resources/queries/NullHandling.json b/pinot-query-runtime/src/test/resources/queries/NullHandling.json
new file mode 100644
index 0000000000..f51701317c
--- /dev/null
+++ b/pinot-query-runtime/src/test/resources/queries/NullHandling.json
@@ -0,0 +1,51 @@
+{
+  "null_on_intermediate": {
+    "tables": {
+      "tbl1" : {
+        "schema": [
+          {"name": "strCol1", "type": "STRING"},
+          {"name": "intCol1", "type": "INT"},
+          {"name": "strCol2", "type": "STRING"}
+        ],
+        "inputs": [
+          ["foo", 1, "foo"],
+          ["bar", 2, "alice"]
+        ]
+      },
+      "tbl2" : {
+        "schema": [
+          {"name": "strCol1", "type": "STRING"},
+          {"name": "strCol2", "type": "STRING"},
+          {"name": "intCol1", "type": "INT"},
+          {"name": "doubleCol1", "type": "DOUBLE"}
+        ],
+        "inputs": [
+          ["foo", "bob", 3, 3.1416],
+          ["alice", "alice", 4, 2.7183]
+        ]
+      }
+    },
+    "queries": [
+      {
+        "description": "LEFT JOIN and FILTER",
+        "sql": "SELECT {tbl1}.strCol2, {tbl2}.doubleCol1 IS NULL OR {tbl1}.intCol1 > 3 AS boolFlag FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1"
+      },
+      {
+        "description": "LEFT JOIN and TRANSFORM",
+        "sql": "SELECT {tbl1}.strCol2, {tbl1}.intCol1 * {tbl2}.doubleCol1 + {tbl2}.intCol1 FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1"
+      },
+      {
+        "description": "LEFT JOIN and AGGREGATE",
+        "sql": "SELECT COUNT({tbl2}.intCol1), MIN({tbl2}.intCol1), MAX({tbl2}.doubleCol1), SUM({tbl2}.doubleCol1) FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1"
+      },
+      {
+        "description": "LEFT JOIN and GROUP BY",
+        "sql": "SELECT {tbl1}.strCol2, {tbl2}.intCol1, COUNT(*) FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1 GROUP BY {tbl1}.strCol2, {tbl2}.intCol1"
+      },
+      {
+        "description": "LEFT JOIN and GROUP BY with AGGREGATE",
+        "sql": "SELECT {tbl1}.strCol2, COUNT({tbl2}.intCol1), MIN({tbl2}.intCol1), MAX({tbl2}.doubleCol1), SUM({tbl2}.doubleCol1) FROM {tbl1} LEFT OUTER JOIN {tbl2} ON {tbl1}.strCol1 = {tbl2}.strCol1 GROUP BY {tbl1}.strCol2"
+      }
+    ]
+  }
+}
\ No newline at end of file
diff --git a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json
index 5caa982bec..afc989fded 100644
--- a/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json
+++ b/pinot-query-runtime/src/test/resources/queries/WindowFunctions.json
@@ -572,7 +572,7 @@
         "description": "Single empty OVER() with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column",
         "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER() as count FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)",
         "outputs": [
-          [0]
+          [null]
         ]
       },
       {
@@ -580,7 +580,7 @@
         "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(ORDER BY string_col) as count FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)",
         "keepOutputRowOrder": true,
         "outputs": [
-          [0]
+          [null]
         ]
       },
       {
@@ -1335,7 +1335,7 @@
         "description": "Multiple empty OVER()s with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column",
         "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER() as count, MIN(double_col) OVER() as min FROM {tbl} WHERE string_col = 'a' AND bool_col != false AND int_col > 200)",
         "outputs": [
-          [0]
+          [null]
         ]
       },
       {
@@ -1343,7 +1343,7 @@
         "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(ORDER BY string_col) as count, MIN(double_col) OVER(ORDER BY string_col) as min FROM {tbl} WHERE string_col = 'a' AND bool_col != false AND int_col > 200)",
         "keepOutputRowOrder": true,
         "outputs": [
-          [0]
+          [null]
         ]
       },
       {
@@ -2477,7 +2477,7 @@
         "description": "Single OVER(PARTITION BY) with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column",
         "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(PARTITION BY string_col) as count FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)",
         "outputs": [
-          [0]
+          [null]
         ]
       },
       {
@@ -2486,7 +2486,7 @@
         "comments": "Cannot enforce a global ordering as partitions aren't ordered, just keys within a partition are",
         "keepOutputRowOrder": false,
         "outputs": [
-          [0]
+          [null]
         ]
       },
       {
@@ -3445,7 +3445,7 @@
         "description": "Multiple OVER(PARTITION BY)s with select col and filter which matches no rows in a sub-query and outer query with aggregation on that column",
         "sql": "SELECT SUM(count) FROM (SELECT string_col, COUNT(bool_col) OVER(PARTITION BY string_col) as count, AVG(int_col) OVER(PARTITION BY string_col) as avg FROM {tbl} WHERE string_col = 'a' AND bool_col = false AND int_col > 200)",
         "outputs": [
-          [0]
+          [null]
         ]
       },
       {
@@ -3454,7 +3454,7 @@
         "comments": "Cannot enforce a global ordering as partitions aren't ordered, just keys within a partition are",
         "keepOutputRowOrder": false,
         "outputs": [
-          [0]
+          [null]
         ]
       },
       {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@pinot.apache.org
For additional commands, e-mail: commits-help@pinot.apache.org