You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jb...@apache.org on 2019/03/05 14:38:14 UTC

[lucene-solr] branch master updated: SOLR-13287: Allow zplot to visualize probability distributions in Apache Zeppelin

This is an automated email from the ASF dual-hosted git repository.

jbernste pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git


The following commit(s) were added to refs/heads/master by this push:
     new c34c56b  SOLR-13287: Allow zplot to visualize probability distributions in Apache Zeppelin
c34c56b is described below

commit c34c56b7b25197ae21b5d3b330e53f1c86d26751
Author: Joel Bernstein <jb...@apache.org>
AuthorDate: Tue Mar 5 09:18:47 2019 -0500

    SOLR-13287: Allow zplot to visualize probability distributions in Apache Zeppelin
---
 .../io/eval/EmpiricalDistributionEvaluator.java    |  24 ++--
 .../solr/client/solrj/io/stream/TupStream.java     |  45 +++++--
 .../solr/client/solrj/io/stream/ZplotStream.java   | 106 +++++++++++++++-
 .../client/solrj/io/stream/MathExpressionTest.java | 139 ++++++++++++++-------
 4 files changed, 247 insertions(+), 67 deletions(-)

diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EmpiricalDistributionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EmpiricalDistributionEvaluator.java
index 945cdb0..19da016 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EmpiricalDistributionEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EmpiricalDistributionEvaluator.java
@@ -14,6 +14,7 @@
  * See the License for the specific language governing permissions and
  * limitations under the License.
  */
+
 package org.apache.solr.client.solrj.io.eval;
 
 import java.io.IOException;
@@ -24,27 +25,34 @@ import org.apache.commons.math3.random.EmpiricalDistribution;
 import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
 import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
 
