You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/07/26 02:02:45 UTC

[flink-ml] branch master updated: [FLINK-28501] Add Transformer and Estimator for VectorIndexer

This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new e31101f  [FLINK-28501] Add Transformer and Estimator for VectorIndexer
e31101f is described below

commit e31101fdb614ea6366111f761e3176c5c2ab6ef0
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Tue Jul 26 10:02:41 2022 +0800

    [FLINK-28501] Add Transformer and Estimator for VectorIndexer
    
    This closes #134.
---
 .../org/apache/flink/ml/linalg/DenseVector.java    |   5 +
 .../org/apache/flink/ml/linalg/SparseVector.java   |  24 +-
 .../java/org/apache/flink/ml/linalg/Vector.java    |   3 +
 .../apache/flink/ml/linalg/DenseVectorTest.java    |  10 +
 .../apache/flink/ml/linalg/SparseVectorTest.java   |  13 +
 .../ml/examples/feature/VectorIndexerExample.java  |  81 ++++++
 .../ml/feature/vectorindexer/VectorIndexer.java    | 255 ++++++++++++++++++
 .../feature/vectorindexer/VectorIndexerModel.java  | 200 +++++++++++++++
 .../vectorindexer/VectorIndexerModelData.java      | 133 ++++++++++
 .../vectorindexer/VectorIndexerModelParams.java    |  34 +--
 .../feature/vectorindexer/VectorIndexerParams.java |  46 ++++
 .../apache/flink/ml/feature/VectorIndexerTest.java | 285 +++++++++++++++++++++
 .../examples/ml/feature/vectorindexer_example.py   |  80 ++++++
 flink-ml-python/pyflink/ml/core/linalg.py          |  27 +-
 .../pyflink/ml/core/tests/test_linalg.py           |  15 ++
 .../ml/lib/feature/tests/test_vectorindexer.py     | 103 ++++++++
 .../pyflink/ml/lib/feature/vectorindexer.py        | 127 +++++++++
 17 files changed, 1415 insertions(+), 26 deletions(-)

diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java
index b4f0603..ff08503 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/DenseVector.java
@@ -46,6 +46,11 @@ public class DenseVector implements Vector {
         return values[i];
     }
 
