You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jb...@apache.org on 2018/03/05 01:23:00 UTC
lucene-solr:master: SOLR-12054: ebeAdd and ebeSubtract should support
matrix operations
Repository: lucene-solr
Updated Branches:
refs/heads/master 97299ed00 -> dc5db9b2f
SOLR-12054: ebeAdd and ebeSubtract should support matrix operations
Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/dc5db9b2
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/dc5db9b2
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/dc5db9b2
Branch: refs/heads/master
Commit: dc5db9b2f1050f1d1fc545c33f117ae4ec867983
Parents: 97299ed
Author: Joel Bernstein <jb...@apache.org>
Authored: Sun Mar 4 20:22:33 2018 -0500
Committer: Joel Bernstein <jb...@apache.org>
Committed: Sun Mar 4 20:22:33 2018 -0500
----------------------------------------------------------------------
.../client/solrj/io/eval/EBEAddEvaluator.java | 39 ++++---
.../solrj/io/eval/EBESubtractEvaluator.java | 38 ++++---
.../solrj/io/stream/StreamExpressionTest.java | 103 ++++++++++++++-----
3 files changed, 125 insertions(+), 55 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/dc5db9b2/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java
index d385770..0c86a95 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBEAddEvaluator.java
@@ -22,10 +22,12 @@ import java.util.List;
import java.util.Locale;
import org.apache.commons.math3.util.MathArrays;
+import org.apache.commons.math3.linear.RealMatrix;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
-public class EBEAddEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
+public class EBEAddEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
protected static final long serialVersionUID = 1L;
public EBEAddEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
@@ -40,23 +42,28 @@ public class EBEAddEvaluator extends RecursiveNumericEvaluator implements TwoVal
if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
}
- if(!(first instanceof List<?>)){
- throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
- }
- if(!(second instanceof List<?>)){
- throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
- }
- double[] result = MathArrays.ebeAdd(
- ((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
- ((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
- );
+ if(first instanceof List && second instanceof List) {
+ double[] result = MathArrays.ebeAdd(
+ ((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
+ ((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
+ );
- List<Number> numbers = new ArrayList();
- for(double d : result) {
- numbers.add(d);
- }
+ List<Number> numbers = new ArrayList();
+ for (double d : result) {
+ numbers.add(d);
+ }
- return numbers;
+ return numbers;
+ } else if(first instanceof Matrix && second instanceof Matrix) {
+ double[][] data1 = ((Matrix) first).getData();
+ double[][] data2 = ((Matrix) second).getData();
+ Array2DRowRealMatrix matrix1 = new Array2DRowRealMatrix(data1);
+ Array2DRowRealMatrix matrix2 = new Array2DRowRealMatrix(data2);
+ RealMatrix matrix3 = matrix1.add(matrix2);
+ return new Matrix(matrix3.getData());
+ } else {
+ throw new IOException("Parameters for ebeAdd must either be two numeric arrays or two matrices. ");
+ }
}
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/dc5db9b2/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java
index cd36e23..ac7a968 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/EBESubtractEvaluator.java
@@ -21,11 +21,13 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Locale;
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.linear.RealMatrix;
import org.apache.commons.math3.util.MathArrays;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
-public class EBESubtractEvaluator extends RecursiveNumericEvaluator implements TwoValueWorker {
+public class EBESubtractEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
protected static final long serialVersionUID = 1L;
public EBESubtractEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
@@ -40,23 +42,27 @@ public class EBESubtractEvaluator extends RecursiveNumericEvaluator implements T
if(null == second){
throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - null found for the second value",toExpression(constructingFactory)));
}
- if(!(first instanceof List<?>)){
- throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the first value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
- }
- if(!(second instanceof List<?>)){
- throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for the second value, expecting a list of numbers",toExpression(constructingFactory), first.getClass().getSimpleName()));
- }
+ if(first instanceof List && second instanceof List) {
+ double[] result = MathArrays.ebeSubtract(
+ ((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
+ ((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
+ );
- double[] result = MathArrays.ebeSubtract(
- ((List) first).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray(),
- ((List) second).stream().mapToDouble(value -> ((Number) value).doubleValue()).toArray()
- );
+ List<Number> numbers = new ArrayList();
+ for (double d : result) {
+ numbers.add(d);
+ }
- List<Number> numbers = new ArrayList();
- for(double d : result) {
- numbers.add(d);
+ return numbers;
+ } else if(first instanceof Matrix && second instanceof Matrix) {
+ double[][] data1 = ((Matrix) first).getData();
+ double[][] data2 = ((Matrix) second).getData();
+ Array2DRowRealMatrix matrix1 = new Array2DRowRealMatrix(data1);
+ Array2DRowRealMatrix matrix2 = new Array2DRowRealMatrix(data2);
+ RealMatrix matrix3 = matrix1.subtract(matrix2);
+ return new Matrix(matrix3.getData());
+ } else {
+ throw new IOException("Parameters for ebeSubtract must either be two numeric arrays or two matrices. ");
}
-
- return numbers;
}
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/dc5db9b2/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
index c44dedc..039abec 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
@@ -6975,9 +6975,19 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertEquals(termVectors.get(0).size(), 0);
}
+
+
@Test
- public void testEBESubtract() throws Exception {
- String cexpr = "ebeSubtract(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
+ public void testEbeSubtract() throws Exception {
+ String cexpr = "let(echo=true," +
+ " a=array(2, 4, 6, 8, 10, 12)," +
+ " b=array(1, 2, 3, 4, 5, 6)," +
+ " c=ebeSubtract(a,b)," +
+ " d=array(10, 11, 12, 13, 14, 15)," +
+ " e=array(100, 200, 300, 400, 500, 600)," +
+ " f=matrix(a, b)," +
+ " g=matrix(d, e)," +
+ " h=ebeSubtract(f, g))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
@@ -6987,18 +6997,37 @@ public class StreamExpressionTest extends SolrCloudTestCase {
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
- List<Number> out = (List<Number>)tuples.get(0).get("return-value");
- assertTrue(out.size() == 6);
- assertTrue(out.get(0).intValue() == 1);
- assertTrue(out.get(1).intValue() == 2);
- assertTrue(out.get(2).intValue() == 3);
- assertTrue(out.get(3).intValue() == 4);
- assertTrue(out.get(4).intValue() == 5);
- assertTrue(out.get(5).intValue() == 6);
+ List<Number> out = (List<Number>)tuples.get(0).get("c");
+ assertEquals(out.size(), 6);
+ assertEquals(out.get(0).doubleValue(), 1.0, 0.0);
+ assertEquals(out.get(1).doubleValue(), 2.0, 0.0);
+ assertEquals(out.get(2).doubleValue(), 3.0, 0.0);
+ assertEquals(out.get(3).doubleValue(), 4.0, 0.0);
+ assertEquals(out.get(4).doubleValue(), 5.0, 0.0);
+ assertEquals(out.get(5).doubleValue(), 6.0, 0.0);
+
+ List<List<Number>> mout = (List<List<Number>>)tuples.get(0).get("h");
+ assertEquals(mout.size(), 2);
+ List<Number> row1 = mout.get(0);
+ assertEquals(row1.size(), 6);
+ assertEquals(row1.get(0).doubleValue(), -8.0, 0.0);
+ assertEquals(row1.get(1).doubleValue(), -7.0, 0.0);
+ assertEquals(row1.get(2).doubleValue(), -6.0, 0.0);
+ assertEquals(row1.get(3).doubleValue(), -5.0, 0.0);
+ assertEquals(row1.get(4).doubleValue(), -4.0, 0.0);
+ assertEquals(row1.get(5).doubleValue(), -3.0, 0.0);
+
+ List<Number> row2 = mout.get(1);
+ assertEquals(row2.size(), 6);
+ assertEquals(row2.get(0).doubleValue(), -99.0, 0.0);
+ assertEquals(row2.get(1).doubleValue(), -198.0, 0.0);
+ assertEquals(row2.get(2).doubleValue(), -297.0, 0.0);
+ assertEquals(row2.get(3).doubleValue(), -396.0, 0.0);
+ assertEquals(row2.get(4).doubleValue(), -495.0, 0.0);
+ assertEquals(row2.get(5).doubleValue(), -594.0, 0.0);
}
-
@Test
public void testMatrixMult() throws Exception {
String cexpr = "let(echo=true," +
@@ -7341,7 +7370,7 @@ public class StreamExpressionTest extends SolrCloudTestCase {
}
@Test
- public void testEBEMultiply() throws Exception {
+ public void testEbeMultiply() throws Exception {
String cexpr = "ebeMultiply(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
@@ -7364,8 +7393,16 @@ public class StreamExpressionTest extends SolrCloudTestCase {
@Test
- public void testEBEAdd() throws Exception {
- String cexpr = "ebeAdd(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
+ public void testEbeAdd() throws Exception {
+ String cexpr = "let(echo=true," +
+ " a=array(2, 4, 6, 8, 10, 12)," +
+ " b=array(1, 2, 3, 4, 5, 6)," +
+ " c=ebeAdd(a,b)," +
+ " d=array(10, 11, 12, 13, 14, 15)," +
+ " e=array(100, 200, 300, 400, 500, 600)," +
+ " f=matrix(a, b)," +
+ " g=matrix(d, e)," +
+ " h=ebeAdd(f, g))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
@@ -7375,19 +7412,39 @@ public class StreamExpressionTest extends SolrCloudTestCase {
solrStream.setStreamContext(context);
List<Tuple> tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
- List<Number> out = (List<Number>)tuples.get(0).get("return-value");
- assertTrue(out.size() == 6);
- assertTrue(out.get(0).intValue() == 3);
- assertTrue(out.get(1).intValue() == 6);
- assertTrue(out.get(2).intValue() == 9);
- assertTrue(out.get(3).intValue() == 12);
- assertTrue(out.get(4).intValue() == 15);
- assertTrue(out.get(5).intValue() == 18);
+ List<Number> out = (List<Number>)tuples.get(0).get("c");
+ assertEquals(out.size(), 6);
+ assertEquals(out.get(0).doubleValue(), 3.0, 0.0);
+ assertEquals(out.get(1).doubleValue(), 6.0, 0.0);
+ assertEquals(out.get(2).doubleValue(), 9.0, 0.0);
+ assertEquals(out.get(3).doubleValue(), 12.0, 0.0);
+ assertEquals(out.get(4).doubleValue(), 15.0, 0.0);
+ assertEquals(out.get(5).doubleValue(), 18.0, 0.0);
+
+ List<List<Number>> mout = (List<List<Number>>)tuples.get(0).get("h");
+ assertEquals(mout.size(), 2);
+ List<Number> row1 = mout.get(0);
+ assertEquals(row1.size(), 6);
+ assertEquals(row1.get(0).doubleValue(), 12.0, 0.0);
+ assertEquals(row1.get(1).doubleValue(), 15.0, 0.0);
+ assertEquals(row1.get(2).doubleValue(), 18.0, 0.0);
+ assertEquals(row1.get(3).doubleValue(), 21.0, 0.0);
+ assertEquals(row1.get(4).doubleValue(), 24.0, 0.0);
+ assertEquals(row1.get(5).doubleValue(), 27.0, 0.0);
+
+ List<Number> row2 = mout.get(1);
+ assertEquals(row2.size(), 6);
+ assertEquals(row2.get(0).doubleValue(), 101.0, 0.0);
+ assertEquals(row2.get(1).doubleValue(), 202.0, 0.0);
+ assertEquals(row2.get(2).doubleValue(), 303.0, 0.0);
+ assertEquals(row2.get(3).doubleValue(), 404.0, 0.0);
+ assertEquals(row2.get(4).doubleValue(), 505.0, 0.0);
+ assertEquals(row2.get(5).doubleValue(), 606.0, 0.0);
}
@Test
- public void testEBEDivide() throws Exception {
+ public void testEbeDivide() throws Exception {
String cexpr = "ebeDivide(array(2,4,6,8,10,12),array(1,2,3,4,5,6))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);