You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by za...@apache.org on 2019/03/29 12:33:04 UTC
[ignite] branch master updated: IGNITE-9497: [ML] Add Pipeline
support to Cross-Validation process
This is an automated email from the ASF dual-hosted git repository.
zaleslaw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new 7ddf776 IGNITE-9497: [ML] Add Pipeline support to Cross-Validation process
7ddf776 is described below
commit 7ddf776bf4f6320284ed7f2c883c37751e5d302d
Author: Zinoviev Alexey <za...@gmail.com>
AuthorDate: Fri Mar 29 14:26:03 2019 +0300
IGNITE-9497: [ML] Add Pipeline support to Cross-Validation process
This closes #6226
---
.../ml/tutorial/Step_5_Scaling_with_Pipeline.java | 11 +-
.../Step_8_CV_with_Param_Grid_and_metrics.java | 7 +-
..._with_Param_Grid_and_metrics_and_pipeline.java} | 111 ++++---------
.../org/apache/ignite/ml/pipeline/Pipeline.java | 36 +++--
.../ignite/ml/selection/cv/CrossValidation.java | 178 +++++++++++++++++++--
.../apache/ignite/ml/pipeline/PipelineTest.java | 15 +-
6 files changed, 234 insertions(+), 124 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
index e65a9a6..276418e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling_with_Pipeline.java
@@ -17,6 +17,7 @@
package org.apache.ignite.examples.ml.tutorial;
+import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -32,8 +33,6 @@ import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import java.io.FileNotFoundException;
-
/**
* {@link MinMaxScalerTrainer} and {@link NormalizationTrainer} are used in this example due to different values
* distribution in columns and rows.
@@ -67,13 +66,13 @@ public class Step_5_Scaling_with_Pipeline {
PipelineMdl<Integer, Object[]> mdl = new Pipeline<Integer, Object[], Object[]>()
.addFeatureExtractor(featureExtractor)
.addLabelExtractor(lbExtractor)
- .addPreprocessor(new EncoderTrainer<Integer, Object[]>()
+ .addPreprocessingTrainer(new EncoderTrainer<Integer, Object[]>()
.withEncoderType(EncoderType.STRING_ENCODER)
.withEncodedFeature(1)
.withEncodedFeature(6))
- .addPreprocessor(new ImputerTrainer<Integer, Object[]>())
- .addPreprocessor(new MinMaxScalerTrainer<Integer, Object[]>())
- .addPreprocessor(new NormalizationTrainer<Integer, Object[]>()
+ .addPreprocessingTrainer(new ImputerTrainer<Integer, Object[]>())
+ .addPreprocessingTrainer(new MinMaxScalerTrainer<Integer, Object[]>())
+ .addPreprocessingTrainer(new NormalizationTrainer<Integer, Object[]>()
.withP(1))
.addTrainer(new DecisionTreeClassificationTrainer(5, 0))
.fit(ignite, dataCache);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
index d8bb5ef..181a59e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
@@ -17,6 +17,8 @@
package org.apache.ignite.examples.ml.tutorial;
+import java.io.FileNotFoundException;
+import java.util.Arrays;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -40,9 +42,6 @@ import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
-import java.io.FileNotFoundException;
-import java.util.Arrays;
-
/**
* To choose the best hyperparameters the cross-validation with {@link ParamGrid} will be used in this example.
* <p>
@@ -125,7 +124,7 @@ public class Step_8_CV_with_Param_Grid_and_metrics {
= new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
- .addHyperParam("maxDeep", new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 10.0})
+ .addHyperParam("maxDeep", new Double[]{1.0, 2.0, 3.0, 4.0, 5.0, 10.0})
.addHyperParam("minImpurityDecrease", new Double[] {0.0, 0.25, 0.5});
BinaryClassificationMetrics metrics = (BinaryClassificationMetrics) new BinaryClassificationMetrics()
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics_and_pipeline.java
similarity index 62%
copy from examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
copy to examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics_and_pipeline.java
index d8bb5ef..b4e998e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics_and_pipeline.java
@@ -17,22 +17,21 @@
package org.apache.ignite.examples.ml.tutorial;
+import java.io.FileNotFoundException;
+import java.util.Arrays;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
-import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.pipeline.Pipeline;
import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerTrainer;
-import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import org.apache.ignite.ml.selection.cv.CrossValidation;
import org.apache.ignite.ml.selection.cv.CrossValidationResult;
import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
-import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetricValues;
import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetrics;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
@@ -40,9 +39,6 @@ import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
-import java.io.FileNotFoundException;
-import java.util.Arrays;
-
/**
* To choose the best hyperparameters the cross-validation with {@link ParamGrid} will be used in this example.
* <p>
@@ -67,7 +63,7 @@ import java.util.Arrays;
* <p>
* All scenarios are described there: https://sebastianraschka.com/faq/docs/evaluate-a-model.html</p>
*/
-public class Step_8_CV_with_Param_Grid_and_metrics {
+public class Step_8_CV_with_Param_Grid_and_metrics_and_pipeline {
/** Run example. */
public static void main(String[] args) {
System.out.println();
@@ -79,54 +75,44 @@ public class Step_8_CV_with_Param_Grid_and_metrics {
// Defines first preprocessor that extracts features from an upstream data.
// Extracts "pclass", "sibsp", "parch", "sex", "embarked", "age", "fare" .
- IgniteBiFunction<Integer, Object[], Object[]> featureExtractor
- = (k, v) -> new Object[] {v[0], v[3], v[4], v[5], v[6], v[8], v[10]};
-
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[]{
+ (double) v[0],
+ (double) v[4],
+ (double) v[5],
+ (double) v[6],
+ (double) v[8],
+ };
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+ data[3] = Double.isNaN(data[3]) ? 0 : data[3];
+ data[4] = Double.isNaN(data[4]) ? 0 : data[4];
+
+ return VectorUtils.of(data);
+ };
+
+
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1];
TrainTestSplit<Integer, Object[]> split = new TrainTestDatasetSplitter<Integer, Object[]>()
.split(0.75);
- IgniteBiFunction<Integer, Object[], Vector> strEncoderPreprocessor = new EncoderTrainer<Integer, Object[]>()
- .withEncoderType(EncoderType.STRING_ENCODER)
- .withEncodedFeature(1)
- .withEncodedFeature(6) // <--- Changed index here.
- .fit(ignite,
- dataCache,
- featureExtractor
- );
-
- IgniteBiFunction<Integer, Object[], Vector> imputingPreprocessor = new ImputerTrainer<Integer, Object[]>()
- .fit(ignite,
- dataCache,
- strEncoderPreprocessor
- );
-
- IgniteBiFunction<Integer, Object[], Vector> minMaxScalerPreprocessor = new MinMaxScalerTrainer<Integer, Object[]>()
- .fit(
- ignite,
- dataCache,
- imputingPreprocessor
- );
-
- IgniteBiFunction<Integer, Object[], Vector> normalizationPreprocessor = new NormalizationTrainer<Integer, Object[]>()
- .withP(2)
- .fit(
- ignite,
- dataCache,
- minMaxScalerPreprocessor
- );
+ Pipeline<Integer, Object[], Vector> pipeline = new Pipeline<Integer, Object[], Vector>()
+ .addFeatureExtractor(featureExtractor)
+ .addLabelExtractor(lbExtractor)
+ .addPreprocessingTrainer(new ImputerTrainer<Integer, Object[]>())
+ .addPreprocessingTrainer(new MinMaxScalerTrainer<Integer, Object[]>())
+ .addTrainer(new DecisionTreeClassificationTrainer(5, 0));
// Tune hyperparams with K-fold Cross-Validation on the split training set.
- DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer();
-
CrossValidation<DecisionTreeNode, Double, Integer, Object[]> scoreCalculator
= new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
- .addHyperParam("maxDeep", new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 10.0, 10.0})
- .addHyperParam("minImpurityDecrease", new Double[] {0.0, 0.25, 0.5});
+ .addHyperParam("maxDeep", new Double[]{1.0, 2.0, 3.0, 4.0, 5.0, 10.0})
+ .addHyperParam("minImpurityDecrease", new Double[]{0.0, 0.25, 0.5});
BinaryClassificationMetrics metrics = (BinaryClassificationMetrics) new BinaryClassificationMetrics()
.withNegativeClsLb(0.0)
@@ -134,12 +120,11 @@ public class Step_8_CV_with_Param_Grid_and_metrics {
.withMetric(BinaryClassificationMetricValues::accuracy);
CrossValidationResult crossValidationRes = scoreCalculator.score(
- trainerCV,
+ pipeline,
metrics,
ignite,
dataCache,
split.getTrainFilter(),
- normalizationPreprocessor,
lbExtractor,
3,
paramGrid
@@ -148,10 +133,6 @@ public class Step_8_CV_with_Param_Grid_and_metrics {
System.out.println("Train with maxDeep: " + crossValidationRes.getBest("maxDeep")
+ " and minImpurityDecrease: " + crossValidationRes.getBest("minImpurityDecrease"));
- DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer()
- .withMaxDeep(crossValidationRes.getBest("maxDeep"))
- .withMinImpurityDecrease(crossValidationRes.getBest("minImpurityDecrease"));
-
System.out.println(crossValidationRes);
System.out.println("Best score: " + Arrays.toString(crossValidationRes.getBestScore()));
@@ -161,31 +142,7 @@ public class Step_8_CV_with_Param_Grid_and_metrics {
crossValidationRes.getScoringBoard().forEach((hyperParams, score)
-> System.out.println("Score " + Arrays.toString(score) + " for hyper params " + hyperParams));
- // Train decision tree model.
- DecisionTreeNode bestMdl = trainer.fit(
- ignite,
- dataCache,
- split.getTrainFilter(),
- FeatureLabelExtractorWrapper.wrap(normalizationPreprocessor, lbExtractor) //TODO: IGNITE-11581
- );
-
- System.out.println("\n>>> Trained model: " + bestMdl);
-
- double accuracy = Evaluator.evaluate(
- dataCache,
- split.getTestFilter(),
- bestMdl,
- normalizationPreprocessor,
- lbExtractor,
- new Accuracy<>()
- );
-
- System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println("\n>>> Test Error " + (1 - accuracy));
-
- System.out.println(">>> Tutorial step 8 (cross-validation with param grid) example started.");
- }
- catch (FileNotFoundException e) {
+ } catch (FileNotFoundException e) {
e.printStackTrace();
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/pipeline/Pipeline.java b/modules/ml/src/main/java/org/apache/ignite/ml/pipeline/Pipeline.java
index 947bd6a..9897788 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/pipeline/Pipeline.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/pipeline/Pipeline.java
@@ -17,6 +17,9 @@
package org.apache.ignite.ml.pipeline;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Map;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.ml.IgniteModel;
@@ -29,10 +32,6 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.PreprocessingTrainer;
import org.apache.ignite.ml.trainers.DatasetTrainer;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Map;
-
/**
* A simple pipeline, which acts as a global trainer which produce a Pipeline Model.
* A Pipeline consists of a sequence of stages, each of which is either a Preprocessing Stage or a Trainer.
@@ -44,13 +43,16 @@ import java.util.Map;
*/
public class Pipeline<K, V, R> {
/** Feature extractor. */
+ private IgniteBiFunction<K, V, R> featureExtractor;
+
+ /** Final Feature extractor. */
private IgniteBiFunction<K, V, R> finalFeatureExtractor;
/** Label extractor. */
private IgniteBiFunction<K, V, Double> lbExtractor;
/** Preprocessor stages. */
- private List<PreprocessingTrainer> preprocessors = new ArrayList<>();
+ private List<PreprocessingTrainer> preprocessingTrainers = new ArrayList<>();
/** Final trainer stage. */
private DatasetTrainer finalStage;
@@ -65,7 +67,7 @@ public class Pipeline<K, V, R> {
* @return The updated Pipeline.
*/
public Pipeline<K, V, R> addFeatureExtractor(IgniteBiFunction<K, V, R> featureExtractor) {
- this.finalFeatureExtractor = featureExtractor;
+ this.featureExtractor = featureExtractor;
return this;
}
@@ -83,11 +85,11 @@ public class Pipeline<K, V, R> {
/**
* Adds a preprocessor.
*
- * @param preprocessor The parameter value.
+ * @param preprocessingTrainer The parameter value.
* @return The updated Pipeline.
*/
- public Pipeline<K, V, R> addPreprocessor(PreprocessingTrainer preprocessor) {
- preprocessors.add(preprocessor);
+ public Pipeline<K, V, R> addPreprocessingTrainer(PreprocessingTrainer preprocessingTrainer) {
+ preprocessingTrainers.add(preprocessingTrainer);
return this;
}
@@ -103,6 +105,13 @@ public class Pipeline<K, V, R> {
}
/**
+ * Returns trainer.
+ */
+ public DatasetTrainer getTrainer() {
+ return finalStage;
+ }
+
+ /**
* Fits the pipeline to the input cache.
*
* @param ignite Ignite instance.
@@ -136,14 +145,17 @@ public class Pipeline<K, V, R> {
}
/** Fits the pipeline to the input dataset builder. */
- private PipelineMdl<K, V> fit(DatasetBuilder datasetBuilder) {
+ public PipelineMdl<K, V> fit(DatasetBuilder datasetBuilder) {
assert lbExtractor != null;
- assert finalFeatureExtractor != null;
+ assert featureExtractor != null;
if (finalStage == null)
throw new IllegalStateException("The Pipeline should be finished with the Training Stage.");
- preprocessors.forEach(e -> {
+ // Reload for new fit
+ finalFeatureExtractor = featureExtractor;
+
+ preprocessingTrainers.forEach(e -> {
finalFeatureExtractor = e.fit(
envBuilder,
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java
index 59aeddc..1f64cce 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/selection/cv/CrossValidation.java
@@ -17,6 +17,14 @@
package org.apache.ignite.ml.selection.cv;
+import java.lang.reflect.InvocationTargetException;
+import java.lang.reflect.Method;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.function.BiFunction;
+import java.util.function.Function;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.lang.IgniteBiPredicate;
@@ -27,6 +35,8 @@ import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.pipeline.Pipeline;
+import org.apache.ignite.ml.pipeline.PipelineMdl;
import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
import org.apache.ignite.ml.selection.paramgrid.ParameterSetGenerator;
import org.apache.ignite.ml.selection.scoring.cursor.CacheBasedLabelPairCursor;
@@ -37,15 +47,6 @@ import org.apache.ignite.ml.selection.split.mapper.SHA256UniformMapper;
import org.apache.ignite.ml.selection.split.mapper.UniformMapper;
import org.apache.ignite.ml.trainers.DatasetTrainer;
-import java.lang.reflect.InvocationTargetException;
-import java.lang.reflect.Method;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.function.BiFunction;
-import java.util.function.Function;
-
/**
* Cross validation score calculator. Cross validation is an approach that allows to avoid overfitting that is made the
* following way: the training set is split into k smaller sets. The following procedure is followed for each of the k
@@ -113,14 +114,14 @@ public class CrossValidation<M extends IgniteModel<Vector, L>, L, K, V> {
* @param filter Base {@code upstream} data filter.
* @param featureExtractor Feature extractor.
* @param lbExtractor Label extractor.
- * @param cv Number of folds.
+ * @param amountOfFolds Amount of folds.
* @param paramGrid Parameter grid.
* @return Array of scores of the estimator for each run of the cross validation.
*/
public CrossValidationResult score(DatasetTrainer<M, L> trainer, Metric<L> scoreCalculator, Ignite ignite,
- IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
- IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int cv,
- ParamGrid paramGrid) {
+ IgniteCache<K, V> upstreamCache, IgniteBiPredicate<K, V> filter,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor, int amountOfFolds,
+ ParamGrid paramGrid) {
List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
@@ -159,7 +160,7 @@ public class CrossValidation<M extends IgniteModel<Vector, L>, L, K, V> {
}
double[] locScores = score(trainer, scoreCalculator, ignite, upstreamCache, filter, featureExtractor, lbExtractor,
- new SHA256UniformMapper<>(), cv);
+ new SHA256UniformMapper<>(), amountOfFolds);
cvRes.addScores(locScores, paramMap);
@@ -308,9 +309,9 @@ public class CrossValidation<M extends IgniteModel<Vector, L>, L, K, V> {
*/
private double[] score(DatasetTrainer<M, L> trainer, Function<IgniteBiPredicate<K, V>,
DatasetBuilder<K, V>> datasetBuilderSupplier,
- BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> testDataIterSupplier,
- IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor,
- Metric<L> scoreCalculator, UniformMapper<K, V> mapper, int cv) {
+ BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> testDataIterSupplier,
+ IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, L> lbExtractor,
+ Metric<L> scoreCalculator, UniformMapper<K, V> mapper, int cv) {
double[] scores = new double[cv];
@@ -336,4 +337,147 @@ public class CrossValidation<M extends IgniteModel<Vector, L>, L, K, V> {
return scores;
}
+
+
+ /**
+ * Computes cross-validated metrics with a passed parameter grid.
+ *
+ * The real cross-validation training will be called each time for each parameter set.
+ *
+ * @param pipeline Pipeline of stages.
+ * @param scoreCalculator Base score calculator.
+ * @param ignite Ignite instance.
+ * @param upstreamCache Ignite cache with {@code upstream} data.
+ * @param filter Base {@code upstream} data filter.
+ * @param lbExtractor Label extractor.
+ * @param amountOfFolds Amount of folds.
+ * @param paramGrid Parameter grid.
+ * @return Array of scores of the estimator for each run of the cross validation.
+ */
+ public CrossValidationResult score(Pipeline<K, V, Vector> pipeline,
+ Metric<L> scoreCalculator,
+ Ignite ignite,
+ IgniteCache<K, V> upstreamCache,
+ IgniteBiPredicate<K, V> filter,
+ IgniteBiFunction<K, V, L> lbExtractor,
+ int amountOfFolds,
+ ParamGrid paramGrid) {
+
+ List<Double[]> paramSets = new ParameterSetGenerator(paramGrid.getParamValuesByParamIdx()).generate();
+
+ CrossValidationResult cvRes = new CrossValidationResult();
+
+ DatasetTrainer trainer = pipeline.getTrainer();
+
+ paramSets.forEach(paramSet -> {
+ Map<String, Double> paramMap = new HashMap<>();
+
+
+ for (int paramIdx = 0; paramIdx < paramSet.length; paramIdx++) {
+ String paramName = paramGrid.getParamNameByIndex(paramIdx);
+ Double paramVal = paramSet[paramIdx];
+
+ paramMap.put(paramName, paramVal);
+
+ try {
+ final String mtdName = "with" +
+ paramName.substring(0, 1).toUpperCase() +
+ paramName.substring(1);
+
+ Method trainerSetter = null;
+
+ // We should iterate along all methods due to we have no info about signature and passed types.
+ for (Method method : trainer.getClass().getDeclaredMethods()) {
+ if (method.getName().equals(mtdName))
+ trainerSetter = method;
+ }
+
+ if (trainerSetter != null)
+ trainerSetter.invoke(trainer, paramVal);
+ else
+ throw new NoSuchMethodException(mtdName);
+
+ } catch (NoSuchMethodException | IllegalAccessException | InvocationTargetException e) {
+ e.printStackTrace();
+ }
+ }
+
+ double[] locScores = scorePipeline(
+ pipeline,
+ predicate -> new CacheBasedDatasetBuilder<>(
+ ignite,
+ upstreamCache,
+ (k, v) -> filter.apply(k, v) && predicate.apply(k, v)
+ ),
+ (predicate, mdl) -> new CacheBasedLabelPairCursor<>(
+ upstreamCache,
+ (k, v) -> filter.apply(k, v) && !predicate.apply(k, v),
+ ((PipelineMdl<K, V>) mdl).getFeatureExtractor(),
+ lbExtractor,
+ mdl
+ ),
+ scoreCalculator,
+ new SHA256UniformMapper<>(),
+ amountOfFolds
+ );
+
+
+ cvRes.addScores(locScores, paramMap);
+
+ final double locAvgScore = Arrays.stream(locScores).average().orElse(Double.MIN_VALUE);
+
+ if (locAvgScore > cvRes.getBestAvgScore()) {
+ cvRes.setBestScore(locScores);
+ cvRes.setBestHyperParams(paramMap);
+ System.out.println(paramMap.toString());
+ }
+ });
+
+ return cvRes;
+
+ }
+
+ /**
+ * Computes cross-validated metrics.
+ *
+ * @param pipeline Pipeline of stages.
+ * @param datasetBuilderSupplier Dataset builder supplier.
+ * @param testDataIterSupplier Test data iterator supplier.
+ * @param scoreCalculator Base score calculator.
+ * @param mapper Mapper used to map a key-value pair to a point on the segment (0, 1).
+ * @param cv Number of folds.
+ * @return Array of scores of the estimator for each run of the cross validation.
+ */
+ private double[] scorePipeline(Pipeline<K, V, Vector> pipeline,
+ Function<IgniteBiPredicate<K, V>, DatasetBuilder<K, V>> datasetBuilderSupplier,
+ BiFunction<IgniteBiPredicate<K, V>, M, LabelPairCursor<L>> testDataIterSupplier,
+ Metric<L> scoreCalculator,
+ UniformMapper<K, V> mapper,
+ int cv
+ ) {
+
+ double[] scores = new double[cv];
+
+ double foldSize = 1.0 / cv;
+ for (int i = 0; i < cv; i++) {
+ double from = foldSize * i;
+ double to = foldSize * (i + 1);
+
+ IgniteBiPredicate<K, V> trainSetFilter = (k, v) -> {
+ double pnt = mapper.map(k, v);
+ return pnt < from || pnt > to;
+ };
+
+ DatasetBuilder<K, V> datasetBuilder = datasetBuilderSupplier.apply(trainSetFilter);
+ PipelineMdl<K, V> mdl = pipeline.fit(datasetBuilder);
+
+ try (LabelPairCursor<L> cursor = testDataIterSupplier.apply(trainSetFilter, (M) mdl)) {
+ scores[i] = scoreCalculator.score(cursor.iterator());
+ } catch (Exception e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ return scores;
+ }
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
index b4c1712..341e574 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/pipeline/PipelineTest.java
@@ -17,6 +17,9 @@
package org.apache.ignite.ml.pipeline;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.Map;
import org.apache.ignite.ml.TestUtils;
import org.apache.ignite.ml.common.TrainerTest;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -29,10 +32,6 @@ import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
import org.junit.Test;
-import java.util.Arrays;
-import java.util.HashMap;
-import java.util.Map;
-
/**
* Tests for {@link Pipeline}.
*/
@@ -63,8 +62,8 @@ public class PipelineTest extends TrainerTest {
PipelineMdl<Integer, Double[]> mdl = new Pipeline<Integer, Double[], Vector>()
.addFeatureExtractor((k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)))
.addLabelExtractor((k, v) -> v[0])
- .addPreprocessor(new MinMaxScalerTrainer<Integer, Object[]>())
- .addPreprocessor(new NormalizationTrainer<Integer, Object[]>()
+ .addPreprocessingTrainer(new MinMaxScalerTrainer<Integer, Object[]>())
+ .addPreprocessingTrainer(new NormalizationTrainer<Integer, Object[]>()
.withP(1))
.addTrainer(trainer)
.fit(
@@ -94,8 +93,8 @@ public class PipelineTest extends TrainerTest {
PipelineMdl<Integer, Double[]> mdl = new Pipeline<Integer, Double[], Vector>()
.addFeatureExtractor((k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)))
.addLabelExtractor((k, v) -> v[0])
- .addPreprocessor(new MinMaxScalerTrainer<Integer, Object[]>())
- .addPreprocessor(new NormalizationTrainer<Integer, Object[]>()
+ .addPreprocessingTrainer(new MinMaxScalerTrainer<Integer, Object[]>())
+ .addPreprocessingTrainer(new NormalizationTrainer<Integer, Object[]>()
.withP(1))
.fit(
cacheMock,