You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ag...@apache.org on 2018/11/28 11:53:14 UTC
[43/50] [abbrv] ignite git commit: IGNITE-8542: [ML] Add OneVsRest
Trainer to handle cases with multiple class labels in dataset.
IGNITE-8542: [ML] Add OneVsRest Trainer to handle cases with
multiple class labels in dataset.
This closes #5512
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/c3fd4a93
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/c3fd4a93
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/c3fd4a93
Branch: refs/heads/ignite-9720
Commit: c3fd4a930cc1a76b4d1fbccc6d764bdfe88da941
Parents: 3885f3f
Author: zaleslaw <za...@gmail.com>
Authored: Wed Nov 28 01:45:11 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Wed Nov 28 01:45:11 2018 +0300
----------------------------------------------------------------------
.../ignite/ml/multiclass/MultiClassModel.java | 115 +++++++++++++++
.../ignite/ml/multiclass/OneVsRestTrainer.java | 147 +++++++++++++++++++
.../org/apache/ignite/ml/IgniteMLTestSuite.java | 4 +-
.../ml/multiclass/MultiClassTestSuite.java | 32 ++++
.../ml/multiclass/OneVsRestTrainerTest.java | 126 ++++++++++++++++
5 files changed, 423 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java
new file mode 100644
index 0000000..8520aa9
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/MultiClassModel.java
@@ -0,0 +1,115 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.multiclass;
+
+import java.io.Serializable;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.TreeMap;
+import org.apache.ignite.ml.Exportable;
+import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/** Base class for multi-classification model for set of classifiers. */
+public class MultiClassModel<M extends Model<Vector, Double>> implements Model<Vector, Double>, Exportable<MultiClassModel>, Serializable {
+ /** */
+ private static final long serialVersionUID = -114986533359917L;
+
+ /** List of models associated with each class. */
+ private Map<Double, M> models;
+
+ /** */
+ public MultiClassModel() {
+ this.models = new HashMap<>();
+ }
+
+ /**
+ * Adds a specific binary classifier to the bunch of same classifiers.
+ *
+ * @param clsLb The class label for the added model.
+ * @param mdl The model.
+ */
+ public void add(double clsLb, M mdl) {
+ models.put(clsLb, mdl);
+ }
+
+ /**
+ * @param clsLb Class label.
+ * @return model for class label if it exists.
+ */
+ public Optional<M> getModel(Double clsLb) {
+ return Optional.ofNullable(models.get(clsLb));
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double apply(Vector input) {
+ TreeMap<Double, Double> maxMargins = new TreeMap<>();
+
+ models.forEach((k, v) -> maxMargins.put(v.apply(input), k));
+
+ // returns value the most closest to 1
+ return maxMargins.lastEntry().getValue();
+ }
+
+ /** {@inheritDoc} */
+ @Override public <P> void saveModel(Exporter<MultiClassModel, P> exporter, P path) {
+ exporter.save(this, path);
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean equals(Object o) {
+ if (this == o)
+ return true;
+
+ if (o == null || getClass() != o.getClass())
+ return false;
+
+ MultiClassModel mdl = (MultiClassModel)o;
+
+ return Objects.equals(models, mdl.models);
+ }
+
+ /** {@inheritDoc} */
+ @Override public int hashCode() {
+ return Objects.hash(models);
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ StringBuilder wholeStr = new StringBuilder();
+
+ models.forEach((clsLb, mdl) ->
+ wholeStr
+ .append("The class with label ")
+ .append(clsLb)
+ .append(" has classifier: ")
+ .append(mdl.toString())
+ .append(System.lineSeparator())
+ );
+
+ return wholeStr.toString();
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString(boolean pretty) {
+ return toString();
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
new file mode 100644
index 0000000..7426506
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
@@ -0,0 +1,147 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.multiclass;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.structures.partition.LabelPartitionDataBuilderOnHeap;
+import org.apache.ignite.ml.structures.partition.LabelPartitionDataOnHeap;
+import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+
+/**
+ * This is a common heuristic trainer for multi-class labeled models.
+ *
+ * NOTE: The current implementation suffers from unbalanced training over the dataset due to unweighted approach
+ * during the process of reassign labels from all range of labels to 0,1.
+ */
+public class OneVsRestTrainer<M extends Model<Vector, Double>>
+ extends SingleLabelDatasetTrainer<MultiClassModel<M>> {
+ /** The common binary classifier with all hyper-parameters to spread them for all separate trainings . */
+ private SingleLabelDatasetTrainer<M> classifier;
+
+ /** */
+ public OneVsRestTrainer(SingleLabelDatasetTrainer<M> classifier) {
+ this.classifier = classifier;
+ }
+
+ /**
+ * Trains model based on the specified data.
+ *
+ * @param datasetBuilder Dataset builder.
+ * @param featureExtractor Feature extractor.
+ * @param lbExtractor Label extractor.
+ * @return Model.
+ */
+ @Override public <K, V> MultiClassModel<M> fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor,
+ IgniteBiFunction<K, V, Double> lbExtractor) {
+
+ return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> MultiClassModel<M> updateModel(MultiClassModel<M> newMdl,
+ DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+ IgniteBiFunction<K, V, Double> lbExtractor) {
+
+ List<Double> classes = extractClassLabels(datasetBuilder, lbExtractor);
+
+ if (classes.isEmpty())
+ return getLastTrainedModelOrThrowEmptyDatasetException(newMdl);
+
+ MultiClassModel<M> multiClsMdl = new MultiClassModel<>();
+
+ classes.forEach(clsLb -> {
+ IgniteBiFunction<K, V, Double> lbTransformer = (k, v) -> {
+ Double lb = lbExtractor.apply(k, v);
+
+ if (lb.equals(clsLb))
+ return 1.0;
+ else
+ return 0.0;
+ };
+
+ M mdl = Optional.ofNullable(newMdl)
+ .flatMap(multiClassModel -> multiClassModel.getModel(clsLb))
+ .map(learnedModel -> classifier.update(learnedModel, datasetBuilder, featureExtractor, lbTransformer))
+ .orElseGet(() -> classifier.fit(datasetBuilder, featureExtractor, lbTransformer));
+
+ multiClsMdl.add(clsLb, mdl);
+ });
+
+ return multiClsMdl;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected boolean checkState(MultiClassModel<M> mdl) {
+ return true;
+ }
+
+ /** Iterates among dataset and collects class labels. */
+ private <K, V> List<Double> extractClassLabels(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Double> lbExtractor) {
+ assert datasetBuilder != null;
+
+ PartitionDataBuilder<K, V, EmptyContext, LabelPartitionDataOnHeap> partDataBuilder = new LabelPartitionDataBuilderOnHeap<>(lbExtractor);
+
+ List<Double> res = new ArrayList<>();
+
+ try (Dataset<EmptyContext, LabelPartitionDataOnHeap> dataset = datasetBuilder.build(
+ (upstream, upstreamSize) -> new EmptyContext(),
+ partDataBuilder
+ )) {
+ final Set<Double> clsLabels = dataset.compute(data -> {
+ final Set<Double> locClsLabels = new HashSet<>();
+
+ final double[] lbs = data.getY();
+
+ for (double lb : lbs)
+ locClsLabels.add(lb);
+
+ return locClsLabels;
+ }, (a, b) -> {
+ if (a == null)
+ return b == null ? new HashSet<>() : b;
+ if (b == null)
+ return a;
+ return Stream.of(a, b).flatMap(Collection::stream).collect(Collectors.toSet());
+ });
+
+ if (clsLabels != null)
+ res.addAll(clsLabels);
+
+ }
+ catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ return res;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
index f9645d8..78d6659 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
@@ -26,6 +26,7 @@ import org.apache.ignite.ml.genetic.GAGridTestSuite;
import org.apache.ignite.ml.inference.InferenceTestSuite;
import org.apache.ignite.ml.knn.KNNTestSuite;
import org.apache.ignite.ml.math.MathImplMainTestSuite;
+import org.apache.ignite.ml.multiclass.MultiClassTestSuite;
import org.apache.ignite.ml.nn.MLPTestSuite;
import org.apache.ignite.ml.pipeline.PipelineTestSuite;
import org.apache.ignite.ml.preprocessing.PreprocessingTestSuite;
@@ -61,7 +62,8 @@ import org.junit.runners.Suite;
StructuresTestSuite.class,
CommonTestSuite.class,
InferenceTestSuite.class,
- BaggingTest.class
+ BaggingTest.class,
+ MultiClassTestSuite.class
})
public class IgniteMLTestSuite {
// No-op.
http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java
new file mode 100644
index 0000000..551597f
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/MultiClassTestSuite.java
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.multiclass;
+
+import org.junit.runner.RunWith;
+import org.junit.runners.Suite;
+
+/**
+ * Test suite for multilayer perceptrons.
+ */
+@RunWith(Suite.class)
+@Suite.SuiteClasses({
+ OneVsRestTrainerTest.class
+})
+public class MultiClassTestSuite {
+ // No-op.
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/c3fd4a93/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
new file mode 100644
index 0000000..9842d92
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/multiclass/OneVsRestTrainerTest.java
@@ -0,0 +1,126 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.multiclass;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.junit.Assert;
+import org.junit.Test;
+
+/**
+ * Tests for {@link OneVsRestTrainer}.
+ */
+public class OneVsRestTrainerTest extends TrainerTest {
+ /**
+ * Test trainer on 2 linearly separable sets.
+ */
+ @Test
+ public void testTrainWithTheLinearlySeparableCase() {
+ Map<Integer, double[]> cacheMock = new HashMap<>();
+
+ for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+ cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+ LogisticRegressionSGDTrainer<?> binaryTrainer = new LogisticRegressionSGDTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+ .withMaxIterations(1000)
+ .withLocIterations(10)
+ .withBatchSize(100)
+ .withSeed(123L);
+
+ OneVsRestTrainer<LogisticRegressionModel> trainer = new OneVsRestTrainer<>(binaryTrainer);
+
+ MultiClassModel mdl = trainer.fit(
+ cacheMock,
+ parts,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
+
+ Assert.assertTrue(mdl.toString().length() > 0);
+ Assert.assertTrue(mdl.toString(true).length() > 0);
+ Assert.assertTrue(mdl.toString(false).length() > 0);
+
+ TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(-100, 0)), PRECISION);
+ TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 0)), PRECISION);
+ }
+
+ /** */
+ @Test
+ public void testUpdate() {
+ Map<Integer, double[]> cacheMock = new HashMap<>();
+
+ for (int i = 0; i < twoLinearlySeparableClasses.length; i++)
+ cacheMock.put(i, twoLinearlySeparableClasses[i]);
+
+ LogisticRegressionSGDTrainer<?> binaryTrainer = new LogisticRegressionSGDTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+ .withMaxIterations(1000)
+ .withLocIterations(10)
+ .withBatchSize(100)
+ .withSeed(123L);
+
+ OneVsRestTrainer<LogisticRegressionModel> trainer = new OneVsRestTrainer<>(binaryTrainer);
+
+ MultiClassModel originalMdl = trainer.fit(
+ cacheMock,
+ parts,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
+
+ MultiClassModel updatedOnSameDS = trainer.update(
+ originalMdl,
+ cacheMock,
+ parts,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
+
+ MultiClassModel updatedOnEmptyDS = trainer.update(
+ originalMdl,
+ new HashMap<Integer, double[]>(),
+ parts,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
+
+ List<Vector> vectors = Arrays.asList(
+ VectorUtils.of(-100, 0),
+ VectorUtils.of(100, 0)
+ );
+
+ for (Vector vec : vectors) {
+ TestUtils.assertEquals(originalMdl.apply(vec), updatedOnSameDS.apply(vec), PRECISION);
+ TestUtils.assertEquals(originalMdl.apply(vec), updatedOnEmptyDS.apply(vec), PRECISION);
+ }
+ }
+}