amaliujia commented on a change in pull request #35352:
URL: https://github.com/apache/spark/pull/35352#discussion_r829353129



##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -2095,7 +2095,9 @@ case class ArrayPosition(left: Expression, right: 
Expression)
 case class ElementAt(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled)
+    failOnError: Boolean = SQLConf.get.ansiEnabled,
+    // The value to return if index is out of bound
+    defaultValueOutOfBound: Any = null)

Review comment:
       Is there an example of how to create a Literal("") for Expression type 
of `defaultValueOutOfBound`?
   
   I tried `Literal("")`, etc. It always gives me this error: 
   
   ```
   requirement failed: Literal must have a corresponding value to string, but 
class String found.
   java.lang.IllegalArgumentException: requirement failed: Literal must have a 
corresponding value to string, but class String found.
   ```

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -2095,7 +2095,9 @@ case class ArrayPosition(left: Expression, right: 
Expression)
 case class ElementAt(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled)
+    failOnError: Boolean = SQLConf.get.ansiEnabled,
+    // The value to return if index is out of bound
+    defaultValueOutOfBound: Any = null)

Review comment:
       Is there an example of how to create a Literal("") for Expression type 
of `defaultValueOutOfBound`?
   
   I tried `Literal("")`, etc. It always gives me this error: 
   
   ```
   requirement failed: Literal must have a corresponding value to string, but 
class String found.
   java.lang.IllegalArgumentException: requirement failed: Literal must have a 
corresponding value to string, but class String found.
        at scala.Predef$.require(Predef.scala:281)
        at 
org.apache.spark.sql.catalyst.expressions.Literal$.validateLiteralValue(literals.scala:242)
        at 
org.apache.spark.sql.catalyst.expressions.Literal.<init>(literals.scala:331)
        at 
org.apache.spark.sql.catalyst.expressions.SplitPart.replacement$lzycompute(stringExpressions.scala:3017)
        at 
org.apache.spark.sql.catalyst.expressions.SplitPart.replacement(stringExpressions.scala:3016)
        at 
org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable.dataType(Expression.scala:361)
        at 
org.apache.spark.sql.catalyst.expressions.RuntimeReplaceable.dataType$(Expression.scala:361)
        at 
org.apache.spark.sql.catalyst.expressions.SplitPart.dataType(stringExpressions.scala:3011)
        at 
org.apache.spark.sql.catalyst.expressions.Alias.toAttribute(namedExpressions.scala:197)
        at 
org.apache.spark.sql.catalyst.plans.logical.Project.$anonfun$output$1(basicLogicalOperators.scala:70)
        at scala.collection.immutable.List.map(List.scala:293)
   ```

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -2095,7 +2095,9 @@ case class ArrayPosition(left: Expression, right: 
Expression)
 case class ElementAt(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled)
+    failOnError: Boolean = SQLConf.get.ansiEnabled,
+    // The value to return if index is out of bound
+    defaultValueOutOfBound: Any = null)

Review comment:
       I think existing code base has been using `Literal("")` for empty string 
as Expression. Not sure why it fails in my case when doing the replacement for 
`split_part`.

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
##########
@@ -2943,3 +2943,85 @@ case class Sentences(
     copy(str = newFirst, language = newSecond, country = newThird)
 
 }
