This is an automated email from the ASF dual-hosted git repository.

meng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git


The following commit(s) were added to refs/heads/master by this push:
     new e248bc7  [SPARK-31610][SPARK-31668][ML] Address hashingTF 
saving&loading bug and expose hashFunc property in HashingTF
e248bc7 is described below

commit e248bc7af6086cde7dd89a51459ae6a221a600c8
Author: Weichen Xu <weichen...@databricks.com>
AuthorDate: Tue May 12 08:54:28 2020 -0700

    [SPARK-31610][SPARK-31668][ML] Address hashingTF saving&loading bug and 
expose hashFunc property in HashingTF
    
    ### What changes were proposed in this pull request?
    Expose hashFunc property in HashingTF
    
    Some third-party library such as mleap need to access it.
    See background description here:
    https://github.com/combust/mleap/pull/665#issuecomment-621258623
    
    ### Why are the changes needed?
    See https://github.com/combust/mleap/pull/665#issuecomment-621258623
    
    ### Does this PR introduce any user-facing change?
    No. Only add a package private constructor.
    
    ### How was this patch tested?
    N/A
    
    Closes #28413 from WeichenXu123/hashing_tf_expose_hashfunc.
    
    Authored-by: Weichen Xu <weichen...@databricks.com>
    Signed-off-by: Xiangrui Meng <m...@databricks.com>
---
 .../org/apache/spark/ml/feature/HashingTF.scala    | 40 +++++++++++++++++-----
 .../apache/spark/ml/feature/HashingTFSuite.scala   |  4 +++
 2 files changed, 35 insertions(+), 9 deletions(-)

diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala 
b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 80bf859..d2bb013 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -42,14 +42,17 @@ import org.apache.spark.util.VersionUtils.majorMinorVersion
  * otherwise the features will not be mapped evenly to the columns.
  */
 @Since("1.2.0")
-class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
+class HashingTF @Since("3.0.0") private[ml] (
+    @Since("1.4.0") override val uid: String,
+    @Since("3.1.0") val hashFuncVersion: Int)
   extends Transformer with HasInputCol with HasOutputCol with HasNumFeatures
     with DefaultParamsWritable {
 
-  private var hashFunc: Any => Int = FeatureHasher.murmur3Hash
-
   @Since("1.2.0")
-  def this() = this(Identifiable.randomUID("hashingTF"))
+  def this() = this(Identifiable.randomUID("hashingTF"), 
HashingTF.SPARK_3_MURMUR3_HASH)
+
+  @Since("1.4.0")
+  def this(uid: String) = this(uid, hashFuncVersion = 
HashingTF.SPARK_3_MURMUR3_HASH)
 
   /** @group setParam */
   @Since("1.4.0")
@@ -122,7 +125,12 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override 
val uid: String)
    */
   @Since("3.0.0")
   def indexOf(term: Any): Int = {
-    Utils.nonNegativeMod(hashFunc(term), $(numFeatures))
+    val hashValue = hashFuncVersion match {
+      case HashingTF.SPARK_2_MURMUR3_HASH => OldHashingTF.murmur3Hash(term)
+      case HashingTF.SPARK_3_MURMUR3_HASH => FeatureHasher.murmur3Hash(term)
+      case _ => throw new IllegalArgumentException("Illegal hash function 
version setting.")
+    }
+    Utils.nonNegativeMod(hashValue, $(numFeatures))
   }
 
   @Since("1.4.1")
@@ -132,27 +140,41 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override 
val uid: String)
   override def toString: String = {
     s"HashingTF: uid=$uid, binary=${$(binary)}, numFeatures=${$(numFeatures)}"
   }
+
+  @Since("3.0.0")
+  override def save(path: String): Unit = {
+    require(hashFuncVersion == HashingTF.SPARK_3_MURMUR3_HASH,
+      "Cannot save model which is loaded from lower version spark saved model. 
We can address " +
+      "it by (1) use old spark version to save the model, or (2) use new 
version spark to " +
+      "re-train the pipeline.")
+    super.save(path)
+  }
 }
 
 @Since("1.6.0")
 object HashingTF extends DefaultParamsReadable[HashingTF] {
 
+  private[ml] val SPARK_2_MURMUR3_HASH = 1
+  private[ml] val SPARK_3_MURMUR3_HASH = 2
+
   private class HashingTFReader extends MLReader[HashingTF] {
 
     private val className = classOf[HashingTF].getName
 
     override def load(path: String): HashingTF = {
       val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
-      val hashingTF = new HashingTF(metadata.uid)
-      metadata.getAndSetParams(hashingTF)
 
       // We support loading old `HashingTF` saved by previous Spark versions.
       // Previous `HashingTF` uses `mllib.feature.HashingTF.murmur3Hash`, but 
new `HashingTF` uses
       // `ml.Feature.FeatureHasher.murmur3Hash`.
       val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion)
-      if (majorVersion < 3) {
-        hashingTF.hashFunc = OldHashingTF.murmur3Hash
+      val hashFuncVersion = if (majorVersion < 3) {
+        SPARK_2_MURMUR3_HASH
+      } else {
+        SPARK_3_MURMUR3_HASH
       }
+      val hashingTF = new HashingTF(metadata.uid, hashFuncVersion = 
hashFuncVersion)
+      metadata.getAndSetParams(hashingTF)
       hashingTF
     }
   }
diff --git 
a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala 
b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 722302e..8fd192f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -100,6 +100,10 @@ class HashingTFSuite extends MLTest with 
DefaultReadWriteTest {
     val metadata = spark.read.json(s"$hashingTFPath/metadata")
     val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
     assert(sparkVersionStr == "2.4.4")
+
+    intercept[IllegalArgumentException] {
+      loadedHashingTF.save(hashingTFPath)
+    }
   }
 
   test("read/write") {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org
For additional commands, e-mail: commits-h...@spark.apache.org

Reply via email to