This is an automated email from the ASF dual-hosted git repository.

laurent pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/drill.git


The following commit(s) were added to refs/heads/master by this push:
     new ad3f344  DRILL-7928: Add fourth parameter for split_part udf
ad3f344 is described below

commit ad3f344ac21e0462aa82f51f648a21a0554cf368
Author: feiteng <[email protected]>
AuthorDate: Sat May 15 12:05:53 2021 +0800

    DRILL-7928: Add fourth parameter for split_part udf
---
 .../drill/exec/expr/fn/impl/StringFunctions.java   |  86 ++++++++++++++--
 .../exec/expr/fn/impl/TestStringFunctions.java     | 114 +++++++++++++++++----
 2 files changed, 175 insertions(+), 25 deletions(-)

diff --git 
a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
 
b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
index e58e286..4dca322 100644
--- 
a/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
+++ 
b/exec/java-exec/src/main/java/org/apache/drill/exec/expr/fn/impl/StringFunctions.java
@@ -382,7 +382,10 @@ public class StringFunctions{
 
   }
 
-
+  /**
+   * Return the string part at index after splitting the input string using the
+   * specified delimiter. The index must be a positive integer.
+   */
   @FunctionTemplate(name = "split_part", scope = FunctionScope.SIMPLE, nulls = 
NullHandling.NULL_IF_NULL,
                     outputWidthCalculatorType = 
OutputWidthCalculatorType.CUSTOM_FIXED_WIDTH_DEFAULT)
   public static class SplitPart implements DrillSimpleFunc {
@@ -405,10 +408,6 @@ public class StringFunctions{
 
     @Override
     public void setup() {
-      if (index.value < 1) {
-        throw org.apache.drill.common.exceptions.UserException.functionError()
-            .message("Index in split_part must be positive, value provided was 
" + index.value).build();
-      }
       String split = org.apache.drill.exec.expr.fn.impl.StringFunctionHelpers.
               toStringFromUTF8(delimiter.start, delimiter.end, 
delimiter.buffer);
       splitter = com.google.common.base.Splitter.on(split);
@@ -417,8 +416,13 @@ public class StringFunctions{
 
     @Override
     public void eval() {
-      String inputString =
-              
org.apache.drill.exec.expr.fn.impl.StringFunctionHelpers.toStringFromUTF8(in.start,
 in.end, in.buffer);
+      if (index.value < 1) {
+        throw org.apache.drill.common.exceptions.UserException.functionError()
+          .message("Index in split_part must be positive, value provided was "
+            + index.value).build();
+      }
+      String inputString = org.apache.drill.exec.expr.fn.impl.
+        StringFunctionHelpers.getStringFromVarCharHolder(in);
       int arrayIndex = index.value - 1;
       String result =
               (String) 
com.google.common.collect.Iterables.get(splitter.split(inputString), 
arrayIndex, "");
@@ -432,6 +436,74 @@ public class StringFunctions{
 
   }
 
+  /**
+   * Return the string part from start to end after splitting the input string
+   * using the specified delimiter. The start must be a positive integer. The
+   * end is included and must be greater than or equal to the start index.
+   */
+  @FunctionTemplate(name = "split_part", scope = FunctionScope.SIMPLE, nulls =
+    NullHandling.NULL_IF_NULL, outputWidthCalculatorType =
+    OutputWidthCalculatorType.CUSTOM_FIXED_WIDTH_DEFAULT)
+  public static class SplitPartStartEnd implements DrillSimpleFunc {
+    @Param
+    VarCharHolder in;
+    @Param
+    VarCharHolder delimiter;
+    @Param
+    IntHolder start;
+    @Param
+    IntHolder end;
+
+    @Workspace
+    com.google.common.base.Splitter splitter;
+
+    @Workspace
+    com.google.common.base.Joiner joiner;
+
+    @Inject
+    DrillBuf buffer;
+
+    @Output
+    VarCharHolder out;
+
+    @Override
+    public void setup() {
+      String split = org.apache.drill.exec.expr.fn.impl.StringFunctionHelpers.
+        toStringFromUTF8(delimiter.start, delimiter.end, delimiter.buffer);
+      splitter = com.google.common.base.Splitter.on(split);
+      joiner = com.google.common.base.Joiner.on(split);
+    }
+
+    @Override
+    public void eval() {
+      if (start.value < 1) {
+        throw org.apache.drill.common.exceptions.UserException.functionError()
+          .message("Start in split_part must be positive, value provided was "
+            + start.value).build();
+      }
+      if (end.value < start.value) {
+        throw org.apache.drill.common.exceptions.UserException.functionError()
+          .message("End in split_part must be greater than or equal to start, 
" +
+            "value provided was start:" + start.value + ",end:" + 
end.value).build();
+      }
+      String inputString = org.apache.drill.exec.expr.fn.impl.
+        StringFunctionHelpers.getStringFromVarCharHolder(in);
+      int arrayIndex = start.value - 1;
+      java.util.Iterator<String> iterator = com.google.common.collect.Iterables
+        .limit(com.google.common.collect.Iterables.skip(splitter
+            .split(inputString), arrayIndex),end.value - start.value + 1)
+        .iterator();
+      byte[] strBytes = joiner.join(iterator).getBytes(
+        com.google.common.base.Charsets.UTF_8);
+
+      out.buffer = buffer = buffer.reallocIfNeeded(strBytes.length);
+      out.start = 0;
+      out.end = strBytes.length;
+      out.buffer.setBytes(0, strBytes);
+    }
+
+  }
+
   // same as function "position(substr, str) ", except the reverse order of 
argument.
   @FunctionTemplate(name = "strpos", scope = FunctionScope.SIMPLE, nulls = 
NullHandling.NULL_IF_NULL)
   public static class Strpos implements DrillSimpleFunc {
diff --git 
a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
 
b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
index 8965edf..f7a09ce 100644
--- 
a/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
+++ 
b/exec/java-exec/src/test/java/org/apache/drill/exec/expr/fn/impl/TestStringFunctions.java
@@ -70,24 +70,6 @@ public class TestStringFunctions extends BaseTestQuery {
         .baselineValues("rty")
         .go();
 
-    // invalid index
-    boolean expectedErrorEncountered;
-    try {
-      testBuilder()
-          .sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', 0) res1 from 
(values(1))")
-          .ordered()
-          .baselineColumns("res1")
-          .baselineValues("abc")
-          .go();
-      expectedErrorEncountered = false;
-    } catch (Exception ex) {
-      assertTrue(ex.getMessage().contains("Index in split_part must be 
positive, value provided was 0"));
-      expectedErrorEncountered = true;
-    }
-    if (!expectedErrorEncountered) {
-      throw new RuntimeException("Missing expected error on invalid index for 
split_part function");
-    }
-
     // with a multi-byte splitter
     testBuilder()
         .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', 
2) res1 from (values(1))")
@@ -114,6 +96,102 @@ public class TestStringFunctions extends BaseTestQuery {
   }
 
   @Test
+  public void testSplitPartStartEnd() throws Exception {
+    testBuilder()
+      .sqlQuery("select split_part(a, '~@~', 1, 2) res1 from (" +
+        "values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("abc~@~def")
+      .baselineValues("qwe~@~rty")
+      .go();
+
+    testBuilder()
+      .sqlQuery("select split_part(a, '~@~', 2, 3) res1 from (" +
+        "values('abc~@~def~@~ghi'), ('qwe~@~rty~@~uio')) as t(a)")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("def~@~ghi")
+      .baselineValues("rty~@~uio")
+      .go();
+
+    // with a multi-byte splitter
+    testBuilder()
+      .sqlQuery("select split_part('abc\\u1111drill\\u1111ghi', '\\u1111', 2, 
2) " +
+        "res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("drill")
+      .go();
+
+    // start index going beyond the last available index, returns empty string
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ',', 4, 5) res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("")
+      .go();
+
+    // end index going beyond the last available index, returns remaining 
string
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ',', 1, 10) res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("a,b,c")
+      .go();
+
+    // if the delimiter does not appear in the string, 1 returns the whole 
string
+    testBuilder()
+      .sqlQuery("select split_part('a,b,c', ' ', 1, 2) res1 from (values(1))")
+      .ordered()
+      .baselineColumns("res1")
+      .baselineValues("a,b,c")
+      .go();
+  }
+
+  @Test
+  public void testInvalidSplitPartParameters() {
+    boolean expectedErrorEncountered;
+    try {
+      testBuilder()
+        .sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', 0) res1 from " +
+          "(values(1))")
+        .ordered()
+        .baselineColumns("res1")
+        .baselineValues("abc")
+        .go();
+      expectedErrorEncountered = false;
+    } catch (Exception ex) {
+      assertTrue(ex.getMessage().contains("Index in split_part must be 
positive, " +
+        "value provided was 0"));
+      expectedErrorEncountered = true;
+    }
+    if (!expectedErrorEncountered) {
+      throw new RuntimeException("Missing expected error on invalid index for 
" +
+        "split_part function");
+    }
+
+    try {
+      testBuilder()
+        .sqlQuery("select split_part('abc~@~def~@~ghi', '~@~', 2, 1) res1 from 
" +
+          "(values(1))")
+        .ordered()
+        .baselineColumns("res1")
+        .baselineValues("abc")
+        .go();
+      expectedErrorEncountered = false;
+    } catch (Exception ex) {
+      assertTrue(ex.getMessage().contains("End in split_part must be greater " 
+
+        "than or equal to start"));
+      expectedErrorEncountered = true;
+    }
+    if (!expectedErrorEncountered) {
+      throw new RuntimeException("Missing expected error on invalid index for 
" +
+        "split_part function");
+    }
+  }
+
+  @Test
   public void testRegexpMatches() throws Exception {
     testBuilder()
         .sqlQuery("select regexp_matches(a, '^a.*') res1, regexp_matches(b, 
'^a.*') res2 " +

Reply via email to