You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2018/08/13 01:06:04 UTC

[4/4] calcite git commit: [CALCITE-2402] Implement regr functions: COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY

[CALCITE-2402] Implement regr functions: COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY

Use filters in case of AggregateReduceFunctionsRule expansions.

Close apache/calcite#779


Project: http://git-wip-us.apache.org/repos/asf/calcite/repo
Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/ca858dd7
Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/ca858dd7
Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/ca858dd7

Branch: refs/heads/master
Commit: ca858dd725dea6bf9b4a9059cf1c3ba98bd82f26
Parents: 5574873
Author: snuyanzin <sn...@gmail.com>
Authored: Fri Jun 29 11:41:24 2018 +0300
Committer: Julian Hyde <jh...@apache.org>
Committed: Sun Aug 12 18:04:44 2018 -0700

----------------------------------------------------------------------
 core/src/main/codegen/templates/Parser.jj       |   1 +
 .../calcite/adapter/enumerable/RexImpTable.java |   3 +
 .../rel/rules/AggregateReduceFunctionsRule.java | 259 ++++++++++++++++++-
 .../java/org/apache/calcite/sql/SqlKind.java    |  23 +-
 .../calcite/sql/fun/SqlCountAggFunction.java    |  10 +-
 .../calcite/sql/fun/SqlCovarAggFunction.java    |   8 +-
 .../sql/fun/SqlRegrCountAggFunction.java        |  37 +++
 .../calcite/sql/fun/SqlStdOperatorTable.java    |   6 +
 .../apache/calcite/sql/type/ReturnTypes.java    |   2 +-
 .../sql2rel/StandardConvertletTable.java        | 250 ++++++++++++++----
 .../apache/calcite/sql/test/SqlAdvisorTest.java |   1 +
 core/src/test/resources/sql/agg.iq              |  97 +++++++
 core/src/test/resources/sql/winagg.iq           | 133 ++++++++++
 site/_docs/reference.md                         |   2 +-
 14 files changed, 761 insertions(+), 71 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/codegen/templates/Parser.jj
----------------------------------------------------------------------
diff --git a/core/src/main/codegen/templates/Parser.jj b/core/src/main/codegen/templates/Parser.jj
index be05d9c..0dc23eb 100644
--- a/core/src/main/codegen/templates/Parser.jj
+++ b/core/src/main/codegen/templates/Parser.jj
@@ -5171,6 +5171,7 @@ SqlIdentifier ReservedFunctionName() :
     |   <PERCENT_RANK>
     |   <POWER>
     |   <RANK>
+    |   <REGR_COUNT>
     |   <REGR_SXX>
     |   <REGR_SYY>
     |   <ROW_NUMBER>

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
index 5ba5959..80d5541 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
@@ -191,6 +191,7 @@ import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RADIANS;
 import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RAND;
 import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RAND_INTEGER;
 import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RANK;
+import static org.apache.calcite.sql.fun.SqlStdOperatorTable.REGR_COUNT;
 import static org.apache.calcite.sql.fun.SqlStdOperatorTable.REINTERPRET;
 import static org.apache.calcite.sql.fun.SqlStdOperatorTable.REPLACE;
 import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ROUND;
@@ -438,6 +439,7 @@ public class RexImpTable {
     map.put(LOCALTIMESTAMP, systemFunctionImplementor);
 
     aggMap.put(COUNT, constructorSupplier(CountImplementor.class));
+    aggMap.put(REGR_COUNT, constructorSupplier(CountImplementor.class));
     aggMap.put(SUM0, constructorSupplier(SumImplementor.class));
     aggMap.put(SUM, constructorSupplier(SumImplementor.class));
     Supplier<MinMaxImplementor> minMax =
@@ -464,6 +466,7 @@ public class RexImpTable {
     winAggMap.put(LAG, constructorSupplier(LagImplementor.class));
     winAggMap.put(NTILE, constructorSupplier(NtileImplementor.class));
     winAggMap.put(COUNT, constructorSupplier(CountWinImplementor.class));
+    winAggMap.put(REGR_COUNT, constructorSupplier(CountWinImplementor.class));
   }
 
   private <T> Supplier<T> constructorSupplier(Class<T> klass) {

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
index 8bdd6c1..68f6b16 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
@@ -72,6 +72,17 @@ import java.util.Map;
  *
  * <li>VAR_SAMP(x) &rarr; (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
  *        / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
+ *
+ * <li>COVAR_POP(x, y) &rarr; (SUM(x * y) - SUM(x, y) * SUM(y, x)
+ *     / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+ *
+ * <li>COVAR_SAMP(x, y) &rarr; (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y))
+ *     / CASE REGR_COUNT(x, y) WHEN 1 THEN NULL ELSE REGR_COUNT(x, y) - 1 END
+ *
+ * <li>REGR_SXX(x, y) &rarr; REGR_COUNT(x, y) * VAR_POP(y)
+ *
+ * <li>REGR_SYY(x, y) &rarr; REGR_COUNT(x, y) * VAR_POP(x)
+ *
  * </ul>
  *
  * <p>Since many of these rewrites introduce multiple occurrences of simpler
@@ -127,7 +138,8 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
    * Returns whether the aggregate call is a reducible function
    */
   private boolean isReducible(final SqlKind kind) {
-    if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)) {
+    if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)
+        || SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(kind)) {
       return true;
     }
     switch (kind) {
@@ -201,6 +213,8 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
       List<RexNode> inputExprs) {
     final SqlKind kind = oldCall.getAggregation().getKind();
     if (isReducible(kind)) {
+      final Integer y;
+      final Integer x;
       switch (kind) {
       case SUM:
         // replace original SUM(x) with
@@ -209,6 +223,37 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
       case AVG:
         // replace original AVG(x) with SUM(x) / COUNT(x)
         return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
+      case COVAR_POP:
+        // replace original COVAR_POP(x, y) with
+        //     (SUM(x * y) - SUM(y) * SUM(y) / COUNT(x))
+        //     / COUNT(x))
+        return reduceCovariance(oldAggRel, oldCall, true, newCalls,
+            aggCallMapping, inputExprs);
+      case COVAR_SAMP:
+        // replace original COVAR_SAMP(x, y) with
+        //   SQRT(
+        //     (SUM(x * y) - SUM(x) * SUM(y) / COUNT(x))
+        //     / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
+        return reduceCovariance(oldAggRel, oldCall, false, newCalls,
+            aggCallMapping, inputExprs);
+      case REGR_SXX:
+        // replace original REGR_SXX(x, y) with
+        // REGR_COUNT(x, y) * VAR_POP(y)
+        assert oldCall.getArgList().size() == 2 : oldCall.getArgList();
+        x = oldCall.getArgList().get(0);
+        y = oldCall.getArgList().get(1);
+        //noinspection SuspiciousNameCombination
+        return reduceRegrSzz(oldAggRel, oldCall, newCalls, aggCallMapping,
+            inputExprs, y, y, x);
+      case REGR_SYY:
+        // replace original REGR_SYY(x, y) with
+        // REGR_COUNT(x, y) * VAR_POP(x)
+        assert oldCall.getArgList().size() == 2 : oldCall.getArgList();
+        x = oldCall.getArgList().get(0);
+        y = oldCall.getArgList().get(1);
+        //noinspection SuspiciousNameCombination
+        return reduceRegrSzz(oldAggRel, oldCall, newCalls, aggCallMapping,
+            inputExprs, x, x, y);
       case STDDEV_POP:
         // replace original STDDEV_POP(x) with
         //   SQRT(
@@ -260,16 +305,17 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
       RelDataType operandType,
       Aggregate oldAggRel,
       AggregateCall oldCall,
-      int argOrdinal) {
+      int argOrdinal,
+      int filter) {
     final Aggregate.AggCallBinding binding =
         new Aggregate.AggCallBinding(typeFactory, aggFunction,
             ImmutableList.of(operandType), oldAggRel.getGroupCount(),
-            oldCall.filterArg >= 0);
+            filter >= 0);
     return AggregateCall.create(aggFunction,
         oldCall.isDistinct(),
         oldCall.isApproximate(),
         ImmutableIntList.of(argOrdinal),
-        oldCall.filterArg,
+        filter,
         aggFunction.inferReturnType(binding),
         null);
   }
@@ -346,6 +392,7 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
         getFieldType(
             oldAggRel.getInput(),
             arg);
+
     final AggregateCall sumZeroCall =
         AggregateCall.create(SqlStdOperatorTable.SUM0, oldCall.isDistinct(),
             oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg,
@@ -424,7 +471,6 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
 
     final RexNode argRef =
         rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), true);
-    final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
 
     final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY,
         argRef, argRef);
