You are viewing a plain text version of this content. The canonical link for it is here.
Posted to reviews@spark.apache.org by MrBago <gi...@git.apache.org> on 2017/12/05 00:23:10 UTC
[GitHub] spark pull request #19746: [SPARK-22346][ML] VectorSizeHint Transformer for ...
Github user MrBago commented on a diff in the pull request:
https://github.com/apache/spark/pull/19746#discussion_r154815581
--- Diff: mllib/src/test/scala/org/apache/spark/ml/feature/VectorSizeHintSuite.scala ---
@@ -0,0 +1,173 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.feature
+
+import org.apache.spark.{SparkException, SparkFunSuite}
+import org.apache.spark.ml.attribute.AttributeGroup
+import org.apache.spark.ml.linalg.{Vector, Vectors}
+import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.Row
+import org.apache.spark.sql.execution.streaming.MemoryStream
+import org.apache.spark.sql.streaming.StreamTest
+
+class VectorSizeHintSuite
+ extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
+
+ import testImplicits._
+
+ test("Test Param Validators") {
+ intercept[IllegalArgumentException] (new VectorSizeHint().setHandleInvalid("invalidValue"))
+ intercept[IllegalArgumentException] (new VectorSizeHint().setSize(-3))
+ }
+
+ test("Adding size to column of vectors.") {
+
+ val size = 3
+ val vectorColName = "vector"
+ val denseVector = Vectors.dense(1, 2, 3)
+ val sparseVector = Vectors.sparse(size, Array(), Array())
+
+ val data = Seq(denseVector, denseVector, sparseVector).map(Tuple1.apply)
+ val dataFrame = data.toDF(vectorColName)
+ assert(
+ AttributeGroup.fromStructField(dataFrame.schema(vectorColName)).size == -1,
+ "Transformer did not add expected size data.")
+
+ for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
+ val transformer = new VectorSizeHint()
+ .setInputCol(vectorColName)
+ .setSize(size)
+ .setHandleInvalid(handleInvalid)
+ val withSize = transformer.transform(dataFrame)
+ assert(
+ AttributeGroup.fromStructField(withSize.schema(vectorColName)).size == size,
+ "Transformer did not add expected size data.")
+ withSize.collect
+ }
+ }
+
+ test("Size hint preserves attributes.") {
+
+ val size = 3
+ val vectorColName = "vector"
+ val data = Seq((1, 2, 3), (2, 3, 3))
+ val dataFrame = data.toDF("x", "y", "z")
+
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("x", "y", "z"))
+ .setOutputCol(vectorColName)
+ val dataFrameWithMetadata = assembler.transform(dataFrame)
+ val group = AttributeGroup.fromStructField(dataFrameWithMetadata.schema(vectorColName))
+
+ for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
+ val transformer = new VectorSizeHint()
+ .setInputCol(vectorColName)
+ .setSize(size)
+ .setHandleInvalid(handleInvalid)
+ val withSize = transformer.transform(dataFrameWithMetadata)
+
+ val newGroup = AttributeGroup.fromStructField(withSize.schema(vectorColName))
+ assert(newGroup.size === size, "Transformer did not add expected size data.")
+ assert(
+ newGroup.attributes.get.deep === group.attributes.get.deep,
+ "SizeHintTransformer did not preserve attributes.")
+ withSize.collect
+ }
+ }
+
+ test("Size miss-match between current and target size raises an error.") {
+ val size = 4
+ val vectorColName = "vector"
+ val data = Seq((1, 2, 3), (2, 3, 3))
+ val dataFrame = data.toDF("x", "y", "z")
+
+ val assembler = new VectorAssembler()
+ .setInputCols(Array("x", "y", "z"))
+ .setOutputCol(vectorColName)
+ val dataFrameWithMetadata = assembler.transform(dataFrame)
+
+ for (handleInvalid <- VectorSizeHint.supportedHandleInvalids) {
+ val transformer = new VectorSizeHint()
+ .setInputCol(vectorColName)
+ .setSize(size)
+ .setHandleInvalid(handleInvalid)
+ intercept[SparkException](transformer.transform(dataFrameWithMetadata))
+ }
+ }
+
+ test("Handle invalid does the right thing.") {
+
+ val vector = Vectors.dense(1, 2, 3)
+ val short = Vectors.dense(2)
+ val dataWithNull = Seq(vector, null).map(Tuple1.apply).toDF("vector")
+ val dataWithShort = Seq(vector, short).map(Tuple1.apply).toDF("vector")
+
+ val sizeHint = new VectorSizeHint()
+ .setInputCol("vector")
+ .setHandleInvalid("error")
+ .setSize(3)
+
+ intercept[SparkException](sizeHint.transform(dataWithNull).collect)
+ intercept[SparkException](sizeHint.transform(dataWithShort).collect)
+
+ sizeHint.setHandleInvalid("skip")
+ assert(sizeHint.transform(dataWithNull).count() === 1)
+ assert(sizeHint.transform(dataWithShort).count() === 1)
+ }
+
+ test("read/write") {
+ val sizeHint = new VectorSizeHint()
+ .setInputCol("myInputCol")
+ .setSize(11)
+ .setHandleInvalid("skip")
+ testDefaultReadWrite(sizeHint)
+ }
+}
+
+class VectorSizeHintStreamingSuite extends StreamTest {
+
+ import testImplicits._
+
+ test("Test assemble vectors with size hint in steaming.") {
+ val a = Vectors.dense(0, 1, 2)
+ val b = Vectors.sparse(4, Array(0, 3), Array(3, 6))
+
+ val stream = MemoryStream[(Vector, Vector)]
+ val streamingDF = stream.toDS.toDF("a", "b")
+ val sizeHintA = new VectorSizeHint()
+ .setSize(3)
+ .setInputCol("a")
+ val sizeHintB = new VectorSizeHint()
+ .setSize(4)
+ .setInputCol("b")
+ val vectorAssembler = new VectorAssembler()
+ .setInputCols(Array("a", "b"))
+ .setOutputCol("assembled")
+ val output = Seq(sizeHintA, sizeHintB, vectorAssembler).foldLeft(streamingDF) {
+ case (data, transform) => transform.transform(data)
+ }.select("assembled")
+
+ val expected = Vectors.dense(0, 1, 2, 3, 0, 0, 6)
+
+ testStream (output) (
+ AddData(stream, (a, b), (a, b)),
+ CheckAnswerRows(Seq(Row(expected), Row(expected)), false, false)
--- End diff --
The reason I didn't use `CheckAnswer` is because there isn't an implicit encoder in `testImplicits` that handles `Vector`. I tried `CheckAnswer[Vector](expected, expected)` but that doesn't work either :(. Is there an encoder that works for Vectors?
---
---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscribe@spark.apache.org
For additional commands, e-mail: reviews-help@spark.apache.org