You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jb...@apache.org on 2018/03/05 01:23:00 UTC

lucene-solr:master: SOLR-12054: ebeAdd and ebeSubtract should support matrix operations

Repository: lucene-solr
Updated Branches:
  refs/heads/master 97299ed00 -> dc5db9b2f


SOLR-12054: ebeAdd and ebeSubtract should support matrix operations


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/dc5db9b2
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/dc5db9b2
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/dc5db9b2

Branch: refs/heads/master
Commit: dc5db9b2f1050f1d1fc545c33f117ae4ec867983
Parents: 97299ed
Author: Joel Bernstein <jb...@apache.org>
Authored: Sun Mar 4 20:22:33 2018 -0500
Committer: Joel Bernstein <jb...@apache.org>
Committed: Sun Mar 4 20:22:33 2018 -0500

----------------------------------------------------------------------
 .../client/solrj/io/eval/EBEAddEvaluator.java   |  39 ++++---
 .../solrj/io/eval/EBESubtractEvaluator.java     |  38 ++++---
 .../solrj/io/stream/StreamExpressionTest.java   | 103 ++++++++++++++-----
 3 files changed, 125 insertions(+), 55 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/dc5db9b2/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java
index d385770..0c86a95 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java
@@ -22,10 +22,12 @@ import java.util.List;
 import java.util.Locale;
 
 import org.apache.commons.math3.util.MathArrays;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
 import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
 import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
 
-public class EBEAddEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
+public class EBEAddEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
   protected static final long serialVersionUID = 1L;
 
   public EBEAddEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
@@ -40,23 +42,28 @@ public class EBEAddEvaluator extends RecursiveNumericEvaluator implements TwoVal
     if(null == second){
       throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
     }
-    if(!(first instanceof List<?>)){
-      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
-    }
-    if(!(second instanceof List<?>)){
-      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
-    }
 
