This is an automated email from the ASF dual-hosted git repository. meng pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push: new e248bc7 [SPARK-31610][SPARK-31668][ML] Address hashingTF saving&loading bug and expose hashFunc property in HashingTF e248bc7 is described below commit e248bc7af6086cde7dd89a51459ae6a221a600c8 Author: Weichen Xu <weichen...@databricks.com> AuthorDate: Tue May 12 08:54:28 2020 -0700 [SPARK-31610][SPARK-31668][ML] Address hashingTF saving&loading bug and expose hashFunc property in HashingTF ### What changes were proposed in this pull request? Expose hashFunc property in HashingTF Some third-party library such as mleap need to access it. See background description here: https://github.com/combust/mleap/pull/665#issuecomment-621258623 ### Why are the changes needed? See https://github.com/combust/mleap/pull/665#issuecomment-621258623 ### Does this PR introduce any user-facing change? No. Only add a package private constructor. ### How was this patch tested? N/A Closes #28413 from WeichenXu123/hashing_tf_expose_hashfunc. Authored-by: Weichen Xu <weichen...@databricks.com> Signed-off-by: Xiangrui Meng <m...@databricks.com> --- .../org/apache/spark/ml/feature/HashingTF.scala | 40 +++++++++++++++++----- .../apache/spark/ml/feature/HashingTFSuite.scala | 4 +++ 2 files changed, 35 insertions(+), 9 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 80bf859..d2bb013 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -42,14 +42,17 @@ import org.apache.spark.util.VersionUtils.majorMinorVersion * otherwise the features will not be mapped evenly to the columns. */ @Since("1.2.0") -class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) +class HashingTF @Since("3.0.0") private[ml] ( + @Since("1.4.0") override val uid: String, + @Since("3.1.0") val hashFuncVersion: Int) extends Transformer with HasInputCol with HasOutputCol with HasNumFeatures with DefaultParamsWritable { - private var hashFunc: Any => Int = FeatureHasher.murmur3Hash - @Since("1.2.0") - def this() = this(Identifiable.randomUID("hashingTF")) + def this() = this(Identifiable.randomUID("hashingTF"), HashingTF.SPARK_3_MURMUR3_HASH) + + @Since("1.4.0") + def this(uid: String) = this(uid, hashFuncVersion = HashingTF.SPARK_3_MURMUR3_HASH) /** @group setParam */ @Since("1.4.0") @@ -122,7 +125,12 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) */ @Since("3.0.0") def indexOf(term: Any): Int = { - Utils.nonNegativeMod(hashFunc(term), $(numFeatures)) + val hashValue = hashFuncVersion match { + case HashingTF.SPARK_2_MURMUR3_HASH => OldHashingTF.murmur3Hash(term) + case HashingTF.SPARK_3_MURMUR3_HASH => FeatureHasher.murmur3Hash(term) + case _ => throw new IllegalArgumentException("Illegal hash function version setting.") + } + Utils.nonNegativeMod(hashValue, $(numFeatures)) } @Since("1.4.1") @@ -132,27 +140,41 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String) override def toString: String = { s"HashingTF: uid=$uid, binary=${$(binary)}, numFeatures=${$(numFeatures)}" } + + @Since("3.0.0") + override def save(path: String): Unit = { + require(hashFuncVersion == HashingTF.SPARK_3_MURMUR3_HASH, + "Cannot save model which is loaded from lower version spark saved model. We can address " + + "it by (1) use old spark version to save the model, or (2) use new version spark to " + + "re-train the pipeline.") + super.save(path) + } } @Since("1.6.0") object HashingTF extends DefaultParamsReadable[HashingTF] { + private[ml] val SPARK_2_MURMUR3_HASH = 1 + private[ml] val SPARK_3_MURMUR3_HASH = 2 + private class HashingTFReader extends MLReader[HashingTF] { private val className = classOf[HashingTF].getName override def load(path: String): HashingTF = { val metadata = DefaultParamsReader.loadMetadata(path, sc, className) - val hashingTF = new HashingTF(metadata.uid) - metadata.getAndSetParams(hashingTF) // We support loading old `HashingTF` saved by previous Spark versions. // Previous `HashingTF` uses `mllib.feature.HashingTF.murmur3Hash`, but new `HashingTF` uses // `ml.Feature.FeatureHasher.murmur3Hash`. val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion) - if (majorVersion < 3) { - hashingTF.hashFunc = OldHashingTF.murmur3Hash + val hashFuncVersion = if (majorVersion < 3) { + SPARK_2_MURMUR3_HASH + } else { + SPARK_3_MURMUR3_HASH } + val hashingTF = new HashingTF(metadata.uid, hashFuncVersion = hashFuncVersion) + metadata.getAndSetParams(hashingTF) hashingTF } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 722302e..8fd192f 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -100,6 +100,10 @@ class HashingTFSuite extends MLTest with DefaultReadWriteTest { val metadata = spark.read.json(s"$hashingTFPath/metadata") val sparkVersionStr = metadata.select("sparkVersion").first().getString(0) assert(sparkVersionStr == "2.4.4") + + intercept[IllegalArgumentException] { + loadedHashingTF.save(hashingTFPath) + } } test("read/write") { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@spark.apache.org For additional commands, e-mail: commits-h...@spark.apache.org