You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by da...@apache.org on 2018/01/11 08:01:08 UTC
[32/50] [abbrv] lucene-solr:jira/solr-11702: SOLR-10716: Add
termVectors Stream Evaluator
SOLR-10716: Add termVectors Stream Evaluator
Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/459ed850
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/459ed850
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/459ed850
Branch: refs/heads/jira/solr-11702
Commit: 459ed85052a72219631f0dcdeb1b6650b632a8fa
Parents: 07407a5
Author: Joel Bernstein <jb...@apache.org>
Authored: Mon Jan 8 19:38:37 2018 -0500
Committer: Joel Bernstein <jb...@apache.org>
Committed: Mon Jan 8 19:39:43 2018 -0500
----------------------------------------------------------------------
.../org/apache/solr/handler/StreamHandler.java | 4 +
.../solr/client/solrj/io/eval/Attributes.java | 26 +++
.../solrj/io/eval/CorrelationEvaluator.java | 6 +-
.../eval/CorrelationSignificanceEvaluator.java | 2 +-
.../solrj/io/eval/GetAttributeEvaluator.java | 43 ++++
.../solrj/io/eval/GetColumnLabelsEvaluator.java | 42 ++++
.../solrj/io/eval/GetRowLabelsEvaluator.java | 42 ++++
.../solr/client/solrj/io/eval/Matrix.java | 35 ++-
.../solrj/io/eval/TermVectorsEvaluator.java | 163 +++++++++++++
.../solrj/io/eval/TransposeEvaluator.java | 6 +-
.../solrj/io/stream/StreamExpressionTest.java | 229 +++++++++++++++++++
11 files changed, 585 insertions(+), 13 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/459ed850/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
index 8f3c4d6..ee3a17b 100644
--- a/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
+++ b/solr/core/src/java/org/apache/solr/handler/StreamHandler.java
@@ -292,6 +292,10 @@ public class StreamHandler extends RequestHandlerBase implements SolrCoreAware,
.withFunctionName("lerp", LerpEvaluator.class)
.withFunctionName("chiSquareDataSet", ChiSquareDataSetEvaluator.class)
.withFunctionName("gtestDataSet", GTestDataSetEvaluator.class)
+ .withFunctionName("termVectors", TermVectorsEvaluator.class)
+ .withFunctionName("getColumnLabels", GetColumnLabelsEvaluator.class)
+ .withFunctionName("getRowLabels", GetRowLabelsEvaluator.class)
+ .withFunctionName("getAttribute", GetAttributeEvaluator.class)
// Boolean Stream Evaluators
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/459ed850/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/Attributes.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/Attributes.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/Attributes.java
new file mode 100644
index 0000000..10f3a33
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/Attributes.java
@@ -0,0 +1,26 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.client.solrj.io.eval;
+
+import java.util.Map;
+
+
+public interface Attributes {
+ Object getAttribute(String key);
+ void setAttribute(String key, Object value);
+ Map getAttributes();
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/459ed850/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CorrelationEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CorrelationEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CorrelationEvaluator.java
index a5065d4..866c5d0 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CorrelationEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CorrelationEvaluator.java
@@ -108,21 +108,21 @@ public class CorrelationEvaluator extends RecursiveObjectEvaluator implements Ma
RealMatrix corrMatrix = pearsonsCorrelation.getCorrelationMatrix();
double[][] corrMatrixData = corrMatrix.getData();
Matrix realMatrix = new Matrix(corrMatrixData);
- realMatrix.addToContext("corr", pearsonsCorrelation);
+ realMatrix.setAttribute("corr", pearsonsCorrelation);
return realMatrix;
} else if (type.equals(CorrelationType.kendalls)) {
KendallsCorrelation kendallsCorrelation = new KendallsCorrelation(data);
RealMatrix corrMatrix = kendallsCorrelation.getCorrelationMatrix();
double[][] corrMatrixData = corrMatrix.getData();
Matrix realMatrix = new Matrix(corrMatrixData);
- realMatrix.addToContext("corr", kendallsCorrelation);
+ realMatrix.setAttribute("corr", kendallsCorrelation);
return realMatrix;
} else if (type.equals(CorrelationType.spearmans)) {
SpearmansCorrelation spearmansCorrelation = new SpearmansCorrelation(new Array2DRowRealMatrix(data));
RealMatrix corrMatrix = spearmansCorrelation.getCorrelationMatrix();
double[][] corrMatrixData = corrMatrix.getData();
Matrix realMatrix = new Matrix(corrMatrixData);
- realMatrix.addToContext("corr", spearmansCorrelation.getRankCorrelation());
+ realMatrix.setAttribute("corr", spearmansCorrelation.getRankCorrelation());
return realMatrix;
} else {
return null;
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/459ed850/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CorrelationSignificanceEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CorrelationSignificanceEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CorrelationSignificanceEvaluator.java
index 1726c9c..f534a8d 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CorrelationSignificanceEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/CorrelationSignificanceEvaluator.java
@@ -42,7 +42,7 @@ public class CorrelationSignificanceEvaluator extends RecursiveObjectEvaluator i
return null;
} else if(value instanceof Matrix) {
Matrix matrix = (Matrix) value;
- Object corr = matrix.getContextValue("corr");
+ Object corr = matrix.getAttribute("corr");
if(corr instanceof PearsonsCorrelation) {
PearsonsCorrelation pcorr = (PearsonsCorrelation)corr;
RealMatrix realMatrix = pcorr.getCorrelationPValues();
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/459ed850/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetAttributeEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetAttributeEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetAttributeEvaluator.java
new file mode 100644
index 0000000..81eea23
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetAttributeEvaluator.java
@@ -0,0 +1,43 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.eval;
+
+import java.io.IOException;
+import java.util.Locale;
+
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+
+public class GetAttributeEvaluator extends RecursiveObjectEvaluator implements TwoValueWorker {
+ private static final long serialVersionUID = 1;
+
+ public GetAttributeEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
+ super(expression, factory);
+ }
+
+ @Override
+ public Object doWork(Object value1, Object value2) throws IOException {
+ if(!(value1 instanceof Attributes)){
+ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting an Attributes",toExpression(constructingFactory), value1.getClass().getSimpleName()));
+ } else {
+ Attributes attributes = (Attributes)value1;
+ String key = (String)value2;
+ return attributes.getAttribute(key.replace("\"", ""));
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/459ed850/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetColumnLabelsEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetColumnLabelsEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetColumnLabelsEvaluator.java
new file mode 100644
index 0000000..f2710d3
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetColumnLabelsEvaluator.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.eval;
+
+import java.io.IOException;
+import java.util.Locale;
+
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+
+public class GetColumnLabelsEvaluator extends RecursiveObjectEvaluator implements OneValueWorker {
+ private static final long serialVersionUID = 1;
+
+ public GetColumnLabelsEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
+ super(expression, factory);
+ }
+
+ @Override
+ public Object doWork(Object value) throws IOException {
+ if(!(value instanceof Matrix)){
+ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a Matrix",toExpression(constructingFactory), value.getClass().getSimpleName()));
+ } else {
+ Matrix matrix = (Matrix)value;
+ return matrix.getColumnLabels();
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/459ed850/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetRowLabelsEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetRowLabelsEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetRowLabelsEvaluator.java
new file mode 100644
index 0000000..9af25b4
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/GetRowLabelsEvaluator.java
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.eval;
+
+import java.io.IOException;
+import java.util.Locale;
+
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+
+public class GetRowLabelsEvaluator extends RecursiveObjectEvaluator implements OneValueWorker {
+ private static final long serialVersionUID = 1;
+
+ public GetRowLabelsEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
+ super(expression, factory);
+ }
+
+ @Override
+ public Object doWork(Object value) throws IOException {
+ if(!(value instanceof Matrix)){
+ throw new IOException(String.format(Locale.ROOT,"Invalid expression %s - found type %s for value, expecting a Matrix",toExpression(constructingFactory), value.getClass().getSimpleName()));
+ } else {
+ Matrix matrix = (Matrix)value;
+ return matrix.getRowLabels();
+ }
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/459ed850/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/Matrix.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/Matrix.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/Matrix.java
index 365a018..7fcfca2 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/Matrix.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/Matrix.java
@@ -23,25 +23,44 @@ import java.util.ArrayList;
import java.util.Iterator;
-public class Matrix implements Iterable {
+public class Matrix implements Iterable, Attributes {
private double[][] data;
- private Map context = new HashMap();
+ private List<String> columnLabels;
+ private List<String> rowLabels;
+
+ private Map<String, Object> attributes = new HashMap();
public Matrix(double[][] data) {
this.data = data;
}
- public Map getContext() {
- return this.context;
+ public Map getAttributes() {
+ return this.attributes;
+ }
+
+ public void setAttribute(String key, Object value) {
+ this.attributes.put(key, value);
+ }
+
+ public Object getAttribute(String key) {
+ return this.attributes.get(key);
+ }
+
+ public List<String> getColumnLabels() {
+ return this.columnLabels;
+ }
+
+ public void setColumnLabels(List<String> columnLabels) {
+ this.columnLabels = columnLabels;
}
- public void addToContext(Object key, Object value) {
- this.context.put(key, value);
+ public List<String> getRowLabels() {
+ return rowLabels;
}
- public Object getContextValue(Object key) {
- return this.context.get(key);
+ public void setRowLabels(List<String> rowLables) {
+ this.rowLabels = rowLables;
}
public double[][] getData() {
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/459ed850/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TermVectorsEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TermVectorsEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TermVectorsEvaluator.java
new file mode 100644
index 0000000..343e65c
--- /dev/null
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TermVectorsEvaluator.java
@@ -0,0 +1,163 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.client.solrj.io.eval;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.HashSet;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.ArrayList;
+import java.util.Map;
+import java.util.TreeMap;
+
+import org.apache.solr.client.solrj.io.Tuple;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpression;
+import org.apache.solr.client.solrj.io.stream.expr.StreamExpressionNamedParameter;
+import org.apache.solr.client.solrj.io.stream.expr.StreamFactory;
+
+public class TermVectorsEvaluator extends RecursiveObjectEvaluator implements ManyValueWorker {
+ protected static final long serialVersionUID = 1L;
+
+ private int minTermLength = 3;
+ private double minDocFreq = .05; // 5% of the docs min
+ private double maxDocFreq = .5; // 50% of the docs max
+
+ public TermVectorsEvaluator(StreamExpression expression, StreamFactory factory) throws IOException {
+ super(expression, factory);
+
+ List<StreamExpressionNamedParameter> namedParams = factory.getNamedOperands(expression);
+
+ for (StreamExpressionNamedParameter namedParam : namedParams) {
+ if (namedParam.getName().equals("minTermLength")) {
+ this.minTermLength = Integer.parseInt(namedParam.getParameter().toString().trim());
+ } else if (namedParam.getName().equals("minDocFreq")) {
+ this.minDocFreq = Double.parseDouble(namedParam.getParameter().toString().trim());
+ if (minDocFreq < 0 || minDocFreq > 1) {
+ throw new IOException("Doc frequency percentage must be between 0 and 1");
+ }
+ } else if (namedParam.getName().equals("maxDocFreq")) {
+ this.maxDocFreq = Double.parseDouble(namedParam.getParameter().toString().trim());
+ if (maxDocFreq < 0 || maxDocFreq > 1) {
+ throw new IOException("Doc frequency percentage must be between 0 and 1");
+ }
+ } else {
+ throw new IOException("Unexpected named parameter:" + namedParam.getName());
+ }
+ }
+ }
+
+ @Override
+ public Object doWork(Object... objects) throws IOException {
+
+ if (objects.length == 1) {
+ //Just docs
+ List<Tuple> tuples = (List<Tuple>) objects[0];
+ TreeMap<String, Integer> docFreqs = new TreeMap();
+ List<String> features = new ArrayList();
+ List<String> rowLabels = new ArrayList();
+
+ for (Tuple tuple : tuples) {
+
+ Set<String> docTerms = new HashSet();
+
+ if (tuple.get("terms") == null) {
+ throw new IOException("The document tuples must contain a terms field");
+ }
+
+ List<String> terms = (List<String>) tuple.get("terms");
+
+ String id = tuple.getString("id");
+ rowLabels.add(id);
+
+ for (String term : terms) {
+
+ if (term.length() < minTermLength) {
+ //Eliminate terms due to length
+ continue;
+ }
+
+ if (!docTerms.contains(term)) {
+ docTerms.add(term);
+ if (docFreqs.containsKey(term)) {
+ int count = docFreqs.get(term).intValue();
+ docFreqs.put(term, ++count);
+ } else {
+ docFreqs.put(term, 1);
+ }
+ }
+ }
+ }
+
+ //Eliminate terms based on frequency
+
+ int min = (int) (tuples.size() * minDocFreq);
+ int max = (int) (tuples.size() * maxDocFreq);
+
+ Set<Map.Entry<String, Integer>> entries = docFreqs.entrySet();
+ Iterator<Map.Entry<String, Integer>> it = entries.iterator();
+ while (it.hasNext()) {
+ Map.Entry<String, Integer> entry = it.next();
+ int count = entry.getValue().intValue();
+
+ if (count < min || count > max) {
+ it.remove();
+ }
+ }
+
+ int totalTerms = docFreqs.size();
+ Set<String> keys = docFreqs.keySet();
+ features.addAll(keys);
+ double[][] docVec = new double[tuples.size()][];
+ for (int t = 0; t < tuples.size(); t++) {
+ Tuple tuple = tuples.get(t);
+ List<String> terms = (List<String>) tuple.get("terms");
+ Map<String, Integer> termFreq = new HashMap();
+
+ for (String term : terms) {
+ if (docFreqs.containsKey(term)) {
+ if (termFreq.containsKey(term)) {
+ int count = termFreq.get(term).intValue();
+ termFreq.put(term, ++count);
+ } else {
+ termFreq.put(term, 1);
+ }
+ }
+ }
+
+ double[] termVec = new double[totalTerms];
+ for (int i = 0; i < totalTerms; i++) {
+ String feature = features.get(i);
+ int df = docFreqs.get(feature);
+ int tf = termFreq.containsKey(feature) ? termFreq.get(feature) : 0;
+ termVec[i] = Math.sqrt(tf) * (double) (Math.log((tuples.size() + 1) / (double) (df + 1)) + 1.0);
+ }
+ docVec[t] = termVec;
+ }
+
+ Matrix matrix = new Matrix(docVec);
+ matrix.setColumnLabels(features);
+ matrix.setRowLabels(rowLabels);
+ matrix.setAttribute("docFreqs", docFreqs);
+ return matrix;
+ } else {
+ throw new IOException("The termVectors function a single positional parameter.");
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/459ed850/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TransposeEvaluator.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TransposeEvaluator.java b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TransposeEvaluator.java
index b206cc5..dfc2346 100644
--- a/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TransposeEvaluator.java
+++ b/solr/solrj/src/java/org/apache/solr/client/solrj/io/eval/TransposeEvaluator.java
@@ -44,7 +44,11 @@ public class TransposeEvaluator extends RecursiveObjectEvaluator implements OneV
double[][] data = matrix.getData();
Array2DRowRealMatrix amatrix = new Array2DRowRealMatrix(data);
Array2DRowRealMatrix tmatrix = (Array2DRowRealMatrix)amatrix.transpose();
- return new Matrix(tmatrix.getData());
+ Matrix newMatrix = new Matrix(tmatrix.getData());
+ //Switch the row and column labels
+ newMatrix.setColumnLabels(matrix.getRowLabels());
+ newMatrix.setRowLabels(matrix.getColumnLabels());
+ return newMatrix;
} else {
throw new IOException("matrix parameter expected for transpose function");
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/459ed850/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
----------------------------------------------------------------------
diff --git a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
index 9d41c54..2a9df01 100644
--- a/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
+++ b/solr/solrj/src/test/org/apache/solr/client/solrj/io/stream/StreamExpressionTest.java
@@ -6632,6 +6632,235 @@ public class StreamExpressionTest extends SolrCloudTestCase {
assertEquals(sd, 3.5, 0.0);
}
+ @Test
+ public void testTermVectors() throws Exception {
+ // Test termVectors with only documents and default termVector settings
+ String cexpr = "let(echo=true," +
+ "a=select(list(tuple(id=\"1\", text=\"hello world\"), " +
+ "tuple(id=\"2\", text=\"hello steve\"), " +
+ "tuple(id=\"3\", text=\"hello jim jim\"), " +
+ "tuple(id=\"4\", text=\"hello jack\")), id, analyze(text, test_t) as terms)," +
+ " b=termVectors(a, minDocFreq=0, maxDocFreq=1)," +
+ " c=getRowLabels(b)," +
+ " d=getColumnLabels(b)," +
+ " e=getAttribute(b, docFreqs))";
+ ModifiableSolrParams paramsLoc = new ModifiableSolrParams();
+ paramsLoc.set("expr", cexpr);
+ paramsLoc.set("qt", "/stream");
+ String url = cluster.getJettySolrRunners().get(0).getBaseUrl().toString()+"/"+COLLECTIONORALIAS;
+ TupleStream solrStream = new SolrStream(url, paramsLoc);
+ StreamContext context = new StreamContext();
+ solrStream.setStreamContext(context);
+ List<Tuple> tuples = getTuples(solrStream);
+ assertTrue(tuples.size() == 1);
+ List<List<Number>> termVectors = (List<List<Number>>)tuples.get(0).get("b");
+
+ assertEquals(termVectors.size(), 4);
+ List<Number> termVector = termVectors.get(0);
+ assertEquals(termVector.size(), 5);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+ assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(4).doubleValue(), 1.916290731874155, 0.0);
+
+ termVector = termVectors.get(1);
+ assertEquals(termVector.size(), 5);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+ assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(3).doubleValue(), 1.916290731874155, 0.0);
+ assertEquals(termVector.get(4).doubleValue(), 0.0, 0.0);
+
+ termVector = termVectors.get(2);
+ assertEquals(termVector.size(), 5);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+ assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(2).doubleValue(), 2.7100443424662948, 0.0);
+ assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(4).doubleValue(), 0.0, 0.0);
+
+ termVector = termVectors.get(3);
+ assertEquals(termVector.size(), 5);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+ assertEquals(termVector.get(1).doubleValue(), 1.916290731874155, 0.0);
+ assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(4).doubleValue(), 0.0, 0.0);
+
+ List<String> rowLabels = (List<String>)tuples.get(0).get("c");
+ assertEquals(rowLabels.size(), 4);
+ assertEquals(rowLabels.get(0), "1");
+ assertEquals(rowLabels.get(1), "2");
+ assertEquals(rowLabels.get(2), "3");
+ assertEquals(rowLabels.get(3), "4");
+
+ List<String> columnLabels = (List<String>)tuples.get(0).get("d");
+ assertEquals(columnLabels.size(), 5);
+ assertEquals(columnLabels.get(0), "hello");
+ assertEquals(columnLabels.get(1), "jack");
+ assertEquals(columnLabels.get(2), "jim");
+ assertEquals(columnLabels.get(3), "steve");
+ assertEquals(columnLabels.get(4), "world");
+
+ Map<String, Number> docFreqs = (Map<String, Number>)tuples.get(0).get("e");
+
+ assertEquals(docFreqs.size(), 5);
+ assertEquals(docFreqs.get("hello").intValue(), 4);
+ assertEquals(docFreqs.get("jack").intValue(), 1);
+ assertEquals(docFreqs.get("jim").intValue(), 1);
+ assertEquals(docFreqs.get("steve").intValue(), 1);
+ assertEquals(docFreqs.get("world").intValue(), 1);
+
+ //Test minTermLength. This should drop off the term jim
+
+ cexpr = "let(echo=true," +
+ "a=select(list(tuple(id=\"1\", text=\"hello world\"), " +
+ "tuple(id=\"2\", text=\"hello steve\"), " +
+ "tuple(id=\"3\", text=\"hello jim jim\"), " +
+ "tuple(id=\"4\", text=\"hello jack\")), id, analyze(text, test_t) as terms)," +
+ " b=termVectors(a, minTermLength=4, minDocFreq=0, maxDocFreq=1)," +
+ " c=getRowLabels(b)," +
+ " d=getColumnLabels(b)," +
+ " e=getAttribute(b, docFreqs))";
+ paramsLoc = new ModifiableSolrParams();
+ paramsLoc.set("expr", cexpr);
+ paramsLoc.set("qt", "/stream");
+ solrStream = new SolrStream(url, paramsLoc);
+ context = new StreamContext();
+ solrStream.setStreamContext(context);
+ tuples = getTuples(solrStream);
+ assertTrue(tuples.size() == 1);
+ termVectors = (List<List<Number>>)tuples.get(0).get("b");
+ assertEquals(termVectors.size(), 4);
+ termVector = termVectors.get(0);
+ assertEquals(termVector.size(), 4);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+ assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(3).doubleValue(), 1.916290731874155, 0.0);
+
+ termVector = termVectors.get(1);
+ assertEquals(termVector.size(), 4);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+ assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(2).doubleValue(), 1.916290731874155, 0.0);
+ assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0);
+
+ termVector = termVectors.get(2);
+ assertEquals(termVector.size(), 4);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+ assertEquals(termVector.get(1).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0);
+
+ termVector = termVectors.get(3);
+ assertEquals(termVector.size(), 4);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+ assertEquals(termVector.get(1).doubleValue(), 1.916290731874155, 0.0);
+ assertEquals(termVector.get(2).doubleValue(), 0.0, 0.0);
+ assertEquals(termVector.get(3).doubleValue(), 0.0, 0.0);
+
+ rowLabels = (List<String>)tuples.get(0).get("c");
+ assertEquals(rowLabels.size(), 4);
+ assertEquals(rowLabels.get(0), "1");
+ assertEquals(rowLabels.get(1), "2");
+ assertEquals(rowLabels.get(2), "3");
+ assertEquals(rowLabels.get(3), "4");
+
+ columnLabels = (List<String>)tuples.get(0).get("d");
+ assertEquals(columnLabels.size(), 4);
+ assertEquals(columnLabels.get(0), "hello");
+ assertEquals(columnLabels.get(1), "jack");
+ assertEquals(columnLabels.get(2), "steve");
+ assertEquals(columnLabels.get(3), "world");
+
+ docFreqs = (Map<String, Number>)tuples.get(0).get("e");
+
+ assertEquals(docFreqs.size(), 4);
+ assertEquals(docFreqs.get("hello").intValue(), 4);
+ assertEquals(docFreqs.get("jack").intValue(), 1);
+ assertEquals(docFreqs.get("steve").intValue(), 1);
+ assertEquals(docFreqs.get("world").intValue(), 1);
+
+
+ //Test minDocFreq attribute at .5. This should eliminate all but the term hello
+
+ cexpr = "let(echo=true," +
+ "a=select(list(tuple(id=\"1\", text=\"hello world\"), " +
+ "tuple(id=\"2\", text=\"hello steve\"), " +
+ "tuple(id=\"3\", text=\"hello jim jim\"), " +
+ "tuple(id=\"4\", text=\"hello jack\")), id, analyze(text, test_t) as terms)," +
+ " b=termVectors(a, minDocFreq=.5, maxDocFreq=1)," +
+ " c=getRowLabels(b)," +
+ " d=getColumnLabels(b)," +
+ " e=getAttribute(b, docFreqs))";
+ paramsLoc = new ModifiableSolrParams();
+ paramsLoc.set("expr", cexpr);
+ paramsLoc.set("qt", "/stream");
+ solrStream = new SolrStream(url, paramsLoc);
+ context = new StreamContext();
+ solrStream.setStreamContext(context);
+ tuples = getTuples(solrStream);
+ assertTrue(tuples.size() == 1);
+ termVectors = (List<List<Number>>)tuples.get(0).get("b");
+
+ assertEquals(termVectors.size(), 4);
+ termVector = termVectors.get(0);
+ assertEquals(termVector.size(), 1);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+
+ termVector = termVectors.get(1);
+ assertEquals(termVector.size(), 1);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+
+ termVector = termVectors.get(2);
+ assertEquals(termVector.size(), 1);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+
+ termVector = termVectors.get(3);
+ assertEquals(termVector.size(), 1);
+ assertEquals(termVector.get(0).doubleValue(), 1.0, 0.0);
+
+ rowLabels = (List<String>)tuples.get(0).get("c");
+ assertEquals(rowLabels.size(), 4);
+ assertEquals(rowLabels.get(0), "1");
+ assertEquals(rowLabels.get(1), "2");
+ assertEquals(rowLabels.get(2), "3");
+ assertEquals(rowLabels.get(3), "4");
+
+ columnLabels = (List<String>)tuples.get(0).get("d");
+ assertEquals(columnLabels.size(), 1);
+ assertEquals(columnLabels.get(0), "hello");
+
+ docFreqs = (Map<String, Number>)tuples.get(0).get("e");
+
+ assertEquals(docFreqs.size(), 1);
+ assertEquals(docFreqs.get("hello").intValue(), 4);
+
+ //Test maxDocFreq attribute at 0. This should eliminate all terms
+
+ cexpr = "let(echo=true," +
+ "a=select(list(tuple(id=\"1\", text=\"hello world\"), " +
+ "tuple(id=\"2\", text=\"hello steve\"), " +
+ "tuple(id=\"3\", text=\"hello jim jim\"), " +
+ "tuple(id=\"4\", text=\"hello jack\")), id, analyze(text, test_t) as terms)," +
+ " b=termVectors(a, maxDocFreq=0)," +
+ " c=getRowLabels(b)," +
+ " d=getColumnLabels(b)," +
+ " e=getAttribute(b, docFreqs))";
+ paramsLoc = new ModifiableSolrParams();
+ paramsLoc.set("expr", cexpr);
+ paramsLoc.set("qt", "/stream");
+ solrStream = new SolrStream(url, paramsLoc);
+ context = new StreamContext();
+ solrStream.setStreamContext(context);
+ tuples = getTuples(solrStream);
+ assertTrue(tuples.size() == 1);
+ termVectors = (List<List<Number>>)tuples.get(0).get("b");
+ assertEquals(termVectors.size(), 4);
+ assertEquals(termVectors.get(0).size(), 0);
+ }
@Test
public void testEBESubtract() throws Exception {