+    @Override
+    public void set(int i, double value) {
+        values[i] = value;
+    }
+
     @Override
     public double[] toArray() {
         return values;
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java
index f241e30..3c043fc 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/SparseVector.java
@@ -29,8 +29,8 @@ import java.util.Objects;
 @TypeInfo(SparseVectorTypeInfoFactory.class)
 public class SparseVector implements Vector {
     public final int n;
-    public final int[] indices;
-    public final double[] values;
+    public int[] indices;
+    public double[] values;
 
     public SparseVector(int n, int[] indices, double[] values) {
         this.n = n;
@@ -56,6 +56,26 @@ public class SparseVector implements Vector {
         return 0.;
     }
 
+    @Override
+    public void set(int i, double value) {
+        int pos = Arrays.binarySearch(indices, i);
+        if (pos >= 0) {
+            values[pos] = value;
+        } else if (value != 0.0) {
+            Preconditions.checkArgument(i < n, "Index out of bounds: " + i);
+            int[] indices = new int[this.indices.length + 1];
+            double[] values = new double[this.indices.length + 1];
+            System.arraycopy(this.indices, 0, indices, 0, -pos - 1);
+            System.arraycopy(this.values, 0, values, 0, -pos - 1);
+            indices[-pos - 1] = i;
+            values[-pos - 1] = value;
+            System.arraycopy(this.indices, -pos - 1, indices, -pos, this.indices.length + pos + 1);
+            System.arraycopy(this.values, -pos - 1, values, -pos, this.indices.length + pos + 1);
+            this.indices = indices;
+            this.values = values;
+        }
+    }
+
     @Override
     public double[] toArray() {
         double[] result = new double[n];
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java
index 21718b3..976e9a7 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java
@@ -29,6 +29,9 @@ public interface Vector extends Serializable {
     /** Gets the value of the ith element. */
     double get(int i);
 
+    /** Sets the value of the ith element. */
+    void set(int i, double value);
+
     /** Converts the instance to a double array. */
     double[] toArray();
 
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java
index 12ebc61..427403d 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/DenseVectorTest.java
@@ -21,6 +21,7 @@ package org.apache.flink.ml.linalg;
 import org.junit.Test;
 
 import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
 
 /** Tests the behavior of {@link DenseVector}. */
 public class DenseVectorTest {
@@ -37,4 +38,13 @@ public class DenseVectorTest {
         assertArrayEquals(denseVec.values, new double[] {1, 2, 3}, TOLERANCE);
         assertArrayEquals(clonedDenseVec.values, new double[] {-1, 2, 3}, TOLERANCE);
     }
+
+    @Test
+    public void testGetAndSet() {
+        DenseVector denseVec = Vectors.dense(1, 2, 3);
+        assertEquals(1, denseVec.get(0), TOLERANCE);
+
+        denseVec.set(0, 2);
+        assertEquals(2, denseVec.get(0), TOLERANCE);
+    }
 }
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java
index 163586d..916963f 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/linalg/SparseVectorTest.java
@@ -147,4 +147,17 @@ public class SparseVectorTest {
         assertArrayEquals(clonedSparseVec.indices, new int[] {0, 2});
         assertArrayEquals(clonedSparseVec.values, new double[] {-1, 3}, TOLERANCE);
     }
+
+    @Test
+    public void testGetAndSet() {
+        SparseVector sparseVec = Vectors.sparse(4, new int[] {2}, new double[] {0.3});
+        assertEquals(0, sparseVec.get(0), TOLERANCE);
+        assertEquals(0.3, sparseVec.get(2), TOLERANCE);
+
+        sparseVec.set(2, 0.5);
+        assertEquals(0.5, sparseVec.get(2), TOLERANCE);
+
+        sparseVec.set(0, 0.1);
+        assertEquals(0.1, sparseVec.get(0), TOLERANCE);
+    }
 }
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorIndexerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorIndexerExample.java
new file mode 100644
index 0000000..b80506e
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/VectorIndexerExample.java
@@ -0,0 +1,81 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.feature.vectorindexer.VectorIndexer;
+import org.apache.flink.ml.feature.vectorindexer.VectorIndexerModel;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+import java.util.List;
+
+/** Simple program that creates a VectorIndexer instance and uses it for feature engineering. */
+public class VectorIndexerExample {
+    public static void main(String[] args) {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input data.
+        List<Row> trainInput =
+                Arrays.asList(
+                        Row.of(Vectors.dense(1, 1)),
+                        Row.of(Vectors.dense(2, -1)),
+                        Row.of(Vectors.dense(3, 1)),
+                        Row.of(Vectors.dense(4, 0)),
+                        Row.of(Vectors.dense(5, 0)));
+
+        List<Row> predictInput =
+                Arrays.asList(
+                        Row.of(Vectors.dense(0, 2)),
+                        Row.of(Vectors.dense(0, 0)),
+                        Row.of(Vectors.dense(0, -1)));
+
+        Table trainTable = tEnv.fromDataStream(env.fromCollection(trainInput)).as("input");
+        Table predictTable = tEnv.fromDataStream(env.fromCollection(predictInput)).as("input");
+
+        // Creates a VectorIndexer object and initializes its parameters.
+        VectorIndexer vectorIndexer =
+                new VectorIndexer()
+                        .setInputCol("input")
+                        .setOutputCol("output")
+                        .setHandleInvalid(HasHandleInvalid.KEEP_INVALID)
+                        .setMaxCategories(3);
+
+        // Trains the VectorIndexer Model.
+        VectorIndexerModel model = vectorIndexer.fit(trainTable);
+
+        // Uses the VectorIndexer Model for predictions.
+        Table outputTable = model.transform(predictTable)[0];
+
+        // Extracts and displays the results.
+        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+            Row row = it.next();
+            System.out.printf(
+                    "Input Value: %s \tOutput Value: %s\n",
+                    row.getField(vectorIndexer.getInputCol()),
+                    row.getField(vectorIndexer.getOutputCol()));
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
new file mode 100644
index 0000000..74d6a6b
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexer.java
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorindexer;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.functions.ReduceFunction;
+import org.apache.flink.api.common.state.ListState;
+import org.apache.flink.api.common.state.ListStateDescriptor;
+import org.apache.flink.api.common.typeinfo.TypeHint;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.iteration.operator.OperatorStateUtils;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.runtime.state.StateInitializationContext;
+import org.apache.flink.runtime.state.StateSnapshotContext;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.operators.AbstractStreamOperator;
+import org.apache.flink.streaming.api.operators.BoundedOneInput;
+import org.apache.flink.streaming.api.operators.OneInputStreamOperator;
+import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Map;
+
+/**
+ * An Estimator which implements the vector indexing algorithm.
+ *
+ * <p>A vector indexer maps each column of the input vector into a continuous/categorical feature.
+ * Whether one feature is transformed into a continuous or categorical feature depends on the number
+ * of distinct values in this column. If the number of distinct values in one column is greater than
+ * a specified parameter (i.e., maxCategories), the corresponding output column is unchanged.
+ * Otherwise, it is transformed into a categorical value. For categorical outputs, the indices are
+ * in [0, numDistinctValuesInThisColumn].
+ *
+ * <p>The output model is organized in ascending order except that 0.0 is always mapped to 0 (for
+ * sparsity). We list two examples here:
+ *
+ * <ul>
+ *   <li>If one column contains {-1.0, 1.0}, then -1.0 should be encoded as 0 and 1.0 will be
+ *       encoded as 1.
+ *   <li>If one column contains {-1.0, 0.0, 1.0}, then -1.0 should be encoded as 1, 0.0 should be
+ *       encoded as 0 and 1.0 should be encoded as 2.
+ * </ul>
+ *
+ * <p>The `keep` option of {@link HasHandleInvalid} means that we put the invalid entries in a
+ * special bucket, whose index is the number of distinct values in this column.
+ */
+public class VectorIndexer
+        implements Estimator<VectorIndexer, VectorIndexerModel>,
+                VectorIndexerParams<VectorIndexer> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public VectorIndexer() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public VectorIndexerModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        int maxCategories = getMaxCategories();
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<HashSet<Double>[]> localDistinctDoubles =
+                tEnv.toDataStream(inputs[0])
+                        .transform(
+                                "computeDistinctDoublesOperator",
+                                TypeInformation.of(new TypeHint<HashSet<Double>[]>() {}),
+                                new ComputeDistinctDoublesOperator(getInputCol(), maxCategories));
+
+        DataStream<HashSet<Double>[]> distinctDoubles =
+                DataStreamUtils.reduce(
+                        localDistinctDoubles,
+                        (ReduceFunction<HashSet<Double>[]>)
+                                (value1, value2) -> {
+                                    for (int i = 0; i < value1.length; i++) {
+                                        if (value1[i] == null || value2[i] == null) {
+                                            value1[i] = null;
+                                        } else {
+                                            value1[i].addAll(value2[i]);
+                                        }
+                                    }
+                                    return value1;
+                                });
+
+        DataStream<VectorIndexerModelData> modelData =
+                distinctDoubles.map(new ModelGenerator(maxCategories));
+        modelData.getTransformation().setParallelism(1);
+
+        VectorIndexerModel model =
+                new VectorIndexerModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, paramMap);
+        return model;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static VectorIndexer load(StreamTableEnvironment tEnv, String path) throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    /**
+     * Computes the distinct doubles by columns. If the number of distinct values in one column is
+     * greater than maxCategories, the corresponding returned HashSet is null.
+     */
+    private static class ComputeDistinctDoublesOperator
+            extends AbstractStreamOperator<HashSet<Double>[]>
+            implements OneInputStreamOperator<Row, HashSet<Double>[]>, BoundedOneInput {
+        /** The name of input column. */
+        private final String inputCol;
+        /** Max number of categories. */
+        private final int maxCategories;
+        /** The distinct doubles of each column. */
+        private HashSet<Double>[] doublesByColumn;
+        /** The state of doublesByColumn. */
+        private ListState<HashSet<Double>[]> doublesByColumnState;
+
+        public ComputeDistinctDoublesOperator(String inputCol, int maxCategories) {
+            this.inputCol = inputCol;
+            this.maxCategories = maxCategories;
+        }
+
+        @Override
+        public void endInput() {
+            if (doublesByColumn != null) {
+                output.collect(new StreamRecord<>(doublesByColumn));
+            }
+            doublesByColumnState.clear();
+        }
+
+        @Override
+        public void processElement(StreamRecord<Row> element) {
+            if (doublesByColumn == null) {
+                // First record.
+                Vector vector = (Vector) element.getValue().getField(inputCol);
+                doublesByColumn = new HashSet[vector.size()];
+                for (int i = 0; i < doublesByColumn.length; i++) {
+                    doublesByColumn[i] = new HashSet<>();
+                }
+            }
+
+            Vector vector = (Vector) element.getValue().getField(inputCol);
+            Preconditions.checkState(
+                    vector.size() == doublesByColumn.length,
+                    "The size of the all input vectors should be the same.");
+            double[] values = vector.toDense().values;
+            for (int i = 0; i < values.length; i++) {
+                if (doublesByColumn[i] != null) {
+                    doublesByColumn[i].add(values[i]);
+                    if (doublesByColumn[i].size() > maxCategories) {
+                        doublesByColumn[i] = null;
+                    }
+                }
+            }
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            doublesByColumnState =
+                    context.getOperatorStateStore()
+                            .getListState(
+                                    new ListStateDescriptor<>(
+                                            "doublesByColumnState",
+                                            TypeInformation.of(
+                                                    new TypeHint<HashSet<Double>[]>() {})));
+
+            OperatorStateUtils.getUniqueElement(doublesByColumnState, "doublesByColumnState")
+                    .ifPresent(x -> doublesByColumn = x);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            doublesByColumnState.update(Collections.singletonList(doublesByColumn));
+        }
+    }
+
+    /**
+     * Merges all the distinct doubles by columns and generates the {@link VectorIndexerModelData}.
+     */
+    private static class ModelGenerator
+            implements MapFunction<HashSet<Double>[], VectorIndexerModelData> {
+        private final int maxCategories;
+
+        public ModelGenerator(int maxCategories) {
+            this.maxCategories = maxCategories;
+        }
+
+        @Override
+        public VectorIndexerModelData map(HashSet<Double>[] distinctDoubles) {
+            Map<Integer, Map<Double, Integer>> categoryMaps = new HashMap<>();
+            for (int i = 0; i < distinctDoubles.length; i++) {
+                if (distinctDoubles[i] != null && distinctDoubles[i].size() <= maxCategories) {
+                    double[] values =
+                            distinctDoubles[i].stream().mapToDouble(Double::doubleValue).toArray();
+                    Arrays.sort(values);
+                    // If 0 exists, we put it as the first element.
+                    int index0 = Arrays.binarySearch(values, 0);
+                    while (index0 > 0) {
+                        values[index0] = values[--index0];
+                    }
+                    if (index0 == 0) {
+                        values[index0] = 0;
+                    }
+                    Map<Double, Integer> valueAndIndex = new HashMap<>(values.length);
+                    for (int valueIdx = 0; valueIdx < values.length; valueIdx++) {
+                        valueAndIndex.put(values[valueIdx], valueIdx);
+                    }
+                    categoryMaps.put(i, valueAndIndex);
+                }
+            }
+
+            return new VectorIndexerModelData(categoryMaps);
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java
new file mode 100644
index 0000000..3becd23
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModel.java
@@ -0,0 +1,200 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorindexer;
+
+import org.apache.flink.api.common.functions.RichFlatMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Collector;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * A Model which encodes input vector to an output vector using the model data computed by {@link
+ * VectorIndexer}.
+ *
+ * <p>The `keep` option of {@link HasHandleInvalid} means that we put the invalid entries in a
+ * special bucket, whose index is the number of distinct values in this column.
+ */
+public class VectorIndexerModel
+        implements Model<VectorIndexerModel>, VectorIndexerModelParams<VectorIndexerModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public VectorIndexerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked, rawtypes")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        String inputCol = getInputCol();
+        String outputCol = getOutputCol();
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), outputCol));
+
+        final String broadcastModelKey = "broadcastModelKey";
+        DataStream<VectorIndexerModelData> modelDataStream =
+                VectorIndexerModelData.getModelDataStream(modelDataTable);
+
+        DataStream<Row> result =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(tEnv.toDataStream(inputs[0])),
+                        Collections.singletonMap(broadcastModelKey, modelDataStream),
+                        inputList -> {
+                            DataStream inputData = inputList.get(0);
+                            return inputData.flatMap(
+                                    new FindIndex(broadcastModelKey, inputCol, getHandleInvalid()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(result)};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                VectorIndexerModelData.getModelDataStream(modelDataTable),
+                path,
+                new VectorIndexerModelData.ModelDataEncoder());
+    }
+
+    public static VectorIndexerModel load(StreamTableEnvironment tEnv, String path)
+            throws IOException {
+        VectorIndexerModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(
+                        tEnv, path, new VectorIndexerModelData.ModelDataDecoder());
+        return model.setModelData(modelDataTable);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public VectorIndexerModel setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    /** Finds the index for the input vector using the model data. */
+    private static class FindIndex extends RichFlatMapFunction<Row, Row> {
+        private final String broadcastModelKey;
+        private final String inputCol;
+        private final String handleInValid;
+        private Map<Integer, Map<Double, Integer>> categoryMaps;
+
+        public FindIndex(String broadcastModelKey, String inputCol, String handleInValid) {
+            this.broadcastModelKey = broadcastModelKey;
+            this.inputCol = inputCol;
+            this.handleInValid = handleInValid;
+        }
+
+        @Override
+        public void flatMap(Row input, Collector<Row> out) {
+            if (categoryMaps == null) {
+                VectorIndexerModelData modelData =
+                        (VectorIndexerModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastModelKey).get(0);
+                categoryMaps = modelData.categoryMaps;
+            }
+
+            Vector outputVector = ((Vector) input.getField(inputCol)).clone();
+            for (Map.Entry<Integer, Map<Double, Integer>> entry : categoryMaps.entrySet()) {
+                int columnId = entry.getKey();
+                Map<Double, Integer> mapping = entry.getValue();
+                double feature = outputVector.get(columnId);
+                Integer categoricalFeature = getMapping(feature, mapping, handleInValid);
+                if (categoricalFeature == null) {
+                    return;
+                } else {
+                    outputVector.set(columnId, categoricalFeature);
+                }
+            }
+
+            out.collect(Row.join(input, Row.of(outputVector)));
+        }
+    }
+
+    /**
+     * Maps the input feature to a categorical value using the mappings.
+     *
+     * @param feature The input continuous feature.
+     * @param mapping The mappings from continues features to categorical features.
+     * @param handleInValid The way to handle invalid features.
+     * @return The categorical value. Returns null if invalid values are skipped.
+     */
+    private static Integer getMapping(
+            double feature, Map<Double, Integer> mapping, String handleInValid) {
+        if (mapping.containsKey(feature)) {
+            return mapping.get(feature);
+        } else {
+            switch (handleInValid) {
+                case SKIP_INVALID:
+                    return null;
+                case ERROR_INVALID:
+                    throw new RuntimeException(
+                            "The input contains unseen double: "
+                                    + feature
+                                    + ". See "
+                                    + HANDLE_INVALID
+                                    + " parameter for more options.");
+                case KEEP_INVALID:
+                    return mapping.size();
+                default:
+                    throw new UnsupportedOperationException(
+                            "Unsupported " + HANDLE_INVALID + "type: " + handleInValid);
+            }
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelData.java
new file mode 100644
index 0000000..2d896eb
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelData.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorindexer;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.DoubleSerializer;
+import org.apache.flink.api.common.typeutils.base.IntSerializer;
+import org.apache.flink.api.common.typeutils.base.MapSerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+import java.util.Map;
+
+/**
+ * Model data of {@link VectorIndexerModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to DataStream, and classes
+ * to save/load model data.
+ */
+public class VectorIndexerModelData {
+    /**
+     * Index of feature values. Keys are column indices. Values are mapping from original continuous
+     * features values to 0-based categorical indices. If a feature is not in this map, it is
+     * treated as a continuous feature.
+     */
+    public Map<Integer, Map<Double, Integer>> categoryMaps;
+
+    public VectorIndexerModelData(Map<Integer, Map<Double, Integer>> categoryMaps) {
+        this.categoryMaps = categoryMaps;
+    }
+
+    public VectorIndexerModelData() {}
+
+    /**
+     * Converts the table model to a data stream.
+     *
+     * @param modelData The table model data.
+     * @return The data stream model data.
+     */
+    public static DataStream<VectorIndexerModelData> getModelDataStream(Table modelData) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) modelData).getTableEnvironment();
+        return tEnv.toDataStream(modelData)
+                .map(
+                        x ->
+                                new VectorIndexerModelData(
+                                        (Map<Integer, Map<Double, Integer>>) x.getField(0)));
+    }
+
+    /** Data encoder for {@link VectorIndexerModel}. */
+    public static class ModelDataEncoder implements Encoder<VectorIndexerModelData> {
+
+        @Override
+        public void encode(VectorIndexerModelData modelData, OutputStream outputStream)
+                throws IOException {
+            DataOutputViewStreamWrapper outputViewStreamWrapper =
+                    new DataOutputViewStreamWrapper(outputStream);
+
+            MapSerializer<Integer, Map<Double, Integer>> mapSerializer =
+                    new MapSerializer<>(
+                            IntSerializer.INSTANCE,
+                            new MapSerializer<>(DoubleSerializer.INSTANCE, IntSerializer.INSTANCE));
+
+            mapSerializer.serialize(modelData.categoryMaps, outputViewStreamWrapper);
+        }
+    }
+
+    /** Data decoder for {@link VectorIndexerModel}. */
+    public static class ModelDataDecoder extends SimpleStreamFormat<VectorIndexerModelData> {
+
+        @Override
+        public Reader<VectorIndexerModelData> createReader(
+                Configuration configuration, FSDataInputStream inputStream) {
+            return new Reader<VectorIndexerModelData>() {
+
+                @Override
+                public VectorIndexerModelData read() throws IOException {
+                    try {
+                        DataInputViewStreamWrapper inputViewStreamWrapper =
+                                new DataInputViewStreamWrapper(inputStream);
+                        MapSerializer<Integer, Map<Double, Integer>> mapSerializer =
+                                new MapSerializer<>(
+                                        IntSerializer.INSTANCE,
+                                        new MapSerializer<>(
+                                                DoubleSerializer.INSTANCE, IntSerializer.INSTANCE));
+                        Map<Integer, Map<Double, Integer>> categoryMaps =
+                                mapSerializer.deserialize(inputViewStreamWrapper);
+                        return new VectorIndexerModelData(categoryMaps);
+                    } catch (EOFException e) {
+                        return null;
+                    }
+                }
+
+                @Override
+                public void close() throws IOException {
+                    inputStream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<VectorIndexerModelData> getProducedType() {
+            return TypeInformation.of(VectorIndexerModelData.class);
+        }
+    }
+}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelParams.java
similarity index 58%
copy from flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java
copy to flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelParams.java
index 21718b3..e9e89ee 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/linalg/Vector.java
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerModelParams.java
@@ -16,28 +16,16 @@
  * limitations under the License.
  */
 
-package org.apache.flink.ml.linalg;
+package org.apache.flink.ml.feature.vectorindexer;
 
-import java.io.Serializable;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
 
-/** A vector of double values. */
-public interface Vector extends Serializable {
-
-    /** Gets the size of the vector. */
-    int size();
-
-    /** Gets the value of the ith element. */
-    double get(int i);
-
-    /** Converts the instance to a double array. */
-    double[] toArray();
-
-    /** Converts the instance to a dense vector. */
-    DenseVector toDense();
-
-    /** Converts the instance to a sparse vector. */
-    SparseVector toSparse();
-
-    /** Makes a deep copy of the vector. */
-    Vector clone();
-}
+/**
+ * Params for {@link VectorIndexerModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface VectorIndexerModelParams<T>
+        extends HasInputCol<T>, HasOutputCol<T>, HasHandleInvalid<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerParams.java
new file mode 100644
index 0000000..9af8b8a
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/vectorindexer/VectorIndexerParams.java
@@ -0,0 +1,46 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.vectorindexer;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link VectorIndexer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface VectorIndexerParams<T> extends VectorIndexerModelParams<T> {
+    Param<Integer> MAX_CATEGORIES =
+            new IntParam(
+                    "maxCategories",
+                    "Threshold for the number of values a categorical feature can take (>= 2). "
+                            + "If a feature is found to have > maxCategories values, then it is declared continuous.",
+                    20,
+                    ParamValidators.gtEq(2));
+
+    default T setMaxCategories(int value) {
+        return set(MAX_CATEGORIES, value);
+    }
+
+    default int getMaxCategories() {
+        return get(MAX_CATEGORIES);
+    }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
new file mode 100644
index 0000000..17ffda5
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/VectorIndexerTest.java
@@ -0,0 +1,285 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.common.param.HasHandleInvalid;
+import org.apache.flink.ml.feature.vectorindexer.VectorIndexer;
+import org.apache.flink.ml.feature.vectorindexer.VectorIndexerModel;
+import org.apache.flink.ml.feature.vectorindexer.VectorIndexerModelData;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Expressions;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.Arrays;
+import java.util.Collections;
+import java.util.Comparator;
+import java.util.HashMap;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests {@link VectorIndexer} and {@link VectorIndexerModel}. */
+public class VectorIndexerTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table trainInputTable;
+    private Table testInputTable;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        List<Row> trainInput =
+                Arrays.asList(
+                        Row.of(Vectors.dense(1, 1)),
+                        Row.of(Vectors.dense(2, -1)),
+                        Row.of(Vectors.dense(3, 1)),
+                        Row.of(Vectors.dense(4, 0)),
+                        Row.of(Vectors.dense(5, 0)));
+        List<Row> testInput =
+                Arrays.asList(
+                        Row.of(Vectors.dense(0, 2)),
+                        Row.of(Vectors.dense(0, 0)),
+                        Row.of(Vectors.dense(0, -1)));
+        trainInputTable = tEnv.fromDataStream(env.fromCollection(trainInput)).as("input");
+        testInputTable = tEnv.fromDataStream(env.fromCollection(testInput)).as("input");
+    }
+
+    @Test
+    public void testParam() {
+        VectorIndexer vectorIndexer = new VectorIndexer();
+        assertEquals("input", vectorIndexer.getInputCol());
+        assertEquals("output", vectorIndexer.getOutputCol());
+        assertEquals(20, vectorIndexer.getMaxCategories());
+        assertEquals(HasHandleInvalid.ERROR_INVALID, vectorIndexer.getHandleInvalid());
+
+        vectorIndexer
+                .setInputCol("test_input")
+                .setOutputCol("test_output")
+                .setMaxCategories(3)
+                .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+
+        assertEquals("test_input", vectorIndexer.getInputCol());
+        assertEquals("test_output", vectorIndexer.getOutputCol());
+        assertEquals(3, vectorIndexer.getMaxCategories());
+        assertEquals(HasHandleInvalid.KEEP_INVALID, vectorIndexer.getHandleInvalid());
+    }
+
+    @Test
+    public void testOutputSchema() {
+        VectorIndexer vectorIndexer = new VectorIndexer();
+        Table output = vectorIndexer.fit(trainInputTable).transform(trainInputTable)[0];
+
+        assertEquals(
+                Arrays.asList(vectorIndexer.getInputCol(), vectorIndexer.getOutputCol()),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testFitAndPredictOnSparseInput() throws Exception {
+        List<Row> sparseTrainInput =
+                Arrays.asList(
+                        Row.of(Vectors.sparse(2, new int[] {0}, new double[] {1})),
+                        Row.of(Vectors.sparse(2, new int[] {0, 1}, new double[] {2, -1})),
+                        Row.of(Vectors.sparse(2, new int[] {0, 1}, new double[] {3, 1})),
+                        Row.of(Vectors.sparse(2, new int[] {0}, new double[] {4})),
+                        Row.of(Vectors.sparse(2, new int[] {0}, new double[] {5})));
+
+        List<Row> sparseTestInput =
+                Collections.singletonList(
+                        Row.of(Vectors.sparse(2, new int[] {0, 1}, new double[] {0, 2})));
+        Table sparseTrainTable =
+                tEnv.fromDataStream(env.fromCollection(sparseTrainInput)).as("input");
+        Table sparseTestTable =
+                tEnv.fromDataStream(env.fromCollection(sparseTestInput)).as("input");
+
+        Table output =
+                new VectorIndexer()
+                        .setHandleInvalid(HasHandleInvalid.KEEP_INVALID)
+                        .setMaxCategories(3)
+                        .fit(sparseTrainTable)
+                        .transform(sparseTestTable)[0];
+
+        List<Row> expectedOutput =
+                Collections.singletonList(
+                        Row.of(Vectors.sparse(2, new int[] {0, 1}, new double[] {0, 3})));
+        verifyPredictionResult(expectedOutput, output, "output");
+    }
+
+    @Test
+    public void testFitAndPredictWithLargeMaxCategories() throws Exception {
+        VectorIndexer vectorIndexer =
+                new VectorIndexer()
+                        .setMaxCategories(Integer.MAX_VALUE)
+                        .setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+
+        Table output = vectorIndexer.fit(trainInputTable).transform(testInputTable)[0];
+        List<Row> expectedOutput =
+                Arrays.asList(
+                        Row.of(Vectors.dense(5, 3)),
+                        Row.of(Vectors.dense(5, 0)),
+                        Row.of(Vectors.dense(5, 1)));
+        verifyPredictionResult(expectedOutput, output, vectorIndexer.getOutputCol());
+    }
+
+    @Test
+    public void testFitAndPredictWithHandleInvalid() throws Exception {
+        Table output;
+        List<Row> expectedOutput;
+        VectorIndexer vectorIndexer = new VectorIndexer().setMaxCategories(3);
+
+        // Keeps invalid data.
+        expectedOutput =
+                Arrays.asList(
+                        Row.of(Vectors.dense(0, 3)),
+                        Row.of(Vectors.dense(0, 0)),
+                        Row.of(Vectors.dense(0, 1)));
+        vectorIndexer.setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+        output = vectorIndexer.fit(trainInputTable).transform(testInputTable)[0];
+        verifyPredictionResult(expectedOutput, output, vectorIndexer.getOutputCol());
+
+        // Skips invalid data.
+        vectorIndexer.setHandleInvalid(HasHandleInvalid.SKIP_INVALID);
+        expectedOutput = Arrays.asList(Row.of(Vectors.dense(0, 0)), Row.of(Vectors.dense(0, 1)));
+        output = vectorIndexer.fit(trainInputTable).transform(testInputTable)[0];
+        verifyPredictionResult(expectedOutput, output, vectorIndexer.getOutputCol());
+
+        // Throws an exception on invalid data.
+        vectorIndexer.setHandleInvalid(HasHandleInvalid.ERROR_INVALID);
+        try {
+            output = vectorIndexer.fit(trainInputTable).transform(testInputTable)[0];
+            IteratorUtils.toList(tEnv.toDataStream(output).executeAndCollect());
+            fail();
+        } catch (Throwable e) {
+            assertEquals(
+                    "The input contains unseen double: 2.0. "
+                            + "See "
+                            + HasHandleInvalid.HANDLE_INVALID
+                            + " parameter for more options.",
+                    ExceptionUtils.getRootCause(e).getMessage());
+        }
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        VectorIndexer vectorIndexer =
+                new VectorIndexer().setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+        vectorIndexer =
+                TestUtils.saveAndReload(
+                        tEnv, vectorIndexer, tempFolder.newFolder().getAbsolutePath());
+
+        VectorIndexerModel model = vectorIndexer.fit(trainInputTable);
+        model = TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+
+        assertEquals(
+                Collections.singletonList("categoryMaps"),
+                model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+        Table output = model.transform(testInputTable)[0];
+        List<Row> expectedOutput =
+                Arrays.asList(
+                        Row.of(Vectors.dense(5, 3)),
+                        Row.of(Vectors.dense(5, 0)),
+                        Row.of(Vectors.dense(5, 1)));
+        verifyPredictionResult(expectedOutput, output, vectorIndexer.getOutputCol());
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testGetModelData() throws Exception {
+        VectorIndexer vectorIndexer = new VectorIndexer().setMaxCategories(3);
+        VectorIndexerModel model = vectorIndexer.fit(trainInputTable);
+        Table modelDataTable = model.getModelData()[0];
+
+        assertEquals(
+                Collections.singletonList("categoryMaps"),
+                modelDataTable.getResolvedSchema().getColumnNames());
+
+        List<VectorIndexerModelData> collectedModelData =
+                (List<VectorIndexerModelData>)
+                        IteratorUtils.toList(
+                                VectorIndexerModelData.getModelDataStream(modelDataTable)
+                                        .executeAndCollect());
+
+        assertEquals(1, collectedModelData.size());
+        HashMap<Double, Integer> column1ModelData = new HashMap<>();
+        column1ModelData.put(-1.0, 1);
+        column1ModelData.put(0.0, 0);
+        column1ModelData.put(1.0, 2);
+        assertEquals(
+                Collections.singletonMap(1, column1ModelData),
+                collectedModelData.get(0).categoryMaps);
+    }
+
+    @Test
+    public void testSetModelData() throws Exception {
+        VectorIndexer vectorIndexer =
+                new VectorIndexer().setHandleInvalid(HasHandleInvalid.KEEP_INVALID);
+        VectorIndexerModel model = vectorIndexer.fit(trainInputTable);
+
+        VectorIndexerModel newModel = new VectorIndexerModel();
+        ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+        newModel.setModelData(model.getModelData());
+        Table output = newModel.transform(testInputTable)[0];
+
+        List<Row> expectedOutput =
+                Arrays.asList(
+                        Row.of(Vectors.dense(5, 3)),
+                        Row.of(Vectors.dense(5, 0)),
+                        Row.of(Vectors.dense(5, 1)));
+        verifyPredictionResult(expectedOutput, output, newModel.getOutputCol());
+    }
+
+    @SuppressWarnings("unchecked")
+    private void verifyPredictionResult(List<Row> expectedOutput, Table output, String outputCol)
+            throws Exception {
+        List<Row> collectedResult =
+                IteratorUtils.toList(
+                        tEnv.toDataStream(output.select(Expressions.$(outputCol)))
+                                .executeAndCollect());
+        compareResultCollections(
+                expectedOutput,
+                collectedResult,
+                Comparator.comparingInt(o -> (o.getField(0)).hashCode()));
+    }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/vectorindexer_example.py b/flink-ml-python/pyflink/examples/ml/feature/vectorindexer_example.py
new file mode 100644
index 0000000..a913444
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/vectorindexer_example.py
@@ -0,0 +1,80 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that trains a StringIndexer model and uses it for feature
+# engineering.
+#
+# Before executing this program, please make sure you have followed Flink ML's
+# quick start guideline to set up Flink ML and Flink environment. The guideline
+# can be found at
+#
+# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
+
+from pyflink.common import Types
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.vectorindexer import VectorIndexer
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input training and prediction data.
+train_table = t_env.from_data_stream(
+    env.from_collection([
+        (Vectors.dense(1, 1),),
+        (Vectors.dense(2, -1),),
+        (Vectors.dense(3, 1),),
+        (Vectors.dense(4, 0),),
+        (Vectors.dense(5, 0),)
+    ],
+        type_info=Types.ROW_NAMED(
+            ['input', ],
+            [DenseVectorTypeInfo(), ])))
+
+predict_table = t_env.from_data_stream(
+    env.from_collection([
+        (Vectors.dense(0, 2),),
+        (Vectors.dense(0, 0),),
+        (Vectors.dense(0, -1),),
+    ],
+        type_info=Types.ROW_NAMED(
+            ['input', ],
+            [DenseVectorTypeInfo(), ])))
+
+# Creates a VectorIndexer object and initializes its parameters.
+vector_indexer = VectorIndexer() \
+    .set_input_col('input') \
+    .set_output_col('output') \
+    .set_handle_invalid('keep') \
+    .set_max_categories(3)
+
+# Trains the VectorIndexer Model.
+model = vector_indexer.fit(train_table)
+
+# Uses the VectorIndexer Model for predictions.
+output = model.transform(predict_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+    print('Input Value: ' + str(result[field_names.index(vector_indexer.get_input_col())])
+          + '\tOutput Value: ' + str(result[field_names.index(vector_indexer.get_output_col())]))
diff --git a/flink-ml-python/pyflink/ml/core/linalg.py b/flink-ml-python/pyflink/ml/core/linalg.py
index 196a8f4..71619a3 100644
--- a/flink-ml-python/pyflink/ml/core/linalg.py
+++ b/flink-ml-python/pyflink/ml/core/linalg.py
@@ -305,6 +305,9 @@ class DenseVector(Vector):
     def get(self, i: int):
         return self._values[i]
 
+    def set(self, i: int, value: np.float64):
+        self._values[i] = value
+
     def to_array(self) -> np.ndarray:
         return self._values
 
@@ -481,7 +484,29 @@ class SparseVector(Vector):
         return self._size
 
     def get(self, i: int):
-        return self._values[self._indices.searchsorted(i)]
+        idx = self._indices.searchsorted(i)
+        if idx < len(self._indices) and self._indices[idx] == i:
+            return self._values[idx]
+        else:
+            return 0.0
+
+    def set(self, i: int, value: np.float64):
+        idx = self._indices.searchsorted(i)
+        if idx < len(self._indices) and self._indices[idx] == i:
+            self._values[idx] = value
+        elif value != 0:
+            assert i < self._size
+            cur_len = len(self._indices)
+            indices = np.zeros(cur_len + 1, dtype=np.int32)
+            values = np.zeros(cur_len + 1, dtype=np.float64)
+            indices[0:idx] = self._indices[0:idx]
+            values[0:idx] = self._values[0:idx]
+            indices[idx] = i
+            values[idx] = value
+            indices[idx + 1:] = self._indices[idx:]
+            values[idx + 1:] = self._values[idx]
+            self._indices = indices
+            self._values = values
 
     def to_array(self) -> np.ndarray:
         """
diff --git a/flink-ml-python/pyflink/ml/core/tests/test_linalg.py b/flink-ml-python/pyflink/ml/core/tests/test_linalg.py
index 2095b4a..3ea015c 100644
--- a/flink-ml-python/pyflink/ml/core/tests/test_linalg.py
+++ b/flink-ml-python/pyflink/ml/core/tests/test_linalg.py
@@ -77,3 +77,18 @@ class VectorTests(unittest.TestCase):
         self.assertFalse(v2 == v4)
         self.assertFalse(v1 == v5)
         self.assertFalse(v1 == v6)
+
+    def test_get_set(self):
+        v1 = DenseVector([0.0, 1.0, 0.0, 5.5])
+        self.assertEqual(0.0, v1.get(0))
+        v1.set(0, 1.0)
+        self.assertEqual(1.0, v1.get(0))
+
+        v2 = SparseVector(4, [(1, 1.0), (3, 5.5)])
+        self.assertEqual(0.0, v2.get(0))
+
+        v2.set(0, 1.0)
+        self.assertEqual(1.0, v2.get(0))
+
+        v2.set(1, 2.0)
+        self.assertEqual(2.0, v2.get(1))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py
new file mode 100644
index 0000000..8e63427
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_vectorindexer.py
@@ -0,0 +1,103 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.feature.vectorindexer import VectorIndexer, VectorIndexerModel
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class VectorIndexerTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(VectorIndexerTest, self).setUp()
+        self.train_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (Vectors.dense(1, 1),),
+                (Vectors.dense(2, -1),),
+                (Vectors.dense(3, 1),),
+                (Vectors.dense(4, 0),),
+                (Vectors.dense(5, 0),)
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['input', ],
+                    [DenseVectorTypeInfo(), ])))
+
+        self.predict_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (Vectors.dense(0, 2),),
+                (Vectors.dense(0, 0),),
+                (Vectors.dense(0, -1),),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['input', ],
+                    [DenseVectorTypeInfo(), ])))
+
+        self.expected_output = [
+            Vectors.dense(5, 3),
+            Vectors.dense(5, 0),
+            Vectors.dense(5, 1)]
+
+    def test_param(self):
+        vector_indexer = VectorIndexer()
+
+        self.assertEqual('input', vector_indexer.input_col)
+        self.assertEqual('output', vector_indexer.output_col)
+        self.assertEqual(20, vector_indexer.max_categories)
+        self.assertEqual('error', vector_indexer.handle_invalid)
+
+        vector_indexer.set_input_col('test_input') \
+            .set_output_col("test_output") \
+            .set_max_categories(3) \
+            .set_handle_invalid('skip')
+
+        self.assertEqual('test_input', vector_indexer.input_col)
+        self.assertEqual('test_output', vector_indexer.output_col)
+        self.assertEqual(3, vector_indexer.max_categories)
+        self.assertEqual('skip', vector_indexer.handle_invalid)
+
+    def test_output_schema(self):
+        vector_indexer = VectorIndexer()
+
+        output = vector_indexer.fit(self.train_table).transform(self.predict_table)[0]
+
+        self.assertEqual(
+            ['input', 'output'],
+            output.get_schema().get_field_names())
+
+    def test_save_load_predict(self):
+        vector_indexer = VectorIndexer().set_handle_invalid('keep')
+        estimator_path = os.path.join(self.temp_dir, 'test_save_load_predict_vectorindexer')
+        vector_indexer.save(estimator_path)
+        vector_indexer = VectorIndexer.load(self.t_env, estimator_path)
+
+        model = vector_indexer.fit(self.train_table)
+        model_path = os.path.join(self.temp_dir, 'test_save_load_predict_vectorindexer_model')
+        model.save(model_path)
+        self.env.execute('save_model')
+        model = VectorIndexerModel.load(self.t_env, model_path)
+
+        output_table = model.transform(self.predict_table)[0]
+        predicted_results = [result[1] for result in
+                             self.t_env.to_data_stream(output_table).execute_and_collect()]
+
+        predicted_results.sort(key=lambda x: x[1])
+        self.expected_output.sort(key=lambda x: x[1])
+        self.assertEqual(self.expected_output, predicted_results)
diff --git a/flink-ml-python/pyflink/ml/lib/feature/vectorindexer.py b/flink-ml-python/pyflink/ml/lib/feature/vectorindexer.py
new file mode 100644
index 0000000..e397ecf
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/vectorindexer.py
@@ -0,0 +1,127 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.core.param import IntParam, ParamValidators
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+from pyflink.ml.lib.param import HasInputCol, HasOutputCol, HasHandleInvalid
+
+
+class _VectorIndexerModelParams(
+    JavaWithParams,
+    HasInputCol,
+    HasOutputCol,
+    HasHandleInvalid
+):
+    """
+    Params for :class:`VectorIndexerModel`.
+    """
+
+    def __init__(self, java_params):
+        super(_VectorIndexerModelParams, self).__init__(java_params)
+
+
+class _VectorIndexerParams(_VectorIndexerModelParams):
+    """
+    Params for :class:`VectorIndexer`.
+    """
+
+    MAX_CATEGORIES: IntParam = IntParam(
+        "max_categories",
+        "Threshold for the number of values a categorical feature can take (>= 2). "
+        + "If a feature is found to have > maxCategories values, then it is declared continuous.",
+        20,
+        ParamValidators.gt_eq(2)
+    )
+
+    def __init__(self, java_params):
+        super(_VectorIndexerParams, self).__init__(java_params)
+
+    def set_max_categories(self, value: int):
+        return typing.cast(_VectorIndexerParams, self.set(self.MAX_CATEGORIES, value))
+
+    def get_max_categories(self) -> int:
+        return self.get(self.MAX_CATEGORIES)
+
+    @property
+    def max_categories(self):
+        return self.get_max_categories()
+
+
+class VectorIndexerModel(JavaFeatureModel, _VectorIndexerModelParams):
+    """
+    A Model which encodes input vector to an output vector using the model data computed by
+    :class::VectorIndexer.
+
+    The `keep` option of {@link HasHandleInvalid} means that we put the invalid entries in a
+    special bucket, whose index is the number of distinct values in this column.
+    """
+
+    def __init__(self, java_model=None):
+        super(VectorIndexerModel, self).__init__(java_model)
+
+    @classmethod
+    def _java_model_package_name(cls) -> str:
+        return "vectorindexer"
+
+    @classmethod
+    def _java_model_class_name(cls) -> str:
+        return "VectorIndexerModel"
+
+
+class VectorIndexer(JavaFeatureEstimator, _VectorIndexerParams):
+    """
+    An Estimator which implements the vector indexing algorithm.
+
+    A vector indexer maps each column of the input vector into a continuous/categorical
+    feature. Whether one feature is transformed into a continuous or categorical feature
+    depends on the number of distinct values in this column. If the number of distinct
+    values in one column is greater than a specified parameter (i.e., maxCategories),
+    the corresponding output column is unchanged. Otherwise, it is transformed into
+    a categorical value. For categorical outputs, the indices are
+    in [0, numDistinctValuesInThisColumn].
+
+    The output model is organized in ascending order except that 0.0 is always mapped
+     to 0 (for sparsity). We list two examples here:
+
+    <ul>
+        <li>If one column contains {-1.0, 1.0}, then -1.0 should be encoded as 0
+        and 1.0 will be encoded as 1.
+        <li>If one column contains {-1.0, 0.0, 1.0}, then -1.0 should be encoded as 1,
+         0.0 should be encoded as 0 and 1.0 should be encoded as 2.
+    </ul>
+
+    The `keep` option of {@link HasHandleInvalid} means that we put the invalid entries
+    in a special bucket, whose index is the number of distinct values in this column.
+    """
+
+    def __init__(self):
+        super(VectorIndexer, self).__init__()
+
+    @classmethod
+    def _create_model(cls, java_model) -> VectorIndexerModel:
+        return VectorIndexerModel(java_model)
+
+    @classmethod
+    def _java_estimator_package_name(cls) -> str:
+        return "vectorindexer"
+
+    @classmethod
+    def _java_estimator_class_name(cls) -> str:
+        return "VectorIndexer"