You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@spark.apache.org by jk...@apache.org on 2016/11/28 23:14:50 UTC

spark git commit: [SPARK-18408][ML] API Improvements for LSH

Repository: spark
Updated Branches:
  refs/heads/master 8b1609beb -> 05f7c6ffa


[SPARK-18408][ML] API Improvements for LSH

## What changes were proposed in this pull request?

(1) Change output schema to `Array of Vector` instead of `Vectors`
(2) Use `numHashTables` as the dimension of Array
(3) Rename `RandomProjection` to `BucketedRandomProjectionLSH`, `MinHash` to `MinHashLSH`
(4) Make `randUnitVectors/randCoefficients` private
(5) Make Multi-Probe NN Search and `hashDistance` private for future discussion

Saved for future PRs:
(1) AND-amplification and `numHashFunctions` as the dimension of Vector are saved for a future PR.
(2) `hashDistance` and MultiProbe NN Search needs more discussion. The current implementation is just a backward compatible one.

## How was this patch tested?
Related unit tests are modified to make sure the performance of LSH are ensured, and the outputs of the APIs meets expectation.

Author: Yun Ni <yu...@uber.com>
Author: Yunni <Eu...@gmail.com>

Closes #15874 from Yunni/SPARK-18408-yunn-api-improvements.


Project: http://git-wip-us.apache.org/repos/asf/spark/repo
Commit: http://git-wip-us.apache.org/repos/asf/spark/commit/05f7c6ff
Tree: http://git-wip-us.apache.org/repos/asf/spark/tree/05f7c6ff
Diff: http://git-wip-us.apache.org/repos/asf/spark/diff/05f7c6ff

Branch: refs/heads/master
Commit: 05f7c6ffab2a6be548375cd624dc27092677232f
Parents: 8b1609b
Author: Yun Ni <yu...@uber.com>
Authored: Mon Nov 28 15:14:46 2016 -0800
Committer: Joseph K. Bradley <jo...@databricks.com>
Committed: Mon Nov 28 15:14:46 2016 -0800

----------------------------------------------------------------------
 .../feature/BucketedRandomProjectionLSH.scala   | 234 +++++++++++++++++++
 .../scala/org/apache/spark/ml/feature/LSH.scala | 138 ++++++-----
 .../org/apache/spark/ml/feature/MinHash.scala   | 195 ----------------
 .../apache/spark/ml/feature/MinHashLSH.scala    | 201 ++++++++++++++++
 .../spark/ml/feature/RandomProjection.scala     | 225 ------------------
 .../BucketedRandomProjectionLSHSuite.scala      | 213 +++++++++++++++++
 .../org/apache/spark/ml/feature/LSHTest.scala   |  17 +-
 .../spark/ml/feature/MinHashLSHSuite.scala      | 161 +++++++++++++
 .../apache/spark/ml/feature/MinHashSuite.scala  | 126 ----------
 .../ml/feature/RandomProjectionSuite.scala      | 197 ----------------
 10 files changed, 896 insertions(+), 811 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/spark/blob/05f7c6ff/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