@@ -432,7 +478,7 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
 
     final AggregateCall sumArgSquaredAggCall =
         createAggregateCallWithBinding(typeFactory, SqlStdOperatorTable.SUM,
-            argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
+            argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal, -1);
 
     final RexNode sumArgSquared =
         rexBuilder.addAggCall(sumArgSquaredAggCall,
@@ -530,6 +576,207 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
         oldCall.getType(), result);
   }
 
+  private RexNode getSumAggregatedRexNode(Aggregate oldAggRel,
+      AggregateCall oldCall,
+      List<AggregateCall> newCalls,
+      Map<AggregateCall, RexNode> aggCallMapping,
+      RexBuilder rexBuilder,
+      int argOrdinal,
+      int filterArg) {
+    final AggregateCall aggregateCall =
+        AggregateCall.create(SqlStdOperatorTable.SUM,
+            oldCall.isDistinct(),
+            oldCall.isApproximate(),
+            ImmutableIntList.of(argOrdinal),
+            filterArg,
+            oldAggRel.getGroupCount(),
+            oldAggRel.getInput(),
+            null,
+            null);
+    return rexBuilder.addAggCall(aggregateCall,
+        oldAggRel.getGroupCount(),
+        oldAggRel.indicator,
+        newCalls,
+        aggCallMapping,
+        ImmutableList.of(aggregateCall.getType()));
+  }
+
+  private RexNode getSumAggregatedRexNodeWithBinding(Aggregate oldAggRel,
+      AggregateCall oldCall,
+      List<AggregateCall> newCalls,
+      Map<AggregateCall, RexNode> aggCallMapping,
+      RelDataType operandType,
+      int argOrdinal,
+      int filter) {
+    RelOptCluster cluster = oldAggRel.getCluster();
+    final AggregateCall sumArgSquaredAggCall =
+        createAggregateCallWithBinding(cluster.getTypeFactory(),
+            SqlStdOperatorTable.SUM, operandType, oldAggRel, oldCall, argOrdinal, filter);
+
+    return cluster.getRexBuilder().addAggCall(sumArgSquaredAggCall,
+        oldAggRel.getGroupCount(),
+        oldAggRel.indicator,
+        newCalls,
+        aggCallMapping,
+        ImmutableList.of(sumArgSquaredAggCall.getType()));
+  }
+
+  private RexNode getRegrCountRexNode(Aggregate oldAggRel,
+      AggregateCall oldCall,
+      List<AggregateCall> newCalls,
+      Map<AggregateCall, RexNode> aggCallMapping,
+      ImmutableIntList argOrdinals,
+      ImmutableList<RelDataType> operandTypes,
+      int filterArg) {
+    final AggregateCall countArgAggCall =
+        AggregateCall.create(SqlStdOperatorTable.REGR_COUNT,
+            oldCall.isDistinct(),
+            oldCall.isApproximate(),
+            argOrdinals,
+            filterArg,
+            oldAggRel.getGroupCount(),
+            oldAggRel,
+            null,
+            null);
+
+    return oldAggRel.getCluster().getRexBuilder().addAggCall(countArgAggCall,
+        oldAggRel.getGroupCount(),
+        oldAggRel.indicator,
+        newCalls,
+        aggCallMapping,
+        operandTypes);
+  }
+
+  private RexNode reduceRegrSzz(
+      Aggregate oldAggRel,
+      AggregateCall oldCall,
+      List<AggregateCall> newCalls,
+      Map<AggregateCall, RexNode> aggCallMapping,
+      List<RexNode> inputExprs,
+      int xIndex,
+      int yIndex,
+      int nullFilterIndex) {
+    // regr_sxx(x, y) ==>
+    //    sum(y * y, x) - sum(y, x) * sum(y, x) / regr_count(x, y)
+    //
+
+    final RelOptCluster cluster = oldAggRel.getCluster();
+    final RexBuilder rexBuilder = cluster.getRexBuilder();
+    final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+    final RelDataType argXType = getFieldType(oldAggRel.getInput(), xIndex);
+    final RelDataType argYType =
+        xIndex == yIndex ? argXType : getFieldType(oldAggRel.getInput(), yIndex);
+    final RelDataType nullFilterIndexType =
+        nullFilterIndex == yIndex ? argYType : getFieldType(oldAggRel.getInput(), yIndex);
+
+    final RelDataType oldCallType =
+        typeFactory.createTypeWithNullability(oldCall.getType(),
+            argXType.isNullable() || argYType.isNullable() || nullFilterIndexType.isNullable());
+
+    final RexNode argX =
+        rexBuilder.ensureType(oldCallType, inputExprs.get(xIndex), true);
+    final RexNode argY =
+        rexBuilder.ensureType(oldCallType, inputExprs.get(yIndex), true);
+    final RexNode argNullFilter =
+        rexBuilder.ensureType(oldCallType, inputExprs.get(nullFilterIndex), true);
+
+    final RexNode argXArgY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argX, argY);
+    final int argSquaredOrdinal = lookupOrAdd(inputExprs, argXArgY);
+
+    final RexNode argXAndYNotNullFilter = rexBuilder.makeCall(SqlStdOperatorTable.AND,
+        rexBuilder.makeCall(SqlStdOperatorTable.AND,
+            rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argX),
+            rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argY)),
+        rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argNullFilter));
+    final int argXAndYNotNullFilterOrdinal = lookupOrAdd(inputExprs, argXAndYNotNullFilter);
+    final RexNode sumXY = getSumAggregatedRexNodeWithBinding(
+        oldAggRel, oldCall, newCalls, aggCallMapping, argXArgY.getType(),
+        argSquaredOrdinal, argXAndYNotNullFilterOrdinal);
+    final RexNode sumXYCast = rexBuilder.ensureType(oldCallType, sumXY, true);
+
+    final RexNode sumX = getSumAggregatedRexNode(oldAggRel, oldCall,
+        newCalls, aggCallMapping, rexBuilder, xIndex, argXAndYNotNullFilterOrdinal);
+    final RexNode sumY = xIndex == yIndex
+        ? sumX
+        : getSumAggregatedRexNode(oldAggRel, oldCall, newCalls,
+            aggCallMapping, rexBuilder, yIndex, argXAndYNotNullFilterOrdinal);
+
+    final RexNode sumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumX, sumY);
+
+    final RexNode countArg = getRegrCountRexNode(oldAggRel, oldCall, newCalls, aggCallMapping,
+        ImmutableIntList.of(xIndex), ImmutableList.of(argXType), argXAndYNotNullFilterOrdinal);
+
+    RexLiteral zero = rexBuilder.makeExactLiteral(BigDecimal.ZERO);
+    RexNode nul = rexBuilder.constantNull();
+    final RexNode avgSumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.CASE,
+        rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, zero), nul,
+            rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumXSumY, countArg));
+    final RexNode avgSumXSumYCast = rexBuilder.ensureType(oldCallType, avgSumXSumY, true);
+    final RexNode result =
+        rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumXYCast, avgSumXSumYCast);
+    return rexBuilder.makeCast(oldCall.getType(), result);
+  }
+
+  private RexNode reduceCovariance(
+      Aggregate oldAggRel,
+      AggregateCall oldCall,
+      boolean biased,
+      List<AggregateCall> newCalls,
+      Map<AggregateCall, RexNode> aggCallMapping,
+      List<RexNode> inputExprs) {
+    // covar_pop(x, y) ==>
+    //     (sum(x * y) - sum(x) * sum(y) / regr_count(x, y))
+    //     / regr_count(x, y)
+    //
+    // covar_samp(x, y) ==>
+    //     (sum(x * y) - sum(x) * sum(y) / regr_count(x, y))
+    //     / regr_count(count(x, y) - 1, 0)
+    final RelOptCluster cluster = oldAggRel.getCluster();
+    final RexBuilder rexBuilder = cluster.getRexBuilder();
+    final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+    assert oldCall.getArgList().size() == 2 : oldCall.getArgList();
+    final int argXOrdinal = oldCall.getArgList().get(0);
+    final int argYOrdinal = oldCall.getArgList().get(1);
+    final RelDataType argXOrdinalType = getFieldType(oldAggRel.getInput(), argXOrdinal);
+    final RelDataType argYOrdinalType = getFieldType(oldAggRel.getInput(), argYOrdinal);
+    final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(),
+        argXOrdinalType.isNullable() || argYOrdinalType.isNullable());
+    final RexNode argX = rexBuilder.ensureType(oldCallType, inputExprs.get(argXOrdinal), true);
+    final RexNode argY = rexBuilder.ensureType(oldCallType, inputExprs.get(argYOrdinal), true);
+    final RexNode argXAndYNotNullFilter = rexBuilder.makeCall(SqlStdOperatorTable.AND,
+        rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argX),
+        rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argY));
+    final int argXAndYNotNullFilterOrdinal = lookupOrAdd(inputExprs, argXAndYNotNullFilter);
+    final RexNode argXY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argX, argY);
+    final int argXYOrdinal = lookupOrAdd(inputExprs, argXY);
+    final RexNode sumXY = getSumAggregatedRexNodeWithBinding(oldAggRel, oldCall, newCalls,
+        aggCallMapping, argXY.getType(), argXYOrdinal, argXAndYNotNullFilterOrdinal);
+    final RexNode sumX = getSumAggregatedRexNode(oldAggRel, oldCall, newCalls,
+        aggCallMapping, rexBuilder, argXOrdinal, argXAndYNotNullFilterOrdinal);
+    final RexNode sumY = getSumAggregatedRexNode(oldAggRel, oldCall, newCalls,
+        aggCallMapping, rexBuilder, argYOrdinal, argXAndYNotNullFilterOrdinal);
+    final RexNode sumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumX, sumY);
+    final RexNode countArg = getRegrCountRexNode(oldAggRel, oldCall, newCalls, aggCallMapping,
+        ImmutableIntList.of(argXOrdinal, argYOrdinal),
+        ImmutableList.of(argXOrdinalType, argYOrdinalType),
+        argXAndYNotNullFilterOrdinal);
+    final RexNode avgSumSquaredArg =
+         rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumXSumY, countArg);
+    final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumXY, avgSumSquaredArg);
+    final RexNode denominator;
+    if (biased) {
+      denominator = countArg;
+    } else {
+      final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
+      final RexNode nul = rexBuilder.makeCast(countArg.getType(), rexBuilder.constantNull());
+      final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
+      final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
+      denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
+    }
+    final RexNode result = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator);
+    return rexBuilder.makeCast(oldCall.getType(), result);
+  }
+
   /**
    * Finds the ordinal of an element in a list, or adds it.
    *

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/SqlKind.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
index 54b390b..cbf201f 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
@@ -810,14 +810,11 @@ public enum SqlKind {
   /** The {@code GROUP_ID()} function. */
   GROUP_ID,
 
