You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2018/11/20 14:54:59 UTC
[16/50] [abbrv] ignite git commit: IGNITE-8867: [ML] Bagging on
learning sample
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java
new file mode 100644
index 0000000..c22da04
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/BaggingTest.java
@@ -0,0 +1,218 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers;
+
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
+import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteTriFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.junit.Test;
+
+/**
+ * Tests for bagging algorithm.
+ */
+public class BaggingTest extends TrainerTest {
+ /**
+ * Test that count of entries in context is equal to initial dataset size * subsampleRatio.
+ */
+ @Test
+ public void testBaggingContextCount() {
+ count((ctxCount, countData, integer) -> ctxCount);
+ }
+
+ /**
+ * Test that count of entries in data is equal to initial dataset size * subsampleRatio.
+ */
+ @Test
+ public void testBaggingDataCount() {
+ count((ctxCount, countData, integer) -> countData.cnt);
+ }
+
+ /**
+ * Test that bagged log regression makes correct predictions.
+ */
+ @Test
+ public void testNaiveBaggingLogRegression() {
+ Map<Integer, Double[]> cacheMock = getCacheMock();
+
+ DatasetTrainer<LogisticRegressionModel, Double> trainer =
+ (LogisticRegressionSGDTrainer<?>)new LogisticRegressionSGDTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+ .withMaxIterations(30000)
+ .withLocIterations(100)
+ .withBatchSize(10)
+ .withSeed(123L);
+
+ DatasetTrainer<ModelsComposition, Double> baggedTrainer =
+ TrainerTransformers.makeBagged(
+ trainer,
+ 10,
+ 0.7,
+ new OnMajorityPredictionsAggregator());
+
+ ModelsComposition mdl = baggedTrainer.fit(
+ cacheMock,
+ parts,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
+ (k, v) -> v[0]
+ );
+
+ TestUtils.assertEquals(0, mdl.apply(VectorUtils.of(100, 10)), PRECISION);
+ TestUtils.assertEquals(1, mdl.apply(VectorUtils.of(10, 100)), PRECISION);
+ }
+
+ /**
+ * Method used to test counts of data passed in context and in data builders.
+ *
+ * @param counter Function specifying which data we should count.
+ */
+ protected void count(IgniteTriFunction<Long, CountData, Integer, Long> counter) {
+ Map<Integer, Double[]> cacheMock = getCacheMock();
+
+ CountTrainer countTrainer = new CountTrainer(counter);
+
+ double subsampleRatio = 0.3;
+
+ ModelsComposition model = TrainerTransformers.makeBagged(countTrainer, 100, subsampleRatio, new MeanValuePredictionsAggregator())
+ .fit(cacheMock, parts, null, null);
+
+ Double res = model.apply(null);
+
+ TestUtils.assertEquals(twoLinearlySeparableClasses.length * subsampleRatio, res, twoLinearlySeparableClasses.length / 10);
+ }
+
+ /**
+ * Create cache mock.
+ *
+ * @return Cache mock.
+ */
+ private Map<Integer, Double[]> getCacheMock() {
+ Map<Integer, Double[]> cacheMock = new HashMap<>();
+
+ for (int i = 0; i < twoLinearlySeparableClasses.length; i++) {
+ double[] row = twoLinearlySeparableClasses[i];
+ Double[] convertedRow = new Double[row.length];
+ for (int j = 0; j < row.length; j++)
+ convertedRow[j] = row[j];
+ cacheMock.put(i, convertedRow);
+ }
+ return cacheMock;
+ }
+
+ /**
+ * Get sum of two Long values each of which can be null.
+ *
+ * @param a First value.
+ * @param b Second value.
+ * @return Sum of parameters.
+ */
+ protected static Long plusOfNullables(Long a, Long b) {
+ if (a == null) {
+ return b;
+ }
+ if (b == null) {
+ return a;
+ }
+
+ return a + b;
+ }
+
+ /**
+ * Trainer used to count entries in context or in data.
+ */
+ protected static class CountTrainer extends DatasetTrainer<Model<Vector, Double>, Double> {
+ /**
+ * Function specifying which entries to count.
+ */
+ private final IgniteTriFunction<Long, CountData, Integer, Long> counter;
+
+ /**
+ * Construct instance of this class.
+ *
+ * @param counter Function specifying which entries to count.
+ */
+ public CountTrainer(IgniteTriFunction<Long, CountData, Integer, Long> counter) {
+ this.counter = counter;
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> Model<Vector, Double> fit(
+ DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor,
+ IgniteBiFunction<K, V, Double> lbExtractor) {
+ Dataset<Long, CountData> dataset = datasetBuilder.build(
+ (upstreamData, upstreamDataSize) -> upstreamDataSize,
+ (upstreamData, upstreamDataSize, ctx) -> new CountData(upstreamDataSize)
+ );
+
+ Long cnt = dataset.computeWithCtx(counter, BaggingTest::plusOfNullables);
+
+ return x -> Double.valueOf(cnt);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected boolean checkState(Model<Vector, Double> mdl) {
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected <K, V> Model<Vector, Double> updateModel(
+ Model<Vector, Double> mdl,
+ DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+ return fit(datasetBuilder, featureExtractor, lbExtractor);
+ }
+ }
+
+ /** Data for count trainer. */
+ protected static class CountData implements AutoCloseable {
+ /** Counter. */
+ private long cnt;
+
+ /**
+ * Construct instance of this class.
+ *
+ * @param cnt Counter.
+ */
+ public CountData(long cnt) {
+ this.cnt = cnt;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void close() throws Exception {
+ // No-op
+ }
+ }
+}