You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by jh...@apache.org on 2018/08/13 01:06:04 UTC
[4/4] calcite git commit: [CALCITE-2402] Implement regr functions:
COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY
[CALCITE-2402] Implement regr functions: COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY
Use filters in case of AggregateReduceFunctionsRule expansions.
Close apache/calcite#779
Project: http://git-wip-us.apache.org/repos/asf/calcite/repo
Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/ca858dd7
Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/ca858dd7
Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/ca858dd7
Branch: refs/heads/master
Commit: ca858dd725dea6bf9b4a9059cf1c3ba98bd82f26
Parents: 5574873
Author: snuyanzin <sn...@gmail.com>
Authored: Fri Jun 29 11:41:24 2018 +0300
Committer: Julian Hyde <jh...@apache.org>
Committed: Sun Aug 12 18:04:44 2018 -0700
----------------------------------------------------------------------
core/src/main/codegen/templates/Parser.jj | 1 +
.../calcite/adapter/enumerable/RexImpTable.java | 3 +
.../rel/rules/AggregateReduceFunctionsRule.java | 259 ++++++++++++++++++-
.../java/org/apache/calcite/sql/SqlKind.java | 23 +-
.../calcite/sql/fun/SqlCountAggFunction.java | 10 +-
.../calcite/sql/fun/SqlCovarAggFunction.java | 8 +-
.../sql/fun/SqlRegrCountAggFunction.java | 37 +++
.../calcite/sql/fun/SqlStdOperatorTable.java | 6 +
.../apache/calcite/sql/type/ReturnTypes.java | 2 +-
.../sql2rel/StandardConvertletTable.java | 250 ++++++++++++++----
.../apache/calcite/sql/test/SqlAdvisorTest.java | 1 +
core/src/test/resources/sql/agg.iq | 97 +++++++
core/src/test/resources/sql/winagg.iq | 133 ++++++++++
site/_docs/reference.md | 2 +-
14 files changed, 761 insertions(+), 71 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/codegen/templates/Parser.jj
----------------------------------------------------------------------
diff --git a/core/src/main/codegen/templates/Parser.jj b/core/src/main/codegen/templates/Parser.jj
index be05d9c..0dc23eb 100644
--- a/core/src/main/codegen/templates/Parser.jj
+++ b/core/src/main/codegen/templates/Parser.jj
@@ -5171,6 +5171,7 @@ SqlIdentifier ReservedFunctionName() :
| <PERCENT_RANK>
| <POWER>
| <RANK>
+ | <REGR_COUNT>
| <REGR_SXX>
| <REGR_SYY>
| <ROW_NUMBER>
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
index 5ba5959..80d5541 100644
--- a/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
+++ b/core/src/main/java/org/apache/calcite/adapter/enumerable/RexImpTable.java
@@ -191,6 +191,7 @@ import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RADIANS;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RAND;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RAND_INTEGER;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.RANK;
+import static org.apache.calcite.sql.fun.SqlStdOperatorTable.REGR_COUNT;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.REINTERPRET;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.REPLACE;
import static org.apache.calcite.sql.fun.SqlStdOperatorTable.ROUND;
@@ -438,6 +439,7 @@ public class RexImpTable {
map.put(LOCALTIMESTAMP, systemFunctionImplementor);
aggMap.put(COUNT, constructorSupplier(CountImplementor.class));
+ aggMap.put(REGR_COUNT, constructorSupplier(CountImplementor.class));
aggMap.put(SUM0, constructorSupplier(SumImplementor.class));
aggMap.put(SUM, constructorSupplier(SumImplementor.class));
Supplier<MinMaxImplementor> minMax =
@@ -464,6 +466,7 @@ public class RexImpTable {
winAggMap.put(LAG, constructorSupplier(LagImplementor.class));
winAggMap.put(NTILE, constructorSupplier(NtileImplementor.class));
winAggMap.put(COUNT, constructorSupplier(CountWinImplementor.class));
+ winAggMap.put(REGR_COUNT, constructorSupplier(CountWinImplementor.class));
}
private <T> Supplier<T> constructorSupplier(Class<T> klass) {
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
index 8bdd6c1..68f6b16 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
@@ -72,6 +72,17 @@ import java.util.Map;
*
* <li>VAR_SAMP(x) → (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x))
* / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END
+ *
+ * <li>COVAR_POP(x, y) → (SUM(x * y) - SUM(x, y) * SUM(y, x)
+ * / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+ *
+ * <li>COVAR_SAMP(x, y) → (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y))
+ * / CASE REGR_COUNT(x, y) WHEN 1 THEN NULL ELSE REGR_COUNT(x, y) - 1 END
+ *
+ * <li>REGR_SXX(x, y) → REGR_COUNT(x, y) * VAR_POP(y)
+ *
+ * <li>REGR_SYY(x, y) → REGR_COUNT(x, y) * VAR_POP(x)
+ *
* </ul>
*
* <p>Since many of these rewrites introduce multiple occurrences of simpler
@@ -127,7 +138,8 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
* Returns whether the aggregate call is a reducible function
*/
private boolean isReducible(final SqlKind kind) {
- if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)) {
+ if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)
+ || SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(kind)) {
return true;
}
switch (kind) {
@@ -201,6 +213,8 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
List<RexNode> inputExprs) {
final SqlKind kind = oldCall.getAggregation().getKind();
if (isReducible(kind)) {
+ final Integer y;
+ final Integer x;
switch (kind) {
case SUM:
// replace original SUM(x) with
@@ -209,6 +223,37 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
case AVG:
// replace original AVG(x) with SUM(x) / COUNT(x)
return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
+ case COVAR_POP:
+ // replace original COVAR_POP(x, y) with
+ // (SUM(x * y) - SUM(y) * SUM(y) / COUNT(x))
+ // / COUNT(x))
+ return reduceCovariance(oldAggRel, oldCall, true, newCalls,
+ aggCallMapping, inputExprs);
+ case COVAR_SAMP:
+ // replace original COVAR_SAMP(x, y) with
+ // SQRT(
+ // (SUM(x * y) - SUM(x) * SUM(y) / COUNT(x))
+ // / CASE COUNT(x) WHEN 1 THEN NULL ELSE COUNT(x) - 1 END)
+ return reduceCovariance(oldAggRel, oldCall, false, newCalls,
+ aggCallMapping, inputExprs);
+ case REGR_SXX:
+ // replace original REGR_SXX(x, y) with
+ // REGR_COUNT(x, y) * VAR_POP(y)
+ assert oldCall.getArgList().size() == 2 : oldCall.getArgList();
+ x = oldCall.getArgList().get(0);
+ y = oldCall.getArgList().get(1);
+ //noinspection SuspiciousNameCombination
+ return reduceRegrSzz(oldAggRel, oldCall, newCalls, aggCallMapping,
+ inputExprs, y, y, x);
+ case REGR_SYY:
+ // replace original REGR_SYY(x, y) with
+ // REGR_COUNT(x, y) * VAR_POP(x)
+ assert oldCall.getArgList().size() == 2 : oldCall.getArgList();
+ x = oldCall.getArgList().get(0);
+ y = oldCall.getArgList().get(1);
+ //noinspection SuspiciousNameCombination
+ return reduceRegrSzz(oldAggRel, oldCall, newCalls, aggCallMapping,
+ inputExprs, x, x, y);
case STDDEV_POP:
// replace original STDDEV_POP(x) with
// SQRT(
@@ -260,16 +305,17 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
RelDataType operandType,
Aggregate oldAggRel,
AggregateCall oldCall,
- int argOrdinal) {
+ int argOrdinal,
+ int filter) {
final Aggregate.AggCallBinding binding =
new Aggregate.AggCallBinding(typeFactory, aggFunction,
ImmutableList.of(operandType), oldAggRel.getGroupCount(),
- oldCall.filterArg >= 0);
+ filter >= 0);
return AggregateCall.create(aggFunction,
oldCall.isDistinct(),
oldCall.isApproximate(),
ImmutableIntList.of(argOrdinal),
- oldCall.filterArg,
+ filter,
aggFunction.inferReturnType(binding),
null);
}
@@ -346,6 +392,7 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
getFieldType(
oldAggRel.getInput(),
arg);
+
final AggregateCall sumZeroCall =
AggregateCall.create(SqlStdOperatorTable.SUM0, oldCall.isDistinct(),
oldCall.isApproximate(), oldCall.getArgList(), oldCall.filterArg,
@@ -424,7 +471,6 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
final RexNode argRef =
rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), true);
- final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY,
argRef, argRef);
@@ -432,7 +478,7 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
final AggregateCall sumArgSquaredAggCall =
createAggregateCallWithBinding(typeFactory, SqlStdOperatorTable.SUM,
- argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
+ argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal, -1);
final RexNode sumArgSquared =
rexBuilder.addAggCall(sumArgSquaredAggCall,
@@ -530,6 +576,207 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
oldCall.getType(), result);
}
+ private RexNode getSumAggregatedRexNode(Aggregate oldAggRel,
+ AggregateCall oldCall,
+ List<AggregateCall> newCalls,
+ Map<AggregateCall, RexNode> aggCallMapping,
+ RexBuilder rexBuilder,
+ int argOrdinal,
+ int filterArg) {
+ final AggregateCall aggregateCall =
+ AggregateCall.create(SqlStdOperatorTable.SUM,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ ImmutableIntList.of(argOrdinal),
+ filterArg,
+ oldAggRel.getGroupCount(),
+ oldAggRel.getInput(),
+ null,
+ null);
+ return rexBuilder.addAggCall(aggregateCall,
+ oldAggRel.getGroupCount(),
+ oldAggRel.indicator,
+ newCalls,
+ aggCallMapping,
+ ImmutableList.of(aggregateCall.getType()));
+ }
+
+ private RexNode getSumAggregatedRexNodeWithBinding(Aggregate oldAggRel,
+ AggregateCall oldCall,
+ List<AggregateCall> newCalls,
+ Map<AggregateCall, RexNode> aggCallMapping,
+ RelDataType operandType,
+ int argOrdinal,
+ int filter) {
+ RelOptCluster cluster = oldAggRel.getCluster();
+ final AggregateCall sumArgSquaredAggCall =
+ createAggregateCallWithBinding(cluster.getTypeFactory(),
+ SqlStdOperatorTable.SUM, operandType, oldAggRel, oldCall, argOrdinal, filter);
+
+ return cluster.getRexBuilder().addAggCall(sumArgSquaredAggCall,
+ oldAggRel.getGroupCount(),
+ oldAggRel.indicator,
+ newCalls,
+ aggCallMapping,
+ ImmutableList.of(sumArgSquaredAggCall.getType()));
+ }
+
+ private RexNode getRegrCountRexNode(Aggregate oldAggRel,
+ AggregateCall oldCall,
+ List<AggregateCall> newCalls,
+ Map<AggregateCall, RexNode> aggCallMapping,
+ ImmutableIntList argOrdinals,
+ ImmutableList<RelDataType> operandTypes,
+ int filterArg) {
+ final AggregateCall countArgAggCall =
+ AggregateCall.create(SqlStdOperatorTable.REGR_COUNT,
+ oldCall.isDistinct(),
+ oldCall.isApproximate(),
+ argOrdinals,
+ filterArg,
+ oldAggRel.getGroupCount(),
+ oldAggRel,
+ null,
+ null);
+
+ return oldAggRel.getCluster().getRexBuilder().addAggCall(countArgAggCall,
+ oldAggRel.getGroupCount(),
+ oldAggRel.indicator,
+ newCalls,
+ aggCallMapping,
+ operandTypes);
+ }
+
+ private RexNode reduceRegrSzz(
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ List<AggregateCall> newCalls,
+ Map<AggregateCall, RexNode> aggCallMapping,
+ List<RexNode> inputExprs,
+ int xIndex,
+ int yIndex,
+ int nullFilterIndex) {
+ // regr_sxx(x, y) ==>
+ // sum(y * y, x) - sum(y, x) * sum(y, x) / regr_count(x, y)
+ //
+
+ final RelOptCluster cluster = oldAggRel.getCluster();
+ final RexBuilder rexBuilder = cluster.getRexBuilder();
+ final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+ final RelDataType argXType = getFieldType(oldAggRel.getInput(), xIndex);
+ final RelDataType argYType =
+ xIndex == yIndex ? argXType : getFieldType(oldAggRel.getInput(), yIndex);
+ final RelDataType nullFilterIndexType =
+ nullFilterIndex == yIndex ? argYType : getFieldType(oldAggRel.getInput(), yIndex);
+
+ final RelDataType oldCallType =
+ typeFactory.createTypeWithNullability(oldCall.getType(),
+ argXType.isNullable() || argYType.isNullable() || nullFilterIndexType.isNullable());
+
+ final RexNode argX =
+ rexBuilder.ensureType(oldCallType, inputExprs.get(xIndex), true);
+ final RexNode argY =
+ rexBuilder.ensureType(oldCallType, inputExprs.get(yIndex), true);
+ final RexNode argNullFilter =
+ rexBuilder.ensureType(oldCallType, inputExprs.get(nullFilterIndex), true);
+
+ final RexNode argXArgY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argX, argY);
+ final int argSquaredOrdinal = lookupOrAdd(inputExprs, argXArgY);
+
+ final RexNode argXAndYNotNullFilter = rexBuilder.makeCall(SqlStdOperatorTable.AND,
+ rexBuilder.makeCall(SqlStdOperatorTable.AND,
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argX),
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argY)),
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argNullFilter));
+ final int argXAndYNotNullFilterOrdinal = lookupOrAdd(inputExprs, argXAndYNotNullFilter);
+ final RexNode sumXY = getSumAggregatedRexNodeWithBinding(
+ oldAggRel, oldCall, newCalls, aggCallMapping, argXArgY.getType(),
+ argSquaredOrdinal, argXAndYNotNullFilterOrdinal);
+ final RexNode sumXYCast = rexBuilder.ensureType(oldCallType, sumXY, true);
+
+ final RexNode sumX = getSumAggregatedRexNode(oldAggRel, oldCall,
+ newCalls, aggCallMapping, rexBuilder, xIndex, argXAndYNotNullFilterOrdinal);
+ final RexNode sumY = xIndex == yIndex
+ ? sumX
+ : getSumAggregatedRexNode(oldAggRel, oldCall, newCalls,
+ aggCallMapping, rexBuilder, yIndex, argXAndYNotNullFilterOrdinal);
+
+ final RexNode sumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumX, sumY);
+
+ final RexNode countArg = getRegrCountRexNode(oldAggRel, oldCall, newCalls, aggCallMapping,
+ ImmutableIntList.of(xIndex), ImmutableList.of(argXType), argXAndYNotNullFilterOrdinal);
+
+ RexLiteral zero = rexBuilder.makeExactLiteral(BigDecimal.ZERO);
+ RexNode nul = rexBuilder.constantNull();
+ final RexNode avgSumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.CASE,
+ rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, zero), nul,
+ rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumXSumY, countArg));
+ final RexNode avgSumXSumYCast = rexBuilder.ensureType(oldCallType, avgSumXSumY, true);
+ final RexNode result =
+ rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumXYCast, avgSumXSumYCast);
+ return rexBuilder.makeCast(oldCall.getType(), result);
+ }
+
+ private RexNode reduceCovariance(
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ boolean biased,
+ List<AggregateCall> newCalls,
+ Map<AggregateCall, RexNode> aggCallMapping,
+ List<RexNode> inputExprs) {
+ // covar_pop(x, y) ==>
+ // (sum(x * y) - sum(x) * sum(y) / regr_count(x, y))
+ // / regr_count(x, y)
+ //
+ // covar_samp(x, y) ==>
+ // (sum(x * y) - sum(x) * sum(y) / regr_count(x, y))
+ // / regr_count(count(x, y) - 1, 0)
+ final RelOptCluster cluster = oldAggRel.getCluster();
+ final RexBuilder rexBuilder = cluster.getRexBuilder();
+ final RelDataTypeFactory typeFactory = cluster.getTypeFactory();
+ assert oldCall.getArgList().size() == 2 : oldCall.getArgList();
+ final int argXOrdinal = oldCall.getArgList().get(0);
+ final int argYOrdinal = oldCall.getArgList().get(1);
+ final RelDataType argXOrdinalType = getFieldType(oldAggRel.getInput(), argXOrdinal);
+ final RelDataType argYOrdinalType = getFieldType(oldAggRel.getInput(), argYOrdinal);
+ final RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(),
+ argXOrdinalType.isNullable() || argYOrdinalType.isNullable());
+ final RexNode argX = rexBuilder.ensureType(oldCallType, inputExprs.get(argXOrdinal), true);
+ final RexNode argY = rexBuilder.ensureType(oldCallType, inputExprs.get(argYOrdinal), true);
+ final RexNode argXAndYNotNullFilter = rexBuilder.makeCall(SqlStdOperatorTable.AND,
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argX),
+ rexBuilder.makeCall(SqlStdOperatorTable.IS_NOT_NULL, argY));
+ final int argXAndYNotNullFilterOrdinal = lookupOrAdd(inputExprs, argXAndYNotNullFilter);
+ final RexNode argXY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, argX, argY);
+ final int argXYOrdinal = lookupOrAdd(inputExprs, argXY);
+ final RexNode sumXY = getSumAggregatedRexNodeWithBinding(oldAggRel, oldCall, newCalls,
+ aggCallMapping, argXY.getType(), argXYOrdinal, argXAndYNotNullFilterOrdinal);
+ final RexNode sumX = getSumAggregatedRexNode(oldAggRel, oldCall, newCalls,
+ aggCallMapping, rexBuilder, argXOrdinal, argXAndYNotNullFilterOrdinal);
+ final RexNode sumY = getSumAggregatedRexNode(oldAggRel, oldCall, newCalls,
+ aggCallMapping, rexBuilder, argYOrdinal, argXAndYNotNullFilterOrdinal);
+ final RexNode sumXSumY = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY, sumX, sumY);
+ final RexNode countArg = getRegrCountRexNode(oldAggRel, oldCall, newCalls, aggCallMapping,
+ ImmutableIntList.of(argXOrdinal, argYOrdinal),
+ ImmutableList.of(argXOrdinalType, argYOrdinalType),
+ argXAndYNotNullFilterOrdinal);
+ final RexNode avgSumSquaredArg =
+ rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, sumXSumY, countArg);
+ final RexNode diff = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, sumXY, avgSumSquaredArg);
+ final RexNode denominator;
+ if (biased) {
+ denominator = countArg;
+ } else {
+ final RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
+ final RexNode nul = rexBuilder.makeCast(countArg.getType(), rexBuilder.constantNull());
+ final RexNode countMinusOne = rexBuilder.makeCall(SqlStdOperatorTable.MINUS, countArg, one);
+ final RexNode countEqOne = rexBuilder.makeCall(SqlStdOperatorTable.EQUALS, countArg, one);
+ denominator = rexBuilder.makeCall(SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
+ }
+ final RexNode result = rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, diff, denominator);
+ return rexBuilder.makeCast(oldCall.getType(), result);
+ }
+
/**
* Finds the ordinal of an element in a list, or adds it.
*
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/SqlKind.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
index 54b390b..cbf201f 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
@@ -810,14 +810,11 @@ public enum SqlKind {
/** The {@code GROUP_ID()} function. */
GROUP_ID,
- /**
- * the internal permute function in match_recognize cluse
- */
+ /** The internal "permute" function in a MATCH_RECOGNIZE clause. */
PATTERN_PERMUTE,
- /**
- * the special patterns to exclude enclosing pattern from output in match_recognize clause
- */
+ /** The special patterns to exclude enclosing pattern from output in a
+ * MATCH_RECOGNIZE clause. */
PATTERN_EXCLUDED,
// Aggregate functions
@@ -858,6 +855,9 @@ public enum SqlKind {
/** The {@code COVAR_SAMP} aggregate function. */
COVAR_SAMP,
+ /** The {@code REGR_COUNT} aggregate function. */
+ REGR_COUNT,
+
/** The {@code REGR_SXX} aggregate function. */
REGR_SXX,
@@ -1064,7 +1064,7 @@ public enum SqlKind {
*/
public static final EnumSet<SqlKind> AGGREGATE =
EnumSet.of(COUNT, SUM, SUM0, MIN, MAX, LEAD, LAG, FIRST_VALUE,
- LAST_VALUE, COVAR_POP, COVAR_SAMP, REGR_SXX, REGR_SYY,
+ LAST_VALUE, COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY,
AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP, NTILE, COLLECT,
FUSION, SINGLE_VALUE, ROW_NUMBER, RANK, PERCENT_RANK, DENSE_RANK,
CUME_DIST);
@@ -1180,6 +1180,15 @@ public enum SqlKind {
EnumSet.of(AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP);
/**
+ * Category of SqlCovarAggFunction.
+ *
+ * <p>Consists of {@link #COVAR_POP}, {@link #COVAR_SAMP}, {@link #REGR_SXX},
+ * {@link #REGR_SYY}.
+ */
+ public static final Set<SqlKind> COVAR_AVG_AGG_FUNCTIONS =
+ EnumSet.of(COVAR_POP, COVAR_SAMP, REGR_COUNT, REGR_SXX, REGR_SYY);
+
+ /**
* Category of comparison operators.
*
* <p>Consists of:
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
index e053294..db54102 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCountAggFunction.java
@@ -26,6 +26,7 @@ import org.apache.calcite.sql.SqlSplittableAggFunction;
import org.apache.calcite.sql.SqlSyntax;
import org.apache.calcite.sql.type.OperandTypes;
import org.apache.calcite.sql.type.ReturnTypes;
+import org.apache.calcite.sql.type.SqlOperandTypeChecker;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.validate.SqlValidator;
import org.apache.calcite.sql.validate.SqlValidatorScope;
@@ -45,11 +46,12 @@ public class SqlCountAggFunction extends SqlAggFunction {
//~ Constructors -----------------------------------------------------------
public SqlCountAggFunction(String name) {
+ this(name, SqlValidator.STRICT ? OperandTypes.ANY : OperandTypes.ONE_OR_MORE);
+ }
+
+ public SqlCountAggFunction(String name, SqlOperandTypeChecker sqlOperandTypeChecker) {
super(name, null, SqlKind.COUNT, ReturnTypes.BIGINT, null,
- SqlValidator.STRICT
- ? OperandTypes.ANY
- : OperandTypes.ONE_OR_MORE,
- SqlFunctionCategory.NUMERIC, false, false);
+ sqlOperandTypeChecker, SqlFunctionCategory.NUMERIC, false, false);
}
//~ Methods ----------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
index 8c62290..8591959 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
@@ -43,16 +43,14 @@ public class SqlCovarAggFunction extends SqlAggFunction {
super(kind.name(),
null,
kind,
- ReturnTypes.COVAR_FUNCTION,
+ kind == SqlKind.REGR_COUNT ? ReturnTypes.BIGINT : ReturnTypes.COVAR_REGR_FUNCTION,
null,
OperandTypes.NUMERIC_NUMERIC,
SqlFunctionCategory.NUMERIC,
false,
false);
- Preconditions.checkArgument(kind == SqlKind.COVAR_POP
- || kind == SqlKind.COVAR_SAMP
- || kind == SqlKind.REGR_SXX
- || kind == SqlKind.REGR_SYY);
+ Preconditions.checkArgument(SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains(kind),
+ "unsupported sql kind: " + kind);
}
@Deprecated // to be removed before 2.0
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java
new file mode 100644
index 0000000..4408272
--- /dev/null
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlRegrCountAggFunction.java
@@ -0,0 +1,37 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to you under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.calcite.sql.fun;
+
+import org.apache.calcite.sql.SqlKind;
+import org.apache.calcite.sql.type.OperandTypes;
+
+import com.google.common.base.Preconditions;
+
+/**
+ * Definition of the SQL <code>REGR_COUNT</code> aggregation function.
+ *
+ * <p><code>REGR_COUNT</code> is an aggregator which returns the number of rows which
+ * have gone into it and both arguments are not <code>null</code>.
+ */
+public class SqlRegrCountAggFunction extends SqlCountAggFunction {
+ public SqlRegrCountAggFunction(SqlKind kind) {
+ super("REGR_COUNT", OperandTypes.NUMERIC_NUMERIC);
+ Preconditions.checkArgument(SqlKind.REGR_COUNT == kind, "unsupported sql kind: " + kind);
+ }
+}
+
+// End SqlRegrCountAggFunction.java
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
index 0064cba..ea17ec8 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
@@ -915,6 +915,12 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable {
new SqlAvgAggFunction(SqlKind.STDDEV_POP);
/**
+ * <code>REGR_COUNT</code> aggregate function.
+ */
+ public static final SqlAggFunction REGR_COUNT =
+ new SqlRegrCountAggFunction(SqlKind.REGR_COUNT);
+
+ /**
* <code>REGR_SXX</code> aggregate function.
*/
public static final SqlAggFunction REGR_SXX =
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
index 8b07b83..fc0022d 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
@@ -766,7 +766,7 @@ public abstract class ReturnTypes {
}
};
- public static final SqlReturnTypeInference COVAR_FUNCTION = opBinding -> {
+ public static final SqlReturnTypeInference COVAR_REGR_FUNCTION = opBinding -> {
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
final RelDataType relDataType =
typeFactory.getTypeSystem().deriveCovarType(typeFactory,
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
index e9b7cf6..987e821 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
@@ -67,6 +67,7 @@ import org.apache.calcite.sql.type.SqlTypeFamily;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.sql.validate.SqlValidator;
+import org.apache.calcite.sql.validate.SqlValidatorImpl;
import org.apache.calcite.util.Pair;
import org.apache.calcite.util.Util;
@@ -77,6 +78,7 @@ import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.ArrayList;
import java.util.List;
+import java.util.Objects;
/**
* Standard implementation of {@link SqlRexConvertletTable}.
@@ -237,6 +239,14 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
registerOp(SqlStdOperatorTable.VARIANCE,
new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
+ registerOp(SqlStdOperatorTable.COVAR_POP,
+ new RegrCovarianceConvertlet(SqlKind.COVAR_POP));
+ registerOp(SqlStdOperatorTable.COVAR_SAMP,
+ new RegrCovarianceConvertlet(SqlKind.COVAR_SAMP));
+ registerOp(SqlStdOperatorTable.REGR_SXX,
+ new RegrCovarianceConvertlet(SqlKind.REGR_SXX));
+ registerOp(SqlStdOperatorTable.REGR_SYY,
+ new RegrCovarianceConvertlet(SqlKind.REGR_SYY));
final SqlRexConvertlet floorCeilConvertlet = new FloorCeilConvertlet();
registerOp(SqlStdOperatorTable.FLOOR, floorCeilConvertlet);
@@ -342,14 +352,26 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
SqlNodeList thenList = call.getThenOperands();
assert whenList.size() == thenList.size();
+ RexBuilder rexBuilder = cx.getRexBuilder();
final List<RexNode> exprList = new ArrayList<>();
for (int i = 0; i < whenList.size(); i++) {
- exprList.add(cx.convertExpression(whenList.get(i)));
- exprList.add(cx.convertExpression(thenList.get(i)));
+ if (SqlUtil.isNullLiteral(whenList.get(i), false)) {
+ exprList.add(rexBuilder.constantNull());
+ } else {
+ exprList.add(cx.convertExpression(whenList.get(i)));
+ }
+ if (SqlUtil.isNullLiteral(thenList.get(i), false)) {
+ exprList.add(rexBuilder.constantNull());
+ } else {
+ exprList.add(cx.convertExpression(thenList.get(i)));
+ }
+ }
+ if (SqlUtil.isNullLiteral(call.getElseOperand(), false)) {
+ exprList.add(rexBuilder.constantNull());
+ } else {
+ exprList.add(cx.convertExpression(call.getElseOperand()));
}
- exprList.add(cx.convertExpression(call.getElseOperand()));
- RexBuilder rexBuilder = cx.getRexBuilder();
RelDataType type =
rexBuilder.deriveReturnType(call.getOperator(), exprList);
for (int i : elseArgs(exprList.size())) {
@@ -473,11 +495,13 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
return castToValidatedType(cx, call, cx.convertExpression(left));
}
SqlDataTypeSpec dataType = (SqlDataTypeSpec) right;
+ RelDataType type = dataType.deriveType(typeFactory);
if (SqlUtil.isNullLiteral(left, false)) {
+ final SqlValidatorImpl validator = (SqlValidatorImpl) cx.getValidator();
+ validator.setValidatedNodeType(left, type);
return cx.convertExpression(left);
}
RexNode arg = cx.convertExpression(left);
- RelDataType type = dataType.deriveType(typeFactory);
if (type == null) {
type = cx.getValidator().getValidatedNodeType(dataType.getTypeName());
}
@@ -1061,6 +1085,133 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
return rexBuilder.makeCast(type, e);
}
+ /** Convertlet that handles {@code COVAR_POP}, {@code COVAR_SAMP},
+ * {@code REGR_SXX}, {@code REGR_SYY} windowed aggregate functions.
+ */
+ private static class RegrCovarianceConvertlet implements SqlRexConvertlet {
+ private final SqlKind kind;
+
+ RegrCovarianceConvertlet(SqlKind kind) {
+ this.kind = kind;
+ }
+
+ public RexNode convertCall(SqlRexContext cx, SqlCall call) {
+ assert call.operandCount() == 2;
+ final SqlNode arg1 = call.operand(0);
+ final SqlNode arg2 = call.operand(1);
+ final SqlNode expr;
+ final RelDataType type =
+ cx.getValidator().getValidatedNodeType(call);
+ switch (kind) {
+ case COVAR_POP:
+ expr = expandCovariance(arg1, arg2, null, type, cx, true);
+ break;
+ case COVAR_SAMP:
+ expr = expandCovariance(arg1, arg2, null, type, cx, false);
+ break;
+ case REGR_SXX:
+ expr = expandRegrSzz(arg2, arg1, type, cx, true);
+ break;
+ case REGR_SYY:
+ expr = expandRegrSzz(arg1, arg2, type, cx, true);
+ break;
+ default:
+ throw Util.unexpected(kind);
+ }
+ RexNode rex = cx.convertExpression(expr);
+ return cx.getRexBuilder().ensureType(type, rex, true);
+ }
+
+ private SqlNode expandRegrSzz(
+ final SqlNode arg1, final SqlNode arg2,
+ final RelDataType avgType, final SqlRexContext cx, boolean variance) {
+ final SqlParserPos pos = SqlParserPos.ZERO;
+ final SqlNode count =
+ SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg1, arg2);
+ final SqlNode varPop =
+ expandCovariance(arg1, variance ? arg1 : arg2, arg2, avgType, cx, true);
+ final RexNode varPopRex = cx.convertExpression(varPop);
+ final SqlNode varPopCast;
+ varPopCast = getCastedSqlNode(varPop, avgType, pos, varPopRex);
+ return SqlStdOperatorTable.MULTIPLY.createCall(pos, varPopCast, count);
+ }
+
+ private SqlNode expandCovariance(
+ final SqlNode arg0Input,
+ final SqlNode arg1Input,
+ final SqlNode dependent,
+ final RelDataType varType,
+ final SqlRexContext cx,
+ boolean biased) {
+ // covar_pop(x1, x2) ==>
+ // (sum(x1 * x2) - sum(x2) * sum(x1) / count(x1, x2))
+ // / count(x1, x2)
+ //
+ // covar_samp(x1, x2) ==>
+ // (sum(x1 * x2) - sum(x1) * sum(x2) / count(x1, x2))
+ // / (count(x1, x2) - 1)
+ final SqlParserPos pos = SqlParserPos.ZERO;
+ final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
+
+ final RexNode arg0Rex = cx.convertExpression(arg0Input);
+ final RexNode arg1Rex = cx.convertExpression(arg1Input);
+
+ final SqlNode arg0 = getCastedSqlNode(arg0Input, varType, pos, arg0Rex);
+ final SqlNode arg1 = getCastedSqlNode(arg1Input, varType, pos, arg1Rex);
+ final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg0, arg1);
+ final SqlNode sumArgSquared;
+ final SqlNode sum0;
+ final SqlNode sum1;
+ final SqlNode count;
+ if (dependent == null) {
+ sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared);
+ sum0 = SqlStdOperatorTable.SUM.createCall(pos, arg0, arg1);
+ sum1 = SqlStdOperatorTable.SUM.createCall(pos, arg1, arg0);
+ count = SqlStdOperatorTable.REGR_COUNT.createCall(pos, arg0, arg1);
+ } else {
+ sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquared, dependent);
+ sum0 = SqlStdOperatorTable.SUM.createCall(
+ pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent);
+ sum1 = SqlStdOperatorTable.SUM.createCall(
+ pos, arg1, Objects.equals(dependent, arg1Input) ? arg0 : dependent);
+ count = SqlStdOperatorTable.REGR_COUNT.createCall(
+ pos, arg0, Objects.equals(dependent, arg0Input) ? arg1 : dependent);
+ }
+
+ final SqlNode sumSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, sum0, sum1);
+ final SqlNode countCasted =
+ getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
+
+ final SqlNode avgSumSquared =
+ SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquared, countCasted);
+ final SqlNode diff = SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquared, avgSumSquared);
+ SqlNode denominator;
+ if (biased) {
+ denominator = countCasted;
+ } else {
+ final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
+ denominator = new SqlCase(SqlParserPos.ZERO, countCasted,
+ SqlNodeList.of(SqlStdOperatorTable.EQUALS.createCall(pos, countCasted, one)),
+ SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)),
+ SqlStdOperatorTable.MINUS.createCall(pos, countCasted, one));
+ }
+
+ return SqlStdOperatorTable.DIVIDE.createCall(pos, diff, denominator);
+ }
+
+ private SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType,
+ SqlParserPos pos, RexNode argRex) {
+ SqlNode arg;
+ if (argRex != null && !argRex.getType().equals(varType)) {
+ arg = SqlStdOperatorTable.CAST.createCall(
+ pos, argInput, SqlTypeUtil.convertTypeToSpec(varType));
+ } else {
+ arg = argInput;
+ }
+ return arg;
+ }
+ }
+
/** Convertlet that handles {@code AVG} and {@code VARIANCE}
* windowed aggregate functions. */
private static class AvgVarianceConvertlet implements SqlRexConvertlet {
@@ -1106,14 +1257,7 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
SqlStdOperatorTable.SUM.createCall(pos, arg);
final RexNode sumRex = cx.convertExpression(sum);
final SqlNode sumCast;
- if (!sumRex.getType().equals(avgType)) {
- sumCast = SqlStdOperatorTable.CAST.createCall(pos,
- new SqlDataTypeSpec(
- new SqlIdentifier(avgType.getSqlTypeName().getName(), pos),
- avgType.getPrecision(), avgType.getScale(), null, null, pos));
- } else {
- sumCast = sum;
- }
+ sumCast = getCastedSqlNode(sum, avgType, pos, sumRex);
final SqlNode count =
SqlStdOperatorTable.COUNT.createCall(pos, arg);
return SqlStdOperatorTable.DIVIDE.createCall(
@@ -1147,54 +1291,66 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
// / (count(x) - 1)
final SqlParserPos pos = SqlParserPos.ZERO;
- final RexNode argRex = cx.convertExpression(argInput);
- final SqlNode arg;
- if (!argRex.getType().equals(varType)) {
- arg = SqlStdOperatorTable.CAST.createCall(pos,
- new SqlDataTypeSpec(new SqlIdentifier(varType.getSqlTypeName().getName(), pos),
- varType.getPrecision(), varType.getScale(), null, null, pos));
- } else {
- arg = argInput;
- }
+ final SqlNode arg = getCastedSqlNode(argInput, varType, pos, cx.convertExpression(argInput));
- final SqlNode argSquared =
- SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
- final SqlNode sumArgSquared =
- SqlStdOperatorTable.SUM.createCall(pos, argSquared);
- final SqlNode sum =
- SqlStdOperatorTable.SUM.createCall(pos, arg);
+ final SqlNode argSquared = SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
+ final SqlNode argSquaredCasted =
+ getCastedSqlNode(argSquared, varType, pos, cx.convertExpression(argSquared));
+ final SqlNode sumArgSquared = SqlStdOperatorTable.SUM.createCall(pos, argSquaredCasted);
+ final SqlNode sumArgSquaredCasted =
+ getCastedSqlNode(sumArgSquared, varType, pos, cx.convertExpression(sumArgSquared));
+ final SqlNode sum = SqlStdOperatorTable.SUM.createCall(pos, arg);
+ final SqlNode sumCasted = getCastedSqlNode(sum, varType, pos, cx.convertExpression(sum));
final SqlNode sumSquared =
- SqlStdOperatorTable.MULTIPLY.createCall(pos, sum, sum);
- final SqlNode count =
- SqlStdOperatorTable.COUNT.createCall(pos, arg);
+ SqlStdOperatorTable.MULTIPLY.createCall(pos, sumCasted, sumCasted);
+ final SqlNode sumSquaredCasted =
+ getCastedSqlNode(sumSquared, varType, pos, cx.convertExpression(sumSquared));
+ final SqlNode count = SqlStdOperatorTable.COUNT.createCall(pos, arg);
+ final SqlNode countCasted =
+ getCastedSqlNode(count, varType, pos, cx.convertExpression(count));
final SqlNode avgSumSquared =
- SqlStdOperatorTable.DIVIDE.createCall(
- pos, sumSquared, count);
+ SqlStdOperatorTable.DIVIDE.createCall(pos, sumSquaredCasted, countCasted);
+ final SqlNode avgSumSquaredCasted =
+ getCastedSqlNode(avgSumSquared, varType, pos, cx.convertExpression(avgSumSquared));
final SqlNode diff =
- SqlStdOperatorTable.MINUS.createCall(
- pos, sumArgSquared, avgSumSquared);
+ SqlStdOperatorTable.MINUS.createCall(pos, sumArgSquaredCasted, avgSumSquaredCasted);
+ final SqlNode diffCasted =
+ getCastedSqlNode(diff, varType, pos, cx.convertExpression(diff));
final SqlNode denominator;
if (biased) {
- denominator = count;
+ denominator = countCasted;
} else {
- final SqlNumericLiteral one =
- SqlLiteral.createExactNumeric("1", pos);
- denominator =
- SqlStdOperatorTable.MINUS.createCall(
- pos, count, one);
+ final SqlNumericLiteral one = SqlLiteral.createExactNumeric("1", pos);
+ final SqlLiteral nullLiteral = SqlLiteral.createNull(SqlParserPos.ZERO);
+ denominator = new SqlCase(SqlParserPos.ZERO,
+ count,
+ SqlNodeList.of(SqlStdOperatorTable.EQUALS.createCall(pos, count, one)),
+ SqlNodeList.of(getCastedSqlNode(nullLiteral, varType, pos, null)),
+ SqlStdOperatorTable.MINUS.createCall(pos, count, one));
}
final SqlNode div =
- SqlStdOperatorTable.DIVIDE.createCall(
- pos, diff, denominator);
+ SqlStdOperatorTable.DIVIDE.createCall(pos, diffCasted, denominator);
+ final SqlNode divCasted = getCastedSqlNode(div, varType, pos, cx.convertExpression(div));
+
SqlNode result = div;
if (sqrt) {
- final SqlNumericLiteral half =
- SqlLiteral.createExactNumeric("0.5", pos);
- result =
- SqlStdOperatorTable.POWER.createCall(pos, div, half);
+ final SqlNumericLiteral half = SqlLiteral.createExactNumeric("0.5", pos);
+ result = SqlStdOperatorTable.POWER.createCall(pos, divCasted, half);
}
return result;
}
+
+ private SqlNode getCastedSqlNode(SqlNode argInput, RelDataType varType,
+ SqlParserPos pos, RexNode argRex) {
+ SqlNode arg;
+ if (argRex != null && !argRex.getType().equals(varType)) {
+ arg = SqlStdOperatorTable.CAST.createCall(
+ pos, argInput, SqlTypeUtil.convertTypeToSpec(varType));
+ } else {
+ arg = argInput;
+ }
+ return arg;
+ }
}
/** Convertlet that converts {@code LTRIM} and {@code RTRIM} to
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java b/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java
index 0742731..634c5ff 100644
--- a/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java
+++ b/core/src/test/java/org/apache/calcite/sql/test/SqlAdvisorTest.java
@@ -183,6 +183,7 @@ public class SqlAdvisorTest extends SqlValidatorTestCase {
"KEYWORD(POWER)",
"KEYWORD(PREV)",
"KEYWORD(RANK)",
+ "KEYWORD(REGR_COUNT)",
"KEYWORD(REGR_SXX)",
"KEYWORD(REGR_SYY)",
"KEYWORD(ROW)",
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/test/resources/sql/agg.iq
----------------------------------------------------------------------
diff --git a/core/src/test/resources/sql/agg.iq b/core/src/test/resources/sql/agg.iq
index 6c26a89..997ac94 100755
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -2284,4 +2284,101 @@ EnumerableCalc(expr#0..1=[{inputs}], ANYEMPNO=[$t1])
EnumerableTableScan(table=[[scott, EMP]])
!plan
+# [CALCITE-1776, CALCITE-2402] REGR_COUNT
+SELECT regr_count(COMM, SAL) as "REGR_COUNT(COMM, SAL)",
+ regr_count(EMPNO, SAL) as "REGR_COUNT(EMPNO, SAL)"
+from "scott".emp;
++-----------------------+------------------------+
+| REGR_COUNT(COMM, SAL) | REGR_COUNT(EMPNO, SAL) |
++-----------------------+------------------------+
+| 4 | 14 |
++-----------------------+------------------------+
+(1 row)
+
+!ok
+
+EnumerableAggregate(group=[{}], REGR_COUNT(COMM, SAL)=[REGR_COUNT($6, $5)], REGR_COUNT(EMPNO, SAL)=[REGR_COUNT($5)])
+ EnumerableTableScan(table=[[scott, EMP]])
+!plan
+
+# [CALCITE-1776, CALCITE-2402] REGR_SXX, REGR_SXY, REGR_SYY
+SELECT
+ regr_sxx(COMM, SAL) as "REGR_SXX(COMM, SAL)",
+ regr_syy(COMM, SAL) as "REGR_SYY(COMM, SAL)",
+ regr_sxx(SAL, COMM) as "REGR_SXX(SAL, COMM)",
+ regr_syy(SAL, COMM) as "REGR_SYY(SAL, COMM)"
+from "scott".emp;
++---------------------+---------------------+---------------------+---------------------+
+| REGR_SXX(COMM, SAL) | REGR_SYY(COMM, SAL) | REGR_SXX(SAL, COMM) | REGR_SYY(SAL, COMM) |
++---------------------+---------------------+---------------------+---------------------+
+| 95000.0000 | 1090000.0000 | 1090000.0000 | 95000.0000 |
++---------------------+---------------------+---------------------+---------------------+
+(1 row)
+
+!ok
+
+# [CALCITE-1776, CALCITE-2402] COVAR_POP, COVAR_SAMP, VAR_SAMP, VAR_POP
+SELECT
+ covar_pop(COMM, COMM) as "COVAR_POP(COMM, COMM)",
+ covar_samp(SAL, SAL) as "COVAR_SAMP(SAL, SAL)",
+ var_pop(COMM) as "VAR_POP(COMM)",
+ var_samp(SAL) as "VAR_SAMP(SAL)"
+from "scott".emp;
++-----------------------+----------------------+---------------+-------------------+
+| COVAR_POP(COMM, COMM) | COVAR_SAMP(SAL, SAL) | VAR_POP(COMM) | VAR_SAMP(SAL) |
++-----------------------+----------------------+---------------+-------------------+
+| 272500.0000 | 1398313.873626374 | 272500.0000 | 1398313.873626374 |
++-----------------------+----------------------+---------------+-------------------+
+(1 row)
+
+!ok
+
+# [CALCITE-1776, CALCITE-2402] REGR_COUNT with group by
+SELECT SAL, regr_count(COMM, SAL) as "REGR_COUNT(COMM, SAL)",
+ regr_count(EMPNO, SAL) as "REGR_COUNT(EMPNO, SAL)"
+from "scott".emp group by SAL;
++---------+-----------------------+------------------------+
+| SAL | REGR_COUNT(COMM, SAL) | REGR_COUNT(EMPNO, SAL) |
++---------+-----------------------+------------------------+
+| 1100.00 | 0 | 1 |
+| 1250.00 | 2 | 2 |
+| 1300.00 | 0 | 1 |
+| 1500.00 | 1 | 1 |
+| 1600.00 | 1 | 1 |
+| 2450.00 | 0 | 1 |
+| 2850.00 | 0 | 1 |
+| 2975.00 | 0 | 1 |
+| 3000.00 | 0 | 2 |
+| 5000.00 | 0 | 1 |
+| 800.00 | 0 | 1 |
+| 950.00 | 0 | 1 |
++---------+-----------------------+------------------------+
+(12 rows)
+
+!ok
+
+# [CALCITE-1776, CALCITE-2402] COVAR_POP, COVAR_SAMP, VAR_SAMP, VAR_POP with group by
+SELECT
+ MONTH(HIREDATE) as "MONTH",
+ covar_samp(SAL, COMM) as "COVAR_SAMP(SAL, COMM)",
+ var_pop(COMM) as "VAR_POP(COMM)",
+ var_samp(SAL) as "VAR_SAMP(SAL)"
+from "scott".emp
+group by MONTH(HIREDATE);
++-------+-----------------------+---------------+-------------------+
+| MONTH | COVAR_SAMP(SAL, COMM) | VAR_POP(COMM) | VAR_SAMP(SAL) |
++-------+-----------------------+---------------+-------------------+
+| 1 | | | 1201250.0000 |
+| 11 | | | |
+| 12 | | | 1510833.333333334 |
+| 2 | -35000.0000 | 10000.0000 | 831458.333333335 |
+| 4 | | | |
+| 5 | | | |
+| 6 | | | |
+| 9 | -175000.0000 | 490000.0000 | 31250.0000 |
++-------+-----------------------+---------------+-------------------+
+(8 rows)
+
+!ok
+
# End agg.iq
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/core/src/test/resources/sql/winagg.iq
----------------------------------------------------------------------
diff --git a/core/src/test/resources/sql/winagg.iq b/core/src/test/resources/sql/winagg.iq
index eac5822..fbd0dde 100644
--- a/core/src/test/resources/sql/winagg.iq
+++ b/core/src/test/resources/sql/winagg.iq
@@ -455,4 +455,137 @@ from emp order by emp."ENAME";
!ok
+# [CALCITE-2402] COVAR_POP, REGR_COUNT functions
+# SUM(x, y) = SUM(x) WHERE y IS NOT NULL
+# COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+select emps."AGE", emps."DEPTNO",
+ sum(emps."AGE" * emps."DEPTNO") over() as "sum(age * deptno)",
+ regr_count(emps."AGE", emps."DEPTNO") over() as "regr_count(age, deptno)",
+ covar_pop(emps."DEPTNO", emps."AGE") over() as "covar_pop"
+from emps order by emps."AGE";
++-----+--------+-------------------+-------------------------+-----------+
+| AGE | DEPTNO | sum(age * deptno) | regr_count(age, deptno) | covar_pop |
++-----+--------+-------------------+-------------------------+-----------+
+| 5 | 20 | 1950 | 3 | 39 |
+| 25 | 10 | 1950 | 3 | 39 |
+| 80 | 20 | 1950 | 3 | 39 |
+| | 40 | 1950 | 3 | 39 |
+| | 40 | 1950 | 3 | 39 |
++-----+--------+-------------------+-------------------------+-----------+
+(5 rows)
+
+!ok
+
+# [CALCITE-2402] COVAR_POP, REGR_COUNT functions
+# SUM(x, y) = SUM(x) WHERE y IS NOT NULL
+# COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+select emps."AGE", emps."DEPTNO", emps."GENDER",
+ sum(emps."AGE" * emps."DEPTNO") over(partition by emps."GENDER") as "sum(age * deptno)",
+ regr_count(emps."AGE", emps."DEPTNO") over(partition by emps."GENDER") as "regr_count(age, deptno)",
+ covar_pop(emps."DEPTNO", emps."AGE") over(partition by emps."GENDER") as "covar_pop"
+from emps order by emps."GENDER";
++-----+--------+--------+-------------------+-------------------------+-----------+
+| AGE | DEPTNO | GENDER | sum(age * deptno) | regr_count(age, deptno) | covar_pop |
++-----+--------+--------+-------------------+-------------------------+-----------+
+| 5 | 20 | F | 100 | 1 | 0 |
+| | 40 | F | 100 | 1 | 0 |
+| 80 | 20 | M | 1600 | 1 | 0 |
+| | 40 | M | 1600 | 1 | 0 |
+| 25 | 10 | | 250 | 1 | 0 |
++-----+--------+--------+-------------------+-------------------------+-----------+
+(5 rows)
+
+!ok
+
+# [CALCITE-2402] COVAR_SAMP functions
+# SUM(x, y) = SUM(x) WHERE y IS NOT NULL
+# COVAR_SAMP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / (REGR_COUNT(x, y) - 1)
+select emps."AGE", emps."DEPTNO", emps."GENDER",
+ covar_samp(emps."AGE", emps."AGE") over() as "var_samp",
+ covar_samp(emps."DEPTNO", emps."AGE") over() as "covar_samp",
+ covar_samp(emps."EMPNO", emps."DEPTNO") over(partition by emps."MANAGER") as "covar_samp partitioned"
+from emps order by emps."AGE";
++-----+--------+--------+----------+------------+------------------------+
+| AGE | DEPTNO | GENDER | var_samp | covar_samp | covar_samp partitioned |
++-----+--------+--------+----------+------------+------------------------+
+| 5 | 20 | F | 1508 | 58 | 0 |
+| 25 | 10 | | 1508 | 58 | 50 |
+| 80 | 20 | M | 1508 | 58 | 50 |
+| | 40 | M | 1508 | 58 | 0 |
+| | 40 | F | 1508 | 58 | 0 |
++-----+--------+--------+----------+------------+------------------------+
+(5 rows)
+
+!ok
+
+# [CALCITE-2402] VAR_POP, VAR_SAMP functions
+# VAR_POP(x) = (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / COUNT(x)
+# VAR_SAMP(x) = (SUM(x * x) - SUM(x) * SUM(x) / COUNT(x)) / (COUNT(x) - 1)
+select emps."AGE", emps."DEPTNO", emps."GENDER",
+ var_pop(emps."AGE") over() as "var_pop",
+ var_pop(emps."AGE") over(partition by emps."AGE") as "var_pop by age",
+ var_samp(emps."AGE") over() as "var_samp",
+ var_samp(emps."AGE") over(partition by emps."GENDER") as "var_samp by gender"
+from emps order by emps."AGE";
++-----+--------+--------+---------+----------------+----------+--------------------+
+| AGE | DEPTNO | GENDER | var_pop | var_pop by age | var_samp | var_samp by gender |
++-----+--------+--------+---------+----------------+----------+--------------------+
+| 5 | 20 | F | 1005 | 0 | 1508 | |
+| 25 | 10 | | 1005 | 0 | 1508 | |
+| 80 | 20 | M | 1005 | 0 | 1508 | |
+| | 40 | F | 1005 | | 1508 | |
+| | 40 | M | 1005 | | 1508 | |
++-----+--------+--------+---------+----------------+----------+--------------------+
+(5 rows)
+
+!ok
+
+# [CALCITE-2402] REGR_SXX, REGR_SXY, REGR_SYY functions
+# SUM(x, y) = SUM(x) WHERE y IS NOT NULL
+# REGR_SXX(x, y) = REGR_COUNT(x, y) * VAR_POP(y, y)
+# REGR_SXY(x, y) = REGR_COUNT(x, y) * COVAR_POP(x, y)
+# REGR_SYY(x, y) = REGR_COUNT(x, y) * VAR_POP(x, x)
+## COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+## VAR_POP(y, y) = (SUM(y * y, x) - SUM(y, x) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+select emps."AGE", emps."DEPTNO",
+ regr_sxx(emps."AGE", emps."DEPTNO") over() as "regr_sxx(age, deptno)",
+ regr_syy(emps."AGE", emps."DEPTNO") over() as "regr_syy(age, deptno)"
+from emps order by emps."AGE";
++-----+--------+-----------------------+-----------------------+
+| AGE | DEPTNO | regr_sxx(age, deptno) | regr_syy(age, deptno) |
++-----+--------+-----------------------+-----------------------+
+| 5 | 20 | 66 | 3015 |
+| 25 | 10 | 66 | 3015 |
+| 80 | 20 | 66 | 3015 |
+| | 40 | 66 | 3015 |
+| | 40 | 66 | 3015 |
++-----+--------+-----------------------+-----------------------+
+(5 rows)
+
+!ok
+
+# [CALCITE-2402] REGR_SXX, REGR_SXY, REGR_SYY functions
+# SUM(x, y) = SUM(x) WHERE y IS NOT NULL
+# REGR_SXX(x, y) = REGR_COUNT(x, y) * COVAR_POP(y, y)
+# REGR_SXY(x, y) = REGR_COUNT(x, y) * COVAR_POP(x, y)
+# REGR_SYY(x, y) = REGR_COUNT(x, y) * COVAR_POP(x, x)
+## COVAR_POP(x, y) = (SUM(x * y) - SUM(x, y) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+## COVAR_POP(y, y) = (SUM(y * y, x) - SUM(y, x) * SUM(y, x) / REGR_COUNT(x, y)) / REGR_COUNT(x, y)
+select emps."AGE", emps."DEPTNO", emps."GENDER",
+ regr_sxx(emps."AGE", emps."DEPTNO") over(partition by emps."GENDER") as "regr_sxx(age, deptno)",
+ regr_syy(emps."AGE", emps."DEPTNO") over(partition by emps."GENDER") as "regr_syy(age, deptno)"
+from emps order by emps."GENDER";
++-----+--------+--------+-----------------------+-----------------------+
+| AGE | DEPTNO | GENDER | regr_sxx(age, deptno) | regr_syy(age, deptno) |
++-----+--------+--------+-----------------------+-----------------------+
+| 5 | 20 | F | 0 | 0 |
+| | 40 | F | 0 | 0 |
+| 80 | 20 | M | 0 | 0 |
+| | 40 | M | 0 | 0 |
+| 25 | 10 | | 0 | 0 |
++-----+--------+--------+-----------------------+-----------------------+
+(5 rows)
+
+!ok
+
# End winagg.iq
http://git-wip-us.apache.org/repos/asf/calcite/blob/ca858dd7/site/_docs/reference.md
----------------------------------------------------------------------
diff --git a/site/_docs/reference.md b/site/_docs/reference.md
index acf8f29..e76e7d6 100644
--- a/site/_docs/reference.md
+++ b/site/_docs/reference.md
@@ -1510,6 +1510,7 @@ passed to the aggregate function.
| VAR_SAMP( [ ALL | DISTINCT ] numeric) | Returns the sample variance (square of the sample standard deviation) of *numeric* across all input values
| COVAR_POP(numeric1, numeric2) | Returns the population covariance of the pair (*numeric1*, *numeric2*) across all input values
| COVAR_SAMP(numeric1, numeric2) | Returns the sample covariance of the pair (*numeric1*, *numeric2*) across all input values
+| REGR_COUNT(numeric1, numeric2) | Returns the number of rows where both dependent and independent expressions are not null
| REGR_SXX(numeric1, numeric2) | Returns the sum of squares of the dependent expression in a linear regression model
| REGR_SYY(numeric1, numeric2) | Returns the sum of squares of the independent expression in a linear regression model
@@ -1517,7 +1518,6 @@ Not implemented:
* REGR_AVGX(numeric1, numeric2)
* REGR_AVGY(numeric1, numeric2)
-* REGR_COUNT(numeric1, numeric2)
* REGR_INTERCEPT(numeric1, numeric2)
* REGR_R2(numeric1, numeric2)
* REGR_SLOPE(numeric1, numeric2)