new file mode 100644
index 0000000..cbac163
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSH.scala
@@ -0,0 +1,234 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.Random
+
+import breeze.linalg.normalize
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.linalg._
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.HasSeed
+import org.apache.spark.ml.util._
+import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.types.StructType
+
+/**
+ * :: Experimental ::
+ *
+ * Params for [[BucketedRandomProjectionLSH]].
+ */
+private[ml] trait BucketedRandomProjectionLSHParams extends Params {
+
+  /**
+   * The length of each hash bucket, a larger bucket lowers the false negative rate. The number of
+   * buckets will be `(max L2 norm of input vectors) / bucketLength`.
+   *
+   *
+   * If input vectors are normalized, 1-10 times of pow(numRecords, -1/inputDim) would be a
+   * reasonable value
+   * @group param
+   */
+  val bucketLength: DoubleParam = new DoubleParam(this, "bucketLength",
+    "the length of each hash bucket, a larger bucket lowers the false negative rate.",
+    ParamValidators.gt(0))
+
+  /** @group getParam */
+  final def getBucketLength: Double = $(bucketLength)
+}
+
+/**
+ * :: Experimental ::
+ *
+ * Model produced by [[BucketedRandomProjectionLSH]], where multiple random vectors are stored. The
+ * vectors are normalized to be unit vectors and each vector is used in a hash function:
+ *    `h_i(x) = floor(r_i.dot(x) / bucketLength)`
+ * where `r_i` is the i-th random unit vector. The number of buckets will be `(max L2 norm of input
+ * vectors) / bucketLength`.
+ *
+ * @param randUnitVectors An array of random unit vectors. Each vector represents a hash function.
+ */
+@Experimental
+@Since("2.1.0")
+class BucketedRandomProjectionLSHModel private[ml](
+    override val uid: String,
+    private[ml] val randUnitVectors: Array[Vector])
+  extends LSHModel[BucketedRandomProjectionLSHModel] with BucketedRandomProjectionLSHParams {
+
+  @Since("2.1.0")
+  override protected[ml] val hashFunction: Vector => Array[Vector] = {
+    key: Vector => {
+      val hashValues: Array[Double] = randUnitVectors.map({
+        randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength))
+      })
+      // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
+      hashValues.map(Vectors.dense(_))
+    }
+  }
+
+  @Since("2.1.0")
+  override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {
+    Math.sqrt(Vectors.sqdist(x, y))
+  }
+
+  @Since("2.1.0")
+  override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
+    // Since it's generated by hashing, it will be a pair of dense vectors.
+    x.zip(y).map(vectorPair => Vectors.sqdist(vectorPair._1, vectorPair._2)).min
+  }
+
+  @Since("2.1.0")
+  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+
+  @Since("2.1.0")
+  override def write: MLWriter = {
+    new BucketedRandomProjectionLSHModel.BucketedRandomProjectionLSHModelWriter(this)
+  }
+}
+
+/**
+ * :: Experimental ::
+ *
+ * This [[BucketedRandomProjectionLSH]] implements Locality Sensitive Hashing functions for
+ * Euclidean distance metrics.
+ *
+ * The input is dense or sparse vectors, each of which represents a point in the Euclidean
+ * distance space. The output will be vectors of configurable dimension. Hash values in the
+ * same dimension are calculated by the same hash function.
+ *
+ * References:
+ *
+ * 1. <a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions">
+ * Wikipedia on Stable Distributions</a>
+ *
+ * 2. Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint
+ * arXiv:1408.2927 (2014).
+ */
+@Experimental
+@Since("2.1.0")
+class BucketedRandomProjectionLSH(override val uid: String)
+  extends LSH[BucketedRandomProjectionLSHModel]
+    with BucketedRandomProjectionLSHParams with HasSeed {
+
+  @Since("2.1.0")
+  override def setInputCol(value: String): this.type = super.setInputCol(value)
+
+  @Since("2.1.0")
+  override def setOutputCol(value: String): this.type = super.setOutputCol(value)
+
+  @Since("2.1.0")
+  override def setNumHashTables(value: Int): this.type = super.setNumHashTables(value)
+
+  @Since("2.1.0")
+  def this() = {
+    this(Identifiable.randomUID("brp-lsh"))
+  }
+
+  /** @group setParam */
+  @Since("2.1.0")
+  def setBucketLength(value: Double): this.type = set(bucketLength, value)
+
+  /** @group setParam */
+  @Since("2.1.0")
+  def setSeed(value: Long): this.type = set(seed, value)
+
+  @Since("2.1.0")
+  override protected[this] def createRawLSHModel(
+    inputDim: Int): BucketedRandomProjectionLSHModel = {
+    val rand = new Random($(seed))
+    val randUnitVectors: Array[Vector] = {
+      Array.fill($(numHashTables)) {
+        val randArray = Array.fill(inputDim)(rand.nextGaussian())
+        Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray)))
+      }
+    }
+    new BucketedRandomProjectionLSHModel(uid, randUnitVectors)
+  }
+
+  @Since("2.1.0")
+  override def transformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+    validateAndTransformSchema(schema)
+  }
+
+  @Since("2.1.0")
+  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+}
+
+@Since("2.1.0")
+object BucketedRandomProjectionLSH extends DefaultParamsReadable[BucketedRandomProjectionLSH] {
+
+  @Since("2.1.0")
+  override def load(path: String): BucketedRandomProjectionLSH = super.load(path)
+}
+
+@Since("2.1.0")
+object BucketedRandomProjectionLSHModel extends MLReadable[BucketedRandomProjectionLSHModel] {
+
+  @Since("2.1.0")
+  override def read: MLReader[BucketedRandomProjectionLSHModel] = {
+    new BucketedRandomProjectionLSHModelReader
+  }
+
+  @Since("2.1.0")
+  override def load(path: String): BucketedRandomProjectionLSHModel = super.load(path)
+
+  private[BucketedRandomProjectionLSHModel] class BucketedRandomProjectionLSHModelWriter(
+    instance: BucketedRandomProjectionLSHModel) extends MLWriter {
+
+    // TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved.
+    private case class Data(randUnitVectors: Matrix)
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val numRows = instance.randUnitVectors.length
+      require(numRows > 0)
+      val numCols = instance.randUnitVectors.head.size
+      val values = instance.randUnitVectors.map(_.toArray).reduce(Array.concat(_, _))
+      val randMatrix = Matrices.dense(numRows, numCols, values)
+      val data = Data(randMatrix)
+      val dataPath = new Path(path, "data").toString
+      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class BucketedRandomProjectionLSHModelReader
+    extends MLReader[BucketedRandomProjectionLSHModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[BucketedRandomProjectionLSHModel].getName
+
+    override def load(path: String): BucketedRandomProjectionLSHModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sparkSession.read.parquet(dataPath)
+      val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors")
+        .select("randUnitVectors")
+        .head()
+      val model = new BucketedRandomProjectionLSHModel(metadata.uid,
+        randUnitVectors.rowIter.toArray)
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/05f7c6ff/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
index eb117c4..309cc2e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/LSH.scala
@@ -33,28 +33,28 @@ import org.apache.spark.sql.types._
  */
 private[ml] trait LSHParams extends HasInputCol with HasOutputCol {
   /**
-   * Param for the dimension of LSH OR-amplification.
+   * Param for the number of hash tables used in LSH OR-amplification.
    *
-   * In this implementation, we use LSH OR-amplification to reduce the false negative rate. The
-   * higher the dimension is, the lower the false negative rate.
+   * LSH OR-amplification can be used to reduce the false negative rate. Higher values for this
+   * param lead to a reduced false negative rate, at the expense of added computational complexity.
    * @group param
    */
-  final val outputDim: IntParam = new IntParam(this, "outputDim", "output dimension, where" +
-    " increasing dimensionality lowers the false negative rate, and decreasing dimensionality" +
-    " improves the running performance", ParamValidators.gt(0))
+  final val numHashTables: IntParam = new IntParam(this, "numHashTables", "number of hash " +
+    "tables, where increasing number of hash tables lowers the false negative rate, and " +
+    "decreasing it improves the running performance", ParamValidators.gt(0))
 
   /** @group getParam */
-  final def getOutputDim: Int = $(outputDim)
+  final def getNumHashTables: Int = $(numHashTables)
 
-  setDefault(outputDim -> 1)
+  setDefault(numHashTables -> 1)
 
   /**
    * Transform the Schema for LSH
-   * @param schema The schema of the input dataset without [[outputCol]]
-   * @return A derived schema with [[outputCol]] added
+   * @param schema The schema of the input dataset without [[outputCol]].
+   * @return A derived schema with [[outputCol]] added.
    */
   protected[this] final def validateAndTransformSchema(schema: StructType): StructType = {
-    SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+    SchemaUtils.appendColumn(schema, $(outputCol), DataTypes.createArrayType(new VectorUDT))
   }
 }
 
@@ -66,32 +66,32 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
   self: T =>
 
   /**
-   * The hash function of LSH, mapping a predefined KeyType to a Vector
+   * The hash function of LSH, mapping an input feature vector to multiple hash vectors.
    * @return The mapping of LSH function.
    */
-  protected[ml] val hashFunction: Vector => Vector
+  protected[ml] val hashFunction: Vector => Array[Vector]
 
   /**
    * Calculate the distance between two different keys using the distance metric corresponding
-   * to the hashFunction
-   * @param x One input vector in the metric space
-   * @param y One input vector in the metric space
-   * @return The distance between x and y
+   * to the hashFunction.
+   * @param x One input vector in the metric space.
+   * @param y One input vector in the metric space.
+   * @return The distance between x and y.
    */
   protected[ml] def keyDistance(x: Vector, y: Vector): Double
 
   /**
    * Calculate the distance between two different hash Vectors.
    *
-   * @param x One of the hash vector
-   * @param y Another hash vector
-   * @return The distance between hash vectors x and y
+   * @param x One of the hash vector.
+   * @param y Another hash vector.
+   * @return The distance between hash vectors x and y.
    */
-  protected[ml] def hashDistance(x: Vector, y: Vector): Double
+  protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double
 
   override def transform(dataset: Dataset[_]): DataFrame = {
     transformSchema(dataset.schema, logging = true)
-    val transformUDF = udf(hashFunction, new VectorUDT)
+    val transformUDF = udf(hashFunction, DataTypes.createArrayType(new VectorUDT))
     dataset.withColumn($(outputCol), transformUDF(dataset($(inputCol))))
   }
 
@@ -99,29 +99,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
     validateAndTransformSchema(schema)
   }
 
-  /**
-   * Given a large dataset and an item, approximately find at most k items which have the closest
-   * distance to the item. If the [[outputCol]] is missing, the method will transform the data; if
-   * the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the
-   * transformed data when necessary.
-   *
-   * This method implements two ways of fetching k nearest neighbors:
-   *  - Single Probing: Fast, return at most k elements (Probing only one buckets)
-   *  - Multiple Probing: Slow, return exact k elements (Probing multiple buckets close to the key)
-   *
-   * @param dataset the dataset to search for nearest neighbors of the key
-   * @param key Feature vector representing the item to search for
-   * @param numNearestNeighbors The maximum number of nearest neighbors
-   * @param singleProbing True for using Single Probing; false for multiple probing
-   * @param distCol Output column for storing the distance between each result row and the key
-   * @return A dataset containing at most k items closest to the key. A distCol is added to show
-   *         the distance between each row and the key.
-   */
-  def approxNearestNeighbors(
+  // TODO: Fix the MultiProbe NN Search in SPARK-18454
+  private[feature] def approxNearestNeighbors(
       dataset: Dataset[_],
       key: Vector,
       numNearestNeighbors: Int,
-      singleProbing: Boolean,
+      singleProbe: Boolean,
       distCol: String): Dataset[_] = {
     require(numNearestNeighbors > 0, "The number of nearest neighbors cannot be less than 1")
     // Get Hash Value of the key
@@ -132,14 +115,24 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
         dataset.toDF()
       }
 
-    // In the origin dataset, find the hash value that is closest to the key
-    val hashDistUDF = udf((x: Vector) => hashDistance(x, keyHash), DataTypes.DoubleType)
-    val hashDistCol = hashDistUDF(col($(outputCol)))
+    val modelSubset = if (singleProbe) {
+      def sameBucket(x: Seq[Vector], y: Seq[Vector]): Boolean = {
+        x.zip(y).exists(tuple => tuple._1 == tuple._2)
+      }
+
+      // In the origin dataset, find the hash value that hash the same bucket with the key
+      val sameBucketWithKeyUDF = udf((x: Seq[Vector]) =>
+        sameBucket(x, keyHash), DataTypes.BooleanType)
 
-    val modelSubset = if (singleProbing) {
-      modelDataset.filter(hashDistCol === 0.0)
+      modelDataset.filter(sameBucketWithKeyUDF(col($(outputCol))))
     } else {
+      // In the origin dataset, find the hash value that is closest to the key
+      // Limit the use of hashDist since it's controversial
+      val hashDistUDF = udf((x: Seq[Vector]) => hashDistance(x, keyHash), DataTypes.DoubleType)
+      val hashDistCol = hashDistUDF(col($(outputCol)))
+
       // Compute threshold to get exact k elements.
+      // TODO: SPARK-18409: Use approxQuantile to get the threshold
       val modelDatasetSortedByHash = modelDataset.sort(hashDistCol).limit(numNearestNeighbors)
       val thresholdDataset = modelDatasetSortedByHash.select(max(hashDistCol))
       val hashThreshold = thresholdDataset.take(1).head.getDouble(0)
@@ -155,8 +148,30 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
   }
 
   /**
-   * Overloaded method for approxNearestNeighbors. Use Single Probing as default way to search
-   * nearest neighbors and "distCol" as default distCol.
+   * Given a large dataset and an item, approximately find at most k items which have the closest
+   * distance to the item. If the [[outputCol]] is missing, the method will transform the data; if
+   * the [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the
+   * transformed data when necessary.
+   *
+   * @note This method is experimental and will likely change behavior in the next release.
+   *
+   * @param dataset The dataset to search for nearest neighbors of the key.
+   * @param key Feature vector representing the item to search for.
+   * @param numNearestNeighbors The maximum number of nearest neighbors.
+   * @param distCol Output column for storing the distance between each result row and the key.
+   * @return A dataset containing at most k items closest to the key. A column "distCol" is added
+   *         to show the distance between each row and the key.
+   */
+  def approxNearestNeighbors(
+    dataset: Dataset[_],
+    key: Vector,
+    numNearestNeighbors: Int,
+    distCol: String): Dataset[_] = {
+    approxNearestNeighbors(dataset, key, numNearestNeighbors, true, distCol)
+  }
+
+  /**
+   * Overloaded method for approxNearestNeighbors. Use "distCol" as default distCol.
    */
   def approxNearestNeighbors(
       dataset: Dataset[_],
@@ -172,31 +187,28 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
    *
    * @param dataset The dataset to transform and explode.
    * @param explodeCols The alias for the exploded columns, must be a seq of two strings.
-   * @return A dataset containing idCol, inputCol and explodeCols
+   * @return A dataset containing idCol, inputCol and explodeCols.
    */
   private[this] def processDataset(
       dataset: Dataset[_],
       inputName: String,
       explodeCols: Seq[String]): Dataset[_] = {
     require(explodeCols.size == 2, "explodeCols must be two strings.")
-    val vectorToMap = udf((x: Vector) => x.asBreeze.iterator.toMap,
-      MapType(DataTypes.IntegerType, DataTypes.DoubleType))
     val modelDataset: DataFrame = if (!dataset.columns.contains($(outputCol))) {
       transform(dataset)
     } else {
       dataset.toDF()
     }
     modelDataset.select(
-      struct(col("*")).as(inputName),
-      explode(vectorToMap(col($(outputCol)))).as(explodeCols))
+      struct(col("*")).as(inputName), posexplode(col($(outputCol))).as(explodeCols))
   }
 
   /**
    * Recreate a column using the same column name but different attribute id. Used in approximate
    * similarity join.
-   * @param dataset The dataset where a column need to recreate
-   * @param colName The name of the column to recreate
-   * @param tmpColName A temporary column name which does not conflict with existing columns
+   * @param dataset The dataset where a column need to recreate.
+   * @param colName The name of the column to recreate.
+   * @param tmpColName A temporary column name which does not conflict with existing columns.
    * @return
    */
   private[this] def recreateCol(
@@ -215,12 +227,12 @@ private[ml] abstract class LSHModel[T <: LSHModel[T]]
    * [[outputCol]] exists, it will use the [[outputCol]]. This allows caching of the transformed
    * data when necessary.
    *
-   * @param datasetA One of the datasets to join
-   * @param datasetB Another dataset to join
-   * @param threshold The threshold for the distance of row pairs
-   * @param distCol Output column for storing the distance between each result row and the key
+   * @param datasetA One of the datasets to join.
+   * @param datasetB Another dataset to join.
+   * @param threshold The threshold for the distance of row pairs.
+   * @param distCol Output column for storing the distance between each result row and the key.
    * @return A joined dataset containing pairs of rows. The original rows are in columns
-   *         "datasetA" and "datasetB", and a distCol is added to show the distance of each pair
+   *         "datasetA" and "datasetB", and a distCol is added to show the distance of each pair.
    */
   def approxSimilarityJoin(
       datasetA: Dataset[_],
@@ -293,7 +305,7 @@ private[ml] abstract class LSH[T <: LSHModel[T]]
   def setOutputCol(value: String): this.type = set(outputCol, value)
 
   /** @group setParam */
-  def setOutputDim(value: Int): this.type = set(outputDim, value)
+  def setNumHashTables(value: Int): this.type = set(numHashTables, value)
 
   /**
    * Validate and create a new instance of concrete LSHModel. Because different LSHModel may have

http://git-wip-us.apache.org/repos/asf/spark/blob/05f7c6ff/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala
deleted file mode 100644
index f37233e..0000000
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHash.scala
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.feature
-
-import scala.util.Random
-
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
-import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.param.shared.HasSeed
-import org.apache.spark.ml.util._
-import org.apache.spark.sql.types.StructType
-
-/**
- * :: Experimental ::
- *
- * Model produced by [[MinHash]], where multiple hash functions are stored. Each hash function is
- * a perfect hash function:
- *    `h_i(x) = (x * k_i mod prime) mod numEntries`
- * where `k_i` is the i-th coefficient, and both `x` and `k_i` are from `Z_prime^*`
- *
- * Reference:
- * <a href="https://en.wikipedia.org/wiki/Perfect_hash_function">
- * Wikipedia on Perfect Hash Function</a>
- *
- * @param numEntries The number of entries of the hash functions.
- * @param randCoefficients An array of random coefficients, each used by one hash function.
- */
-@Experimental
-@Since("2.1.0")
-class MinHashModel private[ml] (
-    override val uid: String,
-    @Since("2.1.0") val numEntries: Int,
-    @Since("2.1.0") val randCoefficients: Array[Int])
-  extends LSHModel[MinHashModel] {
-
-  @Since("2.1.0")
-  override protected[ml] val hashFunction: Vector => Vector = {
-    elems: Vector =>
-      require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.")
-      val elemsList = elems.toSparse.indices.toList
-      val hashValues = randCoefficients.map({ randCoefficient: Int =>
-          elemsList.map({elem: Int =>
-            (1 + elem) * randCoefficient.toLong % MinHash.prime % numEntries
-          }).min.toDouble
-      })
-      Vectors.dense(hashValues)
-  }
-
-  @Since("2.1.0")
-  override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {
-    val xSet = x.toSparse.indices.toSet
-    val ySet = y.toSparse.indices.toSet
-    val intersectionSize = xSet.intersect(ySet).size.toDouble
-    val unionSize = xSet.size + ySet.size - intersectionSize
-    assert(unionSize > 0, "The union of two input sets must have at least 1 elements")
-    1 - intersectionSize / unionSize
-  }
-
-  @Since("2.1.0")
-  override protected[ml] def hashDistance(x: Vector, y: Vector): Double = {
-    // Since it's generated by hashing, it will be a pair of dense vectors.
-    x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min
-  }
-
-  @Since("2.1.0")
-  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
-
-  @Since("2.1.0")
-  override def write: MLWriter = new MinHashModel.MinHashModelWriter(this)
-}
-
-/**
- * :: Experimental ::
- *
- * LSH class for Jaccard distance.
- *
- * The input can be dense or sparse vectors, but it is more efficient if it is sparse. For example,
- *    `Vectors.sparse(10, Array[(2, 1.0), (3, 1.0), (5, 1.0)])`
- * means there are 10 elements in the space. This set contains elem 2, elem 3 and elem 5.
- * Also, any input vector must have at least 1 non-zero indices, and all non-zero values are treated
- * as binary "1" values.
- *
- * References:
- * <a href="https://en.wikipedia.org/wiki/MinHash">Wikipedia on MinHash</a>
- */
-@Experimental
-@Since("2.1.0")
-class MinHash(override val uid: String) extends LSH[MinHashModel] with HasSeed {
-
-
-  @Since("2.1.0")
-  override def setInputCol(value: String): this.type = super.setInputCol(value)
-
-  @Since("2.1.0")
-  override def setOutputCol(value: String): this.type = super.setOutputCol(value)
-
-  @Since("2.1.0")
-  override def setOutputDim(value: Int): this.type = super.setOutputDim(value)
-
-  @Since("2.1.0")
-  def this() = {
-    this(Identifiable.randomUID("min hash"))
-  }
-
-  /** @group setParam */
-  @Since("2.1.0")
-  def setSeed(value: Long): this.type = set(seed, value)
-
-  @Since("2.1.0")
-  override protected[ml] def createRawLSHModel(inputDim: Int): MinHashModel = {
-    require(inputDim <= MinHash.prime / 2,
-      s"The input vector dimension $inputDim exceeds the threshold ${MinHash.prime / 2}.")
-    val rand = new Random($(seed))
-    val numEntry = inputDim * 2
-    val randCoofs: Array[Int] = Array.fill($(outputDim))(1 + rand.nextInt(MinHash.prime - 1))
-    new MinHashModel(uid, numEntry, randCoofs)
-  }
-
-  @Since("2.1.0")
-  override def transformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
-    validateAndTransformSchema(schema)
-  }
-
-  @Since("2.1.0")
-  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
-}
-
-@Since("2.1.0")
-object MinHash extends DefaultParamsReadable[MinHash] {
-  // A large prime smaller than sqrt(2^63 \u2212 1)
-  private[ml] val prime = 2038074743
-
-  @Since("2.1.0")
-  override def load(path: String): MinHash = super.load(path)
-}
-
-@Since("2.1.0")
-object MinHashModel extends MLReadable[MinHashModel] {
-
-  @Since("2.1.0")
-  override def read: MLReader[MinHashModel] = new MinHashModelReader
-
-  @Since("2.1.0")
-  override def load(path: String): MinHashModel = super.load(path)
-
-  private[MinHashModel] class MinHashModelWriter(instance: MinHashModel) extends MLWriter {
-
-    private case class Data(numEntries: Int, randCoefficients: Array[Int])
-
-    override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
-      val data = Data(instance.numEntries, instance.randCoefficients)
-      val dataPath = new Path(path, "data").toString
-      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
-    }
-  }
-
-  private class MinHashModelReader extends MLReader[MinHashModel] {
-
-    /** Checked against metadata when loading model */
-    private val className = classOf[MinHashModel].getName
-
-    override def load(path: String): MinHashModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
-
-      val dataPath = new Path(path, "data").toString
-      val data = sparkSession.read.parquet(dataPath).select("numEntries", "randCoefficients").head()
-      val numEntries = data.getAs[Int](0)
-      val randCoefficients = data.getAs[Seq[Int]](1).toArray
-      val model = new MinHashModel(metadata.uid, numEntries, randCoefficients)
-
-      DefaultParamsReader.getAndSetParams(model, metadata)
-      model
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/05f7c6ff/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
new file mode 100644
index 0000000..620e1fb
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/MinHashLSH.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import scala.util.Random
+
+import org.apache.hadoop.fs.Path
+
+import org.apache.spark.annotation.{Experimental, Since}
+import org.apache.spark.ml.linalg.{Vector, Vectors, VectorUDT}
+import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasSeed
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.types.StructType
+
+/**
+ * :: Experimental ::
+ *
+ * Model produced by [[MinHashLSH]], where multiple hash functions are stored. Each hash function
+ * is picked from the following family of hash functions, where a_i and b_i are randomly chosen
+ * integers less than prime:
+ *    `h_i(x) = ((x \cdot a_i + b_i) \mod prime)`
+ *
+ * This hash family is approximately min-wise independent according to the reference.
+ *
+ * Reference:
+ * Tom Bohman, Colin Cooper, and Alan Frieze. "Min-wise independent linear permutations."
+ * Electronic Journal of Combinatorics 7 (2000): R26.
+ *
+ * @param randCoefficients Pairs of random coefficients. Each pair is used by one hash function.
+ */
+@Experimental
+@Since("2.1.0")
+class MinHashLSHModel private[ml](
+    override val uid: String,
+    private[ml] val randCoefficients: Array[(Int, Int)])
+  extends LSHModel[MinHashLSHModel] {
+
+  @Since("2.1.0")
+  override protected[ml] val hashFunction: Vector => Array[Vector] = {
+    elems: Vector => {
+      require(elems.numNonzeros > 0, "Must have at least 1 non zero entry.")
+      val elemsList = elems.toSparse.indices.toList
+      val hashValues = randCoefficients.map { case (a, b) =>
+        elemsList.map { elem: Int =>
+          ((1 + elem) * a + b) % MinHashLSH.HASH_PRIME
+        }.min.toDouble
+      }
+      // TODO: Output vectors of dimension numHashFunctions in SPARK-18450
+      hashValues.map(Vectors.dense(_))
+    }
+  }
+
+  @Since("2.1.0")
+  override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {
+    val xSet = x.toSparse.indices.toSet
+    val ySet = y.toSparse.indices.toSet
+    val intersectionSize = xSet.intersect(ySet).size.toDouble
+    val unionSize = xSet.size + ySet.size - intersectionSize
+    assert(unionSize > 0, "The union of two input sets must have at least 1 elements")
+    1 - intersectionSize / unionSize
+  }
+
+  @Since("2.1.0")
+  override protected[ml] def hashDistance(x: Seq[Vector], y: Seq[Vector]): Double = {
+    // Since it's generated by hashing, it will be a pair of dense vectors.
+    // TODO: This hashDistance function requires more discussion in SPARK-18454
+    x.zip(y).map(vectorPair =>
+      vectorPair._1.toArray.zip(vectorPair._2.toArray).count(pair => pair._1 != pair._2)
+    ).min
+  }
+
+  @Since("2.1.0")
+  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+
+  @Since("2.1.0")
+  override def write: MLWriter = new MinHashLSHModel.MinHashLSHModelWriter(this)
+}
+
+/**
+ * :: Experimental ::
+ *
+ * LSH class for Jaccard distance.
+ *
+ * The input can be dense or sparse vectors, but it is more efficient if it is sparse. For example,
+ *    `Vectors.sparse(10, Array((2, 1.0), (3, 1.0), (5, 1.0)))`
+ * means there are 10 elements in the space. This set contains elements 2, 3, and 5. Also, any
+ * input vector must have at least 1 non-zero index, and all non-zero values are
+ * treated as binary "1" values.
+ *
+ * References:
+ * <a href="https://en.wikipedia.org/wiki/MinHash">Wikipedia on MinHash</a>
+ */
+@Experimental
+@Since("2.1.0")
+class MinHashLSH(override val uid: String) extends LSH[MinHashLSHModel] with HasSeed {
+
+  @Since("2.1.0")
+  override def setInputCol(value: String): this.type = super.setInputCol(value)
+
+  @Since("2.1.0")
+  override def setOutputCol(value: String): this.type = super.setOutputCol(value)
+
+  @Since("2.1.0")
+  override def setNumHashTables(value: Int): this.type = super.setNumHashTables(value)
+
+  @Since("2.1.0")
+  def this() = {
+    this(Identifiable.randomUID("mh-lsh"))
+  }
+
+  /** @group setParam */
+  @Since("2.1.0")
+  def setSeed(value: Long): this.type = set(seed, value)
+
+  @Since("2.1.0")
+  override protected[ml] def createRawLSHModel(inputDim: Int): MinHashLSHModel = {
+    require(inputDim <= MinHashLSH.HASH_PRIME,
+      s"The input vector dimension $inputDim exceeds the threshold ${MinHashLSH.HASH_PRIME}.")
+    val rand = new Random($(seed))
+    val randCoefs: Array[(Int, Int)] = Array.fill($(numHashTables)) {
+        (1 + rand.nextInt(MinHashLSH.HASH_PRIME - 1), rand.nextInt(MinHashLSH.HASH_PRIME - 1))
+      }
+    new MinHashLSHModel(uid, randCoefs)
+  }
+
+  @Since("2.1.0")
+  override def transformSchema(schema: StructType): StructType = {
+    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
+    validateAndTransformSchema(schema)
+  }
+
+  @Since("2.1.0")
+  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
+}
+
+@Since("2.1.0")
+object MinHashLSH extends DefaultParamsReadable[MinHashLSH] {
+  // A large prime smaller than sqrt(2^63 \u2212 1)
+  private[ml] val HASH_PRIME = 2038074743
+
+  @Since("2.1.0")
+  override def load(path: String): MinHashLSH = super.load(path)
+}
+
+@Since("2.1.0")
+object MinHashLSHModel extends MLReadable[MinHashLSHModel] {
+
+  @Since("2.1.0")
+  override def read: MLReader[MinHashLSHModel] = new MinHashLSHModelReader
+
+  @Since("2.1.0")
+  override def load(path: String): MinHashLSHModel = super.load(path)
+
+  private[MinHashLSHModel] class MinHashLSHModelWriter(instance: MinHashLSHModel)
+    extends MLWriter {
+
+    private case class Data(randCoefficients: Array[Int])
+
+    override protected def saveImpl(path: String): Unit = {
+      DefaultParamsWriter.saveMetadata(instance, path, sc)
+      val data = Data(instance.randCoefficients.flatMap(tuple => Array(tuple._1, tuple._2)))
+      val dataPath = new Path(path, "data").toString
+      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
+    }
+  }
+
+  private class MinHashLSHModelReader extends MLReader[MinHashLSHModel] {
+
+    /** Checked against metadata when loading model */
+    private val className = classOf[MinHashLSHModel].getName
+
+    override def load(path: String): MinHashLSHModel = {
+      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
+
+      val dataPath = new Path(path, "data").toString
+      val data = sparkSession.read.parquet(dataPath).select("randCoefficients").head()
+      val randCoefficients = data.getAs[Seq[Int]](0).grouped(2)
+        .map(tuple => (tuple(0), tuple(1))).toArray
+      val model = new MinHashLSHModel(metadata.uid, randCoefficients)
+
+      DefaultParamsReader.getAndSetParams(model, metadata)
+      model
+    }
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/05f7c6ff/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala
----------------------------------------------------------------------
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala
deleted file mode 100644
index 2bff59a..0000000
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/RandomProjection.scala
+++ /dev/null
@@ -1,225 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.feature
-
-import scala.util.Random
-
-import breeze.linalg.normalize
-import org.apache.hadoop.fs.Path
-
-import org.apache.spark.annotation.{Experimental, Since}
-import org.apache.spark.ml.linalg._
-import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.HasSeed
-import org.apache.spark.ml.util._
-import org.apache.spark.mllib.util.MLUtils
-import org.apache.spark.sql.Row
-import org.apache.spark.sql.types.StructType
-
-/**
- * :: Experimental ::
- *
- * Params for [[RandomProjection]].
- */
-private[ml] trait RandomProjectionParams extends Params {
-
-  /**
-   * The length of each hash bucket, a larger bucket lowers the false negative rate. The number of
-   * buckets will be `(max L2 norm of input vectors) / bucketLength`.
-   *
-   *
-   * If input vectors are normalized, 1-10 times of pow(numRecords, -1/inputDim) would be a
-   * reasonable value
-   * @group param
-   */
-  val bucketLength: DoubleParam = new DoubleParam(this, "bucketLength",
-    "the length of each hash bucket, a larger bucket lowers the false negative rate.",
-    ParamValidators.gt(0))
-
-  /** @group getParam */
-  final def getBucketLength: Double = $(bucketLength)
-}
-
-/**
- * :: Experimental ::
- *
- * Model produced by [[RandomProjection]], where multiple random vectors are stored. The vectors
- * are normalized to be unit vectors and each vector is used in a hash function:
- *    `h_i(x) = floor(r_i.dot(x) / bucketLength)`
- * where `r_i` is the i-th random unit vector. The number of buckets will be `(max L2 norm of input
- * vectors) / bucketLength`.
- *
- * @param randUnitVectors An array of random unit vectors. Each vector represents a hash function.
- */
-@Experimental
-@Since("2.1.0")
-class RandomProjectionModel private[ml] (
-    override val uid: String,
-    @Since("2.1.0") val randUnitVectors: Array[Vector])
-  extends LSHModel[RandomProjectionModel] with RandomProjectionParams {
-
-  @Since("2.1.0")
-  override protected[ml] val hashFunction: (Vector) => Vector = {
-    key: Vector => {
-      val hashValues: Array[Double] = randUnitVectors.map({
-        randUnitVector => Math.floor(BLAS.dot(key, randUnitVector) / $(bucketLength))
-      })
-      Vectors.dense(hashValues)
-    }
-  }
-
-  @Since("2.1.0")
-  override protected[ml] def keyDistance(x: Vector, y: Vector): Double = {
-    Math.sqrt(Vectors.sqdist(x, y))
-  }
-
-  @Since("2.1.0")
-  override protected[ml] def hashDistance(x: Vector, y: Vector): Double = {
-    // Since it's generated by hashing, it will be a pair of dense vectors.
-    x.toDense.values.zip(y.toDense.values).map(pair => math.abs(pair._1 - pair._2)).min
-  }
-
-  @Since("2.1.0")
-  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
-
-  @Since("2.1.0")
-  override def write: MLWriter = new RandomProjectionModel.RandomProjectionModelWriter(this)
-}
-
-/**
- * :: Experimental ::
- *
- * This [[RandomProjection]] implements Locality Sensitive Hashing functions for Euclidean
- * distance metrics.
- *
- * The input is dense or sparse vectors, each of which represents a point in the Euclidean
- * distance space. The output will be vectors of configurable dimension. Hash value in the same
- * dimension is calculated by the same hash function.
- *
- * References:
- *
- * 1. <a href="https://en.wikipedia.org/wiki/Locality-sensitive_hashing#Stable_distributions">
- * Wikipedia on Stable Distributions</a>
- *
- * 2. Wang, Jingdong et al. "Hashing for similarity search: A survey." arXiv preprint
- * arXiv:1408.2927 (2014).
- */
-@Experimental
-@Since("2.1.0")
-class RandomProjection(override val uid: String) extends LSH[RandomProjectionModel]
-  with RandomProjectionParams with HasSeed {
-
-  @Since("2.1.0")
-  override def setInputCol(value: String): this.type = super.setInputCol(value)
-
-  @Since("2.1.0")
-  override def setOutputCol(value: String): this.type = super.setOutputCol(value)
-
-  @Since("2.1.0")
-  override def setOutputDim(value: Int): this.type = super.setOutputDim(value)
-
-  @Since("2.1.0")
-  def this() = {
-    this(Identifiable.randomUID("random projection"))
-  }
-
-  /** @group setParam */
-  @Since("2.1.0")
-  def setBucketLength(value: Double): this.type = set(bucketLength, value)
-
-  /** @group setParam */
-  @Since("2.1.0")
-  def setSeed(value: Long): this.type = set(seed, value)
-
-  @Since("2.1.0")
-  override protected[this] def createRawLSHModel(inputDim: Int): RandomProjectionModel = {
-    val rand = new Random($(seed))
-    val randUnitVectors: Array[Vector] = {
-      Array.fill($(outputDim)) {
-        val randArray = Array.fill(inputDim)(rand.nextGaussian())
-        Vectors.fromBreeze(normalize(breeze.linalg.Vector(randArray)))
-      }
-    }
-    new RandomProjectionModel(uid, randUnitVectors)
-  }
-
-  @Since("2.1.0")
-  override def transformSchema(schema: StructType): StructType = {
-    SchemaUtils.checkColumnType(schema, $(inputCol), new VectorUDT)
-    validateAndTransformSchema(schema)
-  }
-
-  @Since("2.1.0")
-  override def copy(extra: ParamMap): this.type = defaultCopy(extra)
-}
-
-@Since("2.1.0")
-object RandomProjection extends DefaultParamsReadable[RandomProjection] {
-
-  @Since("2.1.0")
-  override def load(path: String): RandomProjection = super.load(path)
-}
-
-@Since("2.1.0")
-object RandomProjectionModel extends MLReadable[RandomProjectionModel] {
-
-  @Since("2.1.0")
-  override def read: MLReader[RandomProjectionModel] = new RandomProjectionModelReader
-
-  @Since("2.1.0")
-  override def load(path: String): RandomProjectionModel = super.load(path)
-
-  private[RandomProjectionModel] class RandomProjectionModelWriter(instance: RandomProjectionModel)
-    extends MLWriter {
-
-    // TODO: Save using the existing format of Array[Vector] once SPARK-12878 is resolved.
-    private case class Data(randUnitVectors: Matrix)
-
-    override protected def saveImpl(path: String): Unit = {
-      DefaultParamsWriter.saveMetadata(instance, path, sc)
-      val numRows = instance.randUnitVectors.length
-      require(numRows > 0)
-      val numCols = instance.randUnitVectors.head.size
-      val values = instance.randUnitVectors.map(_.toArray).reduce(Array.concat(_, _))
-      val randMatrix = Matrices.dense(numRows, numCols, values)
-      val data = Data(randMatrix)
-      val dataPath = new Path(path, "data").toString
-      sparkSession.createDataFrame(Seq(data)).repartition(1).write.parquet(dataPath)
-    }
-  }
-
-  private class RandomProjectionModelReader extends MLReader[RandomProjectionModel] {
-
-    /** Checked against metadata when loading model */
-    private val className = classOf[RandomProjectionModel].getName
-
-    override def load(path: String): RandomProjectionModel = {
-      val metadata = DefaultParamsReader.loadMetadata(path, sc, className)
-
-      val dataPath = new Path(path, "data").toString
-      val data = sparkSession.read.parquet(dataPath)
-      val Row(randUnitVectors: Matrix) = MLUtils.convertMatrixColumnsToML(data, "randUnitVectors")
-        .select("randUnitVectors")
-        .head()
-      val model = new RandomProjectionModel(metadata.uid, randUnitVectors.rowIter.toArray)
-
-      DefaultParamsReader.getAndSetParams(model, metadata)
-      model
-    }
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/05f7c6ff/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
new file mode 100644
index 0000000..ab93768
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/BucketedRandomProjectionLSHSuite.scala
@@ -0,0 +1,213 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import breeze.numerics.{cos, sin}
+import breeze.numerics.constants.Pi
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.TestingUtils._
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+
+class BucketedRandomProjectionLSHSuite
+  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+  @transient var dataset: Dataset[_] = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+
+    val data = {
+      for (i <- -10 until 10; j <- -10 until 10) yield Vectors.dense(i.toDouble, j.toDouble)
+    }
+    dataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
+  }
+
+  test("params") {
+    ParamsSuite.checkParams(new BucketedRandomProjectionLSH)
+    val model = new BucketedRandomProjectionLSHModel(
+      "brp", randUnitVectors = Array(Vectors.dense(1.0, 0.0)))
+    ParamsSuite.checkParams(model)
+  }
+
+  test("BucketedRandomProjectionLSH: default params") {
+    val brp = new BucketedRandomProjectionLSH
+    assert(brp.getNumHashTables === 1.0)
+  }
+
+  test("read/write") {
+    def checkModelData(
+      model: BucketedRandomProjectionLSHModel,
+      model2: BucketedRandomProjectionLSHModel): Unit = {
+      model.randUnitVectors.zip(model2.randUnitVectors)
+        .foreach(pair => assert(pair._1 === pair._2))
+    }
+    val mh = new BucketedRandomProjectionLSH()
+    val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0)
+    testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+  }
+
+  test("hashFunction") {
+    val randUnitVectors = Array(Vectors.dense(0.0, 1.0), Vectors.dense(1.0, 0.0))
+    val model = new BucketedRandomProjectionLSHModel("brp", randUnitVectors)
+    model.set(model.bucketLength, 0.5)
+    val res = model.hashFunction(Vectors.dense(1.23, 4.56))
+    assert(res.length == 2)
+    assert(res(0).equals(Vectors.dense(9.0)))
+    assert(res(1).equals(Vectors.dense(2.0)))
+  }
+
+  test("keyDistance") {
+    val model = new BucketedRandomProjectionLSHModel("brp", Array(Vectors.dense(0.0, 1.0)))
+    val keyDist = model.keyDistance(Vectors.dense(1, 2), Vectors.dense(-2, -2))
+    assert(keyDist === 5)
+  }
+
+  test("BucketedRandomProjectionLSH: randUnitVectors") {
+    val brp = new BucketedRandomProjectionLSH()
+      .setNumHashTables(20)
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setBucketLength(1.0)
+      .setSeed(12345)
+    val unitVectors = brp.fit(dataset).randUnitVectors
+    unitVectors.foreach { v: Vector =>
+      assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14)
+    }
+  }
+
+  test("BucketedRandomProjectionLSH: test of LSH property") {
+    // Project from 2 dimensional Euclidean Space to 1 dimensions
+    val brp = new BucketedRandomProjectionLSH()
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setBucketLength(1.0)
+      .setSeed(12345)
+
+    val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, brp, 8.0, 2.0)
+    assert(falsePositive < 0.4)
+    assert(falseNegative < 0.4)
+  }
+
+  test("BucketedRandomProjectionLSH with high dimension data: test of LSH property") {
+    val numDim = 100
+    val data = {
+      for (i <- 0 until numDim; j <- Seq(-2, -1, 1, 2))
+        yield Vectors.sparse(numDim, Seq((i, j.toDouble)))
+    }
+    val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
+
+    // Project from 100 dimensional Euclidean Space to 10 dimensions
+    val brp = new BucketedRandomProjectionLSH()
+      .setNumHashTables(10)
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setBucketLength(2.5)
+      .setSeed(12345)
+
+    val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(df, brp, 3.0, 2.0)
+    assert(falsePositive < 0.3)
+    assert(falseNegative < 0.3)
+  }
+
+  test("approxNearestNeighbors for bucketed random projection") {
+    val key = Vectors.dense(1.2, 3.4)
+
+    val brp = new BucketedRandomProjectionLSH()
+      .setNumHashTables(2)
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setBucketLength(4.0)
+      .setSeed(12345)
+
+    val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(brp, dataset, key, 100,
+      singleProbe = true)
+    assert(precision >= 0.6)
+    assert(recall >= 0.6)
+  }
+
+  test("approxNearestNeighbors with multiple probing") {
+    val key = Vectors.dense(1.2, 3.4)
+
+    val brp = new BucketedRandomProjectionLSH()
+      .setNumHashTables(20)
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setBucketLength(1.0)
+      .setSeed(12345)
+
+    val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(brp, dataset, key, 100,
+      singleProbe = false)
+    assert(precision >= 0.7)
+    assert(recall >= 0.7)
+  }
+
+  test("approxNearestNeighbors for numNeighbors <= 0") {
+    val key = Vectors.dense(1.2, 3.4)
+
+    val model = new BucketedRandomProjectionLSHModel(
+      "brp", randUnitVectors = Array(Vectors.dense(1.0, 0.0)))
+
+    intercept[IllegalArgumentException] {
+      model.approxNearestNeighbors(dataset, key, 0)
+    }
+    intercept[IllegalArgumentException] {
+      model.approxNearestNeighbors(dataset, key, -1)
+    }
+  }
+
+  test("approxSimilarityJoin for bucketed random projection on different dataset") {
+    val data2 = {
+      for (i <- 0 until 24) yield Vectors.dense(10 * sin(Pi / 12 * i), 10 * cos(Pi / 12 * i))
+    }
+    val dataset2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")
+
+    val brp = new BucketedRandomProjectionLSH()
+      .setNumHashTables(2)
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setBucketLength(4.0)
+      .setSeed(12345)
+
+    val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(brp, dataset, dataset2, 1.0)
+    assert(precision == 1.0)
+    assert(recall >= 0.7)
+  }
+
+  test("approxSimilarityJoin for self join") {
+    val data = {
+      for (i <- 0 until 24) yield Vectors.dense(10 * sin(Pi / 12 * i), 10 * cos(Pi / 12 * i))
+    }
+    val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
+
+    val brp = new BucketedRandomProjectionLSH()
+      .setNumHashTables(2)
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setBucketLength(4.0)
+      .setSeed(12345)
+
+    val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(brp, df, df, 3.0)
+    assert(precision == 1.0)
+    assert(recall >= 0.7)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/05f7c6ff/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
index 5c02554..a9b559f 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/LSHTest.scala
@@ -58,12 +58,18 @@ private[ml] object LSHTest {
     val outputCol = model.getOutputCol
     val transformedData = model.transform(dataset)
 
-    SchemaUtils.checkColumnType(transformedData.schema, model.getOutputCol, new VectorUDT)
+    // Check output column type
+    SchemaUtils.checkColumnType(
+      transformedData.schema, model.getOutputCol, DataTypes.createArrayType(new VectorUDT))
+
+    // Check output column dimensions
+    val headHashValue = transformedData.select(outputCol).head().get(0).asInstanceOf[Seq[Vector]]
+    assert(headHashValue.length == model.getNumHashTables)
 
     // Perform a cross join and label each pair of same_bucket and distance
     val pairs = transformedData.as("a").crossJoin(transformedData.as("b"))
     val distUDF = udf((x: Vector, y: Vector) => model.keyDistance(x, y), DataTypes.DoubleType)
-    val sameBucket = udf((x: Vector, y: Vector) => model.hashDistance(x, y) == 0.0,
+    val sameBucket = udf((x: Seq[Vector], y: Seq[Vector]) => model.hashDistance(x, y) == 0.0,
       DataTypes.BooleanType)
     val result = pairs
       .withColumn("same_bucket", sameBucket(col(s"a.$outputCol"), col(s"b.$outputCol")))
@@ -83,6 +89,7 @@ private[ml] object LSHTest {
    * @param dataset the dataset to look for the key
    * @param key The key to hash for the item
    * @param k The maximum number of items closest to the key
+   * @param singleProbe True for using single-probe; false for multi-probe
    * @tparam T The class type of lsh
    * @return A tuple of two doubles, representing precision and recall rate
    */
@@ -91,7 +98,7 @@ private[ml] object LSHTest {
       dataset: Dataset[_],
       key: Vector,
       k: Int,
-      singleProbing: Boolean): (Double, Double) = {
+      singleProbe: Boolean): (Double, Double) = {
     val model = lsh.fit(dataset)
 
     // Compute expected
@@ -99,14 +106,14 @@ private[ml] object LSHTest {
     val expected = dataset.sort(distUDF(col(model.getInputCol))).limit(k)
 
     // Compute actual
-    val actual = model.approxNearestNeighbors(dataset, key, k, singleProbing, "distCol")
+    val actual = model.approxNearestNeighbors(dataset, key, k, singleProbe, "distCol")
 
     assert(actual.schema.sameType(model
       .transformSchema(dataset.schema)
       .add("distCol", DataTypes.DoubleType))
     )
 
-    if (!singleProbing) {
+    if (!singleProbe) {
       assert(actual.count() == k)
     }
 

http://git-wip-us.apache.org/repos/asf/spark/blob/05f7c6ff/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
new file mode 100644
index 0000000..3461cdf
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashLSHSuite.scala
@@ -0,0 +1,161 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Dataset
+
+class MinHashLSHSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+  @transient var dataset: Dataset[_] = _
+
+  override def beforeAll(): Unit = {
+    super.beforeAll()
+
+    val data = {
+      for (i <- 0 to 95) yield Vectors.sparse(100, (i until i + 5).map((_, 1.0)))
+    }
+    dataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
+  }
+
+  test("params") {
+    ParamsSuite.checkParams(new MinHashLSH)
+    val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0)))
+    ParamsSuite.checkParams(model)
+  }
+
+  test("MinHashLSH: default params") {
+    val rp = new MinHashLSH
+    assert(rp.getNumHashTables === 1.0)
+  }
+
+  test("read/write") {
+    def checkModelData(model: MinHashLSHModel, model2: MinHashLSHModel): Unit = {
+      assertResult(model.randCoefficients)(model2.randCoefficients)
+    }
+    val mh = new MinHashLSH()
+    val settings = Map("inputCol" -> "keys", "outputCol" -> "values")
+    testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
+  }
+
+  test("hashFunction") {
+    val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0)))
+    val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0))))
+    assert(res.length == 3)
+    assert(res(0).equals(Vectors.dense(1.0)))
+    assert(res(1).equals(Vectors.dense(5.0)))
+    assert(res(2).equals(Vectors.dense(9.0)))
+  }
+
+  test("hashFunction: empty vector") {
+    val model = new MinHashLSHModel("mh", randCoefficients = Array((0, 1), (1, 2), (3, 0)))
+    intercept[IllegalArgumentException] {
+      model.hashFunction(Vectors.sparse(10, Seq()))
+    }
+  }
+
+  test("keyDistance") {
+    val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0)))
+    val v1 = Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)))
+    val v2 = Vectors.sparse(10, Seq((1, 1.0), (3, 1.0), (5, 1.0), (7, 1.0), (9, 1.0)))
+    val keyDist = model.keyDistance(v1, v2)
+    assert(keyDist === 0.5)
+  }
+
+  test("MinHashLSH: test of LSH property") {
+    val mh = new MinHashLSH()
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setSeed(12344)
+
+    val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, mh, 0.75, 0.5)
+    assert(falsePositive < 0.3)
+    assert(falseNegative < 0.3)
+  }
+
+  test("MinHashLSH: test of inputDim > prime") {
+    val mh = new MinHashLSH()
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setSeed(12344)
+
+    val data = {
+      for (i <- 0 to 2) yield Vectors.sparse(Int.MaxValue, (i until i + 5).map((_, 1.0)))
+    }
+    val badDataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
+    intercept[IllegalArgumentException] {
+      mh.fit(badDataset)
+    }
+  }
+
+  test("approxNearestNeighbors for min hash") {
+    val mh = new MinHashLSH()
+      .setNumHashTables(20)
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setSeed(12345)
+
+    val key: Vector = Vectors.sparse(100,
+      (0 until 100).filter(_.toString.contains("1")).map((_, 1.0)))
+
+    val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(mh, dataset, key, 20,
+      singleProbe = true)
+    assert(precision >= 0.7)
+    assert(recall >= 0.7)
+  }
+
+  test("approxNearestNeighbors for numNeighbors <= 0") {
+    val model = new MinHashLSHModel("mh", randCoefficients = Array((1, 0)))
+
+    val key: Vector = Vectors.sparse(100,
+      (0 until 100).filter(_.toString.contains("1")).map((_, 1.0)))
+
+    intercept[IllegalArgumentException] {
+      model.approxNearestNeighbors(dataset, key, 0)
+    }
+    intercept[IllegalArgumentException] {
+      model.approxNearestNeighbors(dataset, key, -1)
+    }
+  }
+
+  test("approxSimilarityJoin for min hash on different dataset") {
+    val data1 = {
+      for (i <- 0 until 20) yield Vectors.sparse(100, (5 * i until 5 * i + 5).map((_, 1.0)))
+    }
+    val df1 = spark.createDataFrame(data1.map(Tuple1.apply)).toDF("keys")
+
+    val data2 = {
+      for (i <- 0 until 30) yield Vectors.sparse(100, (3 * i until 3 * i + 3).map((_, 1.0)))
+    }
+    val df2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")
+
+    val mh = new MinHashLSH()
+      .setNumHashTables(20)
+      .setInputCol("keys")
+      .setOutputCol("values")
+      .setSeed(12345)
+
+    val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(mh, df1, df2, 0.5)
+    assert(precision == 1.0)
+    assert(recall >= 0.7)
+  }
+}

