You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/11/08 09:06:00 UTC

[flink-ml] branch master updated: [FLINK-29323] Add inputSizes parameter for VectorAssembler

This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new ae5c011  [FLINK-29323] Add inputSizes parameter for VectorAssembler
ae5c011 is described below

commit ae5c0115e6f271e4bb125675a4e04f5d672814d5
Author: weibo <wb...@pku.edu.cn>
AuthorDate: Tue Nov 8 17:05:55 2022 +0800

    [FLINK-29323] Add inputSizes parameter for VectorAssembler
    
    This closes #156.
---
 .../docs/operators/feature/vectorassembler.md      |  35 ++-
 .../examples/feature/VectorAssemblerExample.java   |   3 +-
 .../feature/vectorassembler/VectorAssembler.java   | 128 ++++++++---
 .../vectorassembler/VectorAssemblerParams.java     |  35 ++-
 .../flink/ml/feature/VectorAssemblerTest.java      | 253 +++++++++++++++++++--
 .../examples/ml/feature/vectorassembler_example.py |   1 +
 .../ml/lib/feature/tests/test_vectorassembler.py   |   5 +-
 .../pyflink/ml/lib/feature/vectorassembler.py      |  54 ++++-
 8 files changed, 441 insertions(+), 73 deletions(-)

diff --git a/docs/content/docs/operators/feature/vectorassembler.md b/docs/content/docs/operators/feature/vectorassembler.md
index 10f3fc3..84f064a 100644
--- a/docs/content/docs/operators/feature/vectorassembler.md
+++ b/docs/content/docs/operators/feature/vectorassembler.md
@@ -26,10 +26,20 @@ under the License.
 -->
 
 ## Vector Assembler
