You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/09/07 03:32:51 UTC

[flink-ml] branch master updated: [FLINK-28806] Add Transformer and Estimator for IDF

This is an automated email from the ASF dual-hosted git repository.

lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new d224f56  [FLINK-28806] Add Transformer and Estimator for IDF
d224f56 is described below

commit d224f565dd7adfa59f10ae208c49159abf106635
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Wed Sep 7 11:32:46 2022 +0800

    [FLINK-28806] Add Transformer and Estimator for IDF
    
    This closes #150.
---
 docs/content/docs/operators/feature/idf.md         | 172 ++++++++++++++++
 .../ml/common/datastream/DataStreamUtils.java      | 131 +++++++++++++
 .../flink/ml/common/datastream/TableUtils.java     |  11 ++
 .../ml/common/datastream/DataStreamUtilsTest.java  |  34 ++++
 .../flink/ml/examples/feature/IDFExample.java      |  64 ++++++
 .../java/org/apache/flink/ml/feature/idf/IDF.java  | 169 ++++++++++++++++
 .../org/apache/flink/ml/feature/idf/IDFModel.java  | 149 ++++++++++++++
 .../apache/flink/ml/feature/idf/IDFModelData.java  | 124 ++++++++++++
 .../flink/ml/feature/idf/IDFModelParams.java       |  29 +++
 .../org/apache/flink/ml/feature/idf/IDFParams.java |  45 +++++
 .../java/org/apache/flink/ml/feature/IDFTest.java  | 216 +++++++++++++++++++++
 .../pyflink/examples/ml/feature/idf_example.py     |  60 ++++++
 flink-ml-python/pyflink/ml/lib/feature/idf.py      | 106 ++++++++++
 .../pyflink/ml/lib/feature/tests/test_idf.py       | 128 ++++++++++++
 14 files changed, 1438 insertions(+)

diff --git a/docs/content/docs/operators/feature/idf.md b/docs/content/docs/operators/feature/idf.md
new file mode 100644
index 0000000..5a29a4c
--- /dev/null
+++ b/docs/content/docs/operators/feature/idf.md
@@ -0,0 +1,172 @@
+---
+title: "IDF"
+weight: 1
+type: docs
+aliases:
+- /operators/feature/IDF.html
+---
+
+<!--
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements.  See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership.  The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License.  You may obtain a copy of the License at
+
+  http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied.  See the License for the
+specific language governing permissions and limitations
+under the License.
+-->
+
+## IDF
+
+IDF computes the inverse document frequency (IDF) for the
+input documents. IDF is computed following 
+`idf = log((m + 1) / (d(t) + 1))`, where `m` is the total
+number of documents and `d(t)` is the number of documents
+that contains `t`.
+
+IDFModel further uses the computed inverse document frequency
+to compute [tf-idf](https://en.wikipedia.org/wiki/Tf%E2%80%93idf).
+
+### Input Columns
+
+| Param name | Type   | Default   | Description      |
+|:-----------|:-------|:----------|:-----------------|
+| inputCol   | Vector | `"input"` | Input documents. |
+
+### Output Columns
+
+| Param name | Type   | Default    | Description                 |
+|:-----------|:-------|:-----------|:----------------------------|
+| outputCol  | Vector | `"output"` | Tf-idf values of the input. |
+
+### Parameters
+
+Below are the parameters required by `IDFModel`.
+
+| Key       | Default    | Type   | Required | Description         |
+|:----------|:-----------|:-------|:---------|:--------------------|
+| inputCol  | `"input"`  | String | no       | Input column name.  |
+| outputCol | `"output"` | String | no       | Output column name. |
+
+
+`IDF` needs parameters above and also below.
+
+| Key        | Default    | Type    | Required | Description                                                          |
+|:-----------|:-----------|:--------|:---------|:---------------------------------------------------------------------|
+| minDocFreq | `0`        | Integer | no       | Minimum number of documents that a term should appear for filtering. |
+
+### Examples
+
+{{< tabs examples >}}
+
+{{< tab "Java">}}
+
+```java
+import org.apache.flink.ml.feature.idf.IDF;
+import org.apache.flink.ml.feature.idf.IDFModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that trains an IDF model and uses it for feature engineering. */
+public class IDFExample {
+	public static void main(String[] args) {
+		StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+		StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+		// Generates input data.
+		DataStream<Row> inputStream =
+			env.fromElements(
+				Row.of(Vectors.dense(0, 1, 0, 2)),
+				Row.of(Vectors.dense(0, 1, 2, 3)),
+				Row.of(Vectors.dense(0, 1, 0, 0)));
+
+		Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+		// Creates an IDF object and initializes its parameters.
+		IDF idf = new IDF().setMinDocFreq(2);
+
+		// Trains the IDF Model.
+		IDFModel model = idf.fit(inputTable);
+
+		// Uses the IDF Model for predictions.
+		Table outputTable = model.transform(inputTable)[0];
+
+		// Extracts and displays the results.
+		for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+			Row row = it.next();
+			DenseVector inputValue = (DenseVector) row.getField(idf.getInputCol());
+			DenseVector outputValue = (DenseVector) row.getField(idf.getOutputCol());
+			System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
+		}
+	}
+}
+
+```
+
+{{< /tab>}}
+
+{{< tab "Python">}}
+
+```python
+# Simple program that trains an IDF model and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.idf import IDF
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input for training and prediction.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (Vectors.dense(0, 1, 0, 2),),
+        (Vectors.dense(0, 1, 2, 3),),
+        (Vectors.dense(0, 1, 0, 0),),
+    ],
+        type_info=Types.ROW_NAMED(
+            ['input', ],
+            [DenseVectorTypeInfo(), ])))
+
+# Creates an IDF object and initializes its parameters.
+idf = IDF().set_min_doc_freq(2)
+
+# Trains the IDF Model.
+model = idf.fit(input_table)
+
+# Uses the IDF Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+    input_index = field_names.index(idf.get_input_col())
+    output_index = field_names.index(idf.get_output_col())
+    print('Input Value: ' + str(result[input_index]) +
+          '\tOutput Value: ' + str(result[output_index]))
+```
+
+{{< /tab>}}
+
+{{< /tabs>}}
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
index f8df7e3..b34dbdf 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/DataStreamUtils.java
@@ -19,6 +19,7 @@
 package org.apache.flink.ml.common.datastream;
 
 import org.apache.flink.annotation.Internal;
+import org.apache.flink.api.common.functions.AggregateFunction;
 import org.apache.flink.api.common.functions.FlatMapFunction;
 import org.apache.flink.api.common.functions.MapFunction;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
@@ -116,6 +117,36 @@ public class DataStreamUtils {
         }
     }
 
