You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jb...@apache.org on 2018/08/09 01:07:45 UTC
lucene-solr:master: SOLR-11863: Add knnRegress Stream Evaluator to
support nearest neighbor regression
Repository: lucene-solr
Updated Branches:
refs/heads/master e9f3a3ce1 -> cb1db4825
SOLR-11863: Add knnRegress Stream Evaluator to support nearest neighbor regression
Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/cb1db482
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/cb1db482
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/cb1db482
Branch: refs/heads/master
Commit: cb1db482523cf33b7927b5155d506202d8ddbd89
Parents: e9f3a3c
Author: Joel Bernstein <jb...@apache.org>
Authored: Wed Aug 8 21:05:02 2018 -0400
Committer: Joel Bernstein <jb...@apache.org>
Committed: Wed Aug 8 21:05:21 2018 -0400
----------------------------------------------------------------------
.../org/apache/solr/client/solrj/io/Lang.java | 1 +
.../solr/client/solrj/io/eval/KnnEvaluator.java | 18 +-
.../solrj/io/eval/KnnRegressionEvaluator.java | 194 +++++++++++++++++++
.../solrj/io/eval/MinMaxScaleEvaluator.java | 2 +-
.../client/solrj/io/eval/PredictEvaluator.java | 30 ++-
.../apache/solr/client/solrj/io/TestLang.java | 2 +-
.../solrj/io/stream/MathExpressionTest.java | 71 +++++++
7 files changed, 311 insertions(+), 7 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/cb1db482/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java
index a01a841..6f170c4 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java
@@ -246,6 +246,7 @@ public class Lang {
.withFunctionName("zeros", ZerosEvaluator.class)
.withFunctionName("getValue", GetValueEvaluator.class)
.withFunctionName("setValue", SetValueEvaluator.class)
+ .withFunctionName("knnRegress", KnnRegressionEvaluator.class)
// Boolean Stream Evaluators
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/cb1db482/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
index 251e092..81607cf 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
@@ -67,8 +67,6 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
throw new IOException("The third parameter for knn should be k.");
}
- double[][] data = matrix.getData();
-
DistanceMeasure distanceMeasure = null;
if(values.length == 4) {
@@ -77,6 +75,15 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
distanceMeasure = new EuclideanDistance();
}
+ return search(matrix, vec, k, distanceMeasure);
+ }
+
+ public static Matrix search(Matrix observations,
+ double[] vec,
+ int k,
+ DistanceMeasure distanceMeasure) {
+
+ double[][] data = observations.getData();
TreeSet<Neighbor> neighbors = new TreeSet();
for(int i=0; i<data.length; i++) {
double distance = distanceMeasure.compute(vec, data[i]);
@@ -87,8 +94,9 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
}
double[][] out = new double[neighbors.size()][];
- List<String> rowLabels = matrix.getRowLabels();
+ List<String> rowLabels = observations.getRowLabels();
List<String> newRowLabels = new ArrayList();
+ List<Number> indexes = new ArrayList();
List<Number> distances = new ArrayList();
int i=-1;
@@ -102,6 +110,7 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
out[++i] = data[rowIndex];
distances.add(neighbor.getDistance());
+ indexes.add(rowIndex);
}
Matrix knn = new Matrix(out);
@@ -110,8 +119,9 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
knn.setRowLabels(newRowLabels);
}
- knn.setColumnLabels(matrix.getColumnLabels());
+ knn.setColumnLabels(observations.getColumnLabels());
knn.setAttribute("distances", distances);
+ knn.setAttribute("indexes", indexes);
return knn;
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/cb1db482/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
new file mode 100644
index 0000000..957936e
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.eval;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.HashMap;
+
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+
+public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
+ protected static final long serialVersionUID = 1L;
+
+ public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
+ super(expression, factory);
+ }
+
+ @Override
+ public Object doWork(Object ... values) throws IOException {
+
+ if(values.length < 3) {
+ throw new IOException("knnRegress expects atleast three parameters: an observation matrix, an outcomes vector and k.");
+ }
+
+ Matrix observations = null;
+ List<Number> outcomes = null;
+ int k = 5;
+ DistanceMeasure distanceMeasure = new EuclideanDistance();
+
+ if(values[0] instanceof Matrix) {
+ observations = (Matrix)values[0];
+ } else {
+ throw new IOException("The first parameter for knnRegress should be the observation matrix.");
+ }
+
+ if(values[1] instanceof List) {
+ outcomes = (List) values[1];
+ } else {
+ throw new IOException("The second parameter for knnRegress should be outcome array. ");
+ }
+
+ if(values[2] instanceof Number) {
+ k = ((Number) values[2]).intValue();
+ } else {
+ throw new IOException("The third parameter for knnRegress should be k. ");
+ }
+
+ if(values.length == 4) {
+ if(values[3] instanceof DistanceMeasure) {
+ distanceMeasure = (DistanceMeasure) values[3];
+ throw new IOException("The fourth parameter for knnRegress should be a distance measure. ");
+ }
+ }
+
+ double[] outcomeData = new double[outcomes.size()];
+ for(int i=0; i<outcomeData.length; i++) {
+ outcomeData[i] = outcomes.get(i).doubleValue();
+ }
+
+ Map map = new HashMap();
+ map.put("k", k);
+ map.put("observations", observations.getRowCount());
+ map.put("features", observations.getColumnCount());
+ map.put("distance", distanceMeasure.getClass().getSimpleName());
+
+ return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map);
+ }
+
+
+ public static class KnnRegressionTuple extends Tuple {
+
+ private Matrix observations;
+ private Matrix scaledObservations;
+ private double[] outcomes;
+ private int k;
+ private DistanceMeasure distanceMeasure;
+
+ public KnnRegressionTuple(Matrix observations,
+ double[] outcomes,
+ int k,
+ DistanceMeasure distanceMeasure,
+ Map<?,?> map) {
+ super(map);
+ this.observations = observations;
+ this.outcomes = outcomes;
+ this.k = k;
+ this.distanceMeasure = distanceMeasure;
+ }
+
+ //MinMax Scale both the observations and the predictors
+
+ public double[] scale(double[] predictors) {
+ double[][] data = observations.getData();
+ //We need to scale the columns of the data matrix with along with the predictors
+ Array2DRowRealMatrix matrix = new Array2DRowRealMatrix(data);
+ Array2DRowRealMatrix transposed = (Array2DRowRealMatrix) matrix.transpose();
+ double[][] featureRows = transposed.getDataRef();
+
+ double[] scaledPredictors = new double[predictors.length];
+
+ for(int i=0; i<featureRows.length; i++) {
+ double[] featureRow = featureRows[i];
+ double[] combinedFeatureRow = new double[featureRow.length+1];
+ System.arraycopy(featureRow, 0, combinedFeatureRow, 0, featureRow.length);
+ combinedFeatureRow[featureRow.length] = predictors[i]; // Add the last feature from the predictor
+ double[] scaledFeatures = MinMaxScaleEvaluator.scale(combinedFeatureRow, 0, 1);
+ scaledPredictors[i] = scaledFeatures[featureRow.length];
+ System.arraycopy(scaledFeatures, 0, featureRow, 0, featureRow.length);
+ }
+
+ Array2DRowRealMatrix scaledFeatureMatrix = new Array2DRowRealMatrix(featureRows);
+
+
+ Array2DRowRealMatrix scaledObservationsMatrix= (Array2DRowRealMatrix)scaledFeatureMatrix.transpose();
+ this.scaledObservations = new Matrix(scaledObservationsMatrix.getDataRef());
+ return scaledPredictors;
+ }
+
+
+ public Matrix scale(Matrix predictors) {
+ double[][] observationData = observations.getData();
+ //We need to scale the columns of the data matrix with along with the predictors
+ Array2DRowRealMatrix observationMatrix = new Array2DRowRealMatrix(observationData);
+ Array2DRowRealMatrix observationTransposed = (Array2DRowRealMatrix) observationMatrix.transpose();
+ double[][] observationFeatureRows = observationTransposed.getDataRef();
+
+ double[][] predictorsData = predictors.getData();
+ //We need to scale the columns of the data matrix with along with the predictors
+ Array2DRowRealMatrix predictorMatrix = new Array2DRowRealMatrix(predictorsData);
+ Array2DRowRealMatrix predictorTransposed = (Array2DRowRealMatrix) predictorMatrix.transpose();
+ double[][] predictorFeatureRows = predictorTransposed.getDataRef();
+
+ for(int i=0; i<observationFeatureRows.length; i++) {
+ double[] observationFeatureRow = observationFeatureRows[i];
+ double[] predictorFeatureRow = predictorFeatureRows[i];
+ double[] combinedFeatureRow = new double[observationFeatureRow.length+predictorFeatureRow.length];
+ System.arraycopy(observationFeatureRow, 0, combinedFeatureRow, 0, observationFeatureRow.length);
+ System.arraycopy(predictorFeatureRow, 0, combinedFeatureRow, observationFeatureRow.length, predictorFeatureRow.length);
+
+ double[] scaledFeatures = MinMaxScaleEvaluator.scale(combinedFeatureRow, 0, 1);
+ System.arraycopy(scaledFeatures, 0, observationFeatureRow, 0, observationFeatureRow.length);
+ System.arraycopy(scaledFeatures, observationFeatureRow.length, predictorFeatureRow, 0, predictorFeatureRow.length);
+ }
+
+ Array2DRowRealMatrix scaledFeatureMatrix = new Array2DRowRealMatrix(observationFeatureRows);
+ Array2DRowRealMatrix scaledObservationsMatrix= (Array2DRowRealMatrix)scaledFeatureMatrix.transpose();
+ this.scaledObservations = new Matrix(scaledObservationsMatrix.getDataRef());
+
+ Array2DRowRealMatrix scaledPredictorMatrix = new Array2DRowRealMatrix(predictorFeatureRows);
+ Array2DRowRealMatrix scaledTransposedPredictorMatrix= (Array2DRowRealMatrix)scaledPredictorMatrix.transpose();
+ return new Matrix(scaledTransposedPredictorMatrix.getDataRef());
+ }
+
+
+ public double predict(double[] values) {
+
+ Matrix knn = KnnEvaluator.search(scaledObservations, values, k, distanceMeasure);
+ List<Number> indexes = (List<Number>)knn.getAttribute("indexes");
+
+ double sum = 0;
+
+ //Collect the outcomes for the nearest neighbors
+ for(Number n : indexes) {
+ sum += outcomes[n.intValue()];
+ }
+
+ //Return the average of the outcomes as the prediction.
+
+ return sum/((double)indexes.size());
+ }
+ }
+}
+
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/cb1db482/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java
index 60c6377..3996910 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java
@@ -76,7 +76,7 @@ public class MinMaxScaleEvaluator extends RecursiveObjectEvaluator implements Ma
}
}
- private double[] scale(double[] values, double min, double max) {
+ public static double[] scale(double[] values, double min, double max) {
double localMin = Double.MAX_VALUE;
double localMax = Double.MIN_VALUE;
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/cb1db482/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
index 2444370..9385928 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
@@ -43,7 +43,11 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
Object first = objects[0];
Object second = objects[1];
- if (!(first instanceof BivariateFunction) && !(first instanceof VectorFunction) && !(first instanceof RegressionEvaluator.RegressionTuple) && !(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple)) {
+ if (!(first instanceof BivariateFunction) &&
+ !(first instanceof VectorFunction) &&
+ !(first instanceof RegressionEvaluator.RegressionTuple) &&
+ !(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple) &&
+ !(first instanceof KnnRegressionEvaluator.KnnRegressionTuple)) {
throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a RegressionTuple", toExpression(constructingFactory), first.getClass().getSimpleName()));
}
@@ -83,6 +87,30 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
return predictions;
}
+ } else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple) {
+ KnnRegressionEvaluator.KnnRegressionTuple regressedTuple = (KnnRegressionEvaluator.KnnRegressionTuple) first;
+ if (second instanceof List) {
+ List<Number> list = (List<Number>) second;
+ double[] predictors = new double[list.size()];
+
+ for (int i = 0; i < list.size(); i++) {
+ predictors[i] = list.get(i).doubleValue();
+ }
+
+ predictors = regressedTuple.scale(predictors);
+
+ return regressedTuple.predict(predictors);
+ } else if (second instanceof Matrix) {
+
+ Matrix m = (Matrix) second;
+ m = regressedTuple.scale(m);
+ double[][] data = m.getData();
+ List<Number> predictions = new ArrayList();
+ for (double[] predictors : data) {
+ predictions.add(regressedTuple.predict(predictors));
+ }
+ return predictions;
+ }
} else if (first instanceof VectorFunction) {
VectorFunction vectorFunction = (VectorFunction) first;
UnivariateFunction univariateFunction = (UnivariateFunction)vectorFunction.getFunction();
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/cb1db482/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java
index 22b432f..df56844 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java
@@ -69,7 +69,7 @@ public class TestLang extends LuceneTestCase {
TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow",
"mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt",
"cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft", "euclidean","manhattan",
- "earthMovers", "canberra", "chebyshev", "ones", "zeros", "setValue", "getValue"};
+ "earthMovers", "canberra", "chebyshev", "ones", "zeros", "setValue", "getValue", "knnRegress"};
@Test
public void testLang() {
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/cb1db482/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
index 98a52a6..a9be57e 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
@@ -3395,6 +3395,77 @@ public class MathExpressionTest extends SolrCloudTestCase {
}
@Test
+ public void testKnnRegress() throws Exception {
+ String cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 10.19999981), " +
+ "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 7.699999809, 4.5)," +
+ "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
+ "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
+ "e=transpose(matrix(a, b, c))," +
+ "f=knnRegress(e, d, 1)," +
+ "g=predict(f, e))";
+ ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
+ paramsLoc.set("expr", cexpr);
+ paramsLoc.set("qt", "/stream");
+ String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+ TupleStream solrStream = new SolrStream(url, paramsLoc);
+ StreamContext context = new StreamContext();
+ solrStream.setStreamContext(context);
+ List<Tuple> tuples = getTuples(solrStream);
+ assertTrue(tuples.size() == 1);
+ List<Number> predictions = (List<Number>)tuples.get(0).get("g");
+ assertEquals(predictions.size(), 10);
+ //k=1 should bring back only one prediction for the exact match in the training set
+ assertEquals(predictions.get(0).doubleValue(), 85.09999847, 0);
+ assertEquals(predictions.get(1).doubleValue(), 106.3000031, 0);
+ assertEquals(predictions.get(2).doubleValue(), 50.20000076, 0);
+ assertEquals(predictions.get(3).doubleValue(), 130.6000061, 0);
+ assertEquals(predictions.get(4).doubleValue(), 54.79999924, 0);
+ assertEquals(predictions.get(5).doubleValue(), 30.29999924, 0);
+ assertEquals(predictions.get(6).doubleValue(), 79.40000153, 0);
+ assertEquals(predictions.get(7).doubleValue(), 91, 0);
+ assertEquals(predictions.get(8).doubleValue(), 135.3999939, 0);
+ assertEquals(predictions.get(9).doubleValue(), 89.30000305, 0);
+
+ cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 10.19999981), " +
+ "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 7.699999809, 4.5)," +
+ "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
+ "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
+ "e=transpose(matrix(a, b, c))," +
+ "f=knnRegress(e, d, 1)," +
+ "g=predict(f, array(8, 5, 4)))";
+ paramsLoc = new ModifiableSolrParams();
+ paramsLoc.set("expr", cexpr);
+ paramsLoc.set("qt", "/stream");
+ url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+ solrStream = new SolrStream(url, paramsLoc);
+ context = new StreamContext();
+ solrStream.setStreamContext(context);
+ tuples = getTuples(solrStream);
+ assertTrue(tuples.size() == 1);
+ Number prediction = (Number)tuples.get(0).get("g");
+ assertEquals(prediction.doubleValue(), 85.09999847, 0);
+
+ cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 8.19999981), " +
+ "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 5.699999809, 4.5)," +
+ "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 4.900000095)," +
+ "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
+ "e=transpose(matrix(a, b, c))," +
+ "f=knnRegress(e, d, 2)," +
+ "g=predict(f, array(8, 5, 4)))";
+ paramsLoc = new ModifiableSolrParams();
+ paramsLoc.set("expr", cexpr);
+ paramsLoc.set("qt", "/stream");
+ url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+ solrStream = new SolrStream(url, paramsLoc);
+ context = new StreamContext();
+ solrStream.setStreamContext(context);
+ tuples = getTuples(solrStream);
+ assertTrue(tuples.size() == 1);
+ prediction = (Number)tuples.get(0).get("g");
+ assertEquals(prediction.doubleValue(), 87.20000076, 0);
+ }
+
+ @Test
public void testPlot() throws Exception {
String cexpr = "let(a=array(3,2,3), plot(type=scatter, x=a, y=array(5,6,3)))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();