You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by GitBox <gi...@apache.org> on 2019/01/15 17:23:45 UTC
[ignite] Diff for: [GitHub] asfgit closed pull request #5767: [ML]
IGNITE-10573: Consistent API for Ensemble training
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
index 58f739d79010..c9b10b168785 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
@@ -22,7 +22,8 @@
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.bagging.BaggedModel;
+import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
@@ -31,7 +32,6 @@
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
import org.apache.ignite.ml.selection.cv.CrossValidation;
import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.trainers.TrainerTransformers;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
@@ -75,7 +75,7 @@ public static void main(String[] args) throws FileNotFoundException {
System.out.println(">>> Perform the training to get the model.");
- DatasetTrainer< ModelsComposition, Double> baggedTrainer = TrainerTransformers.makeBagged(
+ BaggedTrainer<Double> baggedTrainer = TrainerTransformers.makeBagged(
trainer,
10,
0.6,
@@ -85,7 +85,7 @@ public static void main(String[] args) throws FileNotFoundException {
System.out.println(">>> Perform evaluation of the model.");
- double[] score = new CrossValidation<ModelsComposition, Double, Integer, Vector>().score(
+ double[] score = new CrossValidation<BaggedModel, Double, Integer, Vector>().score(
baggedTrainer,
new Accuracy<>(),
ignite,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java
new file mode 100644
index 000000000000..a2257bf078b6
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.tutorial;
+
+import java.io.FileNotFoundException;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer;
+import org.apache.ignite.ml.composition.stacking.StackedModel;
+import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
+import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
+import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
+import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
+import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.BinaryClassificationEvaluator;
+import org.apache.ignite.ml.selection.scoring.metric.Accuracy;
+import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
+import org.apache.ignite.ml.tree.DecisionTreeNode;
+
+/**
+ * {@link MinMaxScalerTrainer} and {@link NormalizationTrainer} are used in this example due to different values
+ * distribution in columns and rows.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data (based on Titanic passengers data).</p>
+ * <p>
+ * After that it defines preprocessors that extract features from an upstream data and perform other desired changes
+ * over the extracted data, including the scaling.</p>
+ * <p>
+ * Then, it trains the model based on the processed data using decision tree classification.</p>
+ * <p>
+ * Finally, this example uses {@link BinaryClassificationEvaluator} functionality to compute metrics from predictions.</p>
+ */
+public class Step_10_Scaling_With_Stacking {
+ /** Run example. */
+ public static void main(String[] args) {
+ System.out.println();
+ System.out.println(">>> Tutorial step 5 (scaling) example started.");
+
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ try {
+ IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+
+ // Defines first preprocessor that extracts features from an upstream data.
+ // Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare".
+ IgniteBiFunction<Integer, Object[], Object[]> featureExtractor
+ = (k, v) -> new Object[] {v[0], v[3], v[4], v[5], v[6], v[8], v[10]};
+
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
+
+ IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>()
+ .withEncoderType(EncoderType.STRING_ENCODER)
+ .withEncodedFeature(1)
+ .withEncodedFeature(6) // <--- Changed index here.
+ .fit(ignite,
+ dataCache,
+ featureExtractor
+ );
+
+ IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>()
+ .fit(ignite,
+ dataCache,
+ strEncoderPreprocessor
+ );
+
+ IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>()
+ .fit(
+ ignite,
+ dataCache,
+ imputingPreprocessor
+ );
+
+ IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
+ .withP(1)
+ .fit(
+ ignite,
+ dataCache,
+ minMaxScalerPreprocessor
+ );
+
+ DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
+ DecisionTreeClassificationTrainer trainer1 = new DecisionTreeClassificationTrainer(3, 0);
+ DecisionTreeClassificationTrainer trainer2 = new DecisionTreeClassificationTrainer(4, 0);
+
+ LogisticRegressionSGDTrainer aggregator = new LogisticRegressionSGDTrainer()
+ .withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg));
+
+ StackedModel<Vector, Vector, Double, LogisticRegressionModel> mdl =
+ new StackedVectorDatasetTrainer<>(aggregator)
+ .addTrainerWithDoubleOutput(trainer)
+ .addTrainerWithDoubleOutput(trainer1)
+ .addTrainerWithDoubleOutput(trainer2)
+ .fit(
+ ignite,
+ dataCache,
+ normalizationPreprocessor,
+ lbExtractor
+ );
+
+ System.out.println("\n>>> Trained model: " + mdl);
+
+ double accuracy = BinaryClassificationEvaluator.evaluate(
+ dataCache,
+ mdl,
+ normalizationPreprocessor,
+ lbExtractor,
+ new Accuracy<>()
+ );
+
+ System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Test Error " + (1 - accuracy));
+
+ System.out.println(">>> Tutorial step 5 (scaling) example completed.");
+ }
+ catch (FileNotFoundException e) {
+ e.printStackTrace();
+ }
+ }
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/IgniteModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/IgniteModel.java
index a1165e168e50..6268d065e706 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/IgniteModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/IgniteModel.java
@@ -20,8 +20,10 @@
import java.io.Serializable;
import java.util.function.BiFunction;
import org.apache.ignite.ml.inference.Model;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
/** Basic interface for all models. */
+@FunctionalInterface
public interface IgniteModel<T, V> extends Model<T, V>, Serializable {
/**
* Combines this model with other model via specified combiner
@@ -37,12 +39,46 @@
/**
* Get a composition model of the form {@code x -> after(mdl(x))}.
*
- * @param after Function to apply after this model.
+ * @param after Model to apply after this model.
* @param <V1> Type of input of function applied before this model.
* @return Composition model of the form {@code x -> after(mdl(x))}.
*/
public default <V1> IgniteModel<T, V1> andThen(IgniteModel<V, V1> after) {
- return t -> after.predict(predict(t));
+ IgniteModel<T, V> self = this;
+ return new IgniteModel<T, V1>() {
+ /** {@inheritDoc} */
+ @Override public V1 predict(T input) {
+ return after.predict(self.predict(input));
+ }
+
+ /** {@inheritDoc} */
+ @Override public void close() {
+ self.close();
+ after.close();
+ }
+ };
+ }
+
+ /**
+ * Get a composition model of the form {@code x -> after(mdl(x))}.
+ *
+ * @param after Function to apply after this model.
+ * @param <V1> Type of input of function applied before this model.
+ * @return Composition model of the form {@code x -> after(mdl(x))}.
+ */
+ public default <V1> IgniteModel<T, V1> andThen(IgniteFunction<V, V1> after) {
+ IgniteModel<T, V> self = this;
+ return new IgniteModel<T, V1>() {
+ /** {@inheritDoc} */
+ @Override public V1 predict(T input) {
+ return after.apply(self.predict(input));
+ }
+
+ /** {@inheritDoc} */
+ @Override public void close() {
+ self.close();
+ }
+ };
}
/**
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
index 88ea9b9c9b97..3206b5fed541 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
@@ -149,7 +149,7 @@
}
/** {@inheritDoc} */
- @Override protected boolean checkState(KMeansModel mdl) {
+ @Override public boolean isUpdateable(KMeansModel mdl) {
return mdl.getCenters().length == k && mdl.distanceMeasure().equals(distance);
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java
new file mode 100644
index 000000000000..5a2f40a92fa8
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition;
+
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+
+/**
+ * Various utility functions for trainers composition.
+ */
+public class CompositionUtils {
+ /**
+ * Perform blurring of model type of given trainer to {@code IgniteModel<I, O>}, where I, O are input and output
+ * types of original model.
+ *
+ * @param trainer Trainer to coerce.
+ * @param <I> Type of input of model produced by coerced trainer.
+ * @param <O> Type of output of model produced by coerced trainer.
+ * @param <M> Type of model produced by coerced trainer.
+ * @param <L> Type of labels.
+ * @return Trainer coerced to {@code DatasetTrainer<IgniteModel<I, O>, L>}.
+ */
+ public static <I, O, M extends IgniteModel<I, O>, L> DatasetTrainer<IgniteModel<I, O>, L> unsafeCoerce(
+ DatasetTrainer<? extends M, L> trainer) {
+ return new DatasetTrainer<IgniteModel<I, O>, L>() {
+ /** {@inheritDoc} */
+ @Override public <K, V> IgniteModel<I, O> fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ return trainer.fit(datasetBuilder, featureExtractor, lbExtractor);
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> IgniteModel<I, O> update(IgniteModel<I, O> mdl, DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ DatasetTrainer<IgniteModel<I, O>, L> trainer1 = (DatasetTrainer<IgniteModel<I, O>, L>)trainer;
+ return trainer1.update(mdl, datasetBuilder, featureExtractor, lbExtractor);
+ }
+
+ /**
+ * This method is never called, instead of constructing logic of update from
+ * {@link DatasetTrainer#isUpdateable} and
+ * {@link DatasetTrainer#updateModel}
+ * in this class we explicitly override update method.
+ *
+ * @param mdl Model.
+ * @return True if current critical for training parameters correspond to parameters from last training.
+ */
+ @Override public boolean isUpdateable(IgniteModel<I, O> mdl) {
+ throw new IllegalStateException();
+ }
+
+ /**
+ * This method is never called, instead of constructing logic of update from
+ * {@link DatasetTrainer#isUpdateable(IgniteModel)} and
+ * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
+ * in this class we explicitly override update method.
+ *
+ * @param mdl Model.
+ * @return Updated model.
+ */
+ @Override protected <K, V> IgniteModel<I, O> updateModel(IgniteModel<I, O> mdl, DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ throw new IllegalStateException();
+ }
+ };
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/DatasetMapping.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/DatasetMapping.java
new file mode 100644
index 000000000000..9547d5475be7
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/DatasetMapping.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition;
+
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * This class represents dataset mapping. This is just a tuple of two mappings: one for features and one for labels.
+ *
+ * @param <L1> Type of labels before mapping.
+ * @param <L2> Type of labels after mapping.
+ */
+public interface DatasetMapping<L1, L2> {
+ /**
+ * Method used to map feature vectors.
+ *
+ * @param v Feature vector.
+ * @return Mapped feature vector.
+ */
+ public default Vector mapFeatures(Vector v) {
+ return v;
+ }
+
+ /**
+ * Method used to map labels.
+ *
+ * @param lbl Label.
+ * @return Mapped label.
+ */
+ public L2 mapLabels(L1 lbl);
+
+ /**
+ * Dataset mapping which maps features, leaving labels unaffected.
+ *
+ * @param mapper Function used to map features.
+ * @param <L> Type of labels.
+ * @return Dataset mapping which maps features, leaving labels unaffected.
+ */
+ public static <L> DatasetMapping<L, L> mappingFeatures(IgniteFunction<Vector, Vector> mapper) {
+ return new DatasetMapping<L, L>() {
+ /** {@inheritDoc} */
+ @Override public Vector mapFeatures(Vector v) {
+ return mapper.apply(v);
+ }
+
+ /** {@inheritDoc} */
+ @Override public L mapLabels(L lbl) {
+ return lbl;
+ }
+ };
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedModel.java
new file mode 100644
index 000000000000..c59a6342b0b4
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedModel.java
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.bagging;
+
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * This class represents model produced by {@link BaggedTrainer}.
+ * It is a wrapper around inner representation of model produced by {@link BaggedTrainer}.
+ */
+public class BaggedModel implements IgniteModel<Vector, Double> {
+ /** Inner representation of model produced by {@link BaggedTrainer}. */
+ private IgniteModel<Vector, Double> mdl;
+
+ /**
+ * Construct instance of this class given specified model.
+ * @param mdl Model to wrap.
+ */
+ BaggedModel(IgniteModel<Vector, Double> mdl) {
+ this.mdl = mdl;
+ }
+
+ /**
+ * Get wrapped model.
+ *
+ * @return Wrapped model.
+ */
+ IgniteModel<Vector, Double> model() {
+ return mdl;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double predict(Vector i) {
+ return mdl.predict(i);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void close() {
+ mdl.close();
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java
new file mode 100644
index 000000000000..5b0962a7a079
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java
@@ -0,0 +1,212 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.bagging;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
+import org.apache.ignite.ml.util.Utils;
+
+/**
+ * Trainer encapsulating logic of bootstrap aggregating (bagging).
+ * This trainer accepts some other trainer and returns bagged version of it.
+ * Resulting model consists of submodels results of which are aggregated by a specified aggregator.
+ * <p>Bagging is done
+ * on both samples and features (<a href="https://en.wikipedia.org/wiki/Bootstrap_aggregating"></a>Samples bagging</a>,
+ * <a href="https://en.wikipedia.org/wiki/Random_subspace_method"></a>Features bagging</a>).</p>
+ *
+ * @param <L> Type of labels.
+ */
+public class BaggedTrainer<L> extends
+ DatasetTrainer<BaggedModel, L> {
+ /** Trainer for which bagged version is created. */
+ private final DatasetTrainer<? extends IgniteModel, L> tr;
+
+ /** Aggregator of submodels results. */
+ private final PredictionsAggregator aggregator;
+
+ /** Count of submodels in the ensemble. */
+ private final int ensembleSize;
+
+ /** Ratio determining which part of dataset will be taken as subsample for each submodel training. */
+ private final double subsampleRatio;
+
+ /** Dimensionality of feature vectors. */
+ private final int featuresVectorSize;
+
+ /** Dimension of subspace on which all samples from subsample are projected. */
+ private final int featureSubspaceDim;
+
+ /**
+ * Construct instance of this class with given parameters.
+ *
+ * @param tr Trainer for making bagged.
+ * @param aggregator Aggregator of models.
+ * @param ensembleSize Size of ensemble.
+ * @param subsampleRatio Ratio (subsample size) / (initial dataset size).
+ * @param featuresVectorSize Dimensionality of feature vector.
+ * @param featureSubspaceDim Dimensionality of feature subspace.
+ */
+ public BaggedTrainer(DatasetTrainer<? extends IgniteModel, L> tr,
+ PredictionsAggregator aggregator, int ensembleSize, double subsampleRatio, int featuresVectorSize,
+ int featureSubspaceDim) {
+ this.tr = tr;
+ this.aggregator = aggregator;
+ this.ensembleSize = ensembleSize;
+ this.subsampleRatio = subsampleRatio;
+ this.featuresVectorSize = featuresVectorSize;
+ this.featureSubspaceDim = featureSubspaceDim;
+ }
+
+ /**
+ * Create trainer bagged trainer.
+ *
+ * @return Bagged trainer.
+ */
+ private DatasetTrainer<IgniteModel<Vector, Double>, L> getTrainer() {
+ List<int[]> mappings = (featuresVectorSize > 0 && featureSubspaceDim != featuresVectorSize) ?
+ IntStream.range(0, ensembleSize).mapToObj(
+ modelIdx -> getMapping(
+ featuresVectorSize,
+ featureSubspaceDim,
+ environment.randomNumbersGenerator().nextLong()))
+ .collect(Collectors.toList()) :
+ null;
+
+ List<DatasetTrainer<? extends IgniteModel, L>> trainers = Collections.nCopies(ensembleSize, tr);
+
+ // Generate a list of trainers each each copy of original trainer but on its own subspace and subsample.
+ List<DatasetTrainer<IgniteModel<Vector, Double>, L>> subspaceTrainers = IntStream.range(0, ensembleSize)
+ .mapToObj(mdlIdx -> {
+ AdaptableDatasetTrainer<Vector, Double, Vector, Double, ? extends IgniteModel, L> tr =
+ AdaptableDatasetTrainer.of(trainers.get(mdlIdx));
+ if (mappings != null) {
+ tr = tr.afterFeatureExtractor(featureValues -> {
+ int[] mapping = mappings.get(mdlIdx);
+ double[] newFeaturesValues = new double[mapping.length];
+ for (int j = 0; j < mapping.length; j++)
+ newFeaturesValues[j] = featureValues.get(mapping[j]);
+
+ return VectorUtils.of(newFeaturesValues);
+ }).beforeTrainedModel(getProjector(mappings.get(mdlIdx)));
+ }
+ return tr
+ .withUpstreamTransformerBuilder(BaggingUpstreamTransformer.builder(subsampleRatio, mdlIdx))
+ .withEnvironmentBuilder(envBuilder);
+ })
+ .map(CompositionUtils::unsafeCoerce)
+ .collect(Collectors.toList());
+
+ AdaptableDatasetTrainer<Vector, Double, Vector, List<Double>, IgniteModel<Vector, List<Double>>, L> finalTrainer = AdaptableDatasetTrainer.of(
+ new TrainersParallelComposition<>(
+ subspaceTrainers)).afterTrainedModel(l -> aggregator.apply(l.stream().mapToDouble(Double::valueOf).toArray()));
+
+ return CompositionUtils.unsafeCoerce(finalTrainer);
+ }
+
+ /**
+ * Get mapping R^featuresVectorSize -> R^maximumFeaturesCntPerMdl.
+ *
+ * @param featuresVectorSize Features vector size (Dimension of initial space).
+ * @param maximumFeaturesCntPerMdl Dimension of target space.
+ * @param seed Seed.
+ * @return Mapping R^featuresVectorSize -> R^maximumFeaturesCntPerMdl.
+ */
+ public static int[] getMapping(int featuresVectorSize, int maximumFeaturesCntPerMdl, long seed) {
+ return Utils.selectKDistinct(featuresVectorSize, maximumFeaturesCntPerMdl, new Random(seed));
+ }
+
+ /**
+ * Get projector from index mapping.
+ *
+ * @param mapping Index mapping.
+ * @return Projector.
+ */
+ public static IgniteFunction<Vector, Vector> getProjector(int[] mapping) {
+ return v -> {
+ Vector res = VectorUtils.zeroes(mapping.length);
+ for (int i = 0; i < mapping.length; i++)
+ res.set(i, v.get(mapping[i]));
+
+ return res;
+ };
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> BaggedModel fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ IgniteModel<Vector, Double> fit = getTrainer().fit(datasetBuilder, featureExtractor, lbExtractor);
+ return new BaggedModel(fit);
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> BaggedModel update(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ IgniteModel<Vector, Double> updated = getTrainer().update(mdl.model(), datasetBuilder, featureExtractor, lbExtractor);
+ return new BaggedModel(updated);
+ }
+
+ /** {@inheritDoc} */
+ @Override public BaggedTrainer<L> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+ return (BaggedTrainer<L>)super.withEnvironmentBuilder(envBuilder);
+ }
+
+ /**
+ * This method is never called, instead of constructing logic of update from
+ * {@link DatasetTrainer#isUpdateable} and
+ * {@link DatasetTrainer#updateModel}
+ * in this class we explicitly override update method.
+ *
+ * @param mdl Model.
+ * @return True if current critical for training parameters correspond to parameters from last training.
+ */
+ @Override public boolean isUpdateable(BaggedModel mdl) {
+ // Should be never called.
+ throw new IllegalStateException();
+ }
+
+ /**
+ * This method is never called, instead of constructing logic of update from
+ * {@link DatasetTrainer#isUpdateable} and
+ * {@link DatasetTrainer#updateModel}
+ * in this class we explicitly override update method.
+ *
+ * @param mdl Model.
+ * @return Updated model.
+ */
+ @Override protected <K, V> BaggedModel updateModel(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ // Should be never called.
+ throw new IllegalStateException();
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
index 35502ab4f052..7d88ddbf9e35 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
@@ -141,7 +141,7 @@ public GDBTrainer(double gradStepSize, Integer cntOfIterations, Loss loss) {
}
/** {@inheritDoc} */
- @Override protected boolean checkState(ModelsComposition mdl) {
+ @Override public boolean isUpdateable(ModelsComposition mdl) {
return mdl instanceof GDBModel;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/package-info.java
new file mode 100644
index 000000000000..b39067dc8db9
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains various combinators of trainers and models.
+ */
+package org.apache.ignite.ml.composition.combinators;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/ModelsParallelComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/ModelsParallelComposition.java
new file mode 100644
index 000000000000..7947ea947f5e
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/ModelsParallelComposition.java
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.combinators.parallel;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.ignite.ml.IgniteModel;
+
+/**
+ * Parallel composition of models.
+ * Parallel composition of models is a model which contains a list of submodels with same input and output types.
+ * Result of prediction in such model is a list of predictions of each of submodels.
+ *
+ * @param <I> Type of submodel input.
+ * @param <O> Type of submodel output.
+ */
+public class ModelsParallelComposition<I, O> implements IgniteModel<I, List<O>> {
+ /** List of submodels. */
+ private final List<IgniteModel<I, O>> submodels;
+
+ /**
+ * Construc an instance of this class from list of submodels.
+ *
+ * @param submodels List of submodels constituting this model.
+ */
+ public ModelsParallelComposition(List<IgniteModel<I, O>> submodels) {
+ this.submodels = submodels;
+ }
+
+ /** {@inheritDoc} */
+ @Override public List<O> predict(I i) {
+ return submodels
+ .stream()
+ .map(m -> m.predict(i))
+ .collect(Collectors.toList());
+ }
+
+ /**
+ * List of submodels constituting this model.
+ *
+ * @return List of submodels constituting this model.
+ */
+ public List<IgniteModel<I, O>> submodels() {
+ return Collections.unmodifiableList(submodels);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void close() {
+ submodels.forEach(IgniteModel::close);
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java
new file mode 100644
index 000000000000..3d4d99b57d46
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java
@@ -0,0 +1,146 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.combinators.parallel;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.environment.parallelism.Promise;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+
+/**
+ * This class represents a parallel composition of trainers.
+ * Parallel composition of trainers is a trainer itself which trains a list of trainers with same
+ * input and output. Training is done in following manner:
+ * <pre>
+ * 1. Independently train all trainers on the same dataset and get a list of models.
+ * 2. Combine models produced in step (1) into a {@link ModelsParallelComposition}.
+ * </pre>
+ * Updating is made in a similar fashion.
+ * Like in other trainers combinators we avoid to include type of contained trainers in type parameters
+ * because otherwise compositions of compositions would have a relatively complex generic type which will
+ * reduce readability.
+ *
+ * @param <I> Type of trainers inputs.
+ * @param <O> Type of trainers outputs.
+ * @param <L> Type of dataset labels.
+ */
+public class TrainersParallelComposition<I, O, L> extends DatasetTrainer<IgniteModel<I, List<O>>, L> {
+ /** List of trainers. */
+ private final List<DatasetTrainer<IgniteModel<I, O>, L>> trainers;
+
+ /**
+ * Construct an instance of this class from a list of trainers.
+ *
+ * @param trainers Trainers.
+ * @param <M> Type of model.
+ * @param <T> Type of trainer.
+ */
+ public <M extends IgniteModel<I, O>, T extends DatasetTrainer<? extends IgniteModel<I, O>, L>> TrainersParallelComposition(
+ List<T> trainers) {
+ this.trainers = trainers.stream().map(CompositionUtils::unsafeCoerce).collect(Collectors.toList());
+ }
+
+ /**
+ * Create parallel composition of trainers contained in a given list.
+ *
+ * @param trainers List of trainers.
+ * @param <I> Type of input of model priduced by trainers.
+ * @param <O> Type of output of model priduced by trainers.
+ * @param <M> Type of model priduced by trainers.
+ * @param <T> Type of trainers.
+ * @param <L> Type of input of labels.
+ * @return Parallel composition of trainers contained in a given list.
+ */
+ public static <I, O, M extends IgniteModel<I, O>, T extends DatasetTrainer<M, L>, L> TrainersParallelComposition<I, O, L> of(List<T> trainers) {
+ List<DatasetTrainer<IgniteModel<I, O>, L>> trs =
+ trainers.stream().map(CompositionUtils::unsafeCoerce).collect(Collectors.toList());
+
+ return new TrainersParallelComposition<>(trs);
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> IgniteModel<I, List<O>> fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ List<IgniteSupplier<IgniteModel<I, O>>> tasks = trainers.stream()
+ .map(tr -> (IgniteSupplier<IgniteModel<I, O>>)(() -> tr.fit(datasetBuilder, featureExtractor, lbExtractor)))
+ .collect(Collectors.toList());
+
+ List<IgniteModel<I, O>> mdls = environment.parallelismStrategy().submit(tasks).stream()
+ .map(Promise::unsafeGet)
+ .collect(Collectors.toList());
+
+ return new ModelsParallelComposition<>(mdls);
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> IgniteModel<I, List<O>> update(IgniteModel<I, List<O>> mdl, DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ // Unsafe.
+ ModelsParallelComposition<I, O> typedMdl = (ModelsParallelComposition<I, O>)mdl;
+
+ assert typedMdl.submodels().size() == trainers.size();
+ List<IgniteSupplier<IgniteModel<I, O>>> tasks = new ArrayList<>();
+
+ for (int i = 0; i < trainers.size(); i++) {
+ int j = i;
+ tasks.add(() -> trainers.get(j).update(typedMdl.submodels().get(j), datasetBuilder, featureExtractor, lbExtractor));
+ }
+
+ List<IgniteModel<I, O>> mdls = environment.parallelismStrategy().submit(tasks).stream()
+ .map(Promise::unsafeGet)
+ .collect(Collectors.toList());
+
+ return new ModelsParallelComposition<>(mdls);
+ }
+
+ /**
+ * This method is never called, instead of constructing logic of update from
+ * {@link DatasetTrainer#isUpdateable} and
+ * {@link DatasetTrainer#updateModel}
+ * in this class we explicitly override update method.
+ *
+ * @param mdl Model.
+ * @return True if current critical for training parameters correspond to parameters from last training.
+ */
+ @Override public boolean isUpdateable(IgniteModel<I, List<O>> mdl) {
+ // Never called.
+ throw new IllegalStateException();
+ }
+
+ /**
+ * This method is never called, instead of constructing logic of update from
+ * {@link DatasetTrainer#isUpdateable(IgniteModel)} and
+ * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
+ * in this class we explicitly override update method.
+ *
+ * @param mdl Model.
+ * @return Updated model.
+ */
+ @Override protected <K, V> IgniteModel<I, List<O>> updateModel(IgniteModel<I, List<O>> mdl, DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ // Never called.
+ throw new IllegalStateException();
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/package-info.java
new file mode 100644
index 000000000000..cb242509426d
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains parallel combinators of trainers and models.
+ */
+package org.apache.ignite.ml.composition.combinators.parallel;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/ModelsSequentialComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/ModelsSequentialComposition.java
new file mode 100644
index 000000000000..78e2846e09ea
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/ModelsSequentialComposition.java
@@ -0,0 +1,100 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.combinators.sequential;
+
+import java.util.List;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+
+/**
+ * Sequential composition of models.
+ * Sequential composition is a model consisting of two models {@code mdl1 :: I -> O1, mdl2 :: O1 -> O2} with prediction
+ * corresponding to application of composition {@code mdl1 `andThen` mdl2} to input.
+ *
+ * @param <I> Type of input of the first model.
+ * @param <O1> Type of output of the first model (and input of second).
+ * @param <O2> Type of output of the second model.
+ */
+public class ModelsSequentialComposition<I, O1, O2> implements IgniteModel<I, O2> {
+ /** First model. */
+ private IgniteModel<I, O1> mdl1;
+
+ /** Second model. */
+ private IgniteModel<O1, O2> mdl2;
+
+ /**
+ * Get sequential composition of submodels with same type.
+ *
+ * @param lst List of submodels.
+ * @param output2Input Function for conversion output to input.
+ * @param <I> Type of input of submodel.
+ * @param <O> Type of output of submodel.
+ * @return Sequential composition of submodels with same type.
+ */
+ public static <I, O> ModelsSequentialComposition<I, I, O> ofSame(List<? extends IgniteModel<I, O>> lst,
+ IgniteFunction<O, I> output2Input) {
+ assert lst.size() >= 2;
+
+ if (lst.size() == 2)
+ return new ModelsSequentialComposition<>(lst.get(0).andThen(output2Input),
+ lst.get(1));
+
+ return new ModelsSequentialComposition<>(lst.get(0).andThen(output2Input),
+ ofSame(lst.subList(1, lst.size()), output2Input));
+ }
+
+ /**
+ * Construct instance of this class from two given models.
+ *
+ * @param mdl1 First model.
+ * @param mdl2 Second model.
+ */
+ public ModelsSequentialComposition(IgniteModel<I, O1> mdl1, IgniteModel<O1, O2> mdl2) {
+ this.mdl1 = mdl1;
+ this.mdl2 = mdl2;
+ }
+
+ /**
+ * Get first model.
+ *
+ * @return First model.
+ */
+ public IgniteModel<I, O1> firstModel() {
+ return mdl1;
+ }
+
+ /**
+ * Get second model.
+ *
+ * @return Second model.
+ */
+ public IgniteModel<O1, O2> secondModel() {
+ return mdl2;
+ }
+
+ /** {@inheritDoc} */
+ @Override public O2 predict(I i1) {
+ return mdl1.andThen(mdl2).predict(i1);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void close() {
+ mdl1.close();
+ mdl2.close();
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java
new file mode 100644
index 000000000000..9aa37d9a1f59
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java
@@ -0,0 +1,141 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.combinators.sequential;
+
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.composition.DatasetMapping;
+import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.DatasetTrainer;
+import sun.reflect.generics.reflectiveObjects.NotImplementedException;
+
+/**
+ * Sequential composition of trainers.
+ * Sequential composition of trainers is itself trainer which produces {@link ModelsSequentialComposition}.
+ * Training is done in following fashion:
+ * <pre>
+ * 1. First trainer is trained and `mdl1` is produced.
+ * 2. From `mdl1` {@link DatasetMapping} is constructed. This mapping `dsM` encapsulates dependency between first
+ * training result and second trainer.
+ * 3. Second trainer is trained using dataset aquired from application `dsM` to original dataset; `mdl2` is produced.
+ * 4. `mdl1` and `mdl2` are composed into {@link ModelsSequentialComposition}.
+ * </pre>
+ *
+ * @param <I> Type of input of model produced by first trainer.
+ * @param <O1> Type of output of model produced by first trainer.
+ * @param <O2> Type of output of model produced by second trainer.
+ * @param <L> Type of labels.
+ */
+public class TrainersSequentialComposition<I, O1, O2, L> extends DatasetTrainer<ModelsSequentialComposition<I, O1, O2>, L> {
+ /** First trainer. */
+ private DatasetTrainer<IgniteModel<I, O1>, L> tr1;
+
+ /** Second trainer. */
+ private DatasetTrainer<IgniteModel<O1, O2>, L> tr2;
+
+ /** Dataset mapping. */
+ private IgniteFunction<? super IgniteModel<I, O1>, DatasetMapping<L, L>> datasetMapping;
+
+ /**
+ * Construct sequential composition of given two trainers.
+ *
+ * @param tr1 First trainer.
+ * @param tr2 Second trainer.
+ * @param datasetMapping Dataset mapping.
+ */
+ public TrainersSequentialComposition(DatasetTrainer<? extends IgniteModel<I, O1>, L> tr1,
+ DatasetTrainer<? extends IgniteModel<O1, O2>, L> tr2,
+ IgniteFunction<? super IgniteModel<I, O1>, DatasetMapping<L, L>> datasetMapping) {
+ this.tr1 = CompositionUtils.unsafeCoerce(tr1);
+ this.tr2 = CompositionUtils.unsafeCoerce(tr2);
+ this.datasetMapping = datasetMapping;
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> ModelsSequentialComposition<I, O1, O2> fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+
+ IgniteModel<I, O1> mdl1 = tr1.fit(datasetBuilder, featureExtractor, lbExtractor);
+ DatasetMapping<L, L> mapping = datasetMapping.apply(mdl1);
+
+ IgniteModel<O1, O2> mdl2 = tr2.fit(datasetBuilder,
+ featureExtractor.andThen(mapping::mapFeatures),
+ lbExtractor.andThen(mapping::mapLabels));
+
+ return new ModelsSequentialComposition<>(mdl1, mdl2);
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V> ModelsSequentialComposition<I, O1, O2> update(
+ ModelsSequentialComposition<I, O1, O2> mdl, DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+
+ IgniteModel<I, O1> firstUpdated = tr1.update(mdl.firstModel(), datasetBuilder, featureExtractor, lbExtractor);
+ DatasetMapping<L, L> mapping = datasetMapping.apply(firstUpdated);
+
+ IgniteModel<O1, O2> secondUpdated = tr2.update(mdl.secondModel(),
+ datasetBuilder,
+ featureExtractor.andThen(mapping::mapFeatures),
+ lbExtractor.andThen(mapping::mapLabels));
+
+ return new ModelsSequentialComposition<>(firstUpdated, secondUpdated);
+ }
+
+ /**
+ * This method is never called, instead of constructing logic of update from
+ * {@link DatasetTrainer#isUpdateable} and
+ * {@link DatasetTrainer#updateModel}
+ * in this class we explicitly override update method.
+ *
+ * @param mdl Model.
+ * @return True if current critical for training parameters correspond to parameters from last training.
+ */
+ @Override public boolean isUpdateable(ModelsSequentialComposition<I, O1, O2> mdl) {
+ // Never called.
+ throw new IllegalStateException();
+ }
+
+ /**
+ * This method is never called, instead of constructing logic of update from
+ * {@link DatasetTrainer#isUpdateable(IgniteModel)} and
+ * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
+ * in this class we explicitly override update method.
+ *
+ * @param mdl Model.
+ * @return Updated model.
+ */
+ @Override protected <K, V> ModelsSequentialComposition<I, O1, O2> updateModel(
+ ModelsSequentialComposition<I, O1, O2> mdl, DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ // Never called.
+ throw new IllegalStateException();
+ }
+
+ /**
+ * Performs coersion of this trainer to {@code DatasetTrainer<IgniteModel<I, O2>, L>}.
+ *
+ * @return Trainer coerced to {@code DatasetTrainer<IgniteModel<I, O>, L>}.
+ */
+ public DatasetTrainer<IgniteModel<I, O2>, L> unsafeSimplyTyped() {
+ return CompositionUtils.unsafeCoerce(this);
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/package-info.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/package-info.java
new file mode 100644
index 000000000000..02ca2df8c376
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/package-info.java
@@ -0,0 +1,22 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * <!-- Package description. -->
+ * Contains sequential combinators of trainers and models.
+ */
+package org.apache.ignite.ml.composition.combinators.sequential;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java
index e58107d223f2..45fcecc2dc30 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java
@@ -21,15 +21,18 @@
import java.util.List;
import java.util.stream.Collectors;
import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.composition.DatasetMapping;
+import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition;
+import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
-import org.apache.ignite.ml.environment.parallelism.Promise;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
import org.apache.ignite.ml.trainers.DatasetTrainer;
/**
@@ -220,31 +223,7 @@ public StackedDatasetTrainer() {
// Unsafely coerce DatasetTrainer<M1, L> to DatasetTrainer<Model<IS, IA>, L>, but we fully control
// usages of this unsafely coerced object, on the other hand this makes work with
// submodelTrainers easier.
- submodelsTrainers.add(new DatasetTrainer<IgniteModel<IS, IA>, L>() {
- /** {@inheritDoc} */
- @Override public <K, V> IgniteModel<IS, IA> fit(DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
- return trainer.fit(datasetBuilder, featureExtractor, lbExtractor);
- }
-
- /** {@inheritDoc} */
- @Override public <K, V> IgniteModel<IS, IA> update(IgniteModel<IS, IA> mdl, DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
- DatasetTrainer<IgniteModel<IS, IA>, L> trainer1 = (DatasetTrainer<IgniteModel<IS, IA>, L>)trainer;
- return trainer1.update(mdl, datasetBuilder, featureExtractor, lbExtractor);
- }
-
- /** {@inheritDoc} */
- @Override protected boolean checkState(IgniteModel<IS, IA> mdl) {
- return true;
- }
-
- /** {@inheritDoc} */
- @Override protected <K, V> IgniteModel<IS, IA> updateModel(IgniteModel<IS, IA> mdl, DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
- return null;
- }
- });
+ submodelsTrainers.add(CompositionUtils.unsafeCoerce(trainer));
return this;
}
@@ -254,62 +233,60 @@ public StackedDatasetTrainer() {
IgniteBiFunction<K, V, Vector> featureExtractor,
IgniteBiFunction<K, V, L> lbExtractor) {
- return update(null, datasetBuilder, featureExtractor, lbExtractor);
+ return new StackedModel<>(getTrainer().fit(datasetBuilder, featureExtractor, lbExtractor));
}
/** {@inheritDoc} */
@Override public <K, V> StackedModel<IS, IA, O, AM> update(StackedModel<IS, IA, O, AM> mdl,
DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
IgniteBiFunction<K, V, L> lbExtractor) {
- return runOnSubmodels(
- ensemble -> {
- List<IgniteSupplier<IgniteModel<IS, IA>>> res = new ArrayList<>();
- for (int i = 0; i < ensemble.size(); i++) {
- final int j = i;
- res.add(() -> {
- DatasetTrainer<IgniteModel<IS, IA>, L> trainer = ensemble.get(j);
- return mdl == null ?
- trainer.fit(datasetBuilder, featureExtractor, lbExtractor) :
- trainer.update(mdl.submodels().get(j), datasetBuilder, featureExtractor, lbExtractor);
- });
- }
- return res;
- },
- (at, extr) -> mdl == null ?
- at.fit(datasetBuilder, extr, lbExtractor) :
- at.update(mdl.aggregatorModel(), datasetBuilder, extr, lbExtractor),
- featureExtractor
- );
- }
- /** {@inheritDoc} */
- @Override public StackedDatasetTrainer<IS, IA, O, AM, L> withEnvironmentBuilder(
- LearningEnvironmentBuilder envBuilder) {
- submodelsTrainers =
- submodelsTrainers.stream().map(x -> x.withEnvironmentBuilder(envBuilder)).collect(Collectors.toList());
- aggregatorTrainer = aggregatorTrainer.withEnvironmentBuilder(envBuilder);
-
- return this;
+ return new StackedModel<>(getTrainer().update(mdl, datasetBuilder, featureExtractor, lbExtractor));
}
/**
- * <pre>
- * 1. Obtain models produced by running specified tasks;
- * 2. run other specified task on dataset augmented with results of models from step 2.
- * </pre>
+ * Get the trainer for stacking.
*
- * @param taskSupplier Function used to generate tasks for first step.
- * @param aggregatorProcessor Function used
- * @param featureExtractor Feature extractor.
- * @param <K> Type of keys in upstream.
- * @param <V> Type of values in upstream.
- * @return {@link StackedModel}.
+ * @return Trainer for stacking.
*/
- private <K, V> StackedModel<IS, IA, O, AM> runOnSubmodels(
- IgniteFunction<List<DatasetTrainer<IgniteModel<IS, IA>, L>>, List<IgniteSupplier<IgniteModel<IS, IA>>>> taskSupplier,
- IgniteBiFunction<DatasetTrainer<AM, L>, IgniteBiFunction<K, V, Vector>, AM> aggregatorProcessor,
- IgniteBiFunction<K, V, Vector> featureExtractor) {
+ private DatasetTrainer<IgniteModel<IS, O>, L> getTrainer() {
+ checkConsistency();
+
+ List<DatasetTrainer<IgniteModel<IS, IA>, L>> subs = new ArrayList<>();
+ if (submodelInput2AggregatingInputConverter != null) {
+ DatasetTrainer<IgniteModel<IS, IS>, L> id = DatasetTrainer.identityTrainer();
+ DatasetTrainer<IgniteModel<IS, IA>, L> mappedId = CompositionUtils.unsafeCoerce(
+ AdaptableDatasetTrainer.of(id).afterTrainedModel(submodelInput2AggregatingInputConverter));
+ subs.add(mappedId);
+ }
+
+ subs.addAll(submodelsTrainers);
+
+ TrainersParallelComposition<IS, IA, L> composition = new TrainersParallelComposition<>(subs);
+ IgniteBiFunction<List<IgniteModel<IS, IA>>, Vector, Vector> featureMapper = getFeatureExtractorForAggregator(
+ submodelOutput2VectorConverter,
+ vector2SubmodelInputConverter);
+
+ return AdaptableDatasetTrainer
+ .of(composition)
+ .afterTrainedModel(lst -> lst.stream().reduce(aggregatingInputMerger).get())
+ .andThen(aggregatorTrainer, model -> new DatasetMapping<L, L>() {
+ @Override public Vector mapFeatures(Vector v) {
+ List<IgniteModel<IS, IA>> models = ((ModelsParallelComposition<IS, IA>)model.innerModel()).submodels();
+ return featureMapper.apply(models, v);
+ }
+
+ @Override public L mapLabels(L lbl) {
+ return lbl;
+ }
+ }).unsafeSimplyTyped();
+ }
+
+ /**
+ * Method checking consistency of this trainer.
+ */
+ private void checkConsistency() {
// Make sure there is at least one way for submodel input to propagate to aggregator.
if (submodelInput2AggregatingInputConverter == null && submodelsTrainers.isEmpty())
throw new IllegalStateException("There should be at least one way for submodels " +
@@ -321,60 +298,36 @@ public StackedDatasetTrainer() {
if (aggregatingInputMerger == null)
throw new IllegalStateException("Binary operator used to convert outputs of submodels is not specified");
+ }
- List<IgniteSupplier<IgniteModel<IS, IA>>> mdlSuppliers = taskSupplier.apply(submodelsTrainers);
-
- List<IgniteModel<IS, IA>> subMdls = environment.parallelismStrategy().submit(mdlSuppliers).stream()
- .map(Promise::unsafeGet)
- .collect(Collectors.toList());
-
- // Add new columns consisting in submodels output in features.
- IgniteBiFunction<K, V, Vector> augmentedExtractor = getFeatureExtractorForAggregator(featureExtractor,
- subMdls,
- submodelInput2AggregatingInputConverter,
- submodelOutput2VectorConverter,
- vector2SubmodelInputConverter);
-
- AM aggregator = aggregatorProcessor.apply(aggregatorTrainer, augmentedExtractor);
-
- StackedModel<IS, IA, O, AM> res = new StackedModel<>(
- aggregator,
- aggregatingInputMerger,
- submodelInput2AggregatingInputConverter);
-
- for (IgniteModel<IS, IA> subMdl : subMdls)
- res.addSubmodel(subMdl);
+ /** {@inheritDoc} */
+ @Override public StackedDatasetTrainer<IS, IA, O, AM, L> withEnvironmentBuilder(
+ LearningEnvironmentBuilder envBuilder) {
+ submodelsTrainers =
+ submodelsTrainers.stream().map(x -> x.withEnvironmentBuilder(envBuilder)).collect(Collectors.toList());
+ aggregatorTrainer = aggregatorTrainer.withEnvironmentBuilder(envBuilder);
- return res;
+ return this;
}
/**
* Get feature extractor which will be used for aggregator trainer from original feature extractor.
* This method is static to make sure that we will not grab context of instance in serialization.
*
- * @param featureExtractor Original feature extractor.
- * @param subMdls Submodels.
+ * @param <IS> Type of submodels input.
+ * @param <IA> Type of aggregator input.
* @param <K> Type of upstream keys.
- * @param <V> Type of upstream values.
+ * @param <V> Type of upstream values
* @return Feature extractor which will be used for aggregator trainer from original feature extractor.
*/
- private static <IS, IA, K, V> IgniteBiFunction<K, V, Vector> getFeatureExtractorForAggregator(
- IgniteBiFunction<K, V, Vector> featureExtractor, List<IgniteModel<IS, IA>> subMdls,
- IgniteFunction<IS, IA> submodelInput2AggregatingInputConverter,
+ private static <IS, IA, K, V> IgniteBiFunction<List<IgniteModel<IS, IA>>, Vector, Vector> getFeatureExtractorForAggregator(
IgniteFunction<IA, Vector> submodelOutput2VectorConverter,
IgniteFunction<Vector, IS> vector2SubmodelInputConverter) {
- if (submodelInput2AggregatingInputConverter != null)
- return featureExtractor.andThen((Vector v) -> {
- Vector[] vs = subMdls.stream().map(sm ->
- applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new);
- return VectorUtils.concat(v, vs);
- });
- else
- return featureExtractor.andThen((Vector v) -> {
- Vector[] vs = subMdls.stream().map(sm ->
- applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new);
- return VectorUtils.concat(vs);
- });
+ return (List<IgniteModel<IS, IA>> subMdls, Vector v) -> {
+ Vector[] vs = subMdls.stream().map(sm ->
+ applyToVector(sm, submodelOutput2VectorConverter, vector2SubmodelInputConverter, v)).toArray(Vector[]::new);
+ return VectorUtils.concat(vs);
+ };
}
/**
@@ -396,17 +349,34 @@ public StackedDatasetTrainer() {
return vector2SubmodelInputConverter.andThen(mdl::predict).andThen(submodelOutput2VectorConverter).apply(v);
}
- /** {@inheritDoc} */
+ /**
+ * This method is never called, instead of constructing logic of update from
+ * {@link DatasetTrainer#isUpdateable(IgniteModel)} and
+ * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
+ * in this class we explicitly override update method.
+ *
+ * @param mdl Model.
+ * @return Updated model.
+ */
@Override protected <K, V> StackedModel<IS, IA, O, AM> updateModel(StackedModel<IS, IA, O, AM> mdl,
DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, Vector> featureExtractor,
IgniteBiFunction<K, V, L> lbExtractor) {
// This method is never called, we override "update" instead.
- return null;
+ throw new IllegalStateException();
}
- /** {@inheritDoc} */
- @Override protected boolean checkState(StackedModel<IS, IA, O, AM> mdl) {
- return true;
+ /**
+ * This method is never called, instead of constructing logic of update from
+ * {@link DatasetTrainer#isUpdateable} and
+ * {@link DatasetTrainer#updateModel}
+ * in this class we explicitly override update method.
+ *
+ * @param mdl Model.
+ * @return True if current critical for training parameters correspond to parameters from last training.
+ */
+ @Override public boolean isUpdateable(StackedModel<IS, IA, O, AM> mdl) {
+ // Should be never called.
+ throw new IllegalStateException();
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java
index a9be8f8f91ac..34e1a97a91ff 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedModel.java
@@ -17,19 +17,17 @@
package org.apache.ignite.ml.composition.stacking;
-import java.util.ArrayList;
-import java.util.List;
import org.apache.ignite.ml.IgniteModel;
-import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition;
/**
+ * This is a wrapper for model produced by {@link StackedDatasetTrainer}.
* Model consisting of two layers:
* <pre>
* 1. Submodels layer {@code (IS -> IA)}.
* 2. Aggregator layer {@code (IA -> O)}.
* </pre>
- * Submodels layer is a "parallel" composition of several models {@code IS -> IA} each of them getting same input
+ * Submodels layer is a {@link ModelsParallelComposition} of several models {@code IS -> IA} each of them getting same input
* {@code IS} and produce own output, these outputs outputs {@code [IA]}
* are combined into a single output with a given binary "merger" operator {@code IA -> IA -> IA}. Result of merge
* is then passed to the aggregator layer.
@@ -41,66 +39,24 @@
* @param <AM> Type of aggregator model.
*/
public class StackedModel<IS, IA, O, AM extends IgniteModel<IA, O>> implements IgniteModel<IS, O> {
- /** Submodels layer. */
- private IgniteModel<IS, IA> subModelsLayer;
-
- /** Aggregator model. */
- private final AM aggregatorMdl;
-
- /** Models constituting submodels layer. */
- private List<IgniteModel<IS, IA>> submodels;
-
- /** Binary operator merging submodels outputs. */
- private final IgniteBinaryOperator<IA> aggregatingInputMerger;
-
- /**
- * Constructs instance of this class.
- *
- * @param aggregatorMdl Aggregator model.
- * @param aggregatingInputMerger Binary operator used to merge submodels outputs.
- * @param subMdlInput2AggregatingInput Function converting submodels input to aggregator input. (This function
- * is needed when in {@link StackedDatasetTrainer} option to keep original features is chosen).
- */
- StackedModel(AM aggregatorMdl,
- IgniteBinaryOperator<IA> aggregatingInputMerger,
- IgniteFunction<IS, IA> subMdlInput2AggregatingInput) {
- this.aggregatorMdl = aggregatorMdl;
- this.aggregatingInputMerger = aggregatingInputMerger;
- this.subModelsLayer = subMdlInput2AggregatingInput != null ? subMdlInput2AggregatingInput::apply : null;
- submodels = new ArrayList<>();
- }
-
- /**
- * Get submodels constituting first layer of this model.
- *
- * @return Submodels constituting first layer of this model.
- */
- List<IgniteModel<IS, IA>> submodels() {
- return submodels;
- }
+ /** Model to wrap. */
+ private IgniteModel<IS, O> mdl;
/**
- * Get aggregator model.
- *
- * @return Aggregator model.
+ * Construct instance of this class from {@link IgniteModel}.
+ * @param mdl
*/
- AM aggregatorModel() {
- return aggregatorMdl;
+ StackedModel(IgniteModel<IS, O> mdl) {
+ this.mdl = mdl;
}
- /**
- * Add submodel into first layer.
- *
- * @param subMdl Submodel to add.
- */
- void addSubmodel(IgniteModel<IS, IA> subMdl) {
- submodels.add(subMdl);
- subModelsLayer = subModelsLayer != null ? subModelsLayer.combine(subMdl, aggregatingInputMerger)
- : subMdl;
+ /** {@inheritDoc} */
+ @Override public O predict(IS is) {
+ return mdl.predict(is);
}
/** {@inheritDoc} */
- @Override public O predict(IS is) {
- return subModelsLayer.andThen(aggregatorMdl).predict(is);
+ @Override public void close() {
+ mdl.close();
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java
index 7230e3ceee30..c25b721c1952 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedVectorDatasetTrainer.java
@@ -81,6 +81,7 @@ public StackedVectorDatasetTrainer() {
}
/** {@inheritDoc} */
+ // TODO: IGNITE-10843 Add possibility to keep features with specific indices.
@Override public StackedVectorDatasetTrainer<O, AM, L> withOriginalFeaturesKept(
IgniteFunction<Vector, Vector> submodelInput2AggregatingInputConverter) {
return (StackedVectorDatasetTrainer<O, AM, L>)super.withOriginalFeaturesKept(
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java
index 990065954b1c..c826a40632a8 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetBuilder.java
@@ -67,7 +67,7 @@
* @return Returns new instance of {@link DatasetBuilder} with new {@link UpstreamTransformerBuilder} added
* to chain of upstream transformer builders.
*/
- public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder);
+ public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder builder);
/**
* Returns new instance of DatasetBuilder using conjunction of internal filter and {@code filterToAdd}.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java
index 9c0e281f3419..c7fb92f17ded 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformer.java
@@ -22,29 +22,15 @@
/**
* Interface of transformer of upstream.
- *
- * @param <K> Type of keys in the upstream.
- * @param <V> Type of values in the upstream.
*/
// TODO: IGNITE-10297: Investigate possibility of API change.
@FunctionalInterface
-public interface UpstreamTransformer<K, V> extends Serializable {
+public interface UpstreamTransformer extends Serializable {
/**
* Transform upstream.
*
* @param upstream Upstream to transform.
* @return Transformed upstream.
*/
- public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream);
-
- /**
- * Get composition of this transformer and other transformer which is
- * itself is {@link UpstreamTransformer} applying this transformer and then other transformer.
- *
- * @param other Other transformer.
- * @return Composition of this and other transformer.
- */
- public default UpstreamTransformer<K, V> andThen(UpstreamTransformer<K, V> other) {
- return upstream -> other.transform(transform(upstream));
- }
+ public Stream<UpstreamEntry> transform(Stream<UpstreamEntry> upstream);
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java
index 9adfab56d048..ea9f12669ab3 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/UpstreamTransformerBuilder.java
@@ -21,19 +21,17 @@
import org.apache.ignite.ml.environment.LearningEnvironment;
/**
- * Builder of {@link UpstreamTransformerBuilder}.
- * @param <K> Type of keys in upstream.
- * @param <V> Type of values in upstream.
+ * Builder of {@link UpstreamTransformer}.
*/
@FunctionalInterface
-public interface UpstreamTransformerBuilder<K, V> extends Serializable {
+public interface UpstreamTransformerBuilder extends Serializable {
/**
* Create {@link UpstreamTransformer} based on learning environment.
*
* @param env Learning environment.
* @return Upstream transformer.
*/
- public UpstreamTransformer<K, V> build(LearningEnvironment env);
+ public UpstreamTransformer build(LearningEnvironment env);
/**
* Combunes two builders (this and other respectfully)
@@ -49,11 +47,11 @@
* @param other Builder to combine with.
* @return Compositional builder.
*/
- public default UpstreamTransformerBuilder<K, V> andThen(UpstreamTransformerBuilder<K, V> other) {
- UpstreamTransformerBuilder<K, V> self = this;
+ public default UpstreamTransformerBuilder andThen(UpstreamTransformerBuilder other) {
+ UpstreamTransformerBuilder self = this;
return env -> {
- UpstreamTransformer<K, V> transformer1 = self.build(env);
- UpstreamTransformer<K, V> transformer2 = other.build(env);
+ UpstreamTransformer transformer1 = self.build(env);
+ UpstreamTransformer transformer2 = other.build(env);
return upstream -> transformer2.transform(transformer1.transform(upstream));
};
@@ -66,7 +64,7 @@
* @param <V> Type of values in upstream.
* @return Identity upstream transformer.
*/
- public static <K, V> UpstreamTransformerBuilder<K, V> identity() {
+ public static <K, V> UpstreamTransformerBuilder identity() {
return env -> upstream -> upstream;
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
index bde4bb6d8f3b..b2aa00b6e890 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDataset.java
@@ -64,7 +64,7 @@
private final IgniteBiPredicate<K, V> filter;
/** Builder of transformation applied to upstream. */
- private final UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder;
+ private final UpstreamTransformerBuilder upstreamTransformerBuilder;
/** Ignite Cache with partition {@code context}. */
private final IgniteCache<Integer, C> datasetCache;
@@ -94,7 +94,7 @@ public CacheBasedDataset(
Ignite ignite,
IgniteCache<K, V> upstreamCache,
IgniteBiPredicate<K, V> filter,
- UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder,
+ UpstreamTransformerBuilder upstreamTransformerBuilder,
IgniteCache<Integer, C> datasetCache,
LearningEnvironmentBuilder envBuilder,
PartitionDataBuilder<K, V, C, D> partDataBuilder,
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java
index be40158bbd8b..b85bfc2a70d5 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/CacheBasedDatasetBuilder.java
@@ -59,7 +59,7 @@
private final IgniteBiPredicate<K, V> filter;
/** Upstream transformer builder. */
- private final UpstreamTransformerBuilder<K, V> transformerBuilder;
+ private final UpstreamTransformerBuilder transformerBuilder;
/**
* Constructs a new instance of cache based dataset builder that makes {@link CacheBasedDataset} with default
@@ -93,7 +93,7 @@ public CacheBasedDatasetBuilder(Ignite ignite, IgniteCache<K, V> upstreamCache,
public CacheBasedDatasetBuilder(Ignite ignite,
IgniteCache<K, V> upstreamCache,
IgniteBiPredicate<K, V> filter,
- UpstreamTransformerBuilder<K, V> transformerBuilder) {
+ UpstreamTransformerBuilder transformerBuilder) {
this.ignite = ignite;
this.upstreamCache = upstreamCache;
this.filter = filter;
@@ -136,7 +136,7 @@ public CacheBasedDatasetBuilder(Ignite ignite,
}
/** {@inheritDoc} */
- @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder) {
+ @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder builder) {
return new CacheBasedDatasetBuilder<>(ignite, upstreamCache, filter, transformerBuilder.andThen(builder));
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
index 7fa1efa573b1..6bda657447c6 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/cache/util/ComputeUtils.java
@@ -45,8 +45,8 @@
import org.apache.ignite.ml.dataset.UpstreamEntry;
import org.apache.ignite.ml.dataset.UpstreamTransformer;
import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
-import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
+import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.util.Utils;
@@ -185,7 +185,7 @@ public static LearningEnvironment getLearningEnvironment(Ignite ignite,
public static <K, V, C extends Serializable, D extends AutoCloseable> D getData(
Ignite ignite,
String upstreamCacheName, IgniteBiPredicate<K, V> filter,
- UpstreamTransformerBuilder<K, V> transformerBuilder,
+ UpstreamTransformerBuilder transformerBuilder,
String datasetCacheName, UUID datasetId,
PartitionDataBuilder<K, V, C, D> partDataBuilder,
LearningEnvironment env) {
@@ -208,8 +208,8 @@ public static LearningEnvironment getLearningEnvironment(Ignite ignite,
qry.setPartition(part);
qry.setFilter(filter);
- UpstreamTransformer<K, V> transformer = transformerBuilder.build(env);
- UpstreamTransformer<K, V> transformerCp = Utils.copy(transformer);
+ UpstreamTransformer transformer = transformerBuilder.build(env);
+ UpstreamTransformer transformerCp = Utils.copy(transformer);
long cnt = computeCount(upstreamCache, qry, transformer);
@@ -218,9 +218,8 @@ public static LearningEnvironment getLearningEnvironment(Ignite ignite,
e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
Iterator<UpstreamEntry<K, V>> it = cursor.iterator();
- Stream<UpstreamEntry<K, V>> transformedStream = transformerCp.transform(Utils.asStream(it, cnt));
- it = transformedStream.iterator();
-
+ Stream<UpstreamEntry> transformedStream = transformerCp.transform(Utils.asStream(it, cnt).map(x -> (UpstreamEntry)x));
+ it = Utils.asStream(transformedStream.iterator()).map(x -> (UpstreamEntry<K, V>)x).iterator();
Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(it, cnt,
"Cache expected to be not modified during dataset data building [partition=" + part + ']');
@@ -268,7 +267,7 @@ public static void removeLearningEnv(Ignite ignite, UUID datasetId) {
public static <K, V, C extends Serializable> void initContext(
Ignite ignite,
String upstreamCacheName,
- UpstreamTransformerBuilder<K, V> transformerBuilder,
+ UpstreamTransformerBuilder transformerBuilder,
IgniteBiPredicate<K, V> filter,
String datasetCacheName,
PartitionContextBuilder<K, V, C> ctxBuilder,
@@ -287,8 +286,8 @@ public static void removeLearningEnv(Ignite ignite, UUID datasetId) {
qry.setFilter(filter);
C ctx;
- UpstreamTransformer<K, V> transformer = transformerBuilder.build(env);
- UpstreamTransformer<K, V> transformerCp = Utils.copy(transformer);
+ UpstreamTransformer transformer = transformerBuilder.build(env);
+ UpstreamTransformer transformerCp = Utils.copy(transformer);
long cnt = computeCount(locUpstreamCache, qry, transformer);
@@ -296,8 +295,8 @@ public static void removeLearningEnv(Ignite ignite, UUID datasetId) {
e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
Iterator<UpstreamEntry<K, V>> it = cursor.iterator();
- Stream<UpstreamEntry<K, V>> transformedStream = transformerCp.transform(Utils.asStream(it, cnt));
- it = transformedStream.iterator();
+ Stream<UpstreamEntry> transformedStream = transformerCp.transform(Utils.asStream(it, cnt).map(x -> (UpstreamEntry)x));
+ it = Utils.asStream(transformedStream.iterator()).map(x -> (UpstreamEntry<K, V>)x).iterator();
Iterator<UpstreamEntry<K, V>> iter = new IteratorWithConcurrentModificationChecker<>(
it,
@@ -334,7 +333,7 @@ public static void removeLearningEnv(Ignite ignite, UUID datasetId) {
Ignite ignite,
String upstreamCacheName,
IgniteBiPredicate<K, V> filter,
- UpstreamTransformerBuilder<K, V> transformerBuilder,
+ UpstreamTransformerBuilder transformerBuilder,
String datasetCacheName,
PartitionContextBuilder<K, V, C> ctxBuilder,
LearningEnvironmentBuilder envBuilder,
@@ -382,11 +381,11 @@ public static void removeLearningEnv(Ignite ignite, UUID datasetId) {
private static <K, V> long computeCount(
IgniteCache<K, V> cache,
ScanQuery<K, V> qry,
- UpstreamTransformer<K, V> transformer) {
+ UpstreamTransformer transformer) {
try (QueryCursor<UpstreamEntry<K, V>> cursor = cache.query(qry,
e -> new UpstreamEntry<>(e.getKey(), e.getValue()))) {
- return computeCount(transformer.transform(Utils.asStream(cursor.iterator())).iterator());
+ return computeCount(transformer.transform(Utils.asStream(cursor.iterator()).map(x -> (UpstreamEntry<K, V>)x)).iterator());
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
index b8cd8dc685f8..84f3e087ccd3 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/local/LocalDatasetBuilder.java
@@ -54,7 +54,7 @@
private final IgniteBiPredicate<K, V> filter;
/** Upstream transformers. */
- private final UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder;
+ private final UpstreamTransformerBuilder upstreamTransformerBuilder;
/**
* Constructs a new instance of local dataset builder that makes {@link LocalDataset} with default predicate that
@@ -78,7 +78,7 @@ public LocalDatasetBuilder(Map<K, V> upstreamMap, int partitions) {
public LocalDatasetBuilder(Map<K, V> upstreamMap,
IgniteBiPredicate<K, V> filter,
int partitions,
- UpstreamTransformerBuilder<K, V> upstreamTransformerBuilder) {
+ UpstreamTransformerBuilder upstreamTransformerBuilder) {
this.upstreamMap = upstreamMap;
this.filter = filter;
this.partitions = partitions;
@@ -129,23 +129,26 @@ public LocalDatasetBuilder(Map<K, V> upstreamMap,
int cntBeforeTransform =
part == partitions - 1 ? entriesList.size() - ptr : Math.min(partSize, entriesList.size() - ptr);
LearningEnvironment env = envs.get(part);
- UpstreamTransformer<K, V> transformer1 = upstreamTransformerBuilder.build(env);
- UpstreamTransformer<K, V> transformer2 = Utils.copy(transformer1);
- UpstreamTransformer<K, V> transformer3 = Utils.copy(transformer1);
+ UpstreamTransformer transformer1 = upstreamTransformerBuilder.build(env);
+ UpstreamTransformer transformer2 = Utils.copy(transformer1);
+ UpstreamTransformer transformer3 = Utils.copy(transformer1);
int cnt = (int)transformer1.transform(Utils.asStream(new IteratorWindow<>(thirdKeysIter, k -> k, cntBeforeTransform))).count();
- Iterator<UpstreamEntry<K, V>> iter =
- transformer2.transform(Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cntBeforeTransform))).iterator();
+ Iterator<UpstreamEntry> iter =
+ transformer2.transform(Utils.asStream(new IteratorWindow<>(firstKeysIter, k -> k, cntBeforeTransform)).map(x -> (UpstreamEntry)x)).iterator();
+ Iterator<UpstreamEntry<K, V>> convertedBack = Utils.asStream(iter).map(x -> (UpstreamEntry<K, V>)x).iterator();
- C ctx = cntBeforeTransform > 0 ? partCtxBuilder.build(env, iter, cnt) : null;
+ C ctx = cntBeforeTransform > 0 ? partCtxBuilder.build(env, convertedBack, cnt) : null;
- Iterator<UpstreamEntry<K, V>> iter1 = transformer3.transform(
+ Iterator<UpstreamEntry> iter1 = transformer3.transform(
Utils.asStream(new IteratorWindow<>(secondKeysIter, k -> k, cntBeforeTransform))).iterator();
+ Iterator<UpstreamEntry<K, V>> convertedBack1 = Utils.asStream(iter1).map(x -> (UpstreamEntry<K, V>)x).iterator();
+
D data = cntBeforeTransform > 0 ? partDataBuilder.build(
env,
- iter1,
+ convertedBack1,
cnt,
ctx
) : null;
@@ -160,7 +163,7 @@ public LocalDatasetBuilder(Map<K, V> upstreamMap,
}
/** {@inheritDoc} */
- @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder<K, V> builder) {
+ @Override public DatasetBuilder<K, V> withUpstreamTransformer(UpstreamTransformerBuilder builder) {
return new LocalDatasetBuilder<>(upstreamMap, filter, partitions, upstreamTransformerBuilder.andThen(builder));
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
index c32ca5673ac2..0cdfc5235d75 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
@@ -102,7 +102,7 @@
}
/** {@inheritDoc} */
- @Override protected boolean checkState(ANNClassificationModel mdl) {
+ @Override public boolean isUpdateable(ANNClassificationModel mdl) {
return mdl.getDistanceMeasure().equals(distance) && mdl.getCandidates().rowSize() == k;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
index c52ad2b3a7df..16bf1862547b 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/classification/KNNClassificationTrainer.java
@@ -60,7 +60,7 @@
}
/** {@inheritDoc} */
- @Override protected boolean checkState(KNNClassificationModel mdl) {
+ @Override public boolean isUpdateable(KNNClassificationModel mdl) {
return true;
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
index 9b348f3d9c3a..e6218018d830 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
@@ -56,7 +56,7 @@
}
/** {@inheritDoc} */
- @Override protected boolean checkState(KNNRegressionModel mdl) {
+ @Override public boolean isUpdateable(KNNRegressionModel mdl) {
return true;
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
index 4eca27f23760..a44b5b437563 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/multiclass/OneVsRestTrainer.java
@@ -101,7 +101,7 @@ public OneVsRestTrainer(SingleLabelDatasetTrainer<M> classifier) {
}
/** {@inheritDoc} */
- @Override protected boolean checkState(MultiClassModel<M> mdl) {
+ @Override public boolean isUpdateable(MultiClassModel<M> mdl) {
return true;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
index 0779b84a221a..0179b31007ac 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesTrainer.java
@@ -59,7 +59,7 @@
}
/** {@inheritDoc} */
- @Override protected boolean checkState(DiscreteNaiveBayesModel mdl) {
+ @Override public boolean isUpdateable(DiscreteNaiveBayesModel mdl) {
if (mdl.getBucketThresholds().length != bucketThresholds.length)
return false;
@@ -124,7 +124,7 @@
return a.merge(b);
});
- if (mdl != null && checkState(mdl)) {
+ if (mdl != null && isUpdateable(mdl)) {
if (checkSumsHolder(sumsHolder, mdl.getSumsHolder()))
sumsHolder = sumsHolder.merge(mdl.getSumsHolder());
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
index cdaac5ab1be1..c4ef1bd85743 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesTrainer.java
@@ -55,7 +55,7 @@
}
/** {@inheritDoc} */
- @Override protected boolean checkState(GaussianNaiveBayesModel mdl) {
+ @Override public boolean isUpdateable(GaussianNaiveBayesModel mdl) {
return true;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
index ea0bb6cce586..cf511ec4aba5 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
@@ -354,7 +354,7 @@ public long getSeed() {
}
/** {@inheritDoc} */
- @Override protected boolean checkState(MultilayerPerceptron mdl) {
+ @Override public boolean isUpdateable(MultilayerPerceptron mdl) {
return true;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
index 6b2b11e22b00..e2736330f432 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
@@ -79,7 +79,7 @@
}
/** {@inheritDoc} */
- @Override public boolean checkState(LinearRegressionModel mdl) {
+ @Override public boolean isUpdateable(LinearRegressionModel mdl) {
return true;
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
index 4132d359047b..7dc4df6fb13f 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
@@ -160,7 +160,7 @@ public LinearRegressionSGDTrainer(UpdatesStrategy<? super MultilayerPerceptron,
}
/** {@inheritDoc} */
- @Override protected boolean checkState(LinearRegressionModel mdl) {
+ @Override public boolean isUpdateable(LinearRegressionModel mdl) {
return true;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
index 864187db6ec2..16ffac318457 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionSGDTrainer.java
@@ -139,7 +139,7 @@
}
/** {@inheritDoc} */
- @Override protected boolean checkState(LogisticRegressionModel mdl) {
+ @Override public boolean isUpdateable(LogisticRegressionModel mdl) {
return true;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java
index 67484ea59870..90bbe379ffd9 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java
@@ -121,7 +121,7 @@
}
/** {@inheritDoc} */
- @Override protected boolean checkState(SVMLinearClassificationModel mdl) {
+ @Override public boolean isUpdateable(SVMLinearClassificationModel mdl) {
return true;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java
index 4205286a4062..4695946387c6 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/AdaptableDatasetTrainer.java
@@ -18,7 +18,10 @@
package org.apache.ignite.ml.trainers;
import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.DatasetMapping;
+import org.apache.ignite.ml.composition.combinators.sequential.TrainersSequentialComposition;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.UpstreamTransformerBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -46,6 +49,15 @@
/** Function used to convert output type of wrapped trainer. */
private final IgniteFunction<OW, O> after;
+ /** Function which is applied after feature extractor. */
+ private final IgniteFunction<Vector, Vector> afterFeatureExtractor;
+
+ /** Function which is applied after label extractor. */
+ private final IgniteFunction<L, L> afterLabelExtractor;
+
+ /** Upstream transformer builder which will be used in dataset builder. */
+ private final UpstreamTransformerBuilder upstreamTransformerBuilder;
+
/**
* Construct instance of this class from a given {@link DatasetTrainer}.
*
@@ -56,39 +68,65 @@
* @param <L> Type of labels.
* @return Instance of this class.
*/
- public static <I, O, M extends IgniteModel<I, O>, L> AdaptableDatasetTrainer<I, O, I, O, M, L> of(DatasetTrainer<M, L> wrapped) {
- return new AdaptableDatasetTrainer<>(IgniteFunction.identity(), wrapped, IgniteFunction.identity());
+ public static <I, O, M extends IgniteModel<I, O>, L> AdaptableDatasetTrainer<I, O, I, O, M, L> of(
+ DatasetTrainer<M, L> wrapped) {
+ return new AdaptableDatasetTrainer<>(IgniteFunction.identity(),
+ wrapped,
+ IgniteFunction.identity(),
+ IgniteFunction.identity(),
+ IgniteFunction.identity(),
+ UpstreamTransformerBuilder.identity());
}
/**
* Construct instance of this class with specified wrapped trainer and converter functions.
*
* @param before Function used to convert input type of wrapped trainer.
- * @param wrapped Wrapped trainer.
+ * @param wrapped Wrapped trainer.
* @param after Function used to convert output type of wrapped trainer.
+ * @param extractor Function which is applied after label extractor.
+ * @param builder Upstream transformer builder which will be used in dataset builder.
*/
- private AdaptableDatasetTrainer(IgniteFunction<I, IW> before, DatasetTrainer<M, L> wrapped, IgniteFunction<OW, O> after) {
+ private AdaptableDatasetTrainer(IgniteFunction<I, IW> before, DatasetTrainer<M, L> wrapped,
+ IgniteFunction<OW, O> after,
+ IgniteFunction<Vector, Vector> afterFeatureExtractor,
+ IgniteFunction<L, L> extractor, UpstreamTransformerBuilder builder) {
this.before = before;
this.wrapped = wrapped;
this.after = after;
+ this.afterFeatureExtractor = afterFeatureExtractor;
+ afterLabelExtractor = extractor;
+ upstreamTransformerBuilder = builder;
}
/** {@inheritDoc} */
@Override public <K, V> AdaptableDatasetModel<I, O, IW, OW, M> fit(DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
- M fit = wrapped.fit(datasetBuilder, featureExtractor, lbExtractor);
+ IgniteBiFunction<K, V, Vector> featureExtractor,
+ IgniteBiFunction<K, V, L> lbExtractor) {
+ M fit = wrapped.fit(
+ datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder),
+ featureExtractor.andThen(afterFeatureExtractor),
+ lbExtractor.andThen(afterLabelExtractor));
+
return new AdaptableDatasetModel<>(before, fit, after);
}
/** {@inheritDoc} */
- @Override protected boolean checkState(AdaptableDatasetModel<I, O, IW, OW, M> mdl) {
- return wrapped.checkState(mdl.innerModel());
+ @Override public boolean isUpdateable(AdaptableDatasetModel<I, O, IW, OW, M> mdl) {
+ return wrapped.isUpdateable(mdl.innerModel());
}
/** {@inheritDoc} */
- @Override protected <K, V> AdaptableDatasetModel<I, O, IW, OW, M> updateModel(AdaptableDatasetModel<I, O, IW, OW, M> mdl, DatasetBuilder<K, V> datasetBuilder,
+ @Override protected <K, V> AdaptableDatasetModel<I, O, IW, OW, M> updateModel(
+ AdaptableDatasetModel<I, O, IW, OW, M> mdl, DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
- return mdl.withInnerModel(wrapped.updateModel(mdl.innerModel(), datasetBuilder, featureExtractor, lbExtractor));
+ M updated = wrapped.updateModel(
+ mdl.innerModel(),
+ datasetBuilder.withUpstreamTransformer(upstreamTransformerBuilder),
+ featureExtractor.andThen(afterFeatureExtractor),
+ lbExtractor.andThen(afterLabelExtractor));
+
+ return mdl.withInnerModel(updated);
}
/**
@@ -101,7 +139,12 @@ private AdaptableDatasetTrainer(IgniteFunction<I, IW> before, DatasetTrainer<M,
* original trainer.
*/
public <O1> AdaptableDatasetTrainer<I, O1, IW, OW, M, L> afterTrainedModel(IgniteFunction<O, O1> after) {
- return new AdaptableDatasetTrainer<>(before, wrapped, i -> after.apply(this.after.apply(i)));
+ return new AdaptableDatasetTrainer<>(before,
+ wrapped,
+ i -> after.apply(this.after.apply(i)),
+ afterFeatureExtractor,
+ afterLabelExtractor,
+ upstreamTransformerBuilder);
}
/**
@@ -115,6 +158,116 @@ private AdaptableDatasetTrainer(IgniteFunction<I, IW> before, DatasetTrainer<M,
*/
public <I1> AdaptableDatasetTrainer<I1, O, IW, OW, M, L> beforeTrainedModel(IgniteFunction<I1, I> before) {
IgniteFunction<I1, IW> function = i -> this.before.apply(before.apply(i));
- return new AdaptableDatasetTrainer<>(function, wrapped, after);
+ return new AdaptableDatasetTrainer<>(function,
+ wrapped,
+ after,
+ afterFeatureExtractor,
+ afterLabelExtractor,
+ upstreamTransformerBuilder);
+ }
+
+ /**
+ * Specify {@link DatasetMapping} which will be applied to dataset before fitting and updating.
+ *
+ * @param mapping {@link DatasetMapping} which will be applied to dataset before fitting and updating.
+ * @return New trainer of the same type, but with specified mapping applied to dataset before fitting and updating.
+ */
+ public AdaptableDatasetTrainer<I, O, IW, OW, M, L> withDatasetMapping(DatasetMapping<L, L> mapping) {
+ return of(new DatasetTrainer<M, L>() {
+ @Override public <K, V> M fit(
+ DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
+ IgniteBiFunction<K, V, L> lbExtractor) {
+ IgniteBiFunction<K, V, Vector> fe = featureExtractor.andThen(mapping::mapFeatures);
+ IgniteBiFunction<K, V, L> le = lbExtractor.andThen(mapping::mapLabels);
+
+ return wrapped.fit(datasetBuilder,
+ fe,
+ le);
+ }
+
+ @Override public <K, V> M update(M mdl, DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ return wrapped.update(mdl, datasetBuilder,
+ featureExtractor.andThen(mapping::mapFeatures),
+ lbExtractor.andThen((IgniteFunction<L, L>)mapping::mapLabels));
+ }
+
+ @Override public boolean isUpdateable(M mdl) {
+ return false;
+ }
+
+ @Override protected <K, V> M updateModel(M mdl, DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ return null;
+ }
+ }).beforeTrainedModel(before).afterTrainedModel(after);
+ }
+
+ /**
+ * Create a {@link TrainersSequentialComposition} of whis trainer and specified trainer.
+ *
+ * @param tr Trainer to compose with.
+ * @param datasetMappingProducer {@link DatasetMapping} producer specifying dependency between this trainer and
+ * trainer to compose with.
+ * @param <O1> Type of output of trainer to compose with.
+ * @param <M1> Type of model produced by the trainer to compose with.
+ * @return A {@link TrainersSequentialComposition} of whis trainer and specified trainer.
+ */
+ public <O1, M1 extends IgniteModel<O, O1>> TrainersSequentialComposition<I, O, O1, L> andThen(
+ DatasetTrainer<M1, L> tr,
+ IgniteFunction<AdaptableDatasetModel<I, O, IW, OW, M>, DatasetMapping<L, L>> datasetMappingProducer) {
+ IgniteFunction<IgniteModel<I, O>, DatasetMapping<L, L>> coercedMapping = mdl ->
+ datasetMappingProducer.apply((AdaptableDatasetModel<I, O, IW, OW, M>)mdl);
+ return new TrainersSequentialComposition<>(this,
+ tr,
+ coercedMapping);
+ }
+
+ /**
+ * Specify function which will be applied after feature extractor.
+ *
+ * @param after Function which will be applied after feature extractor.
+ * @return New trainer with same parameters as this trainer except that specified function will be applied
+ * after feature extractor.
+ */
+ public AdaptableDatasetTrainer<I, O, IW, OW, M, L> afterFeatureExtractor(IgniteFunction<Vector, Vector> after) {
+ return new AdaptableDatasetTrainer<>(before,
+ wrapped,
+ this.after,
+ after,
+ afterLabelExtractor,
+ upstreamTransformerBuilder);
+ }
+
+ /**
+ * Specify function which will be applied after label extractor.
+ *
+ * @param after Function which will be applied after label extractor.
+ * @return New trainer with same parameters as this trainer has except that specified function will be applied
+ * after label extractor.
+ */
+ public AdaptableDatasetTrainer<I, O, IW, OW, M, L> afterLabelExtractor(IgniteFunction<L, L> after) {
+ return new AdaptableDatasetTrainer<>(before,
+ wrapped,
+ this.after,
+ afterFeatureExtractor,
+ after,
+ upstreamTransformerBuilder);
+ }
+
+ /**
+ * Specify which {@link UpstreamTransformerBuilder} will be used.
+ *
+ * @param upstreamTransformerBuilder {@link UpstreamTransformerBuilder} to use.
+ * @return New trainer with same parameters as this trainer has except that specified {@link UpstreamTransformerBuilder} will be used.
+ */
+ public AdaptableDatasetTrainer<I, O, IW, OW, M, L> withUpstreamTransformerBuilder(
+ UpstreamTransformerBuilder upstreamTransformerBuilder) {
+ return new AdaptableDatasetTrainer<>(before,
+ wrapped,
+ after,
+ afterFeatureExtractor,
+ afterLabelExtractor,
+ upstreamTransformerBuilder);
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
index 88c4bcd153d5..0b4258bf657c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/DatasetTrainer.java
@@ -76,7 +76,7 @@
IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
if(mdl != null) {
- if (checkState(mdl))
+ if (isUpdateable(mdl))
return updateModel(mdl, datasetBuilder, featureExtractor, lbExtractor);
else {
environment.logger(getClass()).log(
@@ -94,7 +94,7 @@
* @param mdl Model.
* @return true if current critical for training parameters correspond to parameters from last training.
*/
- protected abstract boolean checkState(M mdl);
+ public abstract boolean isUpdateable(M mdl);
/**
* Used on update phase when given dataset is empty.
@@ -308,12 +308,12 @@
}
/**
- * Creates {@code DatasetTrainer} with same training logic, but able to accept labels of given new type
+ * Creates {@link DatasetTrainer} with same training logic, but able to accept labels of given new type
* of labels.
*
* @param new2Old Converter of new labels to old labels.
* @param <L1> New labels type.
- * @return {@code DatasetTrainer} with same training logic, but able to accept labels of given new type
+ * @return {@link DatasetTrainer} with same training logic, but able to accept labels of given new type
* of labels.
*/
public <L1> DatasetTrainer<M, L1> withConvertedLabels(IgniteFunction<L1, L> new2Old) {
@@ -326,8 +326,8 @@
}
/** {@inheritDoc} */
- @Override protected boolean checkState(M mdl) {
- return old.checkState(mdl);
+ @Override public boolean isUpdateable(M mdl) {
+ return old.isUpdateable(mdl);
}
/** {@inheritDoc} */
@@ -362,4 +362,31 @@ public EmptyDatasetException() {
}
}
+ /**
+ * Returns the trainer which returns identity model.
+ *
+ * @param <I> Type of model input.
+ * @param <L> Type of labels in dataset.
+ * @return Trainer which returns identity model.
+ */
+ public static <I, L> DatasetTrainer<IgniteModel<I, I>, L> identityTrainer() {
+ return new DatasetTrainer<IgniteModel<I, I>, L>() {
+ @Override public <K, V> IgniteModel<I, I> fit(DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor,
+ IgniteBiFunction<K, V, L> lbExtractor) {
+ return x -> x;
+ }
+
+ /** {@inheritDoc} */
+ @Override public boolean isUpdateable(IgniteModel<I, I> mdl) {
+ return true;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected <K, V> IgniteModel<I, I> updateModel(IgniteModel<I, I> mdl, DatasetBuilder<K, V> datasetBuilder,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor) {
+ return x -> x;
+ }
+ };
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
index 43c160020a7e..db5522ec5eb2 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/TrainerTransformers.java
@@ -24,6 +24,7 @@
import java.util.stream.IntStream;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.environment.LearningEnvironment;
@@ -48,12 +49,11 @@
* @param ensembleSize Size of ensemble.
* @param subsampleRatio Subsample ratio to whole dataset.
* @param aggregator Aggregator.
- * @param <M> Type of one model in ensemble.
* @param <L> Type of labels.
* @return Bagged trainer.
*/
- public static <M extends IgniteModel<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged(
- DatasetTrainer<M, L> trainer,
+ public static <L> BaggedTrainer<L> makeBagged(
+ DatasetTrainer<? extends IgniteModel, L> trainer,
int ensembleSize,
double subsampleRatio,
PredictionsAggregator aggregator) {
@@ -71,58 +71,19 @@
* @param <L> Type of labels.
* @return Bagged trainer.
*/
- public static <M extends IgniteModel<Vector, Double>, L> DatasetTrainer<ModelsComposition, L> makeBagged(
+ public static <M extends IgniteModel<Vector, Double>, L> BaggedTrainer<L> makeBagged(
DatasetTrainer<M, L> trainer,
int ensembleSize,
double subsampleRatio,
int featureVectorSize,
int featuresSubspaceDim,
PredictionsAggregator aggregator) {
- return new DatasetTrainer<ModelsComposition, L>() {
- /** {@inheritDoc} */
- @Override public <K, V> ModelsComposition fit(
- DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, L> lbExtractor) {
- return runOnEnsemble(
- (db, i, fe) -> (() -> trainer.fit(db, fe, lbExtractor)),
- datasetBuilder,
- ensembleSize,
- subsampleRatio,
- featureVectorSize,
- featuresSubspaceDim,
- featureExtractor,
- aggregator,
- environment);
- }
-
- /** {@inheritDoc} */
- @Override protected boolean checkState(ModelsComposition mdl) {
- return mdl.getModels().stream().allMatch(m -> trainer.checkState((M)m));
- }
-
- /** {@inheritDoc} */
- @Override protected <K, V> ModelsComposition updateModel(
- ModelsComposition mdl,
- DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, L> lbExtractor) {
- return runOnEnsemble(
- (db, i, fe) -> (() -> trainer.updateModel(
- ((ModelWithMapping<Vector, Double, M>)mdl.getModels().get(i)).model(),
- db,
- fe,
- lbExtractor)),
- datasetBuilder,
- ensembleSize,
- subsampleRatio,
- featureVectorSize,
- featuresSubspaceDim,
- featureExtractor,
- aggregator,
- environment);
- }
- }.withEnvironmentBuilder(trainer.envBuilder);
+ return new BaggedTrainer<>(trainer,
+ aggregator,
+ ensembleSize,
+ subsampleRatio,
+ featureVectorSize,
+ featuresSubspaceDim);
}
/**
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java
index 7f45fdddd01a..36e78678ee6d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trainers/transformers/BaggingUpstreamTransformer.java
@@ -28,11 +28,8 @@
* This class encapsulates the logic needed to do bagging (bootstrap aggregating) by features.
* The action of this class on a given upstream is to replicate each entry in accordance to
* Poisson distribution.
- *
- * @param <K> Type of upstream keys.
- * @param <V> Type of upstream values.
*/
-public class BaggingUpstreamTransformer<K, V> implements UpstreamTransformer<K, V> {
+public class BaggingUpstreamTransformer implements UpstreamTransformer {
/** Serial version uid. */
private static final long serialVersionUID = -913152523469994149L;
@@ -51,8 +48,8 @@
* @param <V> Type of upstream values.
* @return Builder of {@link BaggingUpstreamTransformer}.
*/
- public static <K, V> UpstreamTransformerBuilder<K, V> builder(double subsampleRatio, int mdlIdx) {
- return env -> new BaggingUpstreamTransformer<>(env.randomNumbersGenerator().nextLong() + mdlIdx, subsampleRatio);
+ public static <K, V> UpstreamTransformerBuilder builder(double subsampleRatio, int mdlIdx) {
+ return env -> new BaggingUpstreamTransformer(env.randomNumbersGenerator().nextLong() + mdlIdx, subsampleRatio);
}
/**
@@ -67,7 +64,7 @@ public BaggingUpstreamTransformer(long seed, double subsampleRatio) {
}
/** {@inheritDoc} */
- @Override public Stream<UpstreamEntry<K, V>> transform(Stream<UpstreamEntry<K, V>> upstream) {
+ @Override public Stream<UpstreamEntry> transform(Stream<UpstreamEntry> upstream) {
PoissonDistribution poisson = new PoissonDistribution(
new Well19937c(seed),
subsampleRatio,
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
index 35d1ea4a2e4a..f3fc4ce69945 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
@@ -106,7 +106,7 @@
}
/** {@inheritDoc} */
- @Override protected boolean checkState(DecisionTreeNode mdl) {
+ @Override public boolean isUpdateable(DecisionTreeNode mdl) {
return true;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
index d9b8e3075580..6d92948fc0f4 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
@@ -239,7 +239,7 @@ protected boolean init(Dataset<EmptyContext, BootstrappedDatasetPartition> datas
}
/** {@inheritDoc} */
- @Override protected boolean checkState(ModelsComposition mdl) {
+ @Override public boolean isUpdateable(ModelsComposition mdl) {
ModelsComposition fakeComposition = buildComposition(Collections.emptyList());
return mdl.getPredictionsAggregator().getClass() == fakeComposition.getPredictionsAggregator().getClass();
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
index 63a9f3c3bd3a..b02e3bef8f8f 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
@@ -23,9 +23,11 @@
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Iterator;
+import java.util.Objects;
import java.util.Random;
import java.util.Spliterator;
import java.util.Spliterators;
+import java.util.function.BiFunction;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.ignite.IgniteException;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DataStreamGenerator.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DataStreamGenerator.java
index c2fd6522da80..e57c5baa9285 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DataStreamGenerator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DataStreamGenerator.java
@@ -126,7 +126,7 @@ public default DataStreamGenerator blur(RandomProducer rnd) {
* @return Dataset builder.
*/
public default DatasetBuilder<Vector, Double> asDatasetBuilder(int datasetSize, IgniteBiPredicate<Vector, Double> filter,
- int partitions, UpstreamTransformerBuilder<Vector, Double> upstreamTransformerBuilder) {
+ int partitions, UpstreamTransformerBuilder upstreamTransformerBuilder) {
return new DatasetBuilderAdapter(this, datasetSize, filter, partitions, upstreamTransformerBuilder);
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DatasetBuilderAdapter.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DatasetBuilderAdapter.java
index 189e053e3559..7e5060e55a1d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DatasetBuilderAdapter.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/generators/DatasetBuilderAdapter.java
@@ -48,7 +48,7 @@ public DatasetBuilderAdapter(DataStreamGenerator generator, int datasetSize, int
*/
public DatasetBuilderAdapter(DataStreamGenerator generator, int datasetSize,
IgniteBiPredicate<Vector, Double> filter, int partitions,
- UpstreamTransformerBuilder<Vector, Double> upstreamTransformerBuilder) {
+ UpstreamTransformerBuilder upstreamTransformerBuilder) {
super(generator.asMap(datasetSize), filter, partitions, upstreamTransformerBuilder);
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java
index fc3bf5c1d0c6..ed23373474f4 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/TestUtils.java
@@ -429,7 +429,7 @@ public T val() {
}
/** {@inheritDoc} */
- @Override public boolean checkState(M mdl) {
+ @Override public boolean isUpdateable(M mdl) {
return true;
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java
index dd4b11eeb3ec..4f8f412ef31e 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/BaggingTest.java
@@ -22,6 +22,8 @@
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.TestUtils;
import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.composition.bagging.BaggedModel;
+import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
@@ -77,18 +79,16 @@ public void testNaiveBaggingLogRegression() {
.withBatchSize(10)
.withSeed(123L);
- trainer.withEnvironmentBuilder(TestUtils.testEnvBuilder());
-
- DatasetTrainer<ModelsComposition, Double> baggedTrainer =
- TrainerTransformers.makeBagged(
- trainer,
- 10,
- 0.7,
- 2,
- 2,
- new OnMajorityPredictionsAggregator());
+ BaggedTrainer<Double> baggedTrainer = TrainerTransformers.makeBagged(
+ trainer,
+ 10,
+ 0.7,
+ 2,
+ 2,
+ new OnMajorityPredictionsAggregator())
+ .withEnvironmentBuilder(TestUtils.testEnvBuilder());
- ModelsComposition mdl = baggedTrainer.fit(
+ BaggedModel mdl = baggedTrainer.fit(
cacheMock,
parts,
(k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -111,14 +111,17 @@ protected void count(IgniteTriFunction<Long, CountData, LearningEnvironment, Lon
double subsampleRatio = 0.3;
- ModelsComposition mdl = TrainerTransformers.makeBagged(
+ BaggedModel mdl = TrainerTransformers.makeBagged(
cntTrainer,
100,
subsampleRatio,
2,
2,
new MeanValuePredictionsAggregator())
- .fit(cacheMock, parts, null, null);
+ .fit(cacheMock,
+ parts,
+ (integer, doubles) -> VectorUtils.of(doubles),
+ (integer, doubles) -> doubles[doubles.length - 1]);
Double res = mdl.predict(null);
@@ -177,7 +180,7 @@ public CountTrainer(IgniteTriFunction<Long, CountData, LearningEnvironment, Long
}
/** {@inheritDoc} */
- @Override protected boolean checkState(IgniteModel<Vector, Double> mdl) {
+ @Override public boolean isUpdateable(IgniteModel<Vector, Double> mdl) {
return true;
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
index d253ea05078e..874547ffe56f 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/environment/LearningEnvironmentTest.java
@@ -103,7 +103,7 @@ public void testRandomNumbersGenerator() {
}
/** {@inheritDoc} */
- @Override protected boolean checkState(IgniteModel<Object, Vector> mdl) {
+ @Override public boolean isUpdateable(IgniteModel<Object, Vector> mdl) {
return false;
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java
new file mode 100644
index 000000000000..6b24a170a23e
--- /dev/null
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trainers/StackingTest.java
@@ -0,0 +1,169 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.trainers;
+
+import java.util.Arrays;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.TestUtils;
+import org.apache.ignite.ml.common.TrainerTest;
+import org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer;
+import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.matrix.Matrix;
+import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.nn.Activators;
+import org.apache.ignite.ml.nn.MLPTrainer;
+import org.apache.ignite.ml.nn.MultilayerPerceptron;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
+import org.apache.ignite.ml.optimization.LossFunctions;
+import org.apache.ignite.ml.optimization.SmoothParametrized;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.composition.stacking.StackedModel;
+import org.junit.Rule;
+import org.junit.Test;
+import org.junit.rules.ExpectedException;
+
+import static junit.framework.TestCase.assertEquals;
+
+/**
+ * Tests stacked trainers.
+ */
+public class StackingTest extends TrainerTest {
+ /** Rule to check exceptions. */
+ @Rule
+ public ExpectedException thrown = ExpectedException.none();
+
+ /**
+ * Tests simple stack training.
+ */
+ @Test
+ public void testSimpleStack() {
+ StackedDatasetTrainer<Vector, Vector, Double, LinearRegressionModel, Double> trainer =
+ new StackedDatasetTrainer<>();
+
+ UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ );
+
+ MLPArchitecture arch = new MLPArchitecture(2).
+ withAddedLayer(10, true, Activators.RELU).
+ withAddedLayer(1, false, Activators.SIGMOID);
+
+ MLPTrainer<SimpleGDParameterUpdate> trainer1 = new MLPTrainer<>(
+ arch,
+ LossFunctions.MSE,
+ updatesStgy,
+ 3000,
+ 10,
+ 50,
+ 123L
+ );
+
+ // Convert model trainer to produce Vector -> Vector model
+ DatasetTrainer<AdaptableDatasetModel<Vector, Vector, Matrix, Matrix, MultilayerPerceptron>, Double> mlpTrainer =
+ AdaptableDatasetTrainer.of(trainer1)
+ .beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 1))
+ .afterTrainedModel((Matrix mtx) -> mtx.getRow(0))
+ .withConvertedLabels(VectorUtils::num2Arr);
+
+ final double factor = 3;
+
+ StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer
+ .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
+ .addTrainer(mlpTrainer)
+ .withAggregatorInputMerger(VectorUtils::concat)
+ .withSubmodelOutput2VectorConverter(IgniteFunction.identity())
+ .withVector2SubmodelInputConverter(IgniteFunction.identity())
+ .withOriginalFeaturesKept(IgniteFunction.identity())
+ .withEnvironmentBuilder(TestUtils.testEnvBuilder())
+ .fit(getCacheMock(xor),
+ parts,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+ (k, v) -> v[v.length - 1]);
+
+ assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
+ assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
+ assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
+ assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
+ }
+
+ /**
+ * Tests simple stack training.
+ */
+ @Test
+ public void testSimpleVectorStack() {
+ StackedVectorDatasetTrainer<Double, LinearRegressionModel, Double> trainer =
+ new StackedVectorDatasetTrainer<>();
+
+ UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate::sumLocal,
+ SimpleGDParameterUpdate::avg
+ );
+
+ MLPArchitecture arch = new MLPArchitecture(2).
+ withAddedLayer(10, true, Activators.RELU).
+ withAddedLayer(1, false, Activators.SIGMOID);
+
+ DatasetTrainer<MultilayerPerceptron, Double> mlpTrainer = new MLPTrainer<>(
+ arch,
+ LossFunctions.MSE,
+ updatesStgy,
+ 3000,
+ 10,
+ 50,
+ 123L
+ ).withConvertedLabels(VectorUtils::num2Arr);
+
+ final double factor = 3;
+
+ StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer
+ .withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor))
+ .addMatrix2MatrixTrainer(mlpTrainer)
+ .withEnvironmentBuilder(TestUtils.testEnvBuilder())
+ .fit(getCacheMock(xor),
+ parts,
+ (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
+ (k, v) -> v[v.length - 1]);
+
+ assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
+ assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
+ assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
+ assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
+ }
+
+ /**
+ * Tests that if there is no any way for input of first layer to propagate to second layer,
+ * exception will be thrown.
+ */
+ @Test
+ public void testINoWaysOfPropagation() {
+ StackedDatasetTrainer<Void, Void, Void, IgniteModel<Void, Void>, Void> trainer =
+ new StackedDatasetTrainer<>();
+ thrown.expect(IllegalStateException.class);
+ trainer.fit(null, null, null);
+ }
+}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java
index f2899c280e04..d711fc4ba44f 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/util/generators/DataStreamGeneratorTest.java
@@ -147,8 +147,8 @@ public void testAsDatasetBuilder() throws Exception {
DatasetBuilder<Vector, Double> b2 = generator.asDatasetBuilder(N, (v, l) -> l == 0, 2);
counter.set(0);
DatasetBuilder<Vector, Double> b3 = generator.asDatasetBuilder(N, (v, l) -> l == 1, 2,
- new UpstreamTransformerBuilder<Vector, Double>() {
- @Override public UpstreamTransformer<Vector, Double> build(LearningEnvironment env) {
+ new UpstreamTransformerBuilder() {
+ @Override public UpstreamTransformer build(LearningEnvironment env) {
return new UpstreamTransformerForTest();
}
});
@@ -201,10 +201,10 @@ private void checkDataset(int sampleSize, DatasetBuilder<Vector, Double> dataset
}
/** */
- private static class UpstreamTransformerForTest implements UpstreamTransformer<Vector, Double> {
- @Override public Stream<UpstreamEntry<Vector, Double>> transform(
- Stream<UpstreamEntry<Vector, Double>> upstream) {
- return upstream.map(entry -> new UpstreamEntry<>(entry.getKey(), -entry.getValue()));
+ private static class UpstreamTransformerForTest implements UpstreamTransformer {
+ @Override public Stream<UpstreamEntry> transform(
+ Stream<UpstreamEntry> upstream) {
+ return upstream.map(entry -> new UpstreamEntry<>(entry.getKey(), -((double)entry.getValue())));
}
}
}
With regards,
Apache Git Services