-
-Vector Assembler combines a given list of input columns into a vector column.
-Types of input columns must be either vector or numerical value.
-
+A Transformer which combines a given list of input columns into a vector column. Input columns
+would be numerical or vectors whose sizes are specified by the {@link #INPUT_SIZES} parameter.
+Invalid input data with null values or values with wrong sizes would be dealt with according to
+the strategy specified by the {@link HasHandleInvalid} parameter as follows:
+<ul>
+   <li>keep: If the input column data is null, a vector would be created with the specified size
+       and NaN values. The vector would be used in the assembling process to represent the input
+       column data. If the input column data is a vector, the data would be used in the assembling
+       process even if it has a wrong size.
+   <li>skip: If the input column data is null or a vector with wrong size, the input row would be
+       filtered out and not be sent to downstream operators.
+   <li>error: If the input column data is null or a vector with wrong size, an exception would be
+       thrown.
+</ul>
 ### Input Columns
 
 | Param name | Type          | Default | Description                     |
@@ -44,11 +54,12 @@ Types of input columns must be either vector or numerical value.
 
 ### Parameters
 
-| Key           | Default    | Type     | Required | Description                                                                    |
-|---------------|------------|----------|----------|--------------------------------------------------------------------------------|
-| inputCols     | `null`     | String[] | yes      | Input column names.                                                            |
-| outputCol     | `"output"` | String   | no       | Output column name.                                                            |
-| handleInvalid | `"error"`  | String   | no       | Strategy to handle invalid entries. Supported values: 'error', 'skip', 'keep'. |
+| Key             | Default    | Type      | Required | Description                                                                    |
+|-----------------|------------|-----------|----------|--------------------------------------------------------------------------------|
+| inputCols       | `null`     | String[]  | yes      | Input column names.                                                            |
+| outputCol       | `"output"` | String    | no       | Output column name.                                                            |
+| inputSizes      | `null`     | Integer[] | yes      | Sizes of the input elements to be assembled.                                   |
+| handleInvalid   | `"error"`  | String    | no       | Strategy to handle invalid entries. Supported values: 'error', 'skip', 'keep'. |
 
 ### Examples
 
@@ -95,7 +106,8 @@ public class VectorAssemblerExample {
         VectorAssembler vectorAssembler =
                 new VectorAssembler()
                         .setInputCols("vec", "num", "sparseVec")
-                        .setOutputCol("assembledVec");
+                        .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 5);
 
         // Uses the VectorAssembler object for feature transformations.
         Table outputTable = vectorAssembler.transform(inputTable)[0];
@@ -159,9 +171,10 @@ input_data_table = t_env.from_data_stream(
 vector_assembler = VectorAssembler() \
     .set_input_cols('vec', 'num', 'sparse_vec') \
     .set_output_col('assembled_vec') \
+    .set_input_sizes(2, 1, 5) \
     .set_handle_invalid('keep')
 
-# use the vector assembler model for feature engineering
+# use the vector assembler for feature engineering
 output = vector_assembler.transform(input_data_table)[0]
 
 # extract and display the results
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java
index 50e51c2..0c14625 100644
--- a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorAssemblerExample.java
@@ -56,7 +56,8 @@ public class VectorAssemblerExample {
         VectorAssembler vectorAssembler =
                 new VectorAssembler()
                         .setInputCols("vec", "num", "sparseVec")
-                        .setOutputCol("assembledVec");
+                        .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 5);
 
         // Uses the VectorAssembler object for feature transformations.
         Table outputTable = vectorAssembler.transform(inputTable)[0];
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
index e3a01f1..e951f80 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssembler.java
@@ -19,6 +19,7 @@
 package org.apache.flink.ml.feature.vectorassembler;
 
 import org.apache.flink.api.common.functions.FlatMapFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
 import org.apache.flink.api.java.typeutils.RowTypeInfo;
 import org.apache.flink.ml.api.Transformer;
 import org.apache.flink.ml.common.datastream.TableUtils;
@@ -46,11 +47,21 @@ import java.util.HashMap;
 import java.util.Map;
 
 /**
- * A Transformer which combines a given list of input columns into a vector column. Types of input
- * columns must be either vector or numerical value.
+ * A Transformer which combines a given list of input columns into a vector column. Input columns
+ * would be numerical or vectors whose sizes are specified by the {@link #INPUT_SIZES} parameter.
+ * Invalid input data with null values or values with wrong sizes would be dealt with according to
+ * the strategy specified by the {@link HasHandleInvalid} parameter as follows:
  *
- * <p>The `keep` option of {@link HasHandleInvalid} means that we output bad rows with output column
- * set to null.
+ * <ul>
+ *   <li>keep: If the input column data is null, a vector would be created with the specified size
+ *       and NaN values. The vector would be used in the assembling process to represent the input
+ *       column data. If the input column data is a vector, the data would be used in the assembling
+ *       process even if it has a wrong size.
+ *   <li>skip: If the input column data is null or a vector with wrong size, the input row would be
+ *       filtered out and not be sent to downstream operators.
+ *   <li>error: If the input column data is null or a vector with wrong size, an exception would be
+ *       thrown.
+ * </ul>
  */
 public class VectorAssembler
         implements Transformer<VectorAssembler>, VectorAssemblerParams<VectorAssembler> {
@@ -64,6 +75,7 @@ public class VectorAssembler
     @Override
     public Table[] transform(Table... inputs) {
         Preconditions.checkArgument(inputs.length == 1);
+        Preconditions.checkArgument(getInputSizes().length == getInputCols().length);
         StreamTableEnvironment tEnv =
                 (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
         RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
@@ -74,64 +86,107 @@ public class VectorAssembler
         DataStream<Row> output =
                 tEnv.toDataStream(inputs[0])
                         .flatMap(
-                                new AssemblerFunc(getInputCols(), getHandleInvalid()),
+                                new AssemblerFunction(
+                                        getInputCols(), getHandleInvalid(), getInputSizes()),
                                 outputTypeInfo);
         Table outputTable = tEnv.fromDataStream(output);
         return new Table[] {outputTable};
     }
 
-    private static class AssemblerFunc implements FlatMapFunction<Row, Row> {
+    private static class AssemblerFunction implements FlatMapFunction<Row, Row> {
         private final String[] inputCols;
         private final String handleInvalid;
+        private final Integer[] inputSizes;
+        private final boolean keepInvalid;
 
-        public AssemblerFunc(String[] inputCols, String handleInvalid) {
+        public AssemblerFunction(String[] inputCols, String handleInvalid, Integer[] inputSizes) {
             this.inputCols = inputCols;
             this.handleInvalid = handleInvalid;
+            this.inputSizes = inputSizes;
+            keepInvalid = handleInvalid.equals(HasHandleInvalid.KEEP_INVALID);
         }
 
         @Override
         public void flatMap(Row value, Collector<Row> out) {
-            int nnz = 0;
-            int vectorSize = 0;
             try {
-                for (String inputCol : inputCols) {
-                    Object object = value.getField(inputCol);
-                    Preconditions.checkNotNull(object, "Input column value should not be null.");
+                Tuple2<Integer, Integer> vectorSizeAndNnz = computeVectorSizeAndNnz(value);
+                int vectorSize = vectorSizeAndNnz.f0;
+                int nnz = vectorSizeAndNnz.f1;
+                Vector assembledVec =
+                        nnz * RATIO > vectorSize
+                                ? assembleDense(inputCols, value, vectorSize)
+                                : assembleSparse(inputCols, value, vectorSize, nnz);
+                out.collect(Row.join(value, Row.of(assembledVec)));
+            } catch (Exception e) {
+                if (handleInvalid.equals(ERROR_INVALID)) {
+                    throw new RuntimeException("Vector assembler failed with exception : " + e);
+                }
+            }
+        }
+
+        private Tuple2<Integer, Integer> computeVectorSizeAndNnz(Row value) {
+            int vectorSize = 0;
+            int nnz = 0;
+            for (int i = 0; i < inputCols.length; ++i) {
+                Object object = value.getField(inputCols[i]);
+                if (object != null) {
                     if (object instanceof Number) {
-                        nnz += 1;
+                        checkSize(inputSizes[i], 1);
+                        if (Double.isNaN(((Number) object).doubleValue()) && !keepInvalid) {
+                            throw new RuntimeException(
+                                    "Encountered NaN while assembling a row with handleInvalid = 'error'. Consider "
+                                            + "removing NaNs from dataset or using handleInvalid = 'keep' or 'skip'.");
+                        }
                         vectorSize += 1;
+                        nnz += 1;
                     } else if (object instanceof SparseVector) {
+                        int localSize = ((SparseVector) object).size();
+                        checkSize(inputSizes[i], localSize);
                         nnz += ((SparseVector) object).indices.length;
-                        vectorSize += ((SparseVector) object).size();
+                        vectorSize += localSize;
                     } else if (object instanceof DenseVector) {
+                        int localSize = ((DenseVector) object).size();
+                        checkSize(inputSizes[i], localSize);
+                        vectorSize += localSize;
                         nnz += ((DenseVector) object).size();
-                        vectorSize += ((DenseVector) object).size();
                     } else {
                         throw new IllegalArgumentException(
-                                "Input type has not been supported yet.");
+                                String.format(
+                                        "Input type %s has not been supported yet. Only Vector and Number types are supported.",
+                                        object.getClass()));
+                    }
+                } else {
+                    vectorSize += inputSizes[i];
+                    nnz += inputSizes[i];
+                    if (keepInvalid) {
+                        if (inputSizes[i] > 1) {
+                            DenseVector tmpVec = new DenseVector(inputSizes[i]);
+                            for (int j = 0; j < inputSizes[i]; ++j) {
+                                tmpVec.values[j] = Double.NaN;
+                            }
+                            value.setField(inputCols[i], tmpVec);
+                        } else {
+                            value.setField(inputCols[i], Double.NaN);
+                        }
+                    } else {
+                        throw new RuntimeException(
+                                "Input column value is null. Please check the input data or using handleInvalid = 'keep'.");
                     }
-                }
-            } catch (Exception e) {
-                switch (handleInvalid) {
-                    case ERROR_INVALID:
-                        throw e;
-                    case SKIP_INVALID:
-                        return;
-                    case KEEP_INVALID:
-                        out.collect(Row.join(value, Row.of((Object) null)));
-                        return;
-                    default:
-                        throw new UnsupportedOperationException(
-                                "Unsupported " + HANDLE_INVALID + " type: " + handleInvalid);
                 }
             }
+            return Tuple2.of(vectorSize, nnz);
+        }
 
-            boolean toDense = nnz * RATIO > vectorSize;
-            Vector assembledVec =
-                    toDense
-                            ? assembleDense(inputCols, value, vectorSize)
-                            : assembleSparse(inputCols, value, vectorSize, nnz);
-            out.collect(Row.join(value, Row.of(assembledVec)));
+        private void checkSize(int expectedSize, int currentSize) {
+            if (keepInvalid) {
+                return;
+            }
+            if (currentSize != expectedSize) {
+                throw new IllegalArgumentException(
+                        String.format(
+                                "Input vector/number size does not meet with expected. Expected size: %d, actual size: %s.",
+                                expectedSize, currentSize));
+            }
         }
     }
 
@@ -167,8 +222,7 @@ public class VectorAssembler
 
             } else {
                 DenseVector denseVector = (DenseVector) object;
-                System.arraycopy(
-                        denseVector.values, 0, values, currentOffset, denseVector.values.length);
+                System.arraycopy(denseVector.values, 0, values, currentOffset, denseVector.size());
 
                 currentOffset += denseVector.size();
             }
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
index cc3637e..d6b3c12 100644
--- a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorassembler/VectorAssemblerParams.java
@@ -21,6 +21,9 @@ package org.apache.flink.ml.feature.vectorassembler;
 import org.apache.flink.ml.common.param.HasHandleInvalid;
 import org.apache.flink.ml.common.param.HasInputCols;
 import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.IntArrayParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidator;
 
 /**
  * Params of {@link VectorAssembler}.
@@ -28,4 +31,34 @@ import org.apache.flink.ml.common.param.HasOutputCol;
  * @param <T> The class type of this instance.
  */
 public interface VectorAssemblerParams<T>
-        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {}
+        extends HasInputCols<T>, HasOutputCol<T>, HasHandleInvalid<T> {
+    Param<Integer[]> INPUT_SIZES =
+            new IntArrayParam(
+                    "inputSizes",
+                    "Sizes of the input elements to be assembled.",
+                    null,
+                    sizesValidator());
+
+    default Integer[] getInputSizes() {
+        return get(INPUT_SIZES);
+    }
+
+    default T setInputSizes(Integer... value) {
+        return set(INPUT_SIZES, value);
+    }
+
+    // Checks the inputSizes parameter.
+    static ParamValidator<Integer[]> sizesValidator() {
+        return inputSizes -> {
+            if (inputSizes == null) {
+                return false;
+            }
+            for (Integer size : inputSizes) {
+                if (size <= 0) {
+                    return false;
+                }
+            }
+            return inputSizes.length != 0;
+        };
+    }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
index 420844b..f22d013 100644
--- a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorAssemblerTest.java
@@ -45,13 +45,14 @@ import java.util.List;
 
 import static org.junit.Assert.assertArrayEquals;
 import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.assertNull;
 
 /** Tests VectorAssembler. */
 public class VectorAssemblerTest extends AbstractTestBase {
 
     private StreamTableEnvironment tEnv;
     private Table inputDataTable;
+    private Table inputNullDataTable;
+    private Table inputNanDataTable;
 
     private static final List<Row> INPUT_DATA =
             Arrays.asList(
@@ -66,24 +67,82 @@ public class VectorAssemblerTest extends AbstractTestBase {
                             1.0,
                             Vectors.sparse(
                                     5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})),
-                    Row.of(2, null, null, null));
+                    Row.of(
+                            2,
+                            Vectors.dense(2.0, 2.1),
+                            1.0,
+                            Vectors.sparse(
+                                    5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})));
+
+    private static final List<Row> INPUT_NAN_DATA =
+            Arrays.asList(
+                    Row.of(
+                            0,
+                            Vectors.dense(2.1, 3.1),
+                            1.0,
+                            Vectors.sparse(5, new int[] {3}, new double[] {1.0})),
+                    Row.of(
+                            1,
+                            Vectors.dense(2.1, 3.1),
+                            1.0,
+                            Vectors.sparse(
+                                    5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})),
+                    Row.of(
+                            2,
+                            Vectors.dense(2.0, 2.1),
+                            Double.NaN,
+                            Vectors.sparse(
+                                    5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})));
+
+    private static final List<Row> INPUT_NULL_DATA =
+            Arrays.asList(
+                    Row.of(
+                            0,
+                            Vectors.dense(2.1, 3.1),
+                            1.0,
+                            Vectors.sparse(5, new int[] {3}, new double[] {1.0})),
+                    Row.of(
+                            1,
+                            null,
+                            1.0,
+                            Vectors.sparse(
+                                    5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})),
+                    Row.of(
+                            2,
+                            null,
+                            1.0,
+                            Vectors.sparse(
+                                    5, new int[] {4, 2, 3, 1}, new double[] {4.0, 2.0, 3.0, 1.0})));
 
     private static final SparseVector EXPECTED_OUTPUT_DATA_1 =
             Vectors.sparse(8, new int[] {0, 1, 2, 6}, new double[] {2.1, 3.1, 1.0, 1.0});
     private static final DenseVector EXPECTED_OUTPUT_DATA_2 =
             Vectors.dense(2.1, 3.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0);
+    private static final DenseVector EXPECTED_OUTPUT_DATA_3 =
+            Vectors.dense(2.0, 2.1, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0);
+    private static final DenseVector EXPECTED_OUTPUT_DATA_4 =
+            Vectors.dense(Double.NaN, Double.NaN, 1.0, 0.0, 1.0, 2.0, 3.0, 4.0);
+    private static final DenseVector EXPECTED_OUTPUT_DATA_5 =
+            Vectors.dense(2.0, 2.1, Double.NaN, 0.0, 1.0, 2.0, 3.0, 4.0);
 
     @Before
     public void before() {
         Configuration config = new Configuration();
         config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+
         StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
-        env.setParallelism(4);
+        env.setParallelism(2);
         env.enableCheckpointing(100);
         env.setRestartStrategy(RestartStrategies.noRestart());
+
         tEnv = StreamTableEnvironment.create(env);
         DataStream<Row> dataStream = env.fromCollection(INPUT_DATA);
         inputDataTable = tEnv.fromDataStream(dataStream).as("id", "vec", "num", "sparseVec");
+        DataStream<Row> nullDataStream = env.fromCollection(INPUT_NULL_DATA);
+        inputNullDataTable =
+                tEnv.fromDataStream(nullDataStream).as("id", "vec", "num", "sparseVec");
+        DataStream<Row> nanDataStream = env.fromCollection(INPUT_NAN_DATA);
+        inputNanDataTable = tEnv.fromDataStream(nanDataStream).as("id", "vec", "num", "sparseVec");
     }
 
     private void verifyOutputResult(Table output, String outputCol, int outputSize)
@@ -91,13 +150,14 @@ public class VectorAssemblerTest extends AbstractTestBase {
         DataStream<Row> dataStream = tEnv.toDataStream(output);
         List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
         assertEquals(outputSize, results.size());
+
         for (Row result : results) {
             if (result.getField(0) == (Object) 0) {
                 assertEquals(EXPECTED_OUTPUT_DATA_1, result.getField(outputCol));
             } else if (result.getField(0) == (Object) 1) {
                 assertEquals(EXPECTED_OUTPUT_DATA_2, result.getField(outputCol));
-            } else {
-                assertNull(result.getField(outputCol));
+            } else if (result.getField(0) == (Object) 2) {
+                assertEquals(EXPECTED_OUTPUT_DATA_3, result.getField(outputCol));
             }
         }
     }
@@ -107,35 +167,147 @@ public class VectorAssemblerTest extends AbstractTestBase {
         VectorAssembler vectorAssembler = new VectorAssembler();
         assertEquals(HasHandleInvalid.ERROR_INVALID, vectorAssembler.getHandleInvalid());
         assertEquals("output", vectorAssembler.getOutputCol());
+
         vectorAssembler
                 .setInputCols("vec", "num", "sparseVec")
                 .setOutputCol("assembledVec")
+                .setInputSizes(2, 1, 5)
                 .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+
         assertArrayEquals(new String[] {"vec", "num", "sparseVec"}, vectorAssembler.getInputCols());
         assertEquals(HasHandleInvalid.SKIP_INVALID, vectorAssembler.getHandleInvalid());
         assertEquals("assembledVec", vectorAssembler.getOutputCol());
+        assertArrayEquals(new Integer[] {2, 1, 5}, vectorAssembler.getInputSizes());
     }
 
     @Test
-    public void testKeepInvalid() throws Exception {
+    public void testOutputSchema() throws Exception {
         VectorAssembler vectorAssembler =
                 new VectorAssembler()
-                        .setInputCols("vec", "num", "sparseVec")
+                        .setInputCols("num")
                         .setOutputCol("assembledVec")
+                        .setInputSizes(1)
                         .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
         Table output = vectorAssembler.transform(inputDataTable)[0];
+
         assertEquals(
                 Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"),
                 output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testKeepInvalidWithNull() throws Exception {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("vec", "num", "sparseVec")
+                        .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 5)
+                        .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+        Table output = vectorAssembler.transform(inputNullDataTable)[0];
+
+        DataStream<Row> dataStream = tEnv.toDataStream(output);
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        assertEquals(3, results.size());
+
+        String outputCol = vectorAssembler.getOutputCol();
+        for (Row result : results) {
+            if (result.getField(0) == (Object) 0) {
+                assertEquals(EXPECTED_OUTPUT_DATA_1, result.getField(outputCol));
+            } else if (result.getField(0) == (Object) 1) {
+                assertEquals(EXPECTED_OUTPUT_DATA_4, result.getField(outputCol));
+            } else if (result.getField(0) == (Object) 2) {
+                assertEquals(EXPECTED_OUTPUT_DATA_4, result.getField(outputCol));
+            }
+        }
+    }
+
+    @Test
+    public void testKeepInvalidWithNaN() throws Exception {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("vec", "num", "sparseVec")
+                        .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 5)
+                        .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+        Table output = vectorAssembler.transform(inputNanDataTable)[0];
+        DataStream<Row> dataStream = tEnv.toDataStream(output);
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        assertEquals(3, results.size());
+
+        String outputCol = vectorAssembler.getOutputCol();
+        for (Row result : results) {
+            if (result.getField(0) == (Object) 0) {
+                assertEquals(EXPECTED_OUTPUT_DATA_1, result.getField(outputCol));
+            } else if (result.getField(0) == (Object) 1) {
+                assertEquals(EXPECTED_OUTPUT_DATA_2, result.getField(outputCol));
+            } else if (result.getField(0) == (Object) 2) {
+                assertEquals(EXPECTED_OUTPUT_DATA_5, result.getField(outputCol));
+            }
+        }
+    }
+
+    @Test
+    public void testKeepInvalidWithErrorSizes() throws Exception {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("vec", "num", "sparseVec")
+                        .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 4)
+                        .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+        Table output = vectorAssembler.transform(inputDataTable)[0];
         verifyOutputResult(output, vectorAssembler.getOutputCol(), 3);
     }
 
     @Test
-    public void testErrorInvalid() {
+    public void testErrorInvalidWithNull() {
         VectorAssembler vectorAssembler =
                 new VectorAssembler()
                         .setInputCols("vec", "num", "sparseVec")
                         .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 5)
+                        .setHandleInvalid(HasHandleInvalid.ERROR_INVALID);
+
+        try {
+            Table outputTable = vectorAssembler.transform(inputNullDataTable)[0];
+            outputTable.execute().collect().next();
+            Assert.fail("Expected IllegalArgumentException");
+        } catch (Throwable e) {
+            assertEquals(
+                    "Vector assembler failed with exception : java.lang.RuntimeException: "
+                            + "Input column value is null. Please check the input data or using handleInvalid = 'keep'.",
+                    ExceptionUtils.getRootCause(e).getMessage());
+        }
+    }
+
+    @Test
+    public void testErrorInvalidWithNaN() {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("vec", "num", "sparseVec")
+                        .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 5)
+                        .setHandleInvalid(HasHandleInvalid.ERROR_INVALID);
+
+        try {
+            Table outputTable = vectorAssembler.transform(inputNanDataTable)[0];
+            outputTable.execute().collect().next();
+            Assert.fail("Expected IllegalArgumentException");
+        } catch (Throwable e) {
+            assertEquals(
+                    "Vector assembler failed with exception : java.lang.RuntimeException: Encountered NaN "
+                            + "while assembling a row with handleInvalid = 'error'. Consider removing NaNs from "
+                            + "dataset or using handleInvalid = 'keep' or 'skip'.",
+                    ExceptionUtils.getRootCause(e).getMessage());
+        }
+    }
+
+    @Test
+    public void testErrorInvalidWithErrorSizes() throws Exception {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("vec", "num", "sparseVec")
+                        .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 4)
                         .setHandleInvalid(HasHandleInvalid.ERROR_INVALID);
         try {
             Table outputTable = vectorAssembler.transform(inputDataTable)[0];
@@ -143,37 +315,64 @@ public class VectorAssemblerTest extends AbstractTestBase {
             Assert.fail("Expected IllegalArgumentException");
         } catch (Throwable e) {
             assertEquals(
-                    "Input column value should not be null.",
+                    "Vector assembler failed with exception : java.lang.IllegalArgumentException: "
+                            + "Input vector/number size does not meet with expected. Expected size: 4, actual size: 5.",
                     ExceptionUtils.getRootCause(e).getMessage());
         }
     }
 
     @Test
