You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/11/23 09:25:27 UTC

[flink-ml] branch master updated: [FLINK-29604] Add Estimator and Transformer for CountVectorizer

This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 23d7646  [FLINK-29604] Add Estimator and Transformer for CountVectorizer
23d7646 is described below

commit 23d76467f9d0c089206019afde2d8f87409ec526
Author: JiangXin <ji...@alibaba-inc.com>
AuthorDate: Wed Nov 23 17:25:22 2022 +0800

    [FLINK-29604] Add Estimator and Transformer for CountVectorizer
    
    This closes #174.
---
 .../docs/operators/feature/countvectorizer.md      | 182 ++++++++++
 .../examples/feature/CountVectorizerExample.java   |  71 ++++
 .../feature/countvectorizer/CountVectorizer.java   | 218 ++++++++++++
 .../countvectorizer/CountVectorizerModel.java      | 188 ++++++++++
 .../countvectorizer/CountVectorizerModelData.java  | 110 ++++++
 .../CountVectorizerModelParams.java                |  67 ++++
 .../countvectorizer/CountVectorizerParams.java     |  86 +++++
 .../flink/ml/feature/CountVectorizerTest.java      | 390 +++++++++++++++++++++
 .../examples/ml/feature/countvectorizer_example.py |  62 ++++
 .../pyflink/ml/lib/feature/countvectorizer.py      | 189 ++++++++++
 .../ml/lib/feature/tests/test_countvectorizer.py   | 144 ++++++++
 11 files changed, 1707 insertions(+)

