You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by jb...@apache.org on 2018/08/17 19:49:09 UTC
lucene-solr:master: SOLR-12671: Add robust flag to knnRegress Stream
Evaluator
Repository: lucene-solr
Updated Branches:
refs/heads/master 124be4e20 -> 52f9cee97
SOLR-12671: Add robust flag to knnRegress Stream Evaluator
Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/52f9cee9
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/52f9cee9
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/52f9cee9
Branch: refs/heads/master
Commit: 52f9cee97b4f293af26de0e7b4ec534cb6b11b10
Parents: 124be4e
Author: Joel Bernstein <jb...@apache.org>
Authored: Fri Aug 17 14:26:05 2018 -0400
Committer: Joel Bernstein <jb...@apache.org>
Committed: Fri Aug 17 14:26:17 2018 -0400
----------------------------------------------------------------------
.../solr/client/solrj/io/eval/KnnEvaluator.java | 4 ++
.../solrj/io/eval/KnnRegressionEvaluator.java | 65 ++++++++++++++++----
.../client/solrj/io/eval/PredictEvaluator.java | 8 ++-
.../solrj/io/stream/MathExpressionTest.java | 36 ++++++++---
4 files changed, 92 insertions(+), 21 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/52f9cee9/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
index 81607cf..17fb011 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnEvaluator.java
@@ -144,6 +144,10 @@ public class KnnEvaluator extends RecursiveObjectEvaluator implements ManyValueW
}
public int compareTo(Neighbor neighbor) {
+ if(this.distance.compareTo(neighbor.getDistance()) == 0) {
+ return row-neighbor.getRow();
+ }
+
return this.distance.compareTo(neighbor.getDistance());
}
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/52f9cee9/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
index 957936e..e6f6d80 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/KnnRegressionEvaluator.java
@@ -25,15 +25,32 @@ import java.util.HashMap;
import org.apache.commons.math3.linear.Array2DRowRealMatrix;
import org.apache.commons.math3.ml.distance.DistanceMeasure;
import org.apache.commons.math3.ml.distance.EuclideanDistance;
+import org.apache.commons.math3.stat.descriptive.rank.Percentile;
import org.apache.solr.client.solrj.io.Tuple;
import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
protected static final long serialVersionUID = 1L;
+ private boolean robust=false;
+ private boolean scale=false;
+
public KnnRegressionEvaluator(StreamExpression expression, StreamFactory factory) throws IOException{
super(expression, factory);
+
+ List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
+
+ for(StreamExpressionNamedParameter namedParam : namedParams){
+ if(namedParam.getName().equals("scale")){
+ this.scale = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
+ } else if(namedParam.getName().equals("robust")) {
+ this.robust = Boolean.parseBoolean(namedParam.getParameter().toString().trim());
+ } else {
+ throw new IOException("Unexpected named parameter:"+namedParam.getName());
+ }
+ }
}
@Override
@@ -84,7 +101,7 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
map.put("features", observations.getColumnCount());
map.put("distance", distanceMeasure.getClass().getSimpleName());
- return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map);
+ return new KnnRegressionTuple(observations, outcomeData, k, distanceMeasure, map, scale, robust);
}
@@ -95,17 +112,27 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
private double[] outcomes;
private int k;
private DistanceMeasure distanceMeasure;
+ private boolean scale;
+ private boolean robust;
public KnnRegressionTuple(Matrix observations,
double[] outcomes,
int k,
DistanceMeasure distanceMeasure,
- Map<?,?> map) {
+ Map<?,?> map,
+ boolean scale,
+ boolean robust) {
super(map);
this.observations = observations;
this.outcomes = outcomes;
this.k = k;
this.distanceMeasure = distanceMeasure;
+ this.scale = scale;
+ this.robust = robust;
+ }
+
+ public boolean getScale() {
+ return this.scale;
}
//MinMax Scale both the observations and the predictors
@@ -175,19 +202,33 @@ public class KnnRegressionEvaluator extends RecursiveObjectEvaluator implements
public double predict(double[] values) {
- Matrix knn = KnnEvaluator.search(scaledObservations, values, k, distanceMeasure);
+ Matrix obs = scaledObservations != null ? scaledObservations : observations;
+ Matrix knn = KnnEvaluator.search(obs, values, k, distanceMeasure);
List<Number> indexes = (List<Number>)knn.getAttribute("indexes");
- double sum = 0;
-
- //Collect the outcomes for the nearest neighbors
- for(Number n : indexes) {
- sum += outcomes[n.intValue()];
+ if(robust) {
+ //Get the median of the results.
+ double[] vals = new double[indexes.size()];
+ Percentile percentile = new Percentile();
+ int i=0;
+ for (Number n : indexes) {
+ vals[i++]=outcomes[n.intValue()];
+ }
+
+ //Return 50 percentile.
+ return percentile.evaluate(vals, 50);
+ } else {
+ //Get the average of the results
+ double sum = 0;
+
+ //Collect the outcomes for the nearest neighbors
+ for (Number n : indexes) {
+ sum += outcomes[n.intValue()];
+ }
+
+ //Return the average of the outcomes as the prediction.
+ return sum / ((double) indexes.size());
}
-
- //Return the average of the outcomes as the prediction.
-
- return sum/((double)indexes.size());
}
}
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/52f9cee9/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
index 9385928..c8e83ba 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/PredictEvaluator.java
@@ -97,13 +97,17 @@ public class PredictEvaluator extends RecursiveObjectEvaluator implements ManyVa
predictors[i] = list.get(i).doubleValue();
}
- predictors = regressedTuple.scale(predictors);
+ if(regressedTuple.getScale()) {
+ predictors = regressedTuple.scale(predictors);
+ }
return regressedTuple.predict(predictors);
} else if (second instanceof Matrix) {
Matrix m = (Matrix) second;
- m = regressedTuple.scale(m);
+ if(regressedTuple.getScale()) {
+ m = regressedTuple.scale(m);
+ }
double[][] data = m.getData();
List<Number> predictions = new ArrayList();
for (double[] predictors : data) {
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/52f9cee9/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
index bfd4160..6565b76 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/MathExpressionTest.java
@@ -3450,7 +3450,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
"e=transpose(matrix(a, b, c))," +
- "f=knnRegress(e, d, 1)," +
+ "f=knnRegress(e, d, 1, scale=true)," +
"g=predict(f, e))";
ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
@@ -3480,7 +3480,7 @@ public class MathExpressionTest extends SolrCloudTestCase {
"c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 7.900000095)," +
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
"e=transpose(matrix(a, b, c))," +
- "f=knnRegress(e, d, 1)," +
+ "f=knnRegress(e, d, 1, scale=true)," +
"g=predict(f, array(8, 5, 4)))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
@@ -3494,12 +3494,14 @@ public class MathExpressionTest extends SolrCloudTestCase {
Number prediction = (Number)tuples.get(0).get("g");
assertEquals(prediction.doubleValue(), 85.09999847, 0);
- cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 15.10000038, 8.19999981), " +
- "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 7.599999905, 5.699999809, 4.5)," +
- "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 20.79999924, 4.900000095)," +
+ //Test robust. Take the median rather then average
+
+ cexpr = "let(echo=true, a=array(8.5, 12.89999962, 5.199999809, 10.69999981, 3.099999905, 3.5, 9.199999809, 9, 8.10000038, 8.19999981), " +
+ "b=array(5.099999905, 5.800000191, 2.099999905, 8.399998665, 2.900000095, 1.200000048, 3.700000048, 5.599999905, 5.699999809, 4.5)," +
+ "c=array(4.699999809, 8.800000191, 15.10000038, 12.19999981, 10.60000038, 3.5, 9.699999809, 5.900000095, 4.79999924, 4.900000095)," +
"d=array(85.09999847, 106.3000031, 50.20000076, 130.6000061, 54.79999924, 30.29999924, 79.40000153, 91, 135.3999939, 89.30000305)," +
"e=transpose(matrix(a, b, c))," +
- "f=knnRegress(e, d, 2)," +
+ "f=knnRegress(e, d, 3, scale=true, robust=true)," +
"g=predict(f, array(8, 5, 4)))";
paramsLoc = new ModifiableSolrParams();
paramsLoc.set("expr", cexpr);
@@ -3511,7 +3513,27 @@ public class MathExpressionTest extends SolrCloudTestCase {
tuples = getTuples(solrStream);
assertTrue(tuples.size() == 1);
prediction = (Number)tuples.get(0).get("g");
- assertEquals(prediction.doubleValue(), 87.20000076, 0);
+ assertEquals(prediction.doubleValue(), 89.30000305, 0);
+
+
+ //Test univariate regression with scaling off
+
+ cexpr = "let(echo=true, a=sequence(10, 0, 1), " +
+ "b=transpose(matrix(a))," +
+ "c=knnRegress(b, a, 3)," +
+ "d=predict(c, array(3)))";
+ paramsLoc = new ModifiableSolrParams();
+ paramsLoc.set("expr", cexpr);
+ paramsLoc.set("qt", "/stream");
+ url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+ solrStream = new SolrStream(url, paramsLoc);
+ context = new StreamContext();
+ solrStream.setStreamContext(context);
+ tuples = getTuples(solrStream);
+ assertTrue(tuples.size() == 1);
+ prediction = (Number)tuples.get(0).get("d");
+ assertEquals(prediction.doubleValue(), 3, 0);
+
}
@Test