You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@commons.apache.org by er...@apache.org on 2018/01/12 13:39:54 UTC
[4/5] commons-rng git commit: Avoid recomputations.
Avoid recomputations.
Project: http://git-wip-us.apache.org/repos/asf/commons-rng/repo
Commit: http://git-wip-us.apache.org/repos/asf/commons-rng/commit/daf0f7b0
Tree: http://git-wip-us.apache.org/repos/asf/commons-rng/tree/daf0f7b0
Diff: http://git-wip-us.apache.org/repos/asf/commons-rng/diff/daf0f7b0
Branch: refs/heads/master
Commit: daf0f7b0d3e0fa33d13743d35ddee18fca7b2bf6
Parents: cb807a1
Author: Gilles <er...@apache.org>
Authored: Fri Jan 12 13:51:26 2018 +0100
Committer: Gilles <er...@apache.org>
Committed: Fri Jan 12 13:51:26 2018 +0100
----------------------------------------------------------------------
.../AhrensDieterMarsagliaTsangGammaSampler.java | 34 +++++++++++++-------
1 file changed, 22 insertions(+), 12 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/commons-rng/blob/daf0f7b0/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java
----------------------------------------------------------------------
diff --git a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java
index 233b4af..146888a 100644
--- a/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java
+++ b/commons-rng-sampling/src/main/java/org/apache/commons/rng/sampling/distribution/AhrensDieterMarsagliaTsangGammaSampler.java
@@ -42,10 +42,20 @@ import org.apache.commons.rng.UniformRandomProvider;
public class AhrensDieterMarsagliaTsangGammaSampler
extends SamplerBase
implements ContinuousSampler {
+ /** 1/3 */
+ private static final double ONE_THIRD = 1d / 3;
/** The shape parameter. */
private final double theta;
/** The alpha parameter. */
private final double alpha;
+ /** Inverse of "theta". */
+ private final double oneOverTheta;
+ /** Optimization (see code). */
+ private final double bGSOptim;
+ /** Optimization (see code). */
+ private final double dOptim;
+ /** Optimization (see code). */
+ private final double cOptim;
/** Gaussian sampling. */
private final NormalizedGaussianSampler gaussian;
@@ -61,6 +71,10 @@ public class AhrensDieterMarsagliaTsangGammaSampler
this.alpha = alpha;
this.theta = theta;
gaussian = new ZigguratNormalizedGaussianSampler(rng);
+ oneOverTheta = 1 / theta;
+ bGSOptim = 1 + theta / Math.E;
+ dOptim = theta - ONE_THIRD;
+ cOptim = ONE_THIRD / Math.sqrt(dOptim);
}
/** {@inheritDoc} */
@@ -72,13 +86,12 @@ public class AhrensDieterMarsagliaTsangGammaSampler
while (true) {
// Step 1:
final double u = nextDouble();
- final double bGS = 1 + theta / Math.E;
- final double p = bGS * u;
+ final double p = bGSOptim * u;
if (p <= 1) {
// Step 2:
- final double x = Math.pow(p, 1 / theta);
+ final double x = Math.pow(p, oneOverTheta);
final double u2 = nextDouble();
if (u2 > Math.exp(-x)) {
@@ -90,7 +103,7 @@ public class AhrensDieterMarsagliaTsangGammaSampler
} else {
// Step 3:
- final double x = -1 * Math.log((bGS - p) / theta);
+ final double x = -Math.log((bGSOptim - p) * oneOverTheta);
final double u2 = nextDouble();
if (u2 > Math.pow(x, theta - 1)) {
@@ -104,13 +117,10 @@ public class AhrensDieterMarsagliaTsangGammaSampler
}
// Now theta >= 1.
-
- final double d = theta - 0.333333333333333333;
- final double c = 1 / (3 * Math.sqrt(d));
-
while (true) {
final double x = gaussian.sample();
- final double v = (1 + c * x) * (1 + c * x) * (1 + c * x);
+ final double oPcTx = (1 + cOptim * x);
+ final double v = oPcTx * oPcTx * oPcTx;
if (v <= 0) {
continue;
@@ -121,11 +131,11 @@ public class AhrensDieterMarsagliaTsangGammaSampler
// Squeeze.
if (u < 1 - 0.0331 * x2 * x2) {
- return alpha * d * v;
+ return alpha * dOptim * v;
}
- if (Math.log(u) < 0.5 * x2 + d * (1 - v + Math.log(v))) {
- return alpha * d * v;
+ if (Math.log(u) < 0.5 * x2 + dOptim * (1 - v + Math.log(v))) {
+ return alpha * dOptim * v;
}
}
}