http://git-wip-us.apache.org/repos/asf/spark/blob/05f7c6ff/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala
deleted file mode 100644
index c32ca7d..0000000
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/MinHashSuite.scala
+++ /dev/null
@@ -1,126 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.feature
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Dataset
-
-class MinHashSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
-
-  @transient var dataset: Dataset[_] = _
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-
-    val data = {
-      for (i <- 0 to 95) yield Vectors.sparse(100, (i until i + 5).map((_, 1.0)))
-    }
-    dataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
-  }
-
-  test("params") {
-    ParamsSuite.checkParams(new MinHash)
-    val model = new MinHashModel("mh", numEntries = 2, randCoefficients = Array(1))
-    ParamsSuite.checkParams(model)
-  }
-
-  test("MinHash: default params") {
-    val rp = new MinHash
-    assert(rp.getOutputDim === 1.0)
-  }
-
-  test("read/write") {
-    def checkModelData(model: MinHashModel, model2: MinHashModel): Unit = {
-      assert(model.numEntries === model2.numEntries)
-      assertResult(model.randCoefficients)(model2.randCoefficients)
-    }
-    val mh = new MinHash()
-    val settings = Map("inputCol" -> "keys", "outputCol" -> "values")
-    testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
-  }
-
-  test("hashFunction") {
-    val model = new MinHashModel("mh", numEntries = 20, randCoefficients = Array(0, 1, 3))
-    val res = model.hashFunction(Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0))))
-    assert(res.equals(Vectors.dense(0.0, 3.0, 4.0)))
-  }
-
-  test("keyDistance and hashDistance") {
-    val model = new MinHashModel("mh", numEntries = 20, randCoefficients = Array(1))
-    val v1 = Vectors.sparse(10, Seq((2, 1.0), (3, 1.0), (5, 1.0), (7, 1.0)))
-    val v2 = Vectors.sparse(10, Seq((1, 1.0), (3, 1.0), (5, 1.0), (7, 1.0), (9, 1.0)))
-    val keyDist = model.keyDistance(v1, v2)
-    val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2))
-    assert(keyDist === 0.5)
-    assert(hashDist === 3)
-  }
-
-  test("MinHash: test of LSH property") {
-    val mh = new MinHash()
-      .setOutputDim(1)
-      .setInputCol("keys")
-      .setOutputCol("values")
-      .setSeed(12344)
-
-    val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, mh, 0.75, 0.5)
-    assert(falsePositive < 0.3)
-    assert(falseNegative < 0.3)
-  }
-
-  test("approxNearestNeighbors for min hash") {
-    val mh = new MinHash()
-      .setOutputDim(20)
-      .setInputCol("keys")
-      .setOutputCol("values")
-      .setSeed(12345)
-
-    val key: Vector = Vectors.sparse(100,
-      (0 until 100).filter(_.toString.contains("1")).map((_, 1.0)))
-
-    val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(mh, dataset, key, 20,
-      singleProbing = true)
-    assert(precision >= 0.7)
-    assert(recall >= 0.7)
-  }
-
-  test("approxSimilarityJoin for minhash on different dataset") {
-    val data1 = {
-      for (i <- 0 until 20) yield Vectors.sparse(100, (5 * i until 5 * i + 5).map((_, 1.0)))
-    }
-    val df1 = spark.createDataFrame(data1.map(Tuple1.apply)).toDF("keys")
-
-    val data2 = {
-      for (i <- 0 until 30) yield Vectors.sparse(100, (3 * i until 3 * i + 3).map((_, 1.0)))
-    }
-    val df2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")
-
-    val mh = new MinHash()
-      .setOutputDim(20)
-      .setInputCol("keys")
-      .setOutputCol("values")
-      .setSeed(12345)
-
-    val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(mh, df1, df2, 0.5)
-    assert(precision == 1.0)
-    assert(recall >= 0.7)
-  }
-}

