You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ab...@apache.org on 2017/12/04 17:49:01 UTC

[17/50] lucene-solr:jira/solr-11458-2: SOLR-11674: Support ranges in the probability Stream Evaluator

SOLR-11674: Support ranges in the probability Stream Evaluator


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/7acccd51
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/7acccd51
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/7acccd51

Branch: refs/heads/jira/solr-11458-2
Commit: 7acccd51576117e18d131d18cbe6373ea091b1b1
Parents: 360902e
Author: Joel Bernstein <jb...@apache.org>
Authored: Sun Nov 26 21:32:05 2017 -0500
Committer: Joel Bernstein <jb...@apache.org>
Committed: Sun Nov 26 21:32:05 2017 -0500

----------------------------------------------------------------------
 .../solrj/io/eval/ProbabilityEvaluator.java     | 68 +++++++++++++++-----
 .../solrj/io/stream/StreamExpressionTest.java   | 17 +++++
 2 files changed, 68 insertions(+), 17 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/7acccd51/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ProbabilityEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ProbabilityEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ProbabilityEvaluator.java
index f0c25cb..092760c 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ProbabilityEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/ProbabilityEvaluator.java
@@ -20,10 +20,11 @@ import java.io.IOException;
 import java.util.Locale;
 
 import org.apache.commons.math3.distribution.IntegerDistribution;
+import org.apache.commons.math3.distribution.AbstractRealDistribution;
 import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
 import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
 
-public class ProbabilityEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
+public class ProbabilityEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
   protected static final long serialVersionUID = 1L;
 
   public ProbabilityEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
@@ -31,22 +32,55 @@ public class ProbabilityEvaluator extends RecursiveObjectEvaluator implements Tw
   }
 
   @Override
-  public Object doWork(Object first, Object second) throws IOException{
-    if(null == first){
-      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the first value",toExpression(constructingFactory)));
-    }
-    if(null == second){
-      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
-    }
-    if(!(first instanceof IntegerDistribution)){
-      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a IntegerDistribution",toExpression(constructingFactory), first.getClass().getSimpleName()));
-    }
-    if(!(second instanceof Number)){
-      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a Number",toExpression(constructingFactory), first.getClass().getSimpleName()));
-    }
+  public Object doWork(Object... values) throws IOException{
+
+    Object first = null;
+    Object second = null;
+    Object third = null;
+
+    if(values.length == 2) {
+      first = values[0];
+      second = values[1];
 
-    IntegerDistribution d = (IntegerDistribution) first;
-    Number predictOver = (Number) second;
-    return d.probability(predictOver.intValue());
+      if (null == first) {
+        throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the first value", toExpression(constructingFactory)));
+      }
+      if (null == second) {
+        throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - null found for the second value", toExpression(constructingFactory)));
+      }
+      if (!(first instanceof IntegerDistribution)) {
+        throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a IntegerDistributionm for probability at a specific value.", toExpression(constructingFactory), first.getClass().getSimpleName()));
+      }
+      if (!(second instanceof Number)) {
+        throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the second value, expecting a Number", toExpression(constructingFactory), first.getClass().getSimpleName()));
+      }
+
+      IntegerDistribution d = (IntegerDistribution) first;
+      Number predictOver = (Number) second;
+      return d.probability(predictOver.intValue());
+
+    } else if(values.length == 3) {
+      first = values[0];
+      second = values[1];
+      third = values[2];
+
+      if (!(first instanceof AbstractRealDistribution)) {
+        throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a RealDistribution for probability ranges", toExpression(constructingFactory), first.getClass().getSimpleName()));
+      }
+      if (!(second instanceof Number)) {
+        throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the second value, expecting a Number", toExpression(constructingFactory), first.getClass().getSimpleName()));
+      }
+
+      if (!(third instanceof Number)) {
+        throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the second value, expecting a Number", toExpression(constructingFactory), first.getClass().getSimpleName()));
+      }
+
+      AbstractRealDistribution realDistribution = (AbstractRealDistribution)first;
+      Number start = (Number) second;
+      Number end = (Number) third;
+      return realDistribution.probability(start.doubleValue(), end.doubleValue());
+    } else {
+      throw new IOException("The probability function expects 2 or 3 parameters");
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/7acccd51/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
index a2b6b58..bfee198 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
@@ -6431,6 +6431,23 @@ public class StreamExpressionTest extends SolrCloudTestCase {
   }
 
   @Test
+  public void testProbabilityRange() throws Exception {
+    String cexpr = "let(a=normalDistribution(500, 20), " +
+                       "b=probability(a, 520, 530))";
+    ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
+    paramsLoc.set("expr", cexpr);
+    paramsLoc.set("qt", "/stream");
+    String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+    TupleStream solrStream = new SolrStream(url, paramsLoc);
+    StreamContext context = new StreamContext();
+    solrStream.setStreamContext(context);
+    List<Tuple> tuples = getTuples(solrStream);
+    assertTrue(tuples.size() == 1);
+    Number prob = (Number)tuples.get(0).get("b");
+    assertEquals(prob.doubleValue(),  0.09184805266259899, 0.0);
+  }
+
+      @Test
   public void testDistributions() throws Exception {
     String cexpr = "let(a=normalDistribution(10, 2), " +
                        "b=sample(a, 250), " +