You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2018/11/20 14:55:00 UTC
[17/50] [abbrv] ignite git commit: IGNITE-8867: [ML] Bagging on
learning sample
IGNITE-8867: [ML] Bagging on learning sample
this closes #5058
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/355ce6fe
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/355ce6fe
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/355ce6fe
Branch: refs/heads/ignite-10044
Commit: 355ce6fe8839ea707bded79a6c21a2f74451366b
Parents: 28cb3a0
Author: Artem Malykh <am...@gmail.com>
Authored: Mon Nov 19 00:59:56 2018 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Mon Nov 19 00:59:56 2018 +0300
----------------------------------------------------------------------
...ggedLogisticRegressionSGDTrainerExample.java | 108 ++++++
.../logistic/bagged/package-info.java | 22 ++
.../ml/composition/BaggingModelTrainer.java | 200 ----------
.../ignite/ml/dataset/DatasetBuilder.java | 11 +
.../ml/dataset/PartitionContextBuilder.java | 21 ++
.../ignite/ml/dataset/PartitionDataBuilder.java | 11 +-
.../ignite/ml/dataset/UpstreamTransformer.java | 42 +++
.../ml/dataset/UpstreamTransformerChain.java | 154 ++++++++
.../dataset/impl/cache/CacheBasedDataset.java | 15 +-
.../impl/cache/CacheBasedDatasetBuilder.java | 15 +-
.../dataset/impl/cache/util/ComputeUtils.java | 95 +++--
.../ml/dataset/impl/local/LocalDataset.java | 2 +-
.../dataset/impl/local/LocalDatasetBuilder.java | 95 +++--
.../environment/LearningEnvironmentBuilder.java | 2 +-
.../binomial/LogisticRegressionSGDTrainer.java | 9 +-
.../ignite/ml/trainers/DatasetTrainer.java | 1 +
.../ignite/ml/trainers/TrainerTransformers.java | 376 +++++++++++++++++++
.../BaggingUpstreamTransformer.java | 58 +++
.../ml/trainers/transformers/package-info.java | 22 ++
.../impurity/ImpurityHistogramsComputer.java | 2 +-
.../java/org/apache/ignite/ml/util/Utils.java | 32 ++
.../org/apache/ignite/ml/IgniteMLTestSuite.java | 4 +-
.../impl/cache/util/ComputeUtilsTest.java | 3 +
.../apache/ignite/ml/trainers/BaggingTest.java | 218 +++++++++++
24 files changed, 1261 insertions(+), 257 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
new file mode 100644
index 0000000..baf513a
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.regression.logistic.bagged;
+
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.binomial.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.selection.cv.CrossValidation;
+import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.trainers.TrainerTransformers;
+
+/**
+ * This example shows how bagging technique may be applied to arbitrary trainer.
+ * As an example (a bit synthetic) logistic regression is considered.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points (based on the
+ * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p>
+ * <p>
+ * After that it trains bootstrapped (or bagged) version of logistic regression trainer. Bootstrapping is done
+ * on both samples and features (<a href="https://en.wikipedia.org/wiki/Bootstrap_aggregating"></a>Samples bagging</a>,
+ * <a href="https://en.wikipedia.org/wiki/Random_subspace_method"></a>Features bagging</a>).</p>
+ * <p>
+ * Finally, this example applies cross-validation to resulted model and prints accuracy if each fold.</p>
+ */
+public class BaggedLogisticRegressionSGDTrainerExample {
+ /** Run example. */
+ public static void main(String[] args) throws FileNotFoundException {
+ System.out.println();
+ System.out.println(">>> Logistic regression model over partitioned dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+
+ System.out.println(">>> Create new logistic regression trainer object.");
+ LogisticRegressionSGDTrainer<?> trainer = new LogisticRegressionSGDTrainer<>()
+ .withUpdatesStgy(new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ ))
+ .withMaxIterations(100000)
+ .withLocIterations(100)
+ .withBatchSize(10)
+ .withSeed(123L);
+
+ System.out.println(">>> Perform the training to get the model.");
+
+ DatasetTrainer< ModelsComposition, Double> baggedTrainer = TrainerTransformers.makeBagged(
+ trainer,
+ 10,
+ 0.6,
+ 4,
+ 3,
+ new OnMajorityPredictionsAggregator(),
+ 123L);
+
+ System.out.println(">>> Perform evaluation of the model.");
+
+ double[] score = new CrossValidation<ModelsComposition, Double, Integer, Vector>().score(
+ baggedTrainer,
+ new Accuracy<>(),
+ ignite,
+ dataCache,
+ (k, v) -> v.copyOfRange(1, v.size()),
+ (k, v) -> v.get(0),
+ 3
+ );
+
+ System.out.println(">>> ---------------------------------");
+
+ Arrays.stream(score).forEach(sc -> {
+ System.out.println("\n>>> Accuracy " + sc);
+ });
+
+ System.out.println(">>> Bagged logistic regression model over partitioned dataset usage example completed.");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/package-info.java
----------------------------------------------------------------------
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/package-info.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/package-info.java
new file mode 100644
index 0000000..ea0d19e
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * ML bagged logistic regression examples.
+ */
+package org.apache.ignite.examples.ml.regression.logistic.bagged;
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
deleted file mode 100644
index 493c1da..0000000
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/BaggingModelTrainer.java
+++ /dev/null
@@ -1,200 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.ignite.ml.composition;
-
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Random;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import org.apache.ignite.ml.Model;
-import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
-import org.apache.ignite.ml.dataset.DatasetBuilder;
-import org.apache.ignite.ml.environment.logging.MLLogger;
-import org.apache.ignite.ml.environment.parallelism.Promise;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
-import org.apache.ignite.ml.util.Utils;
-import org.jetbrains.annotations.NotNull;
-
-/**
- * Abstract trainer implementing bagging logic. In each learning iteration the algorithm trains one model on subset of
- * learning sample and subspace of features space. Each model is produced from same model-class [e.g. Decision Trees].
- */
-public abstract class BaggingModelTrainer extends DatasetTrainer<ModelsComposition, Double> {
- /**
- * Predictions aggregator.
- */
- private final PredictionsAggregator predictionsAggregator;
- /**
- * Number of features to draw from original features vector to train each model.
- */
- private final int maximumFeaturesCntPerMdl;
- /**
- * Ensemble size.
- */
- private final int ensembleSize;
- /**
- * Size of sample part in percent to train one model.
- */
- private final double samplePartSizePerMdl;
- /**
- * Feature vector size.
- */
- private final int featureVectorSize;
-
- /**
- * Constructs new instance of BaggingModelTrainer.
- *
- * @param predictionsAggregator Predictions aggregator.
- * @param featureVectorSize Feature vector size.
- * @param maximumFeaturesCntPerMdl Number of features to draw from original features vector to train each model.
- * @param ensembleSize Ensemble size.
- * @param samplePartSizePerMdl Size of sample part in percent to train one model.
- */
- public BaggingModelTrainer(PredictionsAggregator predictionsAggregator,
- int featureVectorSize,
- int maximumFeaturesCntPerMdl,
- int ensembleSize,
- double samplePartSizePerMdl) {
-
- this.predictionsAggregator = predictionsAggregator;
- this.maximumFeaturesCntPerMdl = maximumFeaturesCntPerMdl;
- this.ensembleSize = ensembleSize;
- this.samplePartSizePerMdl = samplePartSizePerMdl;
- this.featureVectorSize = featureVectorSize;
- }
-
- /** {@inheritDoc} */
- @Override public <K, V> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lbExtractor) {
-
- MLLogger log = environment.logger(getClass());
- log.log(MLLogger.VerboseLevel.LOW, "Start learning");
-
- Long startTs = System.currentTimeMillis();
-
- List<IgniteSupplier<ModelOnFeaturesSubspace>> tasks = new ArrayList<>();
- for(int i = 0; i < ensembleSize; i++)
- tasks.add(() -> learnModel(datasetBuilder, featureExtractor, lbExtractor));
-
- List<Model<Vector, Double>> models = environment.parallelismStrategy().submit(tasks)
- .stream().map(Promise::unsafeGet)
- .collect(Collectors.toList());
-
- double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
- log.log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);
- log.log(MLLogger.VerboseLevel.LOW, "Learning finished");
- return new ModelsComposition(models, predictionsAggregator);
- }
-
- /**
- * Trains one model on part of sample and features subspace.
- *
- * @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
- */
- @NotNull private <K, V> ModelOnFeaturesSubspace learnModel(
- DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lbExtractor) {
-
- Random rnd = new Random();
- SHA256UniformMapper<K, V> sampleFilter = new SHA256UniformMapper<>(rnd);
- long featureExtractorSeed = rnd.nextLong();
- Map<Integer, Integer> featuresMapping = createFeaturesMapping(featureExtractorSeed, featureVectorSize);
-
- //TODO: IGNITE-8867 Need to implement bootstrapping algorithm
- Long startTs = System.currentTimeMillis();
- Model<Vector, Double> mdl = buildDatasetTrainerForModel().fit(
- datasetBuilder.withFilter((features, answer) -> sampleFilter.map(features, answer) < samplePartSizePerMdl),
- wrapFeatureExtractor(featureExtractor, featuresMapping),
- lbExtractor);
- double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
- environment.logger(getClass()).log(MLLogger.VerboseLevel.HIGH, "One model training time was %.2fs", learningTime);
-
- return new ModelOnFeaturesSubspace(featuresMapping, mdl);
- }
-
- /**
- * Constructs mapping from original feature vector to subspace.
- *
- * @param seed Seed.
- * @param featuresVectorSize Features vector size.
- */
- private Map<Integer, Integer> createFeaturesMapping(long seed, int featuresVectorSize) {
- int[] featureIdxs = Utils.selectKDistinct(featuresVectorSize, maximumFeaturesCntPerMdl, new Random(seed));
- Map<Integer, Integer> locFeaturesMapping = new HashMap<>();
-
- IntStream.range(0, maximumFeaturesCntPerMdl)
- .forEach(localId -> locFeaturesMapping.put(localId, featureIdxs[localId]));
-
- return locFeaturesMapping;
- }
-
- /**
- * Creates trainer specific to ensemble.
- */
- protected abstract DatasetTrainer<? extends Model<Vector, Double>, Double> buildDatasetTrainerForModel();
-
- /**
- * Wraps the original feature extractor with features subspace mapping applying.
- *
- * @param featureExtractor Feature extractor.
- * @param featureMapping Feature mapping.
- */
- private <K, V> IgniteBiFunction<K, V, Vector> wrapFeatureExtractor(
- IgniteBiFunction<K, V, Vector> featureExtractor,
- Map<Integer, Integer> featureMapping) {
-
- return featureExtractor.andThen((IgniteFunction<Vector, Vector>)featureValues -> {
- double[] newFeaturesValues = new double[featureMapping.size()];
- featureMapping.forEach((localId, featureValueId) -> newFeaturesValues[localId] = featureValues.get(featureValueId));
- return VectorUtils.of(newFeaturesValues);
- });
- }
-
- /**
- * Learn new models on dataset and create new Compositions over them and already learned models.
- *
- * @param mdl Learned model.
- * @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
- * @param <K> Type of a key in {@code upstream} data.
- * @param <V> Type of a value in {@code upstream} data.
- * @return New models composition.
- */
- @Override public <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
-
- ArrayList<Model<Vector, Double>> newModels = new ArrayList<>(mdl.getModels());
- newModels.addAll(fit(datasetBuilder, featureExtractor, lbExtractor).getModels());
-
- return new ModelsComposition(newModels, predictionsAggregator);
- }
-}
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java
index 19bdde9..4dd0a96 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java
@@ -21,6 +21,7 @@ import java.io.Serializable;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
+import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
/**
* A builder constructing instances of a {@link Dataset}. Implementations of this interface encapsulate logic of
@@ -48,6 +49,16 @@ public interface DatasetBuilder<K, V> {
public <C extends Serializable, D extends AutoCloseable> Dataset<C, D> build(
PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder);
+ /**
+ * Get upstream transformers chain. This chain is applied to upstream data before it is passed
+ * to {@link PartitionDataBuilder} and {@link PartitionContextBuilder}. This is needed to allow
+ * transformation to upstream data which are agnostic of any changes that happen after.
+ * Such transformations may be used for deriving meta-algorithms such as bagging
+ * (see {@link BaggingUpstreamTransformer}).
+ *
+ * @return Upstream transformers chain.
+ */
+ public UpstreamTransformerChain<K, V> upstreamTransformersChain();
/**
* Returns new instance of DatasetBuilder using conjunction of internal filter and {@code filterToAdd}.
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java
index 027ec34..6e1fec3 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionContextBuilder.java
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.dataset;
import java.io.Serializable;
import java.util.Iterator;
+import java.util.stream.Stream;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.math.functions.IgniteFunction;
@@ -37,6 +38,10 @@ import org.apache.ignite.ml.math.functions.IgniteFunction;
public interface PartitionContextBuilder<K, V, C extends Serializable> extends Serializable {
/**
* Builds a new partition {@code context} from an {@code upstream} data.
+ * Important: there is no guarantee that there will be no more than one UpstreamEntry with given key,
+ * UpstreamEntry should be thought rather as a container saving all data from upstream, but omitting uniqueness
+ * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating
+ * entries. For example it can be useful for bootstrapping.
*
* @param upstreamData Partition {@code upstream} data.
* @param upstreamDataSize Partition {@code upstream} data size.
@@ -44,6 +49,22 @@ public interface PartitionContextBuilder<K, V, C extends Serializable> extends S
*/
public C build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize);
+
+ /**
+ * Builds a new partition {@code context} from an {@code upstream} data.
+ * Important: there is no guarantee that there will be no more than one UpstreamEntry with given key,
+ * UpstreamEntry should be thought rather as a container saving all data from upstream, but omitting uniqueness
+ * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating
+ * entries. For example it can be useful for bootstrapping.
+ *
+ * @param upstreamData Partition {@code upstream} data.
+ * @param upstreamDataSize Partition {@code upstream} data size.
+ * @return Partition {@code context}.
+ */
+ public default C build(Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize) {
+ return build(upstreamData.iterator(), upstreamDataSize);
+ }
+
/**
* Makes a composed partition {@code context} builder that first builds a {@code context} and then applies the
* specified function on the result.
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java
index c1391b1..54c7611 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/PartitionDataBuilder.java
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.dataset;
import java.io.Serializable;
import java.util.Iterator;
+import java.util.stream.Stream;
import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@@ -39,7 +40,11 @@ import org.apache.ignite.ml.math.functions.IgniteBiFunction;
@FunctionalInterface
public interface PartitionDataBuilder<K, V, C extends Serializable, D extends AutoCloseable> extends Serializable {
/**
- * Builds a new partition {@code data} from a partition {@code upstream} data and partition {@code context}
+ * Builds a new partition {@code data} from a partition {@code upstream} data and partition {@code context}.
+ * Important: there is no guarantee that there will be no more than one UpstreamEntry with given key,
+ * UpstreamEntry should be thought rather as a container saving all data from upstream, but omitting uniqueness
+ * constraint. This constraint is omitted to allow upstream data transformers in {@link DatasetBuilder} replicating
+ * entries. For example it can be useful for bootstrapping.
*
* @param upstreamData Partition {@code upstream} data.
* @param upstreamDataSize Partition {@code upstream} data size.
@@ -48,6 +53,10 @@ public interface PartitionDataBuilder<K, V, C extends Serializable, D extends Au
*/
public D build(Iterator<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx);
+ public default D build(Stream<UpstreamEntry<K, V>> upstreamData, long upstreamDataSize, C ctx) {
+ return build(upstreamData.iterator(), upstreamDataSize, ctx);
+ }
+
/**
* Makes a composed partition {@code data} builder that first builds a {@code data} and then applies the specified
* function on the result.
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java
new file mode 100644
index 0000000..ba70e2e
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset;
+
+import java.io.Serializable;
+import java.util.Random;
+import java.util.stream.Stream;
+
+/**
+ * Interface of transformer of upstream.
+ *
+ * @param <K> Type of keys in the upstream.
+ * @param <V> Type of values in the upstream.
+ */
+@FunctionalInterface
+public interface UpstreamTransformer<K, V> extends Serializable {
+ /**
+ * Perform transformation of upstream.
+ *
+ * @param rnd Random numbers generator.
+ * @param upstream Upstream.
+ * @return Transformed upstream.
+ */
+ // TODO: IGNITE-10296: Inject capabilities of randomization through learning environment.
+ // TODO: IGNITE-10297: Investigate possibility of API change.
+ public Stream<UpstreamEntry<K, V>> transform(Random rnd, Stream<UpstreamEntry<K, V>> upstream);
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java
new file mode 100644
index 0000000..dc83926
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerChain.java
@@ -0,0 +1,154 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Stream;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Class representing chain of transformers applied to upstream.
+ *
+ * @param <K> Type of upstream keys.
+ * @param <V> Type of upstream values.
+ */
+public class UpstreamTransformerChain<K, V> implements Serializable {
+ /** Seed used for transformations. */
+ private Long seed;
+
+ /** List of upstream transformations. */
+ private List<UpstreamTransformer<K, V>> list;
+
+ /**
+ * Creates empty upstream transformers chain (basically identity function).
+ *
+ * @param <K> Type of upstream keys.
+ * @param <V> Type of upstream values.
+ * @return Empty upstream transformers chain.
+ */
+ public static <K, V> UpstreamTransformerChain<K, V> empty() {
+ return new UpstreamTransformerChain<>();
+ }
+
+ /**
+ * Creates upstream transformers chain consisting of one specified transformer.
+ *
+ * @param <K> Type of upstream keys.
+ * @param <V> Type of upstream values.
+ * @return Upstream transformers chain consisting of one specified transformer.
+ */
+ public static <K, V> UpstreamTransformerChain<K, V> of(UpstreamTransformer<K, V> trans) {
+ UpstreamTransformerChain<K, V> res = new UpstreamTransformerChain<>();
+ return res.addUpstreamTransformer(trans);
+ }
+
+ /**
+ * Construct instance of this class.
+ */
+ private UpstreamTransformerChain() {
+ list = new ArrayList<>();
+ seed = new Random().nextLong();
+ }
+
+ /**
+ * Adds upstream transformer to this chain.
+ *
+ * @param next Transformer to add.
+ * @return This chain with added transformer.
+ */
+ public UpstreamTransformerChain<K, V> addUpstreamTransformer(UpstreamTransformer<K, V> next) {
+ list.add(next);
+
+ return this;
+ }
+
+ /**
+ * Add upstream transformer based on given lambda.
+ *
+ * @param transformer Transformer.
+ * @return This object.
+ */
+ public UpstreamTransformerChain<K, V> addUpstreamTransformer(IgniteFunction<Stream<UpstreamEntry<K, V>>,
+ Stream<UpstreamEntry<K, V>>> transformer) {
+ return addUpstreamTransformer((rnd, upstream) -> transformer.apply(upstream));
+ }
+
+ /**
+ * Performs stream transformation using RNG based on provided seed as pseudo-randomness source for all
+ * transformers in the chain.
+ *
+ * @param upstream Upstream.
+ * @return Transformed upstream.
+ */
+ public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream) {
+ Random rnd = new Random(seed);
+
+ Stream<UpstreamEntry<K, V>> res = upstream;
+
+ for (UpstreamTransformer<K, V> kvUpstreamTransformer : list) {
+ res = kvUpstreamTransformer.transform(rnd, res);
+ }
+
+ return res;
+ }
+
+ /**
+ * Checks if this chain is empty.
+ *
+ * @return Result of check if this chain is empty.
+ */
+ public boolean isEmpty() {
+ return list.isEmpty();
+ }
+
+ /**
+ * Set seed for transformations.
+ *
+ * @param seed Seed.
+ * @return This object.
+ */
+ public UpstreamTransformerChain<K, V> setSeed(long seed) {
+ this.seed = seed;
+
+ return this;
+ }
+
+ /**
+ * Modifies seed for transformations if it is present.
+ *
+ * @param f Modification function.
+ * @return This object.
+ */
+ public UpstreamTransformerChain<K, V> modifySeed(IgniteFunction<Long, Long> f) {
+ seed = f.apply(seed);
+
+ return this;
+ }
+
+ /**
+ * Get seed used for RNG in transformations.
+ *
+ * @return Seed used for RNG in transformations.
+ */
+ public Long seed() {
+ return seed;
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
index e5eb483..0736906 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
@@ -26,7 +26,9 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
import org.apache.ignite.ml.dataset.impl.cache.util.ComputeUtils;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
@@ -59,6 +61,9 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
/** Filter for {@code upstream} data. */
private final IgniteBiPredicate<K, V> filter;
+ /** Chain of transformers applied to upstream. */
+ private final UpstreamTransformerChain<K, V> upstreamTransformers;
+
/** Ignite Cache with partition {@code context}. */
private final IgniteCache<Integer, C> datasetCache;
@@ -75,16 +80,22 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
* @param ignite Ignite instance.
* @param upstreamCache Ignite Cache with {@code upstream} data.
* @param filter Filter for {@code upstream} data.
+ * @param upstreamTransformers Transformers of upstream data (see description in {@link DatasetBuilder}).
* @param datasetCache Ignite Cache with partition {@code context}.
* @param partDataBuilder Partition {@code data} builder.
* @param datasetId Dataset ID.
*/
- public CacheBasedDataset(Ignite ignite, IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
+ public CacheBasedDataset(
+ Ignite ignite,
+ IgniteCache<K, V> upstreamCache,
+ IgniteBiPredicate<K, V> filter,
+ UpstreamTransformerChain<K, V> upstreamTransformers,
IgniteCache<Integer, C> datasetCache, PartitionDataBuilder<K, V, C, D> partDataBuilder,
UUID datasetId) {
this.ignite = ignite;
this.upstreamCache = upstreamCache;
this.filter = filter;
+ this.upstreamTransformers = upstreamTransformers;
this.datasetCache = datasetCache;
this.partDataBuilder = partDataBuilder;
this.datasetId = datasetId;
@@ -102,6 +113,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
Ignition.localIgnite(),
upstreamCacheName,
filter,
+ upstreamTransformers,
datasetCacheName,
datasetId,
part,
@@ -131,6 +143,7 @@ public class CacheBasedDataset<K, V, C extends Serializable, D extends AutoClose
Ignition.localIgnite(),
upstreamCacheName,
filter,
+ upstreamTransformers,
datasetCacheName,
datasetId,
part,
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java
index 335ce63..1d00875 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java
@@ -27,6 +27,7 @@ import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionContextBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
import org.apache.ignite.ml.dataset.impl.cache.util.ComputeUtils;
import org.apache.ignite.ml.dataset.impl.cache.util.DatasetAffinityFunctionWrapper;
@@ -56,6 +57,9 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
/** Filter for {@code upstream} data. */
private final IgniteBiPredicate<K, V> filter;
+ /** Chain of upstream transformers. */
+ private final UpstreamTransformerChain<K, V> transformersChain;
+
/**
* Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset} with default
* predicate that passes all upstream entries to dataset.
@@ -78,6 +82,7 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
this.ignite = ignite;
this.upstreamCache = upstreamCache;
this.filter = filter;
+ transformersChain = UpstreamTransformerChain.empty();
}
/** {@inheritDoc} */
@@ -102,16 +107,24 @@ public class CacheBasedDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
ignite,
upstreamCache.getName(),
filter,
+ transformersChain,
datasetCache.getName(),
partCtxBuilder,
RETRIES,
RETRY_INTERVAL
);
- return new CacheBasedDataset<>(ignite, upstreamCache, filter, datasetCache, partDataBuilder, datasetId);
+ return new CacheBasedDataset<>(ignite, upstreamCache, filter, transformersChain, datasetCache, partDataBuilder, datasetId);
}
/** {@inheritDoc} */
+ @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() {
+ return transformersChain;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
@Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) {
return new CacheBasedDatasetBuilder<>(ignite, upstreamCache,
(e1, e2) -> filter.apply(e1, e2) && filterToAdd.apply(e1, e2));
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
index a5cdd3b..6646e89 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
@@ -27,6 +27,7 @@ import java.util.Iterator;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.locks.LockSupport;
+import java.util.stream.Stream;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.IgniteException;
@@ -40,13 +41,17 @@ import org.apache.ignite.lang.IgniteFuture;
import org.apache.ignite.ml.dataset.PartitionContextBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.util.Utils;
/**
* Util class that provides common methods to perform computations on top of the Ignite Compute Grid.
*/
public class ComputeUtils {
- /** Template of the key used to store partition {@code data} in local storage. */
+ /**
+ * Template of the key used to store partition {@code data} in local storage.
+ */
private static final String DATA_STORAGE_KEY_TEMPLATE = "part_data_storage_%s";
/**
@@ -136,6 +141,7 @@ public class ComputeUtils {
* @param ignite Ignite instance.
* @param upstreamCacheName Name of an {@code upstream} cache.
* @param filter Filter for {@code upstream} data.
+ * @param transformersChain Upstream transformers.
* @param datasetCacheName Name of a partition {@code context} cache.
* @param datasetId Dataset ID.
* @param part Partition index.
@@ -146,8 +152,13 @@ public class ComputeUtils {
* @param <D> Type of a partition {@code data}.
* @return Partition {@code data}.
*/
- public static <K, V, C extends Serializable, D extends AutoCloseable> D getData(Ignite ignite,
- String upstreamCacheName, IgniteBiPredicate<K, V> filter, String datasetCacheName, UUID datasetId, int part,
+ public static <K, V, C extends Serializable, D extends AutoCloseable> D getData(
+ Ignite ignite,
+ String upstreamCacheName, IgniteBiPredicate<K, V> filter,
+ UpstreamTransformerChain<K, V> transformersChain,
+ String datasetCacheName,
+ UUID datasetId,
+ int part,
PartitionDataBuilder<K, V, C, D> partDataBuilder) {
PartitionDataStorage dataStorage = (PartitionDataStorage)ignite
@@ -166,13 +177,22 @@ public class ComputeUtils {
qry.setPartition(part);
qry.setFilter(filter);
- long cnt = computeCount(upstreamCache, qry);
+ UpstreamTransformerChain<K, V> chainCopy = Utils.copy(transformersChain);
+ chainCopy.modifySeed(s -> s + part);
+
+ long cnt = computeCount(upstreamCache, qry, chainCopy);
if (cnt > 0) {
try (QueryCursor<UpstreamEntry<K, V>> cursor = upstreamCache.query(qry,
e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
- Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(cursor.iterator(), cnt,
+ Iterator<UpstreamEntry<K, V>> it = cursor.iterator();
+ if (!chainCopy.isEmpty()) {
+ Stream<UpstreamEntry<K, V>> transformedStream = chainCopy.transform(Utils.asStream(it, cnt));
+ it = transformedStream.iterator();
+ }
+
+ Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(it, cnt,
"Cache expected to be not modified during dataset data building [partition=" + part + ']');
return partDataBuilder.build(iter, cnt, ctx);
@@ -193,21 +213,25 @@ public class ComputeUtils {
ignite.cluster().nodeLocalMap().remove(String.format(DATA_STORAGE_KEY_TEMPLATE, datasetId));
}
-
/**
* Initializes partition {@code context} by loading it from a partition {@code upstream}.
- *
+ * @param <K> Type of a key in {@code upstream} data.
+ * @param <V> Type of a value in {@code upstream} data.
+ * @param <C> Type of a partition {@code context}.
* @param ignite Ignite instance.
* @param upstreamCacheName Name of an {@code upstream} cache.
* @param filter Filter for {@code upstream} data.
- * @param datasetCacheName Name of a partition {@code context} cache.
+ * @param transformersChain Upstream data {@link Stream} transformers chain.
* @param ctxBuilder Partition {@code context} builder.
- * @param <K> Type of a key in {@code upstream} data.
- * @param <V> Type of a value in {@code upstream} data.
- * @param <C> Type of a partition {@code context}.
*/
- public static <K, V, C extends Serializable> void initContext(Ignite ignite, String upstreamCacheName,
- IgniteBiPredicate<K, V> filter, String datasetCacheName, PartitionContextBuilder<K, V, C> ctxBuilder, int retries,
+ public static <K, V, C extends Serializable> void initContext(
+ Ignite ignite,
+ String upstreamCacheName,
+ IgniteBiPredicate<K, V> filter,
+ UpstreamTransformerChain<K, V> transformersChain,
+ String datasetCacheName,
+ PartitionContextBuilder<K, V, C> ctxBuilder,
+ int retries,
int interval) {
affinityCallWithRetries(ignite, Arrays.asList(datasetCacheName, upstreamCacheName), part -> {
Ignite locIgnite = Ignition.localIgnite();
@@ -219,13 +243,23 @@ public class ComputeUtils {
qry.setPartition(part);
qry.setFilter(filter);
- long cnt = computeCount(locUpstreamCache, qry);
-
C ctx;
+ UpstreamTransformerChain<K, V> chainCopy = Utils.copy(transformersChain);
+ chainCopy.modifySeed(s -> s + part);
+
+ long cnt = computeCount(locUpstreamCache, qry, transformersChain);
+
try (QueryCursor<UpstreamEntry<K, V>> cursor = locUpstreamCache.query(qry,
e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
- Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(cursor.iterator(), cnt,
+ Iterator<UpstreamEntry<K, V>> it = cursor.iterator();
+ if (!chainCopy.isEmpty()) {
+ Stream<UpstreamEntry<K, V>> transformedStream = chainCopy.transform(Utils.asStream(it, cnt));
+ it = transformedStream.iterator();
+ }
+ Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(
+ it,
+ cnt,
"Cache expected to be not modified during dataset data building [partition=" + part + ']');
ctx = ctxBuilder.build(iter, cnt);
@@ -245,6 +279,7 @@ public class ComputeUtils {
* @param ignite Ignite instance.
* @param upstreamCacheName Name of an {@code upstream} cache.
* @param filter Filter for {@code upstream} data.
+ * @param transformersChain Transformers of upstream data.
* @param datasetCacheName Name of a partition {@code context} cache.
* @param ctxBuilder Partition {@code context} builder.
* @param retries Number of retries for the case when one of partitions not found on the node.
@@ -252,10 +287,15 @@ public class ComputeUtils {
* @param <V> Type of a value in {@code upstream} data.
* @param <C> Type of a partition {@code context}.
*/
- public static <K, V, C extends Serializable> void initContext(Ignite ignite, String upstreamCacheName,
- IgniteBiPredicate<K, V> filter, String datasetCacheName, PartitionContextBuilder<K, V, C> ctxBuilder,
+ public static <K, V, C extends Serializable> void initContext(
+ Ignite ignite,
+ String upstreamCacheName,
+ IgniteBiPredicate<K, V> filter,
+ UpstreamTransformerChain<K, V> transformersChain,
+ String datasetCacheName,
+ PartitionContextBuilder<K, V, C> ctxBuilder,
int retries) {
- initContext(ignite, upstreamCacheName, filter, datasetCacheName, ctxBuilder, retries, 0);
+ initContext(ignite, upstreamCacheName, filter, transformersChain, datasetCacheName, ctxBuilder, retries, 0);
}
/**
@@ -288,16 +328,25 @@ public class ComputeUtils {
/**
* Computes number of entries selected from the cache by the query.
*
- * @param cache Ignite cache with upstream data.
- * @param qry Cache query.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
+ * @param cache Ignite cache with upstream data.
+ * @param qry Cache query.
+ * @param transformersChain Transformers of stream of upstream data.
* @return Number of entries supplied by the iterator.
*/
- private static <K, V> long computeCount(IgniteCache<K, V> cache, ScanQuery<K, V> qry) {
+ private static <K, V> long computeCount(
+ IgniteCache<K, V> cache,
+ ScanQuery<K, V> qry,
+ UpstreamTransformerChain<K, V> transformersChain) {
try (QueryCursor<UpstreamEntry<K, V>> cursor = cache.query(qry,
e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
- return computeCount(cursor.iterator());
+
+ // 'If' statement below is just for optimization, to avoid unnecessary iterator -> stream -> iterator
+ // operations.
+ return transformersChain.isEmpty() ?
+ computeCount(cursor.iterator()) :
+ computeCount(transformersChain.transform(Utils.asStream(cursor.iterator())).iterator());
}
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
index e312b20..975beda 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDataset.java
@@ -25,7 +25,7 @@ import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
import org.apache.ignite.ml.math.functions.IgniteTriFunction;
/**
- * An implementation of dataset based on local data structures such as {@code Map} and {@code List} and doesn't requires
+ * An implementation of dataset based on local data structures such as {@code Map} and {@code List} and doesn't require
* Ignite environment. Introduces for testing purposes mostly, but can be used for simple local computations as well.
*
* @param <C> Type of a partition {@code context}.
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
index 6e0df2f..ce909ff 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
@@ -19,7 +19,6 @@ package org.apache.ignite.ml.dataset.impl.local;
import java.io.Serializable;
import java.util.ArrayList;
-import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
@@ -28,7 +27,9 @@ import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionContextBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.util.Utils;
/**
* A dataset builder that makes {@link LocalDataset}. Encapsulate logic of building local dataset such as allocation
@@ -47,6 +48,9 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
/** Filter for {@code upstream} data. */
private final IgniteBiPredicate<K, V> filter;
+ /** Upstream transformers. */
+ private final UpstreamTransformerChain<K, V> upstreamTransformers;
+
/**
* Constructs a new instance of local dataset builder that makes {@link LocalDataset} with default predicate that
* passes all upstream entries to dataset.
@@ -69,6 +73,7 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
this.upstreamMap = upstreamMap;
this.filter = filter;
this.partitions = partitions;
+ this.upstreamTransformers = UpstreamTransformerChain.empty();
}
/** {@inheritDoc} */
@@ -77,28 +82,55 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
List<C> ctxList = new ArrayList<>();
List<D> dataList = new ArrayList<>();
- Map<K, V> filteredMap = new HashMap<>();
- upstreamMap.forEach((key, val) -> {
- if (filter.apply(key, val))
- filteredMap.put(key, val);
- });
+ List<UpstreamEntry<K, V>> entriesList = new ArrayList<>();
+
+ upstreamMap
+ .entrySet()
+ .stream()
+ .filter(en -> filter.apply(en.getKey(), en.getValue()))
+ .map(en -> new UpstreamEntry<>(en.getKey(), en.getValue()))
+ .forEach(entriesList::add);
- int partSize = Math.max(1, filteredMap.size() / partitions);
+ int partSize = Math.max(1, entriesList.size() / partitions);
- Iterator<K> firstKeysIter = filteredMap.keySet().iterator();
- Iterator<K> secondKeysIter = filteredMap.keySet().iterator();
+ Iterator<UpstreamEntry<K, V>> firstKeysIter = entriesList.iterator();
+ Iterator<UpstreamEntry<K, V>> secondKeysIter = entriesList.iterator();
+ Iterator<UpstreamEntry<K, V>> thirdKeysIter = entriesList.iterator();
int ptr = 0;
- for (int part = 0; part < partitions; part++) {
- int cnt = part == partitions - 1 ? filteredMap.size() - ptr : Math.min(partSize, filteredMap.size() - ptr);
- C ctx = cnt > 0 ? partCtxBuilder.build(
- new IteratorWindow<>(firstKeysIter, k -> new UpstreamEntry<>(k, filteredMap.get(k)), cnt),
- cnt
- ) : null;
+ for (int part = 0; part < partitions; part++) {
+ int cnt = part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr);
+
+ int p = part;
+ upstreamTransformers.modifySeed(s -> s + p);
+
+ if (!upstreamTransformers.isEmpty()) {
+ cnt = (int)upstreamTransformers.transform(
+ Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cnt))).count();
+ }
+
+ Iterator<UpstreamEntry<K, V>> iter;
+ if (upstreamTransformers.isEmpty()) {
+ iter = new IteratorWindow<>(firstKeysIter, k -> k, cnt);
+ }
+ else {
+ iter = upstreamTransformers.transform(
+ Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cnt))).iterator();
+ }
+ C ctx = cnt > 0 ? partCtxBuilder.build(iter, cnt) : null;
+
+ Iterator<UpstreamEntry<K, V>> iter1;
+ if (upstreamTransformers.isEmpty()) {
+ iter1 = upstreamTransformers.transform(
+ Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cnt))).iterator();
+ }
+ else {
+ iter1 = new IteratorWindow<>(secondKeysIter, k -> k, cnt);
+ }
D data = cnt > 0 ? partDataBuilder.build(
- new IteratorWindow<>(secondKeysIter, k -> new UpstreamEntry<>(k, filteredMap.get(k)), cnt),
+ iter1,
cnt,
ctx
) : null;
@@ -113,6 +145,13 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
}
/** {@inheritDoc} */
+ @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() {
+ return upstreamTransformers;
+ }
+
+ /**
+ * {@inheritDoc}
+ */
@Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) {
return new LocalDatasetBuilder<>(upstreamMap,
(e1, e2) -> filter.apply(e1, e2) && filterToAdd.apply(e1, e2), partitions);
@@ -126,16 +165,24 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
* @param <T> Target type of entries.
*/
private static class IteratorWindow<K, T> implements Iterator<T> {
- /** Delegate iterator. */
+ /**
+ * Delegate iterator.
+ */
private final Iterator<K> delegate;
- /** Transformer that transforms entries from one type to another. */
+ /**
+ * Transformer that transforms entries from one type to another.
+ */
private final IgniteFunction<K, T> map;
- /** Count of entries to produce. */
+ /**
+ * Count of entries to produce.
+ */
private final int cnt;
- /** Number of already produced entries. */
+ /**
+ * Number of already produced entries.
+ */
private int ptr;
/**
@@ -151,12 +198,16 @@ public class LocalDatasetBuilder<K, V> implements DatasetBuilder<K, V> {
this.cnt = cnt;
}
- /** {@inheritDoc} */
+ /**
+ * {@inheritDoc}
+ */
@Override public boolean hasNext() {
return delegate.hasNext() && ptr < cnt;
}
- /** {@inheritDoc} */
+ /**
+ * {@inheritDoc}
+ */
@Override public T next() {
++ptr;
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java
index 91e832d..98f584f 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/LearningEnvironmentBuilder.java
@@ -35,7 +35,7 @@ public class LearningEnvironmentBuilder {
/**
* Creates an instance of LearningEnvironmentBuilder.
*/
- LearningEnvironmentBuilder() {
+ public LearningEnvironmentBuilder() {
parallelismStgy = NoParallelismStrategy.INSTANCE;
loggingFactory = NoOpLogger.factory();
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
index 74a296d..47fa59d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/binomial/LogisticRegressionSGDTrainer.java
@@ -74,16 +74,15 @@ public class LogisticRegressionSGDTrainer<P extends Serializable> extends Single
IgniteBiFunction<K, V, Double> lbExtractor) {
IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
- int cols = dataset.compute(data -> {
+ Integer cols = dataset.compute(data -> {
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> {
+ // If both are null then zero will be propagated, no good.
if (a == null)
- return b == null ? 0 : b;
- if (b == null)
- return a;
- return b;
+ return b;
+ return a;
});
MLPArchitecture architecture = new MLPArchitecture(cols);
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
index 5c3913e..f321744 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
@@ -310,4 +310,5 @@ public abstract class DatasetTrainer<M extends Model, L> {
super("Cannot train model on empty dataset");
}
}
+
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
new file mode 100644
index 0000000..4f11327
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
@@ -0,0 +1,376 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.apache.ignite.lang.IgniteBiPredicate;
+import org.apache.ignite.ml.Model;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
+import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.PartitionContextBuilder;
+import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
+import org.apache.ignite.ml.environment.LearningEnvironment;
+import org.apache.ignite.ml.environment.logging.MLLogger;
+import org.apache.ignite.ml.environment.parallelism.Promise;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.math.functions.IgniteTriFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
+import org.apache.ignite.ml.util.Utils;
+
+/**
+ * Class containing various trainer transformers.
+ */
+public class TrainerTransformers {
+ /**
+ * Add bagging logic to a given trainer.
+ *
+ * @param ensembleSize Size of ensemble.
+ * @param subsampleRatio Subsample ratio to whole dataset.
+ * @param aggregator Aggregator.
+ * @param <M> Type of one model in ensemble.
+ * @param <L> Type of labels.
+ * @return Bagged trainer.
+ */
+ public static <M extends Model<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged(
+ DatasetTrainer<M, L> trainer,
+ int ensembleSize,
+ double subsampleRatio,
+ PredictionsAggregator aggregator) {
+ return makeBagged(trainer, ensembleSize, subsampleRatio, -1, -1, aggregator, new Random().nextLong());
+ }
+
+ /**
+ * Add bagging logic to a given trainer.
+ *
+ * @param ensembleSize Size of ensemble.
+ * @param subsampleRatio Subsample ratio to whole dataset.
+ * @param aggregator Aggregator.
+ * @param featureVectorSize Feature vector dimensionality.
+ * @param featuresSubspaceDim Feature subspace dimensionality.
+ * @param transformationSeed Transformations seed.
+ * @param <M> Type of one model in ensemble.
+ * @param <L> Type of labels.
+ * @return Bagged trainer.
+ */
+ // TODO: IGNITE-10296: Inject capabilities of seeding through learning environment (remove).
+ public static <M extends Model<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged(
+ DatasetTrainer<M, L> trainer,
+ int ensembleSize,
+ double subsampleRatio,
+ int featureVectorSize,
+ int featuresSubspaceDim,
+ PredictionsAggregator aggregator,
+ Long transformationSeed) {
+ return new DatasetTrainer<ModelsComposition, L>() {
+ /** {@inheritDoc} */
+ @Override public <K, V> ModelsComposition fit(
+ DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor,
+ IgniteBiFunction<K, V, L> lbExtractor) {
+ datasetBuilder.upstreamTransformersChain().setSeed(
+ transformationSeed == null
+ ? new Random().nextLong()
+ : transformationSeed);
+
+ return runOnEnsemble(
+ (db, i, fe) -> (() -> trainer.fit(db, fe, lbExtractor)),
+ datasetBuilder,
+ ensembleSize,
+ subsampleRatio,
+ featureVectorSize,
+ featuresSubspaceDim,
+ featureExtractor,
+ aggregator,
+ environment);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected boolean checkState(ModelsComposition mdl) {
+ return mdl.getModels().stream().allMatch(m -> trainer.checkState((M)m));
+ }
+
+ /** {@inheritDoc} */
+ @Override protected <K, V> ModelsComposition updateModel(
+ ModelsComposition mdl,
+ DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor,
+ IgniteBiFunction<K, V, L> lbExtractor) {
+ return runOnEnsemble(
+ (db, i, fe) -> (() -> trainer.updateModel(
+ ((ModelWithMapping<Vector, Double, M>)mdl.getModels().get(i)).model(),
+ db,
+ fe,
+ lbExtractor)),
+ datasetBuilder,
+ ensembleSize,
+ subsampleRatio,
+ featureVectorSize,
+ featuresSubspaceDim,
+ featureExtractor,
+ aggregator,
+ environment);
+ }
+ };
+ }
+
+ /**
+ * This method accepts function which for given dataset builder and index of model in ensemble generates
+ * task of training this model.
+ *
+ * @param trainingTaskGenerator Training test generator.
+ * @param datasetBuilder Dataset builder.
+ * @param ensembleSize Size of ensemble.
+ * @param subsampleRatio Ratio (subsample size) / (initial dataset size).
+ * @param featuresVectorSize Dimensionality of feature vector.
+ * @param featureSubspaceDim Dimensionality of feature subspace.
+ * @param aggregator Aggregator of models.
+ * @param environment Environment.
+ * @param <K> Type of keys in dataset builder.
+ * @param <V> Type of values in dataset builder.
+ * @param <M> Type of model.
+ * @return Composition of models trained on bagged dataset.
+ */
+ private static <K, V, M extends Model<Vector, Double>> ModelsComposition runOnEnsemble(
+ IgniteTriFunction<DatasetBuilder<K, V>, Integer, IgniteBiFunction<K, V, Vector>, IgniteSupplier<M>> trainingTaskGenerator,
+ DatasetBuilder<K, V> datasetBuilder,
+ int ensembleSize,
+ double subsampleRatio,
+ int featuresVectorSize,
+ int featureSubspaceDim,
+ IgniteBiFunction<K, V, Vector> extractor,
+ PredictionsAggregator aggregator,
+ LearningEnvironment environment) {
+
+ MLLogger log = environment.logger(datasetBuilder.getClass());
+ log.log(MLLogger.VerboseLevel.LOW, "Start learning.");
+
+ List<int[]> mappings = null;
+ if (featuresVectorSize > 0) {
+ mappings = IntStream.range(0, ensembleSize).mapToObj(
+ modelIdx -> getMapping(
+ featuresVectorSize,
+ featureSubspaceDim,
+ datasetBuilder.upstreamTransformersChain().seed() + modelIdx))
+ .collect(Collectors.toList());
+ }
+
+ Long startTs = System.currentTimeMillis();
+
+ datasetBuilder
+ .upstreamTransformersChain()
+ .addUpstreamTransformer(new BaggingUpstreamTransformer<>(subsampleRatio));
+
+ List<IgniteSupplier<M>> tasks = new ArrayList<>();
+ List<IgniteBiFunction<K, V, Vector>> extractors = new ArrayList<>();
+ if (mappings != null) {
+ for (int[] mapping : mappings) {
+ extractors.add(wrapExtractor(extractor, mapping));
+ }
+ }
+
+ for (int i = 0; i < ensembleSize; i++) {
+ UpstreamTransformerChain<K, V> newChain = Utils.copy(datasetBuilder.upstreamTransformersChain());
+ DatasetBuilder<K, V> newBuilder = withNewChain(datasetBuilder, newChain);
+ int j = i;
+ newChain.modifySeed(s -> s * s + j);
+ tasks.add(
+ trainingTaskGenerator.apply(newBuilder, i, mappings != null ? extractors.get(i) : extractor));
+ }
+
+ List<ModelWithMapping<Vector, Double, M>> models = environment.parallelismStrategy().submit(tasks)
+ .stream()
+ .map(Promise::unsafeGet)
+ .map(ModelWithMapping<Vector, Double, M>::new)
+ .collect(Collectors.toList());
+
+ // If we need to do projection, do it.
+ if (mappings != null) {
+ for (int i = 0; i < models.size(); i++) {
+ models.get(i).setMapping(getProjector(mappings.get(i)));
+ }
+ }
+
+ double learningTime = (double)(System.currentTimeMillis() - startTs) / 1000.0;
+ log.log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs.", learningTime);
+ log.log(MLLogger.VerboseLevel.LOW, "Learning finished.");
+
+ return new ModelsComposition(models, aggregator);
+ }
+
+ /**
+ * Get mapping R^featuresVectorSize -> R^maximumFeaturesCntPerMdl.
+ *
+ * @param featuresVectorSize Features vector size (Dimension of initial space).
+ * @param maximumFeaturesCntPerMdl Dimension of target space.
+ * @param seed Seed.
+ * @return Mapping R^featuresVectorSize -> R^maximumFeaturesCntPerMdl.
+ */
+ public static int[] getMapping(int featuresVectorSize, int maximumFeaturesCntPerMdl, long seed) {
+ return Utils.selectKDistinct(featuresVectorSize, maximumFeaturesCntPerMdl, new Random(seed));
+ }
+
+ /**
+ * Get projector from index mapping.
+ *
+ * @param mapping Index mapping.
+ * @return Projector.
+ */
+ public static IgniteFunction<Vector, Vector> getProjector(int[] mapping) {
+ return v -> {
+ Vector res = VectorUtils.zeroes(mapping.length);
+ for (int i = 0; i < mapping.length; i++) {
+ res.set(i, v.get(mapping[i]));
+ }
+ return res;
+ };
+ }
+
+ /**
+ * Creates feature extractor which is a composition of given feature extractor and projection given by
+ * coordinate indexes mapping.
+ *
+ * @param featureExtractor Initial feature extractor.
+ * @param featureMapping Coordinate indexes mapping.
+ * @param <K> Type of keys.
+ * @param <V> Type of values.
+ * @return Composition of given feature extractor and projection given by coordinate indexes mapping.
+ */
+ private static <K, V> IgniteBiFunction<K, V, Vector> wrapExtractor(IgniteBiFunction<K, V, Vector> featureExtractor,
+ int[] featureMapping) {
+ return featureExtractor.andThen((IgniteFunction<Vector, Vector>)featureValues -> {
+ double[] newFeaturesValues = new double[featureMapping.length];
+ for (int i = 0; i < featureMapping.length; i++) {
+ newFeaturesValues[i] = featureValues.get(featureMapping[i]);
+ }
+ return VectorUtils.of(newFeaturesValues);
+ });
+ }
+
+ /**
+ * Model with mapping from X to X.
+ *
+ * @param <X> Input space.
+ * @param <Y> Output space.
+ * @param <M> Model.
+ */
+ private static class ModelWithMapping<X, Y, M extends Model<X, Y>> implements Model<X, Y> {
+ /** Model. */
+ private final M model;
+
+ /** Mapping. */
+ private IgniteFunction<X, X> mapping;
+
+ /**
+ * Create instance of this class from a given model.
+ * Identity mapping will be used as a mapping.
+ *
+ * @param model Model.
+ */
+ public ModelWithMapping(M model) {
+ this(model, x -> x);
+ }
+
+ /**
+ * Create instance of this class from given model and mapping.
+ *
+ * @param model Model.
+ * @param mapping Mapping.
+ */
+ public ModelWithMapping(M model, IgniteFunction<X, X> mapping) {
+ this.model = model;
+ this.mapping = mapping;
+ }
+
+ /**
+ * Sets mapping.
+ *
+ * @param mapping Mapping.
+ */
+ public void setMapping(IgniteFunction<X, X> mapping) {
+ this.mapping = mapping;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Y apply(X x) {
+ return model.apply(mapping.apply(x));
+ }
+
+ /**
+ * Gets model.
+ *
+ * @return Model.
+ */
+ public M model() {
+ return model;
+ }
+
+ /**
+ * Gets mapping.
+ *
+ * @return Mapping.
+ */
+ public IgniteFunction<X, X> mapping() {
+ return mapping;
+ }
+ }
+
+ /**
+ * Creates new dataset builder which is delegate of a given dataset builder in everything except
+ * new transformations chain.
+ *
+ * @param builder Initial builder.
+ * @param chain New chain.
+ * @param <K> Type of keys.
+ * @param <V> Type of values.
+ * @return new dataset builder which is delegate of a given dataset builder in everything except
+ * new transformations chain.
+ */
+ private static <K, V> DatasetBuilder<K, V> withNewChain(
+ DatasetBuilder<K, V> builder,
+ UpstreamTransformerChain<K, V> chain) {
+ return new DatasetBuilder<K, V>() {
+ /** {@inheritDoc} */
+ @Override public <C extends Serializable, D extends AutoCloseable> Dataset<C, D> build(
+ PartitionContextBuilder<K, V, C> partCtxBuilder, PartitionDataBuilder<K, V, C, D> partDataBuilder) {
+ return builder.build(partCtxBuilder, partDataBuilder);
+ }
+
+ /** {@inheritDoc} */
+ @Override public UpstreamTransformerChain<K, V> upstreamTransformersChain() {
+ return chain;
+ }
+
+ /** {@inheritDoc} */
+ @Override public DatasetBuilder<K, V> withFilter(IgniteBiPredicate<K, V> filterToAdd) {
+ return builder.withFilter(filterToAdd);
+ }
+ };
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java
new file mode 100644
index 0000000..f935ebd
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java
@@ -0,0 +1,58 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers.transformers;
+
+import java.util.Random;
+import java.util.stream.Stream;
+import org.apache.commons.math3.distribution.PoissonDistribution;
+import org.apache.commons.math3.random.Well19937c;
+import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.UpstreamTransformer;
+
+/**
+ * This class encapsulates the logic needed to do bagging (bootstrap aggregating) by features.
+ * The action of this class on a given upstream is to replicate each entry in accordance to
+ * Poisson distribution.
+ *
+ * @param <K> Type of upstream keys.
+ * @param <V> Type of upstream values.
+ */
+public class BaggingUpstreamTransformer<K, V> implements UpstreamTransformer<K, V> {
+ /** Ratio of subsample to entire upstream size */
+ private double subsampleRatio;
+
+ /**
+ * Construct instance of this transformer with a given subsample ratio.
+ *
+ * @param subsampleRatio Subsample ratio.
+ */
+ public BaggingUpstreamTransformer(double subsampleRatio) {
+ this.subsampleRatio = subsampleRatio;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Stream<UpstreamEntry<K, V>> transform(Random rnd, Stream<UpstreamEntry<K, V>> upstream) {
+ PoissonDistribution poisson = new PoissonDistribution(
+ new Well19937c(rnd.nextLong()),
+ subsampleRatio,
+ PoissonDistribution.DEFAULT_EPSILON,
+ PoissonDistribution.DEFAULT_MAX_ITERATIONS);
+
+ return upstream.sequential().flatMap(en -> Stream.generate(() -> en).limit(poisson.sample()));
+ }
+}
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/package-info.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/package-info.java
new file mode 100644
index 0000000..b698ead
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Various upstream transformers.
+ */
+package org.apache.ignite.ml.trainers.transformers;
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java
index 8320461..d202441 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java
@@ -45,7 +45,7 @@ public abstract class ImpurityHistogramsComputer<S extends ImpurityComputer<Boot
private static final long serialVersionUID = -4984067145908187508L;
/**
- * Computes histograms for each features.
+ * Computes histograms for each feature.
*
* @param roots Random forest roots.
* @param histMeta Histograms meta.
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
index ed0ebd3..63a9f3c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
@@ -22,7 +22,12 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import java.util.Iterator;
import java.util.Random;
+import java.util.Spliterator;
+import java.util.Spliterators;
+import java.util.stream.Stream;
+import java.util.stream.StreamSupport;
import org.apache.ignite.IgniteException;
/**
@@ -98,4 +103,31 @@ public class Utils {
public static int[] selectKDistinct(int n, int k) {
return selectKDistinct(n, k, new Random());
}
+
+ /**
+ * Convert given iterator to a stream with known count of entries.
+ *
+ * @param iter Iterator.
+ * @param cnt Count.
+ * @param <T> Type of entries.
+ * @return Stream constructed from iterator.
+ */
+ public static <T> Stream<T> asStream(Iterator<T> iter, long cnt) {
+ return StreamSupport.stream(
+ Spliterators.spliterator(iter, cnt, Spliterator.ORDERED),
+ false);
+ }
+
+ /**
+ * Convert given iterator to a stream.
+ *
+ * @param iter Iterator.
+ * @param <T> Iterator content type.
+ * @return Stream constructed from iterator.
+ */
+ public static <T> Stream<T> asStream(Iterator<T> iter) {
+ return StreamSupport.stream(
+ Spliterators.spliteratorUnknownSize(iter, Spliterator.ORDERED),
+ false);
+ }
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
index 481e1fa..e26b5b8 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/IgniteMLTestSuite.java
@@ -32,6 +32,7 @@ import org.apache.ignite.ml.regressions.RegressionsTestSuite;
import org.apache.ignite.ml.selection.SelectionTestSuite;
import org.apache.ignite.ml.structures.StructuresTestSuite;
import org.apache.ignite.ml.svm.SVMTestSuite;
+import org.apache.ignite.ml.trainers.BaggingTest;
import org.apache.ignite.ml.tree.DecisionTreeTestSuite;
import org.junit.runner.RunWith;
import org.junit.runners.Suite;
@@ -57,7 +58,8 @@ import org.junit.runners.Suite;
CompositionTestSuite.class,
EnvironmentTestSuite.class,
StructuresTestSuite.class,
- CommonTestSuite.class
+ CommonTestSuite.class,
+ BaggingTest.class
})
public class IgniteMLTestSuite {
// No-op.
http://git-wip-us.apache.org/repos/asf/ignite/blob/355ce6fe/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java
index 952fc43..cee8f4f 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtilsTest.java
@@ -33,6 +33,7 @@ import org.apache.ignite.cluster.ClusterNode;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.UpstreamTransformerChain;
import org.apache.ignite.testframework.junits.common.GridCommonAbstractTest;
/**
@@ -178,6 +179,7 @@ public class ComputeUtilsTest extends GridCommonAbstractTest {
ignite,
upstreamCacheName,
(k, v) -> true,
+ UpstreamTransformerChain.empty(),
datasetCacheName,
datasetId,
0,
@@ -227,6 +229,7 @@ public class ComputeUtilsTest extends GridCommonAbstractTest {
ignite,
upstreamCacheName,
(k, v) -> true,
+ UpstreamTransformerChain.empty(),
datasetCacheName,
(upstream, upstreamSize) -> {