You are viewing a plain text version of this content. The canonical link for it is here.
Posted to commits@solr.apache.org by ab...@apache.org on 2023/01/30 15:05:53 UTC
[solr] branch branch_9x updated: SOLR-16596: Introduce support for null feature values in LTR MultipleAdditiveTreeModel (#1257)
This is an automated email from the ASF dual-hosted git repository.
abenedetti pushed a commit to branch branch_9x
in repository https://gitbox.apache.org/repos/asf/solr.git
The following commit(s) were added to refs/heads/branch_9x by this push:
new 750b3d5cb01 SOLR-16596: Introduce support for null feature values in LTR MultipleAdditiveTreeModel (#1257)
750b3d5cb01 is described below
commit 750b3d5cb0169dfad2042f1ec60ef447547052cc
Author: aruggero <57...@users.noreply.github.com>
AuthorDate: Mon Jan 30 15:30:44 2023 +0100
SOLR-16596: Introduce support for null feature values in LTR MultipleAdditiveTreeModel (#1257)
---
solr/CHANGES.txt | 2 +
.../java/org/apache/solr/ltr/CSVFeatureLogger.java | 2 +-
.../java/org/apache/solr/ltr/LTRScoringQuery.java | 8 +-
.../solr/ltr/model/MultipleAdditiveTreesModel.java | 92 +++++++++-
...ivetreesmodel_features_with_missing_branch.json | 26 +++
...ipleadditivetreesmodel_with_missing_branch.json | 48 +++++
...model_with_missing_branch_for_interleaving.json | 48 +++++
.../ltr/model/TestMultipleAdditiveTreesModel.java | 63 ++++++-
.../transform/TestFeatureLoggerTransformer.java | 197 +++++++++++++++++++++
.../query-guide/pages/learning-to-rank.adoc | 87 +++++++++
10 files changed, 564 insertions(+), 9 deletions(-)
diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt
index e39de2581a3..be1abc14bec 100644
--- a/solr/CHANGES.txt
+++ b/solr/CHANGES.txt
@@ -19,6 +19,8 @@ New Features
* SOLR-16532: New OpenTelemetry (OTEL) module with OTLP/gRPC trace exporter. See ref.guide. (janhoy, David Smiley)
+* SOLR-16596: Learning To Rank - Added support for null feature values in multiple additive trees models (Anna Ruggero via Alessandro Benedetti)
+
Improvements
---------------------
diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java
index 244b1b76d21..aea4d337b20 100644
--- a/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java
+++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java
@@ -43,7 +43,7 @@ public class CSVFeatureLogger extends FeatureLogger {
StringBuilder sb = new StringBuilder(featuresInfo.length * 3);
boolean isDense = featureFormat.equals(FeatureFormat.DENSE);
for (LTRScoringQuery.FeatureInfo featInfo : featuresInfo) {
- if (featInfo.isUsed() || isDense) {
+ if (isDense || featInfo.isUsed()) {
sb.append(featInfo.getName())
.append(keyValueSep)
.append(featInfo.getValue())
diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java
index a7b3d7bd7f3..d7787826991 100644
--- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java
+++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java
@@ -527,8 +527,8 @@ public class LTRScoringQuery extends Query implements Accountable {
public ModelScorer(Weight weight, List<Feature.FeatureWeight.FeatureScorer> featureScorers) {
super(weight);
docInfo = new DocInfo();
- for (final Feature.FeatureWeight.FeatureScorer subSocer : featureScorers) {
- subSocer.setDocInfo(docInfo);
+ for (final Feature.FeatureWeight.FeatureScorer subScorer : featureScorers) {
+ subScorer.setDocInfo(docInfo);
}
if (featureScorers.size() <= 1) {
// future enhancement: allow the use of dense features in other cases
@@ -593,7 +593,7 @@ public class LTRScoringQuery extends Query implements Accountable {
@Override
public float score() throws IOException {
final DisiWrapper topList = subScorers.topList();
- // If target doc we wanted to advance to matches the actual doc
+ // If target doc we wanted to advance to match the actual doc
// the underlying features advanced to, perform the feature
// calculations,
// otherwise just continue with the model's scoring process with empty
@@ -648,7 +648,7 @@ public class LTRScoringQuery extends Query implements Accountable {
@Override
public final int advance(int target) throws IOException {
- // If target doc we wanted to advance to matches the actual doc
+ // If target doc we wanted to advance to match the actual doc
// the underlying features advanced to, perform the feature
// calculations,
// otherwise just continue with the model's scoring process with
diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.java
index 4fa4ab77674..0d591bd0122 100644
--- a/solr/modules/ltr/src/java/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.java
+++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.java
@@ -22,9 +22,11 @@ import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.Explanation;
import org.apache.solr.ltr.feature.Feature;
+import org.apache.solr.ltr.feature.FeatureException;
import org.apache.solr.ltr.norm.Normalizer;
import org.apache.solr.util.SolrPluginUtils;
@@ -111,6 +113,8 @@ public class MultipleAdditiveTreesModel extends LTRScoringModel {
*/
private List<RegressionTree> trees;
+ private boolean isNullSameAsZero = true;
+
private RegressionTree createRegressionTree(Map<String, Object> map) {
final RegressionTree rt = new RegressionTree();
if (map != null) {
@@ -127,6 +131,10 @@ public class MultipleAdditiveTreesModel extends LTRScoringModel {
return rtn;
}
+ public void setIsNullSameAsZero(boolean nullSameAsZero) {
+ isNullSameAsZero = nullSameAsZero;
+ }
+
public class RegressionTreeNode {
private static final float NODE_SPLIT_SLACK = 1E-6f;
@@ -136,6 +144,7 @@ public class MultipleAdditiveTreesModel extends LTRScoringModel {
private Float threshold;
private RegressionTreeNode left;
private RegressionTreeNode right;
+ private String missing;
public void setValue(float value) {
this.value = value;
@@ -145,6 +154,10 @@ public class MultipleAdditiveTreesModel extends LTRScoringModel {
this.value = Float.parseFloat(value);
}
+ public void setMissing(String direction) {
+ this.missing = direction;
+ }
+
public void setFeature(String feature) {
this.feature = feature;
final Integer idx = fname2index.get(this.feature);
@@ -184,6 +197,9 @@ public class MultipleAdditiveTreesModel extends LTRScoringModel {
} else {
sb.append("(feature=").append(feature);
sb.append(",threshold=").append(threshold.floatValue() - NODE_SPLIT_SLACK);
+ if (missing != null) {
+ sb.append(",missing=").append(missing);
+ }
sb.append(",left=").append(left);
sb.append(",right=").append(right);
sb.append(')');
@@ -213,7 +229,11 @@ public class MultipleAdditiveTreesModel extends LTRScoringModel {
}
public float score(float[] featureVector) {
- return weight.floatValue() * scoreNode(featureVector, root);
+ if (isNullSameAsZero) {
+ return weight.floatValue() * scoreNode(featureVector, root);
+ } else {
+ return weight.floatValue() * scoreNodeWithNullSupport(featureVector, root);
+ }
}
public String explain(float[] featureVector) {
@@ -275,6 +295,31 @@ public class MultipleAdditiveTreesModel extends LTRScoringModel {
}
}
+ @Override
+ public void normalizeFeaturesInPlace(float[] modelFeatureValues) {
+ normalizeFeaturesInPlace(modelFeatureValues, isNullSameAsZero);
+ }
+
+ protected void normalizeFeaturesInPlace(float[] modelFeatureValues, boolean isNullSameAsZero) {
+ float[] modelFeatureValuesNormalized = modelFeatureValues;
+ if (modelFeatureValues.length != norms.size()) {
+ throw new FeatureException("Must have normalizer for every feature");
+ }
+ if (isNullSameAsZero) {
+ for (int idx = 0; idx < modelFeatureValuesNormalized.length; ++idx) {
+ modelFeatureValuesNormalized[idx] =
+ norms.get(idx).normalize(modelFeatureValuesNormalized[idx]);
+ }
+ } else {
+ for (int idx = 0; idx < modelFeatureValuesNormalized.length; ++idx) {
+ if (!Float.isNaN(modelFeatureValuesNormalized[idx])) {
+ modelFeatureValuesNormalized[idx] =
+ norms.get(idx).normalize(modelFeatureValuesNormalized[idx]);
+ }
+ }
+ }
+ }
+
@Override
public float score(float[] modelFeatureValuesNormalized) {
float score = 0;
@@ -303,6 +348,34 @@ public class MultipleAdditiveTreesModel extends LTRScoringModel {
}
}
+ private static float scoreNodeWithNullSupport(
+ float[] featureVector, RegressionTreeNode regressionTreeNode) {
+ while (true) {
+ if (regressionTreeNode.isLeaf()) {
+ return regressionTreeNode.value;
+ }
+ // unsupported feature (tree is looking for a feature that does not exist)
+ if ((regressionTreeNode.featureIndex < 0)
+ || (regressionTreeNode.featureIndex >= featureVector.length)) {
+ return 0f;
+ }
+
+ if (featureVector[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold) {
+ regressionTreeNode = regressionTreeNode.left;
+ } else if (featureVector[regressionTreeNode.featureIndex] > regressionTreeNode.threshold) {
+ regressionTreeNode = regressionTreeNode.right;
+ } else if (Float.isNaN(featureVector[regressionTreeNode.featureIndex])) {
+ switch (regressionTreeNode.missing) {
+ case "left":
+ regressionTreeNode = regressionTreeNode.left;
+ break;
+ default:
+ regressionTreeNode = regressionTreeNode.right;
+ }
+ }
+ }
+ }
+
private static void validateNode(RegressionTreeNode regressionTreeNode) throws ModelException {
// Create an empty stack and push root to it
@@ -359,7 +432,6 @@ public class MultipleAdditiveTreesModel extends LTRScoringModel {
// could store extra information about how much training data supported
// each branch and report
// that here
-
if (featureVector[regressionTreeNode.featureIndex] <= regressionTreeNode.threshold) {
returnValueBuilder
.append("'")
@@ -370,7 +442,7 @@ public class MultipleAdditiveTreesModel extends LTRScoringModel {
.append(regressionTreeNode.threshold)
.append(", Go Left | ");
regressionTreeNode = regressionTreeNode.left;
- } else {
+ } else if (featureVector[regressionTreeNode.featureIndex] > regressionTreeNode.threshold) {
returnValueBuilder
.append("'")
.append(regressionTreeNode.feature)
@@ -380,6 +452,20 @@ public class MultipleAdditiveTreesModel extends LTRScoringModel {
.append(regressionTreeNode.threshold)
.append(", Go Right | ");
regressionTreeNode = regressionTreeNode.right;
+ } else if (Float.isNaN(featureVector[regressionTreeNode.featureIndex])) {
+ if (Objects.equals(regressionTreeNode.missing, "left")) {
+ returnValueBuilder
+ .append("'")
+ .append(regressionTreeNode.feature)
+ .append("': NaN, Go Left | ");
+ regressionTreeNode = regressionTreeNode.left;
+ } else {
+ returnValueBuilder
+ .append("'")
+ .append(regressionTreeNode.feature)
+ .append("': NaN, Go Right | ");
+ regressionTreeNode = regressionTreeNode.right;
+ }
}
}
}
diff --git a/solr/modules/ltr/src/test-files/featureExamples/multipleadditivetreesmodel_features_with_missing_branch.json b/solr/modules/ltr/src/test-files/featureExamples/multipleadditivetreesmodel_features_with_missing_branch.json
new file mode 100644
index 00000000000..3e8b2b74d32
--- /dev/null
+++ b/solr/modules/ltr/src/test-files/featureExamples/multipleadditivetreesmodel_features_with_missing_branch.json
@@ -0,0 +1,26 @@
+[
+ {
+ "name": "matchedTitle",
+ "class": "org.apache.solr.ltr.feature.SolrFeature",
+ "params": {
+ "q": "{!terms f=title}${user_query}"
+ }
+ },
+ {
+ "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs",
+ "class": "org.apache.solr.ltr.feature.ValueFeature",
+ "params": {
+ "value": "1"
+ }
+ },
+ {
+ "name": "userDevice",
+ "class": "org.apache.solr.ltr.feature.ValueFeature",
+ "params": {
+ "value": "${user_device}",
+ "defaultValue": "NaN",
+ "required": false
+ }
+ }
+]
+
diff --git a/solr/modules/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_with_missing_branch.json b/solr/modules/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_with_missing_branch.json
new file mode 100644
index 00000000000..c8467c52089
--- /dev/null
+++ b/solr/modules/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_with_missing_branch.json
@@ -0,0 +1,48 @@
+{
+ "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
+ "name":"modelA",
+ "features":[
+ { "name": "matchedTitle"},
+ { "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"},
+ { "name": "userDevice"}
+ ],
+ "params":{
+ "isNullSameAsZero": false,
+ "trees": [
+ {
+ "weight" : "1f",
+ "root": {
+ "feature": "matchedTitle",
+ "threshold": "0.5f",
+ "left" : {
+ "value" : "-100"
+ },
+ "right": {
+ "feature" : "constantScoreToForceMultipleAdditiveTreesScoreAllDocs",
+ "threshold": "10.0f",
+ "left" : {
+ "feature" : "userDevice",
+ "threshold": "0f",
+ "missing": "left",
+ "left" : {
+ "value" : "50"
+ },
+ "right" : {
+ "value" : "65"
+ }
+ },
+ "right" : {
+ "value" : "75"
+ }
+ }
+ }
+ },
+ {
+ "weight" : "2f",
+ "root": {
+ "value" : "-10"
+ }
+ }
+ ]
+ }
+}
diff --git a/solr/modules/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_with_missing_branch_for_interleaving.json b/solr/modules/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_with_missing_branch_for_interleaving.json
new file mode 100644
index 00000000000..b11b74af911
--- /dev/null
+++ b/solr/modules/ltr/src/test-files/modelExamples/multipleadditivetreesmodel_with_missing_branch_for_interleaving.json
@@ -0,0 +1,48 @@
+{
+ "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
+ "name":"modelB",
+ "features":[
+ { "name": "matchedTitle"},
+ { "name": "constantScoreToForceMultipleAdditiveTreesScoreAllDocs"},
+ { "name": "userDevice"}
+ ],
+ "params":{
+ "isNullSameAsZero": false,
+ "trees": [
+ {
+ "weight" : "1f",
+ "root": {
+ "feature": "matchedTitle",
+ "threshold": "0.5f",
+ "left" : {
+ "feature" : "constantScoreToForceMultipleAdditiveTreesScoreAllDocs",
+ "threshold": "10.0f",
+ "left" : {
+ "feature" : "userDevice",
+ "threshold": "0f",
+ "missing": "left",
+ "left" : {
+ "value" : "20"
+ },
+ "right" : {
+ "value" : "15"
+ }
+ },
+ "right" : {
+ "value" : "85"
+ }
+ },
+ "right": {
+ "value" : "-5"
+ }
+ }
+ },
+ {
+ "weight" : "2f",
+ "root": {
+ "value" : "-20"
+ }
+ }
+ ]
+ }
+}
diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/model/TestMultipleAdditiveTreesModel.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/model/TestMultipleAdditiveTreesModel.java
index 25d2578937f..37aa10fb12b 100644
--- a/solr/modules/ltr/src/test/org/apache/solr/ltr/model/TestMultipleAdditiveTreesModel.java
+++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/model/TestMultipleAdditiveTreesModel.java
@@ -169,7 +169,7 @@ public class TestMultipleAdditiveTreesModel extends TestRerankBase {
}
@Test
- public void multipleAdditiveTreesTestTreesParamDoesNotContatinTree() throws Exception {
+ public void multipleAdditiveTreesTestTreesParamDoesNotContainTree() throws Exception {
final ModelException expectedException =
new ModelException("MultipleAdditiveTreesModel tree doesn't contain a tree");
Exception ex =
@@ -296,4 +296,65 @@ public class TestMultipleAdditiveTreesModel extends TestRerankBase {
});
assertEquals(expectedException.toString(), ex.toString());
}
+
+ @Test
+ public void testMultipleAdditiveTreesWithNulls() throws Exception {
+ loadFeatures("multipleadditivetreesmodel_features_with_missing_branch.json");
+ loadModels("multipleadditivetreesmodel_with_missing_branch.json");
+
+ doTestMultipleAdditiveTreesWithNulls();
+ doTestMultipleAdditiveTreesExplainWithNulls();
+ }
+
+ private void doTestMultipleAdditiveTreesWithNulls() throws Exception {
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("rows", "3");
+ query.add("fl", "*,score");
+
+ query.add("rq", "{!ltr reRankDocs=3 model=modelA efi.user_query=w3}");
+
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='3'");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/score==30.0");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[1]/score==-120.0");
+ assertJQ("/query" + query.toQueryString(), "/response/docs/[2]/score==-120.0");
+ }
+
+ private void doTestMultipleAdditiveTreesExplainWithNulls() throws Exception {
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("fl", "*,score,[fv]");
+ query.add("rows", "3");
+
+ query.add("rq", "{!ltr reRankDocs=3 model=modelA efi.user_query=w3}");
+
+ // test out the explain feature, make sure it returns something
+ query.setParam("debugQuery", "on");
+
+ String qryResult = JQ("/query" + query.toQueryString());
+ qryResult = qryResult.replaceAll("\n", " ");
+
+ MatcherAssert.assertThat(qryResult, containsString("\"debug\":{"));
+ qryResult = qryResult.substring(qryResult.indexOf("debug"));
+
+ MatcherAssert.assertThat(qryResult, containsString("\"explain\":{"));
+ qryResult = qryResult.substring(qryResult.indexOf("explain"));
+
+ MatcherAssert.assertThat(qryResult, containsString("modelA"));
+ MatcherAssert.assertThat(
+ qryResult, containsString(MultipleAdditiveTreesModel.class.getSimpleName()));
+
+ MatcherAssert.assertThat(qryResult, containsString("50.0 = tree 0"));
+ MatcherAssert.assertThat(qryResult, containsString("-20.0 = tree 1"));
+ MatcherAssert.assertThat(qryResult, containsString("'matchedTitle':1.0 > 0.5"));
+ MatcherAssert.assertThat(
+ qryResult,
+ containsString("'constantScoreToForceMultipleAdditiveTreesScoreAllDocs':1.0 <= 10.0"));
+ MatcherAssert.assertThat(qryResult, containsString("'userDevice': NaN"));
+
+ MatcherAssert.assertThat(qryResult, containsString(" Go Right "));
+ MatcherAssert.assertThat(qryResult, containsString(" Go Left "));
+ }
}
diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/response/transform/TestFeatureLoggerTransformer.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/response/transform/TestFeatureLoggerTransformer.java
index cd1067dd7f3..8f2d611e2d3 100644
--- a/solr/modules/ltr/src/test/org/apache/solr/ltr/response/transform/TestFeatureLoggerTransformer.java
+++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/response/transform/TestFeatureLoggerTransformer.java
@@ -122,6 +122,110 @@ public class TestFeatureLoggerTransformer extends TestRerankBase {
"{\"weights\":{\"featureC1\":5.0, \"featureC2\":25.0}}");
}
+ protected void loadFeaturesAndModelsWithNulls() throws Exception {
+ loadFeatures("multipleadditivetreesmodel_features_with_missing_branch.json");
+ loadModels("multipleadditivetreesmodel_with_missing_branch.json");
+ loadModels("multipleadditivetreesmodel_with_missing_branch_for_interleaving.json");
+ }
+
+ @Test
+ public void featureTransformer_shouldWorkInSparseFormat_withNulls() throws Exception {
+ loadFeaturesAndModelsWithNulls();
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("fl", "*, score,features:[fv format=sparse]");
+ query.add("rows", "10");
+ query.add("debugQuery", "true");
+ query.add("rq", "{!ltr model=modelA reRankDocs=10 efi.user_query=w3}");
+
+ String[] expectedFeatureVectors =
+ new String[] {
+ "matchedTitle\\=1.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0"
+ };
+
+ int[] expectedIds = new int[] {7, 1, 2, 3, 4, 5, 6, 8};
+
+ String[] tests = new String[17];
+ tests[0] = "/response/numFound/==8";
+ for (int i = 1; i <= 8; i++) {
+ tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedIds[(i - 1)] + "\"";
+ tests[i + 8] =
+ "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
+ }
+ assertJQ("/query" + query.toQueryString(), tests);
+
+ // user_device has a different default value (NaN), if zero we would like to see the zero value
+ final SolrQuery query2 = new SolrQuery();
+ query2.setQuery("*:*");
+ query2.add("fl", "*, score,features:[fv format=sparse]");
+ query2.add("rows", "10");
+ query2.add("debugQuery", "true");
+ query2.add("rq", "{!ltr model=modelA reRankDocs=10 efi.user_query=w3 efi.user_device=0}");
+
+ expectedFeatureVectors =
+ new String[] {
+ "matchedTitle\\=1.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0"
+ };
+
+ expectedIds = new int[] {7, 1, 2, 3, 4, 5, 6, 8};
+
+ tests = new String[17];
+ tests[0] = "/response/numFound/==8";
+ for (int i = 1; i <= 8; i++) {
+ tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedIds[(i - 1)] + "\"";
+ tests[i + 8] =
+ "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
+ }
+ assertJQ("/query" + query2.toQueryString(), tests);
+ }
+
+ @Test
+ public void featureTransformer_shouldWorkInDenseFormat_withNulls() throws Exception {
+ loadFeaturesAndModelsWithNulls();
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("fl", "*, score,features:[fv format=dense]");
+ query.add("rows", "10");
+ query.add("debugQuery", "true");
+ query.add("rq", "{!ltr model=modelA reRankDocs=10 efi.user_query=w3}");
+
+ String[] expectedFeatureVectors =
+ new String[] {
+ "matchedTitle\\=1.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN",
+ "matchedTitle\\=0.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN",
+ "matchedTitle\\=0.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN",
+ "matchedTitle\\=0.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN",
+ "matchedTitle\\=0.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN",
+ "matchedTitle\\=0.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN",
+ "matchedTitle\\=0.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN",
+ "matchedTitle\\=0.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN"
+ };
+
+ String[] tests = new String[17];
+ tests[0] = "/response/numFound/==8";
+ for (int i = 1; i <= 8; i++) {
+ tests[i + 8] =
+ "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
+ }
+ assertJQ("/query" + query.toQueryString(), tests);
+ }
+
@Test
public void interleaving_featureTransformer_shouldWorkInSparseFormat() throws Exception {
TeamDraftInterleaving.setRANDOM(
@@ -212,6 +316,99 @@ public class TestFeatureLoggerTransformer extends TestRerankBase {
assertJQ("/query" + query.toQueryString(), tests);
}
+ @Test
+ public void interleaving_featureTransformer_shouldWorkInSparseFormat_withNulls()
+ throws Exception {
+ TeamDraftInterleaving.setRANDOM(
+ new Random(10101011)); // Random Boolean Choices Generation from Seed: [0,0,1]
+ loadFeaturesAndModelsWithNulls();
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("fl", "*, score,features:[fv format=sparse]");
+ query.add("rows", "10");
+ query.add("debugQuery", "true");
+ query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
+ query.add(
+ "rq",
+ "{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5' efi.user_device=0}");
+
+ /*
+ Doc1 = "matchedTitle=0.0,constantScoreToForceMultipleAdditiveTreesScoreAllDocs=1.0,userDevice=0.0", ScoreA(30), ScoreB(-20)
+ Doc3 = "matchedTitle=0.0,constantScoreToForceMultipleAdditiveTreesScoreAllDocs=1.0,userDevice=0.0", ScoreA(30), ScoreB(-20)
+ Doc4 = "matchedTitle=0.0,constantScoreToForceMultipleAdditiveTreesScoreAllDocs=1.0,userDevice=0.0", ScoreA(30), ScoreB(-20)
+ Doc8 = "matchedTitle=0.0,constantScoreToForceMultipleAdditiveTreesScoreAllDocs=1.0,userDevice=0.0", ScoreA(30), ScoreB(-20)
+ Doc7 = "matchedTitle=1.0,constantScoreToForceMultipleAdditiveTreesScoreAllDocs=1.0,userDevice=0.0", ScoreA(30), ScoreB(-45)
+ ModelARerankedList = [7,1,3,4,8]
+ ModelBRerankedList = [1,3,4,8,7]
+
+ Random Boolean Choices Generation from Seed: [0,0,1]
+ */
+ String[] expectedFeatureVectors =
+ new String[] {
+ "matchedTitle\\=1.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0",
+ "constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=0.0"
+ };
+ int[] expectedInterleaved = new int[] {7, 1, 3, 4, 8};
+
+ String[] tests = new String[11];
+ tests[0] = "/response/numFound/==5";
+ for (int i = 1; i <= 5; i++) {
+ tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
+ tests[i + 5] =
+ "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
+ }
+ assertJQ("/query" + query.toQueryString(), tests);
+ }
+
+ @Test
+ public void interleaving_featureTransformer_shouldWorkInDenseFormat_withNulls() throws Exception {
+ TeamDraftInterleaving.setRANDOM(
+ new Random(10101011)); // Random Boolean Choices Generation from Seed: [0,0,1]
+ loadFeaturesAndModelsWithNulls();
+
+ final SolrQuery query = new SolrQuery();
+ query.setQuery("*:*");
+ query.add("fl", "*, score,features:[fv format=dense]");
+ query.add("rows", "10");
+ query.add("debugQuery", "true");
+ query.add("fq", "{!terms f=title}w1"); // 1,3,4,7,8
+ query.add("rq", "{!ltr model=modelA model=modelB reRankDocs=10 efi.user_query='w5'}");
+
+ /*
+ Doc1 = "matchedTitle=0.0,constantScoreToForceMultipleAdditiveTreesScoreAllDocs=1.0,userDevice=NaN", ScoreA(30), ScoreB(-20)
+ Doc3 = "matchedTitle=0.0,constantScoreToForceMultipleAdditiveTreesScoreAllDocs=1.0,userDevice=NaN", ScoreA(30), ScoreB(-20)
+ Doc4 = "matchedTitle=0.0,constantScoreToForceMultipleAdditiveTreesScoreAllDocs=1.0,userDevice=NaN", ScoreA(30), ScoreB(-20)
+ Doc8 = "matchedTitle=0.0,constantScoreToForceMultipleAdditiveTreesScoreAllDocs=1.0,userDevice=NaN", ScoreA(30), ScoreB(-20)
+ Doc7 = "matchedTitle=1.0,constantScoreToForceMultipleAdditiveTreesScoreAllDocs=1.0,userDevice=NaN", ScoreA(30), ScoreB(-45)
+ ModelARerankedList = [7,1,3,4,8]
+ ModelBRerankedList = [1,3,4,8,7]
+
+ Random Boolean Choices Generation from Seed: [0,0,1]
+ */
+ String[] expectedFeatureVectors =
+ new String[] {
+ "matchedTitle\\=1.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN",
+ "matchedTitle\\=0.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN",
+ "matchedTitle\\=0.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN",
+ "matchedTitle\\=0.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN",
+ "matchedTitle\\=0.0\\,constantScoreToForceMultipleAdditiveTreesScoreAllDocs\\=1.0\\,userDevice\\=NaN"
+ };
+ int[] expectedInterleaved = new int[] {7, 1, 3, 4, 8};
+
+ String[] tests = new String[11];
+ tests[0] = "/response/numFound/==5";
+ for (int i = 1; i <= 5; i++) {
+ tests[i] = "/response/docs/[" + (i - 1) + "]/id==\"" + expectedInterleaved[(i - 1)] + "\"";
+ tests[i + 5] =
+ "/response/docs/[" + (i - 1) + "]/features==" + expectedFeatureVectors[(i - 1)];
+ }
+ assertJQ("/query" + query.toQueryString(), tests);
+ }
+
@Test
public void interleaving_explicitNewFeatureStore_shouldExtractAllFeaturesFromNewStore()
throws Exception {
diff --git a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc
index 8f4b6c31716..4546f0efbb7 100644
--- a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc
+++ b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc
@@ -555,6 +555,93 @@ For sparse CSV output such as `featureA:0.1 featureB:0.2 featureC:0.3` you can c
</transformer>
----
+==== Models handling features' null values
+This feature is available only for {solr-javadocs}/modules/ltr/org/apache/solr/ltr/model/MultipleAdditiveTreesModel.html[MultipleAdditiveTreesModel].
+
+In some scenarios a null value for a feature has a different meaning than a zero value. There are models that are trained to distinguish the two (e.g. https://xgboost.readthedocs.io/en/stable/faq.html#how-to-deal-with-missing-values), in Solr an additional `"missing"` branch parameter has been introduced to support this feature.
+
+This defines the branch to follow when the corresponding feature value is null. With the default configuration a null and a zero value have the same meaning.
+
+To handle null values, the `"myFeatures.json"` file needs to be modified. A `"defaultValue"` parameter with a `"NaN"` value needs to be added to each feature that can assume a null value.
+
+.Example: /path/myFeatures.json
+[source,json]
+----
+[
+ {
+ "name": "matchedTitle",
+ "class": "org.apache.solr.ltr.feature.SolrFeature",
+ "params": {
+ "q": "{!terms f=title}${user_query}"
+ }
+ },
+ {
+ "name": "productReviewScore",
+ "class": "org.apache.solr.ltr.feature.FieldValueFeature",
+ "params": {
+ "field": "product_review_score",
+ "defaultValue": "NaN"
+ }
+ }
+]
+----
+
+Also, the model configuration needs two additional parameter:
+
+* `"isNullSameAsZero"` needs to be defined in the model `"params"` and set to `"false"`;
+
+* the `"missing"` parameter needs to be added to each branch where the corresponding feature supports null values. This can assume one value between `"left"` and `"right"`.
+
+.Example: /path/myModel.json
+[source,json]
+----
+{
+ "class":"org.apache.solr.ltr.model.MultipleAdditiveTreesModel",
+ "name":"multipleadditivetreesmodel",
+ "features":[
+ { "name": "matchedTitle"},
+ { "name": "productReviewScore"}
+ ],
+ "params":{
+ "isNullSameAsZero": "false",
+ "trees": [
+ {
+ "weight" : "1f",
+ "root": {
+ "feature": "matchedTitle",
+ "threshold": "0.5f",
+ "left" : {
+ "value" : "-100"
+ },
+ "right": {
+ "feature" : "productReviewScore",
+ "threshold": "0f",
+ "missing": "left",
+ "left" : {
+ "value" : "50"
+ },
+ "right" : {
+ "value" : "65"
+ }
+ }
+ }
+ }
+ ]
+ }
+}
+
+----
+
+When isNullSameAsZero is `"false"` for your model, the feature vector changes.
+
+* dense format: all features values are shown, also the default values which can be zero or null values.
+* sparse format: only non default values are shown.
+
+e.g.
+
+given the features defined in <<models-handling-features-null-values>>.
+If their values are `matchedTitle=0` and `productReviewScore=0`, the sparse format will return `productReviewScore:0` (0 is the default value of `matchedTitle=0` and therefore it is not returned, 0 is not the default value of `productReviewScore=0` and therefore it is returned).
+
==== Implementation and Contributions
How does Solr Learning-To-Rank work under the hood?::