You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@ignite.apache.org by sb...@apache.org on 2017/12/04 07:52:40 UTC
[2/9] ignite git commit: IGNITE-7007: Decision tree code cleanup
IGNITE-7007: Decision tree code cleanup
This closes #3084
Project: http://git-wip-us.apache.org/repos/asf/ignite/repo
Commit: http://git-wip-us.apache.org/repos/asf/ignite/commit/a29fe352
Tree: http://git-wip-us.apache.org/repos/asf/ignite/tree/a29fe352
Diff: http://git-wip-us.apache.org/repos/asf/ignite/diff/a29fe352
Branch: refs/heads/ignite-zk
Commit: a29fe352de4fa3f66f471a4315fff097fe06c786
Parents: 3979e6a
Author: artemmalykh <am...@gridgain.com>
Authored: Fri Dec 1 20:54:59 2017 +0300
Committer: Yury Babak <yb...@gridgain.com>
Committed: Fri Dec 1 20:54:59 2017 +0300
----------------------------------------------------------------------
.../ignite/ml/math/distributed/CacheUtils.java | 2 --
.../columnbased/ColumnDecisionTreeTrainer.java | 33 +++++++++++++-------
.../org/apache/ignite/ml/util/MnistUtils.java | 17 +++++-----
.../java/org/apache/ignite/ml/util/Utils.java | 6 ++--
.../ml/trees/ColumnDecisionTreeTrainerTest.java | 3 +-
.../ColumnDecisionTreeTrainerBenchmark.java | 31 +++++++++---------
6 files changed, 51 insertions(+), 41 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
index 6baa865..9ca167c 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/math/distributed/CacheUtils.java
@@ -484,9 +484,7 @@ public class CacheUtils {
m.put(k, v);
}
- long before = System.currentTimeMillis();
cache.putAll(m);
- System.out.println("PutAll took: " + (System.currentTimeMillis() - before));
});
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
index 32e33f3..fec0a83 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/trees/trainers/columnbased/ColumnDecisionTreeTrainer.java
@@ -26,6 +26,7 @@ import java.util.HashMap;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
@@ -37,12 +38,12 @@ import java.util.stream.Stream;
import javax.cache.Cache;
import org.apache.ignite.Ignite;
import org.apache.ignite.IgniteCache;
+import org.apache.ignite.IgniteLogger;
import org.apache.ignite.Ignition;
import org.apache.ignite.cache.CachePeekMode;
import org.apache.ignite.cache.affinity.Affinity;
import org.apache.ignite.cluster.ClusterNode;
import org.apache.ignite.internal.processors.cache.CacheEntryImpl;
-import org.apache.ignite.internal.util.typedef.X;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.Trainer;
import org.apache.ignite.ml.math.Vector;
@@ -115,6 +116,9 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
/** Ignite instance. */
private final Ignite ignite;
+ /** Logger */
+ private final IgniteLogger log;
+
/**
* Construct {@link ColumnDecisionTreeTrainer}.
*
@@ -135,6 +139,7 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
this.categoricalCalculatorProvider = categoricalCalculatorProvider;
this.regCalc = regCalc;
this.ignite = ignite;
+ this.log = ignite.log();
}
/**
@@ -329,7 +334,8 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
regsCnt++;
- X.println(">>> Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt);
+ if (log.isDebugEnabled())
+ log.debug("Globally best: " + best.info + " idx time: " + findBestRegIdx + ", calculate best: " + findBestSplit + " fi: " + best.featureIdx + ", regs: " + regsCnt);
// Request bitset for split region.
int ind = best.info.regionIndex();
@@ -361,8 +367,10 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
if (d > curDepth) {
curDepth = d;
- X.println(">>> Depth: " + curDepth);
- X.println(">>> Cache size: " + prjsCache.size(CachePeekMode.PRIMARY));
+ if (log.isDebugEnabled()) {
+ log.debug("Depth: " + curDepth);
+ log.debug("Cache size: " + prjsCache.size(CachePeekMode.PRIMARY));
+ }
}
before = System.currentTimeMillis();
@@ -415,16 +423,19 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
},
bestRegsKeys);
- X.println(">>> Update of projs cache took " + (System.currentTimeMillis() - before));
+ if (log.isDebugEnabled())
+ log.debug("Update of projections cache time: " + (System.currentTimeMillis() - before));
before = System.currentTimeMillis();
updateSplitCache(ind, rc, featuresCnt, ig -> i -> input.affinityKey(i, ig), uuid);
- X.println(">>> Update of split cache took " + (System.currentTimeMillis() - before));
+ if (log.isDebugEnabled())
+ log.debug("Update of split cache time: " + (System.currentTimeMillis() - before));
}
else {
- X.println(">>> Best feature index: " + bestFeatureIdx + ", best infoGain " + bestInfoGain);
+ if (log.isDebugEnabled())
+ log.debug("Best split [bestFeatureIdx=" + bestFeatureIdx + ", bestInfoGain=" + bestInfoGain + "]");
break;
}
}
@@ -541,15 +552,15 @@ public class ColumnDecisionTreeTrainer<D extends ContinuousRegionInfo> implement
double[] values = ctx.values(fIdx, ign);
double[] labels = ctx.labels();
- IgniteBiTuple<Integer, Double> max = toCompare.entrySet().stream().
+ Optional<IgniteBiTuple<Integer, Double>> max = toCompare.entrySet().stream().
map(ent -> {
SplitInfo bestSplit = ctx.featureProcessor(fIdx).findBestSplit(ent.getValue(), values, labels, ent.getKey());
return new IgniteBiTuple<>(ent.getKey(), bestSplit != null ? bestSplit.infoGain() : Double.NEGATIVE_INFINITY);
}).
- max(Comparator.comparingDouble(IgniteBiTuple::get2)).
- get();
+ max(Comparator.comparingDouble(IgniteBiTuple::get2));
- return Stream.of(new CacheEntryImpl<>(e.getKey(), max));
+ return max.<Stream<Cache.Entry<SplitKey, IgniteBiTuple<Integer, Double>>>>
+ map(objects -> Stream.of(new CacheEntryImpl<>(e.getKey(), objects))).orElseGet(Stream::empty);
},
() -> IntStream.range(0, featuresCnt).mapToObj(fIdx -> SplitCache.key(fIdx, affinity.apply(ignite).apply(fIdx), trainingUUID)).collect(Collectors.toSet())
);
http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
index d69781e..a3f1d21 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/MnistUtils.java
@@ -25,6 +25,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Random;
import java.util.stream.Stream;
+import org.apache.ignite.IgniteException;
import org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector;
/**
@@ -40,14 +41,14 @@ public class MnistUtils {
* @param rnd Random numbers generatror.
* @param cnt Count of samples to read.
* @return Stream of MNIST samples.
- * @throws IOException
+ * @throws IOException In case of exception.
*/
public static Stream<DenseLocalOnHeapVector> mnist(String imagesPath, String labelsPath, Random rnd, int cnt)
throws IOException {
FileInputStream isImages = new FileInputStream(imagesPath);
FileInputStream isLabels = new FileInputStream(labelsPath);
- int magic = read4Bytes(isImages); // Skip magic number.
+ read4Bytes(isImages); // Skip magic number.
int numOfImages = read4Bytes(isImages);
int imgHeight = read4Bytes(isImages);
int imgWidth = read4Bytes(isImages);
@@ -57,10 +58,6 @@ public class MnistUtils {
int numOfPixels = imgHeight * imgWidth;
- System.out.println("Magic: " + magic);
- System.out.println("Num of images: " + numOfImages);
- System.out.println("Num of pixels: " + numOfPixels);
-
double[][] vecs = new double[numOfImages][numOfPixels + 1];
for (int imgNum = 0; imgNum < numOfImages; imgNum++) {
@@ -88,7 +85,7 @@ public class MnistUtils {
* @param outPath Path to output path.
* @param rnd Random numbers generator.
* @param cnt Count of samples to read.
- * @throws IOException
+ * @throws IOException In case of exception.
*/
public static void asLIBSVM(String imagesPath, String labelsPath, String outPath, Random rnd, int cnt)
throws IOException {
@@ -109,7 +106,7 @@ public class MnistUtils {
}
catch (IOException e) {
- e.printStackTrace();
+ throw new IgniteException("Error while converting to LIBSVM.");
}
});
}
@@ -119,9 +116,9 @@ public class MnistUtils {
* Utility method for reading 4 bytes from input stream.
*
* @param is Input stream.
- * @throws IOException
+ * @throws IOException In case of exception.
*/
private static int read4Bytes(FileInputStream is) throws IOException {
return (is.read() << 24) | (is.read() << 16) | (is.read() << 8) | (is.read());
}
-}
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
index bb779e3..847b1f1 100644
--- a/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
+++ b/modules/ml/src/main/java/org/apache/ignite/ml/util/Utils.java
@@ -22,6 +22,7 @@ import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
+import org.apache.ignite.IgniteException;
/**
* Class with various utility methods.
@@ -34,8 +35,9 @@ public class Utils {
* @param <T> Class of original object;
* @return Deep copy of original object.
*/
+ @SuppressWarnings({"unchecked"})
public static <T> T copy(T orig) {
- Object obj = null;
+ Object obj;
try {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
@@ -50,7 +52,7 @@ public class Utils {
obj = in.readObject();
}
catch (IOException | ClassNotFoundException e) {
- e.printStackTrace();
+ throw new IgniteException("Couldn't copy the object.");
}
return (T)obj;
http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
index 2b03b47..929ded9 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/ColumnDecisionTreeTrainerTest.java
@@ -26,6 +26,7 @@ import java.util.Random;
import java.util.stream.Collectors;
import java.util.stream.DoubleStream;
import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.internal.util.typedef.X;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.math.StorageConstants;
import org.apache.ignite.ml.math.Tracer;
@@ -183,7 +184,7 @@ public class ColumnDecisionTreeTrainerTest extends BaseDecisionTreeTest {
byRegion.keySet().forEach(k -> {
LabeledVectorDouble sp = byRegion.get(k).get(0);
Tracer.showAscii(sp.vector());
- System.out.println("Act: " + sp.label() + " " + " pred: " + mdl.predict(sp.vector()));
+ X.println("Actual and predicted vectors [act=" + sp.label() + " " + ", pred=" + mdl.predict(sp.vector()) + "]");
assert mdl.predict(sp.vector()) == sp.doubleLabel();
});
}
http://git-wip-us.apache.org/repos/asf/ignite/blob/a29fe352/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
----------------------------------------------------------------------
diff --git a/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java b/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
index 4e7cc24..7ca5d38 100644
--- a/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
+++ b/modules/ml/src/test/java/org/apache/ignite/ml/trees/performance/ColumnDecisionTreeTrainerBenchmark.java
@@ -45,6 +45,7 @@ import org.apache.ignite.configuration.CacheConfiguration;
import org.apache.ignite.configuration.IgniteConfiguration;
import org.apache.ignite.internal.processors.cache.GridCacheProcessor;
import org.apache.ignite.internal.util.IgniteUtils;
+import org.apache.ignite.internal.util.typedef.X;
import org.apache.ignite.lang.IgniteBiTuple;
import org.apache.ignite.ml.Model;
import org.apache.ignite.ml.estimators.Estimators;
@@ -163,14 +164,14 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer =
new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
- System.out.println(">>> Training started");
+ X.println("Training started.");
long before = System.currentTimeMillis();
DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt));
- System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
+ X.println("Training finished in " + (System.currentTimeMillis() - before));
IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
- System.out.println(">>> Errs percentage: " + accuracy);
+ X.println("Errors percentage: " + accuracy);
Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size());
Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size());
@@ -204,14 +205,14 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer =
new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
- System.out.println(">>> Training started");
+ X.println("Training started");
long before = System.currentTimeMillis();
DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>()));
- System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
+ X.println("Training finished in " + (System.currentTimeMillis() - before));
IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
- System.out.println(">>> Errs percentage: " + accuracy);
+ X.println("Errors percentage: " + accuracy);
Assert.assertEquals(0, SplitCache.getOrCreate(ignite).size());
Assert.assertEquals(0, FeaturesCache.getOrCreate(ignite).size());
@@ -252,10 +253,10 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
SparseDistributedMatrixStorage sto = (SparseDistributedMatrixStorage)m.getStorage();
long before = System.currentTimeMillis();
- System.out.println(">>> Batch loading started...");
+ X.println("Batch loading started...");
loadVectorsIntoSparseDistributedMatrixCache(sto.cache().getName(), sto.getUUID(), gen.
points(ptsPerReg, (i, rn) -> i).map(IgniteBiTuple::get2).iterator(), featCnt + 1);
- System.out.println(">>> Batch loading took " + (System.currentTimeMillis() - before) + " ms.");
+ X.println("Batch loading took " + (System.currentTimeMillis() - before) + " ms.");
for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) {
byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
@@ -268,12 +269,12 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
before = System.currentTimeMillis();
DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo));
- System.out.println(">>> Took time(ms): " + (System.currentTimeMillis() - before));
+ X.println("Training took: " + (System.currentTimeMillis() - before) + " ms.");
byRegion.keySet().forEach(k -> {
LabeledVectorDouble sp = byRegion.get(k).get(0);
Tracer.showAscii(sp.vector());
- System.out.println("Prediction: " + mdl.predict(sp.vector()) + "label: " + sp.doubleLabel());
+ X.println("Predicted value and label [pred=" + mdl.predict(sp.vector()) + ", label=" + sp.doubleLabel() + "]");
assert mdl.predict(sp.vector()) == sp.doubleLabel();
});
}
@@ -307,16 +308,16 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
ColumnDecisionTreeTrainer<VarianceSplitCalculator.VarianceData> trainer =
new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, regCalc, ignite);
- System.out.println(">>> Training started");
+ X.println("Training started.");
long before = System.currentTimeMillis();
DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, new HashMap<>()));
- System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
+ X.println("Training finished in: " + (System.currentTimeMillis() - before) + " ms.");
Vector[] testVectors = vecsFromRanges(ranges, featCnt, defRng, new Random(123L), 20, f1);
IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.MSE();
Double accuracy = mse.apply(mdl, Arrays.stream(testVectors).map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
- System.out.println(">>> MSE: " + accuracy);
+ X.println("MSE: " + accuracy);
}
/**
@@ -358,7 +359,7 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
for (int i = 0; i < vectorSize; i++)
batch.get(i).put(sampleIdx, next.getX(i));
- System.out.println(sampleIdx);
+ X.println("Sample index: " + sampleIdx);
if (sampleIdx % batchSize == 0) {
batch.keySet().forEach(fi -> streamer.addData(new SparseMatrixKey(fi, uuid, fi), batch.get(fi)));
IntStream.range(0, vectorSize).forEach(i -> batch.put(i, new HashMap<>()));
@@ -396,7 +397,7 @@ public class ColumnDecisionTreeTrainerBenchmark extends BaseDecisionTreeTest {
sampleIdx++;
if (sampleIdx % 1000 == 0)
- System.out.println(">>> Loaded " + sampleIdx + " vectors.");
+ System.out.println("Loaded: " + sampleIdx + " vectors.");
}
}
}