Repository: spark
Updated Branches:
  refs/heads/master de726b0d5 -> d642b2735


[SPARK-15397][SQL] fix string udf locate as hive

## What changes were proposed in this pull request?

in hive, `locate("aa", "aaa", 0)` would yield 0, `locate("aa", "aaa", 1)` would 
yield 1 and `locate("aa", "aaa", 2)` would yield 2, while in Spark, 
`locate("aa", "aaa", 0)` would yield 1,  `locate("aa", "aaa", 1)` would yield 2 
and  `locate("aa", "aaa", 2)` would yield 0. This results from the different 
understanding of the third parameter in udf `locate`. It means the starting 
index and starts from 1, so when we use 0, the return would always be 0.

## How was this patch tested?

tested with modified `StringExpressionsSuite` and `StringFunctionsSuite`

Author: Daoyuan Wang <daoyuan.w...@intel.com>

Closes #13186 from adrian-wang/locate.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/d642b273
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/d642b273
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/d642b273

Branch: refs/heads/master
Commit: d642b273544bb77ef7f584326aa2d214649ac61b
Parents: de726b0
Author: Daoyuan Wang <daoyuan.w...@intel.com>
Authored: Mon May 23 23:29:15 2016 -0700
Committer: Andrew Or <and...@databricks.com>
Committed: Mon May 23 23:29:15 2016 -0700

----------------------------------------------------------------------
 R/pkg/R/functions.R                              |  2 +-
 R/pkg/inst/tests/testthat/test_sparkSQL.R        |  2 +-
 python/pyspark/sql/functions.py                  |  2 +-
 .../catalyst/expressions/stringExpressions.scala | 19 +++++++++++++------
 .../expressions/StringExpressionsSuite.scala     | 16 +++++++++-------
 .../apache/spark/sql/StringFunctionsSuite.scala  | 10 +++++-----
 6 files changed, 30 insertions(+), 21 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/d642b273/R/pkg/R/functions.R
