You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/02/04 14:23:44 UTC
[ignite] branch master updated: IGNITE-11144: [ML] Create example
for FeatureLabelExtractor
This is an automated email from the ASF dual-hosted git repository.
chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new 5b98080 IGNITE-11144: [ML] Create example for FeatureLabelExtractor
5b98080 is described below
commit 5b9808064bce6884f44ba6fbd169b2fde4621c67
Author: Artem Malykh <am...@gmail.com>
AuthorDate: Mon Feb 4 17:23:25 2019 +0300
IGNITE-11144: [ML] Create example for FeatureLabelExtractor
This closes #5993
---
.../linear/LinearRegressionLSQRTrainerExample.java | 17 ++-
.../ml/composition/bagging/BaggedTrainer.java | 11 +-
.../ml/math/primitives/vector/VectorUtils.java | 9 +-
.../apache/ignite/ml/trainers/DatasetTrainer.java | 18 +++
.../ignite/ml/trainers/FeatureLabelExtractor.java | 5 +-
.../ignite/ml/trainers/TrainerTransformers.java | 11 +-
.../apache/ignite/ml/composition/StackingTest.java | 3 +-
.../apache/ignite/ml/trainers/StackingTest.java | 169 ---------------------
8 files changed, 51 insertions(+), 192 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
index 1bb4146..772a35b 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
@@ -27,6 +27,8 @@ import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
@@ -59,11 +61,22 @@ public class LinearRegressionLSQRTrainerExample {
LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
System.out.println(">>> Perform the training to get the model.");
+
+ // This object is used to extract features and vectors from upstream entities which are
+ // essentialy tuples of the form (key, value) (in our case (Integer, Vector)).
+ // Key part of tuple in our example is ignored.
+ // Label is extracted from 0th entry of the value (which is a Vector)
+ // and features are all remaining vector part. Alternatively we could use
+ // DatasetTrainer#fit(Ignite, IgniteCache, IgniteBiFunction, IgniteBiFunction) method call
+ // where there is a separate lambda for extracting label from (key, value) and a separate labmda for
+ // extracting features.
+ FeatureLabelExtractor<Integer, Vector, Double> extractor =
+ (k, v) -> new LabeledVector<>(v.copyOfRange(1, v.size()), v.get(0));
+
LinearRegressionModel mdl = trainer.fit(
ignite,
dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
+ extractor
);
System.out.println(">>> Linear regression model: " + mdl);
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java
index a63ef62..b588b25 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java
@@ -17,6 +17,11 @@
package org.apache.ignite.ml.composition.bagging;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition;
@@ -31,12 +36,6 @@ import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
import org.apache.ignite.ml.util.Utils;
-import java.util.Collections;
-import java.util.List;
-import java.util.Random;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
/**
* Trainer encapsulating logic of bootstrap aggregating (bagging).
* This trainer accepts some other trainer and returns bagged version of it.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java
index 0c12672..72f95af 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/primitives/vector/VectorUtils.java
@@ -17,6 +17,10 @@
package org.apache.ignite.ml.math.primitives.vector;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.math.StorageConstants;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@@ -25,11 +29,6 @@ import org.apache.ignite.ml.math.primitives.vector.impl.DelegatingNamedVector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Objects;
-
/**
* Some utils for {@link Vector}.
*/
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
index 7455ff1..a78396d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
@@ -198,6 +198,24 @@ public abstract class DatasetTrainer<M extends IgniteModel, L> {
}
/**
+ * Trains model based on the specified data.
+ *
+ * @param ignite Ignite instance.
+ * @param cache Ignite cache.
+ * @param extractor Features and labels extractor.
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ * @return Model.
+ */
+ public <K, V> M fit(Ignite ignite, IgniteCache<K, V> cache,
+ FeatureLabelExtractor<K, V, L> extractor) {
+ return fit(
+ new CacheBasedDatasetBuilder<>(ignite, cache),
+ extractor
+ );
+ }
+
+ /**
* Gets state of model in arguments, update in according to new data and return new model.
*
* @param mdl Learned model.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/FeatureLabelExtractor.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/FeatureLabelExtractor.java
index cd8a0ae..37a2e57 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/FeatureLabelExtractor.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/FeatureLabelExtractor.java
@@ -17,13 +17,12 @@
package org.apache.ignite.ml.trainers;
+import java.io.Serializable;
+import java.util.Objects;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
-import java.io.Serializable;
-import java.util.Objects;
-
/**
* Class fro extracting features and vectors from upstream.
*
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
index 0cba06c..8661d4b 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
@@ -17,6 +17,11 @@
package org.apache.ignite.ml.trainers;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
@@ -34,12 +39,6 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
import org.apache.ignite.ml.util.Utils;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Random;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-
/**
* Class containing various trainer transformers.
*/
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java
index 1203cfb..3267790 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/StackingTest.java
@@ -18,6 +18,7 @@
package org.apache.ignite.ml.composition;
import java.util.Arrays;
+import org.apache.ignite.IgniteCache;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.TestUtils;
import org.apache.ignite.ml.common.TrainerTest;
@@ -167,6 +168,6 @@ public class StackingTest extends TrainerTest {
StackedDatasetTrainer<Void, Void, Void, IgniteModel<Void, Void>, Void> trainer =
new StackedDatasetTrainer<>();
thrown.expect(IllegalStateException.class);
- trainer.fit(null, null, null);
+ trainer.fit(null, (IgniteCache<Object, Object>)null, null);
}
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java
deleted file mode 100644
index 9c089ce..0000000
--- a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java
+++ /dev/null
@@ -1,169 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.trainers;
-
-import java.util.Arrays;
-import org.apache.ignite.ml.IgniteModel;
-import org.apache.ignite.ml.TestUtils;
-import org.apache.ignite.ml.common.TrainerTest;
-import org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer;
-import org.apache.ignite.ml.composition.stacking.StackedModel;
-import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.primitives.matrix.Matrix;
-import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.nn.Activators;
-import org.apache.ignite.ml.nn.MLPTrainer;
-import org.apache.ignite.ml.nn.MultilayerPerceptron;
-import org.apache.ignite.ml.nn.UpdatesStrategy;
-import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
-import org.apache.ignite.ml.optimization.LossFunctions;
-import org.apache.ignite.ml.optimization.SmoothParametrized;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
-import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
-import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
-import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
-import org.junit.Rule;
-import org.junit.Test;
-import org.junit.rules.ExpectedException;
-
-import static junit.framework.TestCase.assertEquals;
-
-/**
- * Tests stacked trainers.
- */
-public class StackingTest extends TrainerTest {
- /** Rule to check exceptions. */
- @Rule
- public ExpectedException thrown = ExpectedException.none();
-
- /**
- * Tests simple stack training.
- */
- @Test
- public void testSimpleStack() {
- StackedDatasetTrainer<Vector, Vector, Double, LinearRegressionModel, Double> trainer =
- new StackedDatasetTrainer<>();
-
- UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- );
-
- MLPArchitecture arch = new MLPArchitecture(2).
- withAddedLayer(10, true, Activators.RELU).
- withAddedLayer(1, false, Activators.SIGMOID);
-
- MLPTrainer<SimpleGDParameterUpdate> trainer1 = new MLPTrainer<>(
- arch,
- LossFunctions.MSE,
- updatesStgy,
- 3000,
- 10,
- 50,
- 123L
- );
-
- // Convert model trainer to produce Vector -> Vector model
- DatasetTrainer<AdaptableDatasetModel<Vector, Vector, Matrix, Matrix, MultilayerPerceptron>, Double> mlpTrainer =
- AdaptableDatasetTrainer.of(trainer1)
- .beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 1))
- .afterTrainedModel((Matrix mtx) -> mtx.getRow(0))
- .withConvertedLabels(VectorUtils::num2Arr);
-
- final double factor = 3;
-
- StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer
- .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
- .addTrainer(mlpTrainer)
- .withAggregatorInputMerger(VectorUtils::concat)
- .withSubmodelOutput2VectorConverter(IgniteFunction.identity())
- .withVector2SubmodelInputConverter(IgniteFunction.identity())
- .withOriginalFeaturesKept(IgniteFunction.identity())
- .withEnvironmentBuilder(TestUtils.testEnvBuilder())
- .fit(getCacheMock(xor),
- parts,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
- (k, v) -> v[v.length - 1]);
-
- assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
- assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
- assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
- assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
- }
-
- /**
- * Tests simple stack training.
- */
- @Test
- public void testSimpleVectorStack() {
- StackedVectorDatasetTrainer<Double, LinearRegressionModel, Double> trainer =
- new StackedVectorDatasetTrainer<>();
-
- UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- );
-
- MLPArchitecture arch = new MLPArchitecture(2).
- withAddedLayer(10, true, Activators.RELU).
- withAddedLayer(1, false, Activators.SIGMOID);
-
- DatasetTrainer<MultilayerPerceptron, Double> mlpTrainer = new MLPTrainer<>(
- arch,
- LossFunctions.MSE,
- updatesStgy,
- 3000,
- 10,
- 50,
- 123L
- ).withConvertedLabels(VectorUtils::num2Arr);
-
- final double factor = 3;
-
- StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer
- .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
- .addMatrix2MatrixTrainer(mlpTrainer)
- .withEnvironmentBuilder(TestUtils.testEnvBuilder())
- .fit(getCacheMock(xor),
- parts,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
- (k, v) -> v[v.length - 1]);
-
- assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
- assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
- assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
- assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
- }
-
- /**
- * Tests that if there is no any way for input of first layer to propagate to second layer,
- * exception will be thrown.
- */
- @Test
- public void testINoWaysOfPropagation() {
- StackedDatasetTrainer<Void, Void, Void, IgniteModel<Void, Void>, Void> trainer =
- new StackedDatasetTrainer<>();
- thrown.expect(IllegalStateException.class);
- trainer.fit(null, null, null);
- }
-}