You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by cp...@apache.org on 2021/05/25 15:01:54 UTC
[lucene-solr] branch branch_8x updated: SOLR-11134: restructure
TestLTRReRankingPipeline and fix testDifferentTopN test (#145)
This is an automated email from the ASF dual-hosted git repository.
cpoerschke pushed a commit to branch branch_8x
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git
The following commit(s) were added to refs/heads/branch_8x by this push:
new eb8b1b5 SOLR-11134: restructure TestLTRReRankingPipeline and fix testDifferentTopN test (#145)
eb8b1b5 is described below
commit eb8b1b5a1e864e262f9619ed2697074c45d9daf3
Author: Christine Poerschke <cp...@apache.org>
AuthorDate: Tue May 25 14:09:02 2021 +0100
SOLR-11134: restructure TestLTRReRankingPipeline and fix testDifferentTopN test (#145)
(Stanislav Livotov, Tom Gilke, Christine Poerschke)
---
.../test-files/solr/collection1/conf/schema.xml | 3 +
.../apache/solr/ltr/TestLTRReRankingPipeline.java | 345 +++++++++------------
.../org/apache/solr/ltr/model/TestLinearModel.java | 8 +-
3 files changed, 156 insertions(+), 200 deletions(-)
diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml b/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml
index 4699b0f..b6f5b3b 100644
--- a/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml
+++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml
@@ -19,6 +19,9 @@
<schema name="example" version="1.5">
<fields>
<field name="id" type="string" indexed="true" stored="true" required="true" multiValued="false" />
+ <field name="finalScore" type="string" indexed="true" stored="true" multiValued="false"/>
+ <field name="finalScoreFloat" type="float" indexed="true" stored="true" multiValued="false"/>
+ <field name="field" type="text_general" indexed="true" stored="false" multiValued="false"/>
<field name="title" type="text_general" indexed="true" stored="true"/>
<field name="description" type="text_general" indexed="true" stored="true"/>
<field name="keywords" type="text_general" indexed="true" stored="true" multiValued="true"/>
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java
index 8501944..e8a6942 100644
--- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java
@@ -25,24 +25,18 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
-import org.apache.lucene.document.Document;
-import org.apache.lucene.document.Field;
-import org.apache.lucene.document.FloatDocValuesField;
-import org.apache.lucene.index.IndexReader;
import org.apache.lucene.index.LeafReaderContext;
-import org.apache.lucene.index.RandomIndexWriter;
import org.apache.lucene.index.Term;
import org.apache.lucene.search.BooleanClause;
import org.apache.lucene.search.BooleanQuery;
import org.apache.lucene.search.Explanation;
-import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.Scorable;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.TermQuery;
import org.apache.lucene.search.TopDocs;
-import org.apache.lucene.store.Directory;
-import org.apache.solr.SolrTestCase;
+import org.apache.solr.SolrTestCaseJ4;
+import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.core.SolrResourceLoader;
import org.apache.solr.ltr.feature.Feature;
import org.apache.solr.ltr.feature.FieldValueFeature;
@@ -50,25 +44,23 @@ import org.apache.solr.ltr.model.LTRScoringModel;
import org.apache.solr.ltr.model.TestLinearModel;
import org.apache.solr.ltr.norm.IdentityNormalizer;
import org.apache.solr.ltr.norm.Normalizer;
+import org.apache.solr.request.LocalSolrQueryRequest;
+import org.apache.solr.request.SolrQueryRequest;
+import org.apache.solr.search.SolrIndexSearcher;
+import org.junit.BeforeClass;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
-public class TestLTRReRankingPipeline extends SolrTestCase {
+public class TestLTRReRankingPipeline extends SolrTestCaseJ4 {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private static final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(Paths.get("").toAbsolutePath());
- private IndexSearcher getSearcher(IndexReader r) {
- // 'yes' to maybe wrapping in general
- final boolean maybeWrap = true;
- final boolean wrapWithAssertions = false;
- // 'no' to asserting wrap because lucene AssertingWeight
- // cannot be cast to solr LTRScoringQuery$ModelWeight
- final IndexSearcher searcher = newSearcher(r, maybeWrap, wrapWithAssertions);
-
- return searcher;
+ @BeforeClass
+ public static void setup() throws Exception {
+ initCore("solrconfig-ltr.xml", "schema.xml");
}
private static List<Feature> makeFieldValueFeatures(int[] featureIds,
@@ -109,199 +101,156 @@ public class TestLTRReRankingPipeline extends SolrTestCase {
}
@Test
- public void testRescorer() throws IOException {
- final Directory dir = newDirectory();
- final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
-
- Document doc = new Document();
- doc.add(newStringField("id", "0", Field.Store.YES));
- doc.add(newTextField("field", "wizard the the the the the oz",
- Field.Store.NO));
- doc.add(newStringField("final-score", "F", Field.Store.YES)); // TODO: change to numeric field
-
- w.addDocument(doc);
- doc = new Document();
- doc.add(newStringField("id", "1", Field.Store.YES));
- // 1 extra token, but wizard and oz are close;
- doc.add(newTextField("field", "wizard oz the the the the the the",
- Field.Store.NO));
- doc.add(newStringField("final-score", "T", Field.Store.YES)); // TODO: change to numeric field
- w.addDocument(doc);
-
- final IndexReader r = w.getReader();
- w.close();
-
- // Do ordinary BooleanQuery:
- final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
- bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
- bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
- final IndexSearcher searcher = getSearcher(r);
- // first run the standard query
- TopDocs hits = searcher.search(bqBuilder.build(), 10);
- assertEquals(2, hits.totalHits.value);
- assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
- assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
-
- final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
- "final-score");
- final List<Normalizer> norms =
- new ArrayList<Normalizer>(
- Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
- final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
- 2, 3, 4, 5, 6, 7, 8, 9}, "final-score");
- final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
- features, norms, "test", allFeatures, TestLinearModel.makeFeatureWeights(features));
-
- final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
- hits = rescorer.rescore(searcher, hits, 2);
-
- // rerank using the field final-score
- assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id"));
- assertEquals("0", searcher.doc(hits.scoreDocs[1].doc).get("id"));
-
- r.close();
- dir.close();
-
+ public void testRescorer() throws Exception {
+ assertU(delQ("*:*"));
+ assertU(adoc("id", "0", "field", "wizard the the the the the oz", "finalScore", "F"));
+ assertU(adoc("id", "1", "field", "wizard oz the the the the the the", "finalScore", "T"));
+ assertU(commit());
+
+ try (SolrQueryRequest solrQueryRequest = new LocalSolrQueryRequest(h.getCore(), new ModifiableSolrParams())) {
+
+ final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
+ bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
+ bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
+ final SolrIndexSearcher searcher = solrQueryRequest.getSearcher();
+ // first run the standard query
+ TopDocs hits = searcher.search(bqBuilder.build(), 10);
+ assertEquals(2, hits.totalHits.value);
+ assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+ assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+
+ final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
+ "finalScore");
+ final List<Normalizer> norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
+ 2, 3, 4, 5, 6, 7, 8, 9}, "finalScore");
+ final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
+ features, norms, "test", allFeatures, TestLinearModel.makeFeatureWeights(features));
+
+ LTRScoringQuery ltrScoringQuery = new LTRScoringQuery(ltrScoringModel);
+ ltrScoringQuery.setRequest(solrQueryRequest);
+ final LTRRescorer rescorer = new LTRRescorer(ltrScoringQuery);
+ hits = rescorer.rescore(searcher, hits, 2);
+
+ // rerank using the field finalScore
+ assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+ assertEquals("0", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+ }
}
- @AwaitsFix(bugUrl = "https://issues.apache.org/jira/browse/SOLR-11134")
@Test
public void testDifferentTopN() throws IOException {
- final Directory dir = newDirectory();
- final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
-
- Document doc = new Document();
- doc.add(newStringField("id", "0", Field.Store.YES));
- doc.add(newTextField("field", "wizard oz oz oz oz oz", Field.Store.NO));
- doc.add(new FloatDocValuesField("final-score", 1.0f));
- w.addDocument(doc);
-
- doc = new Document();
- doc.add(newStringField("id", "1", Field.Store.YES));
- doc.add(newTextField("field", "wizard oz oz oz oz the", Field.Store.NO));
- doc.add(new FloatDocValuesField("final-score", 2.0f));
- w.addDocument(doc);
- doc = new Document();
- doc.add(newStringField("id", "2", Field.Store.YES));
- doc.add(newTextField("field", "wizard oz oz oz the the ", Field.Store.NO));
- doc.add(new FloatDocValuesField("final-score", 3.0f));
- w.addDocument(doc);
- doc = new Document();
- doc.add(newStringField("id", "3", Field.Store.YES));
- doc.add(newTextField("field", "wizard oz oz the the the the ",
- Field.Store.NO));
- doc.add(new FloatDocValuesField("final-score", 4.0f));
- w.addDocument(doc);
- doc = new Document();
- doc.add(newStringField("id", "4", Field.Store.YES));
- doc.add(newTextField("field", "wizard oz the the the the the the",
- Field.Store.NO));
- doc.add(new FloatDocValuesField("final-score", 5.0f));
- w.addDocument(doc);
-
- final IndexReader r = w.getReader();
- w.close();
-
- // Do ordinary BooleanQuery:
- final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
- bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
- bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
- final IndexSearcher searcher = getSearcher(r);
+ assertU(delQ("*:*"));
+ assertU(adoc("id", "0", "field", "wizard oz oz oz oz oz", "finalScoreFloat", "1.0"));
+ assertU(adoc("id", "1", "field", "wizard oz oz oz oz the", "finalScoreFloat", "2.0"));
+ assertU(adoc("id", "2", "field", "wizard oz oz oz the the ", "finalScoreFloat", "3.0"));
+ assertU(adoc("id", "3", "field", "wizard oz oz the the the the ", "finalScoreFloat", "4.0"));
+ assertU(adoc("id", "4", "field", "wizard oz the the the the the the", "finalScoreFloat", "5.0"));
+ assertU(commit());
+
+ try (SolrQueryRequest solrQueryRequest = new LocalSolrQueryRequest(h.getCore(), new ModifiableSolrParams())) {
+ // Do ordinary BooleanQuery:
+ final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
+ bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
+ bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
+ final SolrIndexSearcher searcher = solrQueryRequest.getSearcher();
+
+ // first run the standard query
+ TopDocs hits = searcher.search(bqBuilder.build(), 10);
+ assertEquals(5, hits.totalHits.value);
+
+ assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+ assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+ assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
+ assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
+ assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
+
+ final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
+ "finalScoreFloat");
+ final List<Normalizer> norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
+ 2, 3, 4, 5, 6, 7, 8, 9}, "finalScoreFloat");
+ final Double featureWeight = 0.1;
+ final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
+ features, norms, "test", allFeatures, TestLinearModel.makeFeatureWeights(features, featureWeight));
+
+ LTRScoringQuery scoringQuery = new LTRScoringQuery(ltrScoringModel);
+ scoringQuery.setRequest(solrQueryRequest);
+ final LTRRescorer rescorer = new LTRRescorer(scoringQuery);
+
+ // rerank @ 0 should not change the order
+ hits = rescorer.rescore(searcher, hits, 0);
+ assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+ assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+ assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
+ assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
+ assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
+
+ // test rerank with different topN cuts
+
+ for (int topN = 1; topN <= 5; topN++) {
+ log.info("rerank {} documents ", topN);
+ hits = searcher.search(bqBuilder.build(), 10);
+
+ final ScoreDoc[] slice = new ScoreDoc[topN];
+ System.arraycopy(hits.scoreDocs, 0, slice, 0, topN);
+ hits = new TopDocs(hits.totalHits, slice);
+ hits = rescorer.rescore(searcher, hits, topN);
+ for (int i = topN - 1, j = 0; i >= 0; i--, j++) {
+ if (log.isInfoEnabled()) {
+ log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc)
+ .get("id"), j);
+ }
+ assertEquals(i,
+ Integer.parseInt(searcher.doc(hits.scoreDocs[j].doc).get("id")));
+ assertEquals((i + 1) * features.size()*featureWeight, hits.scoreDocs[j].score, 0.00001);
- // first run the standard query
- TopDocs hits = searcher.search(bqBuilder.build(), 10);
- assertEquals(5, hits.totalHits.value);
-
- assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
- assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
- assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
- assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
- assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
-
- final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
- "final-score");
- final List<Normalizer> norms =
- new ArrayList<Normalizer>(
- Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
- final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
- 2, 3, 4, 5, 6, 7, 8, 9}, "final-score");
- final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
- features, norms, "test", allFeatures, TestLinearModel.makeFeatureWeights(features));
-
- final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
-
- // rerank @ 0 should not change the order
- hits = rescorer.rescore(searcher, hits, 0);
- assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
- assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
- assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
- assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
- assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
-
- // test rerank with different topN cuts
-
- for (int topN = 1; topN <= 5; topN++) {
- log.info("rerank {} documents ", topN);
- hits = searcher.search(bqBuilder.build(), 10);
-
- final ScoreDoc[] slice = new ScoreDoc[topN];
- System.arraycopy(hits.scoreDocs, 0, slice, 0, topN);
- hits = new TopDocs(hits.totalHits, slice);
- hits = rescorer.rescore(searcher, hits, topN);
- for (int i = topN - 1, j = 0; i >= 0; i--, j++) {
- if (log.isInfoEnabled()) {
- log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc)
- .get("id"), j);
}
-
- assertEquals(i,
- Integer.parseInt(searcher.doc(hits.scoreDocs[j].doc).get("id")));
- assertEquals(i + 1, hits.scoreDocs[j].score, 0.00001);
-
}
}
-
- r.close();
- dir.close();
-
}
@Test
public void testDocParam() throws Exception {
- final Map<String,Object> test = new HashMap<String,Object>();
- test.put("fake", 2);
- List<Feature> features = makeFieldValueFeatures(new int[] {0},
- "final-score");
- List<Normalizer> norms =
- new ArrayList<Normalizer>(
- Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
- List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0},
- "final-score");
- MockModel ltrScoringModel = new MockModel("test",
- features, norms, "test", allFeatures, null);
- LTRScoringQuery query = new LTRScoringQuery(ltrScoringModel);
- LTRScoringQuery.ModelWeight wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
- LTRScoringQuery.ModelWeight.ModelScorer modelScr = wgt.scorer(null);
- modelScr.getDocInfo().setOriginalDocScore(1f);
- for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
- assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
- }
+ try (SolrQueryRequest solrQueryRequest = new LocalSolrQueryRequest(h.getCore(), new ModifiableSolrParams())) {
+ List<Feature> features = makeFieldValueFeatures(new int[] {0},
+ "finalScore");
+ List<Normalizer> norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0},
+ "finalScore");
+ MockModel ltrScoringModel = new MockModel("test",
+ features, norms, "test", allFeatures, null);
+ LTRScoringQuery query = new LTRScoringQuery(ltrScoringModel);
+ query.setRequest(solrQueryRequest);
+ LTRScoringQuery.ModelWeight wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
+ LTRScoringQuery.ModelWeight.ModelScorer modelScr = wgt.scorer(null);
+ modelScr.getDocInfo().setOriginalDocScore(1f);
+ for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
+ assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
+ }
- features = makeFieldValueFeatures(new int[] {0, 1, 2}, "final-score");
- norms =
- new ArrayList<Normalizer>(
- Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
- allFeatures = makeFieldValueFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8,
- 9}, "final-score");
- ltrScoringModel = new MockModel("test", features, norms,
- "test", allFeatures, null);
- query = new LTRScoringQuery(ltrScoringModel);
- wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
- modelScr = wgt.scorer(null);
- modelScr.getDocInfo().setOriginalDocScore(1f);
- for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
- assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
+ features = makeFieldValueFeatures(new int[] {0, 1, 2}, "finalScore");
+ norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ allFeatures = makeFieldValueFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8,
+ 9}, "finalScore");
+ ltrScoringModel = new MockModel("test", features, norms,
+ "test", allFeatures, null);
+ query = new LTRScoringQuery(ltrScoringModel);
+ query.setRequest(solrQueryRequest);
+ wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
+ modelScr = wgt.scorer(null);
+ modelScr.getDocInfo().setOriginalDocScore(1f);
+ for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
+ assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
+ }
}
}
-
}
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java
index df03896..f528af3 100644
--- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java
@@ -47,10 +47,14 @@ public class TestLinearModel extends TestRerankBase {
}
public static Map<String,Object> makeFeatureWeights(List<Feature> features) {
+ return makeFeatureWeights(features, 0.1);
+ }
+
+ public static Map<String,Object> makeFeatureWeights(List<Feature> features, Number weight) {
final Map<String,Object> nameParams = new HashMap<String,Object>();
- final HashMap<String,Double> modelWeights = new HashMap<String,Double>();
+ final HashMap<String,Number> modelWeights = new HashMap<String,Number>();
for (final Feature feat : features) {
- modelWeights.put(feat.getName(), 0.1);
+ modelWeights.put(feat.getName(), weight);
}
nameParams.put("weights", modelWeights);
return nameParams;