You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2018/11/26 11:10:59 UTC
ignite git commit: IGNITE-9145:[ML] Add different strategies to index
labels in StringEncoderTrainer
Repository: ignite
Updated Branches:
refs/heads/master cdaeda108 -> 9137af73e
IGNITE-9145:[ML] Add different strategies to index labels in
StringEncoderTrainer
this closes #5481
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/9137af73
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/9137af73
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/9137af73
Branch: refs/heads/master
Commit: 9137af73ef20228ee98e4bc95a8ccb15dadd0010
Parents: cdaeda1
Author: zaleslaw <za...@gmail.com>
Authored: Mon Nov 26 14:10:51 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Mon Nov 26 14:10:51 2018 +0300
----------------------------------------------------------------------
.../encoding/EncoderSortingStrategy.java | 31 ++++++++++++++++++++
.../preprocessing/encoding/EncoderTrainer.java | 25 +++++++++++++++-
.../encoding/EncoderTrainerTest.java | 27 +++++++++++++++++
3 files changed, 82 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/9137af73/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java
new file mode 100644
index 0000000..22cca53
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.preprocessing.encoding;
+
+/**
+ * Describes Encoder sorting strategy to define mapping of integer values to values of categorical feature .
+ *
+ * @see EncoderTrainer
+ */
+public enum EncoderSortingStrategy {
+ /** Descending order by label frequency (most frequent label assigned 0). */
+ FREQUENCY_DESC,
+
+ /** Ascending order by label frequency (least frequent label assigned 0). */
+ FREQUENCY_ASC
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/9137af73/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java
index 9a97a6d..d5668e4 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java
@@ -17,6 +17,8 @@
package org.apache.ignite.ml.preprocessing.encoding;
+import java.util.Collections;
+import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
@@ -47,6 +49,9 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[]
/** Encoder preprocessor type. */
private EncoderType encoderType = EncoderType.ONE_HOT_ENCODER;
+ /** Encoder sorting strategy. */
+ private EncoderSortingStrategy encoderSortingStgy = EncoderSortingStrategy.FREQUENCY_DESC;
+
/** {@inheritDoc} */
@Override public EncoderPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, Object[]> basePreprocessor) {
@@ -129,9 +134,16 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[]
* @return Encoding values.
*/
private Map<String, Integer> transformFrequenciesToEncodingValues(Map<String, Integer> frequencies) {
+ Comparator<Map.Entry<String, Integer>> comp;
+
+ if (encoderSortingStgy.equals(EncoderSortingStrategy.FREQUENCY_DESC))
+ comp = Map.Entry.comparingByValue();
+ else
+ comp = Collections.reverseOrder(Map.Entry.comparingByValue());
+
final HashMap<String, Integer> resMap = frequencies.entrySet()
.stream()
- .sorted(Map.Entry.comparingByValue())
+ .sorted(comp)
.collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,
(oldValue, newValue) -> oldValue, LinkedHashMap::new));
@@ -211,6 +223,17 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[]
}
/**
+ * Sets the encoder indexing strategy.
+ *
+ * @param encoderSortingStgy The encoder indexing strategy.
+ * @return The changed trainer.
+ */
+ public EncoderTrainer<K, V> withEncoderIndexingStrategy(EncoderSortingStrategy encoderSortingStgy) {
+ this.encoderSortingStgy = encoderSortingStgy;
+ return this;
+ }
+
+ /**
* Sets the encoder preprocessor type.
*
* @param type The encoder preprocessor type.
http://git-wip-us.apache.org/repos/asf/ignite/blob/9137af73/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java
index 23afd30..7c7eabe 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java
@@ -115,4 +115,31 @@ public class EncoderTrainerTest extends TrainerTest {
}
fail("UnknownCategorialFeatureValue");
}
+
+ /** Tests {@code fit()} method. */
+ @Test
+ public void testFitOnStringCategorialFeaturesWithReversedOrder() {
+ Map<Integer, String[]> data = new HashMap<>();
+ data.put(1, new String[] {"Monday", "September"});
+ data.put(2, new String[] {"Monday", "August"});
+ data.put(3, new String[] {"Monday", "August"});
+ data.put(4, new String[] {"Friday", "June"});
+ data.put(5, new String[] {"Friday", "June"});
+ data.put(6, new String[] {"Sunday", "August"});
+
+ DatasetBuilder<Integer, String[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts);
+
+ EncoderTrainer<Integer, String[]> strEncoderTrainer = new EncoderTrainer<Integer, String[]>()
+ .withEncoderType(EncoderType.STRING_ENCODER)
+ .withEncoderIndexingStrategy(EncoderSortingStrategy.FREQUENCY_ASC)
+ .withEncodedFeature(0)
+ .withEncodedFeature(1);
+
+ EncoderPreprocessor<Integer, String[]> preprocessor = strEncoderTrainer.fit(
+ datasetBuilder,
+ (k, v) -> v
+ );
+
+ assertArrayEquals(new double[] {2.0, 0.0}, preprocessor.apply(7, new String[] {"Monday", "September"}).asArray(), 1e-8);
+ }
}