You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jb...@apache.org on 2019/01/12 20:26:42 UTC

[lucene-solr] branch master updated: SOLR-13134: Allow the knnRegress Stream Evaluator to more easily perform bivariate regression

This is an automated email from the ASF dual-hosted git repository.

jbernste pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git


The following commit(s) were added to refs/heads/master by this push:
     new 292e26b  SOLR-13134: Allow the knnRegress Stream Evaluator to more easily perform bivariate regression
292e26b is described below

commit 292e26bc2d149c82c2ac55c4396364d561ea55e4
Author: Joel Bernstein <jb...@apache.org>
AuthorDate: Sat Jan 12 15:25:45 2019 -0500

    SOLR-13134: Allow the knnRegress Stream Evaluator to more easily perform bivariate regression
---
 .../solrj/io/eval/KnnRegressionEvaluator.java      | 21 ++++++--
 .../client/solrj/io/eval/PredictEvaluator.java     | 58 +++++++++++++++-------
 .../client/solrj/io/stream/MathExpressionTest.java | 21 ++++++--
 3 files changed, 75 insertions(+), 25 deletions(-)

diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
index e298f45..e16e60e 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
@@ -64,11 +64,20 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
     List<Number> outcomes = null;
     int k = 5;
     DistanceMeasure distanceMeasure = new EuclideanDistance();
+    boolean bivariate = false;
 
     if(values[0] instanceof Matrix) {
       observations = (Matrix)values[0];
+    } else if(values[0] instanceof List) {
+      bivariate = true;
+      List<Number> vec = (List<Number>)values[0];
+      double[][] data = new double[vec.size()][1];
+      for(int i=0; i<vec.size(); i++) {
+        data[i][0] = vec.get(i).doubleValue();
+      }
+      observations = new Matrix(data);
     } else {
-      throw new IOException("The first parameter for knnRegress should be the observation matrix.");
+      throw new IOException("The first parameter for knnRegress should be the observation vector or matrix.");
     }
 
     if(values[1] instanceof List) {
@@ -104,7 +113,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
     map.put("robust", robust);
     map.put("scale", scale);
 
-    return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust);
+    return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust, bivariate);
   }
 
 
@@ -117,6 +126,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
     private DistanceMeasure distanceMeasure;
     private boolean scale;
     private boolean robust;
+    private boolean bivariate;
 
     public KnnRegressionTuple(Matrix observations,
                               double[] outcomes,
@@ -124,7 +134,8 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
                               DistanceMeasure distanceMeasure,
                               Map<?,?> map,
                               boolean scale,
-                              boolean robust) {
+                              boolean robust,
+                              boolean bivariate) {
       super(map);
       this.observations = observations;
       this.outcomes = outcomes;
@@ -132,11 +143,15 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
       this.distanceMeasure = distanceMeasure;
       this.scale = scale;
       this.robust = robust;
+      this.bivariate = bivariate;
     }
 
     public boolean getScale() {
       return this.scale;
     }
+    public boolean getBivariate() {
+      return this.bivariate;
+    }
 
     //MinMax Scale both the observations and the predictors
 
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
index c8e83ba..3d87687 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
@@ -89,31 +89,51 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
 
     } else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple) {
       KnnRegressionEvaluator.KnnRegressionTuple regressedTuple = (KnnRegressionEvaluator.KnnRegressionTuple) first;
-      if (second instanceof List) {
-        List<Number> list = (List<Number>) second;
-        double[] predictors = new double[list.size()];
 
-        for (int i = 0; i < list.size(); i++) {
-          predictors[i] = list.get(i).doubleValue();
+      if(regressedTuple.getBivariate()) {
+        //Handle bi-variate regression
+        if(second instanceof Number) {
+          double[] predictors = new double[1];
+          predictors[0] = ((Number)second).doubleValue();
+          return regressedTuple.predict(predictors);
+        } else if(second instanceof List) {
+          List<Number> vec = (List<Number>)second;
+          List<Number> predictions = new ArrayList();
+          for(Number num : vec) {
+            double[] predictors = new double[1];
+            predictors[0] = num.doubleValue();
+            predictions.add(regressedTuple.predict(predictors));
+          }
+          return predictions;
         }
+      } else {
+        //Handle multi-variate regression
+        if (second instanceof List) {
+          List<Number> list = (List<Number>) second;
+          double[] predictors = new double[list.size()];
 
-        if(regressedTuple.getScale()) {
-          predictors = regressedTuple.scale(predictors);
-        }
+          for (int i = 0; i < list.size(); i++) {
+            predictors[i] = list.get(i).doubleValue();
+          }
 
-        return regressedTuple.predict(predictors);
-      } else if (second instanceof Matrix) {
+          if (regressedTuple.getScale()) {
+            predictors = regressedTuple.scale(predictors);
+          }
 
-        Matrix m = (Matrix) second;
-        if(regressedTuple.getScale()) {
-          m = regressedTuple.scale(m);
-        }
-        double[][] data = m.getData();
-        List<Number> predictions = new ArrayList();
-        for (double[] predictors : data) {
-          predictions.add(regressedTuple.predict(predictors));
+          return regressedTuple.predict(predictors);
+        } else if (second instanceof Matrix) {
+
+          Matrix m = (Matrix) second;
+          if (regressedTuple.getScale()) {
+            m = regressedTuple.scale(m);
+          }
+          double[][] data = m.getData();
+          List<Number> predictions = new ArrayList();
+          for (double[] predictors : data) {
+            predictions.add(regressedTuple.predict(predictors));
+          }
+          return predictions;
         }
-        return predictions;
       }
     } else if (first instanceof VectorFunction) {
       VectorFunction vectorFunction = (VectorFunction) first;
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
index 1e3c40b..6c61ad7 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
@@ -4136,9 +4136,8 @@ public class MathExpressionTest extends SolrCloudTestCase {
     //Test univariate regression with scaling off
 
     cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
-        "b=transpose(matrix(a))," +
-        "c=knnRegress(b, a, 3)," +
-        "d=predict(c, array(3)))";
+        "c=knnRegress(a, a, 3)," +
+        "d=predict(c, 3))";
     paramsLoc = new ModifiableSolrParams();
     paramsLoc.set("expr", cexpr);
     paramsLoc.set("qt", "/stream");
@@ -4151,6 +4150,22 @@ public class MathExpressionTest extends SolrCloudTestCase {
     prediction = (Number)tuples.get(0).get("d");
     assertEquals(prediction.doubleValue(), 3, 0);
 
+    cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
+        "c=knnRegress(a, a, 3)," +
+        "d=predict(c, array(3,4)))";
+    paramsLoc = new ModifiableSolrParams();
+    paramsLoc.set("expr", cexpr);
+    paramsLoc.set("qt", "/stream");
+    url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+    solrStream = new SolrStream(url, paramsLoc);
+    context = new StreamContext();
+    solrStream.setStreamContext(context);
+    tuples = getTuples(solrStream);
+    assertTrue(tuples.size() == 1);
+    predictions = (List<Number>)tuples.get(0).get("d");
+    assertEquals(predictions.size(), 2);
+    assertEquals(predictions.get(0).doubleValue(), 3, 0);
+    assertEquals(predictions.get(1).doubleValue(), 4, 0);
   }
 
   @Test