You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by cg...@apache.org on 2023/01/10 17:44:36 UTC
[drill] branch master updated: DRILL-8376: Add Distribution UDFs (#2729)
This is an automated email from the ASF dual-hosted git repository.
cgivre pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git
The following commit(s) were added to refs/heads/master by this push:
new 36fbe6ec21 DRILL-8376: Add Distribution UDFs (#2729)
36fbe6ec21 is described below
commit 36fbe6ec2160250882e7eb09f9bc688093ccfd27
Author: Charles S. Givre <cg...@apache.org>
AuthorDate: Tue Jan 10 12:44:26 2023 -0500
DRILL-8376: Add Distribution UDFs (#2729)
---
contrib/udfs/README.md | 10 +
.../drill/exec/udfs/DistributionFunctions.java | 330 +++++++++++++++++++++
.../drill/exec/udfs/TestDistributionFunctions.java | 108 +++++++
contrib/udfs/src/test/resources/regr_test.csvh | 13 +
contrib/udfs/src/test/resources/test_data.csvh | 5 +
5 files changed, 466 insertions(+)
diff --git a/contrib/udfs/README.md b/contrib/udfs/README.md
index f8e47b9ee8..38046f9fc6 100644
--- a/contrib/udfs/README.md
+++ b/contrib/udfs/README.md
@@ -424,5 +424,15 @@ The functions are:
* `entropy(<string>)`: This function calculates the Shannon Entropy of a given string of text.
* `entropyPerByte(<string>)`: This function calculates the Shannon Entropy of a given string of text, normed for the string length.
+# Statistical Functions
+Drill has several functions for correlations and understanding the distribution of your data.
+
+The functions are:
+* `width_bucket(value, min, max, buckets)`: Useful for crafting histograms and understanding distributions of continuous variables.
+* `kendall_correlation(col1, col2)`: Calculates the kendall correlation coefficient of two columns within a dataset.
+* `regr_slope(x,y)`: Determines the slope of the least-squares-fit linear equation
+* `regr_intercept(x,y)`: Computes the y-intercept of the least-squares-fit linear equation
+
+
[1]: https://github.com/target/huntlib
diff --git a/contrib/udfs/src/main/java/org/apache/drill/exec/udfs/DistributionFunctions.java b/contrib/udfs/src/main/java/org/apache/drill/exec/udfs/DistributionFunctions.java
new file mode 100644
index 0000000000..0b4b623246
--- /dev/null
+++ b/contrib/udfs/src/main/java/org/apache/drill/exec/udfs/DistributionFunctions.java
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.drill.exec.udfs;
+
+import org.apache.drill.exec.expr.DrillAggFunc;
+import org.apache.drill.exec.expr.DrillSimpleFunc;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate.FunctionScope;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate.NullHandling;
+import org.apache.drill.exec.expr.annotations.Output;
+import org.apache.drill.exec.expr.annotations.Param;
+import org.apache.drill.exec.expr.annotations.Workspace;
+import org.apache.drill.exec.expr.holders.Float8Holder;
+import org.apache.drill.exec.expr.holders.IntHolder;
+
+public class DistributionFunctions {
+
+ @FunctionTemplate(names = {"width_bucket", "widthBucket"},
+ scope = FunctionScope.SIMPLE,
+ nulls = NullHandling.NULL_IF_NULL)
+ public static class WidthBucketFunction implements DrillSimpleFunc {
+
+ @Param
+ Float8Holder inputValue;
+
+ @Param
+ Float8Holder MinRangeValueHolder;
+
+ @Param
+ Float8Holder MaxRangeValueHolder;
+
+ @Param
+ IntHolder bucketCountHolder;
+
+ @Workspace
+ double binWidth;
+
+ @Workspace
+ int bucketCount;
+
+ @Output
+ IntHolder bucket;
+
+ @Override
+ public void setup() {
+ double max = MaxRangeValueHolder.value;
+ double min = MinRangeValueHolder.value;
+ bucketCount = bucketCountHolder.value;
+ binWidth = (max - min) / bucketCount;
+ }
+
+ @Override
+ public void eval() {
+ if (inputValue.value < MinRangeValueHolder.value) {
+ bucket.value = 0;
+ } else if (inputValue.value > MaxRangeValueHolder.value) {
+ bucket.value = bucketCount + 1;
+ } else {
+ bucket.value = (int) (1 + (inputValue.value - MinRangeValueHolder.value) / binWidth);
+ }
+ }
+ }
+
+ @FunctionTemplate(
+ names = {"kendall_correlation","kendallCorrelation", "kendallTau", "kendall_tau"},
+ scope = FunctionScope.POINT_AGGREGATE,
+ nulls = NullHandling.INTERNAL
+ )
+ public static class KendallTauFunction implements DrillAggFunc {
+ @Param
+ Float8Holder xInput;
+
+ @Param
+ Float8Holder yInput;
+
+ @Workspace
+ Float8Holder prevXValue;
+
+ @Workspace
+ Float8Holder prevYValue;
+
+ @Workspace
+ IntHolder concordantPairs;
+
+ @Workspace
+ IntHolder discordantPairs;
+
+ @Workspace
+ IntHolder n;
+
+ @Output
+ Float8Holder tau;
+
+ @Override
+ public void add() {
+ double xValue = xInput.value;
+ double yValue = yInput.value;
+
+ if (n.value > 0) {
+ if ((xValue > prevXValue.value && yValue > prevYValue.value) || (xValue < prevXValue.value && yValue < prevYValue.value)) {
+ concordantPairs.value = concordantPairs.value + 1;
+ } else if ((xValue > prevXValue.value && yValue < prevYValue.value) || (xValue < prevXValue.value && yValue > prevYValue.value)) {
+ discordantPairs.value = discordantPairs.value + 1;
+ } else {
+ // Tie...
+ }
+ n.value = n.value + 1;
+
+ } else if (n.value == 0){
+ n.value = n.value + 1;
+ }
+ prevXValue.value = xValue;
+ prevYValue.value = yValue;
+
+ }
+
+ @Override
+ public void setup() {
+ }
+
+ @Override
+ public void reset() {
+ prevXValue.value = 0;
+ prevYValue.value = 0;
+ concordantPairs.value = 0;
+ discordantPairs.value = 0;
+ n.value = 0;
+ }
+
+ @Override
+ public void output() {
+ double result = 0.0;
+ result = (concordantPairs.value - discordantPairs.value) / (0.5 * n.value * (n.value - 1));
+ tau.value = result;
+ }
+ }
+
+ @FunctionTemplate(names = {"regr_slope", "regrSlope"},
+ scope = FunctionScope.POINT_AGGREGATE,
+ nulls = NullHandling.INTERNAL)
+ public static class RegrSlopeFunction implements DrillAggFunc {
+
+ @Param
+ Float8Holder xInput;
+
+ @Param
+ Float8Holder yInput;
+
+ @Workspace
+ Float8Holder sum_x;
+
+ @Workspace
+ Float8Holder sum_y;
+
+ @Workspace
+ Float8Holder avg_x;
+
+ @Workspace
+ Float8Holder avg_y;
+
+ @Workspace
+ Float8Holder diff_x;
+
+ @Workspace
+ Float8Holder diff_y;
+
+ @Workspace
+ Float8Holder ss_x;
+
+ @Workspace
+ Float8Holder ss_xy;
+
+ @Workspace
+ IntHolder recordCount;
+
+ @Output
+ Float8Holder slope;
+ @Override
+ public void setup() {
+ recordCount.value = 0;
+ sum_y.value = 0;
+ sum_x.value = 0;
+ avg_x.value = 0;
+ avg_y.value = 0;
+ diff_x.value = 0;
+ diff_y.value = 0;
+ ss_x.value = 0;
+ ss_xy.value = 0;
+ }
+
+ @Override
+ public void add() {
+ recordCount.value += 1;
+ sum_x.value += xInput.value;
+ avg_x.value = sum_x.value / recordCount.value;
+ diff_x.value = avg_x.value - xInput.value;
+ ss_x.value = (diff_x.value * diff_x.value) + ss_x.value;
+
+ // Now compute the sum of squares for the y
+ sum_y.value = sum_y.value + yInput.value;
+ avg_y.value = sum_y.value / recordCount.value;
+ diff_y.value = avg_y.value - yInput.value;
+
+ ss_xy.value = (diff_x.value * diff_y.value) + ss_xy.value;
+ }
+
+ @Override
+ public void output() {
+ slope.value = ss_xy.value / ss_x.value;
+ }
+
+ @Override
+ public void reset() {
+ recordCount.value = 0;
+ sum_y.value = 0;
+ sum_x.value = 0;
+ avg_x.value = 0;
+ avg_y.value = 0;
+ diff_x.value = 0;
+ diff_y.value = 0;
+ ss_x.value = 0;
+ ss_xy.value = 0;
+ }
+ }
+
+ @FunctionTemplate(names = {"regr_intercept", "regrIntercept"},
+ scope = FunctionScope.POINT_AGGREGATE,
+ nulls = NullHandling.INTERNAL)
+ public static class RegrInterceptFunction implements DrillAggFunc {
+
+ @Param
+ Float8Holder xInput;
+
+ @Param
+ Float8Holder yInput;
+
+ @Workspace
+ Float8Holder sum_x;
+
+ @Workspace
+ Float8Holder sum_y;
+
+ @Workspace
+ Float8Holder avg_x;
+
+ @Workspace
+ Float8Holder avg_y;
+
+ @Workspace
+ Float8Holder diff_x;
+
+ @Workspace
+ Float8Holder diff_y;
+
+ @Workspace
+ Float8Holder ss_x;
+
+ @Workspace
+ Float8Holder ss_xy;
+
+ @Workspace
+ IntHolder recordCount;
+
+ @Output
+ Float8Holder intercept;
+ @Override
+ public void setup() {
+ recordCount.value = 0;
+ sum_y.value = 0;
+ sum_x.value = 0;
+ avg_x.value = 0;
+ avg_y.value = 0;
+ diff_x.value = 0;
+ diff_y.value = 0;
+ ss_x.value = 0;
+ ss_xy.value = 0;
+ }
+
+ @Override
+ public void add() {
+ recordCount.value += 1;
+ sum_x.value += xInput.value;
+ avg_x.value = sum_x.value / recordCount.value;
+ diff_x.value = avg_x.value - xInput.value;
+ ss_x.value = (diff_x.value * diff_x.value) + ss_x.value;
+
+ // Now compute the sum of squares for the y
+ sum_y.value = sum_y.value + yInput.value;
+ avg_y.value = sum_y.value / recordCount.value;
+ diff_y.value = avg_y.value - yInput.value;
+
+ ss_xy.value = (diff_x.value * diff_y.value) + ss_xy.value;
+ }
+
+ @Override
+ public void output() {
+ double slope = ss_xy.value / ss_x.value;
+ intercept.value = avg_y.value - slope * avg_x.value;
+ }
+
+ @Override
+ public void reset() {
+ recordCount.value = 0;
+ sum_y.value = 0;
+ sum_x.value = 0;
+ avg_x.value = 0;
+ avg_y.value = 0;
+ diff_x.value = 0;
+ diff_y.value = 0;
+ ss_x.value = 0;
+ ss_xy.value = 0;
+ }
+ }
+}
diff --git a/contrib/udfs/src/test/java/org/apache/drill/exec/udfs/TestDistributionFunctions.java b/contrib/udfs/src/test/java/org/apache/drill/exec/udfs/TestDistributionFunctions.java
new file mode 100644
index 0000000000..b7b8b530a5
--- /dev/null
+++ b/contrib/udfs/src/test/java/org/apache/drill/exec/udfs/TestDistributionFunctions.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.drill.exec.udfs;
+
+import org.apache.drill.test.ClusterFixture;
+import org.apache.drill.test.ClusterFixtureBuilder;
+import org.apache.drill.test.ClusterTest;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class TestDistributionFunctions extends ClusterTest {
+
+ @BeforeClass
+ public static void setup() throws Exception {
+ ClusterFixtureBuilder builder = ClusterFixture.builder(dirTestWatcher);
+ startCluster(builder);
+ }
+
+ @Test
+ public void testWidthBucket() throws Exception {
+ // Test with float input
+ String query = "SELECT width_bucket(5.35, 0,10,5) AS bucket FROM (VALUES(1))";
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("bucket")
+ .baselineValues(3)
+ .go();
+
+ // Test with int input
+ query = "SELECT width_bucket(2, 0,10,5) AS bucket FROM (VALUES(1))";
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("bucket")
+ .baselineValues(2)
+ .go();
+
+ // Test with string input
+ query = "SELECT width_bucket('9', 0,10,5) AS bucket FROM (VALUES(1))";
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("bucket")
+ .baselineValues(5)
+ .go();
+
+ // Test with input out of range
+ query = "SELECT width_bucket(-5, 0,10,5) AS too_low_bucket, width_bucket(505, 0,10,5) AS too_high_bucket FROM (VALUES(1))";
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("too_low_bucket", "too_high_bucket")
+ .baselineValues(0, 6)
+ .go();
+
+ }
+
+ @Test
+ public void testKendall() throws Exception {
+ String query = "SELECT kendall_correlation(col1,col2) AS R FROM cp.`test_data.csvh`";
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("R")
+ .baselineValues(0.16666666666666666)
+ .go();
+
+ }
+
+ @Test
+ public void testRegrSlope() throws Exception {
+ String query = "SELECT regr_slope(spend,sales) AS slope FROM cp.`regr_test.csvh`";
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("slope")
+ .baselineValues(10.619633290847284)
+ .go();
+ }
+
+ @Test
+ public void testRegrIntercept() throws Exception {
+ String query = "SELECT regr_intercept(spend,sales) AS intercept FROM cp.`regr_test.csvh`";
+ testBuilder()
+ .sqlQuery(query)
+ .unOrdered()
+ .baselineColumns("intercept")
+ .baselineValues(1400.2322223740048)
+ .go();
+ }
+}
diff --git a/contrib/udfs/src/test/resources/regr_test.csvh b/contrib/udfs/src/test/resources/regr_test.csvh
new file mode 100644
index 0000000000..1a793b4a33
--- /dev/null
+++ b/contrib/udfs/src/test/resources/regr_test.csvh
@@ -0,0 +1,13 @@
+spend,sales
+1000,9914
+4000,40487
+5000,54324
+4500,50044
+3000,34719
+4000,42551
+9000,94871
+11000,118914
+15000,158484
+12000,131348
+7000,78504
+3000,36284
\ No newline at end of file
diff --git a/contrib/udfs/src/test/resources/test_data.csvh b/contrib/udfs/src/test/resources/test_data.csvh
new file mode 100644
index 0000000000..adfcc9764b
--- /dev/null
+++ b/contrib/udfs/src/test/resources/test_data.csvh
@@ -0,0 +1,5 @@
+col1,col2
+2,25
+3,32
+4,49
+5,32
\ No newline at end of file