This is an automated email from the ASF dual-hosted git repository. abenedetti pushed a commit to branch branch_9x in repository https://gitbox.apache.org/repos/asf/solr.git
commit da35f09a5f81efa7d9530e1a208d1b24c1137622 Author: aruggero <[email protected]> AuthorDate: Mon Jun 30 11:03:04 2025 +0200 SOLR-17760: solving bug in LTR dense/sparse format (#3354) * Fixed field value feature (cherry picked from commit ba981cd76c21df23ee518d3ee74f26e6f673ebe5) --- solr/CHANGES.txt | 2 + .../java/org/apache/solr/ltr/CSVFeatureLogger.java | 2 +- .../java/org/apache/solr/ltr/LTRScoringQuery.java | 37 +++--- .../solr/ltr/TestSelectiveWeightCreation.java | 19 +-- .../solr/ltr/feature/TestFieldValueFeature.java | 146 +++++++++++++-------- .../query-guide/pages/learning-to-rank.adoc | 30 +++++ 6 files changed, 153 insertions(+), 83 deletions(-) diff --git a/solr/CHANGES.txt b/solr/CHANGES.txt index f8364b40eac..37e8beff119 100644 --- a/solr/CHANGES.txt +++ b/solr/CHANGES.txt @@ -132,6 +132,8 @@ Bug Fixes * SOLR-17726: MoreLikeThis to support copy-fields (Ilaria Petreti via Alessandro Benedetti) +* SOLR-16667: Fixed dense/sparse representation in LTR module. (Anna Ruggero, Alessandro Benedetti) + Dependency Upgrades --------------------- * SOLR-17471: Upgrade Lucene to 9.12.1. (Pierre Salagnac, Christine Poerschke) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java index 73e98b249ed..22ddcb8724a 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/CSVFeatureLogger.java @@ -44,7 +44,7 @@ public class CSVFeatureLogger extends FeatureLogger { StringBuilder sb = new StringBuilder(featuresInfo.length * 3); boolean isDense = featureFormat.equals(FeatureFormat.DENSE); for (LTRScoringQuery.FeatureInfo featInfo : featuresInfo) { - if (featInfo != null && (isDense || featInfo.isUsed())) { + if (featInfo != null && (isDense || !featInfo.isDefaultValue())) { sb.append(featInfo.getName()) .append(keyValueSep) .append(featInfo.getValue()) diff --git a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java index 8731cccb346..0ce2e21217d 100644 --- a/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java +++ b/solr/modules/ltr/src/java/org/apache/solr/ltr/LTRScoringQuery.java @@ -331,12 +331,12 @@ public class LTRScoringQuery extends Query implements Accountable { public static class FeatureInfo { private final String name; private float value; - private boolean used; + private boolean isDefaultValue; - FeatureInfo(String n, float v, boolean u) { - name = n; - value = v; - used = u; + FeatureInfo(String name, float value, boolean isDefaultValue) { + this.name = name; + this.value = value; + this.isDefaultValue = isDefaultValue; } public void setValue(float value) { @@ -351,12 +351,12 @@ public class LTRScoringQuery extends Query implements Accountable { return value; } - public boolean isUsed() { - return used; + public boolean isDefaultValue() { + return isDefaultValue; } - public void setUsed(boolean used) { - this.used = used; + public void setIsDefaultValue(boolean isDefaultValue) { + this.isDefaultValue = isDefaultValue; } } @@ -408,7 +408,7 @@ public class LTRScoringQuery extends Query implements Accountable { String featName = extractedFeatureWeights[i].getName(); int featId = extractedFeatureWeights[i].getIndex(); float value = extractedFeatureWeights[i].getDefaultValue(); - featuresInfo[featId] = new FeatureInfo(featName, value, false); + featuresInfo[featId] = new FeatureInfo(featName, value, true); } } @@ -440,12 +440,7 @@ public class LTRScoringQuery extends Query implements Accountable { for (final Feature.FeatureWeight feature : modelFeatureWeights) { final int featureId = feature.getIndex(); FeatureInfo fInfo = featuresInfo[featureId]; - // not checking for finfo == null as that would be a bug we should catch - if (fInfo.isUsed()) { - modelFeatureValuesNormalized[pos] = fInfo.getValue(); - } else { - modelFeatureValuesNormalized[pos] = feature.getDefaultValue(); - } + modelFeatureValuesNormalized[pos] = fInfo.getValue(); pos++; } ltrScoringModel.normalizeFeaturesInPlace(modelFeatureValuesNormalized); @@ -480,7 +475,7 @@ public class LTRScoringQuery extends Query implements Accountable { // need to set default value everytime as the default value is used in 'dense' // mode even if used=false featuresInfo[featId].setValue(value); - featuresInfo[featId].setUsed(false); + featuresInfo[featId].setIsDefaultValue(true); } } @@ -598,7 +593,9 @@ public class LTRScoringQuery extends Query implements Accountable { Feature.FeatureWeight scFW = (Feature.FeatureWeight) subScorer.getWeight(); final int featureId = scFW.getIndex(); featuresInfo[featureId].setValue(subScorer.score()); - featuresInfo[featureId].setUsed(true); + if (featuresInfo[featureId].getValue() != scFW.getDefaultValue()) { + featuresInfo[featureId].setIsDefaultValue(false); + } } } return makeNormalizedFeaturesAndScore(); @@ -683,7 +680,9 @@ public class LTRScoringQuery extends Query implements Accountable { Feature.FeatureWeight scFW = (Feature.FeatureWeight) scorer.getWeight(); final int featureId = scFW.getIndex(); featuresInfo[featureId].setValue(scorer.score()); - featuresInfo[featureId].setUsed(true); + if (featuresInfo[featureId].getValue() != scFW.getDefaultValue()) { + featuresInfo[featureId].setIsDefaultValue(false); + } } } } diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java index 6f730bf86a3..d96bcc61886 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/TestSelectiveWeightCreation.java @@ -142,7 +142,10 @@ public class TestSelectiveWeightCreation extends TestRerankBase { assertEquals("11", searcher.storedFields().document(hits.scoreDocs[1].doc).get("id")); List<Feature> features = makeFeatures(new int[] {0, 1, 2}); + List<Feature> expectedNonDefaultFeatures = makeFeatures(new int[] {1, 2}); final List<Feature> allFeatures = makeFeatures(new int[] {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}); + List<Feature> expectedNonDefaultAllFeatures = + makeFeatures(new int[] {1, 2, 3, 4, 5, 6, 7, 8, 9}); final List<Normalizer> norms = new ArrayList<>(); for (int k = 0; k < features.size(); ++k) { norms.add(IdentityNormalizer.INSTANCE); @@ -167,13 +170,13 @@ public class TestSelectiveWeightCreation extends TestRerankBase { LTRScoringQuery.FeatureInfo[] featuresInfo = modelWeight.getFeaturesInfo(); assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length); - int validFeatures = 0; + int nonDefaultFeatures = 0; for (int i = 0; i < featuresInfo.length; ++i) { - if (featuresInfo[i] != null && featuresInfo[i].isUsed()) { - validFeatures += 1; + if (featuresInfo[i] != null && !featuresInfo[i].isDefaultValue()) { + nonDefaultFeatures += 1; } } - assertEquals(validFeatures, features.size()); + assertEquals(expectedNonDefaultFeatures.size(), nonDefaultFeatures); // when features are requested in the response, weights should be created for all features final LTRScoringModel ltrScoringModel2 = @@ -194,13 +197,13 @@ public class TestSelectiveWeightCreation extends TestRerankBase { assertEquals(features.size(), modelWeight.getModelFeatureValuesNormalized().length); assertEquals(allFeatures.size(), modelWeight.getExtractedFeatureWeights().length); - validFeatures = 0; + nonDefaultFeatures = 0; for (int i = 0; i < featuresInfo.length; ++i) { - if (featuresInfo[i] != null && featuresInfo[i].isUsed()) { - validFeatures += 1; + if (featuresInfo[i] != null && !featuresInfo[i].isDefaultValue()) { + nonDefaultFeatures += 1; } } - assertEquals(validFeatures, allFeatures.size()); + assertEquals(expectedNonDefaultAllFeatures.size(), nonDefaultFeatures); assertU(delI("10")); assertU(delI("11")); diff --git a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFieldValueFeature.java b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFieldValueFeature.java index b10d9d7f952..d8d9d308ac1 100644 --- a/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFieldValueFeature.java +++ b/solr/modules/ltr/src/test/org/apache/solr/ltr/feature/TestFieldValueFeature.java @@ -250,6 +250,31 @@ public class TestFieldValueFeature extends TestRerankBase { assertJQ("/query" + query.toQueryString(), "/response/numFound/==1"); assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='42'"); + final String docs0fv_dense_csv = + FeatureLoggerTestUtils.toFeatureVector( + "popularity", + "0.0", + "dvIntPopularity", + "0.0", + "dvLongPopularity", + "0.0", + "dvFloatPopularity", + "0.0", + "dvDoublePopularity", + "0.0", + "dvStringPopularity", + "0.0", + "isTrendy", + "0.0", + "dvIsTrendy", + "0.0", + "storedDvIsTrendy", + "0.0"); + final String docs0fv_sparse_csv = FeatureLoggerTestUtils.toFeatureVector(""); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + query = new SolrQuery(); query.setQuery("id:42"); query.add("rq", "{!ltr model=model reRankDocs=4}"); @@ -262,9 +287,7 @@ public class TestFieldValueFeature extends TestRerankBase { assertJQ("/query" + query.toQueryString(), "/response/numFound/==1"); assertJQ( "/query" + query.toQueryString(), - "/response/docs/[0]/=={'[fv]':'popularity=0.0,dvIntPopularity=0.0,dvLongPopularity=0.0," - + "dvFloatPopularity=0.0,dvDoublePopularity=0.0," - + "dvStringPopularity=0.0,isTrendy=0.0,dvIsTrendy=0.0,storedDvIsTrendy=0.0'}"); + "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); } @Test @@ -292,16 +315,22 @@ public class TestFieldValueFeature extends TestRerankBase { assertJQ("/query" + query.toQueryString(), "/response/numFound/==1"); assertJQ("/query" + query.toQueryString(), "/response/docs/[0]/id=='42'"); + + final String docs0fv_dense_csv = FeatureLoggerTestUtils.toFeatureVector(field + "42", "42.0"); + final String docs0fv_sparse_csv = FeatureLoggerTestUtils.toFeatureVector(""); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + query = new SolrQuery(); query.setQuery("id:42"); query.add("rq", "{!ltr model=" + field + "-model42 reRankDocs=4}"); query.add("fl", "[fv]"); + assertJQ("/query" + query.toQueryString(), "/response/numFound/==1"); assertJQ( "/query" + query.toQueryString(), - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector(field + "42", "42.0") - + "'}"); + "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); } } @@ -360,6 +389,14 @@ public class TestFieldValueFeature extends TestRerankBase { fstore, "{\"weights\":{\"not-existing-field\":1.0}}"); + final String docs0fv_dense_csv = + FeatureLoggerTestUtils.toFeatureVector( + "not-existing-field", Float.toString(FIELD_VALUE_FEATURE_DEFAULT_VAL)); + final String docs0fv_sparse_csv = FeatureLoggerTestUtils.toFeatureVector(""); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + final SolrQuery query = new SolrQuery(); query.setQuery("id:42"); query.add("rq", "{!ltr model=not-existing-field-model reRankDocs=4}"); @@ -368,10 +405,7 @@ public class TestFieldValueFeature extends TestRerankBase { assertJQ("/query" + query.toQueryString(), "/response/numFound/==1"); assertJQ( "/query" + query.toQueryString(), - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector( - "not-existing-field", Float.toString(FIELD_VALUE_FEATURE_DEFAULT_VAL)) - + "'}"); + "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); assertEquals( FieldValueFeatureScorer.class.getName(), ObservingFieldValueFeature.usedScorerClass); } @@ -394,6 +428,14 @@ public class TestFieldValueFeature extends TestRerankBase { fstore, "{\"weights\":{\"" + field + "\":1.0}}"); + final String docs0fv_dense_csv = + FeatureLoggerTestUtils.toFeatureVector( + field, Float.toString(FIELD_VALUE_FEATURE_DEFAULT_VAL)); + final String docs0fv_sparse_csv = FeatureLoggerTestUtils.toFeatureVector(""); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + final SolrQuery query = new SolrQuery("id:42"); query.add("rq", "{!ltr model=" + field + "-model reRankDocs=4}"); query.add("fl", "[fv]"); @@ -402,10 +444,7 @@ public class TestFieldValueFeature extends TestRerankBase { assertJQ("/query" + query.toQueryString(), "/response/numFound/==1"); assertJQ( "/query" + query.toQueryString(), - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector( - field, Float.toString(FIELD_VALUE_FEATURE_DEFAULT_VAL)) - + "'}"); + "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); assertEquals( DefaultValueFieldValueFeatureScorer.class.getName(), ObservingFieldValueFeature.usedScorerClass); @@ -442,15 +481,19 @@ public class TestFieldValueFeature extends TestRerankBase { fstore, "{\"weights\":{\"trendy\":1.0}}"); + final String docs0fv_dense_csv = FeatureLoggerTestUtils.toFeatureVector("trendy", "0.0"); + final String docs0fv_sparse_csv = FeatureLoggerTestUtils.toFeatureVector(""); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + SolrQuery query = new SolrQuery(); query.setQuery("id:4"); query.add("rq", "{!ltr model=trendy-model reRankDocs=4}"); query.add("fl", "[fv]"); assertJQ( "/query" + query.toQueryString(), - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector("trendy", "0.0") - + "'}"); + "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); query = new SolrQuery(); query.setQuery("id:5"); @@ -469,9 +512,7 @@ public class TestFieldValueFeature extends TestRerankBase { query.add("fl", "[fv]"); assertJQ( "/query" + query.toQueryString(), - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector("trendy", "0.0") - + "'}"); + "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}"); } @Test @@ -583,6 +624,17 @@ public class TestFieldValueFeature extends TestRerankBase { @Test public void testThatStringValuesAreCorrectlyParsed() throws Exception { for (String field : new String[] {"dvStrNumField", "noDvStrNumField"}) { + final String false_docs0fv_dense_csv = FeatureLoggerTestUtils.toFeatureVector(field, "0.0"); + final String default_docs0fv_dense_csv = + FeatureLoggerTestUtils.toFeatureVector( + field, Float.toString(FIELD_VALUE_FEATURE_DEFAULT_VAL)); + final String docs0fv_sparse_csv = FeatureLoggerTestUtils.toFeatureVector(""); + + final String false_docs0fv_default_csv = + chooseDefaultFeatureVector(false_docs0fv_dense_csv, docs0fv_sparse_csv); + final String default_docs0fv_default_csv = + chooseDefaultFeatureVector(default_docs0fv_dense_csv, docs0fv_sparse_csv); + final String[][] inputsAndTests = { new String[] { "T", @@ -590,39 +642,17 @@ public class TestFieldValueFeature extends TestRerankBase { + FeatureLoggerTestUtils.toFeatureVector(field, "1.0") + "'}" }, + new String[] {"F", "/response/docs/[0]/=={'[fv]':'" + false_docs0fv_default_csv + "'}"}, new String[] { - "F", - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector(field, "0.0") - + "'}" - }, - new String[] { - "-7324.427", - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector( - field, Float.toString(FIELD_VALUE_FEATURE_DEFAULT_VAL)) - + "'}" - }, - new String[] { - "532", - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector( - field, Float.toString(FIELD_VALUE_FEATURE_DEFAULT_VAL)) - + "'}" + "-7324.427", "/response/docs/[0]/=={'[fv]':'" + default_docs0fv_default_csv + "'}" }, + new String[] {"532", "/response/docs/[0]/=={'[fv]':'" + default_docs0fv_default_csv + "'}"}, new String[] { Float.toString(Float.NaN), - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector( - field, Float.toString(FIELD_VALUE_FEATURE_DEFAULT_VAL)) - + "'}" + "/response/docs/[0]/=={'[fv]':'" + default_docs0fv_default_csv + "'}" }, new String[] { - "notanumber", - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector( - field, Float.toString(FIELD_VALUE_FEATURE_DEFAULT_VAL)) - + "'}" + "notanumber", "/response/docs/[0]/=={'[fv]':'" + default_docs0fv_default_csv + "'}" } }; @@ -651,12 +681,15 @@ public class TestFieldValueFeature extends TestRerankBase { @Test public void testThatDateValuesAreCorrectlyParsed() throws Exception { for (String field : new String[] {"dvDateField", "noDvDateField"}) { + final String docs0fv_dense_csv = FeatureLoggerTestUtils.toFeatureVector(field, "0.0"); + final String docs0fv_sparse_csv = FeatureLoggerTestUtils.toFeatureVector(""); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + final String[][] inputsAndTests = { new String[] { - "1970-01-01T00:00:00.000Z", - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector(field, "0.0") - + "'}" + "1970-01-01T00:00:00.000Z", "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}" }, new String[] { "1970-01-01T00:00:00.001Z", @@ -753,12 +786,15 @@ public class TestFieldValueFeature extends TestRerankBase { public void testRelativeDateFieldValueFeature() throws Exception { final String field = "dvDateField"; for (boolean since : new boolean[] {false, true}) { + final String docs0fv_dense_csv = FeatureLoggerTestUtils.toFeatureVector(field, "0.0"); + final String docs0fv_sparse_csv = FeatureLoggerTestUtils.toFeatureVector(""); + + final String docs0fv_default_csv = + chooseDefaultFeatureVector(docs0fv_dense_csv, docs0fv_sparse_csv); + final String[][] inputsAndTests = { new String[] { - "2000-01-01T00:00:00.000Z", - "/response/docs/[0]/=={'[fv]':'" - + FeatureLoggerTestUtils.toFeatureVector(field, "0.0") - + "'}" + "2000-01-01T00:00:00.000Z", "/response/docs/[0]/=={'[fv]':'" + docs0fv_default_csv + "'}" }, new String[] { "2000-01-01T00:01:02.003Z", diff --git a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc index e40deb89cc1..27895e77829 100644 --- a/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc +++ b/solr/solr-ref-guide/modules/query-guide/pages/learning-to-rank.adoc @@ -328,6 +328,36 @@ http://localhost:8983/solr/techproducts/schema/feature-store/_DEFAULT_ ] ---- +==== Feature Parameters + +All the ltr feature types accept the parameters described below. + +`defaultValue`:: ++ +[%autowidth,frame=none] +|=== +|Optional |Default: `0.0` +|=== ++ +This parameter specifies the default value of the feature to use for both logging and reranking. ++ +.Example: /path/myFeatures.json +[source,json] +---- +[ + { + "name": "productReviewScore", + "class": "org.apache.solr.ltr.feature.FieldValueFeature", + "params": { + "field": "product_review_score", + "defaultValue": "5.2" + } + } +] +---- ++ +CAUTION: When defining a `defaultValue` for a `FieldValueFeature`, check that no `default` is assigned to that field in the schema, otherwise, the feature value will be the one defined in the schema and not in the feature store. + === Logging Features To log features as part of a query, add `[features]` to the `fl` parameter, for example:
