You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jb...@apache.org on 2018/08/09 01:16:38 UTC

lucene-solr:branch_7x: SOLR-11863: Add knnRegress Stream Evaluator to support nearest neighbor regression

Repository: lucene-solr
Updated Branches:
  refs/heads/branch_7x 13b9e28f9 -> 19647d802


SOLR-11863: Add knnRegress Stream Evaluator to support nearest neighbor regression


Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/19647d80
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/19647d80
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/19647d80

Branch: refs/heads/branch_7x
Commit: 19647d8023f23de6acbdc21426bf32afd8ffc1b1
Parents: 13b9e28f
Author: Joel Bernstein <jb...@apache.org>
Authored: Wed Aug 8 21:05:02 2018 -0400
Committer: Joel Bernstein <jb...@apache.org>
Committed: Wed Aug 8 21:11:50 2018 -0400

----------------------------------------------------------------------
 .../org/apache/solr/client/solrj/io/Lang.java   |   1 +
 .../solr/client/solrj/io/eval/KnnEvaluator.java |  18 +-
 .../solrj/io/eval/KnnRegressionEvaluator.java   | 194 +++++++++++++++++++
 .../solrj/io/eval/MinMaxScaleEvaluator.java     |   2 +-
 .../client/solrj/io/eval/PredictEvaluator.java  |  30 ++-
 .../apache/solr/client/solrj/io/TestLang.java   |   2 +-
 .../solrj/io/stream/MathExpressionTest.java     |  71 +++++++
 7 files changed, 311 insertions(+), 7 deletions(-)
----------------------------------------------------------------------


http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/19647d80/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java
index a01a841..6f170c4 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/Lang.java
@@ -246,6 +246,7 @@ public class Lang {
         .withFunctionName("zeros", ZerosEvaluator.class)
         .withFunctionName("getValue", GetValueEvaluator.class)
         .withFunctionName("setValue", SetValueEvaluator.class)
+        .withFunctionName("knnRegress", KnnRegressionEvaluator.class)
 
         // Boolean Stream Evaluators
 

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/19647d80/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
index 251e092..81607cf 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
@@ -67,8 +67,6 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
       throw new IOException("The third parameter for knn should be k.");
     }
 
-    double[][] data = matrix.getData();
-
     DistanceMeasure distanceMeasure = null;
 
     if(values.length == 4) {
@@ -77,6 +75,15 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
       distanceMeasure = new EuclideanDistance();
     }
 
+    return search(matrix, vec, k, distanceMeasure);
+  }
+
+  public static Matrix search(Matrix observations,
+                              double[] vec,
+                              int k,
+                              DistanceMeasure distanceMeasure) {
+
+    double[][] data = observations.getData();
     TreeSet<Neighbor> neighbors = new TreeSet();
     for(int i=0; i<data.length; i++) {
       double distance = distanceMeasure.compute(vec, data[i]);
@@ -87,8 +94,9 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
     }
 
     double[][] out = new double[neighbors.size()][];
-    List<String> rowLabels = matrix.getRowLabels();
+    List<String> rowLabels = observations.getRowLabels();
     List<String> newRowLabels = new ArrayList();
+    List<Number> indexes = new ArrayList();
     List<Number> distances = new ArrayList();
     int i=-1;
 
@@ -102,6 +110,7 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
 
       out[++i] = data[rowIndex];
       distances.add(neighbor.getDistance());
+      indexes.add(rowIndex);
     }
 
     Matrix knn = new Matrix(out);
@@ -110,8 +119,9 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
       knn.setRowLabels(newRowLabels);
     }
 
-    knn.setColumnLabels(matrix.getColumnLabels());
+    knn.setColumnLabels(observations.getColumnLabels());
     knn.setAttribute("distances", distances);
