You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by cp...@apache.org on 2021/05/25 15:01:54 UTC

[lucene-solr] branch branch_8x updated: SOLR-11134: restructure TestLTRReRankingPipeline and fix testDifferentTopN test (#145)

This is an automated email from the ASF dual-hosted git repository.

cpoerschke pushed a commit to branch branch_8x
in repository https://gitbox.apache.org/repos/asf/lucene-solr.git


The following commit(s) were added to refs/heads/branch_8x by this push:
     new eb8b1b5  SOLR-11134: restructure TestLTRReRankingPipeline and fix testDifferentTopN test (#145)
eb8b1b5 is described below

commit eb8b1b5a1e864e262f9619ed2697074c45d9daf3
Author: Christine Poerschke <cp...@apache.org>
AuthorDate: Tue May 25 14:09:02 2021 +0100

    SOLR-11134: restructure TestLTRReRankingPipeline and fix testDifferentTopN test (#145)
    
    (Stanislav Livotov, Tom Gilke, Christine Poerschke)
---
 .../test-files/solr/collection1/conf/schema.xml    |   3 +
 .../apache/solr/ltr/TestLTRReRankingPipeline.java  | 345 +++++++++------------
 .../org/apache/solr/ltr/model/TestLinearModel.java |   8 +-
 3 files changed, 156 insertions(+), 200 deletions(-)

diff --git a/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml b/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml
index 4699b0f..b6f5b3b 100644
--- a/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml
+++ b/solr/contrib/ltr/src/test-files/solr/collection1/conf/schema.xml
@@ -19,6 +19,9 @@
 <schema name="example" version="1.5">
   <fields>
     <field name="id" type="string" indexed="true" stored="true" required="true" multiValued="false" />
+    <field name="finalScore" type="string" indexed="true" stored="true" multiValued="false"/>
+    <field name="finalScoreFloat" type="float" indexed="true" stored="true" multiValued="false"/>
+    <field name="field" type="text_general" indexed="true" stored="false" multiValued="false"/>
     <field name="title" type="text_general" indexed="true" stored="true"/>
     <field name="description" type="text_general" indexed="true" stored="true"/>
     <field name="keywords" type="text_general" indexed="true" stored="true" multiValued="true"/>
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java
index 8501944..e8a6942 100644
--- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java
@@ -25,24 +25,18 @@ import java.util.HashMap;
 import java.util.List;
 import java.util.Map;
 
-import org.apache.lucene.document.Document;
-import org.apache.lucene.document.Field;
-import org.apache.lucene.document.FloatDocValuesField;
-import org.apache.lucene.index.IndexReader;
 import org.apache.lucene.index.LeafReaderContext;
-import org.apache.lucene.index.RandomIndexWriter;
 import org.apache.lucene.index.Term;
 import org.apache.lucene.search.BooleanClause;
 import org.apache.lucene.search.BooleanQuery;
 import org.apache.lucene.search.Explanation;
-import org.apache.lucene.search.IndexSearcher;
 import org.apache.lucene.search.Scorable;
 import org.apache.lucene.search.ScoreDoc;
 import org.apache.lucene.search.ScoreMode;
 import org.apache.lucene.search.TermQuery;
 import org.apache.lucene.search.TopDocs;
-import org.apache.lucene.store.Directory;
-import org.apache.solr.SolrTestCase;
+import org.apache.solr.SolrTestCaseJ4;
+import org.apache.solr.common.params.ModifiableSolrParams;
 import org.apache.solr.core.SolrResourceLoader;
 import org.apache.solr.ltr.feature.Feature;
 import org.apache.solr.ltr.feature.FieldValueFeature;
@@ -50,25 +44,23 @@ import org.apache.solr.ltr.model.LTRScoringModel;
 import org.apache.solr.ltr.model.TestLinearModel;
 import org.apache.solr.ltr.norm.IdentityNormalizer;
 import org.apache.solr.ltr.norm.Normalizer;
+import org.apache.solr.request.LocalSolrQueryRequest;
+import org.apache.solr.request.SolrQueryRequest;
+import org.apache.solr.search.SolrIndexSearcher;
+import org.junit.BeforeClass;
 import org.junit.Test;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
-public class TestLTRReRankingPipeline extends SolrTestCase {
+public class TestLTRReRankingPipeline extends SolrTestCaseJ4 {
 
   private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
 
   private static final SolrResourceLoader solrResourceLoader = new SolrResourceLoader(Paths.get("").toAbsolutePath());
 
-  private IndexSearcher getSearcher(IndexReader r) {
-    // 'yes' to maybe wrapping in general
-    final boolean maybeWrap = true;
-    final boolean wrapWithAssertions = false;
-     // 'no' to asserting wrap because lucene AssertingWeight
-     // cannot be cast to solr LTRScoringQuery$ModelWeight
-    final IndexSearcher searcher = newSearcher(r, maybeWrap, wrapWithAssertions);
-
-    return searcher;
+  @BeforeClass
+  public static void setup() throws Exception {
+    initCore("solrconfig-ltr.xml", "schema.xml");
   }
 
   private static List<Feature> makeFieldValueFeatures(int[] featureIds,
@@ -109,199 +101,156 @@ public class TestLTRReRankingPipeline extends SolrTestCase {
   }
 
   @Test
-  public void testRescorer() throws IOException {
-    final Directory dir = newDirectory();
-    final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
-
-    Document doc = new Document();
-    doc.add(newStringField("id", "0", Field.Store.YES));
-    doc.add(newTextField("field", "wizard the the the the the oz",
-        Field.Store.NO));
-    doc.add(newStringField("final-score", "F", Field.Store.YES)); // TODO: change to numeric field
-
-    w.addDocument(doc);
-    doc = new Document();
-    doc.add(newStringField("id", "1", Field.Store.YES));
-    // 1 extra token, but wizard and oz are close;
-    doc.add(newTextField("field", "wizard oz the the the the the the",
-        Field.Store.NO));
-    doc.add(newStringField("final-score", "T", Field.Store.YES)); // TODO: change to numeric field
-    w.addDocument(doc);
-
-    final IndexReader r = w.getReader();
-    w.close();
-
-    // Do ordinary BooleanQuery:
-    final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
-    bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
-    bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
-    final IndexSearcher searcher = getSearcher(r);
-    // first run the standard query
-    TopDocs hits = searcher.search(bqBuilder.build(), 10);
-    assertEquals(2, hits.totalHits.value);
-    assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
-    assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
-
-    final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
-        "final-score");
-    final List<Normalizer> norms =
-        new ArrayList<Normalizer>(
-            Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
-    final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
-        2, 3, 4, 5, 6, 7, 8, 9}, "final-score");
-    final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
-        features, norms, "test", allFeatures, TestLinearModel.makeFeatureWeights(features));
-
-    final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
-    hits = rescorer.rescore(searcher, hits, 2);
-
-    // rerank using the field final-score
-    assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id"));
-    assertEquals("0", searcher.doc(hits.scoreDocs[1].doc).get("id"));
-
-    r.close();
-    dir.close();
-
+  public void testRescorer() throws Exception {
+    assertU(delQ("*:*"));
+    assertU(adoc("id", "0", "field", "wizard the the the the the oz", "finalScore", "F"));
+    assertU(adoc("id", "1", "field", "wizard oz the the the the the the", "finalScore", "T"));
+    assertU(commit());
+
+    try (SolrQueryRequest solrQueryRequest = new LocalSolrQueryRequest(h.getCore(), new ModifiableSolrParams())) {
+
+      final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
+      bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
+      bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
+      final SolrIndexSearcher searcher = solrQueryRequest.getSearcher();
+      // first run the standard query
+      TopDocs hits = searcher.search(bqBuilder.build(), 10);
+      assertEquals(2, hits.totalHits.value);
+      assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+      assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+
+      final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
+              "finalScore");
+      final List<Normalizer> norms =
+              new ArrayList<Normalizer>(
+                      Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+      final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
+              2, 3, 4, 5, 6, 7, 8, 9}, "finalScore");
+      final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
+              features, norms, "test", allFeatures, TestLinearModel.makeFeatureWeights(features));
+
+      LTRScoringQuery ltrScoringQuery = new LTRScoringQuery(ltrScoringModel);
+      ltrScoringQuery.setRequest(solrQueryRequest);
+      final LTRRescorer rescorer = new LTRRescorer(ltrScoringQuery);
+      hits = rescorer.rescore(searcher, hits, 2);
+
+      // rerank using the field finalScore
+      assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+      assertEquals("0", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+    }
   }
 
-  @AwaitsFix(bugUrl = "https://issues.apache.org/jira/browse/SOLR-11134")
   @Test
   public void testDifferentTopN() throws IOException {
-    final Directory dir = newDirectory();
-    final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
-
-    Document doc = new Document();
-    doc.add(newStringField("id", "0", Field.Store.YES));
-    doc.add(newTextField("field", "wizard oz oz oz oz oz", Field.Store.NO));
-    doc.add(new FloatDocValuesField("final-score", 1.0f));
-    w.addDocument(doc);
-
-    doc = new Document();
-    doc.add(newStringField("id", "1", Field.Store.YES));
-    doc.add(newTextField("field", "wizard oz oz oz oz the", Field.Store.NO));
-    doc.add(new FloatDocValuesField("final-score", 2.0f));
-    w.addDocument(doc);
-    doc = new Document();
-    doc.add(newStringField("id", "2", Field.Store.YES));
-    doc.add(newTextField("field", "wizard oz oz oz the the ", Field.Store.NO));
-    doc.add(new FloatDocValuesField("final-score", 3.0f));
-    w.addDocument(doc);
-    doc = new Document();
-    doc.add(newStringField("id", "3", Field.Store.YES));
-    doc.add(newTextField("field", "wizard oz oz the the the the ",
-        Field.Store.NO));
-    doc.add(new FloatDocValuesField("final-score", 4.0f));
-    w.addDocument(doc);
-    doc = new Document();
-    doc.add(newStringField("id", "4", Field.Store.YES));
-    doc.add(newTextField("field", "wizard oz the the the the the the",
-        Field.Store.NO));
-    doc.add(new FloatDocValuesField("final-score", 5.0f));
-    w.addDocument(doc);
-
-    final IndexReader r = w.getReader();
-    w.close();
-
-    // Do ordinary BooleanQuery:
-    final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
-    bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
-    bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
-    final IndexSearcher searcher = getSearcher(r);
+    assertU(delQ("*:*"));
+    assertU(adoc("id", "0", "field", "wizard oz oz oz oz oz", "finalScoreFloat", "1.0"));
+    assertU(adoc("id", "1", "field", "wizard oz oz oz oz the", "finalScoreFloat", "2.0"));
+    assertU(adoc("id", "2", "field", "wizard oz oz oz the the ", "finalScoreFloat", "3.0"));
+    assertU(adoc("id", "3", "field", "wizard oz oz the the the the ", "finalScoreFloat", "4.0"));
+    assertU(adoc("id", "4", "field", "wizard oz the the the the the the", "finalScoreFloat", "5.0"));
+    assertU(commit());
+
+    try (SolrQueryRequest solrQueryRequest = new LocalSolrQueryRequest(h.getCore(), new ModifiableSolrParams())) {
+      // Do ordinary BooleanQuery:
+      final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
+      bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
+      bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
+      final SolrIndexSearcher searcher = solrQueryRequest.getSearcher();
+
+      // first run the standard query
+      TopDocs hits = searcher.search(bqBuilder.build(), 10);
+      assertEquals(5, hits.totalHits.value);
+
+      assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+      assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+      assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
+      assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
+      assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
+
+      final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
+              "finalScoreFloat");
+      final List<Normalizer> norms =
+              new ArrayList<Normalizer>(
+                      Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+      final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
+              2, 3, 4, 5, 6, 7, 8, 9}, "finalScoreFloat");
+      final Double featureWeight = 0.1;
+      final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
+              features, norms, "test", allFeatures, TestLinearModel.makeFeatureWeights(features, featureWeight));
+
+      LTRScoringQuery scoringQuery = new LTRScoringQuery(ltrScoringModel);
+      scoringQuery.setRequest(solrQueryRequest);
+      final LTRRescorer rescorer = new LTRRescorer(scoringQuery);
+
+      // rerank @ 0 should not change the order
+      hits = rescorer.rescore(searcher, hits, 0);
+      assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+      assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+      assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
+      assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
+      assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
+
+      // test rerank with different topN cuts
+
+      for (int topN = 1; topN <= 5; topN++) {
+        log.info("rerank {} documents ", topN);
+        hits = searcher.search(bqBuilder.build(), 10);
+
+        final ScoreDoc[] slice = new ScoreDoc[topN];
+        System.arraycopy(hits.scoreDocs, 0, slice, 0, topN);
+        hits = new TopDocs(hits.totalHits, slice);
+        hits = rescorer.rescore(searcher, hits, topN);
+        for (int i = topN - 1, j = 0; i >= 0; i--, j++) {
+          if (log.isInfoEnabled()) {
+            log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc)
+                .get("id"), j);
+          }
+          assertEquals(i,
+                  Integer.parseInt(searcher.doc(hits.scoreDocs[j].doc).get("id")));
+          assertEquals((i + 1) * features.size()*featureWeight, hits.scoreDocs[j].score, 0.00001);
 
-    // first run the standard query
-    TopDocs hits = searcher.search(bqBuilder.build(), 10);
-    assertEquals(5, hits.totalHits.value);
-
-    assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
-    assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
-    assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
-    assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
-    assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
-
-    final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
-        "final-score");
-    final List<Normalizer> norms =
-        new ArrayList<Normalizer>(
-            Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
-    final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
-        2, 3, 4, 5, 6, 7, 8, 9}, "final-score");
-    final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
-        features, norms, "test", allFeatures, TestLinearModel.makeFeatureWeights(features));
-
-    final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
-
-    // rerank @ 0 should not change the order
-    hits = rescorer.rescore(searcher, hits, 0);
-    assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
-    assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
-    assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
-    assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
-    assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
-
-    // test rerank with different topN cuts
-
-    for (int topN = 1; topN <= 5; topN++) {
-      log.info("rerank {} documents ", topN);
-      hits = searcher.search(bqBuilder.build(), 10);
-
-      final ScoreDoc[] slice = new ScoreDoc[topN];
-      System.arraycopy(hits.scoreDocs, 0, slice, 0, topN);
-      hits = new TopDocs(hits.totalHits, slice);
-      hits = rescorer.rescore(searcher, hits, topN);
-      for (int i = topN - 1, j = 0; i >= 0; i--, j++) {
-        if (log.isInfoEnabled()) {
-          log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc)
-              .get("id"), j);
         }
