You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@drill.apache.org by dz...@apache.org on 2022/01/17 15:57:48 UTC
[drill] branch master updated: DRILL-8094: Support reverse truncation for split_part udf (#2416)
This is an automated email from the ASF dual-hosted git repository.
dzamo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git
The following commit(s) were added to refs/heads/master by this push:
new 55e94c4 DRILL-8094: Support reverse truncation for split_part udf (#2416)
55e94c4 is described below
commit 55e94c4e1c4a05ac7010391daea8f4f0804b0286
Author: leon <32...@qq.com>
AuthorDate: Mon Jan 17 23:57:35 2022 +0800
DRILL-8094: Support reverse truncation for split_part udf (#2416)
* DRILL-8094: Support reverse truncation for split_part udf
* fix ut
Co-authored-by: feiteng.wtf <fe...@cainiao.com>
---
.../drill/exec/expr/fn/impl/StringFunctions.java | 70 +++++++++----
.../exec/expr/fn/impl/TestStringFunctions.java | 113 ++++++++++++++++++++-
2 files changed, 159 insertions(+), 24 deletions(-)
diff --git a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
index 4dca322..27b0644 100644
--- a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
+++ b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
@@ -384,7 +384,8 @@ public class StringFunctions{
/**
* Return the string part at index after splitting the input string using the
- * specified delimiter. The index must be a positive integer.
+ * specified delimiter. The index starts 1 or -1, counting from beginning if
+ * is positive, from end if is negative.
*/
@FunctionTemplate(name = "split_part", scope = FunctionScope.SIMPLE, nulls = NullHandling.NULL_IF_NULL,
outputWidthCalculatorType = OutputWidthCalculatorType.CUSTOM_FIXED_WIDTH_DEFAULT)
@@ -416,16 +417,25 @@ public class StringFunctions{
@Override
public void eval() {
- if (index.value < 1) {
+ if (index.value == 0) {
throw org.apache.drill.common.exceptions.UserException.functionError()
- .message("Index in split_part must be positive, value provided was "
- + index.value).build();
+ .message("Index in split_part can not be zero").build();
}
String inputString = org.apache.drill.exec.expr.fn.impl.
StringFunctionHelpers.getStringFromVarCharHolder(in);
- int arrayIndex = index.value - 1;
- String result =
- (String) com.google.common.collect.Iterables.get(splitter.split(inputString), arrayIndex, "");
+ String result = "";
+ if (index.value < 0) {
+ java.util.List<String> splits = splitter.splitToList(inputString);
+ int size = splits.size();
+ int arrayIndex = size + index.value;
+ if (arrayIndex >= 0) {
+ result = (String) splits.get(arrayIndex);
+ }
+ } else {
+ int arrayIndex = index.value - 1;
+ result =
+ (String) com.google.common.collect.Iterables.get(splitter.split(inputString), arrayIndex, "");
+ }
byte[] strBytes = result.getBytes(com.google.common.base.Charsets.UTF_8);
out.buffer = buffer = buffer.reallocIfNeeded(strBytes.length);
@@ -438,8 +448,10 @@ public class StringFunctions{
/**
* Return the string part from start to end after splitting the input string
- * using the specified delimiter. The start must be a positive integer. The
- * end is included and must be greater than or equal to the start index.
+ * using the specified delimiter. The start and end index can be positive or
+ * negative, counting from beginning if is positive, from end if is negative.
+ * End index is included and must have the same sign and greater than or equal
+ * to the start index.
*/
@FunctionTemplate(name = "split_part", scope = FunctionScope.SIMPLE, nulls =
NullHandling.NULL_IF_NULL, outputWidthCalculatorType =
@@ -476,26 +488,44 @@ public class StringFunctions{
@Override
public void eval() {
- if (start.value < 1) {
+ if (start.value == 0) {
+ throw org.apache.drill.common.exceptions.UserException.functionError()
+ .message("Start index in split_part can not be zero, value provided was " +
+ "[start:" + start.value + "]").build();
+ }
+ if (start.value * end.value <= 0) {
throw org.apache.drill.common.exceptions.UserException.functionError()
- .message("Start in split_part must be positive, value provided was "
- + start.value).build();
+ .message("End index in split_part must has the same sign as the start " +
+ "index, value provided was [start:" + start.value + ",end:" + end.value + "]").build();
}
if (end.value < start.value) {
throw org.apache.drill.common.exceptions.UserException.functionError()
- .message("End in split_part must be greater than or equal to start, " +
- "value provided was start:" + start.value + ",end:" + end.value).build();
+ .message("End index in split_part must be greater or equal to start " +
+ "index, value provided was [start:" + start.value + ",end:" + end.value + "]").build();
}
+
String inputString = org.apache.drill.exec.expr.fn.impl.
StringFunctionHelpers.getStringFromVarCharHolder(in);
- int arrayIndex = start.value - 1;
- java.util.Iterator<String> iterator = com.google.common.collect.Iterables
- .limit(com.google.common.collect.Iterables.skip(splitter
- .split(inputString), arrayIndex),end.value - start.value + 1)
- .iterator();
+ java.util.Iterator<String> iterator = java.util.Collections.emptyIterator();
+ if (start.value < 0) {
+ java.util.List<String> splits = splitter.splitToList(inputString);
+ int size = splits.size();
+ int startIndex = size + start.value;
+ int endIndex = size + end.value + 1;
+ if (startIndex >= 0) {
+ iterator = splits.subList(startIndex, endIndex).iterator();
+ } else if (endIndex > 0) {
+ iterator = splits.subList(0, endIndex).iterator();
+ }
+ } else {
+ int arrayIndex = start.value - 1;
+ iterator = com.google.common.collect.Iterables
+ .limit(com.google.common.collect.Iterables.skip(splitter
+ .split(inputString), arrayIndex), end.value - start.value + 1)
+ .iterator();
+ }
byte[] strBytes = joiner.join(iterator).getBytes(
com.google.common.base.Charsets.UTF_8);
-
out.buffer = buffer = buffer.reallocIfNeeded(strBytes.length);
out.start = 0;
out.end = strBytes.length;
diff --git a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
index f7a09ce..555323b 100644
--- a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
+++ b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
@@ -70,6 +70,14 @@ public class TestStringFunctions extends BaseTestQuery {
.baselineValues("rty")
.go();
+ testBuilder()
+ .sqlQuery("select split_part(a, '~@~', -2) res1 from (values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
+ .ordered()
+ .baselineColumns("res1")
+ .baselineValues("def")
+ .baselineValues("rty")
+ .go();
+
// with a multi-byte splitter
testBuilder()
.sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', 2) res1 from (values(1))")
@@ -78,6 +86,13 @@ public class TestStringFunctions extends BaseTestQuery {
.baselineValues("drill")
.go();
+ testBuilder()
+ .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', -2) res1 from (values(1))")
+ .ordered()
+ .baselineColumns("res1")
+ .baselineValues("drill")
+ .go();
+
// going beyond the last available index, returns empty string
testBuilder()
.sqlQuery("select split_part('a,b,c', ',', 4) res1 from (values(1))")
@@ -86,6 +101,13 @@ public class TestStringFunctions extends BaseTestQuery {
.baselineValues("")
.go();
+ testBuilder()
+ .sqlQuery("select split_part('a,b,c', ',', -4) res1 from (values(1))")
+ .ordered()
+ .baselineColumns("res1")
+ .baselineValues("")
+ .go();
+
// if the delimiter does not appear in the string, 1 returns the whole string
testBuilder()
.sqlQuery("select split_part('a,b,c', ' ', 1) res1 from (values(1))")
@@ -93,6 +115,13 @@ public class TestStringFunctions extends BaseTestQuery {
.baselineColumns("res1")
.baselineValues("a,b,c")
.go();
+
+ testBuilder()
+ .sqlQuery("select split_part('a,b,c', ' ', -1) res1 from (values(1))")
+ .ordered()
+ .baselineColumns("res1")
+ .baselineValues("a,b,c")
+ .go();
}
@Test
@@ -115,6 +144,15 @@ public class TestStringFunctions extends BaseTestQuery {
.baselineValues("rty~@~uio")
.go();
+ testBuilder()
+ .sqlQuery("select split_part(a, '~@~', -2, -1) res1 from (" +
+ "values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
+ .ordered()
+ .baselineColumns("res1")
+ .baselineValues("def~@~ghi")
+ .baselineValues("rty~@~uio")
+ .go();
+
// with a multi-byte splitter
testBuilder()
.sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', 2, 2) " +
@@ -124,6 +162,14 @@ public class TestStringFunctions extends BaseTestQuery {
.baselineValues("drill")
.go();
+ testBuilder()
+ .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', -2, -2) " +
+ "res1 from (values(1))")
+ .ordered()
+ .baselineColumns("res1")
+ .baselineValues("drill")
+ .go();
+
// start index going beyond the last available index, returns empty string
testBuilder()
.sqlQuery("select split_part('a,b,c', ',', 4, 5) res1 from (values(1))")
@@ -132,6 +178,13 @@ public class TestStringFunctions extends BaseTestQuery {
.baselineValues("")
.go();
+ testBuilder()
+ .sqlQuery("select split_part('a,b,c', ',', -5, -4) res1 from (values(1))")
+ .ordered()
+ .baselineColumns("res1")
+ .baselineValues("")
+ .go();
+
// end index going beyond the last available index, returns remaining string
testBuilder()
.sqlQuery("select split_part('a,b,c', ',', 1, 10) res1 from (values(1))")
@@ -140,6 +193,13 @@ public class TestStringFunctions extends BaseTestQuery {
.baselineValues("a,b,c")
.go();
+ testBuilder()
+ .sqlQuery("select split_part('a,b,c', ',', -10, -1) res1 from (values(1))")
+ .ordered()
+ .baselineColumns("res1")
+ .baselineValues("a,b,c")
+ .go();
+
// if the delimiter does not appear in the string, 1 returns the whole string
testBuilder()
.sqlQuery("select split_part('a,b,c', ' ', 1, 2) res1 from (values(1))")
@@ -147,6 +207,13 @@ public class TestStringFunctions extends BaseTestQuery {
.baselineColumns("res1")
.baselineValues("a,b,c")
.go();
+
+ testBuilder()
+ .sqlQuery("select split_part('a,b,c', ' ', -2, -1) res1 from (values(1))")
+ .ordered()
+ .baselineColumns("res1")
+ .baselineValues("a,b,c")
+ .go();
}
@Test
@@ -162,8 +229,8 @@ public class TestStringFunctions extends BaseTestQuery {
.go();
expectedErrorEncountered = false;
} catch (Exception ex) {
- assertTrue(ex.getMessage().contains("Index in split_part must be positive, " +
- "value provided was 0"));
+ assertTrue(ex.getMessage(),
+ ex.getMessage().contains("Index in split_part can not be zero"));
expectedErrorEncountered = true;
}
if (!expectedErrorEncountered) {
@@ -181,8 +248,46 @@ public class TestStringFunctions extends BaseTestQuery {
.go();
expectedErrorEncountered = false;
} catch (Exception ex) {
- assertTrue(ex.getMessage().contains("End in split_part must be greater " +
- "than or equal to start"));
+ assertTrue(ex.getMessage(),
+ ex.getMessage().contains("End index in split_part must be greater or equal to start index"));
+ expectedErrorEncountered = true;
+ }
+ if (!expectedErrorEncountered) {
+ throw new RuntimeException("Missing expected error on invalid index for " +
+ "split_part function");
+ }
+
+ try {
+ testBuilder()
+ .sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', -1, -2) res1 from " +
+ "(values(1))")
+ .ordered()
+ .baselineColumns("res1")
+ .baselineValues("abc")
+ .go();
+ expectedErrorEncountered = false;
+ } catch (Exception ex) {
+ assertTrue(ex.getMessage(),
+ ex.getMessage().contains("End index in split_part must be greater or equal to start index"));
+ expectedErrorEncountered = true;
+ }
+ if (!expectedErrorEncountered) {
+ throw new RuntimeException("Missing expected error on invalid index for " +
+ "split_part function");
+ }
+
+ try {
+ testBuilder()
+ .sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', -1, 2) res1 from " +
+ "(values(1))")
+ .ordered()
+ .baselineColumns("res1")
+ .baselineValues("abc")
+ .go();
+ expectedErrorEncountered = false;
+ } catch (Exception ex) {
+ assertTrue(ex.getMessage(),
+ ex.getMessage().contains("End index in split_part must has the same sign as the start index"));
expectedErrorEncountered = true;
}
if (!expectedErrorEncountered) {