+    knn.setAttribute("indexes", indexes);
     return knn;
   }
 

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/19647d80/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
new file mode 100644
index 0000000..957936e
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
@@ -0,0 +1,194 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *     http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.eval;
+
+import java.io.IOException;
+import java.util.List;
+import java.util.Map;
+import java.util.HashMap;
+
+import org.apache.commons.math3.linear.Array2DRowRealMatrix;
+import org.apache.commons.math3.ml.distance.DistanceMeasure;
+import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+
+public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
+  protected static final long serialVersionUID = 1L;
+
+  public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
+    super(expression, factory);
+  }
+
+  @Override
+  public Object doWork(Object ... values) throws IOException {
+
+    if(values.length < 3) {
+      throw new IOException("knnRegress expects atleast three parameters: an observation matrix, an outcomes vector and k.");
+    }
+
+    Matrix observations = null;
+    List<Number> outcomes = null;
+    int k = 5;
+    DistanceMeasure distanceMeasure = new EuclideanDistance();
+
+    if(values[0] instanceof Matrix) {
+      observations = (Matrix)values[0];
+    } else {
+      throw new IOException("The first parameter for knnRegress should be the observation matrix.");
+    }
+
+    if(values[1] instanceof List) {
+      outcomes = (List) values[1];
+    } else {
+      throw new IOException("The second parameter for knnRegress should be outcome array. ");
+    }
+
+    if(values[2] instanceof Number) {
+      k = ((Number) values[2]).intValue();
+    } else {
+      throw new IOException("The third parameter for knnRegress should be k. ");
+    }
+
+    if(values.length == 4) {
+      if(values[3] instanceof DistanceMeasure) {
+        distanceMeasure = (DistanceMeasure) values[3];
+        throw new IOException("The fourth parameter for knnRegress should be a distance measure. ");
+      }
+    }
+
+    double[] outcomeData = new double[outcomes.size()];
+    for(int i=0; i<outcomeData.length; i++) {
+      outcomeData[i] = outcomes.get(i).doubleValue();
+    }
+
+    Map map = new HashMap();
+    map.put("k", k);
+    map.put("observations", observations.getRowCount());
+    map.put("features", observations.getColumnCount());
+    map.put("distance", distanceMeasure.getClass().getSimpleName());
+
+    return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map);
+  }
+
+
+  public static class KnnRegressionTuple extends Tuple {
+
+    private Matrix observations;
+    private Matrix scaledObservations;
+    private double[] outcomes;
+    private int k;
+    private DistanceMeasure distanceMeasure;
+
+    public KnnRegressionTuple(Matrix observations,
+                              double[] outcomes,
+                              int k,
+                              DistanceMeasure distanceMeasure,
+                              Map<?,?> map) {
+      super(map);
+      this.observations = observations;
+      this.outcomes = outcomes;
+      this.k = k;
+      this.distanceMeasure = distanceMeasure;
+    }
+
+    //MinMax Scale both the observations and the predictors
+
+    public double[] scale(double[] predictors) {
+      double[][] data = observations.getData();
+      //We need to scale the columns of the data matrix with along with the predictors
+      Array2DRowRealMatrix matrix = new Array2DRowRealMatrix(data);
+      Array2DRowRealMatrix transposed = (Array2DRowRealMatrix) matrix.transpose();
+      double[][] featureRows = transposed.getDataRef();
+
+      double[] scaledPredictors = new double[predictors.length];
+
+      for(int i=0; i<featureRows.length; i++) {
+        double[] featureRow = featureRows[i];
+        double[] combinedFeatureRow = new double[featureRow.length+1];
+        System.arraycopy(featureRow, 0, combinedFeatureRow, 0, featureRow.length);
+        combinedFeatureRow[featureRow.length] = predictors[i];  // Add the last feature from the predictor
+        double[] scaledFeatures = MinMaxScaleEvaluator.scale(combinedFeatureRow, 0, 1);
+        scaledPredictors[i] = scaledFeatures[featureRow.length];
+        System.arraycopy(scaledFeatures, 0, featureRow, 0, featureRow.length);
+      }
+
+      Array2DRowRealMatrix scaledFeatureMatrix = new Array2DRowRealMatrix(featureRows);
+
+
+      Array2DRowRealMatrix scaledObservationsMatrix= (Array2DRowRealMatrix)scaledFeatureMatrix.transpose();
+      this.scaledObservations = new Matrix(scaledObservationsMatrix.getDataRef());
+      return scaledPredictors;
+    }
+
+
+    public Matrix scale(Matrix predictors) {
+      double[][] observationData = observations.getData();
+      //We need to scale the columns of the data matrix with along with the predictors
+      Array2DRowRealMatrix observationMatrix = new Array2DRowRealMatrix(observationData);
+      Array2DRowRealMatrix observationTransposed = (Array2DRowRealMatrix) observationMatrix.transpose();
+      double[][] observationFeatureRows = observationTransposed.getDataRef();
+
+      double[][] predictorsData = predictors.getData();
+      //We need to scale the columns of the data matrix with along with the predictors
+      Array2DRowRealMatrix predictorMatrix = new Array2DRowRealMatrix(predictorsData);
+      Array2DRowRealMatrix predictorTransposed = (Array2DRowRealMatrix) predictorMatrix.transpose();
+      double[][] predictorFeatureRows = predictorTransposed.getDataRef();
+
+      for(int i=0; i<observationFeatureRows.length; i++) {
+        double[] observationFeatureRow = observationFeatureRows[i];
+        double[] predictorFeatureRow = predictorFeatureRows[i];
+        double[] combinedFeatureRow = new double[observationFeatureRow.length+predictorFeatureRow.length];
+        System.arraycopy(observationFeatureRow, 0, combinedFeatureRow, 0, observationFeatureRow.length);
+        System.arraycopy(predictorFeatureRow, 0, combinedFeatureRow, observationFeatureRow.length, predictorFeatureRow.length);
+
+        double[] scaledFeatures = MinMaxScaleEvaluator.scale(combinedFeatureRow, 0, 1);
+        System.arraycopy(scaledFeatures, 0, observationFeatureRow, 0, observationFeatureRow.length);
+        System.arraycopy(scaledFeatures, observationFeatureRow.length, predictorFeatureRow, 0, predictorFeatureRow.length);
+      }
+
+      Array2DRowRealMatrix scaledFeatureMatrix = new Array2DRowRealMatrix(observationFeatureRows);
+      Array2DRowRealMatrix scaledObservationsMatrix= (Array2DRowRealMatrix)scaledFeatureMatrix.transpose();
+      this.scaledObservations = new Matrix(scaledObservationsMatrix.getDataRef());
+
+      Array2DRowRealMatrix scaledPredictorMatrix = new Array2DRowRealMatrix(predictorFeatureRows);
+      Array2DRowRealMatrix scaledTransposedPredictorMatrix= (Array2DRowRealMatrix)scaledPredictorMatrix.transpose();
+      return new Matrix(scaledTransposedPredictorMatrix.getDataRef());
+    }
+
+
+    public double predict(double[] values) {
+
+      Matrix knn = KnnEvaluator.search(scaledObservations, values, k, distanceMeasure);
+      List<Number> indexes = (List<Number>)knn.getAttribute("indexes");
+
+      double sum = 0;
+
+      //Collect the outcomes for the nearest neighbors
+      for(Number n : indexes) {
+        sum += outcomes[n.intValue()];
+      }
+
+      //Return the average of the outcomes as the prediction.
+
+      return sum/((double)indexes.size());
+    }
+  }
+}
+

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/19647d80/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java
index 60c6377..3996910 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/MinMaxScaleEvaluator.java
@@ -76,7 +76,7 @@ public class MinMaxScaleEvaluator extends RecursiveObjectEvaluator implements Ma
     }
   }
 
