You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/03/27 12:24:40 UTC
[ignite] branch master updated: IGNITE-11449: [ML] Umbrella: API
for Feature/Label extracting (part 1)
This is an automated email from the ASF dual-hosted git repository.
chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new a0a15d6 IGNITE-11449: [ML] Umbrella: API for Feature/Label extracting (part 1)
a0a15d6 is described below
commit a0a15d62a250defb0db9ec72153ee287830f6a15
Author: Alexey Platonov <ap...@gmail.com>
AuthorDate: Wed Mar 27 15:24:03 2019 +0300
IGNITE-11449: [ML] Umbrella: API for Feature/Label extracting (part 1)
This closes #6232
---
.../ml/TrainingWithBinaryObjectExample.java | 32 +--
.../ml/clustering/GmmClusterizationExample.java | 92 +++---
.../ml/clustering/KMeansClusterizationExample.java | 62 ++--
.../dataset/AlgorithmSpecificDatasetExample.java | 101 ++++---
.../ml/dataset/CacheBasedDatasetExample.java | 31 +-
.../IgniteModelDistributedInferenceExample.java | 76 ++---
.../spark/LogRegFromSparkThroughPMMLExample.java | 33 ++-
.../modelparser/DecisionTreeFromSparkExample.java | 55 ++--
.../DecisionTreeRegressionFromSparkExample.java | 62 ++--
.../spark/modelparser/GBTFromSparkExample.java | 51 ++--
.../modelparser/GBTRegressionFromSparkExample.java | 62 ++--
.../spark/modelparser/KMeansFromSparkExample.java | 62 ++--
.../LinearRegressionFromSparkExample.java | 62 ++--
.../spark/modelparser/LogRegFromSparkExample.java | 49 ++--
.../modelparser/RandomForestFromSparkExample.java | 51 ++--
.../RandomForestRegressionFromSparkExample.java | 62 ++--
.../spark/modelparser/SVMFromSparkExample.java | 51 ++--
.../examples/ml/knn/ANNClassificationExample.java | 100 ++++---
.../examples/ml/knn/KNNClassificationExample.java | 43 +--
.../examples/ml/knn/KNNRegressionExample.java | 46 +--
.../multiclass/OneVsRestClassificationExample.java | 187 ++++++------
.../DiscreteNaiveBayesTrainerExample.java | 52 ++--
.../GaussianNaiveBayesTrainerExample.java | 49 ++--
.../ignite/examples/ml/nn/MLPTrainerExample.java | 122 ++++----
.../ml/preprocessing/BinarizationExample.java | 34 ++-
.../examples/ml/preprocessing/ImputingExample.java | 38 +--
.../ImputingWithMostFrequentValuesExample.java | 36 ++-
.../ml/preprocessing/MaxAbsScalerExample.java | 34 ++-
.../ml/preprocessing/MinMaxScalerExample.java | 34 ++-
.../ml/preprocessing/NormalizationExample.java | 36 ++-
.../ml/preprocessing/StandardScalerExample.java | 34 ++-
.../linear/LinearRegressionLSQRTrainerExample.java | 62 ++--
...gressionLSQRTrainerWithMinMaxScalerExample.java | 57 ++--
.../linear/LinearRegressionSGDTrainerExample.java | 60 ++--
.../BaggedLogisticRegressionSGDTrainerExample.java | 80 +++---
.../LogisticRegressionSGDTrainerExample.java | 66 +++--
.../ml/selection/cv/CrossValidationExample.java | 151 +++++-----
.../ml/selection/scoring/EvaluatorExample.java | 56 ++--
.../selection/scoring/MultipleMetricsExample.java | 44 +--
.../selection/scoring/RegressionMetricExample.java | 55 ++--
.../split/TrainTestDatasetSplitterExample.java | 69 ++---
...eeClassificationTrainerSQLInferenceExample.java | 155 +++++-----
...onTreeClassificationTrainerSQLTableExample.java | 163 ++++++-----
.../ml/svm/SVMBinaryClassificationExample.java | 44 +--
.../DecisionTreeClassificationTrainerExample.java | 110 +++-----
.../tree/DecisionTreeRegressionTrainerExample.java | 74 ++---
.../GDBOnTreesClassificationTrainerExample.java | 54 ++--
.../GDBOnTreesRegressionTrainerExample.java | 50 ++--
.../RandomForestClassificationExample.java | 85 +++---
.../RandomForestRegressionExample.java | 99 ++++---
.../ml/tutorial/Step_10_Scaling_With_Stacking.java | 6 +-
.../ml/tutorial/Step_1_Read_and_Learn.java | 11 +-
.../examples/ml/tutorial/Step_2_Imputing.java | 4 +-
.../examples/ml/tutorial/Step_3_Categorial.java | 4 +-
.../Step_3_Categorial_with_One_Hot_Encoder.java | 4 +-
.../examples/ml/tutorial/Step_4_Add_age_fare.java | 4 +-
.../examples/ml/tutorial/Step_5_Scaling.java | 4 +-
.../ignite/examples/ml/tutorial/Step_6_KNN.java | 4 +-
.../ml/tutorial/Step_7_Split_train_test.java | 4 +-
.../ignite/examples/ml/tutorial/Step_8_CV.java | 4 +-
.../ml/tutorial/Step_8_CV_with_Param_Grid.java | 4 +-
.../Step_8_CV_with_Param_Grid_and_metrics.java | 4 +-
.../examples/ml/tutorial/Step_9_Go_to_LogReg.java | 8 +-
.../ignite/examples/ml/tutorial/TitanicUtils.java | 13 +-
.../util/generators/DatasetCreationExamples.java | 34 ++-
.../gmm/CovarianceMatricesAggregator.java | 19 +-
.../ignite/ml/clustering/gmm/GmmPartitionData.java | 58 ++--
.../ignite/ml/clustering/gmm/GmmTrainer.java | 272 +++++++++---------
.../gmm/MeanWithClusterProbAggregator.java | 17 +-
.../gmm/NewComponentStatisticsAggregator.java | 11 +-
.../ignite/ml/clustering/kmeans/KMeansTrainer.java | 35 +--
.../ignite/ml/composition/CompositionUtils.java | 25 +-
.../ml/composition/bagging/BaggedTrainer.java | 26 +-
.../boosting/GDBBinaryClassifierTrainer.java | 20 +-
.../composition/boosting/GDBLearningStrategy.java | 53 ++--
.../composition/boosting/GDBRegressionTrainer.java | 13 +-
.../ignite/ml/composition/boosting/GDBTrainer.java | 70 ++---
.../boosting/convergence/ConvergenceChecker.java | 40 ++-
.../convergence/ConvergenceCheckerFactory.java | 15 +-
.../mean/MeanAbsValueConvergenceChecker.java | 30 +-
.../MeanAbsValueConvergenceCheckerFactory.java | 11 +-
.../median/MedianOfMedianConvergenceChecker.java | 32 +--
.../MedianOfMedianConvergenceCheckerFactory.java | 11 +-
.../convergence/simple/ConvergenceCheckerStub.java | 30 +-
.../simple/ConvergenceCheckerStubFactory.java | 13 +-
.../ignite/ml/composition/boosting/loss/Loss.java | 4 +-
.../parallel/TrainersParallelComposition.java | 59 ++--
.../sequential/TrainersSequentialComposition.java | 180 ++++++------
.../stacking/StackedDatasetTrainer.java | 24 +-
.../apache/ignite/ml/dataset/DatasetFactory.java | 201 +++++++------
.../ignite/ml/dataset/feature/BucketMeta.java | 4 +-
.../ignite/ml/dataset/feature/Histogram.java | 8 +-
.../ignite/ml/dataset/feature/ObjectHistogram.java | 10 +-
.../dataset/feature/extractor/ExtractionUtils.java | 133 +++++++++
.../ml/dataset/feature/extractor/Vectorizer.java | 307 ++++++++++++++++++++
.../feature/extractor/impl/ArraysVectorizer.java} | 42 +--
.../extractor/impl/BinaryObjectVectorizer.java | 142 ++++++++++
.../feature/extractor/impl/DummyVectorizer.java} | 35 ++-
.../impl/FeatureLabelExtractorWrapper.java | 84 ++++++
.../extractor/impl/LabeledDummyVectorizer.java | 65 +++++
.../feature/extractor/impl/package-info.java} | 22 +-
.../feature/extractor/package-info.java} | 23 +-
.../bootstrapping/BootstrappedDatasetBuilder.java | 12 +-
.../BootstrappedDatasetPartition.java | 5 +-
.../impl/bootstrapping/BootstrappedVector.java | 7 +-
.../FeatureMatrixWithLabelsOnHeapDataBuilder.java | 34 +--
.../builder/data/SimpleDatasetDataBuilder.java | 16 +-
.../data/SimpleLabeledDatasetDataBuilder.java | 31 +-
.../ml/environment/logging/ConsoleLogger.java | 12 +-
.../parallelism/ParallelismStrategy.java | 10 +-
.../java/org/apache/ignite/ml/genetic/Gene.java | 5 +-
.../java/org/apache/ignite/ml/knn/KNNUtils.java | 17 +-
.../ml/knn/ann/ANNClassificationTrainer.java | 63 ++---
.../classification/KNNClassificationTrainer.java | 18 +-
.../ml/knn/regression/KNNRegressionTrainer.java | 17 +-
.../ignite/ml/math/primitives/vector/Vector.java | 19 +-
.../apache/ignite/ml/math/stat/Distribution.java | 7 +-
.../ignite/ml/math/stat/DistributionMixture.java | 13 +-
.../stat/MultivariateGaussianDistribution.java | 4 +-
.../ignite/ml/multiclass/MultiClassModel.java | 11 +-
.../ignite/ml/multiclass/OneVsRestTrainer.java | 42 +--
.../discrete/DiscreteNaiveBayesModel.java | 5 +-
.../discrete/DiscreteNaiveBayesTrainer.java | 20 +-
.../gaussian/GaussianNaiveBayesTrainer.java | 25 +-
.../java/org/apache/ignite/ml/nn/MLPTrainer.java | 14 +-
.../ignite/ml/nn/ReplicatedVectorMatrix.java | 22 +-
.../updatecalculators/RPropParameterUpdate.java | 24 +-
.../updatecalculators/SimpleGDParameterUpdate.java | 16 +-
.../org/apache/ignite/ml/pipeline/Pipeline.java | 8 +-
.../linear/LinearRegressionLSQRTrainer.java | 38 ++-
.../linear/LinearRegressionSGDTrainer.java | 24 +-
.../logistic/LogisticRegressionSGDTrainer.java | 22 +-
.../ignite/ml/selection/cv/CrossValidation.java | 20 +-
.../ml/selection/scoring/metric/MetricValues.java | 11 +-
.../ignite/ml/sql/SQLFeatureLabelExtractor.java | 125 --------
.../partition/LabelPartitionDataBuilderOnHeap.java | 21 +-
.../LabeledDatasetPartitionDataBuilderOnHeap.java | 30 +-
.../ml/svm/SVMLinearClassificationTrainer.java | 92 ++++--
.../ml/trainers/AdaptableDatasetTrainer.java | 65 ++---
.../apache/ignite/ml/trainers/DatasetTrainer.java | 314 ++++++++-------------
.../org/apache/ignite/ml/tree/DecisionTree.java | 112 ++++----
.../tree/boosting/GDBOnTreesLearningStrategy.java | 33 ++-
.../ml/tree/data/DecisionTreeDataBuilder.java | 13 +-
.../ml/tree/randomforest/RandomForestTrainer.java | 35 +--
.../ml/tree/randomforest/data/NodeSplit.java | 2 +-
.../ignite/ml/tree/randomforest/data/TreeNode.java | 13 +-
.../ignite/ml/tree/randomforest/data/TreeRoot.java | 7 +-
.../randomforest/data/impurity/GiniHistogram.java | 15 +-
.../data/impurity/ImpurityHistogram.java | 7 +-
.../data/impurity/ImpurityHistogramsComputer.java | 17 +-
.../randomforest/data/impurity/MSEHistogram.java | 11 +-
.../data/statistics/LeafValuesComputer.java | 17 +-
.../data/statistics/MeanValueStatistic.java | 2 +-
.../statistics/NormalDistributionStatistics.java | 12 +-
.../NormalDistributionStatisticsComputer.java | 11 +-
.../primitives/vector/VectorGenerator.java | 17 +-
.../primitives/vector/VectorGeneratorsFamily.java | 11 +-
.../test/java/org/apache/ignite/ml/TestUtils.java | 15 +-
.../ignite/ml/clustering/KMeansTrainerTest.java | 22 +-
.../clustering/gmm/GmmTrainerIntegrationTest.java | 8 +-
.../ignite/ml/clustering/gmm/GmmTrainerTest.java | 23 +-
.../apache/ignite/ml/common/KeepBinaryTest.java | 13 +-
.../apache/ignite/ml/common/LocalModelsTest.java | 20 +-
.../org/apache/ignite/ml/common/TrainerTest.java | 11 +-
.../apache/ignite/ml/composition/StackingTest.java | 21 +-
.../ignite/ml/composition/bagging/BaggingTest.java | 27 +-
.../ml/composition/boosting/GDBTrainerTest.java | 35 ++-
.../convergence/ConvergenceCheckerTest.java | 25 +-
.../mean/MeanAbsValueConvergenceCheckerTest.java | 15 +-
.../MedianOfMedianConvergenceCheckerTest.java | 9 +-
.../dataset/feature/extractor/VectorizerTest.java | 108 +++++++
.../ml/dataset/primitive/SimpleDatasetTest.java | 8 +-
.../primitive/SimpleLabeledDatasetTest.java | 14 +-
.../ml/environment/LearningEnvironmentTest.java | 22 +-
.../ignite/ml/knn/ANNClassificationTest.java | 21 +-
.../ignite/ml/knn/KNNClassificationTest.java | 49 ++--
.../apache/ignite/ml/knn/KNNRegressionTest.java | 24 +-
.../apache/ignite/ml/knn/LabeledDatasetHelper.java | 7 +-
.../java/org/apache/ignite/ml/math/BlasTest.java | 11 +-
.../ignite/ml/math/isolve/lsqr/LSQROnHeapTest.java | 27 +-
.../ignite/ml/multiclass/OneVsRestTrainerTest.java | 40 ++-
.../discrete/DiscreteNaiveBayesTest.java | 11 +-
.../discrete/DiscreteNaiveBayesTrainerTest.java | 24 +-
.../gaussian/GaussianNaiveBayesTest.java | 14 +-
.../gaussian/GaussianNaiveBayesTrainerTest.java | 33 +--
.../ignite/ml/nn/MLPTrainerIntegrationTest.java | 72 ++---
.../org/apache/ignite/ml/nn/MLPTrainerTest.java | 74 ++---
.../MLPTrainerMnistIntegrationTest.java | 18 +-
.../ml/nn/performance/MLPTrainerMnistTest.java | 22 +-
.../apache/ignite/ml/pipeline/PipelineTest.java | 9 +-
.../linear/LinearRegressionLSQRTrainerTest.java | 25 +-
.../linear/LinearRegressionSGDTrainerTest.java | 37 ++-
.../logistic/LogisticRegressionSGDTrainerTest.java | 29 +-
.../BinaryClassificationEvaluatorTest.java | 19 +-
.../selection/scoring/evaluator/EvaluatorTest.java | 4 +-
.../scoring/evaluator/RegressionEvaluatorTest.java | 12 +-
.../apache/ignite/ml/svm/SVMBinaryTrainerTest.java | 27 +-
...onTreeClassificationTrainerIntegrationTest.java | 13 +-
.../DecisionTreeClassificationTrainerTest.java | 21 +-
...cisionTreeRegressionTrainerIntegrationTest.java | 10 +-
.../ml/tree/DecisionTreeRegressionTrainerTest.java | 17 +-
.../DecisionTreeMNISTIntegrationTest.java | 10 +-
.../ml/tree/performance/DecisionTreeMNISTTest.java | 16 +-
.../RandomForestClassifierTrainerTest.java | 26 +-
.../randomforest/RandomForestIntegrationTest.java | 14 +-
.../RandomForestRegressionTrainerTest.java | 16 +-
.../DataStreamGeneratorFillCacheTest.java | 9 +-
.../util/generators/DataStreamGeneratorTest.java | 37 +--
208 files changed, 4572 insertions(+), 3825 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/TrainingWithBinaryObjectExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/TrainingWithBinaryObjectExample.java
index f8df0a8..59d96c5 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/TrainingWithBinaryObjectExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/TrainingWithBinaryObjectExample.java
@@ -26,10 +26,9 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
/**
* Example of support model training with binary objects.
@@ -43,23 +42,24 @@ public class TrainingWithBinaryObjectExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, BinaryObject> dataCache = populateCache(ignite);
+ IgniteCache<Integer, BinaryObject> dataCache = null;
+ try {
+ dataCache = populateCache(ignite);
- // Create dataset builder with enabled support of keeping binary for upstream cache.
- CacheBasedDatasetBuilder<Integer, BinaryObject> datasetBuilder =
- new CacheBasedDatasetBuilder<>(ignite, dataCache).withKeepBinary(true);
+ // Create dataset builder with enabled support of keeping binary for upstream cache.
+ CacheBasedDatasetBuilder<Integer, BinaryObject> datasetBuilder =
+ new CacheBasedDatasetBuilder<>(ignite, dataCache).withKeepBinary(true);
- //
- IgniteBiFunction<Integer, BinaryObject, Vector> featureExtractor
- = (k, v) -> VectorUtils.of(new double[] {v.field("feature1")});
+ Vectorizer<Integer, BinaryObject, String, Double> vectorizer =
+ new BinaryObjectVectorizer<Integer>("feature1").labeled("label");
- IgniteBiFunction<Integer, BinaryObject, Double> lbExtractor = (k, v) -> (double)v.field("label");
+ KMeansTrainer trainer = new KMeansTrainer();
+ KMeansModel kmdl = trainer.fit(datasetBuilder, vectorizer);
- KMeansTrainer trainer = new KMeansTrainer();
-
- KMeansModel kmdl = trainer.fit(datasetBuilder, featureExtractor, lbExtractor);
-
- System.out.println(">>> Model trained over binary objects. Model " + kmdl);
+ System.out.println(">>> Model trained over binary objects. Model " + kmdl);
+ } finally {
+ dataCache.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java
index d9f03c1..1769c36 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/GmmClusterizationExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.clustering;
-import java.util.concurrent.atomic.AtomicInteger;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -25,6 +24,7 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.clustering.gmm.GmmModel;
import org.apache.ignite.ml.clustering.gmm.GmmTrainer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.Tracer;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -35,6 +35,8 @@ import org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProduce
import org.apache.ignite.ml.util.generators.primitives.scalar.RandomProducer;
import org.apache.ignite.ml.util.generators.primitives.vector.VectorGeneratorsFamily;
+import java.util.concurrent.atomic.AtomicInteger;
+
/**
* Example of using GMM clusterization algorithm. Gaussian Mixture Algorithm (GMM, see {@link GmmModel}, {@link
* GmmTrainer}) can be used for input dataset data distribution representation as mixture of multivariance gaussians.
@@ -56,52 +58,58 @@ public class GmmClusterizationExample {
System.out.println(">>> Ignite grid started.");
long seed = 0;
- IgniteCache<Integer, LabeledVector<Double>> dataCache = ignite.getOrCreateCache(
- new CacheConfiguration<Integer, LabeledVector<Double>>("GMM_EXAMPLE_CACHE")
- .setAffinity(new RendezvousAffinityFunction(false, 10))
- );
- // Dataset consists of three gaussians where two from them are rotated onto PI/4.
- DataStreamGenerator dataStream = new VectorGeneratorsFamily.Builder().add(
- RandomProducer.vectorize(
- new GaussRandomProducer(0, 2., seed++),
- new GaussRandomProducer(0, 3., seed++)
- ).rotate(Math.PI / 4).move(VectorUtils.of(10., 10.))).add(
- RandomProducer.vectorize(
- new GaussRandomProducer(0, 1., seed++),
- new GaussRandomProducer(0, 2., seed++)
- ).rotate(-Math.PI / 4).move(VectorUtils.of(-10., 10.))).add(
- RandomProducer.vectorize(
- new GaussRandomProducer(0, 3., seed++),
- new GaussRandomProducer(0, 3., seed++)
- ).move(VectorUtils.of(0., -10.))
- ).build(seed++).asDataStream();
+ IgniteCache<Integer, LabeledVector<Double>> dataCache = null;
+ try {
+ dataCache = ignite.createCache(
+ new CacheConfiguration<Integer, LabeledVector<Double>>("GMM_EXAMPLE_CACHE")
+ .setAffinity(new RendezvousAffinityFunction(false, 10))
+ );
- AtomicInteger keyGen = new AtomicInteger();
- dataStream.fillCacheWithCustomKey(50000, dataCache, v -> keyGen.getAndIncrement());
- GmmTrainer trainer = new GmmTrainer(1);
+ // Dataset consists of three gaussians where two from them are rotated onto PI/4.
+ DataStreamGenerator dataStream = new VectorGeneratorsFamily.Builder().add(
+ RandomProducer.vectorize(
+ new GaussRandomProducer(0, 2., seed++),
+ new GaussRandomProducer(0, 3., seed++)
+ ).rotate(Math.PI / 4).move(VectorUtils.of(10., 10.))).add(
+ RandomProducer.vectorize(
+ new GaussRandomProducer(0, 1., seed++),
+ new GaussRandomProducer(0, 2., seed++)
+ ).rotate(-Math.PI / 4).move(VectorUtils.of(-10., 10.))).add(
+ RandomProducer.vectorize(
+ new GaussRandomProducer(0, 3., seed++),
+ new GaussRandomProducer(0, 3., seed++)
+ ).move(VectorUtils.of(0., -10.))
+ ).build(seed++).asDataStream();
- GmmModel mdl = trainer
- .withMaxCountIterations(10)
- .withMaxCountOfClusters(4)
- .withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed))
- .fit(ignite, dataCache, (k, v) -> v.features(), (k, v) -> v.label());
+ AtomicInteger keyGen = new AtomicInteger();
+ dataStream.fillCacheWithCustomKey(50000, dataCache, v -> keyGen.getAndIncrement());
+ GmmTrainer trainer = new GmmTrainer(1);
- System.out.println(">>> GMM means and covariances");
- for (int i = 0; i < mdl.countOfComponents(); i++) {
- MultivariateGaussianDistribution distribution = mdl.distributions().get(i);
- System.out.println();
- System.out.println("============");
- System.out.println("Component #" + i);
- System.out.println("============");
- System.out.println("Mean vector = ");
- Tracer.showAscii(distribution.mean());
- System.out.println();
- System.out.println("Covariance matrix = ");
- Tracer.showAscii(distribution.covariance());
- }
+ GmmModel mdl = trainer
+ .withMaxCountIterations(10)
+ .withMaxCountOfClusters(4)
+ .withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(seed))
+ .fit(ignite, dataCache, new LabeledDummyVectorizer<>());
+
+ System.out.println(">>> GMM means and covariances");
+ for (int i = 0; i < mdl.countOfComponents(); i++) {
+ MultivariateGaussianDistribution distribution = mdl.distributions().get(i);
+ System.out.println();
+ System.out.println("============");
+ System.out.println("Component #" + i);
+ System.out.println("============");
+ System.out.println("Mean vector = ");
+ Tracer.showAscii(distribution.mean());
+ System.out.println();
+ System.out.println("Covariance matrix = ");
+ Tracer.showAscii(distribution.covariance());
+ }
- System.out.println(">>>");
+ System.out.println(">>>");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
index e748f4d..31f36be 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.clustering;
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -26,11 +24,16 @@ import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.math.Tracer;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import javax.cache.Cache;
+import java.io.FileNotFoundException;
+
/**
* Run KMeans clustering algorithm ({@link KMeansTrainer}) over distributed dataset.
* <p>
@@ -54,40 +57,43 @@ public class KMeansClusterizationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
- KMeansTrainer trainer = new KMeansTrainer();
+ KMeansTrainer trainer = new KMeansTrainer();
- KMeansModel mdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
- );
+ KMeansModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)
+ );
- System.out.println(">>> KMeans centroids");
- Tracer.showAscii(mdl.getCenters()[0]);
- Tracer.showAscii(mdl.getCenters()[1]);
- System.out.println(">>>");
+ System.out.println(">>> KMeans centroids");
+ Tracer.showAscii(mdl.getCenters()[0]);
+ Tracer.showAscii(mdl.getCenters()[1]);
+ System.out.println(">>>");
- System.out.println(">>> --------------------------------------------");
- System.out.println(">>> | Predicted cluster\t| Erased class label\t|");
- System.out.println(">>> --------------------------------------------");
+ System.out.println(">>> --------------------------------------------");
+ System.out.println(">>> | Predicted cluster\t| Erased class label\t|");
+ System.out.println(">>> --------------------------------------------");
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = mdl.predict(inputs);
+ double prediction = mdl.predict(inputs);
- System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
+ System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> KMeans clustering algorithm over cached dataset usage example completed.");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> KMeans clustering algorithm over cached dataset usage example completed.");
+ }
+ } finally {
+ dataCache.destroy();
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/dataset/AlgorithmSpecificDatasetExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/dataset/AlgorithmSpecificDatasetExample.java
index 5148d9a..af0d184 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/dataset/AlgorithmSpecificDatasetExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/dataset/AlgorithmSpecificDatasetExample.java
@@ -18,29 +18,32 @@
package org.apache.ignite.examples.ml.dataset;
import com.github.fommil.netlib.BLAS;
-import java.io.Serializable;
-import java.util.Arrays;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.dataset.model.Person;
+import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetFactory;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.dataset.primitive.DatasetWrapper;
import org.apache.ignite.ml.dataset.primitive.builder.data.SimpleLabeledDatasetDataBuilder;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import java.io.Serializable;
+import java.util.Arrays;
+
/**
- * Example that shows how to implement your own algorithm
- * (<a href="https://en.wikipedia.org/wiki/Gradient_descent">gradient</a> descent trainer for linear regression)
- * which uses dataset as an underlying infrastructure.
+ * Example that shows how to implement your own algorithm (<a href="https://en.wikipedia.org/wiki/Gradient_descent">gradient</a>
+ * descent trainer for linear regression) which uses dataset as an underlying infrastructure.
* <p>
* Code in this example launches Ignite grid and fills the cache with simple test data.</p>
* <p>
- * After that it creates an algorithm specific dataset to perform linear regression as described in more detail below.</p>
+ * After that it creates an algorithm specific dataset to perform linear regression as described in more detail
+ * below.</p>
* <p>
* Finally, this example trains linear regression model using gradient descent and outputs the result.</p>
* <p>
@@ -52,13 +55,12 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
* {@link DatasetWrapper}) in a sequential manner.</p>
* <p>
* In this example we need to implement gradient descent. This is iterative method that involves calculation of gradient
- * on every step. In according with the common idea we define
- * {@link AlgorithmSpecificDatasetExample.AlgorithmSpecificDataset} - extended version of {@code Dataset} with
- * {@code gradient} method. As a result our gradient descent method looks like a simple loop where every iteration
- * includes call of the {@code gradient} method. In the example we want to keep iteration number as well for logging.
- * Iteration number cannot be recovered from the {@code upstream} data and we need to keep it in the custom
- * partition {@code context} which is represented by
- * {@link AlgorithmSpecificDatasetExample.AlgorithmSpecificPartitionContext} class.</p>
+ * on every step. In according with the common idea we define {@link AlgorithmSpecificDatasetExample.AlgorithmSpecificDataset}
+ * - extended version of {@code Dataset} with {@code gradient} method. As a result our gradient descent method looks
+ * like a simple loop where every iteration includes call of the {@code gradient} method. In the example we want to keep
+ * iteration number as well for logging. Iteration number cannot be recovered from the {@code upstream} data and we need
+ * to keep it in the custom partition {@code context} which is represented by {@link
+ * AlgorithmSpecificDatasetExample.AlgorithmSpecificPartitionContext} class.</p>
*/
public class AlgorithmSpecificDatasetExample {
/** Run example. */
@@ -66,49 +68,58 @@ public class AlgorithmSpecificDatasetExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Algorithm Specific Dataset example started.");
- IgniteCache<Integer, Person> persons = createCache(ignite);
+ IgniteCache<Integer, Person> persons = null;
+ try {
+ persons = createCache(ignite);
- // Creates a algorithm specific dataset to perform linear regression. Here we define the way features and
- // labels are extracted, and partition data and context are created.
- try (AlgorithmSpecificDataset dataset = DatasetFactory.create(
- ignite,
- persons,
- (env, upstream, upstreamSize) -> new AlgorithmSpecificPartitionContext(),
- new SimpleLabeledDatasetDataBuilder<Integer, Person, AlgorithmSpecificPartitionContext>(
- (k, v) -> VectorUtils.of(v.getAge()),
- (k, v) -> new double[] {v.getSalary()}
- ).andThen((data, ctx) -> {
- double[] features = data.getFeatures();
- int rows = data.getRows();
+ // Creates a algorithm specific dataset to perform linear regression. Here we define the way features and
+ // labels are extracted, and partition data and context are created.
+ SimpleLabeledDatasetDataBuilder<Integer, Person, AlgorithmSpecificPartitionContext, ? extends Serializable> builder =
+ new SimpleLabeledDatasetDataBuilder<>(new FeatureLabelExtractorWrapper<>(CompositionUtils.asFeatureLabelExtractor(
+ (k, v) -> VectorUtils.of(v.getAge()),
+ (k, v) -> new double[] {v.getSalary()}
+ )));
- // Makes a copy of features to supplement it by columns with values equal to 1.0.
- double[] a = new double[features.length + rows];
+ try (AlgorithmSpecificDataset dataset = DatasetFactory.create(
+ ignite,
+ persons,
+ (env, upstream, upstreamSize) -> new AlgorithmSpecificPartitionContext(),
+ builder.andThen((data, ctx) -> {
+ double[] features = data.getFeatures();
+ int rows = data.getRows();
- for (int i = 0; i < rows; i++)
- a[i] = 1.0;
+ // Makes a copy of features to supplement it by columns with values equal to 1.0.
+ double[] a = new double[features.length + rows];
- System.arraycopy(features, 0, a, rows, features.length);
+ for (int i = 0; i < rows; i++)
+ a[i] = 1.0;
- return new SimpleLabeledDatasetData(a, data.getLabels(), rows);
- })
- ).wrap(AlgorithmSpecificDataset::new)) {
- // Trains linear regression model using gradient descent.
- double[] linearRegressionMdl = new double[2];
+ System.arraycopy(features, 0, a, rows, features.length);
- for (int i = 0; i < 1000; i++) {
- double[] gradient = dataset.gradient(linearRegressionMdl);
+ return new SimpleLabeledDatasetData(a, data.getLabels(), rows);
+ })
+ ).wrap(AlgorithmSpecificDataset::new)) {
+ // Trains linear regression model using gradient descent.
+ double[] linearRegressionMdl = new double[2];
- if (BLAS.getInstance().dnrm2(gradient.length, gradient, 1) < 1e-4)
- break;
+ for (int i = 0; i < 1000; i++) {
+ double[] gradient = dataset.gradient(linearRegressionMdl);
- for (int j = 0; j < gradient.length; j++)
- linearRegressionMdl[j] -= 0.1 / persons.size() * gradient[j];
+ if (BLAS.getInstance().dnrm2(gradient.length, gradient, 1) < 1e-4)
+ break;
+
+ for (int j = 0; j < gradient.length; j++)
+ linearRegressionMdl[j] -= 0.1 / persons.size() * gradient[j];
+ }
+
+ System.out.println("Linear Regression Model: " + Arrays.toString(linearRegressionMdl));
}
- System.out.println("Linear Regression Model: " + Arrays.toString(linearRegressionMdl));
+ System.out.println(">>> Algorithm Specific Dataset example completed.");
+ }
+ finally {
+ persons.destroy();
}
-
- System.out.println(">>> Algorithm Specific Dataset example completed.");
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/dataset/CacheBasedDatasetExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/dataset/CacheBasedDatasetExample.java
index 3f75540..8e026a4 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/dataset/CacheBasedDatasetExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/dataset/CacheBasedDatasetExample.java
@@ -25,6 +25,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.dataset.model.Person;
import org.apache.ignite.examples.ml.util.DatasetHelper;
import org.apache.ignite.ml.dataset.DatasetFactory;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.dataset.primitive.SimpleDataset;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -34,8 +35,8 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
* <p>
* Code in this example launches Ignite grid and fills the cache with simple test data.</p>
* <p>
- * After that it creates the dataset based on the data in the cache and uses Dataset API to find and output
- * various statistical metrics of the data.</p>
+ * After that it creates the dataset based on the data in the cache and uses Dataset API to find and output various
+ * statistical metrics of the data.</p>
* <p>
* You can change the test data used in this example and re-run it to explore this functionality further.</p>
*/
@@ -45,19 +46,25 @@ public class CacheBasedDatasetExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Cache Based Dataset example started.");
- IgniteCache<Integer, Person> persons = createCache(ignite);
+ IgniteCache<Integer, Person> persons = null;
+ try {
+ persons = createCache(ignite);
- // Creates a cache based simple dataset containing features and providing standard dataset API.
- try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(
- ignite,
- persons,
- (k, v) -> VectorUtils.of(v.getAge(), v.getSalary())
- )) {
- new DatasetHelper(dataset).describe();
- }
+ // Creates a cache based simple dataset containing features and providing standard dataset API.
+ try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(
+ ignite,
+ persons,
+ FeatureLabelExtractorWrapper.wrap((k, v) -> VectorUtils.of(v.getAge(), v.getSalary()))
+ )) {
+ new DatasetHelper(dataset).describe();
+ }
- System.out.println(">>> Cache Based Dataset example completed.");
+ System.out.println(">>> Cache Based Dataset example completed.");
+ } finally {
+ persons.destroy();
+ }
}
+
}
/** */
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/IgniteModelDistributedInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/IgniteModelDistributedInferenceExample.java
index 8a43a79..7126c69 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/IgniteModelDistributedInferenceExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/IgniteModelDistributedInferenceExample.java
@@ -17,16 +17,14 @@
package org.apache.ignite.examples.ml.inference;
-import java.io.IOException;
-import java.util.concurrent.ExecutionException;
-import java.util.concurrent.Future;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.examples.ml.regression.linear.LinearRegressionLSQRTrainerExample;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.inference.Model;
import org.apache.ignite.ml.inference.builder.IgniteDistributedModelBuilder;
import org.apache.ignite.ml.inference.parser.IgniteModelParser;
@@ -39,6 +37,11 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import javax.cache.Cache;
+import java.io.IOException;
+import java.util.concurrent.ExecutionException;
+import java.util.concurrent.Future;
+
/**
* This example is based on {@link LinearRegressionLSQRTrainerExample}, but to perform inference it uses an approach
* implemented in {@link org.apache.ignite.ml.inference} package.
@@ -52,49 +55,52 @@ public class IgniteModelDistributedInferenceExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
- System.out.println(">>> Create new linear regression trainer object.");
- LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
+ System.out.println(">>> Create new linear regression trainer object.");
+ LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
- System.out.println(">>> Perform the training to get the model.");
- LinearRegressionModel mdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
- );
+ System.out.println(">>> Perform the training to get the model.");
+ LinearRegressionModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)
+ );
- System.out.println(">>> Linear regression model: " + mdl);
+ System.out.println(">>> Linear regression model: " + mdl);
- System.out.println(">>> Preparing model reader and model parser.");
- ModelReader reader = new InMemoryModelReader(mdl);
- ModelParser<Vector, Double, ?> parser = new IgniteModelParser<>();
- try (Model<Vector, Future<Double>> infMdl = new IgniteDistributedModelBuilder(ignite, 4, 4)
- .build(reader, parser)) {
- System.out.println(">>> Inference model is ready.");
+ System.out.println(">>> Preparing model reader and model parser.");
+ ModelReader reader = new InMemoryModelReader(mdl);
+ ModelParser<Vector, Double, ?> parser = new IgniteModelParser<>();
+ try (Model<Vector, Future<Double>> infMdl = new IgniteDistributedModelBuilder(ignite, 4, 4)
+ .build(reader, parser)) {
+ System.out.println(">>> Inference model is ready.");
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = infMdl.predict(inputs).get();
+ double prediction = infMdl.predict(inputs).get();
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
}
}
- }
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
- System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+ System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/LogRegFromSparkThroughPMMLExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/LogRegFromSparkThroughPMMLExample.java
index fdcea6a..46c20bd 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/LogRegFromSparkThroughPMMLExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/LogRegFromSparkThroughPMMLExample.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.inference.spark;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
@@ -53,23 +54,29 @@ public class LogRegFromSparkThroughPMMLExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
- LogisticRegressionModel mdl = PMMLParser.load("examples/src/main/resources/models/spark/iris.pmml");
+ String path = IgniteUtils.resolveIgnitePath("examples/src/main/resources/models/spark/iris.pmml")
+ .toPath().toAbsolutePath().toString();
+ LogisticRegressionModel mdl = PMMLParser.load(path);
- System.out.println(">>> Logistic regression model: " + mdl);
+ System.out.println(">>> Logistic regression model: " + mdl);
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0),
- new Accuracy<>()
- );
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ (k, v) -> v.copyOfRange(1, v.size()),
+ (k, v) -> v.get(0),
+ new Accuracy<>()
+ );
- System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println("\n>>> Test Error " + (1 - accuracy));
+ System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Test Error " + (1 - accuracy));
+ } finally {
+ dataCache.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java
index 3af9916..1e5afe6 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java
@@ -21,6 +21,7 @@ import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.examples.ml.tutorial.TitanicUtils;
+import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -40,7 +41,8 @@ import java.io.FileNotFoundException;
*/
public class DecisionTreeFromSparkExample {
/** Path to Spark DT model. */
- public static final String SPARK_MDL_PATH = "examples/src/main/resources/models/spark/serialized/dt";
+ public static final String SPARK_MDL_PATH = IgniteUtils.resolveIgnitePath("examples/src/main/resources/models/spark/serialized/dt")
+ .toPath().toAbsolutePath().toString();
/** Run example. */
public static void main(String[] args) throws FileNotFoundException {
@@ -50,36 +52,41 @@ public class DecisionTreeFromSparkExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+ IgniteCache<Integer, Object[]> dataCache = null;
+ try {
+ dataCache = TitanicUtils.readPassengers(ignite);
- IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
- double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
- data[0] = Double.isNaN(data[0]) ? 0 : data[0];
- data[1] = Double.isNaN(data[1]) ? 0 : data[1];
- data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
- return VectorUtils.of(data);
- };
+ return VectorUtils.of(data);
+ };
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
- DecisionTreeNode mdl = (DecisionTreeNode)SparkModelParser.parse(
- SPARK_MDL_PATH,
- SupportedSparkModels.DECISION_TREE
- );
+ DecisionTreeNode mdl = (DecisionTreeNode)SparkModelParser.parse(
+ SPARK_MDL_PATH,
+ SupportedSparkModels.DECISION_TREE
+ );
- System.out.println(">>> DT: " + mdl);
+ System.out.println(">>> DT: " + mdl);
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor,
- new Accuracy<>()
- );
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor,
+ new Accuracy<>()
+ );
- System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println("\n>>> Test Error " + (1 - accuracy));
+ System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Test Error " + (1 - accuracy));
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeRegressionFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeRegressionFromSparkExample.java
index 0dbfd2a..881b767 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeRegressionFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeRegressionFromSparkExample.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.inference.spark.modelparser;
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -32,6 +30,9 @@ import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import javax.cache.Cache;
+import java.io.FileNotFoundException;
+
/**
* Run Decision tree regression model loaded from snappy.parquet file.
* The snappy.parquet file was generated by Spark MLLib model.write.overwrite().save(..) operator.
@@ -50,41 +51,46 @@ public class DecisionTreeRegressionFromSparkExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+ IgniteCache<Integer, Object[]> dataCache = null;
+ try {
+ dataCache = TitanicUtils.readPassengers(ignite);
- IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
- double[] data = new double[] {(double)v[0], (double)v[1], (double)v[5], (double)v[6]};
- data[0] = Double.isNaN(data[0]) ? 0 : data[0];
- data[1] = Double.isNaN(data[1]) ? 0 : data[1];
- data[2] = Double.isNaN(data[2]) ? 0 : data[2];
- data[3] = Double.isNaN(data[3]) ? 0 : data[3];
- return VectorUtils.of(data);
- };
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[] {(double)v[0], (double)v[1], (double)v[5], (double)v[6]};
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+ data[3] = Double.isNaN(data[3]) ? 0 : data[3];
+ return VectorUtils.of(data);
+ };
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[4];
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[4];
- DecisionTreeNode mdl = (DecisionTreeNode)SparkModelParser.parse(
- SPARK_MDL_PATH,
- SupportedSparkModels.DECISION_TREE_REGRESSION
- );
+ DecisionTreeNode mdl = (DecisionTreeNode)SparkModelParser.parse(
+ SPARK_MDL_PATH,
+ SupportedSparkModels.DECISION_TREE_REGRESSION
+ );
- System.out.println(">>> Decision tree regression model: " + mdl);
+ System.out.println(">>> Decision tree regression model: " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Object[]> observation : observations) {
- Vector inputs = featureExtractor.apply(observation.getKey(), observation.getValue());
- double groundTruth = lbExtractor.apply(observation.getKey(), observation.getValue());
- double prediction = mdl.predict(inputs);
+ try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Object[]> observation : observations) {
+ Vector inputs = featureExtractor.apply(observation.getKey(), observation.getValue());
+ double groundTruth = lbExtractor.apply(observation.getKey(), observation.getValue());
+ double prediction = mdl.predict(inputs);
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
}
- }
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTFromSparkExample.java
index 33e5cca..40be633 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTFromSparkExample.java
@@ -50,36 +50,41 @@ public class GBTFromSparkExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+ IgniteCache<Integer, Object[]> dataCache = null;
+ try {
+ dataCache = TitanicUtils.readPassengers(ignite);
- IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
- double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
- data[0] = Double.isNaN(data[0]) ? 0 : data[0];
- data[1] = Double.isNaN(data[1]) ? 0 : data[1];
- data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
- return VectorUtils.of(data);
- };
+ return VectorUtils.of(data);
+ };
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
- ModelsComposition mdl = (ModelsComposition)SparkModelParser.parse(
- SPARK_MDL_PATH,
- SupportedSparkModels.GRADIENT_BOOSTED_TREES
- );
+ ModelsComposition mdl = (ModelsComposition)SparkModelParser.parse(
+ SPARK_MDL_PATH,
+ SupportedSparkModels.GRADIENT_BOOSTED_TREES
+ );
- System.out.println(">>> GBT: " + mdl.toString(true));
+ System.out.println(">>> GBT: " + mdl.toString(true));
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor,
- new Accuracy<>()
- );
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor,
+ new Accuracy<>()
+ );
- System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println("\n>>> Test Error " + (1 - accuracy));
+ System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Test Error " + (1 - accuracy));
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTRegressionFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTRegressionFromSparkExample.java
index d0b2548..0b9f14c 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTRegressionFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/GBTRegressionFromSparkExample.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.inference.spark.modelparser;
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -32,6 +30,9 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
+import javax.cache.Cache;
+import java.io.FileNotFoundException;
+
/**
* Run GBT Regression model loaded from snappy.parquet file.
* The snappy.parquet file was generated by Spark MLLib model.write.overwrite().save(..) operator.
@@ -50,41 +51,46 @@ public class GBTRegressionFromSparkExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+ IgniteCache<Integer, Object[]> dataCache = null;
+ try {
+ dataCache = TitanicUtils.readPassengers(ignite);
- IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
- double[] data = new double[] {(double)v[0], (double)v[1], (double)v[5], (double)v[6]};
- data[0] = Double.isNaN(data[0]) ? 0 : data[0];
- data[1] = Double.isNaN(data[1]) ? 0 : data[1];
- data[2] = Double.isNaN(data[2]) ? 0 : data[2];
- data[3] = Double.isNaN(data[3]) ? 0 : data[3];
- return VectorUtils.of(data);
- };
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[] {(double)v[0], (double)v[1], (double)v[5], (double)v[6]};
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+ data[3] = Double.isNaN(data[3]) ? 0 : data[3];
+ return VectorUtils.of(data);
+ };
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[4];
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[4];
- ModelsComposition mdl = (ModelsComposition)SparkModelParser.parse(
- SPARK_MDL_PATH,
- SupportedSparkModels.GRADIENT_BOOSTED_TREES_REGRESSION
- );
+ ModelsComposition mdl = (ModelsComposition)SparkModelParser.parse(
+ SPARK_MDL_PATH,
+ SupportedSparkModels.GRADIENT_BOOSTED_TREES_REGRESSION
+ );
- System.out.println(">>> GBT Regression model: " + mdl);
+ System.out.println(">>> GBT Regression model: " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Object[]> observation : observations) {
- Vector inputs = featureExtractor.apply(observation.getKey(), observation.getValue());
- double groundTruth = lbExtractor.apply(observation.getKey(), observation.getValue());
- double prediction = mdl.predict(inputs);
+ try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Object[]> observation : observations) {
+ Vector inputs = featureExtractor.apply(observation.getKey(), observation.getValue());
+ double groundTruth = lbExtractor.apply(observation.getKey(), observation.getValue());
+ double prediction = mdl.predict(inputs);
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
}
- }
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/KMeansFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/KMeansFromSparkExample.java
index d76b158..07a491a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/KMeansFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/KMeansFromSparkExample.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.inference.spark.modelparser;
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -32,6 +30,9 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
+import javax.cache.Cache;
+import java.io.FileNotFoundException;
+
/**
* Run KMeans model loaded from snappy.parquet file.
* The snappy.parquet file was generated by Spark MLLib model.write.overwrite().save(..) operator.
@@ -50,40 +51,45 @@ public class KMeansFromSparkExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+ IgniteCache<Integer, Object[]> dataCache = null;
+ try {
+ dataCache = TitanicUtils.readPassengers(ignite);
- IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
- double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6], (double)v[4]};
- data[0] = Double.isNaN(data[0]) ? 0 : data[0];
- data[1] = Double.isNaN(data[1]) ? 0 : data[1];
- data[2] = Double.isNaN(data[2]) ? 0 : data[2];
- data[3] = Double.isNaN(data[3]) ? 0 : data[3];
- return VectorUtils.of(data);
- };
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6], (double)v[4]};
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+ data[3] = Double.isNaN(data[3]) ? 0 : data[3];
+ return VectorUtils.of(data);
+ };
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
- KMeansModel mdl = (KMeansModel)SparkModelParser.parse(
- SPARK_MDL_PATH,
- SupportedSparkModels.KMEANS
- );
+ KMeansModel mdl = (KMeansModel)SparkModelParser.parse(
+ SPARK_MDL_PATH,
+ SupportedSparkModels.KMEANS
+ );
- System.out.println(">>> K-Means model: " + mdl);
- System.out.println(">>> ------------------------------------");
- System.out.println(">>> | Predicted cluster\t| Is survived\t|");
- System.out.println(">>> ------------------------------------");
+ System.out.println(">>> K-Means model: " + mdl);
+ System.out.println(">>> ------------------------------------");
+ System.out.println(">>> | Predicted cluster\t| Is survived\t|");
+ System.out.println(">>> ------------------------------------");
- try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Object[]> observation : observations) {
- Vector inputs = featureExtractor.apply(observation.getKey(), observation.getValue());
- double isSurvived = lbExtractor.apply(observation.getKey(), observation.getValue());
- double clusterId = mdl.predict(inputs);
+ try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Object[]> observation : observations) {
+ Vector inputs = featureExtractor.apply(observation.getKey(), observation.getValue());
+ double isSurvived = lbExtractor.apply(observation.getKey(), observation.getValue());
+ double clusterId = mdl.predict(inputs);
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", clusterId, isSurvived);
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", clusterId, isSurvived);
+ }
}
- }
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LinearRegressionFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LinearRegressionFromSparkExample.java
index 8902c72..0c313a8 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LinearRegressionFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LinearRegressionFromSparkExample.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.inference.spark.modelparser;
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -32,6 +30,9 @@ import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
+import javax.cache.Cache;
+import java.io.FileNotFoundException;
+
/**
* Run linear regression model loaded from snappy.parquet file.
* The snappy.parquet file was generated by Spark MLLib model.write.overwrite().save(..) operator.
@@ -50,41 +51,46 @@ public class LinearRegressionFromSparkExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+ IgniteCache<Integer, Object[]> dataCache = null;
+ try {
+ dataCache = TitanicUtils.readPassengers(ignite);
- IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
- double[] data = new double[] {(double)v[0], (double)v[1], (double)v[5], (double)v[6]};
- data[0] = Double.isNaN(data[0]) ? 0 : data[0];
- data[1] = Double.isNaN(data[1]) ? 0 : data[1];
- data[2] = Double.isNaN(data[2]) ? 0 : data[2];
- data[3] = Double.isNaN(data[3]) ? 0 : data[3];
- return VectorUtils.of(data);
- };
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[] {(double)v[0], (double)v[1], (double)v[5], (double)v[6]};
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+ data[3] = Double.isNaN(data[3]) ? 0 : data[3];
+ return VectorUtils.of(data);
+ };
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[4];
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[4];
- LinearRegressionModel mdl = (LinearRegressionModel)SparkModelParser.parse(
- SPARK_MDL_PATH,
- SupportedSparkModels.LINEAR_REGRESSION
- );
+ LinearRegressionModel mdl = (LinearRegressionModel)SparkModelParser.parse(
+ SPARK_MDL_PATH,
+ SupportedSparkModels.LINEAR_REGRESSION
+ );
- System.out.println(">>> Linear regression model: " + mdl);
+ System.out.println(">>> Linear regression model: " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Object[]> observation : observations) {
- Vector inputs = featureExtractor.apply(observation.getKey(), observation.getValue());
- double groundTruth = lbExtractor.apply(observation.getKey(), observation.getValue());
- double prediction = mdl.predict(inputs);
+ try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Object[]> observation : observations) {
+ Vector inputs = featureExtractor.apply(observation.getKey(), observation.getValue());
+ double groundTruth = lbExtractor.apply(observation.getKey(), observation.getValue());
+ double prediction = mdl.predict(inputs);
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
}
- }
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LogRegFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LogRegFromSparkExample.java
index c927f44..a2bc5c3 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LogRegFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/LogRegFromSparkExample.java
@@ -50,36 +50,41 @@ public class LogRegFromSparkExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+ IgniteCache<Integer, Object[]> dataCache = null;
+ try {
+ dataCache = TitanicUtils.readPassengers(ignite);
- IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
- double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
- data[0] = Double.isNaN(data[0]) ? 0 : data[0];
- data[1] = Double.isNaN(data[1]) ? 0 : data[1];
- data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
- return VectorUtils.of(data);
- };
+ return VectorUtils.of(data);
+ };
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
- LogisticRegressionModel mdl = (LogisticRegressionModel) SparkModelParser.parse(
- SPARK_MDL_PATH,
- SupportedSparkModels.LOG_REGRESSION
+ LogisticRegressionModel mdl = (LogisticRegressionModel)SparkModelParser.parse(
+ SPARK_MDL_PATH,
+ SupportedSparkModels.LOG_REGRESSION
);
- System.out.println(">>> Logistic regression model: " + mdl);
+ System.out.println(">>> Logistic regression model: " + mdl);
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor,
- new Accuracy<>()
- );
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor,
+ new Accuracy<>()
+ );
- System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println("\n>>> Test Error " + (1 - accuracy));
+ System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Test Error " + (1 - accuracy));
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestFromSparkExample.java
index 1bfe41f..819f559 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestFromSparkExample.java
@@ -50,36 +50,41 @@ public class RandomForestFromSparkExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+ IgniteCache<Integer, Object[]> dataCache = null;
+ try {
+ dataCache = TitanicUtils.readPassengers(ignite);
- IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
- double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
- data[0] = Double.isNaN(data[0]) ? 0 : data[0];
- data[1] = Double.isNaN(data[1]) ? 0 : data[1];
- data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
- return VectorUtils.of(data);
- };
+ return VectorUtils.of(data);
+ };
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
- ModelsComposition mdl = (ModelsComposition)SparkModelParser.parse(
- SPARK_MDL_PATH,
- SupportedSparkModels.RANDOM_FOREST
- );
+ ModelsComposition mdl = (ModelsComposition)SparkModelParser.parse(
+ SPARK_MDL_PATH,
+ SupportedSparkModels.RANDOM_FOREST
+ );
- System.out.println(">>> Random Forest model: " + mdl.toString(true));
+ System.out.println(">>> Random Forest model: " + mdl.toString(true));
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor,
- new Accuracy<>()
- );
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor,
+ new Accuracy<>()
+ );
- System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println("\n>>> Test Error " + (1 - accuracy));
+ System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Test Error " + (1 - accuracy));
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestRegressionFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestRegressionFromSparkExample.java
index 42c6699..b5da71a 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestRegressionFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/RandomForestRegressionFromSparkExample.java
@@ -17,8 +17,6 @@
package org.apache.ignite.examples.ml.inference.spark.modelparser;
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -32,6 +30,9 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
+import javax.cache.Cache;
+import java.io.FileNotFoundException;
+
/**
* Run Random Forest regression model loaded from snappy.parquet file.
* The snappy.parquet file was generated by Spark MLLib model.write.overwrite().save(..) operator.
@@ -50,41 +51,46 @@ public class RandomForestRegressionFromSparkExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+ IgniteCache<Integer, Object[]> dataCache = null;
+ try {
+ dataCache = TitanicUtils.readPassengers(ignite);
- IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
- double[] data = new double[] {(double)v[0], (double)v[1], (double)v[5], (double)v[6]};
- data[0] = Double.isNaN(data[0]) ? 0 : data[0];
- data[1] = Double.isNaN(data[1]) ? 0 : data[1];
- data[2] = Double.isNaN(data[2]) ? 0 : data[2];
- data[3] = Double.isNaN(data[3]) ? 0 : data[3];
- return VectorUtils.of(data);
- };
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[] {(double)v[0], (double)v[1], (double)v[5], (double)v[6]};
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+ data[3] = Double.isNaN(data[3]) ? 0 : data[3];
+ return VectorUtils.of(data);
+ };
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[4];
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[4];
- ModelsComposition mdl = (ModelsComposition)SparkModelParser.parse(
- SPARK_MDL_PATH,
- SupportedSparkModels.RANDOM_FOREST_REGRESSION
- );
+ ModelsComposition mdl = (ModelsComposition)SparkModelParser.parse(
+ SPARK_MDL_PATH,
+ SupportedSparkModels.RANDOM_FOREST_REGRESSION
+ );
- System.out.println(">>> Random Forest regression model: " + mdl);
+ System.out.println(">>> Random Forest regression model: " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Object[]> observation : observations) {
- Vector inputs = featureExtractor.apply(observation.getKey(), observation.getValue());
- double groundTruth = lbExtractor.apply(observation.getKey(), observation.getValue());
- double prediction = mdl.predict(inputs);
+ try (QueryCursor<Cache.Entry<Integer, Object[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Object[]> observation : observations) {
+ Vector inputs = featureExtractor.apply(observation.getKey(), observation.getValue());
+ double groundTruth = lbExtractor.apply(observation.getKey(), observation.getValue());
+ double prediction = mdl.predict(inputs);
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
}
- }
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/SVMFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/SVMFromSparkExample.java
index 888bd54..704cda3 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/SVMFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/SVMFromSparkExample.java
@@ -50,36 +50,41 @@ public class SVMFromSparkExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
+ IgniteCache<Integer, Object[]> dataCache = null;
+ try {
+ dataCache = TitanicUtils.readPassengers(ignite);
- IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
- double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
- data[0] = Double.isNaN(data[0]) ? 0 : data[0];
- data[1] = Double.isNaN(data[1]) ? 0 : data[1];
- data[2] = Double.isNaN(data[2]) ? 0 : data[2];
+ IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
+ double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
+ data[0] = Double.isNaN(data[0]) ? 0 : data[0];
+ data[1] = Double.isNaN(data[1]) ? 0 : data[1];
+ data[2] = Double.isNaN(data[2]) ? 0 : data[2];
- return VectorUtils.of(data);
- };
+ return VectorUtils.of(data);
+ };
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
- SVMLinearClassificationModel mdl = (SVMLinearClassificationModel)SparkModelParser.parse(
- SPARK_MDL_PATH,
- SupportedSparkModels.LINEAR_SVM
- );
+ SVMLinearClassificationModel mdl = (SVMLinearClassificationModel)SparkModelParser.parse(
+ SPARK_MDL_PATH,
+ SupportedSparkModels.LINEAR_SVM
+ );
- System.out.println(">>> SVM: " + mdl);
+ System.out.println(">>> SVM: " + mdl);
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor,
- new Accuracy<>()
- );
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor,
+ new Accuracy<>()
+ );
- System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println("\n>>> Test Error " + (1 - accuracy));
+ System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Test Error " + (1 - accuracy));
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
index a5d15d1..a7d1a0b 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
@@ -17,9 +17,6 @@
package org.apache.ignite.examples.ml.knn;
-import java.util.Arrays;
-import java.util.UUID;
-import javax.cache.Cache;
import org.apache.commons.math3.util.Precision;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
@@ -28,14 +25,19 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.ArraysVectorizer;
import org.apache.ignite.ml.knn.NNClassificationModel;
import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.distances.ManhattanDistance;
-import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import javax.cache.Cache;
+import java.util.Arrays;
+import java.util.UUID;
+
/**
* Run ANN multi-class classification trainer ({@link ANNClassificationTrainer}) over distributed dataset.
* <p>
@@ -59,65 +61,69 @@ public class ANNClassificationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
+ IgniteCache<Integer, double[]> dataCache = null;
+ try {
+ dataCache = getTestCache(ignite);
- ANNClassificationTrainer trainer = new ANNClassificationTrainer()
- .withDistance(new ManhattanDistance())
- .withK(50)
- .withMaxIterations(1000)
- .withEpsilon(1e-2);
+ ANNClassificationTrainer trainer = new ANNClassificationTrainer()
+ .withDistance(new ManhattanDistance())
+ .withK(50)
+ .withMaxIterations(1000)
+ .withEpsilon(1e-2);
- long startTrainingTime = System.currentTimeMillis();
+ long startTrainingTime = System.currentTimeMillis();
- NNClassificationModel knnMdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
- (k, v) -> v[0]
- ).withK(5)
- .withDistanceMeasure(new EuclideanDistance())
- .withStrategy(NNStrategy.WEIGHTED);
+ NNClassificationModel knnMdl = trainer.fit(
+ ignite,
+ dataCache,
+ new ArraysVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)
+ ).withK(5)
+ .withDistanceMeasure(new EuclideanDistance())
+ .withStrategy(NNStrategy.WEIGHTED);
- long endTrainingTime = System.currentTimeMillis();
+ long endTrainingTime = System.currentTimeMillis();
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- int amountOfErrors = 0;
- int totalAmount = 0;
+ int amountOfErrors = 0;
+ int totalAmount = 0;
- long totalPredictionTime = 0L;
+ long totalPredictionTime = 0L;
- try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, double[]> observation : observations) {
- double[] val = observation.getValue();
- double[] inputs = Arrays.copyOfRange(val, 1, val.length);
- double groundTruth = val[0];
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
- long startPredictionTime = System.currentTimeMillis();
- double prediction = knnMdl.predict(new DenseVector(inputs));
- long endPredictionTime = System.currentTimeMillis();
+ long startPredictionTime = System.currentTimeMillis();
+ double prediction = knnMdl.predict(new DenseVector(inputs));
+ long endPredictionTime = System.currentTimeMillis();
- totalPredictionTime += (endPredictionTime - startPredictionTime);
+ totalPredictionTime += (endPredictionTime - startPredictionTime);
- totalAmount++;
- if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
- amountOfErrors++;
+ totalAmount++;
+ if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
+ amountOfErrors++;
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
- System.out.println("Training costs = " + (endTrainingTime - startTrainingTime));
- System.out.println("Prediction costs = " + totalPredictionTime);
+ System.out.println("Training costs = " + (endTrainingTime - startTrainingTime));
+ System.out.println("Prediction costs = " + totalPredictionTime);
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
- System.out.println(totalAmount);
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ System.out.println(totalAmount);
- System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
+ System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
+ }
+ } finally {
+ dataCache.destroy();
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
index 8a2e095..c0e4905 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNClassificationExample.java
@@ -20,11 +20,13 @@ package org.apache.ignite.examples.ml.knn;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.knn.NNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.math.distances.EuclideanDistance;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.util.MLSandboxDatasets;
@@ -55,31 +57,30 @@ public class KNNClassificationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
- KNNClassificationTrainer trainer = new KNNClassificationTrainer();
+ KNNClassificationTrainer trainer = new KNNClassificationTrainer();
- IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
- IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- NNClassificationModel mdl = trainer.fit(
- ignite,
- dataCache,
- featureExtractor,
- lbExtractor
- ).withK(3)
- .withDistanceMeasure(new EuclideanDistance())
- .withStrategy(NNStrategy.WEIGHTED);
+ NNClassificationModel mdl = trainer.fit(ignite, dataCache, vectorizer).withK(3)
+ .withDistanceMeasure(new EuclideanDistance())
+ .withStrategy(NNStrategy.WEIGHTED);
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor
- ).accuracy();
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ CompositionUtils.asFeatureExtractor(vectorizer),
+ CompositionUtils.asLabelExtractor(vectorizer)
+ ).accuracy();
- System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Accuracy " + accuracy);
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
index fad238d..af450ca 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/KNNRegressionExample.java
@@ -20,11 +20,13 @@ package org.apache.ignite.examples.ml.knn;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
import org.apache.ignite.ml.math.distances.ManhattanDistance;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
@@ -57,32 +59,32 @@ public class KNNRegressionExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.CLEARED_MACHINES);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.CLEARED_MACHINES);
- KNNRegressionTrainer trainer = new KNNRegressionTrainer();
+ KNNRegressionTrainer trainer = new KNNRegressionTrainer();
- final IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
- final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
- ignite,
- dataCache,
- featureExtractor,
- lbExtractor
- ).withK(5)
- .withDistanceMeasure(new ManhattanDistance())
- .withStrategy(NNStrategy.WEIGHTED);
+ KNNRegressionModel knnMdl = (KNNRegressionModel)trainer.fit(ignite, dataCache, vectorizer)
+ .withK(5)
+ .withDistanceMeasure(new ManhattanDistance())
+ .withStrategy(NNStrategy.WEIGHTED);
- double rmse = Evaluator.evaluate(
- dataCache,
- knnMdl,
- featureExtractor,
- lbExtractor,
- new RegressionMetrics()
- );
+ double rmse = Evaluator.evaluate(
+ dataCache,
+ knnMdl,
+ CompositionUtils.asFeatureExtractor(vectorizer),
+ CompositionUtils.asLabelExtractor(vectorizer),
+ new RegressionMetrics()
+ );
- System.out.println("\n>>> Rmse = " + rmse);
+ System.out.println("\n>>> Rmse = " + rmse);
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/OneVsRestClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/OneVsRestClassificationExample.java
index 080f45d..0674590 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/OneVsRestClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/multiclass/OneVsRestClassificationExample.java
@@ -17,15 +17,15 @@
package org.apache.ignite.examples.ml.multiclass;
-import java.io.FileNotFoundException;
-import java.util.Arrays;
-import javax.cache.Cache;
import org.apache.commons.math3.util.Precision;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.multiclass.MultiClassModel;
@@ -36,23 +36,27 @@ import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import javax.cache.Cache;
+import java.io.FileNotFoundException;
+import java.util.Arrays;
+
/**
* Run One-vs-Rest multi-class classification trainer ({@link OneVsRestTrainer}) parametrized by binary SVM classifier
- * ({@link SVMLinearClassificationTrainer}) over distributed dataset
- * to build two models: one with min-max scaling and one without min-max scaling.
+ * ({@link SVMLinearClassificationTrainer}) over distributed dataset to build two models: one with min-max scaling and
+ * one without min-max scaling.
* <p>
* Code in this example launches Ignite grid and fills the cache with test data points (preprocessed
* <a href="https://archive.ics.uci.edu/ml/datasets/Glass+Identification">Glass dataset</a>).</p>
* <p>
- * After that it trains two One-vs-Rest multi-class models based on the specified data - one model is with min-max scaling
- * and one without min-max scaling.</p>
+ * After that it trains two One-vs-Rest multi-class models based on the specified data - one model is with min-max
+ * scaling and one without min-max scaling.</p>
* <p>
- * Finally, this example loops over the test set of data points, applies the trained models to predict what cluster
- * does this point belong to, compares prediction to expected outcome (ground truth), and builds
+ * Finally, this example loops over the test set of data points, applies the trained models to predict what cluster does
+ * this point belong to, compares prediction to expected outcome (ground truth), and builds
* <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p>
* <p>
- * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
- * NOTE: the smallest 3rd class could not be classified via linear SVM here.
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p> NOTE: the
+ * smallest 3rd class could not be classified via linear SVM here.
*/
public class OneVsRestClassificationExample {
/** Run example. */
@@ -63,100 +67,105 @@ public class OneVsRestClassificationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.GLASS_IDENTIFICATION);
-
- OneVsRestTrainer<SVMLinearClassificationModel> trainer
- = new OneVsRestTrainer<>(new SVMLinearClassificationTrainer()
- .withAmountOfIterations(20)
- .withAmountOfLocIterations(50)
- .withLambda(0.2)
- .withSeed(1234L)
- );
-
- MultiClassModel<SVMLinearClassificationModel> mdl = trainer.fit(
- ignite,
- dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
- );
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.GLASS_IDENTIFICATION);
+
+ OneVsRestTrainer<SVMLinearClassificationModel> trainer
+ = new OneVsRestTrainer<>(new SVMLinearClassificationTrainer()
+ .withAmountOfIterations(20)
+ .withAmountOfLocIterations(50)
+ .withLambda(0.2)
+ .withSeed(1234L)
+ );
+
+ MultiClassModel<SVMLinearClassificationModel> mdl = trainer.fit(
+ ignite,
+ dataCache,
+ new DummyVectorizer<Integer>().labeled(0)
+ );
+
+ System.out.println(">>> One-vs-Rest SVM Multi-class model");
+ System.out.println(mdl.toString());
+
+ MinMaxScalerTrainer<Integer, Vector> minMaxScalerTrainer = new MinMaxScalerTrainer<>();
+
+ IgniteBiFunction<Integer, Vector, Vector> preprocessor = minMaxScalerTrainer.fit(
+ ignite,
+ dataCache,
+ CompositionUtils.asFeatureExtractor(new DummyVectorizer<Integer>().exclude(0)) //TODO: IGNITE-11504
+ );
+
+ MultiClassModel<SVMLinearClassificationModel> mdlWithScaling = trainer.fit(
+ ignite,
+ dataCache,
+ FeatureLabelExtractorWrapper.wrap(
+ preprocessor,
+ CompositionUtils.asLabelExtractor(new DummyVectorizer<Integer>().labeled(0)) //TODO: IGNITE-11504
+ )
+ );
+
+ System.out.println(">>> One-vs-Rest SVM Multi-class model with MinMaxScaling");
+ System.out.println(mdlWithScaling.toString());
- System.out.println(">>> One-vs-Rest SVM Multi-class model");
- System.out.println(mdl.toString());
-
- MinMaxScalerTrainer<Integer, Vector> minMaxScalerTrainer = new MinMaxScalerTrainer<>();
-
- IgniteBiFunction<Integer, Vector, Vector> preprocessor = minMaxScalerTrainer.fit(
- ignite,
- dataCache,
- (k, v) -> v.copyOfRange(1, v.size())
- );
-
- MultiClassModel<SVMLinearClassificationModel> mdlWithScaling = trainer.fit(
- ignite,
- dataCache,
- preprocessor,
- (k, v) -> v.get(0)
- );
+ System.out.println(">>> ----------------------------------------------------------------");
+ System.out.println(">>> | Prediction\t| Prediction with MinMaxScaling\t| Ground Truth\t|");
+ System.out.println(">>> ----------------------------------------------------------------");
- System.out.println(">>> One-vs-Rest SVM Multi-class model with MinMaxScaling");
- System.out.println(mdlWithScaling.toString());
+ int amountOfErrors = 0;
+ int amountOfErrorsWithMinMaxScaling = 0;
+ int totalAmount = 0;
- System.out.println(">>> ----------------------------------------------------------------");
- System.out.println(">>> | Prediction\t| Prediction with MinMaxScaling\t| Ground Truth\t|");
- System.out.println(">>> ----------------------------------------------------------------");
+ // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
+ int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
+ int[][] confusionMtxWithMinMaxScaling = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
- int amountOfErrors = 0;
- int amountOfErrorsWithMinMaxScaling = 0;
- int totalAmount = 0;
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
- int[][] confusionMtx = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
- int[][] confusionMtxWithMinMaxScaling = {{0, 0, 0}, {0, 0, 0}, {0, 0, 0}};
+ double prediction = mdl.predict(inputs);
+ double predictionWithMinMaxScaling = mdlWithScaling.predict(inputs);
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
+ totalAmount++;
- double prediction = mdl.predict(inputs);
- double predictionWithMinMaxScaling = mdlWithScaling.predict(inputs);
+ // Collect data for model
+ if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
+ amountOfErrors++;
- totalAmount++;
+ int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2);
+ int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
- // Collect data for model
- if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
- amountOfErrors++;
+ confusionMtx[idx1][idx2]++;
- int idx1 = (int)prediction == 1 ? 0 : ((int)prediction == 3 ? 1 : 2);
- int idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
+ // Collect data for model with min-max scaling
+ if (!Precision.equals(groundTruth, predictionWithMinMaxScaling, Precision.EPSILON))
+ amountOfErrorsWithMinMaxScaling++;
- confusionMtx[idx1][idx2]++;
+ idx1 = (int)predictionWithMinMaxScaling == 1 ? 0 : ((int)predictionWithMinMaxScaling == 3 ? 1 : 2);
+ idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
- // Collect data for model with min-max scaling
- if (!Precision.equals(groundTruth, predictionWithMinMaxScaling, Precision.EPSILON))
- amountOfErrorsWithMinMaxScaling++;
+ confusionMtxWithMinMaxScaling[idx1][idx2]++;
- idx1 = (int)predictionWithMinMaxScaling == 1 ? 0 : ((int)predictionWithMinMaxScaling == 3 ? 1 : 2);
- idx2 = (int)groundTruth == 1 ? 0 : ((int)groundTruth == 3 ? 1 : 2);
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithMinMaxScaling, groundTruth);
+ }
+ System.out.println(">>> ----------------------------------------------------------------");
+ System.out.println("\n>>> -----------------One-vs-Rest SVM model-------------");
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
- confusionMtxWithMinMaxScaling[idx1][idx2]++;
+ System.out.println("\n>>> -----------------One-vs-Rest SVM model with MinMaxScaling-------------");
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithMinMaxScaling);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithMinMaxScaling / (double)totalAmount));
+ System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithMinMaxScaling));
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t\t\t\t\t| %.4f\t\t|\n", prediction, predictionWithMinMaxScaling, groundTruth);
+ System.out.println(">>> One-vs-Rest SVM model over cache based dataset usage example completed.");
}
- System.out.println(">>> ----------------------------------------------------------------");
- System.out.println("\n>>> -----------------One-vs-Rest SVM model-------------");
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
-
- System.out.println("\n>>> -----------------One-vs-Rest SVM model with MinMaxScaling-------------");
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrorsWithMinMaxScaling);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrorsWithMinMaxScaling / (double)totalAmount));
- System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtxWithMinMaxScaling));
-
- System.out.println(">>> One-vs-Rest SVM model over cache based dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
index 4114f2d..c19b537 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/DiscreteNaiveBayesTrainerExample.java
@@ -20,7 +20,9 @@ package org.apache.ignite.examples.ml.naivebayes;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
@@ -54,37 +56,35 @@ public class DiscreteNaiveBayesTrainerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.ENGLISH_VS_SCOTTISH);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.ENGLISH_VS_SCOTTISH);
- double[][] thresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, {.5}};
- System.out.println(">>> Create new Discrete naive Bayes classification trainer object.");
- DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer()
- .setBucketThresholds(thresholds);
+ double[][] thresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, {.5}};
+ System.out.println(">>> Create new Discrete naive Bayes classification trainer object.");
+ DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer()
+ .setBucketThresholds(thresholds);
- System.out.println(">>> Perform the training to get the model.");
- IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
- IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+ System.out.println(">>> Perform the training to get the model.");
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- DiscreteNaiveBayesModel mdl = trainer.fit(
- ignite,
- dataCache,
- featureExtractor,
- lbExtractor
- );
+ DiscreteNaiveBayesModel mdl = trainer.fit(ignite, dataCache, vectorizer);
+ System.out.println(">>> Discrete Naive Bayes model: " + mdl);
- System.out.println(">>> Discrete Naive Bayes model: " + mdl);
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ CompositionUtils.asFeatureExtractor(vectorizer),
+ CompositionUtils.asLabelExtractor(vectorizer)
+ ).accuracy();
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor
- ).accuracy();
+ System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println("\n>>> Accuracy " + accuracy);
-
- System.out.println(">>> Discrete Naive bayes model over partitioned dataset usage example completed.");
+ System.out.println(">>> Discrete Naive bayes model over partitioned dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
index c98ad62..2f6ed43 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/naivebayes/GaussianNaiveBayesTrainerExample.java
@@ -20,6 +20,9 @@ package org.apache.ignite.examples.ml.naivebayes;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
@@ -54,36 +57,36 @@ public class GaussianNaiveBayesTrainerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
- System.out.println(">>> Create new naive Bayes classification trainer object.");
- GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
+ System.out.println(">>> Create new naive Bayes classification trainer object.");
+ GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
- System.out.println(">>> Perform the training to get the model.");
+ System.out.println(">>> Perform the training to get the model.");
- IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
- IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- GaussianNaiveBayesModel mdl = trainer.fit(
- ignite,
- dataCache,
- featureExtractor,
- lbExtractor
- );
+ GaussianNaiveBayesModel mdl = trainer.fit(ignite, dataCache, vectorizer);
+ System.out.println(">>> Naive Bayes model: " + mdl);
- System.out.println(">>> Naive Bayes model: " + mdl);
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = CompositionUtils.asFeatureExtractor(vectorizer);
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = CompositionUtils.asLabelExtractor(vectorizer);
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor
+ ).accuracy();
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor
- ).accuracy();
+ System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println("\n>>> Accuracy " + accuracy);
-
- System.out.println(">>> Naive bayes model over partitioned dataset usage example completed.");
+ System.out.println(">>> Naive bayes model over partitioned dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
index a6f177a..c137898 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/nn/MLPTrainerExample.java
@@ -23,6 +23,7 @@ import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ExampleNodeStartup;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -34,6 +35,7 @@ import org.apache.ignite.ml.nn.architecture.MLPArchitecture;
import org.apache.ignite.ml.optimization.LossFunctions;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.structures.LabeledVector;
/**
* Example of using distributed {@link MultilayerPerceptron}.
@@ -44,16 +46,16 @@ import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalcula
* <a href="https://en.wikipedia.org/wiki/Neural_network">neural network</a> trainer, trains neural network
* and obtains multilayer perceptron model.</p>
* <p>
- * Finally, this example loops over the test set, applies the trained model to predict the value and
- * compares prediction to expected outcome.</p>
+ * Finally, this example loops over the test set, applies the trained model to predict the value and compares prediction
+ * to expected outcome.</p>
* <p>
* You can change the test data used in this example and re-run it to explore this functionality further.</p>
* <p>
- * Remote nodes should always be started with special configuration file which
- * enables P2P class loading: {@code 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p>
+ * Remote nodes should always be started with special configuration file which enables P2P class loading: {@code
+ * 'ignite.{sh|bat} examples/config/example-ignite.xml'}.</p>
* <p>
- * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node
- * with {@code examples/config/example-ignite.xml} configuration.</p>
+ * Alternatively you can run {@link ExampleNodeStartup} in another JVM which will start node with {@code
+ * examples/config/example-ignite.xml} configuration.</p>
*/
public class MLPTrainerExample {
/**
@@ -70,64 +72,64 @@ public class MLPTrainerExample {
System.out.println(">>> Ignite grid started.");
// Create cache with training data.
- CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
+ CacheConfiguration<Integer, LabeledVector<double[]>> trainingSetCfg = new CacheConfiguration<>();
trainingSetCfg.setName("TRAINING_SET");
trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
- IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
-
- // Fill cache with training data.
- trainingSet.put(0, new LabeledPoint(0, 0, 0));
- trainingSet.put(1, new LabeledPoint(0, 1, 1));
- trainingSet.put(2, new LabeledPoint(1, 0, 1));
- trainingSet.put(3, new LabeledPoint(1, 1, 0));
-
- // Define a layered architecture.
- MLPArchitecture arch = new MLPArchitecture(2).
- withAddedLayer(10, true, Activators.RELU).
- withAddedLayer(1, false, Activators.SIGMOID);
-
- // Define a neural network trainer.
- MLPTrainer<SimpleGDParameterUpdate> trainer = new MLPTrainer<>(
- arch,
- LossFunctions.MSE,
- new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.1),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ),
- 3000,
- 4,
- 50,
- 123L
- );
-
- // Train neural network and get multilayer perceptron model.
- MultilayerPerceptron mlp = trainer.fit(
- ignite,
- trainingSet,
- (k, v) -> VectorUtils.of(v.x, v.y),
- (k, v) -> new double[] {v.lb}
- );
-
- int totalCnt = 4;
- int failCnt = 0;
-
- // Calculate score.
- for (int i = 0; i < 4; i++) {
- LabeledPoint pnt = trainingSet.get(i);
- Matrix predicted = mlp.predict(new DenseMatrix(new double[][] {{pnt.x, pnt.y}}));
-
- double predictedVal = predicted.get(0, 0);
- double lbl = pnt.lb;
- System.out.printf(">>> key: %d\t\t predicted: %.4f\t\tlabel: %.4f\n", i, predictedVal, lbl);
- failCnt += Math.abs(predictedVal - lbl) < 0.5 ? 0 : 1;
+ IgniteCache<Integer, LabeledVector<double[]>> trainingSet = null;
+ try {
+ trainingSet = ignite.createCache(trainingSetCfg);
+
+ // Fill cache with training data.
+ trainingSet.put(0, new LabeledVector<>(VectorUtils.of(0, 0), new double[] {0}));
+ trainingSet.put(1, new LabeledVector<>(VectorUtils.of(0, 1), new double[] {1}));
+ trainingSet.put(2, new LabeledVector<>(VectorUtils.of(1, 0), new double[] {1}));
+ trainingSet.put(3, new LabeledVector<>(VectorUtils.of(1, 1), new double[] {0}));
+
+ // Define a layered architecture.
+ MLPArchitecture arch = new MLPArchitecture(2).
+ withAddedLayer(10, true, Activators.RELU).
+ withAddedLayer(1, false, Activators.SIGMOID);
+
+ // Define a neural network trainer.
+ MLPTrainer<SimpleGDParameterUpdate> trainer = new MLPTrainer<>(
+ arch,
+ LossFunctions.MSE,
+ new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.1),
+ SimpleGDParameterUpdate.SUM_LOCAL,
+ SimpleGDParameterUpdate.AVG
+ ),
+ 3000,
+ 4,
+ 50,
+ 123L
+ );
+
+ // Train neural network and get multilayer perceptron model.
+ MultilayerPerceptron mlp = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
+
+ int totalCnt = 4;
+ int failCnt = 0;
+
+ // Calculate score.
+ for (int i = 0; i < 4; i++) {
+ LabeledVector<double[]> pnt = trainingSet.get(i);
+ Matrix predicted = mlp.predict(new DenseMatrix(new double[][] {{pnt.features().get(0), pnt.features().get(1)}}));
+
+ double predictedVal = predicted.get(0, 0);
+ double lbl = pnt.label()[0];
+ System.out.printf(">>> key: %d\t\t predicted: %.4f\t\tlabel: %.4f\n", i, predictedVal, lbl);
+ failCnt += Math.abs(predictedVal - lbl) < 0.5 ? 0 : 1;
+ }
+
+ double failRatio = (double)failCnt / totalCnt;
+
+ System.out.println("\n>>> Fail percentage: " + (failRatio * 100) + "%.");
+ System.out.println("\n>>> Distributed multilayer perceptron example completed.");
+ } finally {
+ trainingSet.destroy();
}
-
- double failRatio = (double)failCnt / totalCnt;
-
- System.out.println("\n>>> Fail percentage: " + (failRatio * 100) + "%.");
- System.out.println("\n>>> Distributed multilayer perceptron example completed.");
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/BinarizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/BinarizationExample.java
index a1e7672..4919a2f 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/BinarizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/BinarizationExample.java
@@ -25,6 +25,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.dataset.model.Person;
import org.apache.ignite.examples.ml.util.DatasetHelper;
import org.apache.ignite.ml.dataset.DatasetFactory;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.dataset.primitive.SimpleDataset;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -49,24 +50,29 @@ public class BinarizationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Binarization example started.");
- IgniteCache<Integer, Person> persons = createCache(ignite);
+ IgniteCache<Integer, Person> persons = null;
+ try {
+ persons = createCache(ignite);
- // Defines first preprocessor that extracts features from an upstream data.
- IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
- v.getAge()
- );
+ // Defines first preprocessor that extracts features from an upstream data.
+ IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
+ v.getAge()
+ );
- // Defines second preprocessor that normalizes features.
- IgniteBiFunction<Integer, Person, Vector> preprocessor = new BinarizationTrainer<Integer, Person>()
- .withThreshold(40)
- .fit(ignite, persons, featureExtractor);
+ // Defines second preprocessor that normalizes features.
+ IgniteBiFunction<Integer, Person, Vector> preprocessor = new BinarizationTrainer<Integer, Person>()
+ .withThreshold(40)
+ .fit(ignite, persons, featureExtractor);
- // Creates a cache based simple dataset containing features and providing standard dataset API.
- try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, preprocessor)) {
- new DatasetHelper(dataset).describe();
- }
+ // Creates a cache based simple dataset containing features and providing standard dataset API.
+ try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, FeatureLabelExtractorWrapper.wrap(preprocessor))) {
+ new DatasetHelper(dataset).describe();
+ }
- System.out.println(">>> Binarization example completed.");
+ System.out.println(">>> Binarization example completed.");
+ } finally {
+ persons.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingExample.java
index eefe063..8fd9216 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingExample.java
@@ -25,6 +25,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.dataset.model.Person;
import org.apache.ignite.examples.ml.util.DatasetHelper;
import org.apache.ignite.ml.dataset.DatasetFactory;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.dataset.primitive.SimpleDataset;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -39,8 +40,8 @@ import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
* <p>
* After that it defines preprocessors that extract features from an upstream data and impute missing values.</p>
* <p>
- * Finally, it creates the dataset based on the processed data and uses Dataset API to find and output
- * various statistical metrics of the data.</p>
+ * Finally, it creates the dataset based on the processed data and uses Dataset API to find and output various
+ * statistical metrics of the data.</p>
* <p>
* You can change the test data used in this example and re-run it to explore this functionality further.</p>
*/
@@ -50,24 +51,29 @@ public class ImputingExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Imputing example started.");
- IgniteCache<Integer, Person> persons = createCache(ignite);
+ IgniteCache<Integer, Person> persons = null;
+ try {
+ persons = createCache(ignite);
- // Defines first preprocessor that extracts features from an upstream data.
- IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
- v.getAge(),
- v.getSalary()
- );
+ // Defines first preprocessor that extracts features from an upstream data.
+ IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
+ v.getAge(),
+ v.getSalary()
+ );
- // Defines second preprocessor that imputing features.
- IgniteBiFunction<Integer, Person, Vector> preprocessor = new ImputerTrainer<Integer, Person>()
- .fit(ignite, persons, featureExtractor);
+ // Defines second preprocessor that imputing features.
+ IgniteBiFunction<Integer, Person, Vector> preprocessor = new ImputerTrainer<Integer, Person>()
+ .fit(ignite, persons, featureExtractor);
- // Creates a cache based simple dataset containing features and providing standard dataset API.
- try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, preprocessor)) {
- new DatasetHelper(dataset).describe();
- }
+ // Creates a cache based simple dataset containing features and providing standard dataset API.
+ try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, FeatureLabelExtractorWrapper.wrap(preprocessor))) {
+ new DatasetHelper(dataset).describe();
+ }
- System.out.println(">>> Imputing example completed.");
+ System.out.println(">>> Imputing example completed.");
+ } finally {
+ persons.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingWithMostFrequentValuesExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingWithMostFrequentValuesExample.java
index 8e39409..11a5b79 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingWithMostFrequentValuesExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/ImputingWithMostFrequentValuesExample.java
@@ -25,6 +25,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.dataset.model.Person;
import org.apache.ignite.examples.ml.util.DatasetHelper;
import org.apache.ignite.ml.dataset.DatasetFactory;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.dataset.primitive.SimpleDataset;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -51,25 +52,30 @@ public class ImputingWithMostFrequentValuesExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Imputing with most frequent values example started.");
- IgniteCache<Integer, Person> persons = createCache(ignite);
+ IgniteCache<Integer, Person> persons = null;
+ try {
+ persons = createCache(ignite);
- // Defines first preprocessor that extracts features from an upstream data.
- IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
- v.getAge(),
- v.getSalary()
- );
+ // Defines first preprocessor that extracts features from an upstream data.
+ IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
+ v.getAge(),
+ v.getSalary()
+ );
- // Defines second preprocessor that normalizes features.
- IgniteBiFunction<Integer, Person, Vector> preprocessor = new ImputerTrainer<Integer, Person>()
- .withImputingStrategy(ImputingStrategy.MOST_FREQUENT)
- .fit(ignite, persons, featureExtractor);
+ // Defines second preprocessor that normalizes features.
+ IgniteBiFunction<Integer, Person, Vector> preprocessor = new ImputerTrainer<Integer, Person>()
+ .withImputingStrategy(ImputingStrategy.MOST_FREQUENT)
+ .fit(ignite, persons, featureExtractor);
- // Creates a cache based simple dataset containing features and providing standard dataset API.
- try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, preprocessor)) {
- new DatasetHelper(dataset).describe();
- }
+ // Creates a cache based simple dataset containing features and providing standard dataset API.
+ try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, FeatureLabelExtractorWrapper.wrap(preprocessor))) {
+ new DatasetHelper(dataset).describe();
+ }
- System.out.println(">>> Imputing with most frequent values example completed.");
+ System.out.println(">>> Imputing with most frequent values example completed.");
+ } finally {
+ persons.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/MaxAbsScalerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/MaxAbsScalerExample.java
index 955702a..6c0d981 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/MaxAbsScalerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/MaxAbsScalerExample.java
@@ -25,6 +25,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.dataset.model.Person;
import org.apache.ignite.examples.ml.util.DatasetHelper;
import org.apache.ignite.ml.dataset.DatasetFactory;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.dataset.primitive.SimpleDataset;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -44,24 +45,29 @@ public class MaxAbsScalerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Max abs example started.");
- IgniteCache<Integer, Person> persons = createCache(ignite);
+ IgniteCache<Integer, Person> persons = null;
+ try {
+ persons = createCache(ignite);
- // Defines first preprocessor that extracts features from an upstream data.
- IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
- v.getAge(),
- v.getSalary()
- );
+ // Defines first preprocessor that extracts features from an upstream data.
+ IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
+ v.getAge(),
+ v.getSalary()
+ );
- // Defines second preprocessor that processes features.
- IgniteBiFunction<Integer, Person, Vector> preprocessor = new MaxAbsScalerTrainer<Integer, Person>()
- .fit(ignite, persons, featureExtractor);
+ // Defines second preprocessor that processes features.
+ IgniteBiFunction<Integer, Person, Vector> preprocessor = new MaxAbsScalerTrainer<Integer, Person>()
+ .fit(ignite, persons, featureExtractor);
- // Creates a cache based simple dataset containing features and providing standard dataset API.
- try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, preprocessor)) {
- new DatasetHelper(dataset).describe();
- }
+ // Creates a cache based simple dataset containing features and providing standard dataset API.
+ try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, FeatureLabelExtractorWrapper.wrap(preprocessor))) {
+ new DatasetHelper(dataset).describe();
+ }
- System.out.println(">>> Max abs example completed.");
+ System.out.println(">>> Max abs example completed.");
+ } finally {
+ persons.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/MinMaxScalerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/MinMaxScalerExample.java
index f73228f..746aa63 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/MinMaxScalerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/MinMaxScalerExample.java
@@ -25,6 +25,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.dataset.model.Person;
import org.apache.ignite.examples.ml.util.DatasetHelper;
import org.apache.ignite.ml.dataset.DatasetFactory;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.dataset.primitive.SimpleDataset;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -53,24 +54,29 @@ public class MinMaxScalerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> MinMax preprocessing example started.");
- IgniteCache<Integer, Person> persons = createCache(ignite);
+ IgniteCache<Integer, Person> persons = null;
+ try {
+ persons = createCache(ignite);
- // Defines first preprocessor that extracts features from an upstream data.
- IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
- v.getAge(),
- v.getSalary()
- );
+ // Defines first preprocessor that extracts features from an upstream data.
+ IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
+ v.getAge(),
+ v.getSalary()
+ );
- // Defines second preprocessor that normalizes features.
- IgniteBiFunction<Integer, Person, Vector> preprocessor = new MinMaxScalerTrainer<Integer, Person>()
- .fit(ignite, persons, featureExtractor);
+ // Defines second preprocessor that normalizes features.
+ IgniteBiFunction<Integer, Person, Vector> preprocessor = new MinMaxScalerTrainer<Integer, Person>()
+ .fit(ignite, persons, featureExtractor);
- // Creates a cache based simple dataset containing features and providing standard dataset API.
- try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, preprocessor)) {
- new DatasetHelper(dataset).describe();
- }
+ // Creates a cache based simple dataset containing features and providing standard dataset API.
+ try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, FeatureLabelExtractorWrapper.wrap(preprocessor))) {
+ new DatasetHelper(dataset).describe();
+ }
- System.out.println(">>> MinMax preprocessing example completed.");
+ System.out.println(">>> MinMax preprocessing example completed.");
+ } finally {
+ persons.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java
index 3159845..41e0e16 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/NormalizationExample.java
@@ -25,6 +25,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.dataset.model.Person;
import org.apache.ignite.examples.ml.util.DatasetHelper;
import org.apache.ignite.ml.dataset.DatasetFactory;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.dataset.primitive.SimpleDataset;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -49,25 +50,30 @@ public class NormalizationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Normalization example started.");
- IgniteCache<Integer, Person> persons = createCache(ignite);
+ IgniteCache<Integer, Person> persons = null;
+ try {
+ persons = createCache(ignite);
- // Defines first preprocessor that extracts features from an upstream data.
- IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
- v.getAge(),
- v.getSalary()
- );
+ // Defines first preprocessor that extracts features from an upstream data.
+ IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
+ v.getAge(),
+ v.getSalary()
+ );
- // Defines second preprocessor that normalizes features.
- IgniteBiFunction<Integer, Person, Vector> preprocessor = new NormalizationTrainer<Integer, Person>()
- .withP(1)
- .fit(ignite, persons, featureExtractor);
+ // Defines second preprocessor that normalizes features.
+ IgniteBiFunction<Integer, Person, Vector> preprocessor = new NormalizationTrainer<Integer, Person>()
+ .withP(1)
+ .fit(ignite, persons, featureExtractor);
- // Creates a cache based simple dataset containing features and providing standard dataset API.
- try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, preprocessor)) {
- new DatasetHelper(dataset).describe();
- }
+ // Creates a cache based simple dataset containing features and providing standard dataset API.
+ try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, FeatureLabelExtractorWrapper.wrap(preprocessor))) {
+ new DatasetHelper(dataset).describe();
+ }
- System.out.println(">>> Normalization example completed.");
+ System.out.println(">>> Normalization example completed.");
+ } finally {
+ persons.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java
index 13d8635..9e02a2d 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/StandardScalerExample.java
@@ -25,6 +25,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.examples.ml.dataset.model.Person;
import org.apache.ignite.examples.ml.util.DatasetHelper;
import org.apache.ignite.ml.dataset.DatasetFactory;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.dataset.primitive.SimpleDataset;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -44,24 +45,29 @@ public class StandardScalerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Standard scaler example started.");
- IgniteCache<Integer, Person> persons = createCache(ignite);
+ IgniteCache<Integer, Person> persons = null;
+ try {
+ persons = createCache(ignite);
- // Defines first preprocessor that extracts features from an upstream data.
- IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
- v.getAge(),
- v.getSalary()
- );
+ // Defines first preprocessor that extracts features from an upstream data.
+ IgniteBiFunction<Integer, Person, Vector> featureExtractor = (k, v) -> VectorUtils.of(
+ v.getAge(),
+ v.getSalary()
+ );
- // Defines second preprocessor that processes features.
- IgniteBiFunction<Integer, Person, Vector> preprocessor = new StandardScalerTrainer<Integer, Person>()
- .fit(ignite, persons, featureExtractor);
+ // Defines second preprocessor that processes features.
+ IgniteBiFunction<Integer, Person, Vector> preprocessor = new StandardScalerTrainer<Integer, Person>()
+ .fit(ignite, persons, featureExtractor);
- // Creates a cache based simple dataset containing features and providing standard dataset API.
- try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, preprocessor)) {
- new DatasetHelper(dataset).describe();
- }
+ // Creates a cache based simple dataset containing features and providing standard dataset API.
+ try (SimpleDataset<?> dataset = DatasetFactory.createSimpleDataset(ignite, persons, FeatureLabelExtractorWrapper.wrap(preprocessor))) {
+ new DatasetHelper(dataset).describe();
+ }
- System.out.println(">>> Standard scaler example completed.");
+ System.out.println(">>> Standard scaler example completed.");
+ } finally {
+ persons.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
index 6f1fe4c..26279c8 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerExample.java
@@ -21,13 +21,13 @@ import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.regression.RegressionMetrics;
-import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
@@ -55,42 +55,42 @@ public class LinearRegressionLSQRTrainerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
- System.out.println(">>> Create new linear regression trainer object.");
- LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
+ System.out.println(">>> Create new linear regression trainer object.");
+ LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
- System.out.println(">>> Perform the training to get the model.");
+ System.out.println(">>> Perform the training to get the model.");
- // This object is used to extract features and vectors from upstream entities which are
- // essentialy tuples of the form (key, value) (in our case (Integer, Vector)).
- // Key part of tuple in our example is ignored.
- // Label is extracted from 0th entry of the value (which is a Vector)
- // and features are all remaining vector part. Alternatively we could use
- // DatasetTrainer#fit(Ignite, IgniteCache, IgniteBiFunction, IgniteBiFunction) method call
- // where there is a separate lambda for extracting label from (key, value) and a separate labmda for
- // extracting features.
- FeatureLabelExtractor<Integer, Vector, Double> extractor =
- (k, v) -> new LabeledVector<>(v.copyOfRange(1, v.size()), v.get(0));
+ // This object is used to extract features and vectors from upstream entities which are
+ // essentialy tuples of the form (key, value) (in our case (Integer, Vector)).
+ // Key part of tuple in our example is ignored.
+ // Label is extracted from 0th entry of the value (which is a Vector)
+ // and features are all remaining vector part. Alternatively we could use
+ // DatasetTrainer#fit(Ignite, IgniteCache, IgniteBiFunction, IgniteBiFunction) method call
+ // where there is a separate lambda for extracting label from (key, value) and a separate labmda for
+ // extracting features.
+ Vectorizer<Integer, Vector, Integer, Double> extractor = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- LinearRegressionModel mdl = trainer.fit(
- ignite,
- dataCache,
- extractor
- );
+ LinearRegressionModel mdl = trainer.fit(ignite, dataCache, extractor);
- double rmse = Evaluator.evaluate(
- dataCache,
- mdl,
- CompositionUtils.asFeatureExtractor(extractor),
- CompositionUtils.asLabelExtractor(extractor),
- new RegressionMetrics()
- );
+ double rmse = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ CompositionUtils.asFeatureExtractor(extractor),
+ CompositionUtils.asLabelExtractor(extractor),
+ new RegressionMetrics()
+ );
- System.out.println("\n>>> Rmse = " + rmse);
+ System.out.println("\n>>> Rmse = " + rmse);
- System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+ System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
index 6c7ec85..de9e3ff 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionLSQRTrainerWithMinMaxScalerExample.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.regression.linear;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.minmaxscaling.MinMaxScalerPreprocessor;
@@ -58,42 +59,46 @@ public class LinearRegressionLSQRTrainerWithMinMaxScalerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
- System.out.println(">>> Create new MinMaxScaler trainer object.");
- MinMaxScalerTrainer<Integer, Vector> minMaxScalerTrainer = new MinMaxScalerTrainer<>();
+ System.out.println(">>> Create new MinMaxScaler trainer object.");
+ MinMaxScalerTrainer<Integer, Vector> minMaxScalerTrainer = new MinMaxScalerTrainer<>();
- System.out.println(">>> Perform the training to get the MinMaxScaler preprocessor.");
- IgniteBiFunction<Integer, Vector, Vector> preprocessor = minMaxScalerTrainer.fit(
- ignite,
- dataCache,
- (k, v) -> v.copyOfRange(1, v.size())
- );
+ System.out.println(">>> Perform the training to get the MinMaxScaler preprocessor.");
+ IgniteBiFunction<Integer, Vector, Vector> preprocessor = minMaxScalerTrainer.fit(
+ ignite,
+ dataCache,
+ (k, v) -> v.copyOfRange(1, v.size())
+ );
- System.out.println(">>> Create new linear regression trainer object.");
- LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
+ System.out.println(">>> Create new linear regression trainer object.");
+ LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
- System.out.println(">>> Perform the training to get the model.");
+ System.out.println(">>> Perform the training to get the model.");
- final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+ final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
- LinearRegressionModel mdl = trainer.fit(ignite, dataCache, preprocessor, lbExtractor);
+ LinearRegressionModel mdl = trainer.fit(ignite, dataCache, FeatureLabelExtractorWrapper.wrap(preprocessor, lbExtractor)); //TODO: IGNITE-11581
- System.out.println(">>> Linear regression model: " + mdl);
+ System.out.println(">>> Linear regression model: " + mdl);
- double rmse = Evaluator.evaluate(
- dataCache,
- mdl,
- preprocessor,
- lbExtractor,
- new RegressionMetrics()
- );
+ double rmse = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ preprocessor,
+ lbExtractor,
+ new RegressionMetrics()
+ );
- System.out.println("\n>>> Rmse = " + rmse);
+ System.out.println("\n>>> Rmse = " + rmse);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> Linear regression model with MinMaxScaler preprocessor over cache based dataset usage example completed.");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> Linear regression model with MinMaxScaler preprocessor over cache based dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
index cb868b2..8e095f1 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/LinearRegressionSGDTrainerExample.java
@@ -20,6 +20,9 @@ package org.apache.ignite.examples.ml.regression.linear;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
@@ -58,42 +61,43 @@ public class LinearRegressionSGDTrainerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
- System.out.println(">>> Create new linear regression trainer object.");
- LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>(
- new RPropUpdateCalculator(),
- RPropParameterUpdate::sumLocal,
- RPropParameterUpdate::avg
- ), 100000, 10, 100, 123L);
+ System.out.println(">>> Create new linear regression trainer object.");
+ LinearRegressionSGDTrainer<?> trainer = new LinearRegressionSGDTrainer<>(new UpdatesStrategy<>(
+ new RPropUpdateCalculator(),
+ RPropParameterUpdate.SUM_LOCAL,
+ RPropParameterUpdate.AVG
+ ), 100000, 10, 100, 123L);
- System.out.println(">>> Perform the training to get the model.");
+ System.out.println(">>> Perform the training to get the model.");
- final IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
- final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- LinearRegressionModel mdl = trainer.fit(
- ignite,
- dataCache,
- featureExtractor,
- lbExtractor
- );
+ LinearRegressionModel mdl = trainer.fit(ignite, dataCache, vectorizer);
- System.out.println(">>> Linear regression model: " + mdl);
+ final IgniteBiFunction<Integer, Vector, Vector> featureExtractor = CompositionUtils.asFeatureExtractor(vectorizer);
+ final IgniteBiFunction<Integer, Vector, Double> lbExtractor = CompositionUtils.asLabelExtractor(vectorizer);
+ System.out.println(">>> Linear regression model: " + mdl);
- double rmse = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor,
- new RegressionMetrics()
- );
+ double rmse = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor,
+ new RegressionMetrics()
+ );
- System.out.println("\n>>> Rmse = " + rmse);
+ System.out.println("\n>>> Rmse = " + rmse);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
index 1e3914a..3da20ca 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/bagged/BaggedLogisticRegressionSGDTrainerExample.java
@@ -20,9 +20,12 @@ package org.apache.ignite.examples.ml.regression.logistic.bagged;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.composition.bagging.BaggedModel;
import org.apache.ignite.ml.composition.bagging.BaggedTrainer;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
@@ -60,49 +63,56 @@ public class BaggedLogisticRegressionSGDTrainerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
- System.out.println(">>> Create new logistic regression trainer object.");
- LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
- .withUpdatesStgy(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ))
- .withMaxIterations(100000)
- .withLocIterations(100)
- .withBatchSize(10)
- .withSeed(123L);
+ System.out.println(">>> Create new logistic regression trainer object.");
+ LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
+ .withUpdatesStgy(new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate.SUM_LOCAL,
+ SimpleGDParameterUpdate.AVG
+ ))
+ .withMaxIterations(100000)
+ .withLocIterations(100)
+ .withBatchSize(10)
+ .withSeed(123L);
- System.out.println(">>> Perform the training to get the model.");
+ System.out.println(">>> Perform the training to get the model.");
- BaggedTrainer<Double> baggedTrainer = TrainerTransformers.makeBagged(
- trainer,
- 10,
- 0.6,
- 4,
- 3,
- new OnMajorityPredictionsAggregator())
- .withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1));
+ BaggedTrainer<Double> baggedTrainer = TrainerTransformers.makeBagged(
+ trainer,
+ 10,
+ 0.6,
+ 4,
+ 3,
+ new OnMajorityPredictionsAggregator())
+ .withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder().withRNGSeed(1));
- System.out.println(">>> Perform evaluation of the model.");
+ System.out.println(">>> Perform evaluation of the model.");
- double[] score = new CrossValidation<BaggedModel, Double, Integer, Vector>().score(
- baggedTrainer,
- new Accuracy<>(),
- ignite,
- dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0),
- 3
- );
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- System.out.println(">>> ---------------------------------");
+ double[] score = new CrossValidation<BaggedModel, Double, Integer, Vector>().score(
+ baggedTrainer,
+ new Accuracy<>(),
+ ignite,
+ dataCache,
+ CompositionUtils.asFeatureExtractor(vectorizer),
+ CompositionUtils.asLabelExtractor(vectorizer),
+ 3
+ );
- Arrays.stream(score).forEach(sc -> System.out.println("\n>>> Accuracy " + sc));
+ System.out.println(">>> ---------------------------------");
- System.out.println(">>> Bagged logistic regression model over partitioned dataset usage example completed.");
+ Arrays.stream(score).forEach(sc -> System.out.println("\n>>> Accuracy " + sc));
+
+ System.out.println(">>> Bagged logistic regression model over partitioned dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
index a7c5ba9..09c50f8 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/logistic/binary/LogisticRegressionSGDTrainerExample.java
@@ -20,6 +20,9 @@ package org.apache.ignite.examples.ml.regression.logistic.binary;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
@@ -57,44 +60,45 @@ public class LogisticRegressionSGDTrainerExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
- System.out.println(">>> Create new logistic regression trainer object.");
- LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
- .withUpdatesStgy(new UpdatesStrategy<>(
- new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal,
- SimpleGDParameterUpdate::avg
- ))
- .withMaxIterations(100000)
- .withLocIterations(100)
- .withBatchSize(10)
- .withSeed(123L);
+ System.out.println(">>> Create new logistic regression trainer object.");
+ LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
+ .withUpdatesStgy(new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate.SUM_LOCAL,
+ SimpleGDParameterUpdate.AVG
+ ))
+ .withMaxIterations(100000)
+ .withLocIterations(100)
+ .withBatchSize(10)
+ .withSeed(123L);
- System.out.println(">>> Perform the training to get the model.");
- IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
- IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+ System.out.println(">>> Perform the training to get the model.");
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- LogisticRegressionModel mdl = trainer.fit(
- ignite,
- dataCache,
- featureExtractor,
- lbExtractor
- );
+ LogisticRegressionModel mdl = trainer.fit(ignite, dataCache, vectorizer);
- System.out.println(">>> Logistic regression model: " + mdl);
+ System.out.println(">>> Logistic regression model: " + mdl);
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor
- ).accuracy();
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = CompositionUtils.asFeatureExtractor(vectorizer);
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = CompositionUtils.asLabelExtractor(vectorizer);
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor
+ ).accuracy();
- System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println(">>> Logistic regression model over partitioned dataset usage example completed.");
+ System.out.println(">>> Logistic regression model over partitioned dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
index 8f06e0f..a99186f 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
@@ -22,11 +22,16 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.selection.cv.CrossValidation;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetricValues;
import org.apache.ignite.ml.selection.scoring.metric.classification.BinaryClassificationMetrics;
+import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
@@ -35,7 +40,8 @@ import java.util.Random;
/**
* Run <a href="https://en.wikipedia.org/wiki/Decision_tree">decision tree</a> classification with
- * <a href="https://en.wikipedia.org/wiki/Cross-validation_(statistics)">cross validation</a> ({@link CrossValidation}).
+ * <a href="https://en.wikipedia.org/wiki/Cross-validation_(statistics)">cross validation</a> ({@link
+ * CrossValidation}).
* <p>
* Code in this example launches Ignite grid and fills the cache with pseudo random training data points.</p>
* <p>
@@ -56,54 +62,62 @@ public class CrossValidationExample {
System.out.println(">>> Ignite grid started.");
// Create cache with training data.
- CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
+ CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
trainingSetCfg.setName("TRAINING_SET");
trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
- IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
-
- Random rnd = new Random(0);
-
- // Fill training data.
- for (int i = 0; i < 1000; i++)
- trainingSet.put(i, generatePoint(rnd));
-
- // Create classification trainer.
- DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
-
- CrossValidation<DecisionTreeNode, Double, Integer, LabeledPoint> scoreCalculator
- = new CrossValidation<>();
-
- double[] accuracyScores = scoreCalculator.score(
- trainer,
- new Accuracy<>(),
- ignite,
- trainingSet,
- (k, v) -> VectorUtils.of(v.x, v.y),
- (k, v) -> v.lb,
- 4
- );
-
- System.out.println(">>> Accuracy: " + Arrays.toString(accuracyScores));
-
- BinaryClassificationMetrics metrics = (BinaryClassificationMetrics) new BinaryClassificationMetrics()
- .withNegativeClsLb(0.0)
- .withPositiveClsLb(1.0)
- .withMetric(BinaryClassificationMetricValues::balancedAccuracy);
-
- double[] balancedAccuracyScores = scoreCalculator.score(
- trainer,
- metrics,
- ignite,
- trainingSet,
- (k, v) -> VectorUtils.of(v.x, v.y),
- (k, v) -> v.lb,
- 4
- );
-
- System.out.println(">>> Balanced Accuracy: " + Arrays.toString(balancedAccuracyScores));
-
- System.out.println(">>> Cross validation score calculator example completed.");
+ IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
+ try {
+ trainingSet = ignite.createCache(trainingSetCfg);
+
+ Random rnd = new Random(0);
+
+ // Fill training data.
+ for (int i = 0; i < 1000; i++)
+ trainingSet.put(i, generatePoint(rnd));
+
+ // Create classification trainer.
+ DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
+
+ LabeledDummyVectorizer<Integer, Double> vectorizer = new LabeledDummyVectorizer<>();
+ CrossValidation<DecisionTreeNode, Double, Integer, LabeledVector<Double>> scoreCalculator
+ = new CrossValidation<>();
+
+ IgniteBiFunction<Integer, LabeledVector<Double>, Vector> featureExtractor = CompositionUtils.asFeatureExtractor(vectorizer);
+ IgniteBiFunction<Integer, LabeledVector<Double>, Double> lbExtractor = CompositionUtils.asLabelExtractor(vectorizer);
+ double[] accuracyScores = scoreCalculator.score(
+ trainer,
+ new Accuracy<>(),
+ ignite,
+ trainingSet,
+ featureExtractor,
+ lbExtractor,
+ 4
+ );
+
+ System.out.println(">>> Accuracy: " + Arrays.toString(accuracyScores));
+
+ BinaryClassificationMetrics metrics = (BinaryClassificationMetrics)new BinaryClassificationMetrics()
+ .withNegativeClsLb(0.0)
+ .withPositiveClsLb(1.0)
+ .withMetric(BinaryClassificationMetricValues::balancedAccuracy);
+
+ double[] balancedAccuracyScores = scoreCalculator.score(
+ trainer,
+ metrics,
+ ignite,
+ trainingSet,
+ featureExtractor,
+ lbExtractor,
+ 4
+ );
+
+ System.out.println(">>> Balanced Accuracy: " + Arrays.toString(balancedAccuracyScores));
+
+ System.out.println(">>> Cross validation score calculator example completed.");
+ } finally {
+ trainingSet.destroy();
+ }
}
}
@@ -114,49 +128,14 @@ public class CrossValidationExample {
* @param rnd Random.
* @return Point with label.
*/
- private static LabeledPoint generatePoint(Random rnd) {
+ private static LabeledVector<Double> generatePoint(Random rnd) {
double x = rnd.nextDouble() - 0.5;
double y = rnd.nextDouble() - 0.5;
- return new LabeledPoint(x, y, x * y > 0 ? 1 : 0);
- }
-
- /** Point data class. */
- private static class Point {
- /** X coordinate. */
- final double x;
-
- /** Y coordinate. */
- final double y;
-
- /**
- * Constructs a new instance of point.
- *
- * @param x X coordinate.
- * @param y Y coordinate.
- */
- Point(double x, double y) {
- this.x = x;
- this.y = y;
- }
- }
-
- /** Labeled point data class. */
- private static class LabeledPoint extends Point {
- /** Point label. */
- final double lb;
-
- /**
- * Constructs a new instance of labeled point data.
- *
- * @param x X coordinate.
- * @param y Y coordinate.
- * @param lb Point label.
- */
- LabeledPoint(double x, double y, double lb) {
- super(x, y);
- this.lb = lb;
- }
+ return new LabeledVector<>(
+ VectorUtils.of(x, y),
+ x * y > 0 ? 1.0 : 0.0
+ );
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
index b5a6f89..31adf42 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/EvaluatorExample.java
@@ -20,6 +20,9 @@ package org.apache.ignite.examples.ml.selection.scoring;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
@@ -54,39 +57,40 @@ public class EvaluatorExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
- SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer();
+ SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer();
- IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
- IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- SVMLinearClassificationModel mdl = trainer.fit(
- ignite,
- dataCache,
- featureExtractor,
- lbExtractor
- );
+ SVMLinearClassificationModel mdl = trainer.fit(ignite, dataCache, vectorizer);
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor,
- new Accuracy<>()
- );
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = CompositionUtils.asFeatureExtractor(vectorizer);
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = CompositionUtils.asLabelExtractor(vectorizer);
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor,
+ new Accuracy<>()
+ );
- System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Accuracy " + accuracy);
- double f1Score = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor
- ).f1Score();
+ double f1Score = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor
+ ).f1Score();
- System.out.println("\n>>> F1-Score " + f1Score);
+ System.out.println("\n>>> F1-Score " + f1Score);
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java
index 934fb32..a085476 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/MultipleMetricsExample.java
@@ -20,6 +20,9 @@ package org.apache.ignite.examples.ml.selection.scoring;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -53,31 +56,32 @@ public class MultipleMetricsExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
- SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer();
+ SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer();
- IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
- IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- SVMLinearClassificationModel mdl = trainer.fit(
- ignite,
- dataCache,
- featureExtractor,
- lbExtractor
- );
+ SVMLinearClassificationModel mdl = trainer.fit(ignite, dataCache, vectorizer);
- Map<String, Double> scores = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor
- ).toMap();
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = CompositionUtils.asFeatureExtractor(vectorizer);
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = CompositionUtils.asLabelExtractor(vectorizer);
+ Map<String, Double> scores = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor
+ ).toMap();
- scores.forEach(
- (metricName, score) -> System.out.println("\n>>>" + metricName + ": " + score)
- );
+ scores.forEach(
+ (metricName, score) -> System.out.println("\n>>>" + metricName + ": " + score)
+ );
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/RegressionMetricExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/RegressionMetricExample.java
index a978078..dc39eae 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/RegressionMetricExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/scoring/RegressionMetricExample.java
@@ -20,6 +20,9 @@ package org.apache.ignite.examples.ml.selection.scoring;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.knn.classification.NNStrategy;
import org.apache.ignite.ml.knn.regression.KNNRegressionModel;
import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
@@ -40,11 +43,11 @@ import java.io.FileNotFoundException;
* After that it trains the model based on the specified data using
* <a href="https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm">kNN</a> regression algorithm.</p>
* <p>
- * Finally, this example loops over the test set of data points, applies the trained model to predict what cluster
- * does this point belong to, and compares prediction to expected outcome (ground truth).</p>
+ * Finally, this example loops over the test set of data points, applies the trained model to predict what cluster does
+ * this point belong to, and compares prediction to expected outcome (ground truth).</p>
* <p>
- * You can change the test data used in this example or trainer object settings and re-run it to explore
- * this algorithm further.</p>
+ * You can change the test data used in this example or trainer object settings and re-run it to explore this algorithm
+ * further.</p>
*/
public class RegressionMetricExample {
/** Run example. */
@@ -55,33 +58,33 @@ public class RegressionMetricExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.CLEARED_MACHINES);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.CLEARED_MACHINES);
- KNNRegressionTrainer trainer = new KNNRegressionTrainer();
+ KNNRegressionTrainer trainer = new KNNRegressionTrainer();
- final IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
- final IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- KNNRegressionModel knnMdl = (KNNRegressionModel) trainer.fit(
- ignite,
- dataCache,
- featureExtractor,
- lbExtractor
- ).withK(5)
- .withDistanceMeasure(new ManhattanDistance())
- .withStrategy(NNStrategy.WEIGHTED);
+ KNNRegressionModel knnMdl = (KNNRegressionModel)trainer.fit(ignite, dataCache, vectorizer).withK(5)
+ .withDistanceMeasure(new ManhattanDistance())
+ .withStrategy(NNStrategy.WEIGHTED);
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = CompositionUtils.asFeatureExtractor(vectorizer);
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = CompositionUtils.asLabelExtractor(vectorizer);
+ double mae = Evaluator.evaluate(
+ dataCache,
+ knnMdl,
+ featureExtractor,
+ lbExtractor,
+ new RegressionMetrics().withMetric(RegressionMetricValues::mae)
+ );
- double mae = Evaluator.evaluate(
- dataCache,
- knnMdl,
- featureExtractor,
- lbExtractor,
- new RegressionMetrics().withMetric(RegressionMetricValues::mae)
- );
-
- System.out.println("\n>>> Mae " + mae);
+ System.out.println("\n>>> Mae " + mae);
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
index c9a7ae4..c28edb9 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/split/TrainTestDatasetSplitterExample.java
@@ -17,13 +17,13 @@
package org.apache.ignite.examples.ml.selection.split;
-import java.io.FileNotFoundException;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
@@ -32,6 +32,9 @@ import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import javax.cache.Cache;
+import java.io.FileNotFoundException;
+
/**
* Run linear regression model over dataset split on train and test subsets ({@link TrainTestDatasetSplitter}).
* <p>
@@ -55,48 +58,48 @@ public class TrainTestDatasetSplitterExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
- System.out.println(">>> Create new linear regression trainer object.");
- LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
+ System.out.println(">>> Create new linear regression trainer object.");
+ LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
- System.out.println(">>> Create new training dataset splitter object.");
- TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>()
- .split(0.75);
+ System.out.println(">>> Create new training dataset splitter object.");
+ TrainTestSplit<Integer, Vector> split = new TrainTestDatasetSplitter<Integer, Vector>()
+ .split(0.75);
- System.out.println(">>> Perform the training to get the model.");
- LinearRegressionModel mdl = trainer.fit(
- ignite,
- dataCache,
- split.getTrainFilter(),
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
- );
+ System.out.println(">>> Perform the training to get the model.");
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
+ LinearRegressionModel mdl = trainer.fit(ignite, dataCache, split.getTrainFilter(), vectorizer);
- System.out.println(">>> Linear regression model: " + mdl);
+ System.out.println(">>> Linear regression model: " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- ScanQuery<Integer, Vector> qry = new ScanQuery<>();
- qry.setFilter(split.getTestFilter());
+ ScanQuery<Integer, Vector> qry = new ScanQuery<>();
+ qry.setFilter(split.getTestFilter());
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(qry)) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(qry)) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = mdl.predict(inputs);
+ double prediction = mdl.predict(inputs);
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
}
- }
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> Linear regression model over cache based dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
index 185c2c2..926512e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.sql;
-import java.util.List;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -25,13 +24,15 @@ import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.SqlFieldsQuery;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer;
import org.apache.ignite.ml.inference.IgniteModelStorageUtil;
-import org.apache.ignite.ml.sql.SQLFeatureLabelExtractor;
import org.apache.ignite.ml.sql.SQLFunctions;
import org.apache.ignite.ml.sql.SqlDatasetBuilder;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.util.List;
+
/**
* Example of using distributed {@link DecisionTreeClassificationTrainer} on a data stored in SQL table and inference
* made as SQL select query.
@@ -59,79 +60,85 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
.setSqlSchema("PUBLIC")
.setSqlFunctionClasses(SQLFunctions.class);
- IgniteCache<?, ?> cache = ignite.createCache(cacheCfg);
-
- System.out.println(">>> Creating table with training data...");
- cache.query(new SqlFieldsQuery("create table titanik_train (\n" +
- " passengerid int primary key,\n" +
- " survived int,\n" +
- " pclass int,\n" +
- " name varchar(255),\n" +
- " sex varchar(255),\n" +
- " age float,\n" +
- " sibsp int,\n" +
- " parch int,\n" +
- " ticket varchar(255),\n" +
- " fare float,\n" +
- " cabin varchar(255),\n" +
- " embarked varchar(255)\n" +
- ") with \"template=partitioned\";")).getAll();
-
- System.out.println(">>> Filling training data...");
- cache.query(new SqlFieldsQuery("insert into titanik_train select * from csvread('" +
- IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll();
-
- System.out.println(">>> Creating table with test data...");
- cache.query(new SqlFieldsQuery("create table titanik_test (\n" +
- " passengerid int primary key,\n" +
- " pclass int,\n" +
- " name varchar(255),\n" +
- " sex varchar(255),\n" +
- " age float,\n" +
- " sibsp int,\n" +
- " parch int,\n" +
- " ticket varchar(255),\n" +
- " fare float,\n" +
- " cabin varchar(255),\n" +
- " embarked varchar(255)\n" +
- ") with \"template=partitioned\";")).getAll();
-
- System.out.println(">>> Filling training data...");
- cache.query(new SqlFieldsQuery("insert into titanik_test select * from csvread('" +
- IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll();
-
- System.out.println(">>> Prepare trainer...");
- DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
-
- System.out.println(">>> Perform training...");
- DecisionTreeNode mdl = trainer.fit(
- new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIK_TRAIN"),
- new SQLFeatureLabelExtractor()
- .withFeatureFields("pclass", "age", "sibsp", "parch", "fare")
- .withFeatureField("sex", e -> "male".equals(e) ? 1 : 0)
- .withLabelField("survived")
- );
-
- System.out.println(">>> Saving model...");
-
- // Model storage is used to store raw serialized model.
- System.out.println("Saving model into model storage...");
- IgniteModelStorageUtil.saveModel(ignite, mdl, "titanik_model_tree");
-
- // Making inference using saved model.
- System.out.println("Inference...");
- try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " +
- "survived as truth, " +
- "predict('titanik_model_tree', pclass, age, sibsp, parch, fare, case sex when 'male' then 1 else 0 end) as prediction " +
- "from titanik_train"))) {
- // Print inference result.
- System.out.println("| Truth | Prediction |");
- System.out.println("|--------------------|");
- for (List<?> row : cursor)
- System.out.println("| " + row.get(0) + " | " + row.get(1) + " |");
+ IgniteCache<?, ?> cache = null;
+ try {
+ cache = ignite.getOrCreateCache(cacheCfg);
+
+ System.out.println(">>> Creating table with training data...");
+ cache.query(new SqlFieldsQuery("create table titanik_train (\n" +
+ " passengerid int primary key,\n" +
+ " survived int,\n" +
+ " pclass int,\n" +
+ " name varchar(255),\n" +
+ " sex varchar(255),\n" +
+ " age float,\n" +
+ " sibsp int,\n" +
+ " parch int,\n" +
+ " ticket varchar(255),\n" +
+ " fare float,\n" +
+ " cabin varchar(255),\n" +
+ " embarked varchar(255)\n" +
+ ") with \"template=partitioned\";")).getAll();
+
+ System.out.println(">>> Filling training data...");
+ cache.query(new SqlFieldsQuery("insert into titanik_train select * from csvread('" +
+ IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll();
+
+ System.out.println(">>> Creating table with test data...");
+ cache.query(new SqlFieldsQuery("create table titanik_test (\n" +
+ " passengerid int primary key,\n" +
+ " pclass int,\n" +
+ " name varchar(255),\n" +
+ " sex varchar(255),\n" +
+ " age float,\n" +
+ " sibsp int,\n" +
+ " parch int,\n" +
+ " ticket varchar(255),\n" +
+ " fare float,\n" +
+ " cabin varchar(255),\n" +
+ " embarked varchar(255)\n" +
+ ") with \"template=partitioned\";")).getAll();
+
+ System.out.println(">>> Filling training data...");
+ cache.query(new SqlFieldsQuery("insert into titanik_test select * from csvread('" +
+ IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll();
+
+ System.out.println(">>> Prepare trainer...");
+ DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
+
+ System.out.println(">>> Perform training...");
+ DecisionTreeNode mdl = trainer.fit(
+ new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIK_TRAIN"),
+ new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare")
+ .withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0))
+ .labeled("survived")
+ );
+
+ System.out.println(">>> Saving model...");
+
+ // Model storage is used to store raw serialized model.
+ System.out.println("Saving model into model storage...");
+ IgniteModelStorageUtil.saveModel(ignite, mdl, "titanik_model_tree");
+
+ // Making inference using saved model.
+ System.out.println("Inference...");
+ try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " +
+ "survived as truth, " +
+ "predict('titanik_model_tree', pclass, age, sibsp, parch, fare, case sex when 'male' then 1 else 0 end) as prediction " +
+ "from titanik_train"))) {
+ // Print inference result.
+ System.out.println("| Truth | Prediction |");
+ System.out.println("|--------------------|");
+ for (List<?> row : cursor)
+ System.out.println("| " + row.get(0) + " | " + row.get(1) + " |");
+ }
+
+ IgniteModelStorageUtil.removeModel(ignite, "titanik_model_tree");
+ } finally {
+ cache.query(new SqlFieldsQuery("DROP TABLE titanik_train"));
+ cache.query(new SqlFieldsQuery("DROP TABLE titanik_test"));
+ cache.destroy();
}
-
- IgniteModelStorageUtil.removeModel(ignite, "titanik_model_tree");
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
index a4f9a2d..23e4204 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
@@ -17,7 +17,6 @@
package org.apache.ignite.examples.ml.sql;
-import java.util.List;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -25,13 +24,15 @@ import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.SqlFieldsQuery;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.BinaryObjectVectorizer;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.sql.SQLFeatureLabelExtractor;
import org.apache.ignite.ml.sql.SqlDatasetBuilder;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.util.List;
+
/**
* Example of using distributed {@link DecisionTreeClassificationTrainer} on a data stored in SQL table.
*/
@@ -57,84 +58,90 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
CacheConfiguration<?, ?> cacheCfg = new CacheConfiguration<>(DUMMY_CACHE_NAME)
.setSqlSchema("PUBLIC");
- IgniteCache<?, ?> cache = ignite.createCache(cacheCfg);
-
- System.out.println(">>> Creating table with training data...");
- cache.query(new SqlFieldsQuery("create table titanik_train (\n" +
- " passengerid int primary key,\n" +
- " survived int,\n" +
- " pclass int,\n" +
- " name varchar(255),\n" +
- " sex varchar(255),\n" +
- " age float,\n" +
- " sibsp int,\n" +
- " parch int,\n" +
- " ticket varchar(255),\n" +
- " fare float,\n" +
- " cabin varchar(255),\n" +
- " embarked varchar(255)\n" +
- ") with \"template=partitioned\";")).getAll();
-
- System.out.println(">>> Filling training data...");
- cache.query(new SqlFieldsQuery("insert into titanik_train select * from csvread('" +
- IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll();
-
- System.out.println(">>> Creating table with test data...");
- cache.query(new SqlFieldsQuery("create table titanik_test (\n" +
- " passengerid int primary key,\n" +
- " pclass int,\n" +
- " name varchar(255),\n" +
- " sex varchar(255),\n" +
- " age float,\n" +
- " sibsp int,\n" +
- " parch int,\n" +
- " ticket varchar(255),\n" +
- " fare float,\n" +
- " cabin varchar(255),\n" +
- " embarked varchar(255)\n" +
- ") with \"template=partitioned\";")).getAll();
-
- System.out.println(">>> Filling training data...");
- cache.query(new SqlFieldsQuery("insert into titanik_test select * from csvread('" +
- IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll();
-
- System.out.println(">>> Prepare trainer...");
- DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
-
- System.out.println(">>> Perform training...");
- DecisionTreeNode mdl = trainer.fit(
- new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIK_TRAIN"),
- new SQLFeatureLabelExtractor()
- .withFeatureFields("pclass", "age", "sibsp", "parch", "fare")
- .withFeatureField("sex", e -> "male".equals(e) ? 1 : 0)
- .withLabelField("survived")
- );
-
- System.out.println(">>> Perform inference...");
- try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " +
- "pclass, " +
- "sex, " +
- "age, " +
- "sibsp, " +
- "parch, " +
- "fare from titanik_test"))) {
- for (List<?> passenger : cursor) {
- Vector input = VectorUtils.of(new Double[]{
- asDouble(passenger.get(0)),
- "male".equals(passenger.get(1)) ? 1.0 : 0.0,
- asDouble(passenger.get(2)),
- asDouble(passenger.get(3)),
- asDouble(passenger.get(4)),
- asDouble(passenger.get(5))
- });
-
- double prediction = mdl.predict(input);
-
- System.out.printf("Passenger %s will %s.\n", passenger, prediction == 0 ? "die" : "survive");
+ IgniteCache<?, ?> cache = null;
+ try {
+ cache = ignite.getOrCreateCache(cacheCfg);
+
+ System.out.println(">>> Creating table with training data...");
+ cache.query(new SqlFieldsQuery("create table titanik_train (\n" +
+ " passengerid int primary key,\n" +
+ " survived int,\n" +
+ " pclass int,\n" +
+ " name varchar(255),\n" +
+ " sex varchar(255),\n" +
+ " age float,\n" +
+ " sibsp int,\n" +
+ " parch int,\n" +
+ " ticket varchar(255),\n" +
+ " fare float,\n" +
+ " cabin varchar(255),\n" +
+ " embarked varchar(255)\n" +
+ ") with \"template=partitioned\";")).getAll();
+
+ System.out.println(">>> Filling training data...");
+ cache.query(new SqlFieldsQuery("insert into titanik_train select * from csvread('" +
+ IgniteUtils.resolveIgnitePath(TRAIN_DATA_RES).getAbsolutePath() + "')")).getAll();
+
+ System.out.println(">>> Creating table with test data...");
+ cache.query(new SqlFieldsQuery("create table titanik_test (\n" +
+ " passengerid int primary key,\n" +
+ " pclass int,\n" +
+ " name varchar(255),\n" +
+ " sex varchar(255),\n" +
+ " age float,\n" +
+ " sibsp int,\n" +
+ " parch int,\n" +
+ " ticket varchar(255),\n" +
+ " fare float,\n" +
+ " cabin varchar(255),\n" +
+ " embarked varchar(255)\n" +
+ ") with \"template=partitioned\";")).getAll();
+
+ System.out.println(">>> Filling training data...");
+ cache.query(new SqlFieldsQuery("insert into titanik_test select * from csvread('" +
+ IgniteUtils.resolveIgnitePath(TEST_DATA_RES).getAbsolutePath() + "')")).getAll();
+
+ System.out.println(">>> Prepare trainer...");
+ DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
+
+ System.out.println(">>> Perform training...");
+ DecisionTreeNode mdl = trainer.fit(
+ new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIK_TRAIN"),
+ new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare")
+ .withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0))
+ .labeled("survived")
+ );
+
+ System.out.println(">>> Perform inference...");
+ try (QueryCursor<List<?>> cursor = cache.query(new SqlFieldsQuery("select " +
+ "pclass, " +
+ "sex, " +
+ "age, " +
+ "sibsp, " +
+ "parch, " +
+ "fare from titanik_test"))) {
+ for (List<?> passenger : cursor) {
+ Vector input = VectorUtils.of(new Double[] {
+ asDouble(passenger.get(0)),
+ "male".equals(passenger.get(1)) ? 1.0 : 0.0,
+ asDouble(passenger.get(2)),
+ asDouble(passenger.get(3)),
+ asDouble(passenger.get(4)),
+ asDouble(passenger.get(5))
+ });
+
+ double prediction = mdl.predict(input);
+
+ System.out.printf("Passenger %s will %s.\n", passenger, prediction == 0 ? "die" : "survive");
+ }
}
- }
- System.out.println(">>> Example completed.");
+ System.out.println(">>> Example completed.");
+ } finally {
+ cache.query(new SqlFieldsQuery("DROP TABLE titanik_train"));
+ cache.query(new SqlFieldsQuery("DROP TABLE titanik_test"));
+ cache.destroy();
+ }
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
index 3d9c8ab..fd39c77 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/svm/SVMBinaryClassificationExample.java
@@ -20,6 +20,9 @@ package org.apache.ignite.examples.ml.svm;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
@@ -53,33 +56,34 @@ public class SVMBinaryClassificationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
- SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer();
+ SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer();
- IgniteBiFunction<Integer, Vector, Vector> featureExtractor = (k, v) -> v.copyOfRange(1, v.size());
- IgniteBiFunction<Integer, Vector, Double> lbExtractor = (k, v) -> v.get(0);
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- SVMLinearClassificationModel mdl = trainer.fit(
- ignite,
- dataCache,
- featureExtractor,
- lbExtractor
- );
+ SVMLinearClassificationModel mdl = trainer.fit(ignite, dataCache, vectorizer);
- System.out.println(">>> SVM model " + mdl);
+ System.out.println(">>> SVM model " + mdl);
- double accuracy = Evaluator.evaluate(
- dataCache,
- mdl,
- featureExtractor,
- lbExtractor
- ).accuracy();
+ IgniteBiFunction<Integer, Vector, Vector> featureExtractor = CompositionUtils.asFeatureExtractor(vectorizer);
+ IgniteBiFunction<Integer, Vector, Double> lbExtractor = CompositionUtils.asLabelExtractor(vectorizer);
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ featureExtractor,
+ lbExtractor
+ ).accuracy();
- System.out.println("\n>>> Accuracy " + accuracy);
+ System.out.println("\n>>> Accuracy " + accuracy);
- System.out.println(">>> SVM Binary classification model over cache based dataset usage example completed.");
+ System.out.println(">>> SVM Binary classification model over cache based dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
+ }
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
index 606660f..9787719 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
@@ -17,17 +17,20 @@
package org.apache.ignite.examples.ml.tree;
-import java.util.Random;
import org.apache.commons.math3.util.Precision;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
import org.apache.ignite.ml.tree.DecisionTreeNode;
+import java.util.Random;
+
/**
* Example of using distributed {@link DecisionTreeClassificationTrainer}.
* <p>
@@ -54,48 +57,53 @@ public class DecisionTreeClassificationTrainerExample {
System.out.println(">>> Ignite grid started.");
// Create cache with training data.
- CacheConfiguration<Integer, LabeledPoint> trainingSetCfg = new CacheConfiguration<>();
+ CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
trainingSetCfg.setName("TRAINING_SET");
trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
- IgniteCache<Integer, LabeledPoint> trainingSet = ignite.createCache(trainingSetCfg);
+ IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
+ try {
+ trainingSet = ignite.createCache(trainingSetCfg);
- Random rnd = new Random(0);
+ Random rnd = new Random(0);
- // Fill training data.
- for (int i = 0; i < 1000; i++)
- trainingSet.put(i, generatePoint(rnd));
+ // Fill training data.
+ for (int i = 0; i < 1000; i++)
+ trainingSet.put(i, generatePoint(rnd));
- // Create classification trainer.
- DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
+ // Create classification trainer.
+ DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
- // Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(
- ignite,
- trainingSet,
- (k, v) -> VectorUtils.of(v.x, v.y),
- (k, v) -> v.lb
- );
+ // Train decision tree model.
+ LabeledDummyVectorizer<Integer, Double> vectorizer = new LabeledDummyVectorizer<>();
+ DecisionTreeNode mdl = trainer.fit(
+ ignite,
+ trainingSet,
+ vectorizer
+ );
- System.out.println(">>> Decision tree classification model: " + mdl);
+ System.out.println(">>> Decision tree classification model: " + mdl);
- // Calculate score.
- int correctPredictions = 0;
- for (int i = 0; i < 1000; i++) {
- LabeledPoint pnt = generatePoint(rnd);
+ // Calculate score.
+ int correctPredictions = 0;
+ for (int i = 0; i < 1000; i++) {
+ LabeledVector<Double> pnt = generatePoint(rnd);
- double prediction = mdl.predict(VectorUtils.of(pnt.x, pnt.y));
- double lbl = pnt.lb;
+ double prediction = mdl.predict(pnt.features());
+ double lbl = pnt.label();
- if (i %50 == 1)
- System.out.printf(">>> test #: %d\t\t predicted: %.4f\t\tlabel: %.4f\n", i, prediction, lbl);
+ if (i % 50 == 1)
+ System.out.printf(">>> test #: %d\t\t predicted: %.4f\t\tlabel: %.4f\n", i, prediction, lbl);
- if (Precision.equals(prediction, lbl, Precision.EPSILON))
- correctPredictions++;
- }
+ if (Precision.equals(prediction, lbl, Precision.EPSILON))
+ correctPredictions++;
+ }
- System.out.println(">>> Accuracy: " + correctPredictions / 10.0 + "%");
- System.out.println(">>> Decision tree classification trainer example completed.");
+ System.out.println(">>> Accuracy: " + correctPredictions / 10.0 + "%");
+ System.out.println(">>> Decision tree classification trainer example completed.");
+ } finally {
+ trainingSet.destroy();
+ }
}
}
@@ -106,49 +114,11 @@ public class DecisionTreeClassificationTrainerExample {
* @param rnd Random.
* @return Point with label.
*/
- private static LabeledPoint generatePoint(Random rnd) {
+ private static LabeledVector<Double> generatePoint(Random rnd) {
double x = rnd.nextDouble() - 0.5;
double y = rnd.nextDouble() - 0.5;
- return new LabeledPoint(x, y, x * y > 0 ? 1 : 0);
- }
-
- /** Point data class. */
- private static class Point {
- /** X coordinate. */
- final double x;
-
- /** Y coordinate. */
- final double y;
-
- /**
- * Constructs a new instance of point.
- *
- * @param x X coordinate.
- * @param y Y coordinate.
- */
- Point(double x, double y) {
- this.x = x;
- this.y = y;
- }
- }
-
- /** Labeled point data class. */
- private static class LabeledPoint extends Point {
- /** Point label. */
- final double lb;
-
- /**
- * Constructs a new instance of labeled point data.
- *
- * @param x X coordinate.
- * @param y Y coordinate.
- * @param lb Point label.
- */
- LabeledPoint(double x, double y, double lb) {
- super(x, y);
- this.lb = lb;
- }
+ return new LabeledVector<>(VectorUtils.of(x,y), x * y > 0 ? 1. : 0.);
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
index 3e37646..568aac3 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
@@ -22,7 +22,9 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.tree.DecisionTreeNode;
import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
@@ -53,74 +55,54 @@ public class DecisionTreeRegressionTrainerExample {
System.out.println(">>> Ignite grid started.");
// Create cache with training data.
- CacheConfiguration<Integer, Point> trainingSetCfg = new CacheConfiguration<>();
+ CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
trainingSetCfg.setName("TRAINING_SET");
trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
- IgniteCache<Integer, Point> trainingSet = ignite.createCache(trainingSetCfg);
+ IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
+ try {
+ trainingSet = ignite.createCache(trainingSetCfg);
- // Fill training data.
- generatePoints(trainingSet);
+ // Fill training data.
+ generatePoints(trainingSet);
- // Create regression trainer.
- DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
+ // Create regression trainer.
+ DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
- // Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(
- ignite,
- trainingSet,
- (k, v) -> VectorUtils.of(v.x),
- (k, v) -> v.y
- );
+ // Train decision tree model.
+ DecisionTreeNode mdl = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
- System.out.println(">>> Decision tree regression model: " + mdl);
+ System.out.println(">>> Decision tree regression model: " + mdl);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Ground Truth\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
- // Calculate score.
- for (int x = 0; x < 10; x++) {
- double predicted = mdl.predict(VectorUtils.of(x));
+ // Calculate score.
+ for (int x = 0; x < 10; x++) {
+ double predicted = mdl.predict(VectorUtils.of(x));
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x));
- }
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x));
+ }
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
- System.out.println(">>> Decision tree regression trainer example completed.");
+ System.out.println(">>> Decision tree regression trainer example completed.");
+ } finally {
+ trainingSet.destroy();
+ }
}
}
/**
* Generates {@code sin(x)} on interval {@code [0, 10)} and loads into the specified cache.
*/
- private static void generatePoints(IgniteCache<Integer, Point> trainingSet) {
+ private static void generatePoints(IgniteCache<Integer, LabeledVector<Double>> trainingSet) {
for (int i = 0; i < 1000; i++) {
double x = i / 100.0;
double y = Math.sin(x);
- trainingSet.put(i, new Point(x, y));
- }
- }
-
- /** Point data class. */
- private static class Point {
- /** X coordinate. */
- final double x;
-
- /** Y coordinate. */
- final double y;
-
- /**
- * Constructs a new instance of point.
- *
- * @param x X coordinate.
- * @param y Y coordinate.
- */
- Point(double x, double y) {
- this.x = x;
- this.y = y;
+ trainingSet.put(i, new LabeledVector<>(VectorUtils.of(x), y));
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
index fd46556..1d58daa 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
@@ -24,6 +24,8 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.ArraysVectorizer;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
@@ -51,36 +53,40 @@ public class GDBOnTreesClassificationTrainerExample {
// Create cache with training data.
CacheConfiguration<Integer, double[]> trainingSetCfg = createCacheConfiguration();
- IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg);
+ IgniteCache<Integer, double[]> trainingSet = null;
+ try {
+ trainingSet = fillTrainingData(ignite, trainingSetCfg);
- // Create regression trainer.
- DatasetTrainer<ModelsComposition, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.)
- .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.1));
+ // Create regression trainer.
+ DatasetTrainer<ModelsComposition, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.)
+ .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.1));
- // Train decision tree model.
- ModelsComposition mdl = trainer.fit(
- ignite,
- trainingSet,
- (k, v) -> VectorUtils.of(v[0]),
- (k, v) -> v[1]
- );
+ // Train decision tree model.
+ ModelsComposition mdl = trainer.fit(
+ ignite,
+ trainingSet,
+ new ArraysVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
+ );
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Valid answer\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Valid answer\t|");
+ System.out.println(">>> ---------------------------------");
- // Calculate score.
- for (int x = -5; x < 5; x++) {
- double predicted = mdl.predict(VectorUtils.of(x));
+ // Calculate score.
+ for (int x = -5; x < 5; x++) {
+ double predicted = mdl.predict(VectorUtils.of(x));
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x) < 0 ? 0.0 : 1.0);
- }
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x) < 0 ? 0.0 : 1.0);
+ }
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> Count of trees = " + mdl.getModels().size());
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> Count of trees = " + mdl.getModels().size());
+ System.out.println(">>> ---------------------------------");
- System.out.println(">>> GDB classification trainer example completed.");
+ System.out.println(">>> GDB classification trainer example completed.");
+ } finally {
+ trainingSet.destroy();
+ }
}
}
@@ -102,7 +108,7 @@ public class GDBOnTreesClassificationTrainerExample {
*/
@NotNull private static IgniteCache<Integer, double[]> fillTrainingData(Ignite ignite,
CacheConfiguration<Integer, double[]> trainingSetCfg) {
- IgniteCache<Integer, double[]> trainingSet = ignite.createCache(trainingSetCfg);
+ IgniteCache<Integer, double[]> trainingSet = ignite.getOrCreateCache(trainingSetCfg);
for(int i = -50; i <= 50; i++) {
double x = ((double)i) / 10.0;
double y = Math.sin(x) < 0 ? 0.0 : 1.0;
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
index d04415a..9ba4b83 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
@@ -24,6 +24,8 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.ArraysVectorizer;
import org.apache.ignite.ml.inference.Model;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -53,33 +55,37 @@ public class GDBOnTreesRegressionTrainerExample {
// Create cache with training data.
CacheConfiguration<Integer, double[]> trainingSetCfg = createCacheConfiguration();
- IgniteCache<Integer, double[]> trainingSet = fillTrainingData(ignite, trainingSetCfg);
+ IgniteCache<Integer, double[]> trainingSet = null;
+ try {
+ trainingSet = fillTrainingData(ignite, trainingSetCfg);
- // Create regression trainer.
- DatasetTrainer<ModelsComposition, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.)
- .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.001));
+ // Create regression trainer.
+ DatasetTrainer<ModelsComposition, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.)
+ .withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.001));
- // Train decision tree model.
- Model<Vector, Double> mdl = trainer.fit(
- ignite,
- trainingSet,
- (k, v) -> VectorUtils.of(v[0]),
- (k, v) -> v[1]
- );
+ // Train decision tree model.
+ Model<Vector, Double> mdl = trainer.fit(
+ ignite,
+ trainingSet,
+ new ArraysVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
+ );
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Valid answer \t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Valid answer \t|");
+ System.out.println(">>> ---------------------------------");
- // Calculate score.
- for (int x = -5; x < 5; x++) {
- double predicted = mdl.predict(VectorUtils.of(x));
+ // Calculate score.
+ for (int x = -5; x < 5; x++) {
+ double predicted = mdl.predict(VectorUtils.of(x));
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.pow(x, 2));
- }
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.pow(x, 2));
+ }
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> GDB regression trainer example completed.");
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> GDB regression trainer example completed.");
+ } finally {
+ trainingSet.destroy();
+ }
}
}
@@ -101,7 +107,7 @@ public class GDBOnTreesRegressionTrainerExample {
*/
@NotNull private static IgniteCache<Integer, double[]> fillTrainingData(Ignite ignite,
CacheConfiguration<Integer, double[]> trainingSetCfg) {
- IgniteCache<Integer, double[]> trainingSet = ignite.createCache(trainingSetCfg);
+ IgniteCache<Integer, double[]> trainingSet = ignite.getOrCreateCache(trainingSetCfg);
for(int i = -50; i <= 50; i++) {
double x = ((double)i) / 10.0;
double y = Math.pow(x, 2);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java
index fd95033..13e6792 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestClassificationExample.java
@@ -17,11 +17,6 @@
package org.apache.ignite.examples.ml.tree.randomforest;
-import java.io.FileNotFoundException;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import javax.cache.Cache;
import org.apache.commons.math3.util.Precision;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
@@ -30,12 +25,20 @@ import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.dataset.feature.FeatureMeta;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.tree.randomforest.RandomForestClassifierTrainer;
import org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies;
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import javax.cache.Cache;
+import java.io.FileNotFoundException;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
/**
* Example represents a solution for the task of wine classification based on a
* <a href ="https://en.wikipedia.org/wiki/Random_forest">Random Forest</a> implementation for
@@ -63,50 +66,54 @@ public class RandomForestClassificationExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.WINE_RECOGNITION);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.WINE_RECOGNITION);
- AtomicInteger idx = new AtomicInteger(0);
- RandomForestClassifierTrainer classifier = new RandomForestClassifierTrainer(
- IntStream.range(0, dataCache.get(1).size() - 1).mapToObj(
- x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList())
- ).withAmountOfTrees(101)
- .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
- .withMaxDepth(4)
- .withMinImpurityDelta(0.)
- .withSubSampleSize(0.3)
- .withSeed(0);
+ AtomicInteger idx = new AtomicInteger(0);
+ RandomForestClassifierTrainer classifier = new RandomForestClassifierTrainer(
+ IntStream.range(0, dataCache.get(1).size() - 1).mapToObj(
+ x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList())
+ ).withAmountOfTrees(101)
+ .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
+ .withMaxDepth(4)
+ .withMinImpurityDelta(0.)
+ .withSubSampleSize(0.3)
+ .withSeed(0);
- System.out.println(">>> Configured trainer: " + classifier.getClass().getSimpleName());
+ System.out.println(">>> Configured trainer: " + classifier.getClass().getSimpleName());
- ModelsComposition randomForestMdl = classifier.fit(ignite, dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
- );
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
+ ModelsComposition randomForestMdl = classifier.fit(ignite, dataCache, vectorizer);
- System.out.println(">>> Trained model: " + randomForestMdl.toString(true));
+ System.out.println(">>> Trained model: " + randomForestMdl.toString(true));
- int amountOfErrors = 0;
- int totalAmount = 0;
+ int amountOfErrors = 0;
+ int totalAmount = 0;
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = randomForestMdl.predict(inputs);
+ double prediction = randomForestMdl.predict(inputs);
- totalAmount++;
- if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
- amountOfErrors++;
- }
+ totalAmount++;
+ if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
+ amountOfErrors++;
+ }
+
+ System.out.println("\n>>> Evaluated model on " + totalAmount + " data points.");
- System.out.println("\n>>> Evaluated model on " + totalAmount + " data points.");
+ System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
+ System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double)totalAmount));
+ System.out.println(">>> Random Forest multi-class classification algorithm over cached dataset usage example completed.");
+ }
- System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
- System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
- System.out.println(">>> Random Forest multi-class classification algorithm over cached dataset usage example completed.");
+ } finally {
+ dataCache.destroy();
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java
index e1bbc8b..42f818e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/randomforest/RandomForestRegressionExample.java
@@ -17,11 +17,6 @@
package org.apache.ignite.examples.ml.tree.randomforest;
-import java.io.FileNotFoundException;
-import java.util.concurrent.atomic.AtomicInteger;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
-import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -29,9 +24,10 @@ import org.apache.ignite.cache.query.QueryCursor;
import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.dataset.feature.FeatureMeta;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.environment.logging.ConsoleLogger;
-import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer;
@@ -39,6 +35,12 @@ import org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrateg
import org.apache.ignite.ml.util.MLSandboxDatasets;
import org.apache.ignite.ml.util.SandboxMLCache;
+import javax.cache.Cache;
+import java.io.FileNotFoundException;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
/**
* Example represents a solution for the task of price predictions for houses in Boston based on a
* <a href ="https://en.wikipedia.org/wiki/Random_forest">Random Forest</a> implementation for regression.
@@ -66,61 +68,64 @@ public class RandomForestRegressionExample {
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
System.out.println(">>> Ignite grid started.");
- IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
- .fillCacheWith(MLSandboxDatasets.BOSTON_HOUSE_PRICES);
+ IgniteCache<Integer, Vector> dataCache = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.BOSTON_HOUSE_PRICES);
- AtomicInteger idx = new AtomicInteger(0);
- RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(
- IntStream.range(0, dataCache.get(1).size() - 1).mapToObj(
- x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList())
- ).withAmountOfTrees(101)
- .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
- .withMaxDepth(4)
- .withMinImpurityDelta(0.)
- .withSubSampleSize(0.3)
- .withSeed(0);
+ AtomicInteger idx = new AtomicInteger(0);
+ RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(
+ IntStream.range(0, dataCache.get(1).size() - 1).mapToObj(
+ x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList())
+ ).withAmountOfTrees(101)
+ .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
+ .withMaxDepth(4)
+ .withMinImpurityDelta(0.)
+ .withSubSampleSize(0.3)
+ .withSeed(0);
- trainer.withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder()
- .withParallelismStrategyTypeDependency(part -> ParallelismStrategy.Type.ON_DEFAULT_POOL)
- .withLoggingFactoryDependency(part -> ConsoleLogger.factory(MLLogger.VerboseLevel.LOW))
- );
+ trainer.withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder()
+ .withParallelismStrategyTypeDependency(ParallelismStrategy.ON_DEFAULT_POOL)
+ .withLoggingFactoryDependency(ConsoleLogger.Factory.LOW)
+ );
- System.out.println(">>> Configured trainer: " + trainer.getClass().getSimpleName());
+ System.out.println(">>> Configured trainer: " + trainer.getClass().getSimpleName());
- ModelsComposition randomForestMdl = trainer.fit(ignite, dataCache,
- (k, v) -> v.copyOfRange(1, v.size()),
- (k, v) -> v.get(0)
- );
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
+ ModelsComposition randomForestMdl = trainer.fit(ignite, dataCache, vectorizer);
- System.out.println(">>> Trained model: " + randomForestMdl.toString(true));
+ System.out.println(">>> Trained model: " + randomForestMdl.toString(true));
- double mse = 0.0;
- double mae = 0.0;
- int totalAmount = 0;
+ double mse = 0.0;
+ double mae = 0.0;
+ int totalAmount = 0;
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
- double prediction = randomForestMdl.predict(inputs);
+ double prediction = randomForestMdl.predict(inputs);
- mse += Math.pow(prediction - groundTruth, 2.0);
- mae += Math.abs(prediction - groundTruth);
+ mse += Math.pow(prediction - groundTruth, 2.0);
+ mae += Math.abs(prediction - groundTruth);
- totalAmount++;
- }
+ totalAmount++;
+ }
- System.out.println("\n>>> Evaluated model on " + totalAmount + " data points.");
+ System.out.println("\n>>> Evaluated model on " + totalAmount + " data points.");
- mse = mse / totalAmount;
- System.out.println("\n>>> Mean squared error (MSE) " + mse);
+ mse = mse / totalAmount;
+ System.out.println("\n>>> Mean squared error (MSE) " + mse);
- mae = mae / totalAmount;
- System.out.println("\n>>> Mean absolute error (MAE) " + mae);
+ mae = mae / totalAmount;
+ System.out.println("\n>>> Mean absolute error (MAE) " + mae);
- System.out.println(">>> Random Forest regression algorithm over cached dataset usage example completed.");
+ System.out.println(">>> Random Forest regression algorithm over cached dataset usage example completed.");
+ }
+ } finally {
+ dataCache.destroy();
}
}
}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java
index cb88ebf..6c9f3e5 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_10_Scaling_With_Stacking.java
@@ -22,6 +22,7 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.ml.composition.stacking.StackedModel;
import org.apache.ignite.ml.composition.stacking.StackedVectorDatasetTrainer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
@@ -106,7 +107,7 @@ public class Step_10_Scaling_With_Stacking {
LogisticRegressionSGDTrainer aggregator = new LogisticRegressionSGDTrainer()
.withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2),
- SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg));
+ SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG));
StackedModel<Vector, Vector, Double, LogisticRegressionModel> mdl =
new StackedVectorDatasetTrainer<>(aggregator)
@@ -116,8 +117,7 @@ public class Step_10_Scaling_With_Stacking {
.fit(
ignite,
dataCache,
- normalizationPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(normalizationPreprocessor, lbExtractor) //TODO: IGNITE-11581
);
System.out.println("\n>>> Trained model: " + mdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
index 34d6fe8..2bc6f82 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -52,7 +53,7 @@ public class Step_1_Read_and_Learn {
IgniteCache<Integer, Object[]> dataCache = TitanicUtils.readPassengers(ignite);
IgniteBiFunction<Integer, Object[], Vector> featureExtractor = (k, v) -> {
- double[] data = new double[]{(double) v[0], (double) v[5], (double) v[6]};
+ double[] data = new double[] {(double)v[0], (double)v[5], (double)v[6]};
data[0] = Double.isNaN(data[0]) ? 0 : data[0];
data[1] = Double.isNaN(data[1]) ? 0 : data[1];
data[2] = Double.isNaN(data[2]) ? 0 : data[2];
@@ -60,15 +61,17 @@ public class Step_1_Read_and_Learn {
return VectorUtils.of(data);
};
- IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double) v[1];
+ IgniteBiFunction<Integer, Object[], Double> lbExtractor = (k, v) -> (double)v[1];
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
DecisionTreeNode mdl = trainer.fit(
ignite,
dataCache,
- featureExtractor, // "pclass", "sibsp", "parch"
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap( //TODO: IGNITE-11581
+ featureExtractor, // "pclass", "sibsp", "parch"
+ lbExtractor
+ )
);
System.out.println("\n>>> Trained model: " + mdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
index 72ae0cb..dfdd327 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
@@ -70,8 +71,7 @@ public class Step_2_Imputing {
DecisionTreeNode mdl = trainer.fit(
ignite,
dataCache,
- imputingPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(imputingPreprocessor, lbExtractor) //TODO: IGNITE-11581
);
System.out.println("\n>>> Trained model: " + mdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
index 337421e..54b6489 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
@@ -83,8 +84,7 @@ public class Step_3_Categorial {
DecisionTreeNode mdl = trainer.fit(
ignite,
dataCache,
- imputingPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(imputingPreprocessor, lbExtractor) //TODO: IGNITE-11581
);
System.out.println("\n>>> Trained model: " + mdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
index d390fec..c12edf8 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
@@ -86,8 +87,7 @@ public class Step_3_Categorial_with_One_Hot_Encoder {
DecisionTreeNode mdl = trainer.fit(
ignite,
dataCache,
- imputingPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(imputingPreprocessor, lbExtractor) //TODO: IGNITE-11581
);
System.out.println("\n>>> Trained model: " + mdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
index 6b7f6be..b541760 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
@@ -82,8 +83,7 @@ public class Step_4_Add_age_fare {
DecisionTreeNode mdl = trainer.fit(
ignite,
dataCache,
- imputingPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(imputingPreprocessor, lbExtractor) //TODO: IGNITE-11581
);
System.out.println("\n>>> Trained model: " + mdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
index ca595ef..67d5b43 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
@@ -100,8 +101,7 @@ public class Step_5_Scaling {
DecisionTreeNode mdl = trainer.fit(
ignite,
dataCache,
- normalizationPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(normalizationPreprocessor, lbExtractor) //TODO: IGNITE-11581
);
System.out.println("\n>>> Trained model: " + mdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
index a4ba699..1cc375f 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_6_KNN.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.knn.NNClassificationModel;
import org.apache.ignite.ml.knn.classification.KNNClassificationTrainer;
import org.apache.ignite.ml.knn.classification.NNStrategy;
@@ -101,8 +102,7 @@ public class Step_6_KNN {
NNClassificationModel mdl = trainer.fit(
ignite,
dataCache,
- normalizationPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(normalizationPreprocessor, lbExtractor) //TODO: IGNITE-11581
).withK(1).withStrategy(NNStrategy.WEIGHTED);
System.out.println("\n>>> Trained model: " + mdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
index 350145f..d728df1 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
@@ -107,8 +108,7 @@ public class Step_7_Split_train_test {
ignite,
dataCache,
split.getTrainFilter(),
- normalizationPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(normalizationPreprocessor, lbExtractor) //TODO: IGNITE-11581
);
System.out.println("\n>>> Trained model: " + mdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
index 175133fc..567f1b0 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
@@ -170,8 +171,7 @@ public class Step_8_CV {
ignite,
dataCache,
split.getTrainFilter(),
- normalizationPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(normalizationPreprocessor, lbExtractor) //TODO: IGNITE-11581
);
System.out.println("\n>>> Trained model: " + bestMdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
index 325b656..1d1476f 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
@@ -158,8 +159,7 @@ public class Step_8_CV_with_Param_Grid {
ignite,
dataCache,
split.getTrainFilter(),
- normalizationPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(normalizationPreprocessor, lbExtractor) //TODO: IGNITE-11581
);
System.out.println("\n>>> Trained model: " + bestMdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
index a12dcc2..d8bb5ef 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_metrics.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.encoding.EncoderTrainer;
@@ -165,8 +166,7 @@ public class Step_8_CV_with_Param_Grid_and_metrics {
ignite,
dataCache,
split.getTrainFilter(),
- normalizationPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(normalizationPreprocessor, lbExtractor) //TODO: IGNITE-11581
);
System.out.println("\n>>> Trained model: " + bestMdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
index 5c0ad57..f479264 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_9_Go_to_LogReg.java
@@ -20,6 +20,7 @@ package org.apache.ignite.examples.ml.tutorial;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.nn.UpdatesStrategy;
@@ -127,7 +128,7 @@ public class Step_9_Go_to_LogReg {
LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
.withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(learningRate),
- SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+ SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG))
.withMaxIterations(maxIterations)
.withLocIterations(locIterations)
.withBatchSize(batchSize)
@@ -191,7 +192,7 @@ public class Step_9_Go_to_LogReg {
LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
.withUpdatesStgy(new UpdatesStrategy<>(new SimpleGDUpdateCalculator(bestLearningRate),
- SimpleGDParameterUpdate::sumLocal, SimpleGDParameterUpdate::avg))
+ SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG))
.withMaxIterations(bestMaxIterations)
.withLocIterations(bestLocIterations)
.withBatchSize(bestBatchSize)
@@ -202,8 +203,7 @@ public class Step_9_Go_to_LogReg {
ignite,
dataCache,
split.getTrainFilter(),
- normalizationPreprocessor,
- lbExtractor
+ FeatureLabelExtractorWrapper.wrap(normalizationPreprocessor, lbExtractor) //TODO: IGNITE-11581
);
System.out.println("\n>>> Trained model: " + bestMdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TitanicUtils.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TitanicUtils.java
index 1927a8c..8c5925e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TitanicUtils.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/TitanicUtils.java
@@ -17,17 +17,18 @@
package org.apache.ignite.examples.ml.tutorial;
-import java.io.File;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.internal.util.IgniteUtils;
+
import java.io.FileNotFoundException;
import java.text.NumberFormat;
import java.text.ParseException;
import java.util.Locale;
import java.util.Scanner;
import java.util.UUID;
-import org.apache.ignite.Ignite;
-import org.apache.ignite.IgniteCache;
-import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
-import org.apache.ignite.configuration.CacheConfiguration;
/**
* The utility class.
@@ -43,7 +44,7 @@ public class TitanicUtils {
public static IgniteCache<Integer, Object[]> readPassengers(Ignite ignite)
throws FileNotFoundException {
IgniteCache<Integer, Object[]> cache = getCache(ignite);
- Scanner scanner = new Scanner(new File("examples/src/main/resources/datasets/titanic.csv"));
+ Scanner scanner = new Scanner(IgniteUtils.resolveIgnitePath("examples/src/main/resources/datasets/titanic.csv"));
int cnt = 0;
while (scanner.hasNextLine()) {
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/util/generators/DatasetCreationExamples.java b/examples/src/main/java/org/apache/ignite/examples/ml/util/generators/DatasetCreationExamples.java
index 42f0500..f5ea6c5 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/util/generators/DatasetCreationExamples.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/util/generators/DatasetCreationExamples.java
@@ -17,9 +17,6 @@
package org.apache.ignite.examples.ml.util.generators;
-import java.util.UUID;
-import java.util.stream.DoubleStream;
-import java.util.stream.Stream;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -27,6 +24,7 @@ import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.configuration.IgniteConfiguration;
import org.apache.ignite.ml.dataset.Dataset;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.FeatureLabelExtractorWrapper;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
@@ -39,6 +37,10 @@ import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.util.generators.DataStreamGenerator;
import org.apache.ignite.ml.util.generators.primitives.scalar.UniformRandomProducer;
+import java.util.UUID;
+import java.util.stream.DoubleStream;
+import java.util.stream.Stream;
+
/**
* Examples of using {@link DataStreamGenerator} methods for filling cache or creating local datasets.
*/
@@ -68,7 +70,7 @@ public class DatasetCreationExamples {
double meanFromLocDataset;
try (Dataset<EmptyContext, SimpleDatasetData> dataset = generator.asDatasetBuilder(DATASET_SIZE, 10)
.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(),
- new SimpleDatasetDataBuilder<>((k, v) -> k))) {
+ new SimpleDatasetDataBuilder<>(FeatureLabelExtractorWrapper.wrap((k, v) -> k)))) {
meanFromLocDataset = dataset.compute(
data -> DoubleStream.of(data.getFeatures()).sum(),
@@ -82,15 +84,19 @@ public class DatasetCreationExamples {
IgniteConfiguration configuration = new IgniteConfiguration().setPeerClassLoadingEnabled(true);
try (Ignite ignite = Ignition.start(configuration)) {
String cacheName = "TEST_CACHE";
- IgniteCache<UUID, LabeledVector<Double>> withCustomKeyCache = ignite.getOrCreateCache(
- new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName)
- .setAffinity(new RendezvousAffinityFunction(false, 10))
- );
-
- // DataStreamGenerator can fill cache with vectors as values and HashCodes/random UUID/custom keys.
- generator.fillCacheWithVecUUIDAsKey(DATASET_SIZE, withCustomKeyCache);
- meanFromCache = computeMean(ignite, withCustomKeyCache);
- ignite.destroyCache(cacheName);
+ IgniteCache<UUID, LabeledVector<Double>> withCustomKeyCache = null;
+ try {
+ withCustomKeyCache = ignite.getOrCreateCache(
+ new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName)
+ .setAffinity(new RendezvousAffinityFunction(false, 10))
+ );
+
+ // DataStreamGenerator can fill cache with vectors as values and HashCodes/random UUID/custom keys.
+ generator.fillCacheWithVecUUIDAsKey(DATASET_SIZE, withCustomKeyCache);
+ meanFromCache = computeMean(ignite, withCustomKeyCache);
+ } finally {
+ ignite.destroyCache(cacheName);
+ }
}
// Results should be near to expected value.
@@ -135,7 +141,7 @@ public class DatasetCreationExamples {
try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset =
builder.build(LearningEnvironmentBuilder.defaultBuilder(),
new EmptyContextBuilder<>(),
- new SimpleDatasetDataBuilder<>((k, v) -> v.features()))) {
+ new SimpleDatasetDataBuilder<>(FeatureLabelExtractorWrapper.wrap((k, v) -> v.features())))) {
result = dataset.compute(
data -> DoubleStream.of(data.getFeatures()).sum(),
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregator.java
index c36d030..9f7d6e4 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/CovarianceMatricesAggregator.java
@@ -17,16 +17,17 @@
package org.apache.ignite.ml.clustering.gmm;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.primitives.matrix.Matrix;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+
/**
* This class encapsulates statistics aggregation logic for feature vector covariance matrix computation of one GMM
* component (cluster).
@@ -107,7 +108,7 @@ public class CovarianceMatricesAggregator implements Serializable {
/**
* @param other Other.
- * @return sum of aggregators.
+ * @return Sum of aggregators.
*/
CovarianceMatricesAggregator plus(CovarianceMatricesAggregator other) {
A.ensure(this.mean.equals(other.mean), "this.mean == other.mean");
@@ -143,7 +144,7 @@ public class CovarianceMatricesAggregator implements Serializable {
/**
* @param clusterProb GMM component probability.
- * @return computed covariance matrix.
+ * @return Computed covariance matrix.
*/
private Matrix covariance(double clusterProb) {
return weightedSum.divide(rowCount * clusterProb);
@@ -174,21 +175,21 @@ public class CovarianceMatricesAggregator implements Serializable {
}
/**
- * @return mean vector.
+ * @return Mean vector.
*/
Vector mean() {
return mean.copy();
}
/**
- * @return weighted sum.
+ * @return Weighted sum.
*/
Matrix weightedSum() {
return weightedSum.copy();
}
/**
- * @return rows count.
+ * @return Rows count.
*/
public int rowCount() {
return rowCount;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java
index 1b8e50c..c83eb88 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmPartitionData.java
@@ -17,20 +17,22 @@
package org.apache.ignite.ml.clustering.gmm;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Iterator;
-import java.util.List;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution;
import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
/**
* Partition data for GMM algorithm. Unlike partition data for other algorithms this class aggregate probabilities of
@@ -65,10 +67,20 @@ class GmmPartitionData implements AutoCloseable {
}
/**
- * @return all vectors from partition.
+ * Updates P(c|xi) values in partitions and compute dataset likelihood.
+ *
+ * @param dataset Dataset.
+ * @param clusterProbs Component probabilities.
+ * @param components Components.
+ * @return Dataset likelihood.
*/
- public List<LabeledVector<Double>> getAllXs() {
- return Collections.unmodifiableList(xs);
+ static double updatePcxiAndComputeLikelihood(Dataset<EmptyContext, GmmPartitionData> dataset, Vector clusterProbs,
+ List<MultivariateGaussianDistribution> components) {
+
+ return dataset.compute(
+ data -> updatePcxi(data, clusterProbs, components),
+ (left, right) -> asPrimitive(left) + asPrimitive(right)
+ );
}
/**
@@ -90,10 +102,10 @@ class GmmPartitionData implements AutoCloseable {
}
/**
- * @return size of dataset partition.
+ * @return All vectors from partition.
*/
- public int size() {
- return pcxi.length;
+ public List<LabeledVector<Double>> getAllXs() {
+ return Collections.unmodifiableList(xs);
}
/** {@inheritDoc} */
@@ -104,12 +116,12 @@ class GmmPartitionData implements AutoCloseable {
/**
* Builder for GMM partition data.
*/
- public static class Builder<K, V> implements PartitionDataBuilder<K, V, EmptyContext, GmmPartitionData> {
+ public static class Builder<K, V, C extends Serializable> implements PartitionDataBuilder<K, V, EmptyContext, GmmPartitionData> {
/** Serial version uid. */
private static final long serialVersionUID = 1847063348042022561L;
- /** Extractor. */
- private final FeatureLabelExtractor<K, V, Double> extractor;
+ /** Upsteam vectorizer. */
+ private final Vectorizer<K, V, C, Double> extractor;
/** Count of components of mixture. */
private final int countOfComponents;
@@ -120,7 +132,7 @@ class GmmPartitionData implements AutoCloseable {
* @param extractor Extractor.
* @param countOfComponents Count of components.
*/
- public Builder(FeatureLabelExtractor<K, V, Double> extractor, int countOfComponents) {
+ public Builder(Vectorizer<K, V, C, Double> extractor, int countOfComponents) {
this.extractor = extractor;
this.countOfComponents = countOfComponents;
}
@@ -169,20 +181,10 @@ class GmmPartitionData implements AutoCloseable {
}
/**
- * Updates P(c|xi) values in partitions and compute dataset likelihood.
- *
- * @param dataset Dataset.
- * @param clusterProbs Component probabilities.
- * @param components Components.
- * @return dataset likelihood.
+ * @return Size of dataset partition.
*/
- static double updatePcxiAndComputeLikelihood(Dataset<EmptyContext, GmmPartitionData> dataset, Vector clusterProbs,
- List<MultivariateGaussianDistribution> components) {
-
- return dataset.compute(
- data -> updatePcxi(data, clusterProbs, components),
- (left, right) -> asPrimitive(left) + asPrimitive(right)
- );
+ public int size() {
+ return pcxi.length;
}
/**
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java
index 09e93f6..c04ae3d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmTrainer.java
@@ -17,17 +17,10 @@
package org.apache.ignite.ml.clustering.gmm;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Comparator;
-import java.util.List;
-import java.util.Optional;
-import java.util.stream.Collectors;
-import java.util.stream.DoubleStream;
-import java.util.stream.Stream;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironment;
@@ -40,9 +33,14 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution;
import org.apache.ignite.ml.structures.DatasetRow;
import org.apache.ignite.ml.trainers.DatasetTrainer;
-import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.jetbrains.annotations.NotNull;
+import java.io.Serializable;
+import java.util.*;
+import java.util.stream.Collectors;
+import java.util.stream.DoubleStream;
+import java.util.stream.Stream;
+
/**
* Traner for GMM model.
*/
@@ -103,16 +101,61 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
}
/** {@inheritDoc} */
- @Override public <K, V> GmmModel fit(DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, Double> extractor) {
+ @Override public <K, V, C extends Serializable> GmmModel fit(DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, Double> extractor) {
return updateModel(null, datasetBuilder, extractor);
}
/**
+ * Returns mapper for initial means selection.
+ *
+ * @param n Number of components.
+ * @return Mapper.
+ */
+ private static IgniteBiFunction<GmmPartitionData, LearningEnvironment, Vector[][]> selectNRandomXsMapper(int n) {
+ return (data, env) -> {
+ Vector[] result;
+
+ if (data.size() <= n) {
+ result = data.getAllXs().stream()
+ .map(DatasetRow::features)
+ .toArray(Vector[]::new);
+ }
+ else {
+ result = env.randomNumbersGenerator().ints(0, data.size())
+ .distinct().mapToObj(data::getX).limit(n)
+ .toArray(Vector[]::new);
+ }
+
+ return new Vector[][] {result};
+ };
+ }
+
+ /**
+ * Reducer for means selection.
+ *
+ * @return Reducer.
+ */
+ private static Vector[][] selectNRandomXsReducer(Vector[][] l, Vector[][] r) {
+ A.ensure(l != null || r != null, "l != null || r != null");
+
+ if (l == null)
+ return r;
+ if (r == null)
+ return l;
+
+ Vector[][] res = new Vector[l.length + r.length][];
+ System.arraycopy(l, 0, res, 0, l.length);
+ System.arraycopy(r, 0, res, l.length, r.length);
+
+ return res;
+ }
+
+ /**
* Sets numberOfComponents.
*
* @param numberOfComponents Number of components.
- * @return trainer.
+ * @return Trainer.
*/
public GmmTrainer withInitialCountOfComponents(int numberOfComponents) {
A.ensure(numberOfComponents > 0, "Number of components in GMM cannot equal 0");
@@ -128,7 +171,7 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
* Sets initial means.
*
* @param means Initial means for clusters.
- * @return trainer.
+ * @return Trainer.
*/
public GmmTrainer withInitialMeans(List<Vector> means) {
A.notEmpty(means, "GMM should start with non empty initial components list");
@@ -144,7 +187,7 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
* Sets max count of iterations
*
* @param maxCountOfIterations Max count of iterations.
- * @return trainer.
+ * @return Trainer.
*/
public GmmTrainer withMaxCountIterations(int maxCountOfIterations) {
A.ensure(maxCountOfIterations > 0, "Max count iterations cannot be less or equal zero or negative");
@@ -157,7 +200,7 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
* Sets min divergence beween iterations.
*
* @param eps Eps.
- * @return trainer.
+ * @return Trainer.
*/
public GmmTrainer withEps(double eps) {
A.ensure(eps > 0 && eps < 1.0, "Min divergence beween iterations should be between 0.0 and 1.0");
@@ -171,7 +214,7 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
* means randomly MaxCountOfInitTries times.
*
* @param maxCountOfInitTries Max count of init tries.
- * @return trainer.
+ * @return Trainer.
*/
public GmmTrainer withMaxCountOfInitTries(int maxCountOfInitTries) {
A.ensure(maxCountOfInitTries > 0, "Max initialization count should be great than zero.");
@@ -184,7 +227,7 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
* Sets maximum number of clusters in GMM.
*
* @param maxCountOfClusters Max count of clusters.
- * @return trainer.
+ * @return Trainer.
*/
public GmmTrainer withMaxCountOfClusters(int maxCountOfClusters) {
A.ensure(maxCountOfClusters >= countOfComponents, "Max count of components should be greater than " +
@@ -199,7 +242,7 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
* identification.
*
* @param maxLikelihoodDivergence Max likelihood divergence.
- * @return trainer.
+ * @return Trainer.
*/
public GmmTrainer withMaxLikelihoodDivergence(double maxLikelihoodDivergence) {
A.ensure(maxLikelihoodDivergence > 0, "Max likelihood divergence should be > 0");
@@ -209,31 +252,6 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
}
/**
- * Sets minimum required anomalies in terms of maxLikelihoodDivergence for creating new cluster.
- *
- * @param minElementsForNewCluster Min elements for new cluster.
- * @return trainer.
- */
- public GmmTrainer withMinElementsForNewCluster(int minElementsForNewCluster) {
- A.ensure(minElementsForNewCluster > 0, "Min elements for new cluster should be > 0");
-
- this.minElementsForNewCluster = minElementsForNewCluster;
- return this;
- }
-
- /**
- * Sets minimum requred probability for cluster. If cluster has probability value less than this value then this
- * cluster will be eliminated.
- *
- * @param minClusterProbability Min cluster probability.
- * @return trainer.
- */
- public GmmTrainer withMinClusterProbability(double minClusterProbability) {
- this.minClusterProbability = minClusterProbability;
- return this;
- }
-
- /**
* Trains model based on the specified data.
*
* @param dataset Dataset.
@@ -280,10 +298,55 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
}
/**
+ * Sets minimum required anomalies in terms of maxLikelihoodDivergence for creating new cluster.
+ *
+ * @param minElementsForNewCluster Min elements for new cluster.
+ * @return Trainer.
+ */
+ public GmmTrainer withMinElementsForNewCluster(int minElementsForNewCluster) {
+ A.ensure(minElementsForNewCluster > 0, "Min elements for new cluster should be > 0");
+
+ this.minElementsForNewCluster = minElementsForNewCluster;
+ return this;
+ }
+
+ /**
+ * Sets minimum requred probability for cluster. If cluster has probability value less than this value then this
+ * cluster will be eliminated.
+ *
+ * @param minClusterProbability Min cluster probability.
+ * @return Trainer.
+ */
+ public GmmTrainer withMinClusterProbability(double minClusterProbability) {
+ this.minClusterProbability = minClusterProbability;
+ return this;
+ }
+
+ /**
+ * Result of current model update by EM-algorithm.
+ */
+ private static class UpdateResult {
+ /** Model. */
+ private final GmmModel model;
+
+ /** Max likelihood in dataset. */
+ private final double maxProbInDataset;
+
+ /**
+ * @param model Model.
+ * @param maxProbInDataset Max likelihood in dataset.
+ */
+ public UpdateResult(GmmModel model, double maxProbInDataset) {
+ this.model = model;
+ this.maxProbInDataset = maxProbInDataset;
+ }
+ }
+
+ /**
* Remove clusters with probability value < minClusterProbability
*
* @param model Model.
- * @return filtered model.
+ * @return Filtered model.
*/
private GmmModel filterModel(GmmModel model) {
List<Double> componentProbs = new ArrayList<>();
@@ -310,7 +373,7 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
*
* @param dataset Dataset.
* @param model Model.
- * @return updated model.
+ * @return Updated model.
*/
@NotNull private UpdateResult updateModel(Dataset<EmptyContext, GmmPartitionData> dataset, GmmModel model) {
boolean isConverged = false;
@@ -346,30 +409,10 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
}
/**
- * Result of current model update by EM-algorithm.
- */
- private static class UpdateResult {
- /** Model. */
- private final GmmModel model;
-
- /** Max likelihood in dataset. */
- private final double maxProbInDataset;
-
- /**
- * @param model Model.
- * @param maxProbInDataset Max likelihood in dataset.
- */
- public UpdateResult(GmmModel model, double maxProbInDataset) {
- this.model = model;
- this.maxProbInDataset = maxProbInDataset;
- }
- }
-
- /**
* Init means and covariances.
*
* @param dataset Dataset.
- * @return initial model.
+ * @return Initial model.
*/
private Optional<GmmModel> init(Dataset<EmptyContext, GmmPartitionData> dataset) {
int countOfTries = 0;
@@ -426,53 +469,14 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
}
}
- /**
- * Create new model components with provided means and covariances.
- *
- * @param means Means.
- * @param covs Covariances.
- * @return gmm components.
- */
- private List<MultivariateGaussianDistribution> buildComponents(Vector[] means, List<Matrix> covs) {
- A.ensure(means.length == covs.size(), "means.size() == covs.size()");
-
- List<MultivariateGaussianDistribution> res = new ArrayList<>();
- for (int i = 0; i < means.length; i++)
- res.add(new MultivariateGaussianDistribution(means[i], covs.get(i)));
-
- return res;
- }
-
- /**
- * Check algorithm covergency. If it's true then algorithm stops.
- *
- * @param oldModel Old model.
- * @param newModel New model.
- * @return true if algorithm gonverged.
- */
- private boolean isConverged(GmmModel oldModel, GmmModel newModel) {
- A.ensure(oldModel.countOfComponents() == newModel.countOfComponents(),
- "oldModel.countOfComponents() == newModel.countOfComponents()");
-
- for (int i = 0; i < oldModel.countOfComponents(); i++) {
- MultivariateGaussianDistribution d1 = oldModel.distributions().get(i);
- MultivariateGaussianDistribution d2 = newModel.distributions().get(i);
-
- if (Math.sqrt(d1.mean().getDistanceSquared(d2.mean())) >= eps)
- return false;
- }
-
- return true;
- }
-
/** {@inheritDoc} */
@Override public boolean isUpdateable(GmmModel mdl) {
return mdl.countOfComponents() == countOfComponents;
}
/** {@inheritDoc} */
- @Override protected <K, V> GmmModel updateModel(GmmModel mdl, DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, Double> extractor) {
+ @Override protected <K, V, C extends Serializable> GmmModel updateModel(GmmModel mdl, DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, Double> extractor) {
try (Dataset<EmptyContext, GmmPartitionData> dataset = datasetBuilder.build(envBuilder,
new EmptyContextBuilder<>(),
@@ -500,47 +504,41 @@ public class GmmTrainer extends DatasetTrainer<GmmModel, Double> {
}
/**
- * Returns mapper for initial means selection.
+ * Create new model components with provided means and covariances.
*
- * @param n Number of components.
- * @return mapper.
+ * @param means Means.
+ * @param covs Covariances.
+ * @return Gmm components.
*/
- private static IgniteBiFunction<GmmPartitionData, LearningEnvironment, Vector[][]> selectNRandomXsMapper(int n) {
- return (data, env) -> {
- Vector[] result;
+ private List<MultivariateGaussianDistribution> buildComponents(Vector[] means, List<Matrix> covs) {
+ A.ensure(means.length == covs.size(), "means.size() == covs.size()");
- if (data.size() <= n) {
- result = data.getAllXs().stream()
- .map(DatasetRow::features)
- .toArray(Vector[]::new);
- }
- else {
- result = env.randomNumbersGenerator().ints(0, data.size())
- .distinct().mapToObj(data::getX).limit(n)
- .toArray(Vector[]::new);
- }
+ List<MultivariateGaussianDistribution> res = new ArrayList<>();
+ for (int i = 0; i < means.length; i++)
+ res.add(new MultivariateGaussianDistribution(means[i], covs.get(i)));
- return new Vector[][] {result};
- };
+ return res;
}
/**
- * Reducer for means selection.
+ * Check algorithm covergency. If it's true then algorithm stops.
*
- * @return reducer.
+ * @param oldModel Old model.
+ * @param newModel New model.
+ * @return True if algorithm gonverged.
*/
- private static Vector[][] selectNRandomXsReducer(Vector[][] l, Vector[][] r) {
- A.ensure(l != null || r != null, "l != null || r != null");
+ private boolean isConverged(GmmModel oldModel, GmmModel newModel) {
+ A.ensure(oldModel.countOfComponents() == newModel.countOfComponents(),
+ "oldModel.countOfComponents() == newModel.countOfComponents()");
- if (l == null)
- return r;
- if (r == null)
- return l;
+ for (int i = 0; i < oldModel.countOfComponents(); i++) {
+ MultivariateGaussianDistribution d1 = oldModel.distributions().get(i);
+ MultivariateGaussianDistribution d2 = newModel.distributions().get(i);
- Vector[][] res = new Vector[l.length + r.length][];
- System.arraycopy(l, 0, res, 0, l.length);
- System.arraycopy(r, 0, res, l.length, r.length);
+ if (Math.sqrt(d1.mean().getDistanceSquared(d2.mean())) >= eps)
+ return false;
+ }
- return res;
+ return true;
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java
index 99e60ba..59e8e2b 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/MeanWithClusterProbAggregator.java
@@ -17,16 +17,17 @@
package org.apache.ignite.ml.clustering.gmm;
-import java.io.Serializable;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.stream.Collectors;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
+
/**
* Statistics aggregator for mean values and cluster probabilities computing.
*/
@@ -64,14 +65,14 @@ class MeanWithClusterProbAggregator implements Serializable {
}
/**
- * @return compute mean value by aggregated data.
+ * @return Compute mean value by aggregated data.
*/
public Vector mean() {
return weightedXsSum.divide(pcxiSum);
}
/**
- * @return compute cluster probability by aggreated data.
+ * @return Compute cluster probability by aggreated data.
*/
public double clusterProb() {
return pcxiSum / rowCount;
@@ -192,14 +193,14 @@ class MeanWithClusterProbAggregator implements Serializable {
}
/**
- * @return clusters probabilities.
+ * @return Clusters probabilities.
*/
public Vector clusterProbabilities() {
return clusterProbs;
}
/**
- * @return means.
+ * @return Means.
*/
public List<Vector> means() {
return means;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregator.java
index 4fa5406..d46fc7e 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/NewComponentStatisticsAggregator.java
@@ -17,12 +17,13 @@
package org.apache.ignite.ml.clustering.gmm;
-import java.io.Serializable;
import org.apache.ignite.internal.util.typedef.internal.A;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import java.io.Serializable;
+
/**
* Class for aggregate statistics for finding new mean for GMM.
*/
@@ -86,7 +87,7 @@ public class NewComponentStatisticsAggregator implements Serializable {
* @param maxXsProb Max likelihood between all xs.
* @param maxProbDivergence Max probability divergence between maximum value and others.
* @param currentModel Current model.
- * @return aggregated statistics for new mean.
+ * @return Aggregated statistics for new mean.
*/
static NewComponentStatisticsAggregator computeNewMean(Dataset<EmptyContext, GmmPartitionData> dataset,
double maxXsProb, double maxProbDivergence, GmmModel currentModel) {
@@ -104,7 +105,7 @@ public class NewComponentStatisticsAggregator implements Serializable {
* @param maxXsProb Max xs prob.
* @param maxProbDivergence Max prob divergence.
* @param currentModel Current model.
- * @return aggregator for partition.
+ * @return Aggregator for partition.
*/
static NewComponentStatisticsAggregator computeNewMeanMap(GmmPartitionData data, double maxXsProb,
double maxProbDivergence, GmmModel currentModel) {
@@ -141,7 +142,7 @@ public class NewComponentStatisticsAggregator implements Serializable {
*
* @param left Left argument of reduce.
* @param right Right argument of reduce.
- * @return sum of aggregators.
+ * @return Sum of aggregators.
*/
static NewComponentStatisticsAggregator computeNewMeanReduce(NewComponentStatisticsAggregator left,
NewComponentStatisticsAggregator right) {
@@ -157,7 +158,7 @@ public class NewComponentStatisticsAggregator implements Serializable {
/**
* @param other Other aggregator.
- * @return sum of aggregators.
+ * @return Sum of aggregators.
*/
NewComponentStatisticsAggregator plus(NewComponentStatisticsAggregator other) {
return new NewComponentStatisticsAggregator(
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
index e1aa16b..8ac4f06 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
@@ -17,21 +17,11 @@
package org.apache.ignite.ml.clustering.kmeans;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Optional;
-import java.util.Random;
-import java.util.Set;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
import org.apache.ignite.lang.IgniteBiTuple;
-import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
@@ -44,9 +34,14 @@ import org.apache.ignite.ml.math.util.MapUtil;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
-import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
+import java.io.Serializable;
+import java.util.*;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
+
/**
* The trainer for KMeans algorithm.
*/
@@ -64,8 +59,8 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
private DistanceMeasure distance = new EuclideanDistance();
/** {@inheritDoc} */
- @Override public <K, V> KMeansModel fit(DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, Double> extractor) {
+ @Override public <K, V, C extends Serializable> KMeansModel fit(DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, Double> extractor) {
return updateModel(null, datasetBuilder, extractor);
}
@@ -75,18 +70,12 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
}
/** {@inheritDoc} */
- @Override protected <K, V> KMeansModel updateModel(KMeansModel mdl, DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, Double> extractor) {
+ @Override protected <K, V, C extends Serializable> KMeansModel updateModel(KMeansModel mdl, DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, Double> extractor) {
assert datasetBuilder != null;
- IgniteBiFunction<K, V, Vector> featureExtractor = CompositionUtils.asFeatureExtractor(extractor);
- IgniteBiFunction<K, V, Double> lbExtractor = CompositionUtils.asLabelExtractor(extractor);
-
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder =
- new LabeledDatasetPartitionDataBuilderOnHeap<>(
- featureExtractor,
- lbExtractor
- );
+ new LabeledDatasetPartitionDataBuilderOnHeap<>(extractor);
Vector[] centers;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java
index 95e4ee8..9b3e56a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/CompositionUtils.java
@@ -19,12 +19,15 @@ package org.apache.ignite.ml.composition;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+import java.io.Serializable;
+
/**
* Various utility functions for trainers composition.
*/
@@ -44,14 +47,15 @@ public class CompositionUtils {
DatasetTrainer<? extends M, L> trainer) {
return new DatasetTrainer<IgniteModel<I, O>, L>() {
/** {@inheritDoc} */
- @Override public <K, V> IgniteModel<I, O> fit(DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ @Override public <K, V, C extends Serializable> IgniteModel<I, O> fit(DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
return trainer.fit(datasetBuilder, extractor);
}
/** {@inheritDoc} */
- @Override public <K, V> IgniteModel<I, O> update(IgniteModel<I, O> mdl, DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ @Override public <K, V, C extends Serializable> IgniteModel<I, O> update(IgniteModel<I, O> mdl,
+ DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
DatasetTrainer<IgniteModel<I, O>, L> trainer1 = (DatasetTrainer<IgniteModel<I, O>, L>)trainer;
return trainer1.update(mdl, datasetBuilder, extractor);
}
@@ -72,14 +76,15 @@ public class CompositionUtils {
/**
* This method is never called, instead of constructing logic of update from
* {@link DatasetTrainer#isUpdateable(IgniteModel)} and
- * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
+ * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, Vectorizer)}
* in this class we explicitly override update method.
*
* @param mdl Model.
* @return Updated model.
*/
- @Override protected <K, V> IgniteModel<I, O> updateModel(IgniteModel<I, O> mdl, DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ @Override protected <K, V, C extends Serializable> IgniteModel<I, O> updateModel(IgniteModel<I, O> mdl,
+ DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
throw new IllegalStateException();
}
};
@@ -94,7 +99,8 @@ public class CompositionUtils {
* @param <L> Type of labels.
* @return Feature extractor created from given mapping {@code (key, value) -> LabeledVector}.
*/
- public static <K, V, L> IgniteBiFunction<K, V, Vector> asFeatureExtractor(FeatureLabelExtractor<K, V, L> extractor) {
+ public static <K, V, L> IgniteBiFunction<K, V, Vector> asFeatureExtractor(
+ FeatureLabelExtractor<K, V, L> extractor) {
return (k, v) -> extractor.extract(k, v).features();
}
@@ -121,7 +127,8 @@ public class CompositionUtils {
* @param <L> Type of labels.
* @return Label extractor created from given mapping {@code (key, value) -> LabeledVector}.
*/
- public static <K, V, L> FeatureLabelExtractor<K, V, L> asFeatureLabelExtractor(IgniteBiFunction<K, V, Vector> featureExtractor,
+ public static <K, V, L> FeatureLabelExtractor<K, V, L> asFeatureLabelExtractor(
+ IgniteBiFunction<K, V, Vector> featureExtractor,
IgniteBiFunction<K, V, L> lbExtractor) {
return (k, v) -> new LabeledVector<>(featureExtractor.apply(k, v), lbExtractor.apply(k, v));
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java
index b588b25..6ce4688 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/bagging/BaggedTrainer.java
@@ -17,25 +17,27 @@
package org.apache.ignite.ml.composition.bagging;
-import java.util.Collections;
-import java.util.List;
-import java.util.Random;
-import java.util.stream.Collectors;
-import java.util.stream.IntStream;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
import org.apache.ignite.ml.trainers.DatasetTrainer;
-import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.apache.ignite.ml.trainers.transformers.BaggingUpstreamTransformer;
import org.apache.ignite.ml.util.Utils;
+import java.io.Serializable;
+import java.util.Collections;
+import java.util.List;
+import java.util.Random;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
/**
* Trainer encapsulating logic of bootstrap aggregating (bagging).
* This trainer accepts some other trainer and returns bagged version of it.
@@ -148,15 +150,15 @@ public class BaggedTrainer<L> extends
/** {@inheritDoc} */
- @Override public <K, V> BaggedModel fit(DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ @Override public <K, V, C extends Serializable> BaggedModel fit(DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
IgniteModel<Vector, Double> fit = getTrainer().fit(datasetBuilder, extractor);
return new BaggedModel(fit);
}
/** {@inheritDoc} */
- @Override public <K, V> BaggedModel update(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ @Override public <K, V, C extends Serializable> BaggedModel update(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
IgniteModel<Vector, Double> updated = getTrainer().update(mdl.model(), datasetBuilder, extractor);
return new BaggedModel(updated);
}
@@ -189,8 +191,8 @@ public class BaggedTrainer<L> extends
* @param mdl Model.
* @return Updated model.
*/
- @Override protected <K, V> BaggedModel updateModel(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ @Override protected <K, V, C extends Serializable> BaggedModel updateModel(BaggedModel mdl, DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
// Should be never called.
throw new IllegalStateException();
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java
index 3acca14..abe3e16 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBBinaryClassifierTrainer.java
@@ -17,23 +17,24 @@
package org.apache.ignite.ml.composition.boosting;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.Set;
-import java.util.stream.Collectors;
import org.apache.ignite.ml.composition.boosting.loss.LogLoss;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Set;
+import java.util.stream.Collectors;
+
/**
* Trainer for binary classifier using Gradient Boosting. As preparing stage this algorithm learn labels in dataset and
* create mapping dataset labels to 0 and 1. This algorithm uses gradient of Logarithmic Loss metric [LogLoss] by
@@ -67,14 +68,13 @@ public abstract class GDBBinaryClassifierTrainer extends GDBTrainer {
}
/** {@inheritDoc} */
- @Override protected <V, K> boolean learnLabels(DatasetBuilder<K, V> builder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lExtractor) {
+ @Override protected <V, K, C extends Serializable> boolean learnLabels(DatasetBuilder<K, V> builder,
+ Vectorizer<K, V, C, Double> vectorizer) {
Set<Double> uniqLabels = builder.build(
envBuilder,
new EmptyContextBuilder<>(),
- new LabeledDatasetPartitionDataBuilderOnHeap<>(featureExtractor, lExtractor))
+ new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer))
.compute((IgniteFunction<LabeledVectorSet<Double, LabeledVector>, Set<Double>>)x ->
Arrays.stream(x.labels()).boxed().collect(Collectors.toSet()), (a, b) -> {
if (a == null)
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
index 7e85449..a22bc8d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
@@ -17,9 +17,6 @@
package org.apache.ignite.ml.composition.boosting;
-import java.util.ArrayList;
-import java.util.Arrays;
-import java.util.List;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
@@ -28,18 +25,22 @@ import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueCo
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.environment.logging.MLLogger;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.DatasetTrainer;
-import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.jetbrains.annotations.NotNull;
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.List;
+
/**
* Learning strategy for gradient boosting.
*/
@@ -82,38 +83,35 @@ public class GDBLearningStrategy {
* model based on gradient of loss-function for current models composition.
*
* @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
- * @return list of learned models.
+ * @param vectorizer Upstream vectorizer.
+ * @return List of learned models.
*/
- public <K, V> List<IgniteModel<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+ public <K, V, C extends Serializable> List<IgniteModel<Vector, Double>> learnModels(DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, Double> vectorizer) {
- return update(null, datasetBuilder, featureExtractor, lbExtractor);
+ return update(null, datasetBuilder, vectorizer);
}
/**
- * Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then
- * trainer updates model in according to new data and return new model. In other case trains new model.
+ * Gets state of model in arguments, compare it with training parameters of trainer and if they are fit then trainer
+ * updates model in according to new data and return new model. In other case trains new model.
*
* @param mdlToUpdate Learned model.
* @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
+ * @param vectorizer Upstream vectorizer.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @return Updated models list.
*/
- public <K,V> List<IgniteModel<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate,
- DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lbExtractor) {
+ public <K, V, C extends Serializable> List<IgniteModel<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate,
+ DatasetBuilder<K, V> datasetBuilder, Vectorizer<K, V, C, Double> vectorizer) {
if (trainerEnvironment == null)
throw new IllegalStateException("Learning environment builder is not set.");
List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
- ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize,
- externalLbToInternalMapping, loss, datasetBuilder, featureExtractor, lbExtractor);
+ ConvergenceChecker<K, V, C> convCheck = checkConvergenceStgyFactory.create(sampleSize,
+ externalLbToInternalMapping, loss, datasetBuilder, vectorizer);
DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
for (int i = 0; i < cntOfIterations; i++) {
@@ -124,11 +122,12 @@ public class GDBLearningStrategy {
if (convCheck.isConverged(envBuilder, datasetBuilder, currComposition))
break;
- FeatureLabelExtractor<K, V, Double> extractor = new FeatureLabelExtractor<K, V, Double>() {
+ Vectorizer<K, V, C, Double> extractor = new Vectorizer.VectorizerAdapter<K, V, C, Double>() {
/** {@inheritDoc} */
@Override public LabeledVector<Double> extract(K k, V v) {
- Vector features = featureExtractor.apply(k, v);
- Double realAnswer = externalLbToInternalMapping.apply(lbExtractor.apply(k, v));
+ LabeledVector<Double> labeledVector = vectorizer.extract(k, v);
+ Vector features = labeledVector.features();
+ Double realAnswer = externalLbToInternalMapping.apply(labeledVector.label());
Double mdlAnswer = currComposition.predict(features);
return new LabeledVector<>(features, -loss.gradient(sampleSize, realAnswer, mdlAnswer));
}
@@ -147,16 +146,16 @@ public class GDBLearningStrategy {
* Restores state of already learned model if can and sets learning parameters according to this state.
*
* @param mdlToUpdate Model to update.
- * @return list of already learned models.
+ * @return List of already learned models.
*/
@NotNull protected List<IgniteModel<Vector, Double>> initLearningState(GDBTrainer.GDBModel mdlToUpdate) {
List<IgniteModel<Vector, Double>> models = new ArrayList<>();
- if(mdlToUpdate != null) {
+ if (mdlToUpdate != null) {
models.addAll(mdlToUpdate.getModels());
- WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator) mdlToUpdate.getPredictionsAggregator();
+ WeightedPredictionsAggregator aggregator = (WeightedPredictionsAggregator)mdlToUpdate.getPredictionsAggregator();
meanLbVal = aggregator.getBias();
compositionWeights = new double[models.size() + cntOfIterations];
- for(int i = 0; i < models.size(); i++)
+ for (int i = 0; i < models.size(); i++)
compositionWeights[i] = aggregator.getWeights()[i];
}
else
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java
index 3dc95ee..f708b6a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBRegressionTrainer.java
@@ -19,13 +19,14 @@ package org.apache.ignite.ml.composition.boosting;
import org.apache.ignite.ml.composition.boosting.loss.SquaredError;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+import java.io.Serializable;
/**
- * Trainer for regressor using Gradient Boosting.
- * This algorithm uses gradient of Mean squared error loss metric [MSE] in each step of learning.
+ * Trainer for regressor using Gradient Boosting. This algorithm uses gradient of Mean squared error loss metric [MSE]
+ * in each step of learning.
*/
public abstract class GDBRegressionTrainer extends GDBTrainer {
/**
@@ -39,8 +40,8 @@ public abstract class GDBRegressionTrainer extends GDBTrainer {
}
/** {@inheritDoc} */
- @Override protected <V, K> boolean learnLabels(DatasetBuilder<K, V> builder, IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lExtractor) {
+ @Override protected <V, K, C extends Serializable> boolean learnLabels(DatasetBuilder<K, V> builder,
+ Vectorizer<K, V, C, Double> vectorizer) {
return true;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
index ff87e15..aa82d3a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
@@ -17,11 +17,8 @@
package org.apache.ignite.ml.composition.boosting;
-import java.util.Arrays;
-import java.util.List;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.IgniteModel;
-import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
@@ -29,24 +26,27 @@ import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
import org.apache.ignite.ml.regressions.linear.LinearRegressionSGDTrainer;
import org.apache.ignite.ml.trainers.DatasetTrainer;
-import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer;
import org.jetbrains.annotations.NotNull;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+
/**
* Abstract Gradient Boosting trainer. It implements gradient descent in functional space using user-selected regressor
* in child class. Each learning iteration the trainer evaluate gradient of error-function and fit regression model to
@@ -78,8 +78,8 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
*
* @param gradStepSize Grad step size.
* @param cntOfIterations Count of learning iterations.
- * @param loss Gradient of loss function. First argument is sample size, second argument is valid answer
- * third argument is current model prediction.
+ * @param loss Gradient of loss function. First argument is sample size, second argument is valid answer third
+ * argument is current model prediction.
*/
public GDBTrainer(double gradStepSize, Integer cntOfIterations, Loss loss) {
gradientStep = gradStepSize;
@@ -88,29 +88,20 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
}
/** {@inheritDoc} */
- @Override public <K, V> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, Double> extractor) {
+ @Override public <K, V, C extends Serializable> ModelsComposition fit(DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, Double> extractor) {
return updateModel(null, datasetBuilder, extractor);
}
/** {@inheritDoc} */
- @Override protected <K, V> ModelsComposition updateModel(ModelsComposition mdl,
+ @Override protected <K, V, C extends Serializable> ModelsComposition updateModel(ModelsComposition mdl,
DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, Double> extractor) {
- if (!learnLabels(datasetBuilder, CompositionUtils.asFeatureExtractor(extractor), CompositionUtils.asLabelExtractor(extractor)))
+ Vectorizer<K, V, C, Double> extractor) {
+ if (!learnLabels(datasetBuilder, extractor))
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
- IgniteBiFunction<K, V, Vector> featureExtractor =
- (k, v) -> extractor.extract(k, v).features();
- IgniteBiFunction<K, V, Double> lbExtractor =
- (k, v) -> extractor.extract(k, v).label();
-
- IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(
- envBuilder,
- datasetBuilder,
- featureExtractor,
- lbExtractor);
- if(initAndSampleSize == null)
+ IgniteBiTuple<Double, Long> initAndSampleSize = computeInitialValue(envBuilder, datasetBuilder, extractor);
+ if (initAndSampleSize == null)
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
Double mean = initAndSampleSize.get1();
@@ -131,14 +122,9 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
List<IgniteModel<Vector, Double>> models;
if (mdl != null)
- models = stgy.update((GDBModel)mdl,
- datasetBuilder,
- featureExtractor,
- lbExtractor);
+ models = stgy.update((GDBModel)mdl, datasetBuilder, extractor);
else
- models = stgy.learnModels(datasetBuilder,
- featureExtractor,
- lbExtractor);
+ models = stgy.learnModels(datasetBuilder, extractor);
double learningTime = (double)(System.currentTimeMillis() - learningStartTs) / 1000.0;
environment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "The training time was %.2fs", learningTime);
@@ -164,13 +150,11 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
* Defines unique labels in dataset if need (useful in case of classification).
*
* @param builder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Labels extractor.
- * @return true if labels learning was successful.
+ * @param vectorizer Upstream vectorizer.
+ * @return True if labels learning was successful.
*/
- protected abstract <V, K> boolean learnLabels(DatasetBuilder<K, V> builder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lbExtractor);
+ protected abstract <V, K, C extends Serializable> boolean learnLabels(DatasetBuilder<K, V> builder,
+ Vectorizer<K, V, C, Double> vectorizer);
/**
* Returns regressor model trainer for one step of GDB.
@@ -197,19 +181,17 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
*
* @param builder Dataset builder.
* @param envBuilder Learning environment builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
+ * @param vectorizer Vectorizer.
*/
- protected <V, K> IgniteBiTuple<Double, Long> computeInitialValue(
+ protected <V, K, C extends Serializable> IgniteBiTuple<Double, Long> computeInitialValue(
LearningEnvironmentBuilder envBuilder,
DatasetBuilder<K, V> builder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lbExtractor) {
+ Vectorizer<K, V, C, Double> vectorizer) {
try (Dataset<EmptyContext, DecisionTreeData> dataset = builder.build(
envBuilder,
new EmptyContextBuilder<>(),
- new DecisionTreeDataBuilder<>(CompositionUtils.asFeatureLabelExtractor(featureExtractor, lbExtractor), false)
+ new DecisionTreeDataBuilder<>(vectorizer, false)
)) {
IgniteBiTuple<Double, Long> meanTuple = dataset.compute(
data -> {
@@ -241,7 +223,7 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
* Sets CheckConvergenceStgyFactory.
*
* @param factory Factory.
- * @return trainer.
+ * @return Trainer.
*/
public GDBTrainer withCheckConvergenceStgyFactory(ConvergenceCheckerFactory factory) {
this.checkConvergenceStgyFactory = factory;
@@ -251,7 +233,7 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
/**
* Returns learning strategy.
*
- * @return learning strategy.
+ * @return Learning strategy.
*/
protected GDBLearningStrategy getLearningStrategy() {
return new GDBLearningStrategy();
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java
index f7da9a1..1924863 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceChecker.java
@@ -17,27 +17,28 @@
package org.apache.ignite.ml.composition.boosting.convergence;
-import java.io.Serializable;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapDataBuilder;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import java.io.Serializable;
+
/**
* Contains logic of error computing and convergence checking for Gradient Boosting algorithms.
*
* @param <K> Type of a key in upstream data.
* @param <V> Type of a value in upstream data.
*/
-public abstract class ConvergenceChecker<K, V> implements Serializable {
+public abstract class ConvergenceChecker<K, V, C extends Serializable> implements Serializable {
/** Serial version uid. */
private static final long serialVersionUID = 710762134746674105L;
@@ -50,11 +51,8 @@ public abstract class ConvergenceChecker<K, V> implements Serializable {
/** Loss function. */
private Loss loss;
- /** Feature extractor. */
- private IgniteBiFunction<K, V, Vector> featureExtractor;
-
- /** Label extractor. */
- private IgniteBiFunction<K, V, Double> lbExtractor;
+ /** Upstream vectorizer. */
+ private Vectorizer<K, V, C, Double> vectorizer;
/** Precision of convergence check. */
private double precision;
@@ -66,24 +64,21 @@ public abstract class ConvergenceChecker<K, V> implements Serializable {
* @param externalLbToInternalMapping External label to internal mapping.
* @param loss Loss gradient.
* @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
- * @param precision Precision.
+ * @param vectorizer Upstream vectorizer.
+ * @param precision Precision.FeatureMatrixWithLabelsOnHeapDataBuilder.java
*/
public ConvergenceChecker(long sampleSize,
IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss,
- DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lbExtractor,
- double precision) {
+ DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, Double> vectorizer, double precision) {
assert precision < 1 && precision >= 0;
this.sampleSize = sampleSize;
this.externalLbToInternalMapping = externalLbToInternalMapping;
this.loss = loss;
- this.featureExtractor = featureExtractor;
- this.lbExtractor = lbExtractor;
this.precision = precision;
+ this.vectorizer = vectorizer;
}
/**
@@ -91,7 +86,7 @@ public abstract class ConvergenceChecker<K, V> implements Serializable {
*
* @param envBuilder Learning environment builder.
* @param currMdl Current model.
- * @return true if GDB is converged.
+ * @return True if GDB is converged.
*/
public boolean isConverged(
LearningEnvironmentBuilder envBuilder,
@@ -100,7 +95,7 @@ public abstract class ConvergenceChecker<K, V> implements Serializable {
try (Dataset<EmptyContext, FeatureMatrixWithLabelsOnHeapData> dataset = datasetBuilder.build(
envBuilder,
new EmptyContextBuilder<>(),
- new FeatureMatrixWithLabelsOnHeapDataBuilder<>(featureExtractor, lbExtractor)
+ new FeatureMatrixWithLabelsOnHeapDataBuilder<>(vectorizer)
)) {
return isConverged(dataset, currMdl);
}
@@ -114,9 +109,10 @@ public abstract class ConvergenceChecker<K, V> implements Serializable {
*
* @param dataset Dataset.
* @param currMdl Current model.
- * @return true if GDB is converged.
+ * @return True if GDB is converged.
*/
- public boolean isConverged(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset, ModelsComposition currMdl) {
+ public boolean isConverged(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
+ ModelsComposition currMdl) {
Double error = computeMeanErrorOnDataset(dataset, currMdl);
return error < precision || error.isNaN();
}
@@ -126,7 +122,7 @@ public abstract class ConvergenceChecker<K, V> implements Serializable {
*
* @param dataset Learning dataset.
* @param mdl Model.
- * @return error mean value.
+ * @return Error mean value.
*/
public abstract Double computeMeanErrorOnDataset(
Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
@@ -136,7 +132,7 @@ public abstract class ConvergenceChecker<K, V> implements Serializable {
* Compute error for the specific vector of dataset.
*
* @param currMdl Current model.
- * @return error.
+ * @return Error.
*/
public double computeError(Vector features, Double answer, ModelsComposition currMdl) {
Double realAnswer = externalLbToInternalMapping.apply(answer);
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerFactory.java
index 7592f50..c2be71e 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerFactory.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/ConvergenceCheckerFactory.java
@@ -19,15 +19,16 @@ package org.apache.ignite.ml.composition.boosting.convergence;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.DatasetBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+import java.io.Serializable;
/**
* Factory for ConvergenceChecker.
*/
public abstract class ConvergenceCheckerFactory {
- /** Precision of error checking. If error <= precision then it is equated to 0.0*/
+ /** Precision of error checking. If error <= precision then it is equated to 0.0 */
protected double precision;
/**
@@ -46,13 +47,11 @@ public abstract class ConvergenceCheckerFactory {
* @param externalLbToInternalMapping External label to internal mapping.
* @param loss Loss function.
* @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
+ * @param vectorizer Upstream vectorizer.
* @return ConvergenceCheckerFactory instance.
*/
- public abstract <K,V> ConvergenceChecker<K,V> create(long sampleSize,
+ public abstract <K, V, C extends Serializable> ConvergenceChecker<K, V, C> create(long sampleSize,
IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss,
DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lbExtractor);
+ Vectorizer<K, V, C, Double> vectorizer);
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceChecker.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceChecker.java
index 82b194b..6dbdf29 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceChecker.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceChecker.java
@@ -23,20 +23,21 @@ import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import java.io.Serializable;
+
/**
* Use mean value of errors for estimating error on dataset.
*
* @param <K> Type of a key in upstream data.
* @param <V> Type of a value in upstream data.
*/
-public class MeanAbsValueConvergenceChecker<K,V> extends ConvergenceChecker<K,V> {
+public class MeanAbsValueConvergenceChecker<K, V, C extends Serializable> extends ConvergenceChecker<K, V, C> {
/** Serial version uid. */
private static final long serialVersionUID = 8534776439755210864L;
@@ -47,19 +48,17 @@ public class MeanAbsValueConvergenceChecker<K,V> extends ConvergenceChecker<K,V>
* @param externalLbToInternalMapping External label to internal mapping.
* @param loss Loss.
* @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
+ * @param vectorizer Upstream vectorizer.
*/
public MeanAbsValueConvergenceChecker(long sampleSize, IgniteFunction<Double, Double> externalLbToInternalMapping,
- Loss loss, DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lbExtractor,
- double precision) {
+ Loss loss, DatasetBuilder<K, V> datasetBuilder, Vectorizer<K, V, C, Double> vectorizer, double precision) {
- super(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, featureExtractor, lbExtractor, precision);
+ super(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer, precision);
}
/** {@inheritDoc} */
- @Override public Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
+ @Override public Double computeMeanErrorOnDataset(
+ Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
ModelsComposition mdl) {
IgniteBiTuple<Double, Long> sumAndCnt = dataset.compute(
@@ -67,7 +66,7 @@ public class MeanAbsValueConvergenceChecker<K,V> extends ConvergenceChecker<K,V>
this::reduce
);
- if(sumAndCnt == null || sumAndCnt.getValue() == 0)
+ if (sumAndCnt == null || sumAndCnt.getValue() == 0)
return Double.NaN;
return sumAndCnt.getKey() / sumAndCnt.getValue();
}
@@ -79,15 +78,16 @@ public class MeanAbsValueConvergenceChecker<K,V> extends ConvergenceChecker<K,V>
* @param part Partition.
* @return Tuple (sum of errors, count of rows)
*/
- private IgniteBiTuple<Double, Long> computeStatisticOnPartition(ModelsComposition mdl, FeatureMatrixWithLabelsOnHeapData part) {
+ private IgniteBiTuple<Double, Long> computeStatisticOnPartition(ModelsComposition mdl,
+ FeatureMatrixWithLabelsOnHeapData part) {
Double sum = 0.0;
- for(int i = 0; i < part.getFeatures().length; i++) {
+ for (int i = 0; i < part.getFeatures().length; i++) {
double error = computeError(VectorUtils.of(part.getFeatures()[i]), part.getLabels()[i], mdl);
sum += Math.abs(error);
}
- return new IgniteBiTuple<>(sum, (long) part.getLabels().length);
+ return new IgniteBiTuple<>(sum, (long)part.getLabels().length);
}
/**
@@ -95,7 +95,7 @@ public class MeanAbsValueConvergenceChecker<K,V> extends ConvergenceChecker<K,V>
*
* @param left Left.
* @param right Right.
- * @return merged value.
+ * @return Merged value.
*/
private IgniteBiTuple<Double, Long> reduce(IgniteBiTuple<Double, Long> left, IgniteBiTuple<Double, Long> right) {
if (left == null) {
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerFactory.java
index f02a606..9308d9d 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerFactory.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/mean/MeanAbsValueConvergenceCheckerFactory.java
@@ -21,9 +21,10 @@ import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.DatasetBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+import java.io.Serializable;
/**
* Factory for {@link MeanAbsValueConvergenceChecker}.
@@ -37,11 +38,11 @@ public class MeanAbsValueConvergenceCheckerFactory extends ConvergenceCheckerFac
}
/** {@inheritDoc} */
- @Override public <K, V> ConvergenceChecker<K, V> create(long sampleSize,
+ @Override public <K, V, C extends Serializable> ConvergenceChecker<K, V, C> create(long sampleSize,
IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+ Vectorizer<K, V, C, Double> vectorizer) {
return new MeanAbsValueConvergenceChecker<>(sampleSize, externalLbToInternalMapping, loss,
- datasetBuilder, featureExtractor, lbExtractor, precision);
+ datasetBuilder, vectorizer, precision);
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceChecker.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceChecker.java
index 7e66a9c..e41c37a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceChecker.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceChecker.java
@@ -17,19 +17,20 @@
package org.apache.ignite.ml.composition.boosting.convergence.median;
-import java.util.Arrays;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import java.io.Serializable;
+import java.util.Arrays;
+
/**
* Use median of median on partitions value of errors for estimating error on dataset. This algorithm may be less
* sensitive to
@@ -37,7 +38,7 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
* @param <K> Type of a key in upstream data.
* @param <V> Type of a value in upstream data.
*/
-public class MedianOfMedianConvergenceChecker<K, V> extends ConvergenceChecker<K, V> {
+public class MedianOfMedianConvergenceChecker<K, V, C extends Serializable> extends ConvergenceChecker<K, V, C> {
/** Serial version uid. */
private static final long serialVersionUID = 4902502002933415287L;
@@ -48,19 +49,18 @@ public class MedianOfMedianConvergenceChecker<K, V> extends ConvergenceChecker<K
* @param lblMapping External label to internal mapping.
* @param loss Loss function.
* @param datasetBuilder Dataset builder.
- * @param fExtr Feature extractor.
- * @param lbExtr Label extractor.
+ * @param vectorizer Upstream vectorizer.
* @param precision Precision.
*/
public MedianOfMedianConvergenceChecker(long sampleSize, IgniteFunction<Double, Double> lblMapping, Loss loss,
- DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> fExtr,
- IgniteBiFunction<K, V, Double> lbExtr, double precision) {
+ DatasetBuilder<K, V> datasetBuilder, Vectorizer<K, V, C, Double> vectorizer, double precision) {
- super(sampleSize, lblMapping, loss, datasetBuilder, fExtr, lbExtr, precision);
+ super(sampleSize, lblMapping, loss, datasetBuilder, vectorizer, precision);
}
/** {@inheritDoc} */
- @Override public Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
+ @Override public Double computeMeanErrorOnDataset(
+ Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
ModelsComposition mdl) {
double[] medians = dataset.compute(
@@ -68,7 +68,7 @@ public class MedianOfMedianConvergenceChecker<K, V> extends ConvergenceChecker<K
this::reduce
);
- if(medians == null)
+ if (medians == null)
return Double.POSITIVE_INFINITY;
return getMedian(medians);
}
@@ -78,7 +78,7 @@ public class MedianOfMedianConvergenceChecker<K, V> extends ConvergenceChecker<K
*
* @param mdl Model.
* @param data Data.
- * @return median value.
+ * @return Median value.
*/
private double[] computeMedian(ModelsComposition mdl, FeatureMatrixWithLabelsOnHeapData data) {
double[] errors = new double[data.getLabels().length];
@@ -91,10 +91,10 @@ public class MedianOfMedianConvergenceChecker<K, V> extends ConvergenceChecker<K
* Compute median value on array of errors.
*
* @param errors Error values.
- * @return median value of errors.
+ * @return Median value of errors.
*/
private double getMedian(double[] errors) {
- if(errors.length == 0)
+ if (errors.length == 0)
return Double.POSITIVE_INFINITY;
Arrays.sort(errors);
@@ -110,12 +110,12 @@ public class MedianOfMedianConvergenceChecker<K, V> extends ConvergenceChecker<K
*
* @param left Left partition.
* @param right Right partition.
- * @return merged median values.
+ * @return Merged median values.
*/
private double[] reduce(double[] left, double[] right) {
if (left == null)
return right;
- if(right == null)
+ if (right == null)
return left;
double[] res = new double[left.length + right.length];
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerFactory.java
index a1affe0..2daa6e5 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerFactory.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/median/MedianOfMedianConvergenceCheckerFactory.java
@@ -21,9 +21,10 @@ import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.DatasetBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+import java.io.Serializable;
/**
* Factory for {@link MedianOfMedianConvergenceChecker}.
@@ -37,11 +38,11 @@ public class MedianOfMedianConvergenceCheckerFactory extends ConvergenceCheckerF
}
/** {@inheritDoc} */
- @Override public <K, V> ConvergenceChecker<K, V> create(long sampleSize,
+ @Override public <K, V, C extends Serializable> ConvergenceChecker<K, V, C> create(long sampleSize,
IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss, DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+ Vectorizer<K, V, C, Double> vectorizer) {
return new MedianOfMedianConvergenceChecker<>(sampleSize, externalLbToInternalMapping, loss,
- datasetBuilder, featureExtractor, lbExtractor, precision);
+ datasetBuilder, vectorizer, precision);
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java
index 193afaf..4996902 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStub.java
@@ -22,21 +22,22 @@ import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.FeatureMatrixWithLabelsOnHeapData;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+import java.io.Serializable;
/**
- * This strategy skip estimating error on dataset step.
- * According to this strategy, training will stop after reaching the maximum number of iterations.
+ * This strategy skip estimating error on dataset step. According to this strategy, training will stop after reaching
+ * the maximum number of iterations.
*
* @param <K> Type of a key in upstream data.
* @param <V> Type of a value in upstream data.
*/
-public class ConvergenceCheckerStub<K,V> extends ConvergenceChecker<K,V> {
+public class ConvergenceCheckerStub<K, V, C extends Serializable> extends ConvergenceChecker<K, V, C> {
/** Serial version uid. */
private static final long serialVersionUID = 8534776439755210864L;
@@ -47,21 +48,17 @@ public class ConvergenceCheckerStub<K,V> extends ConvergenceChecker<K,V> {
* @param externalLbToInternalMapping External label to internal mapping.
* @param loss Loss function.
* @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
+ * @param vectorizer Upstream vectorizer.
*/
- public ConvergenceCheckerStub(long sampleSize,
- IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss,
- DatasetBuilder<K, V> datasetBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lbExtractor) {
+ public ConvergenceCheckerStub(long sampleSize, IgniteFunction externalLbToInternalMapping, Loss loss,
+ DatasetBuilder datasetBuilder, Vectorizer<K, V, C, Double> vectorizer, double precision) {
- super(sampleSize, externalLbToInternalMapping, loss, datasetBuilder,
- featureExtractor, lbExtractor, 0.0);
+ super(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer, 0.0);
}
/** {@inheritDoc} */
- @Override public boolean isConverged(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder, ModelsComposition currMdl) {
+ @Override public boolean isConverged(LearningEnvironmentBuilder envBuilder, DatasetBuilder<K, V> datasetBuilder,
+ ModelsComposition currMdl) {
return false;
}
@@ -72,7 +69,8 @@ public class ConvergenceCheckerStub<K,V> extends ConvergenceChecker<K,V> {
}
/** {@inheritDoc} */
- @Override public Double computeMeanErrorOnDataset(Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
+ @Override public Double computeMeanErrorOnDataset(
+ Dataset<EmptyContext, ? extends FeatureMatrixWithLabelsOnHeapData> dataset,
ModelsComposition mdl) {
throw new UnsupportedOperationException();
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStubFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStubFactory.java
index a0f0d5c..f7cd500 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStubFactory.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/convergence/simple/ConvergenceCheckerStubFactory.java
@@ -21,9 +21,10 @@ import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
import org.apache.ignite.ml.dataset.DatasetBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.math.functions.IgniteFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+import java.io.Serializable;
/**
* Factory for {@link ConvergenceCheckerStub}.
@@ -37,12 +38,10 @@ public class ConvergenceCheckerStubFactory extends ConvergenceCheckerFactory {
}
/** {@inheritDoc} */
- @Override public <K, V> ConvergenceChecker<K, V> create(long sampleSize,
+ @Override public <K, V, C extends Serializable> ConvergenceChecker<K, V, C> create(long sampleSize,
IgniteFunction<Double, Double> externalLbToInternalMapping, Loss loss,
- DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lbExtractor) {
+ DatasetBuilder<K, V> datasetBuilder, Vectorizer<K, V, C, Double> vectorizer) {
- return new ConvergenceCheckerStub<>(sampleSize, externalLbToInternalMapping, loss,
- datasetBuilder, featureExtractor, lbExtractor);
+ return new ConvergenceCheckerStub<>(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer, precision);
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java
index 72fff30..317105f 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java
@@ -29,7 +29,7 @@ public interface Loss extends Serializable {
* @param sampleSize Sample size.
* @param lb Label.
* @param mdlAnswer Model answer.
- * @return error value.
+ * @return Error value.
*/
public double error(long sampleSize, double lb, double mdlAnswer);
@@ -39,7 +39,7 @@ public interface Loss extends Serializable {
* @param sampleSize Sample size.
* @param lb Label.
* @param mdlAnswer Model answer.
- * @return error value.
+ * @return Error value.
*/
public double gradient(long sampleSize, double lb, double mdlAnswer);
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java
index 0b21b5b..65335b3 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/parallel/TrainersParallelComposition.java
@@ -17,30 +17,29 @@
package org.apache.ignite.ml.composition.combinators.parallel;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.stream.Collectors;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.environment.parallelism.Promise;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteSupplier;
import org.apache.ignite.ml.trainers.DatasetTrainer;
-import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
/**
- * This class represents a parallel composition of trainers.
- * Parallel composition of trainers is a trainer itself which trains a list of trainers with same
- * input and output. Training is done in following manner:
+ * This class represents a parallel composition of trainers. Parallel composition of trainers is a trainer itself which
+ * trains a list of trainers with same input and output. Training is done in following manner:
* <pre>
* 1. Independently train all trainers on the same dataset and get a list of models.
* 2. Combine models produced in step (1) into a {@link ModelsParallelComposition}.
* </pre>
- * Updating is made in a similar fashion.
- * Like in other trainers combinators we avoid to include type of contained trainers in type parameters
- * because otherwise compositions of compositions would have a relatively complex generic type which will
- * reduce readability.
+ * Updating is made in a similar fashion. Like in other trainers combinators we avoid to include type of contained
+ * trainers in type parameters because otherwise compositions of compositions would have a relatively complex generic
+ * type which will reduce readability.
*
* @param <I> Type of trainers inputs.
* @param <O> Type of trainers outputs.
@@ -72,7 +71,8 @@ public class TrainersParallelComposition<I, O, L> extends DatasetTrainer<IgniteM
* @param <L> Type of input of labels.
* @return Parallel composition of trainers contained in a given list.
*/
- public static <I, O, M extends IgniteModel<I, O>, T extends DatasetTrainer<M, L>, L> TrainersParallelComposition<I, O, L> of(List<T> trainers) {
+ public static <I, O, M extends IgniteModel<I, O>, T extends DatasetTrainer<M, L>, L> TrainersParallelComposition<I, O, L> of(
+ List<T> trainers) {
List<DatasetTrainer<IgniteModel<I, O>, L>> trs =
trainers.stream().map(CompositionUtils::unsafeCoerce).collect(Collectors.toList());
@@ -80,12 +80,10 @@ public class TrainersParallelComposition<I, O, L> extends DatasetTrainer<IgniteM
}
/** {@inheritDoc} */
- @Override public <K, V> IgniteModel<I, List<O>> fit(DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ @Override public <K, V, C extends Serializable> IgniteModel<I, List<O>> fit(DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
List<IgniteSupplier<IgniteModel<I, O>>> tasks = trainers.stream()
- .map(tr -> (IgniteSupplier<IgniteModel<I, O>>)(() -> tr.fit(datasetBuilder,
- CompositionUtils.asFeatureExtractor(extractor),
- CompositionUtils.asLabelExtractor(extractor))))
+ .map(tr -> (IgniteSupplier<IgniteModel<I, O>>)(() -> tr.fit(datasetBuilder, extractor)))
.collect(Collectors.toList());
List<IgniteModel<I, O>> mdls = environment.parallelismStrategy().submit(tasks).stream()
@@ -96,8 +94,9 @@ public class TrainersParallelComposition<I, O, L> extends DatasetTrainer<IgniteM
}
/** {@inheritDoc} */
- @Override public <K, V> IgniteModel<I, List<O>> update(IgniteModel<I, List<O>> mdl, DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ @Override public <K, V, C extends Serializable> IgniteModel<I, List<O>> update(IgniteModel<I, List<O>> mdl,
+ DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
ModelsParallelComposition<I, O> typedMdl = (ModelsParallelComposition<I, O>)mdl;
assert typedMdl.submodels().size() == trainers.size();
@@ -105,9 +104,7 @@ public class TrainersParallelComposition<I, O, L> extends DatasetTrainer<IgniteM
for (int i = 0; i < trainers.size(); i++) {
int j = i;
- tasks.add(() -> trainers.get(j).update(typedMdl.submodels().get(j), datasetBuilder,
- CompositionUtils.asFeatureExtractor(extractor),
- CompositionUtils.asLabelExtractor(extractor)));
+ tasks.add(() -> trainers.get(j).update(typedMdl.submodels().get(j), datasetBuilder, extractor));
}
List<IgniteModel<I, O>> mdls = environment.parallelismStrategy().submit(tasks).stream()
@@ -118,10 +115,8 @@ public class TrainersParallelComposition<I, O, L> extends DatasetTrainer<IgniteM
}
/**
- * This method is never called, instead of constructing logic of update from
- * {@link DatasetTrainer#isUpdateable} and
- * {@link DatasetTrainer#updateModel}
- * in this class we explicitly override update method.
+ * This method is never called, instead of constructing logic of update from {@link DatasetTrainer#isUpdateable} and
+ * {@link DatasetTrainer#updateModel} in this class we explicitly override update method.
*
* @param mdl Model.
* @return True if current critical for training parameters correspond to parameters from last training.
@@ -132,16 +127,16 @@ public class TrainersParallelComposition<I, O, L> extends DatasetTrainer<IgniteM
}
/**
- * This method is never called, instead of constructing logic of update from
- * {@link DatasetTrainer#isUpdateable(IgniteModel)} and
- * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
+ * This method is never called, instead of constructing logic of update from {@link
+ * DatasetTrainer#isUpdateable(IgniteModel)} and {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, Vectorizer)}
* in this class we explicitly override update method.
*
* @param mdl Model.
* @return Updated model.
*/
- @Override protected <K, V> IgniteModel<I, List<O>> updateModel(IgniteModel<I, List<O>> mdl, DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ @Override protected <K, V, C extends Serializable> IgniteModel<I, List<O>> updateModel(IgniteModel<I, List<O>> mdl,
+ DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
// Never called.
throw new IllegalStateException();
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java
index 1105e1b..9c1eb31 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/combinators/sequential/TrainersSequentialComposition.java
@@ -17,18 +17,20 @@
package org.apache.ignite.ml.composition.combinators.sequential;
-import java.util.ArrayList;
-import java.util.List;
import org.apache.ignite.lang.IgniteBiPredicate;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.composition.DatasetMapping;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.DatasetTrainer;
-import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
/**
* Sequential composition of trainers.
@@ -82,69 +84,16 @@ public class TrainersSequentialComposition<I, O1, O2, L> extends DatasetTrainer<
out2In);
}
- /**
- * Sequential composition of same trainers.
- *
- * @param <I> Type of input of model produced by trainers.
- * @param <O> Type of output of model produced by trainers.
- * @param <L> Type of labels.
- */
- private static class SameTrainersSequentialComposition<I, O, L> extends TrainersSequentialComposition<I, O, O, L> {
- /** Trainer to sequentially compose. */
- private final DatasetTrainer<IgniteModel<I, O>, L> tr;
-
- /**
- * Predicate depending on index and model produced by the last trainer indicating if composition process should
- * stop
- */
- private final IgniteBiPredicate<Integer, IgniteModel<I, O>> shouldStop;
-
- /** Function for conversion of output of model into input of next model. */
- private final IgniteFunction<O, I> out2Input;
-
- /**
- * Create instance of this class.
- *
- * @param tr Trainer to sequentially compose.
- * @param datasetMapping Dataaset mapping.
- * @param shouldStop Predicate depending on index and model produced by the last trainer
- * indicating if composition process should stop.
- * @param out2Input Function for conversion of output of model into input of next model.
- */
- public SameTrainersSequentialComposition(
- DatasetTrainer<IgniteModel<I, O>, L> tr,
- IgniteBiFunction<Integer, ? super IgniteModel<I, O>, IgniteFunction<LabeledVector<L>, LabeledVector<L>>> datasetMapping,
- IgniteBiPredicate<Integer, IgniteModel<I, O>> shouldStop,
- IgniteFunction<O, I> out2Input) {
- super(null, null, datasetMapping);
- this.tr = tr;
- this.shouldStop = (iteration, model) -> iteration != 0 && shouldStop.apply(iteration, model);
- this.out2Input = out2Input;
- }
-
- /** {@inheritDoc} */
- @Override public <K, V> ModelsSequentialComposition<I, O, O> fit(DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
-
- int i = 0;
- IgniteModel<I, O> currMdl = null;
- IgniteFunction<LabeledVector<L>, LabeledVector<L>> mapping =
- IgniteFunction.identity();
- List<IgniteModel<I, O>> mdls = new ArrayList<>();
-
- while (!shouldStop.apply(i, currMdl)) {
- currMdl = tr.fit(datasetBuilder, extractor.andThen(mapping));
- mdls.add(currMdl);
- if (shouldStop.apply(i, currMdl))
- break;
+ /** {@inheritDoc} */
+ @Override public <K, V, C extends Serializable> ModelsSequentialComposition<I, O1, O2> fit(DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
- mapping = datasetMapping.apply(i, currMdl);
+ IgniteModel<I, O1> mdl1 = tr1.fit(datasetBuilder, extractor);
+ IgniteFunction<LabeledVector<L>, LabeledVector<L>> mapping = datasetMapping.apply(0, mdl1);
- i++;
- }
+ IgniteModel<O1, O2> mdl2 = tr2.fit(datasetBuilder, extractor.map(mapping));
- return ModelsSequentialComposition.ofSame(mdls, out2Input);
- }
+ return new ModelsSequentialComposition<>(mdl1, mdl2);
}
/**
@@ -176,29 +125,32 @@ public class TrainersSequentialComposition<I, O1, O2, L> extends DatasetTrainer<
this.datasetMapping = datasetMapping;
}
- /** {@inheritDoc} */
- @Override public <K, V> ModelsSequentialComposition<I, O1, O2> fit(DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
-
- IgniteModel<I, O1> mdl1 = tr1.fit(datasetBuilder, extractor);
- IgniteFunction<LabeledVector<L>, LabeledVector<L>> mapping = datasetMapping.apply(0, mdl1);
-
- IgniteModel<O1, O2> mdl2 = tr2.fit(datasetBuilder, extractor.andThen(mapping));
-
- return new ModelsSequentialComposition<>(mdl1, mdl2);
+ /**
+ * This method is never called, instead of constructing logic of update from
+ * {@link DatasetTrainer#isUpdateable(IgniteModel)} and
+ * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, Vectorizer)}
+ * in this class we explicitly override update method.
+ *
+ * @param mdl Model.
+ * @return Updated model.
+ */
+ @Override protected <K, V, C extends Serializable> ModelsSequentialComposition<I, O1, O2> updateModel(
+ ModelsSequentialComposition<I, O1, O2> mdl,
+ DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
+ // Never called.
+ throw new IllegalStateException();
}
/** {@inheritDoc} */
- @Override public <K, V> ModelsSequentialComposition<I, O1, O2> update(
+ @Override public <K, V, C extends Serializable> ModelsSequentialComposition<I, O1, O2> update(
ModelsSequentialComposition<I, O1, O2> mdl, DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ Vectorizer<K, V, C, L> extractor) {
IgniteModel<I, O1> firstUpdated = tr1.update(mdl.firstModel(), datasetBuilder, extractor);
IgniteFunction<LabeledVector<L>, LabeledVector<L>> mapping = datasetMapping.apply(0, firstUpdated);
- IgniteModel<O1, O2> secondUpdated = tr2.update(mdl.secondModel(),
- datasetBuilder,
- extractor.andThen(mapping));
+ IgniteModel<O1, O2> secondUpdated = tr2.update(mdl.secondModel(), datasetBuilder, extractor.map(mapping));
return new ModelsSequentialComposition<>(firstUpdated, secondUpdated);
}
@@ -218,20 +170,68 @@ public class TrainersSequentialComposition<I, O1, O2, L> extends DatasetTrainer<
}
/**
- * This method is never called, instead of constructing logic of update from
- * {@link DatasetTrainer#isUpdateable(IgniteModel)} and
- * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
- * in this class we explicitly override update method.
+ * Sequential composition of same trainers.
*
- * @param mdl Model.
- * @return Updated model.
+ * @param <I> Type of input of model produced by trainers.
+ * @param <O> Type of output of model produced by trainers.
+ * @param <L> Type of labels.
*/
- @Override protected <K, V> ModelsSequentialComposition<I, O1, O2> updateModel(
- ModelsSequentialComposition<I, O1, O2> mdl,
- DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
- // Never called.
- throw new IllegalStateException();
+ private static class SameTrainersSequentialComposition<I, O, L> extends TrainersSequentialComposition<I, O, O, L> {
+ /** Trainer to sequentially compose. */
+ private final DatasetTrainer<IgniteModel<I, O>, L> tr;
+
+ /**
+ * Predicate depending on index and model produced by the last trainer indicating if composition process should
+ * stop
+ */
+ private final IgniteBiPredicate<Integer, IgniteModel<I, O>> shouldStop;
+
+ /** Function for conversion of output of model into input of next model. */
+ private final IgniteFunction<O, I> out2Input;
+
+ /**
+ * Create instance of this class.
+ *
+ * @param tr Trainer to sequentially compose.
+ * @param datasetMapping Dataaset mapping.
+ * @param shouldStop Predicate depending on index and model produced by the last trainer
+ * indicating if composition process should stop.
+ * @param out2Input Function for conversion of output of model into input of next model.
+ */
+ public SameTrainersSequentialComposition(
+ DatasetTrainer<IgniteModel<I, O>, L> tr,
+ IgniteBiFunction<Integer, ? super IgniteModel<I, O>, IgniteFunction<LabeledVector<L>, LabeledVector<L>>> datasetMapping,
+ IgniteBiPredicate<Integer, IgniteModel<I, O>> shouldStop,
+ IgniteFunction<O, I> out2Input) {
+ super(null, null, datasetMapping);
+ this.tr = tr;
+ this.shouldStop = (iteration, model) -> iteration != 0 && shouldStop.apply(iteration, model);
+ this.out2Input = out2Input;
+ }
+
+ /** {@inheritDoc} */
+ @Override public <K, V, C extends Serializable> ModelsSequentialComposition<I, O, O> fit(DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
+
+ int i = 0;
+ IgniteModel<I, O> currMdl = null;
+ IgniteFunction<LabeledVector<L>, LabeledVector<L>> mapping =
+ IgniteFunction.identity();
+ List<IgniteModel<I, O>> mdls = new ArrayList<>();
+
+ while (!shouldStop.apply(i, currMdl)) {
+ currMdl = tr.fit(datasetBuilder, extractor.map(mapping));
+ mdls.add(currMdl);
+ if (shouldStop.apply(i, currMdl))
+ break;
+
+ mapping = datasetMapping.apply(i, currMdl);
+
+ i++;
+ }
+
+ return ModelsSequentialComposition.ofSame(mdls, out2Input);
+ }
}
/**
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java
index 8d26f14..931ece8 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/stacking/StackedDatasetTrainer.java
@@ -17,14 +17,12 @@
package org.apache.ignite.ml.composition.stacking;
-import java.util.ArrayList;
-import java.util.List;
-import java.util.stream.Collectors;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.composition.combinators.parallel.ModelsParallelComposition;
import org.apache.ignite.ml.composition.combinators.parallel.TrainersParallelComposition;
import org.apache.ignite.ml.dataset.DatasetBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.functions.IgniteBinaryOperator;
@@ -34,7 +32,11 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.AdaptableDatasetTrainer;
import org.apache.ignite.ml.trainers.DatasetTrainer;
-import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.stream.Collectors;
/**
* {@link DatasetTrainer} encapsulating stacking technique for model training.
@@ -230,15 +232,15 @@ public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L>
}
/** {@inheritDoc} */
- @Override public <K, V> StackedModel<IS, IA, O, AM> fit(DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ @Override public <K, V, C extends Serializable> StackedModel<IS, IA, O, AM> fit(DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, L> extractor) {
return new StackedModel<>(getTrainer().fit(datasetBuilder, extractor));
}
/** {@inheritDoc} */
- @Override public <K, V> StackedModel<IS, IA, O, AM> update(StackedModel<IS, IA, O, AM> mdl,
- DatasetBuilder<K, V> datasetBuilder, FeatureLabelExtractor<K, V, L> extractor) {
+ @Override public <K, V, C extends Serializable> StackedModel<IS, IA, O, AM> update(StackedModel<IS, IA, O, AM> mdl,
+ DatasetBuilder<K, V> datasetBuilder, Vectorizer<K, V, C, L> extractor) {
return new StackedModel<>(getTrainer().update(mdl, datasetBuilder, extractor));
}
@@ -347,15 +349,15 @@ public class StackedDatasetTrainer<IS, IA, O, AM extends IgniteModel<IA, O>, L>
/**
* This method is never called, instead of constructing logic of update from
* {@link DatasetTrainer#isUpdateable(IgniteModel)} and
- * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, IgniteBiFunction, IgniteBiFunction)}
+ * {@link DatasetTrainer#updateModel(IgniteModel, DatasetBuilder, Vectorizer)}
* in this class we explicitly override update method.
*
* @param mdl Model.
* @return Updated model.
*/
- @Override protected <K, V> StackedModel<IS, IA, O, AM> updateModel(StackedModel<IS, IA, O, AM> mdl,
+ @Override protected <K, V, C extends Serializable> StackedModel<IS, IA, O, AM> updateModel(StackedModel<IS, IA, O, AM> mdl,
DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, L> extractor) {
+ Vectorizer<K, V, C, L> extractor) {
// This method is never called, we override "update" instead.
throw new IllegalStateException();
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetFactory.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetFactory.java
index ef8eb23..1e5cf5e 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetFactory.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/DatasetFactory.java
@@ -17,10 +17,9 @@
package org.apache.ignite.ml.dataset;
-import java.io.Serializable;
-import java.util.Map;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder;
import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.dataset.primitive.SimpleDataset;
@@ -32,8 +31,9 @@ import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+import java.io.Serializable;
+import java.util.Map;
/**
* Factory providing a client facing API that allows to construct basic and the most frequently used types of dataset.
@@ -53,19 +53,13 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
* dataset the following approach is used:
*
* <code>
- * {@code
- * Dataset<C, D> dataset = DatasetFactory.create(
- * ignite,
- * cache,
- * partitionContextBuilder,
- * partitionDataBuilder
- * );
+ * {@code Dataset<C, D> dataset = DatasetFactory.create( ignite, cache, partitionContextBuilder, partitionDataBuilder );
* }
* </code>
*
* <p>As well as the generic building method {@code create} this factory provides methods that allow to create a
- * specific dataset types such as method {@code createSimpleDataset} to create {@link SimpleDataset} and method
- * {@code createSimpleLabeledDataset} to create {@link SimpleLabeledDataset}.
+ * specific dataset types such as method {@code createSimpleDataset} to create {@link SimpleDataset} and method {@code
+ * createSimpleLabeledDataset} to create {@link SimpleLabeledDataset}.
*
* @see Dataset
* @see PartitionContextBuilder
@@ -73,9 +67,9 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
*/
public class DatasetFactory {
/**
- * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and
- * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with
- * any desired partition {@code context} and {@code data}.
+ * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and {@code
+ * partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with any
+ * desired partition {@code context} and {@code data}.
*
* @param envBuilder Learning environment builder.
* @param datasetBuilder Dataset builder.
@@ -100,9 +94,9 @@ public class DatasetFactory {
}
/**
- * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and
- * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with
- * any desired partition {@code context} and {@code data}.
+ * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and {@code
+ * partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with any
+ * desired partition {@code context} and {@code data}.
*
* @param datasetBuilder Dataset builder.
* @param partCtxBuilder Partition {@code context} builder.
@@ -125,9 +119,9 @@ public class DatasetFactory {
}
/**
- * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and
- * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with
- * any desired partition {@code context} and {@code data}.
+ * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and {@code
+ * partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with any
+ * desired partition {@code context} and {@code data}.
*
* @param ignite Ignite instance.
* @param upstreamCache Ignite Cache with {@code upstream} data.
@@ -154,9 +148,9 @@ public class DatasetFactory {
}
/**
- * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and
- * {@code partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with
- * any desired partition {@code context} and {@code data}.
+ * Creates a new instance of distributed dataset using the specified {@code partCtxBuilder} and {@code
+ * partDataBuilder}. This is the generic methods that allows to create any Ignite Cache based datasets with any
+ * desired partition {@code context} and {@code data}.
*
* @param ignite Ignite instance.
* @param upstreamCache Ignite Cache with {@code upstream} data.
@@ -180,9 +174,9 @@ public class DatasetFactory {
}
/**
- * Creates a new instance of distributed {@link SimpleDataset} using the specified {@code partCtxBuilder} and
- * {@code featureExtractor}. This methods determines partition {@code data} to be {@link SimpleDatasetData}, but
- * allows to use any desired type of partition {@code context}.
+ * Creates a new instance of distributed {@link SimpleDataset} using the specified {@code partCtxBuilder} and {@code
+ * featureExtractor}. This methods determines partition {@code data} to be {@link SimpleDatasetData}, but allows to
+ * use any desired type of partition {@code context}.
*
* @param datasetBuilder Dataset builder.
* @param envBuilder Learning environment builder.
@@ -193,11 +187,11 @@ public class DatasetFactory {
* @param <C> Type of a partition {@code context}.
* @return Dataset.
*/
- public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(
+ public static <K, V, C extends Serializable, CO extends Serializable> SimpleDataset<C> createSimpleDataset(
DatasetBuilder<K, V> datasetBuilder,
LearningEnvironmentBuilder envBuilder,
PartitionContextBuilder<K, V, C> partCtxBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor) {
+ Vectorizer<K, V, CO, ?> featureExtractor) {
return create(
datasetBuilder,
envBuilder,
@@ -207,9 +201,9 @@ public class DatasetFactory {
}
/**
- * Creates a new instance of distributed {@link SimpleDataset} using the specified {@code partCtxBuilder} and
- * {@code featureExtractor}. This methods determines partition {@code data} to be {@link SimpleDatasetData}, but
- * allows to use any desired type of partition {@code context}.
+ * Creates a new instance of distributed {@link SimpleDataset} using the specified {@code partCtxBuilder} and {@code
+ * featureExtractor}. This methods determines partition {@code data} to be {@link SimpleDatasetData}, but allows to
+ * use any desired type of partition {@code context}.
*
* @param ignite Ignite instance.
* @param upstreamCache Ignite Cache with {@code upstream} data.
@@ -221,11 +215,12 @@ public class DatasetFactory {
* @param <C> Type of a partition {@code context}.
* @return Dataset.
*/
- public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(Ignite ignite,
+ public static <K, V, C extends Serializable, CO extends Serializable> SimpleDataset<C> createSimpleDataset(
+ Ignite ignite,
IgniteCache<K, V> upstreamCache,
LearningEnvironmentBuilder envBuilder,
PartitionContextBuilder<K, V, C> partCtxBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor) {
+ Vectorizer<K, V, CO, ?> featureExtractor) {
return createSimpleDataset(
new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
envBuilder,
@@ -236,69 +231,66 @@ public class DatasetFactory {
/**
* Creates a new instance of distributed {@link SimpleLabeledDataset} using the specified {@code partCtxBuilder},
- * {@code featureExtractor} and {@code lbExtractor}. This method determines partition {@code data} to be
- * {@link SimpleLabeledDatasetData}, but allows to use any desired type of partition {@code context}.
+ * {@code featureExtractor} and {@code lbExtractor}. This method determines partition {@code data} to be {@link
+ * SimpleLabeledDatasetData}, but allows to use any desired type of partition {@code context}.
*
* @param datasetBuilder Dataset builder.
* @param envBuilder Learning environment builder.
* @param partCtxBuilder Partition {@code context} builder.
- * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
- * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
+ * @param vectorizer Upstream vectorizer used to extract features and labels and build {@link
+ * SimpleLabeledDatasetData}.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @param <C> Type of a partition {@code context}.
* @return Dataset.
*/
- public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(
+ public static <K, V, C extends Serializable, CO extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(
DatasetBuilder<K, V> datasetBuilder,
LearningEnvironmentBuilder envBuilder,
PartitionContextBuilder<K, V, C> partCtxBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, double[]> lbExtractor) {
+ Vectorizer<K, V, CO, double[]> vectorizer) {
return create(
datasetBuilder,
envBuilder,
partCtxBuilder,
- new SimpleLabeledDatasetDataBuilder<>(featureExtractor, lbExtractor)
+ new SimpleLabeledDatasetDataBuilder<>(vectorizer)
).wrap(SimpleLabeledDataset::new);
}
/**
* Creates a new instance of distributed {@link SimpleLabeledDataset} using the specified {@code partCtxBuilder},
- * {@code featureExtractor} and {@code lbExtractor}. This method determines partition {@code data} to be
- * {@link SimpleLabeledDatasetData}, but allows to use any desired type of partition {@code context}.
+ * {@code featureExtractor} and {@code lbExtractor}. This method determines partition {@code data} to be {@link
+ * SimpleLabeledDatasetData}, but allows to use any desired type of partition {@code context}.
*
* @param ignite Ignite instance.
* @param upstreamCache Ignite Cache with {@code upstream} data.
* @param envBuilder Learning environment builder.
* @param partCtxBuilder Partition {@code context} builder.
- * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
- * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
+ * @param vectorizer Upstream vectorizer used to extract features and labels and build {@link
+ * SimpleLabeledDatasetData}.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @param <C> Type of a partition {@code context}.
* @return Dataset.
*/
- public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(
+ public static <K, V, C extends Serializable, CO extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(
Ignite ignite,
IgniteCache<K, V> upstreamCache,
LearningEnvironmentBuilder envBuilder,
PartitionContextBuilder<K, V, C> partCtxBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, double[]> lbExtractor) {
+ Vectorizer<K, V, CO, double[]> vectorizer) {
return createSimpleLabeledDataset(
new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
envBuilder,
partCtxBuilder,
- featureExtractor,
- lbExtractor
+ vectorizer
);
}
/**
* Creates a new instance of distributed {@link SimpleDataset} using the specified {@code featureExtractor}. This
- * methods determines partition {@code context} to be {@link EmptyContext} and partition {@code data} to be
- * {@link SimpleDatasetData}.
+ * methods determines partition {@code context} to be {@link EmptyContext} and partition {@code data} to be {@link
+ * SimpleDatasetData}.
*
* @param datasetBuilder Dataset builder.
* @param envBuilder Learning environment builder.
@@ -307,10 +299,10 @@ public class DatasetFactory {
* @param <V> Type of a value in {@code upstream} data.
* @return Dataset.
*/
- public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(
+ public static <K, V, CO extends Serializable> SimpleDataset<EmptyContext> createSimpleDataset(
DatasetBuilder<K, V> datasetBuilder,
LearningEnvironmentBuilder envBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor) {
+ Vectorizer<K, V, CO, ?> featureExtractor) {
return createSimpleDataset(
datasetBuilder,
envBuilder,
@@ -321,8 +313,8 @@ public class DatasetFactory {
/**
* Creates a new instance of distributed {@link SimpleDataset} using the specified {@code featureExtractor}. This
- * methods determines partition {@code context} to be {@link EmptyContext} and partition {@code data} to be
- * {@link SimpleDatasetData}.
+ * methods determines partition {@code context} to be {@link EmptyContext} and partition {@code data} to be {@link
+ * SimpleDatasetData}.
*
* @param ignite Ignite instance.
* @param upstreamCache Ignite Cache with {@code upstream} data.
@@ -332,11 +324,11 @@ public class DatasetFactory {
* @param <V> Type of a value in {@code upstream} data.
* @return Dataset.
*/
- public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(
+ public static <K, V, CO extends Serializable> SimpleDataset<EmptyContext> createSimpleDataset(
Ignite ignite,
IgniteCache<K, V> upstreamCache,
LearningEnvironmentBuilder envBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor) {
+ Vectorizer<K, V, CO, ?> featureExtractor) {
return createSimpleDataset(
new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
envBuilder,
@@ -346,8 +338,8 @@ public class DatasetFactory {
/**
* Creates a new instance of distributed {@link SimpleDataset} using the specified {@code featureExtractor}. This
- * methods determines partition {@code context} to be {@link EmptyContext} and partition {@code data} to be
- * {@link SimpleDatasetData}.
+ * methods determines partition {@code context} to be {@link EmptyContext} and partition {@code data} to be {@link
+ * SimpleDatasetData}.
*
* @param ignite Ignite instance.
* @param upstreamCache Ignite Cache with {@code upstream} data.
@@ -356,10 +348,10 @@ public class DatasetFactory {
* @param <V> Type of a value in {@code upstream} data.
* @return Dataset.
*/
- public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(
+ public static <K, V, CO extends Serializable> SimpleDataset<EmptyContext> createSimpleDataset(
Ignite ignite,
IgniteCache<K, V> upstreamCache,
- IgniteBiFunction<K, V, Vector> featureExtractor) {
+ Vectorizer<K, V, CO, ?> featureExtractor) {
return createSimpleDataset(
new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
LearningEnvironmentBuilder.defaultBuilder(),
@@ -374,23 +366,21 @@ public class DatasetFactory {
*
* @param datasetBuilder Dataset builder.
* @param envBuilder Learning environment builder.
- * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
- * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
+ * @param vectorizer Upstream vectorizer used to extract features and labels and build {@link
+ * SimpleLabeledDatasetData}.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @return Dataset.
*/
- public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(
+ public static <K, V, CO extends Serializable> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(
DatasetBuilder<K, V> datasetBuilder,
LearningEnvironmentBuilder envBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, double[]> lbExtractor) {
+ Vectorizer<K, V, CO, double[]> vectorizer) {
return createSimpleLabeledDataset(
datasetBuilder,
envBuilder,
new EmptyContextBuilder<>(),
- featureExtractor,
- lbExtractor
+ vectorizer
);
}
@@ -402,22 +392,21 @@ public class DatasetFactory {
* @param ignite Ignite instance.
* @param upstreamCache Ignite Cache with {@code upstream} data.
* @param envBuilder Learning environment builder.
- * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
- * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
+ * @param vectorizer Upstream vectorizer used to extract features and labels and build {@link
+ * SimpleLabeledDatasetData}.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @return Dataset.
*/
- public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(
+ public static <K, V, CO extends Serializable> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(
Ignite ignite,
LearningEnvironmentBuilder envBuilder,
- IgniteCache<K, V> upstreamCache, IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, double[]> lbExtractor) {
+ IgniteCache<K, V> upstreamCache,
+ Vectorizer<K, V, CO, double[]> vectorizer) {
return createSimpleLabeledDataset(
new CacheBasedDatasetBuilder<>(ignite, upstreamCache),
envBuilder,
- featureExtractor,
- lbExtractor
+ vectorizer
);
}
@@ -451,9 +440,9 @@ public class DatasetFactory {
}
/**
- * Creates a new instance of local {@link SimpleDataset} using the specified {@code partCtxBuilder} and
- * {@code featureExtractor}. This methods determines partition {@code data} to be {@link SimpleDatasetData}, but
- * allows to use any desired type of partition {@code context}.
+ * Creates a new instance of local {@link SimpleDataset} using the specified {@code partCtxBuilder} and {@code
+ * featureExtractor}. This methods determines partition {@code data} to be {@link SimpleDatasetData}, but allows to
+ * use any desired type of partition {@code context}.
*
* @param upstreamMap {@code Map} with {@code upstream} data.
* @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
@@ -465,12 +454,12 @@ public class DatasetFactory {
* @param <C> Type of a partition {@code context}.
* @return Dataset.
*/
- public static <K, V, C extends Serializable> SimpleDataset<C> createSimpleDataset(
+ public static <K, V, C extends Serializable, CO extends Serializable> SimpleDataset<C> createSimpleDataset(
Map<K, V> upstreamMap,
int partitions,
LearningEnvironmentBuilder envBuilder,
PartitionContextBuilder<K, V, C> partCtxBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor) {
+ Vectorizer<K, V, CO, ?> featureExtractor) {
return createSimpleDataset(
new LocalDatasetBuilder<>(upstreamMap, partitions),
envBuilder,
@@ -480,39 +469,39 @@ public class DatasetFactory {
}
/**
- * Creates a new instance of local {@link SimpleLabeledDataset} using the specified {@code partCtxBuilder},
- * {@code featureExtractor} and {@code lbExtractor}. This method determines partition {@code data} to be
- * {@link SimpleLabeledDatasetData}, but allows to use any desired type of partition {@code context}.
+ * Creates a new instance of local {@link SimpleLabeledDataset} using the specified {@code partCtxBuilder}, {@code
+ * featureExtractor} and {@code lbExtractor}. This method determines partition {@code data} to be {@link
+ * SimpleLabeledDatasetData}, but allows to use any desired type of partition {@code context}.
*
* @param upstreamMap {@code Map} with {@code upstream} data.
* @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
* @param envBuilder Learning environment builder.
* @param partCtxBuilder Partition {@code context} builder.
- * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
- * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
+ * @param vectorizer Upstream vectorizer used to extract features and labels and build {@link
+ * SimpleLabeledDatasetData}.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @param <C> Type of a partition {@code context}.
* @return Dataset.
*/
- public static <K, V, C extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(
+ public static <K, V, C extends Serializable, CO extends Serializable> SimpleLabeledDataset<C> createSimpleLabeledDataset(
Map<K, V> upstreamMap,
int partitions,
LearningEnvironmentBuilder envBuilder,
PartitionContextBuilder<K, V, C> partCtxBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
+ Vectorizer<K, V, CO, double[]> vectorizer) {
return createSimpleLabeledDataset(
new LocalDatasetBuilder<>(upstreamMap, partitions),
envBuilder,
partCtxBuilder,
- featureExtractor, lbExtractor
+ vectorizer
);
}
/**
- * Creates a new instance of local {@link SimpleDataset} using the specified {@code featureExtractor}. This
- * methods determines partition {@code context} to be {@link EmptyContext} and partition {@code data} to be
- * {@link SimpleDatasetData}.
+ * Creates a new instance of local {@link SimpleDataset} using the specified {@code featureExtractor}. This methods
+ * determines partition {@code context} to be {@link EmptyContext} and partition {@code data} to be {@link
+ * SimpleDatasetData}.
*
* @param upstreamMap {@code Map} with {@code upstream} data.
* @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
@@ -522,9 +511,10 @@ public class DatasetFactory {
* @param <V> Type of a value in {@code upstream} data.
* @return Dataset.
*/
- public static <K, V> SimpleDataset<EmptyContext> createSimpleDataset(Map<K, V> upstreamMap, int partitions,
+ public static <K, V, CO extends Serializable> SimpleDataset<EmptyContext> createSimpleDataset(Map<K, V> upstreamMap,
+ int partitions,
LearningEnvironmentBuilder envBuilder,
- IgniteBiFunction<K, V, Vector> featureExtractor) {
+ Vectorizer<K, V, CO, ?> featureExtractor) {
return createSimpleDataset(
new LocalDatasetBuilder<>(upstreamMap, partitions),
envBuilder,
@@ -533,28 +523,27 @@ public class DatasetFactory {
}
/**
- * Creates a new instance of local {@link SimpleLabeledDataset} using the specified {@code featureExtractor}
- * and {@code lbExtractor}. This methods determines partition {@code context} to be {@link EmptyContext} and
- * partition {@code data} to be {@link SimpleLabeledDatasetData}.
+ * Creates a new instance of local {@link SimpleLabeledDataset} using the specified {@code featureExtractor} and
+ * {@code lbExtractor}. This methods determines partition {@code context} to be {@link EmptyContext} and partition
+ * {@code data} to be {@link SimpleLabeledDatasetData}.
*
* @param upstreamMap {@code Map} with {@code upstream} data.
* @param partitions Number of partitions {@code upstream} {@code Map} will be divided on.
* @param envBuilder Learning environment builder.
- * @param featureExtractor Feature extractor used to extract features and build {@link SimpleLabeledDatasetData}.
- * @param lbExtractor Label extractor used to extract labels and build {@link SimpleLabeledDatasetData}.
+ * @param vectorizer Upstream vectorizer used to extract features and labels and build {@link
+ * SimpleLabeledDatasetData}.
* @param <K> Type of a key in {@code upstream} data.
* @param <V> Type of a value in {@code upstream} data.
* @return Dataset.
*/
- public static <K, V> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(Map<K, V> upstreamMap,
+ public static <K, V, CO extends Serializable> SimpleLabeledDataset<EmptyContext> createSimpleLabeledDataset(
+ Map<K, V> upstreamMap,
LearningEnvironmentBuilder envBuilder,
- int partitions, IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, double[]> lbExtractor) {
+ int partitions, Vectorizer<K, V, CO, double[]> vectorizer) {
return createSimpleLabeledDataset(
new LocalDatasetBuilder<>(upstreamMap, partitions),
envBuilder,
- featureExtractor,
- lbExtractor
+ vectorizer
);
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/BucketMeta.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/BucketMeta.java
index 5dab662..a3c14ea 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/BucketMeta.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/BucketMeta.java
@@ -48,7 +48,7 @@ public class BucketMeta implements Serializable {
* Returns bucket id for feature value.
*
* @param val Value.
- * @return bucket id.
+ * @return Bucket id.
*/
public int getBucketId(Double val) {
if(featureMeta.isCategoricalFeature())
@@ -61,7 +61,7 @@ public class BucketMeta implements Serializable {
* Returns mean value by bucket id.
*
* @param bucketId Bucket id.
- * @return mean value of feature.
+ * @return Mean value of feature.
*/
public double bucketIdToValue(int bucketId) {
if(featureMeta.isCategoricalFeature())
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/Histogram.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/Histogram.java
index 6784af1..ef9a9fe 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/Histogram.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/Histogram.java
@@ -37,20 +37,20 @@ public interface Histogram<T, H extends Histogram<T, H>> extends Serializable {
/**
*
- * @return bucket ids.
+ * @return Bucket ids.
*/
public Set<Integer> buckets();
/**
*
* @param bucketId Bucket id.
- * @return value in according to bucket id.
+ * @return Value in according to bucket id.
*/
public Optional<Double> getValue(Integer bucketId);
/**
* @param other Other histogram.
- * @return sum of this and other histogram.
+ * @return Sum of this and other histogram.
*/
public H plus(H other);
@@ -58,7 +58,7 @@ public interface Histogram<T, H extends Histogram<T, H>> extends Serializable {
* Compares histogram with other and returns true if they are equals
*
* @param other Other histogram.
- * @return true if histograms are equal.
+ * @return True if histograms are equal.
*/
public boolean isEqualTo(H other);
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/ObjectHistogram.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/ObjectHistogram.java
index 697e79d..7bc12e2 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/ObjectHistogram.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/ObjectHistogram.java
@@ -17,11 +17,7 @@
package org.apache.ignite.ml.dataset.feature;
-import java.util.HashSet;
-import java.util.Map;
-import java.util.Optional;
-import java.util.Set;
-import java.util.TreeMap;
+import java.util.*;
/**
* Basic implementation of {@link Histogram} that implements also {@link DistributionComputer}.
@@ -118,14 +114,14 @@ public abstract class ObjectHistogram<T> implements Histogram<T, ObjectHistogram
* Counter mapping.
*
* @param obj Object.
- * @return counter.
+ * @return Counter.
*/
public abstract Double mapToCounter(T obj);
/**
* Creates an instance of ObjectHistogram from child class.
*
- * @return object histogram.
+ * @return Object histogram.
*/
public abstract ObjectHistogram<T> newInstance();
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/ExtractionUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/ExtractionUtils.java
new file mode 100644
index 0000000..4728f5d
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/ExtractionUtils.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset.feature.extractor;
+
+import java.io.Serializable;
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * Class aggregates helper classes and shortcut-classes with default types and behaviour for Vectorizers.
+ */
+public class ExtractionUtils {
+ /**
+ * Vectorizer with double-label containing on same level as feature values.
+ *
+ * @param <K> Key type.
+ * @param <V> Value type
+ * @param <C> Type of coordinate.
+ */
+ public abstract static class DefaultLabelVectorizer<K, V, C extends Serializable> extends Vectorizer<K, V, C, Double> {
+ /** Serial version uid. */
+ private static final long serialVersionUID = 2876703640636013770L;
+
+ /**
+ * Creates an instance of Vectorizer.
+ *
+ * @param coords Coordinates.
+ */
+ public DefaultLabelVectorizer(C... coords) {
+ super(coords);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Double label(C coord, K key, V value) {
+ return feature(coord, key, value);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Double zero() {
+ return 0.0;
+ }
+ }
+
+ /**
+ * Vectorizer with String-coordinates.
+ *
+ * @param <K> Type of key.
+ * @param <V> Type of value.
+ */
+ public abstract static class StringCoordVectorizer<K, V> extends DefaultLabelVectorizer<K, V, String> {
+ /** Serial version uid. */
+ private static final long serialVersionUID = 6989473570977667636L;
+
+ /**
+ * Creates an instance of Vectorizer.
+ *
+ * @param coords Coordinates.
+ */
+ public StringCoordVectorizer(String... coords) {
+ super(coords);
+ }
+ }
+
+ /**
+ * Vectorizer with integer coordinates.
+ *
+ * @param <K> Type of key.
+ * @param <V> Type of value.
+ */
+ public abstract static class IntCoordVectorizer<K, V> extends DefaultLabelVectorizer<K, V, Integer> {
+ /** Serial version uid. */
+ private static final long serialVersionUID = -1734141133396507699L;
+
+ /**
+ * Creates an instance of Vectorizer.
+ *
+ * @param coords Coordinates.
+ */
+ public IntCoordVectorizer(Integer... coords) {
+ super(coords);
+ }
+ }
+
+ /**
+ * Vectorizer extracting vectors from array-like structure with finite size and integer coordinates.
+ *
+ * @param <K> Type of key.
+ * @param <V> Type of value.
+ */
+ public abstract static class ArrayLikeVectorizer<K, V> extends IntCoordVectorizer<K, V> {
+ /** Serial version uid. */
+ private static final long serialVersionUID = 5383770258177577358L;
+
+ /**
+ * Creates an instance of Vectorizer.
+ *
+ * @param coords Coordinates.
+ */
+ public ArrayLikeVectorizer(Integer... coords) {
+ super(coords);
+ }
+
+ /**
+ * Size of array-like structure of upstream object.
+ *
+ * @param key Key.
+ * @param value Value.
+ * @return size.
+ */
+ protected abstract int sizeOf(K key, V value);
+
+ /** {@inheritDoc} */
+ @Override protected List<Integer> allCoords(K key, V value) {
+ return IntStream.range(0, sizeOf(key, value)).boxed().collect(Collectors.toList());
+ }
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/Vectorizer.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/Vectorizer.java
new file mode 100644
index 0000000..7c4ebb1
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/Vectorizer.java
@@ -0,0 +1,307 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset.feature.extractor;
+
+import org.apache.ignite.binary.BinaryObject;
+import org.apache.ignite.internal.util.typedef.internal.A;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.HashSet;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Class for extracting labeled vectors from upstream. This is an abstract class providing API for extracting feature
+ * and label values by "coordinates" of them from upstream objects. For example {@link BinaryObject} can be upstream
+ * object and coordinates for them are names of fields with double-values.
+ *
+ * @param <K> Type of keys in upstream.
+ * @param <V> Type of values in upstream.
+ * @param <C> Type of "coordinate" - index of feature value in upstream object.
+ * @param <L> Type of label for resulting vectors.
+ */
+public abstract class Vectorizer<K, V, C extends Serializable, L> implements FeatureLabelExtractor<K, V, L>, Serializable {
+ /** Label coordinate shortcut. */
+ private LabelCoordinate lbCoordinateShortcut = null;
+
+ /** Serial version uid. */
+ private static final long serialVersionUID = 4301406952131379459L;
+
+ /** If useAllValues == true then Vectorizer extract all fields as features from upstream object (except label). */
+ private final boolean useAllValues;
+
+ /** Extraction coordinates. */
+ private List<C> extractionCoordinates;
+
+ /** Label coordinate. */
+ private C labelCoord;
+
+ /**
+ * Extracts labeled vector from upstream object.
+ *
+ * @param key Key.
+ * @param value Value.
+ * @return vector.
+ */
+ public LabeledVector<L> apply(K key, V value) {
+ L lbl = isLabeled() ? label(labelCoord(key, value), key, value) : zero();
+
+ List<C> allCoords = null;
+ if (useAllValues) {
+ allCoords = allCoords(key, value).stream()
+ .filter(coord -> !coord.equals(labelCoord) && !excludedCoords.contains(coord))
+ .collect(Collectors.toList());
+ }
+
+ int vectorLength = useAllValues ? allCoords.size() : extractionCoordinates.size();
+ A.ensure(vectorLength >= 0, "vectorLength >= 0");
+
+ List<C> coordinatesForExtraction = useAllValues ? allCoords : extractionCoordinates;
+ Vector vector = createVector(vectorLength);
+ for (int i = 0; i < coordinatesForExtraction.size(); i++) {
+ Double feature = feature(coordinatesForExtraction.get(i), key, value);
+ if (feature != null)
+ vector.set(i, feature);
+ }
+ return new LabeledVector<>(vector, lbl);
+ }
+
+ /** Excluded coordinates. */
+ private HashSet<C> excludedCoords = new HashSet<>();
+
+ /**
+ * Creates an instance of Vectorizer.
+ *
+ * @param coords Coordinates for feature extraction. If array is empty then Vectorizer will extract all fields from
+ * upstream object.
+ */
+ public Vectorizer(C... coords) {
+ extractionCoordinates = Arrays.asList(coords);
+ this.useAllValues = coords.length == 0;
+ }
+
+ /**
+ * @return true if label in vector is valid.
+ */
+ private boolean isLabeled() {
+ return labelCoord != null || lbCoordinateShortcut != null;
+ }
+
+ /**
+ * Evaluates label coordinate if need.
+ *
+ * @param key Key.
+ * @param value Value.
+ * @return label coordinate.
+ */
+ private C labelCoord(K key, V value) {
+ A.ensure(isLabeled(), "isLabeled");
+ if (labelCoord != null)
+ return labelCoord;
+ else {
+ List<C> allCoords = allCoords(key, value);
+ A.ensure(!allCoords.isEmpty(), "!allCoords.isEmpty()");
+
+ switch (lbCoordinateShortcut) {
+ case FIRST:
+ labelCoord = allCoords.get(0);
+ break;
+ case LAST:
+ labelCoord = allCoords.get(allCoords.size() - 1);
+ break;
+ default:
+ throw new IllegalArgumentException();
+ }
+
+ return labelCoord;
+ }
+ }
+
+ /**
+ * Sets label coordinate for Vectorizer. By default it equals null and zero() will be used as label value.
+ *
+ * @param labelCoord Label coordinate.
+ * @return this.
+ */
+ public Vectorizer<K, V, C, L> labeled(C labelCoord) {
+ this.labelCoord = labelCoord;
+ this.lbCoordinateShortcut = null;
+ return this;
+ }
+
+ /**
+ * Sets label coordinate for Vectorizer. By default it equals null and zero() will be used as label value.
+ *
+ * @param labelCoord Label coordinate.
+ * @return this.
+ */
+ public Vectorizer<K, V, C, L> labeled(LabelCoordinate labelCoord) {
+ this.lbCoordinateShortcut = labelCoord;
+ this.labelCoord = null;
+ return this;
+ }
+
+ /**
+ * Exclude these coordinates from result vector.
+ *
+ * @param coords Coordinates.
+ * @return this.
+ */
+ public Vectorizer<K, V, C, L> exclude(C... coords) {
+ this.excludedCoords.addAll(Arrays.asList(coords));
+ return this;
+ }
+
+ /**
+ * Map vectorizer answer. This method should be called after creating basic vectorizer.
+ * NOTE: function "func" should be on ignite servers.
+ *
+ * @param func mapper.
+ * @param <L1> Type of new label.
+ * @return mapped vectorizer.
+ */
+ public <L1> Vectorizer<K, V, C, L1> map(IgniteFunction<LabeledVector<L>, LabeledVector<L1>> func) {
+ return new MappedVectorizer<>(this, func);
+ }
+
+ /**
+ * Shotrcuts for coordinates in feature vector.
+ */
+ public enum LabelCoordinate {
+ /** First. */FIRST,
+ /** Last. */LAST
+ }
+
+ /** {@inheritDoc} */
+ @Override public LabeledVector<L> extract(K k, V v) {
+ return apply(k, v);
+ }
+
+ /**
+ * Extracts feature value by given coordinate.
+ *
+ * @param coord Coordinate.
+ * @param key Key.
+ * @param value Value.
+ * @return feature value.
+ */
+ protected abstract Double feature(C coord, K key, V value);
+
+ /**
+ * Extract label value by given coordinate.
+ *
+ * @param coord Coordinate.
+ * @param key Key.
+ * @param value Value.
+ * @return label value.
+ */
+ protected abstract L label(C coord, K key, V value);
+
+ /**
+ * Returns default label value for unlabeled data.
+ *
+ * @return label value.
+ */
+ protected abstract L zero();
+
+ /**
+ * Returns list of all coordinate with feature values.
+ *
+ * @param key Key.
+ * @param value Value.
+ * @return all coordinates list.
+ */
+ protected abstract List<C> allCoords(K key, V value);
+
+ /**
+ * Create an instance of vector.
+ *
+ * @param size Vector size.
+ * @return vector.
+ */
+ protected Vector createVector(int size) {
+ return new DenseVector(size);
+ }
+
+ /**
+ * @param <K> Type of key.
+ * @param <V> Type of value.
+ * @param <C> Type of coordinates.
+ * @param <L0> Type of original label.
+ * @param <L1> Type of mapped label.
+ */
+ private static class MappedVectorizer<K, V, C extends Serializable, L0, L1> extends VectorizerAdapter<K, V, C, L1> {
+ /** Original vectorizer. */
+ protected final Vectorizer<K, V, C, L0> original;
+
+ /** Vectors mapping. */
+ private final IgniteFunction<LabeledVector<L0>, LabeledVector<L1>> mapping;
+
+ /**
+ * Creates an instance of MappedVectorizer.
+ */
+ public MappedVectorizer(Vectorizer<K, V, C, L0> original,
+ IgniteFunction<LabeledVector<L0>, LabeledVector<L1>> andThen) {
+
+ this.original = original;
+ this.mapping = andThen;
+ }
+
+ /** {@inheritDoc} */
+ @Override public LabeledVector<L1> apply(K key, V value) {
+ LabeledVector<L0> origVec = original.apply(key, value);
+ return mapping.apply(origVec);
+ }
+ }
+
+ /**
+ * Utility class for convenient overridings.
+ *
+ * @param <K> Type of key.
+ * @param <V> Type of value.
+ * @param <C> Type of coordinate.
+ * @param <L> Type od label.
+ */
+ public static class VectorizerAdapter<K, V, C extends Serializable, L> extends Vectorizer<K, V, C, L> {
+ /** {@inheritDoc} */
+ @Override protected Double feature(C coord, K key, V value) {
+ throw new IllegalStateException();
+ }
+
+ /** {@inheritDoc} */
+ @Override protected L label(C coord, K key, V value) {
+ throw new IllegalStateException();
+ }
+
+ /** {@inheritDoc} */
+ @Override protected L zero() {
+ throw new IllegalStateException();
+ }
+
+ /** {@inheritDoc} */
+ @Override protected List<C> allCoords(K key, V value) {
+ throw new IllegalStateException();
+ }
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/ArraysVectorizer.java
similarity index 51%
copy from modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java
copy to modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/ArraysVectorizer.java
index 72fff30..bd726fb 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/loss/Loss.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/ArraysVectorizer.java
@@ -15,31 +15,35 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.composition.boosting.loss;
+package org.apache.ignite.ml.dataset.feature.extractor.impl;
-import java.io.Serializable;
+import org.apache.ignite.ml.dataset.feature.extractor.ExtractionUtils;
/**
- * Loss interface of computing error or gradient of error on specific row in dataset.
+ * Vectorizer on arrays of doubles.
+ *
+ * @param <K> Key type.
*/
-public interface Loss extends Serializable {
- /**
- * Error value for model answer.
- *
- * @param sampleSize Sample size.
- * @param lb Label.
- * @param mdlAnswer Model answer.
- * @return error value.
- */
- public double error(long sampleSize, double lb, double mdlAnswer);
+public class ArraysVectorizer<K> extends ExtractionUtils.ArrayLikeVectorizer<K, double[]> {
+ /** Serial version uid. */
+ private static final long serialVersionUID = -1177109334215177722L;
/**
- * Error gradient value for model answer.
+ * Creates an instance of Vectorizer.
*
- * @param sampleSize Sample size.
- * @param lb Label.
- * @param mdlAnswer Model answer.
- * @return error value.
+ * @param coords Coordinates.
*/
- public double gradient(long sampleSize, double lb, double mdlAnswer);
+ public ArraysVectorizer(Integer ... coords) {
+ super(coords);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Double feature(Integer coord, K key, double[] value) {
+ return value[coord];
+ }
+
+ /** {@inheritDoc} */
+ @Override protected int sizeOf(K key, double[] value) {
+ return value.length;
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/BinaryObjectVectorizer.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/BinaryObjectVectorizer.java
new file mode 100644
index 0000000..7326589
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/BinaryObjectVectorizer.java
@@ -0,0 +1,142 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset.feature.extractor.impl;
+
+import org.apache.ignite.binary.BinaryObject;
+import org.apache.ignite.internal.binary.BinaryUtils;
+import org.apache.ignite.internal.binary.GridBinaryMarshaller;
+import org.apache.ignite.ml.dataset.feature.extractor.ExtractionUtils;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.impl.SparseVector;
+
+import java.util.HashMap;
+import java.util.List;
+import java.util.stream.Collectors;
+
+/**
+ * Vectorizer on binary objects.
+ *
+ * @param <K> Type of key.
+ */
+public class BinaryObjectVectorizer<K> extends ExtractionUtils.StringCoordVectorizer<K, BinaryObject> {
+ /** Serial version uid. */
+ private static final long serialVersionUID = 2152161240934492838L;
+
+ /** Object for denoting default value of feature mapping. */
+ public static final String DEFAULT_VALUE = "DEFAULT";
+ /** Mapping for feature with non-number values. */
+ private HashMap<String, HashMap<Object, Double>> featureValueMappings = new HashMap<>();
+
+ /**
+ * Creates an instance of Vectorizer.
+ *
+ * @param coords Coordinates.
+ */
+ public BinaryObjectVectorizer(String... coords) {
+ super(coords);
+ }
+
+ /**
+ * Sets values mapping for feature.
+ *
+ * @param coord Feature coordinate.
+ * @param valuesMapping Mapping.
+ * @return this.
+ */
+ public BinaryObjectVectorizer withFeature(String coord, Mapping valuesMapping) {
+ featureValueMappings.put(coord, valuesMapping.toMap());
+ return this;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Double feature(String coord, K key, BinaryObject value) {
+ HashMap<Object, Double> mapping = featureValueMappings.get(coord);
+ if (mapping != null)
+ return mapping.get(coord);
+
+ Number val = value.field(coord);
+ return val != null ? val.doubleValue() : null;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected List<String> allCoords(K key, BinaryObject value) {
+ return value.type().fieldNames().stream()
+ .filter(fname -> fieldIsDouble(value, fname))
+ .collect(Collectors.toList());
+ }
+
+ /**
+ * @param value Value.
+ * @param fname Fname.
+ * @return true if field in binary object has double type.
+ */
+ private boolean fieldIsDouble(BinaryObject value, String fname) {
+ return value.type().fieldTypeName(fname).equals(BinaryUtils.fieldTypeName(GridBinaryMarshaller.DOUBLE));
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Vector createVector(int size) {
+ return new SparseVector(size, SparseVector.RANDOM_ACCESS_MODE);
+ }
+
+ /** Feature values mapping for non-number features. */
+ public static class Mapping {
+ /** Mapping. */
+ private HashMap<Object, Double> value = new HashMap<>();
+
+ /**
+ * Creates an instance of Mapping.
+ */
+ public static Mapping create() {
+ return new Mapping();
+ }
+
+ /**
+ * Add mapping.
+ *
+ * @param from From value.
+ * @param to To double value.
+ * @return this.
+ */
+ public Mapping map(Object from, Double to) {
+ this.value.put(from, to);
+ return this;
+ }
+
+ /**
+ * Default value for new feature values.
+ *
+ * @param value Default value.
+ * @return this.
+ */
+ public Mapping defaultValue(Double value) {
+ this.value.put(DEFAULT_VALUE, value);
+ return this;
+ }
+
+ /**
+ * Converts mapping to HashMap.
+ */
+ private HashMap<Object, Double> toMap() {
+ if(!value.containsKey(DEFAULT_VALUE))
+ value.put(DEFAULT_VALUE, null);
+
+ return value;
+ }
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/DummyVectorizer.java
similarity index 52%
copy from modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java
copy to modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/DummyVectorizer.java
index fe236c3..a660eab 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/DummyVectorizer.java
@@ -15,23 +15,36 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.math.stat;
+package org.apache.ignite.ml.dataset.feature.extractor.impl;
-import java.io.Serializable;
+import org.apache.ignite.ml.dataset.feature.extractor.ExtractionUtils;
import org.apache.ignite.ml.math.primitives.vector.Vector;
/**
- * Interface for distributions.
+ * Vectorizer on Vector.
+ *
+ * @param <K> Type of key.
*/
-public interface Distribution extends Serializable {
- /**
- * @param x Vector.
- * @return probability of vector.
- */
- public double prob(Vector x);
+public class DummyVectorizer<K> extends ExtractionUtils.ArrayLikeVectorizer<K, Vector> {
+ /** Serial version uid. */
+ private static final long serialVersionUID = -6225354615212148224L;
/**
- * @return dimension of vector space.
+ * Creates an instance of Vectorizer.
+ *
+ * @param coords Coordinates.
*/
- public int dimension();
+ public DummyVectorizer(Integer ... coords) {
+ super(coords);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Double feature(Integer coord, K key, Vector value) {
+ return value.get(coord);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected int sizeOf(K key, Vector value) {
+ return value.size();
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/FeatureLabelExtractorWrapper.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/FeatureLabelExtractorWrapper.java
new file mode 100644
index 0000000..2407ce1
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/FeatureLabelExtractorWrapper.java
@@ -0,0 +1,84 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset.feature.extractor.impl;
+
+import org.apache.ignite.ml.composition.CompositionUtils;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.math.functions.IgniteBiFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.structures.LabeledVector;
+import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+
+import java.io.Serializable;
+import java.util.List;
+
+/**
+ * Temporary class for Features/Label extracting.
+ */
+public class FeatureLabelExtractorWrapper<K, V, C extends Serializable, L> extends Vectorizer<K, V, C, L> {
+ /** Original extractor. */
+ private final FeatureLabelExtractor<K, V, L> extractor;
+
+ /**
+ * Creates an instance of FeatureLabelExtractorWrapper.
+ *
+ * @param extractor Features and lavels extractor.
+ */
+ public FeatureLabelExtractorWrapper(FeatureLabelExtractor<K, V, L> extractor) {
+ this.extractor = extractor;
+ }
+
+ /**
+ * @param featuresEx Method for feature vector extracting.
+ * @return wrapper.
+ */
+ public static <K, V, C extends Serializable> FeatureLabelExtractorWrapper<K, V, C, Double> wrap(IgniteBiFunction<K, V, Vector> featuresEx) {
+ return new FeatureLabelExtractorWrapper<>((k, v) -> featuresEx.apply(k, v).labeled(0.0));
+ }
+
+ public static <K, V, C extends Serializable, L> FeatureLabelExtractorWrapper<K, V, C, L> wrap(IgniteBiFunction<K, V, Vector> featuresEx,
+ IgniteBiFunction<K, V, L> lbExtractor) {
+
+ return new FeatureLabelExtractorWrapper<>(CompositionUtils.asFeatureLabelExtractor(featuresEx, lbExtractor));
+ }
+
+ /** {@inheritDoc} */
+ @Override public LabeledVector<L> apply(K key, V value) {
+ return extractor.extract(key, value);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Double feature(C coord, K key, V value) {
+ throw new IllegalStateException();
+ }
+
+ /** {@inheritDoc} */
+ @Override protected L label(C coord, K key, V value) {
+ throw new IllegalStateException();
+ }
+
+ /** {@inheritDoc} */
+ @Override protected L zero() {
+ throw new IllegalStateException();
+ }
+
+ /** {@inheritDoc} */
+ @Override protected List<C> allCoords(K key, V value) {
+ throw new IllegalStateException();
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/LabeledDummyVectorizer.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/LabeledDummyVectorizer.java
new file mode 100644
index 0000000..2c1e704
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/LabeledDummyVectorizer.java
@@ -0,0 +1,65 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.dataset.feature.extractor.impl;
+
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.structures.LabeledVector;
+
+import java.util.List;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+
+/**
+ * Vectorizer on LabeledVector.
+ *
+ * @param <K> Type of key.
+ */
+public class LabeledDummyVectorizer<K, L> extends Vectorizer<K, LabeledVector<L>, Integer, L> {
+ /** Serial version uid. */
+ private static final long serialVersionUID = -6225354615212148224L;
+
+ /**
+ * Creates an instance of Vectorizer.
+ *
+ * @param coords Coordinates.
+ */
+ public LabeledDummyVectorizer(Integer ... coords) {
+ super(coords);
+ labeled(-1);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected Double feature(Integer coord, K key, LabeledVector<L> value) {
+ return value.features().get(coord);
+ }
+
+ /** {@inheritDoc} */
+ @Override protected L label(Integer coord, K key, LabeledVector<L> value) {
+ return value.label();
+ }
+
+ /** {@inheritDoc} */
+ @Override protected L zero() {
+ return null;
+ }
+
+ /** {@inheritDoc} */
+ @Override protected List<Integer> allCoords(K key, LabeledVector<L> value) {
+ return IntStream.range(0, value.features().size()).boxed().collect(Collectors.toList());
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/package-info.java
similarity index 66%
copy from modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java
copy to modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/package-info.java
index fe236c3..a5ee23b 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/impl/package-info.java
@@ -1,4 +1,5 @@
/*
+
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
@@ -15,23 +16,8 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.math.stat;
-
-import java.io.Serializable;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-
/**
- * Interface for distributions.
+ * <!-- Package description. -->
+ * Package contains default implementations of {@link org.apache.ignite.ml.dataset.feature.extractor.Vectorizer}.
*/
-public interface Distribution extends Serializable {
- /**
- * @param x Vector.
- * @return probability of vector.
- */
- public double prob(Vector x);
-
- /**
- * @return dimension of vector space.
- */
- public int dimension();
-}
+package org.apache.ignite.ml.dataset.feature.extractor.impl;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/package-info.java
similarity index 66%
copy from modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java
copy to modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/package-info.java
index fe236c3..fdc2f8a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/Distribution.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/feature/extractor/package-info.java
@@ -1,4 +1,5 @@
/*
+
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
@@ -15,23 +16,9 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.math.stat;
-
-import java.io.Serializable;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-
/**
- * Interface for distributions.
+ * <!-- Package description. -->
+ * Package for upstream object vectorizations. This package contains {@link org.apache.ignite.ml.dataset.feature.extractor.Vectorizer}
+ * implementations allowing extract feature and label values from upstream object.
*/
-public interface Distribution extends Serializable {
- /**
- * @param x Vector.
- * @return probability of vector.
- */
- public double prob(Vector x);
-
- /**
- * @return dimension of vector space.
- */
- public int dimension();
-}
+package org.apache.ignite.ml.dataset.feature.extractor;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetBuilder.java
index bd5ef1c..c127c30 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetBuilder.java
@@ -17,18 +17,21 @@
package org.apache.ignite.ml.dataset.impl.bootstrapping;
-import java.util.Arrays;
-import java.util.Iterator;
import org.apache.commons.math3.distribution.PoissonDistribution;
import org.apache.commons.math3.random.Well19937c;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironment;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.Iterator;
+
/**
* Builder for bootstrapped dataset. Bootstrapped dataset consist of several subsamples created in according to random
* sampling with replacements selection of vectors from original dataset. This realization uses
@@ -58,10 +61,7 @@ public class BootstrappedDatasetBuilder<K,V> implements PartitionDataBuilder<K,V
* @param samplesCnt Samples count.
* @param subsampleSize Subsample size.
*/
- public BootstrappedDatasetBuilder(FeatureLabelExtractor<K, V, Double> extractor,
- int samplesCnt,
- double subsampleSize) {
-
+ public <C extends Serializable> BootstrappedDatasetBuilder(Vectorizer<K, V, C, Double> extractor, int samplesCnt, double subsampleSize) {
this.extractor = extractor;
this.samplesCnt = samplesCnt;
this.subsampleSize = subsampleSize;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetPartition.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetPartition.java
index 2155d1a..ac961bb 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetPartition.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedDatasetPartition.java
@@ -17,9 +17,10 @@
package org.apache.ignite.ml.dataset.impl.bootstrapping;
+import org.jetbrains.annotations.NotNull;
+
import java.util.Arrays;
import java.util.Iterator;
-import org.jetbrains.annotations.NotNull;
/**
* Partition of bootstrapped vectors.
@@ -50,7 +51,7 @@ public class BootstrappedDatasetPartition implements AutoCloseable, Iterable<Boo
/**
* Returns rows count.
*
- * @return rows count.
+ * @return Rows count.
*/
public int getRowsCount() {
return vectors.length;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedVector.java
index d4ecfb6..5bffe9a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedVector.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/impl/bootstrapping/BootstrappedVector.java
@@ -17,12 +17,13 @@
package org.apache.ignite.ml.dataset.impl.bootstrapping;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.structures.LabeledVector;
+
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.Arrays;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
-import org.apache.ignite.ml.structures.LabeledVector;
/**
* Represents vector with repetitions counters for subsamples in bootstrapped dataset.
@@ -48,7 +49,7 @@ public class BootstrappedVector extends LabeledVector<Double> {
}
/**
- * @return repetitions counters vector.
+ * @return Repetitions counters vector.
*/
public int[] counters() {
return counters;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java
index 5273fa6..a815a9a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/FeatureMatrixWithLabelsOnHeapDataBuilder.java
@@ -17,43 +17,39 @@
package org.apache.ignite.ml.dataset.primitive;
-import java.io.Serializable;
-import java.util.Iterator;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.environment.LearningEnvironment;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
+import java.io.Serializable;
+import java.util.Iterator;
+
/**
* A partition {@code data} builder that makes {@link DecisionTreeData}.
*
* @param <K> Type of a key in <tt>upstream</tt> data.
* @param <V> Type of a value in <tt>upstream</tt> data.
* @param <C> Type of a partition <tt>context</tt>.
+ * @param <CO> Typer of COordinate for vectorizer.
*/
-public class FeatureMatrixWithLabelsOnHeapDataBuilder<K, V, C extends Serializable>
+public class FeatureMatrixWithLabelsOnHeapDataBuilder<K, V, C extends Serializable, CO extends Serializable>
implements PartitionDataBuilder<K, V, C, FeatureMatrixWithLabelsOnHeapData> {
/** Serial version uid. */
private static final long serialVersionUID = 6273736987424171813L;
- /** Function that extracts features from an {@code upstream} data. */
- private final IgniteBiFunction<K, V, Vector> featureExtractor;
-
- /** Function that extracts labels from an {@code upstream} data. */
- private final IgniteBiFunction<K, V, Double> lbExtractor;
+ /** Function that extracts features and labels from an {@code upstream} data. */
+ private final Vectorizer<K, V, CO, Double> vectorizer;
/**
* Constructs a new instance of decision tree data builder.
*
- * @param featureExtractor Function that extracts features from an {@code upstream} data.
- * @param lbExtractor Function that extracts labels from an {@code upstream} data.
+ * @param vectorizer Function that extracts features with labels from an {@code upstream} data.
*/
- public FeatureMatrixWithLabelsOnHeapDataBuilder(IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, Double> lbExtractor) {
- this.featureExtractor = featureExtractor;
- this.lbExtractor = lbExtractor;
+ public FeatureMatrixWithLabelsOnHeapDataBuilder(Vectorizer<K, V, CO, Double> vectorizer) {
+ this.vectorizer = vectorizer;
}
/** {@inheritDoc} */
@@ -69,9 +65,9 @@ public class FeatureMatrixWithLabelsOnHeapDataBuilder<K, V, C extends Serializab
while (upstreamData.hasNext()) {
UpstreamEntry<K, V> entry = upstreamData.next();
- features[ptr] = featureExtractor.apply(entry.getKey(), entry.getValue()).asArray();
-
- labels[ptr] = lbExtractor.apply(entry.getKey(), entry.getValue());
+ LabeledVector<Double> labeledVector = vectorizer.apply(entry.getKey(), entry.getValue());
+ features[ptr] = labeledVector.features().asArray();
+ labels[ptr] = labeledVector.label();
ptr++;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleDatasetDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleDatasetDataBuilder.java
index b14d8a2..df315e0 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleDatasetDataBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleDatasetDataBuilder.java
@@ -17,36 +17,38 @@
package org.apache.ignite.ml.dataset.primitive.builder.data;
-import java.io.Serializable;
-import java.util.Iterator;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData;
import org.apache.ignite.ml.environment.LearningEnvironment;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import java.io.Serializable;
+import java.util.Iterator;
+
/**
* A partition {@code data} builder that makes {@link SimpleDatasetData}.
*
* @param <K> Type of a key in <tt>upstream</tt> data.
* @param <V> Type of a value in <tt>upstream</tt> data.
* @param <C> Type of a partition <tt>context</tt>.
+ * @param <CO> Type of COordinate for vectorizer.
*/
-public class SimpleDatasetDataBuilder<K, V, C extends Serializable>
+public class SimpleDatasetDataBuilder<K, V, C extends Serializable, CO extends Serializable>
implements PartitionDataBuilder<K, V, C, SimpleDatasetData> {
/** */
private static final long serialVersionUID = 756800193212149975L;
/** Function that extracts features from an {@code upstream} data. */
- private final IgniteBiFunction<K, V, Vector> featureExtractor;
+ private final Vectorizer<K, V, CO, ?> featureExtractor;
/**
* Construct a new instance of partition {@code data} builder that makes {@link SimpleDatasetData}.
*
* @param featureExtractor Function that extracts features from an {@code upstream} data.
*/
- public SimpleDatasetDataBuilder(IgniteBiFunction<K, V, Vector> featureExtractor) {
+ public SimpleDatasetDataBuilder(Vectorizer<K, V, CO, ?> featureExtractor) {
this.featureExtractor = featureExtractor;
}
@@ -61,7 +63,7 @@ public class SimpleDatasetDataBuilder<K, V, C extends Serializable>
int ptr = 0;
while (upstreamData.hasNext()) {
UpstreamEntry<K, V> entry = upstreamData.next();
- Vector row = featureExtractor.apply(entry.getKey(), entry.getValue());
+ Vector row = featureExtractor.apply(entry.getKey(), entry.getValue()).features();
if (cols < 0) {
cols = row.size();
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleLabeledDatasetDataBuilder.java b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleLabeledDatasetDataBuilder.java
index 48166ee..364660a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleLabeledDatasetDataBuilder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/dataset/primitive/builder/data/SimpleLabeledDatasetDataBuilder.java
@@ -17,14 +17,16 @@
package org.apache.ignite.ml.dataset.primitive.builder.data;
-import java.io.Serializable;
-import java.util.Iterator;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
import org.apache.ignite.ml.dataset.UpstreamEntry;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.data.SimpleLabeledDatasetData;
import org.apache.ignite.ml.environment.LearningEnvironment;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.structures.LabeledVector;
+
+import java.io.Serializable;
+import java.util.Iterator;
/**
* A partition {@code data} builder that makes {@link SimpleLabeledDatasetData}.
@@ -33,27 +35,21 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
* @param <V> Type of a value in <tt>upstream</tt> data.
* @param <C> type of a partition <tt>context</tt>.
*/
-public class SimpleLabeledDatasetDataBuilder<K, V, C extends Serializable>
+public class SimpleLabeledDatasetDataBuilder<K, V, C extends Serializable, CO extends Serializable>
implements PartitionDataBuilder<K, V, C, SimpleLabeledDatasetData> {
/** */
private static final long serialVersionUID = 3678784980215216039L;
- /** Function that extracts features from an {@code upstream} data. */
- private final IgniteBiFunction<K, V, Vector> featureExtractor;
-
- /** Function that extracts labels from an {@code upstream} data. */
- private final IgniteBiFunction<K, V, double[]> lbExtractor;
+ /** Function that extracts labeled vectors from an {@code upstream} data. */
+ private final Vectorizer<K, V, CO, double[]> vectorizer;
/**
* Constructs a new instance of partition {@code data} builder that makes {@link SimpleLabeledDatasetData}.
*
- * @param featureExtractor Function that extracts features from an {@code upstream} data.
- * @param lbExtractor Function that extracts labels from an {@code upstream} data.
+ * @param vectorizer Function that extracts labeled vectors from an {@code upstream} data.
*/
- public SimpleLabeledDatasetDataBuilder(IgniteBiFunction<K, V, Vector> featureExtractor,
- IgniteBiFunction<K, V, double[]> lbExtractor) {
- this.featureExtractor = featureExtractor;
- this.lbExtractor = lbExtractor;
+ public SimpleLabeledDatasetDataBuilder(Vectorizer<K, V, CO, double[]> vectorizer) {
+ this.vectorizer = vectorizer;
}
/** {@inheritDoc} */
@@ -71,7 +67,8 @@ public class SimpleLabeledDatasetDataBuilder<K, V, C extends Serializable>
while (upstreamData.hasNext()) {
UpstreamEntry<K, V> entry = upstreamData.next();
- Vector featureRow = featureExtractor.apply(entry.getKey(), entry.getValue());
+ LabeledVector<double[]> labeledVector = vectorizer.apply(entry.getKey(), entry.getValue());
+ Vector featureRow = labeledVector.features();
if (featureCols < 0) {
featureCols = featureRow.size();
@@ -84,7 +81,7 @@ public class SimpleLabeledDatasetDataBuilder<K, V, C extends Serializable>
for (int i = 0; i < featureCols; i++)
features[Math.toIntExact(i * upstreamDataSize) + ptr] = featureRow.get(i);
- double[] lbRow = lbExtractor.apply(entry.getKey(), entry.getValue());
+ double[] lbRow = labeledVector.label();
if (lbCols < 0) {
lbCols = lbRow.length;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java
index 940d8cf..78886bc 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/logging/ConsoleLogger.java
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.environment.logging;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.math.Tracer;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
/**
@@ -81,7 +82,7 @@ public class ConsoleLogger implements MLLogger {
/**
* ConsoleLogger factory.
*/
- private static class Factory implements MLLogger.Factory {
+ public static class Factory implements MLLogger.Factory {
/** Serial version uuid. */
private static final long serialVersionUID = 5864605548782107893L;
@@ -101,5 +102,14 @@ public class ConsoleLogger implements MLLogger {
@Override public <T> MLLogger create(Class<T> targetCls) {
return new ConsoleLogger(maxVerboseLevel, targetCls.getName());
}
+
+ /** Low. */
+ public static final IgniteFunction<Integer, MLLogger.Factory> LOW = part -> new Factory(VerboseLevel.LOW);
+
+ /** High. */
+ public static final IgniteFunction<Integer, MLLogger.Factory> HIGH = part -> new Factory(VerboseLevel.HIGH);
+
+ /** Offset. */
+ public static final IgniteFunction<Integer, MLLogger.Factory> OFF = part -> new Factory(VerboseLevel.OFF);
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/ParallelismStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/ParallelismStrategy.java
index 329ce89..f27281f 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/ParallelismStrategy.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/environment/parallelism/ParallelismStrategy.java
@@ -17,9 +17,11 @@
package org.apache.ignite.ml.environment.parallelism;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.functions.IgniteSupplier;
+
import java.util.ArrayList;
import java.util.List;
-import org.apache.ignite.ml.math.functions.IgniteSupplier;
/**
* Specifies the behaviour of processes in ML-algorithms that can may be parallelized such as parallel learning in
@@ -54,4 +56,10 @@ public interface ParallelismStrategy {
results.add(submit(task));
return results;
}
+
+ /** On default pool. */
+ public static IgniteFunction<Integer, Type> ON_DEFAULT_POOL = part -> Type.ON_DEFAULT_POOL;
+
+ /** No parallelism. */
+ public static IgniteFunction<Integer, Type> NO_PARALLELISM = part -> Type.NO_PARALLELISM;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/Gene.java b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/Gene.java
index af1af3c..e6bef2b 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/genetic/Gene.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/genetic/Gene.java
@@ -17,9 +17,10 @@
package org.apache.ignite.ml.genetic;
-import java.util.concurrent.atomic.AtomicLong;
import org.apache.ignite.cache.query.annotations.QuerySqlField;
+import java.util.concurrent.atomic.AtomicLong;
+
/**
* Represents the discrete parts of a potential solution (ie: Chromosome)
*
@@ -55,7 +56,7 @@ public class Gene {
}
/**
- * @return value for Gene
+ * @return Value for Gene
*/
public Object getVal() {
return val;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java
index 8239ebd..a330344 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/KNNUtils.java
@@ -20,15 +20,16 @@ package org.apache.ignite.ml.knn;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
-import org.apache.ignite.ml.math.functions.IgniteBiFunction;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
import org.jetbrains.annotations.Nullable;
+import java.io.Serializable;
+
/**
* Helper class for KNNRegression.
*/
@@ -38,18 +39,14 @@ public class KNNUtils {
*
* @param envBuilder Learning environment builder.
* @param datasetBuilder Dataset builder.
- * @param featureExtractor Feature extractor.
- * @param lbExtractor Label extractor.
+ * @param vectorizer Upstream vectorizer.
* @return Dataset.
*/
- @Nullable public static <K, V> Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> buildDataset(
+ @Nullable public static <K, V, C extends Serializable> Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> buildDataset(
LearningEnvironmentBuilder envBuilder,
- DatasetBuilder<K, V> datasetBuilder, IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
+ DatasetBuilder<K, V> datasetBuilder, Vectorizer<K,V,C,Double> vectorizer) {
PartitionDataBuilder<K, V, EmptyContext, LabeledVectorSet<Double, LabeledVector>> partDataBuilder
- = new LabeledDatasetPartitionDataBuilderOnHeap<>(
- featureExtractor,
- lbExtractor
- );
+ = new LabeledDatasetPartitionDataBuilderOnHeap<>(vectorizer);
Dataset<EmptyContext, LabeledVectorSet<Double, LabeledVector>> dataset = null;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
index 4b8677a..b4cef43 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
@@ -17,20 +17,13 @@
package org.apache.ignite.ml.knn.ann;
-import java.io.Serializable;
-import java.util.Arrays;
-import java.util.List;
-import java.util.TreeMap;
-import java.util.concurrent.ConcurrentHashMap;
-import java.util.concurrent.ConcurrentSkipListSet;
-import java.util.stream.Collectors;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
-import org.apache.ignite.ml.composition.CompositionUtils;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.PartitionDataBuilder;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
@@ -41,10 +34,17 @@ import org.apache.ignite.ml.math.util.MapUtil;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.structures.partition.LabeledDatasetPartitionDataBuilderOnHeap;
-import org.apache.ignite.ml.trainers.FeatureLabelExtractor;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.jetbrains.annotations.NotNull;
+import java.io.Serializable;
+import java.util.Arrays;
+import java.util.List;
+import java.util.TreeMap;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.ConcurrentSkipListSet;
+import java.util.stream.Collectors;
+
/**
* ANN algorithm trainer to solve multi-class classification task. This trainer is based on ACD strategy and KMeans
* clustering algorithm to find centroids.
@@ -69,31 +69,28 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
* @param extractor Mapping from upstream entry to {@link LabeledVector}.
* @return Model.
*/
- @Override public <K, V> ANNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
- FeatureLabelExtractor<K, V, Double> extractor) {
+ @Override public <K, V, C extends Serializable> ANNClassificationModel fit(DatasetBuilder<K, V> datasetBuilder,
+ Vectorizer<K, V, C, Double> extractor) {
return updateModel(null, datasetBuilder, extractor);
}
/** {@inheritDoc} */
- @Override protected <K, V> ANNClassificationModel updateModel(ANNClassificationModel mdl,
- DatasetBuilder<K, V> datasetBuilder, FeatureLabelExtractor<K, V, Double> extractor) {
-
- IgniteBiFunction<K, V, Vector> featureExtractor = CompositionUtils.asFeatureExtractor(extractor);
- IgniteBiFunction<K, V, Double> lbExtractor = CompositionUtils.asLabelExtractor(extractor);
+ @Override protected <K, V, C extends Serializable> ANNClassificationModel updateModel(ANNClassificationModel mdl,
+ DatasetBuilder<K, V> datasetBuilder, Vectorizer<K, V, C, Double> extractor) {
List<Vector> centers;
CentroidStat centroidStat;
if (mdl != null) {
centers = Arrays.stream(mdl.getCandidates().data()).map(x -> x.features()).collect(Collectors.toList());
- CentroidStat newStat = getCentroidStat(datasetBuilder, featureExtractor, lbExtractor, centers);
- if(newStat == null)
+ CentroidStat newStat = getCentroidStat(datasetBuilder, extractor, centers);
+ if (newStat == null)
return mdl;
CentroidStat oldStat = mdl.getCentroindsStat();
centroidStat = newStat.merge(oldStat);
} else {
- centers = getCentroids(featureExtractor, lbExtractor, datasetBuilder);
- centroidStat = getCentroidStat(datasetBuilder, featureExtractor, lbExtractor, centers);
+ centers = getCentroids(extractor, datasetBuilder);
+ centroidStat = getCentroidStat(datasetBuilder, extractor, centers);
}
final LabeledVectorSet<ProbableLabel, LabeledVector> dataset = buildLabelsForCandidates(centers, centroidStat);
@@ -128,27 +125,20 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
/**
* Perform KMeans clusterization algorithm to find centroids.
... 6335 lines suppressed ...