This is an automated email from the ASF dual-hosted git repository.

ruifengz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new 3755d51eb5b8 [SPARK-48892][ML] Avoid per-row param read in `Tokenizer`
3755d51eb5b8 is described below

commit 3755d51eb5b8ab17f2e68ff4114aa488e2815fdc
Author: Ruifeng Zheng <[email protected]>
AuthorDate: Wed Jul 17 07:18:18 2024 +0800

    [SPARK-48892][ML] Avoid per-row param read in `Tokenizer`
    
    ### What changes were proposed in this pull request?
    Inspired by https://github.com/apache/spark/pull/47258, I am checking other 
ML implementations, and find that we can also optimize `Tokenizer` in the same 
way
    
    ### Why are the changes needed?
    the function `createTransformFunc` is to build the udf for 
`UnaryTransformer.transform`:
    
https://github.com/apache/spark/blob/d679dabdd1b5ad04b8c7deb1c06ce886a154a928/mllib/src/main/scala/org/apache/spark/ml/Transformer.scala#L118
    
    existing implementation read the params for each row.
    
    ### Does this PR introduce _any_ user-facing change?
    No
    
    ### How was this patch tested?
    CI and manually tests:
    
    create test dataset
    ```
    
spark.range(1000000).select(uuid().as("uuid")).write.mode("overwrite").parquet("/tmp/regex_tokenizer.parquet")
    ```
    
    duration
    ```
    val df = spark.read.parquet("/tmp/regex_tokenizer.parquet")
    import org.apache.spark.ml.feature._
    val tokenizer = new RegexTokenizer().setPattern("-").setInputCol("uuid")
    Seq.range(0, 1000).foreach(i => tokenizer.transform(df).count()) // warm up
    val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => 
tokenizer.transform(df).count()); System.currentTimeMillis - tic
    ```
    
    result (before this PR)
    ```
    scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => 
tokenizer.transform(df).count()); System.currentTimeMillis - tic
    val tic: Long = 1720613235068
    val res5: Long = 50397
    ```
    
    result (after this PR)
    ```
    scala> val tic = System.currentTimeMillis; Seq.range(0, 1000).foreach(i => 
tokenizer.transform(df).count()); System.currentTimeMillis - tic
    val tic: Long = 1720612871256
    val res5: Long = 43748
    ```
    
    ### Was this patch authored or co-authored using generative AI tooling?
    No
    
    Closes #47342 from zhengruifeng/opt_tokenizer.
    
    Authored-by: Ruifeng Zheng <[email protected]>
    Signed-off-by: Ruifeng Zheng <[email protected]>
---
 .../scala/org/apache/spark/ml/feature/Tokenizer.scala | 19 ++++++++++++-------
 1 file changed, 12 insertions(+), 7 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
index e7b3ff76a8d8..1acbfd781820 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/Tokenizer.scala
@@ -141,14 +141,19 @@ class RegexTokenizer @Since("1.4.0") (@Since("1.4.0") 
override val uid: String)
 
   setDefault(minTokenLength -> 1, gaps -> true, pattern -> "\\s+", toLowercase 
-> true)
 
-  override protected def createTransformFunc: String => Seq[String] = { 
originStr =>
+  override protected def createTransformFunc: String => Seq[String] = {
     val re = $(pattern).r
-    // scalastyle:off caselocale
-    val str = if ($(toLowercase)) originStr.toLowerCase() else originStr
-    // scalastyle:on caselocale
-    val tokens = if ($(gaps)) re.split(str).toImmutableArraySeq else 
re.findAllIn(str).toSeq
-    val minLength = $(minTokenLength)
-    tokens.filter(_.length >= minLength)
+    val localToLowercase = $(toLowercase)
+    val localGaps = $(gaps)
+    val localMinTokenLength = $(minTokenLength)
+
+    (originStr: String) => {
+      // scalastyle:off caselocale
+      val str = if (localToLowercase) originStr.toLowerCase() else originStr
+      // scalastyle:on caselocale
+      val tokens = if (localGaps) re.split(str).toImmutableArraySeq else 
re.findAllIn(str).toSeq
+      tokens.filter(_.length >= localMinTokenLength)
+    }
   }
 
   override protected def validateInputType(inputType: DataType): Unit = {


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to