You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@flink.apache.org by li...@apache.org on 2022/08/24 03:11:08 UTC
[flink-ml] branch master updated: [FLINK-28805] Add Transformer for HashingTF
This is an automated email from the ASF dual-hosted git repository.
lindong pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/flink-ml.git
The following commit(s) were added to refs/heads/master by this push:
new 3e4f5d1 [FLINK-28805] Add Transformer for HashingTF
3e4f5d1 is described below
commit 3e4f5d1d7bf6e7ee83233afe6d45d8958aac7908
Author: Zhipeng Zhang <zh...@gmail.com>
AuthorDate: Wed Aug 24 11:11:04 2022 +0800
[FLINK-28805] Add Transformer for HashingTF
This closes #141.
---
.../ml/examples/feature/HashingTFExample.java | 70 ++++++++
.../flink/ml/feature/hashingtf/HashingTF.java | 198 +++++++++++++++++++++
.../ml/feature/hashingtf/HashingTFParams.java | 57 ++++++
.../org/apache/flink/ml/feature/HashingTFTest.java | 186 +++++++++++++++++++
.../examples/ml/feature/hashingtf_example.py | 61 +++++++
.../pyflink/ml/lib/feature/hashingtf.py | 86 +++++++++
.../pyflink/ml/lib/feature/tests/test_hashingtf.py | 115 ++++++++++++
7 files changed, 773 insertions(+)
diff --git a/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/HashingTFExample.java b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/HashingTFExample.java
new file mode 100644
index 0000000..213ebd5
--- /dev/null
+++ b/flink-ml-examples/src/main/java/org/apache/flink/ml/examples/feature/HashingTFExample.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.examples.feature;
+
+import org.apache.flink.ml.feature.hashingtf.HashingTF;
+import org.apache.flink.ml.linalg.SparseVector;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.CloseableIterator;
+
+import java.util.Arrays;
+import java.util.List;
+
+/** Simple program that creates a HashingTF instance and uses it for feature engineering. */
+public class HashingTFExample {
+ public static void main(String[] args) {
+ StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
+ StreamTableEnvironment tEnv = StreamTableEnvironment.create(env);
+
+ // Generates input data.
+ DataStream<Row> inputStream =
+ env.fromElements(
+ Row.of(
+ Arrays.asList(
+ "HashingTFTest", "Hashing", "Term", "Frequency", "Test")),
+ Row.of(
+ Arrays.asList(
+ "HashingTFTest", "Hashing", "Hashing", "Test", "Test")));
+
+ Table inputTable = tEnv.fromDataStream(inputStream).as("input");
+
+ // Creates a HashingTF object and initializes its parameters.
+ HashingTF hashingTF =
+ new HashingTF().setInputCol("input").setOutputCol("output").setNumFeatures(128);
+
+ // Uses the HashingTF object for feature transformations.
+ Table outputTable = hashingTF.transform(inputTable)[0];
+
+ // Extracts and displays the results.
+ for (CloseableIterator<Row> it = outputTable.execute().collect(); it.hasNext(); ) {
+ Row row = it.next();
+
+ List<Object> inputValue = (List<Object>) row.getField(hashingTF.getInputCol());
+ SparseVector outputValue = (SparseVector) row.getField(hashingTF.getOutputCol());
+
+ System.out.printf(
+ "Input Value: %s \tOutput Value: %s\n",
+ Arrays.toString(inputValue.stream().toArray()), outputValue);
+ }
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTF.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTF.java
new file mode 100644
index 0000000..3920198
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTF.java
@@ -0,0 +1,198 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.hashingtf;
+
+import org.apache.flink.api.common.functions.MapFunction;
+import org.apache.flink.api.java.typeutils.RowTypeInfo;
+import org.apache.flink.ml.api.Transformer;
+import org.apache.flink.ml.common.datastream.TableUtils;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.linalg.typeinfo.SparseVectorTypeInfo;
+import org.apache.flink.ml.param.Param;
+import org.apache.flink.ml.util.ParamUtils;
+import org.apache.flink.ml.util.ReadWriteUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.table.api.internal.TableImpl;
+import org.apache.flink.table.catalog.ResolvedSchema;
+import org.apache.flink.types.Row;
+import org.apache.flink.util.Preconditions;
+
+import org.apache.commons.lang3.ArrayUtils;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+
+import static org.apache.flink.shaded.guava30.com.google.common.hash.Hashing.murmur3_32;
+
+/**
+ * A Transformer that maps a sequence of terms(strings, numbers, booleans) to a sparse vector with a
+ * specified dimension using the hashing trick.
+ *
+ * <p>If multiple features are projected into the same column, the output values are accumulated by
+ * default. Users could also enforce all non-zero output values as 1 by setting {@link
+ * HashingTFParams#BINARY} as true.
+ *
+ * <p>For the hashing trick, see https://en.wikipedia.org/wiki/Feature_hashing for details.
+ */
+public class HashingTF implements Transformer<HashingTF>, HashingTFParams<HashingTF> {
+ private final Map<Param<?>, Object> paramMap = new HashMap<>();
+
+ private static final org.apache.flink.shaded.guava30.com.google.common.hash.HashFunction
+ HASH_FUNC = murmur3_32(0);
+
+ public HashingTF() {
+ ParamUtils.initializeMapWithDefaultValues(paramMap, this);
+ }
+
+ @Override
+ public Table[] transform(Table... inputs) {
+ Preconditions.checkArgument(inputs.length == 1);
+ StreamTableEnvironment tEnv =
+ (StreamTableEnvironment) ((TableImpl) inputs[0]).getTableEnvironment();
+
+ ResolvedSchema tableSchema = inputs[0].getResolvedSchema();
+
+ RowTypeInfo inputTypeInfo = TableUtils.getRowTypeInfo(tableSchema);
+ RowTypeInfo outputTypeInfo =
+ new RowTypeInfo(
+ ArrayUtils.addAll(
+ inputTypeInfo.getFieldTypes(), SparseVectorTypeInfo.INSTANCE),
+ ArrayUtils.addAll(inputTypeInfo.getFieldNames(), getOutputCol()));
+
+ DataStream<Row> output =
+ tEnv.toDataStream(inputs[0])
+ .map(
+ new HashTFFunction(getInputCol(), getBinary(), getNumFeatures()),
+ outputTypeInfo);
+ return new Table[] {tEnv.fromDataStream(output)};
+ }
+
+ @Override
+ public void save(String path) throws IOException {
+ ReadWriteUtils.saveMetadata(this, path);
+ }
+
+ @Override
+ public Map<Param<?>, Object> getParamMap() {
+ return paramMap;
+ }
+
+ public static HashingTF load(StreamTableEnvironment tEnv, String path) throws IOException {
+ return ReadWriteUtils.loadStageParam(path);
+ }
+
+ /** The main logic of {@link HashingTF}, which converts the input to a sparse vector. */
+ public static class HashTFFunction implements MapFunction<Row, Row> {
+ private final String inputCol;
+ private final boolean binary;
+ private final int numFeatures;
+
+ public HashTFFunction(String inputCol, boolean binary, int numFeatures) {
+ this.inputCol = inputCol;
+ this.binary = binary;
+ this.numFeatures = numFeatures;
+ }
+
+ @Override
+ public Row map(Row row) throws Exception {
+ Object inputObj = row.getField(inputCol);
+
+ Iterable<Object> inputList;
+ if (inputObj.getClass().isArray()) {
+ inputList = Arrays.asList((Object[]) inputObj);
+ } else if (inputObj instanceof Iterable) {
+ inputList = (Iterable<Object>) inputObj;
+ } else {
+ throw new IllegalArgumentException(
+ "Input format "
+ + inputObj.getClass().getCanonicalName()
+ + " is not supported for input column "
+ + inputCol
+ + ". Supported options are Array and Iterable.");
+ }
+
+ Map<Integer, Integer> map = new HashMap<>();
+ for (Object obj : inputList) {
+ int hashValue = hash(obj);
+ int index = nonNegativeMod(hashValue, numFeatures);
+ if (map.containsKey(index)) {
+ if (!binary) {
+ map.put(index, map.get(index) + 1);
+ }
+ } else {
+ map.put(index, 1);
+ }
+ }
+
+ // Converts from map to a sparse vector.
+ int[] indices = new int[map.size()];
+ double[] values = new double[map.size()];
+ int idx = 0;
+ for (Map.Entry<Integer, Integer> entry : map.entrySet()) {
+ indices[idx] = entry.getKey();
+ values[idx] = entry.getValue();
+ idx++;
+ }
+ return Row.join(row, Row.of(Vectors.sparse(numFeatures, indices, values)));
+ }
+ }
+
+ private static int hash(Object obj) {
+ if (obj == null) {
+ return 0;
+ } else if (obj instanceof Boolean) {
+ int value = (Boolean) obj ? 1 : 0;
+ return HASH_FUNC.hashInt(value).asInt();
+ } else if (obj instanceof Byte) {
+ byte value = (Byte) obj;
+ return HASH_FUNC.hashInt(value).asInt();
+ } else if (obj instanceof Short) {
+ short value = (Short) obj;
+ return HASH_FUNC.hashInt(value).asInt();
+ } else if (obj instanceof Integer) {
+ int value = (Integer) obj;
+ return HASH_FUNC.hashInt(value).asInt();
+ } else if (obj instanceof Long) {
+ long value = (Long) obj;
+ return HASH_FUNC.hashLong(value).asInt();
+ } else if (obj instanceof Float) {
+ float value = (Float) obj;
+ return HASH_FUNC.hashInt(Float.floatToIntBits(value)).asInt();
+ } else if (obj instanceof Double) {
+ double value = (Double) obj;
+ return HASH_FUNC.hashLong(Double.doubleToLongBits(value)).asInt();
+ } else if (obj instanceof String) {
+ return HASH_FUNC.hashUnencodedChars((String) obj).asInt();
+ } else {
+ throw new UnsupportedOperationException(
+ "HashingTF does not support type "
+ + obj.getClass().getCanonicalName()
+ + " of input data.");
+ }
+ }
+
+ private static int nonNegativeMod(int x, int mod) {
+ int rawMod = x % mod;
+ return rawMod < 0 ? rawMod + mod : rawMod;
+ }
+}
diff --git a/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTFParams.java b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTFParams.java
new file mode 100644
index 0000000..f4a4df9
--- /dev/null
+++ b/flink-ml-lib/src/main/java/org/apache/flink/ml/feature/hashingtf/HashingTFParams.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature.hashingtf;
+
+import org.apache.flink.ml.common.param.HasInputCol;
+import org.apache.flink.ml.common.param.HasNumFeatures;
+import org.apache.flink.ml.common.param.HasOutputCol;
+import org.apache.flink.ml.param.BooleanParam;
+import org.apache.flink.ml.param.Param;
+
+/**
+ * Params of {@link HashingTF}.
+ *
+ * @param <T> The class type of this instance.
+ */
+public interface HashingTFParams<T> extends HasInputCol<T>, HasOutputCol<T>, HasNumFeatures<T> {
+
+ /**
+ * Supported options to decide whether each dimension of the output vector is binary or not.
+ *
+ * <ul>
+ * <li>true: the value at one dimension is set as 1 if there are some features hashed to this
+ * column.
+ * <li>false: the value at one dimension is set as number of features that has been hashed to
+ * this column.
+ * </ul>
+ */
+ Param<Boolean> BINARY =
+ new BooleanParam(
+ "binary",
+ "Whether each dimension of the output vector is binary or not.",
+ false);
+
+ default boolean getBinary() {
+ return get(BINARY);
+ }
+
+ default T setBinary(boolean value) {
+ return set(BINARY, value);
+ }
+}
diff --git a/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/HashingTFTest.java b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/HashingTFTest.java
new file mode 100644
index 0000000..5b170ab
--- /dev/null
+++ b/flink-ml-lib/src/test/java/org/apache/flink/ml/feature/HashingTFTest.java
@@ -0,0 +1,186 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.flink.ml.feature;
+
+import org.apache.flink.api.common.restartstrategy.RestartStrategies;
+import org.apache.flink.configuration.Configuration;
+import org.apache.flink.ml.feature.hashingtf.HashingTF;
+import org.apache.flink.ml.linalg.Vectors;
+import org.apache.flink.ml.util.TestUtils;
+import org.apache.flink.streaming.api.datastream.DataStream;
+import org.apache.flink.streaming.api.environment.ExecutionCheckpointingOptions;
+import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
+import org.apache.flink.table.api.Expressions;
+import org.apache.flink.table.api.Table;
+import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
+import org.apache.flink.test.util.AbstractTestBase;
+import org.apache.flink.types.Row;
+
+import org.apache.commons.collections.IteratorUtils;
+import org.junit.Before;
+import org.junit.Test;
+
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertTrue;
+
+/** Tests {@link HashingTF}. */
+public class HashingTFTest extends AbstractTestBase {
+ private StreamTableEnvironment tEnv;
+ private StreamExecutionEnvironment env;
+ private Table inputDataTable;
+
+ private static final List<Row> INPUT =
+ Arrays.asList(
+ Row.of(Arrays.asList("HashingTFTest", "Hashing", "Term", "Frequency", "Test")),
+ Row.of(Arrays.asList("HashingTFTest", "Hashing", "Hashing", "Test", "Test")));
+
+ private static final List<Row> EXPECTED_OUTPUT =
+ Arrays.asList(
+ Row.of(
+ Vectors.sparse(
+ 262144,
+ new int[] {67564, 89917, 113827, 131486, 228971},
+ new double[] {1.0, 1.0, 1.0, 1.0, 1.0})),
+ Row.of(
+ Vectors.sparse(
+ 262144,
+ new int[] {67564, 131486, 228971},
+ new double[] {1.0, 2.0, 2.0})));
+
+ private static final List<Row> EXPECTED_BINARY_OUTPUT =
+ Arrays.asList(
+ Row.of(
+ Vectors.sparse(
+ 262144,
+ new int[] {67564, 89917, 113827, 131486, 228971},
+ new double[] {1.0, 1.0, 1.0, 1.0, 1.0})),
+ Row.of(
+ Vectors.sparse(
+ 262144,
+ new int[] {67564, 131486, 228971},
+ new double[] {1.0, 1.0, 1.0})));
+
+ @Before
+ public void before() {
+ Configuration config = new Configuration();
+ config.set(ExecutionCheckpointingOptions.ENABLE_CHECKPOINTS_AFTER_TASKS_FINISH, true);
+ env = StreamExecutionEnvironment.getExecutionEnvironment(config);
+ env.setParallelism(4);
+ env.enableCheckpointing(100);
+ env.setRestartStrategy(RestartStrategies.noRestart());
+ tEnv = StreamTableEnvironment.create(env);
+ DataStream<Row> dataStream = env.fromCollection(INPUT);
+ inputDataTable = tEnv.fromDataStream(dataStream).as("input");
+ }
+
+ @Test
+ public void testParam() {
+ HashingTF hashingTF = new HashingTF();
+ assertEquals("input", hashingTF.getInputCol());
+ assertFalse(hashingTF.getBinary());
+ assertEquals(262144, hashingTF.getNumFeatures());
+ assertEquals("output", hashingTF.getOutputCol());
+
+ hashingTF
+ .setInputCol("testInputCol")
+ .setBinary(true)
+ .setNumFeatures(1024)
+ .setOutputCol("testOutputCol");
+
+ assertEquals("testInputCol", hashingTF.getInputCol());
+ assertTrue(hashingTF.getBinary());
+ assertEquals(1024, hashingTF.getNumFeatures());
+ assertEquals("testOutputCol", hashingTF.getOutputCol());
+ }
+
+ @Test
+ public void testOutputSchema() {
+ HashingTF hashingTF = new HashingTF();
+ inputDataTable =
+ tEnv.fromDataStream(env.fromElements(Row.of(Arrays.asList(""), Arrays.asList(""))))
+ .as("input", "dummyInput");
+
+ Table output = hashingTF.transform(inputDataTable)[0];
+ assertEquals(
+ Arrays.asList(hashingTF.getInputCol(), "dummyInput", hashingTF.getOutputCol()),
+ output.getResolvedSchema().getColumnNames());
+ }
+
+ @Test
+ public void testTransform() throws Exception {
+ HashingTF hashingTF = new HashingTF();
+ Table output;
+
+ // Tests non-binary.
+ output = hashingTF.transform(inputDataTable)[0];
+ verifyOutputResult(output, hashingTF.getOutputCol(), EXPECTED_OUTPUT);
+
+ // Tests binary.
+ hashingTF.setBinary(true);
+ output = hashingTF.transform(inputDataTable)[0];
+ verifyOutputResult(output, hashingTF.getOutputCol(), EXPECTED_BINARY_OUTPUT);
+ }
+
+ @Test
+ public void testTransformArrayData() throws Exception {
+ HashingTF hashingTF = new HashingTF();
+ inputDataTable =
+ tEnv.fromDataStream(
+ env.fromElements(
+ new String[] {
+ "HashingTFTest", "Hashing", "Term", "Frequency", "Test"
+ },
+ new String[] {
+ "HashingTFTest", "Hashing", "Hashing", "Test", "Test"
+ }))
+ .as("input");
+
+ Table output = hashingTF.transform(inputDataTable)[0];
+ verifyOutputResult(output, hashingTF.getOutputCol(), EXPECTED_OUTPUT);
+ }
+
+ @Test
+ public void testSaveLoadAndTransform() throws Exception {
+ HashingTF hashingTF = new HashingTF();
+ HashingTF loadedHashingTF =
+ TestUtils.saveAndReload(
+ tEnv, hashingTF, TEMPORARY_FOLDER.newFolder().getAbsolutePath());
+
+ Table output = loadedHashingTF.transform(inputDataTable)[0];
+ verifyOutputResult(output, loadedHashingTF.getOutputCol(), EXPECTED_OUTPUT);
+ }
+
+ private void verifyOutputResult(Table output, String outputCol, List<Row> expectedOutput)
+ throws Exception {
+ DataStream<Row> dataStream = tEnv.toDataStream(output.select(Expressions.$(outputCol)));
+ List<Row> results = IteratorUtils.toList(dataStream.executeAndCollect());
+ assertEquals(expectedOutput.size(), results.size());
+
+ results.sort(Comparator.comparingInt(o -> o.getField(0).hashCode()));
+ expectedOutput.sort(Comparator.comparingInt(o -> o.getField(0).hashCode()));
+ for (int i = 0; i < expectedOutput.size(); i++) {
+ assertEquals(expectedOutput.get(i).getField(0), results.get(i).getField(0));
+ }
+ }
+}
diff --git a/flink-ml-python/pyflink/examples/ml/feature/hashingtf_example.py b/flink-ml-python/pyflink/examples/ml/feature/hashingtf_example.py
new file mode 100644
index 0000000..9352c81
--- /dev/null
+++ b/flink-ml-python/pyflink/examples/ml/feature/hashingtf_example.py
@@ -0,0 +1,61 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+
+# Simple program that creates a VectorAssembler instance and uses it for feature
+# engineering.
+#
+# Before executing this program, please make sure you have followed Flink ML's
+# quick start guideline to set up Flink ML and Flink environment. The guideline
+# can be found at
+#
+# https://nightlies.apache.org/flink/flink-ml-docs-master/docs/try-flink-ml/quick-start/
+
+from pyflink.common import Types
+from pyflink.datastream import StreamExecutionEnvironment
+from pyflink.ml.lib.feature.hashingtf import HashingTF
+from pyflink.table import StreamTableEnvironment
+
+env = StreamExecutionEnvironment.get_execution_environment()
+
+t_env = StreamTableEnvironment.create(env)
+
+# Generates input data.
+input_data_table = t_env.from_data_stream(
+ env.from_collection([
+ (['HashingTFTest', 'Hashing', 'Term', 'Frequency', 'Test'],),
+ (['HashingTFTest', 'Hashing', 'Hashing', 'Test', 'Test'],),
+ ],
+ type_info=Types.ROW_NAMED(
+ ["input", ],
+ [Types.OBJECT_ARRAY(Types.STRING())])))
+
+# Creates a HashingTF object and initializes its parameters.
+hashing_tf = HashingTF() \
+ .set_input_col('input') \
+ .set_num_features(128) \
+ .set_output_col('output')
+
+# Uses the HashingTF object for feature transformations.
+output = hashing_tf.transform(input_data_table)[0]
+
+# Extracts and displays the results.
+field_names = output.get_schema().get_field_names()
+for result in t_env.to_data_stream(output).execute_and_collect():
+ input_value = result[field_names.index(hashing_tf.get_input_col())]
+ output_value = result[field_names.index(hashing_tf.get_output_col())]
+ print('Input Value: ' + ' '.join(input_value) + '\tOutput Value: ' + str(output_value))
diff --git a/flink-ml-python/pyflink/ml/lib/feature/hashingtf.py b/flink-ml-python/pyflink/ml/lib/feature/hashingtf.py
new file mode 100644
index 0000000..32a5b7a
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/hashingtf.py
@@ -0,0 +1,86 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import typing
+
+from pyflink.ml.core.param import BooleanParam
+from pyflink.ml.core.wrapper import JavaWithParams
+from pyflink.ml.lib.feature.common import JavaFeatureTransformer
+from pyflink.ml.lib.param import HasInputCol, HasOutputCol, HasNumFeatures
+
+
+class _HashingTFParams(
+ JavaWithParams,
+ HasInputCol,
+ HasOutputCol,
+ HasNumFeatures
+):
+ """
+ Params for :class:`HashingTF`.
+ """
+
+ """
+ Supported options to decide whether each dimension of the output vector is binary or not.
+ <ul>
+ <li>true: the value at one dimension is set as 1 if there are some features hashed to this
+ column.
+ <li>false: the value at one dimension is set as number of features that has been hashed to
+ this column.
+ </ul>
+ """
+ BINARY: BooleanParam = BooleanParam(
+ "binary",
+ "Whether each dimension of the output vector is binary or not.",
+ False
+ )
+
+ def __init__(self, java_params):
+ super(_HashingTFParams, self).__init__(java_params)
+
+ def set_binary(self, value: bool):
+ return typing.cast(_HashingTFParams, self.set(self.BINARY, value))
+
+ def get_binary(self) -> bool:
+ return self.get(self.BINARY)
+
+ @property
+ def binary(self) -> int:
+ return self.get_binary()
+
+
+class HashingTF(JavaFeatureTransformer, _HashingTFParams):
+ """
+ A Transformer that maps a sequence of terms(strings, numbers, booleans) to a sparse vector
+ with a specified dimension using the hashing trick.
+
+ <p>If multiple features are projected into the same column, the output values are accumulated
+ by default. Users could also enforce all non-zero output values as 1 by setting {@link
+ HashingTFParams#BINARY} as true.
+
+ <p>For the hashing trick, see https://en.wikipedia.org/wiki/Feature_hashing for details.
+ """
+
+ def __init__(self, java_model=None):
+ super(HashingTF, self).__init__(java_model)
+
+ @classmethod
+ def _java_transformer_package_name(cls) -> str:
+ return "hashingtf"
+
+ @classmethod
+ def _java_transformer_class_name(cls) -> str:
+ return "HashingTF"
diff --git a/flink-ml-python/pyflink/ml/lib/feature/tests/test_hashingtf.py b/flink-ml-python/pyflink/ml/lib/feature/tests/test_hashingtf.py
new file mode 100644
index 0000000..ade6e3b
--- /dev/null
+++ b/flink-ml-python/pyflink/ml/lib/feature/tests/test_hashingtf.py
@@ -0,0 +1,115 @@
+################################################################################
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+################################################################################
+import os
+
+from pyflink.common import Types
+
+from pyflink.ml.core.linalg import Vectors
+from pyflink.ml.lib.feature.hashingtf import HashingTF
+from pyflink.ml.tests.test_utils import PyFlinkMLTestCase
+
+
+class HashingTFTest(PyFlinkMLTestCase):
+ def setUp(self):
+ super(HashingTFTest, self).setUp()
+ self.input_data_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ (['HashingTFTest', 'Hashing', 'Term', 'Frequency', 'Test'],),
+ (['HashingTFTest', 'Hashing', 'Hashing', 'Test', 'Test'],),
+ ],
+ type_info=Types.ROW_NAMED(
+ ["input", ],
+ [Types.OBJECT_ARRAY(Types.STRING())])))
+
+ self.expected_output = [
+ Vectors.sparse(262144, [67564, 89917, 113827, 131486, 228971],
+ [1.0, 1.0, 1.0, 1.0, 1.0]),
+ Vectors.sparse(262144, [67564, 131486, 228971], [1.0, 2.0, 2.0])
+ ]
+
+ self.expected_binary_output = [
+ Vectors.sparse(262144, [67564, 89917, 113827, 131486, 228971],
+ [1.0, 1.0, 1.0, 1.0, 1.0]),
+ Vectors.sparse(262144, [67564, 131486, 228971], [1.0, 1.0, 1.0])
+ ]
+
+ def test_param(self):
+ hashing_tf = HashingTF()
+ self.assertEqual('input', hashing_tf.input_col)
+ self.assertFalse(hashing_tf.binary)
+ self.assertEqual(262144, hashing_tf.num_features)
+ self.assertEqual('output', hashing_tf.output_col)
+
+ hashing_tf.set_input_col("test_input_col") \
+ .set_binary(True) \
+ .set_num_features(1024) \
+ .set_output_col("test_output_col")
+
+ self.assertEqual('test_input_col', hashing_tf.input_col)
+ self.assertTrue(hashing_tf.binary)
+ self.assertEqual(1024, hashing_tf.num_features)
+ self.assertEqual('test_output_col', hashing_tf.output_col)
+
+ def test_output_schema(self):
+ hashing_tf = HashingTF()
+ input_data_table = self.t_env.from_data_stream(
+ self.env.from_collection([
+ ([''], ''),
+ ],
+ type_info=Types.ROW_NAMED(
+ ['input', 'dummy_input'],
+ [Types.OBJECT_ARRAY(Types.STRING()), Types.STRING()])))
+
+ output = hashing_tf \
+ .set_input_col('input') \
+ .set_output_col('output') \
+ .transform(input_data_table)[0]
+
+ self.assertEqual(
+ [hashing_tf.input_col, 'dummy_input', hashing_tf.output_col],
+ output.get_schema().get_field_names())
+
+ def verify_output_result(self, output_table, expected_output):
+ predicted_result = [result[1] for result in
+ self.t_env.to_data_stream(output_table).execute_and_collect()]
+ expected_output.sort(key=lambda x: x[89917])
+ predicted_result.sort(key=lambda x: x[89917])
+ self.assertEqual(len(expected_output), len(predicted_result))
+
+ for i in range(len(expected_output)):
+ self.assertEqual(expected_output[i], predicted_result[i])
+
+ def test_transform(self):
+ hashing_tf = HashingTF()
+
+ # Tests non-binary.
+ output = hashing_tf.transform(self.input_data_table)[0]
+ self.verify_output_result(output, self.expected_output)
+
+ # Tests binary.
+ hashing_tf.set_binary(True)
+ output = hashing_tf.transform(self.input_data_table)[0]
+ self.verify_output_result(output, self.expected_binary_output)
+
+ def test_save_load_transform(self):
+ hashingtf = HashingTF()
+ path = os.path.join(self.temp_dir, 'test_save_load_transform_hashingtf')
+ hashingtf.save(path)
+ hashingtf = HashingTF.load(self.t_env, path)
+ output = hashingtf.transform(self.input_data_table)[0]
+ self.verify_output_result(output, self.expected_output)