You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by za...@apache.org on 2020/12/02 10:34:41 UTC
[ignite] branch master updated: IGNITE-13672 [ML]: Add initial JSON
export/import support for all models (#8521)
This is an automated email from the ASF dual-hosted git repository.
zaleslaw pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new 6f9052d IGNITE-13672 [ML]: Add initial JSON export/import support for all models (#8521)
6f9052d is described below
commit 6f9052d6a117a2c851081946d5a9a9095c71d7cb
Author: Alexey Zinoviev <za...@gmail.com>
AuthorDate: Wed Dec 2 13:34:24 2020 +0300
IGNITE-13672 [ML]: Add initial JSON export/import support for all models (#8521)
* [IGNITE-13672] Initial solution
* [IGNITE-13672] Added an example
* [IGNITE-13672] Added a draft solution
* [IGNITE-13672] Updated JSON model
* [IGNITE-13672] Updated JSON model
* [IGNITE-13672] Removed GMM support
* [IGNITE-13672] Fixed blank lines
* [IGNITE-13672] Fixed licenses
* [IGNITE-13672] Fixed whitespaces
* [IGNITE-13672] Fixed whitespaces
* [IGNITE-13672] Fixed whitespaces
* [IGNITE-13672] Fixed examples
* [IGNITE-13672] Fixed examples
* [IGNITE-13672] Fixed test
---
.../binary-classification/decision-trees.adoc | 6 +-
.../model-import-from-apache-spark.adoc | 2 +-
.../model-selection/cross-validation.adoc | 4 +-
.../model-selection/pipeline-api.adoc | 4 +-
.../regression/decision-trees-regression.adoc | 6 +-
.../ml/clustering/KMeansClusterizationExample.java | 4 +-
.../ANNClassificationExportImportExample.java | 339 +++++++++++++++++++++
.../CompoundNaiveBayesExportImportExample.java | 129 ++++++++
...sionTreeClassificationExportImportExample.java} | 66 ++--
...DecisionTreeRegressionExportImportExample.java} | 32 +-
.../DiscreteNaiveBayesExportImportExample.java | 117 +++++++
...BOnTreesClassificationExportImportExample.java} | 64 ++--
.../GDBOnTreesRegressionExportImportExample.java} | 61 ++--
.../GaussianNaiveBayesExportImportExample.java | 117 +++++++
.../KMeansClusterizationExportImportExample.java} | 43 ++-
.../LinearRegressionExportImportExample.java | 116 +++++++
.../LogisticRegressionExportImportExample.java | 122 ++++++++
...domForestClassificationExportImportExample.java | 144 +++++++++
.../RandomForestRegressionExportImportExample.java | 151 +++++++++
.../exchange/SVMExportImportExample.java} | 79 ++---
.../modelparser/DecisionTreeFromSparkExample.java | 4 +-
.../DecisionTreeRegressionFromSparkExample.java | 4 +-
.../ml/preprocessing/encoding/EncoderExample.java | 4 +-
.../encoding/EncoderExampleWithNormalization.java | 4 +-
.../encoding/LabelEncoderExample.java | 4 +-
.../linear/BostonHousePricesPredictionExample.java | 4 +-
.../ml/selection/cv/CrossValidationExample.java | 4 +-
...eeClassificationTrainerSQLInferenceExample.java | 4 +-
...onTreeClassificationTrainerSQLTableExample.java | 4 +-
.../DecisionTreeClassificationTrainerExample.java | 4 +-
.../tree/DecisionTreeRegressionTrainerExample.java | 4 +-
.../GDBOnTreesClassificationTrainerExample.java | 8 +-
.../GDBOnTreesRegressionTrainerExample.java | 10 +-
.../examples/ml/tutorial/Step_11_Boosting.java | 8 +-
.../ml/tutorial/Step_1_Read_and_Learn.java | 4 +-
.../examples/ml/tutorial/Step_2_Imputing.java | 4 +-
.../examples/ml/tutorial/Step_3_Categorial.java | 4 +-
.../Step_3_Categorial_with_One_Hot_Encoder.java | 4 +-
.../examples/ml/tutorial/Step_4_Add_age_fare.java | 4 +-
.../examples/ml/tutorial/Step_5_Scaling.java | 4 +-
.../ml/tutorial/Step_7_Split_train_test.java | 4 +-
.../ignite/examples/ml/tutorial/Step_8_CV.java | 6 +-
.../ml/tutorial/Step_8_CV_with_Param_Grid.java | 6 +-
.../Step_8_CV_with_Param_Grid_and_pipeline.java | 4 +-
.../hyperparametertuning/Step_13_RandomSearch.java | 6 +-
.../Step_14_Parallel_Brute_Force_Search.java | 6 +-
.../Step_15_Parallel_Random_Search.java | 6 +-
.../Step_16_Genetic_Programming_Search.java | 6 +-
...tep_17_Parallel_Genetic_Programming_Search.java | 6 +-
modules/ml/pom.xml | 25 ++
.../ml/sparkmodelparser/SparkModelParser.java | 81 +----
.../apache/ignite/ml/clustering/gmm/GmmModel.java | 6 +
.../ml/clustering/kmeans/ClusterizationModel.java | 4 +-
.../ignite/ml/clustering/kmeans/KMeansModel.java | 125 +++++++-
.../ignite/ml/clustering/kmeans/KMeansTrainer.java | 4 +-
.../ignite/ml/composition/ModelsComposition.java | 16 +-
.../ml/composition/ModelsCompositionFormat.java | 6 +-
.../composition/boosting/GDBLearningStrategy.java | 4 +-
.../ignite/ml/composition/boosting/GDBModel.java | 118 +++++++
.../ignite/ml/composition/boosting/GDBTrainer.java | 43 +--
.../PredictionsAggregator.java | 9 +
.../WeightedPredictionsAggregator.java | 7 +-
.../apache/ignite/ml/inference/json/JSONModel.java | 55 ++++
.../json/JSONModelMixIn.java} | 22 +-
.../json/JSONWritable.java} | 32 +-
.../ignite/ml/inference/json/JacksonHelper.java | 39 +++
.../ignite/ml/knn/NNClassificationModel.java | 11 +
.../ignite/ml/knn/ann/ANNClassificationModel.java | 130 +++++++-
.../ml/knn/ann/ANNClassificationTrainer.java | 14 +-
.../apache/ignite/ml/knn/ann/ProbableLabel.java | 5 +-
.../ml/math/distances/BrayCurtisDistance.java | 4 +
.../ignite/ml/math/distances/DistanceMeasure.java | 17 ++
.../ml/math/distances/MinkowskiDistance.java | 16 +-
.../math/distances/WeightedMinkowskiDistance.java | 35 ++-
.../ignite/ml/math/stat/DistributionMixture.java | 9 +-
.../compound/CompoundNaiveBayesModel.java | 73 ++++-
.../discrete/DiscreteNaiveBayesModel.java | 83 ++++-
.../discrete/DiscreteNaiveBayesSumsHolder.java | 11 +
.../gaussian/GaussianNaiveBayesModel.java | 75 ++++-
.../gaussian/GaussianNaiveBayesSumsHolder.java | 15 +
.../linear/LinearRegressionLSQRTrainer.java | 8 +-
.../regressions/linear/LinearRegressionModel.java | 114 ++++++-
.../linear/LinearRegressionSGDTrainer.java | 4 +-
.../logistic/LogisticRegressionModel.java | 112 ++++++-
.../apache/ignite/ml/structures/DatasetRow.java | 4 +
.../apache/ignite/ml/structures/LabeledVector.java | 4 +
.../ml/svm/SVMLinearClassificationModel.java | 112 ++++++-
.../ml/svm/SVMLinearClassificationTrainer.java | 2 +-
.../ml/tree/DecisionTreeClassificationTrainer.java | 2 +-
.../ml/tree/DecisionTreeConditionalNode.java | 16 +-
.../ignite/ml/tree/DecisionTreeLeafNode.java | 10 +-
.../apache/ignite/ml/tree/DecisionTreeModel.java | 111 +++++++
.../apache/ignite/ml/tree/DecisionTreeNode.java | 15 +-
.../ml/tree/DecisionTreeRegressionTrainer.java | 2 +-
...{DecisionTree.java => DecisionTreeTrainer.java} | 20 +-
.../java/org/apache/ignite/ml/tree/NodeData.java | 90 ++++++
.../tree/boosting/GDBOnTreesLearningStrategy.java | 10 +-
.../RandomForestClassifierTrainer.java | 7 +-
.../ml/tree/randomforest/RandomForestModel.java | 106 +++++++
.../RandomForestRegressionTrainer.java | 7 +-
.../ml/tree/randomforest/RandomForestTrainer.java | 37 ++-
.../ignite/ml/tree/randomforest/data/NodeId.java | 11 +-
.../ml/tree/randomforest/data/NodeSplit.java | 9 +-
.../{TreeRoot.java => RandomForestTreeModel.java} | 25 +-
.../ignite/ml/tree/randomforest/data/TreeNode.java | 9 +-
.../data/impurity/ImpurityHistogramsComputer.java | 8 +-
.../data/statistics/LeafValuesComputer.java | 8 +-
.../ignite/ml/clustering/KMeansModelTest.java | 4 +-
.../apache/ignite/ml/common/KeepBinaryTest.java | 2 +-
.../ml/composition/boosting/GDBTrainerTest.java | 6 +-
.../ignite/ml/math/distances/DistanceTest.java | 6 +-
.../distances/WeightedMinkowskiDistanceTest.java | 10 +-
.../linear/LinearRegressionLSQRTrainerTest.java | 16 +-
.../linear/LinearRegressionSGDTrainerTest.java | 16 +-
.../ml/selection/cv/CrossValidationTest.java | 8 +-
...onTreeClassificationTrainerIntegrationTest.java | 7 +-
.../DecisionTreeClassificationTrainerTest.java | 6 +-
...cisionTreeRegressionTrainerIntegrationTest.java | 8 +-
.../ml/tree/DecisionTreeRegressionTrainerTest.java | 6 +-
.../RandomForestClassifierTrainerTest.java | 13 +-
.../randomforest/RandomForestIntegrationTest.java | 3 +-
.../RandomForestRegressionTrainerTest.java | 9 +-
.../ml/tree/randomforest/data/TreeNodeTest.java | 14 +-
123 files changed, 3249 insertions(+), 592 deletions(-)
diff --git a/docs/_docs/machine-learning/binary-classification/decision-trees.adoc b/docs/_docs/machine-learning/binary-classification/decision-trees.adoc
index 57ab7bf..bc9ff05 100644
--- a/docs/_docs/machine-learning/binary-classification/decision-trees.adoc
+++ b/docs/_docs/machine-learning/binary-classification/decision-trees.adoc
@@ -39,12 +39,12 @@ The model works this way - the split process stops when either the algorithm has
== Model
-The Model in a decision tree classification is represented by the class `DecisionTreeNode`. We can make a prediction for a given vector of features in the following way:
+The Model in a decision tree classification is represented by the class `DecisionTreeModel`. We can make a prediction for a given vector of features in the following way:
[source, java]
----
-DecisionTreeNode mdl = ...;
+DecisionTreeModel mdl = ...;
double prediction = mdl.apply(observation);
----
@@ -68,7 +68,7 @@ DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTraine
);
// Train model.
-DecisionTreeNode mdl = trainer.fit(ignite, dataCache, vectorizer);
+DecisionTreeModel mdl = trainer.fit(ignite, dataCache, vectorizer);
----
diff --git a/docs/_docs/machine-learning/importing-model/model-import-from-apache-spark.adoc b/docs/_docs/machine-learning/importing-model/model-import-from-apache-spark.adoc
index 92992f8..065cb78 100644
--- a/docs/_docs/machine-learning/importing-model/model-import-from-apache-spark.adoc
+++ b/docs/_docs/machine-learning/importing-model/model-import-from-apache-spark.adoc
@@ -71,7 +71,7 @@ To load in Ignite ML you should use SparkModelParser class via method parse() ca
[source, java]
----
-DecisionTreeNode mdl = (DecisionTreeNode)SparkModelParser.parse(
+DecisionTreeModel mdl = (DecisionTreeModel)SparkModelParser.parse(
SPARK_MDL_PATH,
SupportedSparkModels.DECISION_TREE
);
diff --git a/docs/_docs/machine-learning/model-selection/cross-validation.adoc b/docs/_docs/machine-learning/model-selection/cross-validation.adoc
index 8e64c68..39e00f1 100644
--- a/docs/_docs/machine-learning/model-selection/cross-validation.adoc
+++ b/docs/_docs/machine-learning/model-selection/cross-validation.adoc
@@ -27,7 +27,7 @@ Let’s imagine that we have a trainer, a training set and we want to make cross
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
// Create cross-validation instance
-CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
+CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
= new CrossValidation<>();
// Set up the cross-validation process
@@ -67,7 +67,7 @@ Pipeline<Integer, Vector, Integer, Double> pipeline
// Create cross-validation instance
-CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
+CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
= new CrossValidation<>();
// Set up the cross-validation process
diff --git a/docs/_docs/machine-learning/model-selection/pipeline-api.adoc b/docs/_docs/machine-learning/model-selection/pipeline-api.adoc
index 7f0cb93..9b2798c 100644
--- a/docs/_docs/machine-learning/model-selection/pipeline-api.adoc
+++ b/docs/_docs/machine-learning/model-selection/pipeline-api.adoc
@@ -64,7 +64,7 @@ Preprocessor<Integer, Vector> normalizationPreprocessor = new NormalizationTrain
DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer();
-CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator = new CrossValidation<>();
+CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator = new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
.addHyperParam("maxDeep", trainerCV::withMaxDeep, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 10.0})
@@ -101,7 +101,7 @@ Pipeline<Integer, Vector, Integer, Double> pipeline = new Pipeline<Integer, Vect
.addPreprocessingTrainer(new MinMaxScalerTrainer<Integer, Vector>())
.addTrainer(trainer);
-CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator = new CrossValidation<>();
+CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator = new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
.addHyperParam("maxDeep", trainer::withMaxDeep, new Double[] {1.0, 2.0, 3.0, 4.0, 5.0, 10.0})
diff --git a/docs/_docs/machine-learning/regression/decision-trees-regression.adoc b/docs/_docs/machine-learning/regression/decision-trees-regression.adoc
index 48f9d5c..2abbaa8 100644
--- a/docs/_docs/machine-learning/regression/decision-trees-regression.adoc
+++ b/docs/_docs/machine-learning/regression/decision-trees-regression.adoc
@@ -39,12 +39,12 @@ The model works this way - the split process stops when either the algorithm has
== Model
-The Model in a decision tree classification is represented by the class `DecisionTreeNode`. We can make a prediction for a given vector of features in the following way:
+The Model in a decision tree classification is represented by the class `DecisionTreeModel`. We can make a prediction for a given vector of features in the following way:
[source, java]
----
-DecisionTreeNode mdl = ...;
+DecisionTreeModel mdl = ...;
double prediction = mdl.apply(observation);
----
@@ -67,7 +67,7 @@ DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(
);
// Train model.
-DecisionTreeNode mdl = trainer.fit(ignite, dataCache, vectorizer);
+DecisionTreeModel mdl = trainer.fit(ignite, dataCache, vectorizer);
----
== Examples
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
index beee4f6..3127418 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
@@ -73,8 +73,8 @@ public class KMeansClusterizationExample {
);
System.out.println(">>> KMeans centroids");
- Tracer.showAscii(mdl.getCenters()[0]);
- Tracer.showAscii(mdl.getCenters()[1]);
+ Tracer.showAscii(mdl.centers()[0]);
+ Tracer.showAscii(mdl.centers()[1]);
System.out.println(">>>");
System.out.println(">>> --------------------------------------------");
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/ANNClassificationExportImportExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/ANNClassificationExportImportExample.java
new file mode 100644
index 0000000..618e4c6
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/ANNClassificationExportImportExample.java
@@ -0,0 +1,339 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.inference.exchange;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.Arrays;
+import java.util.UUID;
+import javax.cache.Cache;
+import org.apache.commons.math3.util.Precision;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.configuration.CacheConfiguration;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer;
+import org.apache.ignite.ml.knn.NNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNClassificationModel;
+import org.apache.ignite.ml.knn.ann.ANNClassificationTrainer;
+import org.apache.ignite.ml.math.distances.EuclideanDistance;
+import org.apache.ignite.ml.math.distances.ManhattanDistance;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
+
+/**
+ * Run ANN multi-class classification trainer ({@link ANNClassificationTrainer}) over distributed dataset.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points (based on the
+ * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p>
+ * <p>
+ * After that it trains the model based on the specified data using
+ * <a href="https://en.wikipedia.org/wiki/K-nearest_neighbors_algorithm">kNN</a> algorithm.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the trained model to predict what cluster does
+ * this point belong to, and compares prediction to expected outcome (ground truth).</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
+ */
+public class ANNClassificationExportImportExample {
+ /**
+ * Run example.
+ */
+ public static void main(String[] args) throws IOException {
+ System.out.println();
+ System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteCache<Integer, double[]> dataCache = null;
+ Path jsonMdlPath = null;
+ try {
+ dataCache = getTestCache(ignite);
+
+ ANNClassificationTrainer trainer = new ANNClassificationTrainer()
+ .withDistance(new ManhattanDistance())
+ .withK(50)
+ .withMaxIterations(1000)
+ .withEpsilon(1e-2);
+
+ ANNClassificationModel mdl = (ANNClassificationModel) trainer.fit(
+ ignite,
+ dataCache,
+ new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)
+ ).withK(5)
+ .withDistanceMeasure(new EuclideanDistance())
+ .withWeighted(true);
+
+ System.out.println("\n>>> Exported ANN model: " + mdl.toString(true));
+
+ double accuracy = evaluateModel(dataCache, mdl);
+
+ System.out.println("\n>>> Accuracy for exported ANN model:" + accuracy);
+
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
+
+ ANNClassificationModel modelImportedFromJSON = ANNClassificationModel.fromJSON(jsonMdlPath);
+
+ System.out.println("\n>>> Imported ANN model: " + modelImportedFromJSON.toString(true));
+
+ accuracy = evaluateModel(dataCache, modelImportedFromJSON);
+
+ System.out.println("\n>>> Accuracy for imported ANN model:" + accuracy);
+
+ System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
+ }
+ finally {
+ if (dataCache != null)
+ dataCache.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
+ }
+ }
+ finally {
+ System.out.flush();
+ }
+ }
+
+ private static double evaluateModel(IgniteCache<Integer, double[]> dataCache, NNClassificationModel knnMdl) {
+ int amountOfErrors = 0;
+ int totalAmount = 0;
+
+ double accuracy;
+ try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Ground Truth\t|");
+ System.out.println(">>> ---------------------------------");
+
+ for (Cache.Entry<Integer, double[]> observation : observations) {
+ double[] val = observation.getValue();
+ double[] inputs = Arrays.copyOfRange(val, 1, val.length);
+ double groundTruth = val[0];
+
+ double prediction = knnMdl.predict(new DenseVector(inputs));
+
+ totalAmount++;
+ if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
+ amountOfErrors++;
+
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
+ }
+
+ System.out.println(">>> ---------------------------------");
+
+ accuracy = 1 - amountOfErrors / (double) totalAmount;
+
+ }
+ return accuracy;
+ }
+
+ /**
+ * Fills cache with data and returns it.
+ *
+ * @param ignite Ignite instance.
+ * @return Filled Ignite Cache.
+ */
+ private static IgniteCache<Integer, double[]> getTestCache(Ignite ignite) {
+ CacheConfiguration<Integer, double[]> cacheConfiguration = new CacheConfiguration<>();
+ cacheConfiguration.setName("TEST_" + UUID.randomUUID());
+ cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 10));
+
+ IgniteCache<Integer, double[]> cache = ignite.createCache(cacheConfiguration);
+
+ for (int k = 0; k < 10; k++) { // multiplies the Iris dataset k times.
+ for (int i = 0; i < data.length; i++)
+ cache.put(k * 10000 + i, mutate(data[i], k));
+ }
+
+ return cache;
+ }
+
+ /**
+ * Tiny changing of data depending on k parameter.
+ *
+ * @param datum The vector data.
+ * @param k The passed parameter.
+ * @return The changed vector data.
+ */
+ private static double[] mutate(double[] datum, int k) {
+ for (int i = 0; i < datum.length; i++)
+ datum[i] += k / 100000;
+ return datum;
+ }
+
+ /**
+ * The Iris dataset.
+ */
+ private static final double[][] data = {
+ {1, 5.1, 3.5, 1.4, 0.2},
+ {1, 4.9, 3, 1.4, 0.2},
+ {1, 4.7, 3.2, 1.3, 0.2},
+ {1, 4.6, 3.1, 1.5, 0.2},
+ {1, 5, 3.6, 1.4, 0.2},
+ {1, 5.4, 3.9, 1.7, 0.4},
+ {1, 4.6, 3.4, 1.4, 0.3},
+ {1, 5, 3.4, 1.5, 0.2},
+ {1, 4.4, 2.9, 1.4, 0.2},
+ {1, 4.9, 3.1, 1.5, 0.1},
+ {1, 5.4, 3.7, 1.5, 0.2},
+ {1, 4.8, 3.4, 1.6, 0.2},
+ {1, 4.8, 3, 1.4, 0.1},
+ {1, 4.3, 3, 1.1, 0.1},
+ {1, 5.8, 4, 1.2, 0.2},
+ {1, 5.7, 4.4, 1.5, 0.4},
+ {1, 5.4, 3.9, 1.3, 0.4},
+ {1, 5.1, 3.5, 1.4, 0.3},
+ {1, 5.7, 3.8, 1.7, 0.3},
+ {1, 5.1, 3.8, 1.5, 0.3},
+ {1, 5.4, 3.4, 1.7, 0.2},
+ {1, 5.1, 3.7, 1.5, 0.4},
+ {1, 4.6, 3.6, 1, 0.2},
+ {1, 5.1, 3.3, 1.7, 0.5},
+ {1, 4.8, 3.4, 1.9, 0.2},
+ {1, 5, 3, 1.6, 0.2},
+ {1, 5, 3.4, 1.6, 0.4},
+ {1, 5.2, 3.5, 1.5, 0.2},
+ {1, 5.2, 3.4, 1.4, 0.2},
+ {1, 4.7, 3.2, 1.6, 0.2},
+ {1, 4.8, 3.1, 1.6, 0.2},
+ {1, 5.4, 3.4, 1.5, 0.4},
+ {1, 5.2, 4.1, 1.5, 0.1},
+ {1, 5.5, 4.2, 1.4, 0.2},
+ {1, 4.9, 3.1, 1.5, 0.1},
+ {1, 5, 3.2, 1.2, 0.2},
+ {1, 5.5, 3.5, 1.3, 0.2},
+ {1, 4.9, 3.1, 1.5, 0.1},
+ {1, 4.4, 3, 1.3, 0.2},
+ {1, 5.1, 3.4, 1.5, 0.2},
+ {1, 5, 3.5, 1.3, 0.3},
+ {1, 4.5, 2.3, 1.3, 0.3},
+ {1, 4.4, 3.2, 1.3, 0.2},
+ {1, 5, 3.5, 1.6, 0.6},
+ {1, 5.1, 3.8, 1.9, 0.4},
+ {1, 4.8, 3, 1.4, 0.3},
+ {1, 5.1, 3.8, 1.6, 0.2},
+ {1, 4.6, 3.2, 1.4, 0.2},
+ {1, 5.3, 3.7, 1.5, 0.2},
+ {1, 5, 3.3, 1.4, 0.2},
+ {2, 7, 3.2, 4.7, 1.4},
+ {2, 6.4, 3.2, 4.5, 1.5},
+ {2, 6.9, 3.1, 4.9, 1.5},
+ {2, 5.5, 2.3, 4, 1.3},
+ {2, 6.5, 2.8, 4.6, 1.5},
+ {2, 5.7, 2.8, 4.5, 1.3},
+ {2, 6.3, 3.3, 4.7, 1.6},
+ {2, 4.9, 2.4, 3.3, 1},
+ {2, 6.6, 2.9, 4.6, 1.3},
+ {2, 5.2, 2.7, 3.9, 1.4},
+ {2, 5, 2, 3.5, 1},
+ {2, 5.9, 3, 4.2, 1.5},
+ {2, 6, 2.2, 4, 1},
+ {2, 6.1, 2.9, 4.7, 1.4},
+ {2, 5.6, 2.9, 3.6, 1.3},
+ {2, 6.7, 3.1, 4.4, 1.4},
+ {2, 5.6, 3, 4.5, 1.5},
+ {2, 5.8, 2.7, 4.1, 1},
+ {2, 6.2, 2.2, 4.5, 1.5},
+ {2, 5.6, 2.5, 3.9, 1.1},
+ {2, 5.9, 3.2, 4.8, 1.8},
+ {2, 6.1, 2.8, 4, 1.3},
+ {2, 6.3, 2.5, 4.9, 1.5},
+ {2, 6.1, 2.8, 4.7, 1.2},
+ {2, 6.4, 2.9, 4.3, 1.3},
+ {2, 6.6, 3, 4.4, 1.4},
+ {2, 6.8, 2.8, 4.8, 1.4},
+ {2, 6.7, 3, 5, 1.7},
+ {2, 6, 2.9, 4.5, 1.5},
+ {2, 5.7, 2.6, 3.5, 1},
+ {2, 5.5, 2.4, 3.8, 1.1},
+ {2, 5.5, 2.4, 3.7, 1},
+ {2, 5.8, 2.7, 3.9, 1.2},
+ {2, 6, 2.7, 5.1, 1.6},
+ {2, 5.4, 3, 4.5, 1.5},
+ {2, 6, 3.4, 4.5, 1.6},
+ {2, 6.7, 3.1, 4.7, 1.5},
+ {2, 6.3, 2.3, 4.4, 1.3},
+ {2, 5.6, 3, 4.1, 1.3},
+ {2, 5.5, 2.5, 4, 1.3},
+ {2, 5.5, 2.6, 4.4, 1.2},
+ {2, 6.1, 3, 4.6, 1.4},
+ {2, 5.8, 2.6, 4, 1.2},
+ {2, 5, 2.3, 3.3, 1},
+ {2, 5.6, 2.7, 4.2, 1.3},
+ {2, 5.7, 3, 4.2, 1.2},
+ {2, 5.7, 2.9, 4.2, 1.3},
+ {2, 6.2, 2.9, 4.3, 1.3},
+ {2, 5.1, 2.5, 3, 1.1},
+ {2, 5.7, 2.8, 4.1, 1.3},
+ {3, 6.3, 3.3, 6, 2.5},
+ {3, 5.8, 2.7, 5.1, 1.9},
+ {3, 7.1, 3, 5.9, 2.1},
+ {3, 6.3, 2.9, 5.6, 1.8},
+ {3, 6.5, 3, 5.8, 2.2},
+ {3, 7.6, 3, 6.6, 2.1},
+ {3, 4.9, 2.5, 4.5, 1.7},
+ {3, 7.3, 2.9, 6.3, 1.8},
+ {3, 6.7, 2.5, 5.8, 1.8},
+ {3, 7.2, 3.6, 6.1, 2.5},
+ {3, 6.5, 3.2, 5.1, 2},
+ {3, 6.4, 2.7, 5.3, 1.9},
+ {3, 6.8, 3, 5.5, 2.1},
+ {3, 5.7, 2.5, 5, 2},
+ {3, 5.8, 2.8, 5.1, 2.4},
+ {3, 6.4, 3.2, 5.3, 2.3},
+ {3, 6.5, 3, 5.5, 1.8},
+ {3, 7.7, 3.8, 6.7, 2.2},
+ {3, 7.7, 2.6, 6.9, 2.3},
+ {3, 6, 2.2, 5, 1.5},
+ {3, 6.9, 3.2, 5.7, 2.3},
+ {3, 5.6, 2.8, 4.9, 2},
+ {3, 7.7, 2.8, 6.7, 2},
+ {3, 6.3, 2.7, 4.9, 1.8},
+ {3, 6.7, 3.3, 5.7, 2.1},
+ {3, 7.2, 3.2, 6, 1.8},
+ {3, 6.2, 2.8, 4.8, 1.8},
+ {3, 6.1, 3, 4.9, 1.8},
+ {3, 6.4, 2.8, 5.6, 2.1},
+ {3, 7.2, 3, 5.8, 1.6},
+ {3, 7.4, 2.8, 6.1, 1.9},
+ {3, 7.9, 3.8, 6.4, 2},
+ {3, 6.4, 2.8, 5.6, 2.2},
+ {3, 6.3, 2.8, 5.1, 1.5},
+ {3, 6.1, 2.6, 5.6, 1.4},
+ {3, 7.7, 3, 6.1, 2.3},
+ {3, 6.3, 3.4, 5.6, 2.4},
+ {3, 6.4, 3.1, 5.5, 1.8},
+ {3, 6, 3, 4.8, 1.8},
+ {3, 6.9, 3.1, 5.4, 2.1},
+ {3, 6.7, 3.1, 5.6, 2.4},
+ {3, 6.9, 3.1, 5.1, 2.3},
+ {3, 5.8, 2.7, 5.1, 1.9},
+ {3, 6.8, 3.2, 5.9, 2.3},
+ {3, 6.7, 3.3, 5.7, 2.5},
+ {3, 6.7, 3, 5.2, 2.3},
+ {3, 6.3, 2.5, 5, 1.9},
+ {3, 6.5, 3, 5.2, 2},
+ {3, 6.2, 3.4, 5.4, 2.3},
+ {3, 5.9, 3, 5.1, 1.8}
+ };
+}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/CompoundNaiveBayesExportImportExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/CompoundNaiveBayesExportImportExample.java
new file mode 100644
index 0000000..7d05f5e
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/CompoundNaiveBayesExportImportExample.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.inference.exchange;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.naivebayes.compound.CompoundNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.compound.CompoundNaiveBayesTrainer;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.MetricName;
+
+import static java.util.Arrays.asList;
+
+/**
+ * Run naive Compound Bayes classification model based on <a href="https://en.wikipedia.org/wiki/Naive_Bayes_classifier">
+ * Nnaive Bayes classifier</a> algorithm ({@link GaussianNaiveBayesTrainer})and <a
+ * href=https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes"> Discrete naive Bayes
+ * classifier</a> algorithm ({@link DiscreteNaiveBayesTrainer}) over distributed cache.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points.
+ * <p>
+ * After that it trains the naive Bayes classification model based on the specified data.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the trained model to predict the target value,
+ * compares prediction to expected outcome (ground truth), and builds
+ * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
+ */
+public class CompoundNaiveBayesExportImportExample {
+ /** Run example. */
+ public static void main(String[] args) throws IOException {
+ System.out.println();
+ System.out.println(">>> Compound Naive Bayes classification model over partitioned dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteCache<Integer, Vector> dataCache = null;
+ Path jsonMdlPath = null;
+ try {
+ dataCache = new SandboxMLCache(ignite)
+ .fillCacheWith(MLSandboxDatasets.MIXED_DATASET);
+
+ double[] priorProbabilities = new double[]{.5, .5};
+ double[][] thresholds = new double[][]{{.5}, {.5}, {.5}, {.5}, {.5}};
+
+ System.out.println("\n>>> Create new naive Bayes classification trainer object.");
+ CompoundNaiveBayesTrainer trainer = new CompoundNaiveBayesTrainer()
+ .withPriorProbabilities(priorProbabilities)
+ .withGaussianNaiveBayesTrainer(new GaussianNaiveBayesTrainer())
+ .withGaussianFeatureIdsToSkip(asList(3, 4, 5, 6, 7))
+ .withDiscreteNaiveBayesTrainer(new DiscreteNaiveBayesTrainer()
+ .setBucketThresholds(thresholds))
+ .withDiscreteFeatureIdsToSkip(asList(0, 1, 2));
+ System.out.println("\n>>> Perform the training to get the model.");
+
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
+
+ CompoundNaiveBayesModel mdl = trainer.fit(ignite, dataCache, vectorizer);
+
+ System.out.println("\n>>> Exported Compound Naive Bayes model: " + mdl.toString(true));
+
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ vectorizer,
+ MetricName.ACCURACY
+ );
+
+ System.out.println("\n>>> Accuracy for exported Compound Naive Bayes model:" + accuracy);
+
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
+
+ CompoundNaiveBayesModel modelImportedFromJSON = CompoundNaiveBayesModel.fromJSON(jsonMdlPath);
+
+ System.out.println("\n>>> Imported Compound Naive Bayes model: " + modelImportedFromJSON.toString(true));
+
+ accuracy = Evaluator.evaluate(
+ dataCache,
+ modelImportedFromJSON,
+ vectorizer,
+ MetricName.ACCURACY
+ );
+
+ System.out.println("\n>>> Accuracy for imported Compound Naive Bayes model:" + accuracy);
+
+ System.out.println("\n>>> Compound Naive Bayes model over partitioned dataset usage example completed.");
+ }
+ finally {
+ if (dataCache != null)
+ dataCache.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
+ }
+ }
+ finally {
+ System.out.flush();
+ }
+ }
+}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/DecisionTreeClassificationExportImportExample.java
similarity index 64%
copy from examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
copy to examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/DecisionTreeClassificationExportImportExample.java
index 600f4a5..e7ad7ca 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/DecisionTreeClassificationExportImportExample.java
@@ -15,8 +15,11 @@
* limitations under the License.
*/
-package org.apache.ignite.examples.ml.tree;
+package org.apache.ignite.examples.ml.inference.exchange;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
import java.util.Random;
import org.apache.commons.math3.util.Precision;
import org.apache.ignite.Ignite;
@@ -28,7 +31,7 @@ import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorize
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Example of using distributed {@link DecisionTreeClassificationTrainer}.
@@ -42,18 +45,18 @@ import org.apache.ignite.ml.tree.DecisionTreeNode;
* <p>
* You can change the test data used in this example and re-run it to explore this algorithm further.</p>
*/
-public class DecisionTreeClassificationTrainerExample {
+public class DecisionTreeClassificationExportImportExample {
/**
* Executes example.
*
* @param args Command line arguments, none required.
*/
- public static void main(String... args) {
+ public static void main(String[] args) throws IOException {
System.out.println(">>> Decision tree classification trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
- System.out.println(">>> Ignite grid started.");
+ System.out.println("\n>>> Ignite grid started.");
// Create cache with training data.
CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
@@ -61,6 +64,7 @@ public class DecisionTreeClassificationTrainerExample {
trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
+ Path jsonMdlPath = null;
try {
trainingSet = ignite.createCache(trainingSetCfg);
@@ -75,34 +79,36 @@ public class DecisionTreeClassificationTrainerExample {
// Train decision tree model.
LabeledDummyVectorizer<Integer, Double> vectorizer = new LabeledDummyVectorizer<>();
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
trainingSet,
vectorizer
);
- System.out.println(">>> Decision tree classification model: " + mdl);
+ System.out.println("\n>>> Exported Decision tree classification model: " + mdl);
- // Calculate score.
- int correctPredictions = 0;
- for (int i = 0; i < 1000; i++) {
- LabeledVector<Double> pnt = generatePoint(rnd);
+ int correctPredictions = evaluateModel(rnd, mdl);
- double prediction = mdl.predict(pnt.features());
- double lbl = pnt.label();
+ System.out.println("\n>>> Accuracy for exported Decision tree classification model: " + correctPredictions / 10.0 + "%");
- if (i % 50 == 1)
- System.out.printf(">>> test #: %d\t\t predicted: %.4f\t\tlabel: %.4f\n", i, prediction, lbl);
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
- if (Precision.equals(prediction, lbl, Precision.EPSILON))
- correctPredictions++;
- }
+ DecisionTreeModel modelImportedFromJSON = DecisionTreeModel.fromJSON(jsonMdlPath);
- System.out.println(">>> Accuracy: " + correctPredictions / 10.0 + "%");
- System.out.println(">>> Decision tree classification trainer example completed.");
+ System.out.println("\n>>> Imported Decision tree classification model: " + modelImportedFromJSON);
+
+ correctPredictions = evaluateModel(rnd, modelImportedFromJSON);
+
+ System.out.println("\n>>> Accuracy for imported Decision tree classification model: " + correctPredictions / 10.0 + "%");
+
+ System.out.println("\n>>> Decision tree classification trainer example completed.");
}
finally {
- trainingSet.destroy();
+ if (trainingSet != null)
+ trainingSet.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
}
}
finally {
@@ -110,6 +116,24 @@ public class DecisionTreeClassificationTrainerExample {
}
}
+ private static int evaluateModel(Random rnd, DecisionTreeModel mdl) {
+ // Calculate score.
+ int correctPredictions = 0;
+ for (int i = 0; i < 1000; i++) {
+ LabeledVector<Double> pnt = generatePoint(rnd);
+
+ double prediction = mdl.predict(pnt.features());
+ double lbl = pnt.label();
+
+ if (i % 50 == 1)
+ System.out.printf(">>> test #: %d\t\t predicted: %.4f\t\tlabel: %.4f\n", i, prediction, lbl);
+
+ if (Precision.equals(prediction, lbl, Precision.EPSILON))
+ correctPredictions++;
+ }
+ return correctPredictions;
+ }
+
/**
* Generate point with {@code x} in (-0.5, 0.5) and {@code y} in the same interval. If {@code x * y > 0} then label
* is 1, otherwise 0.
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/DecisionTreeRegressionExportImportExample.java
similarity index 77%
copy from examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
copy to examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/DecisionTreeRegressionExportImportExample.java
index 1a19771..9857ba9 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/DecisionTreeRegressionExportImportExample.java
@@ -15,8 +15,11 @@
* limitations under the License.
*/
-package org.apache.ignite.examples.ml.tree;
+package org.apache.ignite.examples.ml.inference.exchange;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
@@ -25,7 +28,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
/**
@@ -41,18 +44,18 @@ import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
* <p>
* You can change the test data used in this example and re-run it to explore this algorithm further.</p>
*/
-public class DecisionTreeRegressionTrainerExample {
+public class DecisionTreeRegressionExportImportExample {
/**
* Executes example.
*
* @param args Command line arguments, none required.
*/
- public static void main(String... args) {
+ public static void main(String... args) throws IOException {
System.out.println(">>> Decision tree regression trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
- System.out.println(">>> Ignite grid started.");
+ System.out.println("\n>>> Ignite grid started.");
// Create cache with training data.
CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
@@ -60,6 +63,7 @@ public class DecisionTreeRegressionTrainerExample {
trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
+ Path jsonMdlPath = null;
try {
trainingSet = ignite.createCache(trainingSetCfg);
@@ -70,9 +74,16 @@ public class DecisionTreeRegressionTrainerExample {
DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
// Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
+ DecisionTreeModel mdl = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
- System.out.println(">>> Decision tree regression model: " + mdl);
+ System.out.println("\n>>> Exported Decision tree regression model: " + mdl);
+
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
+
+ DecisionTreeModel modelImportedFromJSON = DecisionTreeModel.fromJSON(jsonMdlPath);
+
+ System.out.println("\n>>> Imported Decision tree regression model: " + modelImportedFromJSON);
System.out.println(">>> ---------------------------------");
System.out.println(">>> | Prediction\t| Ground Truth\t|");
@@ -87,10 +98,13 @@ public class DecisionTreeRegressionTrainerExample {
System.out.println(">>> ---------------------------------");
- System.out.println(">>> Decision tree regression trainer example completed.");
+ System.out.println("\n>>> Decision tree regression trainer example completed.");
}
finally {
- trainingSet.destroy();
+ if (trainingSet != null)
+ trainingSet.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
}
}
finally {
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/DiscreteNaiveBayesExportImportExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/DiscreteNaiveBayesExportImportExample.java
new file mode 100644
index 0000000..c4d44c4
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/DiscreteNaiveBayesExportImportExample.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.inference.exchange;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.MetricName;
+
+/**
+ * Run naive Bayes classification model based on <a href=https://en.wikipedia.org/wiki/Naive_Bayes_classifier#Multinomial_naive_Bayes">
+ * naive Bayes classifier</a> algorithm ({@link DiscreteNaiveBayesTrainer}) over distributed cache.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points.
+ * </p>
+ * <p>
+ * After that it trains the Discrete naive Bayes classification model based on the specified data.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the trained model to predict the target value,
+ * compares prediction to expected outcome (ground truth), and builds
+ * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
+ */
+public class DiscreteNaiveBayesExportImportExample {
+ /**
+ * Run example.
+ */
+ public static void main(String[] args) throws IOException {
+ System.out.println(">>> Discrete naive Bayes classification model over partitioned dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteCache<Integer, Vector> dataCache = null;
+ Path jsonMdlPath = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.ENGLISH_VS_SCOTTISH);
+
+ double[][] thresholds = new double[][] {{.5}, {.5}, {.5}, {.5}, {.5}};
+ System.out.println(">>> Create new Discrete naive Bayes classification trainer object.");
+ DiscreteNaiveBayesTrainer trainer = new DiscreteNaiveBayesTrainer()
+ .setBucketThresholds(thresholds);
+
+ System.out.println("\n>>> Perform the training to get the model.");
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
+
+ DiscreteNaiveBayesModel mdl = trainer.fit(ignite, dataCache, vectorizer);
+ System.out.println("\n>>> Exported Discrete Naive Bayes model: " + mdl.toString(true));
+
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ vectorizer,
+ MetricName.ACCURACY
+ );
+
+ System.out.println("\n>>> Accuracy for exported Discrete Naive Bayes model:" + accuracy);
+
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
+
+ DiscreteNaiveBayesModel modelImportedFromJSON = DiscreteNaiveBayesModel.fromJSON(jsonMdlPath);
+
+ System.out.println("\n>>> Imported Discrete Naive Bayes model: " + modelImportedFromJSON.toString(true));
+
+ accuracy = Evaluator.evaluate(
+ dataCache,
+ modelImportedFromJSON,
+ vectorizer,
+ MetricName.ACCURACY
+ );
+
+ System.out.println("\n>>> Accuracy for imported Discrete Naive Bayes model:" + accuracy);
+
+ System.out.println("\n>>> Discrete Naive bayes model over partitioned dataset usage example completed.");
+ }
+ finally {
+ if (dataCache != null)
+ dataCache.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
+ }
+ }
+ finally {
+ System.out.flush();
+ }
+ }
+
+}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/GDBOnTreesClassificationExportImportExample.java
similarity index 65%
copy from examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
copy to examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/GDBOnTreesClassificationExportImportExample.java
index a2eaf47..9aa8f22 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/GDBOnTreesClassificationExportImportExample.java
@@ -15,19 +15,23 @@
* limitations under the License.
*/
-package org.apache.ignite.examples.ml.tree.boosting;
+package org.apache.ignite.examples.ml.inference.exchange;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.GDBModel;
+import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
import org.jetbrains.annotations.NotNull;
@@ -38,55 +42,58 @@ import org.jetbrains.annotations.NotNull;
* <p>
* In this example dataset is created automatically by meander function {@code f(x) = [sin(x) > 0]}.</p>
*/
-public class GDBOnTreesClassificationTrainerExample {
+public class GDBOnTreesClassificationExportImportExample {
/**
* Run example.
*
* @param args Command line arguments, none required.
*/
- public static void main(String... args) {
+ public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> GDB classification trainer example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
- System.out.println(">>> Ignite grid started.");
+ System.out.println("\n>>> Ignite grid started.");
// Create cache with training data.
CacheConfiguration<Integer, double[]> trainingSetCfg = createCacheConfiguration();
IgniteCache<Integer, double[]> trainingSet = null;
+ Path jsonMdlPath = null;
try {
trainingSet = fillTrainingData(ignite, trainingSetCfg);
// Create classification trainer.
- DatasetTrainer<ModelsComposition, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.)
+ GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.)
.withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.1));
// Train decision tree model.
- ModelsComposition mdl = trainer.fit(
+ GDBModel mdl = trainer.fit(
ignite,
trainingSet,
new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Valid answer\t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println("\n>>> Exported GDB classification model: " + mdl.toString(true));
- // Calculate score.
- for (int x = -5; x < 5; x++) {
- double predicted = mdl.predict(VectorUtils.of(x));
+ predictOnGeneratedData(mdl);
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x) < 0 ? 0.0 : 1.0);
- }
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> Count of trees = " + mdl.getModels().size());
- System.out.println(">>> ---------------------------------");
+ IgniteFunction<Double, Double> lbMapper = lb -> lb > 0.5 ? 1.0 : 0.0;
+ GDBModel modelImportedFromJSON = GDBModel.fromJSON(jsonMdlPath).withLblMapping(lbMapper);
+
+ System.out.println("\n>>> Imported GDB classification model: " + modelImportedFromJSON.toString(true));
+
+ predictOnGeneratedData(modelImportedFromJSON);
System.out.println(">>> GDB classification trainer example completed.");
}
finally {
- trainingSet.destroy();
+ if (trainingSet != null)
+ trainingSet.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
}
}
finally {
@@ -94,6 +101,23 @@ public class GDBOnTreesClassificationTrainerExample {
}
}
+ private static void predictOnGeneratedData(GDBModel mdl) {
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Valid answer\t|");
+ System.out.println(">>> ---------------------------------");
+
+ // Calculate score.
+ for (int x = -5; x < 5; x++) {
+ double predicted = mdl.predict(VectorUtils.of(x));
+
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.sin(x) < 0 ? 0.0 : 1.0);
+ }
+
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> Count of trees = " + mdl.getModels().size());
+ System.out.println(">>> ---------------------------------");
+ }
+
/**
* Create cache configuration.
*/
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/GDBOnTreesRegressionExportImportExample.java
similarity index 67%
copy from examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
copy to examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/GDBOnTreesRegressionExportImportExample.java
index 09dd708..14233e3 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/GDBOnTreesRegressionExportImportExample.java
@@ -15,21 +15,23 @@
* limitations under the License.
*/
-package org.apache.ignite.examples.ml.tree.boosting;
+package org.apache.ignite.examples.ml.inference.exchange;
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.GDBModel;
+import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer;
-import org.apache.ignite.ml.inference.Model;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer;
import org.jetbrains.annotations.NotNull;
@@ -40,13 +42,13 @@ import org.jetbrains.annotations.NotNull;
* <p>
* In this example dataset is created automatically by parabolic function {@code f(x) = x^2}.</p>
*/
-public class GDBOnTreesRegressionTrainerExample {
+public class GDBOnTreesRegressionExportImportExample {
/**
* Run example.
*
* @param args Command line arguments, none required.
*/
- public static void main(String... args) {
+ public static void main(String[] args) throws IOException {
System.out.println();
System.out.println(">>> GDB regression trainer example started.");
// Start ignite grid.
@@ -56,36 +58,42 @@ public class GDBOnTreesRegressionTrainerExample {
// Create cache with training data.
CacheConfiguration<Integer, double[]> trainingSetCfg = createCacheConfiguration();
IgniteCache<Integer, double[]> trainingSet = null;
+ Path jsonMdlPath = null;
try {
trainingSet = fillTrainingData(ignite, trainingSetCfg);
// Create regression trainer.
- DatasetTrainer<ModelsComposition, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.)
+ GDBTrainer trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.)
.withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.001));
// Train decision tree model.
- Model<Vector, Double> mdl = trainer.fit(
+ GDBModel mdl = trainer.fit(
ignite,
trainingSet,
new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> | Prediction\t| Valid answer \t|");
- System.out.println(">>> ---------------------------------");
+ System.out.println("\n>>> Exported GDB regression model: " + mdl.toString(true));
- // Calculate score.
- for (int x = -5; x < 5; x++) {
- double predicted = mdl.predict(VectorUtils.of(x));
+ predictOnGeneratedData(mdl);
- System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.pow(x, 2));
- }
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
+
+ IgniteFunction<Double, Double> lbMapper = lb -> lb;
+ GDBModel modelImportedFromJSON = GDBModel.fromJSON(jsonMdlPath).withLblMapping(lbMapper);
+
+ System.out.println("\n>>> Imported GDB regression model: " + modelImportedFromJSON.toString(true));
+
+ predictOnGeneratedData(modelImportedFromJSON);
- System.out.println(">>> ---------------------------------");
System.out.println(">>> GDB regression trainer example completed.");
}
finally {
- trainingSet.destroy();
+ if (trainingSet != null)
+ trainingSet.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
}
}
finally {
@@ -93,6 +101,21 @@ public class GDBOnTreesRegressionTrainerExample {
}
}
+ private static void predictOnGeneratedData(GDBModel mdl) {
+ System.out.println(">>> ---------------------------------");
+ System.out.println(">>> | Prediction\t| Valid answer \t|");
+ System.out.println(">>> ---------------------------------");
+
+ // Calculate score.
+ for (int x = -5; x < 5; x++) {
+ double predicted = mdl.predict(VectorUtils.of(x));
+
+ System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", predicted, Math.pow(x, 2));
+ }
+
+ System.out.println(">>> ---------------------------------");
+ }
+
/**
* Create cache configuration.
*/
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/GaussianNaiveBayesExportImportExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/GaussianNaiveBayesExportImportExample.java
new file mode 100644
index 0000000..b6fb9c9
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/GaussianNaiveBayesExportImportExample.java
@@ -0,0 +1,117 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.inference.exchange;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
+import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.MetricName;
+
+/**
+ * Run naive Bayes classification model based on <a href="https://en.wikipedia.org/wiki/Naive_Bayes_classifier"> naive
+ * Bayes classifier</a> algorithm ({@link GaussianNaiveBayesTrainer}) over distributed cache.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points (based on the
+ * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p>
+ * <p>
+ * After that it trains the naive Bayes classification model based on the specified data.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the trained model to predict the target value,
+ * compares prediction to expected outcome (ground truth), and builds
+ * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
+ */
+public class GaussianNaiveBayesExportImportExample {
+ /**
+ * Run example.
+ */
+ public static void main(String[] args) throws IOException {
+ System.out.println();
+ System.out.println(">>> Naive Bayes classification model over partitioned dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteCache<Integer, Vector> dataCache = null;
+ Path jsonMdlPath = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+
+ System.out.println(">>> Create new Gaussian Naive Bayes classification trainer object.");
+ GaussianNaiveBayesTrainer trainer = new GaussianNaiveBayesTrainer();
+
+ System.out.println("\n>>> Perform the training to get the model.");
+
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
+
+ GaussianNaiveBayesModel mdl = trainer.fit(ignite, dataCache, vectorizer);
+ System.out.println("\n>>> Exported Gaussian Naive Bayes model: " + mdl.toString(true));
+
+ double accuracy = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ vectorizer,
+ MetricName.ACCURACY
+ );
+
+ System.out.println("\n>>> Accuracy for exported Gaussian Naive Bayes model:" + accuracy);
+
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
+
+ GaussianNaiveBayesModel modelImportedFromJSON = GaussianNaiveBayesModel.fromJSON(jsonMdlPath);
+
+ System.out.println("\n>>> Imported Gaussian Naive Bayes model: " + modelImportedFromJSON.toString(true));
+
+ accuracy = Evaluator.evaluate(
+ dataCache,
+ modelImportedFromJSON,
+ vectorizer,
+ MetricName.ACCURACY
+ );
+
+ System.out.println("\n>>> Accuracy for imported Gaussian Naive Bayes model:" + accuracy);
+
+ System.out.println("\n>>> Gaussian Naive bayes model over partitioned dataset usage example completed.");
+ }
+ finally {
+ if (dataCache != null)
+ dataCache.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
+ }
+ }
+ finally {
+ System.out.flush();
+ }
+ }
+
+}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/KMeansClusterizationExportImportExample.java
similarity index 67%
copy from examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
copy to examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/KMeansClusterizationExportImportExample.java
index beee4f6..ec5e689 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/KMeansClusterizationExportImportExample.java
@@ -15,22 +15,21 @@
* limitations under the License.
*/
-package org.apache.ignite.examples.ml.clustering;
+package org.apache.ignite.examples.ml.inference.exchange;
import java.io.IOException;
-import javax.cache.Cache;
+import java.nio.file.Files;
+import java.nio.file.Path;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
import org.apache.ignite.examples.ml.util.SandboxMLCache;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
-import org.apache.ignite.ml.math.Tracer;
+import org.apache.ignite.ml.math.distances.WeightedMinkowskiDistance;
import org.apache.ignite.ml.math.primitives.vector.Vector;
/**
@@ -47,7 +46,7 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
* <p>
* You can change the test data used in this example and re-run it to explore this algorithm further.</p>
*/
-public class KMeansClusterizationExample {
+public class KMeansClusterizationExportImportExample {
/**
* Run example.
*/
@@ -59,12 +58,15 @@ public class KMeansClusterizationExample {
System.out.println(">>> Ignite grid started.");
IgniteCache<Integer, Vector> dataCache = null;
+ Path jsonMdlPath = null;
try {
dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
- KMeansTrainer trainer = new KMeansTrainer();
+ KMeansTrainer trainer = new KMeansTrainer()
+ .withDistance(new WeightedMinkowskiDistance(2, new double[] {5.9360, 2.7700, 4.2600, 1.3260}));
+ //.withDistance(new MinkowskiDistance(2));
KMeansModel mdl = trainer.fit(
ignite,
@@ -72,33 +74,22 @@ public class KMeansClusterizationExample {
vectorizer
);
- System.out.println(">>> KMeans centroids");
- Tracer.showAscii(mdl.getCenters()[0]);
- Tracer.showAscii(mdl.getCenters()[1]);
- System.out.println(">>>");
+ System.out.println("\n>>> Exported KMeans model: " + mdl);
- System.out.println(">>> --------------------------------------------");
- System.out.println(">>> | Predicted cluster\t| Erased class label\t|");
- System.out.println(">>> --------------------------------------------");
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
+ KMeansModel modelImportedFromJSON = KMeansModel.fromJSON(jsonMdlPath);
- double prediction = mdl.predict(inputs);
+ System.out.println("\n>>> Imported KMeans model: " + modelImportedFromJSON);
- System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
-
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> KMeans clustering algorithm over cached dataset usage example completed.");
- }
+ System.out.println("\n>>> KMeans clustering algorithm over cached dataset usage example completed.");
}
finally {
if (dataCache != null)
dataCache.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
}
}
finally {
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/LinearRegressionExportImportExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/LinearRegressionExportImportExample.java
new file mode 100644
index 0000000..723784b
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/LinearRegressionExportImportExample.java
@@ -0,0 +1,116 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.inference.exchange;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
+import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.MetricName;
+
+/**
+ * Run linear regression model based on <a href="http://web.stanford.edu/group/SOL/software/lsqr/">LSQR algorithm</a>
+ * ({@link LinearRegressionLSQRTrainer}) over cached dataset.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with simple test data.</p>
+ * <p>
+ * After that it trains the linear regression model based on the specified data.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the trained model to predict the target value
+ * and compares prediction to expected outcome (ground truth).</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
+ */
+public class LinearRegressionExportImportExample {
+ /**
+ * Run example.
+ */
+ public static void main(String[] args) throws IOException {
+ System.out.println();
+ System.out.println(">>> Linear regression model over cache based dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println(">>> Ignite grid started.");
+
+ IgniteCache<Integer, Vector> dataCache = null;
+ Path jsonMdlPath = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.MORTALITY_DATA);
+
+ System.out.println("\n>>> Create new linear regression trainer object.");
+ LinearRegressionLSQRTrainer trainer = new LinearRegressionLSQRTrainer();
+
+ System.out.println("\n>>> Perform the training to get the model.");
+
+ LinearRegressionModel mdl = trainer.fit(
+ ignite,
+ dataCache,
+ new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)
+ );
+
+ System.out.println("\n>>> Exported LinearRegression model: " + mdl);
+
+ double rmse = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST),
+ MetricName.RMSE
+ );
+
+ System.out.println("\n>>> RMSE for exported LinearRegression model: " + rmse);
+
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
+
+ LinearRegressionModel modelImportedFromJSON = LinearRegressionModel.fromJSON(jsonMdlPath);
+
+ System.out.println("\n>>> Imported LinearRegression model: " + modelImportedFromJSON);
+
+ rmse = Evaluator.evaluate(
+ dataCache,
+ mdl,
+ new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST),
+ MetricName.RMSE
+ );
+
+ System.out.println("\n>>> RMSE for imported LinearRegression model: " + rmse);
+
+ System.out.println("\n>>> Linear regression model over cache based dataset usage example completed.");
+ }
+ finally {
+ if (dataCache != null)
+ dataCache.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
+ }
+ }
+ finally {
+ System.out.flush();
+ }
+ }
+}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/LogisticRegressionExportImportExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/LogisticRegressionExportImportExample.java
new file mode 100644
index 0000000..6491f7e
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/LogisticRegressionExportImportExample.java
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.inference.exchange;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.nn.UpdatesStrategy;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate;
+import org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
+import org.apache.ignite.ml.regressions.logistic.LogisticRegressionSGDTrainer;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.MetricName;
+
+/**
+ * Run logistic regression model based on <a href="https://en.wikipedia.org/wiki/Stochastic_gradient_descent">
+ * stochastic gradient descent</a> algorithm ({@link LogisticRegressionSGDTrainer}) over distributed cache.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points (based on the
+ * <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p>
+ * <p>
+ * After that it trains the logistic regression model based on the specified data.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, applies the trained model to predict the target value,
+ * compares prediction to expected outcome (ground truth), and builds
+ * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
+ */
+public class LogisticRegressionExportImportExample {
+ /**
+ * Run example.
+ */
+ public static void main(String[] args) throws IOException {
+ System.out.println();
+ System.out.println(">>> Logistic regression model over partitioned dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println("\n>>> Ignite grid started.");
+
+ IgniteCache<Integer, Vector> dataCache = null;
+ Path jsonMdlPath = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
+
+ System.out.println("\n>>> Create new logistic regression trainer object.");
+ LogisticRegressionSGDTrainer trainer = new LogisticRegressionSGDTrainer()
+ .withUpdatesStgy(new UpdatesStrategy<>(
+ new SimpleGDUpdateCalculator(0.2),
+ SimpleGDParameterUpdate.SUM_LOCAL,
+ SimpleGDParameterUpdate.AVG
+ ))
+ .withMaxIterations(100000)
+ .withLocIterations(100)
+ .withBatchSize(10)
+ .withSeed(123L);
+
+ System.out.println("\n>>> Perform the training to get the model.");
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
+
+ LogisticRegressionModel mdl = trainer.fit(ignite, dataCache, vectorizer);
+
+ System.out.println("\n>>> Exported logistic regression model: " + mdl);
+
+ double accuracy = Evaluator.evaluate(dataCache,
+ mdl, vectorizer, MetricName.ACCURACY
+ );
+
+ System.out.println("\n>>> Accuracy for exported logistic regression model " + accuracy);
+
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
+
+ LogisticRegressionModel modelImportedFromJSON = LogisticRegressionModel.fromJSON(jsonMdlPath);
+
+ System.out.println("\n>>> Imported logistic regression model: " + modelImportedFromJSON);
+
+ accuracy = Evaluator.evaluate(dataCache,
+ modelImportedFromJSON, vectorizer, MetricName.ACCURACY
+ );
+
+ System.out.println("\n>>> Accuracy for imported logistic regression model " + accuracy);
+
+ System.out.println("\n>>> Logistic regression model over partitioned dataset usage example completed.");
+ }
+ finally {
+ if (dataCache != null)
+ dataCache.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
+ }
+ }
+ finally {
+ System.out.flush();
+ }
+ }
+}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/RandomForestClassificationExportImportExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/RandomForestClassificationExportImportExample.java
new file mode 100644
index 0000000..6bb368f
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/RandomForestClassificationExportImportExample.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.inference.exchange;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import javax.cache.Cache;
+import org.apache.commons.math3.util.Precision;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.dataset.feature.FeatureMeta;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.tree.randomforest.RandomForestClassifierTrainer;
+import org.apache.ignite.ml.tree.randomforest.RandomForestModel;
+import org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies;
+
+/**
+ * Example represents a solution for the task of wine classification based on a
+ * <a href ="https://en.wikipedia.org/wiki/Random_forest">Random Forest</a> implementation for
+ * multi-classification.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points (based on the
+ * <a href="https://archive.ics.uci.edu/ml/machine-learning-databases/wine/">Wine recognition dataset</a>).</p>
+ * <p>
+ * After that it initializes the {@link RandomForestClassifierTrainer} with thread pool for multi-thread learning and
+ * trains the model based on the specified data using random forest regression algorithm.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, compares prediction of the trained model to the
+ * expected outcome (ground truth), and evaluates accuracy of the model.</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
+ */
+public class RandomForestClassificationExportImportExample {
+ /**
+ * Run example.
+ */
+ public static void main(String[] args) throws IOException {
+ System.out.println();
+ System.out.println(">>> Random Forest multi-class classification algorithm over cached dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println("\n>>> Ignite grid started.");
+
+ IgniteCache<Integer, Vector> dataCache = null;
+ Path jsonMdlPath = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.WINE_RECOGNITION);
+
+ AtomicInteger idx = new AtomicInteger(0);
+ RandomForestClassifierTrainer classifier = new RandomForestClassifierTrainer(
+ IntStream.range(0, dataCache.get(1).size() - 1).mapToObj(
+ x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList())
+ ).withAmountOfTrees(101)
+ .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
+ .withMaxDepth(4)
+ .withMinImpurityDelta(0.)
+ .withSubSampleSize(0.3)
+ .withSeed(0);
+
+ System.out.println(">>> Configured trainer: " + classifier.getClass().getSimpleName());
+
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
+ RandomForestModel mdl = classifier.fit(ignite, dataCache, vectorizer);
+
+ System.out.println(">>> Exported Random Forest classification model: " + mdl.toString(true));
+
+ double accuracy = evaluateModel(dataCache, mdl);
+
+ System.out.println("\n>>> Accuracy for exported Random Forest classification model " + accuracy);
+
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
+
+ RandomForestModel modelImportedFromJSON = RandomForestModel.fromJSON(jsonMdlPath);
+
+ System.out.println("\n>>> Imported Random Forest classification model: " + modelImportedFromJSON);
+
+ accuracy = evaluateModel(dataCache, mdl);
+
+ System.out.println("\n>>> Accuracy for imported Random Forest classification model " + accuracy);
+
+ System.out.println("\n>>> Random Forest multi-class classification algorithm over cached dataset usage example completed.");
+
+ }
+ finally {
+ if (dataCache != null)
+ dataCache.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
+ }
+ }
+ finally {
+ System.out.flush();
+ }
+ }
+
+ private static double evaluateModel(IgniteCache<Integer, Vector> dataCache, RandomForestModel randomForestMdl) {
+ int amountOfErrors = 0;
+ int totalAmount = 0;
+
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
+
+ double prediction = randomForestMdl.predict(inputs);
+
+ totalAmount++;
+ if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
+ amountOfErrors++;
+ }
+ }
+
+ return 1 - amountOfErrors / (double) totalAmount;
+ }
+}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/RandomForestRegressionExportImportExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/RandomForestRegressionExportImportExample.java
new file mode 100644
index 0000000..4d7d4ad
--- /dev/null
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/RandomForestRegressionExportImportExample.java
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.examples.ml.inference.exchange;
+
+import java.io.IOException;
+import java.nio.file.Files;
+import java.nio.file.Path;
+import java.util.concurrent.atomic.AtomicInteger;
+import java.util.stream.Collectors;
+import java.util.stream.IntStream;
+import javax.cache.Cache;
+import org.apache.ignite.Ignite;
+import org.apache.ignite.IgniteCache;
+import org.apache.ignite.Ignition;
+import org.apache.ignite.cache.query.QueryCursor;
+import org.apache.ignite.cache.query.ScanQuery;
+import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
+import org.apache.ignite.examples.ml.util.SandboxMLCache;
+import org.apache.ignite.ml.dataset.feature.FeatureMeta;
+import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
+import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
+import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
+import org.apache.ignite.ml.environment.logging.ConsoleLogger;
+import org.apache.ignite.ml.environment.parallelism.ParallelismStrategy;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.tree.randomforest.RandomForestModel;
+import org.apache.ignite.ml.tree.randomforest.RandomForestRegressionTrainer;
+import org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies;
+
+/**
+ * Example represents a solution for the task of price predictions for houses in Boston based on a
+ * <a href ="https://en.wikipedia.org/wiki/Random_forest">Random Forest</a> implementation for regression.
+ * <p>
+ * Code in this example launches Ignite grid and fills the cache with test data points (based on the
+ * <a href="https://archive.ics.uci.edu/ml/machine-learning-databases/housing/">Boston Housing dataset</a>).</p>
+ * <p>
+ * After that it initializes the {@link RandomForestRegressionTrainer} and trains the model based on the specified data
+ * using random forest regression algorithm.</p>
+ * <p>
+ * Finally, this example loops over the test set of data points, compares prediction of the trained model to the
+ * expected outcome (ground truth), and evaluates model quality in terms of Mean Squared Error (MSE) and Mean Absolute
+ * Error (MAE).</p>
+ * <p>
+ * You can change the test data used in this example and re-run it to explore this algorithm further.</p>
+ */
+public class RandomForestRegressionExportImportExample {
+ /**
+ * Run example.
+ */
+ public static void main(String[] args) throws IOException {
+ System.out.println();
+ System.out.println(">>> Random Forest regression algorithm over cached dataset usage example started.");
+ // Start ignite grid.
+ try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
+ System.out.println("\n>>> Ignite grid started.");
+
+ IgniteCache<Integer, Vector> dataCache = null;
+ Path jsonMdlPath = null;
+ try {
+ dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.BOSTON_HOUSE_PRICES);
+
+ AtomicInteger idx = new AtomicInteger(0);
+ RandomForestRegressionTrainer trainer = new RandomForestRegressionTrainer(
+ IntStream.range(0, dataCache.get(1).size() - 1).mapToObj(
+ x -> new FeatureMeta("", idx.getAndIncrement(), false)).collect(Collectors.toList())
+ ).withAmountOfTrees(101)
+ .withFeaturesCountSelectionStrgy(FeaturesCountSelectionStrategies.ONE_THIRD)
+ .withMaxDepth(4)
+ .withMinImpurityDelta(0.)
+ .withSubSampleSize(0.3)
+ .withSeed(0);
+
+ trainer.withEnvironmentBuilder(LearningEnvironmentBuilder.defaultBuilder()
+ .withParallelismStrategyTypeDependency(ParallelismStrategy.ON_DEFAULT_POOL)
+ .withLoggingFactoryDependency(ConsoleLogger.Factory.LOW)
+ );
+
+ System.out.println("\n>>> Configured trainer: " + trainer.getClass().getSimpleName());
+
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
+ RandomForestModel mdl = trainer.fit(ignite, dataCache, vectorizer);
+
+ System.out.println("\n>>> Exported Random Forest regression model: " + mdl.toString(true));
+
+ double mae = evaluateModel(dataCache, mdl);
+
+ System.out.println("\n>>> Mean absolute error (MAE) for exported Random Forest regression model " + mae);
+
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
+
+ RandomForestModel modelImportedFromJSON = RandomForestModel.fromJSON(jsonMdlPath);
+
+ System.out.println("\n>>> Exported Random Forest regression model: " + modelImportedFromJSON.toString(true));
+
+ mae = evaluateModel(dataCache, modelImportedFromJSON);
+
+ System.out.println("\n>>> Mean absolute error (MAE) for exported Random Forest regression model " + mae);
+
+ System.out.println("\n>>> Random Forest regression algorithm over cached dataset usage example completed.");
+ }
+ finally {
+ if (dataCache != null)
+ dataCache.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
+ }
+ }
+ finally {
+ System.out.flush();
+ }
+ }
+
+ private static double evaluateModel(IgniteCache<Integer, Vector> dataCache, RandomForestModel randomForestMdl) {
+ double mae = 0.0;
+ int totalAmount = 0;
+
+ try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
+ for (Cache.Entry<Integer, Vector> observation : observations) {
+ Vector val = observation.getValue();
+ Vector inputs = val.copyOfRange(1, val.size());
+ double groundTruth = val.get(0);
+
+ double prediction = randomForestMdl.predict(inputs);
+
+ mae += Math.abs(prediction - groundTruth);
+
+ totalAmount++;
+ }
+
+ mae /= totalAmount;
+ }
+ return mae;
+ }
+}
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/SVMExportImportExample.java
similarity index 51%
copy from examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
copy to examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/SVMExportImportExample.java
index beee4f6..2426290 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/exchange/SVMExportImportExample.java
@@ -15,90 +15,95 @@
* limitations under the License.
*/
-package org.apache.ignite.examples.ml.clustering;
+package org.apache.ignite.examples.ml.inference.exchange;
import java.io.IOException;
-import javax.cache.Cache;
+import java.nio.file.Files;
+import java.nio.file.Path;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.cache.query.QueryCursor;
-import org.apache.ignite.cache.query.ScanQuery;
import org.apache.ignite.examples.ml.util.MLSandboxDatasets;
import org.apache.ignite.examples.ml.util.SandboxMLCache;
-import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
-import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
-import org.apache.ignite.ml.math.Tracer;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
+import org.apache.ignite.ml.selection.scoring.metric.MetricName;
+import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
+import org.apache.ignite.ml.svm.SVMLinearClassificationTrainer;
/**
- * Run KMeans clustering algorithm ({@link KMeansTrainer}) over distributed dataset.
+ * Run SVM binary-class classification model ({@link SVMLinearClassificationModel}) over distributed dataset.
* <p>
* Code in this example launches Ignite grid and fills the cache with test data points (based on the
* <a href="https://en.wikipedia.org/wiki/Iris_flower_data_set"></a>Iris dataset</a>).</p>
* <p>
- * After that it trains the model based on the specified data using
- * <a href="https://en.wikipedia.org/wiki/K-means_clustering">KMeans</a> algorithm.</p>
+ * After that it trains the model based on the specified data using KMeans algorithm.</p>
* <p>
* Finally, this example loops over the test set of data points, applies the trained model to predict what cluster does
- * this point belong to, and compares prediction to expected outcome (ground truth).</p>
+ * this point belong to, compares prediction to expected outcome (ground truth), and builds
+ * <a href="https://en.wikipedia.org/wiki/Confusion_matrix">confusion matrix</a>.</p>
* <p>
* You can change the test data used in this example and re-run it to explore this algorithm further.</p>
*/
-public class KMeansClusterizationExample {
+public class SVMExportImportExample {
/**
* Run example.
*/
public static void main(String[] args) throws IOException {
System.out.println();
- System.out.println(">>> KMeans clustering algorithm over cached dataset usage example started.");
+ System.out.println(">>> SVM Binary classification model over cached dataset usage example started.");
// Start ignite grid.
try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
- System.out.println(">>> Ignite grid started.");
+ System.out.println("\n>>> Ignite grid started.");
IgniteCache<Integer, Vector> dataCache = null;
+ Path jsonMdlPath = null;
try {
dataCache = new SandboxMLCache(ignite).fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
- Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
+ SVMLinearClassificationTrainer trainer = new SVMLinearClassificationTrainer();
- KMeansTrainer trainer = new KMeansTrainer();
+ Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>()
+ .labeled(Vectorizer.LabelCoordinate.FIRST);
- KMeansModel mdl = trainer.fit(
- ignite,
+ SVMLinearClassificationModel mdl = trainer.fit(ignite, dataCache, vectorizer);
+
+ System.out.println("\n>>> Exported SVM model: " + mdl);
+
+ double accuracy = Evaluator.evaluate(
dataCache,
- vectorizer
+ mdl,
+ vectorizer,
+ MetricName.ACCURACY
);
- System.out.println(">>> KMeans centroids");
- Tracer.showAscii(mdl.getCenters()[0]);
- Tracer.showAscii(mdl.getCenters()[1]);
- System.out.println(">>>");
+ System.out.println("\n>>> Accuracy for exported SVM model: " + accuracy);
+
+ jsonMdlPath = Files.createTempFile(null, null);
+ mdl.toJSON(jsonMdlPath);
- System.out.println(">>> --------------------------------------------");
- System.out.println(">>> | Predicted cluster\t| Erased class label\t|");
- System.out.println(">>> --------------------------------------------");
+ SVMLinearClassificationModel modelImportedFromJSON = SVMLinearClassificationModel.fromJSON(jsonMdlPath);
- try (QueryCursor<Cache.Entry<Integer, Vector>> observations = dataCache.query(new ScanQuery<>())) {
- for (Cache.Entry<Integer, Vector> observation : observations) {
- Vector val = observation.getValue();
- Vector inputs = val.copyOfRange(1, val.size());
- double groundTruth = val.get(0);
+ System.out.println("\n>>> Imported SVM model: " + modelImportedFromJSON);
- double prediction = mdl.predict(inputs);
+ accuracy = Evaluator.evaluate(
+ dataCache,
+ modelImportedFromJSON,
+ vectorizer,
+ MetricName.ACCURACY
+ );
- System.out.printf(">>> | %.4f\t\t\t| %.4f\t\t|\n", prediction, groundTruth);
- }
+ System.out.println("\n>>> Accuracy for imported SVM model: " + accuracy);
- System.out.println(">>> ---------------------------------");
- System.out.println(">>> KMeans clustering algorithm over cached dataset usage example completed.");
- }
+ System.out.println("\n>>> SVM Binary classification model over cache based dataset usage example completed.");
}
finally {
if (dataCache != null)
dataCache.destroy();
+ if (jsonMdlPath != null)
+ Files.deleteIfExists(jsonMdlPath);
}
}
finally {
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java
index 3340ed9..d03bb96 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeFromSparkExample.java
@@ -34,7 +34,7 @@ import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Run Decision Tree model loaded from snappy.parquet file. The snappy.parquet file was generated by Spark MLLib
@@ -69,7 +69,7 @@ public class DecisionTreeFromSparkExample {
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 5, 6, 4).labeled(1);
- DecisionTreeNode mdl = (DecisionTreeNode)SparkModelParser.parse(
+ DecisionTreeModel mdl = (DecisionTreeModel)SparkModelParser.parse(
SPARK_MDL_PATH,
SupportedSparkModels.DECISION_TREE,
env
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeRegressionFromSparkExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeRegressionFromSparkExample.java
index 9c36198..5fd4461 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeRegressionFromSparkExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/inference/spark/modelparser/DecisionTreeRegressionFromSparkExample.java
@@ -35,7 +35,7 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.sparkmodelparser.SparkModelParser;
import org.apache.ignite.ml.sparkmodelparser.SupportedSparkModels;
import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Run Decision tree regression model loaded from snappy.parquet file. The snappy.parquet file was generated by Spark
@@ -69,7 +69,7 @@ public class DecisionTreeRegressionFromSparkExample {
final Vectorizer<Integer, Vector, Integer, Double> vectorizer = new DummyVectorizer<Integer>(0, 1, 5, 6).labeled(4);
- DecisionTreeNode mdl = (DecisionTreeNode)SparkModelParser.parse(
+ DecisionTreeModel mdl = (DecisionTreeModel)SparkModelParser.parse(
SPARK_MDL_PATH,
SupportedSparkModels.DECISION_TREE_REGRESSION,
env
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/EncoderExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/EncoderExample.java
index c24091c..233cb13 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/EncoderExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/EncoderExample.java
@@ -31,7 +31,7 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Example that shows how to use String Encoder preprocessor to encode features presented as a strings.
@@ -73,7 +73,7 @@ public class EncoderExample {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
// Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
dataCache,
encoderPreprocessor
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/EncoderExampleWithNormalization.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/EncoderExampleWithNormalization.java
index d9482a5..7270b03 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/EncoderExampleWithNormalization.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/EncoderExampleWithNormalization.java
@@ -32,7 +32,7 @@ import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Example that shows how to combine together two preprocessors: String Encoder preprocessor to encode features presented as a strings
@@ -80,7 +80,7 @@ public class EncoderExampleWithNormalization {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
// Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
dataCache,
normalizer
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/LabelEncoderExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/LabelEncoderExample.java
index d97c49c..3547d7e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/LabelEncoderExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/preprocessing/encoding/LabelEncoderExample.java
@@ -31,7 +31,7 @@ import org.apache.ignite.ml.preprocessing.encoding.EncoderType;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Example that shows how to use Label Encoder preprocessor to encode labels presented as a strings.
@@ -79,7 +79,7 @@ public class LabelEncoderExample {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
// Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
dataCache,
lbEncoderPreprocessor
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/BostonHousePricesPredictionExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/BostonHousePricesPredictionExample.java
index 511eb05..c572d81 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/BostonHousePricesPredictionExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/regression/linear/BostonHousePricesPredictionExample.java
@@ -105,7 +105,7 @@ public class BostonHousePricesPredictionExample {
private static String toString(LinearRegressionModel mdl) {
BiFunction<Integer, Double, String> formatter = (idx, val) -> String.format("%.2f*f%d", val, idx);
- Vector weights = mdl.getWeights();
+ Vector weights = mdl.weights();
StringBuilder sb = new StringBuilder(formatter.apply(0, weights.get(0)));
for (int fid = 1; fid < weights.size(); fid++) {
@@ -114,7 +114,7 @@ public class BostonHousePricesPredictionExample {
.append(formatter.apply(fid, Math.abs(w)));
}
- double intercept = mdl.getIntercept();
+ double intercept = mdl.intercept();
sb.append(" ").append(intercept > 0 ? "+" : "-").append(" ")
.append(String.format("%.2f", Math.abs(intercept)));
return sb.toString();
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
index e6a4461..93dc051 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/selection/cv/CrossValidationExample.java
@@ -30,7 +30,7 @@ import org.apache.ignite.ml.selection.cv.CrossValidation;
import org.apache.ignite.ml.selection.scoring.metric.MetricName;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Run <a href="https://en.wikipedia.org/wiki/Decision_tree">decision tree</a> classification with
@@ -75,7 +75,7 @@ public class CrossValidationExample {
LabeledDummyVectorizer<Integer, Double> vectorizer = new LabeledDummyVectorizer<>();
- CrossValidation<DecisionTreeNode, Integer, LabeledVector<Double>> scoreCalculator
+ CrossValidation<DecisionTreeModel, Integer, LabeledVector<Double>> scoreCalculator
= new CrossValidation<>();
double[] accuracyScores = scoreCalculator
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
index 543e211..68058b7 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLInferenceExample.java
@@ -30,7 +30,7 @@ import org.apache.ignite.ml.inference.IgniteModelStorageUtil;
import org.apache.ignite.ml.sql.SQLFunctions;
import org.apache.ignite.ml.sql.SqlDatasetBuilder;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
import static org.apache.ignite.examples.ml.sql.DecisionTreeClassificationTrainerSQLTableExample.loadTitanicDatasets;
@@ -101,7 +101,7 @@ public class DecisionTreeClassificationTrainerSQLInferenceExample {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
System.out.println(">>> Perform training...");
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIC_TRAIN"),
new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare")
.withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0))
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
index 083608e..d05d1a9 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/sql/DecisionTreeClassificationTrainerSQLTableExample.java
@@ -34,7 +34,7 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.sql.SqlDatasetBuilder;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Example of using distributed {@link DecisionTreeClassificationTrainer} on a data stored in SQL table.
@@ -101,7 +101,7 @@ public class DecisionTreeClassificationTrainerSQLTableExample {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
System.out.println(">>> Perform training...");
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
new SqlDatasetBuilder(ignite, "SQL_PUBLIC_TITANIC_TRAIN"),
new BinaryObjectVectorizer<>("pclass", "age", "sibsp", "parch", "fare")
.withFeature("sex", BinaryObjectVectorizer.Mapping.create().map("male", 1.0).defaultValue(0.0))
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
index 600f4a5..b1cf23e 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeClassificationTrainerExample.java
@@ -28,7 +28,7 @@ import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorize
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Example of using distributed {@link DecisionTreeClassificationTrainer}.
@@ -75,7 +75,7 @@ public class DecisionTreeClassificationTrainerExample {
// Train decision tree model.
LabeledDummyVectorizer<Integer, Double> vectorizer = new LabeledDummyVectorizer<>();
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
trainingSet,
vectorizer
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
index 1a19771..5cfb828 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/DecisionTreeRegressionTrainerExample.java
@@ -25,7 +25,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.structures.LabeledVector;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
import org.apache.ignite.ml.tree.DecisionTreeRegressionTrainer;
/**
@@ -70,7 +70,7 @@ public class DecisionTreeRegressionTrainerExample {
DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(10, 0);
// Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
+ DecisionTreeModel mdl = trainer.fit(ignite, trainingSet, new LabeledDummyVectorizer<>());
System.out.println(">>> Decision tree regression model: " + mdl);
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
index a2eaf47..7e6c5d3 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesClassificationTrainerExample.java
@@ -22,12 +22,12 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.GDBModel;
+import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
import org.jetbrains.annotations.NotNull;
@@ -58,11 +58,11 @@ public class GDBOnTreesClassificationTrainerExample {
trainingSet = fillTrainingData(ignite, trainingSetCfg);
// Create classification trainer.
- DatasetTrainer<ModelsComposition, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.)
+ GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(1.0, 300, 2, 0.)
.withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.1));
// Train decision tree model.
- ModelsComposition mdl = trainer.fit(
+ GDBModel mdl = trainer.fit(
ignite,
trainingSet,
new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
index 09dd708..a6ea135 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tree/boosting/GDBOnTreesRegressionTrainerExample.java
@@ -22,14 +22,12 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
-import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.GDBModel;
+import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer;
-import org.apache.ignite.ml.inference.Model;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer;
import org.jetbrains.annotations.NotNull;
@@ -60,11 +58,11 @@ public class GDBOnTreesRegressionTrainerExample {
trainingSet = fillTrainingData(ignite, trainingSetCfg);
// Create regression trainer.
- DatasetTrainer<ModelsComposition, Double> trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.)
+ GDBTrainer trainer = new GDBRegressionOnTreesTrainer(1.0, 2000, 1, 0.)
.withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.001));
// Train decision tree model.
- Model<Vector, Double> mdl = trainer.fit(
+ GDBModel mdl = trainer.fit(
ignite,
trainingSet,
new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_11_Boosting.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_11_Boosting.java
index b9006f5..b8e1d00 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_11_Boosting.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_11_Boosting.java
@@ -21,7 +21,8 @@ import java.io.FileNotFoundException;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
import org.apache.ignite.Ignition;
-import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.boosting.GDBModel;
+import org.apache.ignite.ml.composition.boosting.GDBTrainer;
import org.apache.ignite.ml.composition.boosting.convergence.median.MedianOfMedianConvergenceCheckerFactory;
import org.apache.ignite.ml.dataset.feature.extractor.Vectorizer;
import org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer;
@@ -36,7 +37,6 @@ import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.MetricName;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
-import org.apache.ignite.ml.trainers.DatasetTrainer;
import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
/**
@@ -102,11 +102,11 @@ public class Step_11_Boosting {
);
// Create classification trainer.
- DatasetTrainer<ModelsComposition, Double> trainer = new GDBBinaryClassifierOnTreesTrainer(0.5, 500, 4, 0.)
+ GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(0.5, 500, 4, 0.)
.withCheckConvergenceStgyFactory(new MedianOfMedianConvergenceCheckerFactory(0.1));
// Train decision tree model.
- ModelsComposition mdl = trainer.fit(
+ GDBModel mdl = trainer.fit(
ignite,
dataCache,
split.getTrainFilter(),
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
index b6df5d6..97ccb58 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_1_Read_and_Learn.java
@@ -27,7 +27,7 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Usage of {@link DecisionTreeClassificationTrainer} to predict death in the disaster.
@@ -56,7 +56,7 @@ public class Step_1_Read_and_Learn {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
dataCache,
vectorizer
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
index 094a966..a020dbe 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_2_Imputing.java
@@ -29,7 +29,7 @@ import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Usage of {@link ImputerTrainer} to fill missed data ({@code Double.NaN}) values in the chosen columns.
@@ -66,7 +66,7 @@ public class Step_2_Imputing {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
// Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
dataCache,
vectorizer
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
index 68b05a4..c97ee38 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial.java
@@ -31,7 +31,7 @@ import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Let's add two categorial features "sex", "embarked" to predict more precisely than in {@link Step_1_Read_and_Learn}.
@@ -80,7 +80,7 @@ public class Step_3_Categorial {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
// Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
dataCache,
imputingPreprocessor
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
index 206d2dc..1355979 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_3_Categorial_with_One_Hot_Encoder.java
@@ -31,7 +31,7 @@ import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Let's add two categorial features "sex", "embarked" to predict more precisely than in {@link
@@ -83,7 +83,7 @@ public class Step_3_Categorial_with_One_Hot_Encoder {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
// Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
dataCache,
imputingPreprocessor
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
index 1d85a14..f4763a1 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_4_Add_age_fare.java
@@ -31,7 +31,7 @@ import org.apache.ignite.ml.preprocessing.imputing.ImputerTrainer;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* Add yet two numerical features "age", "fare" to improve our model over {@link Step_3_Categorial}.
@@ -79,7 +79,7 @@ public class Step_4_Add_age_fare {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
// Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
dataCache,
imputingPreprocessor
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
index dfb6de0..05d0137 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_5_Scaling.java
@@ -33,7 +33,7 @@ import org.apache.ignite.ml.preprocessing.normalization.NormalizationTrainer;
import org.apache.ignite.ml.selection.scoring.evaluator.Evaluator;
import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* {@link MinMaxScalerTrainer} and {@link NormalizationTrainer} are used in this example due to different values
@@ -97,7 +97,7 @@ public class Step_5_Scaling {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
// Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
dataCache,
normalizationPreprocessor
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
index e104c51..a60a8ba 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_7_Split_train_test.java
@@ -35,7 +35,7 @@ import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* The highest accuracy in the previous example ({@link Step_6_KNN}) is the result of
@@ -103,7 +103,7 @@ public class Step_7_Split_train_test {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(5, 0);
// Train decision tree model.
- DecisionTreeNode mdl = trainer.fit(
+ DecisionTreeModel mdl = trainer.fit(
ignite,
dataCache,
split.getTrainFilter(),
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
index 0da797d..20f4a72 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV.java
@@ -38,7 +38,7 @@ import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* To choose the best hyper-parameters the cross-validation will be used in this example.
@@ -126,7 +126,7 @@ public class Step_8_CV {
DecisionTreeClassificationTrainer trainer
= new DecisionTreeClassificationTrainer(maxDeep, 0);
- CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
+ CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
= new CrossValidation<>();
double[] scores = scoreCalculator
@@ -167,7 +167,7 @@ public class Step_8_CV {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(bestMaxDeep, 0);
// Train decision tree model.
- DecisionTreeNode bestMdl = trainer.fit(
+ DecisionTreeModel bestMdl = trainer.fit(
ignite,
dataCache,
split.getTrainFilter(),
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
index 5b62714..963e1b7 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid.java
@@ -40,7 +40,7 @@ import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* To choose the best hyper-parameters the cross-validation with {@link ParamGrid} will be used in this example.
@@ -119,7 +119,7 @@ public class Step_8_CV_with_Param_Grid {
DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer();
- CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
+ CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
= new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
@@ -156,7 +156,7 @@ public class Step_8_CV_with_Param_Grid {
-> System.out.println("Score " + Arrays.toString(score) + " for hyper params " + hyperParams));
// Train decision tree model.
- DecisionTreeNode bestMdl = trainer.fit(
+ DecisionTreeModel bestMdl = trainer.fit(
ignite,
dataCache,
split.getTrainFilter(),
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_pipeline.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_pipeline.java
index 6be8496..1aa2d57 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_pipeline.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/Step_8_CV_with_Param_Grid_and_pipeline.java
@@ -36,7 +36,7 @@ import org.apache.ignite.ml.selection.scoring.metric.MetricName;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* To choose the best hyper-parameters the cross-validation with {@link ParamGrid} will be used in this example.
@@ -91,7 +91,7 @@ public class Step_8_CV_with_Param_Grid_and_pipeline {
// Tune hyper-parameters with K-fold Cross-Validation on the split training set.
- CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
+ CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
= new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_13_RandomSearch.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_13_RandomSearch.java
index d7e2f27..c489fc9 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_13_RandomSearch.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_13_RandomSearch.java
@@ -42,7 +42,7 @@ import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* To choose the best hyper-parameters the cross-validation with {@link ParamGrid} will be used in this example.
@@ -123,7 +123,7 @@ public class Step_13_RandomSearch {
DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer();
- CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
+ CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
= new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
@@ -166,7 +166,7 @@ public class Step_13_RandomSearch {
-> System.out.println("Score " + Arrays.toString(score) + " for hyper params " + hyperParams));
// Train decision tree model.
- DecisionTreeNode bestMdl = trainer.fit(
+ DecisionTreeModel bestMdl = trainer.fit(
ignite,
dataCache,
split.getTrainFilter(),
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_14_Parallel_Brute_Force_Search.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_14_Parallel_Brute_Force_Search.java
index 017f123..b63bf96 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_14_Parallel_Brute_Force_Search.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_14_Parallel_Brute_Force_Search.java
@@ -45,7 +45,7 @@ import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* To choose the best hyper-parameters the cross-validation with {@link ParamGrid} will be used in this example.
@@ -126,7 +126,7 @@ public class Step_14_Parallel_Brute_Force_Search {
DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer();
- CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
+ CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
= new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
@@ -168,7 +168,7 @@ public class Step_14_Parallel_Brute_Force_Search {
-> System.out.println("Score " + Arrays.toString(score) + " for hyper params " + hyperParams));
// Train decision tree model.
- DecisionTreeNode bestMdl = trainer.fit(
+ DecisionTreeModel bestMdl = trainer.fit(
ignite,
dataCache,
split.getTrainFilter(),
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_15_Parallel_Random_Search.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_15_Parallel_Random_Search.java
index 3a3e9e8..ac6c1eb 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_15_Parallel_Random_Search.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_15_Parallel_Random_Search.java
@@ -45,7 +45,7 @@ import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* To choose the best hyper-parameters the cross-validation with {@link ParamGrid} will be used in this example.
@@ -125,7 +125,7 @@ public class Step_15_Parallel_Random_Search {
// Tune hyper-parameters with K-fold Cross-Validation on the split training set.
DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer();
- CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
+ CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
= new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
@@ -171,7 +171,7 @@ public class Step_15_Parallel_Random_Search {
-> System.out.println("Score " + Arrays.toString(score) + " for hyper params " + hyperParams));
// Train decision tree model.
- DecisionTreeNode bestMdl = trainer.fit(
+ DecisionTreeModel bestMdl = trainer.fit(
ignite,
dataCache,
split.getTrainFilter(),
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_16_Genetic_Programming_Search.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_16_Genetic_Programming_Search.java
index bee51e4..408eb48 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_16_Genetic_Programming_Search.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_16_Genetic_Programming_Search.java
@@ -42,7 +42,7 @@ import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* To choose the best hyper-parameters the cross-validation with {@link ParamGrid} will be used in this example.
@@ -123,7 +123,7 @@ public class Step_16_Genetic_Programming_Search {
DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer();
- CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
+ CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
= new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
@@ -162,7 +162,7 @@ public class Step_16_Genetic_Programming_Search {
-> System.out.println("Score " + Arrays.toString(score) + " for hyper params " + hyperParams));
// Train decision tree model.
- DecisionTreeNode bestMdl = trainer.fit(
+ DecisionTreeModel bestMdl = trainer.fit(
ignite,
dataCache,
split.getTrainFilter(),
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_17_Parallel_Genetic_Programming_Search.java b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_17_Parallel_Genetic_Programming_Search.java
index 34a8158..a9d39bd 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_17_Parallel_Genetic_Programming_Search.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/tutorial/hyperparametertuning/Step_17_Parallel_Genetic_Programming_Search.java
@@ -45,7 +45,7 @@ import org.apache.ignite.ml.selection.scoring.metric.classification.Accuracy;
import org.apache.ignite.ml.selection.split.TrainTestDatasetSplitter;
import org.apache.ignite.ml.selection.split.TrainTestSplit;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
/**
* To choose the best hyper-parameters the cross-validation with {@link ParamGrid} will be used in this example.
@@ -126,7 +126,7 @@ public class Step_17_Parallel_Genetic_Programming_Search {
DecisionTreeClassificationTrainer trainerCV = new DecisionTreeClassificationTrainer();
- CrossValidation<DecisionTreeNode, Integer, Vector> scoreCalculator
+ CrossValidation<DecisionTreeModel, Integer, Vector> scoreCalculator
= new CrossValidation<>();
ParamGrid paramGrid = new ParamGrid()
@@ -168,7 +168,7 @@ public class Step_17_Parallel_Genetic_Programming_Search {
-> System.out.println("Score " + Arrays.toString(score) + " for hyper params " + hyperParams));
// Train decision tree model.
- DecisionTreeNode bestMdl = trainer.fit(
+ DecisionTreeModel bestMdl = trainer.fit(
ignite,
dataCache,
split.getTrainFilter(),
diff --git a/modules/ml/pom.xml b/modules/ml/pom.xml
index 338d254..37d9c10 100644
--- a/modules/ml/pom.xml
+++ b/modules/ml/pom.xml
@@ -160,6 +160,31 @@
<artifactId>slf4j-api</artifactId>
<version>1.7.7</version>
</dependency>
+ <dependency>
+ <groupId>javax.xml.bind</groupId>
+ <artifactId>jaxb-api</artifactId>
+ <version>2.3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.sun.xml.bind</groupId>
+ <artifactId>jaxb-core</artifactId>
+ <version>2.3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>com.sun.xml.bind</groupId>
+ <artifactId>jaxb-impl</artifactId>
+ <version>2.3.0</version>
+ </dependency>
+ <dependency>
+ <groupId>javax.activation</groupId>
+ <artifactId>activation</artifactId>
+ <version>1.1.1</version>
+ </dependency>
+ <dependency>
+ <groupId>com.fasterxml.jackson.core</groupId>
+ <artifactId>jackson-databind</artifactId>
+ <version>2.10.3</version>
+ </dependency>
</dependencies>
<profiles>
diff --git a/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java b/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java
index 8d349a1..373da3a 100644
--- a/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java
+++ b/modules/ml/spark-model-parser/src/main/java/org/apache/ignite/ml/sparkmodelparser/SparkModelParser.java
@@ -25,7 +25,6 @@ import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import java.util.NavigableMap;
import java.util.Scanner;
import java.util.TreeMap;
import org.apache.hadoop.conf.Configuration;
@@ -34,7 +33,7 @@ import org.apache.ignite.internal.util.IgniteUtils;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.composition.ModelsComposition;
-import org.apache.ignite.ml.composition.boosting.GDBTrainer;
+import org.apache.ignite.ml.composition.boosting.GDBModel;
import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
@@ -49,9 +48,7 @@ import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.regressions.linear.LinearRegressionModel;
import org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel;
import org.apache.ignite.ml.svm.SVMLinearClassificationModel;
-import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
-import org.apache.ignite.ml.tree.DecisionTreeLeafNode;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.NodeData;
import org.apache.parquet.column.page.PageReadStore;
import org.apache.parquet.example.data.Group;
import org.apache.parquet.example.data.simple.SimpleGroup;
@@ -66,6 +63,8 @@ import org.apache.parquet.schema.Type;
import org.jetbrains.annotations.NotNull;
import org.jetbrains.annotations.Nullable;
+import static org.apache.ignite.ml.tree.NodeData.buildDecisionTreeModel;
+
/** Parser of Spark models. */
public class SparkModelParser {
/**
@@ -497,7 +496,7 @@ public class SparkModelParser {
final List<IgniteModel<Vector, Double>> models = new ArrayList<>();
nodesByTreeId.forEach((key, nodes) -> models.add(buildDecisionTreeModel(nodes)));
- return new GDBTrainer.GDBModel(models, new WeightedPredictionsAggregator(treeWeights), lbMapper);
+ return new GDBModel(models, new WeightedPredictionsAggregator(treeWeights), lbMapper);
}
catch (IOException e) {
String msg = "Error reading parquet file: " + e.getMessage();
@@ -604,42 +603,13 @@ public class SparkModelParser {
return null;
}
- /**
- * Builds the DT model by the given sorted map of nodes.
- *
- * @param nodes The sorted map of nodes.
- */
- private static DecisionTreeNode buildDecisionTreeModel(Map<Integer, NodeData> nodes) {
- DecisionTreeNode mdl = null;
- if (!nodes.isEmpty()) {
- NodeData rootNodeData = (NodeData)((NavigableMap)nodes).firstEntry().getValue();
- mdl = buildTree(nodes, rootNodeData);
- return mdl;
- }
- return mdl;
- }
-
- /**
- * Build tree or sub-tree based on indices and nodes sorted map as a dictionary.
- *
- * @param nodes The sorted map of nodes.
- * @param rootNodeData Root node data.
- */
- @NotNull private static DecisionTreeNode buildTree(Map<Integer, NodeData> nodes,
- NodeData rootNodeData) {
- return rootNodeData.isLeafNode ? new DecisionTreeLeafNode(rootNodeData.prediction) : new DecisionTreeConditionalNode(rootNodeData.featureIdx,
- rootNodeData.threshold,
- buildTree(nodes, nodes.get(rootNodeData.rightChildId)),
- buildTree(nodes, nodes.get(rootNodeData.leftChildId)),
- null);
- }
/**
* Form the node data according data in parquet row.
*
* @param g The given group presenting the node data from Spark DT model.
*/
- @NotNull private static SparkModelParser.NodeData extractNodeDataFromParquetRow(SimpleGroup g) {
+ @NotNull private static NodeData extractNodeDataFromParquetRow(SimpleGroup g) {
NodeData nodeData = new NodeData();
nodeData.id = g.getInteger(0, 0);
@@ -888,43 +858,4 @@ public class SparkModelParser {
}
return coefficients;
}
-
- /**
- * Presenting data from one parquet row filled with NodeData in Spark DT model.
- */
- private static class NodeData {
- /** Id. */
- int id;
-
- /** Prediction. */
- double prediction;
-
- /** Left child id. */
- int leftChildId;
-
- /** Right child id. */
- int rightChildId;
-
- /** Threshold. */
- double threshold;
-
- /** Feature index. */
- int featureIdx;
-
- /** Is leaf node. */
- boolean isLeafNode;
-
- /** {@inheritDoc} */
- @Override public String toString() {
- return "NodeData{" +
- "id=" + id +
- ", prediction=" + prediction +
- ", leftChildId=" + leftChildId +
- ", rightChildId=" + rightChildId +
- ", threshold=" + threshold +
- ", featureIdx=" + featureIdx +
- ", isLeafNode=" + isLeafNode +
- '}';
- }
- }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmModel.java
index fda08b3..2546d0c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/gmm/GmmModel.java
@@ -19,6 +19,7 @@ package org.apache.ignite.ml.clustering.gmm;
import java.util.Collections;
import java.util.List;
+import com.fasterxml.jackson.annotation.JsonIgnore;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -47,12 +48,17 @@ public class GmmModel extends DistributionMixture<MultivariateGaussianDistributi
super(componentProbs, distributions);
}
+ /** */
+ public GmmModel() {
+ }
+
/** {@inheritDoc} */
@Override public Double predict(Vector input) {
return (double)likelihood(input).maxElement().index();
}
/** {@inheritDoc} */
+ @JsonIgnore
@Override public List<Object> getDependencies() {
return Collections.emptyList();
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java
index 42b0823..4fba739 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java
@@ -22,8 +22,8 @@ import org.apache.ignite.ml.IgniteModel;
/** Base interface for all clusterization models. */
public interface ClusterizationModel<P, V> extends IgniteModel<P, V> {
/** Gets the clusters count. */
- public int getAmountOfClusters();
+ public int amountOfClusters();
/** Get cluster centers. */
- public P[] getCenters();
+ public P[] centers();
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansModel.java
index f1f677f..de473c9 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansModel.java
@@ -17,28 +17,41 @@
package org.apache.ignite.ml.clustering.kmeans;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
+import java.util.UUID;
import java.util.stream.Collectors;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.SerializationFeature;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
+import org.apache.ignite.ml.inference.json.JSONModel;
+import org.apache.ignite.ml.inference.json.JSONWritable;
import org.apache.ignite.ml.math.Tracer;
import org.apache.ignite.ml.math.distances.DistanceMeasure;
+import org.apache.ignite.ml.math.distances.EuclideanDistance;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
+import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
import org.apache.ignite.ml.util.ModelTrace;
/**
* This class encapsulates result of clusterization by KMeans algorithm.
*/
public final class KMeansModel implements ClusterizationModel<Vector, Integer>, Exportable<KMeansModelFormat>,
- DeployableObject {
+ JSONWritable, DeployableObject {
/** Centers of clusters. */
- private final Vector[] centers;
+ private Vector[] centers;
/** Distance measure. */
- private final DistanceMeasure distanceMeasure;
+ private DistanceMeasure distanceMeasure = new EuclideanDistance();
/**
* Construct KMeans model with given centers and distanceMeasure measure.
@@ -51,18 +64,45 @@ public final class KMeansModel implements ClusterizationModel<Vector, Integer>,
this.distanceMeasure = distanceMeasure;
}
+ /** {@inheritDoc} */
+ private KMeansModel() {
+
+ }
+
/** Distance measure. */
public DistanceMeasure distanceMeasure() {
return distanceMeasure;
}
/** {@inheritDoc} */
- @Override public int getAmountOfClusters() {
+ @Override public int amountOfClusters() {
return centers.length;
}
+ /**
+ * Set up the centroids.
+ *
+ * @param centers The parameter value.
+ * @return Model with new centers parameter value.
+ */
+ public KMeansModel withCentroids(Vector[] centers) {
+ this.centers = centers;
+ return this;
+ }
+
+ /**
+ * Set up the distance measure.
+ *
+ * @param distanceMeasure The parameter value.
+ * @return Model with new distance measure parameter value.
+ */
+ public KMeansModel withDistanceMeasure(DistanceMeasure distanceMeasure) {
+ this.distanceMeasure = distanceMeasure;
+ return this;
+ }
+
/** {@inheritDoc} */
- @Override public Vector[] getCenters() {
+ @Override public Vector[] centers() {
return Arrays.copyOf(centers, centers.length);
}
@@ -119,12 +159,11 @@ public final class KMeansModel implements ClusterizationModel<Vector, Integer>,
/** {@inheritDoc} */
@Override public String toString(boolean pretty) {
- String measureName = distanceMeasure.getClass().getSimpleName();
List<String> centersList = Arrays.stream(centers).map(x -> Tracer.asAscii(x, "%.4f", false))
.collect(Collectors.toList());
return ModelTrace.builder("KMeansModel", pretty)
- .addField("distance measure", measureName)
+ .addField("distance measure", distanceMeasure.toString())
.addField("centroids", centersList)
.toString();
}
@@ -133,4 +172,76 @@ public final class KMeansModel implements ClusterizationModel<Vector, Integer>,
@Override public List<Object> getDependencies() {
return Collections.singletonList(distanceMeasure);
}
+
+ /** Loads KMeansModel from JSON file. */
+ public static KMeansModel fromJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+
+ KMeansJSONExportModel exportModel;
+ try {
+ exportModel = mapper
+ .readValue(new File(path.toAbsolutePath().toString()), KMeansJSONExportModel.class);
+
+ return exportModel.convert();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ return null;
+ }
+
+ // TODO: https://github.com/apache/spark/blob/master/mllib/src/main/scala/org/apache/spark/mllib/pmml/export/KMeansPMMLModelExport.scala
+ /** {@inheritDoc} */
+ @Override public void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+
+ try {
+ KMeansJSONExportModel exportModel = new KMeansJSONExportModel(System.currentTimeMillis(), "ann_" + UUID.randomUUID().toString(), KMeansModel.class.getSimpleName());
+ List<double[]> listOfCenters = new ArrayList<>();
+ for (int i = 0; i < centers.length; i++) {
+ listOfCenters.add(centers[i].asArray());
+ }
+
+ exportModel.mdlCenters = listOfCenters;
+ exportModel.distanceMeasure = distanceMeasure;
+
+ File file = new File(path.toAbsolutePath().toString());
+ mapper.writeValue(file, exportModel);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** */
+ public static class KMeansJSONExportModel extends JSONModel {
+ /** Centers of clusters. */
+ public List<double[]> mdlCenters;
+
+ /** Distance measure. */
+ public DistanceMeasure distanceMeasure;
+
+ /** */
+ public KMeansJSONExportModel(Long timestamp, String uid, String modelClass) {
+ super(timestamp, uid, modelClass);
+ }
+
+ /** */
+ @JsonCreator
+ public KMeansJSONExportModel() {
+ }
+
+ /** {@inheritDoc} */
+ @Override public KMeansModel convert() {
+ KMeansModel mdl = new KMeansModel();
+ Vector[] centers = new DenseVector[mdlCenters.size()];
+ for (int i = 0; i < mdlCenters.size(); i++) {
+ centers[i] = VectorUtils.of(mdlCenters.get(i));
+ }
+
+ DistanceMeasure distanceMeasure = this.distanceMeasure;
+
+ mdl.withCentroids(centers);
+ mdl.withDistanceMeasure(distanceMeasure);
+ return mdl;
+ }
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
index caec370e..c36dd34 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
@@ -102,7 +102,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
return getLastTrainedModelOrThrowEmptyDatasetException(mdl);
centers = Optional.ofNullable(mdl)
- .map(KMeansModel::getCenters)
+ .map(KMeansModel::centers)
.orElseGet(() -> initClusterCentersRandomly(dataset, k));
boolean converged = false;
@@ -139,7 +139,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
/** {@inheritDoc} */
@Override public boolean isUpdateable(KMeansModel mdl) {
- return mdl.getCenters().length == k && mdl.distanceMeasure().equals(distance);
+ return mdl.centers().length == k && mdl.distanceMeasure().equals(distance);
}
/**
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
index 3942b9e..190203c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsComposition.java
@@ -19,6 +19,8 @@ package org.apache.ignite.ml.composition;
import java.util.Collections;
import java.util.List;
+
+import com.fasterxml.jackson.annotation.JsonIgnore;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
@@ -30,17 +32,17 @@ import org.apache.ignite.ml.util.ModelTrace;
/**
* Model consisting of several models and prediction aggregation strategy.
*/
-public class ModelsComposition implements IgniteModel<Vector, Double>, Exportable<ModelsCompositionFormat>,
+public class ModelsComposition<M extends IgniteModel<Vector, Double>> implements IgniteModel<Vector, Double>, Exportable<ModelsCompositionFormat>,
DeployableObject {
/**
* Predictions aggregator.
*/
- private final PredictionsAggregator predictionsAggregator;
+ protected PredictionsAggregator predictionsAggregator;
/**
* Models.
*/
- private final List<IgniteModel<Vector, Double>> models;
+ protected List<M> models;
/**
* Constructs a new instance of composition of models.
@@ -48,11 +50,14 @@ public class ModelsComposition implements IgniteModel<Vector, Double>, Exportabl
* @param models Basic models.
* @param predictionsAggregator Predictions aggregator.
*/
- public ModelsComposition(List<? extends IgniteModel<Vector, Double>> models, PredictionsAggregator predictionsAggregator) {
+ public ModelsComposition(List<M> models, PredictionsAggregator predictionsAggregator) {
this.predictionsAggregator = predictionsAggregator;
this.models = Collections.unmodifiableList(models);
}
+ public ModelsComposition() {
+ }
+
/**
* Applies containing models to features and aggregate them to one prediction.
*
@@ -78,7 +83,7 @@ public class ModelsComposition implements IgniteModel<Vector, Double>, Exportabl
/**
* Returns containing models.
*/
- public List<IgniteModel<Vector, Double>> getModels() {
+ public List<M> getModels() {
return models;
}
@@ -102,6 +107,7 @@ public class ModelsComposition implements IgniteModel<Vector, Double>, Exportabl
}
/** {@inheritDoc} */
+ @JsonIgnore
@Override public List<Object> getDependencies() {
return Collections.singletonList(predictionsAggregator);
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java
index ba71afa..c49638c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/ModelsCompositionFormat.java
@@ -33,7 +33,7 @@ public class ModelsCompositionFormat implements Serializable {
private static final long serialVersionUID = 9115341364082681837L;
/** Models. */
- private List<IgniteModel<Vector, Double>> models;
+ private List<? extends IgniteModel<Vector, Double>> models;
/** Predictions aggregator. */
private PredictionsAggregator predictionsAggregator;
@@ -44,13 +44,13 @@ public class ModelsCompositionFormat implements Serializable {
* @param models Models.
* @param predictionsAggregator Predictions aggregator.
*/
- public ModelsCompositionFormat(List<IgniteModel<Vector, Double>> models,PredictionsAggregator predictionsAggregator) {
+ public ModelsCompositionFormat(List<? extends IgniteModel<Vector, Double>> models, PredictionsAggregator predictionsAggregator) {
this.models = models;
this.predictionsAggregator = predictionsAggregator;
}
/** */
- public List<IgniteModel<Vector, Double>> models() {
+ public List<? extends IgniteModel<Vector, Double>> models() {
return models;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
index 44137f7..45b4318 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBLearningStrategy.java
@@ -103,7 +103,7 @@ public class GDBLearningStrategy {
* @param <V> Type of a value in {@code upstream} data.
* @return Updated models list.
*/
- public <K, V> List<IgniteModel<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate,
+ public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate,
DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> preprocessor) {
if (trainerEnvironment == null)
throw new IllegalStateException("Learning environment builder is not set.");
@@ -148,7 +148,7 @@ public class GDBLearningStrategy {
* @param mdlToUpdate Model to update.
* @return List of already learned models.
*/
- @NotNull protected List<IgniteModel<Vector, Double>> initLearningState(GDBTrainer.GDBModel mdlToUpdate) {
+ @NotNull protected List<IgniteModel<Vector, Double>> initLearningState(GDBModel mdlToUpdate) {
List<IgniteModel<Vector, Double>> models = new ArrayList<>();
if (mdlToUpdate != null) {
models.addAll(mdlToUpdate.getModels());
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBModel.java
new file mode 100644
index 0000000..35cb70e
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBModel.java
@@ -0,0 +1,118 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.composition.boosting;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.List;
+import java.util.UUID;
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.ObjectWriter;
+import com.fasterxml.jackson.databind.SerializationFeature;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
+import org.apache.ignite.ml.inference.json.JSONModel;
+import org.apache.ignite.ml.inference.json.JSONModelMixIn;
+import org.apache.ignite.ml.inference.json.JSONWritable;
+import org.apache.ignite.ml.inference.json.JacksonHelper;
+import org.apache.ignite.ml.math.functions.IgniteFunction;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
+
+/**
+ * GDB model.
+ */
+public final class GDBModel extends ModelsComposition<DecisionTreeModel> implements JSONWritable {
+ /** Serial version uid. */
+ private static final long serialVersionUID = 3476661240155508004L;
+
+ /** Internal to external lbl mapping. */
+ @JsonIgnore private IgniteFunction<Double, Double> internalToExternalLblMapping;
+
+ /**
+ * Creates an instance of GDBModel.
+ *
+ * @param models Models.
+ * @param predictionsAggregator Predictions aggregator.
+ * @param internalToExternalLblMapping Internal to external lbl mapping.
+ */
+ public GDBModel(List<? extends IgniteModel<Vector, Double>> models,
+ WeightedPredictionsAggregator predictionsAggregator,
+ IgniteFunction<Double, Double> internalToExternalLblMapping) {
+
+ super((List<DecisionTreeModel>) models, predictionsAggregator);
+ this.internalToExternalLblMapping = internalToExternalLblMapping;
+ }
+
+ private GDBModel() {
+ }
+
+ public GDBModel withLblMapping(IgniteFunction<Double, Double> internalToExternalLblMapping) {
+ this.internalToExternalLblMapping = internalToExternalLblMapping;
+ return this;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double predict(Vector features) {
+ if (internalToExternalLblMapping == null) {
+ throw new IllegalArgumentException("The mapping should not be empty. Initialize it with apropriate function. ");
+ } else {
+ return internalToExternalLblMapping.apply(super.predict(features));
+ }
+ }
+
+ /** {@inheritDoc} */
+ @Override public void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+ mapper.addMixIn(GDBModel.class, JSONModelMixIn.class);
+
+ ObjectWriter writer = mapper
+ .writerFor(GDBModel.class)
+ .withAttribute("formatVersion", JSONModel.JSON_MODEL_FORMAT_VERSION)
+ .withAttribute("timestamp", System.currentTimeMillis())
+ .withAttribute("uid", "dt_" + UUID.randomUUID().toString())
+ .withAttribute("modelClass", GDBModel.class.getSimpleName());
+
+ try {
+ File file = new File(path.toAbsolutePath().toString());
+ writer.writeValue(file, this);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** Loads RandomForestModel from JSON file. */
+ public static GDBModel fromJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+ mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+
+ GDBModel mdl;
+ try {
+ JacksonHelper.readAndValidateBasicJsonModelProperties(path, mapper, GDBModel.class.getSimpleName());
+ mdl = mapper.readValue(new File(path.toAbsolutePath().toString()), GDBModel.class);
+ return mdl;
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ return null;
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
index ad35d80..a36feec 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/boosting/GDBTrainer.java
@@ -22,7 +22,6 @@ import java.util.Arrays;
import java.util.List;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.IgniteModel;
-import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory;
import org.apache.ignite.ml.composition.boosting.loss.Loss;
@@ -34,7 +33,6 @@ import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
import org.apache.ignite.ml.environment.logging.MLLogger;
import org.apache.ignite.ml.knn.regression.KNNRegressionTrainer;
-import org.apache.ignite.ml.math.functions.IgniteFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer;
@@ -57,7 +55,7 @@ import org.jetbrains.annotations.NotNull;
*
* But in practice Decision Trees is most used regressors (see: {@link DecisionTreeRegressionTrainer}).
*/
-public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Double> {
+public abstract class GDBTrainer extends DatasetTrainer<GDBModel, Double> {
/** Gradient step. */
private final double gradientStep;
@@ -87,13 +85,13 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
}
/** {@inheritDoc} */
- @Override public <K, V> ModelsComposition fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder,
+ @Override public <K, V> GDBModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder,
Preprocessor<K, V> preprocessor) {
return updateModel(null, datasetBuilder, preprocessor);
}
/** {@inheritDoc} */
- @Override protected <K, V> ModelsComposition updateModel(ModelsComposition mdl,
+ @Override protected <K, V> GDBModel updateModel(GDBModel mdl,
DatasetBuilder<K, V> datasetBuilder,
Preprocessor<K, V> preprocessor) {
if (!learnLabels(datasetBuilder, preprocessor))
@@ -121,7 +119,7 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
List<IgniteModel<Vector, Double>> models;
if (mdl != null)
- models = stgy.update((GDBModel) mdl, datasetBuilder, preprocessor);
+ models = stgy.update(mdl, datasetBuilder, preprocessor);
else
models = stgy.learnModels(datasetBuilder, preprocessor);
@@ -136,7 +134,7 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
}
/** {@inheritDoc} */
- @Override public boolean isUpdateable(ModelsComposition mdl) {
+ @Override public boolean isUpdateable(GDBModel mdl) {
return mdl instanceof GDBModel;
}
@@ -239,35 +237,4 @@ public abstract class GDBTrainer extends DatasetTrainer<ModelsComposition, Doubl
protected GDBLearningStrategy getLearningStrategy() {
return new GDBLearningStrategy();
}
-
- /**
- * GDB model.
- */
- public static final class GDBModel extends ModelsComposition {
- /** Serial version uid. */
- private static final long serialVersionUID = 3476661240155508004L;
-
- /** Internal to external lbl mapping. */
- private final IgniteFunction<Double, Double> internalToExternalLblMapping;
-
- /**
- * Creates an instance of GDBModel.
- *
- * @param models Models.
- * @param predictionsAggregator Predictions aggregator.
- * @param internalToExternalLblMapping Internal to external lbl mapping.
- */
- public GDBModel(List<? extends IgniteModel<Vector, Double>> models,
- WeightedPredictionsAggregator predictionsAggregator,
- IgniteFunction<Double, Double> internalToExternalLblMapping) {
-
- super(models, predictionsAggregator);
- this.internalToExternalLblMapping = internalToExternalLblMapping;
- }
-
- /** {@inheritDoc} */
- @Override public Double predict(Vector features) {
- return internalToExternalLblMapping.apply(super.predict(features));
- }
- }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/PredictionsAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/PredictionsAggregator.java
index d996a2a..1490b7c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/PredictionsAggregator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/PredictionsAggregator.java
@@ -17,11 +17,20 @@
package org.apache.ignite.ml.composition.predictionsaggregator;
+import com.fasterxml.jackson.annotation.JsonSubTypes;
+import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.apache.ignite.ml.math.functions.IgniteFunction;
/**
* Predictions aggregator interface.
*/
+@JsonTypeInfo( use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type")
+@JsonSubTypes(
+ {
+ @JsonSubTypes.Type(value = MeanValuePredictionsAggregator.class, name = "MeanValuePredictionsAggregator"),
+ @JsonSubTypes.Type(value = OnMajorityPredictionsAggregator.class, name = "OnMajorityPredictionsAggregator"),
+ @JsonSubTypes.Type(value = WeightedPredictionsAggregator.class, name = "WeightedPredictionsAggregator"),
+ })
public interface PredictionsAggregator extends IgniteFunction<double[], Double> {
/**
* Represents aggregator as String.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregator.java b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregator.java
index 555ff3c..257c635 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregator.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/composition/predictionsaggregator/WeightedPredictionsAggregator.java
@@ -25,10 +25,13 @@ import org.apache.ignite.internal.util.typedef.internal.A;
*/
public final class WeightedPredictionsAggregator implements PredictionsAggregator {
/** Weights for predictions. */
- private final double[] weights;
+ private double[] weights;
/** Bias. */
- private final double bias;
+ private double bias;
+
+ public WeightedPredictionsAggregator() {
+ }
/**
* Constructs WeightedPredictionsAggregator instance.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JSONModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JSONModel.java
new file mode 100644
index 0000000..ac73398
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JSONModel.java
@@ -0,0 +1,55 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.inference.json;
+
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import org.apache.ignite.ml.IgniteModel;
+
+/** Basic class for all non-trivial model data serialization. */
+public abstract class JSONModel {
+ /** Basic Ignite version. */
+ @JsonIgnore
+ public static final String JSON_MODEL_FORMAT_VERSION = "1";
+
+ /** Ignite version. */
+ public String formatVersion = JSON_MODEL_FORMAT_VERSION;
+
+ /** Timestamp in ms from System.currentTimeMillis() method. */
+ public Long timestamp;
+
+ /** Unique string indetifier. */
+ public String uid;
+
+ /** String description of model class. */
+ public String modelClass;
+
+ /** Convert JSON string to IgniteModel object. */
+ public abstract IgniteModel convert();
+
+ /** */
+ public JSONModel(Long timestamp, String uid, String modelClass) {
+ this.timestamp = timestamp;
+ this.uid = uid;
+ this.modelClass = modelClass;
+ }
+
+ @JsonCreator
+ public JSONModel() {
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JSONModelMixIn.java
similarity index 65%
copy from modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java
copy to modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JSONModelMixIn.java
index 42b0823..843b594 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/ClusterizationModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JSONModelMixIn.java
@@ -15,15 +15,17 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.clustering.kmeans;
+package org.apache.ignite.ml.inference.json;
-import org.apache.ignite.ml.IgniteModel;
+import com.fasterxml.jackson.databind.annotation.JsonAppend;
-/** Base interface for all clusterization models. */
-public interface ClusterizationModel<P, V> extends IgniteModel<P, V> {
- /** Gets the clusters count. */
- public int getAmountOfClusters();
-
- /** Get cluster centers. */
- public P[] getCenters();
-}
+/** Just a mixin class to add a few configuration properties. */
+@JsonAppend(
+ attrs = {
+ @JsonAppend.Attr(value = "formatVersion"),
+ @JsonAppend.Attr(value = "timestamp"),
+ @JsonAppend.Attr(value = "uid"),
+ @JsonAppend.Attr(value = "modelClass")
+ }
+)
+public class JSONModelMixIn { }
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ProbableLabel.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JSONWritable.java
similarity index 57%
copy from modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ProbableLabel.java
copy to modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JSONWritable.java
index 1fee123..fcc3037 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ProbableLabel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JSONWritable.java
@@ -15,25 +15,23 @@
* limitations under the License.
*/
-package org.apache.ignite.ml.knn.ann;
+package org.apache.ignite.ml.inference.json;
-import java.util.TreeMap;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.SerializationFeature;
-/**
- * The special class for fuzzy labels presenting the probability distribution
- * over the class labels.
- */
-public class ProbableLabel {
- /** Key is label, value is probability to be this class */
- TreeMap<Double, Double> clsLbls;
+public interface JSONWritable {
+ default void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
- /**
- * The key is class label,
- * the value is the probability to be an item of this class.
- *
- * @param clsLbls Class labels.
- */
- public ProbableLabel(TreeMap<Double, Double> clsLbls) {
- this.clsLbls = clsLbls;
+ try {
+ File file = new File(path.toAbsolutePath().toString());
+ mapper.writeValue(file, this);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JacksonHelper.java b/modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JacksonHelper.java
new file mode 100644
index 0000000..654ade4
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/inference/json/JacksonHelper.java
@@ -0,0 +1,39 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.inference.json;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import com.fasterxml.jackson.databind.ObjectMapper;
+
+public class JacksonHelper {
+ public static void readAndValidateBasicJsonModelProperties(Path path, ObjectMapper mapper, String className) throws IOException {
+ Map jsonAsMap = mapper.readValue(new File(path.toAbsolutePath().toString()), LinkedHashMap.class);
+ String formatVersion = jsonAsMap.get("formatVersion").toString();
+ Long timestamp = (Long) jsonAsMap.get("timestamp");
+ String uid = jsonAsMap.get("uid").toString();
+ String modelClass = jsonAsMap.get("modelClass").toString();
+
+ if (!modelClass.equals(className)) {
+ throw new IllegalArgumentException("You are trying to load " + modelClass + " model to " + className);
+ }
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java
index 2ad0c46..922630e 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/NNClassificationModel.java
@@ -180,6 +180,17 @@ public abstract class NNClassificationModel implements IgniteModel<Vector, Doubl
return distanceMeasure;
}
+ /** */
+ public int getK() {
+ return k;
+ }
+
+ /** */
+ public boolean isWeighted() {
+ return weighted;
+ }
+
+
/** {@inheritDoc} */
@Override public int hashCode() {
int res = 1;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java
index 2c820b7..6015900 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationModel.java
@@ -17,33 +17,47 @@
package org.apache.ignite.ml.knn.ann;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.ArrayList;
import java.util.Arrays;
+import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeMap;
+import java.util.UUID;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.SerializationFeature;
import org.apache.ignite.ml.Exporter;
+import org.apache.ignite.ml.environment.deploy.DeployableObject;
+import org.apache.ignite.ml.inference.json.JSONModel;
+import org.apache.ignite.ml.inference.json.JSONWritable;
import org.apache.ignite.ml.knn.NNClassificationModel;
+import org.apache.ignite.ml.math.distances.DistanceMeasure;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.structures.LabeledVector;
import org.apache.ignite.ml.structures.LabeledVectorSet;
import org.apache.ignite.ml.util.ModelTrace;
-import org.jetbrains.annotations.NotNull;
/**
* ANN model to predict labels in multi-class classification task.
*/
-public final class ANNClassificationModel extends NNClassificationModel {
+public final class ANNClassificationModel extends NNClassificationModel implements JSONWritable, DeployableObject {
/** */
private static final long serialVersionUID = -127312378991350345L;
/** The labeled set of candidates. */
- private final LabeledVectorSet<LabeledVector> candidates;
+ private LabeledVectorSet<LabeledVector> candidates;
/** Centroid statistics. */
- private final ANNClassificationTrainer.CentroidStat centroindsStat;
+ private ANNClassificationTrainer.CentroidStat centroindsStat;
/**
* Build the model based on a candidates set.
@@ -57,6 +71,10 @@ public final class ANNClassificationModel extends NNClassificationModel {
}
/** */
+ private ANNClassificationModel() {
+ }
+
+ /** */
public LabeledVectorSet<LabeledVector> getCandidates() {
return candidates;
}
@@ -94,7 +112,7 @@ public final class ANNClassificationModel extends NNClassificationModel {
* @param distanceIdxPairs The distance map.
* @return K-nearest neighbors.
*/
- @NotNull private LabeledVector[] getKClosestVectors(
+ private LabeledVector[] getKClosestVectors(
TreeMap<Double, Set<Integer>> distanceIdxPairs) {
LabeledVector[] res;
@@ -129,7 +147,7 @@ public final class ANNClassificationModel extends NNClassificationModel {
* @return Key - distanceMeasure from given features before features with idx stored in value. Value is presented
* with Set because there can be a few vectors with the same distance.
*/
- @NotNull private TreeMap<Double, Set<Integer>> getDistances(Vector v) {
+ private TreeMap<Double, Set<Integer>> getDistances(Vector v) {
TreeMap<Double, Set<Integer>> distanceIdxPairs = new TreeMap<>();
for (int i = 0; i < candidates.rowSize(); i++) {
@@ -203,4 +221,104 @@ public final class ANNClassificationModel extends NNClassificationModel {
.addField("amount of candidates", String.valueOf(candidates.rowSize()))
.toString();
}
+
+ /** {@inheritDoc} */
+ @JsonIgnore
+ @Override public List<Object> getDependencies() {
+ return Collections.emptyList();
+ }
+
+ /** Loads ANNClassificationModel from JSON file. */
+ public static ANNClassificationModel fromJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+
+ ANNJSONExportModel exportModel;
+ try {
+ exportModel = mapper
+ .readValue(new File(path.toAbsolutePath().toString()), ANNJSONExportModel.class);
+
+ return exportModel.convert();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ return null;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+
+ try {
+ ANNJSONExportModel exportModel = new ANNJSONExportModel(System.currentTimeMillis(), "ann_" + UUID.randomUUID().toString(), ANNClassificationModel.class.getSimpleName());
+ List<double[]> listOfCandidates = new ArrayList<>();
+ ProbableLabel[] labels = new ProbableLabel[candidates.rowSize()];
+ for (int i = 0; i < candidates.rowSize(); i++) {
+ labels[i] = (ProbableLabel) candidates.getRow(i).getLb();
+ listOfCandidates.add(candidates.features(i).asArray());
+ }
+
+ exportModel.candidateFeatures = listOfCandidates;
+ exportModel.distanceMeasure = distanceMeasure;
+ exportModel.k = k;
+ exportModel.weighted = weighted;
+ exportModel.candidateLabels = labels;
+ exportModel.centroindsStat = centroindsStat;
+
+ File file = new File(path.toAbsolutePath().toString());
+ mapper.writeValue(file, exportModel);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** */
+ public static class ANNJSONExportModel extends JSONModel {
+ /** Centers of clusters. */
+ public List<double[]> candidateFeatures;
+
+ public ProbableLabel[] candidateLabels;
+
+ /** Distance measure. */
+ public DistanceMeasure distanceMeasure;
+
+ /** Amount of nearest neighbors. */
+ public int k;
+
+ /** kNN strategy. */
+ public boolean weighted;
+
+ /** Centroid statistics. */
+ public ANNClassificationTrainer.CentroidStat centroindsStat;
+
+ /** */
+ public ANNJSONExportModel(Long timestamp, String uid, String modelClass) {
+ super(timestamp, uid, modelClass);
+ }
+
+ /** */
+ @JsonCreator
+ public ANNJSONExportModel() {
+ }
+
+ /** {@inheritDoc} */
+ @Override public ANNClassificationModel convert() {
+ if (candidateFeatures == null || candidateFeatures.isEmpty())
+ throw new IllegalArgumentException("Loaded list of candidates is empty. It should be not empty.");
+
+ double[] firstRow = candidateFeatures.get(0);
+ LabeledVectorSet<LabeledVector> candidatesForANN = new LabeledVectorSet<>(candidateFeatures.size(), firstRow.length);
+ LabeledVector<Double>[] data = new LabeledVector[candidateFeatures.size()];
+ for (int i = 0; i < candidateFeatures.size(); i++) {
+ data[i] = new LabeledVector(VectorUtils.of(candidateFeatures.get(i)), candidateLabels[i]);
+ }
+ candidatesForANN.setData(data);
+
+ ANNClassificationModel mdl = new ANNClassificationModel(candidatesForANN, centroindsStat);
+
+ mdl.withDistanceMeasure(distanceMeasure);
+ mdl.withK(k);
+ mdl.withWeighted(weighted);
+ return mdl;
+ }
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
index 2219222..eec8713 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
@@ -24,6 +24,7 @@ import java.util.TreeMap;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentSkipListSet;
import java.util.stream.Collectors;
+import com.fasterxml.jackson.annotation.JsonIgnore;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.clustering.kmeans.KMeansModel;
import org.apache.ignite.ml.clustering.kmeans.KMeansTrainer;
@@ -139,7 +140,7 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
.withEpsilon(epsilon);
KMeansModel mdl = trainer.fit(datasetBuilder, vectorizer);
- return Arrays.asList(mdl.getCenters());
+ return Arrays.asList(mdl.centers());
}
/** */
@@ -324,16 +325,21 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
/** Service class used for statistics. */
public static class CentroidStat implements Serializable {
/** Serial version uid. */
+ @JsonIgnore
private static final long serialVersionUID = 7624883170532045144L;
+ /** */
+ public CentroidStat() {
+ }
+
/** Count of points closest to the center with a given index. */
- ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat = new ConcurrentHashMap<>();
+ public ConcurrentHashMap<Integer, ConcurrentHashMap<Double, Integer>> centroidStat = new ConcurrentHashMap<>();
/** Count of points closest to the center with a given index. */
- ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap<>();
+ public ConcurrentHashMap<Integer, Integer> counts = new ConcurrentHashMap<>();
/** Set of unique labels. */
- ConcurrentSkipListSet<Double> clsLblsSet = new ConcurrentSkipListSet<>();
+ public ConcurrentSkipListSet<Double> clsLblsSet = new ConcurrentSkipListSet<>();
/** Merge current */
CentroidStat merge(CentroidStat other) {
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ProbableLabel.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ProbableLabel.java
index 1fee123..49f56b8 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ProbableLabel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ProbableLabel.java
@@ -25,7 +25,10 @@ import java.util.TreeMap;
*/
public class ProbableLabel {
/** Key is label, value is probability to be this class */
- TreeMap<Double, Double> clsLbls;
+ public TreeMap<Double, Double> clsLbls;
+
+ public ProbableLabel() {
+ }
/**
* The key is class label,
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/BrayCurtisDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/BrayCurtisDistance.java
index 0b43159..2c32ee6 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/BrayCurtisDistance.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/BrayCurtisDistance.java
@@ -51,4 +51,8 @@ public class BrayCurtisDistance implements DistanceMeasure {
@Override public int hashCode() {
return getClass().hashCode();
}
+
+ @Override public String toString() {
+ return "BrayCurtisDistance{}";
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java
index 392e7b0..4176d97 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/DistanceMeasure.java
@@ -20,6 +20,8 @@ import java.io.Externalizable;
import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
+import com.fasterxml.jackson.annotation.JsonSubTypes;
+import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.apache.ignite.ml.math.exceptions.math.CardinalityException;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
@@ -28,6 +30,21 @@ import org.apache.ignite.ml.math.primitives.vector.impl.DenseVector;
* This class is based on the corresponding class from Apache Common Math lib. Interface for distance measures of
* n-dimensional vectors.
*/
+@JsonTypeInfo( use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type")
+@JsonSubTypes(
+ {
+ @JsonSubTypes.Type(value = BrayCurtisDistance.class, name = "BrayCurtisDistance"),
+ @JsonSubTypes.Type(value = CanberraDistance.class, name = "CanberraDistance"),
+ @JsonSubTypes.Type(value = ChebyshevDistance.class, name = "ChebyshevDistance"),
+ @JsonSubTypes.Type(value = CosineSimilarity.class, name = "CosineSimilarity"),
+ @JsonSubTypes.Type(value = EuclideanDistance.class, name = "EuclideanDistance"),
+ @JsonSubTypes.Type(value = HammingDistance.class, name = "HammingDistance"),
+ @JsonSubTypes.Type(value = JaccardIndex.class, name = "JaccardIndex"),
+ @JsonSubTypes.Type(value = JensenShannonDistance.class, name = "JensenShannonDistance"),
+ @JsonSubTypes.Type(value = ManhattanDistance.class, name = "ManhattanDistance"),
+ @JsonSubTypes.Type(value = MinkowskiDistance.class, name = "MinkowskiDistance"),
+ @JsonSubTypes.Type(value = WeightedMinkowskiDistance.class, name = "WeightedMinkowskiDistance"),
+ })
public interface DistanceMeasure extends Externalizable {
/**
* Compute the distance between two n-dimensional vectors.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/MinkowskiDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/MinkowskiDistance.java
index b382112..20c1c02 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/MinkowskiDistance.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/MinkowskiDistance.java
@@ -17,6 +17,8 @@
package org.apache.ignite.ml.math.distances;
import java.util.Objects;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.ignite.ml.math.exceptions.math.CardinalityException;
import org.apache.ignite.ml.math.functions.IgniteDoubleFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
@@ -35,10 +37,16 @@ public class MinkowskiDistance implements DistanceMeasure {
private final double p;
/** @param p norm */
- public MinkowskiDistance(double p) {
+ @JsonCreator
+ public MinkowskiDistance(@JsonProperty("p")double p) {
this.p = p;
}
+ /** Returns p-norm. */
+ public double getP() {
+ return p;
+ }
+
/** {@inheritDoc} */
@Override public double compute(Vector a, Vector b) throws CardinalityException {
assert a.size() == b.size();
@@ -60,4 +68,10 @@ public class MinkowskiDistance implements DistanceMeasure {
@Override public int hashCode() {
return Objects.hash(p);
}
+
+ @Override public String toString() {
+ return "MinkowskiDistance{" +
+ "p=" + p +
+ '}';
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistance.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistance.java
index 662bf90..61e2125 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistance.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistance.java
@@ -16,8 +16,13 @@
*/
package org.apache.ignite.ml.math.distances;
+import java.util.Arrays;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.annotation.JsonProperty;
import org.apache.ignite.ml.math.exceptions.math.CardinalityException;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.math.util.MatrixUtil;
/**
@@ -29,13 +34,18 @@ public class WeightedMinkowskiDistance implements DistanceMeasure {
*/
private static final long serialVersionUID = 1771556549784040096L;
- private final int p;
+ private int p = 1;
- private final Vector weight;
+ private final double[] weights;
- public WeightedMinkowskiDistance(int p, Vector weight) {
+ @JsonIgnore
+ private final Vector internalWeights;
+
+ @JsonCreator
+ public WeightedMinkowskiDistance(@JsonProperty("p")int p, @JsonProperty("weights")double[] weights) {
this.p = p;
- this.weight = weight.copy().map(x -> Math.pow(Math.abs(x), p));
+ this.weights = weights.clone();
+ internalWeights = VectorUtils.of(weights).copy().map(x -> Math.pow(Math.abs(x), p));
}
/**
@@ -47,12 +57,20 @@ public class WeightedMinkowskiDistance implements DistanceMeasure {
return Math.pow(
MatrixUtil.localCopyOf(a).minus(b)
.map(x -> Math.pow(Math.abs(x), p))
- .times(weight)
+ .times(internalWeights)
.sum(),
1 / (double) p
);
}
+ /** Returns p-norm. */
+ public int getP() {
+ return p;
+ }
+
+ /** Returns weights. */
+ public double[] getWeights() { return weights.clone(); }
+
/**
* {@inheritDoc}
*/
@@ -70,4 +88,11 @@ public class WeightedMinkowskiDistance implements DistanceMeasure {
@Override public int hashCode() {
return getClass().hashCode();
}
+
+ @Override public String toString() {
+ return "WeightedMinkowskiDistance{" +
+ "p=" + p +
+ ", weights=" + Arrays.toString(weights) +
+ '}';
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java
index abd39df..4a915fa 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/stat/DistributionMixture.java
@@ -32,13 +32,13 @@ import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
*/
public abstract class DistributionMixture<C extends Distribution> implements Distribution {
/** Component probabilities. */
- private final Vector componentProbs;
+ private Vector componentProbs;
/** Distributions. */
- private final List<C> distributions;
+ private List<C> distributions;
/** Dimension. */
- private final int dimension;
+ private int dimension;
/**
* Creates an instance of DistributionMixture.
@@ -61,6 +61,9 @@ public abstract class DistributionMixture<C extends Distribution> implements Dis
this.dimension = dimension;
}
+ public DistributionMixture() {
+ }
+
/** {@inheritDoc} */
@Override public double prob(Vector x) {
return likelihood(x).sum();
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModel.java
index 6cdc637..a9fc2d0 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/compound/CompoundNaiveBayesModel.java
@@ -17,14 +17,27 @@
package org.apache.ignite.ml.naivebayes.compound;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
+import java.util.UUID;
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.ObjectWriter;
+import com.fasterxml.jackson.databind.SerializationFeature;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
+import org.apache.ignite.ml.inference.json.JSONModel;
+import org.apache.ignite.ml.inference.json.JSONModelMixIn;
+import org.apache.ignite.ml.inference.json.JSONWritable;
+import org.apache.ignite.ml.inference.json.JacksonHelper;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.naivebayes.discrete.DiscreteNaiveBayesModel;
@@ -34,7 +47,8 @@ import org.apache.ignite.ml.naivebayes.gaussian.GaussianNaiveBayesModel;
* A compound Naive Bayes model which uses a composition of{@code GaussianNaiveBayesModel} and {@code
* DiscreteNaiveBayesModel}.
*/
-public class CompoundNaiveBayesModel implements IgniteModel<Vector, Double>, Exportable<CompoundNaiveBayesModel>, DeployableObject {
+public class CompoundNaiveBayesModel implements IgniteModel<Vector, Double>, Exportable<CompoundNaiveBayesModel>,
+ JSONWritable, DeployableObject {
/** Serial version uid. */
private static final long serialVersionUID = -5045925321135798960L;
@@ -56,6 +70,10 @@ public class CompoundNaiveBayesModel implements IgniteModel<Vector, Double>, Exp
/** Feature ids which should be skipped in Discrete model. */
private Collection<Integer> discreteFeatureIdsToSkip = Collections.emptyList();
+ /** */
+ public CompoundNaiveBayesModel() {
+ }
+
/** {@inheritDoc} */
@Override public <P> void saveModel(Exporter<CompoundNaiveBayesModel, P> exporter, P path) {
exporter.save(this, path);
@@ -91,6 +109,22 @@ public class CompoundNaiveBayesModel implements IgniteModel<Vector, Double>, Exp
return discreteModel;
}
+ public double[] getPriorProbabilities() {
+ return priorProbabilities;
+ }
+
+ public double[] getLabels() {
+ return labels;
+ }
+
+ public Collection<Integer> getGaussianFeatureIdsToSkip() {
+ return gaussianFeatureIdsToSkip;
+ }
+
+ public Collection<Integer> getDiscreteFeatureIdsToSkip() {
+ return discreteFeatureIdsToSkip;
+ }
+
/** Sets prior probabilities. */
public CompoundNaiveBayesModel withPriorProbabilities(double[] priorProbabilities) {
this.priorProbabilities = priorProbabilities.clone();
@@ -155,7 +189,44 @@ public class CompoundNaiveBayesModel implements IgniteModel<Vector, Double>, Exp
}
/** {@inheritDoc} */
+ @JsonIgnore
@Override public List<Object> getDependencies() {
return Arrays.asList(discreteModel, gaussianModel);
}
+
+ /** {@inheritDoc} */
+ @Override public void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+ mapper.addMixIn(CompoundNaiveBayesModel.class, JSONModelMixIn.class);
+
+ ObjectWriter writer = mapper
+ .writerFor(CompoundNaiveBayesModel.class)
+ .withAttribute("formatVersion", JSONModel.JSON_MODEL_FORMAT_VERSION)
+ .withAttribute("timestamp", System.currentTimeMillis())
+ .withAttribute("uid", "dt_" + UUID.randomUUID().toString())
+ .withAttribute("modelClass", CompoundNaiveBayesModel.class.getSimpleName());
+
+ try {
+ File file = new File(path.toAbsolutePath().toString());
+ writer.writeValue(file, this);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** Loads CompoundNaiveBayesModel from JSON file. */
+ public static CompoundNaiveBayesModel fromJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+ mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+
+ CompoundNaiveBayesModel mdl;
+ try {
+ JacksonHelper.readAndValidateBasicJsonModelProperties(path, mapper, CompoundNaiveBayesModel.class.getSimpleName());
+ mdl = mapper.readValue(new File(path.toAbsolutePath().toString()), CompoundNaiveBayesModel.class);
+ return mdl;
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ return null;
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java
index b7eb5d3..3d5edce4 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesModel.java
@@ -17,10 +17,23 @@
package org.apache.ignite.ml.naivebayes.discrete;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
import java.util.Collections;
import java.util.List;
+import java.util.UUID;
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.ObjectWriter;
+import com.fasterxml.jackson.databind.SerializationFeature;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
+import org.apache.ignite.ml.inference.json.JSONModel;
+import org.apache.ignite.ml.inference.json.JSONModelMixIn;
+import org.apache.ignite.ml.inference.json.JSONWritable;
+import org.apache.ignite.ml.inference.json.JacksonHelper;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.naivebayes.BayesModel;
@@ -29,7 +42,8 @@ import org.apache.ignite.ml.naivebayes.BayesModel;
* {@code p(C_k,y) =x_1*p_k1^x *...*x_i*p_ki^x_i}. Where {@code x_i} is a discrete feature, {@code p_ki} is a prior
* probability probability of class {@code p(x|C_k)}. Returns the number of the most possible class.
*/
-public class DiscreteNaiveBayesModel implements BayesModel<DiscreteNaiveBayesModel, Vector, Double>, DeployableObject {
+public class DiscreteNaiveBayesModel implements BayesModel<DiscreteNaiveBayesModel, Vector, Double>,
+ JSONWritable, DeployableObject {
/** Serial version uid. */
private static final long serialVersionUID = -127386523291350345L;
@@ -37,23 +51,23 @@ public class DiscreteNaiveBayesModel implements BayesModel<DiscreteNaiveBayesMod
* Probabilities of features for all classes for each label. {@code labels[c][f][b]} contains a probability for
* class {@code c} for feature {@code f} for bucket {@code b}.
*/
- private final double[][][] probabilities;
+ private double[][][] probabilities;
/** Prior probabilities of each class */
- private final double[] clsProbabilities;
+ private double[] clsProbabilities;
/** Labels. */
- private final double[] labels;
+ private double[] labels;
/**
* The bucket thresholds to convert a features to discrete values. {@code bucketThresholds[f][b]} contains the right
* border for feature {@code f} for bucket {@code b}. Everything which is above the last thresdold goes to the next
* bucket.
*/
- private final double[][] bucketThresholds;
+ private double[][] bucketThresholds;
/** Amount values in each buckek for each feature per label. */
- private final DiscreteNaiveBayesSumsHolder sumsHolder;
+ private DiscreteNaiveBayesSumsHolder sumsHolder;
/**
* @param probabilities Probabilities of features for classes.
@@ -64,13 +78,17 @@ public class DiscreteNaiveBayesModel implements BayesModel<DiscreteNaiveBayesMod
*/
public DiscreteNaiveBayesModel(double[][][] probabilities, double[] clsProbabilities, double[] labels,
double[][] bucketThresholds, DiscreteNaiveBayesSumsHolder sumsHolder) {
- this.probabilities = probabilities;
- this.clsProbabilities = clsProbabilities;
- this.labels = labels;
- this.bucketThresholds = bucketThresholds;
+ this.probabilities = probabilities.clone();
+ this.clsProbabilities = clsProbabilities.clone();
+ this.labels = labels.clone();
+ this.bucketThresholds = bucketThresholds.clone();
this.sumsHolder = sumsHolder;
}
+ /** */
+ public DiscreteNaiveBayesModel() {
+ }
+
/** {@inheritDoc} */
@Override public <P> void saveModel(Exporter<DiscreteNaiveBayesModel, P> exporter, P path) {
exporter.save(this, path);
@@ -111,22 +129,22 @@ public class DiscreteNaiveBayesModel implements BayesModel<DiscreteNaiveBayesMod
/** A getter for probabilities.*/
public double[][][] getProbabilities() {
- return probabilities;
+ return probabilities.clone();
}
/** A getter for clsProbabilities.*/
public double[] getClsProbabilities() {
- return clsProbabilities;
+ return clsProbabilities.clone();
}
/** A getter for bucketThresholds.*/
public double[][] getBucketThresholds() {
- return bucketThresholds;
+ return bucketThresholds.clone();
}
/** A getter for labels.*/
public double[] getLabels() {
- return labels;
+ return labels.clone();
}
/** A getter for sumsHolder.*/
@@ -145,7 +163,44 @@ public class DiscreteNaiveBayesModel implements BayesModel<DiscreteNaiveBayesMod
}
/** {@inheritDoc} */
+ @JsonIgnore
@Override public List<Object> getDependencies() {
return Collections.emptyList();
}
+
+ /** {@inheritDoc} */
+ @Override public void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+ mapper.addMixIn(DiscreteNaiveBayesModel.class, JSONModelMixIn.class);
+
+ ObjectWriter writer = mapper
+ .writerFor(DiscreteNaiveBayesModel.class)
+ .withAttribute("formatVersion", JSONModel.JSON_MODEL_FORMAT_VERSION)
+ .withAttribute("timestamp", System.currentTimeMillis())
+ .withAttribute("uid", "dt_" + UUID.randomUUID().toString())
+ .withAttribute("modelClass", DiscreteNaiveBayesModel.class.getSimpleName());
+
+ try {
+ File file = new File(path.toAbsolutePath().toString());
+ writer.writeValue(file, this);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** Loads DiscreteNaiveBayesModel from JSON file. */
+ public static DiscreteNaiveBayesModel fromJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+ mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+
+ DiscreteNaiveBayesModel mdl;
+ try {
+ JacksonHelper.readAndValidateBasicJsonModelProperties(path, mapper, DiscreteNaiveBayesModel.class.getSimpleName());
+ mdl = mapper.readValue(new File(path.toAbsolutePath().toString()), DiscreteNaiveBayesModel.class);
+ return mdl;
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ return null;
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java
index 50b335e..060d188 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/discrete/DiscreteNaiveBayesSumsHolder.java
@@ -32,6 +32,17 @@ public class DiscreteNaiveBayesSumsHolder implements AutoCloseable, Serializable
/** Rows count for each label */
Map<Double, Integer> featureCountersPerLbl = new HashMap<>();
+ public DiscreteNaiveBayesSumsHolder() {
+ }
+
+ public Map<Double, long[][]> getValuesInBucketPerLbl() {
+ return valuesInBucketPerLbl;
+ }
+
+ public Map<Double, Integer> getFeatureCountersPerLbl() {
+ return featureCountersPerLbl;
+ }
+
/** Merge to current */
DiscreteNaiveBayesSumsHolder merge(DiscreteNaiveBayesSumsHolder other) {
valuesInBucketPerLbl = MapUtil.mergeMaps(valuesInBucketPerLbl, other.valuesInBucketPerLbl, this::sum, HashMap::new);
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
index d0a6470..0627ce5 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesModel.java
@@ -17,10 +17,23 @@
package org.apache.ignite.ml.naivebayes.gaussian;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
import java.util.Collections;
import java.util.List;
+import java.util.UUID;
+import com.fasterxml.jackson.annotation.JsonIgnore;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.ObjectWriter;
+import com.fasterxml.jackson.databind.SerializationFeature;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.environment.deploy.DeployableObject;
+import org.apache.ignite.ml.inference.json.JSONModel;
+import org.apache.ignite.ml.inference.json.JSONModelMixIn;
+import org.apache.ignite.ml.inference.json.JSONWritable;
+import org.apache.ignite.ml.inference.json.JacksonHelper;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.naivebayes.BayesModel;
@@ -28,24 +41,25 @@ import org.apache.ignite.ml.naivebayes.BayesModel;
* Simple naive Bayes model which predicts result value {@code y} belongs to a class {@code C_k, k in [0..K]} as {@code
* p(C_k,y) = p(C_k)*p(y_1,C_k) *...*p(y_n,C_k) / p(y)}. Return the number of the most possible class.
*/
-public class GaussianNaiveBayesModel implements BayesModel<GaussianNaiveBayesModel, Vector, Double>, DeployableObject {
+public class GaussianNaiveBayesModel implements BayesModel<GaussianNaiveBayesModel, Vector, Double>,
+ JSONWritable, DeployableObject {
/** Serial version uid. */
private static final long serialVersionUID = -127386523291350345L;
/** Means of features for all classes. kth row contains means for labels[k] class. */
- private final double[][] means;
+ private double[][] means;
/** Variances of features for all classes. kth row contains variances for labels[k] class */
- private final double[][] variances;
+ private double[][] variances;
/** Prior probabilities of each class */
- private final double[] classProbabilities;
+ private double[] classProbabilities;
/** Labels. */
- private final double[] labels;
+ private double[] labels;
/** Feature sum, squared sum and count per label. */
- private final GaussianNaiveBayesSumsHolder sumsHolder;
+ private GaussianNaiveBayesSumsHolder sumsHolder;
/**
* @param means Means of features for all classes.
@@ -56,13 +70,17 @@ public class GaussianNaiveBayesModel implements BayesModel<GaussianNaiveBayesMod
*/
public GaussianNaiveBayesModel(double[][] means, double[][] variances,
double[] classProbabilities, double[] labels, GaussianNaiveBayesSumsHolder sumsHolder) {
- this.means = means;
- this.variances = variances;
- this.classProbabilities = classProbabilities;
- this.labels = labels;
+ this.means = means.clone();
+ this.variances = variances.clone();
+ this.classProbabilities = classProbabilities.clone();
+ this.labels = labels.clone();
this.sumsHolder = sumsHolder;
}
+ /** */
+ public GaussianNaiveBayesModel() {
+ }
+
/** {@inheritDoc} */
@Override public <P> void saveModel(Exporter<GaussianNaiveBayesModel, P> exporter, P path) {
exporter.save(this, path);
@@ -127,7 +145,44 @@ public class GaussianNaiveBayesModel implements BayesModel<GaussianNaiveBayesMod
}
/** {@inheritDoc} */
+ @JsonIgnore
@Override public List<Object> getDependencies() {
return Collections.emptyList();
}
+
+ /** {@inheritDoc} */
+ @Override public void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+ mapper.addMixIn(GaussianNaiveBayesModel.class, JSONModelMixIn.class);
+
+ ObjectWriter writer = mapper
+ .writerFor(GaussianNaiveBayesModel.class)
+ .withAttribute("formatVersion", JSONModel.JSON_MODEL_FORMAT_VERSION)
+ .withAttribute("timestamp", System.currentTimeMillis())
+ .withAttribute("uid", "dt_" + UUID.randomUUID().toString())
+ .withAttribute("modelClass", GaussianNaiveBayesModel.class.getSimpleName());
+
+ try {
+ File file = new File(path.toAbsolutePath().toString());
+ writer.writeValue(file, this);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** Loads GaussianNaiveBayesModel from JSON file. */
+ public static GaussianNaiveBayesModel fromJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+ mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+
+ GaussianNaiveBayesModel mdl;
+ try {
+ JacksonHelper.readAndValidateBasicJsonModelProperties(path, mapper, GaussianNaiveBayesModel.class.getSimpleName());
+ mdl = mapper.readValue(new File(path.toAbsolutePath().toString()), GaussianNaiveBayesModel.class);
+ return mdl;
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ return null;
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
index 7b95ff8..1d85832 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/naivebayes/gaussian/GaussianNaiveBayesSumsHolder.java
@@ -35,6 +35,21 @@ class GaussianNaiveBayesSumsHolder implements Serializable, AutoCloseable {
/** Rows count for each label */
Map<Double, Integer> featureCountersPerLbl = new HashMap<>();
+ public GaussianNaiveBayesSumsHolder() {
+ }
+
+ public Map<Double, double[]> getFeatureSumsPerLbl() {
+ return featureSumsPerLbl;
+ }
+
+ public Map<Double, double[]> getFeatureSquaredSumsPerLbl() {
+ return featureSquaredSumsPerLbl;
+ }
+
+ public Map<Double, Integer> getFeatureCountersPerLbl() {
+ return featureCountersPerLbl;
+ }
+
/** Merge to current */
GaussianNaiveBayesSumsHolder merge(GaussianNaiveBayesSumsHolder other) {
featureSumsPerLbl = MapUtil.mergeMaps(featureSumsPerLbl, other.featureSumsPerLbl, this::sum, HashMap::new);
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
index 9ecc257..d28a2a9 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainer.java
@@ -72,10 +72,10 @@ public class LinearRegressionLSQRTrainer extends SingleLabelDatasetTrainer<Linea
double[] x0 = null;
if (mdl != null) {
- int x0Size = mdl.getWeights().size() + 1;
- Vector weights = mdl.getWeights().like(x0Size);
- mdl.getWeights().nonZeroes().forEach(ith -> weights.set(ith.index(), ith.get()));
- weights.set(weights.size() - 1, mdl.getIntercept());
+ int x0Size = mdl.weights().size() + 1;
+ Vector weights = mdl.weights().like(x0Size);
+ mdl.weights().nonZeroes().forEach(ith -> weights.set(ith.index(), ith.get()));
+ weights.set(weights.size() - 1, mdl.intercept());
x0 = weights.asArray();
}
res = lsqr.solve(0, 1e-12, 1e-12, 1e8, -1, false, x0);
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java
index 150b6d7..4cb5340 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionModel.java
@@ -17,25 +17,35 @@
package org.apache.ignite.ml.regressions.linear;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
import java.util.Objects;
+import java.util.UUID;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.inference.json.JSONModel;
+import org.apache.ignite.ml.inference.json.JSONWritable;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
/**
* Simple linear regression model which predicts result value Y as a linear combination of input variables:
* Y = weights * X + intercept.
*/
-public final class LinearRegressionModel implements IgniteModel<Vector, Double>, Exportable<LinearRegressionModel> {
+public final class LinearRegressionModel implements IgniteModel<Vector, Double>, Exportable<LinearRegressionModel>,
+ JSONWritable {
/** */
private static final long serialVersionUID = -105984600091550226L;
/** Multiplier of the objects's vector required to make prediction. */
- private final Vector weights;
+ private Vector weights;
/** Intercept of the linear regression model */
- private final double intercept;
+ private double intercept;
/** */
public LinearRegressionModel(Vector weights, double intercept) {
@@ -44,15 +54,41 @@ public final class LinearRegressionModel implements IgniteModel<Vector, Double>,
}
/** */
- public Vector getWeights() {
+ private LinearRegressionModel() {
+ }
+
+ /** */
+ public Vector weights() {
return weights;
}
/** */
- public double getIntercept() {
+ public double intercept() {
return intercept;
}
+ /**
+ * Set up the weights.
+ *
+ * @param weights The parameter value.
+ * @return Model with new weights parameter value.
+ */
+ public LinearRegressionModel withWeights(Vector weights) {
+ this.weights = weights;
+ return this;
+ }
+
+ /**
+ * Set up the intercept.
+ *
+ * @param intercept The parameter value.
+ * @return Model with new intercept parameter value.
+ */
+ public LinearRegressionModel withIntercept(double intercept) {
+ this.intercept = intercept;
+ return this;
+ }
+
/** {@inheritDoc} */
@Override public Double predict(Vector input) {
return input.dot(weights) + intercept;
@@ -108,4 +144,72 @@ public final class LinearRegressionModel implements IgniteModel<Vector, Double>,
@Override public String toString(boolean pretty) {
return toString();
}
+
+ /** Loads LinearRegressionModel from JSON file. */
+ public static LinearRegressionModel fromJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+
+ LinearRegressionModelJSONExportModel linearRegressionJSONExportModel;
+ try {
+ linearRegressionJSONExportModel = mapper
+ .readValue(new File(path.toAbsolutePath().toString()), LinearRegressionModelJSONExportModel.class);
+
+ return linearRegressionJSONExportModel.convert();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+
+ return null;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+ try {
+ LinearRegressionModelJSONExportModel exportModel = new LinearRegressionModelJSONExportModel(
+ System.currentTimeMillis(),
+ "linreg_" + UUID.randomUUID().toString(),
+ LinearRegressionModel.class.getSimpleName()
+ );
+ exportModel.intercept = intercept;
+ exportModel.weights = weights.asArray();
+
+ File file = new File(path.toAbsolutePath().toString());
+ mapper.writeValue(file, exportModel);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** */
+ public static class LinearRegressionModelJSONExportModel extends JSONModel {
+ /**
+ * Multiplier of the objects's vector required to make prediction.
+ */
+ public double[] weights;
+
+ /**
+ * Intercept of the linear regression model.
+ */
+ public double intercept;
+
+ /** */
+ public LinearRegressionModelJSONExportModel(Long timestamp, String uid, String modelClass) {
+ super(timestamp, uid, modelClass);
+ }
+
+ /** */
+ @JsonCreator
+ public LinearRegressionModelJSONExportModel() {
+ }
+
+ /** {@inheritDoc} */
+ @Override public LinearRegressionModel convert() {
+ LinearRegressionModel linRegMdl = new LinearRegressionModel();
+ linRegMdl.withWeights(VectorUtils.of(weights));
+ linRegMdl.withIntercept(intercept);
+
+ return linRegMdl;
+ }
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
index da813fc..d982671 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainer.java
@@ -148,8 +148,8 @@ public class LinearRegressionSGDTrainer<P extends Serializable> extends SingleLa
* @return State of MLP from last learning.
*/
@NotNull private MultilayerPerceptron restoreMLPState(LinearRegressionModel mdl) {
- Vector weights = mdl.getWeights();
- double intercept = mdl.getIntercept();
+ Vector weights = mdl.weights();
+ double intercept = mdl.intercept();
MLPArchitecture architecture1 = new MLPArchitecture(weights.size());
architecture1 = architecture1.withAddedLayer(1, true, Activators.LINEAR);
MLPArchitecture architecture = architecture1;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModel.java
index 52ae0dc..9282dfe 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/regressions/logistic/LogisticRegressionModel.java
@@ -17,16 +17,27 @@
package org.apache.ignite.ml.regressions.logistic;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.Arrays;
import java.util.Objects;
+import java.util.UUID;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.inference.json.JSONModel;
+import org.apache.ignite.ml.inference.json.JSONWritable;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
/**
* Logistic regression (logit model) is a generalized linear model used for binomial regression.
*/
-public final class LogisticRegressionModel implements IgniteModel<Vector, Double>, Exportable<LogisticRegressionModel> {
+public final class LogisticRegressionModel implements IgniteModel<Vector, Double>, Exportable<LogisticRegressionModel>,
+ JSONWritable {
/** */
private static final long serialVersionUID = -133984600091550776L;
@@ -43,6 +54,10 @@ public final class LogisticRegressionModel implements IgniteModel<Vector, Double
private double threshold = 0.5;
/** */
+ private LogisticRegressionModel() {
+ }
+
+ /** */
public LogisticRegressionModel(Vector weights, double intercept) {
this.weights = weights;
this.intercept = intercept;
@@ -201,4 +216,99 @@ public final class LogisticRegressionModel implements IgniteModel<Vector, Double
@Override public String toString(boolean pretty) {
return toString();
}
+
+ /** Loads KMeansModel from JSON file. */
+ public static LogisticRegressionModel fromJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+
+ LogisticRegressionJSONExportModel logisticRegressionJSONExportModel;
+ try {
+ logisticRegressionJSONExportModel = mapper
+ .readValue(new File(path.toAbsolutePath().toString()), LogisticRegressionJSONExportModel.class);
+
+ return logisticRegressionJSONExportModel.convert();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ return null;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+
+ try {
+ LogisticRegressionJSONExportModel exportModel = new LogisticRegressionJSONExportModel(
+ System.currentTimeMillis(),
+ "logReg_" + UUID.randomUUID().toString(),
+ LogisticRegressionModel.class.getSimpleName());
+ exportModel.intercept = intercept;
+ exportModel.isKeepingRawLabels = isKeepingRawLabels;
+ exportModel.threshold = threshold;
+ exportModel.weights = weights.asArray();
+
+ File file = new File(path.toAbsolutePath().toString());
+ mapper.writeValue(file, exportModel);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+
+ }
+
+ /** */
+ public static class LogisticRegressionJSONExportModel extends JSONModel {
+ /**
+ * Multiplier of the objects's vector required to make prediction.
+ */
+ public double[] weights;
+
+ /**
+ * Intercept of the linear regression model.
+ */
+ public double intercept;
+
+ /**
+ * Output label format. 0 and 1 for false value and raw sigmoid regression value otherwise.
+ */
+ public boolean isKeepingRawLabels;
+
+ /**
+ * Threshold to assign '1' label to the observation if raw value more than this threshold.
+ */
+ public double threshold = 0.5;
+
+ /** */
+ public LogisticRegressionJSONExportModel(Long timestamp, String uid, String modelClass) {
+ super(timestamp, uid, modelClass);
+ }
+
+ /** */
+ @JsonCreator
+ public LogisticRegressionJSONExportModel() {
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ return "LogisticRegressionJSONExportModel{" +
+ "weights=" + Arrays.toString(weights) +
+ ", intercept=" + intercept +
+ ", isKeepingRawLabels=" + isKeepingRawLabels +
+ ", threshold=" + threshold +
+ '}';
+ }
+
+ /** {@inheritDoc} */
+ @Override public LogisticRegressionModel convert() {
+ LogisticRegressionModel logRegMdl = new LogisticRegressionModel();
+ logRegMdl.withWeights(VectorUtils.of(weights));
+ logRegMdl.withIntercept(intercept);
+ logRegMdl.withRawLabels(isKeepingRawLabels);
+ logRegMdl.withThreshold(threshold);
+
+ return logRegMdl;
+ }
+ }
}
+
+
+
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/DatasetRow.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/DatasetRow.java
index a0e83c4..d511399 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/DatasetRow.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/DatasetRow.java
@@ -118,4 +118,8 @@ public class DatasetRow<V extends Vector> implements Externalizable {
public void set(int idx, double val) {
vector.set(idx, val);
}
+
+ public V getVector() {
+ return vector;
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java
index c00236f..9b3ddde 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/structures/LabeledVector.java
@@ -98,4 +98,8 @@ public class LabeledVector<L> extends DatasetRow<Vector> {
vector = (Vector)in.readObject();
lb = (L)in.readObject();
}
+
+ public L getLb() {
+ return lb;
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java
index a624504..0cadf53 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationModel.java
@@ -17,16 +17,27 @@
package org.apache.ignite.ml.svm;
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.Arrays;
import java.util.Objects;
+import java.util.UUID;
+import com.fasterxml.jackson.annotation.JsonCreator;
+import com.fasterxml.jackson.databind.ObjectMapper;
import org.apache.ignite.ml.Exportable;
import org.apache.ignite.ml.Exporter;
import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.inference.json.JSONModel;
+import org.apache.ignite.ml.inference.json.JSONWritable;
import org.apache.ignite.ml.math.primitives.vector.Vector;
+import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
/**
* Base class for SVM linear classification model.
*/
-public final class SVMLinearClassificationModel implements IgniteModel<Vector, Double>, Exportable<SVMLinearClassificationModel> {
+public final class SVMLinearClassificationModel implements IgniteModel<Vector, Double>, Exportable<SVMLinearClassificationModel>,
+ JSONWritable {
/** */
private static final long serialVersionUID = -996984622291440226L;
@@ -42,6 +53,9 @@ public final class SVMLinearClassificationModel implements IgniteModel<Vector, D
/** Intercept of the linear regression model. */
private double intercept;
+ public SVMLinearClassificationModel() {
+ }
+
/** */
public SVMLinearClassificationModel(Vector weights, double intercept) {
this.weights = weights;
@@ -190,4 +204,100 @@ public final class SVMLinearClassificationModel implements IgniteModel<Vector, D
@Override public String toString(boolean pretty) {
return toString();
}
+
+ /** Loads SVMLinearClassificationModel from JSON file. */
+ public static SVMLinearClassificationModel fromJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+
+ SVMLinearClassificationJSONExportModel exportModel;
+ try {
+ exportModel = mapper
+ .readValue(new File(path.toAbsolutePath().toString()), SVMLinearClassificationJSONExportModel.class);
+
+ return exportModel.convert();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+
+ return null;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+
+ try {
+ SVMLinearClassificationJSONExportModel exportModel = new SVMLinearClassificationJSONExportModel(
+ System.currentTimeMillis(),
+ "svm_" + UUID.randomUUID().toString(),
+ SVMLinearClassificationModel.class.getSimpleName());
+ exportModel.intercept = intercept;
+ exportModel.isKeepingRawLabels = isKeepingRawLabels;
+ exportModel.threshold = threshold;
+ exportModel.weights = weights.asArray();
+
+ File file = new File(path.toAbsolutePath().toString());
+ mapper.writeValue(file, exportModel);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** */
+ public static class SVMLinearClassificationJSONExportModel extends JSONModel {
+ /**
+ * Multiplier of the objects's vector required to make prediction.
+ */
+ public double[] weights;
+
+ /**
+ * Intercept of the linear regression model.
+ */
+ public double intercept;
+
+ /**
+ * Output label format. 0 and 1 for false value and raw sigmoid regression value otherwise.
+ */
+ public boolean isKeepingRawLabels;
+
+ /**
+ * Threshold to assign '1' label to the observation if raw value more than this threshold.
+ */
+ public double threshold = 0.5;
+
+ /**
+ *
+ */
+ public SVMLinearClassificationJSONExportModel(Long timestamp, String uid, String modelClass) {
+ super(timestamp, uid, modelClass);
+ }
+
+ /**
+ *
+ */
+ @JsonCreator
+ public SVMLinearClassificationJSONExportModel() {
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ return "SVMLinearClassificationJSONExportModel{" +
+ "weights=" + Arrays.toString(weights) +
+ ", intercept=" + intercept +
+ ", isKeepingRawLabels=" + isKeepingRawLabels +
+ ", threshold=" + threshold +
+ '}';
+ }
+
+ /** {@inheritDoc} */
+ @Override public SVMLinearClassificationModel convert() {
+ SVMLinearClassificationModel mdl = new SVMLinearClassificationModel();
+ mdl.withWeights(VectorUtils.of(weights));
+ mdl.withIntercept(intercept);
+ mdl.withRawLabels(isKeepingRawLabels);
+ mdl.withThreshold(threshold);
+
+ return mdl;
+ }
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java
index 5c53fff..266403f 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/svm/SVMLinearClassificationTrainer.java
@@ -120,7 +120,7 @@ public class SVMLinearClassificationTrainer extends SingleLabelDatasetTrainer<SV
} catch (Exception e) {
throw new RuntimeException(e);
}
- return new SVMLinearClassificationModel(weights.viewPart(1, weights.size() - 1), weights.get(0));
+ return new SVMLinearClassificationModel(weights.copyOfRange(1, weights.size()), weights.get(0));
}
/** {@inheritDoc} */
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
index cbd5089..a8daf7c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainer.java
@@ -35,7 +35,7 @@ import org.apache.ignite.ml.tree.leaf.MostCommonDecisionTreeLeafBuilder;
* Decision tree classifier based on distributed decision tree trainer that allows to fit trees using row-partitioned
* dataset.
*/
-public class DecisionTreeClassificationTrainer extends DecisionTree<GiniImpurityMeasure> {
+public class DecisionTreeClassificationTrainer extends DecisionTreeTrainer<GiniImpurityMeasure> {
/**
* Constructs a new decision tree classifier with default impurity function compressor.
*
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
index e26eb6f..b148b31 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeConditionalNode.java
@@ -22,20 +22,20 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
/**
* Decision tree conditional (non-leaf) node.
*/
-public final class DecisionTreeConditionalNode implements DecisionTreeNode {
+public final class DecisionTreeConditionalNode extends DecisionTreeNode {
/** */
private static final long serialVersionUID = 981630737007982172L;
/** Column of the value to be tested. */
- private final int col;
+ private int col;
/** Threshold. */
- private final double threshold;
+ private double threshold;
- /** Node that will be used in case tested value is greater then threshold. */
+ /** Right node that will be used in case tested value is greater then threshold. */
private DecisionTreeNode thenNode;
- /** Node that will be used in case tested value is not greater then threshold. */
+ /** Left node that will be used in case tested value is not greater then threshold. */
private DecisionTreeNode elseNode;
/** Node that will be used in case tested value is not presented. */
@@ -59,6 +59,10 @@ public final class DecisionTreeConditionalNode implements DecisionTreeNode {
this.missingNode = missingNode;
}
+ /** For jackson serialization needs. */
+ public DecisionTreeConditionalNode() {
+ }
+
/** {@inheritDoc} */
@Override public Double predict(Vector features) {
double val = features.get(col);
@@ -120,6 +124,6 @@ public final class DecisionTreeConditionalNode implements DecisionTreeNode {
/** {@inheritDoc} */
@Override public String toString(boolean pretty) {
- return DecisionTree.printTree(this, pretty);
+ return DecisionTreeTrainer.printTree(this, pretty);
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java
index bde8e95..c3fe444 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeLeafNode.java
@@ -22,12 +22,12 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
/**
* Decision tree leaf node which contains value.
*/
-public final class DecisionTreeLeafNode implements DecisionTreeNode {
+public final class DecisionTreeLeafNode extends DecisionTreeNode {
/** */
private static final long serialVersionUID = -472145568088482206L;
/** Value of the node. */
- private final double val;
+ private double val;
/**
* Constructs a new decision tree leaf node.
@@ -38,6 +38,10 @@ public final class DecisionTreeLeafNode implements DecisionTreeNode {
this.val = val;
}
+ /** For jackson serialization needs. */
+ public DecisionTreeLeafNode() {
+ }
+
/** {@inheritDoc} */
@Override public Double predict(Vector doubles) {
return val;
@@ -55,6 +59,6 @@ public final class DecisionTreeLeafNode implements DecisionTreeNode {
/** {@inheritDoc} */
@Override public String toString(boolean pretty) {
- return DecisionTree.printTree(this, pretty);
+ return DecisionTreeTrainer.printTree(this, pretty);
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeModel.java
new file mode 100644
index 0000000..22e361d
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeModel.java
@@ -0,0 +1,111 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.UUID;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.ObjectWriter;
+import com.fasterxml.jackson.databind.SerializationFeature;
+import org.apache.ignite.ml.IgniteModel;
+import org.apache.ignite.ml.inference.json.JSONModel;
+import org.apache.ignite.ml.inference.json.JSONModelMixIn;
+import org.apache.ignite.ml.inference.json.JSONWritable;
+import org.apache.ignite.ml.inference.json.JacksonHelper;
+import org.apache.ignite.ml.math.primitives.vector.Vector;
+
+/**
+ * Base class for decision tree models.
+ */
+public class DecisionTreeModel implements IgniteModel<Vector, Double>, JSONWritable {
+ /** Root node. */
+ private DecisionTreeNode rootNode;
+
+ /**
+ * Creates the model.
+ *
+ * @param rootNode Root node of the tree.
+ */
+ public DecisionTreeModel(DecisionTreeNode rootNode) {
+ this.rootNode = rootNode;
+ }
+
+ /** */
+ private DecisionTreeModel() {
+
+ }
+
+ /** Returns the root node. */
+ public DecisionTreeNode getRootNode() {
+ return rootNode;
+ }
+
+ /** {@inheritDoc} */
+ @Override public Double predict(Vector features) {
+ return rootNode.predict(features);
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString() {
+ return toString(false);
+ }
+
+ /** {@inheritDoc} */
+ @Override public String toString(boolean pretty) {
+ return DecisionTreeTrainer.printTree(rootNode, pretty);
+ }
+
+ /** {@inheritDoc} */
+ @Override public void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+ mapper.addMixIn(DecisionTreeModel.class, JSONModelMixIn.class);
+
+ ObjectWriter writer = mapper
+ .writerFor(DecisionTreeModel.class)
+ .withAttribute("formatVersion", JSONModel.JSON_MODEL_FORMAT_VERSION)
+ .withAttribute("timestamp", System.currentTimeMillis())
+ .withAttribute("uid", "dt_" + UUID.randomUUID().toString())
+ .withAttribute("modelClass", DecisionTreeModel.class.getSimpleName());
+
+ try {
+ File file = new File(path.toAbsolutePath().toString());
+ writer.writeValue(file, this);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** Loads DecisionTreeModel from JSON file. */
+ public static DecisionTreeModel fromJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+ mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+
+ DecisionTreeModel mdl;
+ try {
+ JacksonHelper.readAndValidateBasicJsonModelProperties(path, mapper, DecisionTreeModel.class.getSimpleName());
+ mdl = mapper.readValue(new File(path.toAbsolutePath().toString()), DecisionTreeModel.class);
+ return mdl;
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ return null;
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java
index 80036ba..8d705e4 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeNode.java
@@ -17,11 +17,24 @@
package org.apache.ignite.ml.tree;
+import com.fasterxml.jackson.annotation.JsonSubTypes;
+import com.fasterxml.jackson.annotation.JsonTypeInfo;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.math.primitives.vector.Vector;
/**
* Base interface for decision tree nodes.
*/
-public interface DecisionTreeNode extends IgniteModel<Vector, Double> {
+@JsonTypeInfo( use = JsonTypeInfo.Id.NAME, include = JsonTypeInfo.As.PROPERTY, property = "type")
+@JsonSubTypes(
+ {
+ @JsonSubTypes.Type(value = DecisionTreeLeafNode.class, name = "leaf"),
+ @JsonSubTypes.Type(value = DecisionTreeConditionalNode.class, name = "conditional"),
+ })
+public abstract class DecisionTreeNode implements IgniteModel<Vector, Double> {
+ /**
+ * Empty constructor for serialization needs.
+ */
+ protected DecisionTreeNode() {
+ }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
index 2b259f2..7ae86fc 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainer.java
@@ -31,7 +31,7 @@ import org.apache.ignite.ml.tree.leaf.MeanDecisionTreeLeafBuilder;
* Decision tree regressor based on distributed decision tree trainer that allows to fit trees using row-partitioned
* dataset.
*/
-public class DecisionTreeRegressionTrainer extends DecisionTree<MSEImpurityMeasure> {
+public class DecisionTreeRegressionTrainer extends DecisionTreeTrainer<MSEImpurityMeasure> {
/**
* Constructs a new decision tree regressor with default impurity function compressor.
*
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeTrainer.java
similarity index 92%
rename from modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeTrainer.java
index eb2f1e5..0692ec6 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTree.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/DecisionTreeTrainer.java
@@ -41,7 +41,7 @@ import org.apache.ignite.ml.tree.leaf.DecisionTreeLeafBuilder;
*
* @param <T> Type of impurity measure.
*/
-public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends SingleLabelDatasetTrainer<DecisionTreeNode> {
+public abstract class DecisionTreeTrainer<T extends ImpurityMeasure<T>> extends SingleLabelDatasetTrainer<DecisionTreeModel> {
/** Max tree deep. */
int maxDeep;
@@ -65,8 +65,8 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends SingleL
* @param compressor Impurity function compressor.
* @param decisionTreeLeafBuilder Decision tree leaf builder.
*/
- DecisionTree(int maxDeep, double minImpurityDecrease, StepFunctionCompressor<T> compressor,
- DecisionTreeLeafBuilder decisionTreeLeafBuilder) {
+ DecisionTreeTrainer(int maxDeep, double minImpurityDecrease, StepFunctionCompressor<T> compressor,
+ DecisionTreeLeafBuilder decisionTreeLeafBuilder) {
this.maxDeep = maxDeep;
this.minImpurityDecrease = minImpurityDecrease;
this.compressor = compressor;
@@ -108,7 +108,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends SingleL
}
/** {@inheritDoc} */
- @Override public <K, V> DecisionTreeNode fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder,
+ @Override public <K, V> DecisionTreeModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder,
Preprocessor<K, V> preprocessor) {
try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(
envBuilder,
@@ -124,13 +124,13 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends SingleL
}
/** {@inheritDoc} */
- @Override public boolean isUpdateable(DecisionTreeNode mdl) {
+ @Override public boolean isUpdateable(DecisionTreeModel mdl) {
return true;
}
/** {@inheritDoc} */
- @Override public DecisionTree<T> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
- return (DecisionTree<T>)super.withEnvironmentBuilder(envBuilder);
+ @Override public DecisionTreeTrainer<T> withEnvironmentBuilder(LearningEnvironmentBuilder envBuilder) {
+ return (DecisionTreeTrainer<T>)super.withEnvironmentBuilder(envBuilder);
}
/**
@@ -143,7 +143,7 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends SingleL
* @param <V> Type of a value in {@code upstream} data.
* @return New model based on new dataset.
*/
- @Override protected <K, V> DecisionTreeNode updateModel(DecisionTreeNode mdl,
+ @Override protected <K, V> DecisionTreeModel updateModel(DecisionTreeModel mdl,
DatasetBuilder<K, V> datasetBuilder,
Preprocessor<K, V> preprocessor) {
@@ -151,8 +151,8 @@ public abstract class DecisionTree<T extends ImpurityMeasure<T>> extends SingleL
}
/** */
- public <K, V> DecisionTreeNode fit(Dataset<EmptyContext, DecisionTreeData> dataset) {
- return split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset));
+ public <K, V> DecisionTreeModel fit(Dataset<EmptyContext, DecisionTreeData> dataset) {
+ return new DecisionTreeModel(split(dataset, e -> true, 0, getImpurityMeasureCalculator(dataset)));
}
/**
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/NodeData.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/NodeData.java
new file mode 100644
index 0000000..885a14d
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/NodeData.java
@@ -0,0 +1,90 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree;
+
+import java.util.Map;
+import java.util.NavigableMap;
+
+/**
+ * Presenting decision tree data in plain manner (For example: from one parquet row filled with NodeData in Spark DT model).
+ */
+public class NodeData {
+ /** Id. */
+ public int id;
+
+ /** Prediction. */
+ public double prediction;
+
+ /** Left child id. */
+ public int leftChildId;
+
+ /** Right child id. */
+ public int rightChildId;
+
+ /** Threshold. */
+ public double threshold;
+
+ /** Feature index. */
+ public int featureIdx;
+
+ /** Is leaf node. */
+ public boolean isLeafNode;
+
+ /**{@inheritDoc}*/
+ @Override public String toString() {
+ return "NodeData{" +
+ "id=" + id +
+ ", prediction=" + prediction +
+ ", leftChildId=" + leftChildId +
+ ", rightChildId=" + rightChildId +
+ ", threshold=" + threshold +
+ ", featureIdx=" + featureIdx +
+ ", isLeafNode=" + isLeafNode +
+ '}';
+ }
+
+ /**
+ * Build tree or sub-tree based on indices and nodes sorted map as a dictionary.
+ *
+ * @param nodes The sorted map of nodes.
+ * @param rootNodeData Root node data.
+ */
+ public static DecisionTreeNode buildTree(Map<Integer, NodeData> nodes,
+ NodeData rootNodeData) {
+ return rootNodeData.isLeafNode ? new DecisionTreeLeafNode(rootNodeData.prediction) : new DecisionTreeConditionalNode(rootNodeData.featureIdx,
+ rootNodeData.threshold,
+ buildTree(nodes, nodes.get(rootNodeData.rightChildId)),
+ buildTree(nodes, nodes.get(rootNodeData.leftChildId)),
+ null);
+ }
+
+ /**
+ * Builds the DT model by the given sorted map of nodes.
+ *
+ * @param nodes The sorted map of nodes.
+ */
+ public static DecisionTreeModel buildDecisionTreeModel(Map<Integer, NodeData> nodes) {
+ DecisionTreeModel mdl = null;
+ if (!nodes.isEmpty()) {
+ NodeData rootNodeData = (NodeData)((NavigableMap)nodes).firstEntry().getValue();
+ mdl = new DecisionTreeModel(buildTree(nodes, rootNodeData));
+ return mdl;
+ }
+ return mdl;
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
index 1c25f73..a2438e5 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/boosting/GDBOnTreesLearningStrategy.java
@@ -22,7 +22,7 @@ import java.util.List;
import org.apache.ignite.ml.IgniteModel;
import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.boosting.GDBLearningStrategy;
-import org.apache.ignite.ml.composition.boosting.GDBTrainer;
+import org.apache.ignite.ml.composition.boosting.GDBModel;
import org.apache.ignite.ml.composition.boosting.convergence.ConvergenceChecker;
import org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
@@ -35,7 +35,7 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.trainers.DatasetTrainer;
-import org.apache.ignite.ml.tree.DecisionTree;
+import org.apache.ignite.ml.tree.DecisionTreeTrainer;
import org.apache.ignite.ml.tree.data.DecisionTreeData;
import org.apache.ignite.ml.tree.data.DecisionTreeDataBuilder;
@@ -57,15 +57,15 @@ public class GDBOnTreesLearningStrategy extends GDBLearningStrategy {
}
/** {@inheritDoc} */
- @Override public <K, V> List<IgniteModel<Vector, Double>> update(GDBTrainer.GDBModel mdlToUpdate,
+ @Override public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate,
DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
LearningEnvironment environment = envBuilder.buildForTrainer();
environment.initDeployingContext(vectorizer);
DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
- assert trainer instanceof DecisionTree;
- DecisionTree decisionTreeTrainer = (DecisionTree)trainer;
+ assert trainer instanceof DecisionTreeTrainer;
+ DecisionTreeTrainer decisionTreeTrainer = (DecisionTreeTrainer)trainer;
List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
index fb118ec..ab8db2e 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainer.java
@@ -22,7 +22,6 @@ import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
-import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.feature.FeatureMeta;
@@ -31,7 +30,7 @@ import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartit
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.environment.LearningEnvironmentBuilder;
-import org.apache.ignite.ml.tree.randomforest.data.TreeRoot;
+import org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel;
import org.apache.ignite.ml.tree.randomforest.data.impurity.GiniHistogram;
import org.apache.ignite.ml.tree.randomforest.data.impurity.GiniHistogramsComputer;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityHistogramsComputer;
@@ -98,8 +97,8 @@ public class RandomForestClassifierTrainer
}
/** {@inheritDoc} */
- @Override protected ModelsComposition buildComposition(List<TreeRoot> models) {
- return new ModelsComposition(models, new OnMajorityPredictionsAggregator());
+ @Override protected RandomForestModel buildComposition(List<RandomForestTreeModel> models) {
+ return new RandomForestModel(models, new OnMajorityPredictionsAggregator());
}
/** {@inheritDoc} */
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestModel.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestModel.java
new file mode 100644
index 0000000..1ae9576
--- /dev/null
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestModel.java
@@ -0,0 +1,106 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.ignite.ml.tree.randomforest;
+
+import java.io.File;
+import java.io.IOException;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.List;
+import java.util.UUID;
+import com.fasterxml.jackson.databind.DeserializationFeature;
+import com.fasterxml.jackson.databind.ObjectMapper;
+import com.fasterxml.jackson.databind.ObjectWriter;
+import com.fasterxml.jackson.databind.SerializationFeature;
+import org.apache.ignite.ml.composition.ModelsComposition;
+import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
+import org.apache.ignite.ml.composition.predictionsaggregator.PredictionsAggregator;
+import org.apache.ignite.ml.inference.json.JSONModel;
+import org.apache.ignite.ml.inference.json.JSONModelMixIn;
+import org.apache.ignite.ml.inference.json.JSONWritable;
+import org.apache.ignite.ml.inference.json.JacksonHelper;
+import org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel;
+
+/**
+ * Random Forest Model class.
+ */
+public class RandomForestModel extends ModelsComposition<RandomForestTreeModel> implements JSONWritable {
+ /** Serial version uid. */
+ private static final long serialVersionUID = 3476345240155508004L;
+
+ /** */
+ public RandomForestModel() {
+ super(new ArrayList<>(), new MeanValuePredictionsAggregator());
+
+ }
+
+ /** */
+ public RandomForestModel(List<RandomForestTreeModel> oldModels, PredictionsAggregator predictionsAggregator) {
+ super(oldModels, predictionsAggregator);
+ }
+
+ /**
+ * Returns predictions aggregator.
+ */
+ @Override public PredictionsAggregator getPredictionsAggregator() {
+ return predictionsAggregator;
+ }
+
+ /**
+ * Returns containing models.
+ */
+ @Override public List<RandomForestTreeModel> getModels() {
+ return models;
+ }
+
+ /** {@inheritDoc} */
+ @Override public void toJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper().configure(SerializationFeature.FAIL_ON_EMPTY_BEANS, false);
+ mapper.addMixIn(RandomForestModel.class, JSONModelMixIn.class);
+
+ ObjectWriter writer = mapper
+ .writerFor(RandomForestModel.class)
+ .withAttribute("formatVersion", JSONModel.JSON_MODEL_FORMAT_VERSION)
+ .withAttribute("timestamp", System.currentTimeMillis())
+ .withAttribute("uid", "dt_" + UUID.randomUUID().toString())
+ .withAttribute("modelClass", RandomForestModel.class.getSimpleName());
+
+ try {
+ File file = new File(path.toAbsolutePath().toString());
+ writer.writeValue(file, this);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
+
+ /** Loads RandomForestModel from JSON file. */
+ public static RandomForestModel fromJSON(Path path) {
+ ObjectMapper mapper = new ObjectMapper();
+ mapper.configure(DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES, false);
+
+ RandomForestModel mdl;
+ try {
+ JacksonHelper.readAndValidateBasicJsonModelProperties(path, mapper, RandomForestModel.class.getSimpleName());
+ mdl = mapper.readValue(new File(path.toAbsolutePath().toString()), RandomForestModel.class);
+ return mdl;
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ return null;
+ }
+}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java
index ab1d036..4b0499f 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainer.java
@@ -18,10 +18,9 @@
package org.apache.ignite.ml.tree.randomforest;
import java.util.List;
-import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
import org.apache.ignite.ml.dataset.feature.FeatureMeta;
-import org.apache.ignite.ml.tree.randomforest.data.TreeRoot;
+import org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityHistogramsComputer;
import org.apache.ignite.ml.tree.randomforest.data.impurity.MSEHistogram;
import org.apache.ignite.ml.tree.randomforest.data.impurity.MSEHistogramComputer;
@@ -49,8 +48,8 @@ public class RandomForestRegressionTrainer
}
/** {@inheritDoc} */
- @Override protected ModelsComposition buildComposition(List<TreeRoot> models) {
- return new ModelsComposition(models, new MeanValuePredictionsAggregator());
+ @Override protected RandomForestModel buildComposition(List<RandomForestTreeModel> models) {
+ return new RandomForestModel(models, new MeanValuePredictionsAggregator());
}
/** {@inheritDoc} */
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
index fe860ca..481c22b 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/RandomForestTrainer.java
@@ -30,8 +30,6 @@ import java.util.Set;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
-import org.apache.ignite.ml.IgniteModel;
-import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.dataset.Dataset;
import org.apache.ignite.ml.dataset.DatasetBuilder;
import org.apache.ignite.ml.dataset.feature.BucketMeta;
@@ -41,14 +39,13 @@ import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartit
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
-import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.preprocessing.Preprocessor;
import org.apache.ignite.ml.trainers.SingleLabelDatasetTrainer;
import org.apache.ignite.ml.tree.randomforest.data.FeaturesCountSelectionStrategies;
import org.apache.ignite.ml.tree.randomforest.data.NodeId;
import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
+import org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel;
import org.apache.ignite.ml.tree.randomforest.data.TreeNode;
-import org.apache.ignite.ml.tree.randomforest.data.TreeRoot;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityComputer;
import org.apache.ignite.ml.tree.randomforest.data.impurity.ImpurityHistogramsComputer;
import org.apache.ignite.ml.tree.randomforest.data.statistics.LeafValuesComputer;
@@ -68,7 +65,7 @@ import org.apache.ignite.ml.tree.randomforest.data.statistics.NormalDistribution
* @param <T> Type of child of RandomForestTrainer using in with-methods.
*/
public abstract class RandomForestTrainer<L, S extends ImpurityComputer<BootstrappedVector, S>,
- T extends RandomForestTrainer<L, S, T>> extends SingleLabelDatasetTrainer<ModelsComposition> {
+ T extends RandomForestTrainer<L, S, T>> extends SingleLabelDatasetTrainer<RandomForestModel> {
/** Bucket size factor. */
private static final double BUCKET_SIZE_FACTOR = (1 / 10.0);
@@ -110,9 +107,9 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
}
/** {@inheritDoc} */
- @Override public <K, V> ModelsComposition fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder,
- Preprocessor<K, V> preprocessor) {
- List<TreeRoot> models = null;
+ @Override public <K, V> RandomForestModel fitWithInitializedDeployingContext(DatasetBuilder<K, V> datasetBuilder,
+ Preprocessor<K, V> preprocessor) {
+ List<RandomForestTreeModel> models = null;
try (Dataset<EmptyContext, BootstrappedDatasetPartition> dataset = datasetBuilder.build(
envBuilder,
new EmptyContextBuilder<>(),
@@ -215,9 +212,9 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
* @param dataset Dataset.
* @return list of decision trees.
*/
- private List<TreeRoot> fit(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
+ private List<RandomForestTreeModel> fit(Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
Queue<TreeNode> treesQueue = createRootsQueue();
- ArrayList<TreeRoot> roots = initTrees(treesQueue);
+ ArrayList<RandomForestTreeModel> roots = initTrees(treesQueue);
Map<Integer, BucketMeta> histMeta = computeHistogramMeta(meta, dataset);
if (histMeta.isEmpty())
return Collections.emptyList();
@@ -239,20 +236,20 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
}
/** {@inheritDoc} */
- @Override public boolean isUpdateable(ModelsComposition mdl) {
- ModelsComposition fakeComposition = buildComposition(Collections.emptyList());
+ @Override public boolean isUpdateable(RandomForestModel mdl) {
+ RandomForestModel fakeComposition = buildComposition(Collections.emptyList());
return mdl.getPredictionsAggregator().getClass() == fakeComposition.getPredictionsAggregator().getClass();
}
/** {@inheritDoc} */
- @Override protected <K, V> ModelsComposition updateModel(ModelsComposition mdl, DatasetBuilder<K, V> datasetBuilder,
+ @Override protected <K, V> RandomForestModel updateModel(RandomForestModel mdl, DatasetBuilder<K, V> datasetBuilder,
Preprocessor<K, V> preprocessor) {
- ArrayList<IgniteModel<Vector, Double>> oldModels = new ArrayList<>(mdl.getModels());
- ModelsComposition newModels = fit(datasetBuilder, preprocessor);
+ List<RandomForestTreeModel> oldModels = new ArrayList<>(mdl.getModels());
+ RandomForestModel newModels = fit(datasetBuilder, preprocessor);
oldModels.addAll(newModels.getModels());
- return new ModelsComposition(oldModels, mdl.getPredictionsAggregator());
+ return new RandomForestModel(oldModels, mdl.getPredictionsAggregator());
}
/**
@@ -297,16 +294,16 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
* @param treesQueue Trees queue.
* @return List of trees.
*/
- protected ArrayList<TreeRoot> initTrees(Queue<TreeNode> treesQueue) {
+ protected ArrayList<RandomForestTreeModel> initTrees(Queue<TreeNode> treesQueue) {
assert featuresPerTree > 0;
- ArrayList<TreeRoot> roots = new ArrayList<>();
+ ArrayList<RandomForestTreeModel> roots = new ArrayList<>();
List<Integer> allFeatureIds = IntStream.range(0, meta.size()).boxed().collect(Collectors.toList());
for (TreeNode node : treesQueue) {
Collections.shuffle(allFeatureIds, random);
Set<Integer> featuresSubspace = allFeatureIds.stream()
.limit(featuresPerTree).collect(Collectors.toSet());
- roots.add(new TreeRoot(node, featuresSubspace));
+ roots.add(new RandomForestTreeModel(node, featuresSubspace));
}
return roots;
@@ -394,6 +391,6 @@ public abstract class RandomForestTrainer<L, S extends ImpurityComputer<Bootstra
* @param models Models.
* @return composition of built trees.
*/
- protected abstract ModelsComposition buildComposition(List<TreeRoot> models);
+ protected abstract RandomForestModel buildComposition(List<RandomForestTreeModel> models);
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeId.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeId.java
index f0ecd62..a8bc849 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeId.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeId.java
@@ -29,10 +29,10 @@ public class NodeId implements Serializable {
private static final long serialVersionUID = 4400852013136423333L;
/** Tree id. */
- private final int treeId;
+ private int treeId;
/** Node id. */
- private final long nodeId;
+ private long nodeId;
/**
* Create an instance of NodeId.
@@ -45,11 +45,14 @@ public class NodeId implements Serializable {
this.nodeId = nodeId;
}
+ public NodeId() {
+ }
+
/**
*
* @return Tree id.
*/
- public int treeId() {
+ public int getTreeId() {
return treeId;
}
@@ -57,7 +60,7 @@ public class NodeId implements Serializable {
*
* @return Node id.
*/
- public long nodeId() {
+ public long getNodeId() {
return nodeId;
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java
index 6bdf9a9..8146df0 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/NodeSplit.java
@@ -28,13 +28,16 @@ public class NodeSplit implements Serializable {
private static final long serialVersionUID = 1331311529596106124L;
/** Feature id in feature vector. */
- private final int featureId;
+ private int featureId;
/** Feature split value. */
- private final double val;
+ private double val;
/** Impurity at this split point. */
- private final double impurity;
+ private double impurity;
+
+ public NodeSplit() {
+ }
/**
* Creates an instance of NodeSplit.
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeRoot.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/RandomForestTreeModel.java
similarity index 88%
rename from modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeRoot.java
rename to modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/RandomForestTreeModel.java
index 53a2d66..563080a 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeRoot.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/RandomForestTreeModel.java
@@ -27,12 +27,12 @@ import org.apache.ignite.ml.math.primitives.vector.Vector;
/**
* Tree root class.
*/
-public class TreeRoot implements IgniteModel<Vector, Double> {
+public class RandomForestTreeModel implements IgniteModel<Vector, Double> {
/** Serial version uid. */
private static final long serialVersionUID = 531797299171329057L;
/** Root node. */
- private TreeNode node;
+ private TreeNode rootNode;
/** Used features. */
private Set<Integer> usedFeatures;
@@ -43,14 +43,17 @@ public class TreeRoot implements IgniteModel<Vector, Double> {
* @param root Root.
* @param usedFeatures Used features.
*/
- public TreeRoot(TreeNode root, Set<Integer> usedFeatures) {
- this.node = root;
+ public RandomForestTreeModel(TreeNode root, Set<Integer> usedFeatures) {
+ this.rootNode = root;
this.usedFeatures = usedFeatures;
}
+ public RandomForestTreeModel() {
+ }
+
/** {@inheritDoc} */
@Override public Double predict(Vector vector) {
- return node.predict(vector);
+ return rootNode.predict(vector);
}
/** */
@@ -60,15 +63,15 @@ public class TreeRoot implements IgniteModel<Vector, Double> {
/** */
public TreeNode getRootNode() {
- return node;
+ return rootNode;
}
/**
* @return All leafs in tree.
*/
- public List<TreeNode> getLeafs() {
+ public List<TreeNode> leafs() {
List<TreeNode> res = new ArrayList<>();
- getLeafs(node, res);
+ leafs(rootNode, res);
return res;
}
@@ -76,12 +79,12 @@ public class TreeRoot implements IgniteModel<Vector, Double> {
* @param root Root.
* @param res Result list.
*/
- private void getLeafs(TreeNode root, List<TreeNode> res) {
+ private void leafs(TreeNode root, List<TreeNode> res) {
if (root.getType() == TreeNode.Type.LEAF)
res.add(root);
else {
- getLeafs(root.getLeft(), res);
- getLeafs(root.getRight(), res);
+ leafs(root.getLeft(), res);
+ leafs(root.getRight(), res);
}
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java
index b373596..7a480e6 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/TreeNode.java
@@ -44,7 +44,7 @@ public class TreeNode implements IgniteModel<Vector, Double> {
}
/** Id. */
- private final NodeId id;
+ private NodeId id;
/** Feature id. */
private int featureId;
@@ -81,6 +81,9 @@ public class TreeNode implements IgniteModel<Vector, Double> {
this.depth = 1;
}
+ public TreeNode() {
+ }
+
/** {@inheritDoc} */
@Override public Double predict(Vector features) {
assert type != Type.UNKNOWN;
@@ -125,8 +128,8 @@ public class TreeNode implements IgniteModel<Vector, Double> {
assert type == Type.UNKNOWN;
toLeaf(val);
- left = new TreeNode(2 * id.nodeId(), id.treeId());
- right = new TreeNode(2 * id.nodeId() + 1, id.treeId());
+ left = new TreeNode(2 * id.getNodeId(), id.getTreeId());
+ right = new TreeNode(2 * id.getNodeId() + 1, id.getTreeId());
this.type = Type.CONDITIONAL;
this.featureId = featureId;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java
index bc22ee1..521b426 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/impurity/ImpurityHistogramsComputer.java
@@ -32,8 +32,8 @@ import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.tree.randomforest.data.NodeId;
import org.apache.ignite.ml.tree.randomforest.data.NodeSplit;
+import org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel;
import org.apache.ignite.ml.tree.randomforest.data.TreeNode;
-import org.apache.ignite.ml.tree.randomforest.data.TreeRoot;
/**
* Class containing logic of aggregation impurity statistics within learning dataset.
@@ -52,7 +52,7 @@ public abstract class ImpurityHistogramsComputer<S extends ImpurityComputer<Boot
* @param nodesToLearn Nodes to learn.
* @param dataset Dataset.
*/
- public Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatistics(ArrayList<TreeRoot> roots,
+ public Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatistics(ArrayList<RandomForestTreeModel> roots,
Map<Integer, BucketMeta> histMeta, Map<NodeId, TreeNode> nodesToLearn,
Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
@@ -73,7 +73,7 @@ public abstract class ImpurityHistogramsComputer<S extends ImpurityComputer<Boot
* @return Leaf statistics for impurity computing.
*/
private Map<NodeId, NodeImpurityHistograms<S>> aggregateImpurityStatisticsOnPartition(
- BootstrappedDatasetPartition dataset, ArrayList<TreeRoot> roots,
+ BootstrappedDatasetPartition dataset, ArrayList<RandomForestTreeModel> roots,
Map<Integer, BucketMeta> histMeta,
Map<NodeId, TreeNode> part) {
@@ -85,7 +85,7 @@ public abstract class ImpurityHistogramsComputer<S extends ImpurityComputer<Boot
if (vector.counters()[sampleId] == 0)
continue;
- TreeRoot root = roots.get(sampleId);
+ RandomForestTreeModel root = roots.get(sampleId);
NodeId key = root.getRootNode().predictNextNodeKey(vector.features());
if (!part.containsKey(key)) //if we didn't take all nodes from learning queue
continue;
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer.java b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer.java
index 98c2aba..7c8f7e7 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/tree/randomforest/data/statistics/LeafValuesComputer.java
@@ -30,8 +30,8 @@ import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedDatasetPartit
import org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector;
import org.apache.ignite.ml.dataset.primitive.context.EmptyContext;
import org.apache.ignite.ml.tree.randomforest.data.NodeId;
+import org.apache.ignite.ml.tree.randomforest.data.RandomForestTreeModel;
import org.apache.ignite.ml.tree.randomforest.data.TreeNode;
-import org.apache.ignite.ml.tree.randomforest.data.TreeRoot;
/**
* Class containing logic of leaf values computing after building of all trees in random forest.
@@ -49,11 +49,11 @@ public abstract class LeafValuesComputer<T> implements Serializable {
* @param roots Learned trees.
* @param dataset Dataset.
*/
- public void setValuesForLeaves(ArrayList<TreeRoot> roots,
+ public void setValuesForLeaves(ArrayList<RandomForestTreeModel> roots,
Dataset<EmptyContext, BootstrappedDatasetPartition> dataset) {
Map<NodeId, TreeNode> leafs = roots.stream()
- .flatMap(r -> r.getLeafs().stream())
+ .flatMap(r -> r.leafs().stream())
.collect(Collectors.toMap(TreeNode::getId, Function.identity()));
Map<NodeId, T> stats = dataset.compute(
@@ -78,7 +78,7 @@ public abstract class LeafValuesComputer<T> implements Serializable {
* @param data Data.
* @return Statistics on labels for each leaf nodes.
*/
- private Map<NodeId, T> computeLeafsStatisticsInPartition(ArrayList<TreeRoot> roots,
+ private Map<NodeId, T> computeLeafsStatisticsInPartition(ArrayList<RandomForestTreeModel> roots,
Map<NodeId, TreeNode> leafs, BootstrappedDatasetPartition data) {
Map<NodeId, T> res = new HashMap<>();
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java
index cc652e8..5c7f8da 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansModelTest.java
@@ -54,7 +54,7 @@ public class KMeansModelTest {
Assert.assertEquals(mdl.predict(new DenseVector(new double[]{-1.1, -1.1})), 3.0, PRECISION);
Assert.assertEquals(mdl.distanceMeasure(), distanceMeasure);
- Assert.assertEquals(mdl.getAmountOfClusters(), 4);
- Assert.assertArrayEquals(mdl.getCenters(), centers);
+ Assert.assertEquals(mdl.amountOfClusters(), 4);
+ Assert.assertArrayEquals(mdl.centers(), centers);
}
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java
index 0d35df5..ef33aca 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java
@@ -83,7 +83,7 @@ public class KeepBinaryTest extends GridCommonAbstractTest {
Integer zeroCentre = mdl.predict(VectorUtils.num2Vec(0.0));
- assertTrue(mdl.getCenters()[zeroCentre].get(0) == 0);
+ assertTrue(mdl.centers()[zeroCentre].get(0) == 0);
}
/**
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
index e517050..9bd9509 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/composition/boosting/GDBTrainerTest.java
@@ -32,7 +32,7 @@ import org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder;
import org.apache.ignite.ml.math.functions.IgniteBiFunction;
import org.apache.ignite.ml.math.primitives.vector.Vector;
import org.apache.ignite.ml.math.primitives.vector.VectorUtils;
-import org.apache.ignite.ml.tree.DecisionTreeConditionalNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
import org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer;
import org.apache.ignite.ml.tree.boosting.GDBRegressionOnTreesTrainer;
import org.junit.Test;
@@ -83,7 +83,7 @@ public class GDBTrainerTest extends TrainerTest {
assertTrue(!composition.toString(true).isEmpty());
assertTrue(!composition.toString(false).isEmpty());
- composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeConditionalNode));
+ composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeModel));
assertEquals(2000, composition.getModels().size());
assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator);
@@ -145,7 +145,7 @@ public class GDBTrainerTest extends TrainerTest {
assertTrue(mdl instanceof ModelsComposition);
ModelsComposition composition = (ModelsComposition)mdl;
- composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeConditionalNode));
+ composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeModel));
assertTrue(composition.getModels().size() < 500);
assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java
index 0be0b54..40949c4 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/DistanceTest.java
@@ -43,7 +43,7 @@ public class DistanceTest {
new BrayCurtisDistance(),
new CanberraDistance(),
new JensenShannonDistance(),
- new WeightedMinkowskiDistance(4, new DenseVector(new double[]{1, 1, 1})),
+ new WeightedMinkowskiDistance(4, new double[]{1, 1, 1}),
new MinkowskiDistance(Math.random()));
/** */
@@ -197,9 +197,9 @@ public class DistanceTest {
double precistion = 0.01;
int p = 2;
double expRes = 5.0;
- Vector v = new DenseVector(new double[]{2, 3, 4});
+ double[] weights = new double[]{2, 3, 4};
- DistanceMeasure distanceMeasure = new WeightedMinkowskiDistance(p, v);
+ DistanceMeasure distanceMeasure = new WeightedMinkowskiDistance(p, weights);
assertEquals(expRes, distanceMeasure.compute(v1, data2), precistion);
assertEquals(expRes, distanceMeasure.compute(v1, v2), precistion);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistanceTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistanceTest.java
index 1ab93a1..c6a1d18 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistanceTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/math/distances/WeightedMinkowskiDistanceTest.java
@@ -72,7 +72,7 @@ public class WeightedMinkowskiDistanceTest {
/** */
@Test
public void testWeightedMinkowski() {
- DistanceMeasure distanceMeasure = new WeightedMinkowskiDistance(testData.p, testData.weight);
+ DistanceMeasure distanceMeasure = new WeightedMinkowskiDistance(testData.p, testData.weights);
assertEquals(testData.expRes,
distanceMeasure.compute(testData.vectorA, testData.vectorB), PRECISION);
@@ -87,15 +87,15 @@ public class WeightedMinkowskiDistanceTest {
public final Integer p;
- public final Vector weight;
+ public final double[] weights;
public final Double expRes;
- private TestData(double[] vectorA, double[] vectorB, Integer p, double[] weight, double expRes) {
+ private TestData(double[] vectorA, double[] vectorB, Integer p, double[] weights, double expRes) {
this.vectorA = new DenseVector(vectorA);
this.vectorB = new DenseVector(vectorB);
this.p = p;
- this.weight = new DenseVector(weight);
+ this.weights = weights;
this.expRes = expRes;
}
@@ -104,7 +104,7 @@ public class WeightedMinkowskiDistanceTest {
Arrays.toString(vectorA.asArray()),
Arrays.toString(vectorB.asArray()),
p,
- Arrays.toString(weight.asArray()),
+ Arrays.toString(weights),
expRes
);
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
index 96c7158..a64651a 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionLSQRTrainerTest.java
@@ -59,11 +59,11 @@ public class LinearRegressionLSQRTrainerTest extends TrainerTest {
assertArrayEquals(
new double[]{72.26948107, 15.95144674, 24.07403921, 66.73038781},
- mdl.getWeights().getStorage().data(),
+ mdl.weights().getStorage().data(),
1e-6
);
- assertEquals(2.8421709430404007e-14, mdl.getIntercept(), 1e-6);
+ assertEquals(2.8421709430404007e-14, mdl.intercept(), 1e-6);
}
/**
@@ -95,9 +95,9 @@ public class LinearRegressionLSQRTrainerTest extends TrainerTest {
new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
);
- assertArrayEquals(coef, mdl.getWeights().getStorage().data(), 1e-6);
+ assertArrayEquals(coef, mdl.weights().getStorage().data(), 1e-6);
- assertEquals(intercept, mdl.getIntercept(), 1e-6);
+ assertEquals(intercept, mdl.intercept(), 1e-6);
}
/** */
@@ -142,10 +142,10 @@ public class LinearRegressionLSQRTrainerTest extends TrainerTest {
vectorizer
);
- assertArrayEquals(originalMdl.getWeights().getStorage().data(), updatedOnSameDS.getWeights().getStorage().data(), 1e-6);
- assertEquals(originalMdl.getIntercept(), updatedOnSameDS.getIntercept(), 1e-6);
+ assertArrayEquals(originalMdl.weights().getStorage().data(), updatedOnSameDS.weights().getStorage().data(), 1e-6);
+ assertEquals(originalMdl.intercept(), updatedOnSameDS.intercept(), 1e-6);
- assertArrayEquals(originalMdl.getWeights().getStorage().data(), updatedOnEmptyDS.getWeights().getStorage().data(), 1e-6);
- assertEquals(originalMdl.getIntercept(), updatedOnEmptyDS.getIntercept(), 1e-6);
+ assertArrayEquals(originalMdl.weights().getStorage().data(), updatedOnEmptyDS.weights().getStorage().data(), 1e-6);
+ assertEquals(originalMdl.intercept(), updatedOnEmptyDS.intercept(), 1e-6);
}
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
index 22b16d1..9f50369 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/regressions/linear/LinearRegressionSGDTrainerTest.java
@@ -64,11 +64,11 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest {
assertArrayEquals(
new double[]{72.26948107, 15.95144674, 24.07403921, 66.73038781},
- mdl.getWeights().getStorage().data(),
+ mdl.weights().getStorage().data(),
1e-1
);
- assertEquals(2.8421709430404007e-14, mdl.getIntercept(), 1e-1);
+ assertEquals(2.8421709430404007e-14, mdl.intercept(), 1e-1);
}
/** */
@@ -112,19 +112,19 @@ public class LinearRegressionSGDTrainerTest extends TrainerTest {
);
assertArrayEquals(
- originalMdl.getWeights().getStorage().data(),
- updatedOnSameDS.getWeights().getStorage().data(),
+ originalMdl.weights().getStorage().data(),
+ updatedOnSameDS.weights().getStorage().data(),
1.0
);
- assertEquals(originalMdl.getIntercept(), updatedOnSameDS.getIntercept(), 1.0);
+ assertEquals(originalMdl.intercept(), updatedOnSameDS.intercept(), 1.0);
assertArrayEquals(
- originalMdl.getWeights().getStorage().data(),
- updatedOnEmptyDS.getWeights().getStorage().data(),
+ originalMdl.weights().getStorage().data(),
+ updatedOnEmptyDS.weights().getStorage().data(),
1e-1
);
- assertEquals(originalMdl.getIntercept(), updatedOnEmptyDS.getIntercept(), 1e-1);
+ assertEquals(originalMdl.intercept(), updatedOnEmptyDS.intercept(), 1e-1);
}
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java
index 7122c69..bfccc71 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/selection/cv/CrossValidationTest.java
@@ -31,7 +31,7 @@ import org.apache.ignite.ml.selection.paramgrid.ParamGrid;
import org.apache.ignite.ml.selection.paramgrid.RandomStrategy;
import org.apache.ignite.ml.selection.scoring.metric.MetricName;
import org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer;
-import org.apache.ignite.ml.tree.DecisionTreeNode;
+import org.apache.ignite.ml.tree.DecisionTreeModel;
import org.junit.Test;
import static org.apache.ignite.ml.common.TrainerTest.twoLinearlySeparableClasses;
@@ -53,7 +53,7 @@ public class CrossValidationTest {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
- DebugCrossValidation<DecisionTreeNode, Integer, double[]> scoreCalculator =
+ DebugCrossValidation<DecisionTreeModel, Integer, double[]> scoreCalculator =
new DebugCrossValidation<>();
Vectorizer<Integer, double[], Integer, Double> vectorizer = new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST);
@@ -84,7 +84,7 @@ public class CrossValidationTest {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
- DebugCrossValidation<DecisionTreeNode, Integer, double[]> scoreCalculator =
+ DebugCrossValidation<DecisionTreeModel, Integer, double[]> scoreCalculator =
new DebugCrossValidation<>();
int folds = 4;
@@ -298,7 +298,7 @@ public class CrossValidationTest {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
- DebugCrossValidation<DecisionTreeNode, Integer, double[]> scoreCalculator =
+ DebugCrossValidation<DecisionTreeModel, Integer, double[]> scoreCalculator =
new DebugCrossValidation<>();
int folds = 4;
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
index d64c35e..1c3f140 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerIntegrationTest.java
@@ -72,11 +72,12 @@ public class DecisionTreeClassificationTrainerIntegrationTest extends GridCommon
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0);
- DecisionTreeNode tree = trainer.fit(ignite, data, new DoubleArrayVectorizer<Integer>().labeled(1));
+ DecisionTreeModel tree = trainer.fit(ignite, data, new DoubleArrayVectorizer<Integer>().labeled(1));
- assertTrue(tree instanceof DecisionTreeConditionalNode);
+ DecisionTreeNode decisionTreeNode = tree.getRootNode();
+ assertTrue(decisionTreeNode instanceof DecisionTreeConditionalNode);
- DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree;
+ DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) decisionTreeNode;
assertEquals(0, node.getThreshold(), 1e-3);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
index ed7c4fe..e618f63 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeClassificationTrainerTest.java
@@ -75,11 +75,11 @@ public class DecisionTreeClassificationTrainerTest {
DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(1, 0)
.withUseIndex(useIdx == 1);
- DecisionTreeNode tree = trainer.fit(data, parts, new DoubleArrayVectorizer<Integer>().labeled(1));
+ DecisionTreeNode treeNode = trainer.fit(data, parts, new DoubleArrayVectorizer<Integer>().labeled(1)).getRootNode();
- assertTrue(tree instanceof DecisionTreeConditionalNode);
+ assertTrue(treeNode instanceof DecisionTreeConditionalNode);
- DecisionTreeConditionalNode node = (DecisionTreeConditionalNode)tree;
+ DecisionTreeConditionalNode node = (DecisionTreeConditionalNode)treeNode;
assertEquals(0, node.getThreshold(), 1e-3);
assertEquals(0, node.getCol());
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
index 587dacd..686949f 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerIntegrationTest.java
@@ -78,15 +78,15 @@ public class DecisionTreeRegressionTrainerIntegrationTest extends GridCommonAbst
DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0);
- DecisionTreeNode tree = trainer.fit(
+ DecisionTreeNode treeNode = trainer.fit(
ignite,
data,
new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)
- );
+ ).getRootNode();
- assertTrue(tree instanceof DecisionTreeConditionalNode);
+ assertTrue(treeNode instanceof DecisionTreeConditionalNode);
- DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree;
+ DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) treeNode;
assertEquals(0, node.getThreshold(), 1e-3);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
index 6466350..98e3e7a 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/DecisionTreeRegressionTrainerTest.java
@@ -74,11 +74,11 @@ public class DecisionTreeRegressionTrainerTest {
DecisionTreeRegressionTrainer trainer = new DecisionTreeRegressionTrainer(1, 0)
.withUsingIdx(useIdx == 1);
- DecisionTreeNode tree = trainer.fit(data, parts, new DoubleArrayVectorizer<Integer>().labeled(1));
+ DecisionTreeNode treeNode = trainer.fit(data, parts, new DoubleArrayVectorizer<Integer>().labeled(1)).getRootNode();
- assertTrue(tree instanceof DecisionTreeConditionalNode);
+ assertTrue(treeNode instanceof DecisionTreeConditionalNode);
- DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) tree;
+ DecisionTreeConditionalNode node = (DecisionTreeConditionalNode) treeNode;
assertEquals(0, node.getThreshold(), 1e-3);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
index c94799a..cb5961d 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestClassifierTrainerTest.java
@@ -22,7 +22,6 @@ import java.util.HashMap;
import java.util.Map;
import org.apache.ignite.ml.TestUtils;
import org.apache.ignite.ml.common.TrainerTest;
-import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator;
import org.apache.ignite.ml.dataset.feature.FeatureMeta;
import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
@@ -56,12 +55,12 @@ public class RandomForestClassifierTrainerTest extends TrainerTest {
ArrayList<FeatureMeta> meta = new ArrayList<>();
for (int i = 0; i < 4; i++)
meta.add(new FeatureMeta("", i, false));
- DatasetTrainer<ModelsComposition, Double> trainer = new RandomForestClassifierTrainer(meta)
+ DatasetTrainer<RandomForestModel, Double> trainer = new RandomForestClassifierTrainer(meta)
.withAmountOfTrees(5)
.withFeaturesCountSelectionStrgy(x -> 2)
.withEnvironmentBuilder(TestUtils.testEnvBuilder());
- ModelsComposition mdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
+ RandomForestModel mdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
assertTrue(mdl.getPredictionsAggregator() instanceof OnMajorityPredictionsAggregator);
assertEquals(5, mdl.getModels().size());
@@ -84,14 +83,14 @@ public class RandomForestClassifierTrainerTest extends TrainerTest {
ArrayList<FeatureMeta> meta = new ArrayList<>();
for (int i = 0; i < 4; i++)
meta.add(new FeatureMeta("", i, false));
- DatasetTrainer<ModelsComposition, Double> trainer = new RandomForestClassifierTrainer(meta)
+ DatasetTrainer<RandomForestModel, Double> trainer = new RandomForestClassifierTrainer(meta)
.withAmountOfTrees(100)
.withFeaturesCountSelectionStrgy(x -> 2)
.withEnvironmentBuilder(TestUtils.testEnvBuilder());
- ModelsComposition originalMdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
- ModelsComposition updatedOnSameDS = trainer.update(originalMdl, sample, parts, new LabeledDummyVectorizer<>());
- ModelsComposition updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<Integer, LabeledVector<Double>>(), parts, new LabeledDummyVectorizer<>());
+ RandomForestModel originalMdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
+ RandomForestModel updatedOnSameDS = trainer.update(originalMdl, sample, parts, new LabeledDummyVectorizer<>());
+ RandomForestModel updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<Integer, LabeledVector<Double>>(), parts, new LabeledDummyVectorizer<>());
Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
assertEquals(originalMdl.predict(v), updatedOnSameDS.predict(v), 0.01);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestIntegrationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestIntegrationTest.java
index 8bb0894..dc2be85 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestIntegrationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestIntegrationTest.java
@@ -24,7 +24,6 @@ import org.apache.ignite.IgniteCache;
import org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction;
import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.internal.util.IgniteUtils;
-import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
import org.apache.ignite.ml.dataset.feature.FeatureMeta;
import org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer;
@@ -85,7 +84,7 @@ public class RandomForestIntegrationTest extends GridCommonAbstractTest {
.withAmountOfTrees(5)
.withFeaturesCountSelectionStrgy(x -> 2);
- ModelsComposition mdl = trainer.fit(ignite, data, new DoubleArrayVectorizer<Integer>().labeled(1));
+ RandomForestModel mdl = trainer.fit(ignite, data, new DoubleArrayVectorizer<Integer>().labeled(1));
assertTrue(mdl.getPredictionsAggregator() instanceof MeanValuePredictionsAggregator);
assertEquals(5, mdl.getModels().size());
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
index 8ea027f..d501dba 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/RandomForestRegressionTrainerTest.java
@@ -21,7 +21,6 @@ import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;
import org.apache.ignite.ml.common.TrainerTest;
-import org.apache.ignite.ml.composition.ModelsComposition;
import org.apache.ignite.ml.composition.predictionsaggregator.MeanValuePredictionsAggregator;
import org.apache.ignite.ml.dataset.feature.FeatureMeta;
import org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer;
@@ -58,7 +57,7 @@ public class RandomForestRegressionTrainerTest extends TrainerTest {
.withAmountOfTrees(5)
.withFeaturesCountSelectionStrgy(x -> 2);
- ModelsComposition mdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
+ RandomForestModel mdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
assertTrue(mdl.getPredictionsAggregator() instanceof MeanValuePredictionsAggregator);
assertEquals(5, mdl.getModels().size());
}
@@ -84,9 +83,9 @@ public class RandomForestRegressionTrainerTest extends TrainerTest {
.withAmountOfTrees(100)
.withFeaturesCountSelectionStrgy(x -> 2);
- ModelsComposition originalMdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
- ModelsComposition updatedOnSameDS = trainer.update(originalMdl, sample, parts, new LabeledDummyVectorizer<>());
- ModelsComposition updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<Double, LabeledVector<Double>>(), parts, new LabeledDummyVectorizer<>());
+ RandomForestModel originalMdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
+ RandomForestModel updatedOnSameDS = trainer.update(originalMdl, sample, parts, new LabeledDummyVectorizer<>());
+ RandomForestModel updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<Double, LabeledVector<Double>>(), parts, new LabeledDummyVectorizer<>());
Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
assertEquals(originalMdl.predict(v), updatedOnSameDS.predict(v), 0.1);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/TreeNodeTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/TreeNodeTest.java
index 0b199ff..0550eca 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/TreeNodeTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/tree/randomforest/data/TreeNodeTest.java
@@ -38,8 +38,8 @@ public class TreeNodeTest {
TreeNode node = new TreeNode(5, 1);
assertEquals(TreeNode.Type.UNKNOWN, node.getType());
- assertEquals(5, node.predictNextNodeKey(features1).nodeId());
- assertEquals(5, node.predictNextNodeKey(features2).nodeId());
+ assertEquals(5, node.predictNextNodeKey(features1).getNodeId());
+ assertEquals(5, node.predictNextNodeKey(features2).getNodeId());
}
/** */
@@ -49,8 +49,8 @@ public class TreeNodeTest {
node.toLeaf(0.5);
assertEquals(TreeNode.Type.LEAF, node.getType());
- assertEquals(5, node.predictNextNodeKey(features1).nodeId());
- assertEquals(5, node.predictNextNodeKey(features2).nodeId());
+ assertEquals(5, node.predictNextNodeKey(features1).getNodeId());
+ assertEquals(5, node.predictNextNodeKey(features2).getNodeId());
}
/** */
@@ -60,8 +60,8 @@ public class TreeNodeTest {
root.toConditional(0, 0.1);
assertEquals(TreeNode.Type.CONDITIONAL, root.getType());
- assertEquals(2, root.predictNextNodeKey(features1).nodeId());
- assertEquals(3, root.predictNextNodeKey(features2).nodeId());
+ assertEquals(2, root.predictNextNodeKey(features1).getNodeId());
+ assertEquals(3, root.predictNextNodeKey(features2).getNodeId());
}
/** */
@@ -69,7 +69,7 @@ public class TreeNodeTest {
public void testPredictProba() {
TreeNode root = new TreeNode(1, 1);
List<TreeNode> leaves = root.toConditional(0, 0.1);
- leaves.forEach(leaf -> leaf.toLeaf(leaf.getId().nodeId() % 2));
+ leaves.forEach(leaf -> leaf.toLeaf(leaf.getId().getNodeId() % 2));
assertEquals(TreeNode.Type.CONDITIONAL, root.getType());
assertEquals(0.0, root.predict(features1), 0.001);