You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by sa...@apache.org on 2016/12/06 23:13:03 UTC
[08/50] [abbrv] lucene-solr:apiv2: SOLR-8871 - various improvements
to ClassificationURP
SOLR-8871 - various improvements to ClassificationURP
Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/5ad741ee
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/5ad741ee
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/5ad741ee
Branch: refs/heads/apiv2
Commit: 5ad741eef8241de86945e710cdcb32e77a7183a3
Parents: e9e4715
Author: Tommaso Teofili <to...@apache.org>
Authored: Thu Nov 24 23:43:57 2016 +0100
Committer: Tommaso Teofili <to...@apache.org>
Committed: Thu Nov 24 23:43:57 2016 +0100
----------------------------------------------------------------------
.../UIMAUpdateRequestProcessorTest.java | 25 -
.../ClassificationUpdateProcessor.java | 59 ++-
.../ClassificationUpdateProcessorFactory.java | 197 +++----
.../ClassificationUpdateProcessorParams.java | 112 ++++
.../conf/solrconfig-classification.xml | 15 +
...lassificationUpdateProcessorFactoryTest.java | 208 ++------
...ificationUpdateProcessorIntegrationTest.java | 192 +++++++
.../ClassificationUpdateProcessorTest.java | 507 +++++++++++++++++++
.../SignatureUpdateProcessorFactoryTest.java | 28 +-
.../TestPartialUpdateDeduplication.java | 2 -
.../java/org/apache/solr/SolrTestCaseJ4.java | 22 +
11 files changed, 1012 insertions(+), 355 deletions(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/contrib/uima/src/test/org/apache/solr/uima/processor/UIMAUpdateRequestProcessorTest.java
----------------------------------------------------------------------
diff --git a/solr/contrib/uima/src/test/org/apache/solr/uima/processor/UIMAUpdateRequestProcessorTest.java b/solr/contrib/uima/src/test/org/apache/solr/uima/processor/UIMAUpdateRequestProcessorTest.java
index 5879c78..3833696 100644
--- a/solr/contrib/uima/src/test/org/apache/solr/uima/processor/UIMAUpdateRequestProcessorTest.java
+++ b/solr/contrib/uima/src/test/org/apache/solr/uima/processor/UIMAUpdateRequestProcessorTest.java
@@ -16,22 +16,12 @@
*/
package org.apache.solr.uima.processor;
-import java.util.ArrayList;
-import java.util.HashMap;
import java.util.Map;
import org.apache.lucene.util.LuceneTestCase.Slow;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException;
-import org.apache.solr.common.params.MultiMapSolrParams;
-import org.apache.solr.common.params.SolrParams;
-import org.apache.solr.common.params.UpdateParams;
-import org.apache.solr.common.util.ContentStream;
-import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.core.SolrCore;
-import org.apache.solr.handler.UpdateRequestHandler;
-import org.apache.solr.request.SolrQueryRequestBase;
-import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.uima.processor.SolrUIMAConfiguration.MapField;
import org.apache.solr.update.processor.UpdateRequestProcessor;
import org.apache.solr.update.processor.UpdateRequestProcessorChain;
@@ -188,19 +178,4 @@ public class UIMAUpdateRequestProcessorTest extends SolrTestCaseJ4 {
}
}
- private void addDoc(String chain, String doc) throws Exception {
- Map<String, String[]> params = new HashMap<>();
- params.put(UpdateParams.UPDATE_CHAIN, new String[] { chain });
- MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
- SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(), (SolrParams) mmparams) {
- };
-
- UpdateRequestHandler handler = new UpdateRequestHandler();
- handler.init(null);
- ArrayList<ContentStream> streams = new ArrayList<>(2);
- streams.add(new ContentStreamBase.StringStream(doc));
- req.setContentStreams(streams);
- handler.handleRequestBody(req, new SolrQueryResponse());
- }
-
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java
index 050fff0..8ce9814 100644
--- a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java
+++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java
@@ -19,6 +19,7 @@ package org.apache.solr.update.processor;
import java.io.IOException;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
import org.apache.lucene.analysis.Analyzer;
@@ -33,6 +34,7 @@ import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.update.AddUpdateCommand;
+import org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm;
/**
* This Class is a Request Update Processor to classify the document in input and add a field
@@ -42,43 +44,54 @@ import org.apache.solr.update.AddUpdateCommand;
class ClassificationUpdateProcessor
extends UpdateRequestProcessor {
- private String classFieldName; // the field to index the assigned class
-
+ private final String trainingClassField;
+ private final String predictedClassField;
+ private final int maxOutputClasses;
private DocumentClassifier<BytesRef> classifier;
/**
* Sole constructor
*
- * @param inputFieldNames fields to be used as classifier's inputs
- * @param classFieldName field to be used as classifier's output
- * @param minDf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq}, in case algorithm is {@code "knn"}
- * @param minTf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq}, in case algorithm is {@code "knn"}
- * @param k setting for k nearest neighbors to analyze, in case algorithm is {@code "knn"}
- * @param algorithm the name of the classifier to use
+ * @param classificationParams classification advanced params
* @param next next update processor in the chain
* @param indexReader index reader
* @param schema schema
*/
- public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm,
- UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
+ public ClassificationUpdateProcessor(ClassificationUpdateProcessorParams classificationParams, UpdateRequestProcessor next, IndexReader indexReader, IndexSchema schema) {
super(next);
- this.classFieldName = classFieldName;
- Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>();
+ this.trainingClassField = classificationParams.getTrainingClassField();
+ this.predictedClassField = classificationParams.getPredictedClassField();
+ this.maxOutputClasses = classificationParams.getMaxPredictedClasses();
+ String[] inputFieldNamesWithBoost = classificationParams.getInputFieldNames();
+ Algorithm classificationAlgorithm = classificationParams.getAlgorithm();
+
+ Map<String, Analyzer> field2analyzer = new HashMap<>();
+ String[] inputFieldNames = this.removeBoost(inputFieldNamesWithBoost);
for (String fieldName : inputFieldNames) {
SchemaField fieldFromSolrSchema = schema.getField(fieldName);
Analyzer indexAnalyzer = fieldFromSolrSchema.getType().getQueryAnalyzer();
field2analyzer.put(fieldName, indexAnalyzer);
}
- switch (algorithm) {
- case "knn":
- classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, k, minDf, minTf, classFieldName, field2analyzer, inputFieldNames);
+ switch (classificationAlgorithm) {
+ case KNN:
+ classifier = new KNearestNeighborDocumentClassifier(indexReader, null, classificationParams.getTrainingFilterQuery(), classificationParams.getK(), classificationParams.getMinDf(), classificationParams.getMinTf(), trainingClassField, field2analyzer, inputFieldNamesWithBoost);
break;
- case "bayes":
- classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, classFieldName, field2analyzer, inputFieldNames);
+ case BAYES:
+ classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, trainingClassField, field2analyzer, inputFieldNamesWithBoost);
break;
}
}
+ private String[] removeBoost(String[] inputFieldNamesWithBoost) {
+ String[] inputFieldNames = new String[inputFieldNamesWithBoost.length];
+ for (int i = 0; i < inputFieldNamesWithBoost.length; i++) {
+ String singleFieldNameWithBoost = inputFieldNamesWithBoost[i];
+ String[] fieldName2boost = singleFieldNameWithBoost.split("\\^");
+ inputFieldNames[i] = fieldName2boost[0];
+ }
+ return inputFieldNames;
+ }
+
/**
* @param cmd the update command in input containing the Document to classify
* @throws IOException If there is a low-level I/O error
@@ -89,12 +102,14 @@ class ClassificationUpdateProcessor
SolrInputDocument doc = cmd.getSolrInputDocument();
Document luceneDocument = cmd.getLuceneDocument();
String assignedClass;
- Object documentClass = doc.getFieldValue(classFieldName);
+ Object documentClass = doc.getFieldValue(trainingClassField);
if (documentClass == null) {
- ClassificationResult<BytesRef> classificationResult = classifier.assignClass(luceneDocument);
- if (classificationResult != null) {
- assignedClass = classificationResult.getAssignedClass().utf8ToString();
- doc.addField(classFieldName, assignedClass);
+ List<ClassificationResult<BytesRef>> assignedClassifications = classifier.getClasses(luceneDocument, maxOutputClasses);
+ if (assignedClassifications != null) {
+ for (ClassificationResult<BytesRef> singleClassification : assignedClassifications) {
+ assignedClass = singleClassification.getAssignedClass().utf8ToString();
+ doc.addField(predictedClassField, assignedClass);
+ }
}
}
super.processAdd(cmd);
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java
index 81bec2f..19e0dfe 100644
--- a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java
+++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java
@@ -18,12 +18,17 @@
package org.apache.solr.update.processor;
import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.search.Query;
import org.apache.solr.common.SolrException;
import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.request.SolrQueryRequest;
import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.schema.IndexSchema;
+import org.apache.solr.search.LuceneQParser;
+import org.apache.solr.search.SyntaxError;
+
+import static org.apache.solr.update.processor.ClassificationUpdateProcessorFactory.Algorithm.KNN;
/**
* This class implements an UpdateProcessorFactory for the Classification Update Processor.
@@ -33,49 +38,67 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor
// Update Processor Config params
private static final String INPUT_FIELDS_PARAM = "inputFields";
- private static final String CLASS_FIELD_PARAM = "classField";
+ private static final String TRAINING_CLASS_FIELD_PARAM = "classField";
+ private static final String PREDICTED_CLASS_FIELD_PARAM = "predictedClassField";
+ private static final String MAX_CLASSES_TO_ASSIGN_PARAM = "predictedClass.maxCount";
private static final String ALGORITHM_PARAM = "algorithm";
private static final String KNN_MIN_TF_PARAM = "knn.minTf";
private static final String KNN_MIN_DF_PARAM = "knn.minDf";
private static final String KNN_K_PARAM = "knn.k";
+ private static final String KNN_FILTER_QUERY = "knn.filterQuery";
+
+ public enum Algorithm {KNN, BAYES}
//Update Processor Defaults
+ private static final int DEFAULT_MAX_CLASSES_TO_ASSIGN = 1;
private static final int DEFAULT_MIN_TF = 1;
private static final int DEFAULT_MIN_DF = 1;
private static final int DEFAULT_K = 10;
- private static final String DEFAULT_ALGORITHM = "knn";
-
- private String[] inputFieldNames; // the array of fields to be sent to the Classifier
-
- private String classFieldName; // the field containing the class for the Document
-
- private String algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes'
-
- private int minTf; // knn specific - the minimum Term Frequency for considering a term
-
- private int minDf; // knn specific - the minimum Document Frequency for considering a term
+ private static final Algorithm DEFAULT_ALGORITHM = KNN;
- private int k; // knn specific - thw window of top results to evaluate, when assigning the class
+ private SolrParams params;
+ private ClassificationUpdateProcessorParams classificationParams;
@Override
public void init(final NamedList args) {
if (args != null) {
- SolrParams params = SolrParams.toSolrParams(args);
+ params = SolrParams.toSolrParams(args);
+ classificationParams = new ClassificationUpdateProcessorParams();
String fieldNames = params.get(INPUT_FIELDS_PARAM);// must be a comma separated list of fields
checkNotNull(INPUT_FIELDS_PARAM, fieldNames);
- inputFieldNames = fieldNames.split("\\,");
-
- classFieldName = params.get(CLASS_FIELD_PARAM);
- checkNotNull(CLASS_FIELD_PARAM, classFieldName);
-
- algorithm = params.get(ALGORITHM_PARAM);
- if (algorithm == null)
- algorithm = DEFAULT_ALGORITHM;
-
- minTf = getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF);
- minDf = getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF);
- k = getIntParam(params, KNN_K_PARAM, DEFAULT_K);
+ classificationParams.setInputFieldNames(fieldNames.split("\\,"));
+
+ String trainingClassField = (params.get(TRAINING_CLASS_FIELD_PARAM));
+ checkNotNull(TRAINING_CLASS_FIELD_PARAM, trainingClassField);
+ classificationParams.setTrainingClassField(trainingClassField);
+
+ String predictedClassField = (params.get(PREDICTED_CLASS_FIELD_PARAM));
+ if (predictedClassField == null || predictedClassField.isEmpty()) {
+ predictedClassField = trainingClassField;
+ }
+ classificationParams.setPredictedClassField(predictedClassField);
+
+ classificationParams.setMaxPredictedClasses(getIntParam(params, MAX_CLASSES_TO_ASSIGN_PARAM, DEFAULT_MAX_CLASSES_TO_ASSIGN));
+
+ String algorithmString = params.get(ALGORITHM_PARAM);
+ Algorithm classificationAlgorithm;
+ try {
+ if (algorithmString == null || Algorithm.valueOf(algorithmString.toUpperCase()) == null) {
+ classificationAlgorithm = DEFAULT_ALGORITHM;
+ } else {
+ classificationAlgorithm = Algorithm.valueOf(algorithmString.toUpperCase());
+ }
+ } catch (IllegalArgumentException e) {
+ throw new SolrException
+ (SolrException.ErrorCode.SERVER_ERROR,
+ "Classification UpdateProcessor Algorithm: '" + algorithmString + "' not supported");
+ }
+ classificationParams.setAlgorithm(classificationAlgorithm);
+
+ classificationParams.setMinTf(getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF));
+ classificationParams.setMinDf(getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF));
+ classificationParams.setK(getIntParam(params, KNN_K_PARAM, DEFAULT_K));
}
}
@@ -108,116 +131,34 @@ public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessor
@Override
public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
+ String trainingFilterQueryString = (params.get(KNN_FILTER_QUERY));
+ try {
+ if (trainingFilterQueryString != null && !trainingFilterQueryString.isEmpty()) {
+ Query trainingFilterQuery = this.parseFilterQuery(trainingFilterQueryString, params, req);
+ classificationParams.setTrainingFilterQuery(trainingFilterQuery);
+ }
+ } catch (SyntaxError | RuntimeException syntaxError) {
+ throw new SolrException
+ (SolrException.ErrorCode.SERVER_ERROR,
+ "Classification UpdateProcessor Training Filter Query: '" + trainingFilterQueryString + "' is not supported", syntaxError);
+ }
+
IndexSchema schema = req.getSchema();
IndexReader indexReader = req.getSearcher().getIndexReader();
- return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, indexReader, schema);
- }
- /**
- * get field names used as classifier's inputs
- *
- * @return the input field names
- */
- public String[] getInputFieldNames() {
- return inputFieldNames;
- }
-
- /**
- * set field names used as classifier's inputs
- *
- * @param inputFieldNames the input field names
- */
- public void setInputFieldNames(String[] inputFieldNames) {
- this.inputFieldNames = inputFieldNames;
+ return new ClassificationUpdateProcessor(classificationParams, next, indexReader, schema);
}
- /**
- * get field names used as classifier's output
- *
- * @return the output field name
- */
- public String getClassFieldName() {
- return classFieldName;
+ private Query parseFilterQuery(String trainingFilterQueryString, SolrParams params, SolrQueryRequest req) throws SyntaxError {
+ LuceneQParser parser = new LuceneQParser(trainingFilterQueryString, null, params, req);
+ return parser.parse();
}
- /**
- * set field names used as classifier's output
- *
- * @param classFieldName the output field name
- */
- public void setClassFieldName(String classFieldName) {
- this.classFieldName = classFieldName;
+ public ClassificationUpdateProcessorParams getClassificationParams() {
+ return classificationParams;
}
- /**
- * get the name of the classifier algorithm used
- *
- * @return the classifier algorithm used
- */
- public String getAlgorithm() {
- return algorithm;
- }
-
- /**
- * set the name of the classifier algorithm used
- *
- * @param algorithm the classifier algorithm used
- */
- public void setAlgorithm(String algorithm) {
- this.algorithm = algorithm;
- }
-
- /**
- * get the min term frequency value to be used in case algorithm is {@code "knn"}
- *
- * @return the min term frequency
- */
- public int getMinTf() {
- return minTf;
- }
-
- /**
- * set the min term frequency value to be used in case algorithm is {@code "knn"}
- *
- * @param minTf the min term frequency
- */
- public void setMinTf(int minTf) {
- this.minTf = minTf;
- }
-
- /**
- * get the min document frequency value to be used in case algorithm is {@code "knn"}
- *
- * @return the min document frequency
- */
- public int getMinDf() {
- return minDf;
- }
-
- /**
- * set the min document frequency value to be used in case algorithm is {@code "knn"}
- *
- * @param minDf the min document frequency
- */
- public void setMinDf(int minDf) {
- this.minDf = minDf;
- }
-
- /**
- * get the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
- *
- * @return the no. of neighbors to analyze
- */
- public int getK() {
- return k;
- }
-
- /**
- * set the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
- *
- * @param k the no. of neighbors to analyze
- */
- public void setK(int k) {
- this.k = k;
+ public void setClassificationParams(ClassificationUpdateProcessorParams classificationParams) {
+ this.classificationParams = classificationParams;
}
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorParams.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorParams.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorParams.java
new file mode 100644
index 0000000..536cec3
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorParams.java
@@ -0,0 +1,112 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.update.processor;
+
+import org.apache.lucene.search.Query;
+
+public class ClassificationUpdateProcessorParams {
+
+ private String[] inputFieldNames; // the array of fields to be sent to the Classifier
+
+ private Query trainingFilterQuery; // a filter query to reduce the training set to a subset
+
+ private String trainingClassField; // the field containing the class for the Document
+
+ private String predictedClassField; // the field that will contain the predicted class
+
+ private int maxPredictedClasses; // the max number of classes to assign
+
+ private ClassificationUpdateProcessorFactory.Algorithm algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes'
+
+ private int minTf; // knn specific - the minimum Term Frequency for considering a term
+
+ private int minDf; // knn specific - the minimum Document Frequency for considering a term
+
+ private int k; // knn specific - thw window of top results to evaluate, when assigning the class
+
+ public String[] getInputFieldNames() {
+ return inputFieldNames;
+ }
+
+ public void setInputFieldNames(String[] inputFieldNames) {
+ this.inputFieldNames = inputFieldNames;
+ }
+
+ public Query getTrainingFilterQuery() {
+ return trainingFilterQuery;
+ }
+
+ public void setTrainingFilterQuery(Query trainingFilterQuery) {
+ this.trainingFilterQuery = trainingFilterQuery;
+ }
+
+ public String getTrainingClassField() {
+ return trainingClassField;
+ }
+
+ public void setTrainingClassField(String trainingClassField) {
+ this.trainingClassField = trainingClassField;
+ }
+
+ public String getPredictedClassField() {
+ return predictedClassField;
+ }
+
+ public void setPredictedClassField(String predictedClassField) {
+ this.predictedClassField = predictedClassField;
+ }
+
+ public int getMaxPredictedClasses() {
+ return maxPredictedClasses;
+ }
+
+ public void setMaxPredictedClasses(int maxPredictedClasses) {
+ this.maxPredictedClasses = maxPredictedClasses;
+ }
+
+ public ClassificationUpdateProcessorFactory.Algorithm getAlgorithm() {
+ return algorithm;
+ }
+
+ public void setAlgorithm(ClassificationUpdateProcessorFactory.Algorithm algorithm) {
+ this.algorithm = algorithm;
+ }
+
+ public int getMinTf() {
+ return minTf;
+ }
+
+ public void setMinTf(int minTf) {
+ this.minTf = minTf;
+ }
+
+ public int getMinDf() {
+ return minDf;
+ }
+
+ public void setMinDf(int minDf) {
+ this.minDf = minDf;
+ }
+
+ public int getK() {
+ return k;
+ }
+
+ public void setK(int k) {
+ this.k = k;
+ }
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml
----------------------------------------------------------------------
diff --git a/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml b/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml
index 3656335..f688ed1 100644
--- a/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml
+++ b/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml
@@ -47,6 +47,21 @@
<str name="knn.minTf">1</str>
<str name="knn.minDf">1</str>
<str name="knn.k">5</str>
+ <str name="knn.filterQuery">cat:(class1 OR class2)</str>
+ </processor>
+ <processor class="solr.RunUpdateProcessorFactory"/>
+ </updateRequestProcessorChain>
+
+ <updateRequestProcessorChain name="classification-unsupported-filterQuery">
+ <processor class="solr.ClassificationUpdateProcessorFactory">
+ <str name="inputFields">title,content,author</str>
+ <str name="classField">cat</str>
+ <!-- Knn algorithm specific-->
+ <str name="algorithm">knn</str>
+ <str name="knn.minTf">1</str>
+ <str name="knn.minDf">1</str>
+ <str name="knn.k">5</str>
+ <str name="knn.filterQuery">not valid ( lucene query</str>
</processor>
<processor class="solr.RunUpdateProcessorFactory"/>
</updateRequestProcessorChain>
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java
----------------------------------------------------------------------
diff --git a/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java
index 05d112f..fe22918 100644
--- a/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java
+++ b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java
@@ -14,71 +14,31 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.solr.update.processor;
-import java.io.IOException;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.Map;
-
-import org.apache.lucene.document.Document;
-import org.apache.lucene.index.Term;
-import org.apache.lucene.search.ScoreDoc;
-import org.apache.lucene.search.TermQuery;
-import org.apache.lucene.search.TopDocs;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.common.SolrException;
-import org.apache.solr.common.params.MultiMapSolrParams;
-import org.apache.solr.common.params.SolrParams;
-import org.apache.solr.common.params.UpdateParams;
-import org.apache.solr.common.util.ContentStream;
-import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.common.util.NamedList;
-import org.apache.solr.handler.UpdateRequestHandler;
import org.apache.solr.request.SolrQueryRequest;
-import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.response.SolrQueryResponse;
-import org.apache.solr.search.SolrIndexSearcher;
import org.junit.Before;
-import org.junit.BeforeClass;
import org.junit.Test;
+import static org.hamcrest.core.Is.is;
+import static org.mockito.Mockito.mock;
+
/**
- * Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory}
+ * Tests for {@link ClassificationUpdateProcessorFactory}
*/
public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
- // field names are used in accordance with the solrconfig and schema supplied
- private static final String ID = "id";
- private static final String TITLE = "title";
- private static final String CONTENT = "content";
- private static final String AUTHOR = "author";
- private static final String CLASS = "cat";
-
- private static final String CHAIN = "classification";
-
-
private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory();
private NamedList args = new NamedList<String>();
- @BeforeClass
- public static void beforeClass() throws Exception {
- System.setProperty("enable.update.log", "false");
- initCore("solrconfig-classification.xml", "schema-classification.xml");
- }
-
- @Override
- @Before
- public void setUp() throws Exception {
- super.setUp();
- clearIndex();
- assertU(commit());
- }
-
@Before
public void initArgs() {
args.add("inputFields", "inputField1,inputField2");
args.add("classField", "classField1");
+ args.add("predictedClassField", "classFieldX");
args.add("algorithm", "bayes");
args.add("knn.k", "9");
args.add("knn.minDf", "8");
@@ -86,22 +46,23 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
}
@Test
- public void testFullInit() {
+ public void init_fullArgs_shouldInitFullClassificationParams() {
cFactoryToTest.init(args);
+ ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
- String[] inputFieldNames = cFactoryToTest.getInputFieldNames();
+ String[] inputFieldNames = classificationParams.getInputFieldNames();
assertEquals("inputField1", inputFieldNames[0]);
assertEquals("inputField2", inputFieldNames[1]);
- assertEquals("classField1", cFactoryToTest.getClassFieldName());
- assertEquals("bayes", cFactoryToTest.getAlgorithm());
- assertEquals(8, cFactoryToTest.getMinDf());
- assertEquals(10, cFactoryToTest.getMinTf());
- assertEquals(9, cFactoryToTest.getK());
-
+ assertEquals("classField1", classificationParams.getTrainingClassField());
+ assertEquals("classFieldX", classificationParams.getPredictedClassField());
+ assertEquals(ClassificationUpdateProcessorFactory.Algorithm.BAYES, classificationParams.getAlgorithm());
+ assertEquals(8, classificationParams.getMinDf());
+ assertEquals(10, classificationParams.getMinTf());
+ assertEquals(9, classificationParams.getK());
}
@Test
- public void testInitEmptyInputField() {
+ public void init_emptyInputFields_shouldThrowExceptionWithDetailedMessage() {
args.removeAll("inputFields");
try {
cFactoryToTest.init(args);
@@ -111,7 +72,7 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
}
@Test
- public void testInitEmptyClassField() {
+ public void init_emptyClassField_shouldThrowExceptionWithDetailedMessage() {
args.removeAll("classField");
try {
cFactoryToTest.init(args);
@@ -121,114 +82,53 @@ public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
}
@Test
- public void testDefaults() {
- args.removeAll("algorithm");
- args.removeAll("knn.k");
- args.removeAll("knn.minDf");
- args.removeAll("knn.minTf");
- cFactoryToTest.init(args);
- assertEquals("knn", cFactoryToTest.getAlgorithm());
- assertEquals(1, cFactoryToTest.getMinDf());
- assertEquals(1, cFactoryToTest.getMinTf());
- assertEquals(10, cFactoryToTest.getK());
- }
+ public void init_emptyPredictedClassField_shouldDefaultToTrainingClassField() {
+ args.removeAll("predictedClassField");
- @Test
- public void testBasicClassification() throws Exception {
- prepareTrainedIndex();
- // To be classified,we index documents without a class and verify the expected one is returned
- addDoc(adoc(ID, "10",
- TITLE, "word4 word4 word4",
- CONTENT, "word5 word5 ",
- AUTHOR, "Name1 Surname1"));
- addDoc(adoc(ID, "11",
- TITLE, "word1 word1",
- CONTENT, "word2 word2",
- AUTHOR, "Name Surname"));
- addDoc(commit());
+ cFactoryToTest.init(args);
- Document doc10 = getDoc("10");
- assertEquals("class2", doc10.get(CLASS));
- Document doc11 = getDoc("11");
- assertEquals("class1", doc11.get(CLASS));
+ ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
+ assertThat(classificationParams.getPredictedClassField(), is("classField1"));
}
- /**
- * Index some example documents with a class manually assigned.
- * This will be our trained model.
- *
- * @throws Exception If there is a low-level I/O error
- */
- private void prepareTrainedIndex() throws Exception {
- //class1
- addDoc(adoc(ID, "1",
- TITLE, "word1 word1 word1",
- CONTENT, "word2 word2 word2",
- AUTHOR, "Name Surname",
- CLASS, "class1"));
- addDoc(adoc(ID, "2",
- TITLE, "word1 word1",
- CONTENT, "word2 word2",
- AUTHOR, "Name Surname",
- CLASS, "class1"));
- addDoc(adoc(ID, "3",
- TITLE, "word1 word1 word1",
- CONTENT, "word2",
- AUTHOR, "Name Surname",
- CLASS, "class1"));
- addDoc(adoc(ID, "4",
- TITLE, "word1 word1 word1",
- CONTENT, "word2 word2 word2",
- AUTHOR, "Name Surname",
- CLASS, "class1"));
- //class2
- addDoc(adoc(ID, "5",
- TITLE, "word4 word4 word4",
- CONTENT, "word5 word5",
- AUTHOR, "Name1 Surname1",
- CLASS, "class2"));
- addDoc(adoc(ID, "6",
- TITLE, "word4 word4",
- CONTENT, "word5",
- AUTHOR, "Name1 Surname1",
- CLASS, "class2"));
- addDoc(adoc(ID, "7",
- TITLE, "word4 word4 word4",
- CONTENT, "word5 word5 word5",
- AUTHOR, "Name1 Surname1",
- CLASS, "class2"));
- addDoc(adoc(ID, "8",
- TITLE, "word4",
- CONTENT, "word5 word5 word5 word5",
- AUTHOR, "Name1 Surname1",
- CLASS, "class2"));
- addDoc(commit());
+ @Test
+ public void init_unsupportedAlgorithm_shouldThrowExceptionWithDetailedMessage() {
+ args.removeAll("algorithm");
+ args.add("algorithm", "unsupported");
+ try {
+ cFactoryToTest.init(args);
+ } catch (SolrException e) {
+ assertEquals("Classification UpdateProcessor Algorithm: 'unsupported' not supported", e.getMessage());
+ }
}
- private Document getDoc(String id) throws IOException {
- try (SolrQueryRequest req = req()) {
- SolrIndexSearcher searcher = req.getSearcher();
- TermQuery query = new TermQuery(new Term(ID, id));
- TopDocs doc1 = searcher.search(query, 1);
- ScoreDoc scoreDoc = doc1.scoreDocs[0];
- return searcher.doc(scoreDoc.doc);
+ @Test
+ public void init_unsupportedFilterQuery_shouldThrowExceptionWithDetailedMessage() {
+ UpdateRequestProcessor mockProcessor = mock(UpdateRequestProcessor.class);
+ SolrQueryRequest mockRequest = mock(SolrQueryRequest.class);
+ SolrQueryResponse mockResponse = mock(SolrQueryResponse.class);
+ args.add("knn.filterQuery", "not supported query");
+ try {
+ cFactoryToTest.init(args);
+ /* parsing failure happens because of the mocks, fine enough to check a proper exception propagation */
+ cFactoryToTest.getInstance(mockRequest, mockResponse, mockProcessor);
+ } catch (SolrException e) {
+ assertEquals("Classification UpdateProcessor Training Filter Query: 'not supported query' is not supported", e.getMessage());
}
}
- static void addDoc(String doc) throws Exception {
- Map<String, String[]> params = new HashMap<>();
- MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
- params.put(UpdateParams.UPDATE_CHAIN, new String[]{CHAIN});
- SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
- (SolrParams) mmparams) {
- };
+ @Test
+ public void init_emptyArgs_shouldDefaultClassificationParams() {
+ args.removeAll("algorithm");
+ args.removeAll("knn.k");
+ args.removeAll("knn.minDf");
+ args.removeAll("knn.minTf");
+ cFactoryToTest.init(args);
+ ClassificationUpdateProcessorParams classificationParams = cFactoryToTest.getClassificationParams();
- UpdateRequestHandler handler = new UpdateRequestHandler();
- handler.init(null);
- ArrayList<ContentStream> streams = new ArrayList<>(2);
- streams.add(new ContentStreamBase.StringStream(doc));
- req.setContentStreams(streams);
- handler.handleRequestBody(req, new SolrQueryResponse());
- req.close();
+ assertEquals(ClassificationUpdateProcessorFactory.Algorithm.KNN, classificationParams.getAlgorithm());
+ assertEquals(1, classificationParams.getMinDf());
+ assertEquals(1, classificationParams.getMinTf());
+ assertEquals(10, classificationParams.getK());
}
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorIntegrationTest.java
----------------------------------------------------------------------
diff --git a/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorIntegrationTest.java b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorIntegrationTest.java
new file mode 100644
index 0000000..3aee1be
--- /dev/null
+++ b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorIntegrationTest.java
@@ -0,0 +1,192 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.update.processor;
+
+import java.io.IOException;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.TopDocs;
+import org.apache.solr.SolrTestCaseJ4;
+import org.apache.solr.common.SolrException;
+import org.apache.solr.common.util.NamedList;
+import org.apache.solr.request.SolrQueryRequest;
+import org.apache.solr.search.SolrIndexSearcher;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.hamcrest.core.Is.is;
+
+/**
+ * Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory}
+ */
+public class ClassificationUpdateProcessorIntegrationTest extends SolrTestCaseJ4 {
+ /* field names are used in accordance with the solrconfig and schema supplied */
+ private static final String ID = "id";
+ private static final String TITLE = "title";
+ private static final String CONTENT = "content";
+ private static final String AUTHOR = "author";
+ private static final String CLASS = "cat";
+
+ private static final String CHAIN = "classification";
+ private static final String BROKEN_CHAIN_FILTER_QUERY = "classification-unsupported-filterQuery";
+
+ private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory();
+ private NamedList args = new NamedList<String>();
+
+ @BeforeClass
+ public static void beforeClass() throws Exception {
+ System.setProperty("enable.update.log", "false");
+ initCore("solrconfig-classification.xml", "schema-classification.xml");
+ }
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ clearIndex();
+ assertU(commit());
+ }
+
+ @Test
+ public void classify_fullConfiguration_shouldAutoClassify() throws Exception {
+ indexTrainingSet();
+ // To be classified,we index documents without a class and verify the expected one is returned
+ addDoc(adoc(ID, "22",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5 ",
+ AUTHOR, "Name1 Surname1"), CHAIN);
+ addDoc(adoc(ID, "21",
+ TITLE, "word1 word1",
+ CONTENT, "word2 word2",
+ AUTHOR, "Name Surname"), CHAIN);
+ addDoc(commit());
+
+ Document doc22 = getDoc("22");
+ assertThat(doc22.get(CLASS),is("class2"));
+ Document doc21 = getDoc("21");
+ assertThat(doc21.get(CLASS),is("class1"));
+ }
+
+ @Test
+ public void classify_unsupportedFilterQueryConfiguration_shouldThrowExceptionWithDetailedMessage() throws Exception {
+ indexTrainingSet();
+ try {
+ addDoc(adoc(ID, "21",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5 ",
+ AUTHOR, "Name1 Surname1"), BROKEN_CHAIN_FILTER_QUERY);
+ addDoc(adoc(ID, "22",
+ TITLE, "word1 word1",
+ CONTENT, "word2 word2",
+ AUTHOR, "Name Surname"), BROKEN_CHAIN_FILTER_QUERY);
+ addDoc(commit());
+ } catch (SolrException e) {
+ assertEquals("Classification UpdateProcessor Training Filter Query: 'not valid ( lucene query' is not supported", e.getMessage());
+ }
+ }
+
+ /**
+ * Index some example documents with a class manually assigned.
+ * This will be our trained model.
+ *
+ * @throws Exception If there is a low-level I/O error
+ */
+ private void indexTrainingSet() throws Exception {
+ //class1
+ addDoc(adoc(ID, "1",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2 word2 word2",
+ AUTHOR, "Name Surname",
+ CLASS, "class1"), CHAIN);
+ addDoc(adoc(ID, "2",
+ TITLE, "word1 word1",
+ CONTENT, "word2 word2",
+ AUTHOR, "Name Surname",
+ CLASS, "class1"), CHAIN);
+ addDoc(adoc(ID, "3",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2",
+ AUTHOR, "Name Surname",
+ CLASS, "class1"), CHAIN);
+ addDoc(adoc(ID, "4",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2 word2 word2",
+ AUTHOR, "Name Surname",
+ CLASS, "class1"), CHAIN);
+ //class2
+ addDoc(adoc(ID, "5",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5",
+ AUTHOR, "Name Surname",
+ CLASS, "class2"), CHAIN);
+ addDoc(adoc(ID, "6",
+ TITLE, "word4 word4",
+ CONTENT, "word5",
+ AUTHOR, "Name Surname",
+ CLASS, "class2"), CHAIN);
+ addDoc(adoc(ID, "7",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5 word5",
+ AUTHOR, "Name Surname",
+ CLASS, "class2"), CHAIN);
+ addDoc(adoc(ID, "8",
+ TITLE, "word4",
+ CONTENT, "word5 word5 word5 word5",
+ AUTHOR, "Name Surname",
+ CLASS, "class2"), CHAIN);
+ //class3
+ addDoc(adoc(ID, "9",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5",
+ AUTHOR, "Name1 Surname1",
+ CLASS, "class3"), CHAIN);
+ addDoc(adoc(ID, "10",
+ TITLE, "word4 word4",
+ CONTENT, "word5",
+ AUTHOR, "Name1 Surname1",
+ CLASS, "class3"), CHAIN);
+ addDoc(adoc(ID, "11",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5 word5",
+ AUTHOR, "Name1 Surname1",
+ CLASS, "class3"), CHAIN);
+ addDoc(adoc(ID, "12",
+ TITLE, "word4",
+ CONTENT, "word5 word5 word5 word5",
+ AUTHOR, "Name1 Surname1",
+ CLASS, "class3"), CHAIN);
+ addDoc(commit());
+ }
+
+ private Document getDoc(String id) throws IOException {
+ try (SolrQueryRequest req = req()) {
+ SolrIndexSearcher searcher = req.getSearcher();
+ TermQuery query = new TermQuery(new Term(ID, id));
+ TopDocs doc1 = searcher.search(query, 1);
+ ScoreDoc scoreDoc = doc1.scoreDocs[0];
+ return searcher.doc(scoreDoc.doc);
+ }
+ }
+
+ private void addDoc(String doc) throws Exception {
+ addDoc(doc, CHAIN);
+ }
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorTest.java
----------------------------------------------------------------------
diff --git a/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorTest.java b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorTest.java
new file mode 100644
index 0000000..938dfc5
--- /dev/null
+++ b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorTest.java
@@ -0,0 +1,507 @@
+package org.apache.solr.update.processor;
+
+import java.io.IOException;
+import java.util.ArrayList;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.analysis.MockAnalyzer;
+import org.apache.lucene.analysis.MockTokenizer;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.RandomIndexWriter;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.store.Directory;
+import org.apache.solr.SolrTestCaseJ4;
+import org.apache.solr.common.SolrInputDocument;
+import org.apache.solr.update.AddUpdateCommand;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+import static org.hamcrest.core.Is.is;
+import static org.mockito.Mockito.mock;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * Tests for {@link ClassificationUpdateProcessor}
+ */
+public class ClassificationUpdateProcessorTest extends SolrTestCaseJ4 {
+ /* field names are used in accordance with the solrconfig and schema supplied */
+ private static final String ID = "id";
+ private static final String TITLE = "title";
+ private static final String CONTENT = "content";
+ private static final String AUTHOR = "author";
+ private static final String TRAINING_CLASS = "cat";
+ private static final String PREDICTED_CLASS = "predicted";
+ public static final String KNN = "knn";
+
+ protected Directory directory;
+ protected IndexReader reader;
+ protected IndexSearcher searcher;
+ protected Analyzer analyzer = new MockAnalyzer(random(), MockTokenizer.WHITESPACE, false);
+ private ClassificationUpdateProcessor updateProcessorToTest;
+
+ @BeforeClass
+ public static void beforeClass() throws Exception {
+ System.setProperty("enable.update.log", "false");
+ initCore("solrconfig-classification.xml", "schema-classification.xml");
+ }
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ }
+
+ @Override
+ public void tearDown() throws Exception {
+ reader.close();
+ directory.close();
+ analyzer.close();
+ super.tearDown();
+ }
+
+
+
+
+ @Test
+ public void classificationMonoClass_predictedClassFieldSet_shouldAssignClassInPredictedClassField() throws Exception {
+ UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
+ prepareTrainedIndexMonoClass();
+
+ AddUpdateCommand update=new AddUpdateCommand(req());
+ SolrInputDocument unseenDocument1 = sdoc(ID, "10",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word2 word2 ",
+ AUTHOR, "unseenAuthor");
+ update.solrDoc=unseenDocument1;
+
+ ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
+ params.setPredictedClassField(PREDICTED_CLASS);
+
+ updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
+ updateProcessorToTest.processAdd(update);
+
+ assertThat(unseenDocument1.getFieldValue(PREDICTED_CLASS),is("class1"));
+ }
+
+ @Test
+ public void knnMonoClass_sampleParams_shouldAssignCorrectClass() throws Exception {
+ UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
+ prepareTrainedIndexMonoClass();
+
+ AddUpdateCommand update=new AddUpdateCommand(req());
+ SolrInputDocument unseenDocument1 = sdoc(ID, "10",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word2 word2 ",
+ AUTHOR, "unseenAuthor");
+ update.solrDoc=unseenDocument1;
+
+ ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
+
+ updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
+ updateProcessorToTest.processAdd(update);
+
+ assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class1"));
+ }
+
+ @Test
+ public void knnMonoClass_boostFields_shouldAssignCorrectClass() throws Exception {
+ UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
+ prepareTrainedIndexMonoClass();
+
+ AddUpdateCommand update=new AddUpdateCommand(req());
+ SolrInputDocument unseenDocument1 = sdoc(ID, "10",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word2 word2 ",
+ AUTHOR, "unseenAuthor");
+ update.solrDoc=unseenDocument1;
+
+ ClassificationUpdateProcessorParams params = initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
+ params.setInputFieldNames(new String[]{TITLE + "^1.5", CONTENT + "^0.5", AUTHOR + "^2.5"});
+
+ updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
+
+ updateProcessorToTest.processAdd(update);
+
+ assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class2"));
+ }
+
+ @Test
+ public void bayesMonoClass_sampleParams_shouldAssignCorrectClass() throws Exception {
+ UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
+ prepareTrainedIndexMonoClass();
+
+ AddUpdateCommand update=new AddUpdateCommand(req());
+ SolrInputDocument unseenDocument1 = sdoc(ID, "10",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word2 word2 ",
+ AUTHOR, "unseenAuthor");
+ update.solrDoc=unseenDocument1;
+
+ ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
+
+ updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
+ updateProcessorToTest.processAdd(update);
+
+ assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class1"));
+ }
+
+ @Test
+ public void knnMonoClass_contextQueryFiltered_shouldAssignCorrectClass() throws Exception {
+ UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
+ prepareTrainedIndexMonoClass();
+
+ AddUpdateCommand update=new AddUpdateCommand(req());
+ SolrInputDocument unseenDocument1 = sdoc(ID, "10",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word2 word2 ",
+ AUTHOR, "a");
+ update.solrDoc=unseenDocument1;
+
+ ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
+ Query class3DocsChunk=new TermQuery(new Term(TITLE,"word6"));
+ params.setTrainingFilterQuery(class3DocsChunk);
+
+ updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
+ updateProcessorToTest.processAdd(update);
+
+ assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class3"));
+ }
+
+ @Test
+ public void bayesMonoClass_boostFields_shouldAssignCorrectClass() throws Exception {
+ UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
+ prepareTrainedIndexMonoClass();
+
+ AddUpdateCommand update=new AddUpdateCommand(req());
+ SolrInputDocument unseenDocument1 = sdoc(ID, "10",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word2 word2 ",
+ AUTHOR, "unseenAuthor");
+ update.solrDoc=unseenDocument1;
+
+ ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
+ params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
+
+ updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
+
+ updateProcessorToTest.processAdd(update);
+
+ assertThat(unseenDocument1.getFieldValue(TRAINING_CLASS),is("class2"));
+ }
+
+ @Test
+ public void knnClassification_maxOutputClassesGreaterThanAvailable_shouldAssignCorrectClass() throws Exception {
+ UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
+ prepareTrainedIndexMultiClass();
+
+ AddUpdateCommand update=new AddUpdateCommand(req());
+ SolrInputDocument unseenDocument1 = sdoc(ID, "10",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2 word2 ",
+ AUTHOR, "unseenAuthor");
+ update.solrDoc=unseenDocument1;
+
+ ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
+ params.setMaxPredictedClasses(100);
+
+ updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
+ updateProcessorToTest.processAdd(update);
+
+ ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
+ assertThat(assignedClasses.get(0),is("class2"));
+ assertThat(assignedClasses.get(1),is("class1"));
+ }
+
+ @Test
+ public void knnMultiClass_maxOutputClasses2_shouldAssignMax2Classes() throws Exception {
+ UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
+ prepareTrainedIndexMultiClass();
+
+ AddUpdateCommand update=new AddUpdateCommand(req());
+ SolrInputDocument unseenDocument1 = sdoc(ID, "10",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2 word2 ",
+ AUTHOR, "unseenAuthor");
+ update.solrDoc=unseenDocument1;
+
+ ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
+ params.setMaxPredictedClasses(2);
+
+ updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
+ updateProcessorToTest.processAdd(update);
+
+ ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
+ assertThat(assignedClasses.size(),is(2));
+ assertThat(assignedClasses.get(0),is("class2"));
+ assertThat(assignedClasses.get(1),is("class1"));
+ }
+
+ @Test
+ public void bayesMultiClass_maxOutputClasses2_shouldAssignMax2Classes() throws Exception {
+ UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
+ prepareTrainedIndexMultiClass();
+
+ AddUpdateCommand update=new AddUpdateCommand(req());
+ SolrInputDocument unseenDocument1 = sdoc(ID, "10",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2 word2 ",
+ AUTHOR, "unseenAuthor");
+ update.solrDoc=unseenDocument1;
+
+ ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
+ params.setMaxPredictedClasses(2);
+
+ updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
+ updateProcessorToTest.processAdd(update);
+
+ ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
+ assertThat(assignedClasses.size(),is(2));
+ assertThat(assignedClasses.get(0),is("class2"));
+ assertThat(assignedClasses.get(1),is("class1"));
+ }
+
+ @Test
+ public void knnMultiClass_boostFieldsMaxOutputClasses2_shouldAssignMax2Classes() throws Exception {
+ UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
+ prepareTrainedIndexMultiClass();
+
+ AddUpdateCommand update=new AddUpdateCommand(req());
+ SolrInputDocument unseenDocument1 = sdoc(ID, "10",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word2 word2 ",
+ AUTHOR, "unseenAuthor");
+ update.solrDoc=unseenDocument1;
+
+ ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.KNN);
+ params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
+ params.setMaxPredictedClasses(2);
+
+ updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
+
+ updateProcessorToTest.processAdd(update);
+
+ ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
+ assertThat(assignedClasses.size(),is(2));
+ assertThat(assignedClasses.get(0),is("class4"));
+ assertThat(assignedClasses.get(1),is("class6"));
+ }
+
+ @Test
+ public void bayesMultiClass_boostFieldsMaxOutputClasses2_shouldAssignMax2Classes() throws Exception {
+ UpdateRequestProcessor mockProcessor=mock(UpdateRequestProcessor.class);
+ prepareTrainedIndexMultiClass();
+
+ AddUpdateCommand update=new AddUpdateCommand(req());
+ SolrInputDocument unseenDocument1 = sdoc(ID, "10",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word2 word2 ",
+ AUTHOR, "unseenAuthor");
+ update.solrDoc=unseenDocument1;
+
+ ClassificationUpdateProcessorParams params= initParams(ClassificationUpdateProcessorFactory.Algorithm.BAYES);
+ params.setInputFieldNames(new String[]{TITLE+"^1.5",CONTENT+"^0.5",AUTHOR+"^2.5"});
+ params.setMaxPredictedClasses(2);
+
+ updateProcessorToTest=new ClassificationUpdateProcessor(params,mockProcessor,reader,req().getSchema());
+
+ updateProcessorToTest.processAdd(update);
+
+ ArrayList<Object> assignedClasses = (ArrayList)unseenDocument1.getFieldValues(TRAINING_CLASS);
+ assertThat(assignedClasses.size(),is(2));
+ assertThat(assignedClasses.get(0),is("class4"));
+ assertThat(assignedClasses.get(1),is("class6"));
+ }
+
+ private ClassificationUpdateProcessorParams initParams(ClassificationUpdateProcessorFactory.Algorithm classificationAlgorithm) {
+ ClassificationUpdateProcessorParams params= new ClassificationUpdateProcessorParams();
+ params.setInputFieldNames(new String[]{TITLE,CONTENT,AUTHOR});
+ params.setTrainingClassField(TRAINING_CLASS);
+ params.setPredictedClassField(TRAINING_CLASS);
+ params.setMinTf(1);
+ params.setMinDf(1);
+ params.setK(5);
+ params.setAlgorithm(classificationAlgorithm);
+ params.setMaxPredictedClasses(1);
+ return params;
+ }
+
+ /**
+ * Index some example documents with a class manually assigned.
+ * This will be our trained model.
+ *
+ * @throws Exception If there is a low-level I/O error
+ */
+ private void prepareTrainedIndexMonoClass() throws Exception {
+ directory = newDirectory();
+ RandomIndexWriter writer = new RandomIndexWriter(random(), directory);
+
+ //class1
+ addDoc(writer, buildLuceneDocument(ID, "1",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2 word2 word2",
+ AUTHOR, "a",
+ TRAINING_CLASS, "class1"));
+ addDoc(writer, buildLuceneDocument(ID, "2",
+ TITLE, "word1 word1",
+ CONTENT, "word2 word2",
+ AUTHOR, "a",
+ TRAINING_CLASS, "class1"));
+ addDoc(writer, buildLuceneDocument(ID, "3",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2",
+ AUTHOR, "a",
+ TRAINING_CLASS, "class1"));
+ addDoc(writer, buildLuceneDocument(ID, "4",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2 word2 word2",
+ AUTHOR, "a",
+ TRAINING_CLASS, "class1"));
+ //class2
+ addDoc(writer, buildLuceneDocument(ID, "5",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5",
+ AUTHOR, "c",
+ TRAINING_CLASS, "class2"));
+ addDoc(writer, buildLuceneDocument(ID, "6",
+ TITLE, "word4 word4",
+ CONTENT, "word5",
+ AUTHOR, "c",
+ TRAINING_CLASS, "class2"));
+ addDoc(writer, buildLuceneDocument(ID, "7",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5 word5",
+ AUTHOR, "c",
+ TRAINING_CLASS, "class2"));
+ addDoc(writer, buildLuceneDocument(ID, "8",
+ TITLE, "word4",
+ CONTENT, "word5 word5 word5 word5",
+ AUTHOR, "c",
+ TRAINING_CLASS, "class2"));
+ //class3
+ addDoc(writer, buildLuceneDocument(ID, "9",
+ TITLE, "word6",
+ CONTENT, "word7",
+ AUTHOR, "a",
+ TRAINING_CLASS, "class3"));
+ addDoc(writer, buildLuceneDocument(ID, "10",
+ TITLE, "word6",
+ CONTENT, "word7",
+ AUTHOR, "a",
+ TRAINING_CLASS, "class3"));
+ addDoc(writer, buildLuceneDocument(ID, "11",
+ TITLE, "word6",
+ CONTENT, "word7",
+ AUTHOR, "a",
+ TRAINING_CLASS, "class3"));
+ addDoc(writer, buildLuceneDocument(ID, "12",
+ TITLE, "word6",
+ CONTENT, "word7",
+ AUTHOR, "a",
+ TRAINING_CLASS, "class3"));
+
+ reader = writer.getReader();
+ writer.close();
+ searcher = newSearcher(reader);
+ }
+
+ private void prepareTrainedIndexMultiClass() throws Exception {
+ directory = newDirectory();
+ RandomIndexWriter writer = new RandomIndexWriter(random(), directory);
+
+ //class1
+ addDoc(writer, buildLuceneDocument(ID, "1",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2 word2 word2",
+ AUTHOR, "Name Surname",
+ TRAINING_CLASS, "class1",
+ TRAINING_CLASS, "class2"
+ ));
+ addDoc(writer, buildLuceneDocument(ID, "2",
+ TITLE, "word1 word1",
+ CONTENT, "word2 word2",
+ AUTHOR, "Name Surname",
+ TRAINING_CLASS, "class3",
+ TRAINING_CLASS, "class2"
+ ));
+ addDoc(writer, buildLuceneDocument(ID, "3",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2",
+ AUTHOR, "Name Surname",
+ TRAINING_CLASS, "class1",
+ TRAINING_CLASS, "class2"
+ ));
+ addDoc(writer, buildLuceneDocument(ID, "4",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2 word2 word2",
+ AUTHOR, "Name Surname",
+ TRAINING_CLASS, "class1",
+ TRAINING_CLASS, "class2"
+ ));
+ //class2
+ addDoc(writer, buildLuceneDocument(ID, "5",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5",
+ AUTHOR, "Name1 Surname1",
+ TRAINING_CLASS, "class6",
+ TRAINING_CLASS, "class4"
+ ));
+ addDoc(writer, buildLuceneDocument(ID, "6",
+ TITLE, "word4 word4",
+ CONTENT, "word5",
+ AUTHOR, "Name1 Surname1",
+ TRAINING_CLASS, "class5",
+ TRAINING_CLASS, "class4"
+ ));
+ addDoc(writer, buildLuceneDocument(ID, "7",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5 word5",
+ AUTHOR, "Name1 Surname1",
+ TRAINING_CLASS, "class6",
+ TRAINING_CLASS, "class4"
+ ));
+ addDoc(writer, buildLuceneDocument(ID, "8",
+ TITLE, "word4",
+ CONTENT, "word5 word5 word5 word5",
+ AUTHOR, "Name1 Surname1",
+ TRAINING_CLASS, "class6",
+ TRAINING_CLASS, "class4"
+ ));
+
+ reader = writer.getReader();
+ writer.close();
+ searcher = newSearcher(reader);
+ }
+
+ public static Document buildLuceneDocument(Object... fieldsAndValues) {
+ Document luceneDoc = new Document();
+ for (int i=0; i<fieldsAndValues.length; i+=2) {
+ luceneDoc.add(newTextField((String)fieldsAndValues[i], (String)fieldsAndValues[i+1], Field.Store.YES));
+ }
+ return luceneDoc;
+ }
+
+ private int addDoc(RandomIndexWriter writer, Document doc) throws IOException {
+ writer.addDocument(doc);
+ return writer.numDocs() - 1;
+ }
+}
\ No newline at end of file
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/test/org/apache/solr/update/processor/SignatureUpdateProcessorFactoryTest.java
----------------------------------------------------------------------
diff --git a/solr/core/src/test/org/apache/solr/update/processor/SignatureUpdateProcessorFactoryTest.java b/solr/core/src/test/org/apache/solr/update/processor/SignatureUpdateProcessorFactoryTest.java
index 0bef1a0..012b8ce 100644
--- a/solr/core/src/test/org/apache/solr/update/processor/SignatureUpdateProcessorFactoryTest.java
+++ b/solr/core/src/test/org/apache/solr/update/processor/SignatureUpdateProcessorFactoryTest.java
@@ -16,31 +16,28 @@
*/
package org.apache.solr.update.processor;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+
import org.apache.lucene.util.Constants;
import org.apache.solr.SolrTestCaseJ4;
import org.apache.solr.client.solrj.impl.BinaryRequestWriter;
import org.apache.solr.client.solrj.request.UpdateRequest;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.params.MultiMapSolrParams;
-import org.apache.solr.common.params.SolrParams;
import org.apache.solr.common.params.UpdateParams;
import org.apache.solr.common.util.ContentStream;
-import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.common.util.NamedList;
import org.apache.solr.core.SolrCore;
import org.apache.solr.handler.UpdateRequestHandler;
import org.apache.solr.request.LocalSolrQueryRequest;
import org.apache.solr.request.SolrQueryRequest;
-import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.response.SolrQueryResponse;
import org.junit.Before;
import org.junit.BeforeClass;
import org.junit.Test;
-import java.util.ArrayList;
-import java.util.HashMap;
-import java.util.Map;
-
/**
*
*/
@@ -359,21 +356,4 @@ public class SignatureUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
private void addDoc(String doc) throws Exception {
addDoc(doc, chain);
}
-
- static void addDoc(String doc, String chain) throws Exception {
- Map<String, String[]> params = new HashMap<>();
- MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
- params.put(UpdateParams.UPDATE_CHAIN, new String[] { chain });
- SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
- (SolrParams) mmparams) {
- };
-
- UpdateRequestHandler handler = new UpdateRequestHandler();
- handler.init(null);
- ArrayList<ContentStream> streams = new ArrayList<>(2);
- streams.add(new ContentStreamBase.StringStream(doc));
- req.setContentStreams(streams);
- handler.handleRequestBody(req, new SolrQueryResponse());
- req.close();
- }
}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/core/src/test/org/apache/solr/update/processor/TestPartialUpdateDeduplication.java
----------------------------------------------------------------------
diff --git a/solr/core/src/test/org/apache/solr/update/processor/TestPartialUpdateDeduplication.java b/solr/core/src/test/org/apache/solr/update/processor/TestPartialUpdateDeduplication.java
index d494eb6..bab5cd3 100644
--- a/solr/core/src/test/org/apache/solr/update/processor/TestPartialUpdateDeduplication.java
+++ b/solr/core/src/test/org/apache/solr/update/processor/TestPartialUpdateDeduplication.java
@@ -25,8 +25,6 @@ import org.junit.Test;
import java.util.Map;
-import static org.apache.solr.update.processor.SignatureUpdateProcessorFactoryTest.addDoc;
-
public class TestPartialUpdateDeduplication extends SolrTestCaseJ4 {
@BeforeClass
public static void beforeClass() throws Exception {
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5ad741ee/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java
----------------------------------------------------------------------
diff --git a/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java b/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java
index 3adad49..19bf601 100644
--- a/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java
+++ b/solr/test-framework/src/java/org/apache/solr/SolrTestCaseJ4.java
@@ -83,7 +83,11 @@ import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.SolrInputField;
import org.apache.solr.common.params.CommonParams;
import org.apache.solr.common.params.ModifiableSolrParams;
+import org.apache.solr.common.params.MultiMapSolrParams;
import org.apache.solr.common.params.SolrParams;
+import org.apache.solr.common.params.UpdateParams;
+import org.apache.solr.common.util.ContentStream;
+import org.apache.solr.common.util.ContentStreamBase;
import org.apache.solr.common.util.ObjectReleaseTracker;
import org.apache.solr.common.util.XML;
import org.apache.solr.core.CoreContainer;
@@ -96,7 +100,9 @@ import org.apache.solr.core.SolrXmlConfig;
import org.apache.solr.handler.UpdateRequestHandler;
import org.apache.solr.request.LocalSolrQueryRequest;
import org.apache.solr.request.SolrQueryRequest;
+import org.apache.solr.request.SolrQueryRequestBase;
import org.apache.solr.request.SolrRequestHandler;
+import org.apache.solr.response.SolrQueryResponse;
import org.apache.solr.schema.IndexSchema;
import org.apache.solr.schema.SchemaField;
import org.apache.solr.search.SolrIndexSearcher;
@@ -1009,6 +1015,22 @@ public abstract class SolrTestCaseJ4 extends LuceneTestCase {
return out.toString();
}
+ public static void addDoc(String doc, String updateRequestProcessorChain) throws Exception {
+ Map<String, String[]> params = new HashMap<>();
+ MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
+ params.put(UpdateParams.UPDATE_CHAIN, new String[]{updateRequestProcessorChain});
+ SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
+ (SolrParams) mmparams) {
+ };
+
+ UpdateRequestHandler handler = new UpdateRequestHandler();
+ handler.init(null);
+ ArrayList<ContentStream> streams = new ArrayList<>(2);
+ streams.add(new ContentStreamBase.StringStream(doc));
+ req.setContentStreams(streams);
+ handler.handleRequestBody(req, new SolrQueryResponse());
+ req.close();
+ }
/**
* Generates an <add><doc>... XML String with options