You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by cp...@apache.org on 2016/10/07 22:06:16 UTC
[6/8] lucene-solr:jira/solr-8542: SOLR-8542: Added Solr Learning to
Rank (LTR) plugin for reranking results with machine learning models.
(Michael Nilsson, Diego Ceccarelli, Joshua Pantony, Jon Dorando,
Naveen Santhapuri, Alessandro Benedetti, David Groh
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f2a8e8ac/solr/contrib/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java
new file mode 100644
index 0000000..81efc81
--- /dev/null
+++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/FeatureLogger.java
@@ -0,0 +1,193 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import java.lang.invoke.MethodHandles;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Map;
+
+import org.apache.solr.ltr.LTRScoringQuery.FeatureInfo;
+import org.apache.solr.search.SolrIndexSearcher;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * FeatureLogger can be registered in a model and provide a strategy for logging
+ * the feature values.
+ */
+public abstract class FeatureLogger<FV_TYPE> {
+
+ private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
+
+ /** the name of the cache using for storing the feature value **/
+ private static final String QUERY_FV_CACHE_NAME = "QUERY_DOC_FV";
+
+ protected enum FeatureFormat {DENSE, SPARSE};
+ protected final FeatureFormat featureFormat;
+
+ protected FeatureLogger(FeatureFormat f) {
+ this.featureFormat = f;
+ }
+
+ /**
+ * Log will be called every time that the model generates the feature values
+ * for a document and a query.
+ *
+ * @param docid
+ * Solr document id whose features we are saving
+ * @param featuresInfo
+ * List of all the FeatureInfo objects which contain name and value
+ * for all the features triggered by the result set
+ * @return true if the logger successfully logged the features, false
+ * otherwise.
+ */
+
+ public boolean log(int docid, LTRScoringQuery scoringQuery,
+ SolrIndexSearcher searcher, FeatureInfo[] featuresInfo) {
+ final FV_TYPE featureVector = makeFeatureVector(featuresInfo);
+ if (featureVector == null) {
+ return false;
+ }
+
+ return searcher.cacheInsert(QUERY_FV_CACHE_NAME,
+ fvCacheKey(scoringQuery, docid), featureVector) != null;
+ }
+
+ /**
+ * returns a FeatureLogger that logs the features in output, using the format
+ * specified in the 'stringFormat' param: 'csv' will log the features as a unique
+ * string in csv format 'json' will log the features in a map in a Map of
+ * featureName keys to featureValue values if format is null or empty, csv
+ * format will be selected.
+ * 'featureFormat' param: 'dense' will write features in dense format,
+ * 'sparse' will write the features in sparse format, null or empty will
+ * default to 'sparse'
+ *
+ *
+ * @return a feature logger for the format specified.
+ */
+ public static FeatureLogger<?> createFeatureLogger(String stringFormat, String featureFormat) {
+ final FeatureFormat f;
+ if (featureFormat == null || featureFormat.isEmpty() ||
+ featureFormat.equals("sparse")) {
+ f = FeatureFormat.SPARSE;
+ }
+ else if (featureFormat.equals("dense")) {
+ f = FeatureFormat.DENSE;
+ }
+ else {
+ f = FeatureFormat.SPARSE;
+ log.warn("unknown feature logger feature format {} | {}", stringFormat, featureFormat);
+ }
+ if ((stringFormat == null) || stringFormat.isEmpty()) {
+ return new CSVFeatureLogger(f);
+ }
+ if (stringFormat.equals("csv")) {
+ return new CSVFeatureLogger(f);
+ }
+ if (stringFormat.equals("json")) {
+ return new MapFeatureLogger(f);
+ }
+ log.warn("unknown feature logger string format {} | {}", stringFormat, featureFormat);
+ return null;
+
+ }
+
+ public abstract FV_TYPE makeFeatureVector(FeatureInfo[] featuresInfo);
+
+ private static int fvCacheKey(LTRScoringQuery scoringQuery, int docid) {
+ return scoringQuery.hashCode() + (31 * docid);
+ }
+
+ /**
+ * populate the document with its feature vector
+ *
+ * @param docid
+ * Solr document id
+ * @return String representation of the list of features calculated for docid
+ */
+
+ public FV_TYPE getFeatureVector(int docid, LTRScoringQuery scoringQuery,
+ SolrIndexSearcher searcher) {
+ return (FV_TYPE) searcher.cacheLookup(QUERY_FV_CACHE_NAME, fvCacheKey(scoringQuery, docid));
+ }
+
+
+ public static class MapFeatureLogger extends FeatureLogger<Map<String,Float>> {
+
+ public MapFeatureLogger(FeatureFormat f) {
+ super(f);
+ }
+
+ @Override
+ public Map<String,Float> makeFeatureVector(FeatureInfo[] featuresInfo) {
+ boolean isDense = featureFormat.equals(FeatureFormat.DENSE);
+ Map<String,Float> hashmap = Collections.emptyMap();
+ if (featuresInfo.length > 0) {
+ hashmap = new HashMap<String,Float>(featuresInfo.length);
+ for (FeatureInfo featInfo:featuresInfo){
+ if (featInfo.isUsed() || isDense){
+ hashmap.put(featInfo.getName(), featInfo.getValue());
+ }
+ }
+ }
+ return hashmap;
+ }
+
+ }
+
+ public static class CSVFeatureLogger extends FeatureLogger<String> {
+ StringBuilder sb = new StringBuilder(500);
+ char keyValueSep = ':';
+ char featureSep = ';';
+
+ public CSVFeatureLogger(FeatureFormat f) {
+ super(f);
+ }
+
+ public CSVFeatureLogger setKeyValueSep(char keyValueSep) {
+ this.keyValueSep = keyValueSep;
+ return this;
+ }
+
+ public CSVFeatureLogger setFeatureSep(char featureSep) {
+ this.featureSep = featureSep;
+ return this;
+ }
+
+ @Override
+ public String makeFeatureVector(FeatureInfo[] featuresInfo) {
+ boolean isDense = featureFormat.equals(FeatureFormat.DENSE);
+ for (FeatureInfo featInfo:featuresInfo) {
+ if (featInfo.isUsed() || isDense){
+ sb.append(featInfo.getName()).append(keyValueSep)
+ .append(featInfo.getValue());
+ sb.append(featureSep);
+ }
+ }
+
+ final String features = (sb.length() > 0 ? sb.substring(0,
+ sb.length() - 1) : "");
+ sb.setLength(0);
+
+ return features;
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f2a8e8ac/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java
new file mode 100644
index 0000000..d1a7f69
--- /dev/null
+++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRRescorer.java
@@ -0,0 +1,251 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import java.io.IOException;
+import java.util.Arrays;
+import java.util.Comparator;
+import java.util.List;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.ReaderUtil;
+import org.apache.lucene.search.Explanation;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Rescorer;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.Weight;
+import org.apache.solr.ltr.LTRScoringQuery.ModelWeight;
+import org.apache.solr.ltr.LTRScoringQuery.ModelWeight.ModelScorer;
+import org.apache.solr.search.SolrIndexSearcher;
+
+
+/**
+ * Implements the rescoring logic. The top documents returned by solr with their
+ * original scores, will be processed by a {@link LTRScoringQuery} that will assign a
+ * new score to each document. The top documents will be resorted based on the
+ * new score.
+ * */
+public class LTRRescorer extends Rescorer {
+
+ LTRScoringQuery scoringQuery;
+ public LTRRescorer(LTRScoringQuery scoringQuery) {
+ this.scoringQuery = scoringQuery;
+ }
+
+ private void heapAdjust(ScoreDoc[] hits, int size, int root) {
+ final ScoreDoc doc = hits[root];
+ final float score = doc.score;
+ int i = root;
+ while (i <= ((size >> 1) - 1)) {
+ final int lchild = (i << 1) + 1;
+ final ScoreDoc ldoc = hits[lchild];
+ final float lscore = ldoc.score;
+ float rscore = Float.MAX_VALUE;
+ final int rchild = (i << 1) + 2;
+ ScoreDoc rdoc = null;
+ if (rchild < size) {
+ rdoc = hits[rchild];
+ rscore = rdoc.score;
+ }
+ if (lscore < score) {
+ if (rscore < lscore) {
+ hits[i] = rdoc;
+ hits[rchild] = doc;
+ i = rchild;
+ } else {
+ hits[i] = ldoc;
+ hits[lchild] = doc;
+ i = lchild;
+ }
+ } else if (rscore < score) {
+ hits[i] = rdoc;
+ hits[rchild] = doc;
+ i = rchild;
+ } else {
+ return;
+ }
+ }
+ }
+
+ private void heapify(ScoreDoc[] hits, int size) {
+ for (int i = (size >> 1) - 1; i >= 0; i--) {
+ heapAdjust(hits, size, i);
+ }
+ }
+
+ /**
+ * rescores the documents:
+ *
+ * @param searcher
+ * current IndexSearcher
+ * @param firstPassTopDocs
+ * documents to rerank;
+ * @param topN
+ * documents to return;
+ */
+ @Override
+ public TopDocs rescore(IndexSearcher searcher, TopDocs firstPassTopDocs,
+ int topN) throws IOException {
+ if ((topN == 0) || (firstPassTopDocs.totalHits == 0)) {
+ return firstPassTopDocs;
+ }
+ final ScoreDoc[] hits = firstPassTopDocs.scoreDocs;
+ Arrays.sort(hits, new Comparator<ScoreDoc>() {
+ @Override
+ public int compare(ScoreDoc a, ScoreDoc b) {
+ return a.doc - b.doc;
+ }
+ });
+
+ topN = Math.min(topN, firstPassTopDocs.totalHits);
+ final ScoreDoc[] reranked = new ScoreDoc[topN];
+ final List<LeafReaderContext> leaves = searcher.getIndexReader().leaves();
+ final ModelWeight modelWeight = (ModelWeight) searcher
+ .createNormalizedWeight(scoringQuery, true);
+
+ final SolrIndexSearcher solrIndexSearch = (SolrIndexSearcher) searcher;
+ scoreFeatures(solrIndexSearch, firstPassTopDocs,topN, modelWeight, hits, leaves, reranked);
+ // Must sort all documents that we reranked, and then select the top
+ Arrays.sort(reranked, new Comparator<ScoreDoc>() {
+ @Override
+ public int compare(ScoreDoc a, ScoreDoc b) {
+ // Sort by score descending, then docID ascending:
+ if (a.score > b.score) {
+ return -1;
+ } else if (a.score < b.score) {
+ return 1;
+ } else {
+ // This subtraction can't overflow int
+ // because docIDs are >= 0:
+ return a.doc - b.doc;
+ }
+ }
+ });
+
+ return new TopDocs(firstPassTopDocs.totalHits, reranked, reranked[0].score);
+ }
+
+ public void scoreFeatures(SolrIndexSearcher solrIndexSearch, TopDocs firstPassTopDocs,
+ int topN, ModelWeight modelWeight, ScoreDoc[] hits, List<LeafReaderContext> leaves,
+ ScoreDoc[] reranked) throws IOException {
+
+ int readerUpto = -1;
+ int endDoc = 0;
+ int docBase = 0;
+
+ ModelScorer scorer = null;
+ int hitUpto = 0;
+ final FeatureLogger<?> featureLogger = scoringQuery.getFeatureLogger();
+
+ while (hitUpto < hits.length) {
+ final ScoreDoc hit = hits[hitUpto];
+ final int docID = hit.doc;
+ LeafReaderContext readerContext = null;
+ while (docID >= endDoc) {
+ readerUpto++;
+ readerContext = leaves.get(readerUpto);
+ endDoc = readerContext.docBase + readerContext.reader().maxDoc();
+ }
+ // We advanced to another segment
+ if (readerContext != null) {
+ docBase = readerContext.docBase;
+ scorer = modelWeight.scorer(readerContext);
+ }
+ // Scorer for a ModelWeight should never be null since we always have to
+ // call score
+ // even if no feature scorers match, since a model might use that info to
+ // return a
+ // non-zero score. Same applies for the case of advancing a ModelScorer
+ // past the target
+ // doc since the model algorithm still needs to compute a potentially
+ // non-zero score from blank features.
+ assert (scorer != null);
+ final int targetDoc = docID - docBase;
+ scorer.docID();
+ scorer.iterator().advance(targetDoc);
+
+ scorer.getDocInfo().setOriginalDocScore(new Float(hit.score));
+ hit.score = scorer.score();
+ if (hitUpto < topN) {
+ reranked[hitUpto] = hit;
+ // if the heap is not full, maybe I want to log the features for this
+ // document
+ if (featureLogger != null) {
+ featureLogger.log(hit.doc, scoringQuery, solrIndexSearch,
+ modelWeight.getFeaturesInfo());
+ }
+ } else if (hitUpto == topN) {
+ // collected topN document, I create the heap
+ heapify(reranked, topN);
+ }
+ if (hitUpto >= topN) {
+ // once that heap is ready, if the score of this document is lower that
+ // the minimum
+ // i don't want to log the feature. Otherwise I replace it with the
+ // minimum and fix the
+ // heap.
+ if (hit.score > reranked[0].score) {
+ reranked[0] = hit;
+ heapAdjust(reranked, topN, 0);
+ if (featureLogger != null) {
+ featureLogger.log(hit.doc, scoringQuery, solrIndexSearch,
+ modelWeight.getFeaturesInfo());
+ }
+ }
+ }
+ hitUpto++;
+ }
+ }
+
+ @Override
+ public Explanation explain(IndexSearcher searcher,
+ Explanation firstPassExplanation, int docID) throws IOException {
+
+ final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext()
+ .leaves();
+ final int n = ReaderUtil.subIndex(docID, leafContexts);
+ final LeafReaderContext context = leafContexts.get(n);
+ final int deBasedDoc = docID - context.docBase;
+ final Weight modelWeight = searcher.createNormalizedWeight(scoringQuery,
+ true);
+ return modelWeight.explain(context, deBasedDoc);
+ }
+
+ public static LTRScoringQuery.FeatureInfo[] extractFeaturesInfo(ModelWeight modelWeight,
+ int docid,
+ Float originalDocScore,
+ List<LeafReaderContext> leafContexts)
+ throws IOException {
+ final int n = ReaderUtil.subIndex(docid, leafContexts);
+ final LeafReaderContext atomicContext = leafContexts.get(n);
+ final int deBasedDoc = docid - atomicContext.docBase;
+ final ModelScorer r = modelWeight.scorer(atomicContext);
+ if ( (r == null) || (r.iterator().advance(deBasedDoc) != docid) ) {
+ return new LTRScoringQuery.FeatureInfo[0];
+ } else {
+ if (originalDocScore != null) {
+ // If results have not been reranked, the score passed in is the original query's
+ // score, which some features can use instead of recalculating it
+ r.getDocInfo().setOriginalDocScore(originalDocScore);
+ }
+ r.score();
+ return modelWeight.getFeaturesInfo();
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f2a8e8ac/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java
new file mode 100644
index 0000000..71a7ace
--- /dev/null
+++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java
@@ -0,0 +1,730 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import java.io.IOException;
+import java.lang.invoke.MethodHandles;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.Collection;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.Map.Entry;
+import java.util.Set;
+import java.util.concurrent.Callable;
+import java.util.concurrent.Future;
+import java.util.concurrent.FutureTask;
+import java.util.concurrent.RunnableFuture;
+import java.util.concurrent.Semaphore;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.DisiPriorityQueue;
+import org.apache.lucene.search.DisiWrapper;
+import org.apache.lucene.search.DisjunctionDISIApproximation;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.Explanation;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.Weight;
+import org.apache.solr.ltr.feature.Feature;
+import org.apache.solr.ltr.feature.Feature.FeatureWeight;
+import org.apache.solr.ltr.feature.Feature.FeatureWeight.FeatureScorer;
+import org.apache.solr.ltr.model.LTRScoringModel;
+import org.apache.solr.request.SolrQueryRequest;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * The ranking query that is run, reranking results using the
+ * LTRScoringModel algorithm
+ */
+public class LTRScoringQuery extends Query {
+
+ private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
+
+ // contains a description of the model
+ final private LTRScoringModel ltrScoringModel;
+ final private boolean extractAllFeatures;
+ final private LTRThreadModule ltrThreadMgr;
+ final private Semaphore querySemaphore; // limits the number of threads per query, so that multiple requests can be serviced simultaneously
+
+ // feature logger to output the features.
+ protected FeatureLogger<?> fl;
+ // Map of external parameters, such as query intent, that can be used by
+ // features
+ protected final Map<String,String[]> efi;
+ // Original solr query used to fetch matching documents
+ protected Query originalQuery;
+ // Original solr request
+ protected SolrQueryRequest request;
+
+ public LTRScoringQuery(LTRScoringModel ltrScoringModel) {
+ this(ltrScoringModel, Collections.<String,String[]>emptyMap(), false, null);
+ }
+
+ public LTRScoringQuery(LTRScoringModel ltrScoringModel, boolean extractAllFeatures) {
+ this(ltrScoringModel, Collections.<String, String[]>emptyMap(), extractAllFeatures, null);
+ }
+
+ public LTRScoringQuery(LTRScoringModel ltrScoringModel,
+ Map<String, String[]> externalFeatureInfo,
+ boolean extractAllFeatures, LTRThreadModule ltrThreadMgr) {
+ this.ltrScoringModel = ltrScoringModel;
+ this.efi = externalFeatureInfo;
+ this.extractAllFeatures = extractAllFeatures;
+ this.ltrThreadMgr = ltrThreadMgr;
+ if (this.ltrThreadMgr != null) {
+ this.querySemaphore = this.ltrThreadMgr.createQuerySemaphore();
+ } else{
+ this.querySemaphore = null;
+ }
+ }
+
+ public LTRScoringModel getScoringModel() {
+ return ltrScoringModel;
+ }
+
+ public void setFeatureLogger(FeatureLogger fl) {
+ this.fl = fl;
+ }
+
+ public FeatureLogger getFeatureLogger() {
+ return fl;
+ }
+
+ public void setOriginalQuery(Query originalQuery) {
+ this.originalQuery = originalQuery;
+ }
+
+ public Query getOriginalQuery() {
+ return originalQuery;
+ }
+
+ public Map<String,String[]> getExternalFeatureInfo() {
+ return efi;
+ }
+
+ public void setRequest(SolrQueryRequest request) {
+ this.request = request;
+ }
+
+ public SolrQueryRequest getRequest() {
+ return request;
+ }
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = classHash();
+ result = (prime * result) + ((ltrScoringModel == null) ? 0 : ltrScoringModel.hashCode());
+ result = (prime * result)
+ + ((originalQuery == null) ? 0 : originalQuery.hashCode());
+ if (efi == null) {
+ result = (prime * result) + 0;
+ }
+ else {
+ for (final Entry<String,String[]> entry : efi.entrySet()) {
+ final String key = entry.getKey();
+ final String[] values = entry.getValue();
+ result = (prime * result) + key.hashCode();
+ result = (prime * result) + Arrays.hashCode(values);
+ }
+ }
+ result = (prime * result) + this.toString().hashCode();
+ return result;
+ }
+ @Override
+ public boolean equals(Object o) {
+ return sameClassAs(o) && equalsTo(getClass().cast(o));
+ }
+
+ private boolean equalsTo(LTRScoringQuery other) {
+ if (ltrScoringModel == null) {
+ if (other.ltrScoringModel != null) {
+ return false;
+ }
+ } else if (!ltrScoringModel.equals(other.ltrScoringModel)) {
+ return false;
+ }
+ if (originalQuery == null) {
+ if (other.originalQuery != null) {
+ return false;
+ }
+ } else if (!originalQuery.equals(other.originalQuery)) {
+ return false;
+ }
+ if (efi == null) {
+ if (other.efi != null) {
+ return false;
+ }
+ } else {
+ if (other.efi == null || efi.size() != other.efi.size()) {
+ return false;
+ }
+ for(final Entry<String,String[]> entry : efi.entrySet()) {
+ final String key = entry.getKey();
+ final String[] otherValues = other.efi.get(key);
+ if (otherValues == null || !Arrays.equals(otherValues,entry.getValue())) {
+ return false;
+ }
+ }
+ }
+ return true;
+ }
+
+ @Override
+ public ModelWeight createWeight(IndexSearcher searcher, boolean needsScores, float boost)
+ throws IOException {
+ final Collection<Feature> modelFeatures = ltrScoringModel.getFeatures();
+ final Collection<Feature> allFeatures = ltrScoringModel.getAllFeatures();
+ int modelFeatSize = modelFeatures.size();
+
+ Collection<Feature> features = null;
+ if (this.extractAllFeatures) {
+ features = allFeatures;
+ }
+ else{
+ features = modelFeatures;
+ }
+ final FeatureWeight[] extractedFeatureWeights = new FeatureWeight[features.size()];
+ final FeatureWeight[] modelFeaturesWeights = new FeatureWeight[modelFeatSize];
+ List<FeatureWeight > featureWeights = new ArrayList<>(features.size());
+
+ if (querySemaphore == null) {
+ createWeights(searcher, needsScores, boost, featureWeights, features);
+ }
+ else{
+ createWeightsParallel(searcher, needsScores, boost, featureWeights, features);
+ }
+ int i=0, j = 0;
+ if (this.extractAllFeatures) {
+ for (final FeatureWeight fw : featureWeights) {
+ extractedFeatureWeights[i++] = fw;
+ }
+ for (final Feature f : modelFeatures){
+ modelFeaturesWeights[j++] = extractedFeatureWeights[f.getIndex()]; // we can lookup by featureid because all features will be extracted when this.extractAllFeatures is set
+ }
+ }
+ else{
+ for (final FeatureWeight fw: featureWeights){
+ extractedFeatureWeights[i++] = fw;
+ modelFeaturesWeights[j++] = fw;
+ }
+ }
+ return new ModelWeight(searcher, modelFeaturesWeights, extractedFeatureWeights, allFeatures.size());
+ }
+
+ private void createWeights(IndexSearcher searcher, boolean needsScores, float boost,
+ List<FeatureWeight > featureWeights, Collection<Feature> features) throws IOException {
+ final SolrQueryRequest req = getRequest();
+ // since the feature store is a linkedhashmap order is preserved
+ for (final Feature f : features) {
+ try{
+ FeatureWeight fw = f.createWeight(searcher, needsScores, req, originalQuery, efi);
+ featureWeights.add(fw);
+ } catch (final Exception e) {
+ throw new RuntimeException("Exception from createWeight for " + f.toString() + " "
+ + e.getMessage(), e);
+ }
+ }
+ }
+
+ class CreateWeightCallable implements Callable<FeatureWeight>{
+ private Feature f;
+ IndexSearcher searcher;
+ boolean needsScores;
+ SolrQueryRequest req;
+
+ public CreateWeightCallable(Feature f, IndexSearcher searcher, boolean needsScores, SolrQueryRequest req){
+ this.f = f;
+ this.searcher = searcher;
+ this.needsScores = needsScores;
+ this.req = req;
+ }
+
+ @Override
+ public FeatureWeight call() throws Exception{
+ try {
+ FeatureWeight fw = f.createWeight(searcher, needsScores, req, originalQuery, efi);
+ return fw;
+ } catch (final Exception e) {
+ throw new RuntimeException("Exception from createWeight for " + f.toString() + " "
+ + e.getMessage(), e);
+ } finally {
+ querySemaphore.release();
+ ltrThreadMgr.releaseLTRSemaphore();
+ }
+ }
+ } // end of call CreateWeightCallable
+
+ private void createWeightsParallel(IndexSearcher searcher, boolean needsScores, float boost,
+ List<FeatureWeight > featureWeights, Collection<Feature> features) throws RuntimeException {
+
+ final SolrQueryRequest req = getRequest();
+ List<Future<FeatureWeight> > futures = new ArrayList<>(features.size());
+ try{
+ for (final Feature f : features) {
+ CreateWeightCallable callable = new CreateWeightCallable(f, searcher, needsScores, req);
+ RunnableFuture<FeatureWeight> runnableFuture = new FutureTask<>(callable);
+ querySemaphore.acquire(); // always acquire before the ltrSemaphore is acquired, to guarantee a that the current query is within the limit for max. threads
+ ltrThreadMgr.acquireLTRSemaphore();//may block and/or interrupt
+ ltrThreadMgr.execute(runnableFuture);//releases semaphore when done
+ futures.add(runnableFuture);
+ }
+ //Loop over futures to get the feature weight objects
+ for (final Future<FeatureWeight> future : futures) {
+ featureWeights.add(future.get()); // future.get() will block if the job is still running
+ }
+ } catch (Exception e) { // To catch InterruptedException and ExecutionException
+ log.info("Error while creating weights in LTR: InterruptedException", e);
+ throw new RuntimeException("Error while creating weights in LTR: " + e.getMessage(), e);
+ }
+ }
+
+ @Override
+ public String toString(String field) {
+ return field;
+ }
+
+ public class FeatureInfo {
+ String name;
+ float value;
+ boolean used;
+
+ FeatureInfo(String n, float v, boolean u){
+ name = n; value = v; used = u;
+ }
+
+ public void setScore(float score){
+ this.value = score;
+ }
+
+ public String getName(){
+ return name;
+ }
+
+ public float getValue(){
+ return value;
+ }
+
+ public boolean isUsed(){
+ return used;
+ }
+
+ public void setUsed(boolean used){
+ this.used = used;
+ }
+ }
+
+ public class ModelWeight extends Weight {
+
+ IndexSearcher searcher;
+
+ // List of the model's features used for scoring. This is a subset of the
+ // features used for logging.
+ FeatureWeight[] modelFeatureWeights;
+ float[] modelFeatureValuesNormalized;
+ FeatureWeight[] extractedFeatureWeights;
+
+ // List of all the feature names, values - used for both scoring and logging
+ /*
+ * What is the advantage of using a hashmap here instead of an array of objects?
+ * A set of arrays was used earlier and the elements were accessed using the featureId.
+ * With the updated logic to create weights selectively,
+ * the number of elements in the array can be fewer than the total number of features.
+ * When [features] are not requested, only the model features are extracted.
+ * In this case, the indexing by featureId, fails. For this reason,
+ * we need a map which holds just the features that were triggered by the documents in the result set.
+ *
+ */
+ FeatureInfo[] featuresInfo;
+ /*
+ * @param modelFeatureWeights
+ * - should be the same size as the number of features used by the model
+ * @param extractedFeatureWeights
+ * - if features are requested from the same store as model feature store,
+ * this will be the size of total number of features in the model feature store
+ * else, this will be the size of the modelFeatureWeights
+ * @param allFeaturesSize
+ * - total number of feature in the feature store used by this model
+ */
+ public ModelWeight(IndexSearcher searcher, FeatureWeight[] modelFeatureWeights,
+ FeatureWeight[] extractedFeatureWeights, int allFeaturesSize) {
+ super(LTRScoringQuery.this);
+ this.searcher = searcher;
+ this.extractedFeatureWeights = extractedFeatureWeights;
+ this.modelFeatureWeights = modelFeatureWeights;
+ this.modelFeatureValuesNormalized = new float[modelFeatureWeights.length];
+ this.featuresInfo = new FeatureInfo[allFeaturesSize];
+ setFeaturesInfo();
+ }
+
+ private void setFeaturesInfo(){
+ for (int i = 0; i < extractedFeatureWeights.length;++i){
+ String featName = extractedFeatureWeights[i].getName();
+ int featId = extractedFeatureWeights[i].getIndex();
+ float value = extractedFeatureWeights[i].getDefaultValue();
+ featuresInfo[featId] = new FeatureInfo(featName,value,false);
+ }
+ }
+
+ public FeatureInfo[] getFeaturesInfo(){
+ return featuresInfo;
+ }
+
+ /**
+ * Goes through all the stored feature values, and calculates the normalized
+ * values for all the features that will be used for scoring.
+ */
+ private void makeNormalizedFeatures() {
+ int pos = 0;
+ for (final FeatureWeight feature : modelFeatureWeights) {
+ final int featureId = feature.getIndex();
+ FeatureInfo fInfo = featuresInfo[featureId];
+ if (fInfo.isUsed()) { // not checking for finfo == null as that would be a bug we should catch
+ modelFeatureValuesNormalized[pos] = fInfo.getValue();
+ } else {
+ modelFeatureValuesNormalized[pos] = feature.getDefaultValue();
+ }
+ pos++;
+ }
+ ltrScoringModel.normalizeFeaturesInPlace(modelFeatureValuesNormalized);
+ }
+
+ @Override
+ public Explanation explain(LeafReaderContext context, int doc)
+ throws IOException {
+
+ final Explanation[] explanations = new Explanation[this.featuresInfo.length];
+ for (final FeatureWeight feature : extractedFeatureWeights) {
+ explanations[feature.getIndex()] = feature.explain(context, doc);
+ }
+ final List<Explanation> featureExplanations = new ArrayList<>();
+ for (int idx = 0 ;idx < modelFeatureWeights.length; ++idx) {
+ final FeatureWeight f = modelFeatureWeights[idx];
+ Explanation e = ltrScoringModel.getNormalizerExplanation(explanations[f.getIndex()], idx);
+ featureExplanations.add(e);
+ }
+ final ModelScorer bs = scorer(context);
+ bs.iterator().advance(doc);
+
+ final float finalScore = bs.score();
+
+ return ltrScoringModel.explain(context, doc, finalScore, featureExplanations);
+
+ }
+
+ @Override
+ public void extractTerms(Set<Term> terms) {
+ for (final FeatureWeight feature : extractedFeatureWeights) {
+ feature.extractTerms(terms);
+ }
+ }
+
+ protected void reset() {
+ for (int i = 0; i < extractedFeatureWeights.length;++i){
+ int featId = extractedFeatureWeights[i].getIndex();
+ float value = extractedFeatureWeights[i].getDefaultValue();
+ featuresInfo[featId].setScore(value); // need to set default value everytime as the default value is used in 'dense' mode even if used=false
+ featuresInfo[featId].setUsed(false);
+ }
+ }
+
+ @Override
+ public ModelScorer scorer(LeafReaderContext context) throws IOException {
+
+ final List<FeatureScorer> featureScorers = new ArrayList<FeatureScorer>(
+ extractedFeatureWeights.length);
+ for (final FeatureWeight featureWeight : extractedFeatureWeights) {
+ final FeatureScorer scorer = featureWeight.scorer(context);
+ if (scorer != null) {
+ featureScorers.add(featureWeight.scorer(context));
+ }
+ }
+ // Always return a ModelScorer, even if no features match, because we
+ // always need to call
+ // score on the model for every document, since 0 features matching could
+ // return a
+ // non 0 score for a given model.
+ ModelScorer mscorer = new ModelScorer(this, featureScorers);
+ return mscorer;
+
+ }
+
+ public class ModelScorer extends Scorer {
+ final private DocInfo docInfo;
+ final private Scorer featureTraversalScorer;
+
+ public DocInfo getDocInfo() {
+ return docInfo;
+ }
+
+ public ModelScorer(Weight weight, List<FeatureScorer> featureScorers) {
+ super(weight);
+ docInfo = new DocInfo();
+ for (final FeatureScorer subSocer : featureScorers) {
+ subSocer.setDocInfo(docInfo);
+ }
+ if (featureScorers.size() <= 1) { // TODO: Allow the use of dense
+ // features in other cases
+ featureTraversalScorer = new DenseModelScorer(weight, featureScorers);
+ } else {
+ featureTraversalScorer = new SparseModelScorer(weight, featureScorers);
+ }
+ }
+
+ @Override
+ public Collection<ChildScorer> getChildren() {
+ return featureTraversalScorer.getChildren();
+ }
+
+ @Override
+ public int docID() {
+ return featureTraversalScorer.docID();
+ }
+
+ @Override
+ public float score() throws IOException {
+ return featureTraversalScorer.score();
+ }
+
+ @Override
+ public int freq() throws IOException {
+ return featureTraversalScorer.freq();
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return featureTraversalScorer.iterator();
+ }
+
+ public class SparseModelScorer extends Scorer {
+ protected DisiPriorityQueue subScorers;
+ protected ScoringQuerySparseIterator itr;
+
+ protected int targetDoc = -1;
+ protected int activeDoc = -1;
+
+ protected SparseModelScorer(Weight weight,
+ List<FeatureScorer> featureScorers) {
+ super(weight);
+ if (featureScorers.size() <= 1) {
+ throw new IllegalArgumentException(
+ "There must be at least 2 subScorers");
+ }
+ subScorers = new DisiPriorityQueue(featureScorers.size());
+ for (final Scorer scorer : featureScorers) {
+ final DisiWrapper w = new DisiWrapper(scorer);
+ subScorers.add(w);
+ }
+
+ itr = new ScoringQuerySparseIterator(subScorers);
+ }
+
+ @Override
+ public int docID() {
+ return itr.docID();
+ }
+
+ @Override
+ public float score() throws IOException {
+ final DisiWrapper topList = subScorers.topList();
+ // If target doc we wanted to advance to matches the actual doc
+ // the underlying features advanced to, perform the feature
+ // calculations,
+ // otherwise just continue with the model's scoring process with empty
+ // features.
+ reset();
+ if (activeDoc == targetDoc) {
+ for (DisiWrapper w = topList; w != null; w = w.next) {
+ final Scorer subScorer = w.scorer;
+ FeatureWeight scFW = (FeatureWeight) subScorer.getWeight();
+ final int featureId = scFW.getIndex();
+ featuresInfo[featureId].setScore(subScorer.score());
+ featuresInfo[featureId].setUsed(true);
+ }
+ }
+ makeNormalizedFeatures();
+ return ltrScoringModel.score(modelFeatureValuesNormalized);
+ }
+
+ @Override
+ public int freq() throws IOException {
+ final DisiWrapper subMatches = subScorers.topList();
+ int freq = 1;
+ for (DisiWrapper w = subMatches.next; w != null; w = w.next) {
+ freq += 1;
+ }
+ return freq;
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return itr;
+ }
+
+ @Override
+ public final Collection<ChildScorer> getChildren() {
+ final ArrayList<ChildScorer> children = new ArrayList<>();
+ for (final DisiWrapper scorer : subScorers) {
+ children.add(new ChildScorer(scorer.scorer, "SHOULD"));
+ }
+ return children;
+ }
+
+ protected class ScoringQuerySparseIterator extends
+ DisjunctionDISIApproximation {
+
+ public ScoringQuerySparseIterator(DisiPriorityQueue subIterators) {
+ super(subIterators);
+ }
+
+ @Override
+ public final int nextDoc() throws IOException {
+ if (activeDoc == targetDoc) {
+ activeDoc = super.nextDoc();
+ } else if (activeDoc < targetDoc) {
+ activeDoc = super.advance(targetDoc + 1);
+ }
+ return ++targetDoc;
+ }
+
+ @Override
+ public final int advance(int target) throws IOException {
+ // If target doc we wanted to advance to matches the actual doc
+ // the underlying features advanced to, perform the feature
+ // calculations,
+ // otherwise just continue with the model's scoring process with
+ // empty features.
+ if (activeDoc < target) {
+ activeDoc = super.advance(target);
+ }
+ targetDoc = target;
+ return targetDoc;
+ }
+ }
+
+ }
+
+ public class DenseModelScorer extends Scorer {
+ int activeDoc = -1; // The doc that our scorer's are actually at
+ int targetDoc = -1; // The doc we were most recently told to go to
+ int freq = -1;
+ List<FeatureScorer> featureScorers;
+
+ protected DenseModelScorer(Weight weight,
+ List<FeatureScorer> featureScorers) {
+ super(weight);
+ this.featureScorers = featureScorers;
+ }
+
+ @Override
+ public int docID() {
+ return targetDoc;
+ }
+
+ @Override
+ public float score() throws IOException {
+ reset();
+ freq = 0;
+ if (targetDoc == activeDoc) {
+ for (final Scorer scorer : featureScorers) {
+ if (scorer.docID() == activeDoc) {
+ freq++;
+ FeatureWeight scFW = (FeatureWeight) scorer.getWeight();
+ final int featureId = scFW.getIndex();
+ featuresInfo[featureId].setScore(scorer.score());
+ featuresInfo[featureId].setUsed(true);
+ }
+ }
+ }
+ makeNormalizedFeatures();
+ return ltrScoringModel.score(modelFeatureValuesNormalized);
+ }
+
+ @Override
+ public final Collection<ChildScorer> getChildren() {
+ final ArrayList<ChildScorer> children = new ArrayList<>();
+ for (final Scorer scorer : featureScorers) {
+ children.add(new ChildScorer(scorer, "SHOULD"));
+ }
+ return children;
+ }
+
+ @Override
+ public int freq() throws IOException {
+ return freq;
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return new DenseIterator();
+ }
+
+ class DenseIterator extends DocIdSetIterator {
+
+ @Override
+ public int docID() {
+ return targetDoc;
+ }
+
+ @Override
+ public int nextDoc() throws IOException {
+ if (activeDoc <= targetDoc) {
+ activeDoc = NO_MORE_DOCS;
+ for (final Scorer scorer : featureScorers) {
+ if (scorer.docID() != NO_MORE_DOCS) {
+ activeDoc = Math.min(activeDoc, scorer.iterator().nextDoc());
+ }
+ }
+ }
+ return ++targetDoc;
+ }
+
+ @Override
+ public int advance(int target) throws IOException {
+ if (activeDoc < target) {
+ activeDoc = NO_MORE_DOCS;
+ for (final Scorer scorer : featureScorers) {
+ if (scorer.docID() != NO_MORE_DOCS) {
+ activeDoc = Math.min(activeDoc,
+ scorer.iterator().advance(target));
+ }
+ }
+ }
+ targetDoc = target;
+ return target;
+ }
+
+ @Override
+ public long cost() {
+ long sum = 0;
+ for (int i = 0; i < featureScorers.size(); i++) {
+ sum += featureScorers.get(i).iterator().cost();
+ }
+ return sum;
+ }
+
+ }
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f2a8e8ac/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRThreadModule.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRThreadModule.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRThreadModule.java
new file mode 100644
index 0000000..0576f6f
--- /dev/null
+++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/LTRThreadModule.java
@@ -0,0 +1,159 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import java.util.Iterator;
+import java.util.Map;
+import java.util.concurrent.Executor;
+import java.util.concurrent.Semaphore;
+import java.util.concurrent.SynchronousQueue;
+import java.util.concurrent.TimeUnit;
+
+import org.apache.solr.common.util.ExecutorUtil;
+import org.apache.solr.common.util.NamedList;
+import org.apache.solr.util.DefaultSolrThreadFactory;
+import org.apache.solr.util.SolrPluginUtils;
+import org.apache.solr.util.plugin.NamedListInitializedPlugin;
+
+final public class LTRThreadModule implements NamedListInitializedPlugin {
+
+ public static LTRThreadModule getInstance(NamedList args) {
+
+ final LTRThreadModule threadManager;
+ final NamedList threadManagerArgs = extractThreadModuleParams(args);
+ // if and only if there are thread module args then we want a thread module!
+ if (threadManagerArgs.size() > 0) {
+ // create and initialize the new instance
+ threadManager = new LTRThreadModule();
+ threadManager.init(threadManagerArgs);
+ } else {
+ threadManager = null;
+ }
+
+ return threadManager;
+ }
+
+ private static String CONFIG_PREFIX = "threadModule.";
+
+ private static NamedList extractThreadModuleParams(NamedList args) {
+
+ // gather the thread module args from amongst the general args
+ final NamedList extractedArgs = new NamedList();
+ for (Iterator<Map.Entry<String,Object>> it = args.iterator();
+ it.hasNext(); ) {
+ final Map.Entry<String,Object> entry = it.next();
+ final String key = entry.getKey();
+ if (key.startsWith(CONFIG_PREFIX)) {
+ extractedArgs.add(key.substring(CONFIG_PREFIX.length()), entry.getValue());
+ }
+ }
+
+ // remove consumed keys only once iteration is complete
+ // since NamedList iterator does not support 'remove'
+ for (Object key : extractedArgs.asShallowMap().keySet()) {
+ args.remove(CONFIG_PREFIX+key);
+ }
+
+ return extractedArgs;
+ }
+
+ // settings
+ private int totalPoolThreads = 1;
+ private int numThreadsPerRequest = 1;
+ private int maxPoolSize = Integer.MAX_VALUE;
+ private long keepAliveTimeSeconds = 10;
+ private String threadNamePrefix = "ltrExecutor";
+
+ // implementation
+ private Semaphore ltrSemaphore;
+ private Executor createWeightScoreExecutor;
+
+ public LTRThreadModule() {
+ }
+
+ // For test use only.
+ LTRThreadModule(int totalPoolThreads, int numThreadsPerRequest) {
+ this.totalPoolThreads = totalPoolThreads;
+ this.numThreadsPerRequest = numThreadsPerRequest;
+ init(null);
+ }
+
+ @Override
+ public void init(NamedList args) {
+ if (args != null) {
+ SolrPluginUtils.invokeSetters(this, args);
+ }
+ validate();
+ if (this.totalPoolThreads > 1 ){
+ ltrSemaphore = new Semaphore(totalPoolThreads);
+ } else {
+ ltrSemaphore = null;
+ }
+ createWeightScoreExecutor = new ExecutorUtil.MDCAwareThreadPoolExecutor(
+ 0,
+ maxPoolSize,
+ keepAliveTimeSeconds, TimeUnit.SECONDS, // terminate idle threads after 10 sec
+ new SynchronousQueue<Runnable>(), // directly hand off tasks
+ new DefaultSolrThreadFactory(threadNamePrefix)
+ );
+ }
+
+ public void validate() {
+ if (totalPoolThreads <= 0){
+ throw new IllegalArgumentException("totalPoolThreads cannot be less than 1");
+ }
+ if (numThreadsPerRequest <= 0){
+ throw new IllegalArgumentException("numThreadsPerRequest cannot be less than 1");
+ }
+ if (totalPoolThreads < numThreadsPerRequest){
+ throw new IllegalArgumentException("numThreadsPerRequest cannot be greater than totalPoolThreads");
+ }
+ }
+
+ public void setTotalPoolThreads(int totalPoolThreads) {
+ this.totalPoolThreads = totalPoolThreads;
+ }
+
+ public void setNumThreadsPerRequest(int numThreadsPerRequest) {
+ this.numThreadsPerRequest = numThreadsPerRequest;
+ }
+
+ public void setMaxPoolSize(int maxPoolSize) {
+ this.maxPoolSize = maxPoolSize;
+ }
+
+ public void setThreadNamePrefix(String threadNamePrefix) {
+ this.threadNamePrefix = threadNamePrefix;
+ }
+
+ public Semaphore createQuerySemaphore() {
+ return (numThreadsPerRequest > 1 ? new Semaphore(numThreadsPerRequest) : null);
+ }
+
+ public void acquireLTRSemaphore() throws InterruptedException {
+ ltrSemaphore.acquire();
+ }
+
+ public void releaseLTRSemaphore() throws InterruptedException {
+ ltrSemaphore.release();
+ }
+
+ public void execute(Runnable command) {
+ createWeightScoreExecutor.execute(command);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f2a8e8ac/solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java
new file mode 100644
index 0000000..66426ea
--- /dev/null
+++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/SolrQueryRequestContextUtils.java
@@ -0,0 +1,83 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import org.apache.solr.request.SolrQueryRequest;
+
+public class SolrQueryRequestContextUtils {
+
+ /** key prefix to reduce possibility of clash with other code's key choices **/
+ private static final String LTR_PREFIX = "ltr.";
+
+ /** key of the feature logger in the request context **/
+ private static final String FEATURE_LOGGER = LTR_PREFIX + "feature_logger";
+
+ /** key of the scoring query in the request context **/
+ private static final String SCORING_QUERY = LTR_PREFIX + "scoring_query";
+
+ /** key of the isExtractingFeatures flag in the request context **/
+ private static final String IS_EXTRACTING_FEATURES = LTR_PREFIX + "isExtractingFeatures";
+
+ /** key of the feature vector store name in the request context **/
+ private static final String STORE = LTR_PREFIX + "store";
+
+ /** feature logger accessors **/
+
+ public static void setFeatureLogger(SolrQueryRequest req, FeatureLogger<?> featureLogger) {
+ req.getContext().put(FEATURE_LOGGER, featureLogger);
+ }
+
+ public static FeatureLogger<?> getFeatureLogger(SolrQueryRequest req) {
+ return (FeatureLogger<?>) req.getContext().get(FEATURE_LOGGER);
+ }
+
+ /** scoring query accessors **/
+
+ public static void setScoringQuery(SolrQueryRequest req, LTRScoringQuery scoringQuery) {
+ req.getContext().put(SCORING_QUERY, scoringQuery);
+ }
+
+ public static LTRScoringQuery getScoringQuery(SolrQueryRequest req) {
+ return (LTRScoringQuery) req.getContext().get(SCORING_QUERY);
+ }
+
+ /** isExtractingFeatures flag accessors **/
+
+ public static void setIsExtractingFeatures(SolrQueryRequest req) {
+ req.getContext().put(IS_EXTRACTING_FEATURES, Boolean.TRUE);
+ }
+
+ public static void clearIsExtractingFeatures(SolrQueryRequest req) {
+ req.getContext().put(IS_EXTRACTING_FEATURES, Boolean.FALSE);
+ }
+
+ public static boolean isExtractingFeatures(SolrQueryRequest req) {
+ return Boolean.TRUE.equals(req.getContext().get(IS_EXTRACTING_FEATURES));
+ }
+
+ /** feature vector store name accessors **/
+
+ public static void setFvStoreName(SolrQueryRequest req, String fvStoreName) {
+ req.getContext().put(STORE, fvStoreName);
+ }
+
+ public static String getFvStoreName(SolrQueryRequest req) {
+ return (String) req.getContext().get(STORE);
+ }
+
+}
+
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f2a8e8ac/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/Feature.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/Feature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/Feature.java
new file mode 100644
index 0000000..87adde2
--- /dev/null
+++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/Feature.java
@@ -0,0 +1,335 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr.feature;
+
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.Explanation;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.Weight;
+import org.apache.solr.core.SolrResourceLoader;
+import org.apache.solr.ltr.DocInfo;
+import org.apache.solr.request.SolrQueryRequest;
+import org.apache.solr.request.macro.MacroExpander;
+import org.apache.solr.util.SolrPluginUtils;
+
+/**
+ * A recipe for computing a feature. Subclass this for specialized feature calculations.
+ * <p>
+ * A feature consists of
+ * <ul>
+ * <li> a name as the identifier
+ * <li> parameters to represent the specific feature
+ * </ul>
+ * <p>
+ * Example configuration (snippet):
+ * <pre>{
+ "class" : "...",
+ "name" : "myFeature",
+ "params" : {
+ ...
+ }
+}</pre>
+ * <p>
+ * {@link Feature} is an abstract class and concrete classes should implement
+ * the {@link #validate()} function, and must implement the {@link #paramsToMap()}
+ * and createWeight() methods.
+ */
+public abstract class Feature extends Query {
+
+ final protected String name;
+ private int index = -1;
+ private float defaultValue = 0.0f;
+
+ final private Map<String,Object> params;
+
+ public static Feature getInstance(SolrResourceLoader solrResourceLoader,
+ String className, String name, Map<String,Object> params) {
+ final Feature f = solrResourceLoader.newInstance(
+ className,
+ Feature.class,
+ new String[0], // no sub packages
+ new Class[] { String.class, Map.class },
+ new Object[] { name, params });
+ if (params != null) {
+ SolrPluginUtils.invokeSetters(f, params.entrySet());
+ }
+ f.validate();
+ return f;
+ }
+
+ public Feature(String name, Map<String,Object> params) {
+ this.name = name;
+ this.params = params;
+ }
+
+ /**
+ * On construction of a feature, this function confirms
+ * that the feature parameters are validated
+ *
+ * @throws FeatureException
+ * Feature Exception
+ */
+ protected void validate() throws FeatureException {
+
+ }
+
+ @Override
+ public String toString(String field) {
+ final StringBuilder sb = new StringBuilder(64); // default initialCapacity of 16 won't be enough
+ sb.append(getClass().getSimpleName());
+ sb.append(" [name=").append(name);
+ final LinkedHashMap<String,Object> params = paramsToMap();
+ if (params != null) {
+ sb.append(", params=").append(params);
+ }
+ sb.append(']');
+ return sb.toString();
+ }
+
+ public abstract FeatureWeight createWeight(IndexSearcher searcher,
+ boolean needsScores, SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) throws IOException;
+
+ public float getDefaultValue() {
+ return defaultValue;
+ }
+
+ public void setDefaultValue(String value){
+ defaultValue = Float.parseFloat(value);
+ }
+
+
+ @Override
+ public int hashCode() {
+ final int prime = 31;
+ int result = classHash();
+ result = (prime * result) + index;
+ result = (prime * result) + ((name == null) ? 0 : name.hashCode());
+ result = (prime * result) + ((params == null) ? 0 : params.hashCode());
+ return result;
+ }
+
+ @Override
+ public boolean equals(Object o) {
+ return sameClassAs(o) && equalsTo(getClass().cast(o));
+ }
+
+ private boolean equalsTo(Feature other) {
+ if (index != other.index) {
+ return false;
+ }
+ if (name == null) {
+ if (other.name != null) {
+ return false;
+ }
+ } else if (!name.equals(other.name)) {
+ return false;
+ }
+ if (params == null) {
+ if (other.params != null) {
+ return false;
+ }
+ } else if (!params.equals(other.params)) {
+ return false;
+ }
+ return true;
+ }
+
+ /**
+ * @return the name
+ */
+ public String getName() {
+ return name;
+ }
+
+ /**
+ * @return the id
+ */
+ public int getIndex() {
+ return index;
+ }
+
+ /**
+ * @param index
+ * Unique ID for this feature. Similar to feature name, except it can
+ * be used to directly access the feature in the global list of
+ * features.
+ */
+ public void setIndex(int index) {
+ this.index = index;
+ }
+
+ public abstract LinkedHashMap<String,Object> paramsToMap();
+ /**
+ * Weight for a feature
+ **/
+ public abstract class FeatureWeight extends Weight {
+
+ final protected IndexSearcher searcher;
+ final protected SolrQueryRequest request;
+ final protected Map<String,String[]> efi;
+ final protected MacroExpander macroExpander;
+ final protected Query originalQuery;
+
+ /**
+ * Initialize a feature without the normalizer from the feature file. This is
+ * called on initial construction since multiple models share the same
+ * features, but have different normalizers. A concrete model's feature is
+ * copied through featForNewModel().
+ *
+ * @param q
+ * Solr query associated with this FeatureWeight
+ * @param searcher
+ * Solr searcher available for features if they need them
+ */
+ public FeatureWeight(Query q, IndexSearcher searcher,
+ SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) {
+ super(q);
+ this.searcher = searcher;
+ this.request = request;
+ this.originalQuery = originalQuery;
+ this.efi = efi;
+ macroExpander = new MacroExpander(efi,true);
+ }
+
+ public String getName() {
+ return Feature.this.getName();
+ }
+
+ public int getIndex() {
+ return Feature.this.getIndex();
+ }
+
+ public float getDefaultValue() {
+ return Feature.this.getDefaultValue();
+ }
+
+ @Override
+ public abstract FeatureScorer scorer(LeafReaderContext context)
+ throws IOException;
+
+ @Override
+ public Explanation explain(LeafReaderContext context, int doc)
+ throws IOException {
+ final FeatureScorer r = scorer(context);
+ float score = getDefaultValue();
+ if (r != null) {
+ r.iterator().advance(doc);
+ if (r.docID() == doc) score = r.score();
+ return Explanation.match(score, toString());
+ }else{
+ return Explanation.match(score, "The feature has no value");
+ }
+ }
+
+ /**
+ * Used in the FeatureWeight's explain. Each feature should implement this
+ * returning properties of the specific scorer useful for an explain. For
+ * example "MyCustomClassFeature [name=" + name + "myVariable:" + myVariable +
+ * "]"; If not provided, a default implementation will return basic feature
+ * properties, which might not include query time specific values.
+ */
+ @Override
+ public String toString() {
+ return Feature.this.toString();
+ }
+
+ @Override
+ public void extractTerms(Set<Term> terms) {
+ // needs to be implemented by query subclasses
+ throw new UnsupportedOperationException();
+ }
+
+ /**
+ * A 'recipe' for computing a feature
+ */
+ public abstract class FeatureScorer extends Scorer {
+
+ final protected String name;
+ private DocInfo docInfo;
+ protected DocIdSetIterator itr;
+
+ public FeatureScorer(Feature.FeatureWeight weight,
+ DocIdSetIterator itr) {
+ super(weight);
+ this.itr = itr;
+ name = weight.getName();
+ docInfo = null;
+ }
+
+ @Override
+ public abstract float score() throws IOException;
+
+ /**
+ * Used to provide context from initial score steps to later reranking steps.
+ */
+ public void setDocInfo(DocInfo docInfo) {
+ this.docInfo = docInfo;
+ }
+
+ public DocInfo getDocInfo() {
+ return docInfo;
+ }
+
+ @Override
+ public int freq() throws IOException {
+ throw new UnsupportedOperationException();
+ }
+
+ @Override
+ public int docID() {
+ return itr.docID();
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return itr;
+ }
+ }
+
+ /**
+ * Default FeatureScorer class that returns the score passed in. Can be used
+ * as a simple ValueFeature, or to return a default scorer in case an
+ * underlying feature's scorer is null.
+ */
+ public class ValueFeatureScorer extends FeatureScorer {
+ float constScore;
+
+ public ValueFeatureScorer(FeatureWeight weight, float constScore,
+ DocIdSetIterator itr) {
+ super(weight,itr);
+ this.constScore = constScore;
+ }
+
+ @Override
+ public float score() {
+ return constScore;
+ }
+
+ }
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f2a8e8ac/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FeatureException.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FeatureException.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FeatureException.java
new file mode 100644
index 0000000..6c8f827
--- /dev/null
+++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FeatureException.java
@@ -0,0 +1,31 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr.feature;
+
+public class FeatureException extends RuntimeException {
+
+ private static final long serialVersionUID = 1L;
+
+ public FeatureException(String message) {
+ super(message);
+ }
+
+ public FeatureException(String message, Exception cause) {
+ super(message, cause);
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f2a8e8ac/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldLengthFeature.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldLengthFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldLengthFeature.java
new file mode 100644
index 0000000..20e5020
--- /dev/null
+++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldLengthFeature.java
@@ -0,0 +1,144 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr.feature;
+
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import org.apache.lucene.index.IndexableField;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.NumericDocValues;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.util.SmallFloat;
+import org.apache.solr.request.SolrQueryRequest;
+/**
+ * This feature returns the length of a field (in terms) for the current document.
+ * Example configuration:
+ * <pre>{
+ "name": "titleLength",
+ "class": "org.apache.solr.ltr.feature.FieldLengthFeature",
+ "params": {
+ "field": "title"
+ }
+}</pre>
+ * Note: since this feature relies on norms values that are stored in a single byte
+ * the value of the feature could have a lightly different value.
+ * (see also {@link org.apache.lucene.search.similarities.ClassicSimilarity})
+ **/
+public class FieldLengthFeature extends Feature {
+
+ private String field;
+
+ public String getField() {
+ return field;
+ }
+
+ public void setField(String field) {
+ this.field = field;
+ }
+
+ @Override
+ public LinkedHashMap<String,Object> paramsToMap() {
+ final LinkedHashMap<String,Object> params = new LinkedHashMap<>(1, 1.0f);
+ params.put("field", field);
+ return params;
+ }
+
+ /** Cache of decoded bytes. */
+
+ private static final float[] NORM_TABLE = new float[256];
+
+ static {
+ NORM_TABLE[0] = 0;
+ for (int i = 1; i < 256; i++) {
+ float norm = SmallFloat.byte315ToFloat((byte) i);
+ NORM_TABLE[i] = 1.0f / (norm * norm);
+ }
+ }
+
+ /**
+ * Decodes the norm value, assuming it is a single byte.
+ *
+ */
+
+ private final float decodeNorm(long norm) {
+ return NORM_TABLE[(int) (norm & 0xFF)]; // & 0xFF maps negative bytes to
+ // positive above 127
+ }
+
+ public FieldLengthFeature(String name, Map<String,Object> params) {
+ super(name, params);
+ }
+
+ @Override
+ public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores,
+ SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi)
+ throws IOException {
+
+ return new FieldLengthFeatureWeight(searcher, request, originalQuery, efi);
+ }
+
+
+ public class FieldLengthFeatureWeight extends FeatureWeight {
+
+ public FieldLengthFeatureWeight(IndexSearcher searcher,
+ SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) {
+ super(FieldLengthFeature.this, searcher, request, originalQuery, efi);
+ }
+
+ @Override
+ public FeatureScorer scorer(LeafReaderContext context) throws IOException {
+ NumericDocValues norms = context.reader().getNormValues(field);
+ if (norms == null){
+ return new ValueFeatureScorer(this, 0f,
+ DocIdSetIterator.all(DocIdSetIterator.NO_MORE_DOCS));
+ }
+ return new FieldLengthFeatureScorer(this, norms);
+ }
+
+ public class FieldLengthFeatureScorer extends FeatureScorer {
+
+ NumericDocValues norms = null;
+
+ public FieldLengthFeatureScorer(FeatureWeight weight,
+ NumericDocValues norms) throws IOException {
+ super(weight, norms);
+ this.norms = norms;
+
+ // In the constructor, docId is -1, so using 0 as default lookup
+ final IndexableField idxF = searcher.doc(0).getField(field);
+ if (idxF.fieldType().omitNorms()) {
+ throw new IOException(
+ "FieldLengthFeatures can't be used if omitNorms is enabled (field="
+ + field + ")");
+ }
+ }
+
+ @Override
+ public float score() throws IOException {
+
+ final long l = norms.longValue();
+ final float numTerms = decodeNorm(l);
+ return numTerms;
+ }
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f2a8e8ac/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java
new file mode 100644
index 0000000..49e0787
--- /dev/null
+++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/FieldValueFeature.java
@@ -0,0 +1,133 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr.feature;
+
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.Map;
+import java.util.Set;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.index.IndexableField;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.solr.request.SolrQueryRequest;
+
+import com.google.common.collect.Sets;
+/**
+ * This feature returns the value of a field in the current document
+ * Example configuration:
+ * <pre>{
+ "name": "rawHits",
+ "class": "org.apache.solr.ltr.feature.FieldValueFeature",
+ "params": {
+ "field": "hits"
+ }
+}</pre>
+ */
+public class FieldValueFeature extends Feature {
+
+ private String field;
+ private Set<String> fieldAsSet;
+
+ public String getField() {
+ return field;
+ }
+
+ public void setField(String field) {
+ this.field = field;
+ fieldAsSet = Sets.newHashSet(field);
+ }
+
+ @Override
+ public LinkedHashMap<String,Object> paramsToMap() {
+ final LinkedHashMap<String,Object> params = new LinkedHashMap<>(1, 1.0f);
+ params.put("field", field);
+ return params;
+ }
+
+ public FieldValueFeature(String name, Map<String,Object> params) {
+ super(name, params);
+ }
+
+ @Override
+ public FeatureWeight createWeight(IndexSearcher searcher, boolean needsScores,
+ SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi)
+ throws IOException {
+ return new FieldValueFeatureWeight(searcher, request, originalQuery, efi);
+ }
+
+ public class FieldValueFeatureWeight extends FeatureWeight {
+
+ public FieldValueFeatureWeight(IndexSearcher searcher,
+ SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) {
+ super(FieldValueFeature.this, searcher, request, originalQuery, efi);
+ }
+
+ @Override
+ public FeatureScorer scorer(LeafReaderContext context) throws IOException {
+ return new FieldValueFeatureScorer(this, context,
+ DocIdSetIterator.all(DocIdSetIterator.NO_MORE_DOCS));
+ }
+
+ public class FieldValueFeatureScorer extends FeatureScorer {
+
+ LeafReaderContext context = null;
+
+ public FieldValueFeatureScorer(FeatureWeight weight,
+ LeafReaderContext context, DocIdSetIterator itr) {
+ super(weight, itr);
+ this.context = context;
+ }
+
+ @Override
+ public float score() throws IOException {
+
+ try {
+ final Document document = context.reader().document(itr.docID(),
+ fieldAsSet);
+ final IndexableField indexableField = document.getField(field);
+ if (indexableField == null) {
+ return getDefaultValue();
+ }
+ final Number number = indexableField.numericValue();
+ if (number != null) {
+ return number.floatValue();
+ } else {
+ final String string = indexableField.stringValue();
+ // boolean values in the index are encoded with the
+ // chars T/F
+ if (string.equals("T")) {
+ return 1;
+ }
+ if (string.equals("F")) {
+ return 0;
+ }
+ }
+ } catch (final IOException e) {
+ throw new FeatureException(
+ e.toString() + ": " +
+ "Unable to extract feature for "
+ + name, e);
+ }
+ return getDefaultValue();
+ }
+ }
+ }
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/f2a8e8ac/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/OriginalScoreFeature.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/OriginalScoreFeature.java b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/OriginalScoreFeature.java
new file mode 100644
index 0000000..f141474
--- /dev/null
+++ b/solr/contrib/ltr/src/java/org/apache/solr/ltr/feature/OriginalScoreFeature.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr.feature;
+
+import java.io.IOException;
+import java.util.LinkedHashMap;
+import java.util.Map;
+
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.search.DocIdSetIterator;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Query;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.Weight;
+import org.apache.solr.ltr.DocInfo;
+import org.apache.solr.request.SolrQueryRequest;
+/**
+ * This feature returns the original score that the document had before performing
+ * the reranking.
+ * Example configuration:
+ * <pre>{
+ "name": "originalScore",
+ "class": "org.apache.solr.ltr.feature.OriginalScoreFeature",
+ "params": { }
+}</pre>
+ **/
+public class OriginalScoreFeature extends Feature {
+
+ public OriginalScoreFeature(String name, Map<String,Object> params) {
+ super(name, params);
+ }
+
+ @Override
+ public LinkedHashMap<String,Object> paramsToMap() {
+ return null;
+ }
+
+ @Override
+ public OriginalScoreWeight createWeight(IndexSearcher searcher,
+ boolean needsScores, SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) throws IOException {
+ return new OriginalScoreWeight(searcher, request, originalQuery, efi);
+
+ }
+
+ public class OriginalScoreWeight extends FeatureWeight {
+
+ final Weight w;
+
+ public OriginalScoreWeight(IndexSearcher searcher,
+ SolrQueryRequest request, Query originalQuery, Map<String,String[]> efi) throws IOException {
+ super(OriginalScoreFeature.this, searcher, request, originalQuery, efi);
+ w = searcher.createNormalizedWeight(originalQuery, true);
+ };
+
+
+ @Override
+ public String toString() {
+ return "OriginalScoreFeature [query:" + originalQuery.toString() + "]";
+ }
+
+
+
+ @Override
+ public FeatureScorer scorer(LeafReaderContext context) throws IOException {
+
+ final Scorer originalScorer = w.scorer(context);
+ return new OriginalScoreScorer(this, originalScorer);
+ }
+
+ public class OriginalScoreScorer extends FeatureScorer {
+ Scorer originalScorer;
+
+ public OriginalScoreScorer(FeatureWeight weight, Scorer originalScorer) {
+ super(weight,null);
+ this.originalScorer = originalScorer;
+ }
+
+ @Override
+ public float score() throws IOException {
+ // This is done to improve the speed of feature extraction. Since this
+ // was already scored in step 1
+ // we shouldn't need to calc original score again.
+ final DocInfo docInfo = getDocInfo();
+ return (docInfo.hasOriginalDocScore() ? docInfo.getOriginalDocScore() : originalScorer.score());
+ }
+
+ @Override
+ public int docID() {
+ return originalScorer.docID();
+ }
+
+ @Override
+ public DocIdSetIterator iterator() {
+ return originalScorer.iterator();
+ }
+ }
+
+ }
+
+}