+
+/**
+ * Splits a given string by a specified delimiter.
+ */
+case class SplitByDelimiter(
+    str: Expression,
+    delimiter: Expression)
+  extends BinaryExpression with NullIntolerant {
+  override def dataType: DataType = ArrayType(StringType, containsNull = false)
+  override def left: Expression = str
+  override def right: Expression = delimiter
+
+  override def nullSafeEval(string: Any, delimiter: Any): Any = {
+    val strings = {
+      // if delimiter is empty string, skip the regex based splitting directly 
as regex
+      // treats empty string as matching anything, thus use the input directly.
+      if (delimiter.asInstanceOf[UTF8String].numBytes() == 0) {
+        Array(string)
+      } else {
+        string.asInstanceOf[UTF8String].split_delimiter(
+          delimiter.asInstanceOf[UTF8String], -1)
+      }
+    }
+    new GenericArrayData(strings.asInstanceOf[Array[Any]])
+  }
+
+  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
+    val arrayClass = classOf[GenericArrayData].getName
+    nullSafeCodeGen(ctx, ev, (str, delimiter) => {
+      if (delimiter.asInstanceOf[UTF8String].numBytes() == 0) {
+        s"""${ev.value} = Array($str)""".stripMargin
+      } else {
+        // Array in java is covariant, so we don't need to cast UTF8String[] 
to Object[].
+        s"""${ev.value} = new 
$arrayClass($str.split_delimiter($delimiter,-1));""".stripMargin
+      }
+    })
+  }
+
+  override protected def withNewChildrenInternal(
+    newFirst: Expression, newSecond: Expression): SplitByDelimiter =
+    copy(str = newFirst, delimiter = newSecond)
+}
+
+/**
+ * Splits a given string by a specified delimiter and returns the requested 
part.
+ * If any input is null, or index is out of range of split parts, returns null.
+ * If index is 0, throws an ArrayIndexOutOfBoundsException.
+ */
+@ExpressionDescription(
+  usage =
+    """
+    _FUNC_(str, delimiter, partNum) - Splits `str` by delimiter and return
+      requested part of the split (1-based). If any input is null, returns 
null.
+      if `partNum` is out of range of split parts, returns null. If `partNum` 
is 0,

Review comment:
       oops. This comment is stale. I will update it.

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -2095,7 +2095,9 @@ case class ArrayPosition(left: Expression, right: 
Expression)
 case class ElementAt(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled)
+    failOnError: Boolean = SQLConf.get.ansiEnabled,
+    // The value to return if index is out of bound
+    defaultValueOutOfBound: Any = null)

Review comment:
       I updated my change. Hopefully you could give me some suggestions.

##########
File path: 
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collectionOperations.scala
##########
@@ -2095,7 +2095,9 @@ case class ArrayPosition(left: Expression, right: 
Expression)
 case class ElementAt(
     left: Expression,
     right: Expression,
-    failOnError: Boolean = SQLConf.get.ansiEnabled)
+    failOnError: Boolean = SQLConf.get.ansiEnabled,
+    // The value to return if index is out of bound
+    defaultValueOutOfBound: Any = null)

Review comment:
       I pushed my change. Hopefully you could give me some suggestions.

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
##########
@@ -661,4 +661,53 @@ class StringFunctionsSuite extends QueryTest with 
SharedSparkSession {
     }.getMessage
     assert(m.contains("data type mismatch: argument 1 requires string type"))
   }
+
+  test("SPARK-38063: string split_part function") {
+    checkAnswer(
+      sql("select split_part('11,12,13', ',', 1)"),
+      Row("11"))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', 2)"),
+      Row("12"))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', -1)"),
+      Row("13"))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', -3)"),
+      Row("11"))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', 4)"),
+      Row(""))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', 5)"),
+      Row(""))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', -5)"),
+      Row(""))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '', 1)"),
+      Row("11.12.13"))
+
+    checkAnswer(
+      sql("select split_part('11ab12ab13', 'ab', 1)"),
+      Row("11"))
+
+    val m = intercept[ArrayIndexOutOfBoundsException] {
+      checkAnswer(
+        sql("select split_part('11.12.13', '.', 0)"),
+        Row("11"))
+    }.getMessage
+    assert(m.contains("SQL array indices start at 1"))

Review comment:
       hmm this is when ANSI mode is off what the code executes? Maybe a bug?

##########
File path: 
sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
##########
@@ -661,4 +661,53 @@ class StringFunctionsSuite extends QueryTest with 
SharedSparkSession {
     }.getMessage
     assert(m.contains("data type mismatch: argument 1 requires string type"))
   }
+
+  test("SPARK-38063: string split_part function") {
+    checkAnswer(
+      sql("select split_part('11,12,13', ',', 1)"),
+      Row("11"))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', 2)"),
+      Row("12"))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', -1)"),
+      Row("13"))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', -3)"),
+      Row("11"))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', 4)"),
+      Row(""))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', 5)"),
+      Row(""))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '.', -5)"),
+      Row(""))
+
+    checkAnswer(
+      sql("select split_part('11.12.13', '', 1)"),
+      Row("11.12.13"))
+
+    checkAnswer(
+      sql("select split_part('11ab12ab13', 'ab', 1)"),
+      Row("11"))
+
+    val m = intercept[ArrayIndexOutOfBoundsException] {
+      checkAnswer(
+        sql("select split_part('11.12.13', '.', 0)"),
+        Row("11"))
+    }.getMessage
+    assert(m.contains("SQL array indices start at 1"))

Review comment:
       hmm this is when ANSI mode is off what the code executes. Maybe a bug?




-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]



---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to