You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@mahout.apache.org by sm...@apache.org on 2013/11/24 08:02:17 UTC
svn commit: r1544930 - in /mahout/trunk/math/src:
main/java/org/apache/mahout/math/stats/
test/java/org/apache/mahout/math/stats/
Author: smarthi
Date: Sun Nov 24 07:02:17 2013
New Revision: 1544930
URL: http://svn.apache.org/r1544930
Log:
MAHOUT-1361: Updated with latest changes from Ted's git repo.
Added:
mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/TDigest.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/TDigestTest.java
Removed:
mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/Histo.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/HistoTest.java
Modified:
mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/GroupTree.java
mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/GroupTreeTest.java
Modified: mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/GroupTree.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/GroupTree.java?rev=1544930&r1=1544929&r2=1544930&view=diff
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/GroupTree.java (original)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/GroupTree.java Sun Nov 24 07:02:17 2013
@@ -1,20 +1,3 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
package org.apache.mahout.math.stats;
import com.google.common.base.Preconditions;
@@ -25,14 +8,14 @@ import java.util.Deque;
import java.util.Iterator;
/**
- * A tree containing Histo.Group. This adds to the normal NavigableSet the
+ * A tree containing TDigest.Group. This adds to the normal NavigableSet the
* ability to sum up the size of elements to the left of a particular group.
*/
-public class GroupTree implements Iterable<Histo.Group> {
+public class GroupTree implements Iterable<TDigest.Group> {
private int count;
private int size;
private int depth;
- private Histo.Group leaf;
+ private TDigest.Group leaf;
private GroupTree left, right;
public GroupTree() {
@@ -41,7 +24,7 @@ public class GroupTree implements Iterab
left = right = null;
}
- public GroupTree(Histo.Group leaf) {
+ public GroupTree(TDigest.Group leaf) {
size = depth = 1;
this.leaf = leaf;
count = leaf.count();
@@ -57,7 +40,7 @@ public class GroupTree implements Iterab
leaf = this.right.first();
}
- public void add(Histo.Group group) {
+ public void add(TDigest.Group group) {
if (size == 0) {
leaf = group;
depth = 1;
@@ -126,7 +109,7 @@ public class GroupTree implements Iterab
/**
* @return the number of items strictly before the current element
*/
- public int headCount(Histo.Group base) {
+ public int headCount(TDigest.Group base) {
if (size == 0) {
return 0;
} else if (left == null) {
@@ -143,7 +126,7 @@ public class GroupTree implements Iterab
/**
* @return the sum of the size() function for all elements strictly before the current element.
*/
- public int headSum(Histo.Group base) {
+ public int headSum(TDigest.Group base) {
if (size == 0) {
return 0;
} else if (left == null) {
@@ -160,7 +143,7 @@ public class GroupTree implements Iterab
/**
* @return the first Group in this set
*/
- public Histo.Group first() {
+ public TDigest.Group first() {
Preconditions.checkState(size > 0, "No first element of empty set");
if (left == null) {
return leaf;
@@ -172,7 +155,7 @@ public class GroupTree implements Iterab
/**
* Iteratres through all groups in the tree.
*/
- public Iterator<Histo.Group> iterator() {
+ public Iterator<TDigest.Group> iterator() {
return iterator(null);
}
@@ -182,8 +165,8 @@ public class GroupTree implements Iterab
* @return An iterator that goes through the groups in order of mean and id starting at or after the
* specified Group.
*/
- private Iterator<Histo.Group> iterator(final Histo.Group start) {
- return new AbstractIterator<Histo.Group>() {
+ private Iterator<TDigest.Group> iterator(final TDigest.Group start) {
+ return new AbstractIterator<TDigest.Group>() {
{
stack = Queues.newArrayDeque();
push(GroupTree.this, start);
@@ -193,7 +176,7 @@ public class GroupTree implements Iterab
// recurses down to the leaf that is >= start
// pending right hand branches on the way are put on the stack
- private void push(GroupTree z, Histo.Group start) {
+ private void push(GroupTree z, TDigest.Group start) {
while (z.left != null) {
if (start == null || start.compareTo(z.leaf) < 0) {
// remember we will have to process the right hand branch later
@@ -212,7 +195,7 @@ public class GroupTree implements Iterab
}
@Override
- protected Histo.Group computeNext() {
+ protected TDigest.Group computeNext() {
GroupTree r = stack.poll();
while (r != null && r.left != null) {
// unpack r onto the stack
@@ -231,7 +214,7 @@ public class GroupTree implements Iterab
};
}
- public void remove(Histo.Group base) {
+ public void remove(TDigest.Group base) {
Preconditions.checkState(size > 0, "Cannot remove from empty set");
if (size == 1) {
Preconditions.checkArgument(base.compareTo(leaf) == 0, "Element %s not found", base);
@@ -274,7 +257,7 @@ public class GroupTree implements Iterab
/**
* @return the largest element less than or equal to base
*/
- public Histo.Group floor(Histo.Group base) {
+ public TDigest.Group floor(TDigest.Group base) {
if (size == 0) {
return null;
} else {
@@ -284,7 +267,7 @@ public class GroupTree implements Iterab
if (base.compareTo(leaf) < 0) {
return left.floor(base);
} else {
- Histo.Group floor = right.floor(base);
+ TDigest.Group floor = right.floor(base);
if (floor == null) {
floor = left.last();
}
@@ -294,7 +277,7 @@ public class GroupTree implements Iterab
}
}
- public Histo.Group last() {
+ public TDigest.Group last() {
Preconditions.checkState(size > 0, "Cannot find last element of empty set");
if (size == 1) {
return leaf;
@@ -306,14 +289,14 @@ public class GroupTree implements Iterab
/**
* @return the smallest element greater than or equal to base.
*/
- public Histo.Group ceiling(Histo.Group base) {
+ public TDigest.Group ceiling(TDigest.Group base) {
if (size == 0) {
return null;
} else if (size == 1) {
return base.compareTo(leaf) <= 0 ? leaf : null;
} else {
if (base.compareTo(leaf) < 0) {
- Histo.Group r = left.ceiling(base);
+ TDigest.Group r = left.ceiling(base);
if (r == null) {
r = right.first();
}
@@ -327,10 +310,10 @@ public class GroupTree implements Iterab
/**
* @return the subset of elements equal to or greater than base.
*/
- public Iterable<Histo.Group> tailSet(final Histo.Group start) {
- return new Iterable<Histo.Group>() {
+ public Iterable<TDigest.Group> tailSet(final TDigest.Group start) {
+ return new Iterable<TDigest.Group>() {
@Override
- public Iterator<Histo.Group> iterator() {
+ public Iterator<TDigest.Group> iterator() {
return GroupTree.this.iterator(start);
}
};
Added: mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/TDigest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/TDigest.java?rev=1544930&view=auto
==============================================================================
--- mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/TDigest.java (added)
+++ mahout/trunk/math/src/main/java/org/apache/mahout/math/stats/TDigest.java Sun Nov 24 07:02:17 2013
@@ -0,0 +1,564 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+import com.google.common.base.Preconditions;
+import com.google.common.collect.Lists;
+import org.apache.mahout.common.RandomUtils;
+
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+import java.util.concurrent.atomic.AtomicInteger;
+
+/**
+ * Adaptive histogram based on something like streaming k-means crossed with Q-digest.
+ * <p/>
+ * The special characteristics of this algorithm are:
+ * <p/>
+ * a) smaller summaries than Q-digest
+ * <p/>
+ * b) works on doubles as well as integers.
+ * <p/>
+ * c) provides part per million accuracy for extreme quantiles and typically <1000 ppm accuracy for middle quantiles
+ * <p/>
+ * d) fast
+ * <p/>
+ * e) simple
+ * <p/>
+ * f) test coverage > 90%
+ * <p/>
+ * g) easy to adapt for use with map-reduce
+ */
+public class TDigest {
+ private Random gen = RandomUtils.getRandom();
+
+ private double compression = 100;
+ private GroupTree summary = new GroupTree();
+ private int count = 0;
+ private boolean recordAllData = false;
+
+ /**
+ * A histogram structure that will record a sketch of a distribution.
+ *
+ * @param compression How should accuracy be traded for size? A value of N here will give quantile errors
+ * almost always less than 3/N with considerably smaller errors expected for extreme
+ * quantiles. Conversely, you should expect to track about 5 N centroids for this
+ * accuracy.
+ */
+ public TDigest(double compression) {
+ this.compression = compression;
+ }
+
+ /**
+ * Adds a sample to a histogram.
+ *
+ * @param x The value to add.
+ */
+ public void add(double x) {
+ add(x, 1);
+ }
+
+ /**
+ * Adds a sample to a histogram.
+ *
+ * @param x The value to add.
+ * @param w The weight of this point.
+ */
+ public void add(double x, int w) {
+ // note that because of a zero id, this will be sorted *before* any existing Group with the same mean
+ Group base = createGroup(x, 0);
+ add(x, w, base);
+ }
+
+ private void add(double x, int w, Group base) {
+ Group start = summary.floor(base);
+ if (start == null) {
+ start = summary.ceiling(base);
+ }
+
+ if (start == null) {
+ summary.add(Group.createWeighted(x, w, base.data()));
+ count = w;
+ } else {
+ Iterable<Group> neighbors = summary.tailSet(start);
+ double minDistance = Double.MAX_VALUE;
+ int lastNeighbor = 0;
+ int i = summary.headCount(start);
+ for (Group neighbor : neighbors) {
+ double z = Math.abs(neighbor.mean() - x);
+ if (z <= minDistance) {
+ minDistance = z;
+ lastNeighbor = i;
+ } else {
+ break;
+ }
+ i++;
+ }
+
+ Group closest = null;
+ int sum = summary.headSum(start);
+ i = summary.headCount(start);
+ double n = 1;
+ for (Group neighbor : neighbors) {
+ if (i > lastNeighbor) {
+ break;
+ }
+ double z = Math.abs(neighbor.mean() - x);
+ double q = (sum + neighbor.count() / 2.0) / count;
+ double k = 4 * count * q * (1 - q) / compression;
+
+ // this slightly clever selection method improves accuracy with lots of repeated points
+ if (z == minDistance && neighbor.count() + w <= k) {
+ if (gen.nextDouble() < 1 / n) {
+ closest = neighbor;
+ }
+ n++;
+ }
+ sum += neighbor.count();
+ i++;
+ }
+
+ if (closest == null) {
+ summary.add(Group.createWeighted(x, w, base.data()));
+ } else {
+ summary.remove(closest);
+ closest.add(x, w, base.data());
+ summary.add(closest);
+ }
+ count += w;
+
+ if (summary.size() > 100 * compression) {
+ // something such as sequential ordering of data points
+ // has caused a pathological expansion of our summary.
+ // To fight this, we simply replay the current centroids
+ // in random order.
+
+ // this causes us to forget the diagnostic recording of data points
+ compress();
+ }
+ }
+ }
+
+ public void add(TDigest other) {
+ List<Group> tmp = Lists.newArrayList(other.summary);
+ Collections.shuffle(tmp);
+ for (Group group : tmp) {
+ add(group.mean(), group.count(), group);
+ }
+ }
+
+ public static TDigest merge(double compression, Iterable<TDigest> subData) {
+ List<TDigest> elements = Lists.newArrayList(subData);
+ int n = Math.max(1, elements.size() / 4);
+ TDigest r = new TDigest(compression);
+ if (elements.size() > 0 && elements.get(0).recordAllData) {
+ r.recordAllData();
+ }
+ for (int i = 0; i < elements.size(); i += n) {
+ if (n > 1) {
+ r.add(merge(compression, elements.subList(i, Math.min(i + n, elements.size()))));
+ } else {
+ r.add(elements.get(i));
+ }
+ }
+ return r;
+ }
+
+ public void compress() {
+ compress(summary);
+ }
+
+ private void compress(GroupTree other) {
+ TDigest reduced = new TDigest(compression);
+ if (recordAllData) {
+ reduced.recordAllData();
+ }
+ List<Group> tmp = Lists.newArrayList(other);
+ Collections.shuffle(tmp);
+ for (Group group : tmp) {
+ reduced.add(group.mean(), group.count(), group);
+ }
+
+ summary = reduced.summary;
+ }
+
+ /**
+ * Returns the number of samples represented in this histogram. If you want to know how many
+ * centroids are being used, try centroids().size().
+ *
+ * @return the number of samples that have been added.
+ */
+ public int size() {
+ return count;
+ }
+
+ /**
+ * @param x the value at which the CDF should be evaluated
+ * @return the approximate fraction of all samples that were less than or equal to x.
+ */
+ public double cdf(double x) {
+ GroupTree values = summary;
+ if (values.size() == 0) {
+ return Double.NaN;
+ } else if (values.size() == 1) {
+ return x < values.first().mean() ? 0 : 1;
+ } else {
+ double r = 0;
+
+ // we scan a across the centroids
+ Iterator<Group> it = values.iterator();
+ Group a = it.next();
+
+ // b is the look-ahead to the next centroid
+ Group b = it.next();
+
+ // initially, we set left width equal to right width
+ double left = (b.mean() - a.mean()) / 2;
+ double right = left;
+
+ // scan to next to last element
+ while (it.hasNext()) {
+ if (x < a.mean() + right) {
+ return (r + a.count() * interpolate(x, a.mean() - left, a.mean() + right)) / count;
+ }
+ r += a.count();
+
+ a = b;
+ b = it.next();
+
+ left = right;
+ right = (b.mean() - a.mean()) / 2;
+ }
+
+ // for the last element, assume right width is same as left
+ left = right;
+ a = b;
+ if (x < a.mean() + right) {
+ return (r + a.count() * interpolate(x, a.mean() - left, a.mean() + right)) / count;
+ } else {
+ return 1;
+ }
+ }
+ }
+
+ /**
+ * @param q The quantile desired. Can be in the range [0,1].
+ * @return The minimum value x such that we think that the proportion of samples is <= x is q.
+ */
+ public double quantile(double q) {
+ GroupTree values = summary;
+ Preconditions.checkArgument(values.size() > 1);
+
+ Iterator<Group> it = values.iterator();
+ Group a = it.next();
+ Group b = it.next();
+ if (!it.hasNext()) {
+ // both a and b have to have just a single element
+ double diff = (b.mean() - a.mean()) / 2;
+ if (q > 0.75) {
+ return b.mean() + diff * (4 * q - 3);
+ } else {
+ return a.mean() + diff * (4 * q - 1);
+ }
+ } else {
+ q *= count;
+ double right = (b.mean() - a.mean()) / 2;
+ // we have nothing else to go on so make left hanging width same as right to start
+ double left = right;
+
+ if (q <= a.count()) {
+ return a.mean() + left * (2 * q - a.count()) / a.count();
+ } else {
+ double t = a.count();
+ while (it.hasNext()) {
+ if (t + b.count() / 2 >= q) {
+ // left of b
+ return b.mean() - left * 2 * (q - t) / b.count();
+ } else if (t + b.count() >= q) {
+ // right of b but left of the left-most thing beyond
+ return b.mean() + right * 2 * (q - t - b.count() / 2.0) / b.count();
+ }
+ t += b.count();
+
+ a = b;
+ b = it.next();
+ left = right;
+ right = (b.mean() - a.mean()) / 2;
+ }
+ // shouldn't be possible but we have an answer anyway
+ return b.mean() + right;
+ }
+ }
+ }
+
+ public int centroidCount() {
+ return summary.size();
+ }
+
+ public Iterable<? extends Group> centroids() {
+ return summary;
+ }
+
+ public double compression() {
+ return compression;
+ }
+
+ /**
+ * Sets up so that all centroids will record all data assigned to them. For testing only, really.
+ */
+ public TDigest recordAllData() {
+ recordAllData = true;
+ return this;
+ }
+
+ /**
+ * Returns an upper bound on the number bytes that will be required to represent this histogram.
+ */
+ public int byteSize() {
+ return 4 + 8 + 4 + summary.size() * 12;
+ }
+
+ /**
+ * Returns an upper bound on the number of bytes that will be required to represent this histogram in
+ * the tighter representation.
+ */
+ public int smallByteSize() {
+ int bound = byteSize();
+ ByteBuffer buf = ByteBuffer.allocate(bound);
+ asSmallBytes(buf);
+ return buf.position();
+ }
+
+ public final static int VERBOSE_ENCODING = 1;
+ public final static int SMALL_ENCODING = 2;
+
+ /**
+ * Outputs a histogram as bytes using a particularly cheesy encoding.
+ */
+ public void asBytes(ByteBuffer buf) {
+ buf.putInt(VERBOSE_ENCODING);
+ buf.putDouble(compression());
+ buf.putInt(summary.size());
+ for (Group group : summary) {
+ buf.putDouble(group.mean());
+ }
+
+ for (Group group : summary) {
+ buf.putInt(group.count());
+ }
+ }
+
+ public void asSmallBytes(ByteBuffer buf) {
+ buf.putInt(SMALL_ENCODING);
+ buf.putDouble(compression());
+ buf.putInt(summary.size());
+
+ double x = 0;
+ for (Group group : summary) {
+ double delta = group.mean() - x;
+ x = group.mean();
+ buf.putFloat((float) delta);
+ }
+
+ for (Group group : summary) {
+ int n = group.count();
+ encode(buf, n);
+ }
+ }
+
+ public static void encode(ByteBuffer buf, int n) {
+ int k = 0;
+ while (n < 0 || n > 0x7f) {
+ byte b = (byte) (0x80 | (0x7f & n));
+ buf.put(b);
+ n = n >>> 7;
+ k++;
+ Preconditions.checkState(k < 6);
+ }
+ buf.put((byte) n);
+ }
+
+ public static int decode(ByteBuffer buf) {
+ int v = buf.get();
+ int z = 0x7f & v;
+ int shift = 7;
+ while ((v & 0x80) != 0) {
+ Preconditions.checkState(shift <= 28);
+ v = buf.get();
+ z += (v & 0x7f) << shift;
+ shift += 7;
+ }
+ return z;
+ }
+
+ /**
+ * Reads a histogram from a byte buffer
+ *
+ * @return The new histogram structure
+ */
+ public static TDigest fromBytes(ByteBuffer buf) {
+ int encoding = buf.getInt();
+ if (encoding == VERBOSE_ENCODING) {
+ double compression = buf.getDouble();
+ TDigest r = new TDigest(compression);
+ int n = buf.getInt();
+ double[] means = new double[n];
+ for (int i = 0; i < n; i++) {
+ means[i] = buf.getDouble();
+ }
+ for (int i = 0; i < n; i++) {
+ r.add(means[i], buf.getInt());
+ }
+ return r;
+ } else if (encoding == SMALL_ENCODING) {
+ double compression = buf.getDouble();
+ TDigest r = new TDigest(compression);
+ int n = buf.getInt();
+ double[] means = new double[n];
+ double x = 0;
+ for (int i = 0; i < n; i++) {
+ double delta = buf.getFloat();
+ x += delta;
+ means[i] = x;
+ }
+
+ for (int i = 0; i < n; i++) {
+ int z = decode(buf);
+ r.add(means[i], z);
+ }
+ return r;
+ } else {
+ throw new IllegalStateException("Invalid format for serialized histogram");
+ }
+ }
+
+ private Group createGroup(double mean, int id) {
+ return new Group(mean, id, recordAllData);
+ }
+
+ private double interpolate(double x, double x0, double x1) {
+ return (x - x0) / (x1 - x0);
+ }
+
+ public static class Group implements Comparable<Group> {
+ private static final AtomicInteger uniqueCount = new AtomicInteger(1);
+
+ private double centroid = 0;
+ private int count = 0;
+ private int id;
+
+ private List<Double> actualData = null;
+
+ private Group(boolean record) {
+ id = uniqueCount.incrementAndGet();
+ if (record) {
+ actualData = Lists.newArrayList();
+ }
+ }
+
+ public Group(double x) {
+ this(false);
+ start(x, uniqueCount.getAndIncrement());
+ }
+
+ public Group(double x, int id) {
+ this(false);
+ start(x, id);
+ }
+
+ public Group(double x, int id, boolean record) {
+ this(record);
+ start(x, id);
+ }
+
+ private void start(double x, int id) {
+ this.id = id;
+ add(x, 1);
+ }
+
+ public void add(double x, int w) {
+ if (actualData != null) {
+ actualData.add(x);
+ }
+ count += w;
+ centroid += w * (x - centroid) / count;
+ }
+
+ public double mean() {
+ return centroid;
+ }
+
+ public int count() {
+ return count;
+ }
+
+ public int id() {
+ return id;
+ }
+
+ @Override
+ public String toString() {
+ return "Group{" +
+ "centroid=" + centroid +
+ ", count=" + count +
+ '}';
+ }
+
+ @Override
+ public int hashCode() {
+ return id;
+ }
+
+ @Override
+ public int compareTo(Group o) {
+ int r = Double.compare(centroid, o.centroid);
+ if (r == 0) {
+ r = id - o.id;
+ }
+ return r;
+ }
+
+ public Iterable<? extends Double> data() {
+ return actualData;
+ }
+
+ public static Group createWeighted(double x, int w, Iterable<? extends Double> data) {
+ Group r = new Group(data != null);
+ r.add(x, w, data);
+ return r;
+ }
+
+ private void add(double x, int w, Iterable<? extends Double> data) {
+ if (actualData != null) {
+ if (data != null) {
+ for (Double old : data) {
+ actualData.add(old);
+ }
+ } else {
+ actualData.add(x);
+ }
+ }
+ count += w;
+ centroid += w * (x - centroid) / count;
+ }
+ }
+
+}
Modified: mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/GroupTreeTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/GroupTreeTest.java?rev=1544930&r1=1544929&r2=1544930&view=diff
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/GroupTreeTest.java (original)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/GroupTreeTest.java Sun Nov 24 07:02:17 2013
@@ -1,20 +1,3 @@
-/**
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
package org.apache.mahout.math.stats;
import com.google.common.collect.Lists;
@@ -32,13 +15,13 @@ public class GroupTreeTest {
@Test
public void testSimpleAdds() {
GroupTree x = new GroupTree();
- assertNull(x.floor(new Histo.Group(34)));
- assertNull(x.ceiling(new Histo.Group(34)));
+ assertNull(x.floor(new TDigest.Group(34)));
+ assertNull(x.ceiling(new TDigest.Group(34)));
assertEquals(0, x.size());
assertEquals(0, x.sum());
- x.add(new Histo.Group(1));
- Histo.Group group = new Histo.Group(2);
+ x.add(new TDigest.Group(1));
+ TDigest.Group group = new TDigest.Group(2);
group.add(3, 1);
group.add(4, 1);
x.add(group);
@@ -51,7 +34,7 @@ public class GroupTreeTest {
public void testBalancing() {
GroupTree x = new GroupTree();
for (int i = 0; i < 101; i++) {
- x.add(new Histo.Group(i));
+ x.add(new TDigest.Group(i));
}
assertEquals(101, x.sum());
@@ -64,59 +47,59 @@ public class GroupTreeTest {
public void testIterators() {
GroupTree x = new GroupTree();
for (int i = 0; i < 101; i++) {
- x.add(new Histo.Group(i / 2));
+ x.add(new TDigest.Group(i / 2));
}
assertEquals(0, x.first().mean(), 0);
assertEquals(50, x.last().mean(), 0);
- Iterator<Histo.Group> ix = x.iterator();
+ Iterator<TDigest.Group> ix = x.iterator();
for (int i = 0; i < 101; i++) {
assertTrue(ix.hasNext());
- Histo.Group z = ix.next();
+ TDigest.Group z = ix.next();
assertEquals(i / 2, z.mean(), 0);
}
assertFalse(ix.hasNext());
// 34 is special since it is the smallest element of the right hand sub-tree
- Iterable<Histo.Group> z = x.tailSet(new Histo.Group(34, 0));
+ Iterable<TDigest.Group> z = x.tailSet(new TDigest.Group(34, 0));
ix = z.iterator();
for (int i = 68; i < 101; i++) {
assertTrue(ix.hasNext());
- Histo.Group v = ix.next();
+ TDigest.Group v = ix.next();
assertEquals(i / 2, v.mean(), 0);
}
assertFalse(ix.hasNext());
ix = z.iterator();
for (int i = 68; i < 101; i++) {
- Histo.Group v = ix.next();
+ TDigest.Group v = ix.next();
assertEquals(i / 2, v.mean(), 0);
}
- z = x.tailSet(new Histo.Group(33, 0));
+ z = x.tailSet(new TDigest.Group(33, 0));
ix = z.iterator();
for (int i = 66; i < 101; i++) {
assertTrue(ix.hasNext());
- Histo.Group v = ix.next();
+ TDigest.Group v = ix.next();
assertEquals(i / 2, v.mean(), 0);
}
assertFalse(ix.hasNext());
- z = x.tailSet(x.ceiling(new Histo.Group(34, 0)));
+ z = x.tailSet(x.ceiling(new TDigest.Group(34, 0)));
ix = z.iterator();
for (int i = 68; i < 101; i++) {
assertTrue(ix.hasNext());
- Histo.Group v = ix.next();
+ TDigest.Group v = ix.next();
assertEquals(i / 2, v.mean(), 0);
}
assertFalse(ix.hasNext());
- z = x.tailSet(x.floor(new Histo.Group(34, 0)));
+ z = x.tailSet(x.floor(new TDigest.Group(34, 0)));
ix = z.iterator();
for (int i = 67; i < 101; i++) {
assertTrue(ix.hasNext());
- Histo.Group v = ix.next();
+ TDigest.Group v = ix.next();
assertEquals(i / 2, v.mean(), 0);
}
assertFalse(ix.hasNext());
@@ -127,10 +110,10 @@ public class GroupTreeTest {
// mostly tested in other tests
GroupTree x = new GroupTree();
for (int i = 0; i < 101; i++) {
- x.add(new Histo.Group(i / 2));
+ x.add(new TDigest.Group(i / 2));
}
- assertNull(x.floor(new Histo.Group(-30)));
+ assertNull(x.floor(new TDigest.Group(-30)));
}
@@ -138,40 +121,40 @@ public class GroupTreeTest {
public void testRemoveAndSums() {
GroupTree x = new GroupTree();
for (int i = 0; i < 101; i++) {
- x.add(new Histo.Group(i / 2));
+ x.add(new TDigest.Group(i / 2));
}
- Histo.Group g = x.ceiling(new Histo.Group(2, 0));
+ TDigest.Group g = x.ceiling(new TDigest.Group(2, 0));
x.remove(g);
g.add(3, 1);
x.add(g);
- assertEquals(0, x.headCount(new Histo.Group(-1)));
- assertEquals(0, x.headSum(new Histo.Group(-1)));
- assertEquals(0, x.headCount(new Histo.Group(0, 0)));
- assertEquals(0, x.headSum(new Histo.Group(0, 0)));
- assertEquals(0, x.headCount(x.ceiling(new Histo.Group(0, 0))));
- assertEquals(0, x.headSum(x.ceiling(new Histo.Group(0, 0))));
- assertEquals(2, x.headCount(new Histo.Group(1, 0)));
- assertEquals(2, x.headSum(new Histo.Group(1, 0)));
+ assertEquals(0, x.headCount(new TDigest.Group(-1)));
+ assertEquals(0, x.headSum(new TDigest.Group(-1)));
+ assertEquals(0, x.headCount(new TDigest.Group(0, 0)));
+ assertEquals(0, x.headSum(new TDigest.Group(0, 0)));
+ assertEquals(0, x.headCount(x.ceiling(new TDigest.Group(0, 0))));
+ assertEquals(0, x.headSum(x.ceiling(new TDigest.Group(0, 0))));
+ assertEquals(2, x.headCount(new TDigest.Group(1, 0)));
+ assertEquals(2, x.headSum(new TDigest.Group(1, 0)));
- g = x.tailSet(new Histo.Group(2.1)).iterator().next();
+ g = x.tailSet(new TDigest.Group(2.1)).iterator().next();
assertEquals(2.5, g.mean(), 1e-9);
int i = 0;
- for (Histo.Group gx : x) {
+ for (TDigest.Group gx : x) {
if (i > 10) {
break;
}
System.out.printf("%d:%.1f(%d)\t", i++, gx.mean(), gx.count());
}
- assertEquals(5, x.headCount(new Histo.Group(2.1, 0)));
- assertEquals(5, x.headSum(new Histo.Group(2.1, 0)));
+ assertEquals(5, x.headCount(new TDigest.Group(2.1, 0)));
+ assertEquals(5, x.headSum(new TDigest.Group(2.1, 0)));
- assertEquals(6, x.headCount(new Histo.Group(2.7, 0)));
- assertEquals(7, x.headSum(new Histo.Group(2.7, 0)));
+ assertEquals(6, x.headCount(new TDigest.Group(2.7, 0)));
+ assertEquals(7, x.headSum(new TDigest.Group(2.7, 0)));
- assertEquals(101, x.headCount(new Histo.Group(200)));
- assertEquals(102, x.headSum(new Histo.Group(200)));
+ assertEquals(101, x.headCount(new TDigest.Group(200)));
+ assertEquals(102, x.headSum(new TDigest.Group(200)));
}
@Test
@@ -182,7 +165,7 @@ public class GroupTreeTest {
List<Double> y = Lists.newArrayList();
for (int i = 0; i < 1000; i++) {
double v = gen.nextDouble();
- x.add(new Histo.Group(v));
+ x.add(new TDigest.Group(v));
y.add(v);
x.checkBalance();
}
@@ -190,26 +173,26 @@ public class GroupTreeTest {
Collections.sort(y);
Iterator<Double> i = y.iterator();
- for (Histo.Group group : x) {
+ for (TDigest.Group group : x) {
assertEquals(i.next(), group.mean(), 0.0);
}
for (int j = 0; j < 100; j++) {
double v = y.get(gen.nextInt(y.size()));
y.remove(v);
- x.remove(x.floor(new Histo.Group(v)));
+ x.remove(x.floor(new TDigest.Group(v)));
}
Collections.sort(y);
i = y.iterator();
- for (Histo.Group group : x) {
+ for (TDigest.Group group : x) {
assertEquals(i.next(), group.mean(), 0.0);
}
for (int j = 0; j < y.size(); j++) {
double v = y.get(j);
y.set(j, v + 10);
- Histo.Group g = x.floor(new Histo.Group(v));
+ TDigest.Group g = x.floor(new TDigest.Group(v));
x.remove(g);
x.checkBalance();
g.add(g.mean() + 20, 1);
@@ -218,7 +201,7 @@ public class GroupTreeTest {
}
i = y.iterator();
- for (Histo.Group group : x) {
+ for (TDigest.Group group : x) {
assertEquals(i.next(), group.mean(), 0.0);
}
}
Added: mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/TDigestTest.java
URL: http://svn.apache.org/viewvc/mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/TDigestTest.java?rev=1544930&view=auto
==============================================================================
--- mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/TDigestTest.java (added)
+++ mahout/trunk/math/src/test/java/org/apache/mahout/math/stats/TDigestTest.java Sun Nov 24 07:02:17 2013
@@ -0,0 +1,483 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.mahout.math.stats;
+
+import com.google.common.collect.HashMultiset;
+import com.google.common.collect.Iterables;
+import com.google.common.collect.Lists;
+import com.google.common.collect.Multiset;
+import org.apache.mahout.common.RandomUtils;
+import org.apache.mahout.common.RandomWrapper;
+import org.apache.mahout.math.jet.random.AbstractContinousDistribution;
+import org.apache.mahout.math.jet.random.Gamma;
+import org.apache.mahout.math.jet.random.Normal;
+import org.apache.mahout.math.jet.random.Uniform;
+import org.junit.*;
+
+import java.io.FileWriter;
+import java.io.IOException;
+import java.io.PrintWriter;
+import java.nio.ByteBuffer;
+import java.util.Collections;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Random;
+
+import static org.junit.Assert.*;
+
+public class TDigestTest {
+
+ private static PrintWriter sizeDump;
+ private static PrintWriter errorDump;
+ private static PrintWriter deviationDump;
+
+ @BeforeClass
+ public static void setup() throws IOException {
+ sizeDump = new PrintWriter(new FileWriter("sizes.csv"));
+ sizeDump.printf("tag\ti\tq\tk\tactual\n");
+
+ errorDump = new PrintWriter((new FileWriter("errors.csv")));
+ errorDump.printf("dist\ttag\tx\tQ\terror\n");
+
+ deviationDump = new PrintWriter((new FileWriter("deviation.csv")));
+ deviationDump.printf("tag\tQ\tk\tx\tmean\tleft\tright\tdeviation\n");
+ }
+
+ @AfterClass
+ public static void teardown() {
+ sizeDump.close();
+ errorDump.close();
+ deviationDump.close();
+ }
+
+ @After
+ public void flush() {
+ sizeDump.flush();
+ errorDump.flush();
+ deviationDump.flush();
+ }
+
+ @Test
+ public void testUniform() {
+ RandomWrapper gen = RandomUtils.getRandom();
+ for (int i = 0; i < 5; i++) {
+ runTest(new Uniform(0, 1, gen), 100,
+// new double[]{0.0001, 0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999, 0.9999},
+ new double[]{0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999},
+ "uniform", true);
+ }
+ }
+
+ @Test
+ public void testGamma() {
+ // this Gamma distribution is very heavily skewed. The 0.1%-ile is 6.07e-30 while
+ // the median is 0.006 and the 99.9th %-ile is 33.6 while the mean is 1.
+ // this severe skew means that we have to have positional accuracy that
+ // varies by over 11 orders of magnitude.
+ RandomWrapper gen = RandomUtils.getRandom();
+ for (int i = 0; i < 5; i++) {
+ runTest(new Gamma(0.1, 0.1, gen), 100,
+// new double[]{6.0730483624079e-30, 6.0730483624079e-20, 6.0730483627432e-10, 5.9339110446023e-03,
+// 2.6615455373884e+00, 1.5884778179295e+01, 3.3636770117188e+01},
+ new double[]{0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999},
+ "gamma", true);
+ }
+ }
+
+ @Test
+ public void testNarrowNormal() {
+ // this mixture of a uniform and normal distribution has a very narrow peak which is centered
+ // near the median. Our system should be scale invariant and work well regardless.
+ final RandomWrapper gen = RandomUtils.getRandom();
+ AbstractContinousDistribution mix = new AbstractContinousDistribution() {
+ AbstractContinousDistribution normal = new Normal(0, 1e-5, gen);
+ AbstractContinousDistribution uniform = new Uniform(-1, 1, gen);
+
+ @Override
+ public double nextDouble() {
+ double x;
+ if (gen.nextDouble() < 0.5) {
+ x = uniform.nextDouble();
+ } else {
+ x = normal.nextDouble();
+ }
+ return x;
+ }
+ };
+
+ for (int i = 0; i < 5; i++) {
+ runTest(mix, 100, new double[]{0.001, 0.01, 0.1, 0.3, 0.5, 0.7, 0.9, 0.99, 0.999}, "mixture", false);
+ }
+ }
+
+ @Test
+ public void testRepeatedValues() {
+ final RandomWrapper gen = RandomUtils.getRandom();
+
+ // 5% of samples will be 0 or 1.0. 10% for each of the values 0.1 through 0.9
+ AbstractContinousDistribution mix = new AbstractContinousDistribution() {
+ @Override
+ public double nextDouble() {
+ return Math.rint(gen.nextDouble() * 10) / 10.0;
+ }
+ };
+
+ TDigest dist = new TDigest((double) 1000);
+ long t0 = System.nanoTime();
+ Multiset<Double> data = HashMultiset.create();
+ for (int i1 = 0; i1 < 100000; i1++) {
+ double x = mix.nextDouble();
+ data.add(x);
+ dist.add(x);
+ }
+
+ System.out.printf("# %fus per point\n", (System.nanoTime() - t0) * 1e-3 / 100000);
+ System.out.printf("# %d centroids\n", dist.centroidCount());
+
+ // I would be happier with 5x compression, but repeated values make things kind of weird
+ assertTrue("Summary is too large", dist.centroidCount() < 10 * (double) 1000);
+
+ // all quantiles should round to nearest actual value
+ for (int i = 0; i < 10; i++) {
+ double z = i / 10.0;
+ // we skip over troublesome points that are exactly halfway between
+ for (double q = z + 0.002; q < z + 0.09; q += 0.005) {
+ double cdf = dist.cdf(q);
+ // we also relax the tolerances for repeated values
+ assertEquals(String.format("z=%.1f, q = %.3f, cdf = %.3f", z, q, cdf), z + 0.05, cdf, 0.005);
+
+ double estimate = dist.quantile(q);
+ assertEquals(String.format("z=%.1f, q = %.3f, cdf = %.3f, estimate = %.3f", z, q, cdf, estimate), Math.rint(q * 10) / 10.0, estimate, 0.001);
+ }
+ }
+ }
+
+ @Test
+ public void testSequentialPoints() {
+ for (int i = 0; i < 5; i++) {
+ runTest(new AbstractContinousDistribution() {
+ double base = 0;
+
+ @Override
+ public double nextDouble() {
+ base += Math.PI * 1e-5;
+ return base;
+ }
+ }, 100, new double[]{0.001, 0.01, 0.1, 0.5, 0.9, 0.99, 0.999},
+ "sequential", true);
+ }
+ }
+
+ @Test
+ public void testSerialization() {
+ Random gen = RandomUtils.getRandom();
+ TDigest dist = new TDigest(100);
+ for (int i = 0; i < 100000; i++) {
+ double x = gen.nextDouble();
+ dist.add(x);
+ }
+ dist.compress();
+
+ ByteBuffer buf = ByteBuffer.allocate(20000);
+ dist.asBytes(buf);
+ assertTrue(buf.position() < 11000);
+ assertEquals(buf.position(), dist.byteSize());
+ buf.clear();
+
+ dist.asSmallBytes(buf);
+ assertTrue(buf.position() < 6000);
+ assertEquals(buf.position(), dist.smallByteSize());
+
+ System.out.printf("# big %d bytes\n", buf.position());
+
+ buf.flip();
+ TDigest dist2 = TDigest.fromBytes(buf);
+ assertEquals(dist.centroidCount(), dist2.centroidCount());
+ assertEquals(dist.compression(), dist2.compression(), 0);
+ assertEquals(dist.size(), dist2.size());
+
+ for (double q = 0; q < 1; q += 0.01) {
+ assertEquals(dist.quantile(q), dist2.quantile(q), 1e-8);
+ }
+
+ Iterator<? extends TDigest.Group> ix = dist2.centroids().iterator();
+ for (TDigest.Group group : dist.centroids()) {
+ assertTrue(ix.hasNext());
+ assertEquals(group.count(), ix.next().count());
+ }
+ assertFalse(ix.hasNext());
+
+ buf.flip();
+ dist.asSmallBytes(buf);
+ assertTrue(buf.position() < 6000);
+ System.out.printf("# small %d bytes\n", buf.position());
+
+ buf.flip();
+ dist2 = TDigest.fromBytes(buf);
+ assertEquals(dist.centroidCount(), dist2.centroidCount());
+ assertEquals(dist.compression(), dist2.compression(), 0);
+ assertEquals(dist.size(), dist2.size());
+
+ for (double q = 0; q < 1; q += 0.01) {
+ assertEquals(dist.quantile(q), dist2.quantile(q), 1e-6);
+ }
+
+ ix = dist2.centroids().iterator();
+ for (TDigest.Group group : dist.centroids()) {
+ assertTrue(ix.hasNext());
+ assertEquals(group.count(), ix.next().count());
+ }
+ assertFalse(ix.hasNext());
+ }
+
+ @Test
+ public void testIntEncoding() {
+ Random gen = RandomUtils.getRandom();
+ ByteBuffer buf = ByteBuffer.allocate(10000);
+ List<Integer> ref = Lists.newArrayList();
+ for (int i = 0; i < 3000; i++) {
+ int n = gen.nextInt();
+ n = n >>> (i / 100);
+ ref.add(n);
+ TDigest.encode(buf, n);
+ }
+
+ buf.flip();
+
+ for (int i = 0; i < 3000; i++) {
+ int n = TDigest.decode(buf);
+ assertEquals(String.format("%d:", i), ref.get(i).intValue(), n);
+ }
+ }
+
+ //@Test()
+ // very slow running data generator
+ public void testSizeControl() {
+ RandomWrapper gen = RandomUtils.getRandom();
+ System.out.printf("k\tsamples\tcompression\tsize1\tsize2\n");
+ for (int k = 0; k < 40; k++) {
+ for (int size : new int[]{10, 100, 1000, 10000}) {
+ for (double compression : new double[]{2, 5, 10, 20, 50, 100, 200, 500, 1000}) {
+ TDigest dist = new TDigest(compression);
+ for (int i = 0; i < size * 1000; i++) {
+ dist.add(gen.nextDouble());
+ }
+ System.out.printf("%d\t%d\t%.0f\t%d\t%d\n", k, size, compression, dist.smallByteSize(), dist.byteSize());
+ }
+ }
+ }
+ System.out.printf("\n");
+ }
+
+ @Test
+ public void testScaling() {
+ RandomWrapper gen = RandomUtils.getRandom();
+
+ System.out.printf("pass\tcompression\tq\terror\tsize\n");
+ // change to 50 passes for better graphs
+ for (int k = 0; k < 3; k++) {
+ List<Double> data = Lists.newArrayList();
+ for (int i = 0; i < 100000; i++) {
+ data.add(gen.nextDouble());
+ }
+ Collections.sort(data);
+
+ for (double compression : new double[]{2, 5, 10, 20, 50, 100, 200, 500, 1000}) {
+ TDigest dist = new TDigest(compression);
+ for (Double x : data) {
+ dist.add(x);
+ }
+ dist.compress();
+
+ for (double q : new double[]{0.001, 0.01, 0.1, 0.5}) {
+ double estimate = dist.quantile(q);
+ double actual = data.get((int) (q * data.size()));
+ System.out.printf("%d\t%.0f\t%.3f\t%.9f\t%d\n", k, compression, q, estimate - actual, dist.byteSize());
+ }
+ }
+ }
+ }
+
+ /**
+ * Builds estimates of the CDF of a bunch of data points and checks that the centroids are accurately
+ * positioned. Accuracy is assessed in terms of the estimated CDF which is much more stringent than
+ * checking position of quantiles with a single value for desired accuracy.
+ *
+ * @param gen Random number generator that generates desired values.
+ * @param sizeGuide Control for size of the histogram.
+ * @param tag Label for the output lines
+ * @param recordAllData True if the internal histogrammer should be set up to record all data it sees for
+ * diagnostic purposes.
+ */
+ private void runTest(AbstractContinousDistribution gen, double sizeGuide, double[] qValues, String tag, boolean recordAllData) {
+ TDigest dist = new TDigest(sizeGuide);
+ if (recordAllData) {
+ dist.recordAllData();
+ }
+
+ long t0 = System.nanoTime();
+ List<Double> data = Lists.newArrayList();
+ for (int i = 0; i < 100000; i++) {
+ double x = gen.nextDouble();
+ data.add(x);
+ dist.add(x);
+ }
+ dist.compress();
+ Collections.sort(data);
+
+ double[] xValues = qValues.clone();
+ for (int i = 0; i < qValues.length; i++) {
+ double ix = data.size() * qValues[i] - 0.5;
+ int index = (int) Math.floor(ix);
+ double p = ix - index;
+ xValues[i] = data.get(index) * (1 - p) + data.get(index + 1) * p;
+ }
+
+ double qz = 0;
+ int iz = 0;
+ for (TDigest.Group group : dist.centroids()) {
+ double q = (qz + group.count() / 2.0) / dist.size();
+ sizeDump.printf("%s\t%d\t%.6f\t%.3f\t%d\n", tag, iz, q, 4 * q * (1 - q) * dist.size() / dist.compression(), group.count());
+ qz += group.count();
+ iz++;
+ }
+
+ System.out.printf("# %fus per point\n", (System.nanoTime() - t0) * 1e-3 / 100000);
+ System.out.printf("# %d centroids\n", dist.centroidCount());
+
+ assertTrue("Summary is too large", dist.centroidCount() < 10 * sizeGuide);
+ for (int i = 0; i < xValues.length; i++) {
+ double x = xValues[i];
+ double q = qValues[i];
+ double estimate = dist.cdf(x);
+ errorDump.printf("%s\t%s\t%.8g\t%.8f\t%.8f\n", tag, "cdf", x, q, estimate - q);
+ assertEquals(q, estimate, 0.005);
+
+ estimate = cdf(dist.quantile(q), data);
+ errorDump.printf("%s\t%s\t%.8g\t%.8f\t%.8f\n", tag, "quantile", x, q, estimate - q);
+ assertEquals(q, estimate, 0.005);
+ }
+
+ if (recordAllData) {
+ Iterator<? extends TDigest.Group> ix = dist.centroids().iterator();
+ TDigest.Group b = ix.next();
+ TDigest.Group c = ix.next();
+ qz = b.count();
+ while (ix.hasNext()) {
+ TDigest.Group a = b;
+ b = c;
+ c = ix.next();
+ double left = (b.mean() - a.mean()) / 2;
+ double right = (c.mean() - b.mean()) / 2;
+
+ double q = (qz + b.count() / 2.0) / dist.size();
+ for (Double x : b.data()) {
+ deviationDump.printf("%s\t%.5f\t%d\t%.5g\t%.5g\t%.5g\t%.5g\t%.5f\n", tag, q, b.count(), x, b.mean(), left, right, (x - b.mean()) / (right + left));
+ }
+ qz += a.count();
+ }
+ }
+ }
+
+ @Test
+ public void testMerge() {
+ RandomWrapper gen = RandomUtils.getRandom();
+
+ for (int parts : new int[]{2, 5, 10, 20, 50, 100}) {
+ List<Double> data = Lists.newArrayList();
+
+ TDigest dist = new TDigest(100);
+ dist.recordAllData();
+
+ List<TDigest> many = Lists.newArrayList();
+ for (int i = 0; i < 100; i++) {
+ many.add(new TDigest(100).recordAllData());
+ }
+
+ // we accumulate the data into multiple sub-digests
+ List<TDigest> subs = Lists.newArrayList();
+ for (int i = 0; i < parts; i++) {
+ subs.add(new TDigest(50).recordAllData());
+ }
+
+ for (int i = 0; i < 100000; i++) {
+ double x = gen.nextDouble();
+ data.add(x);
+ dist.add(x);
+ subs.get(i % parts).add(x);
+ }
+ dist.compress();
+ Collections.sort(data);
+
+ // collect the raw data from the sub-digests
+ List<Double> data2 = Lists.newArrayList();
+ for (TDigest digest : subs) {
+ for (TDigest.Group group : digest.centroids()) {
+ Iterables.addAll(data2, group.data());
+ }
+ }
+ Collections.sort(data2);
+
+ // verify that the raw data all got recorded
+ assertEquals(data.size(), data2.size());
+ Iterator<Double> ix = data.iterator();
+ for (Double x : data2) {
+ assertEquals(ix.next(), x);
+ }
+
+ // now merge the sub-digests
+ TDigest dist2 = TDigest.merge(50, subs);
+
+ for (double q : new double[]{0.001, 0.01, 0.1, 0.2, 0.3, 0.5}) {
+ double z = quantile(q, data);
+ double e1 = dist.quantile(q) - z;
+ double e2 = dist2.quantile(q) - z;
+ System.out.printf("quantile\t%d\t%.6f\t%.6f\t%.6f\t%.6f\t%.6f\n", parts, q, z - q, e1, e2, Math.abs(e2) / q);
+ assertTrue(String.format("parts=%d, q=%.4f, e1=%.5f, e2=%.5f, rel=%.4f", parts, q, e1, e2, Math.abs(e2) / q), Math.abs(e2) / q < 0.1 && Math.abs(e2) < 0.01);
+ }
+
+ for (double x : new double[]{0.001, 0.01, 0.1, 0.2, 0.3, 0.5}) {
+ double z = cdf(x, data);
+ double e1 = dist.cdf(x) - z;
+ double e2 = dist2.cdf(x) - z;
+
+ System.out.printf("cdf\t%d\t%.6f\t%.6f\t%.6f\t%.6f\t%.6f\n", parts, x, z - x, e1, e2, Math.abs(e2) / x);
+ assertTrue(String.format("parts=%d, x=%.4f, e1=%.5f, e2=%.5f", parts, x, e1, e2), Math.abs(e2) / x < 0.1 && Math.abs(e2) < 0.01);
+ }
+ }
+ }
+
+ private double cdf(final double x, List<Double> data) {
+ int n1 = 0;
+ int n2 = 0;
+ for (Double v : data) {
+ n1 += (v < x) ? 1 : 0;
+ n2 += (v <= x) ? 1 : 0;
+ }
+ return (n1 + n2) / 2.0 / data.size();
+ }
+
+ private double quantile(final double q, List<Double> data) {
+ return data.get((int) Math.floor(data.size() * q));
+ }
+
+ @Before
+ public void setUp() {
+ RandomUtils.useTestSeed();
+ }
+}