You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by cg...@apache.org on 2023/01/10 17:44:36 UTC

[drill] branch master updated: DRILL-8376: Add Distribution UDFs (#2729)

This is an automated email from the ASF dual-hosted git repository.

cgivre pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git


The following commit(s) were added to refs/heads/master by this push:
     new 36fbe6ec21 DRILL-8376: Add Distribution UDFs (#2729)
36fbe6ec21 is described below

commit 36fbe6ec2160250882e7eb09f9bc688093ccfd27
Author: Charles S. Givre <cg...@apache.org>
AuthorDate: Tue Jan 10 12:44:26 2023 -0500

    DRILL-8376: Add Distribution UDFs (#2729)
---
 contrib/udfs/README.md                             |  10 +
 .../drill/exec/udfs/DistributionFunctions.java     | 330 +++++++++++++++++++++
 .../drill/exec/udfs/TestDistributionFunctions.java | 108 +++++++
 contrib/udfs/src/test/resources/regr_test.csvh     |  13 +
 contrib/udfs/src/test/resources/test_data.csvh     |   5 +
 5 files changed, 466 insertions(+)

diff --git a/contrib/udfs/README.md b/contrib/udfs/README.md
index f8e47b9ee8..38046f9fc6 100644
--- a/contrib/udfs/README.md
+++ b/contrib/udfs/README.md
@@ -424,5 +424,15 @@ The functions are:
 * `entropy(<string>)`: This function calculates the Shannon Entropy of a given string of text.
 * `entropyPerByte(<string>)`: This function calculates the Shannon Entropy of a given string of text, normed for the string length.
 
+# Statistical Functions
+Drill has several functions for correlations and understanding the distribution of your data.
+
+The functions are:
+* `width_bucket(value, min, max, buckets)`: Useful for crafting histograms and understanding distributions of continuous variables.
+* `kendall_correlation(col1, col2)`:  Calculates the kendall correlation coefficient of two columns within a dataset.
+* `regr_slope(x,y)`: Determines the slope of the least-squares-fit linear equation
+* `regr_intercept(x,y)`: Computes the y-intercept of the least-squares-fit linear equation
+
+
 [1]: https://github.com/target/huntlib
 
diff --git a/contrib/udfs/src/main/java/org/apache/drill/exec/udfs/DistributionFunctions.java b/contrib/udfs/src/main/java/org/apache/drill/exec/udfs/DistributionFunctions.java
new file mode 100644
index 0000000000..0b4b623246
--- /dev/null
+++ b/contrib/udfs/src/main/java/org/apache/drill/exec/udfs/DistributionFunctions.java
@@ -0,0 +1,330 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.drill.exec.udfs;
+
+import org.apache.drill.exec.expr.DrillAggFunc;
+import org.apache.drill.exec.expr.DrillSimpleFunc;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate.FunctionScope;
+import org.apache.drill.exec.expr.annotations.FunctionTemplate.NullHandling;
+import org.apache.drill.exec.expr.annotations.Output;
+import org.apache.drill.exec.expr.annotations.Param;
+import org.apache.drill.exec.expr.annotations.Workspace;
+import org.apache.drill.exec.expr.holders.Float8Holder;
+import org.apache.drill.exec.expr.holders.IntHolder;
+
+public class DistributionFunctions {
+
+  @FunctionTemplate(names = {"width_bucket", "widthBucket"},
+      scope = FunctionScope.SIMPLE,
+      nulls = NullHandling.NULL_IF_NULL)
+  public static class WidthBucketFunction implements DrillSimpleFunc {
+
+    @Param
+    Float8Holder inputValue;
+
+    @Param
+    Float8Holder MinRangeValueHolder;
+
+    @Param
+    Float8Holder MaxRangeValueHolder;
+
+    @Param
+    IntHolder bucketCountHolder;
+
+    @Workspace
+    double binWidth;
+
+    @Workspace
+    int bucketCount;
+
+    @Output
+    IntHolder bucket;
+
+    @Override
+    public void setup() {
+      double max = MaxRangeValueHolder.value;
+      double min = MinRangeValueHolder.value;
+      bucketCount = bucketCountHolder.value;
+      binWidth = (max - min) / bucketCount;
+    }
+
+    @Override
+    public void eval() {
+      if (inputValue.value < MinRangeValueHolder.value) {
+        bucket.value = 0;
+      } else if (inputValue.value > MaxRangeValueHolder.value) {
+        bucket.value = bucketCount + 1;
+      } else {
+        bucket.value = (int) (1 + (inputValue.value - MinRangeValueHolder.value) / binWidth);
+      }
+    }
+  }
+
+  @FunctionTemplate(
+      names = {"kendall_correlation","kendallCorrelation", "kendallTau", "kendall_tau"},
+      scope = FunctionScope.POINT_AGGREGATE,
+      nulls = NullHandling.INTERNAL
+  )
+  public static class KendallTauFunction implements DrillAggFunc {
+    @Param
+    Float8Holder xInput;
+
+    @Param
+    Float8Holder yInput;
+
+    @Workspace
+    Float8Holder prevXValue;
+
+    @Workspace
+    Float8Holder prevYValue;
+
+    @Workspace
+    IntHolder concordantPairs;
+
+    @Workspace
+    IntHolder discordantPairs;
+
+    @Workspace
+    IntHolder n;
+
+    @Output
+    Float8Holder tau;
+
+    @Override
+    public void add() {
+      double xValue = xInput.value;
+      double yValue = yInput.value;
+
+      if (n.value > 0) {
+        if ((xValue > prevXValue.value && yValue > prevYValue.value) || (xValue < prevXValue.value && yValue < prevYValue.value)) {
+          concordantPairs.value = concordantPairs.value + 1;
+        } else if ((xValue > prevXValue.value && yValue < prevYValue.value) || (xValue < prevXValue.value && yValue > prevYValue.value)) {
+          discordantPairs.value = discordantPairs.value + 1;
+        } else {
+          // Tie...
+        }
+        n.value = n.value + 1;
+
+      } else if (n.value == 0){
+        n.value = n.value + 1;
+      }
+      prevXValue.value = xValue;
+      prevYValue.value = yValue;
+
+    }
+
+    @Override
+    public void setup() {
+    }
+
+    @Override
+    public void reset() {
+      prevXValue.value = 0;
+      prevYValue.value = 0;
+      concordantPairs.value = 0;
+      discordantPairs.value = 0;
+      n.value = 0;
+    }
+
+    @Override
+    public void output() {
+      double result = 0.0;
+      result = (concordantPairs.value - discordantPairs.value) / (0.5 * n.value * (n.value - 1));
+      tau.value = result;
+    }
+  }
+
+  @FunctionTemplate(names = {"regr_slope", "regrSlope"},
+      scope = FunctionScope.POINT_AGGREGATE,
+      nulls = NullHandling.INTERNAL)
+  public static class RegrSlopeFunction implements DrillAggFunc {
+
+    @Param
+    Float8Holder xInput;
+
+    @Param
+    Float8Holder yInput;
+
+    @Workspace
+    Float8Holder sum_x;
+
+    @Workspace
+    Float8Holder sum_y;
+
+    @Workspace
+    Float8Holder avg_x;
+
+    @Workspace
+    Float8Holder avg_y;
+
+    @Workspace
+    Float8Holder diff_x;
+
+    @Workspace
+    Float8Holder diff_y;
+
+    @Workspace
+    Float8Holder ss_x;
+
+    @Workspace
+    Float8Holder ss_xy;
+
+    @Workspace
+    IntHolder recordCount;
+
+    @Output
+    Float8Holder slope;
+    @Override
+    public void setup() {
+      recordCount.value = 0;
+      sum_y.value = 0;
+      sum_x.value = 0;
+      avg_x.value = 0;
+      avg_y.value = 0;
+      diff_x.value = 0;
+      diff_y.value = 0;
+      ss_x.value = 0;
+      ss_xy.value = 0;
+    }
+
+    @Override
+    public void add() {
+      recordCount.value += 1;
+      sum_x.value += xInput.value;
+      avg_x.value = sum_x.value / recordCount.value;
+      diff_x.value = avg_x.value - xInput.value;
+      ss_x.value = (diff_x.value * diff_x.value) + ss_x.value;
+
+      // Now compute the sum of squares for the y
+      sum_y.value = sum_y.value + yInput.value;
+      avg_y.value = sum_y.value / recordCount.value;
+      diff_y.value = avg_y.value - yInput.value;
+
+      ss_xy.value = (diff_x.value * diff_y.value) + ss_xy.value;
+    }
+
+    @Override
+    public void output() {
+      slope.value = ss_xy.value / ss_x.value;
+    }
+
+    @Override
+    public void reset() {
+      recordCount.value = 0;
+      sum_y.value = 0;
+      sum_x.value = 0;
+      avg_x.value = 0;
+      avg_y.value = 0;
+      diff_x.value = 0;
+      diff_y.value = 0;
+      ss_x.value = 0;
+      ss_xy.value = 0;
+    }
+  }
+
+  @FunctionTemplate(names = {"regr_intercept", "regrIntercept"},
+      scope = FunctionScope.POINT_AGGREGATE,
+      nulls = NullHandling.INTERNAL)
+  public static class RegrInterceptFunction implements DrillAggFunc {
+
+    @Param
+    Float8Holder xInput;
+
+    @Param
+    Float8Holder yInput;
+
+    @Workspace
+    Float8Holder sum_x;
+
+    @Workspace
+    Float8Holder sum_y;
+
+    @Workspace
+    Float8Holder avg_x;
+
+    @Workspace
+    Float8Holder avg_y;
+
+    @Workspace
+    Float8Holder diff_x;
+
+    @Workspace
+    Float8Holder diff_y;
+
+    @Workspace
+    Float8Holder ss_x;
+
+    @Workspace
+    Float8Holder ss_xy;
+
+    @Workspace
+    IntHolder recordCount;
+
+    @Output
+    Float8Holder intercept;
+    @Override
+    public void setup() {
+      recordCount.value = 0;
+      sum_y.value = 0;
+      sum_x.value = 0;
+      avg_x.value = 0;
+      avg_y.value = 0;
+      diff_x.value = 0;
+      diff_y.value = 0;
+      ss_x.value = 0;
+      ss_xy.value = 0;
+    }
+
+    @Override
+    public void add() {
+      recordCount.value += 1;
+      sum_x.value += xInput.value;
+      avg_x.value = sum_x.value / recordCount.value;
+      diff_x.value = avg_x.value - xInput.value;
+      ss_x.value = (diff_x.value * diff_x.value) + ss_x.value;
+
+      // Now compute the sum of squares for the y
+      sum_y.value = sum_y.value + yInput.value;
+      avg_y.value = sum_y.value / recordCount.value;
+      diff_y.value = avg_y.value - yInput.value;
+
+      ss_xy.value = (diff_x.value * diff_y.value) + ss_xy.value;
+    }
+
+    @Override
+    public void output() {
+      double slope = ss_xy.value / ss_x.value;
+      intercept.value = avg_y.value - slope * avg_x.value;
+    }
+
+    @Override
+    public void reset() {
+      recordCount.value = 0;
+      sum_y.value = 0;
+      sum_x.value = 0;
+      avg_x.value = 0;
+      avg_y.value = 0;
+      diff_x.value = 0;
+      diff_y.value = 0;
+      ss_x.value = 0;
+      ss_xy.value = 0;
+    }
+  }
+}
diff --git a/contrib/udfs/src/test/java/org/apache/drill/exec/udfs/TestDistributionFunctions.java b/contrib/udfs/src/test/java/org/apache/drill/exec/udfs/TestDistributionFunctions.java
new file mode 100644
index 0000000000..b7b8b530a5
--- /dev/null
+++ b/contrib/udfs/src/test/java/org/apache/drill/exec/udfs/TestDistributionFunctions.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.drill.exec.udfs;
+
+import org.apache.drill.test.ClusterFixture;
+import org.apache.drill.test.ClusterFixtureBuilder;
+import org.apache.drill.test.ClusterTest;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class TestDistributionFunctions extends ClusterTest {
+
+  @BeforeClass
+  public static void setup() throws Exception {
+    ClusterFixtureBuilder builder = ClusterFixture.builder(dirTestWatcher);
+    startCluster(builder);
+  }
+
+  @Test
+  public void testWidthBucket() throws Exception {
+    // Test with float input
+    String query = "SELECT width_bucket(5.35, 0,10,5) AS bucket FROM (VALUES(1))";
+    testBuilder()
+        .sqlQuery(query)
+        .unOrdered()
+        .baselineColumns("bucket")
+        .baselineValues(3)
+        .go();
+
+    // Test with int input
+    query = "SELECT width_bucket(2, 0,10,5) AS bucket FROM (VALUES(1))";
+    testBuilder()
+        .sqlQuery(query)
+        .unOrdered()
+        .baselineColumns("bucket")
+        .baselineValues(2)
+        .go();
+
+    // Test with string input
+    query = "SELECT width_bucket('9', 0,10,5) AS bucket FROM (VALUES(1))";
+    testBuilder()
+        .sqlQuery(query)
+        .unOrdered()
+        .baselineColumns("bucket")
+        .baselineValues(5)
+        .go();
+
+    // Test with input out of range
+    query = "SELECT width_bucket(-5, 0,10,5) AS too_low_bucket, width_bucket(505, 0,10,5) AS too_high_bucket FROM (VALUES(1))";
+    testBuilder()
+        .sqlQuery(query)
+        .unOrdered()
+        .baselineColumns("too_low_bucket", "too_high_bucket")
+        .baselineValues(0, 6)
+        .go();
+
+  }
+
+  @Test
+  public void testKendall() throws Exception {
+    String query = "SELECT kendall_correlation(col1,col2) AS R FROM cp.`test_data.csvh`";
+    testBuilder()
+        .sqlQuery(query)
+        .unOrdered()
+        .baselineColumns("R")
+        .baselineValues(0.16666666666666666)
+        .go();
+
+  }
+
+  @Test
+  public void testRegrSlope() throws Exception {
+    String query = "SELECT regr_slope(spend,sales) AS slope FROM cp.`regr_test.csvh`";
+    testBuilder()
+        .sqlQuery(query)
+        .unOrdered()
+        .baselineColumns("slope")
+        .baselineValues(10.619633290847284)
+        .go();
+  }
+
+  @Test
+  public void testRegrIntercept() throws Exception {
+    String query = "SELECT regr_intercept(spend,sales) AS intercept FROM cp.`regr_test.csvh`";
+    testBuilder()
+        .sqlQuery(query)
+        .unOrdered()
+        .baselineColumns("intercept")
+        .baselineValues(1400.2322223740048)
+        .go();
+  }
+}
diff --git a/contrib/udfs/src/test/resources/regr_test.csvh b/contrib/udfs/src/test/resources/regr_test.csvh
new file mode 100644
index 0000000000..1a793b4a33
--- /dev/null
+++ b/contrib/udfs/src/test/resources/regr_test.csvh
@@ -0,0 +1,13 @@
+spend,sales
+1000,9914
+4000,40487
+5000,54324
+4500,50044
+3000,34719
+4000,42551
+9000,94871
+11000,118914
+15000,158484
+12000,131348
+7000,78504
+3000,36284
\ No newline at end of file
diff --git a/contrib/udfs/src/test/resources/test_data.csvh b/contrib/udfs/src/test/resources/test_data.csvh
new file mode 100644
index 0000000000..adfcc9764b
--- /dev/null
+++ b/contrib/udfs/src/test/resources/test_data.csvh
@@ -0,0 +1,5 @@
+col1,col2
+2,25
+3,32
+4,49
+5,32
\ No newline at end of file