-    double[] result =  MathArrays.ebeAdd(
-        ((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
-        ((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
-    );
+    if(first instanceof List && second instanceof List) {
+      double[] result = MathArrays.ebeAdd(
+          ((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
+          ((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
+      );
 
-    List<Number> numbers = new ArrayList();
-    for(double d : result) {
-      numbers.add(d);
-    }
+      List<Number> numbers = new ArrayList();
+      for (double d : result) {
+        numbers.add(d);
+      }
 
-    return numbers;
+      return numbers;
+    } else if(first instanceof Matrix && second instanceof Matrix) {
+      double[][] data1 = ((Matrix) first).getData();
+      double[][] data2 = ((Matrix) second).getData();
+      Array2DRowRealMatrix matrix1 = new Array2DRowRealMatrix(data1);
+      Array2DRowRealMatrix matrix2 = new Array2DRowRealMatrix(data2);
+      RealMatrix matrix3 = matrix1.add(matrix2);
+      return new Matrix(matrix3.getData());
+    } else {
+      throw new IOException("Parameters for ebeAdd must either be two numeric arrays or two matrices. ");
+    }
   }
 }

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/dc5db9b2/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java
index cd36e23..ac7a968 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java
@@ -21,11 +21,13 @@ import java.util.ArrayList;
 import java.util.List;
 import java.util.Locale;
 
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
 import org.apache.commons.math3.util.MathArrays;
 import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
 import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
 
-public class EBESubtractEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
+public class EBESubtractEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
   protected static final long serialVersionUID = 1L;
 
   public EBESubtractEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
@@ -40,23 +42,27 @@ public class EBESubtractEvaluator extends RecursiveNumericEvaluator implements T
     if(null == second){
       throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
     }
-    if(!(first instanceof List<?>)){
-      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
-    }
-    if(!(second instanceof List<?>)){
-      throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
-    }
+    if(first instanceof List && second instanceof List) {
+      double[] result = MathArrays.ebeSubtract(
+          ((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
+          ((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
+      );
 
-    double[] result =  MathArrays.ebeSubtract(
-        ((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
-        ((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
-    );
+      List<Number> numbers = new ArrayList();
+      for (double d : result) {
+        numbers.add(d);
+      }
 
-    List<Number> numbers = new ArrayList();
-    for(double d : result) {
-      numbers.add(d);
+      return numbers;
+    } else if(first instanceof Matrix && second instanceof Matrix) {
+      double[][] data1 = ((Matrix) first).getData();
+      double[][] data2 = ((Matrix) second).getData();
+      Array2DRowRealMatrix matrix1 = new Array2DRowRealMatrix(data1);
+      Array2DRowRealMatrix matrix2 = new Array2DRowRealMatrix(data2);
+      RealMatrix matrix3 = matrix1.subtract(matrix2);
+      return new Matrix(matrix3.getData());
+    } else {
+      throw new IOException("Parameters for ebeSubtract must either be two numeric arrays or two matrices. ");
     }
-
-    return numbers;
   }
 }

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/dc5db9b2/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
index c44dedc..039abec 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
@@ -6975,9 +6975,19 @@ public class StreamExpressionTest extends SolrCloudTestCase {
     assertEquals(termVectors.get(0).size(), 0);
   }
 
+
+
   @Test
-  public void testEBESubtract() throws Exception {
-    String cexpr = "ebeSubtract(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
+  public void testEbeSubtract() throws Exception {
+    String cexpr = "let(echo=true," +
+        "               a=array(2, 4, 6, 8, 10, 12)," +
+        "               b=array(1, 2, 3, 4, 5, 6)," +
+        "               c=ebeSubtract(a,b)," +
+        "               d=array(10, 11, 12, 13, 14, 15)," +
+        "               e=array(100, 200, 300, 400, 500, 600)," +
+        "               f=matrix(a, b)," +
+        "               g=matrix(d, e)," +
+        "               h=ebeSubtract(f, g))";
     ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
     paramsLoc.set("expr", cexpr);
     paramsLoc.set("qt", "/stream");
@@ -6987,18 +6997,37 @@ public class StreamExpressionTest extends SolrCloudTestCase {
     solrStream.setStreamContext(context);
     List<Tuple> tuples = getTuples(solrStream);
     assertTrue(tuples.size() == 1);
-    List<Number> out = (List<Number>)tuples.get(0).get("return-value");
-    assertTrue(out.size() == 6);
-    assertTrue(out.get(0).intValue() == 1);
-    assertTrue(out.get(1).intValue() == 2);
-    assertTrue(out.get(2).intValue() == 3);
-    assertTrue(out.get(3).intValue() == 4);
-    assertTrue(out.get(4).intValue() == 5);
-    assertTrue(out.get(5).intValue() == 6);
+    List<Number> out = (List<Number>)tuples.get(0).get("c");
+    assertEquals(out.size(), 6);
+    assertEquals(out.get(0).doubleValue(), 1.0, 0.0);
+    assertEquals(out.get(1).doubleValue(), 2.0, 0.0);
+    assertEquals(out.get(2).doubleValue(), 3.0, 0.0);
+    assertEquals(out.get(3).doubleValue(), 4.0, 0.0);
+    assertEquals(out.get(4).doubleValue(), 5.0, 0.0);
+    assertEquals(out.get(5).doubleValue(), 6.0, 0.0);
+
+    List<List<Number>> mout = (List<List<Number>>)tuples.get(0).get("h");
+    assertEquals(mout.size(), 2);
+    List<Number> row1 = mout.get(0);
+    assertEquals(row1.size(), 6);
+    assertEquals(row1.get(0).doubleValue(), -8.0, 0.0);
+    assertEquals(row1.get(1).doubleValue(), -7.0, 0.0);
+    assertEquals(row1.get(2).doubleValue(), -6.0, 0.0);
+    assertEquals(row1.get(3).doubleValue(), -5.0, 0.0);
+    assertEquals(row1.get(4).doubleValue(), -4.0, 0.0);
+    assertEquals(row1.get(5).doubleValue(), -3.0, 0.0);
+
+    List<Number> row2 = mout.get(1);
+    assertEquals(row2.size(), 6);
+    assertEquals(row2.get(0).doubleValue(), -99.0, 0.0);
+    assertEquals(row2.get(1).doubleValue(), -198.0, 0.0);
+    assertEquals(row2.get(2).doubleValue(), -297.0, 0.0);
+    assertEquals(row2.get(3).doubleValue(), -396.0, 0.0);
+    assertEquals(row2.get(4).doubleValue(), -495.0, 0.0);
+    assertEquals(row2.get(5).doubleValue(), -594.0, 0.0);
   }
 
 
-
   @Test
   public void testMatrixMult() throws Exception {
     String cexpr = "let(echo=true," +
@@ -7341,7 +7370,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
   }
 
   @Test
-  public void testEBEMultiply() throws Exception {
+  public void testEbeMultiply() throws Exception {
     String cexpr = "ebeMultiply(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
     ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
     paramsLoc.set("expr", cexpr);
@@ -7364,8 +7393,16 @@ public class StreamExpressionTest extends SolrCloudTestCase {
 
 
   @Test
-  public void testEBEAdd() throws Exception {
-    String cexpr = "ebeAdd(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
+  public void testEbeAdd() throws Exception {
+    String cexpr = "let(echo=true," +
+        "               a=array(2, 4, 6, 8, 10, 12)," +
+        "               b=array(1, 2, 3, 4, 5, 6)," +
+        "               c=ebeAdd(a,b)," +
+        "               d=array(10, 11, 12, 13, 14, 15)," +
+        "               e=array(100, 200, 300, 400, 500, 600)," +
+        "               f=matrix(a, b)," +
+        "               g=matrix(d, e)," +
+        "               h=ebeAdd(f, g))";
     ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
     paramsLoc.set("expr", cexpr);
     paramsLoc.set("qt", "/stream");
@@ -7375,19 +7412,39 @@ public class StreamExpressionTest extends SolrCloudTestCase {
     solrStream.setStreamContext(context);
     List<Tuple> tuples = getTuples(solrStream);
     assertTrue(tuples.size() == 1);
-    List<Number> out = (List<Number>)tuples.get(0).get("return-value");
-    assertTrue(out.size() == 6);
-    assertTrue(out.get(0).intValue() == 3);
-    assertTrue(out.get(1).intValue() == 6);
-    assertTrue(out.get(2).intValue() == 9);
-    assertTrue(out.get(3).intValue() == 12);
-    assertTrue(out.get(4).intValue() == 15);
-    assertTrue(out.get(5).intValue() == 18);
+    List<Number> out = (List<Number>)tuples.get(0).get("c");
+    assertEquals(out.size(), 6);
+    assertEquals(out.get(0).doubleValue(), 3.0, 0.0);
+    assertEquals(out.get(1).doubleValue(), 6.0, 0.0);
+    assertEquals(out.get(2).doubleValue(), 9.0, 0.0);
+    assertEquals(out.get(3).doubleValue(), 12.0, 0.0);
+    assertEquals(out.get(4).doubleValue(), 15.0, 0.0);
+    assertEquals(out.get(5).doubleValue(), 18.0, 0.0);
+
+    List<List<Number>> mout = (List<List<Number>>)tuples.get(0).get("h");
+    assertEquals(mout.size(), 2);
+    List<Number> row1 = mout.get(0);
+    assertEquals(row1.size(), 6);
+    assertEquals(row1.get(0).doubleValue(), 12.0, 0.0);
+    assertEquals(row1.get(1).doubleValue(), 15.0, 0.0);
+    assertEquals(row1.get(2).doubleValue(), 18.0, 0.0);
+    assertEquals(row1.get(3).doubleValue(), 21.0, 0.0);
+    assertEquals(row1.get(4).doubleValue(), 24.0, 0.0);
+    assertEquals(row1.get(5).doubleValue(), 27.0, 0.0);
+
+    List<Number> row2 = mout.get(1);
+    assertEquals(row2.size(), 6);
+    assertEquals(row2.get(0).doubleValue(), 101.0, 0.0);
+    assertEquals(row2.get(1).doubleValue(), 202.0, 0.0);
+    assertEquals(row2.get(2).doubleValue(), 303.0, 0.0);
+    assertEquals(row2.get(3).doubleValue(), 404.0, 0.0);
+    assertEquals(row2.get(4).doubleValue(), 505.0, 0.0);
+    assertEquals(row2.get(5).doubleValue(), 606.0, 0.0);
   }
 
 
   @Test
-  public void testEBEDivide() throws Exception {
+  public void testEbeDivide() throws Exception {
     String cexpr = "ebeDivide(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
     ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
     paramsLoc.set("expr", cexpr);