You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by me...@apache.org on 2020/05/12 15:55:01 UTC
[spark] branch master updated: [SPARK-31610][SPARK-31668][ML]
Address hashingTF saving&loading bug and expose hashFunc property in
HashingTF
This is an automated email from the ASF dual-hosted git repository.
meng pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/spark.git
The following commit(s) were added to refs/heads/master by this push:
new e248bc7 [SPARK-31610][SPARK-31668][ML] Address hashingTF saving&loading bug and expose hashFunc property in HashingTF
e248bc7 is described below
commit e248bc7af6086cde7dd89a51459ae6a221a600c8
Author: Weichen Xu <we...@databricks.com>
AuthorDate: Tue May 12 08:54:28 2020 -0700
[SPARK-31610][SPARK-31668][ML] Address hashingTF saving&loading bug and expose hashFunc property in HashingTF
### What changes were proposed in this pull request?
Expose hashFunc property in HashingTF
Some third-party library such as mleap need to access it.
See background description here:
https://github.com/combust/mleap/pull/665#issuecomment-621258623
### Why are the changes needed?
See https://github.com/combust/mleap/pull/665#issuecomment-621258623
### Does this PR introduce any user-facing change?
No. Only add a package private constructor.
### How was this patch tested?
N/A
Closes #28413 from WeichenXu123/hashing_tf_expose_hashfunc.
Authored-by: Weichen Xu <we...@databricks.com>
Signed-off-by: Xiangrui Meng <me...@databricks.com>
---
.../org/apache/spark/ml/feature/HashingTF.scala | 40 +++++++++++++++++-----
.../apache/spark/ml/feature/HashingTFSuite.scala | 4 +++
2 files changed, 35 insertions(+), 9 deletions(-)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 80bf859..d2bb013 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -42,14 +42,17 @@ import org.apache.spark.util.VersionUtils.majorMinorVersion
* otherwise the features will not be mapped evenly to the columns.
*/
@Since("1.2.0")
-class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
+class HashingTF @Since("3.0.0") private[ml] (
+ @Since("1.4.0") override val uid: String,
+ @Since("3.1.0") val hashFuncVersion: Int)
extends Transformer with HasInputCol with HasOutputCol with HasNumFeatures
with DefaultParamsWritable {
- private var hashFunc: Any => Int = FeatureHasher.murmur3Hash
-
@Since("1.2.0")
- def this() = this(Identifiable.randomUID("hashingTF"))
+ def this() = this(Identifiable.randomUID("hashingTF"), HashingTF.SPARK_3_MURMUR3_HASH)
+
+ @Since("1.4.0")
+ def this(uid: String) = this(uid, hashFuncVersion = HashingTF.SPARK_3_MURMUR3_HASH)
/** @group setParam */
@Since("1.4.0")
@@ -122,7 +125,12 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
*/
@Since("3.0.0")
def indexOf(term: Any): Int = {
- Utils.nonNegativeMod(hashFunc(term), $(numFeatures))
+ val hashValue = hashFuncVersion match {
+ case HashingTF.SPARK_2_MURMUR3_HASH => OldHashingTF.murmur3Hash(term)
+ case HashingTF.SPARK_3_MURMUR3_HASH => FeatureHasher.murmur3Hash(term)
+ case _ => throw new IllegalArgumentException("Illegal hash function version setting.")
+ }
+ Utils.nonNegativeMod(hashValue, $(numFeatures))
}
@Since("1.4.1")
@@ -132,27 +140,41 @@ class HashingTF @Since("1.4.0") (@Since("1.4.0") override val uid: String)
override def toString: String = {
s"HashingTF: uid=$uid, binary=${$(binary)}, numFeatures=${$(numFeatures)}"
}
+
+ @Since("3.0.0")
+ override def save(path: String): Unit = {
+ require(hashFuncVersion == HashingTF.SPARK_3_MURMUR3_HASH,
+ "Cannot save model which is loaded from lower version spark saved model. We can address " +
+ "it by (1) use old spark version to save the model, or (2) use new version spark to " +
+ "re-train the pipeline.")
+ super.save(path)
+ }
}
@Since("1.6.0")
object HashingTF extends DefaultParamsReadable[HashingTF] {
+ private[ml] val SPARK_2_MURMUR3_HASH = 1
+ private[ml] val SPARK_3_MURMUR3_HASH = 2
+
private class HashingTFReader extends MLReader[HashingTF] {
private val className = classOf[HashingTF].getName
override def load(path: String): HashingTF = {
val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
- val hashingTF = new HashingTF(metadata.uid)
- metadata.getAndSetParams(hashingTF)
// We support loading old `HashingTF` saved by previous Spark versions.
// Previous `HashingTF` uses `mllib.feature.HashingTF.murmur3Hash`, but new `HashingTF` uses
// `ml.Feature.FeatureHasher.murmur3Hash`.
val (majorVersion, _) = majorMinorVersion(metadata.sparkVersion)
- if (majorVersion < 3) {
- hashingTF.hashFunc = OldHashingTF.murmur3Hash
+ val hashFuncVersion = if (majorVersion < 3) {
+ SPARK_2_MURMUR3_HASH
+ } else {
+ SPARK_3_MURMUR3_HASH
}
+ val hashingTF = new HashingTF(metadata.uid, hashFuncVersion = hashFuncVersion)
+ metadata.getAndSetParams(hashingTF)
hashingTF
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 722302e..8fd192f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -100,6 +100,10 @@ class HashingTFSuite extends MLTest with DefaultReadWriteTest {
val metadata = spark.read.json(s"$hashingTFPath/metadata")
val sparkVersionStr = metadata.select("sparkVersion").first().getString(0)
assert(sparkVersionStr == "2.4.4")
+
+ intercept[IllegalArgumentException] {
+ loadedHashingTF.save(hashingTFPath)
+ }
}
test("read/write") {
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org