diff --git a/docs/content/docs/operators/feature/countvectorizer.md b/docs/content/docs/operators/feature/countvectorizer.md
new file mode 100644
index 0000000..aef0b7c
--- /dev/null
+++ b/docs/content/docs/operators/feature/countvectorizer.md
@@ -0,0 +1,182 @@
+---
+title: "Count Vectorizer"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/countvectorizer.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+  http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions dand limitations
+under the License.
+-->
+
+## Count Vectorizer
+
+CountVectorizer is an algorithm that converts a collection of text
+documents to vectors of token counts. When an a-priori dictionary is not 
+available, CountVectorizer can be used as an estimator to extract the 
+vocabulary, and generates a CountVectorizerModel. The model produces sparse
+representations for the documents over the vocabulary, which can then be 
+passed to other algorithms like LDA.
+
+### Input Columns
+
+| Param name | Type     | Default   | Description         |
+|:-----------|:---------|:----------|:--------------------|
+| inputCol   | String[] | `"input"` | Input string array. |
+
+### Output Columns
+
+| Param name | Type         | Default    | Description             |
+|:-----------|:-------------|:-----------|:------------------------|
+| outputCol  | SparseVector | `"output"` | Vector of token counts. |
+
+### Parameters
+
+Below are the parameters required by `CountVectorizerModel`.
+
+| Key        | Default    | Type    | Required | Description                                                                                                                                                                                                                                                                                                                                     |
+|------------|------------|---------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| inputCol   | `"input"`  | String  | no       | Input column name.                                                                                                                                                                                                                                                                                                                              |
+| outputCol  | `"output"` | String  | no       | Output column name.                                                                                                                                                                                                                                                                                                                             |
+| minTF      | `1.0`      | Double  | no       | Filter to ignore rare words in a document. For each document, terms with frequency/count less than the given threshold are ignored. If this is an integer >= 1, then this specifies a count (of times the term must appear in the document); if this is a double in [0,1), then this specifies a fraction (out of the document's token count).  |
+| binary     | `false`    | Boolean | no       | Binary toggle to control the output vector values. If True, all nonzero counts (after minTF filter applied) are set to 1.0.                                                                                                                                                                                                                     |
+
+`CountVectorizer` needs parameters above and also below.
+
+| Key            | Default    | Type     | Required | Description                                                                                                                                                                                                                                                                                                                                                                                  |
+|:---------------|:-----------|:---------|:---------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| vocabularySize | `2^18`     | Integer  | no       | Max size of the vocabulary. CountVectorizer will build a vocabulary that only considers the top vocabulary size terms ordered by term frequency across the corpus.                                                                                                                                                                                                                           |
+| minDF          | `1.0`      | Double   | no       | Specifies the minimum number of different documents a term must appear in to be included in the vocabulary. If this is an integer >= 1, this specifies the number of documents the term must appear in; if this is a double in [0,1), then this specifies the fraction of documents.                                                                                                         |
+| maxDF          | `2^63 - 1` | Double   | no       | Specifies the maximum number of different documents a term could appear in to be included in the vocabulary. A term that appears more than the threshold will be ignored. If this is an integer >= 1, this specifies the maximum number of documents the term could appear in; if this is a double in [0,1), then this specifies the maximum fraction of documents the term could appear in. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizer;
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+
+/**
+ * Simple program that trains a {@link CountVectorizer} model and uses it for feature engineering.
+ */
+public class CountVectorizerExample {
+
+    public static void main(String[] args) {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input training and prediction data.
+        DataStream<Row> dataStream =
+                env.fromElements(
+                        Row.of((Object) new String[] {"a", "c", "b", "c"}),
+                        Row.of((Object) new String[] {"c", "d", "e"}),
+                        Row.of((Object) new String[] {"a", "b", "c"}),
+                        Row.of((Object) new String[] {"e", "f"}),
+                        Row.of((Object) new String[] {"a", "c", "a"}));
+        Table inputTable = tEnv.fromDataStream(dataStream).as("input");
+
+        // Creates an CountVectorizer object and initialize its parameters
+        CountVectorizer countVectorizer = new CountVectorizer();
+
+        // Trains the CountVectorizer model
+        CountVectorizerModel model = countVectorizer.fit(inputTable);
+
+        // Uses the CountVectorizer model for predictions.
+        Table outputTable = model.transform(inputTable)[0];
+
+        // Extracts and displays the results.
+        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+            Row row = it.next();
+            String[] inputValue = (String[]) row.getField(countVectorizer.getInputCol());
+            SparseVector outputValue = (SparseVector) row.getField(countVectorizer.getOutputCol());
+            System.out.printf(
+                    "Input Value: %-15s \tOutput Value: %s\n",
+                    Arrays.toString(inputValue), outputValue.toString());
+        }
+    }
+}
+
+```
+
+{{< /tab>}}
+
+{{< tab "Python">}}
+
+```python
+
+# Simple program that creates an CountVectorizer instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.countvectorizer import CountVectorizer
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input training and prediction data.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (1, ['a', 'c', 'b', 'c'],),
+        (2, ['c', 'd', 'e'],),
+        (3, ['a', 'b', 'c'],),
+        (4, ['e', 'f'],),
+        (5, ['a', 'c', 'a'],),
+    ],
+        type_info=Types.ROW_NAMED(
+            ['id', 'input', ],
+            [Types.INT(), Types.OBJECT_ARRAY(Types.STRING())])
+    ))
+
+# Creates an CountVectorizer object and initializes its parameters.
+count_vectorizer = CountVectorizer()
+
+# Trains the CountVectorizer Model.
+model = count_vectorizer.fit(input_table)
+
+# Uses the CountVectorizer Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+    input_index = field_names.index(count_vectorizer.get_input_col())
+    output_index = field_names.index(count_vectorizer.get_output_col())
+    print('Input Value: %-20s Output Value: %10s' %
+          (str(result[input_index]), str(result[output_index])))
+
+```
+
+{{< /tab>}}
+
+{{< /tabs>}}
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java
new file mode 100644
index 0000000..fb1287c
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizer;
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+
+/**
+ * Simple program that trains a {@link CountVectorizer} model and uses it for feature engineering.
+ */
+public class CountVectorizerExample {
+
+    public static void main(String[] args) {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input training and prediction data.
+        DataStream<Row> dataStream =
+                env.fromElements(
+                        Row.of((Object) new String[] {"a", "c", "b", "c"}),
+                        Row.of((Object) new String[] {"c", "d", "e"}),
+                        Row.of((Object) new String[] {"a", "b", "c"}),
+                        Row.of((Object) new String[] {"e", "f"}),
+                        Row.of((Object) new String[] {"a", "c", "a"}));
+        Table inputTable = tEnv.fromDataStream(dataStream).as("input");
+
+        // Creates an CountVectorizer object and initialize its parameters
+        CountVectorizer countVectorizer = new CountVectorizer();
+
+        // Trains the CountVectorizer model
+        CountVectorizerModel model = countVectorizer.fit(inputTable);
+
+        // Uses the CountVectorizer model for predictions.
+        Table outputTable = model.transform(inputTable)[0];
+
+        // Extracts and displays the results.
+        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+            Row row = it.next();
+            String[] inputValue = (String[]) row.getField(countVectorizer.getInputCol());
+            SparseVector outputValue = (SparseVector) row.getField(countVectorizer.getOutputCol());
+            System.out.printf(
+                    "Input Value: %-15s \tOutput Value: %s\n",
+                    Arrays.toString(inputValue), outputValue.toString());
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java
new file mode 100644
index 0000000..76ccc13
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * An Estimator which converts a collection of text documents to vectors of token counts. When an
+ * a-priori dictionary is not available, CountVectorizer can be used as an estimator to extract the
+ * vocabulary, and generates a {@link CountVectorizerModel}. The model produces sparse
+ * representations for the documents over the vocabulary, which can then be passed to other
+ * algorithms like LDA.
+ */
+public class CountVectorizer
+        implements Estimator<CountVectorizer, CountVectorizerModel>,
+                CountVectorizerParams<CountVectorizer> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public CountVectorizer() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public CountVectorizerModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        double minDF = getMinDF();
+        double maxDF = getMaxDF();
+        if (minDF >= 1.0 && maxDF >= 1.0 || minDF < 1.0 && maxDF < 1.0) {
+            Preconditions.checkArgument(maxDF >= minDF, "maxDF must be >= minDF.");
+        }
+
+        String inputCol = getInputCol();
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<String[]> inputData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                (MapFunction<Row, String[]>)
+                                        value -> ((String[]) value.getField(inputCol)));
+
+        DataStream<CountVectorizerModelData> modelData =
+                DataStreamUtils.aggregate(
+                        inputData,
+                        new VocabularyAggregator(getMinDF(), getMaxDF(), getVocabularySize()));
+
+        CountVectorizerModel model =
+                new CountVectorizerModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, getParamMap());
+        return model;
+    }
+
+    /**
+     * Extracts a vocabulary from input document collections and builds the {@link
+     * CountVectorizerModelData}.
+     */
+    private static class VocabularyAggregator
+            implements AggregateFunction<
+                    String[],
+                    Tuple2<Long, Map<String, Tuple2<Long, Long>>>,
+                    CountVectorizerModelData> {
+        private final double minDF;
+        private final double maxDF;
+        private final int vocabularySize;
+
+        public VocabularyAggregator(double minDF, double maxDF, int vocabularySize) {
+            this.minDF = minDF;
+            this.maxDF = maxDF;
+            this.vocabularySize = vocabularySize;
+        }
+
+        @Override
+        public Tuple2<Long, Map<String, Tuple2<Long, Long>>> createAccumulator() {
+            return Tuple2.of(0L, new HashMap<>());
+        }
+
+        @Override
+        public Tuple2<Long, Map<String, Tuple2<Long, Long>>> add(
+                String[] terms, Tuple2<Long, Map<String, Tuple2<Long, Long>>> vocabAccumulator) {
+            Map<String, Long> wc = new HashMap<>();
+            Arrays.stream(terms)
+                    .forEach(
+                            x -> {
+                                if (wc.containsKey(x)) {
+                                    wc.put(x, wc.get(x) + 1);
+                                } else {
+                                    wc.put(x, 1L);
+                                }
+                            });
+
+            Map<String, Tuple2<Long, Long>> counts = vocabAccumulator.f1;
+            wc.forEach(
+                    (w, c) -> {
+                        if (counts.containsKey(w)) {
+                            counts.get(w).f0 += c;
+                            counts.get(w).f1 += 1;
+                        } else {
+                            counts.put(w, Tuple2.of(c, 1L));
+                        }
+                    });
+            vocabAccumulator.f0 += 1;
+
+            return vocabAccumulator;
+        }
+
+        @Override
+        public CountVectorizerModelData getResult(
+                Tuple2<Long, Map<String, Tuple2<Long, Long>>> vocabAccumulator) {
+            Preconditions.checkState(vocabAccumulator.f0 > 0, "The training set is empty.");
+
+            boolean filteringRequired =
+                    !MIN_DF.defaultValue.equals(minDF) || !MAX_DF.defaultValue.equals(maxDF);
+            if (filteringRequired) {
+                long rowNum = vocabAccumulator.f0;
+                double actualMinDF = minDF >= 1.0 ? minDF : minDF * rowNum;
+                double actualMaxDF = maxDF >= 1.0 ? maxDF : maxDF * rowNum;
+                Preconditions.checkState(actualMaxDF >= actualMinDF, "maxDF must be >= minDF.");
+
+                vocabAccumulator.f1 =
+                        vocabAccumulator.f1.entrySet().stream()
+                                .filter(
+                                        x ->
+                                                x.getValue().f1 >= actualMinDF
+                                                        && x.getValue().f1 <= actualMaxDF)
+                                .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+            }
+
+            List<Map.Entry<String, Tuple2<Long, Long>>> list =
+                    new ArrayList<>(vocabAccumulator.f1.entrySet());
+            list.sort((o1, o2) -> (o2.getValue().f1.compareTo(o1.getValue().f1)));
+            List<String> vocabulary =
+                    list.stream().map(Map.Entry::getKey).collect(Collectors.toList());
+            String[] topTerms =
+                    vocabulary
+                            .subList(0, Math.min(vocabulary.size(), vocabularySize))
+                            .toArray(new String[0]);
+            return new CountVectorizerModelData(topTerms);
+        }
+
+        @Override
+        public Tuple2<Long, Map<String, Tuple2<Long, Long>>> merge(
+                Tuple2<Long, Map<String, Tuple2<Long, Long>>> acc1,
+                Tuple2<Long, Map<String, Tuple2<Long, Long>>> acc2) {
+            if (acc1.f0 == 0) {
+                return acc2;
+            }
+
+            if (acc2.f0 == 0) {
+                return acc1;
+            }
+            acc2.f0 += acc1.f0;
+            acc1.f1.forEach(
+                    (term, counts) -> {
+                        if (acc2.f1.containsKey(term)) {
+                            acc2.f1.put(
+                                    term,
+                                    Tuple2.of(
+                                            counts.f0 + acc2.f1.get(term).f0,
+                                            counts.f1 + acc2.f1.get(term).f1));
+                        } else {
+                            acc2.f1.put(term, counts);
+                        }
+                    });
+            return acc2;
+        }
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static CountVectorizer load(StreamTableEnvironment tEnv, String path)
+            throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java
new file mode 100644
index 0000000..390d997
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+/** A Model which transforms data using the model data computed by {@link CountVectorizer}. */
+public class CountVectorizerModel
+        implements Model<CountVectorizerModel>, CountVectorizerModelParams<CountVectorizerModel> {
+
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public CountVectorizerModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public CountVectorizerModel setModelData(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                CountVectorizerModelData.getModelDataStream(modelDataTable),
+                path,
+                new CountVectorizerModelData.ModelDataEncoder());
+    }
+
+    public static CountVectorizerModel load(StreamTableEnvironment tEnv, String path)
+            throws IOException {
+        CountVectorizerModel model = ReadWriteUtils.loadStageParam(path);
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(
+                        tEnv, path, new CountVectorizerModelData.ModelDataDecoder());
+        return model.setModelData(modelDataTable);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+        DataStream<CountVectorizerModelData> modelDataStream =
+                CountVectorizerModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(), SparseVectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(dataStream),
+                        Collections.singletonMap(broadcastModelKey, modelDataStream),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new PredictOutputFunction(
+                                            getInputCol(),
+                                            broadcastModelKey,
+                                            getMinTF(),
+                                            getBinary()),
+                                    outputTypeInfo);
+                        });
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    /** This operator loads model data and predicts result. */
+    private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+
+        private final String inputCol;
+        private final String broadcastKey;
+        private final double minTF;
+        private final boolean binary;
+        private Map<String, Integer> vocabulary;
+
+        public PredictOutputFunction(
+                String inputCol, String broadcastKey, double minTF, boolean binary) {
+            this.inputCol = inputCol;
+            this.broadcastKey = broadcastKey;
+            this.minTF = minTF;
+            this.binary = binary;
+        }
+
+        @Override
+        public Row map(Row row) throws Exception {
+            if (vocabulary == null) {
+                CountVectorizerModelData modelData =
+                        (CountVectorizerModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                vocabulary = new HashMap<>();
+                IntStream.range(0, modelData.vocabulary.length)
+                        .forEach(i -> vocabulary.put(modelData.vocabulary[i], i));
+            }
+
+            String[] document = (String[]) row.getField(inputCol);
+            double[] termCounts = new double[vocabulary.size()];
+            for (String word : document) {
+                if (vocabulary.containsKey(word)) {
+                    termCounts[vocabulary.get(word)] += 1;
+                }
+            }
+
+            double actualMinTF = minTF >= 1.0 ? minTF : document.length * minTF;
+            List<Integer> indices = new ArrayList<>();
+            List<Double> values = new ArrayList<>();
+            for (int i = 0; i < termCounts.length; i++) {
+                if (termCounts[i] >= actualMinTF) {
+                    indices.add(i);
+                    if (binary) {
+                        values.add(1.0);
+                    } else {
+                        values.add(termCounts[i]);
+                    }
+                }
+            }
+
+            SparseVector outputVec =
+                    Vectors.sparse(
+                            termCounts.length,
+                            indices.stream().mapToInt(i -> i).toArray(),
+                            values.stream().mapToDouble(i -> i).toArray());
+            return Row.join(row, Row.of(outputVec));
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelData.java
new file mode 100644
index 0000000..6ac41a5
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelData.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.array.StringArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link CountVectorizerModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a data stream, and
+ * classes to save/load model data.
+ */
+public class CountVectorizerModelData {
+
+    /** The array over terms, only the terms in the vocabulary will be counted. */
+    public String[] vocabulary;
+
+    public CountVectorizerModelData() {}
+
+    public CountVectorizerModelData(String[] vocabulary) {
+        this.vocabulary = vocabulary;
+    }
+
+    /**
+     * Converts the table model to a data stream.
+     *
+     * @param modelDataTable The table model data.
+     * @return The data stream model data.
+     */
+    public static DataStream<CountVectorizerModelData> getModelDataStream(Table modelDataTable) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+        return tEnv.toDataStream(modelDataTable)
+                .map(x -> new CountVectorizerModelData((String[]) x.getField(0)));
+    }
+
+    /** Encoder for {@link CountVectorizerModelData}. */
+    public static class ModelDataEncoder implements Encoder<CountVectorizerModelData> {
+        @Override
+        public void encode(CountVectorizerModelData modelData, OutputStream outputStream)
+                throws IOException {
+            DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
+            StringArraySerializer.INSTANCE.serialize(modelData.vocabulary, dataOutputView);
+        }
+    }
+
+    /** Decoder for {@link CountVectorizerModelData}. */
+    public static class ModelDataDecoder extends SimpleStreamFormat<CountVectorizerModelData> {
+
+        @Override
+        public Reader<CountVectorizerModelData> createReader(
+                Configuration configuration, FSDataInputStream fsDataInputStream) {
+            return new Reader<CountVectorizerModelData>() {
+                @Override
+                public CountVectorizerModelData read() throws IOException {
+                    DataInputView source = new DataInputViewStreamWrapper(fsDataInputStream);
+                    try {
+                        String[] vocabulary = StringArraySerializer.INSTANCE.deserialize(source);
+                        return new CountVectorizerModelData(vocabulary);
+                    } catch (EOFException e) {
+                        return null;
+                    }
+                }
+
+                @Override
+                public void close() throws IOException {
+                    fsDataInputStream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<CountVectorizerModelData> getProducedType() {
+            return TypeInformation.of(CountVectorizerModelData.class);
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelParams.java
new file mode 100644
index 0000000..e725956
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelParams.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params for {@link CountVectorizerModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface CountVectorizerModelParams<T> extends HasInputCol<T>, HasOutputCol<T> {
+    Param<Double> MIN_TF =
+            new DoubleParam(
+                    "minTF",
+                    "Filter to ignore rare words in a document. For each document,"
+                            + "terms with frequency/count less than the given threshold are ignored. "
+                            + "If this is an integer >= 1, then this specifies a count (of times "
+                            + "the term must appear in the document); if this is a double in [0,1), "
+                            + "then this specifies a fraction (out of the document's token count).",
+                    1.0,
+                    ParamValidators.gtEq(0.0));
+
+    Param<Boolean> BINARY =
+            new BooleanParam(
+                    "binary",
+                    "Binary toggle to control the output vector values. If True, all nonzero "
+                            + "counts (after minTF filter applied) are set to 1.0.",
+                    false);
+
+    default double getMinTF() {
+        return get(MIN_TF);
+    }
+
+    default T setMinTF(double value) {
+        return set(MIN_TF, value);
+    }
+
+    default boolean getBinary() {
+        return get(BINARY);
+    }
+
+    default T setBinary(boolean value) {
+        return set(BINARY, value);
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerParams.java
new file mode 100644
index 0000000..992a687
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerParams.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link CountVectorizer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface CountVectorizerParams<T> extends CountVectorizerModelParams<T> {
+    Param<Integer> VOCABULARY_SIZE =
+            new IntParam(
+                    "vocabularySize",
+                    "Max size of the vocabulary. CountVectorizer will build a vocabulary"
+                            + "that only considers the top vocabulary size terms ordered by term "
+                            + "frequency across the corpus.",
+                    1 << 18,
+                    ParamValidators.gt(0));
+
+    Param<Double> MIN_DF =
+            new DoubleParam(
+                    "minDF",
+                    "Specifies the minimum number of different documents a term must"
+                            + "appear in to be included in the vocabulary. If this is an integer >= 1,"
+                            + "this specifies the number of documents the term must appear in;"
+                            + "if this is a double in [0,1), then this specifies the fraction of documents.",
+                    1.0,
+                    ParamValidators.gtEq(0.0));
+
+    Param<Double> MAX_DF =
+            new DoubleParam(
+                    "maxDF",
+                    "Specifies the maximum number of different documents a term could appear "
+                            + "in to be included in the vocabulary. A term that appears more than "
+                            + "the threshold will be ignored. If this is an integer >= 1, this "
+                            + "specifies the maximum number of documents the term could appear in; "
+                            + "if this is a double in [0,1), then this specifies the maximum "
+                            + "fraction of documents the term could appear in.",
+                    (double) Long.MAX_VALUE,
+                    ParamValidators.gtEq(0.0));
+
+    default int getVocabularySize() {
+        return get(VOCABULARY_SIZE);
+    }
+
+    default T setVocabularySize(int value) {
+        return set(VOCABULARY_SIZE, value);
+    }
+
+    default double getMinDF() {
+        return get(MIN_DF);
+    }
+
+    default T setMinDF(double value) {
+        return set(MIN_DF, value);
+    }
+
+    default double getMaxDF() {
+        return get(MAX_DF);
+    }
+
+    default T setMaxDF(double value) {
+        return set(MAX_DF, value);
+    }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java
new file mode 100644
index 0000000..e81c889
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java
@@ -0,0 +1,390 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizer;
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.DoubleStream;
+import java.util.stream.IntStream;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/** Tests {@link CountVectorizer} and {@link CountVectorizerModel}. */
+public class CountVectorizerTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table inputTable;
+
+    private static final double EPS = 1.0e-5;
+    private static final List<Row> INPUT_DATA =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Row.of((Object) new String[] {"a", "c", "b", "c"}),
+                            Row.of((Object) new String[] {"c", "d", "e"}),
+                            Row.of((Object) new String[] {"a", "b", "c"}),
+                            Row.of((Object) new String[] {"e", "f"}),
+                            Row.of((Object) new String[] {"a", "c", "a"})));
+
+    private static final List<SparseVector> EXPECTED_OUTPUT =
+            new ArrayList<>(
+                    Arrays.asList(
+                            Vectors.sparse(
+                                    6,
+                                    IntStream.of(0, 1, 2).toArray(),
+                                    DoubleStream.of(2.0, 1.0, 1.0).toArray()),
+                            Vectors.sparse(
+                                    6,
+                                    IntStream.of(0, 3, 4).toArray(),
+                                    DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+                            Vectors.sparse(
+                                    6,
+                                    IntStream.of(0, 1, 2).toArray(),
+                                    DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+                            Vectors.sparse(
+                                    6,
+                                    IntStream.of(3, 5).toArray(),
+                                    DoubleStream.of(1.0, 1.0).toArray()),
+                            Vectors.sparse(
+                                    6,
+                                    IntStream.of(0, 1).toArray(),
+                                    DoubleStream.of(1.0, 2.0).toArray())));
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        env.getConfig().enableObjectReuse();
+        tEnv = StreamTableEnvironment.create(env);
+
+        inputTable = tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("input");
+    }
+
+    private static void verifyPredictionResult(
+            Table output, String outputCol, List<SparseVector> expected) throws Exception {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+        DataStream<SparseVector> stream =
+                tEnv.toDataStream(output)
+                        .map(
+                                (MapFunction<Row, SparseVector>)
+                                        row -> (SparseVector) row.getField(outputCol));
+        List<SparseVector> result = IteratorUtils.toList(stream.executeAndCollect());
+        compareResultCollections(expected, result, TestUtils::compare);
+    }
+
+    @Test
+    public void testParam() {
+        CountVectorizer countVectorizer = new CountVectorizer();
+        assertEquals("input", countVectorizer.getInputCol());
+        assertEquals("output", countVectorizer.getOutputCol());
+        assertEquals((double) Long.MAX_VALUE, countVectorizer.getMaxDF(), EPS);
+        assertEquals(1.0, countVectorizer.getMinDF(), EPS);
+        assertEquals(1.0, countVectorizer.getMinTF(), EPS);
+        assertEquals(1 << 18, countVectorizer.getVocabularySize());
+        assertFalse(countVectorizer.getBinary());
+
+        countVectorizer
+                .setInputCol("test_input")
+                .setOutputCol("test_output")
+                .setMinDF(0.1)
+                .setMaxDF(0.9)
+                .setMinTF(10)
+                .setVocabularySize(1000)
+                .setBinary(true);
+        assertEquals("test_input", countVectorizer.getInputCol());
+        assertEquals("test_output", countVectorizer.getOutputCol());
+        assertEquals(0.9, countVectorizer.getMaxDF(), EPS);
+        assertEquals(0.1, countVectorizer.getMinDF(), EPS);
+        assertEquals(10, countVectorizer.getMinTF(), EPS);
+        assertEquals(1000, countVectorizer.getVocabularySize());
+        assertTrue(countVectorizer.getBinary());
+    }
+
+    @Test
+    public void testInvalidMinMaxDF() {
+        String errMessage = "maxDF must be >= minDF.";
+        CountVectorizer countVectorizer = new CountVectorizer();
+        countVectorizer.setMaxDF(0.1);
+        countVectorizer.setMinDF(0.2);
+        try {
+            countVectorizer.fit(inputTable);
+            fail();
+        } catch (Throwable e) {
+            assertEquals(errMessage, e.getMessage());
+        }
+        countVectorizer.setMaxDF(1);
+        countVectorizer.setMinDF(2);
+        try {
+            countVectorizer.fit(inputTable);
+            fail();
+        } catch (Throwable e) {
+            assertEquals(errMessage, e.getMessage());
+        }
+        countVectorizer.setMaxDF(1);
+        countVectorizer.setMinDF(0.9);
+        try {
+            CountVectorizerModel model = countVectorizer.fit(inputTable);
+            Table output = model.transform(inputTable)[0];
+            output.execute().print();
+            fail();
+        } catch (Throwable e) {
+            assertEquals(errMessage, ExceptionUtils.getRootCause(e).getMessage());
+        }
+        countVectorizer.setMaxDF(0.1);
+        countVectorizer.setMinDF(10);
+        try {
+            CountVectorizerModel model = countVectorizer.fit(inputTable);
+            Table output = model.transform(inputTable)[0];
+            output.execute().print();
+            fail();
+        } catch (Throwable e) {
+            assertEquals(errMessage, ExceptionUtils.getRootCause(e).getMessage());
+        }
+    }
+
+    @Test
+    public void testOutputSchema() {
+        CountVectorizer countVectorizer =
+                new CountVectorizer().setInputCol("test_input").setOutputCol("test_output");
+        CountVectorizerModel model = countVectorizer.fit(inputTable);
+        Table output = model.transform(inputTable.as("test_input"))[0];
+        assertEquals(
+                Arrays.asList("test_input", "test_output"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        CountVectorizer countVectorizer = new CountVectorizer();
+        CountVectorizerModel model = countVectorizer.fit(inputTable);
+        Table output = model.transform(inputTable)[0];
+
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), EXPECTED_OUTPUT);
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        CountVectorizer countVectorizer = new CountVectorizer();
+        CountVectorizer loadedCountVectorizer =
+                TestUtils.saveAndReload(
+                        tEnv, countVectorizer, tempFolder.newFolder().getAbsolutePath());
+        CountVectorizerModel model = loadedCountVectorizer.fit(inputTable);
+        CountVectorizerModel loadedModel =
+                TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+        assertEquals(
+                Arrays.asList("vocabulary"),
+                loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
+        Table output = loadedModel.transform(inputTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), EXPECTED_OUTPUT);
+    }
+
+    @Test
+    public void testFitOnEmptyData() {
+        Table emptyTable =
+                tEnv.fromDataStream(env.fromCollection(INPUT_DATA).filter(x -> x.getArity() == 0))
+                        .as("input");
+        CountVectorizer countVectorizer = new CountVectorizer();
+        CountVectorizerModel model = countVectorizer.fit(emptyTable);
+        Table modelDataTable = model.getModelData()[0];
+        try {
+            modelDataTable.execute().print();
+            fail();
+        } catch (Throwable e) {
+            assertEquals("The training set is empty.", ExceptionUtils.getRootCause(e).getMessage());
+        }
+    }
+
+    @Test
+    public void testMinMaxDF() throws Exception {
+        List<SparseVector> expectedOutput =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Vectors.sparse(
+                                        4,
+                                        IntStream.of(0, 1, 2).toArray(),
+                                        DoubleStream.of(2.0, 1.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        4,
+                                        IntStream.of(0, 3).toArray(),
+                                        DoubleStream.of(1.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        4,
+                                        IntStream.of(0, 1, 2).toArray(),
+                                        DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        4,
+                                        IntStream.of(3).toArray(),
+                                        DoubleStream.of(1.0).toArray()),
+                                Vectors.sparse(
+                                        4,
+                                        IntStream.of(0, 1).toArray(),
+                                        DoubleStream.of(1.0, 2.0).toArray())));
+        CountVectorizer countVectorizer = new CountVectorizer().setMinDF(2).setMaxDF(4);
+        CountVectorizerModel model = countVectorizer.fit(inputTable);
+        Table output = model.transform(inputTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput);
+
+        countVectorizer.setMinDF(0.4).setMaxDF(0.8);
+        model = countVectorizer.fit(inputTable);
+        output = model.transform(inputTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput);
+    }
+
+    @Test
+    public void testMinTF() throws Exception {
+        List<SparseVector> expectedOutput =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(0).toArray(),
+                                        DoubleStream.of(2.0).toArray()),
+                                Vectors.sparse(6, new int[0], new double[0]),
+                                Vectors.sparse(6, new int[0], new double[0]),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(3, 5).toArray(),
+                                        DoubleStream.of(1.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(1).toArray(),
+                                        DoubleStream.of(2.0).toArray())));
+        CountVectorizer countVectorizer = new CountVectorizer().setMinTF(0.5);
+        CountVectorizerModel model = countVectorizer.fit(inputTable);
+        Table output = model.transform(inputTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput);
+    }
+
+    @Test
+    public void testBinary() throws Exception {
+        List<SparseVector> expectedOutput =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(0, 1, 2).toArray(),
+                                        DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(0, 3, 4).toArray(),
+                                        DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(0, 1, 2).toArray(),
+                                        DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(3, 5).toArray(),
+                                        DoubleStream.of(1.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        6,
+                                        IntStream.of(0, 1).toArray(),
+                                        DoubleStream.of(1.0, 1.0).toArray())));
+        CountVectorizer countVectorizer = new CountVectorizer().setBinary(true);
+        CountVectorizerModel model = countVectorizer.fit(inputTable);
+        Table output = model.transform(inputTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput);
+    }
+
+    @Test
+    public void testVocabularySize() throws Exception {
+        List<SparseVector> expectedOutput =
+                new ArrayList<>(
+                        Arrays.asList(
+                                Vectors.sparse(
+                                        2,
+                                        IntStream.of(0, 1).toArray(),
+                                        DoubleStream.of(2.0, 1.0).toArray()),
+                                Vectors.sparse(
+                                        2,
+                                        IntStream.of(0).toArray(),
+                                        DoubleStream.of(1.0).toArray()),
+                                Vectors.sparse(
+                                        2,
+                                        IntStream.of(0, 1).toArray(),
+                                        DoubleStream.of(1.0, 1.0).toArray()),
+                                Vectors.sparse(2, new int[0], new double[0]),
+                                Vectors.sparse(
+                                        2,
+                                        IntStream.of(0, 1).toArray(),
+                                        DoubleStream.of(1.0, 2.0).toArray())));
+        CountVectorizer countVectorizer = new CountVectorizer().setVocabularySize(2);
+        CountVectorizerModel model = countVectorizer.fit(inputTable);
+        Table output = model.transform(inputTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput);
+    }
+
+    @Test
+    public void testGetModelData() throws Exception {
+        CountVectorizer countVectorizer = new CountVectorizer();
+        CountVectorizerModel model = countVectorizer.fit(inputTable);
+        Table modelData = model.getModelData()[0];
+        assertEquals(Arrays.asList("vocabulary"), modelData.getResolvedSchema().getColumnNames());
+
+        DataStream<Row> output = tEnv.toDataStream(modelData);
+        List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+        String[] vocabulary = (String[]) modelRows.get(0).getField(0);
+        String[] expectedVocabulary = {"c", "a", "b", "e", "d", "f"};
+        assertArrayEquals(expectedVocabulary, vocabulary);
+    }
+
+    @Test
+    public void testSetModelData() throws Exception {
+        CountVectorizer countVectorizer = new CountVectorizer();
+        CountVectorizerModel modelA = countVectorizer.fit(inputTable);
+        Table modelData = modelA.getModelData()[0];
+        CountVectorizerModel modelB = new CountVectorizerModel().setModelData(modelData);
+        Table output = modelB.transform(inputTable)[0];
+        verifyPredictionResult(output, countVectorizer.getOutputCol(), EXPECTED_OUTPUT);
+    }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/countvectorizer_example.py b/flink-ml-python/pyflink/examples/ml/feature/countvectorizer_example.py
new file mode 100644
index 0000000..347d40d
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/countvectorizer_example.py
@@ -0,0 +1,62 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates an CountVectorizer instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.countvectorizer import CountVectorizer
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input training and prediction data.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (1, ['a', 'c', 'b', 'c'],),
+        (2, ['c', 'd', 'e'],),
+        (3, ['a', 'b', 'c'],),
+        (4, ['e', 'f'],),
+        (5, ['a', 'c', 'a'],),
+    ],
+        type_info=Types.ROW_NAMED(
+            ['id', 'input', ],
+            [Types.INT(), Types.OBJECT_ARRAY(Types.STRING())])
+    ))
+
+# Creates an CountVectorizer object and initializes its parameters.
+count_vectorizer = CountVectorizer()
+
+# Trains the CountVectorizer Model.
+model = count_vectorizer.fit(input_table)
+
+# Uses the CountVectorizer Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+    input_index = field_names.index(count_vectorizer.get_input_col())
+    output_index = field_names.index(count_vectorizer.get_output_col())
+    print('Input Value: %-20s Output Value: %10s' %
+          (str(result[input_index]), str(result[output_index])))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/countvectorizer.py b/flink-ml-python/pyflink/ml/lib/feature/countvectorizer.py
new file mode 100644
index 0000000..5388474
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/countvectorizer.py
@@ -0,0 +1,189 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+from pyflink.ml.lib.param import HasOutputCol, HasInputCol
+
+from pyflink.ml.core.wrapper import JavaWithParams
+
+from pyflink.ml.core.param import FloatParam, BooleanParam, ParamValidators, IntParam
+
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+
+
+class _CountVectorizerModelParams(
+    JavaWithParams,
+    HasInputCol,
+    HasOutputCol,
+):
+    """
+    Params for :class:`CountVectorizerModel`.
+    """
+    MIN_TF: FloatParam = FloatParam(
+        "min_t_f",
+        "Filter to ignore rare words in a document. For each document, "
+        "terms with frequency/count less than the given threshold are ignored."
+        "If this is an integer >= 1, then this specifies a count (of times "
+        "the term must appear in the document); if this is a double in [0,1), "
+        "then this specifies a fraction (out of the document's token count).",
+        1.0,
+        ParamValidators.gt_eq(0.0)
+    )
+
+    BINARY: BooleanParam = BooleanParam(
+        "binary",
+        "Binary toggle to control the output vector values. If True, all "
+        "nonzero counts (after minTF filter applied) are set to 1.0.",
+        False
+    )
+
+    def __init__(self, java_params):
+        super(_CountVectorizerModelParams, self).__init__(java_params)
+
+    def set_min_tf(self, value: float):
+        return typing.cast(_CountVectorizerModelParams,
+                           self.set(self.MIN_TF, float(value)))
+
+    def get_min_tf(self):
+        return self.get(self.MIN_TF)
+
+    def set_binary(self, value: bool):
+        return typing.cast(_CountVectorizerModelParams, self.set(self.BINARY, value))
+
+    def get_binary(self):
+        return self.get(self.BINARY)
+
+    @property
+    def min_tf(self):
+        return self.get_min_tf()
+
+    @property
+    def binary(self):
+        return self.get_binary()
+
+
+class _CountVectorizerParams(_CountVectorizerModelParams):
+    """
+    Params for :class:`CountVectorizer`.
+    """
+    VOCABULARY_SIZE: IntParam = IntParam(
+        "vocabulary_size",
+        "Max size of the vocabulary. CountVectorizer will build a vocabulary "
+        "that only considers the top vocabularySize terms ordered by term "
+        "frequency across the corpus.",
+        1 << 18,
+        ParamValidators.gt(0)
+    )
+
+    MIN_DF: FloatParam = FloatParam(
+        "min_d_f",
+        "Specifies the minimum number of different documents a term must"
+        "appear in to be included in the vocabulary. If this is an "
+        "integer >= 1, this specifies the number of documents the term must "
+        "appear in; if this is a double in [0,1), then this specifies the "
+        "fraction of documents.",
+        1.0,
+        ParamValidators.gt_eq(0.0)
+    )
+
+    MAX_DF: FloatParam = FloatParam(
+        "max_d_f",
+        "Specifies the maximum number of different documents a term could "
+        "appear in to be included in the vocabulary. A term that appears "
+        "more than the threshold will be ignored. If this is an integer >= 1,"
+        "this specifies the maximum number of documents the term could "
+        "appear in; if this is a double in [0,1), then this specifies the "
+        "maximum fraction of documents the term could appear in.",
+        float(2**63 - 1),
+        ParamValidators.gt_eq(0.0)
+    )
+
+    def __init__(self, java_params):
+        super(_CountVectorizerParams, self).__init__(java_params)
+
+    def set_vocabulary_size(self, value: str):
+        return typing.cast(_CountVectorizerParams, self.set(self.VOCABULARY_SIZE, value))
+
+    def get_vocabulary_size(self) -> str:
+        return self.get(self.VOCABULARY_SIZE)
+
+    def set_min_df(self, value: float):
+        return typing.cast(_CountVectorizerParams, self.set(self.MIN_DF, float(value)))
+
+    def get_min_df(self):
+        return self.get(self.MIN_DF)
+
+    def set_max_df(self, value: float):
+        return typing.cast(_CountVectorizerParams, self.set(self.MAX_DF, float(value)))
+
+    def get_max_df(self):
+        return self.get(self.MAX_DF)
+
+    @property
+    def vocabulary_size(self):
+        return self.get_vocabulary_size()
+
+    @property
+    def min_df(self):
+        return self.get_min_df()
+
+    @property
+    def max_df(self):
+        return self.get_max_df()
+
+
+class CountVectorizerModel(JavaFeatureModel, _CountVectorizerModelParams):
+    """
+    A Model which transforms data using the model data computed by CountVectorizer.
+    """
+
+    def __init__(self, java_model=None):
+        super(CountVectorizerModel, self).__init__(java_model)
+
+    @classmethod
+    def _java_model_package_name(cls) -> str:
+        return "countvectorizer"
+
+    @classmethod
+    def _java_model_class_name(cls) -> str:
+        return "CountVectorizerModel"
+
+
+class CountVectorizer(JavaFeatureEstimator, _CountVectorizerParams):
+    """
+    An Estimator which converts a collection of text documents
+    to vectors of token counts. When an a-priori dictionary is not available,
+    CountVectorizer can be used as an estimator to extract the vocabulary,
+    and generates a CountVectorizerModel. The model produces sparse
+    representations for the documents over the vocabulary, which can then
+    be passed to other algorithms like LDA.
+    """
+
+    def __init__(self):
+        super(CountVectorizer, self).__init__()
+
+    @classmethod
+    def _create_model(cls, java_model) -> CountVectorizerModel:
+        return CountVectorizerModel(java_model)
+
+    @classmethod
+    def _java_estimator_package_name(cls) -> str:
+        return "countvectorizer"
+
+    @classmethod
+    def _java_estimator_class_name(cls) -> str:
+        return "CountVectorizer"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_countvectorizer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_countvectorizer.py
new file mode 100644
index 0000000..f878236
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_countvectorizer.py
@@ -0,0 +1,144 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from typing import List
+
+from pyflink.common import Types
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+from pyflink.ml.core.linalg import Vectors, DenseVector
+
+from pyflink.ml.lib.feature.countvectorizer import CountVectorizer, CountVectorizerModel
+from pyflink.table import Table
+
+
+class CountVectorizerTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(CountVectorizerTest, self).setUp()
+        self.input_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (1, ['a', 'c', 'b', 'c'],),
+                (2, ['c', 'd', 'e'],),
+                (3, ['a', 'b', 'c'],),
+                (4, ['e', 'f'],),
+                (5, ['a', 'c', 'a'],),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['id', 'input', ],
+                    [Types.INT(), Types.OBJECT_ARRAY(Types.STRING())])))
+
+        self.expected_output = [
+            Vectors.sparse(6, [0, 1, 2], [2.0, 1.0, 1.0]),
+            Vectors.sparse(6, [0, 3, 4], [1.0, 1.0, 1.0]),
+            Vectors.sparse(6, [0, 1, 2], [1.0, 1.0, 1.0]),
+            Vectors.sparse(6, [3, 5], [1.0, 1.0]),
+            Vectors.sparse(6, [0, 1], [1.0, 2.0]),
+        ]
+
+    def test_param(self):
+        count_vectorizer = CountVectorizer()
+        self.assertEqual('input', count_vectorizer.input_col)
+        self.assertEqual('output', count_vectorizer.output_col)
+        self.assertEqual(1, count_vectorizer.min_df)
+        self.assertEqual(float(2**63 - 1), count_vectorizer.max_df)
+        self.assertEqual(1, count_vectorizer.min_tf)
+        self.assertEqual(1 << 18, count_vectorizer.vocabulary_size)
+        self.assertFalse(count_vectorizer.binary)
+
+        count_vectorizer.\
+            set_input_col('test_input').\
+            set_output_col('test_output').\
+            set_min_df(0.1).\
+            set_max_df(0.9).\
+            set_min_tf(10).\
+            set_vocabulary_size(1000).\
+            set_binary(True)
+        self.assertEqual('test_input', count_vectorizer.input_col)
+        self.assertEqual('test_output', count_vectorizer.output_col)
+        self.assertEqual(0.1, count_vectorizer.min_df)
+        self.assertEqual(0.9, count_vectorizer.max_df)
+        self.assertEqual(10, count_vectorizer.min_tf)
+        self.assertEqual(1000, count_vectorizer.vocabulary_size)
+        self.assertTrue(count_vectorizer.binary)
+
+    def test_output_schema(self):
+        count_vectorizer = CountVectorizer()
+        model = count_vectorizer.fit(self.input_table)
+        output = model.transform(self.input_table.alias('id', 'input'))[0]
+        self.assertEqual(
+            ['id', 'input', 'output'],
+            output.get_schema().get_field_names())
+
+    def test_fit_and_predict(self):
+        count_vectorizer = CountVectorizer()
+        model = count_vectorizer.fit(self.input_table)
+        output = model.transform(self.input_table)[0]
+        self.verify_output_result(
+            output,
+            count_vectorizer.get_output_col(),
+            output.get_schema().get_field_names(),
+            self.expected_output)
+
+    def test_save_load_predict(self):
+        count_vectorizer = CountVectorizer()
+        reloaded_count_vectorizer = self.save_and_reload(count_vectorizer)
+        model = reloaded_count_vectorizer.fit(self.input_table)
+        reloaded_model = self.save_and_reload(model)
+        output = reloaded_model.transform(self.input_table)[0]
+        self.verify_output_result(
+            output,
+            count_vectorizer.get_output_col(),
+            output.get_schema().get_field_names(),
+            self.expected_output)
+
+    def test_get_model_data(self):
+        count_vectorizer = CountVectorizer()
+        model = count_vectorizer.fit(self.input_table)
+        model_data_table = model.get_model_data()[0]
+        self.assertEqual(["vocabulary"],
+                         model_data_table.get_schema().get_field_names())
+        model_data = self.t_env.to_data_stream(model_data_table).execute_and_collect().next()
+        expected = ["c", "a", "b", "e", "d", "f"]
+        self.assertEqual(expected, model_data[0])
+
+    def test_set_model_data(self):
+        count_vectorizer = CountVectorizer()
+        model = count_vectorizer.fit(self.input_table)
+
+        new_model = CountVectorizerModel()
+        new_model.set_model_data(*model.get_model_data())
+        output = new_model.transform(self.input_table)[0]
+        self.verify_output_result(
+            output,
+            count_vectorizer.get_output_col(),
+            output.get_schema().get_field_names(),
+            self.expected_output)
+
+    def verify_output_result(
+            self, output: Table,
+            output_col: str,
+            field_names: List[str],
+            expected_result: List[DenseVector]):
+        collected_results = [result for result in
+                             self.t_env.to_data_stream(output).execute_and_collect()]
+        results = []
+        for item in collected_results:
+            item.set_field_names(field_names)
+            results.append(item)
+        results.sort(key=lambda x: x[0])
+        results = list(map(lambda x: x[output_col], results))
+        self.assertEqual(expected_result, results)