You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2018/04/11 00:26:09 UTC

spark git commit: [SPARK-23944][ML] Add the set method for the two LSHModel

Repository: spark
Updated Branches:
  refs/heads/master 4f1e8b9bb -> 7c7570d46


[SPARK-23944][ML] Add the set method for the two LSHModel

## What changes were proposed in this pull request?

Add two set method for LSHModel in LSH.scala, BucketedRandomProjectionLSH.scala, and MinHashLSH.scala

## How was this patch tested?

New test for the param setup was added into

- BucketedRandomProjectionLSHSuite.scala

- MinHashLSHSuite.scala

Please review http://spark.apache.org/contributing.html before opening a pull request.

Author: Lu WANG <lu...@databricks.com>

Closes #21015 from ludatabricks/SPARK-23944.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/7c7570d4
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/7c7570d4
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/7c7570d4

Branch: refs/heads/master
Commit: 7c7570d466a8ded51e580eb6a28583bd9a9c5337
Parents: 4f1e8b9
Author: Lu WANG <lu...@databricks.com>
Authored: Tue Apr 10 17:26:06 2018 -0700
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Tue Apr 10 17:26:06 2018 -0700

----------------------------------------------------------------------
 .../spark/ml/feature/BucketedRandomProjectionLSH.scala       | 8 ++++++++
 mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala   | 6 ++++++
 .../main/scala/org/apache/spark/ml/feature/MinHashLSH.scala  | 8 ++++++++
 .../spark/ml/feature/BucketedRandomProjectionLSHSuite.scala  | 8 ++++++++
 .../scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala  | 8 ++++++++
 5 files changed, 38 insertions(+)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/7c7570d4/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
index 36a46ca..41eaaf9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -73,6 +73,14 @@ class BucketedRandomProjectionLSHModel private[ml](
     private[ml] val randUnitVectors: Array[Vector])
   extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {
 
+  /** @group setParam */
+  @Since("2.4.0")
+  override def setInputCol(value: String): this.type = super.set(inputCol, value)
+
+  /** @group setParam */
+  @Since("2.4.0")
+  override def setOutputCol(value: String): this.type = super.set(outputCol, value)
+
   @Since("2.1.0")
   override protected[ml] val hashFunction: Vector => Array[Vector] = {
     key: Vector => {

http://git-wip-us.apache.org/repos/asf/spark/blob/7c7570d4/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
index 1c9f47a..a70931f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -65,6 +65,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
   extends Model[T] with LSHParams with MLWritable {
   self: T =>
 
+  /** @group setParam */
+  def setInputCol(value: String): this.type = set(inputCol, value)
+
+  /** @group setParam */
+  def setOutputCol(value: String): this.type = set(outputCol, value)
+
   /**
    * The hash function of LSH, mapping an input feature vector to multiple hash vectors.
    * @return The mapping of LSH function.

http://git-wip-us.apache.org/repos/asf/spark/blob/7c7570d4/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
index 145422a..556848e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -51,6 +51,14 @@ class MinHashLSHModel private[ml](
     private[ml] val randCoefficients: Array[(Int, Int)])
   extends LSHModel[MinHashLSHModel] {
 
+  /** @group setParam */
+  @Since("2.4.0")
+  override def setInputCol(value: String): this.type = super.set(inputCol, value)
+
+  /** @group setParam */
+  @Since("2.4.0")
+  override def setOutputCol(value: String): this.type = super.set(outputCol, value)
+
   @Since("2.1.0")
   override protected[ml] val hashFunction: Vector => Array[Vector] = {
     elems: Vector => {

http://git-wip-us.apache.org/repos/asf/spark/blob/7c7570d4/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
index ed9a39d..9b82325 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -48,6 +48,14 @@ class BucketedRandomProjectionLSHSuite extends MLTest with DefaultReadWriteTest
     ParamsSuite.checkParams(model)
   }
 
+  test("setters") {
+    val model = new BucketedRandomProjectionLSHModel("brp", Array(Vectors.dense(0.0, 1.0)))
+      .setInputCol("testkeys")
+      .setOutputCol("testvalues")
+    assert(model.getInputCol  === "testkeys")
+    assert(model.getOutputCol === "testvalues")
+  }
+
   test("BucketedRandomProjectionLSH: default params") {
     val brp = new BucketedRandomProjectionLSH
     assert(brp.getNumHashTables === 1.0)

http://git-wip-us.apache.org/repos/asf/spark/blob/7c7570d4/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
index 96df68d..3da0fb7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -43,6 +43,14 @@ class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
     ParamsSuite.checkParams(model)
   }
 
+  test("setters") {
+    val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0)))
+      .setInputCol("testkeys")
+      .setOutputCol("testvalues")
+    assert(model.getInputCol  === "testkeys")
+    assert(model.getOutputCol === "testvalues")
+  }
+
   test("MinHashLSH: default params") {
     val rp = new MinHashLSH
     assert(rp.getNumHashTables === 1.0)


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org