-    public void testSkipInvalid() throws Exception {
+    public void testSkipInvalidWithNull() throws Exception {
         VectorAssembler vectorAssembler =
                 new VectorAssembler()
                         .setInputCols("vec", "num", "sparseVec")
                         .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 5)
                         .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
-        Table output = vectorAssembler.transform(inputDataTable)[0];
-        assertEquals(
-                Arrays.asList("id", "vec", "num", "sparseVec", "assembledVec"),
-                output.getResolvedSchema().getColumnNames());
+        Table output = vectorAssembler.transform(inputNullDataTable)[0];
+        verifyOutputResult(output, vectorAssembler.getOutputCol(), 1);
+    }
+
+    @Test
+    public void testSkipInvalidWithNaN() throws Exception {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("vec", "num", "sparseVec")
+                        .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 5)
+                        .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+        Table output = vectorAssembler.transform(inputNanDataTable)[0];
+
         verifyOutputResult(output, vectorAssembler.getOutputCol(), 2);
     }
 
+    @Test
+    public void testSkipInvalidWithErrorSizes() throws Exception {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("vec", "num", "sparseVec")
+                        .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 4)
+                        .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+        Table output = vectorAssembler.transform(inputDataTable)[0];
+        verifyOutputResult(output, vectorAssembler.getOutputCol(), 0);
+    }
+
     @Test
     public void testSaveLoadAndTransform() throws Exception {
         VectorAssembler vectorAssembler =
                 new VectorAssembler()
                         .setInputCols("vec", "num", "sparseVec")
                         .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 5)
                         .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+
         VectorAssembler loadedVectorAssembler =
                 TestUtils.saveAndReload(
                         tEnv, vectorAssembler, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+
         Table output = loadedVectorAssembler.transform(inputDataTable)[0];
-        verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 2);
+        verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 3);
     }
 
     @Test
