You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by zh...@apache.org on 2022/08/01 03:55:29 UTC

[flink-ml] branch master updated: [FLINK-28601] Add Transformer for FeatureHasher

This is an automated email from the ASF dual-hosted git repository.

zhangzp pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git


The following commit(s) were added to refs/heads/master by this push:
     new 41cccd1  [FLINK-28601] Add Transformer for FeatureHasher
41cccd1 is described below

commit 41cccd144cf8bfdb219b8e6a736697f8a04ad5a8
Author: weibo <wb...@pku.edu.cn>
AuthorDate: Mon Aug 1 11:55:25 2022 +0800

    [FLINK-28601] Add Transformer for FeatureHasher
    
    This closes #133.
---
 .../ml/examples/feature/FeatureHasherExample.java  |  71 +++++++
 .../flink/ml/common/param/HasCategoricalCols.java  |  42 +++++
 .../flink/ml/common/param/HasNumFeatures.java      |  43 +++++
 .../ml/feature/featurehasher/FeatureHasher.java    | 209 +++++++++++++++++++++
 .../feature/featurehasher/FeatureHasherParams.java |  32 ++++
 .../apache/flink/ml/feature/FeatureHasherTest.java | 152 +++++++++++++++
 ...sembler_example.py => featurehasher_example.py} |  41 ++--
 .../examples/ml/feature/vectorassembler_example.py |   2 +-
 .../pyflink/ml/lib/feature/featurehasher.py        |  67 +++++++
 .../ml/lib/feature/tests/test_feature_hasher.py    |  77 ++++++++
 flink-ml-python/pyflink/ml/lib/param.py            |  42 +++++
 11 files changed, 754 insertions(+), 24 deletions(-)

diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java
new file mode 100644
index 0000000..c0f81c6
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/FeatureHasherExample.java
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.featurehasher.FeatureHasher;
+import org.apache.flink.ml.linalg.Vector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+
+/** Simple program that creates a FeatureHasher instance and uses it for feature engineering. */
+public class FeatureHasherExample {
+    public static void main(String[] args) {
+
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+        StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+        // Generates input data.
+        DataStream<Row> dataStream =
+                env.fromCollection(
+                        Arrays.asList(Row.of(0, "a", 1.0, true), Row.of(1, "c", 1.0, false)));
+        Table inputDataTable = tEnv.fromDataStream(dataStream).as("id", "f0", "f1", "f2");
+
+        // Creates a FeatureHasher object and initializes its parameters.
+        FeatureHasher featureHash =
+                new FeatureHasher()
+                        .setInputCols("f0", "f1", "f2")
+                        .setCategoricalCols("f0", "f2")
+                        .setOutputCol("vec")
+                        .setNumFeatures(1000);
+
+        // Uses the FeatureHasher object for feature transformations.
+        Table outputTable = featureHash.transform(inputDataTable)[0];
+
+        // Extracts and displays the results.
+        for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+            Row row = it.next();
+
+            Object[] inputValues = new Object[featureHash.getInputCols().length];
+            for (int i = 0; i < inputValues.length; i++) {
+                inputValues[i] = row.getField(featureHash.getInputCols()[i]);
+            }
+            Vector outputValue = (Vector) row.getField(featureHash.getOutputCol());
+
+            System.out.printf(
+                    "Input Values: %s \tOutput Value: %s\n",
+                    Arrays.toString(inputValues), outputValue);
+        }
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasCategoricalCols.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasCategoricalCols.java
new file mode 100644
index 0000000..fcc340b
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasCategoricalCols.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.StringArrayParam;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared categoricalCols param. */
+public interface HasCategoricalCols<T> extends WithParams<T> {
+    Param<String[]> CATEGORICAL_COLS =
+            new StringArrayParam(
+                    "categoricalCols",
+                    "Categorical column names.",
+                    new String[] {},
+                    ParamValidators.notNull());
+
+    default String[] getCategoricalCols() {
+        return get(CATEGORICAL_COLS);
+    }
+
+    default T setCategoricalCols(String... value) {
+        return set(CATEGORICAL_COLS, value);
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasNumFeatures.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasNumFeatures.java
new file mode 100644
index 0000000..7d8b699
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/common/param/HasNumFeatures.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.common.param;
+
+import org.apache.flink.ml.param.IntParam;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.param.ParamValidators;
+import org.apache.flink.ml.param.WithParams;
+
+/** Interface for the shared num features param. */
+public interface HasNumFeatures<T> extends WithParams<T> {
+    Param<Integer> NUM_FEATURES =
+            new IntParam(
+                    "numFeatures",
+                    "The number of features. It will be the length of the output vector.",
+                    262144,
+                    ParamValidators.gt(0));
+
+    default int getNumFeatures() {
+        return get(NUM_FEATURES);
+    }
+
+    default T setNumFeatures(int value) {
+        set(NUM_FEATURES, value);
+        return (T) this;
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java
new file mode 100644
index 0000000..3ad50b7
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasher.java
@@ -0,0 +1,209 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.featurehasher;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.typeinfo.VectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.DataTypes;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.table.types.DataType;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.TreeMap;
+
+import static org.apache.flink.shaded.guava30.com.google.common.hash.Hashing.murmur3_32;
+
+/**
+ * A Transformer that transforms a set of categorical or numerical features into a sparse vector of
+ * a specified dimension. The rules of hashing categorical columns and numerical columns are as
+ * follows:
+ *
+ * <ul>
+ *   <li>For numerical columns, the index of this feature in the output vector is the hash value of
+ *       the column name and its correponding value is the same as the input.
+ *   <li>For categorical columns, the index of this feature in the output vector is the hash value
+ *       of the string "column_name=value" and the corresponding value is 1.0.
+ * </ul>
+ *
+ * <p>If multiple features are projected into the same column, the output values are accumulated.
+ * For the hashing trick, see https://en.wikipedia.org/wiki/Feature_hashing for details.
+ */
+public class FeatureHasher
+        implements Transformer<FeatureHasher>, FeatureHasherParams<FeatureHasher> {
+    private final Map<Param<?>, Object> paramMap = new HashMap<>();
+    private static final org.apache.flink.shaded.guava30.com.google.common.hash.HashFunction HASH =
+            murmur3_32(0);
+
+    public FeatureHasher() {
+        ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+    }
+
+    @Override
+    public Table[] transform(Table... inputs) {
+        Preconditions.checkArgument(inputs.length == 1);
+        StreamTableEnvironment tEnv =
+                (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+        ResolvedSchema tableSchema = inputs[0].getResolvedSchema();
+        RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(tableSchema);
+        RowTypeInfo outputTypeInfo =
+                new RowTypeInfo(
+                        ArrayUtils.addAll(inputTypeInfo.getFieldTypes(), VectorTypeInfo.INSTANCE),
+                        ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+        DataStream<Row> output =
+                tEnv.toDataStream(inputs[0])
+                        .map(
+                                new HashFunction(
+                                        getInputCols(),
+                                        generateCategoricalCols(
+                                                tableSchema, getInputCols(), getCategoricalCols()),
+                                        getNumFeatures()),
+                                outputTypeInfo);
+        Table outputTable = tEnv.fromDataStream(output);
+        return new Table[] {outputTable};
+    }
+
+    /**
+     * The main logic for transforming the categorical and numerical features into a sparse vector.
+     * It uses MurMurHash3 to compute the transformed index in the output vector. If multiple
+     * features are projected to the same column, their values are accumulated.
+     */
+    private static class HashFunction implements MapFunction<Row, Row> {
+        private final String[] categoricalCols;
+        private final int numFeatures;
+        private final String[] numericCols;
+
+        public HashFunction(String[] inputCols, String[] categoricalCols, int numFeatures) {
+            this.categoricalCols = categoricalCols;
+            this.numFeatures = numFeatures;
+            this.numericCols = ArrayUtils.removeElements(inputCols, this.categoricalCols);
+        }
+
+        @Override
+        public Row map(Row row) {
+            TreeMap<Integer, Double> feature = new TreeMap<>();
+            for (String col : numericCols) {
+                if (null != row.getField(col)) {
+                    double value = ((Number) row.getFieldAs(col)).doubleValue();
+                    updateMap(col, value, feature, numFeatures);
+                }
+            }
+            for (String col : categoricalCols) {
+                if (null != row.getField(col)) {
+                    updateMap(col + "=" + row.getField(col), 1.0, feature, numFeatures);
+                }
+            }
+            int nnz = feature.size();
+            int[] indices = new int[nnz];
+            double[] values = new double[nnz];
+            int pos = 0;
+            for (Map.Entry<Integer, Double> entry : feature.entrySet()) {
+                indices[pos] = entry.getKey();
+                values[pos] = entry.getValue();
+                pos++;
+            }
+            return Row.join(row, Row.of(new SparseVector(numFeatures, indices, values)));
+        }
+    }
+
+    private String[] generateCategoricalCols(
+            ResolvedSchema tableSchema, String[] inputCols, String[] categoricalCols) {
+        if (null == inputCols) {
+            return categoricalCols;
+        }
+        List<String> categoricalList = Arrays.asList(categoricalCols);
+        List<String> inputList = Arrays.asList(inputCols);
+        if (categoricalCols.length > 0 && !inputList.containsAll(categoricalList)) {
+            throw new IllegalArgumentException("CategoricalCols must be included in inputCols!");
+        }
+        List<DataType> dataColTypes = tableSchema.getColumnDataTypes();
+        List<String> dataColNames = tableSchema.getColumnNames();
+        List<DataType> inputColTypes = new ArrayList<>();
+        for (String col : inputCols) {
+            for (int i = 0; i < dataColNames.size(); ++i) {
+                if (col.equals(dataColNames.get(i))) {
+                    inputColTypes.add(dataColTypes.get(i));
+                    break;
+                }
+            }
+        }
+        List<String> resultColList = new ArrayList<>();
+        for (int i = 0; i < inputCols.length; i++) {
+            boolean included = categoricalList.contains(inputCols[i]);
+            if (included
+                    || DataTypes.BOOLEAN().equals(inputColTypes.get(i))
+                    || DataTypes.STRING().equals(inputColTypes.get(i))) {
+                resultColList.add(inputCols[i]);
+            }
+        }
+        return resultColList.toArray(new String[0]);
+    }
+
+    /**
+     * Updates the treeMap which saves the key-value pair of the final vector, use the hash value of
+     * the string as key and the accumulate the corresponding value.
+     *
+     * @param s the string to hash
+     * @param value the accumulated value
+     */
+    private static void updateMap(
+            String s, double value, TreeMap<Integer, Double> feature, int numFeature) {
+        int hashValue = Math.abs(HASH.hashUnencodedChars(s).asInt());
+
+        int index = Math.floorMod(hashValue, numFeature);
+        if (feature.containsKey(index)) {
+            feature.put(index, feature.get(index) + value);
+        } else {
+            feature.put(index, value);
+        }
+    }
+
+    @Override
+    public void save(String path) throws IOException {
+        ReadWriteUtils.saveMetadata(this, path);
+    }
+
+    public static FeatureHasher load(StreamTableEnvironment env, String path) throws IOException {
+        return ReadWriteUtils.loadStageParam(path);
+    }
+
+    @Override
+    public Map<Param<?>, Object> getParamMap() {
+        return paramMap;
+    }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasherParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasherParams.java
new file mode 100644
index 0000000..997d2d6
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/featurehasher/FeatureHasherParams.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.featurehasher;
+
+import org.apache.flink.ml.common.param.HasCategoricalCols;
+import org.apache.flink.ml.common.param.HasInputCols;
+import org.apache.flink.ml.common.param.HasNumFeatures;
+import org.apache.flink.ml.common.param.HasOutputCol;
+
+/**
+ * Params of {@link FeatureHasher}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface FeatureHasherParams<T>
+        extends HasInputCols<T>, HasOutputCol<T>, HasCategoricalCols<T>, HasNumFeatures<T> {}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java
new file mode 100644
index 0000000..941b3b1
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/FeatureHasherTest.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.featurehasher.FeatureHasher;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.List;
+
+import static org.junit.Assert.assertArrayEquals;
+import static org.junit.Assert.assertEquals;
+
+/** Tests {@link FeatureHasher}. */
+public class FeatureHasherTest extends AbstractTestBase {
+
+    private StreamTableEnvironment tEnv;
+    private Table inputDataTable;
+
+    private static final List<Row> INPUT_DATA =
+            Arrays.asList(Row.of(0, "a", 1.0, true), Row.of(1, "c", 1.0, false));
+
+    private static final SparseVector EXPECTED_OUTPUT_DATA_1 =
+            Vectors.sparse(1000, new int[] {607, 635, 913}, new double[] {1.0, 1.0, 1.0});
+    private static final SparseVector EXPECTED_OUTPUT_DATA_2 =
+            Vectors.sparse(1000, new int[] {242, 869, 913}, new double[] {1.0, 1.0, 1.0});
+
+    @Before
+    public void before() {
+        Configuration config = new Configuration();
+        config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+        StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+        env.setParallelism(4);
+        env.enableCheckpointing(100);
+        env.setRestartStrategy(RestartStrategies.noRestart());
+        tEnv = StreamTableEnvironment.create(env);
+        DataStream<Row> dataStream = env.fromCollection(INPUT_DATA);
+        inputDataTable = tEnv.fromDataStream(dataStream).as("id", "f0", "f1", "f2");
+    }
+
+    private void verifyOutputResult(Table output, String outputCol) throws Exception {
+        DataStream<Row> dataStream = tEnv.toDataStream(output);
+        List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+        assertEquals(2, results.size());
+        for (Row result : results) {
+            if (result.getField(0) == (Object) 0) {
+                assertEquals(EXPECTED_OUTPUT_DATA_1, result.getField(outputCol));
+            } else if (result.getField(0) == (Object) 1) {
+                assertEquals(EXPECTED_OUTPUT_DATA_2, result.getField(outputCol));
+            } else {
+                throw new RuntimeException("unknown output value.");
+            }
+        }
+    }
+
+    @Test
+    public void testParam() {
+        FeatureHasher featureHasher = new FeatureHasher();
+        assertEquals("output", featureHasher.getOutputCol());
+        assertArrayEquals(new String[] {}, featureHasher.getCategoricalCols());
+        assertEquals(262144, featureHasher.getNumFeatures());
+        featureHasher
+                .setInputCols("f0", "f1", "f2")
+                .setOutputCol("vec")
+                .setCategoricalCols("f0", "f2")
+                .setNumFeatures(1000);
+        assertArrayEquals(new String[] {"f0", "f1", "f2"}, featureHasher.getInputCols());
+        assertEquals("vec", featureHasher.getOutputCol());
+        assertArrayEquals(new String[] {"f0", "f2"}, featureHasher.getCategoricalCols());
+        assertEquals(1000, featureHasher.getNumFeatures());
+    }
+
+    @Test
+    public void testSaveLoadAndTransform() throws Exception {
+        FeatureHasher featureHash =
+                new FeatureHasher()
+                        .setInputCols("f0", "f1", "f2")
+                        .setOutputCol("vec")
+                        .setCategoricalCols("f0", "f2")
+                        .setNumFeatures(1000);
+        FeatureHasher loadedFeatureHasher =
+                TestUtils.saveAndReload(
+                        tEnv, featureHash, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+        Table output = loadedFeatureHasher.transform(inputDataTable)[0];
+        verifyOutputResult(output, loadedFeatureHasher.getOutputCol());
+    }
+
+    @Test
+    public void testCategoricalColsNotSet() throws Exception {
+        FeatureHasher featureHash =
+                new FeatureHasher()
+                        .setInputCols("f0", "f1", "f2")
+                        .setOutputCol("vec")
+                        .setNumFeatures(1000);
+        FeatureHasher loadedFeatureHasher =
+                TestUtils.saveAndReload(
+                        tEnv, featureHash, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+        Table output = loadedFeatureHasher.transform(inputDataTable)[0];
+        verifyOutputResult(output, loadedFeatureHasher.getOutputCol());
+    }
+
+    @Test
+    public void testInputTypeConversion() throws Exception {
+        inputDataTable = TestUtils.convertDataTypesToSparseInt(tEnv, inputDataTable);
+        assertArrayEquals(
+                new Class<?>[] {Integer.class, String.class, Integer.class, Boolean.class},
+                TestUtils.getColumnDataTypes(inputDataTable));
+
+        FeatureHasher featureHash =
+                new FeatureHasher()
+                        .setInputCols("f0", "f1", "f2")
+                        .setOutputCol("vec")
+                        .setCategoricalCols("f0", "f2")
+                        .setNumFeatures(1000);
+        FeatureHasher loadedFeatureHasher =
+                TestUtils.saveAndReload(
+                        tEnv, featureHash, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+        Table output = loadedFeatureHasher.transform(inputDataTable)[0];
+        verifyOutputResult(output, loadedFeatureHasher.getOutputCol());
+    }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py b/flink-ml-python/pyflink/examples/ml/feature/featurehasher_example.py
similarity index 61%
copy from flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
copy to flink-ml-python/pyflink/examples/ml/feature/featurehasher_example.py
index 8f45b17..582429c 100644
--- a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
+++ b/flink-ml-python/pyflink/examples/ml/feature/featurehasher_example.py
@@ -16,7 +16,7 @@
 # limitations under the License.
 ################################################################################
 
-# Simple program that creates a VectorAssembler instance and uses it for feature
+# Simple program that creates a FeatureHasher instance and uses it for feature
 # engineering.
 #
 # Before executing this program, please make sure you have followed Flink ML's
@@ -27,8 +27,7 @@
 
 from pyflink.common import Types
 from pyflink.datastream import StreamExecutionEnvironment
-from pyflink.ml.core.linalg import Vectors, DenseVectorTypeInfo, SparseVectorTypeInfo
-from pyflink.ml.lib.feature.vectorassembler import VectorAssembler
+from pyflink.ml.lib.feature.featurehasher import FeatureHasher
 from pyflink.table import StreamTableEnvironment
 
 # create a new StreamExecutionEnvironment
@@ -40,32 +39,28 @@ t_env = StreamTableEnvironment.create(env)
 # generate input data
 input_data_table = t_env.from_data_stream(
     env.from_collection([
-        (Vectors.dense(2.1, 3.1),
-         1.0,
-         Vectors.sparse(5, [3], [1.0])),
-        (Vectors.dense(2.1, 3.1),
-         1.0,
-         Vectors.sparse(5, [1, 2, 3, 4],
-                        [1.0, 2.0, 3.0, 4.0])),
+        (0, 'a', 1.0, True),
+        (1, 'c', 1.0, False),
     ],
         type_info=Types.ROW_NAMED(
-            ['vec', 'num', 'sparse_vec'],
-            [DenseVectorTypeInfo(), Types.DOUBLE(), SparseVectorTypeInfo()])))
+            ['id', 'f0', 'f1', 'f2'],
+            [Types.INT(), Types.STRING(), Types.DOUBLE(), Types.BOOLEAN()])))
 
-# create a vector assembler object and initialize its parameters
-vector_assembler = VectorAssembler() \
-    .set_input_cols('vec', 'num', 'sparse_vec') \
-    .set_output_col('assembled_vec') \
-    .set_handle_invalid('keep')
+# create a feature hasher object and initialize its parameters
+feature_hasher = FeatureHasher() \
+    .set_input_cols('f0', 'f1', 'f2') \
+    .set_categorical_cols('f0', 'f2') \
+    .set_output_col('vec') \
+    .set_num_features(1000)
 
-# use the vector assembler model for feature engineering
-output = vector_assembler.transform(input_data_table)[0]
+# use the feature hasher for feature engineering
+output = feature_hasher.transform(input_data_table)[0]
 
 # extract and display the results
 field_names = output.get_schema().get_field_names()
-input_values = [None for _ in vector_assembler.get_input_cols()]
+input_values = [None for _ in feature_hasher.get_input_cols()]
 for result in t_env.to_data_stream(output).execute_and_collect():
-    for i in range(len(vector_assembler.get_input_cols())):
-        input_values[i] = result[field_names.index(vector_assembler.get_input_cols()[i])]
-    output_value = result[field_names.index(vector_assembler.get_output_col())]
+    for i in range(len(feature_hasher.get_input_cols())):
+        input_values[i] = result[field_names.index(feature_hasher.get_input_cols()[i])]
+    output_value = result[field_names.index(feature_hasher.get_output_col())]
     print('Input Values: ' + str(input_values) + '\tOutput Value: ' + str(output_value))
diff --git a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py b/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
index 8f45b17..7ae15ce 100644
--- a/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
+++ b/flink-ml-python/pyflink/examples/ml/feature/vectorassembler_example.py
@@ -58,7 +58,7 @@ vector_assembler = VectorAssembler() \
     .set_output_col('assembled_vec') \
     .set_handle_invalid('keep')
 
-# use the vector assembler model for feature engineering
+# use the vector assembler for feature engineering
 output = vector_assembler.transform(input_data_table)[0]
 
 # extract and display the results
diff --git a/flink-ml-python/pyflink/ml/lib/feature/featurehasher.py b/flink-ml-python/pyflink/ml/lib/feature/featurehasher.py
new file mode 100644
index 0000000..150cf1f
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/featurehasher.py
@@ -0,0 +1,67 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureTransformer
+from pyflink.ml.lib.param import HasInputCols, HasOutputCol, HasCategoricalCols, HasNumFeatures
+
+
+class _FeatureHasherParams(
+    JavaWithParams,
+    HasInputCols,
+    HasCategoricalCols,
+    HasOutputCol,
+    HasNumFeatures
+):
+    """
+    Params for :class:`FeatureHasher`.
+    """
+
+    def __init__(self, java_params):
+        super(_FeatureHasherParams, self).__init__(java_params)
+
+
+class FeatureHasher(JavaFeatureTransformer, _FeatureHasherParams):
+    """
+    A Transformer that transforms a set of categorical or numerical features into
+    a sparse vector of a specified dimension. The rules of hashing categorical
+    columns and numerical columns are as follows:
+
+    For numerical columns, the index of this feature in the output vector is the
+    hash value of the column name and its correponding value is the same as the
+    input.
+
+    For categorical columns, the index of this feature in the output vector is
+    the hash value of the string "column_name=value" and the corresponding
+    value is 1.0.
+
+    If multiple features are projected into the same column, the output values
+    are accumulated. For the hashing trick, see
+    https://en.wikipedia.org/wiki/Feature_hashing for details.
+    """
+
+    def __init__(self, java_model=None):
+        super(FeatureHasher, self).__init__(java_model)
+
+    @classmethod
+    def _java_transformer_package_name(cls) -> str:
+        return "featurehasher"
+
+    @classmethod
+    def _java_transformer_class_name(cls) -> str:
+        return "FeatureHasher"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_feature_hasher.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_feature_hasher.py
new file mode 100644
index 0000000..7f32128
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_feature_hasher.py
@@ -0,0 +1,77 @@
+################################################################################
+#  Licensed to the Apache Software Foundation (ASF) under one
+#  or more contributor license agreements.  See the NOTICE file
+#  distributed with this work for additional information
+#  regarding copyright ownership.  The ASF licenses this file
+#  to you under the Apache License, Version 2.0 (the
+#  "License"); you may not use this file except in compliance
+#  with the License.  You may obtain a copy of the License at
+#
+#      http://www.apache.org/licenses/LICENSE-2.0
+#
+#  Unless required by applicable law or agreed to in writing, software
+#  distributed under the License is distributed on an "AS IS" BASIS,
+#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+#  See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors
+from pyflink.ml.lib.feature.featurehasher import FeatureHasher
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class FeatureHasherTest(PyFlinkMLTestCase):
+    def setUp(self):
+        super(FeatureHasherTest, self).setUp()
+        self.input_data_table = self.t_env.from_data_stream(
+            self.env.from_collection([
+                (0, 'a', 1.0, True),
+                (1, 'c', 1.0, False)
+            ],
+                type_info=Types.ROW_NAMED(
+                    ['id', 'f0', 'f1', 'f2'],
+                    [Types.INT(), Types.STRING(), Types.DOUBLE(), Types.BOOLEAN()])))
+
+        self.expected_output_data_1 = Vectors.sparse(1000, [607, 635, 913], [1.0, 1.0, 1.0])
+        self.expected_output_data_2 = Vectors.sparse(1000, [242, 869, 913], [1.0, 1.0, 1.0])
+
+    def test_param(self):
+        feature_hasher = FeatureHasher()
+
+        self.assertEqual('output', feature_hasher.output_col)
+        self.assertEqual(262144, feature_hasher.num_features)
+
+        feature_hasher.set_input_cols('f0', 'f1', 'f2') \
+            .set_categorical_cols('f0', 'f2') \
+            .set_output_col('vec') \
+            .set_num_features(1000)
+
+        self.assertEqual(('f0', 'f1', 'f2'), feature_hasher.input_cols)
+        self.assertEqual(('f0', 'f2'), feature_hasher.categorical_cols)
+        self.assertEqual(1000, feature_hasher.num_features)
+        self.assertEqual('vec', feature_hasher.output_col)
+
+    def test_save_load_transform(self):
+        feature_hasher = FeatureHasher() \
+            .set_input_cols('f0', 'f1', 'f2') \
+            .set_categorical_cols('f0', 'f2') \
+            .set_output_col('vec') \
+            .set_num_features(1000)
+
+        path = os.path.join(self.temp_dir, 'test_save_load_transform_feature_hasher')
+        feature_hasher.save(path)
+        feature_hasher = FeatureHasher.load(self.t_env, path)
+
+        output_table = feature_hasher.transform(self.input_data_table)[0]
+        actual_outputs = [(result[0], result[4]) for result in
+                          self.t_env.to_data_stream(output_table).execute_and_collect()]
+        self.assertEqual(2, len(actual_outputs))
+        for actual_output in actual_outputs:
+            if actual_output[0] == 0:
+                self.assertEqual(self.expected_output_data_1, actual_output[1])
+            else:
+                self.assertEqual(self.expected_output_data_2, actual_output[1])
diff --git a/flink-ml-python/pyflink/ml/lib/param.py b/flink-ml-python/pyflink/ml/lib/param.py
index 64cd5ac..39c6277 100644
--- a/flink-ml-python/pyflink/ml/lib/param.py
+++ b/flink-ml-python/pyflink/ml/lib/param.py
@@ -155,6 +155,48 @@ class HasInputCols(WithParams, ABC):
         return self.get_input_cols()
 
 
+class HasCategoricalCols(WithParams, ABC):
+    """
+    Base class for the shared categorical cols param.
+    """
+    CATEGORICAL_COLS: Param[Tuple[str, ...]] = StringArrayParam(
+        "categorical_cols",
+        "Categorical column names.",
+        [],
+        ParamValidators.not_null())
+
+    def set_categorical_cols(self, *cols: str):
+        return self.set(self.CATEGORICAL_COLS, cols)
+
+    def get_categorical_cols(self) -> Tuple[str, ...]:
+        return self.get(self.CATEGORICAL_COLS)
+
+    @property
+    def categorical_cols(self) -> Tuple[str, ...]:
+        return self.get_categorical_cols()
+
+
+class HasNumFeatures(WithParams, ABC):
+    """
+    Base class for the shared numFeatures param.
+    """
+    NUM_FEATURES: Param[int] = IntParam(
+        "num_features",
+        "Number of features.",
+        262144,
+        ParamValidators.gt(0))
+
+    def set_num_features(self, num_features: int):
+        return self.set(self.NUM_FEATURES, num_features)
+
+    def get_num_features(self) -> int:
+        return self.get(self.NUM_FEATURES)
+
+    @property
+    def num_features(self) -> int:
+        return self.get_num_features()
+
+
 class HasLabelCol(WithParams, ABC):
     """
     Base class for the shared label column param.