+    /**
+     * Applies an {@link AggregateFunction} on a bounded stream. The output stream contains the
+     * aggregated result and its parallelism is one.
+     *
+     * @param input The input data stream.
+     * @param func The user defined aggregate function.
+     * @return The result data stream.
+     * @param <IN> The class type of the input.
+     * @param <ACC> The class type of the accumulated values.
+     * @param <OUT> The class type of the output values.
+     */
+    public static <IN, ACC, OUT> DataStream<OUT> aggregate(
+            DataStream<IN> input, AggregateFunction<IN, ACC, OUT> func) {
+        TypeInformation<ACC> accType =
+                TypeExtractor.getAggregateFunctionAccumulatorType(
+                        func, input.getType(), null, true);
+        TypeInformation<OUT> outType =
+                TypeExtractor.getAggregateFunctionReturnType(func, input.getType(), null, true);
+
+        DataStream<ACC> partialAggregatedStream =
+                input.transform(
+                        "partialAggregate", accType, new PartialAggregateOperator<>(func, accType));
+        DataStream<OUT> aggregatedStream =
+                partialAggregatedStream.transform(
+                        "aggregate", outType, new AggregateOperator<>(func, accType));
+        aggregatedStream.getTransformation().setParallelism(1);
+
+        return aggregatedStream;
+    }
+
     /**
      * Performs a uniform sampling over the elements in a bounded data stream.
      *
@@ -263,6 +294,106 @@ public class DataStreamUtils {
         }
     }
 
+    /**
+     * A stream operator to apply {@link AggregateFunction#add(IN, ACC)} on each partition of the
+     * input bounded data stream.
+     */
+    private static class PartialAggregateOperator<IN, ACC, OUT>
+            extends AbstractUdfStreamOperator<ACC, AggregateFunction<IN, ACC, OUT>>
+            implements OneInputStreamOperator<IN, ACC>, BoundedOneInput {
+        /** Type information of the accumulated result. */
+        private final TypeInformation<ACC> accType;
+        /** The accumulated result of the aggregate function in one partition. */
+        private ACC acc;
+        /** State of acc. */
+        private ListState<ACC> accState;
+
+        public PartialAggregateOperator(
+                AggregateFunction<IN, ACC, OUT> userFunction, TypeInformation<ACC> accType) {
+            super(userFunction);
+            this.accType = accType;
+        }
+
+        @Override
+        public void endInput() {
+            output.collect(new StreamRecord<>(acc));
+        }
+
+        @Override
+        public void processElement(StreamRecord<IN> streamRecord) throws Exception {
+            acc = userFunction.add(streamRecord.getValue(), acc);
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            accState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("accState", accType));
+            acc =
+                    OperatorStateUtils.getUniqueElement(accState, "accState")
+                            .orElse(userFunction.createAccumulator());
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            accState.clear();
+            accState.add(acc);
+        }
+    }
+
+    /**
+     * A stream operator to apply {@link AggregateFunction#merge(ACC, ACC)} and {@link
+     * AggregateFunction#getResult(ACC)} on the input bounded data stream.
+     */
+    private static class AggregateOperator<IN, ACC, OUT>
+            extends AbstractUdfStreamOperator<OUT, AggregateFunction<IN, ACC, OUT>>
+            implements OneInputStreamOperator<ACC, OUT>, BoundedOneInput {
+        /** Type information of the accumulated result. */
+        private final TypeInformation<ACC> accType;
+        /** The accumulated result of the aggregate function in the final partition. */
+        private ACC acc;
+        /** State of acc. */
+        private ListState<ACC> accState;
+
+        public AggregateOperator(
+                AggregateFunction<IN, ACC, OUT> userFunction, TypeInformation<ACC> accType) {
+            super(userFunction);
+            this.accType = accType;
+        }
+
+        @Override
+        public void endInput() {
+            output.collect(new StreamRecord<>(userFunction.getResult(acc)));
+        }
+
+        @Override
+        public void processElement(StreamRecord<ACC> streamRecord) throws Exception {
+            if (acc == null) {
+                acc = streamRecord.getValue();
+            } else {
+                acc = userFunction.merge(streamRecord.getValue(), acc);
+            }
+        }
+
+        @Override
+        public void initializeState(StateInitializationContext context) throws Exception {
+            super.initializeState(context);
+            accState =
+                    context.getOperatorStateStore()
+                            .getListState(new ListStateDescriptor<>("accState", accType));
+            acc = OperatorStateUtils.getUniqueElement(accState, "accState").orElse(null);
+        }
+
+        @Override
+        public void snapshotState(StateSnapshotContext context) throws Exception {
+            super.snapshotState(context);
+            accState.clear();
+            accState.add(acc);
+        }
+    }
+
     /**
      * Splits the input data into global batches of batchSize. After splitting, each global batch is
      * further split into local batches for downstream operators with each worker has one batch.
diff --git a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
index 4dc175a..9245c77 100644
--- a/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
+++ b/flink-ml-core/src/main/java/org/apache/flink/ml/common/datastream/TableUtils.java
@@ -43,6 +43,17 @@ public class TableUtils {
         return new RowTypeInfo(types, names);
     }
 
+    // Retrieves the TypeInformation of a column by name. Returns null if the name does not exist in
+    // the input schema.
+    public static TypeInformation<?> getTypeInfoByName(ResolvedSchema schema, String name) {
+        for (Column column : schema.getColumns()) {
+            if (column.getName().equals(name)) {
+                return TypeInformation.of(column.getDataType().getConversionClass());
+            }
+        }
+        return null;
+    }
+
     public static StreamExecutionEnvironment getExecutionEnvironment(StreamTableEnvironment tEnv) {
         Table table = tEnv.fromValues();
         DataStream<Row> dataStream = tEnv.toDataStream(table);
diff --git a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
index a968a0e..fa48bb5 100644
--- a/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
+++ b/flink-ml-core/src/test/java/org/apache/flink/ml/common/datastream/DataStreamUtilsTest.java
@@ -18,6 +18,7 @@
 
 package org.apache.flink.ml.common.datastream;
 
+import org.apache.flink.api.common.functions.AggregateFunction;
 import org.apache.flink.api.common.functions.MapPartitionFunction;
 import org.apache.flink.api.common.functions.ReduceFunction;
 import org.apache.flink.api.common.functions.RichMapPartitionFunction;
@@ -75,6 +76,16 @@ public class DataStreamUtilsTest {
         assertArrayEquals(new long[] {190L}, sum.stream().mapToLong(Long::longValue).toArray());
     }
 
+    @Test
+    public void testAggregate() throws Exception {
+        DataStream<Long> dataStream =
+                env.fromParallelCollection(new NumberSequenceIterator(0L, 19L), Types.LONG);
+        DataStream<String> result = DataStreamUtils.aggregate(dataStream, new TestAggregateFunc());
+        List<String> stringSum = IteratorUtils.toList(result.executeAndCollect());
+        assertEquals(1, stringSum.size());
+        assertEquals("190", stringSum.get(0));
+    }
+
     @Test
     public void testGenerateBatchData() throws Exception {
         DataStream<Long> dataStream =
@@ -99,4 +110,27 @@ public class DataStreamUtilsTest {
             out.collect(cnt);
         }
     }
+
+    /** A simple implementation for {@link AggregateFunction}. */
+    private static class TestAggregateFunc implements AggregateFunction<Long, Long, String> {
+        @Override
+        public Long createAccumulator() {
+            return 0L;
+        }
+
+        @Override
+        public Long add(Long element, Long acc) {
+            return element + acc;
+        }
+
+        @Override
+        public String getResult(Long acc) {
+            return String.valueOf(acc);
+        }
+
+        @Override
+        public Long merge(Long acc1, Long acc2) {
+            return acc1 + acc2;
+        }
+    }
 }
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java
new file mode 100644
index 0000000..ffa4d71
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/IDFExample.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.idf.IDF;
+import org.apache.flink.ml.feature.idf.IDFModel;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+/** Simple program that trains an IDF model and uses it for feature engineering. */
+public class IDFExample {
+    public static void main(String[] args) {
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input data.
+        DataStream<Row> inputStream =
+                env.fromElements(
+                        Row.of(Vectors.dense(0, 1, 0, 2)),
+                        Row.of(Vectors.dense(0, 1, 2, 3)),
+                        Row.of(Vectors.dense(0, 1, 0, 0)));
+
+        Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+        // Creates an IDF object and initializes its parameters.
+        IDF idf = new IDF().setMinDocFreq(2);
+
+        // Trains the IDF Model.
+        IDFModel model = idf.fit(inputTable);
+
+        // Uses the IDF Model for predictions.
+        Table outputTable = model.transform(inputTable)[0];
+
+        // Extracts and displays the results.
+        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+            Row row = it.next();
+            DenseVector inputValue = (DenseVector) row.getField(idf.getInputCol());
+            DenseVector outputValue = (DenseVector) row.getField(idf.getOutputCol());
+            System.out.printf("Input Value: %s\tOutput Value: %s\n", inputValue, outputValue);
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java
new file mode 100644
index 0000000..4a242f0
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDF.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.idf;
+
+import org.apache.flink.api.common.functions.AggregateFunction;
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.tuple.Tuple2;
+import org.apache.flink.ml.api.Estimator;
+import org.apache.flink.ml.common.datastream.DataStreamUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+/**
+ * An Estimator that computes the inverse document frequency (IDF) for the input documents. IDF is
+ * computed following `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total number of documents
+ * and `d(t)` is the number of documents that contains `t`.
+ *
+ * <p>Users could filter out terms that appeared in little documents by setting {@link
+ * IDFParams#getMinDocFreq()}.
+ *
+ * <p>See https://en.wikipedia.org/wiki/Tf%E2%80%93idf.
+ */
+public class IDF implements Estimator<IDF, IDFModel>, IDFParams<IDF> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+    public IDF() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public IDFModel fit(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        final String inputCol = getInputCol();
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Vector> inputData =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                (MapFunction<Row, Vector>)
+                                        value -> ((Vector) value.getField(inputCol)));
+
+        DataStream<IDFModelData> modelData =
+                DataStreamUtils.aggregate(inputData, new IDFAggregator(getMinDocFreq()));
+
+        IDFModel model = new IDFModel().setModelData(tEnv.fromDataStream(modelData));
+        ReadWriteUtils.updateExistingParams(model, getParamMap());
+        return model;
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static IDF load(StreamTableEnvironment tEnv, String path) throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    /** The main logic to compute the model data of IDF. */
+    private static class IDFAggregator
+            implements AggregateFunction<Vector, Tuple2<Long, DenseVector>, IDFModelData> {
+        private final int minDocFreq;
+
+        public IDFAggregator(int minDocFreq) {
+            this.minDocFreq = minDocFreq;
+        }
+
+        @Override
+        public Tuple2<Long, DenseVector> createAccumulator() {
+            return Tuple2.of(0L, new DenseVector(new double[0]));
+        }
+
+        @Override
+        public Tuple2<Long, DenseVector> add(
+                Vector vector, Tuple2<Long, DenseVector> numDocsAndDocFreq) {
+            if (numDocsAndDocFreq.f0 == 0) {
+                numDocsAndDocFreq.f1 = new DenseVector(vector.size());
+            }
+            numDocsAndDocFreq.f0 += 1L;
+
+            double[] values;
+            if (vector instanceof SparseVector) {
+                values = ((SparseVector) vector).values;
+            } else {
+                values = ((DenseVector) vector).values;
+            }
+            for (int i = 0; i < values.length; i++) {
+                values[i] = values[i] > 0 ? 1 : 0;
+            }
+
+            BLAS.axpy(1, vector, numDocsAndDocFreq.f1);
+
+            return numDocsAndDocFreq;
+        }
+
+        @Override
+        public IDFModelData getResult(Tuple2<Long, DenseVector> numDocsAndDocFreq) {
+            long numDocs = numDocsAndDocFreq.f0;
+            DenseVector docFreq = numDocsAndDocFreq.f1;
+            Preconditions.checkState(numDocs > 0, "The training set is empty.");
+
+            long[] filteredDocFreq = new long[docFreq.size()];
+            double[] df = docFreq.values;
+            double[] idf = new double[df.length];
+            for (int i = 0; i < idf.length; i++) {
+                if (df[i] >= minDocFreq) {
+                    idf[i] = Math.log((numDocs + 1) / (df[i] + 1));
+                    filteredDocFreq[i] = (long) df[i];
+                }
+            }
+            return new IDFModelData(Vectors.dense(idf), filteredDocFreq, numDocs);
+        }
+
+        @Override
+        public Tuple2<Long, DenseVector> merge(
+                Tuple2<Long, DenseVector> numDocsAndDocFreq1,
+                Tuple2<Long, DenseVector> numDocsAndDocFreq2) {
+            if (numDocsAndDocFreq1.f0 == 0) {
+                return numDocsAndDocFreq2;
+            }
+
+            if (numDocsAndDocFreq2.f0 == 0) {
+                return numDocsAndDocFreq1;
+            }
+
+            numDocsAndDocFreq2.f0 += numDocsAndDocFreq1.f0;
+            BLAS.axpy(1, numDocsAndDocFreq1.f1, numDocsAndDocFreq2.f1);
+            return numDocsAndDocFreq2;
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java
new file mode 100644
index 0000000..87a2f25
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModel.java
@@ -0,0 +1,149 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.idf;
+
+import org.apache.flink.api.common.functions.RichMapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Model;
+import org.apache.flink.ml.common.broadcast.BroadcastUtils;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.BLAS;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+
+/** A Model which transforms data using the model data computed by {@link IDF}. */
+public class IDFModel implements Model<IDFModel>, IDFModelParams<IDFModel> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private Table modelDataTable;
+
+    public IDFModel() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    @SuppressWarnings("unchecked")
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+        DataStream<Row> data = tEnv.toDataStream(inputs[0]);
+        DataStream<IDFModelData> idfModelData = IDFModelData.getModelDataStream(modelDataTable);
+
+        final String broadcastModelKey = "broadcastModelKey";
+        ResolvedSchema schema = inputs[0].getResolvedSchema();
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(schema);
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(
+                                inputTypeInfo.getFieldTypes(),
+                                TableUtils.getTypeInfoByName(schema, getInputCol())),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+        DataStream<Row> output =
+                BroadcastUtils.withBroadcastStream(
+                        Collections.singletonList(data),
+                        Collections.singletonMap(broadcastModelKey, idfModelData),
+                        inputList -> {
+                            DataStream input = inputList.get(0);
+                            return input.map(
+                                    new ComputeTfIdfFunction(broadcastModelKey, getInputCol()),
+                                    outputTypeInfo);
+                        });
+
+        return new Table[] {tEnv.fromDataStream(output)};
+    }
+
+    @Override
+    public IDFModel setModelData(Table... inputs) {
+        modelDataTable = inputs[0];
+        return this;
+    }
+
+    @Override
+    public Table[] getModelData() {
+        return new Table[] {modelDataTable};
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+        ReadWriteUtils.saveModelData(
+                IDFModelData.getModelDataStream(modelDataTable),
+                path,
+                new IDFModelData.ModelDataEncoder());
+    }
+
+    public static IDFModel load(StreamTableEnvironment tEnv, String path) throws IOException {
+        IDFModel model = ReadWriteUtils.loadStageParam(path);
+
+        Table modelDataTable =
+                ReadWriteUtils.loadModelData(tEnv, path, new IDFModelData.ModelDataDecoder());
+        return model.setModelData(modelDataTable);
+    }
+
+    /** Computes the tf-idf for each term in the input document. */
+    private static class ComputeTfIdfFunction extends RichMapFunction<Row, Row> {
+        private final String inputCol;
+        private final String broadcastKey;
+        private DenseVector idf;
+
+        public ComputeTfIdfFunction(String broadcastKey, String inputCol) {
+            this.broadcastKey = broadcastKey;
+            this.inputCol = inputCol;
+        }
+
+        @Override
+        public Row map(Row row) {
+            if (idf == null) {
+                IDFModelData idfModelDataData =
+                        (IDFModelData)
+                                getRuntimeContext().getBroadcastVariable(broadcastKey).get(0);
+                idf = idfModelDataData.idf;
+            }
+
+            Vector outputVec = ((Vector) Objects.requireNonNull(row.getField(inputCol))).clone();
+            BLAS.hDot(idf, outputVec);
+            return Row.join(row, Row.of(outputVec));
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java
new file mode 100644
index 0000000..a808454
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelData.java
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.idf;
+
+import org.apache.flink.api.common.serialization.Encoder;
+import org.apache.flink.api.common.typeinfo.TypeInformation;
+import org.apache.flink.api.common.typeutils.base.LongSerializer;
+import org.apache.flink.api.common.typeutils.base.array.LongPrimitiveArraySerializer;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.connector.file.src.reader.SimpleStreamFormat;
+import org.apache.flink.core.fs.FSDataInputStream;
+import org.apache.flink.core.memory.DataInputView;
+import org.apache.flink.core.memory.DataInputViewStreamWrapper;
+import org.apache.flink.core.memory.DataOutputView;
+import org.apache.flink.core.memory.DataOutputViewStreamWrapper;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.typeinfo.DenseVectorSerializer;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+
+import java.io.EOFException;
+import java.io.IOException;
+import java.io.OutputStream;
+
+/**
+ * Model data of {@link IDFModel}.
+ *
+ * <p>This class also provides methods to convert model data from Table to a data stream, and
+ * classes to save/load model data.
+ */
+public class IDFModelData {
+    /** Inverse document frequency for all terms. */
+    public DenseVector idf;
+    /** Document frequency for all terms after filtering out infrequent terms. */
+    public long[] docFreq;
+    /** Number of docs in the training set. */
+    public long numDocs;
+
+    public IDFModelData() {}
+
+    public IDFModelData(DenseVector idf, long[] docFreq, long numDocs) {
+        this.idf = idf;
+        this.docFreq = docFreq;
+        this.numDocs = numDocs;
+    }
+
+    /**
+     * Converts the table model to a data stream.
+     *
+     * @param modelDataTable The table model data.
+     * @return The data stream model data.
+     */
+    public static DataStream<IDFModelData> getModelDataStream(Table modelDataTable) {
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) modelDataTable).getTableEnvironment();
+        return tEnv.toDataStream(modelDataTable)
+                .map(x -> new IDFModelData(x.getFieldAs(0), x.getFieldAs(1), x.getFieldAs(2)));
+    }
+
+    /** Encoder for {@link IDFModelData}. */
+    public static class ModelDataEncoder implements Encoder<IDFModelData> {
+        private final DenseVectorSerializer denseVectorSerializer = new DenseVectorSerializer();
+
+        @Override
+        public void encode(IDFModelData modelData, OutputStream outputStream) throws IOException {
+            DataOutputView dataOutputView = new DataOutputViewStreamWrapper(outputStream);
+            denseVectorSerializer.serialize(modelData.idf, dataOutputView);
+            LongPrimitiveArraySerializer.INSTANCE.serialize(modelData.docFreq, dataOutputView);
+            LongSerializer.INSTANCE.serialize(modelData.numDocs, dataOutputView);
+        }
+    }
+
+    /** Decoder for {@link IDFModelData}. */
+    public static class ModelDataDecoder extends SimpleStreamFormat<IDFModelData> {
+        @Override
+        public Reader<IDFModelData> createReader(Configuration config, FSDataInputStream stream) {
+            return new Reader<IDFModelData>() {
+                private final DenseVectorSerializer denseVectorSerializer =
+                        new DenseVectorSerializer();
+
+                @Override
+                public IDFModelData read() throws IOException {
+                    DataInputView source = new DataInputViewStreamWrapper(stream);
+                    try {
+                        DenseVector idf = denseVectorSerializer.deserialize(source);
+                        long[] docFreq = LongPrimitiveArraySerializer.INSTANCE.deserialize(source);
+                        long numDocs = LongSerializer.INSTANCE.deserialize(source);
+                        return new IDFModelData(idf, docFreq, numDocs);
+                    } catch (EOFException e) {
+                        return null;
+                    }
+                }
+
+                @Override
+                public void close() throws IOException {
+                    stream.close();
+                }
+            };
+        }
+
+        @Override
+        public TypeInformation<IDFModelData> getProducedType() {
+            return TypeInformation.of(IDFModelData.class);
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelParams.java
new file mode 100644
index 0000000..168efa5
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFModelParams.java
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.idf;
+
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasOutputCol;
+
+/**
+ * Params for {@link IDFModel}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface IDFModelParams<T> extends HasInputCol<T>, HasOutputCol<T> {}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFParams.java
new file mode 100644
index 0000000..14ca550
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/idf/IDFParams.java
@@ -0,0 +1,45 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.idf;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+
+/**
+ * Params for {@link IDF}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface IDFParams<T> extends IDFModelParams<T> {
+    Param<Integer> MIN_DOC_FREQ =
+            new IntParam(
+                    "minDocFreq",
+                    "Minimum number of documents that a term should appear for filtering.",
+                    0,
+                    ParamValidators.gtEq(0));
+
+    default int getMinDocFreq() {
+        return get(MIN_DOC_FREQ);
+    }
+
+    default T setMinDocFreq(Integer value) {
+        return set(MIN_DOC_FREQ, value);
+    }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
new file mode 100644
index 0000000..e4cae5a
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/IDFTest.java
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.idf.IDF;
+import org.apache.flink.ml.feature.idf.IDFModel;
+import org.apache.flink.ml.feature.idf.IDFModelData;
+import org.apache.flink.ml.linalg.DenseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.apache.commons.lang3.exception.ExceptionUtils;
+import org.junit.Before;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.TemporaryFolder;
+
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
+import static org.apache.flink.table.api.Expressions.$;
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/** Tests {@link IDF} and {@link IDFModel}. */
+public class IDFTest extends AbstractTestBase {
+    @Rule public final TemporaryFolder tempFolder = new TemporaryFolder();
+    private StreamExecutionEnvironment env;
+    private StreamTableEnvironment tEnv;
+    private Table inputTable;
+
+    private static final List<DenseVector> expectedOutput =
+            Arrays.asList(
+                    Vectors.dense(0, 0, 0, 0.5753641),
+                    Vectors.dense(0, 0, 1.3862943, 0.8630462),
+                    Vectors.dense(0, 0, 0, 0));
+    private static final List<DenseVector> expectedOutputMinDocFreqAsTwo =
+            Arrays.asList(
+                    Vectors.dense(0, 0, 0, 0.5753641),
+                    Vectors.dense(0, 0, 0, 0.8630462),
+                    Vectors.dense(0, 0, 0, 0));
+    private static final double TOLERANCE = 1e-7;
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+
+        List<DenseVector> input =
+                Arrays.asList(
+                        Vectors.dense(0, 1, 0, 2),
+                        Vectors.dense(0, 1, 2, 3),
+                        Vectors.dense(0, 1, 0, 0));
+        inputTable = tEnv.fromDataStream(env.fromCollection(input).map(x -> x)).as("input");
+    }
+
+    @SuppressWarnings("unchecked")
+    private void verifyPredictionResult(
+            List<DenseVector> expectedOutput, Table output, String predictionCol) throws Exception {
+        List<Row> collectedResult =
+                IteratorUtils.toList(
+                        tEnv.toDataStream(output.select($(predictionCol))).executeAndCollect());
+        List<DenseVector> actualOutputs = new ArrayList<>(expectedOutput.size());
+        collectedResult.forEach(x -> actualOutputs.add((x.getFieldAs(0))));
+
+        actualOutputs.sort(TestUtils::compare);
+        expectedOutput.sort(TestUtils::compare);
+        assertEquals(expectedOutput.size(), collectedResult.size());
+        for (int i = 0; i < expectedOutput.size(); i++) {
+            assertArrayEquals(expectedOutput.get(i).values, actualOutputs.get(i).values, TOLERANCE);
+        }
+    }
+
+    @Test
+    public void testParam() {
+        IDF idf = new IDF();
+        assertEquals("input", idf.getInputCol());
+        assertEquals(0, idf.getMinDocFreq());
+        assertEquals("output", idf.getOutputCol());
+
+        idf.setInputCol("test_input").setMinDocFreq(2).setOutputCol("test_output");
+
+        assertEquals("test_input", idf.getInputCol());
+        assertEquals(2, idf.getMinDocFreq());
+        assertEquals("test_output", idf.getOutputCol());
+    }
+
+    @Test
+    public void testOutputSchema() {
+        Table tempTable =
+                tEnv.fromDataStream(env.fromElements(Row.of("", "")))
+                        .as("test_input", "dummy_input");
+        IDF idf = new IDF().setInputCol("test_input").setOutputCol("test_output");
+        Table output = idf.fit(tempTable).transform(tempTable)[0];
+
+        assertEquals(
+                Arrays.asList("test_input", "dummy_input", "test_output"),
+                output.getResolvedSchema().getColumnNames());
+    }
+
+    @Test
+    public void testFitAndPredict() throws Exception {
+        IDF idf = new IDF();
+        Table output;
+
+        // Tests minDocFreq = 0.
+        output = idf.fit(inputTable).transform(inputTable)[0];
+        verifyPredictionResult(expectedOutput, output, idf.getOutputCol());
+
+        // Tests minDocFreq = 2.
+        idf.setMinDocFreq(2);
+        output = idf.fit(inputTable).transform(inputTable)[0];
+        verifyPredictionResult(expectedOutputMinDocFreqAsTwo, output, idf.getOutputCol());
+    }
+
+    @Test
+    public void testSaveLoadAndPredict() throws Exception {
+        IDF idf = new IDF();
+        idf = TestUtils.saveAndReload(tEnv, idf, tempFolder.newFolder().getAbsolutePath());
+
+        IDFModel model = idf.fit(inputTable);
+        model = TestUtils.saveAndReload(tEnv, model, tempFolder.newFolder().getAbsolutePath());
+
+        assertEquals(
+                Arrays.asList("idf", "docFreq", "numDocs"),
+                model.getModelData()[0].getResolvedSchema().getColumnNames());
+
+        Table output = model.transform(inputTable)[0];
+        verifyPredictionResult(expectedOutput, output, idf.getOutputCol());
+    }
+
+    @Test
+    @SuppressWarnings("unchecked")
+    public void testGetModelData() throws Exception {
+        IDFModel model = new IDF().fit(inputTable);
+        Table modelDataTable = model.getModelData()[0];
+
+        assertEquals(
+                Arrays.asList("idf", "docFreq", "numDocs"),
+                modelDataTable.getResolvedSchema().getColumnNames());
+
+        List<IDFModelData> collectedModelData =
+                (List<IDFModelData>)
+                        IteratorUtils.toList(
+                                IDFModelData.getModelDataStream(modelDataTable)
+                                        .executeAndCollect());
+
+        assertEquals(1, collectedModelData.size());
+        IDFModelData modelData = collectedModelData.get(0);
+        assertEquals(3, modelData.numDocs);
+        assertArrayEquals(new long[] {0, 3, 1, 2}, modelData.docFreq);
+        assertArrayEquals(
+                new double[] {1.3862943, 0, 0.6931471, 0.2876820}, modelData.idf.values, TOLERANCE);
+    }
+
+    @Test
+    public void testSetModelData() throws Exception {
+        IDFModel model = new IDF().fit(inputTable);
+
+        IDFModel newModel = new IDFModel();
+        ReadWriteUtils.updateExistingParams(newModel, model.getParamMap());
+        newModel.setModelData(model.getModelData());
+        Table output = newModel.transform(inputTable)[0];
+
+        verifyPredictionResult(expectedOutput, output, model.getOutputCol());
+    }
+
+    @Test
+    public void testFitOnEmptyData() {
+        Table emptyTable =
+                tEnv.fromDataStream(env.fromElements(Row.of(1, 2)).filter(x -> x.getArity() == 0))
+                        .as("input");
+        IDFModel model = new IDF().fit(emptyTable);
+        Table modelDataTable = model.getModelData()[0];
+
+        try {
+            modelDataTable.execute().collect().next();
+            fail();
+        } catch (Throwable e) {
+            assertEquals("The training set is empty.", ExceptionUtils.getRootCause(e).getMessage());
+        }
+    }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/idf_example.py b/flink-ml-python/pyflink/examples/ml/feature/idf_example.py
new file mode 100644
index 0000000..fb8e778
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/idf_example.py
@@ -0,0 +1,60 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that trains an IDF model and uses it for feature
+# engineering.
+
+from pyflink.common import Types
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.idf import IDF
+from pyflink.table import StreamTableEnvironment
+
+# Creates a new StreamExecutionEnvironment.
+env = StreamExecutionEnvironment.get_execution_environment()
+
+# Creates a StreamTableEnvironment.
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input for training and prediction.
+input_table = t_env.from_data_stream(
+    env.from_collection([
+        (Vectors.dense(0, 1, 0, 2),),
+        (Vectors.dense(0, 1, 2, 3),),
+        (Vectors.dense(0, 1, 0, 0),),
+    ],
+        type_info=Types.ROW_NAMED(
+            ['input', ],
+            [DenseVectorTypeInfo(), ])))
+
+# Creates an IDF object and initializes its parameters.
+idf = IDF().set_min_doc_freq(2)
+
+# Trains the IDF Model.
+model = idf.fit(input_table)
+
+# Uses the IDF Model for predictions.
+output = model.transform(input_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+    input_index = field_names.index(idf.get_input_col())
+    output_index = field_names.index(idf.get_output_col())
+    print('Input Value: ' + str(result[input_index]) +
+          '\tOutput Value: ' + str(result[output_index]))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/idf.py b/flink-ml-python/pyflink/ml/lib/feature/idf.py
new file mode 100644
index 0000000..dd876e2
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/idf.py
@@ -0,0 +1,106 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.core.param import IntParam, ParamValidators
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureModel, JavaFeatureEstimator
+from pyflink.ml.lib.param import HasInputCol, HasOutputCol
+
+
+class _IDFModelParams(
+    JavaWithParams,
+    HasInputCol,
+    HasOutputCol
+):
+    """
+    Params for :class:`IDFModel`.
+    """
+
+    def __init__(self, java_params):
+        super(_IDFModelParams, self).__init__(java_params)
+
+
+class _IDFParams(_IDFModelParams):
+    """
+    Params for :class:`IDF`.
+    """
+
+    MIN_DOC_FREQ: IntParam = IntParam(
+        "min_doc_freq",
+        "Minimum number of documents that a term should appear for filtering.",
+        0,
+        ParamValidators.gt_eq(0))
+
+    def __init__(self, java_params):
+        super(_IDFParams, self).__init__(java_params)
+
+    def set_min_doc_freq(self, value: int):
+        return typing.cast(_IDFParams, self.set(self.MIN_DOC_FREQ, value))
+
+    def get_min_doc_freq(self) -> int:
+        return self.get(self.MIN_DOC_FREQ)
+
+    @property
+    def min_doc_freq(self):
+        return self.get_min_doc_freq()
+
+
+class IDFModel(JavaFeatureModel, _IDFModelParams):
+    """
+    A Model which transforms data using the model data computed by :class::IDF.
+    """
+
+    def __init__(self, java_model=None):
+        super(IDFModel, self).__init__(java_model)
+
+    @classmethod
+    def _java_model_package_name(cls) -> str:
+        return "idf"
+
+    @classmethod
+    def _java_model_class_name(cls) -> str:
+        return "IDFModel"
+
+
+class IDF(JavaFeatureEstimator, _IDFParams):
+    """
+    An Estimator that computes the inverse document frequency (IDF) for the input documents.
+    IDF is computed following `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total
+    number of documents and `d(t)` is the number of documents that contains `t`.
+
+    <p>Users could filter out terms that appeared in little documents by setting
+    {@link IDFParams#getMinDocFreq()}.
+
+    <p>See https://en.wikipedia.org/wiki/Tf%E2%80%93idf.
+    """
+
+    def __init__(self):
+        super(IDF, self).__init__()
+
+    @classmethod
+    def _create_model(cls, java_model) -> IDFModel:
+        return IDFModel(java_model)
+
+    @classmethod
+    def _java_estimator_package_name(cls) -> str:
+        return "idf"
+
+    @classmethod
+    def _java_estimator_class_name(cls) -> str:
+        return "IDF"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py
new file mode 100644
index 0000000..d192ef3
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_idf.py
@@ -0,0 +1,128 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo
+from pyflink.ml.lib.feature.idf import IDF, IDFModel
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class IDFTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(IDFTest, self).setUp()
+        self.input_data = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (Vectors.dense(0, 1, 0, 2),),
+                (Vectors.dense(0, 1, 2, 3),),
+                (Vectors.dense(0, 1, 0, 0),),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['input', ],
+                    [DenseVectorTypeInfo(), ])))
+
+        self.expected_output = [
+            Vectors.dense(0.0, 0.0, 0.0, 0.5753641),
+            Vectors.dense(0.0, 0.0, 1.3862943, 0.8630462),
+            Vectors.dense(0.0, 0.0, 0.0, 0.0),
+        ]
+
+        self.expected_output_min_doc_freq_as_two = [
+            Vectors.dense(0.0, 0.0, 0.0, 0.5753641),
+            Vectors.dense(0.0, 0.0, 0.0, 0.8630462),
+            Vectors.dense(0.0, 0.0, 0.0, 0.0),
+        ]
+
+        self.tolerance = 1e-7
+
+    def verify_prediction_result(self, expected, output_table):
+        predicted_results = [result[1] for result in
+                             self.t_env.to_data_stream(output_table).execute_and_collect()]
+
+        predicted_results.sort(key=lambda x: x[3])
+        expected.sort(key=lambda x: x[3])
+
+        self.assertEqual(len(expected), len(predicted_results))
+        for i in range(len(expected)):
+            expected_row = expected[i]
+            predicted_row = predicted_results[i]
+            self.assertEqual(len(expected_row), len(predicted_row))
+            for j in range(len(expected_row)):
+                self.assertAlmostEqual(expected_row[j], predicted_row[j], delta=self.tolerance)
+
+    def test_param(self):
+        idf = IDF()
+        self.assertEqual("input", idf.input_col)
+        self.assertEqual(0, idf.min_doc_freq)
+        self.assertEqual("output", idf.output_col)
+
+        idf \
+            .set_input_col("test_input") \
+            .set_min_doc_freq(2) \
+            .set_output_col("test_output")
+
+        self.assertEqual("test_input", idf.input_col)
+        self.assertEqual(2, idf.min_doc_freq)
+        self.assertEqual("test_output", idf.output_col)
+
+    def test_output_schema(self):
+        idf = IDF() \
+            .set_input_col("test_input") \
+            .set_output_col("test_output")
+        input_data_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (Vectors.dense(1), ''),
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['test_input', 'dummy_input'],
+                    [DenseVectorTypeInfo(), Types.STRING()])))
+        output = idf \
+            .fit(input_data_table) \
+            .transform(input_data_table)[0]
+
+        self.assertEqual(
+            [idf.input_col, 'dummy_input', idf.output_col],
+            output.get_schema().get_field_names())
+
+    def test_fit_and_predict(self):
+        idf = IDF()
+        # Tests minDocFreq = 0.
+        output = idf.fit(self.input_data).transform(self.input_data)[0]
+        self.verify_prediction_result(self.expected_output, output)
+
+        # Tests minDocFreq = 2.
+        idf.set_min_doc_freq(2)
+        output = idf.fit(self.input_data).transform(self.input_data)[0]
+        self.verify_prediction_result(self.expected_output_min_doc_freq_as_two, output)
+
+    def test_save_load_predict(self):
+        idf = IDF()
+        estimator_path = os.path.join(self.temp_dir, 'test_save_load_predict_idf')
+        idf.save(estimator_path)
+        idf = IDF.load(self.t_env, estimator_path)
+
+        model = idf.fit(self.input_data)
+        model_path = os.path.join(self.temp_dir, 'test_save_load_predict_idf_model')
+        model.save(model_path)
+        self.env.execute('save_model')
+        model = IDFModel.load(self.t_env, model_path)
+        output = model.transform(self.input_data)[0]
+
+        self.verify_prediction_result(self.expected_output, output)