@@ -189,11 +388,33 @@ public class VectorAssemblerTest extends AbstractTestBase {
                 new VectorAssembler()
                         .setInputCols("vec", "num", "sparseVec")
                         .setOutputCol("assembledVec")
+                        .setInputSizes(2, 1, 5)
                         .setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+
         VectorAssembler loadedVectorAssembler =
                 TestUtils.saveAndReload(
                         tEnv, vectorAssembler, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+
         Table output = loadedVectorAssembler.transform(inputDataTable)[0];
-        verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 2);
+        verifyOutputResult(output, loadedVectorAssembler.getOutputCol(), 3);
+    }
+
+    @Test
+    public void testNumber2Vector() throws Exception {
+        VectorAssembler vectorAssembler =
+                new VectorAssembler()
+                        .setInputCols("num")
+                        .setOutputCol("assembledVec")
+                        .setInputSizes(1)
+                        .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+        Table output = vectorAssembler.transform(inputDataTable)[0];
+
+        DataStream<Row> dataStream = tEnv.toDataStream(output);
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        for (Row result : results) {
+            if (result.getField(2) != null) {
+                assertEquals(result.getField(2), ((DenseVector) result.getField(4)).values[0]);
+            }
+        }
     }
 }
diff --git a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py b/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
index eb12679..a37fc52 100644
--- a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
+++ b/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
@@ -50,6 +50,7 @@ input_data_table = t_env.from_data_stream(
 vector_assembler = VectorAssembler() \
     .set_input_cols('vec', 'num', 'sparse_vec') \
     .set_output_col('assembled_vec') \
+    .set_input_sizes(2, 1, 5) \
     .set_handle_invalid('keep')
 
 # use the vector assembler for feature engineering
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorassembler.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorassembler.py
index 5362fcf..bcc7dd4 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorassembler.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorassembler.py
@@ -55,16 +55,19 @@ class VectorAssemblerTest(PyFlinkMLTestCase):
 
         vector_assembler.set_input_cols('vec', 'num', 'sparse_vec') \
             .set_output_col('assembled_vec') \
+            .set_input_sizes(2, 1, 5) \
             .set_handle_invalid('skip')
 
         self.assertEqual(('vec', 'num', 'sparse_vec'), vector_assembler.input_cols)
-        self.assertEqual('skip', vector_assembler.handle_invalid)
         self.assertEqual('assembled_vec', vector_assembler.output_col)
+        self.assertEqual((2, 1, 5), vector_assembler.input_sizes)
+        self.assertEqual('skip', vector_assembler.handle_invalid)
 
     def test_save_load_transform(self):
         vector_assembler = VectorAssembler() \
             .set_input_cols('vec', 'num', 'sparse_vec') \
             .set_output_col('assembled_vec') \
+            .set_input_sizes(2, 1, 5) \
             .set_handle_invalid('keep')
 
         path = os.path.join(self.temp_dir, 'test_save_load_transform_vector_assembler')
diff --git a/flink-ml-python/pyflink/ml/lib/feature/vectorassembler.py b/flink-ml-python/pyflink/ml/lib/feature/vectorassembler.py
index 08bd689..04ae822 100644
--- a/flink-ml-python/pyflink/ml/lib/feature/vectorassembler.py
+++ b/flink-ml-python/pyflink/ml/lib/feature/vectorassembler.py
@@ -16,9 +16,11 @@
 # limitations under the License.
 ################################################################################
 
+from typing import Tuple
 from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.core.param import IntArrayParam, ParamValidator
 from pyflink.ml.lib.feature.common import JavaFeatureTransformer
-from pyflink.ml.lib.param import HasInputCols, HasOutputCol, HasHandleInvalid
+from pyflink.ml.lib.param import HasInputCols, HasOutputCol, HasHandleInvalid, Param
 
 
 class _VectorAssemblerParams(
@@ -27,21 +29,61 @@ class _VectorAssemblerParams(
     HasOutputCol,
     HasHandleInvalid
 ):
+
+    """
+    Checks the inputSizes parameter.
+    """
+    def SizesValidator(self) -> ParamValidator[Tuple[int]]:
+        class SizesValidator(ParamValidator[Tuple[int]]):
+            def validate(self, indices: Tuple[int]) -> bool:
+                if indices is None:
+                    return False
+                for val in indices:
+                    if val <= 0:
+                        return False
+                return len(indices) != 0
+        return SizesValidator()
+
     """
     Params for :class:`VectorAssembler`.
     """
 
+    INPUT_SIZES: Param[Tuple[int, ...]] = IntArrayParam(
+        "input_sizes",
+        "Sizes of the input elements to be assembled.",
+        None,
+        SizesValidator(None))
+
     def __init__(self, java_params):
         super(_VectorAssemblerParams, self).__init__(java_params)
 
+    def set_input_sizes(self, *sizes: int):
+        return self.set(self.INPUT_SIZES, sizes)
+
+    def get_input_sizes(self) -> Tuple[int, ...]:
+        return self.get(self.INPUT_SIZES)
+
+    @property
+    def input_sizes(self) -> Tuple[int, ...]:
+        return self.get_input_sizes()
+
 
 class VectorAssembler(JavaFeatureTransformer, _VectorAssemblerParams):
     """
-    A Transformer which combines a given list of input columns into a vector column. Types of
-    input columns must be either vector or numerical value.
-
-    The `keep` option of :class:HasHandleInvalid means that we output bad rows with output column
-    set to null.
+     A Transformer which combines a given list of input columns into a vector column. Input columns
+     would be numerical or vectors whose sizes are specified by the :class:INPUT_SIZES parameter.
+     Invalid input data with null values or values with wrong sizes would be dealt with according to
+     the strategy specified by the :class:HasHandleInvalid parameter as follows:
+     <ul>
+       <li>keep: If the input column data is null, a vector would be created with the specified size
+           and NaN values. The vector would be used in the assembling process to represent the input
+           column data. If the input column data is a vector, the data would be used in the
+           assembling process even if it has a wrong size.
+       <li>skip: If the input column data is null or a vector with wrong size, the input row would
+           be filtered out and not be sent to downstream operators.
+       <li>error: If the input column data is null or a vector with wrong size, an exception would
+           be thrown.
+     </ul>
     """
 
     def __init__(self, java_model=None):