This is an automated email from the ASF dual-hosted git repository.

dzamo pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git


The following commit(s) were added to refs/heads/master by this push:
     new 55e94c4  DRILL-8094: Support reverse truncation for split_part udf 
(#2416)
55e94c4 is described below

commit 55e94c4e1c4a05ac7010391daea8f4f0804b0286
Author: leon <[email protected]>
AuthorDate: Mon Jan 17 23:57:35 2022 +0800

    DRILL-8094: Support reverse truncation for split_part udf (#2416)
    
    * DRILL-8094: Support reverse truncation for split_part udf
    
    * fix ut
    
    Co-authored-by: feiteng.wtf <[email protected]>
---
 .../drill/exec/expr/fn/impl/StringFunctions.java   |  70 +++++++++----
 .../exec/expr/fn/impl/TestStringFunctions.java     | 113 ++++++++++++++++++++-
 2 files changed, 159 insertions(+), 24 deletions(-)

diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
index 4dca322..27b0644 100644
--- 
a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
@@ -384,7 +384,8 @@ public class StringFunctions{
 
   /**
    * Return the string part at index after splitting the input string using the
-   * specified delimiter. The index must be a positive integer.
+   * specified delimiter. The index starts 1 or -1, counting from beginning if
+   * is positive, from end if is negative.
    */
   @FunctionTemplate(name = "split_part", scope = FunctionScope.SIMPLE, nulls = 
NullHandling.NULL_IF_NULL,
                     outputWidthCalculatorType = 
OutputWidthCalculatorType.CUSTOM_FIXED_WIDTH_DEFAULT)
@@ -416,16 +417,25 @@ public class StringFunctions{
 
     @Override
     public void eval() {
-      if (index.value < 1) {
+      if (index.value == 0) {
         throw org.apache.drill.common.exceptions.UserException.functionError()
-          .message("Index in split_part must be positive, value provided was "
-            + index.value).build();
+          .message("Index in split_part can not be zero").build();
       }
       String inputString = org.apache.drill.exec.expr.fn.impl.
         StringFunctionHelpers.getStringFromVarCharHolder(in);
-      int arrayIndex = index.value - 1;
-      String result =
-              (String) 
com.google.common.collect.Iterables.get(splitter.split(inputString), 
arrayIndex, "");
+      String result = "";
+      if (index.value < 0) {
+        java.util.List<String> splits = splitter.splitToList(inputString);
+        int size = splits.size();
+        int arrayIndex = size + index.value;
+        if (arrayIndex >= 0) {
+          result = (String) splits.get(arrayIndex);
+        }
+      } else {
+        int arrayIndex = index.value - 1;
+        result =
+          (String) 
com.google.common.collect.Iterables.get(splitter.split(inputString), 
arrayIndex, "");
+      }
       byte[] strBytes = result.getBytes(com.google.common.base.Charsets.UTF_8);
 
       out.buffer = buffer = buffer.reallocIfNeeded(strBytes.length);
@@ -438,8 +448,10 @@ public class StringFunctions{
 
   /**
    * Return the string part from start to end after splitting the input string
-   * using the specified delimiter. The start must be a positive integer. The
-   * end is included and must be greater than or equal to the start index.
+   * using the specified delimiter. The start and end index can be positive or
+   * negative, counting from beginning if is positive, from end if is negative.
+   * End index is included and must have the same sign and greater than or 
equal
+   * to the start index.
    */
   @FunctionTemplate(name = "split_part", scope = FunctionScope.SIMPLE, nulls =
     NullHandling.NULL_IF_NULL, outputWidthCalculatorType =
@@ -476,26 +488,44 @@ public class StringFunctions{
 
     @Override
     public void eval() {
-      if (start.value < 1) {
+      if (start.value == 0) {
+        throw org.apache.drill.common.exceptions.UserException.functionError()
+          .message("Start index in split_part can not be zero, value provided 
was " +
+            "[start:" + start.value + "]").build();
+      }
+      if (start.value * end.value <= 0) {
         throw org.apache.drill.common.exceptions.UserException.functionError()
-          .message("Start in split_part must be positive, value provided was "
-            + start.value).build();
+          .message("End index in split_part must has the same sign as the 
start " +
+            "index, value provided was [start:" + start.value + ",end:" + 
end.value + "]").build();
       }
       if (end.value < start.value) {
         throw org.apache.drill.common.exceptions.UserException.functionError()
-          .message("End in split_part must be greater than or equal to start, 
" +
-            "value provided was start:" + start.value + ",end:" + 
end.value).build();
+          .message("End index in split_part must be greater or equal to start 
" +
+            "index, value provided was [start:" + start.value + ",end:" + 
end.value + "]").build();
       }
+
       String inputString = org.apache.drill.exec.expr.fn.impl.
         StringFunctionHelpers.getStringFromVarCharHolder(in);
-      int arrayIndex = start.value - 1;
-      java.util.Iterator<String> iterator = com.google.common.collect.Iterables
-        .limit(com.google.common.collect.Iterables.skip(splitter
-            .split(inputString), arrayIndex),end.value - start.value + 1)
-        .iterator();
+      java.util.Iterator<String> iterator = 
java.util.Collections.emptyIterator();
+      if (start.value < 0) {
+        java.util.List<String> splits = splitter.splitToList(inputString);
+        int size = splits.size();
+        int startIndex = size + start.value;
+        int endIndex = size + end.value + 1;
+        if (startIndex >= 0) {
+          iterator = splits.subList(startIndex, endIndex).iterator();
+        } else if (endIndex > 0) {
+          iterator = splits.subList(0, endIndex).iterator();
+        }
+      } else {
+        int arrayIndex = start.value - 1;
+        iterator = com.google.common.collect.Iterables
+          .limit(com.google.common.collect.Iterables.skip(splitter
+            .split(inputString), arrayIndex), end.value - start.value + 1)
+          .iterator();
+      }
       byte[] strBytes = joiner.join(iterator).getBytes(
         com.google.common.base.Charsets.UTF_8);
-
       out.buffer = buffer = buffer.reallocIfNeeded(strBytes.length);
       out.start = 0;
       out.end = strBytes.length;
diff --git 
a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
 
b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
index f7a09ce..555323b 100644
--- 
a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
+++ 
b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
@@ -70,6 +70,14 @@ public class TestStringFunctions extends BaseTestQuery {
         .baselineValues("rty")
         .go();
 
+    testBuilder()
+      .sqlQuery("select split_part(a, '~@~', -2) res1 from 
(values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("def")
+      .baselineValues("rty")
+      .go();
+
     // with a multi-byte splitter
     testBuilder()
         .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', 
2) res1 from (values(1))")
@@ -78,6 +86,13 @@ public class TestStringFunctions extends BaseTestQuery {
         .baselineValues("drill")
         .go();
 
+    testBuilder()
+      .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', -2) 
res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("drill")
+      .go();
+
     // going beyond the last available index, returns empty string
     testBuilder()
         .sqlQuery("select split_part('a,b,c', ',', 4) res1 from (values(1))")
@@ -86,6 +101,13 @@ public class TestStringFunctions extends BaseTestQuery {
         .baselineValues("")
         .go();
 
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ',', -4) res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("")
+      .go();
+
     // if the delimiter does not appear in the string, 1 returns the whole 
string
     testBuilder()
         .sqlQuery("select split_part('a,b,c', ' ', 1) res1 from (values(1))")
@@ -93,6 +115,13 @@ public class TestStringFunctions extends BaseTestQuery {
         .baselineColumns("res1")
         .baselineValues("a,b,c")
         .go();
+
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ' ', -1) res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("a,b,c")
+      .go();
   }
 
   @Test
@@ -115,6 +144,15 @@ public class TestStringFunctions extends BaseTestQuery {
       .baselineValues("rty~@~uio")
       .go();
 
+    testBuilder()
+      .sqlQuery("select split_part(a, '~@~', -2, -1) res1 from (" +
+        "values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("def~@~ghi")
+      .baselineValues("rty~@~uio")
+      .go();
+
     // with a multi-byte splitter
     testBuilder()
       .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', 2, 
2) " +
@@ -124,6 +162,14 @@ public class TestStringFunctions extends BaseTestQuery {
       .baselineValues("drill")
       .go();
 
+    testBuilder()
+      .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', -2, 
-2) " +
+        "res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("drill")
+      .go();
+
     // start index going beyond the last available index, returns empty string
     testBuilder()
       .sqlQuery("select split_part('a,b,c', ',', 4, 5) res1 from (values(1))")
@@ -132,6 +178,13 @@ public class TestStringFunctions extends BaseTestQuery {
       .baselineValues("")
       .go();
 
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ',', -5, -4) res1 from 
(values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("")
+      .go();
+
     // end index going beyond the last available index, returns remaining 
string
     testBuilder()
       .sqlQuery("select split_part('a,b,c', ',', 1, 10) res1 from (values(1))")
@@ -140,6 +193,13 @@ public class TestStringFunctions extends BaseTestQuery {
       .baselineValues("a,b,c")
       .go();
 
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ',', -10, -1) res1 from 
(values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("a,b,c")
+      .go();
+
     // if the delimiter does not appear in the string, 1 returns the whole 
string
     testBuilder()
       .sqlQuery("select split_part('a,b,c', ' ', 1, 2) res1 from (values(1))")
@@ -147,6 +207,13 @@ public class TestStringFunctions extends BaseTestQuery {
       .baselineColumns("res1")
       .baselineValues("a,b,c")
       .go();
+
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ' ', -2, -1) res1 from 
(values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("a,b,c")
+      .go();
   }
 
   @Test
@@ -162,8 +229,8 @@ public class TestStringFunctions extends BaseTestQuery {
         .go();
       expectedErrorEncountered = false;
     } catch (Exception ex) {
-      assertTrue(ex.getMessage().contains("Index in split_part must be 
positive, " +
-        "value provided was 0"));
+      assertTrue(ex.getMessage(),
+        ex.getMessage().contains("Index in split_part can not be zero"));
       expectedErrorEncountered = true;
     }
     if (!expectedErrorEncountered) {
@@ -181,8 +248,46 @@ public class TestStringFunctions extends BaseTestQuery {
         .go();
       expectedErrorEncountered = false;
     } catch (Exception ex) {
-      assertTrue(ex.getMessage().contains("End in split_part must be greater " 
+
-        "than or equal to start"));
+      assertTrue(ex.getMessage(),
+        ex.getMessage().contains("End index in split_part must be greater or 
equal to start index"));
+      expectedErrorEncountered = true;
+    }
+    if (!expectedErrorEncountered) {
+      throw new RuntimeException("Missing expected error on invalid index for 
" +
+        "split_part function");
+    }
+
+    try {
+      testBuilder()
+        .sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', -1, -2) res1 
from " +
+          "(values(1))")
+        .ordered()
+        .baselineColumns("res1")
+        .baselineValues("abc")
+        .go();
+      expectedErrorEncountered = false;
+    } catch (Exception ex) {
+      assertTrue(ex.getMessage(),
+        ex.getMessage().contains("End index in split_part must be greater or 
equal to start index"));
+      expectedErrorEncountered = true;
+    }
+    if (!expectedErrorEncountered) {
+      throw new RuntimeException("Missing expected error on invalid index for 
" +
+        "split_part function");
+    }
+
+    try {
+      testBuilder()
+        .sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', -1, 2) res1 
from " +
+          "(values(1))")
+        .ordered()
+        .baselineColumns("res1")
+        .baselineValues("abc")
+        .go();
+      expectedErrorEncountered = false;
+    } catch (Exception ex) {
+      assertTrue(ex.getMessage(),
+        ex.getMessage().contains("End index in split_part must has the same 
sign as the start index"));
       expectedErrorEncountered = true;
     }
     if (!expectedErrorEncountered) {

Reply via email to