You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/11/08 09:06:00 UTC
[flink-ml] branch master updated: [FLINK-29323] Add inputSizes parameter for VectorAssembler
This is an automated email from the ASF dual-hosted git repository.
zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new ae5c011 [FLINK-29323] Add inputSizes parameter for VectorAssembler
ae5c011 is described below
commit ae5c0115e6f271e4bb125675a4e04f5d672814d5
Author: weibo <wb...@pku.edu.cn>
AuthorDate: Tue Nov 8 17:05:55 2022 +0800
[FLINK-29323] Add inputSizes parameter for VectorAssembler
This closes #156.
---
.../docs/operators/feature/vectorassembler.md | 35 ++-
.../examples/feature/VectorAssemblerExample.java | 3 +-
.../feature/vectorassembler/VectorAssembler.java | 128 ++++++++---
.../vectorassembler/VectorAssemblerParams.java | 35 ++-
.../flink/ml/feature/VectorAssemblerTest.java | 253 +++++++++++++++++++--
.../examples/ml/feature/vectorassembler_example.py | 1 +
.../ml/lib/feature/tests/test_vectorassembler.py | 5 +-
.../pyflink/ml/lib/feature/vectorassembler.py | 54 ++++-
8 files changed, 441 insertions(+), 73 deletions(-)
diff --git a/docs/content/docs/operators/feature/vectorassembler.md b/docs/content/docs/operators/feature/vectorassembler.md
index 10f3fc3..84f064a 100644
--- a/docs/content/docs/operators/feature/vectorassembler.md
+++ b/docs/content/docs/operators/feature/vectorassembler.md
@@ -26,10 +26,20 @@ under the License.
-->
## Vector Assembler
-
-Vector Assembler combines a given list of input columns into a vector column.
-Types of input columns must be either vector or numerical value.
-
+A Transformer which combines a given list of input columns into a vector column. Input columns
+would be numerical or vectors whose sizes are specified by the {@link #INPUT_SIZES} parameter.
+Invalid input data with null values or values with wrong sizes would be dealt with according to
+the strategy specified by the {@link HasHandleInvalid} parameter as follows:
+<ul>
+ <li>keep: If the input column data is null, a vector would be created with the specified size
+ and NaN values. The vector would be used in the assembling process to represent the input
+ column data. If the input column data is a vector, the data would be used in the assembling
+ process even if it has a wrong size.
+ <li>skip: If the input column data is null or a vector with wrong size, the input row would be
+ filtered out and not be sent to downstream operators.
+ <li>error: If the input column data is null or a vector with wrong size, an exception would be
+ thrown.
+</ul>
### Input Columns
| Param name | Type | Default | Description |
@@ -44,11 +54,12 @@ Types of input columns must be either vector or numerical value.
### Parameters
-| Key | Default | Type | Required | Description |
-|---------------|------------|----------|----------|--------------------------------------------------------------------------------|
-| inputCols | `null` | String[] | yes | Input column names. |
-| outputCol | `"output"` | String | no | Output column name. |
-| handleInvalid | `"error"` | String | no | Strategy to handle invalid entries. Supported values: 'error', 'skip', 'keep'. |
+| Key | Default | Type | Required | Description |
+|-----------------|------------|-----------|----------|--------------------------------------------------------------------------------|
+| inputCols | `null` | String[] | yes | Input column names. |
+| outputCol | `"output"` | String | no | Output column name. |
+| inputSizes | `null` | Integer[] | yes | Sizes of the input elements to be assembled. |
+| handleInvalid | `"error"` | String | no | Strategy to handle invalid entries. Supported values: 'error', 'skip', 'keep'. |
### Examples
@@ -95,7 +106,8 @@ public class VectorAssemblerExample {
VectorAssembler vectorAssembler =
new VectorAssembler()
.setInputCols("vec", "num", "sparseVec")
- .setOutputCol("assembledVec");
+ .setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 5);
// Uses the VectorAssembler object for feature transformations.
Table outputTable = vectorAssembler.transform(inputTable)[0];
@@ -159,9 +171,10 @@ input_data_table = t_env.from_data_stream(
vector_assembler = VectorAssembler() \
.set_input_cols('vec', 'num', 'sparse_vec') \
.set_output_col('assembled_vec') \
+ .set_input_sizes(2, 1, 5) \
.set_handle_invalid('keep')
-# use the vector assembler model for feature engineering
+# use the vector assembler for feature engineering
output = vector_assembler.transform(input_data_table)[0]
# extract and display the results
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java
index 50e51c2..0c14625 100644
--- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java
@@ -56,7 +56,8 @@ public class VectorAssemblerExample {
VectorAssembler vectorAssembler =
new VectorAssembler()
.setInputCols("vec", "num", "sparseVec")
- .setOutputCol("assembledVec");
+ .setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 5);
// Uses the VectorAssembler object for feature transformations.
Table outputTable = vectorAssembler.transform(inputTable)[0];
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
index e3a01f1..e951f80 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
@@ -19,6 +19,7 @@
package org.apache.flink.ml.feature.vectorassembler;
import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
import org.apache.flink.api.java.typeutils.RowTypeInfo;
import org.apache.flink.ml.api.Transformer;
import org.apache.flink.ml.common.datastream.TableUtils;
@@ -46,11 +47,21 @@ import java.util.HashMap;
import java.util.Map;
/**
- * A Transformer which combines a given list of input columns into a vector column. Types of input
- * columns must be either vector or numerical value.
+ * A Transformer which combines a given list of input columns into a vector column. Input columns
+ * would be numerical or vectors whose sizes are specified by the {@link #INPUT_SIZES} parameter.
+ * Invalid input data with null values or values with wrong sizes would be dealt with according to
+ * the strategy specified by the {@link HasHandleInvalid} parameter as follows:
*
- * <p>The `keep` option of {@link HasHandleInvalid} means that we output bad rows with output column
- * set to null.
+ * <ul>
+ * <li>keep: If the input column data is null, a vector would be created with the specified size
+ * and NaN values. The vector would be used in the assembling process to represent the input
+ * column data. If the input column data is a vector, the data would be used in the assembling
+ * process even if it has a wrong size.
+ * <li>skip: If the input column data is null or a vector with wrong size, the input row would be
+ * filtered out and not be sent to downstream operators.
+ * <li>error: If the input column data is null or a vector with wrong size, an exception would be
+ * thrown.
+ * </ul>
*/
public class VectorAssembler
implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> {
@@ -64,6 +75,7 @@ public class VectorAssembler
@Override
public Table[] transform(Table... inputs) {
Preconditions.checkArgument(inputs.length == 1);
+ Preconditions.checkArgument(getInputSizes().length == getInputCols().length);
StreamTableEnvironment tEnv =
(StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
@@ -74,64 +86,107 @@ public class VectorAssembler
DataStream<Row> output =
tEnv.toDataStream(inputs[0])
.flatMap(
- new AssemblerFunc(getInputCols(), getHandleInvalid()),
+ new AssemblerFunction(
+ getInputCols(), getHandleInvalid(), getInputSizes()),
outputTypeInfo);
Table outputTable = tEnv.fromDataStream(output);
return new Table[] {outputTable};
}
- private static class AssemblerFunc implements FlatMapFunction<Row, Row> {
+ private static class AssemblerFunction implements FlatMapFunction<Row, Row> {
private final String[] inputCols;
private final String handleInvalid;
+ private final Integer[] inputSizes;
+ private final boolean keepInvalid;
- public AssemblerFunc(String[] inputCols, String handleInvalid) {
+ public AssemblerFunction(String[] inputCols, String handleInvalid, Integer[] inputSizes) {
this.inputCols = inputCols;
this.handleInvalid = handleInvalid;
+ this.inputSizes = inputSizes;
+ keepInvalid = handleInvalid.equals(HasHandleInvalid.KEEP_INVALID);
}
@Override
public void flatMap(Row value, Collector<Row> out) {
- int nnz = 0;
- int vectorSize = 0;
try {
- for (String inputCol : inputCols) {
- Object object = value.getField(inputCol);
- Preconditions.checkNotNull(object, "Input column value should not be null.");
+ Tuple2<Integer, Integer> vectorSizeAndNnz = computeVectorSizeAndNnz(value);
+ int vectorSize = vectorSizeAndNnz.f0;
+ int nnz = vectorSizeAndNnz.f1;
+ Vector assembledVec =
+ nnz * RATIO > vectorSize
+ ? assembleDense(inputCols, value, vectorSize)
+ : assembleSparse(inputCols, value, vectorSize, nnz);
+ out.collect(Row.join(value, Row.of(assembledVec)));
+ } catch (Exception e) {
+ if (handleInvalid.equals(ERROR_INVALID)) {
+ throw new RuntimeException("Vector assembler failed with exception : " + e);
+ }
+ }
+ }
+
+ private Tuple2<Integer, Integer> computeVectorSizeAndNnz(Row value) {
+ int vectorSize = 0;
+ int nnz = 0;
+ for (int i = 0; i < inputCols.length; ++i) {
+ Object object = value.getField(inputCols[i]);
+ if (object != null) {
if (object instanceof Number) {
- nnz += 1;
+ checkSize(inputSizes[i], 1);
+ if (Double.isNaN(((Number) object).doubleValue()) && !keepInvalid) {
+ throw new RuntimeException(
+ "Encountered NaN while assembling a row with handleInvalid = 'error'. Consider "
+ + "removing NaNs from dataset or using handleInvalid = 'keep' or 'skip'.");
+ }
vectorSize += 1;
+ nnz += 1;
} else if (object instanceof SparseVector) {
+ int localSize = ((SparseVector) object).size();
+ checkSize(inputSizes[i], localSize);
nnz += ((SparseVector) object).indices.length;
- vectorSize += ((SparseVector) object).size();
+ vectorSize += localSize;
} else if (object instanceof DenseVector) {
+ int localSize = ((DenseVector) object).size();
+ checkSize(inputSizes[i], localSize);
+ vectorSize += localSize;
nnz += ((DenseVector) object).size();
- vectorSize += ((DenseVector) object).size();
} else {
throw new IllegalArgumentException(
- "Input type has not been supported yet.");
+ String.format(
+ "Input type %s has not been supported yet. Only Vector and Number types are supported.",
+ object.getClass()));
+ }
+ } else {
+ vectorSize += inputSizes[i];
+ nnz += inputSizes[i];
+ if (keepInvalid) {
+ if (inputSizes[i] > 1) {
+ DenseVector tmpVec = new DenseVector(inputSizes[i]);
+ for (int j = 0; j < inputSizes[i]; ++j) {
+ tmpVec.values[j] = Double.NaN;
+ }
+ value.setField(inputCols[i], tmpVec);
+ } else {
+ value.setField(inputCols[i], Double.NaN);
+ }
+ } else {
+ throw new RuntimeException(
+ "Input column value is null. Please check the input data or using handleInvalid = 'keep'.");
}
- }
- } catch (Exception e) {
- switch (handleInvalid) {
- case ERROR_INVALID:
- throw e;
- case SKIP_INVALID:
- return;
- case KEEP_INVALID:
- out.collect(Row.join(value, Row.of((Object) null)));
- return;
- default:
- throw new UnsupportedOperationException(
- "Unsupported " + HANDLE_INVALID + " type: " + handleInvalid);
}
}
+ return Tuple2.of(vectorSize, nnz);
+ }
- boolean toDense = nnz * RATIO > vectorSize;
- Vector assembledVec =
- toDense
- ? assembleDense(inputCols, value, vectorSize)
- : assembleSparse(inputCols, value, vectorSize, nnz);
- out.collect(Row.join(value, Row.of(assembledVec)));
+ private void checkSize(int expectedSize, int currentSize) {
+ if (keepInvalid) {
+ return;
+ }
+ if (currentSize != expectedSize) {
+ throw new IllegalArgumentException(
+ String.format(
+ "Input vector/number size does not meet with expected. Expected size: %d, actual size: %s.",
+ expectedSize, currentSize));
+ }
}
}
@@ -167,8 +222,7 @@ public class VectorAssembler
} else {
DenseVector denseVector = (DenseVector) object;
- System.arraycopy(
- denseVector.values, 0, values, currentOffset, denseVector.values.length);
+ System.arraycopy(denseVector.values, 0, values, currentOffset, denseVector.size());
currentOffset += denseVector.size();
}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
index cc3637e..d6b3c12 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
@@ -21,6 +21,9 @@ package org.apache.flink.ml.feature.vectorassembler;
import org.apache.flink.ml.common.param.HasHandleInvalid;
import org.apache.flink.ml.common.param.HasInputCols;
import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.IntArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidator;
/**
* Params of {@link VectorAssembler}.
@@ -28,4 +31,34 @@ import org.apache.flink.ml.common.param.HasOutputCol;
* @param <T> The class type of this instance.
*/
public interface VectorAssemblerParams<T>
- extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {}
+ extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {
+ Param<Integer[]> INPUT_SIZES =
+ new IntArrayParam(
+ "inputSizes",
+ "Sizes of the input elements to be assembled.",
+ null,
+ sizesValidator());
+
+ default Integer[] getInputSizes() {
+ return get(INPUT_SIZES);
+ }
+
+ default T setInputSizes(Integer... value) {
+ return set(INPUT_SIZES, value);
+ }
+
+ // Checks the inputSizes parameter.
+ static ParamValidator<Integer[]> sizesValidator() {
+ return inputSizes -> {
+ if (inputSizes == null) {
+ return false;
+ }
+ for (Integer size : inputSizes) {
+ if (size <= 0) {
+ return false;
+ }
+ }
+ return inputSizes.length != 0;
+ };
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
index 420844b..f22d013 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
@@ -45,13 +45,14 @@ import java.util.List;
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
/** Tests VectorAssembler. */
public class VectorAssemblerTest extends AbstractTestBase {
private StreamTableEnvironment tEnv;
private Table inputDataTable;
+ private Table inputNullDataTable;
+ private Table inputNanDataTable;
private static final List<Row> INPUT_DATA =
Arrays.asList(
@@ -66,24 +67,82 @@ public class VectorAssemblerTest extends AbstractTestBase {
1.0,
Vectors.sparse(
5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})),
- Row.of(2, null, null, null));
+ Row.of(
+ 2,
+ Vectors.dense(2.0, 2.1),
+ 1.0,
+ Vectors.sparse(
+ 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})));
+
+ private static final List<Row> INPUT_NAN_DATA =
+ Arrays.asList(
+ Row.of(
+ 0,
+ Vectors.dense(2.1, 3.1),
+ 1.0,
+ Vectors.sparse(5, new int[] {3}, new double[] {1.0})),
+ Row.of(
+ 1,
+ Vectors.dense(2.1, 3.1),
+ 1.0,
+ Vectors.sparse(
+ 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})),
+ Row.of(
+ 2,
+ Vectors.dense(2.0, 2.1),
+ Double.NaN,
+ Vectors.sparse(
+ 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})));
+
+ private static final List<Row> INPUT_NULL_DATA =
+ Arrays.asList(
+ Row.of(
+ 0,
+ Vectors.dense(2.1, 3.1),
+ 1.0,
+ Vectors.sparse(5, new int[] {3}, new double[] {1.0})),
+ Row.of(
+ 1,
+ null,
+ 1.0,
+ Vectors.sparse(
+ 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})),
+ Row.of(
+ 2,
+ null,
+ 1.0,
+ Vectors.sparse(
+ 5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})));
private static final SparseVector EXPECTED_OUTPUT_DATA_1 =
Vectors.sparse(8, new int[] {0, 1, 2, 6}, new double[] {2.1, 3.1, 1.0, 1.0});
private static final DenseVector EXPECTED_OUTPUT_DATA_2 =
Vectors.dense(2.1, 3.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0);
+ private static final DenseVector EXPECTED_OUTPUT_DATA_3 =
+ Vectors.dense(2.0, 2.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0);
+ private static final DenseVector EXPECTED_OUTPUT_DATA_4 =
+ Vectors.dense(Double.NaN, Double.NaN, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0);
+ private static final DenseVector EXPECTED_OUTPUT_DATA_5 =
+ Vectors.dense(2.0, 2.1, Double.NaN, 0.0, 1.0, 2.0, 3.0, 4.0);
@Before
public void before() {
Configuration config = new Configuration();
config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
- env.setParallelism(4);
+ env.setParallelism(2);
env.enableCheckpointing(100);
env.setRestartStrategy(RestartStrategies.noRestart());
+
tEnv = StreamTableEnvironment.create(env);
DataStream<Row> dataStream = env.fromCollection(INPUT_DATA);
inputDataTable = tEnv.fromDataStream(dataStream).as("id", "vec", "num", "sparseVec");
+ DataStream<Row> nullDataStream = env.fromCollection(INPUT_NULL_DATA);
+ inputNullDataTable =
+ tEnv.fromDataStream(nullDataStream).as("id", "vec", "num", "sparseVec");
+ DataStream<Row> nanDataStream = env.fromCollection(INPUT_NAN_DATA);
+ inputNanDataTable = tEnv.fromDataStream(nanDataStream).as("id", "vec", "num", "sparseVec");
}
private void verifyOutputResult(Table output, String outputCol, int outputSize)
@@ -91,13 +150,14 @@ public class VectorAssemblerTest extends AbstractTestBase {
DataStream<Row> dataStream = tEnv.toDataStream(output);
List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
assertEquals(outputSize, results.size());
+
for (Row result : results) {
if (result.getField(0) == (Object) 0) {
assertEquals(EXPECTED_OUTPUT_DATA_1, result.getField(outputCol));
} else if (result.getField(0) == (Object) 1) {
assertEquals(EXPECTED_OUTPUT_DATA_2, result.getField(outputCol));
- } else {
- assertNull(result.getField(outputCol));
+ } else if (result.getField(0) == (Object) 2) {
+ assertEquals(EXPECTED_OUTPUT_DATA_3, result.getField(outputCol));
}
}
}
@@ -107,35 +167,147 @@ public class VectorAssemblerTest extends AbstractTestBase {
VectorAssembler vectorAssembler = new VectorAssembler();
assertEquals(HasHandleInvalid.ERROR_INVALID, vectorAssembler.getHandleInvalid());
assertEquals("output", vectorAssembler.getOutputCol());
+
vectorAssembler
.setInputCols("vec", "num", "sparseVec")
.setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 5)
.setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+
assertArrayEquals(new String[] {"vec", "num", "sparseVec"}, vectorAssembler.getInputCols());
assertEquals(HasHandleInvalid.SKIP_INVALID, vectorAssembler.getHandleInvalid());
assertEquals("assembledVec", vectorAssembler.getOutputCol());
+ assertArrayEquals(new Integer[] {2, 1, 5}, vectorAssembler.getInputSizes());
}
@Test
- public void testKeepInvalid() throws Exception {
+ public void testOutputSchema() throws Exception {
VectorAssembler vectorAssembler =
new VectorAssembler()
- .setInputCols("vec", "num", "sparseVec")
+ .setInputCols("num")
.setOutputCol("assembledVec")
+ .setInputSizes(1)
.setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
Table output = vectorAssembler.transform(inputDataTable)[0];
+
assertEquals(
Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"),
output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testKeepInvalidWithNull() throws Exception {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 5)
+ .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+ Table output = vectorAssembler.transform(inputNullDataTable)[0];
+
+ DataStream<Row> dataStream = tEnv.toDataStream(output);
+ List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+ assertEquals(3, results.size());
+
+ String outputCol = vectorAssembler.getOutputCol();
+ for (Row result : results) {
+ if (result.getField(0) == (Object) 0) {
+ assertEquals(EXPECTED_OUTPUT_DATA_1, result.getField(outputCol));
+ } else if (result.getField(0) == (Object) 1) {
+ assertEquals(EXPECTED_OUTPUT_DATA_4, result.getField(outputCol));
+ } else if (result.getField(0) == (Object) 2) {
+ assertEquals(EXPECTED_OUTPUT_DATA_4, result.getField(outputCol));
+ }
+ }
+ }
+
+ @Test
+ public void testKeepInvalidWithNaN() throws Exception {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 5)
+ .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+ Table output = vectorAssembler.transform(inputNanDataTable)[0];
+ DataStream<Row> dataStream = tEnv.toDataStream(output);
+ List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+ assertEquals(3, results.size());
+
+ String outputCol = vectorAssembler.getOutputCol();
+ for (Row result : results) {
+ if (result.getField(0) == (Object) 0) {
+ assertEquals(EXPECTED_OUTPUT_DATA_1, result.getField(outputCol));
+ } else if (result.getField(0) == (Object) 1) {
+ assertEquals(EXPECTED_OUTPUT_DATA_2, result.getField(outputCol));
+ } else if (result.getField(0) == (Object) 2) {
+ assertEquals(EXPECTED_OUTPUT_DATA_5, result.getField(outputCol));
+ }
+ }
+ }
+
+ @Test
+ public void testKeepInvalidWithErrorSizes() throws Exception {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 4)
+ .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+ Table output = vectorAssembler.transform(inputDataTable)[0];
verifyOutputResult(output, vectorAssembler.getOutputCol(), 3);
}
@Test
- public void testErrorInvalid() {
+ public void testErrorInvalidWithNull() {
VectorAssembler vectorAssembler =
new VectorAssembler()
.setInputCols("vec", "num", "sparseVec")
.setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 5)
+ .setHandleInvalid(HasHandleInvalid.ERROR_INVALID);
+
+ try {
+ Table outputTable = vectorAssembler.transform(inputNullDataTable)[0];
+ outputTable.execute().collect().next();
+ Assert.fail("Expected IllegalArgumentException");
+ } catch (Throwable e) {
+ assertEquals(
+ "Vector assembler failed with exception : java.lang.RuntimeException: "
+ + "Input column value is null. Please check the input data or using handleInvalid = 'keep'.",
+ ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+
+ @Test
+ public void testErrorInvalidWithNaN() {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 5)
+ .setHandleInvalid(HasHandleInvalid.ERROR_INVALID);
+
+ try {
+ Table outputTable = vectorAssembler.transform(inputNanDataTable)[0];
+ outputTable.execute().collect().next();
+ Assert.fail("Expected IllegalArgumentException");
+ } catch (Throwable e) {
+ assertEquals(
+ "Vector assembler failed with exception : java.lang.RuntimeException: Encountered NaN "
+ + "while assembling a row with handleInvalid = 'error'. Consider removing NaNs from "
+ + "dataset or using handleInvalid = 'keep' or 'skip'.",
+ ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+
+ @Test
+ public void testErrorInvalidWithErrorSizes() throws Exception {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 4)
.setHandleInvalid(HasHandleInvalid.ERROR_INVALID);
try {
Table outputTable = vectorAssembler.transform(inputDataTable)[0];
@@ -143,37 +315,64 @@ public class VectorAssemblerTest extends AbstractTestBase {
Assert.fail("Expected IllegalArgumentException");
} catch (Throwable e) {
assertEquals(
- "Input column value should not be null.",
+ "Vector assembler failed with exception : java.lang.IllegalArgumentException: "
+ + "Input vector/number size does not meet with expected. Expected size: 4, actual size: 5.",
ExceptionUtils.getRootCause(e).getMessage());
}
}
@Test
- public void testSkipInvalid() throws Exception {
+ public void testSkipInvalidWithNull() throws Exception {
VectorAssembler vectorAssembler =
new VectorAssembler()
.setInputCols("vec", "num", "sparseVec")
.setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 5)
.setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
- Table output = vectorAssembler.transform(inputDataTable)[0];
- assertEquals(
- Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"),
- output.getResolvedSchema().getColumnNames());
+ Table output = vectorAssembler.transform(inputNullDataTable)[0];
+ verifyOutputResult(output, vectorAssembler.getOutputCol(), 1);
+ }
+
+ @Test
+ public void testSkipInvalidWithNaN() throws Exception {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 5)
+ .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+ Table output = vectorAssembler.transform(inputNanDataTable)[0];
+
verifyOutputResult(output, vectorAssembler.getOutputCol(), 2);
}
+ @Test
+ public void testSkipInvalidWithErrorSizes() throws Exception {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("vec", "num", "sparseVec")
+ .setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 4)
+ .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+ Table output = vectorAssembler.transform(inputDataTable)[0];
+ verifyOutputResult(output, vectorAssembler.getOutputCol(), 0);
+ }
+
@Test
public void testSaveLoadAndTransform() throws Exception {
VectorAssembler vectorAssembler =
new VectorAssembler()
.setInputCols("vec", "num", "sparseVec")
.setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 5)
.setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+
VectorAssembler loadedVectorAssembler =
TestUtils.saveAndReload(
tEnv, vectorAssembler, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+
Table output = loadedVectorAssembler.transform(inputDataTable)[0];
- verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 2);
+ verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 3);
}
@Test
@@ -189,11 +388,33 @@ public class VectorAssemblerTest extends AbstractTestBase {
new VectorAssembler()
.setInputCols("vec", "num", "sparseVec")
.setOutputCol("assembledVec")
+ .setInputSizes(2, 1, 5)
.setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+
VectorAssembler loadedVectorAssembler =
TestUtils.saveAndReload(
tEnv, vectorAssembler, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+
Table output = loadedVectorAssembler.transform(inputDataTable)[0];
- verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 2);
+ verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 3);
+ }
+
+ @Test
+ public void testNumber2Vector() throws Exception {
+ VectorAssembler vectorAssembler =
+ new VectorAssembler()
+ .setInputCols("num")
+ .setOutputCol("assembledVec")
+ .setInputSizes(1)
+ .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+ Table output = vectorAssembler.transform(inputDataTable)[0];
+
+ DataStream<Row> dataStream = tEnv.toDataStream(output);
+ List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+ for (Row result : results) {
+ if (result.getField(2) != null) {
+ assertEquals(result.getField(2), ((DenseVector) result.getField(4)).values[0]);
+ }
+ }
}
}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py b/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
index eb12679..a37fc52 100644
--- a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
+++ b/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
@@ -50,6 +50,7 @@ input_data_table = t_env.from_data_stream(
vector_assembler = VectorAssembler() \
.set_input_cols('vec', 'num', 'sparse_vec') \
.set_output_col('assembled_vec') \
+ .set_input_sizes(2, 1, 5) \
.set_handle_invalid('keep')
# use the vector assembler for feature engineering
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorassembler.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorassembler.py
index 5362fcf..bcc7dd4 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorassembler.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorassembler.py
@@ -55,16 +55,19 @@ class VectorAssemblerTest(PyFlinkMLTestCase):
vector_assembler.set_input_cols('vec', 'num', 'sparse_vec') \
.set_output_col('assembled_vec') \
+ .set_input_sizes(2, 1, 5) \
.set_handle_invalid('skip')
self.assertEqual(('vec', 'num', 'sparse_vec'), vector_assembler.input_cols)
- self.assertEqual('skip', vector_assembler.handle_invalid)
self.assertEqual('assembled_vec', vector_assembler.output_col)
+ self.assertEqual((2, 1, 5), vector_assembler.input_sizes)
+ self.assertEqual('skip', vector_assembler.handle_invalid)
def test_save_load_transform(self):
vector_assembler = VectorAssembler() \
.set_input_cols('vec', 'num', 'sparse_vec') \
.set_output_col('assembled_vec') \
+ .set_input_sizes(2, 1, 5) \
.set_handle_invalid('keep')
path = os.path.join(self.temp_dir, 'test_save_load_transform_vector_assembler')
diff --git a/flink-ml-python/pyflink/ml/lib/feature/vectorassembler.py b/flink-ml-python/pyflink/ml/lib/feature/vectorassembler.py
index 08bd689..04ae822 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/vectorassembler.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/vectorassembler.py
@@ -16,9 +16,11 @@
# limitations under the License.
################################################################################
+from typing import Tuple
from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.core.param import IntArrayParam, ParamValidator
from pyflink.ml.lib.feature.common import JavaFeatureTransformer
-from pyflink.ml.lib.param import HasInputCols, HasOutputCol, HasHandleInvalid
+from pyflink.ml.lib.param import HasInputCols, HasOutputCol, HasHandleInvalid, Param
class _VectorAssemblerParams(
@@ -27,21 +29,61 @@ class _VectorAssemblerParams(
HasOutputCol,
HasHandleInvalid
):
+
+ """
+ Checks the inputSizes parameter.
+ """
+ def SizesValidator(self) -> ParamValidator[Tuple[int]]:
+ class SizesValidator(ParamValidator[Tuple[int]]):
+ def validate(self, indices: Tuple[int]) -> bool:
+ if indices is None:
+ return False
+ for val in indices:
+ if val <= 0:
+ return False
+ return len(indices) != 0
+ return SizesValidator()
+
"""
Params for :class:`VectorAssembler`.
"""
+ INPUT_SIZES: Param[Tuple[int, ...]] = IntArrayParam(
+ "input_sizes",
+ "Sizes of the input elements to be assembled.",
+ None,
+ SizesValidator(None))
+
def __init__(self, java_params):
super(_VectorAssemblerParams, self).__init__(java_params)
+ def set_input_sizes(self, *sizes: int):
+ return self.set(self.INPUT_SIZES, sizes)
+
+ def get_input_sizes(self) -> Tuple[int, ...]:
+ return self.get(self.INPUT_SIZES)
+
+ @property
+ def input_sizes(self) -> Tuple[int, ...]:
+ return self.get_input_sizes()
+
class VectorAssembler(JavaFeatureTransformer, _VectorAssemblerParams):
"""
- A Transformer which combines a given list of input columns into a vector column. Types of
- input columns must be either vector or numerical value.
-
- The `keep` option of :class:HasHandleInvalid means that we output bad rows with output column
- set to null.
+ A Transformer which combines a given list of input columns into a vector column. Input columns
+ would be numerical or vectors whose sizes are specified by the :class:INPUT_SIZES parameter.
+ Invalid input data with null values or values with wrong sizes would be dealt with according to
+ the strategy specified by the :class:HasHandleInvalid parameter as follows:
+ <ul>
+ <li>keep: If the input column data is null, a vector would be created with the specified size
+ and NaN values. The vector would be used in the assembling process to represent the input
+ column data. If the input column data is a vector, the data would be used in the
+ assembling process even if it has a wrong size.
+ <li>skip: If the input column data is null or a vector with wrong size, the input row would
+ be filtered out and not be sent to downstream operators.
+ <li>error: If the input column data is null or a vector with wrong size, an exception would
+ be thrown.
+ </ul>
"""
def __init__(self, java_model=None):