You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@druid.apache.org by fj...@apache.org on 2019/06/10 16:40:20 UTC

[incubator-druid] branch master updated: Support var_pop, var_samp, stddev_pop and stddev_samp etc in sql (#7801)

This is an automated email from the ASF dual-hosted git repository.

fjy pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/incubator-druid.git


The following commit(s) were added to refs/heads/master by this push:
     new ce591d1  Support var_pop, var_samp, stddev_pop and stddev_samp etc in sql (#7801)
ce591d1 is described below

commit ce591d14574cf0cb81d83ddf405a446aeee7648f
Author: Xue Yu <27...@qq.com>
AuthorDate: Tue Jun 11 00:40:09 2019 +0800

    Support var_pop, var_samp, stddev_pop and stddev_samp etc in sql (#7801)
    
    * support var_pop, stddev_pop etc in sql
    
    * fix sql compatible
    
    * rebase on master
    
    * update doc
---
 docs/content/querying/sql.md                       |   7 +
 extensions-core/stats/pom.xml                      |  24 +
 .../query/aggregation/stats/DruidStatsModule.java  |  11 +
 .../variance/StandardDeviationPostAggregator.java  |  29 ++
 .../aggregation/variance/VarianceAggregator.java   |   9 +-
 .../variance/VarianceBufferAggregator.java         |  42 +-
 .../variance/sql/BaseVarianceSqlAggregator.java    | 193 ++++++++
 .../variance/sql/VarianceSqlAggregatorTest.java    | 518 +++++++++++++++++++++
 8 files changed, 812 insertions(+), 21 deletions(-)

diff --git a/docs/content/querying/sql.md b/docs/content/querying/sql.md
index e16fb40..8921fe0 100644
--- a/docs/content/querying/sql.md
+++ b/docs/content/querying/sql.md
@@ -129,6 +129,13 @@ Only the COUNT aggregation can accept DISTINCT.
 |`APPROX_QUANTILE_DS(expr, probability, [k])`|Computes approximate quantiles on numeric or [Quantiles sketch](../development/extensions-core/datasketches-quantiles.html) exprs. The "probability" should be between 0 and 1 (exclusive). The `k` parameter is described in the Quantiles sketch documentation. The [DataSketches extension](../development/extensions-core/datasketches-extension.html) must be loaded to use this function.|
 |`APPROX_QUANTILE_FIXED_BUCKETS(expr, probability, numBuckets, lowerLimit, upperLimit, [outlierHandlingMode])`|Computes approximate quantiles on numeric or [fixed buckets histogram](../development/extensions-core/approximate-histograms.html#fixed-buckets-histogram) exprs. The "probability" should be between 0 and 1 (exclusive). The `numBuckets`, `lowerLimit`, `upperLimit`, and `outlierHandlingMode` parameters are described in the fixed buckets histogram documentation. The [approximate hi [...]
 |`BLOOM_FILTER(expr, numEntries)`|Computes a bloom filter from values produced by `expr`, with `numEntries` maximum number of distinct values before false positve rate increases. See [bloom filter extension](../development/extensions-core/bloom-filter.html) documentation for additional details.|
+|`VAR_POP(expr)`|Computes variance population of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+|`VAR_SAMP(expr)`|Computes variance sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+|`VARIANCE(expr)`|Computes variance sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+|`STDDEV_POP(expr)`|Computes standard deviation population of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+|`STDDEV_SAMP(expr)`|Computes standard deviation sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+|`STDDEV(expr)`|Computes standard deviation sample of `expr`. See [stats extension](../development/extensions-core/stats.html) documentation for additional details.|
+
 
 For advice on choosing approximate aggregation functions, check out our [approximate aggregations documentation](aggregations.html#approx).
 
diff --git a/extensions-core/stats/pom.xml b/extensions-core/stats/pom.xml
index 7f1b11a..d830e5e 100644
--- a/extensions-core/stats/pom.xml
+++ b/extensions-core/stats/pom.xml
@@ -41,6 +41,12 @@
             <scope>provided</scope>
         </dependency>
         <dependency>
+          <groupId>org.apache.druid</groupId>
+          <artifactId>druid-sql</artifactId>
+          <version>${project.parent.version}</version>
+          <scope>provided</scope>
+        </dependency>
+        <dependency>
             <groupId>org.apache.commons</groupId>
             <artifactId>commons-math3</artifactId>
         </dependency>
@@ -53,6 +59,15 @@
             <scope>test</scope>
             <type>test-jar</type>
         </dependency>
+
+        <dependency>
+            <groupId>org.apache.druid</groupId>
+            <artifactId>druid-server</artifactId>
+            <version>${project.parent.version}</version>
+            <scope>test</scope>
+            <type>test-jar</type>
+        </dependency>
+
         <dependency>
             <groupId>org.apache.druid</groupId>
             <artifactId>druid-processing</artifactId>
@@ -60,6 +75,15 @@
             <scope>test</scope>
             <type>test-jar</type>
         </dependency>
+
+        <dependency>
+            <groupId>org.apache.druid</groupId>
+            <artifactId>druid-sql</artifactId>
+            <version>${project.parent.version}</version>
+            <scope>test</scope>
+            <type>test-jar</type>
+        </dependency>
+
         <dependency>
             <groupId>junit</groupId>
             <artifactId>junit</artifactId>
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/stats/DruidStatsModule.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/stats/DruidStatsModule.java
index d23c524..1679073 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/stats/DruidStatsModule.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/stats/DruidStatsModule.java
@@ -30,7 +30,9 @@ import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregat
 import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
 import org.apache.druid.query.aggregation.variance.VarianceFoldingAggregatorFactory;
 import org.apache.druid.query.aggregation.variance.VarianceSerde;
+import org.apache.druid.query.aggregation.variance.sql.BaseVarianceSqlAggregator;
 import org.apache.druid.segment.serde.ComplexMetrics;
+import org.apache.druid.sql.guice.SqlBindings;
 
 import java.util.List;
 
@@ -55,6 +57,15 @@ public class DruidStatsModule implements DruidModule
   @Override
   public void configure(Binder binder)
   {
+    if (binder != null) {
+      SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarPopSqlAggregator.class);
+      SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarSampSqlAggregator.class);
+      SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.VarianceSqlAggregator.class);
+      SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevPopSqlAggregator.class);
+      SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevSampSqlAggregator.class);
+      SqlBindings.addAggregator(binder, BaseVarianceSqlAggregator.StdDevSqlAggregator.class);
+    }
+
     if (ComplexMetrics.getSerdeForType("variance") == null) {
       ComplexMetrics.registerSerde("variance", new VarianceSerde());
     }
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java
index ed1aa21..cef9bca 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/StandardDeviationPostAggregator.java
@@ -32,6 +32,7 @@ import org.apache.druid.query.cache.CacheKeyBuilder;
 
 import java.util.Comparator;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 
 /**
@@ -121,4 +122,32 @@ public class StandardDeviationPostAggregator implements PostAggregator
         .appendBoolean(isVariancePop)
         .build();
   }
+
+  @Override
+  public boolean equals(Object o)
+  {
+    if (this == o) {
+      return true;
+    }
+    if (o == null || getClass() != o.getClass()) {
+      return false;
+    }
+
+    StandardDeviationPostAggregator that = (StandardDeviationPostAggregator) o;
+
+    if (!Objects.equals(name, that.name)) {
+      return false;
+    }
+    if (!Objects.equals(fieldName, that.fieldName)) {
+      return false;
+    }
+    if (!Objects.equals(estimator, that.estimator)) {
+      return false;
+    }
+    if (!Objects.equals(isVariancePop, that.isVariancePop)) {
+      return false;
+    }
+
+    return true;
+  }
 }
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java
index 438ee1f..935a1b8 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceAggregator.java
@@ -19,6 +19,7 @@
 
 package org.apache.druid.query.aggregation.variance;
 
+import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.query.aggregation.Aggregator;
 import org.apache.druid.segment.BaseFloatColumnValueSelector;
 import org.apache.druid.segment.BaseLongColumnValueSelector;
@@ -76,7 +77,9 @@ public abstract class VarianceAggregator implements Aggregator
     @Override
     public void aggregate()
     {
-      holder.add(selector.getFloat());
+      if (NullHandling.replaceWithDefault() || !selector.isNull()) {
+        holder.add(selector.getFloat());
+      }
     }
   }
 
@@ -93,7 +96,9 @@ public abstract class VarianceAggregator implements Aggregator
     @Override
     public void aggregate()
     {
-      holder.add(selector.getLong());
+      if (NullHandling.replaceWithDefault() || !selector.isNull()) {
+        holder.add(selector.getLong());
+      }
     }
   }
 
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
index 8c871ec..ae30992 100644
--- a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/VarianceBufferAggregator.java
@@ -20,12 +20,12 @@
 package org.apache.druid.query.aggregation.variance;
 
 import com.google.common.base.Preconditions;
+import org.apache.druid.common.config.NullHandling;
 import org.apache.druid.query.aggregation.BufferAggregator;
 import org.apache.druid.query.monomorphicprocessing.RuntimeShapeInspector;
 import org.apache.druid.segment.BaseFloatColumnValueSelector;
 import org.apache.druid.segment.BaseLongColumnValueSelector;
 import org.apache.druid.segment.BaseObjectColumnValueSelector;
-
 import java.nio.ByteBuffer;
 
 /**
@@ -89,15 +89,17 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
     @Override
     public void aggregate(ByteBuffer buf, int position)
     {
-      float v = selector.getFloat();
-      long count = buf.getLong(position + COUNT_OFFSET) + 1;
-      double sum = buf.getDouble(position + SUM_OFFSET) + v;
-      buf.putLong(position, count);
-      buf.putDouble(position + SUM_OFFSET, sum);
-      if (count > 1) {
-        double t = count * v - sum;
-        double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
-        buf.putDouble(position + NVARIANCE_OFFSET, variance);
+      if (NullHandling.replaceWithDefault() || !selector.isNull()) {
+        float v = selector.getFloat();
+        long count = buf.getLong(position + COUNT_OFFSET) + 1;
+        double sum = buf.getDouble(position + SUM_OFFSET) + v;
+        buf.putLong(position, count);
+        buf.putDouble(position + SUM_OFFSET, sum);
+        if (count > 1) {
+          double t = count * v - sum;
+          double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
+          buf.putDouble(position + NVARIANCE_OFFSET, variance);
+        }
       }
     }
 
@@ -120,15 +122,17 @@ public abstract class VarianceBufferAggregator implements BufferAggregator
     @Override
     public void aggregate(ByteBuffer buf, int position)
     {
-      long v = selector.getLong();
-      long count = buf.getLong(position + COUNT_OFFSET) + 1;
-      double sum = buf.getDouble(position + SUM_OFFSET) + v;
-      buf.putLong(position, count);
-      buf.putDouble(position + SUM_OFFSET, sum);
-      if (count > 1) {
-        double t = count * v - sum;
-        double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
-        buf.putDouble(position + NVARIANCE_OFFSET, variance);
+      if (NullHandling.replaceWithDefault() || !selector.isNull()) {
+        long v = selector.getLong();
+        long count = buf.getLong(position + COUNT_OFFSET) + 1;
+        double sum = buf.getDouble(position + SUM_OFFSET) + v;
+        buf.putLong(position, count);
+        buf.putDouble(position + SUM_OFFSET, sum);
+        if (count > 1) {
+          double t = count * v - sum;
+          double variance = buf.getDouble(position + NVARIANCE_OFFSET) + (t * t) / ((double) count * (count - 1));
+          buf.putDouble(position + NVARIANCE_OFFSET, variance);
+        }
       }
     }
 
diff --git a/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
new file mode 100644
index 0000000..f2da37a
--- /dev/null
+++ b/extensions-core/stats/src/main/java/org/apache/druid/query/aggregation/variance/sql/BaseVarianceSqlAggregator.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance.sql;
+
+import com.google.common.collect.ImmutableList;
+import org.apache.calcite.rel.core.AggregateCall;
+import org.apache.calcite.rel.core.Project;
+import org.apache.calcite.rex.RexBuilder;
+import org.apache.calcite.rex.RexNode;
+import org.apache.calcite.sql.SqlAggFunction;
+import org.apache.calcite.sql.fun.SqlStdOperatorTable;
+import org.apache.calcite.sql.type.SqlTypeName;
+import org.apache.druid.java.util.common.IAE;
+import org.apache.druid.java.util.common.StringUtils;
+import org.apache.druid.query.aggregation.AggregatorFactory;
+import org.apache.druid.query.aggregation.PostAggregator;
+import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
+import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
+import org.apache.druid.query.dimension.DefaultDimensionSpec;
+import org.apache.druid.query.dimension.DimensionSpec;
+import org.apache.druid.segment.VirtualColumn;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.sql.calcite.aggregation.Aggregation;
+import org.apache.druid.sql.calcite.aggregation.SqlAggregator;
+import org.apache.druid.sql.calcite.expression.DruidExpression;
+import org.apache.druid.sql.calcite.expression.Expressions;
+import org.apache.druid.sql.calcite.planner.Calcites;
+import org.apache.druid.sql.calcite.planner.PlannerContext;
+import org.apache.druid.sql.calcite.rel.VirtualColumnRegistry;
+import org.apache.druid.sql.calcite.table.RowSignature;
+
+import javax.annotation.Nullable;
+import java.util.ArrayList;
+import java.util.List;
+
+public abstract class BaseVarianceSqlAggregator implements SqlAggregator
+{
+  @Nullable
+  @Override
+  public Aggregation toDruidAggregation(
+      PlannerContext plannerContext,
+      RowSignature rowSignature,
+      VirtualColumnRegistry virtualColumnRegistry,
+      RexBuilder rexBuilder,
+      String name,
+      AggregateCall aggregateCall,
+      Project project,
+      List<Aggregation> existingAggregations,
+      boolean finalizeAggregations
+  )
+  {
+    final RexNode inputOperand = Expressions.fromFieldAccess(
+        rowSignature,
+        project,
+        aggregateCall.getArgList().get(0)
+    );
+    final DruidExpression input = Expressions.toDruidExpression(
+        plannerContext,
+        rowSignature,
+        inputOperand
+    );
+    if (input == null) {
+      return null;
+    }
+
+    final AggregatorFactory aggregatorFactory;
+    final SqlTypeName sqlTypeName = inputOperand.getType().getSqlTypeName();
+    final ValueType inputType = Calcites.getValueTypeForSqlTypeName(sqlTypeName);
+    final List<VirtualColumn> virtualColumns = new ArrayList<>();
+    final DimensionSpec dimensionSpec;
+    final String aggName = StringUtils.format("%s:agg", name);
+    final SqlAggFunction func = calciteFunction();
+    final String estimator;
+    final String inputTypeName;
+    PostAggregator postAggregator = null;
+
+    if (input.isSimpleExtraction()) {
+      dimensionSpec = input.getSimpleExtraction().toDimensionSpec(null, inputType);
+    } else {
+      VirtualColumn virtualColumn =
+          virtualColumnRegistry.getOrCreateVirtualColumnForExpression(plannerContext, input, sqlTypeName);
+      dimensionSpec = new DefaultDimensionSpec(virtualColumn.getOutputName(), null, inputType);
+      virtualColumns.add(virtualColumn);
+    }
+
+    if (inputType == ValueType.LONG) {
+      inputTypeName = "long";
+    } else if (inputType == ValueType.FLOAT || inputType == ValueType.DOUBLE) {
+      inputTypeName = "float";
+    } else {
+      throw new IAE("VarianceSqlAggregator[%s] has invalid inputType[%s]", func, inputType);
+    }
+
+    if (func == SqlStdOperatorTable.VAR_POP || func == SqlStdOperatorTable.STDDEV_POP) {
+      estimator = "population";
+    } else {
+      estimator = "sample";
+    }
+
+    aggregatorFactory = new VarianceAggregatorFactory(
+        aggName,
+        dimensionSpec.getDimension(),
+        estimator,
+        inputTypeName
+    );
+
+    if (func == SqlStdOperatorTable.STDDEV_POP 
+        || func == SqlStdOperatorTable.STDDEV_SAMP
+        || func == SqlStdOperatorTable.STDDEV) {
+      postAggregator = new StandardDeviationPostAggregator(
+          name,
+          aggregatorFactory.getName(),
+          estimator);
+    }
+
+    return Aggregation.create(
+        virtualColumns,
+        ImmutableList.of(aggregatorFactory),
+        postAggregator
+    );
+  }
+
+  public static class VarPopSqlAggregator extends BaseVarianceSqlAggregator
+  {
+    @Override
+    public SqlAggFunction calciteFunction()
+    {
+      return SqlStdOperatorTable.VAR_POP;
+    }
+  }
+  
+  public static class VarSampSqlAggregator extends BaseVarianceSqlAggregator
+  {
+    @Override
+    public SqlAggFunction calciteFunction()
+    {
+      return SqlStdOperatorTable.VAR_SAMP;
+    }
+  }
+
+  public static class VarianceSqlAggregator extends BaseVarianceSqlAggregator
+  {
+    @Override
+    public SqlAggFunction calciteFunction()
+    {
+      return SqlStdOperatorTable.VARIANCE;
+    }
+  }
+
+  public static class StdDevPopSqlAggregator extends BaseVarianceSqlAggregator
+  {
+    @Override
+    public SqlAggFunction calciteFunction()
+    {
+      return SqlStdOperatorTable.STDDEV_POP;
+    }
+  }
+
+  public static class StdDevSampSqlAggregator extends BaseVarianceSqlAggregator
+  {
+    @Override
+    public SqlAggFunction calciteFunction()
+    {
+      return SqlStdOperatorTable.STDDEV_SAMP;
+    }
+  }
+
+  public static class StdDevSqlAggregator extends BaseVarianceSqlAggregator
+  {
+    @Override
+    public SqlAggFunction calciteFunction()
+    {
+      return SqlStdOperatorTable.STDDEV;
+    }
+  }
+}
diff --git a/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
new file mode 100644
index 0000000..50b7177
--- /dev/null
+++ b/extensions-core/stats/src/test/java/org/apache/druid/query/aggregation/variance/sql/VarianceSqlAggregatorTest.java
@@ -0,0 +1,518 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.druid.query.aggregation.variance.sql;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Iterables;
+import org.apache.druid.common.config.NullHandling;
+import org.apache.druid.data.input.InputRow;
+import org.apache.druid.data.input.impl.DimensionSchema;
+import org.apache.druid.data.input.impl.DimensionsSpec;
+import org.apache.druid.data.input.impl.DoubleDimensionSchema;
+import org.apache.druid.data.input.impl.FloatDimensionSchema;
+import org.apache.druid.data.input.impl.InputRowParser;
+import org.apache.druid.data.input.impl.LongDimensionSchema;
+import org.apache.druid.data.input.impl.MapInputRowParser;
+import org.apache.druid.data.input.impl.TimeAndDimsParseSpec;
+import org.apache.druid.data.input.impl.TimestampSpec;
+import org.apache.druid.java.util.common.Pair;
+import org.apache.druid.java.util.common.granularity.Granularities;
+import org.apache.druid.java.util.common.io.Closer;
+import org.apache.druid.query.Druids;
+import org.apache.druid.query.QueryRunnerFactoryConglomerate;
+import org.apache.druid.query.aggregation.CountAggregatorFactory;
+import org.apache.druid.query.aggregation.DoubleSumAggregatorFactory;
+import org.apache.druid.query.aggregation.variance.StandardDeviationPostAggregator;
+import org.apache.druid.query.aggregation.variance.VarianceAggregatorCollector;
+import org.apache.druid.query.aggregation.variance.VarianceAggregatorFactory;
+import org.apache.druid.query.spec.MultipleIntervalSegmentSpec;
+import org.apache.druid.segment.IndexBuilder;
+import org.apache.druid.segment.QueryableIndex;
+import org.apache.druid.segment.column.ValueType;
+import org.apache.druid.segment.incremental.IncrementalIndexSchema;
+import org.apache.druid.segment.writeout.OffHeapMemorySegmentWriteOutMediumFactory;
+import org.apache.druid.server.security.AuthTestUtils;
+import org.apache.druid.server.security.AuthenticationResult;
+import org.apache.druid.sql.SqlLifecycle;
+import org.apache.druid.sql.SqlLifecycleFactory;
+import org.apache.druid.sql.calcite.BaseCalciteQueryTest;
+import org.apache.druid.sql.calcite.filtration.Filtration;
+import org.apache.druid.sql.calcite.planner.DruidOperatorTable;
+import org.apache.druid.sql.calcite.planner.PlannerConfig;
+import org.apache.druid.sql.calcite.planner.PlannerFactory;
+import org.apache.druid.sql.calcite.schema.DruidSchema;
+import org.apache.druid.sql.calcite.schema.SystemSchema;
+import org.apache.druid.sql.calcite.util.CalciteTests;
+import org.apache.druid.sql.calcite.util.QueryLogHook;
+import org.apache.druid.sql.calcite.util.SpecificSegmentsQuerySegmentWalker;
+import org.apache.druid.timeline.DataSegment;
+import org.apache.druid.timeline.partition.LinearShardSpec;
+import org.junit.After;
+import org.junit.AfterClass;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.io.IOException;
+import java.util.List;
+
+public class VarianceSqlAggregatorTest
+{
+  private static AuthenticationResult authenticationResult = CalciteTests.REGULAR_USER_AUTH_RESULT;
+  private static final String DATA_SOURCE = "numfoo";
+
+  private static QueryRunnerFactoryConglomerate conglomerate;
+  private static Closer resourceCloser;
+
+  @BeforeClass
+  public static void setUpClass()
+  {
+    final Pair<QueryRunnerFactoryConglomerate, Closer> conglomerateCloserPair = CalciteTests
+        .createQueryRunnerFactoryConglomerate();
+    conglomerate = conglomerateCloserPair.lhs;
+    resourceCloser = conglomerateCloserPair.rhs;
+  }
+
+  @AfterClass
+  public static void tearDownClass() throws IOException
+  {
+    resourceCloser.close();
+  }
+
+  @Rule
+  public TemporaryFolder temporaryFolder = new TemporaryFolder();
+
+  @Rule
+  public QueryLogHook queryLogHook = QueryLogHook.create();
+
+  private SpecificSegmentsQuerySegmentWalker walker;
+  private SqlLifecycleFactory sqlLifecycleFactory;
+
+  @Before
+  public void setUp() throws Exception
+  {
+    InputRowParser parser = new MapInputRowParser(
+        new TimeAndDimsParseSpec(
+            new TimestampSpec("t", "iso", null),
+            new DimensionsSpec(
+                ImmutableList.<DimensionSchema>builder()
+                    .addAll(DimensionsSpec.getDefaultSchemas(ImmutableList.of("dim1", "dim2", "dim3")))
+                    .add(new DoubleDimensionSchema("d1"))
+                    .add(new FloatDimensionSchema("f1"))
+                    .add(new LongDimensionSchema("l1"))
+                    .build(),
+                null,
+                null
+            )
+        ));
+
+    final QueryableIndex index =
+        IndexBuilder.create()
+                    .tmpDir(temporaryFolder.newFolder())
+                    .segmentWriteOutMediumFactory(OffHeapMemorySegmentWriteOutMediumFactory.instance())
+                    .schema(
+                        new IncrementalIndexSchema.Builder()
+                            .withMetrics(
+                                new CountAggregatorFactory("cnt"),
+                                new DoubleSumAggregatorFactory("m1", "m1")
+                            )
+                            .withDimensionsSpec(parser)
+                            .withRollup(false)
+                            .build()
+                    )
+                    .rows(CalciteTests.ROWS1_WITH_NUMERIC_DIMS)
+                    .buildMMappedIndex();
+
+    walker = new SpecificSegmentsQuerySegmentWalker(conglomerate).add(
+        DataSegment.builder()
+                   .dataSource(DATA_SOURCE)
+                   .interval(index.getDataInterval())
+                   .version("1")
+                   .shardSpec(new LinearShardSpec(0))
+                   .build(),
+        index
+    );
+
+    final PlannerConfig plannerConfig = new PlannerConfig();
+    final DruidSchema druidSchema = CalciteTests.createMockSchema(conglomerate, walker, plannerConfig);
+    final SystemSchema systemSchema = CalciteTests.createMockSystemSchema(druidSchema, walker, plannerConfig);
+    final DruidOperatorTable operatorTable = new DruidOperatorTable(
+        ImmutableSet.of(
+            new BaseVarianceSqlAggregator.VarPopSqlAggregator(),
+            new BaseVarianceSqlAggregator.VarSampSqlAggregator(),
+            new BaseVarianceSqlAggregator.VarianceSqlAggregator(),
+            new BaseVarianceSqlAggregator.StdDevPopSqlAggregator(),
+            new BaseVarianceSqlAggregator.StdDevSampSqlAggregator(),
+            new BaseVarianceSqlAggregator.StdDevSqlAggregator()
+        ),
+        ImmutableSet.of()
+    );
+
+    sqlLifecycleFactory = CalciteTests.createSqlLifecycleFactory(
+        new PlannerFactory(
+            druidSchema,
+            systemSchema,
+            CalciteTests.createMockQueryLifecycleFactory(walker, conglomerate),
+            operatorTable,
+            CalciteTests.createExprMacroTable(),
+            plannerConfig,
+            AuthTestUtils.TEST_AUTHORIZER_MAPPER,
+            CalciteTests.getJsonMapper()
+        )
+    );
+  }
+
+  @After
+  public void tearDown() throws Exception
+  {
+    walker.close();
+    walker = null;
+  }
+
+  public void addToHolder(VarianceAggregatorCollector holder, Object raw)
+  {
+    addToHolder(holder, raw, 1);
+  }
+
+  public void addToHolder(VarianceAggregatorCollector holder, Object raw, int multiply)
+  {
+    if (raw != null) {
+      if (raw instanceof Double) {
+        double v = ((Double) raw).doubleValue() * multiply;
+        holder.add((float) v);
+      } else if (raw instanceof Float) {
+        float v = ((Float) raw).floatValue() * multiply;
+        holder.add(v);
+      } else if (raw instanceof Long) {
+        long v = ((Long) raw).longValue() * multiply;
+        holder.add(v);
+      } else if (raw instanceof Integer) {
+        int v = ((Integer) raw).intValue() * multiply;
+        holder.add(v);
+      }
+    } else {
+      if (NullHandling.replaceWithDefault()) {
+        holder.add(0.0f);
+      }
+    }
+  }
+
+  @Test
+  public void testVarPop() throws Exception
+  {
+    SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
+    final String sql = "SELECT\n"
+                       + "VAR_POP(d1),\n"
+                       + "VAR_POP(f1),\n"
+                       + "VAR_POP(l1)\n"
+                       + "FROM numfoo";
+
+    final List<Object[]> results =
+        sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
+
+    VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
+    VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
+    VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
+    for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
+      Object raw1 = row.getRaw("d1");
+      Object raw2 = row.getRaw("f1");
+      Object raw3 = row.getRaw("l1");
+      addToHolder(holder1, raw1);
+      addToHolder(holder2, raw2);
+      addToHolder(holder3, raw3);
+    }
+
+    final List<Object[]> expectedResults = ImmutableList.of(
+        new Object[]{
+            holder1.getVariance(true),
+            (float) holder2.getVariance(true),
+            (long) holder3.getVariance(true),
+        }
+    );
+    Assert.assertEquals(expectedResults.size(), results.size());
+    for (int i = 0; i < expectedResults.size(); i++) {
+      Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
+    }
+
+    Assert.assertEquals(
+        Druids.newTimeseriesQueryBuilder()
+        .dataSource(CalciteTests.DATASOURCE3)
+        .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+        .granularity(Granularities.ALL)
+        .aggregators(
+            ImmutableList.of(
+              new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
+              new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
+              new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
+            )
+        )
+              .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
+              .build(),
+        Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
+    );
+  }
+
+  @Test
+  public void testVarSamp() throws Exception
+  {
+    SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
+    final String sql = "SELECT\n"
+                       + "VAR_SAMP(d1),\n"
+                       + "VAR_SAMP(f1),\n"
+                       + "VAR_SAMP(l1)\n"
+                       + "FROM numfoo";
+
+    final List<Object[]> results =
+        sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
+
+    VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
+    VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
+    VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
+    for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
+      Object raw1 = row.getRaw("d1");
+      Object raw2 = row.getRaw("f1");
+      Object raw3 = row.getRaw("l1");
+      addToHolder(holder1, raw1);
+      addToHolder(holder2, raw2);
+      addToHolder(holder3, raw3);
+    }
+
+    final List<Object[]> expectedResults = ImmutableList.of(
+        new Object[]{
+            holder1.getVariance(false),
+            (float) holder2.getVariance(false),
+            (long) holder3.getVariance(false),
+        }
+    );
+    Assert.assertEquals(expectedResults.size(), results.size());
+    for (int i = 0; i < expectedResults.size(); i++) {
+      Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
+    }
+
+    Assert.assertEquals(
+        Druids.newTimeseriesQueryBuilder()
+        .dataSource(CalciteTests.DATASOURCE3)
+        .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+        .granularity(Granularities.ALL)
+        .aggregators(
+            ImmutableList.of(
+              new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
+              new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
+              new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
+            )
+        )
+              .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
+              .build(),
+        Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
+    );
+  }
+
+  @Test
+  public void testStdDevPop() throws Exception
+  {
+    SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
+    final String sql = "SELECT\n"
+                       + "STDDEV_POP(d1),\n"
+                       + "STDDEV_POP(f1),\n"
+                       + "STDDEV_POP(l1)\n"
+                       + "FROM numfoo";
+
+    final List<Object[]> results =
+        sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
+
+    VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
+    VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
+    VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
+    for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
+      Object raw1 = row.getRaw("d1");
+      Object raw2 = row.getRaw("f1");
+      Object raw3 = row.getRaw("l1");
+      addToHolder(holder1, raw1);
+      addToHolder(holder2, raw2);
+      addToHolder(holder3, raw3);
+    }
+
+    final List<Object[]> expectedResults = ImmutableList.of(
+        new Object[]{
+            Math.sqrt(holder1.getVariance(true)),
+            (float) Math.sqrt(holder2.getVariance(true)),
+            (long) Math.sqrt(holder3.getVariance(true)),
+        }
+    );
+    Assert.assertEquals(expectedResults.size(), results.size());
+    for (int i = 0; i < expectedResults.size(); i++) {
+      Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
+    }
+
+    Assert.assertEquals(
+        Druids.newTimeseriesQueryBuilder()
+        .dataSource(CalciteTests.DATASOURCE3)
+        .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+        .granularity(Granularities.ALL)
+        .aggregators(
+            ImmutableList.of(
+              new VarianceAggregatorFactory("a0:agg", "d1", "population", "float"),
+              new VarianceAggregatorFactory("a1:agg", "f1", "population", "float"),
+              new VarianceAggregatorFactory("a2:agg", "l1", "population", "long")
+            )
+        )
+        .postAggregators(
+            ImmutableList.of(
+            new StandardDeviationPostAggregator("a0", "a0:agg", "population"),
+            new StandardDeviationPostAggregator("a1", "a1:agg", "population"),
+            new StandardDeviationPostAggregator("a2", "a2:agg", "population"))
+        )
+              .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
+              .build(),
+        Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
+    );
+  }
+
+  @Test
+  public void testStdDevSamp() throws Exception
+  {
+    queryLogHook.clearRecordedQueries();
+    SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
+    final String sql = "SELECT\n"
+                       + "STDDEV_SAMP(d1),\n"
+                       + "STDDEV_SAMP(f1),\n"
+                       + "STDDEV_SAMP(l1)\n"
+                       + "FROM numfoo";
+
+    final List<Object[]> results =
+        sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
+
+    VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
+    VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
+    VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
+    for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
+      Object raw1 = row.getRaw("d1");
+      Object raw2 = row.getRaw("f1");
+      Object raw3 = row.getRaw("l1");
+      addToHolder(holder1, raw1);
+      addToHolder(holder2, raw2);
+      addToHolder(holder3, raw3);
+    }
+
+    final List<Object[]> expectedResults = ImmutableList.of(
+        new Object[]{
+            Math.sqrt(holder1.getVariance(false)),
+            (float) Math.sqrt(holder2.getVariance(false)),
+            (long) Math.sqrt(holder3.getVariance(false)),
+        }
+    );
+    Assert.assertEquals(expectedResults.size(), results.size());
+    for (int i = 0; i < expectedResults.size(); i++) {
+      Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
+    }
+
+    Assert.assertEquals(
+        Druids.newTimeseriesQueryBuilder()
+              .dataSource(CalciteTests.DATASOURCE3)
+              .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+              .granularity(Granularities.ALL)
+              .aggregators(
+                  ImmutableList.of(
+                    new VarianceAggregatorFactory("a0:agg", "d1", "sample", "float"),
+                    new VarianceAggregatorFactory("a1:agg", "f1", "sample", "float"),
+                    new VarianceAggregatorFactory("a2:agg", "l1", "sample", "long")
+                  )
+              )
+              .postAggregators(
+                  new StandardDeviationPostAggregator("a0", "a0:agg", "sample"),
+                  new StandardDeviationPostAggregator("a1", "a1:agg", "sample"),
+                  new StandardDeviationPostAggregator("a2", "a2:agg", "sample")
+              )
+              .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
+              .build(),
+        Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
+    );
+  }
+  
+  @Test
+  public void testStdDevWithVirtualColumns() throws Exception
+  {
+    queryLogHook.clearRecordedQueries();
+    SqlLifecycle sqlLifecycle = sqlLifecycleFactory.factorize();
+    final String sql = "SELECT\n"
+                       + "STDDEV(d1*7),\n"
+                       + "STDDEV(f1*7),\n"
+                       + "STDDEV(l1*7)\n"
+                       + "FROM numfoo";
+
+    final List<Object[]> results =
+        sqlLifecycle.runSimple(sql, BaseCalciteQueryTest.QUERY_CONTEXT_DEFAULT, authenticationResult).toList();
+
+    VarianceAggregatorCollector holder1 = new VarianceAggregatorCollector();
+    VarianceAggregatorCollector holder2 = new VarianceAggregatorCollector();
+    VarianceAggregatorCollector holder3 = new VarianceAggregatorCollector();
+    for (InputRow row : CalciteTests.ROWS1_WITH_NUMERIC_DIMS) {
+      Object raw1 = row.getRaw("d1");
+      Object raw2 = row.getRaw("f1");
+      Object raw3 = row.getRaw("l1");
+      addToHolder(holder1, raw1, 7);
+      addToHolder(holder2, raw2, 7);
+      addToHolder(holder3, raw3, 7);
+    }
+
+    final List<Object[]> expectedResults = ImmutableList.of(
+        new Object[]{
+            Math.sqrt(holder1.getVariance(false)),
+            (float) Math.sqrt(holder2.getVariance(false)),
+            (long) Math.sqrt(holder3.getVariance(false)),
+        }
+    );
+    Assert.assertEquals(expectedResults.size(), results.size());
+    for (int i = 0; i < expectedResults.size(); i++) {
+      Assert.assertArrayEquals(expectedResults.get(i), results.get(i));
+    }
+
+    Assert.assertEquals(
+        Druids.newTimeseriesQueryBuilder()
+              .dataSource(CalciteTests.DATASOURCE3)
+              .intervals(new MultipleIntervalSegmentSpec(ImmutableList.of(Filtration.eternity())))
+              .granularity(Granularities.ALL)
+              .virtualColumns(
+                  BaseCalciteQueryTest.expressionVirtualColumn("v0", "(\"d1\" * 7)", ValueType.DOUBLE),
+                  BaseCalciteQueryTest.expressionVirtualColumn("v1", "(\"f1\" * 7)", ValueType.FLOAT),
+                  BaseCalciteQueryTest.expressionVirtualColumn("v2", "(\"l1\" * 7)", ValueType.LONG)
+              )
+              .aggregators(
+                  ImmutableList.of(
+                    new VarianceAggregatorFactory("a0:agg", "v0", "sample", "float"),
+                    new VarianceAggregatorFactory("a1:agg", "v1", "sample", "float"),
+                    new VarianceAggregatorFactory("a2:agg", "v2", "sample", "long")
+                  )
+              )
+              .postAggregators(
+                  new StandardDeviationPostAggregator("a0", "a0:agg", "sample"),
+                  new StandardDeviationPostAggregator("a1", "a1:agg", "sample"),
+                  new StandardDeviationPostAggregator("a2", "a2:agg", "sample")
+              )
+              .context(BaseCalciteQueryTest.TIMESERIES_CONTEXT_DEFAULT)
+              .build(),
+        Iterables.getOnlyElement(queryLogHook.getRecordedQueries())
+    );
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@druid.apache.org
For additional commands, e-mail: commits-help@druid.apache.org