-  /**
-   * the internal permute function in match_recognize cluse
-   */
+  /** The internal "permute" function in a MATCH_RECOGNIZE clause. */
   PATTERN_PERMUTE,
 
-  /**
-   * the special patterns to exclude enclosing pattern from output in match_recognize clause
-   */
+  /** The special patterns to exclude enclosing pattern from output in a
+   * MATCH_RECOGNIZE clause. */
   PATTERN_EXCLUDED,
 
   // Aggregate functions
@@ -858,6 +855,9 @@ public enum SqlKind {
   /** The {@code COVAR_SAMP} aggregate function. */
   COVAR_SAMP,
 
+  /** The {@code REGR_COUNT} aggregate function. */
+  REGR_COUNT,
+
   /** The {@code REGR_SXX} aggregate function. */
   REGR_SXX,
 
@@ -1064,7 +1064,7 @@ public enum SqlKind {
    */
   public static final EnumSet<SqlKind> AGGREGATE =
       EnumSet.of(COUNT, SUM, SUM0, MIN, MAX, LEAD, LAG, FIRST_VALUE,
-          LAST_VALUE, COVAR_POP, COVAR_SAMP, REGR_SXX, REGR_SYY,
+          LAST_VALUE, COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY,
           AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP, NTILE, COLLECT,
           FUSION, SINGLE_VALUE, ROW_NUMBER, RANK, PERCENT_RANK, DENSE_RANK,
           CUME_DIST);
@@ -1180,6 +1180,15 @@ public enum SqlKind {
       EnumSet.of(AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP);
 
   /**
+   * Category of SqlCovarAggFunction.
+   *
+   * <p>Consists of {@link #COVAR_POP}, {@link #COVAR_SAMP}, {@link #REGR_SXX},
+   * {@link #REGR_SYY}.
+   */
+  public static final Set<SqlKind> COVAR_AVG_AGG_FUNCTIONS =
+      EnumSet.of(COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY);
+
+  /**
    * Category of comparison operators.
    *
    * <p>Consists of:

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
index e053294..db54102 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
@@ -26,6 +26,7 @@ import org.apache.calcite.sql.SqlSplittableAggFunction;
 import org.apache.calcite.sql.SqlSyntax;
 import org.apache.calcite.sql.type.OperandTypes;
 import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.sql.type.SqlOperandTypeChecker;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.validate.SqlValidator;
 import org.apache.calcite.sql.validate.SqlValidatorScope;
@@ -45,11 +46,12 @@ public class SqlCountAggFunction extends SqlAggFunction {
   //~ Constructors -----------------------------------------------------------
 
   public SqlCountAggFunction(String name) {
+    this(name, SqlValidator.STRICT ? OperandTypes.ANY : OperandTypes.ONE_OR_MORE);
+  }
+
+  public SqlCountAggFunction(String name, SqlOperandTypeChecker sqlOperandTypeChecker) {
     super(name, null, SqlKind.COUNT, ReturnTypes.BIGINT, null,
-        SqlValidator.STRICT
-            ? OperandTypes.ANY
-            : OperandTypes.ONE_OR_MORE,
-        SqlFunctionCategory.NUMERIC, false, false);
+        sqlOperandTypeChecker, SqlFunctionCategory.NUMERIC, false, false);
   }
 
   //~ Methods ----------------------------------------------------------------

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
index 8c62290..8591959 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
@@ -43,16 +43,14 @@ public class SqlCovarAggFunction extends SqlAggFunction {
     super(kind.name(),
         null,
         kind,
-        ReturnTypes.COVAR_FUNCTION,
+        kind == SqlKind.REGR_COUNT ? ReturnTypes.BIGINT : ReturnTypes.COVAR_REGR_FUNCTION,
         null,
         OperandTypes.NUMERIC_NUMERIC,
         SqlFunctionCategory.NUMERIC,
         false,
         false);
-    Preconditions.checkArgument(kind == SqlKind.COVAR_POP
-        || kind == SqlKind.COVAR_SAMP
-        || kind == SqlKind.REGR_SXX
-        || kind == SqlKind.REGR_SYY);
+    Preconditions.checkArgument(SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(kind),
+        "unsupported sql kind: " + kind);
   }
 
   @Deprecated // to be removed before 2.0

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java
new file mode 100644
index 0000000..4408272
--- /dev/null
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.sql.fun;
+
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.type.OperandTypes;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Definition of the SQL <code>REGR_COUNT</code> aggregation function.
+ *
+ * <p><code>REGR_COUNT</code> is an aggregator which returns the number of rows which
+ * have gone into it and both arguments are not <code>null</code>.
+ */
+public class SqlRegrCountAggFunction extends SqlCountAggFunction {
+  public SqlRegrCountAggFunction(SqlKind kind) {
+    super("REGR_COUNT", OperandTypes.NUMERIC_NUMERIC);
+    Preconditions.checkArgument(SqlKind.REGR_COUNT == kind, "unsupported sql kind: " + kind);
+  }
+}
+
+// End SqlRegrCountAggFunction.java

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
index 0064cba..ea17ec8 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
@@ -915,6 +915,12 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable {
       new SqlAvgAggFunction(SqlKind.STDDEV_POP);
 
   /**
+   * <code>REGR_COUNT</code> aggregate function.
+   */
+  public static final SqlAggFunction REGR_COUNT =
+      new SqlRegrCountAggFunction(SqlKind.REGR_COUNT);
+
+  /**
    * <code>REGR_SXX</code> aggregate function.
    */
   public static final SqlAggFunction REGR_SXX =

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
index 8b07b83..fc0022d 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
@@ -766,7 +766,7 @@ public abstract class ReturnTypes {
     }
   };
 
-  public static final SqlReturnTypeInference COVAR_FUNCTION = opBinding -> {
+  public static final SqlReturnTypeInference COVAR_REGR_FUNCTION = opBinding -> {
     final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
     final RelDataType relDataType =
         typeFactory.getTypeSystem().deriveCovarType(typeFactory,

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
index e9b7cf6..987e821 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
@@ -67,6 +67,7 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
 import org.apache.calcite.sql.type.SqlTypeName;
 import org.apache.calcite.sql.type.SqlTypeUtil;
 import org.apache.calcite.sql.validate.SqlValidator;
+import org.apache.calcite.sql.validate.SqlValidatorImpl;
 import org.apache.calcite.util.Pair;
 import org.apache.calcite.util.Util;
 
@@ -77,6 +78,7 @@ import java.math.BigDecimal;
 import java.math.RoundingMode;
 import java.util.ArrayList;
 import java.util.List;
+import java.util.Objects;
 
 /**
  * Standard implementation of {@link SqlRexConvertletTable}.
@@ -237,6 +239,14 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
         new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
     registerOp(SqlStdOperatorTable.VARIANCE,
         new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
+    registerOp(SqlStdOperatorTable.COVAR_POP,
+        new RegrCovarianceConvertlet(SqlKind.COVAR_POP));
+    registerOp(SqlStdOperatorTable.COVAR_SAMP,
+        new RegrCovarianceConvertlet(SqlKind.COVAR_SAMP));
+    registerOp(SqlStdOperatorTable.REGR_SXX,
+        new RegrCovarianceConvertlet(SqlKind.REGR_SXX));
+    registerOp(SqlStdOperatorTable.REGR_SYY,
+        new RegrCovarianceConvertlet(SqlKind.REGR_SYY));
 
     final SqlRexConvertlet floorCeilConvertlet = new FloorCeilConvertlet();
     registerOp(SqlStdOperatorTable.FLOOR, floorCeilConvertlet);
@@ -342,14 +352,26 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
     SqlNodeList thenList = call.getThenOperands();
     assert whenList.size() == thenList.size();
 
+    RexBuilder rexBuilder = cx.getRexBuilder();
     final List<RexNode> exprList = new ArrayList<>();
     for (int i = 0; i < whenList.size(); i++) {
-      exprList.add(cx.convertExpression(whenList.get(i)));
-      exprList.add(cx.convertExpression(thenList.get(i)));
+      if (SqlUtil.isNullLiteral(whenList.get(i), false)) {
+        exprList.add(rexBuilder.constantNull());
+      } else {
+        exprList.add(cx.convertExpression(whenList.get(i)));
+      }
+      if (SqlUtil.isNullLiteral(thenList.get(i), false)) {
+        exprList.add(rexBuilder.constantNull());
+      } else {
+        exprList.add(cx.convertExpression(thenList.get(i)));
+      }
+    }
+    if (SqlUtil.isNullLiteral(call.getElseOperand(), false)) {
+      exprList.add(rexBuilder.constantNull());
+    } else {
+      exprList.add(cx.convertExpression(call.getElseOperand()));
     }
-    exprList.add(cx.convertExpression(call.getElseOperand()));
 
-    RexBuilder rexBuilder = cx.getRexBuilder();
     RelDataType type =
         rexBuilder.deriveReturnType(call.getOperator(), exprList);
     for (int i : elseArgs(exprList.size())) {
@@ -473,11 +495,13 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
       return castToValidatedType(cx, call, cx.convertExpression(left));
     }
     SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
+    RelDataType type = dataType.deriveType(typeFactory);
     if (SqlUtil.isNullLiteral(left, false)) {
+      final SqlValidatorImpl validator = (SqlValidatorImpl) cx.getValidator();
+      validator.setValidatedNodeType(left, type);
       return cx.convertExpression(left);
     }
     RexNode arg = cx.convertExpression(left);
-    RelDataType type = dataType.deriveType(typeFactory);
     if (type == null) {
       type = cx.getValidator().getValidatedNodeType(dataType.getTypeName());
     }
@@ -1061,6 +1085,133 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
     return rexBuilder.makeCast(type, e);
   }
 
+  /** Convertlet that handles {@code COVAR_POP}, {@code COVAR_SAMP},
+   * {@code REGR_SXX}, {@code REGR_SYY} windowed aggregate functions.
+   */
+  private static class RegrCovarianceConvertlet implements SqlRexConvertlet {
+    private final SqlKind kind;
+
+    RegrCovarianceConvertlet(SqlKind kind) {
+      this.kind = kind;
+    }
+
+    public RexNode convertCall(SqlRexContext cx, SqlCall call) {
+      assert call.operandCount() == 2;
+      final SqlNode arg1 = call.operand(0);
+      final SqlNode arg2 = call.operand(1);
+      final SqlNode expr;
+      final RelDataType type =
+          cx.getValidator().getValidatedNodeType(call);
+      switch (kind) {
+      case COVAR_POP:
+        expr = expandCovariance(arg1, arg2, null, type, cx, true);
+        break;
+      case COVAR_SAMP:
+        expr = expandCovariance(arg1, arg2, null, type, cx, false);
+        break;
+      case REGR_SXX:
+        expr = expandRegrSzz(arg2, arg1, type, cx, true);
+        break;
+      case REGR_SYY:
+        expr = expandRegrSzz(arg1, arg2, type, cx, true);
+        break;
+      default:
+        throw Util.unexpected(kind);
+      }
+      RexNode rex = cx.convertExpression(expr);
+      return cx.getRexBuilder().ensureType(type, rex, true);
+    }
+
+    private SqlNode expandRegrSzz(
+        final SqlNode arg1, final SqlNode arg2,
+        final RelDataType avgType, final SqlRexContext cx, boolean variance) {
+      final SqlParserPos pos = SqlParserPos.ZERO;
+      final SqlNode count =
+          SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg1, arg2);
+      final SqlNode varPop =
+          expandCovariance(arg1, variance ? arg1 : arg2, arg2, avgType, cx, true);
+      final RexNode varPopRex = cx.convertExpression(varPop);
+      final SqlNode varPopCast;
+      varPopCast = getCastedSqlNode(varPop, avgType, pos, varPopRex);
+      return SqlStdOperatorTable.MULTIPLY.createCall(pos, varPopCast, count);
+    }
+
+    private SqlNode expandCovariance(
+        final SqlNode arg0Input,
+        final SqlNode arg1Input,
+        final SqlNode dependent,
+        final RelDataType varType,
+        final SqlRexContext cx,
+        boolean biased) {
+      // covar_pop(x1, x2) ==>
+      //     (sum(x1 * x2) - sum(x2) * sum(x1) / count(x1, x2))
+      //     / count(x1, x2)
+      //
+      // covar_samp(x1, x2) ==>
+      //     (sum(x1 * x2) - sum(x1) * sum(x2) / count(x1, x2))
+      //     / (count(x1, x2) - 1)
+      final SqlParserPos pos = SqlParserPos.ZERO;
+      final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
+
+      final RexNode arg0Rex = cx.convertExpression(arg0Input);
+      final RexNode arg1Rex = cx.convertExpression(arg1Input);
+
+      final SqlNode arg0 = getCastedSqlNode(arg0Input, varType, pos, arg0Rex);
+      final SqlNode arg1 = getCastedSqlNode(arg1Input, varType, pos, arg1Rex);
+      final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg0, arg1);
+      final SqlNode sumArgSquared;
+      final SqlNode sum0;
+      final SqlNode sum1;
+      final SqlNode count;
+      if (dependent == null) {
+        sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared);
+        sum0 = SqlStdOperatorTable.SUM.createCall(pos, arg0, arg1);
+        sum1 = SqlStdOperatorTable.SUM.createCall(pos, arg1, arg0);
+        count = SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg0, arg1);
+      } else {
+        sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared, dependent);
+        sum0 = SqlStdOperatorTable.SUM.createCall(
+            pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent);
+        sum1 = SqlStdOperatorTable.SUM.createCall(
+            pos, arg1, Objects.equals(dependent, arg1Input) ? arg0 : dependent);
+        count = SqlStdOperatorTable.REGR_COUNT.createCall(
+            pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent);
+      }
+
+      final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum0, sum1);
+      final SqlNode countCasted =
+          getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
+
+      final SqlNode avgSumSquared =
+          SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, countCasted);
+      final SqlNode diff = SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared);
+      SqlNode denominator;
+      if (biased) {
+        denominator = countCasted;
+      } else {
+        final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
+        denominator = new SqlCase(SqlParserPos.ZERO, countCasted,
+            SqlNodeList.of(SqlStdOperatorTable.EQUALS.createCall(pos, countCasted, one)),
+            SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)),
+            SqlStdOperatorTable.MINUS.createCall(pos, countCasted, one));
+      }
+
+      return SqlStdOperatorTable.DIVIDE.createCall(pos, diff, denominator);
+    }
+
+    private SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType,
+        SqlParserPos pos, RexNode argRex) {
+      SqlNode arg;
+      if (argRex != null && !argRex.getType().equals(varType)) {
+        arg = SqlStdOperatorTable.CAST.createCall(
+            pos, argInput, SqlTypeUtil.convertTypeToSpec(varType));
+      } else {
+        arg = argInput;
+      }
+      return arg;
+    }
+  }
+
   /** Convertlet that handles {@code AVG} and {@code VARIANCE}
    * windowed aggregate functions. */
   private static class AvgVarianceConvertlet implements SqlRexConvertlet {
@@ -1106,14 +1257,7 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
           SqlStdOperatorTable.SUM.createCall(pos, arg);
       final RexNode sumRex = cx.convertExpression(sum);
       final SqlNode sumCast;
-      if (!sumRex.getType().equals(avgType)) {
-        sumCast = SqlStdOperatorTable.CAST.createCall(pos,
-            new SqlDataTypeSpec(
-                new SqlIdentifier(avgType.getSqlTypeName().getName(), pos),
-                avgType.getPrecision(), avgType.getScale(), null, null, pos));
-      } else {
-        sumCast = sum;
-      }
+      sumCast = getCastedSqlNode(sum, avgType, pos, sumRex);
       final SqlNode count =
           SqlStdOperatorTable.COUNT.createCall(pos, arg);
       return SqlStdOperatorTable.DIVIDE.createCall(
@@ -1147,54 +1291,66 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
       //     / (count(x) - 1)
       final SqlParserPos pos = SqlParserPos.ZERO;
 
-      final RexNode argRex = cx.convertExpression(argInput);
-      final SqlNode arg;
-      if (!argRex.getType().equals(varType)) {
-        arg = SqlStdOperatorTable.CAST.createCall(pos,
-            new SqlDataTypeSpec(new SqlIdentifier(varType.getSqlTypeName().getName(), pos),
-                varType.getPrecision(), varType.getScale(), null, null, pos));
-      } else {
-        arg = argInput;
-      }
+      final SqlNode arg = getCastedSqlNode(argInput, varType, pos, cx.convertExpression(argInput));
 
-      final SqlNode argSquared =
-          SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
-      final SqlNode sumArgSquared =
-          SqlStdOperatorTable.SUM.createCall(pos, argSquared);
-      final SqlNode sum =
-          SqlStdOperatorTable.SUM.createCall(pos, arg);
+      final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
+      final SqlNode argSquaredCasted =
+          getCastedSqlNode(argSquared, varType, pos, cx.convertExpression(argSquared));
+      final SqlNode sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquaredCasted);
+      final SqlNode sumArgSquaredCasted =
+          getCastedSqlNode(sumArgSquared, varType, pos, cx.convertExpression(sumArgSquared));
+      final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
+      final SqlNode sumCasted = getCastedSqlNode(sum, varType, pos, cx.convertExpression(sum));
       final SqlNode sumSquared =
-          SqlStdOperatorTable.MULTIPLY.createCall(pos, sum, sum);
-      final SqlNode count =
-          SqlStdOperatorTable.COUNT.createCall(pos, arg);
+          SqlStdOperatorTable.MULTIPLY.createCall(pos, sumCasted, sumCasted);
+      final SqlNode sumSquaredCasted =
+          getCastedSqlNode(sumSquared, varType, pos, cx.convertExpression(sumSquared));
+      final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
+      final SqlNode countCasted =
+          getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
       final SqlNode avgSumSquared =
-          SqlStdOperatorTable.DIVIDE.createCall(
-              pos, sumSquared, count);
+          SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquaredCasted, countCasted);
+      final SqlNode avgSumSquaredCasted =
+          getCastedSqlNode(avgSumSquared, varType, pos, cx.convertExpression(avgSumSquared));
       final SqlNode diff =
