You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by ho...@apache.org on 2016/03/21 01:43:47 UTC
[25/50] lucene-solr:jira/SOLR-445: SOLR-7739 - applied patch from
Alessandro Benedetti for integrating Lucene classification into Solr
SOLR-7739 - applied patch from Alessandro Benedetti for integrating Lucene classification into Solr
Project: http://git-wip-us.apache.org/repos/asf/lucene-solr/repo
Commit: http://git-wip-us.apache.org/repos/asf/lucene-solr/commit/5801caab
Tree: http://git-wip-us.apache.org/repos/asf/lucene-solr/tree/5801caab
Diff: http://git-wip-us.apache.org/repos/asf/lucene-solr/diff/5801caab
Branch: refs/heads/jira/SOLR-445
Commit: 5801caab6c5fb881a590e06401096ea04111d905
Parents: 56ca641
Author: Tommaso Teofili <te...@adobe.com>
Authored: Tue Mar 15 12:29:07 2016 +0100
Committer: Tommaso Teofili <te...@adobe.com>
Committed: Tue Mar 15 12:29:07 2016 +0100
----------------------------------------------------------------------
dev-tools/idea/solr/core/src/java/solr-core.iml | 1 +
.../idea/solr/core/src/solr-core-tests.iml | 1 +
solr/common-build.xml | 4 +-
.../ClassificationUpdateProcessor.java | 102 ++++++++
.../ClassificationUpdateProcessorFactory.java | 223 ++++++++++++++++++
.../collection1/conf/schema-classification.xml | 43 ++++
.../conf/solrconfig-classification.xml | 53 +++++
...lassificationUpdateProcessorFactoryTest.java | 234 +++++++++++++++++++
8 files changed, 660 insertions(+), 1 deletion(-)
----------------------------------------------------------------------
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5801caab/dev-tools/idea/solr/core/src/java/solr-core.iml
----------------------------------------------------------------------
diff --git a/dev-tools/idea/solr/core/src/java/solr-core.iml b/dev-tools/idea/solr/core/src/java/solr-core.iml
index f03268c..822b24f 100644
--- a/dev-tools/idea/solr/core/src/java/solr-core.iml
+++ b/dev-tools/idea/solr/core/src/java/solr-core.iml
@@ -27,6 +27,7 @@
<orderEntry type="module" module-name="expressions" />
<orderEntry type="module" module-name="analysis-common" />
<orderEntry type="module" module-name="lucene-core" />
+ <orderEntry type="module" module-name="classification" />
<orderEntry type="module" module-name="queryparser" />
<orderEntry type="module" module-name="join" />
<orderEntry type="module" module-name="sandbox" />
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5801caab/dev-tools/idea/solr/core/src/solr-core-tests.iml
----------------------------------------------------------------------
diff --git a/dev-tools/idea/solr/core/src/solr-core-tests.iml b/dev-tools/idea/solr/core/src/solr-core-tests.iml
index c9f722a..56f768b 100644
--- a/dev-tools/idea/solr/core/src/solr-core-tests.iml
+++ b/dev-tools/idea/solr/core/src/solr-core-tests.iml
@@ -21,6 +21,7 @@
<orderEntry type="module" scope="TEST" module-name="solr-core" />
<orderEntry type="module" scope="TEST" module-name="solrj" />
<orderEntry type="module" scope="TEST" module-name="lucene-core" />
+ <orderEntry type="module" scope="TEST" module-name="classification" />
<orderEntry type="module" scope="TEST" module-name="analysis-common" />
<orderEntry type="module" scope="TEST" module-name="queryparser" />
<orderEntry type="module" scope="TEST" module-name="queries" />
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5801caab/solr/common-build.xml
----------------------------------------------------------------------
diff --git a/solr/common-build.xml b/solr/common-build.xml
index 6a06928..78e10aa 100644
--- a/solr/common-build.xml
+++ b/solr/common-build.xml
@@ -108,6 +108,7 @@
<pathelement location="${queryparser.jar}"/>
<pathelement location="${join.jar}"/>
<pathelement location="${sandbox.jar}"/>
+ <pathelement location="${classification.jar}"/>
</path>
<path id="solr.base.classpath">
@@ -169,7 +170,7 @@
<target name="prep-lucene-jars"
depends="jar-lucene-core, jar-backward-codecs, jar-analyzers-phonetic, jar-analyzers-kuromoji, jar-codecs,jar-expressions, jar-suggest, jar-highlighter, jar-memory,
- jar-misc, jar-spatial-extras, jar-grouping, jar-queries, jar-queryparser, jar-join, jar-sandbox">
+ jar-misc, jar-spatial-extras, jar-grouping, jar-queries, jar-queryparser, jar-join, jar-sandbox, jar-classification">
<property name="solr.deps.compiled" value="true"/>
</target>
@@ -322,6 +323,7 @@
<link offline="true" href="${lucene.javadoc.url}highlighter" packagelistloc="${lucenedocs}/highlighter"/>
<link offline="true" href="${lucene.javadoc.url}memory" packagelistloc="${lucenedocs}/memory"/>
<link offline="true" href="${lucene.javadoc.url}misc" packagelistloc="${lucenedocs}/misc"/>
+ <link offline="true" href="${lucene.javadoc.url}classification" packagelistloc="${lucenedocs}/classification"/>
<link offline="true" href="${lucene.javadoc.url}spatial-extras" packagelistloc="${lucenedocs}/spatial-extras"/>
<links/>
<link href=""/>
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5801caab/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java
new file mode 100644
index 0000000..b752565
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessor.java
@@ -0,0 +1,102 @@
+package org.apache.solr.update.processor;
+
+import java.io.IOException;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.lucene.analysis.Analyzer;
+import org.apache.lucene.classification.ClassificationResult;
+import org.apache.lucene.classification.document.DocumentClassifier;
+import org.apache.lucene.classification.document.KNearestNeighborDocumentClassifier;
+import org.apache.lucene.classification.document.SimpleNaiveBayesDocumentClassifier;
+import org.apache.lucene.document.Document;
+import org.apache.lucene.index.LeafReader;
+import org.apache.lucene.util.BytesRef;
+import org.apache.solr.common.SolrInputDocument;
+import org.apache.solr.schema.IndexSchema;
+import org.apache.solr.schema.SchemaField;
+import org.apache.solr.update.AddUpdateCommand;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * This Class is a Request Update Processor to classify the document in input and add a field
+ * containing the class to the Document.
+ * It uses the Lucene Document Classification module, see {@link DocumentClassifier}.
+ */
+class ClassificationUpdateProcessor
+ extends UpdateRequestProcessor {
+
+ private String classFieldName; // the field to index the assigned class
+
+ private DocumentClassifier<BytesRef> classifier;
+
+ /**
+ * Sole constructor
+ *
+ * @param inputFieldNames fields to be used as classifier's inputs
+ * @param classFieldName field to be used as classifier's output
+ * @param minDf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minDocFreq}, in case algorithm is {@code "knn"}
+ * @param minTf setting for {@link org.apache.lucene.queries.mlt.MoreLikeThis#minTermFreq}, in case algorithm is {@code "knn"}
+ * @param k setting for k nearest neighbors to analyze, in case algorithm is {@code "knn"}
+ * @param algorithm the name of the classifier to use
+ * @param next next update processor in the chain
+ * @param indexReader index reader
+ * @param schema schema
+ */
+ public ClassificationUpdateProcessor(String[] inputFieldNames, String classFieldName, int minDf, int minTf, int k, String algorithm,
+ UpdateRequestProcessor next, LeafReader indexReader, IndexSchema schema) {
+ super(next);
+ this.classFieldName = classFieldName;
+ Map<String, Analyzer> field2analyzer = new HashMap<String, Analyzer>();
+ for (String fieldName : inputFieldNames) {
+ SchemaField fieldFromSolrSchema = schema.getField(fieldName);
+ Analyzer indexAnalyzer = fieldFromSolrSchema.getType().getQueryAnalyzer();
+ field2analyzer.put(fieldName, indexAnalyzer);
+ }
+ switch (algorithm) {
+ case "knn":
+ classifier = new KNearestNeighborDocumentClassifier(indexReader, null, null, k, minDf, minTf, classFieldName, field2analyzer, inputFieldNames);
+ break;
+ case "bayes":
+ classifier = new SimpleNaiveBayesDocumentClassifier(indexReader, null, classFieldName, field2analyzer, inputFieldNames);
+ break;
+ }
+ }
+
+ /**
+ * @param cmd the update command in input conaining the Document to classify
+ * @throws IOException If there is a low-level I/O error
+ */
+ @Override
+ public void processAdd(AddUpdateCommand cmd)
+ throws IOException {
+ SolrInputDocument doc = cmd.getSolrInputDocument();
+ Document luceneDocument = cmd.getLuceneDocument();
+ String assignedClass;
+ Object documentClass = doc.getFieldValue(classFieldName);
+ if (documentClass == null) {
+ ClassificationResult<BytesRef> classificationResult = classifier.assignClass(luceneDocument);
+ if (classificationResult != null) {
+ assignedClass = classificationResult.getAssignedClass().utf8ToString();
+ doc.addField(classFieldName, assignedClass);
+ }
+ }
+ super.processAdd(cmd);
+ }
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5801caab/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java
----------------------------------------------------------------------
diff --git a/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java
new file mode 100644
index 0000000..79b3240
--- /dev/null
+++ b/solr/core/src/java/org/apache/solr/update/processor/ClassificationUpdateProcessorFactory.java
@@ -0,0 +1,223 @@
+package org.apache.solr.update.processor;
+
+import org.apache.lucene.index.LeafReader;
+import org.apache.solr.common.SolrException;
+import org.apache.solr.common.params.SolrParams;
+import org.apache.solr.common.util.NamedList;
+import org.apache.solr.request.SolrQueryRequest;
+import org.apache.solr.response.SolrQueryResponse;
+import org.apache.solr.schema.IndexSchema;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+/**
+ * This class implements an UpdateProcessorFactory for the Classification Update Processor.
+ * It takes in input a series of parameter that will be necessary to instantiate and use the Classifier
+ */
+public class ClassificationUpdateProcessorFactory extends UpdateRequestProcessorFactory {
+
+ // Update Processor Config params
+ private static final String INPUT_FIELDS_PARAM = "inputFields";
+ private static final String CLASS_FIELD_PARAM = "classField";
+ private static final String ALGORITHM_PARAM = "algorithm";
+ private static final String KNN_MIN_TF_PARAM = "knn.minTf";
+ private static final String KNN_MIN_DF_PARAM = "knn.minDf";
+ private static final String KNN_K_PARAM = "knn.k";
+
+ //Update Processor Defaults
+ private static final int DEFAULT_MIN_TF = 1;
+ private static final int DEFAULT_MIN_DF = 1;
+ private static final int DEFAULT_K = 10;
+ private static final String DEFAULT_ALGORITHM = "knn";
+
+ private String[] inputFieldNames; // the array of fields to be sent to the Classifier
+
+ private String classFieldName; // the field containing the class for the Document
+
+ private String algorithm; // the Classification Algorithm to use - currently 'knn' or 'bayes'
+
+ private int minTf; // knn specific - the minimum Term Frequency for considering a term
+
+ private int minDf; // knn specific - the minimum Document Frequency for considering a term
+
+ private int k; // knn specific - thw window of top results to evaluate, when assgning the class
+
+ @Override
+ public void init(final NamedList args) {
+ if (args != null) {
+ SolrParams params = SolrParams.toSolrParams(args);
+
+ String fieldNames = params.get(INPUT_FIELDS_PARAM);// must be a comma separated list of fields
+ checkNotNull(INPUT_FIELDS_PARAM, fieldNames);
+ inputFieldNames = fieldNames.split("\\,");
+
+ classFieldName = params.get(CLASS_FIELD_PARAM);
+ checkNotNull(CLASS_FIELD_PARAM, classFieldName);
+
+ algorithm = params.get(ALGORITHM_PARAM);
+ if (algorithm == null)
+ algorithm = DEFAULT_ALGORITHM;
+
+ minTf = getIntParam(params, KNN_MIN_TF_PARAM, DEFAULT_MIN_TF);
+ minDf = getIntParam(params, KNN_MIN_DF_PARAM, DEFAULT_MIN_DF);
+ k = getIntParam(params, KNN_K_PARAM, DEFAULT_K);
+ }
+ }
+
+ /*
+ * Returns an Int parsed param or a default if the param is null
+ *
+ * @param params Solr params in input
+ * @param name the param name
+ * @param defaultValue the param default
+ * @return the Int value for the param
+ */
+ private int getIntParam(SolrParams params, String name, int defaultValue) {
+ String paramString = params.get(name);
+ int paramInt;
+ if (paramString != null && !paramString.isEmpty()) {
+ paramInt = Integer.parseInt(paramString);
+ } else {
+ paramInt = defaultValue;
+ }
+ return paramInt;
+ }
+
+ private void checkNotNull(String paramName, Object param) {
+ if (param == null) {
+ throw new SolrException
+ (SolrException.ErrorCode.SERVER_ERROR,
+ "Classification UpdateProcessor '" + paramName + "' can not be null");
+ }
+ }
+
+ @Override
+ public UpdateRequestProcessor getInstance(SolrQueryRequest req, SolrQueryResponse rsp, UpdateRequestProcessor next) {
+ IndexSchema schema = req.getSchema();
+ LeafReader leafReader = req.getSearcher().getLeafReader();
+ return new ClassificationUpdateProcessor(inputFieldNames, classFieldName, minDf, minTf, k, algorithm, next, leafReader, schema);
+ }
+
+ /**
+ * get field names used as classifier's inputs
+ *
+ * @return the input field names
+ */
+ public String[] getInputFieldNames() {
+ return inputFieldNames;
+ }
+
+ /**
+ * set field names used as classifier's inputs
+ *
+ * @param inputFieldNames the input field names
+ */
+ public void setInputFieldNames(String[] inputFieldNames) {
+ this.inputFieldNames = inputFieldNames;
+ }
+
+ /**
+ * get field names used as classifier's output
+ *
+ * @return the output field name
+ */
+ public String getClassFieldName() {
+ return classFieldName;
+ }
+
+ /**
+ * set field names used as classifier's output
+ *
+ * @param classFieldName the output field name
+ */
+ public void setClassFieldName(String classFieldName) {
+ this.classFieldName = classFieldName;
+ }
+
+ /**
+ * get the name of the classifier algorithm used
+ *
+ * @return the classifier algorithm used
+ */
+ public String getAlgorithm() {
+ return algorithm;
+ }
+
+ /**
+ * set the name of the classifier algorithm used
+ *
+ * @param algorithm the classifier algorithm used
+ */
+ public void setAlgorithm(String algorithm) {
+ this.algorithm = algorithm;
+ }
+
+ /**
+ * get the min term frequency value to be used in case algorithm is {@code "knn"}
+ *
+ * @return the min term frequency
+ */
+ public int getMinTf() {
+ return minTf;
+ }
+
+ /**
+ * set the min term frequency value to be used in case algorithm is {@code "knn"}
+ *
+ * @param minTf the min term frequency
+ */
+ public void setMinTf(int minTf) {
+ this.minTf = minTf;
+ }
+
+ /**
+ * get the min document frequency value to be used in case algorithm is {@code "knn"}
+ *
+ * @return the min document frequency
+ */
+ public int getMinDf() {
+ return minDf;
+ }
+
+ /**
+ * set the min document frequency value to be used in case algorithm is {@code "knn"}
+ *
+ * @param minDf the min document frequency
+ */
+ public void setMinDf(int minDf) {
+ this.minDf = minDf;
+ }
+
+ /**
+ * get the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
+ *
+ * @return the no. of neighbors to analyze
+ */
+ public int getK() {
+ return k;
+ }
+
+ /**
+ * set the the no. of nearest neighbor to analyze, to be used in case algorithm is {@code "knn"}
+ *
+ * @param k the no. of neighbors to analyze
+ */
+ public void setK(int k) {
+ this.k = k;
+ }
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5801caab/solr/core/src/test-files/solr/collection1/conf/schema-classification.xml
----------------------------------------------------------------------
diff --git a/solr/core/src/test-files/solr/collection1/conf/schema-classification.xml b/solr/core/src/test-files/solr/collection1/conf/schema-classification.xml
new file mode 100644
index 0000000..89c27a6
--- /dev/null
+++ b/solr/core/src/test-files/solr/collection1/conf/schema-classification.xml
@@ -0,0 +1,43 @@
+<?xml version="1.0" encoding="UTF-8" ?>
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+<!--
+ Test Schema for a simple Classification Test
+ -->
+<schema name="classification" version="1.1">
+ <fieldType name="string" class="solr.StrField"/>
+ <fieldType name="text_en" class="solr.TextField" positionIncrementGap="100">
+ <analyzer type="index">
+ <tokenizer class="solr.StandardTokenizerFactory"/>
+ <filter class="solr.LowerCaseFilterFactory"/>
+ <filter class="solr.EnglishPossessiveFilterFactory"/>
+ <filter class="solr.PorterStemFilterFactory"/>
+ </analyzer>
+ <analyzer type="query">
+ <tokenizer class="solr.StandardTokenizerFactory"/>
+ <filter class="solr.LowerCaseFilterFactory"/>
+ <filter class="solr.EnglishPossessiveFilterFactory"/>
+ <filter class="solr.PorterStemFilterFactory"/>
+ </analyzer>
+ </fieldType>
+ <field name="id" type="string" indexed="true" stored="true" required="true"/>
+ <field name="title" type="text_en" indexed="true" stored="true"/>
+ <field name="content" type="text_en" indexed="true" stored="true"/>
+ <field name="author" type="string" indexed="true" stored="true"/>
+ <field name="cat" type="string" indexed="true" stored="true"/>
+ <uniqueKey>id</uniqueKey>
+</schema>
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5801caab/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml
----------------------------------------------------------------------
diff --git a/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml b/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml
new file mode 100644
index 0000000..3656335
--- /dev/null
+++ b/solr/core/src/test-files/solr/collection1/conf/solrconfig-classification.xml
@@ -0,0 +1,53 @@
+<?xml version="1.0" ?>
+
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<!--
+ Test Config for a simple Classification Update Request Processor Chain
+ -->
+<config>
+ <luceneMatchVersion>${tests.luceneMatchVersion:LATEST}</luceneMatchVersion>
+ <xi:include xmlns:xi="http://www.w3.org/2001/XInclude" href="solrconfig.snippet.randomindexconfig.xml"/>
+ <requestHandler name="standard" class="solr.StandardRequestHandler"></requestHandler>
+ <directoryFactory name="DirectoryFactory" class="${solr.directoryFactory:solr.RAMDirectoryFactory}"/>
+ <schemaFactory class="ClassicIndexSchemaFactory"/>
+
+ <updateHandler class="solr.DirectUpdateHandler2">
+ <updateLog enable="${enable.update.log:true}">
+ <str name="dir">${solr.ulog.dir:}</str>
+ </updateLog>
+
+ <commitWithin>
+ <softCommit>${solr.commitwithin.softcommit:true}</softCommit>
+ </commitWithin>
+
+ </updateHandler>
+
+ <updateRequestProcessorChain name="classification">
+ <processor class="solr.ClassificationUpdateProcessorFactory">
+ <str name="inputFields">title,content,author</str>
+ <str name="classField">cat</str>
+ <!-- Knn algorithm specific-->
+ <str name="algorithm">knn</str>
+ <str name="knn.minTf">1</str>
+ <str name="knn.minDf">1</str>
+ <str name="knn.k">5</str>
+ </processor>
+ <processor class="solr.RunUpdateProcessorFactory"/>
+ </updateRequestProcessorChain>
+</config>
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5801caab/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java
----------------------------------------------------------------------
diff --git a/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java
new file mode 100644
index 0000000..27d8dca
--- /dev/null
+++ b/solr/core/src/test/org/apache/solr/update/processor/ClassificationUpdateProcessorFactoryTest.java
@@ -0,0 +1,234 @@
+package org.apache.solr.update.processor;
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.TopDocs;
+import org.apache.solr.SolrTestCaseJ4;
+import org.apache.solr.common.SolrException;
+import org.apache.solr.common.params.MultiMapSolrParams;
+import org.apache.solr.common.params.SolrParams;
+import org.apache.solr.common.params.UpdateParams;
+import org.apache.solr.common.util.ContentStream;
+import org.apache.solr.common.util.ContentStreamBase;
+import org.apache.solr.common.util.NamedList;
+import org.apache.solr.handler.UpdateRequestHandler;
+import org.apache.solr.request.SolrQueryRequest;
+import org.apache.solr.request.SolrQueryRequestBase;
+import org.apache.solr.response.SolrQueryResponse;
+import org.apache.solr.search.SolrIndexSearcher;
+import org.junit.Before;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+/**
+ * Tests for {@link ClassificationUpdateProcessor} and {@link ClassificationUpdateProcessorFactory}
+ */
+public class ClassificationUpdateProcessorFactoryTest extends SolrTestCaseJ4 {
+ // field names are used in accordance with the solrconfig and schema supplied
+ private static final String ID = "id";
+ private static final String TITLE = "title";
+ private static final String CONTENT = "content";
+ private static final String AUTHOR = "author";
+ private static final String CLASS = "cat";
+
+ private static final String CHAIN = "classification";
+
+
+ private ClassificationUpdateProcessorFactory cFactoryToTest = new ClassificationUpdateProcessorFactory();
+ private NamedList args = new NamedList<String>();
+
+ @BeforeClass
+ public static void beforeClass() throws Exception {
+ System.setProperty("enable.update.log", "false");
+ initCore("solrconfig-classification.xml", "schema-classification.xml");
+ }
+
+ @Override
+ @Before
+ public void setUp() throws Exception {
+ super.setUp();
+ clearIndex();
+ assertU(commit());
+ }
+
+ @Before
+ public void initArgs() {
+ args.add("inputFields", "inputField1,inputField2");
+ args.add("classField", "classField1");
+ args.add("algorithm", "bayes");
+ args.add("knn.k", "9");
+ args.add("knn.minDf", "8");
+ args.add("knn.minTf", "10");
+ }
+
+ @Test
+ public void testFullInit() {
+ cFactoryToTest.init(args);
+
+ String[] inputFieldNames = cFactoryToTest.getInputFieldNames();
+ assertEquals("inputField1", inputFieldNames[0]);
+ assertEquals("inputField2", inputFieldNames[1]);
+ assertEquals("classField1", cFactoryToTest.getClassFieldName());
+ assertEquals("bayes", cFactoryToTest.getAlgorithm());
+ assertEquals(8, cFactoryToTest.getMinDf());
+ assertEquals(10, cFactoryToTest.getMinTf());
+ assertEquals(9, cFactoryToTest.getK());
+
+ }
+
+ @Test
+ public void testInitEmptyInputField() {
+ args.removeAll("inputFields");
+ try {
+ cFactoryToTest.init(args);
+ } catch (SolrException e) {
+ assertEquals("Classification UpdateProcessor 'inputFields' can not be null", e.getMessage());
+ }
+ }
+
+ @Test
+ public void testInitEmptyClassField() {
+ args.removeAll("classField");
+ try {
+ cFactoryToTest.init(args);
+ } catch (SolrException e) {
+ assertEquals("Classification UpdateProcessor 'classField' can not be null", e.getMessage());
+ }
+ }
+
+ @Test
+ public void testDefaults() {
+ args.removeAll("algorithm");
+ args.removeAll("knn.k");
+ args.removeAll("knn.minDf");
+ args.removeAll("knn.minTf");
+ cFactoryToTest.init(args);
+ assertEquals("knn", cFactoryToTest.getAlgorithm());
+ assertEquals(1, cFactoryToTest.getMinDf());
+ assertEquals(1, cFactoryToTest.getMinTf());
+ assertEquals(10, cFactoryToTest.getK());
+ }
+
+ @Test
+ public void testBasicClassification() throws Exception {
+ prepareTrainedIndex();
+ // To be classified,we index documents without a class and verify the expected one is returned
+ addDoc(adoc(ID, "10",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5 ",
+ AUTHOR, "Name1 Surname1"));
+ addDoc(adoc(ID, "11",
+ TITLE, "word1 word1",
+ CONTENT, "word2 word2",
+ AUTHOR, "Name Surname"));
+ addDoc(commit());
+
+ Document doc10 = getDoc("10");
+ assertEquals("class2", doc10.get(CLASS));
+ Document doc11 = getDoc("11");
+ assertEquals("class1", doc11.get(CLASS));
+ }
+
+ /**
+ * Index some example documents with a class manually assigned.
+ * This will be our trained model.
+ *
+ * @throws Exception If there is a low-level I/O error
+ */
+ private void prepareTrainedIndex() throws Exception {
+ //class1
+ addDoc(adoc(ID, "1",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2 word2 word2",
+ AUTHOR, "Name Surname",
+ CLASS, "class1"));
+ addDoc(adoc(ID, "2",
+ TITLE, "word1 word1",
+ CONTENT, "word2 word2",
+ AUTHOR, "Name Surname",
+ CLASS, "class1"));
+ addDoc(adoc(ID, "3",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2",
+ AUTHOR, "Name Surname",
+ CLASS, "class1"));
+ addDoc(adoc(ID, "4",
+ TITLE, "word1 word1 word1",
+ CONTENT, "word2 word2 word2",
+ AUTHOR, "Name Surname",
+ CLASS, "class1"));
+ //class2
+ addDoc(adoc(ID, "5",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5",
+ AUTHOR, "Name1 Surname1",
+ CLASS, "class2"));
+ addDoc(adoc(ID, "6",
+ TITLE, "word4 word4",
+ CONTENT, "word5",
+ AUTHOR, "Name1 Surname1",
+ CLASS, "class2"));
+ addDoc(adoc(ID, "7",
+ TITLE, "word4 word4 word4",
+ CONTENT, "word5 word5 word5",
+ AUTHOR, "Name1 Surname1",
+ CLASS, "class2"));
+ addDoc(adoc(ID, "8",
+ TITLE, "word4",
+ CONTENT, "word5 word5 word5 word5",
+ AUTHOR, "Name1 Surname1",
+ CLASS, "class2"));
+ addDoc(commit());
+ }
+
+ private Document getDoc(String id) throws IOException {
+ try (SolrQueryRequest req = req()) {
+ SolrIndexSearcher searcher = req.getSearcher();
+ TermQuery query = new TermQuery(new Term(ID, id));
+ TopDocs doc1 = searcher.search(query, 1);
+ ScoreDoc scoreDoc = doc1.scoreDocs[0];
+ return searcher.doc(scoreDoc.doc);
+ }
+ }
+
+ static void addDoc(String doc) throws Exception {
+ Map<String, String[]> params = new HashMap<>();
+ MultiMapSolrParams mmparams = new MultiMapSolrParams(params);
+ params.put(UpdateParams.UPDATE_CHAIN, new String[]{CHAIN});
+ SolrQueryRequestBase req = new SolrQueryRequestBase(h.getCore(),
+ (SolrParams) mmparams) {
+ };
+
+ UpdateRequestHandler handler = new UpdateRequestHandler();
+ handler.init(null);
+ ArrayList<ContentStream> streams = new ArrayList<>(2);
+ streams.add(new ContentStreamBase.StringStream(doc));
+ req.setContentStreams(streams);
+ handler.handleRequestBody(req, new SolrQueryResponse());
+ req.close();
+ }
+}