----------------------------------------------------------------------
diff --git a/R/pkg/R/functions.R b/R/pkg/R/functions.R
index 4a0bdf3..2665d1d 100644
--- a/R/pkg/R/functions.R
+++ b/R/pkg/R/functions.R
@@ -2226,7 +2226,7 @@ setMethod("window", signature(x = "Column"),
 #' @export
 #' @examples \dontrun{locate('b', df$c, 1)}
 setMethod("locate", signature(substr = "character", str = "Column"),
-          function(substr, str, pos = 0) {
+          function(substr, str, pos = 1) {
             jc <- callJStatic("org.apache.spark.sql.functions",
                               "locate",
                               substr, str@jc, as.integer(pos))

http://git-wip-us.apache.org/repos/asf/spark/blob/d642b273/R/pkg/inst/tests/testthat/test_sparkSQL.R
----------------------------------------------------------------------
diff --git a/R/pkg/inst/tests/testthat/test_sparkSQL.R 
b/R/pkg/inst/tests/testthat/test_sparkSQL.R
index 6a99b43..b2d769f 100644
--- a/R/pkg/inst/tests/testthat/test_sparkSQL.R
+++ b/R/pkg/inst/tests/testthat/test_sparkSQL.R
@@ -1152,7 +1152,7 @@ test_that("string operators", {
   l2 <- list(list(a = "aaads"))
   df2 <- createDataFrame(sqlContext, l2)
   expect_equal(collect(select(df2, locate("aa", df2$a)))[1, 1], 1)
-  expect_equal(collect(select(df2, locate("aa", df2$a, 1)))[1, 1], 2)
+  expect_equal(collect(select(df2, locate("aa", df2$a, 2)))[1, 1], 2)
   expect_equal(collect(select(df2, lpad(df2$a, 8, "#")))[1, 1], "###aaads") # 
nolint
   expect_equal(collect(select(df2, rpad(df2$a, 8, "#")))[1, 1], "aaads###") # 
nolint
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d642b273/python/pyspark/sql/functions.py
----------------------------------------------------------------------
diff --git a/python/pyspark/sql/functions.py b/python/pyspark/sql/functions.py
index 1f15eec..64b8bc4 100644
--- a/python/pyspark/sql/functions.py
+++ b/python/pyspark/sql/functions.py
@@ -1359,7 +1359,7 @@ def levenshtein(left, right):
 
 
 @since(1.5)
-def locate(substr, str, pos=0):
+def locate(substr, str, pos=1):
     """
     Locate the position of the first occurrence of substr in a string column, 
after position pos.
 

http://git-wip-us.apache.org/repos/asf/spark/blob/d642b273/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
index 78e846d..44ff7fd 100644
--- 
a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
+++ 
b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/stringExpressions.scala
@@ -494,7 +494,7 @@ case class StringLocate(substr: Expression, str: 
Expression, start: Expression)
   extends TernaryExpression with ImplicitCastInputTypes {
 
   def this(substr: Expression, str: Expression) = {
-    this(substr, str, Literal(0))
+    this(substr, str, Literal(1))
   }
 
   override def children: Seq[Expression] = substr :: str :: start :: Nil
@@ -516,9 +516,14 @@ case class StringLocate(substr: Expression, str: 
Expression, start: Expression)
         if (l == null) {
           null
         } else {
-          l.asInstanceOf[UTF8String].indexOf(
-            r.asInstanceOf[UTF8String],
-            s.asInstanceOf[Int]) + 1
+          val sVal = s.asInstanceOf[Int]
+          if (sVal < 1) {
+            0
+          } else {
+            l.asInstanceOf[UTF8String].indexOf(
+              r.asInstanceOf[UTF8String],
+              s.asInstanceOf[Int] - 1) + 1
+          }
         }
       }
     }
@@ -537,8 +542,10 @@ case class StringLocate(substr: Expression, str: 
Expression, start: Expression)
         if (!${substrGen.isNull}) {
           ${strGen.code}
           if (!${strGen.isNull}) {
-            ${ev.value} = ${strGen.value}.indexOf(${substrGen.value},
-              ${startGen.value}) + 1;
+            if (${startGen.value} > 0) {
+              ${ev.value} = ${strGen.value}.indexOf(${substrGen.value},
+                ${startGen.value} - 1) + 1;
+            }
           } else {
             ${ev.isNull} = true;
           }

http://git-wip-us.apache.org/repos/asf/spark/blob/d642b273/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
index c09c64f..29bf15b 100644
--- 
a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
+++ 
b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/StringExpressionsSuite.scala
@@ -508,16 +508,18 @@ class StringExpressionsSuite extends SparkFunSuite with 
ExpressionEvalHelper {
     val s2 = 'b.string.at(1)
     val s3 = 'c.string.at(2)
     val s4 = 'd.int.at(3)
-    val row1 = create_row("aaads", "aa", "zz", 1)
-    val row2 = create_row(null, "aa", "zz", 0)
-    val row3 = create_row("aaads", null, "zz", 0)
-    val row4 = create_row(null, null, null, 0)
+    val row1 = create_row("aaads", "aa", "zz", 2)
+    val row2 = create_row(null, "aa", "zz", 1)
+    val row3 = create_row("aaads", null, "zz", 1)
+    val row4 = create_row(null, null, null, 1)
 
     checkEvaluation(new StringLocate(Literal("aa"), Literal("aaads")), 1, row1)
-    checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 
2, row1)
-    checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 
0, row1)
+    checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(0)), 
0, row1)
+    checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(1)), 
1, row1)
+    checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(2)), 
2, row1)
+    checkEvaluation(StringLocate(Literal("aa"), Literal("aaads"), Literal(3)), 
0, row1)
     checkEvaluation(new StringLocate(Literal("de"), Literal("aaads")), 0, row1)
-    checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 1), 0, row1)
+    checkEvaluation(StringLocate(Literal("de"), Literal("aaads"), 2), 0, row1)
 
     checkEvaluation(new StringLocate(s2, s1), 1, row1)
     checkEvaluation(StringLocate(s2, s1, s4), 2, row1)

http://git-wip-us.apache.org/repos/asf/spark/blob/d642b273/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
----------------------------------------------------------------------
diff --git 
a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala 
b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
index c7b95c2..1de2d9b 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/StringFunctionsSuite.scala
@@ -189,15 +189,15 @@ class StringFunctionsSuite extends QueryTest with 
SharedSQLContext {
   }
 
   test("string locate function") {
-    val df = Seq(("aaads", "aa", "zz", 1)).toDF("a", "b", "c", "d")
+    val df = Seq(("aaads", "aa", "zz", 2)).toDF("a", "b", "c", "d")
 
     checkAnswer(
-      df.select(locate("aa", $"a"), locate("aa", $"a", 1)),
-      Row(1, 2))
+      df.select(locate("aa", $"a"), locate("aa", $"a", 2), locate("aa", $"a", 
0)),
+      Row(1, 2, 0))
 
     checkAnswer(
-      df.selectExpr("locate(b, a)", "locate(b, a, d)"),
-      Row(1, 2))
+      df.selectExpr("locate(b, a)", "locate(b, a, d)", "locate(b, a, 3)"),
+      Row(1, 2, 0))
   }
 
   test("string padding functions") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to