-          SqlStdOperatorTable.MINUS.createCall(
-              pos, sumArgSquared, avgSumSquared);
+          SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquaredCasted, avgSumSquaredCasted);
+      final SqlNode diffCasted =
+          getCastedSqlNode(diff, varType, pos, cx.convertExpression(diff));
       final SqlNode denominator;
       if (biased) {
-        denominator = count;
+        denominator = countCasted;
       } else {
-        final SqlNumericLiteral one =
-            SqlLiteral.createExactNumeric("1", pos);
-        denominator =
-            SqlStdOperatorTable.MINUS.createCall(
-                pos, count, one);
+        final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
+        final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
+        denominator = new SqlCase(SqlParserPos.ZERO,
+            count,
+            SqlNodeList.of(SqlStdOperatorTable.EQUALS.createCall(pos, count, one)),
+            SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)),
+            SqlStdOperatorTable.MINUS.createCall(pos, count, one));
       }
       final SqlNode div =
-          SqlStdOperatorTable.DIVIDE.createCall(
-              pos, diff, denominator);
+          SqlStdOperatorTable.DIVIDE.createCall(pos, diffCasted, denominator);
+      final SqlNode divCasted = getCastedSqlNode(div, varType, pos, cx.convertExpression(div));
+
       SqlNode result = div;
       if (sqrt) {
-        final SqlNumericLiteral half =
-            SqlLiteral.createExactNumeric("0.5", pos);
-        result =
-            SqlStdOperatorTable.POWER.createCall(pos, div, half);
+        final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos);
+        result = SqlStdOperatorTable.POWER.createCall(pos, divCasted, half);
       }
       return result;
     }
