You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2018/11/26 11:10:59 UTC

ignite git commit: IGNITE-9145:[ML] Add different strategies to index labels in StringEncoderTrainer

Repository: ignite
Updated Branches:
  refs/heads/master cdaeda108 -> 9137af73e


IGNITE-9145:[ML] Add different strategies to index labels in
StringEncoderTrainer

this closes #5481


Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/9137af73
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/9137af73
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/9137af73

Branch: refs/heads/master
Commit: 9137af73ef20228ee98e4bc95a8ccb15dadd0010
Parents: cdaeda1
Author: zaleslaw <za...@gmail.com>
Authored: Mon Nov 26 14:10:51 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Mon Nov 26 14:10:51 2018 +0300

----------------------------------------------------------------------
 .../encoding/EncoderSortingStrategy.java        | 31 ++++++++++++++++++++
 .../preprocessing/encoding/EncoderTrainer.java  | 25 +++++++++++++++-
 .../encoding/EncoderTrainerTest.java            | 27 +++++++++++++++++
 3 files changed, 82 insertions(+), 1 deletion(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/ignite/blob/9137af73/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java
new file mode 100644
index 0000000..22cca53
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderSortingStrategy.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.preprocessing.encoding;
+
+/**
+ * Describes Encoder sorting strategy to define mapping of integer values to values of categorical feature .
+ *
+ * @see EncoderTrainer
+ */
+public enum EncoderSortingStrategy {
+    /** Descending order by label frequency (most frequent label assigned 0). */
+    FREQUENCY_DESC,
+
+    /** Ascending order by label frequency (least frequent label assigned 0). */
+    FREQUENCY_ASC
+}

http://git-wip-us.apache.org/repos/asf/ignite/blob/9137af73/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java
index 9a97a6d..d5668e4 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainer.java
@@ -17,6 +17,8 @@
 
 package org.apache.ignite.ml.preprocessing.encoding;
 
+import java.util.Collections;
+import java.util.Comparator;
 import java.util.HashMap;
 import java.util.HashSet;
 import java.util.LinkedHashMap;
@@ -47,6 +49,9 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[]
     /** Encoder preprocessor type. */
     private EncoderType encoderType = EncoderType.ONE_HOT_ENCODER;
 
+    /** Encoder sorting strategy. */
+    private EncoderSortingStrategy encoderSortingStgy = EncoderSortingStrategy.FREQUENCY_DESC;
+
     /** {@inheritDoc} */
     @Override public EncoderPreprocessor<K, V> fit(DatasetBuilder<K, V> datasetBuilder,
                                                    IgniteBiFunction<K, V, Object[]> basePreprocessor) {
@@ -129,9 +134,16 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[]
      * @return Encoding values.
      */
     private Map<String, Integer> transformFrequenciesToEncodingValues(Map<String, Integer> frequencies) {
+        Comparator<Map.Entry<String, Integer>> comp;
+
+        if (encoderSortingStgy.equals(EncoderSortingStrategy.FREQUENCY_DESC))
+            comp = Map.Entry.comparingByValue();
+        else
+            comp = Collections.reverseOrder(Map.Entry.comparingByValue());
+
         final HashMap<String, Integer> resMap = frequencies.entrySet()
             .stream()
-            .sorted(Map.Entry.comparingByValue())
+            .sorted(comp)
             .collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue,
                 (oldValue, newValue) -> oldValue, LinkedHashMap::new));
 
@@ -211,6 +223,17 @@ public class EncoderTrainer<K, V> implements PreprocessingTrainer<K, V, Object[]
     }
 
     /**
+     * Sets the encoder indexing strategy.
+     *
+     * @param encoderSortingStgy The encoder indexing strategy.
+     * @return The changed trainer.
+     */
+    public EncoderTrainer<K, V> withEncoderIndexingStrategy(EncoderSortingStrategy encoderSortingStgy) {
+        this.encoderSortingStgy = encoderSortingStgy;
+        return this;
+    }
+
+    /**
      * Sets the encoder preprocessor type.
      *
      * @param type The encoder preprocessor type.

http://git-wip-us.apache.org/repos/asf/ignite/blob/9137af73/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java
index 23afd30..7c7eabe 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/preprocessing/encoding/EncoderTrainerTest.java
@@ -115,4 +115,31 @@ public class EncoderTrainerTest extends TrainerTest {
         }
         fail("UnknownCategorialFeatureValue");
     }
+
+    /** Tests {@code fit()} method. */
+    @Test
+    public void testFitOnStringCategorialFeaturesWithReversedOrder() {
+        Map<Integer, String[]> data = new HashMap<>();
+        data.put(1, new String[] {"Monday", "September"});
+        data.put(2, new String[] {"Monday", "August"});
+        data.put(3, new String[] {"Monday", "August"});
+        data.put(4, new String[] {"Friday", "June"});
+        data.put(5, new String[] {"Friday", "June"});
+        data.put(6, new String[] {"Sunday", "August"});
+
+        DatasetBuilder<Integer, String[]> datasetBuilder = new LocalDatasetBuilder<>(data, parts);
+
+        EncoderTrainer<Integer, String[]> strEncoderTrainer = new EncoderTrainer<Integer, String[]>()
+            .withEncoderType(EncoderType.STRING_ENCODER)
+            .withEncoderIndexingStrategy(EncoderSortingStrategy.FREQUENCY_ASC)
+            .withEncodedFeature(0)
+            .withEncodedFeature(1);
+
+        EncoderPreprocessor<Integer, String[]> preprocessor = strEncoderTrainer.fit(
+            datasetBuilder,
+            (k, v) -> v
+        );
+
+        assertArrayEquals(new double[] {2.0, 0.0}, preprocessor.apply(7, new String[] {"Monday", "September"}).asArray(), 1e-8);
+    }
 }