You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by fj...@apache.org on 2019/06/10 16:40:20 UTC
[incubator-druid] branch master updated: Support var_pop, var_samp,
stddev_pop and stddev_samp etc in sql (#7801)
This is an automated email from the ASF dual-hosted git repository.
fjy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-druid.git
The following commit(s) were added to refs/heads/master by this push:
new ce591d1 Support var_pop, var_samp, stddev_pop and stddev_samp etc in sql (#7801)
ce591d1 is described below
commit ce591d14574cf0cb81d83ddf405a446aeee7648f
Author: Xue Yu <27...@qq.com>
AuthorDate: Tue Jun 11 00:40:09 2019 +0800
Support var_pop, var_samp, stddev_pop and stddev_samp etc in sql (#7801)
* support var_pop, stddev_pop etc in sql
* fix sql compatible
* rebase on master
* update doc
---
docs/content/querying/sql.md | 7 +
extensions-core/stats/pom.xml | 24 +
.../query/aggregation/stats/DruidStatsModule.java | 11 +
.../variance/StandardDeviationPostAggregator.java | 29 ++
.../aggregation/variance/VarianceAggregator.java | 9 +-
.../variance/VarianceBufferAggregator.java | 42 +-
.../variance/sql/BaseVarianceSqlAggregator.java | 193 ++++++++
.../variance/sql/VarianceSqlAggregatorTest.java | 518 +++++++++++++++++++++
8 files changed, 812 insertions(+), 21 deletions(-)
diff --git a/docs/content/querying/sql.md b/docs/content/querying/sql.md
index e16fb40..8921fe0 100644
--- a/docs/content/querying/sql.md
+++ b/docs/content/querying/sql.md
@@ -129,6 +129,13 @@ Only the COUNT aggregation can accept DISTINCT.
|`APPROX_QUANTILE_DS(expr, probability, [k])`|Computes approximate quantiles on numeric or [Quantiles sketch](../development/extensions-core/datasketches-quantiles.html) exprs. The "probability" should be between 0 and 1 (exclusive). The `k` parameter is described in the Quantiles sketch documentation. The [DataSketches extension](../development/extensions-core/datasketches-extension.html) must be loaded to use this function.|
|`APPROX_QUANTILE_FIXED_BUCKETS(expr, probability, numBuckets, lowerLimit, upperLimit, [outlierHandlingMode])`|Computes approximate quantiles on numeric or [fixed buckets histogram](../development/extensions-core/approximate-histograms.html#fixed-buckets-histogram) exprs. The "probability" should be between 0 and 1 (exclusive). The `numBuckets`, `lowerLimit`, `upperLimit`, and `outlierHandlingMode` parameters are described in the fixed buckets histogram documentation. The [approximate hi [...]
|`BLOOM_FILTER(expr, numEntries)`|Computes a bloom filter from values produced by `expr`, with `numEntries` maximum number of distinct values before false positve rate increases. See [bloom filter extension](../development/extensions-core/bloom-filter.html) documentation for additional details.|
+|`VAR_POP(expr)`|Computes variance population of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+|`VAR_SAMP(expr)`|Computes variance sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+|`VARIANCE(expr)`|Computes variance sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+|`STDDEV_POP(expr)`|Computes standard deviation population of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+|`STDDEV_SAMP(expr)`|Computes standard deviation sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+|`STDDEV(expr)`|Computes standard deviation sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+
For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.html#approx).
diff --git a/extensions-core/stats/pom.xml b/extensions-core/stats/pom.xml
index 7f1b11a..d830e5e 100644
--- a/extensions-core/stats/pom.xml
+++ b/extensions-core/stats/pom.xml
@@ -41,6 +41,12 @@
<scope>provided</scope>
</dependency>
<dependency>
+ <groupId>org.apache.druid</groupId>
+ <artifactId>druid-sql</artifactId>
+ <version>${project.parent.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
</dependency>
@@ -53,6 +59,15 @@
<scope>test</scope>
<type>test-jar</type>
</dependency>
+
+ <dependency>
+ <groupId>org.apache.druid</groupId>
+ <artifactId>druid-server</artifactId>
+ <version>${project.parent.version}</version>
+ <scope>test</scope>
+ <type>test-jar</type>
+ </dependency>
+
<dependency>
<groupId>org.apache.druid</groupId>
<artifactId>druid-processing</artifactId>
@@ -60,6 +75,15 @@
<scope>test</scope>
<type>test-jar</type>
</dependency>
+
+ <dependency>
+ <groupId>org.apache.druid</groupId>
+ <artifactId>druid-sql</artifactId>
+ <version>${project.parent.version}</version>
+ <scope>test</scope>
+ <type>test-jar</type>
+ </dependency>
+
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/stats/DruidStatsModule.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/stats/DruidStatsModule.java
index d23c524..1679073 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/stats/DruidStatsModule.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/stats/DruidStatsModule.java
@@ -30,7 +30,9 @@ import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregat
import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
import org.apache.druid.query.aggregation.variance.VarianceFoldingAggregatorFactory;
import org.apache.druid.query.aggregation.variance.VarianceSerde;
+import org.apache.druid.query.aggregation.variance.sql.BaseVarianceSqlAggregator;
import org.apache.druid.segment.serde.ComplexMetrics;
+import org.apache.druid.sql.guice.SqlBindings;
import java.util.List;
@@ -55,6 +57,15 @@ public class DruidStatsModule implements DruidModule
@Override
public void configure(Binder binder)
{
+ if (binder != null) {
+ SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarPopSqlAggregator.class);
+ SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarSampSqlAggregator.class);
+ SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarianceSqlAggregator.class);
+ SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevPopSqlAggregator.class);
+ SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevSampSqlAggregator.class);
+ SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevSqlAggregator.class);
+ }
+
if (ComplexMetrics.getSerdeForType("variance") == null) {
ComplexMetrics.registerSerde("variance", new VarianceSerde());
}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java
index ed1aa21..cef9bca 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java
@@ -32,6 +32,7 @@ import org.apache.druid.query.cache.CacheKeyBuilder;
import java.util.Comparator;
import java.util.Map;
+import java.util.Objects;
import java.util.Set;
/**
@@ -121,4 +122,32 @@ public class StandardDeviationPostAggregator implements PostAggregator
.appendBoolean(isVariancePop)
.build();
}
+
+ @Override
+ public boolean equals(Object o)
+ {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+
+ StandardDeviationPostAggregator that = (StandardDeviationPostAggregator) o;
+
+ if (!Objects.equals(name, that.name)) {
+ return false;
+ }
+ if (!Objects.equals(fieldName, that.fieldName)) {
+ return false;
+ }
+ if (!Objects.equals(estimator, that.estimator)) {
+ return false;
+ }
+ if (!Objects.equals(isVariancePop, that.isVariancePop)) {
+ return false;
+ }
+
+ return true;
+ }
}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java
index 438ee1f..935a1b8 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java
@@ -19,6 +19,7 @@
package org.apache.druid.query.aggregation.variance;
+import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.Aggregator;
import org.apache.druid.segment.BaseFloatColumnValueSelector;
import org.apache.druid.segment.BaseLongColumnValueSelector;
@@ -76,7 +77,9 @@ public abstract class VarianceAggregator implements Aggregator
@Override
public void aggregate()
{
- holder.add(selector.getFloat());
+ if (NullHandling.replaceWithDefault() || !selector.isNull()) {
+ holder.add(selector.getFloat());
+ }
}
}
@@ -93,7 +96,9 @@ public abstract class VarianceAggregator implements Aggregator
@Override
public void aggregate()
{
- holder.add(selector.getLong());
+ if (NullHandling.replaceWithDefault() || !selector.isNull()) {
+ holder.add(selector.getLong());
+ }
}
}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
index 8c871ec..ae30992 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
@@ -20,12 +20,12 @@
package org.apache.druid.query.aggregation.variance;
import com.google.common.base.Preconditions;
+import org.apache.druid.common.config.NullHandling;
import org.apache.druid.query.aggregation.BufferAggregator;
import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
import org.apache.druid.segment.BaseFloatColumnValueSelector;
import org.apache.druid.segment.BaseLongColumnValueSelector;
import org.apache.druid.segment.BaseObjectColumnValueSelector;
-
import java.nio.ByteBuffer;
/**
@@ -89,15 +89,17 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
@Override
public void aggregate(ByteBuffer buf, int position)
{
- float v = selector.getFloat();
- long count = buf.getLong(position + COUNT_OFFSET) + 1;
- double sum = buf.getDouble(position + SUM_OFFSET) + v;
- buf.putLong(position, count);
- buf.putDouble(position + SUM_OFFSET, sum);
- if (count > 1) {
- double t = count * v - sum;
- double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
- buf.putDouble(position + NVARIANCE_OFFSET, variance);
+ if (NullHandling.replaceWithDefault() || !selector.isNull()) {
+ float v = selector.getFloat();
+ long count = buf.getLong(position + COUNT_OFFSET) + 1;
+ double sum = buf.getDouble(position + SUM_OFFSET) + v;
+ buf.putLong(position, count);
+ buf.putDouble(position + SUM_OFFSET, sum);
+ if (count > 1) {
+ double t = count * v - sum;
+ double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
+ buf.putDouble(position + NVARIANCE_OFFSET, variance);
+ }
}
}
@@ -120,15 +122,17 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
@Override
public void aggregate(ByteBuffer buf, int position)
{
- long v = selector.getLong();
- long count = buf.getLong(position + COUNT_OFFSET) + 1;
- double sum = buf.getDouble(position + SUM_OFFSET) + v;
- buf.putLong(position, count);
- buf.putDouble(position + SUM_OFFSET, sum);
- if (count > 1) {
- double t = count * v - sum;
- double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
- buf.putDouble(position + NVARIANCE_OFFSET, variance);
+ if (NullHandling.replaceWithDefault() || !selector.isNull()) {
+ long v = selector.getLong();
+ long count = buf.getLong(position + COUNT_OFFSET) + 1;
+ double sum = buf.getDouble(position + SUM_OFFSET) + v;
+ buf.putLong(position, count);
+ buf.putDouble(position + SUM_OFFSET, sum);
+ if (count > 1) {
+ double t = count * v - sum;
+ double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
+ buf.putDouble(position + NVARIANCE_OFFSET, variance);
+ }
}
}
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
new file mode 100644
index 0000000..f2da37a
--- /dev/null
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance.sql;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.query.aggregation.AggregatorFactory;
+import org.apache.druid.query.aggregation.PostAggregator;
+import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
+import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
+import org.apache.druid.query.dimension.DefaultDimensionSpec;
+import org.apache.druid.query.dimension.DimensionSpec;
+import org.apache.druid.segment.VirtualColumn;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.sql.calcite.aggregation.Aggregation;
+import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
+import org.apache.druid.sql.calcite.expression.DruidExpression;
+import org.apache.druid.sql.calcite.expression.Expressions;
+import org.apache.druid.sql.calcite.planner.Calcites;
+import org.apache.druid.sql.calcite.planner.PlannerContext;
+import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
+import org.apache.druid.sql.calcite.table.RowSignature;
+
+import javax.annotation.Nullable;
+import java.util.ArrayList;
+import java.util.List;
+
+public abstract class BaseVarianceSqlAggregator implements SqlAggregator
+{
+ @Nullable
+ @Override
+ public Aggregation toDruidAggregation(
+ PlannerContext plannerContext,
+ RowSignature rowSignature,
+ VirtualColumnRegistry virtualColumnRegistry,
+ RexBuilder rexBuilder,
+ String name,
+ AggregateCall aggregateCall,
+ Project project,
+ List<Aggregation> existingAggregations,
+ boolean finalizeAggregations
+ )
+ {
+ final RexNode inputOperand = Expressions.fromFieldAccess(
+ rowSignature,
+ project,
+ aggregateCall.getArgList().get(0)
+ );
+ final DruidExpression input = Expressions.toDruidExpression(
+ plannerContext,
+ rowSignature,
+ inputOperand
+ );
+ if (input == null) {
+ return null;
+ }
+
+ final AggregatorFactory aggregatorFactory;
+ final SqlTypeName sqlTypeName = inputOperand.getType().getSqlTypeName();
+ final ValueType inputType = Calcites.getValueTypeForSqlTypeName(sqlTypeName);
+ final List<VirtualColumn> virtualColumns = new ArrayList<>();
+ final DimensionSpec dimensionSpec;
+ final String aggName = StringUtils.format("%s:agg", name);
+ final SqlAggFunction func = calciteFunction();
+ final String estimator;
+ final String inputTypeName;
+ PostAggregator postAggregator = null;
+
+ if (input.isSimpleExtraction()) {
+ dimensionSpec = input.getSimpleExtraction().toDimensionSpec(null, inputType);
+ } else {
+ VirtualColumn virtualColumn =
+ virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, sqlTypeName);
+ dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
+ virtualColumns.add(virtualColumn);
+ }
+
+ if (inputType == ValueType.LONG) {
+ inputTypeName = "long";
+ } else if (inputType == ValueType.FLOAT || inputType == ValueType.DOUBLE) {
+ inputTypeName = "float";
+ } else {
+ throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType);
+ }
+
+ if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) {
+ estimator = "population";
+ } else {
+ estimator = "sample";
+ }
+
+ aggregatorFactory = new VarianceAggregatorFactory(
+ aggName,
+ dimensionSpec.getDimension(),
+ estimator,
+ inputTypeName
+ );
+
+ if (func == SqlStdOperatorTable.STDDEV_POP
+ || func == SqlStdOperatorTable.STDDEV_SAMP
+ || func == SqlStdOperatorTable.STDDEV) {
+ postAggregator = new StandardDeviationPostAggregator(
+ name,
+ aggregatorFactory.getName(),
+ estimator);
+ }
+
+ return Aggregation.create(
+ virtualColumns,
+ ImmutableList.of(aggregatorFactory),
+ postAggregator
+ );
+ }
+
+ public static class VarPopSqlAggregator extends BaseVarianceSqlAggregator
+ {
+ @Override
+ public SqlAggFunction calciteFunction()
+ {
+ return SqlStdOperatorTable.VAR_POP;
+ }
+ }
+
+ public static class VarSampSqlAggregator extends BaseVarianceSqlAggregator
+ {
+ @Override
+ public SqlAggFunction calciteFunction()
+ {
+ return SqlStdOperatorTable.VAR_SAMP;
+ }
+ }
+
+ public static class VarianceSqlAggregator extends BaseVarianceSqlAggregator
+ {
+ @Override
+ public SqlAggFunction calciteFunction()
+ {
+ return SqlStdOperatorTable.VARIANCE;
+ }
+ }
+
+ public static class StdDevPopSqlAggregator extends BaseVarianceSqlAggregator
+ {
+ @Override
+ public SqlAggFunction calciteFunction()
+ {
+ return SqlStdOperatorTable.STDDEV_POP;
+ }
+ }
+
+ public static class StdDevSampSqlAggregator extends BaseVarianceSqlAggregator
+ {
+ @Override
+ public SqlAggFunction calciteFunction()
+ {
+ return SqlStdOperatorTable.STDDEV_SAMP;
+ }
+ }
+
+ public static class StdDevSqlAggregator extends BaseVarianceSqlAggregator
+ {
+ @Override
+ public SqlAggFunction calciteFunction()
+ {
+ return SqlStdOperatorTable.STDDEV;
+ }
+ }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
new file mode 100644
index 0000000..50b7177
--- /dev/null
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
@@ -0,0 +1,518 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance.sql;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.data.input.InputRow;
+import org.apache.druid.data.input.impl.DimensionSchema;
+import org.apache.druid.data.input.impl.DimensionsSpec;
+import org.apache.druid.data.input.impl.DoubleDimensionSchema;
+import org.apache.druid.data.input.impl.FloatDimensionSchema;
+import org.apache.druid.data.input.impl.InputRowParser;
+import org.apache.druid.data.input.impl.LongDimensionSchema;
+import org.apache.druid.data.input.impl.MapInputRowParser;
+import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
+import org.apache.druid.data.input.impl.TimestampSpec;
+import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.java.util.common.io.Closer;
+import org.apache.druid.query.Druids;
+import org.apache.druid.query.QueryRunnerFactoryConglomerate;
+import org.apache.druid.query.aggregation.CountAggregatorFactory;
+import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
+import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
+import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
+import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
+import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
+import org.apache.druid.segment.IndexBuilder;
+import org.apache.druid.segment.QueryableIndex;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.incremental.IncrementalIndexSchema;
+import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
+import org.apache.druid.server.security.AuthTestUtils;
+import org.apache.druid.server.security.AuthenticationResult;
+import org.apache.druid.sql.SqlLifecycle;
+import org.apache.druid.sql.SqlLifecycleFactory;
+import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
+import org.apache.druid.sql.calcite.filtration.Filtration;
+import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
+import org.apache.druid.sql.calcite.planner.PlannerConfig;
+import org.apache.druid.sql.calcite.planner.PlannerFactory;
+import org.apache.druid.sql.calcite.schema.DruidSchema;
+import org.apache.druid.sql.calcite.schema.SystemSchema;
+import org.apache.druid.sql.calcite.util.CalciteTests;
+import org.apache.druid.sql.calcite.util.QueryLogHook;
+import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
+import org.apache.druid.timeline.DataSegment;
+import org.apache.druid.timeline.partition.LinearShardSpec;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.IOException;
+import java.util.List;
+
+public class VarianceSqlAggregatorTest
+{
+ private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
+ private static final String DATA_SOURCE = "numfoo";
+
+ private static QueryRunnerFactoryConglomerate conglomerate;
+ private static Closer resourceCloser;
+
+ @BeforeClass
+ public static void setUpClass()
+ {
+ final Pair<QueryRunnerFactoryConglomerate, Closer> conglomerateCloserPair = CalciteTests
+ .createQueryRunnerFactoryConglomerate();
+ conglomerate = conglomerateCloserPair.lhs;
+ resourceCloser = conglomerateCloserPair.rhs;
+ }
+
+ @AfterClass
+ public static void tearDownClass() throws IOException
+ {
+ resourceCloser.close();
+ }
+
+ @Rule
+ public TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+ @Rule
+ public QueryLogHook queryLogHook = QueryLogHook.create();
+
+ private SpecificSegmentsQuerySegmentWalker walker;
+ private SqlLifecycleFactory sqlLifecycleFactory;
+
+ @Before
+ public void setUp() throws Exception
+ {
+ InputRowParser parser = new MapInputRowParser(
+ new TimeAndDimsParseSpec(
+ new TimestampSpec("t", "iso", null),
+ new DimensionsSpec(
+ ImmutableList.<DimensionSchema>builder()
+ .addAll(DimensionsSpec.getDefaultSchemas(ImmutableList.of("dim1", "dim2", "dim3")))
+ .add(new DoubleDimensionSchema("d1"))
+ .add(new FloatDimensionSchema("f1"))
+ .add(new LongDimensionSchema("l1"))
+ .build(),
+ null,
+ null
+ )
+ ));
+
+ final QueryableIndex index =
+ IndexBuilder.create()
+ .tmpDir(temporaryFolder.newFolder())
+ .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
+ .schema(
+ new IncrementalIndexSchema.Builder()
+ .withMetrics(
+ new CountAggregatorFactory("cnt"),
+ new DoubleSumAggregatorFactory("m1", "m1")
+ )
+ .withDimensionsSpec(parser)
+ .withRollup(false)
+ .build()
+ )
+ .rows(CalciteTests.ROWS1_WITH_NUMERIC_DIMS)
+ .buildMMappedIndex();
+
+ walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
+ DataSegment.builder()
+ .dataSource(DATA_SOURCE)
+ .interval(index.getDataInterval())
+ .version("1")
+ .shardSpec(new LinearShardSpec(0))
+ .build(),
+ index
+ );
+
+ final PlannerConfig plannerConfig = new PlannerConfig();
+ final DruidSchema druidSchema = CalciteTests.createMockSchema(conglomerate, walker, plannerConfig);
+ final SystemSchema systemSchema = CalciteTests.createMockSystemSchema(druidSchema, walker, plannerConfig);
+ final DruidOperatorTable operatorTable = new DruidOperatorTable(
+ ImmutableSet.of(
+ new BaseVarianceSqlAggregator.VarPopSqlAggregator(),
+ new BaseVarianceSqlAggregator.VarSampSqlAggregator(),
+ new BaseVarianceSqlAggregator.VarianceSqlAggregator(),
+ new BaseVarianceSqlAggregator.StdDevPopSqlAggregator(),
+ new BaseVarianceSqlAggregator.StdDevSampSqlAggregator(),
+ new BaseVarianceSqlAggregator.StdDevSqlAggregator()
+ ),
+ ImmutableSet.of()
+ );
+
+ sqlLifecycleFactory = CalciteTests.createSqlLifecycleFactory(
+ new PlannerFactory(
+ druidSchema,
+ systemSchema,
+ CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+ operatorTable,
+ CalciteTests.createExprMacroTable(),
+ plannerConfig,
+ AuthTestUtils.TEST_AUTHORIZER_MAPPER,
+ CalciteTests.getJsonMapper()
+ )
+ );
+ }
+
+ @After
+ public void tearDown() throws Exception
+ {
+ walker.close();
+ walker = null;
+ }
+
+ public void addToHolder(VarianceAggregatorCollector holder, Object raw)
+ {
+ addToHolder(holder, raw, 1);
+ }
+
+ public void addToHolder(VarianceAggregatorCollector holder, Object raw, int multiply)
+ {
+ if (raw != null) {
+ if (raw instanceof Double) {
+ double v = ((Double) raw).doubleValue() * multiply;
+ holder.add((float) v);
+ } else if (raw instanceof Float) {
+ float v = ((Float) raw).floatValue() * multiply;
+ holder.add(v);
+ } else if (raw instanceof Long) {
+ long v = ((Long) raw).longValue() * multiply;
+ holder.add(v);
+ } else if (raw instanceof Integer) {
+ int v = ((Integer) raw).intValue() * multiply;
+ holder.add(v);
+ }
+ } else {
+ if (NullHandling.replaceWithDefault()) {
+ holder.add(0.0f);
+ }
+ }
+ }
+
+ @Test
+ public void testVarPop() throws Exception
+ {
+ SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
+ final String sql = "SELECT\n"
+ + "VAR_POP(d1),\n"
+ + "VAR_POP(f1),\n"
+ + "VAR_POP(l1)\n"
+ + "FROM numfoo";
+
+ final List<Object[]> results =
+ sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
+
+ VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
+ VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
+ VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
+ for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
+ Object raw1 = row.getRaw("d1");
+ Object raw2 = row.getRaw("f1");
+ Object raw3 = row.getRaw("l1");
+ addToHolder(holder1, raw1);
+ addToHolder(holder2, raw2);
+ addToHolder(holder3, raw3);
+ }
+
+ final List<Object[]> expectedResults = ImmutableList.of(
+ new Object[]{
+ holder1.getVariance(true),
+ (float) holder2.getVariance(true),
+ (long) holder3.getVariance(true),
+ }
+ );
+ Assert.assertEquals(expectedResults.size(), results.size());
+ for (int i = 0; i < expectedResults.size(); i++) {
+ Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
+ }
+
+ Assert.assertEquals(
+ Druids.newTimeseriesQueryBuilder()
+ .dataSource(CalciteTests.DATASOURCE3)
+ .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+ .granularity(Granularities.ALL)
+ .aggregators(
+ ImmutableList.of(
+ new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
+ new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
+ new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
+ )
+ )
+ .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
+ .build(),
+ Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
+ );
+ }
+
+ @Test
+ public void testVarSamp() throws Exception
+ {
+ SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
+ final String sql = "SELECT\n"
+ + "VAR_SAMP(d1),\n"
+ + "VAR_SAMP(f1),\n"
+ + "VAR_SAMP(l1)\n"
+ + "FROM numfoo";
+
+ final List<Object[]> results =
+ sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
+
+ VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
+ VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
+ VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
+ for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
+ Object raw1 = row.getRaw("d1");
+ Object raw2 = row.getRaw("f1");
+ Object raw3 = row.getRaw("l1");
+ addToHolder(holder1, raw1);
+ addToHolder(holder2, raw2);
+ addToHolder(holder3, raw3);
+ }
+
+ final List<Object[]> expectedResults = ImmutableList.of(
+ new Object[]{
+ holder1.getVariance(false),
+ (float) holder2.getVariance(false),
+ (long) holder3.getVariance(false),
+ }
+ );
+ Assert.assertEquals(expectedResults.size(), results.size());
+ for (int i = 0; i < expectedResults.size(); i++) {
+ Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
+ }
+
+ Assert.assertEquals(
+ Druids.newTimeseriesQueryBuilder()
+ .dataSource(CalciteTests.DATASOURCE3)
+ .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+ .granularity(Granularities.ALL)
+ .aggregators(
+ ImmutableList.of(
+ new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
+ new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
+ new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
+ )
+ )
+ .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
+ .build(),
+ Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
+ );
+ }
+
+ @Test
+ public void testStdDevPop() throws Exception
+ {
+ SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
+ final String sql = "SELECT\n"
+ + "STDDEV_POP(d1),\n"
+ + "STDDEV_POP(f1),\n"
+ + "STDDEV_POP(l1)\n"
+ + "FROM numfoo";
+
+ final List<Object[]> results =
+ sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
+
+ VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
+ VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
+ VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
+ for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
+ Object raw1 = row.getRaw("d1");
+ Object raw2 = row.getRaw("f1");
+ Object raw3 = row.getRaw("l1");
+ addToHolder(holder1, raw1);
+ addToHolder(holder2, raw2);
+ addToHolder(holder3, raw3);
+ }
+
+ final List<Object[]> expectedResults = ImmutableList.of(
+ new Object[]{
+ Math.sqrt(holder1.getVariance(true)),
+ (float) Math.sqrt(holder2.getVariance(true)),
+ (long) Math.sqrt(holder3.getVariance(true)),
+ }
+ );
+ Assert.assertEquals(expectedResults.size(), results.size());
+ for (int i = 0; i < expectedResults.size(); i++) {
+ Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
+ }
+
+ Assert.assertEquals(
+ Druids.newTimeseriesQueryBuilder()
+ .dataSource(CalciteTests.DATASOURCE3)
+ .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+ .granularity(Granularities.ALL)
+ .aggregators(
+ ImmutableList.of(
+ new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
+ new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
+ new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
+ )
+ )
+ .postAggregators(
+ ImmutableList.of(
+ new StandardDeviationPostAggregator("a0", "a0:agg", "population"),
+ new StandardDeviationPostAggregator("a1", "a1:agg", "population"),
+ new StandardDeviationPostAggregator("a2", "a2:agg", "population"))
+ )
+ .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
+ .build(),
+ Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
+ );
+ }
+
+ @Test
+ public void testStdDevSamp() throws Exception
+ {
+ queryLogHook.clearRecordedQueries();
+ SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
+ final String sql = "SELECT\n"
+ + "STDDEV_SAMP(d1),\n"
+ + "STDDEV_SAMP(f1),\n"
+ + "STDDEV_SAMP(l1)\n"
+ + "FROM numfoo";
+
+ final List<Object[]> results =
+ sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
+
+ VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
+ VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
+ VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
+ for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
+ Object raw1 = row.getRaw("d1");
+ Object raw2 = row.getRaw("f1");
+ Object raw3 = row.getRaw("l1");
+ addToHolder(holder1, raw1);
+ addToHolder(holder2, raw2);
+ addToHolder(holder3, raw3);
+ }
+
+ final List<Object[]> expectedResults = ImmutableList.of(
+ new Object[]{
+ Math.sqrt(holder1.getVariance(false)),
+ (float) Math.sqrt(holder2.getVariance(false)),
+ (long) Math.sqrt(holder3.getVariance(false)),
+ }
+ );
+ Assert.assertEquals(expectedResults.size(), results.size());
+ for (int i = 0; i < expectedResults.size(); i++) {
+ Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
+ }
+
+ Assert.assertEquals(
+ Druids.newTimeseriesQueryBuilder()
+ .dataSource(CalciteTests.DATASOURCE3)
+ .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+ .granularity(Granularities.ALL)
+ .aggregators(
+ ImmutableList.of(
+ new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
+ new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
+ new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
+ )
+ )
+ .postAggregators(
+ new StandardDeviationPostAggregator("a0", "a0:agg", "sample"),
+ new StandardDeviationPostAggregator("a1", "a1:agg", "sample"),
+ new StandardDeviationPostAggregator("a2", "a2:agg", "sample")
+ )
+ .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
+ .build(),
+ Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
+ );
+ }
+
+ @Test
+ public void testStdDevWithVirtualColumns() throws Exception
+ {
+ queryLogHook.clearRecordedQueries();
+ SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
+ final String sql = "SELECT\n"
+ + "STDDEV(d1*7),\n"
+ + "STDDEV(f1*7),\n"
+ + "STDDEV(l1*7)\n"
+ + "FROM numfoo";
+
+ final List<Object[]> results =
+ sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
+
+ VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
+ VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
+ VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
+ for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
+ Object raw1 = row.getRaw("d1");
+ Object raw2 = row.getRaw("f1");
+ Object raw3 = row.getRaw("l1");
+ addToHolder(holder1, raw1, 7);
+ addToHolder(holder2, raw2, 7);
+ addToHolder(holder3, raw3, 7);
+ }
+
+ final List<Object[]> expectedResults = ImmutableList.of(
+ new Object[]{
+ Math.sqrt(holder1.getVariance(false)),
+ (float) Math.sqrt(holder2.getVariance(false)),
+ (long) Math.sqrt(holder3.getVariance(false)),
+ }
+ );
+ Assert.assertEquals(expectedResults.size(), results.size());
+ for (int i = 0; i < expectedResults.size(); i++) {
+ Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
+ }
+
+ Assert.assertEquals(
+ Druids.newTimeseriesQueryBuilder()
+ .dataSource(CalciteTests.DATASOURCE3)
+ .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+ .granularity(Granularities.ALL)
+ .virtualColumns(
+ BaseCalciteQueryTest.expressionVirtualColumn("v0", "(\"d1\" * 7)", ValueType.DOUBLE),
+ BaseCalciteQueryTest.expressionVirtualColumn("v1", "(\"f1\" * 7)", ValueType.FLOAT),
+ BaseCalciteQueryTest.expressionVirtualColumn("v2", "(\"l1\" * 7)", ValueType.LONG)
+ )
+ .aggregators(
+ ImmutableList.of(
+ new VarianceAggregatorFactory("a0:agg", "v0", "sample", "float"),
+ new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"),
+ new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long")
+ )
+ )
+ .postAggregators(
+ new StandardDeviationPostAggregator("a0", "a0:agg", "sample"),
+ new StandardDeviationPostAggregator("a1", "a1:agg", "sample"),
+ new StandardDeviationPostAggregator("a2", "a2:agg", "sample")
+ )
+ .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
+ .build(),
+ Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
+ );
+ }
+}
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org