-public class EmpiricalDistributionEvaluator extends RecursiveNumericEvaluator implements OneValueWorker {
+public class EmpiricalDistributionEvaluator extends RecursiveNumericEvaluator implements ManyValueWorker {
   protected static final long serialVersionUID = 1L;
+  private int bins = 99;
   
   public EmpiricalDistributionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
     super(expression, factory);
     
-    if(1 != containedEvaluators.size()){
-      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting exactly one value but found %d",expression,containedEvaluators.size()));
+    if(2 != containedEvaluators.size() && 1 != containedEvaluators.size()) {
+      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - expecting one or two values but found %d",expression,containedEvaluators.size()));
     }
   }
   
   @Override
-  public Object doWork(Object value) throws IOException {
+  public Object doWork(Object[] values) throws IOException {
+
+    if(!(values[0] instanceof List<?>)){
+      throw new StreamEvaluatorException("List value expected but found type %s for value %s", values[0].getClass().getName(), values[0].toString());
+    }
 
-    if(!(value instanceof List<?>)){
-      throw new StreamEvaluatorException("List value expected but found type %s for value %s", value.getClass().getName(), value.toString());
+    if(values.length == 2) {
+      Number n = (Number)values[1];
+      bins = n.intValue();
     }
 
-    EmpiricalDistribution empiricalDistribution = new EmpiricalDistribution();
+    EmpiricalDistribution empiricalDistribution = new EmpiricalDistribution(bins);
     
-    double[] backingValues = ((List<?>)value).stream().mapToDouble(innerValue -> ((Number)innerValue).doubleValue()).sorted().toArray();
+    double[] backingValues = ((List<?>)values[0]).stream().mapToDouble(innerValue -> ((Number)innerValue).doubleValue()).sorted().toArray();
+
     empiricalDistribution.load(backingValues);
 
     return empiricalDistribution;
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TupStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TupStream.java
index fde8298..a7bca77 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TupStream.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/TupStream.java
@@ -23,6 +23,7 @@ import java.util.List;
 import java.util.Locale;
 import java.util.Map;
 import java.util.Map.Entry;
+import java.util.Iterator;
 
 import org.apache.solr.client.solrj.io.Tuple;
 import org.apache.solr.client.solrj.io.comp.SingleValueComparator;
@@ -45,16 +46,19 @@ public class TupStream extends TupleStream implements Expressible {
 
   private static final long serialVersionUID = 1;
   private StreamContext streamContext;
-  
+
   private Map<String,String> stringParams = new HashMap<>();
   private Map<String,StreamEvaluator> evaluatorParams = new HashMap<>();
   private Map<String,TupleStream> streamParams = new HashMap<>();
   private List<String> fieldNames = new ArrayList();
   private Map<String, String> fieldLabels = new HashMap();
   private Tuple tup = null;
+  private Tuple unnestedTuple = null;
+  private Iterator<Tuple>  unnestedTuples = null;
   
   private boolean finished;
 
+
   public TupStream(StreamExpression expression, StreamFactory factory) throws IOException {
 
     List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
@@ -146,13 +150,27 @@ public class TupStream extends TupleStream implements Expressible {
 
   public Tuple read() throws IOException {
 
-    if(finished) {
-      Map<String,Object> m = new HashMap<>();
-      m.put("EOF", true);
-      return new Tuple(m);
+    if(unnestedTuples == null) {
+      if (finished) {
+        Map<String, Object> m = new HashMap<>();
+        m.put("EOF", true);
+        return new Tuple(m);
+      } else {
+        finished = true;
+        if(unnestedTuple != null) {
+          return unnestedTuple;
+        } else {
+          return tup;
+        }
+      }
     } else {
-      finished = true;
-      return tup;
+      if(unnestedTuples.hasNext()) {
+        return unnestedTuples.next();
+      } else {
+        Map<String, Object> m = new HashMap<>();
+        m.put("EOF", true);
+        return new Tuple(m);
+      }
     }
   }
 
@@ -202,6 +220,19 @@ public class TupStream extends TupleStream implements Expressible {
       }
     }
 
+    if(values.size() == 1) {
+      for(Object o :values.values()) {
+        if(o instanceof Tuple) {
+          unnestedTuple = (Tuple)o;
+        } else if(o instanceof List) {
+          List l = (List)o;
+          if(l.size() > 0 && l.get(0) instanceof Tuple) {
+            List<Tuple> tl = (List<Tuple>)l;
+            unnestedTuples = tl.iterator();
+          }
+        }
+      }
+    }
     this.tup = new Tuple(values);
     tup.fieldNames = fieldNames;
     tup.fieldLabels = fieldLabels;
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ZplotStream.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ZplotStream.java
index c5280dc..8f7165a 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ZplotStream.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/stream/ZplotStream.java
@@ -25,6 +25,12 @@ import java.util.List;
 import java.util.Map;
 import java.util.Set;
 
+import org.apache.commons.math3.distribution.IntegerDistribution;
+import org.apache.commons.math3.distribution.RealDistribution;
+import org.apache.commons.math3.random.EmpiricalDistribution;
+import org.apache.commons.math3.stat.Frequency;
+import org.apache.commons.math3.stat.descriptive.SummaryStatistics;
+import org.apache.commons.math3.util.Precision;
 import org.apache.solr.client.solrj.io.Tuple;
 import org.apache.solr.client.solrj.io.comp.StreamComparator;
 import org.apache.solr.client.solrj.io.eval.StreamEvaluator;
@@ -119,12 +125,15 @@ public class ZplotStream extends TupleStream implements Expressible {
     int numTuples = -1;
     int columns = 0;
     boolean table = false;
+    boolean distribution = false;
     for(Map.Entry<String, Object> entry : entries) {
       ++columns;
 
       String name = entry.getKey();
       if(name.equals("table")) {
         table = true;
+      } else if(name.equals("dist")) {
+        distribution = true;
       }
 
       Object o = entry.getValue();
@@ -145,6 +154,8 @@ public class ZplotStream extends TupleStream implements Expressible {
           evaluated.put(name, l);
         } else if (eo instanceof Tuple) {
           evaluated.put(name, eo);
+        } else {
+          evaluated.put(name, eo);
         }
       } else {
         Object eval = lets.get(o);
@@ -164,13 +175,13 @@ public class ZplotStream extends TupleStream implements Expressible {
       }
     }
 
-    if(columns > 1 && table) {
-      throw new IOException("If the table parameter is set there can only be one parameter.");
+    if(columns > 1 && (table || distribution)) {
+      throw new IOException("If the table or dist parameter is set there can only be one parameter.");
     }
     //Load the values into tuples
 
     List<Tuple> outTuples = new ArrayList();
-    if(!table) {
+    if(!table && !distribution) {
       //Handle the vectors
       for (int i = 0; i < numTuples; i++) {
         Tuple tuple = new Tuple(new HashMap());
@@ -181,7 +192,94 @@ public class ZplotStream extends TupleStream implements Expressible {
 
         outTuples.add(tuple);
       }
-    } else {
+    } else if(distribution) {
+      Object o = evaluated.get("dist");
+      if(o instanceof RealDistribution) {
+        RealDistribution realDistribution = (RealDistribution) o;
+        List<SummaryStatistics> binStats = null;
+        if(realDistribution instanceof  EmpiricalDistribution) {
+          EmpiricalDistribution empiricalDistribution = (EmpiricalDistribution)realDistribution;
+          binStats = empiricalDistribution.getBinStats();
+        } else {
+          double[] samples = realDistribution.sample(500000);
+          EmpiricalDistribution empiricalDistribution = new EmpiricalDistribution(32);
+          empiricalDistribution.load(samples);
+          binStats = empiricalDistribution.getBinStats();
+        }
+        double[] x = new double[binStats.size()];
+        double[] y = new double[binStats.size()];
+        for (int i = 0; i < binStats.size(); i++) {
+          x[i] = binStats.get(i).getMean();
+          y[i] = realDistribution.density(x[i]);
+        }
+
+        for (int i = 0; i < x.length; i++) {
+          Tuple tuple = new Tuple(new HashMap());
+          if(!Double.isNaN(x[i])) {
+            tuple.put("x", Precision.round(x[i], 2));
+            if(y[i] == Double.NEGATIVE_INFINITY || y[i] == Double.POSITIVE_INFINITY) {
+              tuple.put("y", 0);
+
+            } else {
+              tuple.put("y", y[i]);
+            }
+            outTuples.add(tuple);
+          }
+        }
+      } else if(o instanceof IntegerDistribution) {
+        IntegerDistribution integerDistribution = (IntegerDistribution)o;
+        int[] samples = integerDistribution.sample(50000);
+        Frequency frequency = new Frequency();
+        for(int i : samples) {
+          frequency.addValue(i);
+        }
+
+        Iterator it = frequency.valuesIterator();
+        List<Long> values = new ArrayList();
+        while(it.hasNext()) {
+          values.add((Long)it.next());
+        }
+        System.out.println(values);
+        int[] x = new int[values.size()];
+        double[] y = new double[values.size()];
+        for(int i=0; i<values.size(); i++) {
+          x[i] = values.get(i).intValue();
+          y[i] = integerDistribution.probability(x[i]);
+        }
+
+        for (int i = 0; i < x.length; i++) {
+          Tuple tuple = new Tuple(new HashMap());
+          tuple.put("x", x[i]);
+          tuple.put("y", y[i]);
+          outTuples.add(tuple);
+        }
+      } else if(o instanceof List) {
+        System.out.print("Is list");
+        List list = (List)o;
+        if(list.get(0) instanceof Tuple) {
+          System.out.print("Are tuples");
+          List<Tuple> tlist = (List<Tuple>)o;
+          Tuple tuple = tlist.get(0);
+          if(tuple.fields.containsKey("N")) {
+            System.out.println("Is hist");
+            for(Tuple t : tlist) {
+              Tuple outtuple = new Tuple(new HashMap());
+              outtuple.put("x", Precision.round(((double)t.get("mean")), 2));
+              outtuple.put("y", t.get("prob"));
+              outTuples.add(outtuple);
+            }
+          } else if(tuple.fields.containsKey("count")) {
+            System.out.println("Is freq");
+            for(Tuple t : tlist) {
+              Tuple outtuple = new Tuple(new HashMap());
+              outtuple.put("x", t.get("value"));
+              outtuple.put("y", t.get("pct"));
+              outTuples.add(outtuple);
+            }
+          }
+        }
+      }
+    } else if(table){
       //Handle the Tuple and List of Tuples
       Object o = evaluated.get("table");
       if(o instanceof List) {
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
index 6d8c9a8..6c7387b 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
@@ -369,11 +369,9 @@ public class MathExpressionTest extends SolrCloudTestCase {
     StreamContext context = new StreamContext();
     solrStream.setStreamContext(context);
     List<Tuple> tuples = getTuples(solrStream);
-    assertTrue(tuples.size() == 1);
-    List<Map> hist = (List<Map>)tuples.get(0).get("return-value");
-    assertTrue(hist.size() == 10);
-    for(int i=0; i<hist.size(); i++) {
-      Map stats = hist.get(i);
+    assertTrue(tuples.size() == 10);
+    for(int i=0; i<tuples.size(); i++) {
+      Tuple stats = tuples.get(i);
       assertTrue(((Number)stats.get("N")).intValue() == 10);
       assertTrue(((Number)stats.get("min")).intValue() == 10*i);
       assertTrue(((Number)stats.get("var")).doubleValue() == 9.166666666666666);
@@ -388,11 +386,10 @@ public class MathExpressionTest extends SolrCloudTestCase {
     solrStream = new SolrStream(url, paramsLoc);
     solrStream.setStreamContext(context);
     tuples = getTuples(solrStream);
-    assertTrue(tuples.size() == 1);
-    hist = (List<Map>)tuples.get(0).get("return-value");
-    assertTrue(hist.size() == 5);
-    for(int i=0; i<hist.size(); i++) {
-      Map stats = hist.get(i);
+    assertTrue(tuples.size() == 5);
+
+    for(int i=0; i<tuples.size(); i++) {
+      Tuple stats = tuples.get(i);
       assertTrue(((Number)stats.get("N")).intValue() == 20);
       assertTrue(((Number)stats.get("min")).intValue() == 20*i);
       assertTrue(((Number)stats.get("var")).doubleValue() == 35);
@@ -1476,6 +1473,56 @@ public class MathExpressionTest extends SolrCloudTestCase {
     assertEquals(out.getDouble("x").doubleValue(), 4.0, 0.0);
     assertEquals(out.getDouble("y").doubleValue(), 13.0, 0.0);
 
+    cexpr = "zplot(dist=binomialDistribution(10, .50))";
+
+    paramsLoc = new ModifiableSolrParams();
+    paramsLoc.set("expr", cexpr);
+    paramsLoc.set("qt", "/stream");
+    solrStream = new SolrStream(url, paramsLoc);
+    context = new StreamContext();
+    solrStream.setStreamContext(context);
+    tuples = getTuples(solrStream);
+    assertEquals(tuples.size(),11);
+    long x = tuples.get(5).getLong("x");
+    double y = tuples.get(5).getDouble("y");
+
+    assertEquals(x, 5);
+    assertEquals(y,     0.24609375000000003, 0);
+
+    //Due to random errors (bugs) in Apache Commons Math EmpiricalDistribution
+    //there are times when tuples are discarded because
+    //they contain values with NaN values. This will occur
+    //only on the very end of the tails of the normal distribution or other
+    //real distributions and doesn't effect the visual quality of the curve very much.
+    //But it does effect the reliability of tests.
+    //For this reason the loop below is in place to run the test N times looking
+    //for the correct number of tuples before asserting the mean.
+
+    int n = 0;
+    int limit = 15;
+    while(true) {
+      cexpr = "zplot(dist=normalDistribution(100, 10))";
+      paramsLoc = new ModifiableSolrParams();
+      paramsLoc.set("expr", cexpr);
+      paramsLoc.set("qt", "/stream");
+      solrStream = new SolrStream(url, paramsLoc);
+      context = new StreamContext();
+      solrStream.setStreamContext(context);
+      tuples = getTuples(solrStream);
+      //Assert the mean
+      if (tuples.size() == 32) {
+        double x1 = tuples.get(15).getDouble("x");
+        double y1 = tuples.get(15).getDouble("y");
+        assertEquals(x1, 100, 10);
+        assertEquals(y1, .039, .02);
+        break;
+      } else {
+        ++n;
+        if(n == limit) {
+          throw new Exception("Reached iterations limit without correct tuple count.");
+        }
+      }
+    }
   }
 
 
@@ -1751,14 +1798,13 @@ public class MathExpressionTest extends SolrCloudTestCase {
     StreamContext context = new StreamContext();
     solrStream.setStreamContext(context);
     List<Tuple> tuples = getTuples(solrStream);
-    assertTrue(tuples.size() == 1);
-    List<Map<String, Number>> out = (List<Map<String, Number>>)tuples.get(0).get("f");
-    assertEquals(out.size(), 2);
-    Map<String, Number> bin0 = out.get(0);
-    double state0Pct = bin0.get("pct").doubleValue();
+    assertTrue(tuples.size() == 2);
+
+    Tuple bin0 = tuples.get(0);
+    double state0Pct = bin0.getDouble("pct");
     assertEquals(state0Pct, .5, .015);
-    Map<String, Number> bin1 = out.get(1);
-    double state1Pct = bin1.get("pct").doubleValue();
+    Tuple bin1 = tuples.get(1);
+    double state1Pct = bin1.getDouble("pct");
     assertEquals(state1Pct, .5, .015);
   }
 
@@ -2933,32 +2979,30 @@ public class MathExpressionTest extends SolrCloudTestCase {
     StreamContext context = new StreamContext();
     solrStream.setStreamContext(context);
     List<Tuple> tuples = getTuples(solrStream);
-    assertTrue(tuples.size() == 1);
-    List<Map<String,Number>> out = (List<Map<String, Number>>)tuples.get(0).get("return-value");
-    assertTrue(out.size() == 6);
-    Map<String, Number> bucket = out.get(0);
-    assertEquals(bucket.get("value").longValue(), 2);
-    assertEquals(bucket.get("count").longValue(), 2);
-
-    bucket = out.get(1);
-    assertEquals(bucket.get("value").longValue(), 4);
-    assertEquals(bucket.get("count").longValue(), 2);
-
-    bucket = out.get(2);
-    assertEquals(bucket.get("value").longValue(), 6);
-    assertEquals(bucket.get("count").longValue(), 1);
-
-    bucket = out.get(3);
-    assertEquals(bucket.get("value").longValue(), 8);
-    assertEquals(bucket.get("count").longValue(), 4);
-
-    bucket = out.get(4);
-    assertEquals(bucket.get("value").longValue(), 10);
-    assertEquals(bucket.get("count").longValue(), 1);
-
-    bucket = out.get(5);
-    assertEquals(bucket.get("value").longValue(), 12);
-    assertEquals(bucket.get("count").longValue(), 2);
+    assertTrue(tuples.size() == 6);
+    Tuple bucket = tuples.get(0);
+    assertEquals(bucket.getLong("value").longValue(), 2);
+    assertEquals(bucket.getLong("count").longValue(), 2);
+
+    bucket = tuples.get(1);
+    assertEquals(bucket.getLong("value").longValue(), 4);
+    assertEquals(bucket.getLong("count").longValue(), 2);
+
+    bucket = tuples.get(2);
+    assertEquals(bucket.getLong("value").longValue(), 6);
+    assertEquals(bucket.getLong("count").longValue(), 1);
+
+    bucket = tuples.get(3);
+    assertEquals(bucket.getLong("value").longValue(), 8);
+    assertEquals(bucket.getLong("count").longValue(), 4);
+
+    bucket = tuples.get(4);
+    assertEquals(bucket.getLong("value").longValue(), 10);
+    assertEquals(bucket.getLong("count").longValue(), 1);
+
+    bucket = tuples.get(5);
+    assertEquals(bucket.getLong("value").longValue(), 12);
+    assertEquals(bucket.getLong("count").longValue(), 2);
   }
 
   @Test
@@ -4062,7 +4106,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
     solrStream.setStreamContext(context);
     List<Tuple> tuples = getTuples(solrStream);
     assertTrue(tuples.size() == 1);
-    Map out = (Map)tuples.get(0).get("return-value");
+    Tuple out = tuples.get(0);
     assertEquals((double) out.get("p-value"), 0.788298D, .0001);
     assertEquals((double) out.get("f-ratio"), 0.24169D, .0001);
   }
@@ -4347,7 +4391,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
     solrStream.setStreamContext(context);
     List<Tuple> tuples = getTuples(solrStream);
     assertTrue(tuples.size() == 1);
-    Map out = (Map)tuples.get(0).get("return-value");
+    Tuple out = tuples.get(0);
     assertEquals((double) out.get("u-statistic"), 52.5, .1);
     assertEquals((double) out.get("p-value"), 0.7284, .001);
   }
@@ -5142,7 +5186,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
 
     String expr1 = "search("+COLLECTIONORALIAS+", q=\"col_s:a\", fl=\"price_f, order_i\", sort=\"order_i asc\")";
 
-    String cexpr = "let(a="+expr1+", b=col(a, price_f),  tuple(stats=describe(b)))";
+    String cexpr = "let(a="+expr1+", b=col(a, price_f),  stats=describe(b))";
 
     ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
     paramsLoc.set("expr", cexpr);
@@ -5155,8 +5199,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
     solrStream.setStreamContext(context);
     List<Tuple> tuples = getTuples(solrStream);
     assertTrue(tuples.size() == 1);
-    Tuple tuple = tuples.get(0);
-    Map stats = (Map)tuple.get("stats");
+    Tuple stats = tuples.get(0);
     Number min = (Number)stats.get("min");
     Number max = (Number)stats.get("max");
     Number mean = (Number)stats.get("mean");