-  private double[] scale(double[] values, double min, double max) {
+  public static double[] scale(double[] values, double min, double max) {
 
     double localMin = Double.MAX_VALUE;
     double localMax = Double.MIN_VALUE;

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/19647d80/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
index 2444370..9385928 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
@@ -43,7 +43,11 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
     Object first = objects[0];
     Object second = objects[1];
 
-    if (!(first instanceof BivariateFunction) && !(first instanceof VectorFunction) && !(first instanceof RegressionEvaluator.RegressionTuple) && !(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple)) {
+    if (!(first instanceof BivariateFunction) &&
+        !(first instanceof VectorFunction) &&
+        !(first instanceof RegressionEvaluator.RegressionTuple) &&
+        !(first instanceof OLSRegressionEvaluator.MultipleRegressionTuple) &&
+        !(first instanceof KnnRegressionEvaluator.KnnRegressionTuple)) {
       throw new IOException(String.format(Locale.ROOT, "Invalid expression %s - found type %s for the first value, expecting a RegressionTuple", toExpression(constructingFactory), first.getClass().getSimpleName()));
     }
 
@@ -83,6 +87,30 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
         return predictions;
       }
 
+    } else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple) {
+      KnnRegressionEvaluator.KnnRegressionTuple regressedTuple = (KnnRegressionEvaluator.KnnRegressionTuple) first;
+      if (second instanceof List) {
+        List<Number> list = (List<Number>) second;
+        double[] predictors = new double[list.size()];
+
+        for (int i = 0; i < list.size(); i++) {
+          predictors[i] = list.get(i).doubleValue();
+        }
+
+        predictors = regressedTuple.scale(predictors);
+
+        return regressedTuple.predict(predictors);
+      } else if (second instanceof Matrix) {
+
+        Matrix m = (Matrix) second;
+        m = regressedTuple.scale(m);
+        double[][] data = m.getData();
+        List<Number> predictions = new ArrayList();
+        for (double[] predictors : data) {
+          predictions.add(regressedTuple.predict(predictors));
+        }
+        return predictions;
+      }
     } else if (first instanceof VectorFunction) {
       VectorFunction vectorFunction = (VectorFunction) first;
       UnivariateFunction univariateFunction = (UnivariateFunction)vectorFunction.getFunction();

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/19647d80/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java
index 22b432f..df56844 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/TestLang.java
@@ -69,7 +69,7 @@ public class TestLang extends LuceneTestCase {
        TemporalEvaluatorDayOfQuarter.FUNCTION_NAME, "abs", "add", "div", "mult", "sub", "log", "pow",
       "mod", "ceil", "floor", "sin", "asin", "sinh", "cos", "acos", "cosh", "tan", "atan", "tanh", "round", "sqrt",
       "cbrt", "coalesce", "uuid", "if", "convert", "valueAt", "memset", "fft", "ifft", "euclidean","manhattan",
-      "earthMovers", "canberra", "chebyshev", "ones", "zeros", "setValue", "getValue"};
+      "earthMovers", "canberra", "chebyshev", "ones", "zeros", "setValue", "getValue", "knnRegress"};
 
   @Test
   public void testLang() {

http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/19647d80/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
index 98a52a6..a9be57e 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
@@ -3395,6 +3395,77 @@ public class MathExpressionTest extends SolrCloudTestCase {
   }
 
   @Test
+  public void testKnnRegress() throws Exception {
+    String cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 10.19999981), " +
+                                  "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 7.699999809, 4.5)," +
+                                  "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
+                                  "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
+        "e=transpose(matrix(a, b, c))," +
+        "f=knnRegress(e, d, 1)," +
+        "g=predict(f, e))";
+    ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
+    paramsLoc.set("expr", cexpr);
+    paramsLoc.set("qt", "/stream");
+    String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+    TupleStream solrStream = new SolrStream(url, paramsLoc);
+    StreamContext context = new StreamContext();
+    solrStream.setStreamContext(context);
+    List<Tuple> tuples = getTuples(solrStream);
+    assertTrue(tuples.size() == 1);
+    List<Number> predictions = (List<Number>)tuples.get(0).get("g");
+    assertEquals(predictions.size(), 10);
+    //k=1 should bring back only one prediction for the exact match in the training set
+    assertEquals(predictions.get(0).doubleValue(), 85.09999847, 0);
+    assertEquals(predictions.get(1).doubleValue(), 106.3000031, 0);
+    assertEquals(predictions.get(2).doubleValue(), 50.20000076, 0);
+    assertEquals(predictions.get(3).doubleValue(), 130.6000061, 0);
+    assertEquals(predictions.get(4).doubleValue(), 54.79999924, 0);
+    assertEquals(predictions.get(5).doubleValue(), 30.29999924, 0);
+    assertEquals(predictions.get(6).doubleValue(), 79.40000153, 0);
+    assertEquals(predictions.get(7).doubleValue(), 91, 0);
+    assertEquals(predictions.get(8).doubleValue(), 135.3999939, 0);
+    assertEquals(predictions.get(9).doubleValue(), 89.30000305, 0);
+
+    cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 10.19999981), " +
+        "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 7.699999809, 4.5)," +
+        "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
+        "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
+        "e=transpose(matrix(a, b, c))," +
+        "f=knnRegress(e, d, 1)," +
+        "g=predict(f, array(8, 5, 4)))";
+    paramsLoc = new ModifiableSolrParams();
+    paramsLoc.set("expr", cexpr);
+    paramsLoc.set("qt", "/stream");
+    url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+    solrStream = new SolrStream(url, paramsLoc);
+    context = new StreamContext();
+    solrStream.setStreamContext(context);
+    tuples = getTuples(solrStream);
+    assertTrue(tuples.size() == 1);
+    Number prediction = (Number)tuples.get(0).get("g");
+    assertEquals(prediction.doubleValue(), 85.09999847, 0);
+
+    cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 8.19999981), " +
+        "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 5.699999809, 4.5)," +
+        "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 4.900000095)," +
+        "d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
+        "e=transpose(matrix(a, b, c))," +
+        "f=knnRegress(e, d, 2)," +
+        "g=predict(f, array(8, 5, 4)))";
+    paramsLoc = new ModifiableSolrParams();
+    paramsLoc.set("expr", cexpr);
+    paramsLoc.set("qt", "/stream");
+    url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+    solrStream = new SolrStream(url, paramsLoc);
+    context = new StreamContext();
+    solrStream.setStreamContext(context);
+    tuples = getTuples(solrStream);
+    assertTrue(tuples.size() == 1);
+    prediction = (Number)tuples.get(0).get("g");
+    assertEquals(prediction.doubleValue(), 87.20000076, 0);
+  }
+
+  @Test
   public void testPlot() throws Exception {
     String cexpr = "let(a=array(3,2,3), plot(type=scatter, x=a, y=array(5,6,3)))";
     ModifiableSolrParams paramsLoc = new ModifiableSolrParams();