You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@lucene.apache.org by kr...@apache.org on 2016/11/03 14:40:52 UTC
[05/20] lucene-solr:jira/solr-8593: SOLR-8542: Adds Solr Learning to
Rank (LTR) plugin for reranking results with machine learning models.
(Michael Nilsson, Diego Ceccarelli, Joshua Pantony, Jon Dorando,
Naveen Santhapuri, Alessandro Benedetti, David Gro
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5a66b3bc/solr/contrib/ltr/src/test-files/solr/solr.xml
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/test-files/solr/solr.xml b/solr/contrib/ltr/src/test-files/solr/solr.xml
new file mode 100644
index 0000000..c8c3ebe
--- /dev/null
+++ b/solr/contrib/ltr/src/test-files/solr/solr.xml
@@ -0,0 +1,42 @@
+<?xml version="1.0" encoding="UTF-8" ?>
+<!--
+ Licensed to the Apache Software Foundation (ASF) under one or more
+ contributor license agreements. See the NOTICE file distributed with
+ this work for additional information regarding copyright ownership.
+ The ASF licenses this file to You under the Apache License, Version 2.0
+ (the "License"); you may not use this file except in compliance with
+ the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
+-->
+
+<solr>
+
+ <str name="shareSchema">${shareSchema:false}</str>
+ <str name="configSetBaseDir">${configSetBaseDir:configsets}</str>
+ <str name="coreRootDirectory">${coreRootDirectory:.}</str>
+
+ <shardHandlerFactory name="shardHandlerFactory" class="HttpShardHandlerFactory">
+ <str name="urlScheme">${urlScheme:}</str>
+ <int name="socketTimeout">${socketTimeout:90000}</int>
+ <int name="connTimeout">${connTimeout:15000}</int>
+ </shardHandlerFactory>
+
+ <solrcloud>
+ <str name="host">127.0.0.1</str>
+ <int name="hostPort">${hostPort:8983}</int>
+ <str name="hostContext">${hostContext:solr}</str>
+ <int name="zkClientTimeout">${solr.zkclienttimeout:30000}</int>
+ <bool name="genericCoreNodeNames">${genericCoreNodeNames:true}</bool>
+ <int name="leaderVoteWait">${leaderVoteWait:10000}</int>
+ <int name="distribUpdateConnTimeout">${distribUpdateConnTimeout:45000}</int>
+ <int name="distribUpdateSoTimeout">${distribUpdateSoTimeout:340000}</int>
+ </solrcloud>
+
+</solr>
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5a66b3bc/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java
new file mode 100644
index 0000000..2e01a64
--- /dev/null
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTROnSolrCloud.java
@@ -0,0 +1,211 @@
+/* * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import java.io.File;
+import java.util.SortedMap;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.solr.client.solrj.SolrQuery;
+import org.apache.solr.client.solrj.embedded.JettyConfig;
+import org.apache.solr.client.solrj.request.CollectionAdminRequest;
+import org.apache.solr.client.solrj.response.CollectionAdminResponse;
+import org.apache.solr.client.solrj.response.QueryResponse;
+import org.apache.solr.cloud.AbstractDistribZkTestBase;
+import org.apache.solr.cloud.MiniSolrCloudCluster;
+import org.apache.solr.common.SolrInputDocument;
+import org.apache.solr.common.cloud.ZkStateReader;
+import org.apache.solr.ltr.feature.SolrFeature;
+import org.apache.solr.ltr.feature.ValueFeature;
+import org.apache.solr.ltr.model.LinearModel;
+import org.eclipse.jetty.servlet.ServletHolder;
+import org.junit.AfterClass;
+import org.junit.Test;
+
+public class TestLTROnSolrCloud extends TestRerankBase {
+
+ private MiniSolrCloudCluster solrCluster;
+ String solrconfig = "solrconfig-ltr.xml";
+ String schema = "schema.xml";
+
+ SortedMap<ServletHolder,String> extraServlets = null;
+
+ @Override
+ public void setUp() throws Exception {
+ super.setUp();
+ extraServlets = setupTestInit(solrconfig, schema, true);
+ System.setProperty("enable.update.log", "true");
+
+ int numberOfShards = random().nextInt(4)+1;
+ int numberOfReplicas = random().nextInt(2)+1;
+ int maxShardsPerNode = numberOfShards+random().nextInt(4)+1;
+
+ int numberOfNodes = numberOfShards * maxShardsPerNode;
+
+ setupSolrCluster(numberOfShards, numberOfReplicas, numberOfNodes, maxShardsPerNode);
+
+
+ }
+
+
+ @Override
+ public void tearDown() throws Exception {
+ restTestHarness.close();
+ restTestHarness = null;
+ jetty.stop();
+ jetty = null;
+ solrCluster.shutdown();
+ super.tearDown();
+ }
+
+ @Test
+ public void testSimpleQuery() throws Exception {
+ // will randomly pick a configuration with [1..5] shards and [1..3] replicas
+
+ // Test regular query, it will sort the documents by inverse
+ // popularity (the less popular, docid == 1, will be in the first
+ // position
+ SolrQuery query = new SolrQuery("{!func}sub(8,field(popularity))");
+
+ query.setRequestHandler("/query");
+ query.setFields("*,score");
+ query.setParam("rows", "8");
+
+ QueryResponse queryResponse =
+ solrCluster.getSolrClient().query(COLLECTION,query);
+ assertEquals(8, queryResponse.getResults().getNumFound());
+ assertEquals("1", queryResponse.getResults().get(0).get("id").toString());
+ assertEquals("2", queryResponse.getResults().get(1).get("id").toString());
+ assertEquals("3", queryResponse.getResults().get(2).get("id").toString());
+ assertEquals("4", queryResponse.getResults().get(3).get("id").toString());
+
+ // Test re-rank and feature vectors returned
+ query.setFields("*,score,features:[fv]");
+ query.add("rq", "{!ltr model=powpularityS-model reRankDocs=8}");
+ queryResponse =
+ solrCluster.getSolrClient().query(COLLECTION,query);
+ assertEquals(8, queryResponse.getResults().getNumFound());
+ assertEquals("8", queryResponse.getResults().get(0).get("id").toString());
+ assertEquals("powpularityS:64.0;c3:2.0",
+ queryResponse.getResults().get(0).get("features").toString());
+ assertEquals("7", queryResponse.getResults().get(1).get("id").toString());
+ assertEquals("powpularityS:49.0;c3:2.0",
+ queryResponse.getResults().get(1).get("features").toString());
+ assertEquals("6", queryResponse.getResults().get(2).get("id").toString());
+ assertEquals("powpularityS:36.0;c3:2.0",
+ queryResponse.getResults().get(2).get("features").toString());
+ assertEquals("5", queryResponse.getResults().get(3).get("id").toString());
+ assertEquals("powpularityS:25.0;c3:2.0",
+ queryResponse.getResults().get(3).get("features").toString());
+ }
+
+ private void setupSolrCluster(int numShards, int numReplicas, int numServers, int maxShardsPerNode) throws Exception {
+ JettyConfig jc = buildJettyConfig("/solr");
+ jc = JettyConfig.builder(jc).withServlets(extraServlets).build();
+ solrCluster = new MiniSolrCloudCluster(numServers, tmpSolrHome.toPath(), jc);
+ File configDir = tmpSolrHome.toPath().resolve("collection1/conf").toFile();
+ solrCluster.uploadConfigSet(configDir.toPath(), "conf1");
+
+ solrCluster.getSolrClient().setDefaultCollection(COLLECTION);
+
+ createCollection(COLLECTION, "conf1", numShards, numReplicas, maxShardsPerNode);
+ indexDocuments(COLLECTION);
+
+ createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema,
+ "/solr", true, extraServlets);
+ loadModelsAndFeatures();
+ }
+
+
+ private void createCollection(String name, String config, int numShards, int numReplicas, int maxShardsPerNode)
+ throws Exception {
+ CollectionAdminResponse response;
+ CollectionAdminRequest.Create create =
+ CollectionAdminRequest.createCollection(name, config, numShards, numReplicas);
+ create.setMaxShardsPerNode(maxShardsPerNode);
+ response = create.process(solrCluster.getSolrClient());
+
+ if (response.getStatus() != 0 || response.getErrorMessages() != null) {
+ fail("Could not create collection. Response" + response.toString());
+ }
+ ZkStateReader zkStateReader = solrCluster.getSolrClient().getZkStateReader();
+ AbstractDistribZkTestBase.waitForRecoveriesToFinish(name, zkStateReader, false, true, 100);
+ }
+
+
+ void indexDocument(String collection, String id, String title, String description, int popularity)
+ throws Exception{
+ SolrInputDocument doc = new SolrInputDocument();
+ doc.setField("id", id);
+ doc.setField("title", title);
+ doc.setField("description", description);
+ doc.setField("popularity", popularity);
+ solrCluster.getSolrClient().add(collection, doc);
+ }
+
+ private void indexDocuments(final String collection)
+ throws Exception {
+ final int collectionSize = 8;
+ for (int docId = 1; docId <= collectionSize; docId++) {
+ final int popularity = docId;
+ indexDocument(collection, String.valueOf(docId), "a1", "bloom", popularity);
+ }
+ solrCluster.getSolrClient().commit(collection);
+ }
+
+
+ private void loadModelsAndFeatures() throws Exception {
+ final String featureStore = "test";
+ final String[] featureNames = new String[] {"powpularityS","c3"};
+ final String jsonModelParams = "{\"weights\":{\"powpularityS\":1.0,\"c3\":1.0}}";
+
+ loadFeature(
+ featureNames[0],
+ SolrFeature.class.getCanonicalName(),
+ featureStore,
+ "{\"q\":\"{!func}pow(popularity,2)\"}"
+ );
+ loadFeature(
+ featureNames[1],
+ ValueFeature.class.getCanonicalName(),
+ featureStore,
+ "{\"value\":2}"
+ );
+
+ loadModel(
+ "powpularityS-model",
+ LinearModel.class.getCanonicalName(),
+ featureNames,
+ featureStore,
+ jsonModelParams
+ );
+ reloadCollection(COLLECTION);
+ }
+
+ private void reloadCollection(String collection) throws Exception {
+ CollectionAdminRequest.Reload reloadRequest = CollectionAdminRequest.reloadCollection(collection);
+ CollectionAdminResponse response = reloadRequest.process(solrCluster.getSolrClient());
+ assertEquals(0, response.getStatus());
+ assertTrue(response.isSuccess());
+ }
+
+ @AfterClass
+ public static void after() throws Exception {
+ FileUtils.deleteDirectory(tmpSolrHome);
+ System.clearProperty("managed.schema.mutable");
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5a66b3bc/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java
new file mode 100644
index 0000000..2f90df8
--- /dev/null
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserExplain.java
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import org.apache.solr.client.solrj.SolrQuery;
+import org.apache.solr.ltr.model.LinearModel;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class TestLTRQParserExplain extends TestRerankBase {
+
+ @BeforeClass
+ public static void setup() throws Exception {
+ setuptest();
+ loadFeatures("features-store-test-model.json");
+ }
+
+ @AfterClass
+ public static void after() throws Exception {
+ aftertest();
+ }
+
+
+ @Test
+ public void testRerankedExplain() throws Exception {
+ loadModel("linear2", LinearModel.class.getCanonicalName(), new String[] {
+ "constant1", "constant2", "pop"},
+ "{\"weights\":{\"pop\":1.0,\"constant1\":1.5,\"constant2\":3.5}}");
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("title:bloomberg");
+ query.setParam("debugQuery", "on");
+ query.add("rows", "2");
+ query.add("rq", "{!ltr reRankDocs=2 model=linear2}");
+ query.add("fl", "*,score");
+
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/debug/explain/9=='\n13.5 = LinearModel(name=linear2,featureWeights=[constant1=1.5,constant2=3.5,pop=1.0]) model applied to features, sum of:\n 1.5 = prod of:\n 1.5 = weight on feature\n 1.0 = ValueFeature [name=constant1, params={value=1}]\n 7.0 = prod of:\n 3.5 = weight on feature\n 2.0 = ValueFeature [name=constant2, params={value=2}]\n 5.0 = prod of:\n 1.0 = weight on feature\n 5.0 = FieldValueFeature [name=pop, params={field=popularity}]\n'");
+ }
+
+ @Test
+ public void testRerankedExplainSameBetweenDifferentDocsWithSameFeatures() throws Exception {
+ loadFeatures("features-linear.json");
+ loadModels("linear-model.json");
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("title:bloomberg");
+ query.setParam("debugQuery", "on");
+ query.add("rows", "4");
+ query.add("rq", "{!ltr reRankDocs=4 model=6029760550880411648}");
+ query.add("fl", "*,score");
+ query.add("wt", "json");
+ final String expectedExplainNormalizer = "normalized using MinMaxNormalizer(min=0.0,max=10.0)";
+ final String expectedExplain = "\n3.5116758 = LinearModel(name=6029760550880411648,featureWeights=["
+ + "title=0.0,"
+ + "description=0.1,"
+ + "keywords=0.2,"
+ + "popularity=0.3,"
+ + "text=0.4,"
+ + "queryIntentPerson=0.1231231,"
+ + "queryIntentCompany=0.12121211"
+ + "]) model applied to features, sum of:\n 0.0 = prod of:\n 0.0 = weight on feature\n 1.0 = ValueFeature [name=title, params={value=1}]\n 0.2 = prod of:\n 0.1 = weight on feature\n 2.0 = ValueFeature [name=description, params={value=2}]\n 0.4 = prod of:\n 0.2 = weight on feature\n 2.0 = ValueFeature [name=keywords, params={value=2}]\n 0.09 = prod of:\n 0.3 = weight on feature\n 0.3 = "+expectedExplainNormalizer+"\n 3.0 = ValueFeature [name=popularity, params={value=3}]\n 1.6 = prod of:\n 0.4 = weight on feature\n 4.0 = ValueFeature [name=text, params={value=4}]\n 0.6156155 = prod of:\n 0.1231231 = weight on feature\n 5.0 = ValueFeature [name=queryIntentPerson, params={value=5}]\n 0.60606056 = prod of:\n 0.12121211 = weight on feature\n 5.0 = ValueFeature [name=queryIntentCompany, params={value=5}]\n";
+
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/debug/explain/7=='"+expectedExplain+"'}");
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/debug/explain/9=='"+expectedExplain+"'}");
+ }
+
+ @Test
+ public void LinearScoreExplainMissingEfiFeatureShouldReturnDefaultScore() throws Exception {
+ loadFeatures("features-linear-efi.json");
+ loadModels("linear-model-efi.json");
+
+ SolrQuery query = new SolrQuery();
+ query.setQuery("title:bloomberg");
+ query.setParam("debugQuery", "on");
+ query.add("rows", "4");
+ query.add("rq", "{!ltr reRankDocs=4 model=linear-efi}");
+ query.add("fl", "*,score");
+ query.add("wt", "xml");
+
+ final String linearModelEfiString = "LinearModel(name=linear-efi,featureWeights=["
+ + "sampleConstant=1.0,"
+ + "search_number_of_nights=2.0])";
+
+ query.remove("wt");
+ query.add("wt", "json");
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/debug/explain/7=='\n5.0 = "+linearModelEfiString+" model applied to features, sum of:\n 5.0 = prod of:\n 1.0 = weight on feature\n 5.0 = ValueFeature [name=sampleConstant, params={value=5}]\n" +
+ " 0.0 = prod of:\n" +
+ " 2.0 = weight on feature\n" +
+ " 0.0 = The feature has no value\n'}");
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/debug/explain/9=='\n5.0 = "+linearModelEfiString+" model applied to features, sum of:\n 5.0 = prod of:\n 1.0 = weight on feature\n 5.0 = ValueFeature [name=sampleConstant, params={value=5}]\n" +
+ " 0.0 = prod of:\n" +
+ " 2.0 = weight on feature\n" +
+ " 0.0 = The feature has no value\n'}");
+ }
+
+ @Test
+ public void multipleAdditiveTreesScoreExplainMissingEfiFeatureShouldReturnDefaultScore() throws Exception {
+ loadFeatures("external_features_for_sparse_processing.json");
+ loadModels("multipleadditivetreesmodel_external_binary_features.json");
+
+ SolrQuery query = new SolrQuery();
+ query.setQuery("title:bloomberg");
+ query.setParam("debugQuery", "on");
+ query.add("rows", "4");
+ query.add("rq", "{!ltr reRankDocs=4 model=external_model_binary_feature efi.user_device_tablet=1}");
+ query.add("fl", "*,score");
+
+ final String tree1 = "(weight=1.0,root=(feature=user_device_smartphone,threshold=0.5,left=0.0,right=50.0))";
+ final String tree2 = "(weight=1.0,root=(feature=user_device_tablet,threshold=0.5,left=0.0,right=65.0))";
+ final String trees = "["+tree1+","+tree2+"]";
+
+ query.add("wt", "json");
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/debug/explain/7=='\n" +
+ "65.0 = MultipleAdditiveTreesModel(name=external_model_binary_feature,trees="+trees+") model applied to features, sum of:\n" +
+ " 0.0 = tree 0 | \\'user_device_smartphone\\':0.0 <= 0.500001, Go Left | val: 0.0\n" +
+ " 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}");
+ assertJQ(
+ "/query" + query.toQueryString(),
+ "/debug/explain/9=='\n" +
+ "65.0 = MultipleAdditiveTreesModel(name=external_model_binary_feature,trees="+trees+") model applied to features, sum of:\n" +
+ " 0.0 = tree 0 | \\'user_device_smartphone\\':0.0 <= 0.500001, Go Left | val: 0.0\n" +
+ " 65.0 = tree 1 | \\'user_device_tablet\\':1.0 > 0.500001, Go Right | val: 65.0\n'}");
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5a66b3bc/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java
new file mode 100644
index 0000000..f28ab0d
--- /dev/null
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRQParserPlugin.java
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import org.apache.solr.client.solrj.SolrQuery;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class TestLTRQParserPlugin extends TestRerankBase {
+
+
+ @BeforeClass
+ public static void before() throws Exception {
+ setuptest("solrconfig-ltr.xml", "schema.xml");
+ // store = getModelStore();
+ bulkIndex();
+
+ loadFeatures("features-linear.json");
+ loadModels("linear-model.json");
+ }
+
+ @AfterClass
+ public static void after() throws Exception {
+ aftertest();
+ // store.clear();
+ }
+
+ @Test
+ public void ltrModelIdMissingTest() throws Exception {
+ final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
+ final SolrQuery query = new SolrQuery();
+ query.setQuery(solrQuery);
+ query.add("fl", "*, score");
+ query.add("rows", "4");
+ query.add("fv", "true");
+ query.add("rq", "{!ltr reRankDocs=100}");
+
+ final String res = restTestHarness.query("/query" + query.toQueryString());
+ assert (res.contains("Must provide model in the request"));
+ }
+
+ @Test
+ public void ltrModelIdDoesNotExistTest() throws Exception {
+ final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
+ final SolrQuery query = new SolrQuery();
+ query.setQuery(solrQuery);
+ query.add("fl", "*, score");
+ query.add("rows", "4");
+ query.add("fv", "true");
+ query.add("rq", "{!ltr model=-1 reRankDocs=100}");
+
+ final String res = restTestHarness.query("/query" + query.toQueryString());
+ assert (res.contains("cannot find model"));
+ }
+
+ @Test
+ public void ltrMoreResultsThanReRankedTest() throws Exception {
+ final String solrQuery = "_query_:{!edismax qf='title' mm=100% v='bloomberg' tie=0.1}";
+ final SolrQuery query = new SolrQuery();
+ query.setQuery(solrQuery);
+ query.add("fl", "*, score");
+ query.add("rows", "4");
+ query.add("fv", "true");
+
+ String nonRerankedScore = "0.09271725";
+
+ // Normal solr order
+ assertJQ("/query" + query.toQueryString(),
+ "/response/docs/[0]/id=='9'",
+ "/response/docs/[1]/id=='8'",
+ "/response/docs/[2]/id=='7'",
+ "/response/docs/[3]/id=='6'",
+ "/response/docs/[3]/score=="+nonRerankedScore
+ );
+
+ query.add("rq", "{!ltr model=6029760550880411648 reRankDocs=3}");
+
+ // Different order for top 3 reranked, but last one is the same top nonreranked doc
+ assertJQ("/query" + query.toQueryString(),
+ "/response/docs/[0]/id=='7'",
+ "/response/docs/[1]/id=='8'",
+ "/response/docs/[2]/id=='9'",
+ "/response/docs/[3]/id=='6'",
+ "/response/docs/[3]/score=="+nonRerankedScore
+ );
+ }
+
+ @Test
+ public void ltrNoResultsTest() throws Exception {
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("title:bloomberg23");
+ query.add("fl", "*,[fv]");
+ query.add("rows", "3");
+ query.add("debugQuery", "on");
+ query.add("rq", "{!ltr reRankDocs=3 model=6029760550880411648}");
+ assertJQ("/query" + query.toQueryString(), "/response/numFound/==0");
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5a66b3bc/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java
new file mode 100644
index 0000000..a98fc4f
--- /dev/null
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRReRankingPipeline.java
@@ -0,0 +1,300 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import java.io.IOException;
+import java.lang.invoke.MethodHandles;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.FloatDocValuesField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.RandomIndexWriter;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.Explanation;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.ScoreDoc;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.LuceneTestCase;
+import org.apache.solr.core.SolrResourceLoader;
+import org.apache.solr.ltr.feature.Feature;
+import org.apache.solr.ltr.feature.FieldValueFeature;
+import org.apache.solr.ltr.model.LTRScoringModel;
+import org.apache.solr.ltr.model.TestLinearModel;
+import org.apache.solr.ltr.norm.IdentityNormalizer;
+import org.apache.solr.ltr.norm.Normalizer;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class TestLTRReRankingPipeline extends LuceneTestCase {
+
+ private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
+
+ private static final SolrResourceLoader solrResourceLoader = new SolrResourceLoader();
+
+ private IndexSearcher getSearcher(IndexReader r) {
+ final IndexSearcher searcher = newSearcher(r);
+
+ return searcher;
+ }
+
+ private static List<Feature> makeFieldValueFeatures(int[] featureIds,
+ String field) {
+ final List<Feature> features = new ArrayList<>();
+ for (final int i : featureIds) {
+ final Map<String,Object> params = new HashMap<String,Object>();
+ params.put("field", field);
+ final Feature f = Feature.getInstance(solrResourceLoader,
+ FieldValueFeature.class.getCanonicalName(),
+ "f" + i, params);
+ f.setIndex(i);
+ features.add(f);
+ }
+ return features;
+ }
+
+ private class MockModel extends LTRScoringModel {
+
+ public MockModel(String name, List<Feature> features,
+ List<Normalizer> norms,
+ String featureStoreName, List<Feature> allFeatures,
+ Map<String,Object> params) {
+ super(name, features, norms, featureStoreName, allFeatures, params);
+ }
+
+ @Override
+ public float score(float[] modelFeatureValuesNormalized) {
+ return modelFeatureValuesNormalized[2];
+ }
+
+ @Override
+ public Explanation explain(LeafReaderContext context, int doc,
+ float finalScore, List<Explanation> featureExplanations) {
+ return null;
+ }
+
+ }
+
+ @Ignore
+ @Test
+ public void testRescorer() throws IOException {
+ final Directory dir = newDirectory();
+ final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
+
+ Document doc = new Document();
+ doc.add(newStringField("id", "0", Field.Store.YES));
+ doc.add(newTextField("field", "wizard the the the the the oz",
+ Field.Store.NO));
+ doc.add(new FloatDocValuesField("final-score", 1.0f));
+
+ w.addDocument(doc);
+ doc = new Document();
+ doc.add(newStringField("id", "1", Field.Store.YES));
+ // 1 extra token, but wizard and oz are close;
+ doc.add(newTextField("field", "wizard oz the the the the the the",
+ Field.Store.NO));
+ doc.add(new FloatDocValuesField("final-score", 2.0f));
+ w.addDocument(doc);
+
+ final IndexReader r = w.getReader();
+ w.close();
+
+ // Do ordinary BooleanQuery:
+ final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
+ bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
+ bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
+ final IndexSearcher searcher = getSearcher(r);
+ // first run the standard query
+ TopDocs hits = searcher.search(bqBuilder.build(), 10);
+ assertEquals(2, hits.totalHits);
+ assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+ assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+
+ final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
+ "final-score");
+ final List<Normalizer> norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
+ 2, 3, 4, 5, 6, 7, 8, 9}, "final-score");
+ final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
+ features, norms, "test", allFeatures, null);
+
+ final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
+ hits = rescorer.rescore(searcher, hits, 2);
+
+ // rerank using the field final-score
+ assertEquals("1", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+ assertEquals("0", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+
+ r.close();
+ dir.close();
+
+ }
+
+ @Ignore
+ @Test
+ public void testDifferentTopN() throws IOException {
+ final Directory dir = newDirectory();
+ final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
+
+ Document doc = new Document();
+ doc.add(newStringField("id", "0", Field.Store.YES));
+ doc.add(newTextField("field", "wizard oz oz oz oz oz", Field.Store.NO));
+ doc.add(new FloatDocValuesField("final-score", 1.0f));
+ w.addDocument(doc);
+
+ doc = new Document();
+ doc.add(newStringField("id", "1", Field.Store.YES));
+ doc.add(newTextField("field", "wizard oz oz oz oz the", Field.Store.NO));
+ doc.add(new FloatDocValuesField("final-score", 2.0f));
+ w.addDocument(doc);
+ doc = new Document();
+ doc.add(newStringField("id", "2", Field.Store.YES));
+ doc.add(newTextField("field", "wizard oz oz oz the the ", Field.Store.NO));
+ doc.add(new FloatDocValuesField("final-score", 3.0f));
+ w.addDocument(doc);
+ doc = new Document();
+ doc.add(newStringField("id", "3", Field.Store.YES));
+ doc.add(newTextField("field", "wizard oz oz the the the the ",
+ Field.Store.NO));
+ doc.add(new FloatDocValuesField("final-score", 4.0f));
+ w.addDocument(doc);
+ doc = new Document();
+ doc.add(newStringField("id", "4", Field.Store.YES));
+ doc.add(newTextField("field", "wizard oz the the the the the the",
+ Field.Store.NO));
+ doc.add(new FloatDocValuesField("final-score", 5.0f));
+ w.addDocument(doc);
+
+ final IndexReader r = w.getReader();
+ w.close();
+
+ // Do ordinary BooleanQuery:
+ final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
+ bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
+ bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
+ final IndexSearcher searcher = getSearcher(r);
+
+ // first run the standard query
+ TopDocs hits = searcher.search(bqBuilder.build(), 10);
+ assertEquals(5, hits.totalHits);
+
+ assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+ assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+ assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
+ assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
+ assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
+
+ final List<Feature> features = makeFieldValueFeatures(new int[] {0, 1, 2},
+ "final-score");
+ final List<Normalizer> norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ final List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0, 1,
+ 2, 3, 4, 5, 6, 7, 8, 9}, "final-score");
+ final LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
+ features, norms, "test", allFeatures, null);
+
+ final LTRRescorer rescorer = new LTRRescorer(new LTRScoringQuery(ltrScoringModel));
+
+ // rerank @ 0 should not change the order
+ hits = rescorer.rescore(searcher, hits, 0);
+ assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+ assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+ assertEquals("2", searcher.doc(hits.scoreDocs[2].doc).get("id"));
+ assertEquals("3", searcher.doc(hits.scoreDocs[3].doc).get("id"));
+ assertEquals("4", searcher.doc(hits.scoreDocs[4].doc).get("id"));
+
+ // test rerank with different topN cuts
+
+ for (int topN = 1; topN <= 5; topN++) {
+ log.info("rerank {} documents ", topN);
+ hits = searcher.search(bqBuilder.build(), 10);
+
+ final ScoreDoc[] slice = new ScoreDoc[topN];
+ System.arraycopy(hits.scoreDocs, 0, slice, 0, topN);
+ hits = new TopDocs(hits.totalHits, slice, hits.getMaxScore());
+ hits = rescorer.rescore(searcher, hits, topN);
+ for (int i = topN - 1, j = 0; i >= 0; i--, j++) {
+ log.info("doc {} in pos {}", searcher.doc(hits.scoreDocs[j].doc)
+ .get("id"), j);
+
+ assertEquals(i,
+ Integer.parseInt(searcher.doc(hits.scoreDocs[j].doc).get("id")));
+ assertEquals(i + 1, hits.scoreDocs[j].score, 0.00001);
+
+ }
+ }
+
+ r.close();
+ dir.close();
+
+ }
+
+ @Test
+ public void testDocParam() throws Exception {
+ final Map<String,Object> test = new HashMap<String,Object>();
+ test.put("fake", 2);
+ List<Feature> features = makeFieldValueFeatures(new int[] {0},
+ "final-score");
+ List<Normalizer> norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ List<Feature> allFeatures = makeFieldValueFeatures(new int[] {0},
+ "final-score");
+ MockModel ltrScoringModel = new MockModel("test",
+ features, norms, "test", allFeatures, null);
+ LTRScoringQuery query = new LTRScoringQuery(ltrScoringModel);
+ LTRScoringQuery.ModelWeight wgt = query.createWeight(null, true, 1f);
+ LTRScoringQuery.ModelWeight.ModelScorer modelScr = wgt.scorer(null);
+ modelScr.getDocInfo().setOriginalDocScore(new Float(1f));
+ for (final Scorer.ChildScorer feat : modelScr.getChildren()) {
+ assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
+ }
+
+ features = makeFieldValueFeatures(new int[] {0, 1, 2}, "final-score");
+ norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ allFeatures = makeFieldValueFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8,
+ 9}, "final-score");
+ ltrScoringModel = new MockModel("test", features, norms,
+ "test", allFeatures, null);
+ query = new LTRScoringQuery(ltrScoringModel);
+ wgt = query.createWeight(null, true, 1f);
+ modelScr = wgt.scorer(null);
+ modelScr.getDocInfo().setOriginalDocScore(new Float(1f));
+ for (final Scorer.ChildScorer feat : modelScr.getChildren()) {
+ assertNotNull(((Feature.FeatureWeight.FeatureScorer) feat.child).getDocInfo().getOriginalDocScore());
+ }
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5a66b3bc/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java
new file mode 100644
index 0000000..0576c99
--- /dev/null
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRScoringQuery.java
@@ -0,0 +1,319 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+import org.apache.lucene.document.Document;
+import org.apache.lucene.document.Field;
+import org.apache.lucene.document.FloatDocValuesField;
+import org.apache.lucene.index.IndexReader;
+import org.apache.lucene.index.LeafReaderContext;
+import org.apache.lucene.index.RandomIndexWriter;
+import org.apache.lucene.index.ReaderUtil;
+import org.apache.lucene.index.Term;
+import org.apache.lucene.search.BooleanClause;
+import org.apache.lucene.search.BooleanQuery;
+import org.apache.lucene.search.IndexSearcher;
+import org.apache.lucene.search.Scorer;
+import org.apache.lucene.search.TermQuery;
+import org.apache.lucene.search.TopDocs;
+import org.apache.lucene.search.Weight;
+import org.apache.lucene.store.Directory;
+import org.apache.lucene.util.LuceneTestCase;
+import org.apache.solr.core.SolrResourceLoader;
+import org.apache.solr.ltr.feature.Feature;
+import org.apache.solr.ltr.feature.ValueFeature;
+import org.apache.solr.ltr.model.LTRScoringModel;
+import org.apache.solr.ltr.model.ModelException;
+import org.apache.solr.ltr.model.TestLinearModel;
+import org.apache.solr.ltr.norm.IdentityNormalizer;
+import org.apache.solr.ltr.norm.Normalizer;
+import org.apache.solr.ltr.norm.NormalizerException;
+import org.junit.Test;
+
+public class TestLTRScoringQuery extends LuceneTestCase {
+
+ public final static SolrResourceLoader solrResourceLoader = new SolrResourceLoader();
+
+ private IndexSearcher getSearcher(IndexReader r) {
+ final IndexSearcher searcher = newSearcher(r, false, false);
+ return searcher;
+ }
+
+ private static List<Feature> makeFeatures(int[] featureIds) {
+ final List<Feature> features = new ArrayList<>();
+ for (final int i : featureIds) {
+ Map<String,Object> params = new HashMap<String,Object>();
+ params.put("value", i);
+ final Feature f = Feature.getInstance(solrResourceLoader,
+ ValueFeature.class.getCanonicalName(),
+ "f" + i, params);
+ f.setIndex(i);
+ features.add(f);
+ }
+ return features;
+ }
+
+ private static List<Feature> makeFilterFeatures(int[] featureIds) {
+ final List<Feature> features = new ArrayList<>();
+ for (final int i : featureIds) {
+ Map<String,Object> params = new HashMap<String,Object>();
+ params.put("value", i);
+ final Feature f = Feature.getInstance(solrResourceLoader,
+ ValueFeature.class.getCanonicalName(),
+ "f" + i, params);
+ f.setIndex(i);
+ features.add(f);
+ }
+ return features;
+ }
+
+ private static Map<String,Object> makeFeatureWeights(List<Feature> features) {
+ final Map<String,Object> nameParams = new HashMap<String,Object>();
+ final HashMap<String,Double> modelWeights = new HashMap<String,Double>();
+ for (final Feature feat : features) {
+ modelWeights.put(feat.getName(), 0.1);
+ }
+ nameParams.put("weights", modelWeights);
+ return nameParams;
+ }
+
+ private LTRScoringQuery.ModelWeight performQuery(TopDocs hits,
+ IndexSearcher searcher, int docid, LTRScoringQuery model) throws IOException,
+ ModelException {
+ final List<LeafReaderContext> leafContexts = searcher.getTopReaderContext()
+ .leaves();
+ final int n = ReaderUtil.subIndex(hits.scoreDocs[0].doc, leafContexts);
+ final LeafReaderContext context = leafContexts.get(n);
+ final int deBasedDoc = hits.scoreDocs[0].doc - context.docBase;
+
+ final Weight weight = searcher.createNormalizedWeight(model, true);
+ final Scorer scorer = weight.scorer(context);
+
+ // rerank using the field final-score
+ scorer.iterator().advance(deBasedDoc);
+ scorer.score();
+
+ // assertEquals(42.0f, score, 0.0001);
+ // assertTrue(weight instanceof AssertingWeight);
+ // (AssertingIndexSearcher)
+ assertTrue(weight instanceof LTRScoringQuery.ModelWeight);
+ final LTRScoringQuery.ModelWeight modelWeight = (LTRScoringQuery.ModelWeight) weight;
+ return modelWeight;
+
+ }
+
+ @Test
+ public void testLTRScoringQueryEquality() throws ModelException {
+ final List<Feature> features = makeFeatures(new int[] {0, 1, 2});
+ final List<Normalizer> norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ final List<Feature> allFeatures = makeFeatures(
+ new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9});
+ final Map<String,Object> modelParams = makeFeatureWeights(features);
+
+ final LTRScoringModel algorithm1 = TestLinearModel.createLinearModel(
+ "testModelName",
+ features, norms, "testStoreName", allFeatures, modelParams);
+
+ final LTRScoringQuery m0 = new LTRScoringQuery(algorithm1);
+
+ final HashMap<String,String[]> externalFeatureInfo = new HashMap<>();
+ externalFeatureInfo.put("queryIntent", new String[] {"company"});
+ externalFeatureInfo.put("user_query", new String[] {"abc"});
+ final LTRScoringQuery m1 = new LTRScoringQuery(algorithm1, externalFeatureInfo, false, null);
+
+ final HashMap<String,String[]> externalFeatureInfo2 = new HashMap<>();
+ externalFeatureInfo2.put("user_query", new String[] {"abc"});
+ externalFeatureInfo2.put("queryIntent", new String[] {"company"});
+ int totalPoolThreads = 10, numThreadsPerRequest = 10;
+ LTRThreadModule threadManager = new LTRThreadModule(totalPoolThreads, numThreadsPerRequest);
+ final LTRScoringQuery m2 = new LTRScoringQuery(algorithm1, externalFeatureInfo2, false, threadManager);
+
+
+ // Models with same algorithm and efis, just in different order should be the same
+ assertEquals(m1, m2);
+ assertEquals(m1.hashCode(), m2.hashCode());
+
+ // Models with same algorithm, but different efi content should not match
+ assertFalse(m1.equals(m0));
+ assertFalse(m1.hashCode() == m0.hashCode());
+
+
+ final LTRScoringModel algorithm2 = TestLinearModel.createLinearModel(
+ "testModelName2",
+ features, norms, "testStoreName", allFeatures, modelParams);
+ final LTRScoringQuery m3 = new LTRScoringQuery(algorithm2);
+
+ assertFalse(m1.equals(m3));
+ assertFalse(m1.hashCode() == m3.hashCode());
+
+ final LTRScoringModel algorithm3 = TestLinearModel.createLinearModel(
+ "testModelName",
+ features, norms, "testStoreName3", allFeatures, modelParams);
+ final LTRScoringQuery m4 = new LTRScoringQuery(algorithm3);
+
+ assertFalse(m1.equals(m4));
+ assertFalse(m1.hashCode() == m4.hashCode());
+ }
+
+
+ @Test
+ public void testLTRScoringQuery() throws IOException, ModelException {
+ final Directory dir = newDirectory();
+ final RandomIndexWriter w = new RandomIndexWriter(random(), dir);
+
+ Document doc = new Document();
+ doc.add(newStringField("id", "0", Field.Store.YES));
+ doc.add(newTextField("field", "wizard the the the the the oz",
+ Field.Store.NO));
+ doc.add(new FloatDocValuesField("final-score", 1.0f));
+
+ w.addDocument(doc);
+ doc = new Document();
+ doc.add(newStringField("id", "1", Field.Store.YES));
+ // 1 extra token, but wizard and oz are close;
+ doc.add(newTextField("field", "wizard oz the the the the the the",
+ Field.Store.NO));
+ doc.add(new FloatDocValuesField("final-score", 2.0f));
+ w.addDocument(doc);
+
+ final IndexReader r = w.getReader();
+ w.close();
+
+ // Do ordinary BooleanQuery:
+ final BooleanQuery.Builder bqBuilder = new BooleanQuery.Builder();
+ bqBuilder.add(new TermQuery(new Term("field", "wizard")), BooleanClause.Occur.SHOULD);
+ bqBuilder.add(new TermQuery(new Term("field", "oz")), BooleanClause.Occur.SHOULD);
+ final IndexSearcher searcher = getSearcher(r);
+ // first run the standard query
+ final TopDocs hits = searcher.search(bqBuilder.build(), 10);
+ assertEquals(2, hits.totalHits);
+ assertEquals("0", searcher.doc(hits.scoreDocs[0].doc).get("id"));
+ assertEquals("1", searcher.doc(hits.scoreDocs[1].doc).get("id"));
+
+ List<Feature> features = makeFeatures(new int[] {0, 1, 2});
+ final List<Feature> allFeatures = makeFeatures(new int[] {0, 1, 2, 3, 4, 5,
+ 6, 7, 8, 9});
+ List<Normalizer> norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ LTRScoringModel ltrScoringModel = TestLinearModel.createLinearModel("test",
+ features, norms, "test", allFeatures,
+ makeFeatureWeights(features));
+
+ LTRScoringQuery.ModelWeight modelWeight = performQuery(hits, searcher,
+ hits.scoreDocs[0].doc, new LTRScoringQuery(ltrScoringModel));
+ assertEquals(3, modelWeight.getModelFeatureValuesNormalized().length);
+
+ for (int i = 0; i < 3; i++) {
+ assertEquals(i, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
+ }
+ int[] posVals = new int[] {0, 1, 2};
+ int pos = 0;
+ for (LTRScoringQuery.FeatureInfo fInfo:modelWeight.getFeaturesInfo()) {
+ if (fInfo == null){
+ continue;
+ }
+ assertEquals(posVals[pos], fInfo.getValue(), 0.0001);
+ assertEquals("f"+posVals[pos], fInfo.getName());
+ pos++;
+ }
+
+ final int[] mixPositions = new int[] {8, 2, 4, 9, 0};
+ features = makeFeatures(mixPositions);
+ norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ ltrScoringModel = TestLinearModel.createLinearModel("test",
+ features, norms, "test", allFeatures, makeFeatureWeights(features));
+
+ modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc,
+ new LTRScoringQuery(ltrScoringModel));
+ assertEquals(mixPositions.length,
+ modelWeight.getModelFeatureWeights().length);
+
+ for (int i = 0; i < mixPositions.length; i++) {
+ assertEquals(mixPositions[i],
+ modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
+ }
+
+ final ModelException expectedModelException = new ModelException("no features declared for model test");
+ final int[] noPositions = new int[] {};
+ features = makeFeatures(noPositions);
+ norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),IdentityNormalizer.INSTANCE));
+ try {
+ ltrScoringModel = TestLinearModel.createLinearModel("test",
+ features, norms, "test", allFeatures, makeFeatureWeights(features));
+ fail("unexpectedly got here instead of catching "+expectedModelException);
+ modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc,
+ new LTRScoringQuery(ltrScoringModel));
+ assertEquals(0, modelWeight.getModelFeatureWeights().length);
+ } catch (ModelException actualModelException) {
+ assertEquals(expectedModelException.toString(), actualModelException.toString());
+ }
+
+ // test normalizers
+ features = makeFilterFeatures(mixPositions);
+ final Normalizer norm = new Normalizer() {
+
+ @Override
+ public float normalize(float value) {
+ return 42.42f;
+ }
+
+ @Override
+ public LinkedHashMap<String,Object> paramsToMap() {
+ return null;
+ }
+
+ @Override
+ protected void validate() throws NormalizerException {
+ }
+
+ };
+ norms =
+ new ArrayList<Normalizer>(
+ Collections.nCopies(features.size(),norm));
+ final LTRScoringModel normMeta = TestLinearModel.createLinearModel("test",
+ features, norms, "test", allFeatures,
+ makeFeatureWeights(features));
+
+ modelWeight = performQuery(hits, searcher, hits.scoreDocs[0].doc,
+ new LTRScoringQuery(normMeta));
+ normMeta.normalizeFeaturesInPlace(modelWeight.getModelFeatureValuesNormalized());
+ assertEquals(mixPositions.length,
+ modelWeight.getModelFeatureWeights().length);
+ for (int i = 0; i < mixPositions.length; i++) {
+ assertEquals(42.42f, modelWeight.getModelFeatureValuesNormalized()[i], 0.0001);
+ }
+ r.close();
+ dir.close();
+
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5a66b3bc/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithFacet.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithFacet.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithFacet.java
new file mode 100644
index 0000000..ab519ec
--- /dev/null
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithFacet.java
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.ltr;
+
+import org.apache.solr.client.solrj.SolrQuery;
+import org.apache.solr.ltr.feature.SolrFeature;
+import org.apache.solr.ltr.model.LinearModel;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class TestLTRWithFacet extends TestRerankBase {
+
+ @BeforeClass
+ public static void before() throws Exception {
+ setuptest("solrconfig-ltr.xml", "schema.xml");
+
+ assertU(adoc("id", "1", "title", "a1", "description", "E", "popularity",
+ "1"));
+ assertU(adoc("id", "2", "title", "a1 b1", "description",
+ "B", "popularity", "2"));
+ assertU(adoc("id", "3", "title", "a1 b1 c1", "description", "B", "popularity",
+ "3"));
+ assertU(adoc("id", "4", "title", "a1 b1 c1 d1", "description", "B", "popularity",
+ "4"));
+ assertU(adoc("id", "5", "title", "a1 b1 c1 d1 e1", "description", "E", "popularity",
+ "5"));
+ assertU(adoc("id", "6", "title", "a1 b1 c1 d1 e1 f1", "description", "B",
+ "popularity", "6"));
+ assertU(adoc("id", "7", "title", "a1 b1 c1 d1 e1 f1 g1", "description",
+ "C", "popularity", "7"));
+ assertU(adoc("id", "8", "title", "a1 b1 c1 d1 e1 f1 g1 h1", "description",
+ "D", "popularity", "8"));
+ assertU(commit());
+ }
+
+ @Test
+ public void testRankingSolrFacet() throws Exception {
+ // before();
+ loadFeature("powpularityS", SolrFeature.class.getCanonicalName(),
+ "{\"q\":\"{!func}pow(popularity,2)\"}");
+
+ loadModel("powpularityS-model", LinearModel.class.getCanonicalName(),
+ new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}");
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("title:a1");
+ query.add("fl", "*, score");
+ query.add("rows", "4");
+ query.add("facet", "true");
+ query.add("facet.field", "description");
+
+ assertJQ("/query" + query.toQueryString(), "/response/numFound/==8");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='4'");
+ // Normal term match
+ assertJQ("/query" + query.toQueryString(), ""
+ + "/facet_counts/facet_fields/description=="
+ + "['b', 4, 'e', 2, 'c', 1, 'd', 1]");
+
+ query.add("rq", "{!ltr model=powpularityS-model reRankDocs=4}");
+ query.set("debugQuery", "on");
+
+ assertJQ("/query" + query.toQueryString(), "/response/numFound/==8");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='4'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==16.0");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='3'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==9.0");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='2'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==4.0");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==1.0");
+
+ assertJQ("/query" + query.toQueryString(), ""
+ + "/facet_counts/facet_fields/description=="
+ + "['b', 4, 'e', 2, 'c', 1, 'd', 1]");
+ // aftertest();
+
+ }
+
+ @AfterClass
+ public static void after() throws Exception {
+ aftertest();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5a66b3bc/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java
new file mode 100644
index 0000000..1fbe1d5
--- /dev/null
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestLTRWithSort.java
@@ -0,0 +1,102 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.solr.ltr;
+
+import org.apache.solr.client.solrj.SolrQuery;
+import org.apache.solr.ltr.feature.SolrFeature;
+import org.apache.solr.ltr.model.LinearModel;
+import org.junit.AfterClass;
+import org.junit.BeforeClass;
+import org.junit.Test;
+
+public class TestLTRWithSort extends TestRerankBase {
+
+ @BeforeClass
+ public static void before() throws Exception {
+ setuptest("solrconfig-ltr.xml", "schema.xml");
+ assertU(adoc("id", "1", "title", "a1", "description", "E", "popularity",
+ "1"));
+ assertU(adoc("id", "2", "title", "a1 b1", "description",
+ "B", "popularity", "2"));
+ assertU(adoc("id", "3", "title", "a1 b1 c1", "description", "B", "popularity",
+ "3"));
+ assertU(adoc("id", "4", "title", "a1 b1 c1 d1", "description", "B", "popularity",
+ "4"));
+ assertU(adoc("id", "5", "title", "a1 b1 c1 d1 e1", "description", "E", "popularity",
+ "5"));
+ assertU(adoc("id", "6", "title", "a1 b1 c1 d1 e1 f1", "description", "B",
+ "popularity", "6"));
+ assertU(adoc("id", "7", "title", "a1 b1 c1 d1 e1 f1 g1", "description",
+ "C", "popularity", "7"));
+ assertU(adoc("id", "8", "title", "a1 b1 c1 d1 e1 f1 g1 h1", "description",
+ "D", "popularity", "8"));
+ assertU(commit());
+ }
+
+ @Test
+ public void testRankingSolrSort() throws Exception {
+ // before();
+ loadFeature("powpularityS", SolrFeature.class.getCanonicalName(),
+ "{\"q\":\"{!func}pow(popularity,2)\"}");
+
+ loadModel("powpularityS-model", LinearModel.class.getCanonicalName(),
+ new String[] {"powpularityS"}, "{\"weights\":{\"powpularityS\":1.0}}");
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("title:a1");
+ query.add("fl", "*, score");
+ query.add("rows", "4");
+
+ // Normal term match
+ assertJQ("/query" + query.toQueryString(), "/response/numFound/==8");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='2'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='3'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='4'");
+
+ //Add sort
+ query.add("sort", "description desc");
+ assertJQ("/query" + query.toQueryString(), "/response/numFound/==8");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='5'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='8'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='7'");
+
+ query.add("rq", "{!ltr model=powpularityS-model reRankDocs=4}");
+ query.set("debugQuery", "on");
+
+ assertJQ("/query" + query.toQueryString(), "/response/numFound/==8");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='8'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==64.0");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='7'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==49.0");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='5'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==25.0");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/id=='1'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[3]/score==1.0");
+
+ // aftertest();
+
+ }
+
+ @AfterClass
+ public static void after() throws Exception {
+ aftertest();
+ }
+
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5a66b3bc/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestParallelWeightCreation.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestParallelWeightCreation.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestParallelWeightCreation.java
new file mode 100644
index 0000000..f4c21fd
--- /dev/null
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestParallelWeightCreation.java
@@ -0,0 +1,77 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import org.apache.solr.client.solrj.SolrQuery;
+import org.junit.Test;
+
+public class TestParallelWeightCreation extends TestRerankBase{
+
+ @Test
+ public void testLTRScoringQueryParallelWeightCreationResultOrder() throws Exception {
+ setuptest("solrconfig-ltr_Th10_10.xml", "schema.xml");
+
+ assertU(adoc("id", "1", "title", "w1 w3", "description", "w1", "popularity",
+ "1"));
+ assertU(adoc("id", "2", "title", "w2", "description", "w2", "popularity",
+ "2"));
+ assertU(adoc("id", "3", "title", "w3", "description", "w3", "popularity",
+ "3"));
+ assertU(adoc("id", "4", "title", "w4 w3", "description", "w4", "popularity",
+ "4"));
+ assertU(adoc("id", "5", "title", "w5", "description", "w5", "popularity",
+ "5"));
+ assertU(commit());
+
+ loadFeatures("external_features.json");
+ loadModels("external_model.json");
+ loadModels("external_model_store.json");
+
+ // check to make sure that the order of results will be the same when using parallel weight creation
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("fl", "*,score");
+ query.add("rows", "4");
+
+ query.add("rq", "{!ltr reRankDocs=4 model=externalmodel efi.user_query=w3}");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='1'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/id=='3'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/id=='4'");
+ aftertest();
+ }
+
+ @Test
+ public void testLTRQParserThreadInitialization() throws Exception {
+ // setting the value of number of threads to -ve should throw an exception
+ String msg1 = null;
+ try{
+ new LTRThreadModule(1,-1);
+ }catch(IllegalArgumentException iae){
+ msg1 = iae.getMessage();;
+ }
+ assertTrue(msg1.equals("numThreadsPerRequest cannot be less than 1"));
+
+ // set totalPoolThreads to 1 and numThreadsPerRequest to 2 and verify that an exception is thrown
+ String msg2 = null;
+ try{
+ new LTRThreadModule(1,2);
+ }catch(IllegalArgumentException iae){
+ msg2 = iae.getMessage();
+ }
+ assertTrue(msg2.equals("numThreadsPerRequest cannot be greater than totalPoolThreads"));
+ }
+}
http://git-wip-us.apache.org/repos/asf/lucene-solr/blob/5a66b3bc/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java
----------------------------------------------------------------------
diff --git a/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java
new file mode 100644
index 0000000..4914d28
--- /dev/null
+++ b/solr/contrib/ltr/src/test/org/apache/solr/ltr/TestRerankBase.java
@@ -0,0 +1,429 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.solr.ltr;
+
+import java.io.File;
+import java.io.FileNotFoundException;
+import java.io.IOException;
+import java.lang.invoke.MethodHandles;
+import java.net.URL;
+import java.nio.file.Files;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.Scanner;
+import java.util.SortedMap;
+import java.util.TreeMap;
+
+import org.apache.commons.io.FileUtils;
+import org.apache.commons.lang.StringUtils;
+import org.apache.solr.common.params.CommonParams;
+import org.apache.solr.common.util.ContentStream;
+import org.apache.solr.common.util.ContentStreamBase;
+import org.apache.solr.core.SolrResourceLoader;
+import org.apache.solr.ltr.feature.Feature;
+import org.apache.solr.ltr.feature.FeatureException;
+import org.apache.solr.ltr.feature.ValueFeature;
+import org.apache.solr.ltr.model.LTRScoringModel;
+import org.apache.solr.ltr.model.LinearModel;
+import org.apache.solr.ltr.model.ModelException;
+import org.apache.solr.ltr.store.rest.ManagedFeatureStore;
+import org.apache.solr.ltr.store.rest.ManagedModelStore;
+import org.apache.solr.request.SolrQueryRequestBase;
+import org.apache.solr.response.SolrQueryResponse;
+import org.apache.solr.rest.ManagedResourceStorage;
+import org.apache.solr.rest.SolrSchemaRestApi;
+import org.apache.solr.util.RestTestBase;
+import org.eclipse.jetty.servlet.ServletHolder;
+import org.noggit.ObjectBuilder;
+import org.restlet.ext.servlet.ServerServlet;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+public class TestRerankBase extends RestTestBase {
+
+ private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
+
+ protected static final SolrResourceLoader solrResourceLoader = new SolrResourceLoader();
+
+ protected static File tmpSolrHome;
+ protected static File tmpConfDir;
+
+ public static final String FEATURE_FILE_NAME = "_schema_feature-store.json";
+ public static final String MODEL_FILE_NAME = "_schema_model-store.json";
+ public static final String PARENT_ENDPOINT = "/schema/*";
+
+ protected static final String COLLECTION = "collection1";
+ protected static final String CONF_DIR = COLLECTION + "/conf";
+
+ protected static File fstorefile = null;
+ protected static File mstorefile = null;
+
+ public static void setuptest() throws Exception {
+ setuptest("solrconfig-ltr.xml", "schema.xml");
+ bulkIndex();
+ }
+
+ public static void setupPersistenttest() throws Exception {
+ setupPersistentTest("solrconfig-ltr.xml", "schema.xml");
+ bulkIndex();
+ }
+
+ public static ManagedFeatureStore getManagedFeatureStore() {
+ return ManagedFeatureStore.getManagedFeatureStore(h.getCore());
+ }
+
+ public static ManagedModelStore getManagedModelStore() {
+ return ManagedModelStore.getManagedModelStore(h.getCore());
+ }
+
+ protected static SortedMap<ServletHolder,String> setupTestInit(
+ String solrconfig, String schema,
+ boolean isPersistent) throws Exception {
+ tmpSolrHome = createTempDir().toFile();
+ tmpConfDir = new File(tmpSolrHome, CONF_DIR);
+ tmpConfDir.deleteOnExit();
+ FileUtils.copyDirectory(new File(TEST_HOME()),
+ tmpSolrHome.getAbsoluteFile());
+
+ final File fstore = new File(tmpConfDir, FEATURE_FILE_NAME);
+ final File mstore = new File(tmpConfDir, MODEL_FILE_NAME);
+
+ if (isPersistent) {
+ fstorefile = fstore;
+ mstorefile = mstore;
+ }
+
+ if (fstore.exists()) {
+ log.info("remove feature store config file in {}",
+ fstore.getAbsolutePath());
+ Files.delete(fstore.toPath());
+ }
+ if (mstore.exists()) {
+ log.info("remove model store config file in {}",
+ mstore.getAbsolutePath());
+ Files.delete(mstore.toPath());
+ }
+ if (!solrconfig.equals("solrconfig.xml")) {
+ FileUtils.copyFile(new File(tmpSolrHome.getAbsolutePath()
+ + "/collection1/conf/" + solrconfig),
+ new File(tmpSolrHome.getAbsolutePath()
+ + "/collection1/conf/solrconfig.xml"));
+ }
+ if (!schema.equals("schema.xml")) {
+ FileUtils.copyFile(new File(tmpSolrHome.getAbsolutePath()
+ + "/collection1/conf/" + schema),
+ new File(tmpSolrHome.getAbsolutePath()
+ + "/collection1/conf/schema.xml"));
+ }
+
+ final SortedMap<ServletHolder,String> extraServlets = new TreeMap<>();
+ final ServletHolder solrRestApi = new ServletHolder("SolrSchemaRestApi",
+ ServerServlet.class);
+ solrRestApi.setInitParameter("org.restlet.application",
+ SolrSchemaRestApi.class.getCanonicalName());
+ solrRestApi.setInitParameter("storageIO",
+ ManagedResourceStorage.InMemoryStorageIO.class.getCanonicalName());
+ extraServlets.put(solrRestApi, PARENT_ENDPOINT);
+
+ System.setProperty("managed.schema.mutable", "true");
+
+ return extraServlets;
+ }
+
+ public static void setuptest(String solrconfig, String schema)
+ throws Exception {
+ initCore(solrconfig, schema);
+
+ SortedMap<ServletHolder,String> extraServlets =
+ setupTestInit(solrconfig,schema,false);
+ System.setProperty("enable.update.log", "false");
+
+ createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema,
+ "/solr", true, extraServlets);
+ }
+
+ public static void setupPersistentTest(String solrconfig, String schema)
+ throws Exception {
+ initCore(solrconfig, schema);
+
+ SortedMap<ServletHolder,String> extraServlets =
+ setupTestInit(solrconfig,schema,true);
+
+ createJettyAndHarness(tmpSolrHome.getAbsolutePath(), solrconfig, schema,
+ "/solr", true, extraServlets);
+ }
+
+ protected static void aftertest() throws Exception {
+ restTestHarness.close();
+ restTestHarness = null;
+ jetty.stop();
+ jetty = null;
+ FileUtils.deleteDirectory(tmpSolrHome);
+ System.clearProperty("managed.schema.mutable");
+ // System.clearProperty("enable.update.log");
+
+
+ }
+
+ public static void makeRestTestHarnessNull() {
+ restTestHarness = null;
+ }
+
+ /** produces a model encoded in json **/
+ public static String getModelInJson(String name, String type,
+ String[] features, String fstore, String params) {
+ final StringBuilder sb = new StringBuilder();
+ sb.append("{\n");
+ sb.append("\"name\":").append('"').append(name).append('"').append(",\n");
+ sb.append("\"store\":").append('"').append(fstore).append('"')
+ .append(",\n");
+ sb.append("\"class\":").append('"').append(type).append('"').append(",\n");
+ sb.append("\"features\":").append('[');
+ for (final String feature : features) {
+ sb.append("\n\t{ ");
+ sb.append("\"name\":").append('"').append(feature).append('"')
+ .append("},");
+ }
+ sb.deleteCharAt(sb.length() - 1);
+ sb.append("\n]\n");
+ if (params != null) {
+ sb.append(",\n");
+ sb.append("\"params\":").append(params);
+ }
+ sb.append("\n}\n");
+ return sb.toString();
+ }
+
+ /** produces a model encoded in json **/
+ public static String getFeatureInJson(String name, String type,
+ String fstore, String params) {
+ final StringBuilder sb = new StringBuilder();
+ sb.append("{\n");
+ sb.append("\"name\":").append('"').append(name).append('"').append(",\n");
+ sb.append("\"store\":").append('"').append(fstore).append('"')
+ .append(",\n");
+ sb.append("\"class\":").append('"').append(type).append('"');
+ if (params != null) {
+ sb.append(",\n");
+ sb.append("\"params\":").append(params);
+ }
+ sb.append("\n}\n");
+ return sb.toString();
+ }
+
+ protected static void loadFeature(String name, String type, String params)
+ throws Exception {
+ final String feature = getFeatureInJson(name, type, "test", params);
+ log.info("loading feauture \n{} ", feature);
+ assertJPut(ManagedFeatureStore.REST_END_POINT, feature,
+ "/responseHeader/status==0");
+ }
+
+ protected static void loadFeature(String name, String type, String fstore,
+ String params) throws Exception {
+ final String feature = getFeatureInJson(name, type, fstore, params);
+ log.info("loading feauture \n{} ", feature);
+ assertJPut(ManagedFeatureStore.REST_END_POINT, feature,
+ "/responseHeader/status==0");
+ }
+
+ protected static void loadModel(String name, String type, String[] features,
+ String params) throws Exception {
+ loadModel(name, type, features, "test", params);
+ }
+
+ protected static void loadModel(String name, String type, String[] features,
+ String fstore, String params) throws Exception {
+ final String model = getModelInJson(name, type, features, fstore, params);
+ log.info("loading model \n{} ", model);
+ assertJPut(ManagedModelStore.REST_END_POINT, model,
+ "/responseHeader/status==0");
+ }
+
+ public static void loadModels(String fileName) throws Exception {
+ final URL url = TestRerankBase.class.getResource("/modelExamples/"
+ + fileName);
+ final String multipleModels = FileUtils.readFileToString(
+ new File(url.toURI()), "UTF-8");
+
+ assertJPut(ManagedModelStore.REST_END_POINT, multipleModels,
+ "/responseHeader/status==0");
+ }
+
+ public static LTRScoringModel createModelFromFiles(String modelFileName,
+ String featureFileName) throws ModelException, Exception {
+ URL url = TestRerankBase.class.getResource("/modelExamples/"
+ + modelFileName);
+ final String modelJson = FileUtils.readFileToString(new File(url.toURI()),
+ "UTF-8");
+ final ManagedModelStore ms = getManagedModelStore();
+
+ url = TestRerankBase.class.getResource("/featureExamples/"
+ + featureFileName);
+ final String featureJson = FileUtils.readFileToString(
+ new File(url.toURI()), "UTF-8");
+
+ Object parsedFeatureJson = null;
+ try {
+ parsedFeatureJson = ObjectBuilder.fromJSON(featureJson);
+ } catch (final IOException ioExc) {
+ throw new ModelException("ObjectBuilder failed parsing json", ioExc);
+ }
+
+ final ManagedFeatureStore fs = getManagedFeatureStore();
+ // fs.getFeatureStore(null).clear();
+ fs.doDeleteChild(null, "*"); // is this safe??
+ // based on my need to call this I dont think that
+ // "getNewManagedFeatureStore()"
+ // is actually returning a new feature store each time
+ fs.applyUpdatesToManagedData(parsedFeatureJson);
+ ms.setManagedFeatureStore(fs); // can we skip this and just use fs directly below?
+
+ final LTRScoringModel ltrScoringModel = ManagedModelStore.fromLTRScoringModelMap(
+ solrResourceLoader, mapFromJson(modelJson), ms.getManagedFeatureStore());
+ ms.addModel(ltrScoringModel);
+ return ltrScoringModel;
+ }
+
+ @SuppressWarnings("unchecked")
+ static private Map<String,Object> mapFromJson(String json) throws ModelException {
+ Object parsedJson = null;
+ try {
+ parsedJson = ObjectBuilder.fromJSON(json);
+ } catch (final IOException ioExc) {
+ throw new ModelException("ObjectBuilder failed parsing json", ioExc);
+ }
+ return (Map<String,Object>) parsedJson;
+ }
+
+ public static void loadFeatures(String fileName) throws Exception {
+ final URL url = TestRerankBase.class.getResource("/featureExamples/"
+ + fileName);
+ final String multipleFeatures = FileUtils.readFileToString(
+ new File(url.toURI()), "UTF-8");
+ log.info("send \n{}", multipleFeatures);
+
+ assertJPut(ManagedFeatureStore.REST_END_POINT, multipleFeatures,
+ "/responseHeader/status==0");
+ }
+
+ protected List<Feature> getFeatures(List<String> names)
+ throws FeatureException {
+ final List<Feature> features = new ArrayList<>();
+ int pos = 0;
+ for (final String name : names) {
+ final Map<String,Object> params = new HashMap<String,Object>();
+ params.put("value", 10);
+ final Feature f = Feature.getInstance(solrResourceLoader,
+ ValueFeature.class.getCanonicalName(),
+ name, params);
+ f.setIndex(pos);
+ features.add(f);
+ ++pos;
+ }
+ return features;
+ }
+
+ protected List<Feature> getFeatures(String[] names) throws FeatureException {
+ return getFeatures(Arrays.asList(names));
+ }
+
+ protected static void loadModelAndFeatures(String name, int allFeatureCount,
+ int modelFeatureCount) throws Exception {
+ final String[] features = new String[modelFeatureCount];
+ final String[] weights = new String[modelFeatureCount];
+ for (int i = 0; i < allFeatureCount; i++) {
+ final String featureName = "c" + i;
+ if (i < modelFeatureCount) {
+ features[i] = featureName;
+ weights[i] = "\"" + featureName + "\":1.0";
+ }
+ loadFeature(featureName, ValueFeature.ValueFeatureWeight.class.getCanonicalName(),
+ "{\"value\":" + i + "}");
+ }
+
+ loadModel(name, LinearModel.class.getCanonicalName(), features,
+ "{\"weights\":{" + StringUtils.join(weights, ",") + "}}");
+ }
+
+ protected static void bulkIndex() throws Exception {
+ assertU(adoc("title", "bloomberg different bla", "description",
+ "bloomberg", "id", "6", "popularity", "1"));
+ assertU(adoc("title", "bloomberg bloomberg ", "description", "bloomberg",
+ "id", "7", "popularity", "2"));
+ assertU(adoc("title", "bloomberg bloomberg bloomberg", "description",
+ "bloomberg", "id", "8", "popularity", "3"));
+ assertU(adoc("title", "bloomberg bloomberg bloomberg bloomberg",
+ "description", "bloomberg", "id", "9", "popularity", "5"));
+ assertU(commit());
+ }
+
+ protected static void bulkIndex(String filePath) throws Exception {
+ final SolrQueryRequestBase req = lrf.makeRequest(
+ CommonParams.STREAM_CONTENTTYPE, "application/xml");
+
+ final List<ContentStream> streams = new ArrayList<ContentStream>();
+ final File file = new File(filePath);
+ streams.add(new ContentStreamBase.FileStream(file));
+ req.setContentStreams(streams);
+
+ try {
+ final SolrQueryResponse res = new SolrQueryResponse();
+ h.updater.handleRequest(req, res);
+ } catch (final Throwable ex) {
+ // Ignore. Just log the exception and go to the next file
+ log.error(ex.getMessage(), ex);
+ }
+ assertU(commit());
+
+ }
+
+ protected static void buildIndexUsingAdoc(String filepath)
+ throws FileNotFoundException {
+ final Scanner scn = new Scanner(new File(filepath), "UTF-8");
+ StringBuffer buff = new StringBuffer();
+ scn.nextLine();
+ scn.nextLine();
+ scn.nextLine(); // Skip the first 3 lines then add everything else
+ final ArrayList<String> docsToAdd = new ArrayList<String>();
+ while (scn.hasNext()) {
+ String curLine = scn.nextLine();
+ if (curLine.contains("</doc>")) {
+ buff.append(curLine + "\n");
+ docsToAdd.add(buff.toString().replace("</add>", "")
+ .replace("<doc>", "<add>\n<doc>")
+ .replace("</doc>", "</doc>\n</add>"));
+ if (!scn.hasNext()) {
+ break;
+ } else {
+ curLine = scn.nextLine();
+ }
+ buff = new StringBuffer();
+ }
+ buff.append(curLine + "\n");
+ }
+ for (final String doc : docsToAdd) {
+ assertU(doc.trim());
+ }
+ assertU(commit());
+ scn.close();
+ }
+
+}