You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/11/23 09:25:27 UTC
[flink-ml] branch master updated: [FLINK-29604] Add Estimator and Transformer for CountVectorizer
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 23d7646 [FLINK-29604] Add Estimator and Transformer for CountVectorizer
23d7646 is described below
commit 23d76467f9d0c089206019afde2d8f87409ec526
Author: JiangXin <ji...@alibaba-inc.com>
AuthorDate: Wed Nov 23 17:25:22 2022 +0800
[FLINK-29604] Add Estimator and Transformer for CountVectorizer
This closes #174.
---
.../docs/operators/feature/countvectorizer.md | 182 ++++++++++
.../examples/feature/CountVectorizerExample.java | 71 ++++
.../feature/countvectorizer/CountVectorizer.java | 218 ++++++++++++
.../countvectorizer/CountVectorizerModel.java | 188 ++++++++++
.../countvectorizer/CountVectorizerModelData.java | 110 ++++++
.../CountVectorizerModelParams.java | 67 ++++
.../countvectorizer/CountVectorizerParams.java | 86 +++++
.../flink/ml/feature/CountVectorizerTest.java | 390 +++++++++++++++++++++
.../examples/ml/feature/countvectorizer_example.py | 62 ++++
.../pyflink/ml/lib/feature/countvectorizer.py | 189 ++++++++++
.../ml/lib/feature/tests/test_countvectorizer.py | 144 ++++++++
11 files changed, 1707 insertions(+)
diff --git a/docs/content/docs/operators/feature/countvectorizer.md b/docs/content/docs/operators/feature/countvectorizer.md
new file mode 100644
index 0000000..aef0b7c
--- /dev/null
+++ b/docs/content/docs/operators/feature/countvectorizer.md
@@ -0,0 +1,182 @@
+---
+title: "Count Vectorizer"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/countvectorizer.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+ http://www.apache.org/licenses/LICENSE-2.0
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions dand limitations
+under the License.
+-->
+
+## Count Vectorizer
+
+CountVectorizer is an algorithm that converts a collection of text
+documents to vectors of token counts. When an a-priori dictionary is not
+available, CountVectorizer can be used as an estimator to extract the
+vocabulary, and generates a CountVectorizerModel. The model produces sparse
+representations for the documents over the vocabulary, which can then be
+passed to other algorithms like LDA.
+
+### Input Columns
+
+| Param name | Type | Default | Description |
+|:-----------|:---------|:----------|:--------------------|
+| inputCol | String[] | `"input"` | Input string array. |
+
+### Output Columns
+
+| Param name | Type | Default | Description |
+|:-----------|:-------------|:-----------|:------------------------|
+| outputCol | SparseVector | `"output"` | Vector of token counts. |
+
+### Parameters
+
+Below are the parameters required by `CountVectorizerModel`.
+
+| Key | Default | Type | Required | Description |
+|------------|------------|---------|----------|-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| inputCol | `"input"` | String | no | Input column name. |
+| outputCol | `"output"` | String | no | Output column name. |
+| minTF | `1.0` | Double | no | Filter to ignore rare words in a document. For each document, terms with frequency/count less than the given threshold are ignored. If this is an integer >= 1, then this specifies a count (of times the term must appear in the document); if this is a double in [0,1), then this specifies a fraction (out of the document's token count). |
+| binary | `false` | Boolean | no | Binary toggle to control the output vector values. If True, all nonzero counts (after minTF filter applied) are set to 1.0. |
+
+`CountVectorizer` needs parameters above and also below.
+
+| Key | Default | Type | Required | Description |
+|:---------------|:-----------|:---------|:---------|:---------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
+| vocabularySize | `2^18` | Integer | no | Max size of the vocabulary. CountVectorizer will build a vocabulary that only considers the top vocabulary size terms ordered by term frequency across the corpus. |
+| minDF | `1.0` | Double | no | Specifies the minimum number of different documents a term must appear in to be included in the vocabulary. If this is an integer >= 1, this specifies the number of documents the term must appear in; if this is a double in [0,1), then this specifies the fraction of documents. |
+| maxDF | `2^63 - 1` | Double | no | Specifies the maximum number of different documents a term could appear in to be included in the vocabulary. A term that appears more than the threshold will be ignored. If this is an integer >= 1, this specifies the maximum number of documents the term could appear in; if this is a double in [0,1), then this specifies the maximum fraction of documents the term could appear in. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizer;
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+
+/**
+ * Simple program that trains a {@link CountVectorizer} model and uses it for feature engineering.
+ */
+public class CountVectorizerExample {
+
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input training and prediction data.
+ DataStream<Row> dataStream =
+ env.fromElements(
+ Row.of((Object) new String[] {"a", "c", "b", "c"}),
+ Row.of((Object) new String[] {"c", "d", "e"}),
+ Row.of((Object) new String[] {"a", "b", "c"}),
+ Row.of((Object) new String[] {"e", "f"}),
+ Row.of((Object) new String[] {"a", "c", "a"}));
+ Table inputTable = tEnv.fromDataStream(dataStream).as("input");
+
+ // Creates an CountVectorizer object and initialize its parameters
+ CountVectorizer countVectorizer = new CountVectorizer();
+
+ // Trains the CountVectorizer model
+ CountVectorizerModel model = countVectorizer.fit(inputTable);
+
+ // Uses the CountVectorizer model for predictions.
+ Table outputTable = model.transform(inputTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ String[] inputValue = (String[]) row.getField(countVectorizer.getInputCol());
+ SparseVector outputValue = (SparseVector) row.getField(countVectorizer.getOutputCol());
+ System.out.printf(
+ "Input Value: %-15s \tOutput Value: %s\n",
+ Arrays.toString(inputValue), outputValue.toString());
+ }
+ }
+}
+
+```
+
+{{< /tab>}}
+
+{{< tab "Python">}}
+
+```python
+
+# Simple program that creates an CountVectorizer instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.countvectorizer import CountVectorizer
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input training and prediction data.
+input_table = t_env.from_data_stream(
+ env.from_collection([
+ (1, ['a', 'c', 'b', 'c'],),
+ (2, ['c', 'd', 'e'],),
+ (3, ['a', 'b', 'c'],),
+ (4, ['e', 'f'],),
+ (5, ['a', 'c', 'a'],),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['id', 'input', ],
+ [Types.INT(), Types.OBJECT_ARRAY(Types.STRING())])
+ ))
+
+# Creates an CountVectorizer object and initializes its parameters.
+count_vectorizer = CountVectorizer()
+
+# Trains the CountVectorizer Model.
+model = count_vectorizer.fit(input_table)
+
+# Uses the CountVectorizer Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_index = field_names.index(count_vectorizer.get_input_col())
+ output_index = field_names.index(count_vectorizer.get_output_col())
+ print('Input Value: %-20s Output Value: %10s' %
+ (str(result[input_index]), str(result[output_index])))
+
+```
+
+{{< /tab>}}
+
+{{< /tabs>}}
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java
new file mode 100644
index 0000000..fb1287c
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/CountVectorizerExample.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizer;
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+
+/**
+ * Simple program that trains a {@link CountVectorizer} model and uses it for feature engineering.
+ */
+public class CountVectorizerExample {
+
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input training and prediction data.
+ DataStream<Row> dataStream =
+ env.fromElements(
+ Row.of((Object) new String[] {"a", "c", "b", "c"}),
+ Row.of((Object) new String[] {"c", "d", "e"}),
+ Row.of((Object) new String[] {"a", "b", "c"}),
+ Row.of((Object) new String[] {"e", "f"}),
+ Row.of((Object) new String[] {"a", "c", "a"}));
+ Table inputTable = tEnv.fromDataStream(dataStream).as("input");
+
+ // Creates an CountVectorizer object and initialize its parameters
+ CountVectorizer countVectorizer = new CountVectorizer();
+
+ // Trains the CountVectorizer model
+ CountVectorizerModel model = countVectorizer.fit(inputTable);
+
+ // Uses the CountVectorizer model for predictions.
+ Table outputTable = model.transform(inputTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ String[] inputValue = (String[]) row.getField(countVectorizer.getInputCol());
+ SparseVector outputValue = (SparseVector) row.getField(countVectorizer.getOutputCol());
+ System.out.printf(
+ "Input Value: %-15s \tOutput Value: %s\n",
+ Arrays.toString(inputValue), outputValue.toString());
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java
new file mode 100644
index 0000000..76ccc13
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizer.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+/**
+ * An Estimator which converts a collection of text documents to vectors of token counts. When an
+ * a-priori dictionary is not available, CountVectorizer can be used as an estimator to extract the
+ * vocabulary, and generates a {@link CountVectorizerModel}. The model produces sparse
+ * representations for the documents over the vocabulary, which can then be passed to other
+ * algorithms like LDA.
+ */
+public class CountVectorizer
+ implements Estimator<CountVectorizer, CountVectorizerModel>,
+ CountVectorizerParams<CountVectorizer> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public CountVectorizer() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public CountVectorizerModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ double minDF = getMinDF();
+ double maxDF = getMaxDF();
+ if (minDF >= 1.0 && maxDF >= 1.0 || minDF < 1.0 && maxDF < 1.0) {
+ Preconditions.checkArgument(maxDF >= minDF, "maxDF must be >= minDF.");
+ }
+
+ String inputCol = getInputCol();
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ DataStream<String[]> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ (MapFunction<Row, String[]>)
+ value -> ((String[]) value.getField(inputCol)));
+
+ DataStream<CountVectorizerModelData> modelData =
+ DataStreamUtils.aggregate(
+ inputData,
+ new VocabularyAggregator(getMinDF(), getMaxDF(), getVocabularySize()));
+
+ CountVectorizerModel model =
+ new CountVectorizerModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ /**
+ * Extracts a vocabulary from input document collections and builds the {@link
+ * CountVectorizerModelData}.
+ */
+ private static class VocabularyAggregator
+ implements AggregateFunction<
+ String[],
+ Tuple2<Long, Map<String, Tuple2<Long, Long>>>,
+ CountVectorizerModelData> {
+ private final double minDF;
+ private final double maxDF;
+ private final int vocabularySize;
+
+ public VocabularyAggregator(double minDF, double maxDF, int vocabularySize) {
+ this.minDF = minDF;
+ this.maxDF = maxDF;
+ this.vocabularySize = vocabularySize;
+ }
+
+ @Override
+ public Tuple2<Long, Map<String, Tuple2<Long, Long>>> createAccumulator() {
+ return Tuple2.of(0L, new HashMap<>());
+ }
+
+ @Override
+ public Tuple2<Long, Map<String, Tuple2<Long, Long>>> add(
+ String[] terms, Tuple2<Long, Map<String, Tuple2<Long, Long>>> vocabAccumulator) {
+ Map<String, Long> wc = new HashMap<>();
+ Arrays.stream(terms)
+ .forEach(
+ x -> {
+ if (wc.containsKey(x)) {
+ wc.put(x, wc.get(x) + 1);
+ } else {
+ wc.put(x, 1L);
+ }
+ });
+
+ Map<String, Tuple2<Long, Long>> counts = vocabAccumulator.f1;
+ wc.forEach(
+ (w, c) -> {
+ if (counts.containsKey(w)) {
+ counts.get(w).f0 += c;
+ counts.get(w).f1 += 1;
+ } else {
+ counts.put(w, Tuple2.of(c, 1L));
+ }
+ });
+ vocabAccumulator.f0 += 1;
+
+ return vocabAccumulator;
+ }
+
+ @Override
+ public CountVectorizerModelData getResult(
+ Tuple2<Long, Map<String, Tuple2<Long, Long>>> vocabAccumulator) {
+ Preconditions.checkState(vocabAccumulator.f0 > 0, "The training set is empty.");
+
+ boolean filteringRequired =
+ !MIN_DF.defaultValue.equals(minDF) || !MAX_DF.defaultValue.equals(maxDF);
+ if (filteringRequired) {
+ long rowNum = vocabAccumulator.f0;
+ double actualMinDF = minDF >= 1.0 ? minDF : minDF * rowNum;
+ double actualMaxDF = maxDF >= 1.0 ? maxDF : maxDF * rowNum;
+ Preconditions.checkState(actualMaxDF >= actualMinDF, "maxDF must be >= minDF.");
+
+ vocabAccumulator.f1 =
+ vocabAccumulator.f1.entrySet().stream()
+ .filter(
+ x ->
+ x.getValue().f1 >= actualMinDF
+ && x.getValue().f1 <= actualMaxDF)
+ .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
+ }
+
+ List<Map.Entry<String, Tuple2<Long, Long>>> list =
+ new ArrayList<>(vocabAccumulator.f1.entrySet());
+ list.sort((o1, o2) -> (o2.getValue().f1.compareTo(o1.getValue().f1)));
+ List<String> vocabulary =
+ list.stream().map(Map.Entry::getKey).collect(Collectors.toList());
+ String[] topTerms =
+ vocabulary
+ .subList(0, Math.min(vocabulary.size(), vocabularySize))
+ .toArray(new String[0]);
+ return new CountVectorizerModelData(topTerms);
+ }
+
+ @Override
+ public Tuple2<Long, Map<String, Tuple2<Long, Long>>> merge(
+ Tuple2<Long, Map<String, Tuple2<Long, Long>>> acc1,
+ Tuple2<Long, Map<String, Tuple2<Long, Long>>> acc2) {
+ if (acc1.f0 == 0) {
+ return acc2;
+ }
+
+ if (acc2.f0 == 0) {
+ return acc1;
+ }
+ acc2.f0 += acc1.f0;
+ acc1.f1.forEach(
+ (term, counts) -> {
+ if (acc2.f1.containsKey(term)) {
+ acc2.f1.put(
+ term,
+ Tuple2.of(
+ counts.f0 + acc2.f1.get(term).f0,
+ counts.f1 + acc2.f1.get(term).f1));
+ } else {
+ acc2.f1.put(term, counts);
+ }
+ });
+ return acc2;
+ }
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static CountVectorizer load(StreamTableEnvironment tEnv, String path)
+ throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java
new file mode 100644
index 0000000..390d997
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModel.java
@@ -0,0 +1,188 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.stream.IntStream;
+
+/** A Model which transforms data using the model data computed by {@link CountVectorizer}. */
+public class CountVectorizerModel
+ implements Model<CountVectorizerModel>, CountVectorizerModelParams<CountVectorizerModel> {
+
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public CountVectorizerModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public CountVectorizerModel setModelData(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ CountVectorizerModelData.getModelDataStream(modelDataTable),
+ path,
+ new CountVectorizerModelData.ModelDataEncoder());
+ }
+
+ public static CountVectorizerModel load(StreamTableEnvironment tEnv, String path)
+ throws IOException {
+ CountVectorizerModel model = ReadWriteUtils.loadStageParam(path);
+ Table modelDataTable =
+ ReadWriteUtils.loadModelData(
+ tEnv, path, new CountVectorizerModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+ DataStream<Row> dataStream = tEnv.toDataStream(inputs[0]);
+ DataStream<CountVectorizerModelData> modelDataStream =
+ CountVectorizerModelData.getModelDataStream(modelDataTable);
+
+ final String broadcastModelKey = "broadcastModelKey";
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(inputs[0].getResolvedSchema());
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(), SparseVectorTypeInfo.INSTANCE),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(dataStream),
+ Collections.singletonMap(broadcastModelKey, modelDataStream),
+ inputList -> {
+ DataStream input = inputList.get(0);
+ return input.map(
+ new PredictOutputFunction(
+ getInputCol(),
+ broadcastModelKey,
+ getMinTF(),
+ getBinary()),
+ outputTypeInfo);
+ });
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ /** This operator loads model data and predicts result. */
+ private static class PredictOutputFunction extends RichMapFunction<Row, Row> {
+
+ private final String inputCol;
+ private final String broadcastKey;
+ private final double minTF;
+ private final boolean binary;
+ private Map<String, Integer> vocabulary;
+
+ public PredictOutputFunction(
+ String inputCol, String broadcastKey, double minTF, boolean binary) {
+ this.inputCol = inputCol;
+ this.broadcastKey = broadcastKey;
+ this.minTF = minTF;
+ this.binary = binary;
+ }
+
+ @Override
+ public Row map(Row row) throws Exception {
+ if (vocabulary == null) {
+ CountVectorizerModelData modelData =
+ (CountVectorizerModelData)
+ getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+ vocabulary = new HashMap<>();
+ IntStream.range(0, modelData.vocabulary.length)
+ .forEach(i -> vocabulary.put(modelData.vocabulary[i], i));
+ }
+
+ String[] document = (String[]) row.getField(inputCol);
+ double[] termCounts = new double[vocabulary.size()];
+ for (String word : document) {
+ if (vocabulary.containsKey(word)) {
+ termCounts[vocabulary.get(word)] += 1;
+ }
+ }
+
+ double actualMinTF = minTF >= 1.0 ? minTF : document.length * minTF;
+ List<Integer> indices = new ArrayList<>();
+ List<Double> values = new ArrayList<>();
+ for (int i = 0; i < termCounts.length; i++) {
+ if (termCounts[i] >= actualMinTF) {
+ indices.add(i);
+ if (binary) {
+ values.add(1.0);
+ } else {
+ values.add(termCounts[i]);
+ }
+ }
+ }
+
+ SparseVector outputVec =
+ Vectors.sparse(
+ termCounts.length,
+ indices.stream().mapToInt(i -> i).toArray(),
+ values.stream().mapToDouble(i -> i).toArray());
+ return Row.join(row, Row.of(outputVec));
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelData.java
new file mode 100644
index 0000000..6ac41a5
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelData.java
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.array.StringArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link CountVectorizerModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a data stream, and
+ * classes to save/load model data.
+ */
+public class CountVectorizerModelData {
+
+ /** The array over terms, only the terms in the vocabulary will be counted. */
+ public String[] vocabulary;
+
+ public CountVectorizerModelData() {}
+
+ public CountVectorizerModelData(String[] vocabulary) {
+ this.vocabulary = vocabulary;
+ }
+
+ /**
+ * Converts the table model to a data stream.
+ *
+ * @param modelDataTable The table model data.
+ * @return The data stream model data.
+ */
+ public static DataStream<CountVectorizerModelData> getModelDataStream(Table modelDataTable) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+ return tEnv.toDataStream(modelDataTable)
+ .map(x -> new CountVectorizerModelData((String[]) x.getField(0)));
+ }
+
+ /** Encoder for {@link CountVectorizerModelData}. */
+ public static class ModelDataEncoder implements Encoder<CountVectorizerModelData> {
+ @Override
+ public void encode(CountVectorizerModelData modelData, OutputStream outputStream)
+ throws IOException {
+ DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
+ StringArraySerializer.INSTANCE.serialize(modelData.vocabulary, dataOutputView);
+ }
+ }
+
+ /** Decoder for {@link CountVectorizerModelData}. */
+ public static class ModelDataDecoder extends SimpleStreamFormat<CountVectorizerModelData> {
+
+ @Override
+ public Reader<CountVectorizerModelData> createReader(
+ Configuration configuration, FSDataInputStream fsDataInputStream) {
+ return new Reader<CountVectorizerModelData>() {
+ @Override
+ public CountVectorizerModelData read() throws IOException {
+ DataInputView source = new DataInputViewStreamWrapper(fsDataInputStream);
+ try {
+ String[] vocabulary = StringArraySerializer.INSTANCE.deserialize(source);
+ return new CountVectorizerModelData(vocabulary);
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ fsDataInputStream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<CountVectorizerModelData> getProducedType() {
+ return TypeInformation.of(CountVectorizerModelData.class);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelParams.java
new file mode 100644
index 0000000..e725956
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerModelParams.java
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params for {@link CountVectorizerModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface CountVectorizerModelParams<T> extends HasInputCol<T>, HasOutputCol<T> {
+ Param<Double> MIN_TF =
+ new DoubleParam(
+ "minTF",
+ "Filter to ignore rare words in a document. For each document,"
+ + "terms with frequency/count less than the given threshold are ignored. "
+ + "If this is an integer >= 1, then this specifies a count (of times "
+ + "the term must appear in the document); if this is a double in [0,1), "
+ + "then this specifies a fraction (out of the document's token count).",
+ 1.0,
+ ParamValidators.gtEq(0.0));
+
+ Param<Boolean> BINARY =
+ new BooleanParam(
+ "binary",
+ "Binary toggle to control the output vector values. If True, all nonzero "
+ + "counts (after minTF filter applied) are set to 1.0.",
+ false);
+
+ default double getMinTF() {
+ return get(MIN_TF);
+ }
+
+ default T setMinTF(double value) {
+ return set(MIN_TF, value);
+ }
+
+ default boolean getBinary() {
+ return get(BINARY);
+ }
+
+ default T setBinary(boolean value) {
+ return set(BINARY, value);
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerParams.java
new file mode 100644
index 0000000..992a687
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/countvectorizer/CountVectorizerParams.java
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.countvectorizer;
+
+import org.apache.flink.ml.param.DoubleParam;
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params of {@link CountVectorizer}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface CountVectorizerParams<T> extends CountVectorizerModelParams<T> {
+ Param<Integer> VOCABULARY_SIZE =
+ new IntParam(
+ "vocabularySize",
+ "Max size of the vocabulary. CountVectorizer will build a vocabulary"
+ + "that only considers the top vocabulary size terms ordered by term "
+ + "frequency across the corpus.",
+ 1 << 18,
+ ParamValidators.gt(0));
+
+ Param<Double> MIN_DF =
+ new DoubleParam(
+ "minDF",
+ "Specifies the minimum number of different documents a term must"
+ + "appear in to be included in the vocabulary. If this is an integer >= 1,"
+ + "this specifies the number of documents the term must appear in;"
+ + "if this is a double in [0,1), then this specifies the fraction of documents.",
+ 1.0,
+ ParamValidators.gtEq(0.0));
+
+ Param<Double> MAX_DF =
+ new DoubleParam(
+ "maxDF",
+ "Specifies the maximum number of different documents a term could appear "
+ + "in to be included in the vocabulary. A term that appears more than "
+ + "the threshold will be ignored. If this is an integer >= 1, this "
+ + "specifies the maximum number of documents the term could appear in; "
+ + "if this is a double in [0,1), then this specifies the maximum "
+ + "fraction of documents the term could appear in.",
+ (double) Long.MAX_VALUE,
+ ParamValidators.gtEq(0.0));
+
+ default int getVocabularySize() {
+ return get(VOCABULARY_SIZE);
+ }
+
+ default T setVocabularySize(int value) {
+ return set(VOCABULARY_SIZE, value);
+ }
+
+ default double getMinDF() {
+ return get(MIN_DF);
+ }
+
+ default T setMinDF(double value) {
+ return set(MIN_DF, value);
+ }
+
+ default double getMaxDF() {
+ return get(MAX_DF);
+ }
+
+ default T setMaxDF(double value) {
+ return set(MAX_DF, value);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java
new file mode 100644
index 0000000..e81c889
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/CountVectorizerTest.java
@@ -0,0 +1,390 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizer;
+import org.apache.flink.ml.feature.countvectorizer.CountVectorizerModel;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+import java.util.stream.DoubleStream;
+import java.util.stream.IntStream;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+import static org.junit.Assert.fail;
+
+/** Tests {@link CountVectorizer} and {@link CountVectorizerModel}. */
+public class CountVectorizerTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table inputTable;
+
+ private static final double EPS = 1.0e-5;
+ private static final List<Row> INPUT_DATA =
+ new ArrayList<>(
+ Arrays.asList(
+ Row.of((Object) new String[] {"a", "c", "b", "c"}),
+ Row.of((Object) new String[] {"c", "d", "e"}),
+ Row.of((Object) new String[] {"a", "b", "c"}),
+ Row.of((Object) new String[] {"e", "f"}),
+ Row.of((Object) new String[] {"a", "c", "a"})));
+
+ private static final List<SparseVector> EXPECTED_OUTPUT =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.sparse(
+ 6,
+ IntStream.of(0, 1, 2).toArray(),
+ DoubleStream.of(2.0, 1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 6,
+ IntStream.of(0, 3, 4).toArray(),
+ DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 6,
+ IntStream.of(0, 1, 2).toArray(),
+ DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 6,
+ IntStream.of(3, 5).toArray(),
+ DoubleStream.of(1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 6,
+ IntStream.of(0, 1).toArray(),
+ DoubleStream.of(1.0, 2.0).toArray())));
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ env.getConfig().enableObjectReuse();
+ tEnv = StreamTableEnvironment.create(env);
+
+ inputTable = tEnv.fromDataStream(env.fromCollection(INPUT_DATA)).as("input");
+ }
+
+ private static void verifyPredictionResult(
+ Table output, String outputCol, List<SparseVector> expected) throws Exception {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) output).getTableEnvironment();
+ DataStream<SparseVector> stream =
+ tEnv.toDataStream(output)
+ .map(
+ (MapFunction<Row, SparseVector>)
+ row -> (SparseVector) row.getField(outputCol));
+ List<SparseVector> result = IteratorUtils.toList(stream.executeAndCollect());
+ compareResultCollections(expected, result, TestUtils::compare);
+ }
+
+ @Test
+ public void testParam() {
+ CountVectorizer countVectorizer = new CountVectorizer();
+ assertEquals("input", countVectorizer.getInputCol());
+ assertEquals("output", countVectorizer.getOutputCol());
+ assertEquals((double) Long.MAX_VALUE, countVectorizer.getMaxDF(), EPS);
+ assertEquals(1.0, countVectorizer.getMinDF(), EPS);
+ assertEquals(1.0, countVectorizer.getMinTF(), EPS);
+ assertEquals(1 << 18, countVectorizer.getVocabularySize());
+ assertFalse(countVectorizer.getBinary());
+
+ countVectorizer
+ .setInputCol("test_input")
+ .setOutputCol("test_output")
+ .setMinDF(0.1)
+ .setMaxDF(0.9)
+ .setMinTF(10)
+ .setVocabularySize(1000)
+ .setBinary(true);
+ assertEquals("test_input", countVectorizer.getInputCol());
+ assertEquals("test_output", countVectorizer.getOutputCol());
+ assertEquals(0.9, countVectorizer.getMaxDF(), EPS);
+ assertEquals(0.1, countVectorizer.getMinDF(), EPS);
+ assertEquals(10, countVectorizer.getMinTF(), EPS);
+ assertEquals(1000, countVectorizer.getVocabularySize());
+ assertTrue(countVectorizer.getBinary());
+ }
+
+ @Test
+ public void testInvalidMinMaxDF() {
+ String errMessage = "maxDF must be >= minDF.";
+ CountVectorizer countVectorizer = new CountVectorizer();
+ countVectorizer.setMaxDF(0.1);
+ countVectorizer.setMinDF(0.2);
+ try {
+ countVectorizer.fit(inputTable);
+ fail();
+ } catch (Throwable e) {
+ assertEquals(errMessage, e.getMessage());
+ }
+ countVectorizer.setMaxDF(1);
+ countVectorizer.setMinDF(2);
+ try {
+ countVectorizer.fit(inputTable);
+ fail();
+ } catch (Throwable e) {
+ assertEquals(errMessage, e.getMessage());
+ }
+ countVectorizer.setMaxDF(1);
+ countVectorizer.setMinDF(0.9);
+ try {
+ CountVectorizerModel model = countVectorizer.fit(inputTable);
+ Table output = model.transform(inputTable)[0];
+ output.execute().print();
+ fail();
+ } catch (Throwable e) {
+ assertEquals(errMessage, ExceptionUtils.getRootCause(e).getMessage());
+ }
+ countVectorizer.setMaxDF(0.1);
+ countVectorizer.setMinDF(10);
+ try {
+ CountVectorizerModel model = countVectorizer.fit(inputTable);
+ Table output = model.transform(inputTable)[0];
+ output.execute().print();
+ fail();
+ } catch (Throwable e) {
+ assertEquals(errMessage, ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+
+ @Test
+ public void testOutputSchema() {
+ CountVectorizer countVectorizer =
+ new CountVectorizer().setInputCol("test_input").setOutputCol("test_output");
+ CountVectorizerModel model = countVectorizer.fit(inputTable);
+ Table output = model.transform(inputTable.as("test_input"))[0];
+ assertEquals(
+ Arrays.asList("test_input", "test_output"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testFitAndPredict() throws Exception {
+ CountVectorizer countVectorizer = new CountVectorizer();
+ CountVectorizerModel model = countVectorizer.fit(inputTable);
+ Table output = model.transform(inputTable)[0];
+
+ verifyPredictionResult(output, countVectorizer.getOutputCol(), EXPECTED_OUTPUT);
+ }
+
+ @Test
+ public void testSaveLoadAndPredict() throws Exception {
+ CountVectorizer countVectorizer = new CountVectorizer();
+ CountVectorizer loadedCountVectorizer =
+ TestUtils.saveAndReload(
+ tEnv, countVectorizer, tempFolder.newFolder().getAbsolutePath());
+ CountVectorizerModel model = loadedCountVectorizer.fit(inputTable);
+ CountVectorizerModel loadedModel =
+ TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+ assertEquals(
+ Arrays.asList("vocabulary"),
+ loadedModel.getModelData()[0].getResolvedSchema().getColumnNames());
+ Table output = loadedModel.transform(inputTable)[0];
+ verifyPredictionResult(output, countVectorizer.getOutputCol(), EXPECTED_OUTPUT);
+ }
+
+ @Test
+ public void testFitOnEmptyData() {
+ Table emptyTable =
+ tEnv.fromDataStream(env.fromCollection(INPUT_DATA).filter(x -> x.getArity() == 0))
+ .as("input");
+ CountVectorizer countVectorizer = new CountVectorizer();
+ CountVectorizerModel model = countVectorizer.fit(emptyTable);
+ Table modelDataTable = model.getModelData()[0];
+ try {
+ modelDataTable.execute().print();
+ fail();
+ } catch (Throwable e) {
+ assertEquals("The training set is empty.", ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+
+ @Test
+ public void testMinMaxDF() throws Exception {
+ List<SparseVector> expectedOutput =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.sparse(
+ 4,
+ IntStream.of(0, 1, 2).toArray(),
+ DoubleStream.of(2.0, 1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 4,
+ IntStream.of(0, 3).toArray(),
+ DoubleStream.of(1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 4,
+ IntStream.of(0, 1, 2).toArray(),
+ DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 4,
+ IntStream.of(3).toArray(),
+ DoubleStream.of(1.0).toArray()),
+ Vectors.sparse(
+ 4,
+ IntStream.of(0, 1).toArray(),
+ DoubleStream.of(1.0, 2.0).toArray())));
+ CountVectorizer countVectorizer = new CountVectorizer().setMinDF(2).setMaxDF(4);
+ CountVectorizerModel model = countVectorizer.fit(inputTable);
+ Table output = model.transform(inputTable)[0];
+ verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput);
+
+ countVectorizer.setMinDF(0.4).setMaxDF(0.8);
+ model = countVectorizer.fit(inputTable);
+ output = model.transform(inputTable)[0];
+ verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput);
+ }
+
+ @Test
+ public void testMinTF() throws Exception {
+ List<SparseVector> expectedOutput =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.sparse(
+ 6,
+ IntStream.of(0).toArray(),
+ DoubleStream.of(2.0).toArray()),
+ Vectors.sparse(6, new int[0], new double[0]),
+ Vectors.sparse(6, new int[0], new double[0]),
+ Vectors.sparse(
+ 6,
+ IntStream.of(3, 5).toArray(),
+ DoubleStream.of(1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 6,
+ IntStream.of(1).toArray(),
+ DoubleStream.of(2.0).toArray())));
+ CountVectorizer countVectorizer = new CountVectorizer().setMinTF(0.5);
+ CountVectorizerModel model = countVectorizer.fit(inputTable);
+ Table output = model.transform(inputTable)[0];
+ verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput);
+ }
+
+ @Test
+ public void testBinary() throws Exception {
+ List<SparseVector> expectedOutput =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.sparse(
+ 6,
+ IntStream.of(0, 1, 2).toArray(),
+ DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 6,
+ IntStream.of(0, 3, 4).toArray(),
+ DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 6,
+ IntStream.of(0, 1, 2).toArray(),
+ DoubleStream.of(1.0, 1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 6,
+ IntStream.of(3, 5).toArray(),
+ DoubleStream.of(1.0, 1.0).toArray()),
+ Vectors.sparse(
+ 6,
+ IntStream.of(0, 1).toArray(),
+ DoubleStream.of(1.0, 1.0).toArray())));
+ CountVectorizer countVectorizer = new CountVectorizer().setBinary(true);
+ CountVectorizerModel model = countVectorizer.fit(inputTable);
+ Table output = model.transform(inputTable)[0];
+ verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput);
+ }
+
+ @Test
+ public void testVocabularySize() throws Exception {
+ List<SparseVector> expectedOutput =
+ new ArrayList<>(
+ Arrays.asList(
+ Vectors.sparse(
+ 2,
+ IntStream.of(0, 1).toArray(),
+ DoubleStream.of(2.0, 1.0).toArray()),
+ Vectors.sparse(
+ 2,
+ IntStream.of(0).toArray(),
+ DoubleStream.of(1.0).toArray()),
+ Vectors.sparse(
+ 2,
+ IntStream.of(0, 1).toArray(),
+ DoubleStream.of(1.0, 1.0).toArray()),
+ Vectors.sparse(2, new int[0], new double[0]),
+ Vectors.sparse(
+ 2,
+ IntStream.of(0, 1).toArray(),
+ DoubleStream.of(1.0, 2.0).toArray())));
+ CountVectorizer countVectorizer = new CountVectorizer().setVocabularySize(2);
+ CountVectorizerModel model = countVectorizer.fit(inputTable);
+ Table output = model.transform(inputTable)[0];
+ verifyPredictionResult(output, countVectorizer.getOutputCol(), expectedOutput);
+ }
+
+ @Test
+ public void testGetModelData() throws Exception {
+ CountVectorizer countVectorizer = new CountVectorizer();
+ CountVectorizerModel model = countVectorizer.fit(inputTable);
+ Table modelData = model.getModelData()[0];
+ assertEquals(Arrays.asList("vocabulary"), modelData.getResolvedSchema().getColumnNames());
+
+ DataStream<Row> output = tEnv.toDataStream(modelData);
+ List<Row> modelRows = IteratorUtils.toList(output.executeAndCollect());
+ String[] vocabulary = (String[]) modelRows.get(0).getField(0);
+ String[] expectedVocabulary = {"c", "a", "b", "e", "d", "f"};
+ assertArrayEquals(expectedVocabulary, vocabulary);
+ }
+
+ @Test
+ public void testSetModelData() throws Exception {
+ CountVectorizer countVectorizer = new CountVectorizer();
+ CountVectorizerModel modelA = countVectorizer.fit(inputTable);
+ Table modelData = modelA.getModelData()[0];
+ CountVectorizerModel modelB = new CountVectorizerModel().setModelData(modelData);
+ Table output = modelB.transform(inputTable)[0];
+ verifyPredictionResult(output, countVectorizer.getOutputCol(), EXPECTED_OUTPUT);
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/countvectorizer_example.py b/flink-ml-python/pyflink/examples/ml/feature/countvectorizer_example.py
new file mode 100644
index 0000000..347d40d
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/countvectorizer_example.py
@@ -0,0 +1,62 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates an CountVectorizer instance and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.countvectorizer import CountVectorizer
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input training and prediction data.
+input_table = t_env.from_data_stream(
+ env.from_collection([
+ (1, ['a', 'c', 'b', 'c'],),
+ (2, ['c', 'd', 'e'],),
+ (3, ['a', 'b', 'c'],),
+ (4, ['e', 'f'],),
+ (5, ['a', 'c', 'a'],),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['id', 'input', ],
+ [Types.INT(), Types.OBJECT_ARRAY(Types.STRING())])
+ ))
+
+# Creates an CountVectorizer object and initializes its parameters.
+count_vectorizer = CountVectorizer()
+
+# Trains the CountVectorizer Model.
+model = count_vectorizer.fit(input_table)
+
+# Uses the CountVectorizer Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_index = field_names.index(count_vectorizer.get_input_col())
+ output_index = field_names.index(count_vectorizer.get_output_col())
+ print('Input Value: %-20s Output Value: %10s' %
+ (str(result[input_index]), str(result[output_index])))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/countvectorizer.py b/flink-ml-python/pyflink/ml/lib/feature/countvectorizer.py
new file mode 100644
index 0000000..5388474
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/countvectorizer.py
@@ -0,0 +1,189 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+from pyflink.ml.lib.param import HasOutputCol, HasInputCol
+
+from pyflink.ml.core.wrapper import JavaWithParams
+
+from pyflink.ml.core.param import FloatParam, BooleanParam, ParamValidators, IntParam
+
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+
+
+class _CountVectorizerModelParams(
+ JavaWithParams,
+ HasInputCol,
+ HasOutputCol,
+):
+ """
+ Params for :class:`CountVectorizerModel`.
+ """
+ MIN_TF: FloatParam = FloatParam(
+ "min_t_f",
+ "Filter to ignore rare words in a document. For each document, "
+ "terms with frequency/count less than the given threshold are ignored."
+ "If this is an integer >= 1, then this specifies a count (of times "
+ "the term must appear in the document); if this is a double in [0,1), "
+ "then this specifies a fraction (out of the document's token count).",
+ 1.0,
+ ParamValidators.gt_eq(0.0)
+ )
+
+ BINARY: BooleanParam = BooleanParam(
+ "binary",
+ "Binary toggle to control the output vector values. If True, all "
+ "nonzero counts (after minTF filter applied) are set to 1.0.",
+ False
+ )
+
+ def __init__(self, java_params):
+ super(_CountVectorizerModelParams, self).__init__(java_params)
+
+ def set_min_tf(self, value: float):
+ return typing.cast(_CountVectorizerModelParams,
+ self.set(self.MIN_TF, float(value)))
+
+ def get_min_tf(self):
+ return self.get(self.MIN_TF)
+
+ def set_binary(self, value: bool):
+ return typing.cast(_CountVectorizerModelParams, self.set(self.BINARY, value))
+
+ def get_binary(self):
+ return self.get(self.BINARY)
+
+ @property
+ def min_tf(self):
+ return self.get_min_tf()
+
+ @property
+ def binary(self):
+ return self.get_binary()
+
+
+class _CountVectorizerParams(_CountVectorizerModelParams):
+ """
+ Params for :class:`CountVectorizer`.
+ """
+ VOCABULARY_SIZE: IntParam = IntParam(
+ "vocabulary_size",
+ "Max size of the vocabulary. CountVectorizer will build a vocabulary "
+ "that only considers the top vocabularySize terms ordered by term "
+ "frequency across the corpus.",
+ 1 << 18,
+ ParamValidators.gt(0)
+ )
+
+ MIN_DF: FloatParam = FloatParam(
+ "min_d_f",
+ "Specifies the minimum number of different documents a term must"
+ "appear in to be included in the vocabulary. If this is an "
+ "integer >= 1, this specifies the number of documents the term must "
+ "appear in; if this is a double in [0,1), then this specifies the "
+ "fraction of documents.",
+ 1.0,
+ ParamValidators.gt_eq(0.0)
+ )
+
+ MAX_DF: FloatParam = FloatParam(
+ "max_d_f",
+ "Specifies the maximum number of different documents a term could "
+ "appear in to be included in the vocabulary. A term that appears "
+ "more than the threshold will be ignored. If this is an integer >= 1,"
+ "this specifies the maximum number of documents the term could "
+ "appear in; if this is a double in [0,1), then this specifies the "
+ "maximum fraction of documents the term could appear in.",
+ float(2**63 - 1),
+ ParamValidators.gt_eq(0.0)
+ )
+
+ def __init__(self, java_params):
+ super(_CountVectorizerParams, self).__init__(java_params)
+
+ def set_vocabulary_size(self, value: str):
+ return typing.cast(_CountVectorizerParams, self.set(self.VOCABULARY_SIZE, value))
+
+ def get_vocabulary_size(self) -> str:
+ return self.get(self.VOCABULARY_SIZE)
+
+ def set_min_df(self, value: float):
+ return typing.cast(_CountVectorizerParams, self.set(self.MIN_DF, float(value)))
+
+ def get_min_df(self):
+ return self.get(self.MIN_DF)
+
+ def set_max_df(self, value: float):
+ return typing.cast(_CountVectorizerParams, self.set(self.MAX_DF, float(value)))
+
+ def get_max_df(self):
+ return self.get(self.MAX_DF)
+
+ @property
+ def vocabulary_size(self):
+ return self.get_vocabulary_size()
+
+ @property
+ def min_df(self):
+ return self.get_min_df()
+
+ @property
+ def max_df(self):
+ return self.get_max_df()
+
+
+class CountVectorizerModel(JavaFeatureModel, _CountVectorizerModelParams):
+ """
+ A Model which transforms data using the model data computed by CountVectorizer.
+ """
+
+ def __init__(self, java_model=None):
+ super(CountVectorizerModel, self).__init__(java_model)
+
+ @classmethod
+ def _java_model_package_name(cls) -> str:
+ return "countvectorizer"
+
+ @classmethod
+ def _java_model_class_name(cls) -> str:
+ return "CountVectorizerModel"
+
+
+class CountVectorizer(JavaFeatureEstimator, _CountVectorizerParams):
+ """
+ An Estimator which converts a collection of text documents
+ to vectors of token counts. When an a-priori dictionary is not available,
+ CountVectorizer can be used as an estimator to extract the vocabulary,
+ and generates a CountVectorizerModel. The model produces sparse
+ representations for the documents over the vocabulary, which can then
+ be passed to other algorithms like LDA.
+ """
+
+ def __init__(self):
+ super(CountVectorizer, self).__init__()
+
+ @classmethod
+ def _create_model(cls, java_model) -> CountVectorizerModel:
+ return CountVectorizerModel(java_model)
+
+ @classmethod
+ def _java_estimator_package_name(cls) -> str:
+ return "countvectorizer"
+
+ @classmethod
+ def _java_estimator_class_name(cls) -> str:
+ return "CountVectorizer"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_countvectorizer.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_countvectorizer.py
new file mode 100644
index 0000000..f878236
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_countvectorizer.py
@@ -0,0 +1,144 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+from typing import List
+
+from pyflink.common import Types
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+from pyflink.ml.core.linalg import Vectors, DenseVector
+
+from pyflink.ml.lib.feature.countvectorizer import CountVectorizer, CountVectorizerModel
+from pyflink.table import Table
+
+
+class CountVectorizerTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(CountVectorizerTest, self).setUp()
+ self.input_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (1, ['a', 'c', 'b', 'c'],),
+ (2, ['c', 'd', 'e'],),
+ (3, ['a', 'b', 'c'],),
+ (4, ['e', 'f'],),
+ (5, ['a', 'c', 'a'],),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['id', 'input', ],
+ [Types.INT(), Types.OBJECT_ARRAY(Types.STRING())])))
+
+ self.expected_output = [
+ Vectors.sparse(6, [0, 1, 2], [2.0, 1.0, 1.0]),
+ Vectors.sparse(6, [0, 3, 4], [1.0, 1.0, 1.0]),
+ Vectors.sparse(6, [0, 1, 2], [1.0, 1.0, 1.0]),
+ Vectors.sparse(6, [3, 5], [1.0, 1.0]),
+ Vectors.sparse(6, [0, 1], [1.0, 2.0]),
+ ]
+
+ def test_param(self):
+ count_vectorizer = CountVectorizer()
+ self.assertEqual('input', count_vectorizer.input_col)
+ self.assertEqual('output', count_vectorizer.output_col)
+ self.assertEqual(1, count_vectorizer.min_df)
+ self.assertEqual(float(2**63 - 1), count_vectorizer.max_df)
+ self.assertEqual(1, count_vectorizer.min_tf)
+ self.assertEqual(1 << 18, count_vectorizer.vocabulary_size)
+ self.assertFalse(count_vectorizer.binary)
+
+ count_vectorizer.\
+ set_input_col('test_input').\
+ set_output_col('test_output').\
+ set_min_df(0.1).\
+ set_max_df(0.9).\
+ set_min_tf(10).\
+ set_vocabulary_size(1000).\
+ set_binary(True)
+ self.assertEqual('test_input', count_vectorizer.input_col)
+ self.assertEqual('test_output', count_vectorizer.output_col)
+ self.assertEqual(0.1, count_vectorizer.min_df)
+ self.assertEqual(0.9, count_vectorizer.max_df)
+ self.assertEqual(10, count_vectorizer.min_tf)
+ self.assertEqual(1000, count_vectorizer.vocabulary_size)
+ self.assertTrue(count_vectorizer.binary)
+
+ def test_output_schema(self):
+ count_vectorizer = CountVectorizer()
+ model = count_vectorizer.fit(self.input_table)
+ output = model.transform(self.input_table.alias('id', 'input'))[0]
+ self.assertEqual(
+ ['id', 'input', 'output'],
+ output.get_schema().get_field_names())
+
+ def test_fit_and_predict(self):
+ count_vectorizer = CountVectorizer()
+ model = count_vectorizer.fit(self.input_table)
+ output = model.transform(self.input_table)[0]
+ self.verify_output_result(
+ output,
+ count_vectorizer.get_output_col(),
+ output.get_schema().get_field_names(),
+ self.expected_output)
+
+ def test_save_load_predict(self):
+ count_vectorizer = CountVectorizer()
+ reloaded_count_vectorizer = self.save_and_reload(count_vectorizer)
+ model = reloaded_count_vectorizer.fit(self.input_table)
+ reloaded_model = self.save_and_reload(model)
+ output = reloaded_model.transform(self.input_table)[0]
+ self.verify_output_result(
+ output,
+ count_vectorizer.get_output_col(),
+ output.get_schema().get_field_names(),
+ self.expected_output)
+
+ def test_get_model_data(self):
+ count_vectorizer = CountVectorizer()
+ model = count_vectorizer.fit(self.input_table)
+ model_data_table = model.get_model_data()[0]
+ self.assertEqual(["vocabulary"],
+ model_data_table.get_schema().get_field_names())
+ model_data = self.t_env.to_data_stream(model_data_table).execute_and_collect().next()
+ expected = ["c", "a", "b", "e", "d", "f"]
+ self.assertEqual(expected, model_data[0])
+
+ def test_set_model_data(self):
+ count_vectorizer = CountVectorizer()
+ model = count_vectorizer.fit(self.input_table)
+
+ new_model = CountVectorizerModel()
+ new_model.set_model_data(*model.get_model_data())
+ output = new_model.transform(self.input_table)[0]
+ self.verify_output_result(
+ output,
+ count_vectorizer.get_output_col(),
+ output.get_schema().get_field_names(),
+ self.expected_output)
+
+ def verify_output_result(
+ self, output: Table,
+ output_col: str,
+ field_names: List[str],
+ expected_result: List[DenseVector]):
+ collected_results = [result for result in
+ self.t_env.to_data_stream(output).execute_and_collect()]
+ results = []
+ for item in collected_results:
+ item.set_field_names(field_names)
+ results.append(item)
+ results.sort(key=lambda x: x[0])
+ results = list(map(lambda x: x[output_col], results))
+ self.assertEqual(expected_result, results)