You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@calcite.apache.org by mm...@apache.org on 2017/09/05 14:36:50 UTC
[07/16] calcite git commit: [CALCITE-1945] Make return types of AVG,
VARIANCE, STDDEV and COVAR customizable via RelDataTypeSystem
[CALCITE-1945] Make return types of AVG, VARIANCE, STDDEV and COVAR customizable via RelDataTypeSystem
* Introduce VARIANCE and STDDEV as alias for _SAMP
Close apache/calcite#518
Project: http://git-wip-us.apache.org/repos/asf/calcite/repo
Commit: http://git-wip-us.apache.org/repos/asf/calcite/commit/4208d802
Tree: http://git-wip-us.apache.org/repos/asf/calcite/tree/4208d802
Diff: http://git-wip-us.apache.org/repos/asf/calcite/diff/4208d802
Branch: refs/heads/branch-1.14
Commit: 4208d8021b4978a3f0a259ec299fa7a62c582180
Parents: 6d2fc4e
Author: MinJi Kim <mi...@apache.org>
Authored: Sun Aug 27 16:21:22 2017 -0700
Committer: Julian Hyde <jh...@apache.org>
Committed: Tue Aug 29 10:15:17 2017 -0700
----------------------------------------------------------------------
.../calcite/rel/rel2sql/SqlImplementor.java | 6 +-
.../rel/rules/AggregateReduceFunctionsRule.java | 127 ++++++++++++-------
.../calcite/rel/type/RelDataTypeSystem.java | 18 ++-
.../calcite/rel/type/RelDataTypeSystemImpl.java | 14 +-
.../apache/calcite/runtime/SqlFunctions.java | 4 +
.../java/org/apache/calcite/sql/SqlKind.java | 9 ++
.../calcite/sql/fun/SqlAvgAggFunction.java | 15 ++-
.../calcite/sql/fun/SqlCovarAggFunction.java | 2 +-
.../calcite/sql/fun/SqlStdOperatorTable.java | 12 ++
.../apache/calcite/sql/type/ReturnTypes.java | 34 ++++-
.../sql2rel/StandardConvertletTable.java | 49 +++++--
.../calcite/sql/test/SqlOperatorBaseTest.java | 70 ++++++++++
core/src/test/resources/sql/agg.iq | 24 +++-
13 files changed, 303 insertions(+), 81 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
index d227310..57155b7 100644
--- a/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
+++ b/core/src/main/java/org/apache/calcite/rel/rel2sql/SqlImplementor.java
@@ -665,7 +665,11 @@ public abstract class SqlImplementor {
}
final RexCall call = (RexCall) stripCastFromString(rex);
- final SqlOperator op = call.getOperator();
+ SqlOperator op = call.getOperator();
+ switch (op.getKind()) {
+ case SUM0:
+ op = SqlStdOperatorTable.SUM;
+ }
final List<SqlNode> nodeList = toSql(program, call.getOperands());
switch (call.getKind()) {
case CAST:
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
index 8fceff0..7e6e4a1 100644
--- a/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
+++ b/core/src/main/java/org/apache/calcite/rel/rules/AggregateReduceFunctionsRule.java
@@ -31,10 +31,9 @@ import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
-import org.apache.calcite.sql.fun.SqlAvgAggFunction;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
-import org.apache.calcite.sql.fun.SqlSumAggFunction;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
@@ -117,8 +116,7 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
*/
private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
for (AggregateCall call : aggCallList) {
- if (call.getAggregation() instanceof SqlAvgAggFunction
- || call.getAggregation() instanceof SqlSumAggFunction) {
+ if (isReducible(call.getAggregation().getKind())) {
return true;
}
}
@@ -126,6 +124,20 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
}
/**
+ * Returns whether the aggregate call is a reducible function
+ */
+ private boolean isReducible(final SqlKind kind) {
+ if (SqlKind.AVG_AGG_FUNCTIONS.contains(kind)) {
+ return true;
+ }
+ switch (kind) {
+ case SUM:
+ return true;
+ }
+ return false;
+ }
+
+ /**
* Reduces all calls to AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP in
* the aggregates list to.
*
@@ -187,17 +199,16 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
List<AggregateCall> newCalls,
Map<AggregateCall, RexNode> aggCallMapping,
List<RexNode> inputExprs) {
- if (oldCall.getAggregation() instanceof SqlSumAggFunction) {
- // replace original SUM(x) with
- // case COUNT(x) when 0 then null else SUM0(x) end
- return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
- }
- if (oldCall.getAggregation() instanceof SqlAvgAggFunction) {
- final SqlKind kind = oldCall.getAggregation().getKind();
+ final SqlKind kind = oldCall.getAggregation().getKind();
+ if (isReducible(kind)) {
switch (kind) {
+ case SUM:
+ // replace original SUM(x) with
+ // case COUNT(x) when 0 then null else SUM0(x) end
+ return reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
case AVG:
// replace original AVG(x) with SUM(x) / COUNT(x)
- return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping);
+ return reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
case STDDEV_POP:
// replace original STDDEV_POP(x) with
// SQRT(
@@ -243,19 +254,39 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
}
}
+ private AggregateCall createAggregateCallWithBinding(
+ RelDataTypeFactory typeFactory,
+ SqlAggFunction aggFunction,
+ RelDataType operandType,
+ Aggregate oldAggRel,
+ AggregateCall oldCall,
+ int argOrdinal) {
+ final Aggregate.AggCallBinding binding =
+ new Aggregate.AggCallBinding(typeFactory, aggFunction,
+ ImmutableList.of(operandType), oldAggRel.getGroupCount(),
+ oldCall.filterArg >= 0);
+ return AggregateCall.create(aggFunction,
+ oldCall.isDistinct(),
+ ImmutableIntList.of(argOrdinal),
+ oldCall.filterArg,
+ aggFunction.inferReturnType(binding),
+ null);
+ }
+
private RexNode reduceAvg(
Aggregate oldAggRel,
AggregateCall oldCall,
List<AggregateCall> newCalls,
- Map<AggregateCall, RexNode> aggCallMapping) {
+ Map<AggregateCall, RexNode> aggCallMapping,
+ List<RexNode> inputExprs) {
final int nGroups = oldAggRel.getGroupCount();
- RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
- int iAvgInput = oldCall.getArgList().get(0);
- RelDataType avgInputType =
+ final RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
+ final int iAvgInput = oldCall.getArgList().get(0);
+ final RelDataType avgInputType =
getFieldType(
oldAggRel.getInput(),
iAvgInput);
- AggregateCall sumCall =
+ final AggregateCall sumCall =
AggregateCall.create(
SqlStdOperatorTable.SUM,
oldCall.isDistinct(),
@@ -265,7 +296,7 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
oldAggRel.getInput(),
null,
null);
- AggregateCall countCall =
+ final AggregateCall countCall =
AggregateCall.create(
SqlStdOperatorTable.COUNT,
oldCall.isDistinct(),
@@ -285,17 +316,20 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType));
- RexNode denominatorRef =
+ final RexNode denominatorRef =
rexBuilder.addAggCall(countCall,
nGroups,
oldAggRel.indicator,
newCalls,
aggCallMapping,
ImmutableList.of(avgInputType));
+
+ final RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
+ final RelDataType avgType = typeFactory.createTypeWithNullability(
+ oldCall.getType(), numeratorRef.getType().isNullable());
+ numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true);
final RexNode divideRef =
- rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE,
- numeratorRef,
- denominatorRef);
+ rexBuilder.makeCall(SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
return rexBuilder.makeCast(oldCall.getType(), divideRef);
}
@@ -381,36 +415,30 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
assert oldCall.getArgList().size() == 1 : oldCall.getArgList();
final int argOrdinal = oldCall.getArgList().get(0);
- final RelDataType argType =
- getFieldType(
- oldAggRel.getInput(),
- argOrdinal);
+ final RelDataType argOrdinalType = getFieldType(oldAggRel.getInput(), argOrdinal);
+ final RelDataType oldCallType =
+ typeFactory.createTypeWithNullability(oldCall.getType(),
+ argOrdinalType.isNullable());
- final RexNode argRef = inputExprs.get(argOrdinal);
- final RexNode argSquared =
- rexBuilder.makeCall(
- SqlStdOperatorTable.MULTIPLY, argRef, argRef);
+ final RexNode argRef =
+ rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), true);
+ final int argRefOrdinal = lookupOrAdd(inputExprs, argRef);
+
+ final RexNode argSquared = rexBuilder.makeCall(SqlStdOperatorTable.MULTIPLY,
+ argRef, argRef);
final int argSquaredOrdinal = lookupOrAdd(inputExprs, argSquared);
- final Aggregate.AggCallBinding binding =
- new Aggregate.AggCallBinding(typeFactory, SqlStdOperatorTable.SUM,
- ImmutableList.of(argRef.getType()), oldAggRel.getGroupCount(),
- oldCall.filterArg >= 0);
final AggregateCall sumArgSquaredAggCall =
- AggregateCall.create(
- SqlStdOperatorTable.SUM,
- oldCall.isDistinct(),
- ImmutableIntList.of(argSquaredOrdinal),
- oldCall.filterArg,
- SqlStdOperatorTable.SUM.inferReturnType(binding),
- null);
+ createAggregateCallWithBinding(typeFactory, SqlStdOperatorTable.SUM,
+ argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal);
+
final RexNode sumArgSquared =
rexBuilder.addAggCall(sumArgSquaredAggCall,
nGroups,
oldAggRel.indicator,
newCalls,
aggCallMapping,
- ImmutableList.of(argType));
+ ImmutableList.of(sumArgSquaredAggCall.getType()));
final AggregateCall sumArgAggCall =
AggregateCall.create(
@@ -422,17 +450,18 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
oldAggRel.getInput(),
null,
null);
+
final RexNode sumArg =
rexBuilder.addAggCall(sumArgAggCall,
nGroups,
oldAggRel.indicator,
newCalls,
aggCallMapping,
- ImmutableList.of(argType));
-
+ ImmutableList.of(sumArgAggCall.getType()));
+ final RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
final RexNode sumSquaredArg =
rexBuilder.makeCall(
- SqlStdOperatorTable.MULTIPLY, sumArg, sumArg);
+ SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);
final AggregateCall countArgAggCall =
AggregateCall.create(
@@ -441,21 +470,21 @@ public class AggregateReduceFunctionsRule extends RelOptRule {
oldCall.getArgList(),
oldCall.filterArg,
oldAggRel.getGroupCount(),
- oldAggRel.getInput(),
+ oldAggRel,
null,
null);
+
final RexNode countArg =
rexBuilder.addAggCall(countArgAggCall,
nGroups,
oldAggRel.indicator,
newCalls,
aggCallMapping,
- ImmutableList.of(argType));
+ ImmutableList.of(argOrdinalType));
final RexNode avgSumSquaredArg =
rexBuilder.makeCall(
- SqlStdOperatorTable.DIVIDE,
- sumSquaredArg, countArg);
+ SqlStdOperatorTable.DIVIDE, sumSquaredArg, countArg);
final RexNode diff =
rexBuilder.makeCall(
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java
index 858567c..b8a8088 100644
--- a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java
+++ b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystem.java
@@ -69,11 +69,21 @@ public interface RelDataTypeSystem {
* 0 means "not applicable". */
int getNumTypeRadix(SqlTypeName typeName);
- /**
- * Returns the return type of a call to the {@code SUM} aggregate function
- * inferred from its argument type.
+ /** Returns the return type of a call to the {@code SUM} aggregate function,
+ * inferred from its argument type. */
+ RelDataType deriveSumType(RelDataTypeFactory typeFactory,
+ RelDataType argumentType);
+
+ /** Returns the return type of a call to the {@code AVG}, {@code STDDEV} or
+ * {@code VAR} aggregate functions, inferred from its argument type.
*/
- RelDataType deriveSumType(RelDataTypeFactory typeFactory, RelDataType argumentType);
+ RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory,
+ RelDataType argumentType);
+
+ /** Returns the return type of a call to the {@code COVAR} aggregate function,
+ * inferred from its argument types. */
+ RelDataType deriveCovarType(RelDataTypeFactory typeFactory,
+ RelDataType arg0Type, RelDataType arg1Type);
/** Returns the return type of the {@code CUME_DIST} and {@code PERCENT_RANK}
* aggregate functions. */
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java
index ef89895..3e0eebd 100644
--- a/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java
+++ b/core/src/main/java/org/apache/calcite/rel/type/RelDataTypeSystemImpl.java
@@ -207,11 +207,21 @@ public abstract class RelDataTypeSystemImpl implements RelDataTypeSystem {
return 0;
}
- @Override public RelDataType deriveSumType(
- RelDataTypeFactory typeFactory, RelDataType argumentType) {
+ @Override public RelDataType deriveSumType(RelDataTypeFactory typeFactory,
+ RelDataType argumentType) {
return argumentType;
}
+ @Override public RelDataType deriveAvgAggType(RelDataTypeFactory typeFactory,
+ RelDataType argumentType) {
+ return argumentType;
+ }
+
+ @Override public RelDataType deriveCovarType(RelDataTypeFactory typeFactory,
+ RelDataType arg0Type, RelDataType arg1Type) {
+ return arg0Type;
+ }
+
@Override public RelDataType deriveFractionalRankType(RelDataTypeFactory typeFactory) {
return typeFactory.createTypeWithNullability(
typeFactory.createSqlType(SqlTypeName.DOUBLE), false);
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
index 69c6154..6832ee4 100644
--- a/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
+++ b/core/src/main/java/org/apache/calcite/runtime/SqlFunctions.java
@@ -823,6 +823,10 @@ public class SqlFunctions {
return Math.pow(b0, b1);
}
+ public static double power(double b0, BigDecimal b1) {
+ return Math.pow(b0, b1.doubleValue());
+ }
+
public static double power(long b0, long b1) {
return Math.pow(b0, b1);
}
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql/SqlKind.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/SqlKind.java b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
index ad7c4e2..8d7c8aa 100644
--- a/core/src/main/java/org/apache/calcite/sql/SqlKind.java
+++ b/core/src/main/java/org/apache/calcite/sql/SqlKind.java
@@ -1119,6 +1119,15 @@ public enum SqlKind {
EnumSet.of(OTHER_FUNCTION, ROW, TRIM, LTRIM, RTRIM, CAST, JDBC_FN);
/**
+ * Category of SqlAvgAggFunction.
+ *
+ * <p>Consists of {@link #AVG}, {@link #STDDEV_POP}, {@link #STDDEV_SAMP},
+ * {@link #VAR_POP}, {@link #VAR_SAMP}.
+ */
+ public static final Set<SqlKind> AVG_AGG_FUNCTIONS =
+ EnumSet.of(AVG, STDDEV_POP, STDDEV_SAMP, VAR_POP, VAR_SAMP);
+
+ /**
* Category of comparison operators.
*
* <p>Consists of:
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql/fun/SqlAvgAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlAvgAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlAvgAggFunction.java
index 95f8049..6be1ce9 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlAvgAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlAvgAggFunction.java
@@ -32,26 +32,27 @@ import com.google.common.base.Preconditions;
* double</code>), and the result is the same type.
*/
public class SqlAvgAggFunction extends SqlAggFunction {
+
//~ Constructors -----------------------------------------------------------
/**
* Creates a SqlAvgAggFunction.
*/
public SqlAvgAggFunction(SqlKind kind) {
- super(kind.name(),
+ this(kind.name(), kind);
+ }
+
+ SqlAvgAggFunction(String name, SqlKind kind) {
+ super(name,
null,
kind,
- ReturnTypes.ARG0_NULLABLE_IF_EMPTY,
+ ReturnTypes.AVG_AGG_FUNCTION,
null,
OperandTypes.NUMERIC,
SqlFunctionCategory.NUMERIC,
false,
false);
- Preconditions.checkArgument(kind == SqlKind.AVG
- || kind == SqlKind.STDDEV_POP
- || kind == SqlKind.STDDEV_SAMP
- || kind == SqlKind.VAR_POP
- || kind == SqlKind.VAR_SAMP);
+ Preconditions.checkArgument(SqlKind.AVG_AGG_FUNCTIONS.contains(kind), "unsupported sql kind");
}
@Deprecated // to be removed before 2.0
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
index ea23300..8c62290 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlCovarAggFunction.java
@@ -43,7 +43,7 @@ public class SqlCovarAggFunction extends SqlAggFunction {
super(kind.name(),
null,
kind,
- ReturnTypes.ARG0_NULLABLE_IF_EMPTY,
+ ReturnTypes.COVAR_FUNCTION,
null,
OperandTypes.NUMERIC_NUMERIC,
SqlFunctionCategory.NUMERIC,
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
index 3f125bd..39a45b3 100644
--- a/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
+++ b/core/src/main/java/org/apache/calcite/sql/fun/SqlStdOperatorTable.java
@@ -918,6 +918,12 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable {
new SqlAvgAggFunction(SqlKind.STDDEV_SAMP);
/**
+ * <code>STDDEV</code> aggregate function.
+ */
+ public static final SqlAggFunction STDDEV =
+ new SqlAvgAggFunction("STDDEV", SqlKind.STDDEV_SAMP);
+
+ /**
* <code>VAR_POP</code> aggregate function.
*/
public static final SqlAggFunction VAR_POP =
@@ -929,6 +935,12 @@ public class SqlStdOperatorTable extends ReflectiveSqlOperatorTable {
public static final SqlAggFunction VAR_SAMP =
new SqlAvgAggFunction(SqlKind.VAR_SAMP);
+ /**
+ * <code>VARIANCE</code> aggregate function.
+ */
+ public static final SqlAggFunction VARIANCE =
+ new SqlAvgAggFunction("VARIANCE", SqlKind.VAR_SAMP);
+
//-------------------------------------------------------------
// WINDOW Aggregate Functions
//-------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
index 73e99f8..15ca544 100644
--- a/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
+++ b/core/src/main/java/org/apache/calcite/sql/type/ReturnTypes.java
@@ -779,8 +779,10 @@ public abstract class ReturnTypes {
@Override public RelDataType
inferReturnType(SqlOperatorBinding opBinding) {
final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
- return typeFactory.getTypeSystem()
+ final RelDataType sumType = typeFactory.getTypeSystem()
.deriveSumType(typeFactory, opBinding.getOperandType(0));
+ // SUM0 should not return null.
+ return typeFactory.createTypeWithNullability(sumType, false);
}
};
@@ -809,6 +811,36 @@ public abstract class ReturnTypes {
return typeFactory.getTypeSystem().deriveRankType(typeFactory);
}
};
+
+ public static final SqlReturnTypeInference AVG_AGG_FUNCTION =
+ new SqlReturnTypeInference() {
+ @Override public RelDataType
+ inferReturnType(SqlOperatorBinding opBinding) {
+ final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
+ final RelDataType relDataType = typeFactory.getTypeSystem().deriveAvgAggType(
+ typeFactory, opBinding.getOperandType(0));
+ if (opBinding.getGroupCount() == 0 || opBinding.hasFilter()) {
+ return typeFactory.createTypeWithNullability(relDataType, true);
+ } else {
+ return relDataType;
+ }
+ }
+ };
+
+ public static final SqlReturnTypeInference COVAR_FUNCTION =
+ new SqlReturnTypeInference() {
+ @Override public RelDataType
+ inferReturnType(SqlOperatorBinding opBinding) {
+ final RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
+ final RelDataType relDataType = typeFactory.getTypeSystem().deriveCovarType(
+ typeFactory, opBinding.getOperandType(0), opBinding.getOperandType(1));
+ if (opBinding.getGroupCount() == 0 || opBinding.hasFilter()) {
+ return typeFactory.createTypeWithNullability(relDataType, true);
+ } else {
+ return relDataType;
+ }
+ }
+ };
}
// End ReturnTypes.java
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
----------------------------------------------------------------------
diff --git a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
index 0d62f9f..8940629 100644
--- a/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
+++ b/core/src/main/java/org/apache/calcite/sql2rel/StandardConvertletTable.java
@@ -299,11 +299,17 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
SqlStdOperatorTable.STDDEV_SAMP,
new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
registerOp(
+ SqlStdOperatorTable.STDDEV,
+ new AvgVarianceConvertlet(SqlKind.STDDEV_SAMP));
+ registerOp(
SqlStdOperatorTable.VAR_POP,
new AvgVarianceConvertlet(SqlKind.VAR_POP));
registerOp(
SqlStdOperatorTable.VAR_SAMP,
new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
+ registerOp(
+ SqlStdOperatorTable.VARIANCE,
+ new AvgVarianceConvertlet(SqlKind.VAR_SAMP));
final SqlRexConvertlet floorCeilConvertlet = new FloorCeilConvertlet();
registerOp(SqlStdOperatorTable.FLOOR, floorCeilConvertlet);
@@ -1272,44 +1278,56 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
assert call.operandCount() == 1;
final SqlNode arg = call.operand(0);
final SqlNode expr;
+ final RelDataType type =
+ cx.getValidator().getValidatedNodeType(call);
switch (kind) {
case AVG:
- expr = expandAvg(arg);
+ expr = expandAvg(arg, type, cx);
break;
case STDDEV_POP:
- expr = expandVariance(arg, true, true);
+ expr = expandVariance(arg, type, cx, true, true);
break;
case STDDEV_SAMP:
- expr = expandVariance(arg, false, true);
+ expr = expandVariance(arg, type, cx, false, true);
break;
case VAR_POP:
- expr = expandVariance(arg, true, false);
+ expr = expandVariance(arg, type, cx, true, false);
break;
case VAR_SAMP:
- expr = expandVariance(arg, false, false);
+ expr = expandVariance(arg, type, cx, false, false);
break;
default:
throw Util.unexpected(kind);
}
- RelDataType type =
- cx.getValidator().getValidatedNodeType(call);
RexNode rex = cx.convertExpression(expr);
return cx.getRexBuilder().ensureType(type, rex, true);
}
private SqlNode expandAvg(
- final SqlNode arg) {
+ final SqlNode arg, final RelDataType avgType, final SqlRexContext cx) {
final SqlParserPos pos = SqlParserPos.ZERO;
final SqlNode sum =
SqlStdOperatorTable.SUM.createCall(pos, arg);
+ final RexNode sumRex = cx.convertExpression(sum);
+ final SqlNode sumCast;
+ if (!sumRex.getType().equals(avgType)) {
+ sumCast = SqlStdOperatorTable.CAST.createCall(pos,
+ new SqlDataTypeSpec(
+ new SqlIdentifier(avgType.getSqlTypeName().getName(), pos),
+ avgType.getPrecision(), avgType.getScale(), null, null, pos));
+ } else {
+ sumCast = sum;
+ }
final SqlNode count =
SqlStdOperatorTable.COUNT.createCall(pos, arg);
return SqlStdOperatorTable.DIVIDE.createCall(
- pos, sum, count);
+ pos, sumCast, count);
}
private SqlNode expandVariance(
- final SqlNode arg,
+ final SqlNode argInput,
+ final RelDataType varType,
+ final SqlRexContext cx,
boolean biased,
boolean sqrt) {
// stddev_pop(x) ==>
@@ -1332,6 +1350,17 @@ public class StandardConvertletTable extends ReflectiveConvertletTable {
// (sum(x * x) - sum(x) * sum(x) / count(x))
// / (count(x) - 1)
final SqlParserPos pos = SqlParserPos.ZERO;
+
+ final RexNode argRex = cx.convertExpression(argInput);
+ final SqlNode arg;
+ if (!argRex.getType().equals(varType)) {
+ arg = SqlStdOperatorTable.CAST.createCall(pos,
+ new SqlDataTypeSpec(new SqlIdentifier(varType.getSqlTypeName().getName(), pos),
+ varType.getPrecision(), varType.getScale(), null, null, pos));
+ } else {
+ arg = argInput;
+ }
+
final SqlNode argSquared =
SqlStdOperatorTable.MULTIPLY.createCall(pos, arg, arg);
final SqlNode sumArgSquared =
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
----------------------------------------------------------------------
diff --git a/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java b/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
index f73921f..15ddb13 100644
--- a/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
+++ b/core/src/test/java/org/apache/calcite/sql/test/SqlOperatorBaseTest.java
@@ -6417,6 +6417,33 @@ public abstract class SqlOperatorBaseTest {
0d);
}
+ @Test public void testStddevFunc() {
+ tester.setFor(SqlStdOperatorTable.STDDEV, VM_EXPAND);
+ tester.checkFails(
+ "stddev(^*^)",
+ "Unknown identifier '\\*'",
+ false);
+ tester.checkFails(
+ "^stddev(cast(null as varchar(2)))^",
+ "(?s)Cannot apply 'STDDEV' to arguments of type 'STDDEV\\(<VARCHAR\\(2\\)>\\)'\\. Supported form\\(s\\): 'STDDEV\\(<NUMERIC>\\)'.*",
+ false);
+ tester.checkType("stddev(CAST(NULL AS INTEGER))", "INTEGER");
+ checkAggType(tester, "stddev(DISTINCT 1.5)", "DECIMAL(2, 1) NOT NULL");
+ final String[] values = {"0", "CAST(null AS FLOAT)", "3", "3"};
+ // with one value
+ tester.checkAgg(
+ "stddev(x)",
+ new String[]{"5"},
+ null,
+ 0d);
+ // with zero values
+ tester.checkAgg(
+ "stddev(x)",
+ new String[]{},
+ null,
+ 0d);
+ }
+
@Test public void testVarPopFunc() {
tester.setFor(SqlStdOperatorTable.VAR_POP, VM_EXPAND);
tester.checkFails(
@@ -6505,6 +6532,49 @@ public abstract class SqlOperatorBaseTest {
0d);
}
+ @Test public void testVarFunc() {
+ tester.setFor(SqlStdOperatorTable.VARIANCE, VM_EXPAND);
+ tester.checkFails(
+ "variance(^*^)",
+ "Unknown identifier '\\*'",
+ false);
+ tester.checkFails(
+ "^variance(cast(null as varchar(2)))^",
+ "(?s)Cannot apply 'VARIANCE' to arguments of type 'VARIANCE\\(<VARCHAR\\(2\\)>\\)'\\. Supported form\\(s\\): 'VARIANCE\\(<NUMERIC>\\)'.*",
+ false);
+ tester.checkType("variance(CAST(NULL AS INTEGER))", "INTEGER");
+ checkAggType(tester, "variance(DISTINCT 1.5)", "DECIMAL(2, 1) NOT NULL");
+ final String[] values = {"0", "CAST(null AS FLOAT)", "3", "3"};
+ if (!enable) {
+ return;
+ }
+ tester.checkAgg(
+ "variance(x)", values, 3d, // verified on Oracle 10g
+ 0d);
+ tester.checkAgg(
+ "variance(DISTINCT x)", // Oracle does not allow distinct
+ values,
+ 4.5d,
+ 0.0001d);
+ tester.checkAgg(
+ "variance(DISTINCT CASE x WHEN 0 THEN NULL ELSE -1 END)",
+ values,
+ null,
+ 0d);
+ // with one value
+ tester.checkAgg(
+ "variance(x)",
+ new String[]{"5"},
+ null,
+ 0d);
+ // with zero values
+ tester.checkAgg(
+ "variance(x)",
+ new String[]{},
+ null,
+ 0d);
+ }
+
@Test public void testMinFunc() {
tester.setFor(SqlStdOperatorTable.MIN, VM_EXPAND);
tester.checkFails(
http://git-wip-us.apache.org/repos/asf/calcite/blob/4208d802/core/src/test/resources/sql/agg.iq
----------------------------------------------------------------------
diff --git a/core/src/test/resources/sql/agg.iq b/core/src/test/resources/sql/agg.iq
index e4ec228..28e8b4b 100755
--- a/core/src/test/resources/sql/agg.iq
+++ b/core/src/test/resources/sql/agg.iq
@@ -85,19 +85,31 @@ select stddev_pop(deptno) as s from emp;
!ok
+# stddev
+select stddev(deptno) as s from emp;
++----+
+| S |
++----+
+| 19 |
++----+
+(1 row)
+
+!ok
+
# both
select gender,
stddev_pop(deptno) as p,
stddev_samp(deptno) as s,
+ stddev(deptno) as ss,
count(deptno) as c
from emp
group by gender;
-+--------+----+----+---+
-| GENDER | P | S | C |
-+--------+----+----+---+
-| F | 17 | 19 | 5 |
-| M | 17 | 20 | 3 |
-+--------+----+----+---+
++--------+----+----+----+---+
+| GENDER | P | S | SS | C |
++--------+----+----+----+---+
+| F | 17 | 19 | 19 | 5 |
+| M | 17 | 20 | 20 | 3 |
++--------+----+----+----+---+
(2 rows)
!ok