You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/04/11 00:26:09 UTC
spark git commit: [SPARK-23944][ML] Add the set method for the two
LSHModel
Repository: spark
Updated Branches:
refs/heads/master 4f1e8b9bb -> 7c7570d46
[SPARK-23944][ML] Add the set method for the two LSHModel
## What changes were proposed in this pull request?
Add two set method for LSHModel in LSH.scala, BucketedRandomProjectionLSH.scala, and MinHashLSH.scala
## How was this patch tested?
New test for the param setup was added into
- BucketedRandomProjectionLSHSuite.scala
- MinHashLSHSuite.scala
Please review http://spark.apache.org/contributing.html before opening a pull request.
Author: Lu WANG <lu...@databricks.com>
Closes #21015 from ludatabricks/SPARK-23944.
Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7c7570d4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7c7570d4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7c7570d4
Branch: refs/heads/master
Commit: 7c7570d466a8ded51e580eb6a28583bd9a9c5337
Parents: 4f1e8b9
Author: Lu WANG <lu...@databricks.com>
Authored: Tue Apr 10 17:26:06 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Apr 10 17:26:06 2018 -0700
----------------------------------------------------------------------
.../spark/ml/feature/BucketedRandomProjectionLSH.scala | 8 ++++++++
mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala | 6 ++++++
.../main/scala/org/apache/spark/ml/feature/MinHashLSH.scala | 8 ++++++++
.../spark/ml/feature/BucketedRandomProjectionLSHSuite.scala | 8 ++++++++
.../scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala | 8 ++++++++
5 files changed, 38 insertions(+)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/spark/blob/7c7570d4/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index 36a46ca..41eaaf9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -73,6 +73,14 @@ class BucketedRandomProjectionLSHModel private[ml](
private[ml] val randUnitVectors: Array[Vector])
extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {
+ /** @group setParam */
+ @Since("2.4.0")
+ override def setInputCol(value: String): this.type = super.set(inputCol, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ override def setOutputCol(value: String): this.type = super.set(outputCol, value)
+
@Since("2.1.0")
override protected[ml] val hashFunction: Vector => Array[Vector] = {
key: Vector => {
http://git-wip-us.apache.org/repos/asf/spark/blob/7c7570d4/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
index 1c9f47a..a70931f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -65,6 +65,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
extends Model[T] with LSHParams with MLWritable {
self: T =>
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
/**
* The hash function of LSH, mapping an input feature vector to multiple hash vectors.
* @return The mapping of LSH function.
http://git-wip-us.apache.org/repos/asf/spark/blob/7c7570d4/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index 145422a..556848e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -51,6 +51,14 @@ class MinHashLSHModel private[ml](
private[ml] val randCoefficients: Array[(Int, Int)])
extends LSHModel[MinHashLSHModel] {
+ /** @group setParam */
+ @Since("2.4.0")
+ override def setInputCol(value: String): this.type = super.set(inputCol, value)
+
+ /** @group setParam */
+ @Since("2.4.0")
+ override def setOutputCol(value: String): this.type = super.set(outputCol, value)
+
@Since("2.1.0")
override protected[ml] val hashFunction: Vector => Array[Vector] = {
elems: Vector => {
http://git-wip-us.apache.org/repos/asf/spark/blob/7c7570d4/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index ed9a39d..9b82325 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -48,6 +48,14 @@ class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest
ParamsSuite.checkParams(model)
}
+ test("setters") {
+ val model = new BucketedRandomProjectionLSHModel("brp", Array(Vectors.dense(0.0, 1.0)))
+ .setInputCol("testkeys")
+ .setOutputCol("testvalues")
+ assert(model.getInputCol === "testkeys")
+ assert(model.getOutputCol === "testvalues")
+ }
+
test("BucketedRandomProjectionLSH: default params") {
val brp = new BucketedRandomProjectionLSH
assert(brp.getNumHashTables === 1.0)
http://git-wip-us.apache.org/repos/asf/spark/blob/7c7570d4/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
index 96df68d..3da0fb7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -43,6 +43,14 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
ParamsSuite.checkParams(model)
}
+ test("setters") {
+ val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0)))
+ .setInputCol("testkeys")
+ .setOutputCol("testvalues")
+ assert(model.getInputCol === "testkeys")
+ assert(model.getOutputCol === "testvalues")
+ }
+
test("MinHashLSH: default params") {
val rp = new MinHashLSH
assert(rp.getNumHashTables === 1.0)
---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org