You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by ch...@apache.org on 2019/01/23 08:23:52 UTC
[ignite] branch master updated: IGNITE-11010: [ML] Use seed from
learningEnviroment for KMeans trainer
This is an automated email from the ASF dual-hosted git repository.
chief pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/ignite.git
The following commit(s) were added to refs/heads/master by this push:
new 4b0d581 IGNITE-11010: [ML] Use seed from learningEnviroment for KMeans trainer
4b0d581 is described below
commit 4b0d5818697db7be99fddf3d186cc6fa42405a56
Author: YuriBabak <y....@gmail.com>
AuthorDate: Wed Jan 23 11:23:33 2019 +0300
IGNITE-11010: [ML] Use seed from learningEnviroment for KMeans trainer
This closes #5884
---
.../ml/clustering/KMeansClusterizationExample.java | 3 +--
.../examples/ml/knn/ANNClassificationExample.java | 1 -
.../ignite/ml/clustering/kmeans/KMeansTrainer.java | 27 ++--------------------
.../ml/knn/ann/ANNClassificationTrainer.java | 24 -------------------
.../ml/knn/regression/KNNRegressionTrainer.java | 2 +-
.../java/org/apache/ignite/ml/nn/MLPTrainer.java | 2 +-
.../ignite/ml/clustering/KMeansTrainerTest.java | 4 +---
.../apache/ignite/ml/common/KeepBinaryTest.java | 2 +-
.../ignite/ml/knn/ANNClassificationTest.java | 10 ++++----
9 files changed, 11 insertions(+), 64 deletions(-)
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
index 46550f3..e748f4d 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/clustering/KMeansClusterizationExample.java
@@ -57,8 +57,7 @@ public class KMeansClusterizationExample {
IgniteCache<Integer, Vector> dataCache = new SandboxMLCache(ignite)
.fillCacheWith(MLSandboxDatasets.TWO_CLASSED_IRIS);
- KMeansTrainer trainer = new KMeansTrainer()
- .withSeed(7867L);
+ KMeansTrainer trainer = new KMeansTrainer();
KMeansModel mdl = trainer.fit(
ignite,
diff --git a/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java b/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
index 71546e9..a5d15d1 100644
--- a/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
+++ b/examples/src/main/java/org/apache/ignite/examples/ml/knn/ANNClassificationExample.java
@@ -65,7 +65,6 @@ public class ANNClassificationExample {
.withDistance(new ManhattanDistance())
.withK(50)
.withMaxIterations(1000)
- .withSeed(1234L)
.withEpsilon(1e-2);
long startTrainingTime = System.currentTimeMillis();
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
index 3206b5f..4bd017f 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/clustering/kmeans/KMeansTrainer.java
@@ -61,9 +61,6 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
/** Distance measure. */
private DistanceMeasure distance = new EuclideanDistance();
- /** KMeans initializer. */
- private long seed;
-
/**
* Trains model based on the specified data.
*
@@ -235,7 +232,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
if (data.rowSize() != 0) {
if (data.rowSize() > k) { // If it's enough rows in partition to pick k vectors.
- final Random random = new Random(seed);
+ final Random random = environment.randomNumbersGenerator();
for (int i = 0; i < k; i++) {
Set<Integer> uniqueIndices = new HashSet<>();
@@ -272,7 +269,7 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
// Pick k vectors randomly.
if (rndPnts.size() >= k) {
for (int i = 0; i < k; i++) {
- final LabeledVector rndPnt = rndPnts.get(new Random(seed).nextInt(rndPnts.size()));
+ final LabeledVector rndPnt = rndPnts.get(environment.randomNumbersGenerator().nextInt(rndPnts.size()));
rndPnts.remove(rndPnt);
initCenters[i] = rndPnt.features();
}
@@ -394,24 +391,4 @@ public class KMeansTrainer extends SingleLabelDatasetTrainer<KMeansModel> {
this.distance = distance;
return this;
}
-
- /**
- * Gets the seed number.
- *
- * @return The parameter value.
- */
- public long getSeed() {
- return seed;
- }
-
- /**
- * Set up the seed.
- *
- * @param seed The parameter value.
- * @return Model with new seed parameter value.
- */
- public KMeansTrainer withSeed(long seed) {
- this.seed = seed;
- return this;
- }
}
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
index 0cdfc52..2da09db 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/ann/ANNClassificationTrainer.java
@@ -60,9 +60,6 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
/** Distance measure. */
private DistanceMeasure distance = new EuclideanDistance();
- /** KMeans initializer. */
- private long seed;
-
/**
* Trains model based on the specified data.
*
@@ -140,7 +137,6 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
KMeansTrainer trainer = new KMeansTrainer()
.withAmountOfClusters(k)
.withMaxIterations(maxIterations)
- .withSeed(seed)
.withDistance(distance)
.withEpsilon(epsilon);
@@ -334,26 +330,6 @@ public class ANNClassificationTrainer extends SingleLabelDatasetTrainer<ANNClass
return this;
}
- /**
- * Gets the seed number.
- *
- * @return The parameter value.
- */
- public long getSeed() {
- return seed;
- }
-
- /**
- * Set up the seed.
- *
- * @param seed The parameter value.
- * @return Model with new seed parameter value.
- */
- public ANNClassificationTrainer withSeed(long seed) {
- this.seed = seed;
- return this;
- }
-
/** Service class used for statistics. */
public static class CentroidStat implements Serializable {
/** Serial version uid. */
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
index e621801..111f1bb 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/knn/regression/KNNRegressionTrainer.java
@@ -35,7 +35,7 @@ public class KNNRegressionTrainer extends SingleLabelDatasetTrainer<KNNRegressio
* @param lbExtractor Label extractor.
* @return Model.
*/
- public <K, V> KNNRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
+ @Override public <K, V> KNNRegressionModel fit(DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, Double> lbExtractor) {
return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
index cf511ec..43df304 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/nn/MLPTrainer.java
@@ -108,7 +108,7 @@ public class MLPTrainer<P extends Serializable> extends MultiLabelDatasetTrainer
}
/** {@inheritDoc} */
- public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder,
+ @Override public <K, V> MultilayerPerceptron fit(DatasetBuilder<K, V> datasetBuilder,
IgniteBiFunction<K, V, Vector> featureExtractor, IgniteBiFunction<K, V, double[]> lbExtractor) {
return updateModel(null, datasetBuilder, featureExtractor, lbExtractor);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java
index e33ad08..fe4c74d 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/clustering/KMeansTrainerTest.java
@@ -109,10 +109,8 @@ public class KMeansTrainerTest extends TrainerTest {
.withDistance(new EuclideanDistance())
.withAmountOfClusters(10)
.withMaxIterations(1)
- .withEpsilon(PRECISION)
- .withSeed(2);
+ .withEpsilon(PRECISION);
assertEquals(10, trainer.getAmountOfClusters());
- assertEquals(2, trainer.getSeed());
assertTrue(trainer.getDistance() instanceof EuclideanDistance);
return trainer;
}
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java
index 1d1103f..bc2e3d5 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/common/KeepBinaryTest.java
@@ -81,7 +81,7 @@ public class KeepBinaryTest extends GridCommonAbstractTest {
IgniteBiFunction<Integer, BinaryObject, Double> lbExtractor = (k, v) -> (double) v.field("label");
- KMeansTrainer trainer = new KMeansTrainer().withSeed(123L);
+ KMeansTrainer trainer = new KMeansTrainer();
CacheBasedDatasetBuilder<Integer, BinaryObject> datasetBuilder =
new CacheBasedDatasetBuilder<>(ignite, dataCache).withKeepBinary(true);
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
index 9c75824..2f779a2 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/knn/ANNClassificationTest.java
@@ -44,14 +44,12 @@ public class ANNClassificationTest extends TrainerTest {
.withK(10)
.withMaxIterations(10)
.withEpsilon(1e-4)
- .withDistance(new EuclideanDistance())
- .withSeed(1234L);
+ .withDistance(new EuclideanDistance());
Assert.assertEquals(10, trainer.getK());
Assert.assertEquals(10, trainer.getMaxIterations());
TestUtils.assertEquals(1e-4, trainer.getEpsilon(), PRECISION);
Assert.assertEquals(new EuclideanDistance(), trainer.getDistance());
- Assert.assertEquals(1234L, trainer.getSeed());
NNClassificationModel mdl = trainer.fit(
cacheMock,
@@ -83,7 +81,7 @@ public class ANNClassificationTest extends TrainerTest {
.withEpsilon(1e-4)
.withDistance(new EuclideanDistance());
- ANNClassificationModel originalMdl = (ANNClassificationModel) trainer.withSeed(1234L).fit(
+ ANNClassificationModel originalMdl = (ANNClassificationModel) trainer.fit(
cacheMock,
parts,
(k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 1, v.length)),
@@ -92,7 +90,7 @@ public class ANNClassificationTest extends TrainerTest {
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(NNStrategy.SIMPLE);
- ANNClassificationModel updatedOnSameDataset = (ANNClassificationModel) trainer.withSeed(1234L).update(originalMdl,
+ ANNClassificationModel updatedOnSameDataset = (ANNClassificationModel) trainer.update(originalMdl,
cacheMock, parts,
(k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
(k, v) -> v[2]
@@ -100,7 +98,7 @@ public class ANNClassificationTest extends TrainerTest {
.withDistanceMeasure(new EuclideanDistance())
.withStrategy(NNStrategy.SIMPLE);
- ANNClassificationModel updatedOnEmptyDataset = (ANNClassificationModel) trainer.withSeed(1234L).update(originalMdl,
+ ANNClassificationModel updatedOnEmptyDataset = (ANNClassificationModel) trainer.update(originalMdl,
new HashMap<Integer, double[]>(), parts,
(k, v) -> VectorUtils.of(Arrays.copyOfRange(v, 0, v.length - 1)),
(k, v) -> v[2]