-
-        assertEquals(i,
-            Integer.parseInt(searcher.doc(hits.scoreDocs[j].doc).get("id")));
-        assertEquals(i + 1, hits.scoreDocs[j].score, 0.00001);
-
       }
     }
-
-    r.close();
-    dir.close();
-
   }
 
   @Test
   public void testDocParam() throws Exception {
-    final Map<String,Object> test = new HashMap<String,Object>();
-    test.put("fake", 2);
-    List<Feature> features = makeFieldValueFeatures(new int[] {0},
-        "final-score");
-    List<Normalizer> norms =
-        new ArrayList<Normalizer>(
-            Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
-    List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0},
-        "final-score");
-    MockModel ltrScoringModel = new MockModel("test",
-        features, norms, "test", allFeatures, null);
-    LTRScoringQuery query = new LTRScoringQuery(ltrScoringModel);
-    LTRScoringQuery.ModelWeight wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
-    LTRScoringQuery.ModelWeight.ModelScorer modelScr = wgt.scorer(null);
-    modelScr.getDocInfo().setOriginalDocScore(1f);
-    for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
-      assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
-    }
+    try (SolrQueryRequest solrQueryRequest = new LocalSolrQueryRequest(h.getCore(), new ModifiableSolrParams())) {
+      List<Feature> features = makeFieldValueFeatures(new int[] {0},
+              "finalScore");
+      List<Normalizer> norms =
+              new ArrayList<Normalizer>(
+                      Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+      List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0},
+              "finalScore");
+      MockModel ltrScoringModel = new MockModel("test",
+              features, norms, "test", allFeatures, null);
+      LTRScoringQuery query = new LTRScoringQuery(ltrScoringModel);
+      query.setRequest(solrQueryRequest);
+      LTRScoringQuery.ModelWeight wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
+      LTRScoringQuery.ModelWeight.ModelScorer modelScr = wgt.scorer(null);
+      modelScr.getDocInfo().setOriginalDocScore(1f);
+      for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
+        assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
+      }
 