http://git-wip-us.apache.org/repos/asf/spark/blob/05f7c6ff/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala
----------------------------------------------------------------------
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala
deleted file mode 100644
index cd82ee2..0000000
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/RandomProjectionSuite.scala
+++ /dev/null
@@ -1,197 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.ml.feature
-
-import breeze.numerics.{cos, sin}
-import breeze.numerics.constants.Pi
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.linalg.{Vector, Vectors}
-import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
-import org.apache.spark.ml.util.TestingUtils._
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.sql.Dataset
-
-class RandomProjectionSuite
-  extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
-
-  @transient var dataset: Dataset[_] = _
-
-  override def beforeAll(): Unit = {
-    super.beforeAll()
-
-    val data = {
-      for (i <- -10 until 10; j <- -10 until 10) yield Vectors.dense(i.toDouble, j.toDouble)
-    }
-    dataset = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
-  }
-
-  test("params") {
-    ParamsSuite.checkParams(new RandomProjection)
-    val model = new RandomProjectionModel("rp", randUnitVectors = Array(Vectors.dense(1.0, 0.0)))
-    ParamsSuite.checkParams(model)
-  }
-
-  test("RandomProjection: default params") {
-    val rp = new RandomProjection
-    assert(rp.getOutputDim === 1.0)
-  }
-
-  test("read/write") {
-    def checkModelData(model: RandomProjectionModel, model2: RandomProjectionModel): Unit = {
-      model.randUnitVectors.zip(model2.randUnitVectors)
-        .foreach(pair => assert(pair._1 === pair._2))
-    }
-    val mh = new RandomProjection()
-    val settings = Map("inputCol" -> "keys", "outputCol" -> "values", "bucketLength" -> 1.0)
-    testEstimatorAndModelReadWrite(mh, dataset, settings, checkModelData)
-  }
-
-  test("hashFunction") {
-    val randUnitVectors = Array(Vectors.dense(0.0, 1.0), Vectors.dense(1.0, 0.0))
-    val model = new RandomProjectionModel("rp", randUnitVectors)
-    model.set(model.bucketLength, 0.5)
-    val res = model.hashFunction(Vectors.dense(1.23, 4.56))
-    assert(res.equals(Vectors.dense(9.0, 2.0)))
-  }
-
-  test("keyDistance and hashDistance") {
-    val model = new RandomProjectionModel("rp", Array(Vectors.dense(0.0, 1.0)))
-    val keyDist = model.keyDistance(Vectors.dense(1, 2), Vectors.dense(-2, -2))
-    val hashDist = model.hashDistance(Vectors.dense(-5, 5), Vectors.dense(1, 2))
-    assert(keyDist === 5)
-    assert(hashDist === 3)
-  }
-
-  test("RandomProjection: randUnitVectors") {
-    val rp = new RandomProjection()
-      .setOutputDim(20)
-      .setInputCol("keys")
-      .setOutputCol("values")
-      .setBucketLength(1.0)
-      .setSeed(12345)
-    val unitVectors = rp.fit(dataset).randUnitVectors
-    unitVectors.foreach { v: Vector =>
-      assert(Vectors.norm(v, 2.0) ~== 1.0 absTol 1e-14)
-    }
-  }
-
-  test("RandomProjection: test of LSH property") {
-    // Project from 2 dimensional Euclidean Space to 1 dimensions
-    val rp = new RandomProjection()
-      .setOutputDim(1)
-      .setInputCol("keys")
-      .setOutputCol("values")
-      .setBucketLength(1.0)
-      .setSeed(12345)
-
-    val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(dataset, rp, 8.0, 2.0)
-    assert(falsePositive < 0.4)
-    assert(falseNegative < 0.4)
-  }
-
-  test("RandomProjection with high dimension data: test of LSH property") {
-    val numDim = 100
-    val data = {
-      for (i <- 0 until numDim; j <- Seq(-2, -1, 1, 2))
-        yield Vectors.sparse(numDim, Seq((i, j.toDouble)))
-    }
-    val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
-
-    // Project from 100 dimensional Euclidean Space to 10 dimensions
-    val rp = new RandomProjection()
-      .setOutputDim(10)
-      .setInputCol("keys")
-      .setOutputCol("values")
-      .setBucketLength(2.5)
-      .setSeed(12345)
-
-    val (falsePositive, falseNegative) = LSHTest.calculateLSHProperty(df, rp, 3.0, 2.0)
-    assert(falsePositive < 0.3)
-    assert(falseNegative < 0.3)
-  }
-
-  test("approxNearestNeighbors for random projection") {
-    val key = Vectors.dense(1.2, 3.4)
-
-    val rp = new RandomProjection()
-      .setOutputDim(2)
-      .setInputCol("keys")
-      .setOutputCol("values")
-      .setBucketLength(4.0)
-      .setSeed(12345)
-
-    val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100,
-      singleProbing = true)
-    assert(precision >= 0.6)
-    assert(recall >= 0.6)
-  }
-
-  test("approxNearestNeighbors with multiple probing") {
-    val key = Vectors.dense(1.2, 3.4)
-
-    val rp = new RandomProjection()
-      .setOutputDim(20)
-      .setInputCol("keys")
-      .setOutputCol("values")
-      .setBucketLength(1.0)
-      .setSeed(12345)
-
-    val (precision, recall) = LSHTest.calculateApproxNearestNeighbors(rp, dataset, key, 100,
-      singleProbing = false)
-    assert(precision >= 0.7)
-    assert(recall >= 0.7)
-  }
-
-  test("approxSimilarityJoin for random projection on different dataset") {
-    val data2 = {
-      for (i <- 0 until 24) yield Vectors.dense(10 * sin(Pi / 12 * i), 10 * cos(Pi / 12 * i))
-    }
-    val dataset2 = spark.createDataFrame(data2.map(Tuple1.apply)).toDF("keys")
-
-    val rp = new RandomProjection()
-      .setOutputDim(2)
-      .setInputCol("keys")
-      .setOutputCol("values")
-      .setBucketLength(4.0)
-      .setSeed(12345)
-
-    val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(rp, dataset, dataset2, 1.0)
-    assert(precision == 1.0)
-    assert(recall >= 0.7)
-  }
-
-  test("approxSimilarityJoin for self join") {
-    val data = {
-      for (i <- 0 until 24) yield Vectors.dense(10 * sin(Pi / 12 * i), 10 * cos(Pi / 12 * i))
-    }
-    val df = spark.createDataFrame(data.map(Tuple1.apply)).toDF("keys")
-
-    val rp = new RandomProjection()
-      .setOutputDim(2)
-      .setInputCol("keys")
-      .setOutputCol("values")
-      .setBucketLength(4.0)
-      .setSeed(12345)
-
-    val (precision, recall) = LSHTest.calculateApproxSimilarityJoin(rp, df, df, 3.0)
-    assert(precision == 1.0)
-    assert(recall >= 0.7)
-  }
-}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscribe@spark.apache.org
For additional commands, e-mail: commits-help@spark.apache.org