You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by dz...@apache.org on 2022/01/17 15:57:48 UTC

[drill] branch master updated: DRILL-8094: Support reverse truncation for split_part udf (#2416)

This is an automated email from the ASF dual-hosted git repository.

dzamo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git


The following commit(s) were added to refs/heads/master by this push:
     new 55e94c4  DRILL-8094: Support reverse truncation for split_part udf (#2416)
55e94c4 is described below

commit 55e94c4e1c4a05ac7010391daea8f4f0804b0286
Author: leon <32...@qq.com>
AuthorDate: Mon Jan 17 23:57:35 2022 +0800

    DRILL-8094: Support reverse truncation for split_part udf (#2416)
    
    * DRILL-8094: Support reverse truncation for split_part udf
    
    * fix ut
    
    Co-authored-by: feiteng.wtf <fe...@cainiao.com>
---
 .../drill/exec/expr/fn/impl/StringFunctions.java   |  70 +++++++++----
 .../exec/expr/fn/impl/TestStringFunctions.java     | 113 ++++++++++++++++++++-
 2 files changed, 159 insertions(+), 24 deletions(-)

diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
index 4dca322..27b0644 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
@@ -384,7 +384,8 @@ public class StringFunctions{
 
   /**
    * Return the string part at index after splitting the input string using the
-   * specified delimiter. The index must be a positive integer.
+   * specified delimiter. The index starts 1 or -1, counting from beginning if
+   * is positive, from end if is negative.
    */
   @FunctionTemplate(name = "split_part", scope = FunctionScope.SIMPLE, nulls = NullHandling.NULL_IF_NULL,
                     outputWidthCalculatorType = OutputWidthCalculatorType.CUSTOM_FIXED_WIDTH_DEFAULT)
@@ -416,16 +417,25 @@ public class StringFunctions{
 
     @Override
     public void eval() {
-      if (index.value < 1) {
+      if (index.value == 0) {
         throw org.apache.drill.common.exceptions.UserException.functionError()
-          .message("Index in split_part must be positive, value provided was "
-            + index.value).build();
+          .message("Index in split_part can not be zero").build();
       }
       String inputString = org.apache.drill.exec.expr.fn.impl.
         StringFunctionHelpers.getStringFromVarCharHolder(in);
-      int arrayIndex = index.value - 1;
-      String result =
-              (String) com.google.common.collect.Iterables.get(splitter.split(inputString), arrayIndex, "");
+      String result = "";
+      if (index.value < 0) {
+        java.util.List<String> splits = splitter.splitToList(inputString);
+        int size = splits.size();
+        int arrayIndex = size + index.value;
+        if (arrayIndex >= 0) {
+          result = (String) splits.get(arrayIndex);
+        }
+      } else {
+        int arrayIndex = index.value - 1;
+        result =
+          (String) com.google.common.collect.Iterables.get(splitter.split(inputString), arrayIndex, "");
+      }
       byte[] strBytes = result.getBytes(com.google.common.base.Charsets.UTF_8);
 
       out.buffer = buffer = buffer.reallocIfNeeded(strBytes.length);
@@ -438,8 +448,10 @@ public class StringFunctions{
 
   /**
    * Return the string part from start to end after splitting the input string
-   * using the specified delimiter. The start must be a positive integer. The
-   * end is included and must be greater than or equal to the start index.
+   * using the specified delimiter. The start and end index can be positive or
+   * negative, counting from beginning if is positive, from end if is negative.
+   * End index is included and must have the same sign and greater than or equal
+   * to the start index.
    */
   @FunctionTemplate(name = "split_part", scope = FunctionScope.SIMPLE, nulls =
     NullHandling.NULL_IF_NULL, outputWidthCalculatorType =
@@ -476,26 +488,44 @@ public class StringFunctions{
 
     @Override
     public void eval() {
-      if (start.value < 1) {
+      if (start.value == 0) {
+        throw org.apache.drill.common.exceptions.UserException.functionError()
+          .message("Start index in split_part can not be zero, value provided was " +
+            "[start:" + start.value + "]").build();
+      }
+      if (start.value * end.value <= 0) {
         throw org.apache.drill.common.exceptions.UserException.functionError()
-          .message("Start in split_part must be positive, value provided was "
-            + start.value).build();
+          .message("End index in split_part must has the same sign as the start " +
+            "index, value provided was [start:" + start.value + ",end:" + end.value + "]").build();
       }
       if (end.value < start.value) {
         throw org.apache.drill.common.exceptions.UserException.functionError()
-          .message("End in split_part must be greater than or equal to start, " +
-            "value provided was start:" + start.value + ",end:" + end.value).build();
+          .message("End index in split_part must be greater or equal to start " +
+            "index, value provided was [start:" + start.value + ",end:" + end.value + "]").build();
       }
+
       String inputString = org.apache.drill.exec.expr.fn.impl.
         StringFunctionHelpers.getStringFromVarCharHolder(in);
-      int arrayIndex = start.value - 1;
-      java.util.Iterator<String> iterator = com.google.common.collect.Iterables
-        .limit(com.google.common.collect.Iterables.skip(splitter
-            .split(inputString), arrayIndex),end.value - start.value + 1)
-        .iterator();
+      java.util.Iterator<String> iterator = java.util.Collections.emptyIterator();
+      if (start.value < 0) {
+        java.util.List<String> splits = splitter.splitToList(inputString);
+        int size = splits.size();
+        int startIndex = size + start.value;
+        int endIndex = size + end.value + 1;
+        if (startIndex >= 0) {
+          iterator = splits.subList(startIndex, endIndex).iterator();
+        } else if (endIndex > 0) {
+          iterator = splits.subList(0, endIndex).iterator();
+        }
+      } else {
+        int arrayIndex = start.value - 1;
+        iterator = com.google.common.collect.Iterables
+          .limit(com.google.common.collect.Iterables.skip(splitter
+            .split(inputString), arrayIndex), end.value - start.value + 1)
+          .iterator();
+      }
       byte[] strBytes = joiner.join(iterator).getBytes(
         com.google.common.base.Charsets.UTF_8);
-
       out.buffer = buffer = buffer.reallocIfNeeded(strBytes.length);
       out.start = 0;
       out.end = strBytes.length;
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
index f7a09ce..555323b 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
@@ -70,6 +70,14 @@ public class TestStringFunctions extends BaseTestQuery {
         .baselineValues("rty")
         .go();
 
+    testBuilder()
+      .sqlQuery("select split_part(a, '~@~', -2) res1 from (values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("def")
+      .baselineValues("rty")
+      .go();
+
     // with a multi-byte splitter
     testBuilder()
         .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', 2) res1 from (values(1))")
@@ -78,6 +86,13 @@ public class TestStringFunctions extends BaseTestQuery {
         .baselineValues("drill")
         .go();
 
+    testBuilder()
+      .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', -2) res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("drill")
+      .go();
+
     // going beyond the last available index, returns empty string
     testBuilder()
         .sqlQuery("select split_part('a,b,c', ',', 4) res1 from (values(1))")
@@ -86,6 +101,13 @@ public class TestStringFunctions extends BaseTestQuery {
         .baselineValues("")
         .go();
 
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ',', -4) res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("")
+      .go();
+
     // if the delimiter does not appear in the string, 1 returns the whole string
     testBuilder()
         .sqlQuery("select split_part('a,b,c', ' ', 1) res1 from (values(1))")
@@ -93,6 +115,13 @@ public class TestStringFunctions extends BaseTestQuery {
         .baselineColumns("res1")
         .baselineValues("a,b,c")
         .go();
+
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ' ', -1) res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("a,b,c")
+      .go();
   }
 
   @Test
@@ -115,6 +144,15 @@ public class TestStringFunctions extends BaseTestQuery {
       .baselineValues("rty~@~uio")
       .go();
 
+    testBuilder()
+      .sqlQuery("select split_part(a, '~@~', -2, -1) res1 from (" +
+        "values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("def~@~ghi")
+      .baselineValues("rty~@~uio")
+      .go();
+
     // with a multi-byte splitter
     testBuilder()
       .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', 2, 2) " +
@@ -124,6 +162,14 @@ public class TestStringFunctions extends BaseTestQuery {
       .baselineValues("drill")
       .go();
 
+    testBuilder()
+      .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', -2, -2) " +
+        "res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("drill")
+      .go();
+
     // start index going beyond the last available index, returns empty string
     testBuilder()
       .sqlQuery("select split_part('a,b,c', ',', 4, 5) res1 from (values(1))")
@@ -132,6 +178,13 @@ public class TestStringFunctions extends BaseTestQuery {
       .baselineValues("")
       .go();
 
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ',', -5, -4) res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("")
+      .go();
+
     // end index going beyond the last available index, returns remaining string
     testBuilder()
       .sqlQuery("select split_part('a,b,c', ',', 1, 10) res1 from (values(1))")
@@ -140,6 +193,13 @@ public class TestStringFunctions extends BaseTestQuery {
       .baselineValues("a,b,c")
       .go();
 
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ',', -10, -1) res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("a,b,c")
+      .go();
+
     // if the delimiter does not appear in the string, 1 returns the whole string
     testBuilder()
       .sqlQuery("select split_part('a,b,c', ' ', 1, 2) res1 from (values(1))")
@@ -147,6 +207,13 @@ public class TestStringFunctions extends BaseTestQuery {
       .baselineColumns("res1")
       .baselineValues("a,b,c")
       .go();
+
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ' ', -2, -1) res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("a,b,c")
+      .go();
   }
 
   @Test
@@ -162,8 +229,8 @@ public class TestStringFunctions extends BaseTestQuery {
         .go();
       expectedErrorEncountered = false;
     } catch (Exception ex) {
-      assertTrue(ex.getMessage().contains("Index in split_part must be positive, " +
-        "value provided was 0"));
+      assertTrue(ex.getMessage(),
+        ex.getMessage().contains("Index in split_part can not be zero"));
       expectedErrorEncountered = true;
     }
     if (!expectedErrorEncountered) {
@@ -181,8 +248,46 @@ public class TestStringFunctions extends BaseTestQuery {
         .go();
       expectedErrorEncountered = false;
     } catch (Exception ex) {
-      assertTrue(ex.getMessage().contains("End in split_part must be greater " +
-        "than or equal to start"));
+      assertTrue(ex.getMessage(),
+        ex.getMessage().contains("End index in split_part must be greater or equal to start index"));
+      expectedErrorEncountered = true;
+    }
+    if (!expectedErrorEncountered) {
+      throw new RuntimeException("Missing expected error on invalid index for " +
+        "split_part function");
+    }
+
+    try {
+      testBuilder()
+        .sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', -1, -2) res1 from " +
+          "(values(1))")
+        .ordered()
+        .baselineColumns("res1")
+        .baselineValues("abc")
+        .go();
+      expectedErrorEncountered = false;
+    } catch (Exception ex) {
+      assertTrue(ex.getMessage(),
+        ex.getMessage().contains("End index in split_part must be greater or equal to start index"));
+      expectedErrorEncountered = true;
+    }
+    if (!expectedErrorEncountered) {
+      throw new RuntimeException("Missing expected error on invalid index for " +
+        "split_part function");
+    }
+
+    try {
+      testBuilder()
+        .sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', -1, 2) res1 from " +
+          "(values(1))")
+        .ordered()
+        .baselineColumns("res1")
+        .baselineValues("abc")
+        .go();
+      expectedErrorEncountered = false;
+    } catch (Exception ex) {
+      assertTrue(ex.getMessage(),
+        ex.getMessage().contains("End index in split_part must has the same sign as the start index"));
       expectedErrorEncountered = true;
     }
     if (!expectedErrorEncountered) {