-    features = makeFieldValueFeatures(new int[] {0, 1, 2}, "final-score");
-    norms =
-        new ArrayList<Normalizer>(
-            Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
-    allFeatures = makeFieldValueFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8,
-        9}, "final-score");
-    ltrScoringModel = new MockModel("test", features, norms,
-        "test", allFeatures, null);
-    query = new LTRScoringQuery(ltrScoringModel);
-    wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
-    modelScr = wgt.scorer(null);
-    modelScr.getDocInfo().setOriginalDocScore(1f);
-    for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
-      assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
+      features = makeFieldValueFeatures(new int[] {0, 1, 2}, "finalScore");
+      norms =
+              new ArrayList<Normalizer>(
+                      Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+      allFeatures = makeFieldValueFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8,
+              9}, "finalScore");
+      ltrScoringModel = new MockModel("test", features, norms,
+              "test", allFeatures, null);
+      query = new LTRScoringQuery(ltrScoringModel);
+      query.setRequest(solrQueryRequest);
+      wgt = query.createWeight(null, ScoreMode.COMPLETE, 1f);
+      modelScr = wgt.scorer(null);
+      modelScr.getDocInfo().setOriginalDocScore(1f);
+      for (final Scorable.ChildScorable feat : modelScr.getChildren()) {
+        assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
+      }
     }
   }
-
 }
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java
index df03896..f528af3 100644
--- a/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/model/TestLinearModel.java
@@ -47,10 +47,14 @@ public class TestLinearModel extends TestRerankBase {
   }
 
   public static Map<String,Object> makeFeatureWeights(List<Feature> features) {
+    return makeFeatureWeights(features, 0.1);
+  }
+
+  public static Map<String,Object> makeFeatureWeights(List<Feature> features, Number weight) {
     final Map<String,Object> nameParams = new HashMap<String,Object>();
-    final HashMap<String,Double> modelWeights = new HashMap<String,Double>();
+    final HashMap<String,Number> modelWeights = new HashMap<String,Number>();
     for (final Feature feat : features) {
-      modelWeights.put(feat.getName(), 0.1);
+      modelWeights.put(feat.getName(), weight);
     }
     nameParams.put("weights", modelWeights);
     return nameParams;