+
+    private SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType,
+        SqlParserPos pos, RexNode argRex) {
+      SqlNode arg;
+      if (argRex != null && !argRex.getType().equals(varType)) {
+        arg = SqlStdOperatorTable.CAST.createCall(
+            pos, argInput, SqlTypeUtil.convertTypeToSpec(varType));
+      } else {
+        arg = argInput;
+      }
+      return arg;
+    }
   }
 
   /** Convertlet that converts {@code LTRIM} and {@code RTRIM} to

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java b/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java
index 0742731..634c5ff 100644
--- a/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java
+++ b/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java
@@ -183,6 +183,7 @@ public class SqlAdvisorTest extends SqlValidatorTestCase {
           "KEYWORD(POWER)",
           "KEYWORD(PREV)",
           "KEYWORD(RANK)",
+          "KEYWORD(REGR_COUNT)",
           "KEYWORD(REGR_SXX)",
           "KEYWORD(REGR_SYY)",
           "KEYWORD(ROW)",

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/test/resources/sql/agg.iq
----------------------------------------------------------------------
diff --git a/core/src/test/resources/sql/agg.iq b/core/src/test/resources/sql/agg.iq
index 6c26a89..997ac94 100755
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -2284,4 +2284,101 @@ EnumerableCalc(expr#0..1=[{inputs}], ANYEMPNO=[$t1])
     EnumerableTableScan(table=[[scott, EMP]])
 !plan
 
+# [CALCITE-1776, CALCITE-2402] REGR_COUNT
+SELECT regr_count(COMM, SAL) as "REGR_COUNT(COMM, SAL)",
+   regr_count(EMPNO, SAL) as "REGR_COUNT(EMPNO, SAL)"
+from "scott".emp;
++-----------------------+------------------------+
+| REGR_COUNT(COMM, SAL) | REGR_COUNT(EMPNO, SAL) |
++-----------------------+------------------------+
+|                     4 |                     14 |
++-----------------------+------------------------+
+(1 row)
+
+!ok
+
+EnumerableAggregate(group=[{}], REGR_COUNT(COMM, SAL)=[REGR_COUNT($6, $5)], REGR_COUNT(EMPNO, SAL)=[REGR_COUNT($5)])
+  EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# [CALCITE-1776, CALCITE-2402] REGR_SXX, REGR_SXY, REGR_SYY
+SELECT
+  regr_sxx(COMM, SAL) as "REGR_SXX(COMM, SAL)",
+  regr_syy(COMM, SAL) as "REGR_SYY(COMM, SAL)",
+  regr_sxx(SAL, COMM) as "REGR_SXX(SAL, COMM)",
+  regr_syy(SAL, COMM) as "REGR_SYY(SAL, COMM)"
+from "scott".emp;
++---------------------+---------------------+---------------------+---------------------+
+| REGR_SXX(COMM, SAL) | REGR_SYY(COMM, SAL) | REGR_SXX(SAL, COMM) | REGR_SYY(SAL, COMM) |
++---------------------+---------------------+---------------------+---------------------+
+|          95000.0000 |        1090000.0000 |        1090000.0000 |          95000.0000 |
++---------------------+---------------------+---------------------+---------------------+
+(1 row)
+
+!ok
+
+# [CALCITE-1776, CALCITE-2402] COVAR_POP, COVAR_SAMP, VAR_SAMP, VAR_POP
+SELECT
+  covar_pop(COMM, COMM) as "COVAR_POP(COMM, COMM)",
+  covar_samp(SAL, SAL) as "COVAR_SAMP(SAL, SAL)",
+  var_pop(COMM) as "VAR_POP(COMM)",
+  var_samp(SAL) as "VAR_SAMP(SAL)"
+from "scott".emp;
++-----------------------+----------------------+---------------+-------------------+
+| COVAR_POP(COMM, COMM) | COVAR_SAMP(SAL, SAL) | VAR_POP(COMM) | VAR_SAMP(SAL)     |
++-----------------------+----------------------+---------------+-------------------+
+|           272500.0000 |    1398313.873626374 |   272500.0000 | 1398313.873626374 |
++-----------------------+----------------------+---------------+-------------------+
+(1 row)
+
+!ok
+
+# [CALCITE-1776, CALCITE-2402] REGR_COUNT with group by
+SELECT SAL, regr_count(COMM, SAL) as "REGR_COUNT(COMM, SAL)",
+   regr_count(EMPNO, SAL) as "REGR_COUNT(EMPNO, SAL)"
+from "scott".emp group by SAL;
++---------+-----------------------+------------------------+
+| SAL     | REGR_COUNT(COMM, SAL) | REGR_COUNT(EMPNO, SAL) |
++---------+-----------------------+------------------------+
+| 1100.00 |                     0 |                      1 |
+| 1250.00 |                     2 |                      2 |
+| 1300.00 |                     0 |                      1 |
+| 1500.00 |                     1 |                      1 |
+| 1600.00 |                     1 |                      1 |
+| 2450.00 |                     0 |                      1 |
+| 2850.00 |                     0 |                      1 |
+| 2975.00 |                     0 |                      1 |
+| 3000.00 |                     0 |                      2 |
+| 5000.00 |                     0 |                      1 |
+|  800.00 |                     0 |                      1 |
+|  950.00 |                     0 |                      1 |
++---------+-----------------------+------------------------+
+(12 rows)
+
+!ok
+
+# [CALCITE-1776, CALCITE-2402] COVAR_POP, COVAR_SAMP, VAR_SAMP, VAR_POP with group by
+SELECT
+  MONTH(HIREDATE) as "MONTH",
+  covar_samp(SAL, COMM) as "COVAR_SAMP(SAL, COMM)",
+  var_pop(COMM) as "VAR_POP(COMM)",
+  var_samp(SAL) as "VAR_SAMP(SAL)"
+from "scott".emp
+group by MONTH(HIREDATE);
++-------+-----------------------+---------------+-------------------+
+| MONTH | COVAR_SAMP(SAL, COMM) | VAR_POP(COMM) | VAR_SAMP(SAL)     |
++-------+-----------------------+---------------+-------------------+
+|     1 |                       |               |      1201250.0000 |
+|    11 |                       |               |                   |
+|    12 |                       |               | 1510833.333333334 |
+|     2 |           -35000.0000 |    10000.0000 |  831458.333333335 |
+|     4 |                       |               |                   |
+|     5 |                       |               |                   |
+|     6 |                       |               |                   |
+|     9 |          -175000.0000 |   490000.0000 |        31250.0000 |
++-------+-----------------------+---------------+-------------------+
+(8 rows)
+
+!ok
+
 # End agg.iq

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/test/resources/sql/winagg.iq
----------------------------------------------------------------------
diff --git a/core/src/test/resources/sql/winagg.iq b/core/src/test/resources/sql/winagg.iq
index eac5822..fbd0dde 100644
--- a/core/src/test/resources/sql/winagg.iq
+++ b/core/src/test/resources/sql/winagg.iq
@@ -455,4 +455,137 @@ from emp order by emp."ENAME";
 
 !ok
 
+# [CALCITE-2402] COVAR_POP, REGR_COUNT functions
+# SUM(x, y) = SUM(x) WHERE y IS NOT NULL
+# COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+select emps."AGE", emps."DEPTNO",
+ sum(emps."AGE" * emps."DEPTNO") over() as "sum(age * deptno)",
+ regr_count(emps."AGE", emps."DEPTNO") over() as "regr_count(age, deptno)",
+ covar_pop(emps."DEPTNO", emps."AGE") over() as "covar_pop"
+from emps order by emps."AGE";
++-----+--------+-------------------+-------------------------+-----------+
+| AGE | DEPTNO | sum(age * deptno) | regr_count(age, deptno) | covar_pop |
++-----+--------+-------------------+-------------------------+-----------+
+|   5 |     20 |              1950 |                       3 |        39 |
+|  25 |     10 |              1950 |                       3 |        39 |
+|  80 |     20 |              1950 |                       3 |        39 |
+|     |     40 |              1950 |                       3 |        39 |
+|     |     40 |              1950 |                       3 |        39 |
++-----+--------+-------------------+-------------------------+-----------+
+(5 rows)
+
+!ok
+
+# [CALCITE-2402] COVAR_POP, REGR_COUNT functions
+# SUM(x, y) = SUM(x) WHERE y IS NOT NULL
+# COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+select emps."AGE", emps."DEPTNO", emps."GENDER",
+ sum(emps."AGE" * emps."DEPTNO") over(partition by emps."GENDER") as "sum(age * deptno)",
+ regr_count(emps."AGE", emps."DEPTNO") over(partition by emps."GENDER") as "regr_count(age, deptno)",
+ covar_pop(emps."DEPTNO", emps."AGE") over(partition by emps."GENDER") as "covar_pop"
+from emps order by emps."GENDER";
++-----+--------+--------+-------------------+-------------------------+-----------+
+| AGE | DEPTNO | GENDER | sum(age * deptno) | regr_count(age, deptno) | covar_pop |
++-----+--------+--------+-------------------+-------------------------+-----------+
+|   5 |     20 | F      |               100 |                       1 |         0 |
+|     |     40 | F      |               100 |                       1 |         0 |
+|  80 |     20 | M      |              1600 |                       1 |         0 |
+|     |     40 | M      |              1600 |                       1 |         0 |
+|  25 |     10 |        |               250 |                       1 |         0 |
++-----+--------+--------+-------------------+-------------------------+-----------+
+(5 rows)
+
+!ok
+
+# [CALCITE-2402] COVAR_SAMP functions
+# SUM(x, y) = SUM(x) WHERE y IS NOT NULL
+# COVAR_SAMP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / (REGR_COUNT(x, y) - 1)
+select emps."AGE", emps."DEPTNO", emps."GENDER",
+ covar_samp(emps."AGE", emps."AGE") over() as "var_samp",
+ covar_samp(emps."DEPTNO", emps."AGE") over() as "covar_samp",
+ covar_samp(emps."EMPNO", emps."DEPTNO") over(partition by emps."MANAGER") as "covar_samp partitioned"
+from emps order by emps."AGE";
++-----+--------+--------+----------+------------+------------------------+
+| AGE | DEPTNO | GENDER | var_samp | covar_samp | covar_samp partitioned |
++-----+--------+--------+----------+------------+------------------------+
+|   5 |     20 | F      |     1508 |         58 |                      0 |
+|  25 |     10 |        |     1508 |         58 |                     50 |
+|  80 |     20 | M      |     1508 |         58 |                     50 |
+|     |     40 | M      |     1508 |         58 |                      0 |
+|     |     40 | F      |     1508 |         58 |                      0 |
++-----+--------+--------+----------+------------+------------------------+
+(5 rows)
+
+!ok
+
+# [CALCITE-2402] VAR_POP, VAR_SAMP functions
+# VAR_POP(x) = (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / COUNT(x)
+# VAR_SAMP(x) = (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / (COUNT(x) - 1)
+select emps."AGE", emps."DEPTNO", emps."GENDER",
+ var_pop(emps."AGE") over() as "var_pop",
+ var_pop(emps."AGE") over(partition by emps."AGE") as "var_pop by age",
+ var_samp(emps."AGE") over() as "var_samp",
+ var_samp(emps."AGE") over(partition by emps."GENDER") as "var_samp by gender"
+from emps order by emps."AGE";
++-----+--------+--------+---------+----------------+----------+--------------------+
+| AGE | DEPTNO | GENDER | var_pop | var_pop by age | var_samp | var_samp by gender |
++-----+--------+--------+---------+----------------+----------+--------------------+
+|   5 |     20 | F      |    1005 |              0 |     1508 |                    |
+|  25 |     10 |        |    1005 |              0 |     1508 |                    |
+|  80 |     20 | M      |    1005 |              0 |     1508 |                    |
+|     |     40 | F      |    1005 |                |     1508 |                    |
+|     |     40 | M      |    1005 |                |     1508 |                    |
++-----+--------+--------+---------+----------------+----------+--------------------+
+(5 rows)
+
+!ok
+
+# [CALCITE-2402] REGR_SXX, REGR_SXY, REGR_SYY functions
+# SUM(x, y) = SUM(x) WHERE y IS NOT NULL
+# REGR_SXX(x, y) = REGR_COUNT(x, y) * VAR_POP(y, y)
+# REGR_SXY(x, y) = REGR_COUNT(x, y) * COVAR_POP(x, y)
+# REGR_SYY(x, y) = REGR_COUNT(x, y) * VAR_POP(x, x)
+## COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+## VAR_POP(y, y) = (SUM(y * y, x) - SUM(y, x) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+select emps."AGE", emps."DEPTNO",
+ regr_sxx(emps."AGE", emps."DEPTNO") over() as "regr_sxx(age, deptno)",
+ regr_syy(emps."AGE", emps."DEPTNO") over() as "regr_syy(age, deptno)"
+from emps order by emps."AGE";
++-----+--------+-----------------------+-----------------------+
+| AGE | DEPTNO | regr_sxx(age, deptno) | regr_syy(age, deptno) |
++-----+--------+-----------------------+-----------------------+
+|   5 |     20 |                    66 |                  3015 |
+|  25 |     10 |                    66 |                  3015 |
+|  80 |     20 |                    66 |                  3015 |
+|     |     40 |                    66 |                  3015 |
+|     |     40 |                    66 |                  3015 |
++-----+--------+-----------------------+-----------------------+
+(5 rows)
+
+!ok
+
+# [CALCITE-2402] REGR_SXX, REGR_SXY, REGR_SYY functions
+# SUM(x, y) = SUM(x) WHERE y IS NOT NULL
+# REGR_SXX(x, y) = REGR_COUNT(x, y) * COVAR_POP(y, y)
+# REGR_SXY(x, y) = REGR_COUNT(x, y) * COVAR_POP(x, y)
+# REGR_SYY(x, y) = REGR_COUNT(x, y) * COVAR_POP(x, x)
+## COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+## COVAR_POP(y, y) = (SUM(y * y, x) - SUM(y, x) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+select emps."AGE", emps."DEPTNO", emps."GENDER",
+ regr_sxx(emps."AGE", emps."DEPTNO") over(partition by emps."GENDER") as "regr_sxx(age, deptno)",
+ regr_syy(emps."AGE", emps."DEPTNO") over(partition by emps."GENDER") as "regr_syy(age, deptno)"
+from emps order by emps."GENDER";
++-----+--------+--------+-----------------------+-----------------------+
+| AGE | DEPTNO | GENDER | regr_sxx(age, deptno) | regr_syy(age, deptno) |
++-----+--------+--------+-----------------------+-----------------------+
+|   5 |     20 | F      |                     0 |                     0 |
+|     |     40 | F      |                     0 |                     0 |
+|  80 |     20 | M      |                     0 |                     0 |
+|     |     40 | M      |                     0 |                     0 |
+|  25 |     10 |        |                     0 |                     0 |
++-----+--------+--------+-----------------------+-----------------------+
+(5 rows)
+
+!ok
+
 # End winagg.iq

http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/site/_docs/reference.md
----------------------------------------------------------------------
diff --git a/site/_docs/reference.md b/site/_docs/reference.md
index acf8f29..e76e7d6 100644
--- a/site/_docs/reference.md
+++ b/site/_docs/reference.md
@@ -1510,6 +1510,7 @@ passed to the aggregate function.
 | VAR_SAMP( [ ALL &#124; DISTINCT ] numeric)    | Returns the sample variance (square of the sample standard deviation) of *numeric* across all input values
 | COVAR_POP(numeric1, numeric2)      | Returns the population covariance of the pair (*numeric1*, *numeric2*) across all input values
 | COVAR_SAMP(numeric1, numeric2)     | Returns the sample covariance of the pair (*numeric1*, *numeric2*) across all input values
+| REGR_COUNT(numeric1, numeric2)     | Returns the number of rows where both dependent and independent expressions are not null
 | REGR_SXX(numeric1, numeric2)       | Returns the sum of squares of the dependent expression in a linear regression model
 | REGR_SYY(numeric1, numeric2)       | Returns the sum of squares of the independent expression in a linear regression model
 
@@ -1517,7 +1518,6 @@ Not implemented:
 
 * REGR_AVGX(numeric1, numeric2)
 * REGR_AVGY(numeric1, numeric2)
-* REGR_COUNT(numeric1, numeric2)
 * REGR_INTERCEPT(numeric1, numeric2)
 * REGR_R2(numeric1, numeric2)
 * REGR_SLOPE(numeric1, numeric2)