You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/09/07 03:32:51 UTC
[flink-ml] branch master updated: [FLINK-28806] Add Transformer and Estimator for IDF
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new d224f56 [FLINK-28806] Add Transformer and Estimator for IDF
d224f56 is described below
commit d224f565dd7adfa59f10ae208c49159abf106635
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Wed Sep 7 11:32:46 2022 +0800
[FLINK-28806] Add Transformer and Estimator for IDF
This closes #150.
---
docs/content/docs/operators/feature/idf.md | 172 ++++++++++++++++
.../ml/common/datastream/DataStreamUtils.java | 131 +++++++++++++
.../flink/ml/common/datastream/TableUtils.java | 11 ++
.../ml/common/datastream/DataStreamUtilsTest.java | 34 ++++
.../flink/ml/examples/feature/IDFExample.java | 64 ++++++
.../java/org/apache/flink/ml/feature/idf/IDF.java | 169 ++++++++++++++++
.../org/apache/flink/ml/feature/idf/IDFModel.java | 149 ++++++++++++++
.../apache/flink/ml/feature/idf/IDFModelData.java | 124 ++++++++++++
.../flink/ml/feature/idf/IDFModelParams.java | 29 +++
.../org/apache/flink/ml/feature/idf/IDFParams.java | 45 +++++
.../java/org/apache/flink/ml/feature/IDFTest.java | 216 +++++++++++++++++++++
.../pyflink/examples/ml/feature/idf_example.py | 60 ++++++
flink-ml-python/pyflink/ml/lib/feature/idf.py | 106 ++++++++++
.../pyflink/ml/lib/feature/tests/test_idf.py | 128 ++++++++++++
14 files changed, 1438 insertions(+)
diff --git a/docs/content/docs/operators/feature/idf.md b/docs/content/docs/operators/feature/idf.md
new file mode 100644
index 0000000..5a29a4c
--- /dev/null
+++ b/docs/content/docs/operators/feature/idf.md
@@ -0,0 +1,172 @@
+---
+title: "IDF"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/IDF.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## IDF
+
+IDF computes the inverse document frequency (IDF) for the
+input documents. IDF is computed following
+`idf = log((m + 1) / (d(t) + 1))`, where `m` is the total
+number of documents and `d(t)` is the number of documents
+that contains `t`.
+
+IDFModel further uses the computed inverse document frequency
+to compute [tf-idf](https://en.wikipedia.org/wiki/Tf%E2%80%93idf).
+
+### Input Columns
+
+| Param name | Type | Default | Description |
+|:-----------|:-------|:----------|:-----------------|
+| inputCol | Vector | `"input"` | Input documents. |
+
+### Output Columns
+
+| Param name | Type | Default | Description |
+|:-----------|:-------|:-----------|:----------------------------|
+| outputCol | Vector | `"output"` | Tf-idf values of the input. |
+
+### Parameters
+
+Below are the parameters required by `IDFModel`.
+
+| Key | Default | Type | Required | Description |
+|:----------|:-----------|:-------|:---------|:--------------------|
+| inputCol | `"input"` | String | no | Input column name. |
+| outputCol | `"output"` | String | no | Output column name. |
+
+
+`IDF` needs parameters above and also below.
+
+| Key | Default | Type | Required | Description |
+|:-----------|:-----------|:--------|:---------|:---------------------------------------------------------------------|
+| minDocFreq | `0` | Integer | no | Minimum number of documents that a term should appear for filtering. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.idf.IDF;
+import org.apache.flink.ml.feature.idf.IDFModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that trains an IDF model and uses it for feature engineering. */
+public class IDFExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input data.
+ DataStream<Row> inputStream =
+ env.fromElements(
+ Row.of(Vectors.dense(0, 1, 0, 2)),
+ Row.of(Vectors.dense(0, 1, 2, 3)),
+ Row.of(Vectors.dense(0, 1, 0, 0)));
+
+ Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+ // Creates an IDF object and initializes its parameters.
+ IDF idf = new IDF().setMinDocFreq(2);
+
+ // Trains the IDF Model.
+ IDFModel model = idf.fit(inputTable);
+
+ // Uses the IDF Model for predictions.
+ Table outputTable = model.transform(inputTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ DenseVector inputValue = (DenseVector) row.getField(idf.getInputCol());
+ DenseVector outputValue = (DenseVector) row.getField(idf.getOutputCol());
+ System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
+ }
+ }
+}
+
+```
+
+{{< /tab>}}
+
+{{< tab "Python">}}
+
+```python
+# Simple program that trains an IDF model and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.idf import IDF
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input for training and prediction.
+input_table = t_env.from_data_stream(
+ env.from_collection([
+ (Vectors.dense(0, 1, 0, 2),),
+ (Vectors.dense(0, 1, 2, 3),),
+ (Vectors.dense(0, 1, 0, 0),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input', ],
+ [DenseVectorTypeInfo(), ])))
+
+# Creates an IDF object and initializes its parameters.
+idf = IDF().set_min_doc_freq(2)
+
+# Trains the IDF Model.
+model = idf.fit(input_table)
+
+# Uses the IDF Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_index = field_names.index(idf.get_input_col())
+ output_index = field_names.index(idf.get_output_col())
+ print('Input Value: ' + str(result[input_index]) +
+ '\tOutput Value: ' + str(result[output_index]))
+```
+
+{{< /tab>}}
+
+{{< /tabs>}}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index f8df7e3..b34dbdf 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -19,6 +19,7 @@
package org.apache.flink.ml.common.datastream;
import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.FlatMapFunction;
import org.apache.flink.api.common.functions.MapFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
@@ -116,6 +117,36 @@ public class DataStreamUtils {
}
}
+ /**
+ * Applies an {@link AggregateFunction} on a bounded stream. The output stream contains the
+ * aggregated result and its parallelism is one.
+ *
+ * @param input The input data stream.
+ * @param func The user defined aggregate function.
+ * @return The result data stream.
+ * @param <IN> The class type of the input.
+ * @param <ACC> The class type of the accumulated values.
+ * @param <OUT> The class type of the output values.
+ */
+ public static <IN, ACC, OUT> DataStream<OUT> aggregate(
+ DataStream<IN> input, AggregateFunction<IN, ACC, OUT> func) {
+ TypeInformation<ACC> accType =
+ TypeExtractor.getAggregateFunctionAccumulatorType(
+ func, input.getType(), null, true);
+ TypeInformation<OUT> outType =
+ TypeExtractor.getAggregateFunctionReturnType(func, input.getType(), null, true);
+
+ DataStream<ACC> partialAggregatedStream =
+ input.transform(
+ "partialAggregate", accType, new PartialAggregateOperator<>(func, accType));
+ DataStream<OUT> aggregatedStream =
+ partialAggregatedStream.transform(
+ "aggregate", outType, new AggregateOperator<>(func, accType));
+ aggregatedStream.getTransformation().setParallelism(1);
+
+ return aggregatedStream;
+ }
+
/**
* Performs a uniform sampling over the elements in a bounded data stream.
*
@@ -263,6 +294,106 @@ public class DataStreamUtils {
}
}
+ /**
+ * A stream operator to apply {@link AggregateFunction#add(IN, ACC)} on each partition of the
+ * input bounded data stream.
+ */
+ private static class PartialAggregateOperator<IN, ACC, OUT>
+ extends AbstractUdfStreamOperator<ACC, AggregateFunction<IN, ACC, OUT>>
+ implements OneInputStreamOperator<IN, ACC>, BoundedOneInput {
+ /** Type information of the accumulated result. */
+ private final TypeInformation<ACC> accType;
+ /** The accumulated result of the aggregate function in one partition. */
+ private ACC acc;
+ /** State of acc. */
+ private ListState<ACC> accState;
+
+ public PartialAggregateOperator(
+ AggregateFunction<IN, ACC, OUT> userFunction, TypeInformation<ACC> accType) {
+ super(userFunction);
+ this.accType = accType;
+ }
+
+ @Override
+ public void endInput() {
+ output.collect(new StreamRecord<>(acc));
+ }
+
+ @Override
+ public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+ acc = userFunction.add(streamRecord.getValue(), acc);
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws Exception {
+ super.initializeState(context);
+ accState =
+ context.getOperatorStateStore()
+ .getListState(new ListStateDescriptor<>("accState", accType));
+ acc =
+ OperatorStateUtils.getUniqueElement(accState, "accState")
+ .orElse(userFunction.createAccumulator());
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ super.snapshotState(context);
+ accState.clear();
+ accState.add(acc);
+ }
+ }
+
+ /**
+ * A stream operator to apply {@link AggregateFunction#merge(ACC, ACC)} and {@link
+ * AggregateFunction#getResult(ACC)} on the input bounded data stream.
+ */
+ private static class AggregateOperator<IN, ACC, OUT>
+ extends AbstractUdfStreamOperator<OUT, AggregateFunction<IN, ACC, OUT>>
+ implements OneInputStreamOperator<ACC, OUT>, BoundedOneInput {
+ /** Type information of the accumulated result. */
+ private final TypeInformation<ACC> accType;
+ /** The accumulated result of the aggregate function in the final partition. */
+ private ACC acc;
+ /** State of acc. */
+ private ListState<ACC> accState;
+
+ public AggregateOperator(
+ AggregateFunction<IN, ACC, OUT> userFunction, TypeInformation<ACC> accType) {
+ super(userFunction);
+ this.accType = accType;
+ }
+
+ @Override
+ public void endInput() {
+ output.collect(new StreamRecord<>(userFunction.getResult(acc)));
+ }
+
+ @Override
+ public void processElement(StreamRecord<ACC> streamRecord) throws Exception {
+ if (acc == null) {
+ acc = streamRecord.getValue();
+ } else {
+ acc = userFunction.merge(streamRecord.getValue(), acc);
+ }
+ }
+
+ @Override
+ public void initializeState(StateInitializationContext context) throws Exception {
+ super.initializeState(context);
+ accState =
+ context.getOperatorStateStore()
+ .getListState(new ListStateDescriptor<>("accState", accType));
+ acc = OperatorStateUtils.getUniqueElement(accState, "accState").orElse(null);
+ }
+
+ @Override
+ public void snapshotState(StateSnapshotContext context) throws Exception {
+ super.snapshotState(context);
+ accState.clear();
+ accState.add(acc);
+ }
+ }
+
/**
* Splits the input data into global batches of batchSize. After splitting, each global batch is
* further split into local batches for downstream operators with each worker has one batch.
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
index 4dc175a..9245c77 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
@@ -43,6 +43,17 @@ public class TableUtils {
return new RowTypeInfo(types, names);
}
+ // Retrieves the TypeInformation of a column by name. Returns null if the name does not exist in
+ // the input schema.
+ public static TypeInformation<?> getTypeInfoByName(ResolvedSchema schema, String name) {
+ for (Column column : schema.getColumns()) {
+ if (column.getName().equals(name)) {
+ return TypeInformation.of(column.getDataType().getConversionClass());
+ }
+ }
+ return null;
+ }
+
public static StreamExecutionEnvironment getExecutionEnvironment(StreamTableEnvironment tEnv) {
Table table = tEnv.fromValues();
DataStream<Row> dataStream = tEnv.toDataStream(table);
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
index a968a0e..fa48bb5 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
@@ -18,6 +18,7 @@
package org.apache.flink.ml.common.datastream;
+import org.apache.flink.api.common.functions.AggregateFunction;
import org.apache.flink.api.common.functions.MapPartitionFunction;
import org.apache.flink.api.common.functions.ReduceFunction;
import org.apache.flink.api.common.functions.RichMapPartitionFunction;
@@ -75,6 +76,16 @@ public class DataStreamUtilsTest {
assertArrayEquals(new long[] {190L}, sum.stream().mapToLong(Long::longValue).toArray());
}
+ @Test
+ public void testAggregate() throws Exception {
+ DataStream<Long> dataStream =
+ env.fromParallelCollection(new NumberSequenceIterator(0L, 19L), Types.LONG);
+ DataStream<String> result = DataStreamUtils.aggregate(dataStream, new TestAggregateFunc());
+ List<String> stringSum = IteratorUtils.toList(result.executeAndCollect());
+ assertEquals(1, stringSum.size());
+ assertEquals("190", stringSum.get(0));
+ }
+
@Test
public void testGenerateBatchData() throws Exception {
DataStream<Long> dataStream =
@@ -99,4 +110,27 @@ public class DataStreamUtilsTest {
out.collect(cnt);
}
}
+
+ /** A simple implementation for {@link AggregateFunction}. */
+ private static class TestAggregateFunc implements AggregateFunction<Long, Long, String> {
+ @Override
+ public Long createAccumulator() {
+ return 0L;
+ }
+
+ @Override
+ public Long add(Long element, Long acc) {
+ return element + acc;
+ }
+
+ @Override
+ public String getResult(Long acc) {
+ return String.valueOf(acc);
+ }
+
+ @Override
+ public Long merge(Long acc1, Long acc2) {
+ return acc1 + acc2;
+ }
+ }
}
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java
new file mode 100644
index 0000000..ffa4d71
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.idf.IDF;
+import org.apache.flink.ml.feature.idf.IDFModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that trains an IDF model and uses it for feature engineering. */
+public class IDFExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input data.
+ DataStream<Row> inputStream =
+ env.fromElements(
+ Row.of(Vectors.dense(0, 1, 0, 2)),
+ Row.of(Vectors.dense(0, 1, 2, 3)),
+ Row.of(Vectors.dense(0, 1, 0, 0)));
+
+ Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+ // Creates an IDF object and initializes its parameters.
+ IDF idf = new IDF().setMinDocFreq(2);
+
+ // Trains the IDF Model.
+ IDFModel model = idf.fit(inputTable);
+
+ // Uses the IDF Model for predictions.
+ Table outputTable = model.transform(inputTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+ DenseVector inputValue = (DenseVector) row.getField(idf.getInputCol());
+ DenseVector outputValue = (DenseVector) row.getField(idf.getOutputCol());
+ System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java
new file mode 100644
index 0000000..4a242f0
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.idf;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator that computes the inverse document frequency (IDF) for the input documents. IDF is
+ * computed following `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total number of documents
+ * and `d(t)` is the number of documents that contains `t`.
+ *
+ * <p>Users could filter out terms that appeared in little documents by setting {@link
+ * IDFParams#getMinDocFreq()}.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Tf%E2%80%93idf.
+ */
+public class IDF implements Estimator<IDF, IDFModel>, IDFParams<IDF> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ public IDF() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public IDFModel fit(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ final String inputCol = getInputCol();
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ DataStream<Vector> inputData =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ (MapFunction<Row, Vector>)
+ value -> ((Vector) value.getField(inputCol)));
+
+ DataStream<IDFModelData> modelData =
+ DataStreamUtils.aggregate(inputData, new IDFAggregator(getMinDocFreq()));
+
+ IDFModel model = new IDFModel().setModelData(tEnv.fromDataStream(modelData));
+ ReadWriteUtils.updateExistingParams(model, getParamMap());
+ return model;
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ public static IDF load(StreamTableEnvironment tEnv, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /** The main logic to compute the model data of IDF. */
+ private static class IDFAggregator
+ implements AggregateFunction<Vector, Tuple2<Long, DenseVector>, IDFModelData> {
+ private final int minDocFreq;
+
+ public IDFAggregator(int minDocFreq) {
+ this.minDocFreq = minDocFreq;
+ }
+
+ @Override
+ public Tuple2<Long, DenseVector> createAccumulator() {
+ return Tuple2.of(0L, new DenseVector(new double[0]));
+ }
+
+ @Override
+ public Tuple2<Long, DenseVector> add(
+ Vector vector, Tuple2<Long, DenseVector> numDocsAndDocFreq) {
+ if (numDocsAndDocFreq.f0 == 0) {
+ numDocsAndDocFreq.f1 = new DenseVector(vector.size());
+ }
+ numDocsAndDocFreq.f0 += 1L;
+
+ double[] values;
+ if (vector instanceof SparseVector) {
+ values = ((SparseVector) vector).values;
+ } else {
+ values = ((DenseVector) vector).values;
+ }
+ for (int i = 0; i < values.length; i++) {
+ values[i] = values[i] > 0 ? 1 : 0;
+ }
+
+ BLAS.axpy(1, vector, numDocsAndDocFreq.f1);
+
+ return numDocsAndDocFreq;
+ }
+
+ @Override
+ public IDFModelData getResult(Tuple2<Long, DenseVector> numDocsAndDocFreq) {
+ long numDocs = numDocsAndDocFreq.f0;
+ DenseVector docFreq = numDocsAndDocFreq.f1;
+ Preconditions.checkState(numDocs > 0, "The training set is empty.");
+
+ long[] filteredDocFreq = new long[docFreq.size()];
+ double[] df = docFreq.values;
+ double[] idf = new double[df.length];
+ for (int i = 0; i < idf.length; i++) {
+ if (df[i] >= minDocFreq) {
+ idf[i] = Math.log((numDocs + 1) / (df[i] + 1));
+ filteredDocFreq[i] = (long) df[i];
+ }
+ }
+ return new IDFModelData(Vectors.dense(idf), filteredDocFreq, numDocs);
+ }
+
+ @Override
+ public Tuple2<Long, DenseVector> merge(
+ Tuple2<Long, DenseVector> numDocsAndDocFreq1,
+ Tuple2<Long, DenseVector> numDocsAndDocFreq2) {
+ if (numDocsAndDocFreq1.f0 == 0) {
+ return numDocsAndDocFreq2;
+ }
+
+ if (numDocsAndDocFreq2.f0 == 0) {
+ return numDocsAndDocFreq1;
+ }
+
+ numDocsAndDocFreq2.f0 += numDocsAndDocFreq1.f0;
+ BLAS.axpy(1, numDocsAndDocFreq1.f1, numDocsAndDocFreq2.f1);
+ return numDocsAndDocFreq2;
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java
new file mode 100644
index 0000000..87a2f25
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.idf;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+/** A Model which transforms data using the model data computed by {@link IDF}. */
+public class IDFModel implements Model<IDFModel>, IDFModelParams<IDFModel> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+ private Table modelDataTable;
+
+ public IDFModel() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ @SuppressWarnings("unchecked")
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+ DataStream<IDFModelData> idfModelData = IDFModelData.getModelDataStream(modelDataTable);
+
+ final String broadcastModelKey = "broadcastModelKey";
+ ResolvedSchema schema = inputs[0].getResolvedSchema();
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(schema);
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(),
+ TableUtils.getTypeInfoByName(schema, getInputCol())),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+ DataStream<Row> output =
+ BroadcastUtils.withBroadcastStream(
+ Collections.singletonList(data),
+ Collections.singletonMap(broadcastModelKey, idfModelData),
+ inputList -> {
+ DataStream input = inputList.get(0);
+ return input.map(
+ new ComputeTfIdfFunction(broadcastModelKey, getInputCol()),
+ outputTypeInfo);
+ });
+
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ @Override
+ public IDFModel setModelData(Table... inputs) {
+ modelDataTable = inputs[0];
+ return this;
+ }
+
+ @Override
+ public Table[] getModelData() {
+ return new Table[] {modelDataTable};
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ ReadWriteUtils.saveModelData(
+ IDFModelData.getModelDataStream(modelDataTable),
+ path,
+ new IDFModelData.ModelDataEncoder());
+ }
+
+ public static IDFModel load(StreamTableEnvironment tEnv, String path) throws IOException {
+ IDFModel model = ReadWriteUtils.loadStageParam(path);
+
+ Table modelDataTable =
+ ReadWriteUtils.loadModelData(tEnv, path, new IDFModelData.ModelDataDecoder());
+ return model.setModelData(modelDataTable);
+ }
+
+ /** Computes the tf-idf for each term in the input document. */
+ private static class ComputeTfIdfFunction extends RichMapFunction<Row, Row> {
+ private final String inputCol;
+ private final String broadcastKey;
+ private DenseVector idf;
+
+ public ComputeTfIdfFunction(String broadcastKey, String inputCol) {
+ this.broadcastKey = broadcastKey;
+ this.inputCol = inputCol;
+ }
+
+ @Override
+ public Row map(Row row) {
+ if (idf == null) {
+ IDFModelData idfModelDataData =
+ (IDFModelData)
+ getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+ idf = idfModelDataData.idf;
+ }
+
+ Vector outputVec = ((Vector) Objects.requireNonNull(row.getField(inputCol))).clone();
+ BLAS.hDot(idf, outputVec);
+ return Row.join(row, Row.of(outputVec));
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java
new file mode 100644
index 0000000..a808454
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.idf;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.api.common.typeutils.base.array.LongPrimitiveArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link IDFModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a data stream, and
+ * classes to save/load model data.
+ */
+public class IDFModelData {
+ /** Inverse document frequency for all terms. */
+ public DenseVector idf;
+ /** Document frequency for all terms after filtering out infrequent terms. */
+ public long[] docFreq;
+ /** Number of docs in the training set. */
+ public long numDocs;
+
+ public IDFModelData() {}
+
+ public IDFModelData(DenseVector idf, long[] docFreq, long numDocs) {
+ this.idf = idf;
+ this.docFreq = docFreq;
+ this.numDocs = numDocs;
+ }
+
+ /**
+ * Converts the table model to a data stream.
+ *
+ * @param modelDataTable The table model data.
+ * @return The data stream model data.
+ */
+ public static DataStream<IDFModelData> getModelDataStream(Table modelDataTable) {
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+ return tEnv.toDataStream(modelDataTable)
+ .map(x -> new IDFModelData(x.getFieldAs(0), x.getFieldAs(1), x.getFieldAs(2)));
+ }
+
+ /** Encoder for {@link IDFModelData}. */
+ public static class ModelDataEncoder implements Encoder<IDFModelData> {
+ private final DenseVectorSerializer denseVectorSerializer = new DenseVectorSerializer();
+
+ @Override
+ public void encode(IDFModelData modelData, OutputStream outputStream) throws IOException {
+ DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
+ denseVectorSerializer.serialize(modelData.idf, dataOutputView);
+ LongPrimitiveArraySerializer.INSTANCE.serialize(modelData.docFreq, dataOutputView);
+ LongSerializer.INSTANCE.serialize(modelData.numDocs, dataOutputView);
+ }
+ }
+
+ /** Decoder for {@link IDFModelData}. */
+ public static class ModelDataDecoder extends SimpleStreamFormat<IDFModelData> {
+ @Override
+ public Reader<IDFModelData> createReader(Configuration config, FSDataInputStream stream) {
+ return new Reader<IDFModelData>() {
+ private final DenseVectorSerializer denseVectorSerializer =
+ new DenseVectorSerializer();
+
+ @Override
+ public IDFModelData read() throws IOException {
+ DataInputView source = new DataInputViewStreamWrapper(stream);
+ try {
+ DenseVector idf = denseVectorSerializer.deserialize(source);
+ long[] docFreq = LongPrimitiveArraySerializer.INSTANCE.deserialize(source);
+ long numDocs = LongSerializer.INSTANCE.deserialize(source);
+ return new IDFModelData(idf, docFreq, numDocs);
+ } catch (EOFException e) {
+ return null;
+ }
+ }
+
+ @Override
+ public void close() throws IOException {
+ stream.close();
+ }
+ };
+ }
+
+ @Override
+ public TypeInformation<IDFModelData> getProducedType() {
+ return TypeInformation.of(IDFModelData.class);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelParams.java
new file mode 100644
index 0000000..168efa5
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelParams.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.idf;
+
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+
+/**
+ * Params for {@link IDFModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface IDFModelParams<T> extends HasInputCol<T>, HasOutputCol<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFParams.java
new file mode 100644
index 0000000..14ca550
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFParams.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.idf;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params for {@link IDF}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface IDFParams<T> extends IDFModelParams<T> {
+ Param<Integer> MIN_DOC_FREQ =
+ new IntParam(
+ "minDocFreq",
+ "Minimum number of documents that a term should appear for filtering.",
+ 0,
+ ParamValidators.gtEq(0));
+
+ default int getMinDocFreq() {
+ return get(MIN_DOC_FREQ);
+ }
+
+ default T setMinDocFreq(Integer value) {
+ return set(MIN_DOC_FREQ, value);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
new file mode 100644
index 0000000..e4cae5a
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.idf.IDF;
+import org.apache.flink.ml.feature.idf.IDFModel;
+import org.apache.flink.ml.feature.idf.IDFModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.apache.flink.table.api.Expressions.$;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests {@link IDF} and {@link IDFModel}. */
+public class IDFTest extends AbstractTestBase {
+ @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+ private StreamExecutionEnvironment env;
+ private StreamTableEnvironment tEnv;
+ private Table inputTable;
+
+ private static final List<DenseVector> expectedOutput =
+ Arrays.asList(
+ Vectors.dense(0, 0, 0, 0.5753641),
+ Vectors.dense(0, 0, 1.3862943, 0.8630462),
+ Vectors.dense(0, 0, 0, 0));
+ private static final List<DenseVector> expectedOutputMinDocFreqAsTwo =
+ Arrays.asList(
+ Vectors.dense(0, 0, 0, 0.5753641),
+ Vectors.dense(0, 0, 0, 0.8630462),
+ Vectors.dense(0, 0, 0, 0));
+ private static final double TOLERANCE = 1e-7;
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+
+ List<DenseVector> input =
+ Arrays.asList(
+ Vectors.dense(0, 1, 0, 2),
+ Vectors.dense(0, 1, 2, 3),
+ Vectors.dense(0, 1, 0, 0));
+ inputTable = tEnv.fromDataStream(env.fromCollection(input).map(x -> x)).as("input");
+ }
+
+ @SuppressWarnings("unchecked")
+ private void verifyPredictionResult(
+ List<DenseVector> expectedOutput, Table output, String predictionCol) throws Exception {
+ List<Row> collectedResult =
+ IteratorUtils.toList(
+ tEnv.toDataStream(output.select($(predictionCol))).executeAndCollect());
+ List<DenseVector> actualOutputs = new ArrayList<>(expectedOutput.size());
+ collectedResult.forEach(x -> actualOutputs.add((x.getFieldAs(0))));
+
+ actualOutputs.sort(TestUtils::compare);
+ expectedOutput.sort(TestUtils::compare);
+ assertEquals(expectedOutput.size(), collectedResult.size());
+ for (int i = 0; i < expectedOutput.size(); i++) {
+ assertArrayEquals(expectedOutput.get(i).values, actualOutputs.get(i).values, TOLERANCE);
+ }
+ }
+
+ @Test
+ public void testParam() {
+ IDF idf = new IDF();
+ assertEquals("input", idf.getInputCol());
+ assertEquals(0, idf.getMinDocFreq());
+ assertEquals("output", idf.getOutputCol());
+
+ idf.setInputCol("test_input").setMinDocFreq(2).setOutputCol("test_output");
+
+ assertEquals("test_input", idf.getInputCol());
+ assertEquals(2, idf.getMinDocFreq());
+ assertEquals("test_output", idf.getOutputCol());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ Table tempTable =
+ tEnv.fromDataStream(env.fromElements(Row.of("", "")))
+ .as("test_input", "dummy_input");
+ IDF idf = new IDF().setInputCol("test_input").setOutputCol("test_output");
+ Table output = idf.fit(tempTable).transform(tempTable)[0];
+
+ assertEquals(
+ Arrays.asList("test_input", "dummy_input", "test_output"),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testFitAndPredict() throws Exception {
+ IDF idf = new IDF();
+ Table output;
+
+ // Tests minDocFreq = 0.
+ output = idf.fit(inputTable).transform(inputTable)[0];
+ verifyPredictionResult(expectedOutput, output, idf.getOutputCol());
+
+ // Tests minDocFreq = 2.
+ idf.setMinDocFreq(2);
+ output = idf.fit(inputTable).transform(inputTable)[0];
+ verifyPredictionResult(expectedOutputMinDocFreqAsTwo, output, idf.getOutputCol());
+ }
+
+ @Test
+ public void testSaveLoadAndPredict() throws Exception {
+ IDF idf = new IDF();
+ idf = TestUtils.saveAndReload(tEnv, idf, tempFolder.newFolder().getAbsolutePath());
+
+ IDFModel model = idf.fit(inputTable);
+ model = TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+
+ assertEquals(
+ Arrays.asList("idf", "docFreq", "numDocs"),
+ model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+ Table output = model.transform(inputTable)[0];
+ verifyPredictionResult(expectedOutput, output, idf.getOutputCol());
+ }
+
+ @Test
+ @SuppressWarnings("unchecked")
+ public void testGetModelData() throws Exception {
+ IDFModel model = new IDF().fit(inputTable);
+ Table modelDataTable = model.getModelData()[0];
+
+ assertEquals(
+ Arrays.asList("idf", "docFreq", "numDocs"),
+ modelDataTable.getResolvedSchema().getColumnNames());
+
+ List<IDFModelData> collectedModelData =
+ (List<IDFModelData>)
+ IteratorUtils.toList(
+ IDFModelData.getModelDataStream(modelDataTable)
+ .executeAndCollect());
+
+ assertEquals(1, collectedModelData.size());
+ IDFModelData modelData = collectedModelData.get(0);
+ assertEquals(3, modelData.numDocs);
+ assertArrayEquals(new long[] {0, 3, 1, 2}, modelData.docFreq);
+ assertArrayEquals(
+ new double[] {1.3862943, 0, 0.6931471, 0.2876820}, modelData.idf.values, TOLERANCE);
+ }
+
+ @Test
+ public void testSetModelData() throws Exception {
+ IDFModel model = new IDF().fit(inputTable);
+
+ IDFModel newModel = new IDFModel();
+ ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+ newModel.setModelData(model.getModelData());
+ Table output = newModel.transform(inputTable)[0];
+
+ verifyPredictionResult(expectedOutput, output, model.getOutputCol());
+ }
+
+ @Test
+ public void testFitOnEmptyData() {
+ Table emptyTable =
+ tEnv.fromDataStream(env.fromElements(Row.of(1, 2)).filter(x -> x.getArity() == 0))
+ .as("input");
+ IDFModel model = new IDF().fit(emptyTable);
+ Table modelDataTable = model.getModelData()[0];
+
+ try {
+ modelDataTable.execute().collect().next();
+ fail();
+ } catch (Throwable e) {
+ assertEquals("The training set is empty.", ExceptionUtils.getRootCause(e).getMessage());
+ }
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/idf_example.py b/flink-ml-python/pyflink/examples/ml/feature/idf_example.py
new file mode 100644
index 0000000..fb8e778
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/idf_example.py
@@ -0,0 +1,60 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that trains an IDF model and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.idf import IDF
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input for training and prediction.
+input_table = t_env.from_data_stream(
+ env.from_collection([
+ (Vectors.dense(0, 1, 0, 2),),
+ (Vectors.dense(0, 1, 2, 3),),
+ (Vectors.dense(0, 1, 0, 0),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input', ],
+ [DenseVectorTypeInfo(), ])))
+
+# Creates an IDF object and initializes its parameters.
+idf = IDF().set_min_doc_freq(2)
+
+# Trains the IDF Model.
+model = idf.fit(input_table)
+
+# Uses the IDF Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_index = field_names.index(idf.get_input_col())
+ output_index = field_names.index(idf.get_output_col())
+ print('Input Value: ' + str(result[input_index]) +
+ '\tOutput Value: ' + str(result[output_index]))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/idf.py b/flink-ml-python/pyflink/ml/lib/feature/idf.py
new file mode 100644
index 0000000..dd876e2
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/idf.py
@@ -0,0 +1,106 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.core.param import IntParam, ParamValidators
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+from pyflink.ml.lib.param import HasInputCol, HasOutputCol
+
+
+class _IDFModelParams(
+ JavaWithParams,
+ HasInputCol,
+ HasOutputCol
+):
+ """
+ Params for :class:`IDFModel`.
+ """
+
+ def __init__(self, java_params):
+ super(_IDFModelParams, self).__init__(java_params)
+
+
+class _IDFParams(_IDFModelParams):
+ """
+ Params for :class:`IDF`.
+ """
+
+ MIN_DOC_FREQ: IntParam = IntParam(
+ "min_doc_freq",
+ "Minimum number of documents that a term should appear for filtering.",
+ 0,
+ ParamValidators.gt_eq(0))
+
+ def __init__(self, java_params):
+ super(_IDFParams, self).__init__(java_params)
+
+ def set_min_doc_freq(self, value: int):
+ return typing.cast(_IDFParams, self.set(self.MIN_DOC_FREQ, value))
+
+ def get_min_doc_freq(self) -> int:
+ return self.get(self.MIN_DOC_FREQ)
+
+ @property
+ def min_doc_freq(self):
+ return self.get_min_doc_freq()
+
+
+class IDFModel(JavaFeatureModel, _IDFModelParams):
+ """
+ A Model which transforms data using the model data computed by :class::IDF.
+ """
+
+ def __init__(self, java_model=None):
+ super(IDFModel, self).__init__(java_model)
+
+ @classmethod
+ def _java_model_package_name(cls) -> str:
+ return "idf"
+
+ @classmethod
+ def _java_model_class_name(cls) -> str:
+ return "IDFModel"
+
+
+class IDF(JavaFeatureEstimator, _IDFParams):
+ """
+ An Estimator that computes the inverse document frequency (IDF) for the input documents.
+ IDF is computed following `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total
+ number of documents and `d(t)` is the number of documents that contains `t`.
+
+ <p>Users could filter out terms that appeared in little documents by setting
+ {@link IDFParams#getMinDocFreq()}.
+
+ <p>See https://en.wikipedia.org/wiki/Tf%E2%80%93idf.
+ """
+
+ def __init__(self):
+ super(IDF, self).__init__()
+
+ @classmethod
+ def _create_model(cls, java_model) -> IDFModel:
+ return IDFModel(java_model)
+
+ @classmethod
+ def _java_estimator_package_name(cls) -> str:
+ return "idf"
+
+ @classmethod
+ def _java_estimator_class_name(cls) -> str:
+ return "IDF"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py
new file mode 100644
index 0000000..d192ef3
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py
@@ -0,0 +1,128 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.feature.idf import IDF, IDFModel
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class IDFTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(IDFTest, self).setUp()
+ self.input_data = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (Vectors.dense(0, 1, 0, 2),),
+ (Vectors.dense(0, 1, 2, 3),),
+ (Vectors.dense(0, 1, 0, 0),),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input', ],
+ [DenseVectorTypeInfo(), ])))
+
+ self.expected_output = [
+ Vectors.dense(0.0, 0.0, 0.0, 0.5753641),
+ Vectors.dense(0.0, 0.0, 1.3862943, 0.8630462),
+ Vectors.dense(0.0, 0.0, 0.0, 0.0),
+ ]
+
+ self.expected_output_min_doc_freq_as_two = [
+ Vectors.dense(0.0, 0.0, 0.0, 0.5753641),
+ Vectors.dense(0.0, 0.0, 0.0, 0.8630462),
+ Vectors.dense(0.0, 0.0, 0.0, 0.0),
+ ]
+
+ self.tolerance = 1e-7
+
+ def verify_prediction_result(self, expected, output_table):
+ predicted_results = [result[1] for result in
+ self.t_env.to_data_stream(output_table).execute_and_collect()]
+
+ predicted_results.sort(key=lambda x: x[3])
+ expected.sort(key=lambda x: x[3])
+
+ self.assertEqual(len(expected), len(predicted_results))
+ for i in range(len(expected)):
+ expected_row = expected[i]
+ predicted_row = predicted_results[i]
+ self.assertEqual(len(expected_row), len(predicted_row))
+ for j in range(len(expected_row)):
+ self.assertAlmostEqual(expected_row[j], predicted_row[j], delta=self.tolerance)
+
+ def test_param(self):
+ idf = IDF()
+ self.assertEqual("input", idf.input_col)
+ self.assertEqual(0, idf.min_doc_freq)
+ self.assertEqual("output", idf.output_col)
+
+ idf \
+ .set_input_col("test_input") \
+ .set_min_doc_freq(2) \
+ .set_output_col("test_output")
+
+ self.assertEqual("test_input", idf.input_col)
+ self.assertEqual(2, idf.min_doc_freq)
+ self.assertEqual("test_output", idf.output_col)
+
+ def test_output_schema(self):
+ idf = IDF() \
+ .set_input_col("test_input") \
+ .set_output_col("test_output")
+ input_data_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (Vectors.dense(1), ''),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['test_input', 'dummy_input'],
+ [DenseVectorTypeInfo(), Types.STRING()])))
+ output = idf \
+ .fit(input_data_table) \
+ .transform(input_data_table)[0]
+
+ self.assertEqual(
+ [idf.input_col, 'dummy_input', idf.output_col],
+ output.get_schema().get_field_names())
+
+ def test_fit_and_predict(self):
+ idf = IDF()
+ # Tests minDocFreq = 0.
+ output = idf.fit(self.input_data).transform(self.input_data)[0]
+ self.verify_prediction_result(self.expected_output, output)
+
+ # Tests minDocFreq = 2.
+ idf.set_min_doc_freq(2)
+ output = idf.fit(self.input_data).transform(self.input_data)[0]
+ self.verify_prediction_result(self.expected_output_min_doc_freq_as_two, output)
+
+ def test_save_load_predict(self):
+ idf = IDF()
+ estimator_path = os.path.join(self.temp_dir, 'test_save_load_predict_idf')
+ idf.save(estimator_path)
+ idf = IDF.load(self.t_env, estimator_path)
+
+ model = idf.fit(self.input_data)
+ model_path = os.path.join(self.temp_dir, 'test_save_load_predict_idf_model')
+ model.save(model_path)
+ self.env.execute('save_model')
+ model = IDFModel.load(self.t_env, model_path)
+ output = model.transform(self.input_data)[0]
+
+ self.verify_prediction_result(self.expected_output, output)