You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jb...@apache.org on 2019/01/12 20:26:42 UTC
[lucene-solr] branch master updated: SOLR-13134: Allow the
knnRegress Stream Evaluator to more easily perform bivariate regression
This is an automated email from the ASF dual-hosted git repository.
jbernste pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git
The following commit(s) were added to refs/heads/master by this push:
new 292e26b SOLR-13134: Allow the knnRegress Stream Evaluator to more easily perform bivariate regression
292e26b is described below
commit 292e26bc2d149c82c2ac55c4396364d561ea55e4
Author: Joel Bernstein <jb...@apache.org>
AuthorDate: Sat Jan 12 15:25:45 2019 -0500
SOLR-13134: Allow the knnRegress Stream Evaluator to more easily perform bivariate regression
---
.../solrj/io/eval/KnnRegressionEvaluator.java | 21 ++++++--
.../client/solrj/io/eval/PredictEvaluator.java | 58 +++++++++++++++-------
.../client/solrj/io/stream/MathExpressionTest.java | 21 ++++++--
3 files changed, 75 insertions(+), 25 deletions(-)
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
index e298f45..e16e60e 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
@@ -64,11 +64,20 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
List<Number> outcomes = null;
int k = 5;
DistanceMeasure distanceMeasure = new EuclideanDistance();
+ boolean bivariate = false;
if(values[0] instanceof Matrix) {
observations = (Matrix)values[0];
+ } else if(values[0] instanceof List) {
+ bivariate = true;
+ List<Number> vec = (List<Number>)values[0];
+ double[][] data = new double[vec.size()][1];
+ for(int i=0; i<vec.size(); i++) {
+ data[i][0] = vec.get(i).doubleValue();
+ }
+ observations = new Matrix(data);
} else {
- throw new IOException("The first parameter for knnRegress should be the observation matrix.");
+ throw new IOException("The first parameter for knnRegress should be the observation vector or matrix.");
}
if(values[1] instanceof List) {
@@ -104,7 +113,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
map.put("robust", robust);
map.put("scale", scale);
- return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust);
+ return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust, bivariate);
}
@@ -117,6 +126,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
private DistanceMeasure distanceMeasure;
private boolean scale;
private boolean robust;
+ private boolean bivariate;
public KnnRegressionTuple(Matrix observations,
double[] outcomes,
@@ -124,7 +134,8 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
DistanceMeasure distanceMeasure,
Map<?,?> map,
boolean scale,
- boolean robust) {
+ boolean robust,
+ boolean bivariate) {
super(map);
this.observations = observations;
this.outcomes = outcomes;
@@ -132,11 +143,15 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
this.distanceMeasure = distanceMeasure;
this.scale = scale;
this.robust = robust;
+ this.bivariate = bivariate;
}
public boolean getScale() {
return this.scale;
}
+ public boolean getBivariate() {
+ return this.bivariate;
+ }
//MinMax Scale both the observations and the predictors
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
index c8e83ba..3d87687 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
@@ -89,31 +89,51 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
} else if (first instanceof KnnRegressionEvaluator.KnnRegressionTuple) {
KnnRegressionEvaluator.KnnRegressionTuple regressedTuple = (KnnRegressionEvaluator.KnnRegressionTuple) first;
- if (second instanceof List) {
- List<Number> list = (List<Number>) second;
- double[] predictors = new double[list.size()];
- for (int i = 0; i < list.size(); i++) {
- predictors[i] = list.get(i).doubleValue();
+ if(regressedTuple.getBivariate()) {
+ //Handle bi-variate regression
+ if(second instanceof Number) {
+ double[] predictors = new double[1];
+ predictors[0] = ((Number)second).doubleValue();
+ return regressedTuple.predict(predictors);
+ } else if(second instanceof List) {
+ List<Number> vec = (List<Number>)second;
+ List<Number> predictions = new ArrayList();
+ for(Number num : vec) {
+ double[] predictors = new double[1];
+ predictors[0] = num.doubleValue();
+ predictions.add(regressedTuple.predict(predictors));
+ }
+ return predictions;
}
+ } else {
+ //Handle multi-variate regression
+ if (second instanceof List) {
+ List<Number> list = (List<Number>) second;
+ double[] predictors = new double[list.size()];
- if(regressedTuple.getScale()) {
- predictors = regressedTuple.scale(predictors);
- }
+ for (int i = 0; i < list.size(); i++) {
+ predictors[i] = list.get(i).doubleValue();
+ }
- return regressedTuple.predict(predictors);
- } else if (second instanceof Matrix) {
+ if (regressedTuple.getScale()) {
+ predictors = regressedTuple.scale(predictors);
+ }
- Matrix m = (Matrix) second;
- if(regressedTuple.getScale()) {
- m = regressedTuple.scale(m);
- }
- double[][] data = m.getData();
- List<Number> predictions = new ArrayList();
- for (double[] predictors : data) {
- predictions.add(regressedTuple.predict(predictors));
+ return regressedTuple.predict(predictors);
+ } else if (second instanceof Matrix) {
+
+ Matrix m = (Matrix) second;
+ if (regressedTuple.getScale()) {
+ m = regressedTuple.scale(m);
+ }
+ double[][] data = m.getData();
+ List<Number> predictions = new ArrayList();
+ for (double[] predictors : data) {
+ predictions.add(regressedTuple.predict(predictors));
+ }
+ return predictions;
}
- return predictions;
}
} else if (first instanceof VectorFunction) {
VectorFunction vectorFunction = (VectorFunction) first;
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
index 1e3c40b..6c61ad7 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
@@ -4136,9 +4136,8 @@ public class MathExpressionTest extends SolrCloudTestCase {
//Test univariate regression with scaling off
cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
- "b=transpose(matrix(a))," +
- "c=knnRegress(b, a, 3)," +
- "d=predict(c, array(3)))";
+ "c=knnRegress(a, a, 3)," +
+ "d=predict(c, 3))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
paramsLoc.set("qt", "/stream");
@@ -4151,6 +4150,22 @@ public class MathExpressionTest extends SolrCloudTestCase {
prediction = (Number)tuples.get(0).get("d");
assertEquals(prediction.doubleValue(), 3, 0);
+ cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
+ "c=knnRegress(a, a, 3)," +
+ "d=predict(c, array(3,4)))";
+ paramsLoc = new ModifiableSolrParams();
+ paramsLoc.set("expr", cexpr);
+ paramsLoc.set("qt", "/stream");
+ url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+ solrStream = new SolrStream(url, paramsLoc);
+ context = new StreamContext();
+ solrStream.setStreamContext(context);
+ tuples = getTuples(solrStream);
+ assertTrue(tuples.size() == 1);
+ predictions = (List<Number>)tuples.get(0).get("d");
+ assertEquals(predictions.size(), 2);
+ assertEquals(predictions.get(0).doubleValue(), 3, 0);
+ assertEquals(predictions.get(1).doubleValue(), 4, 0);
}
@Test