diff --git a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/TextQuery.java b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/TextQuery.java index 84a5b9d47..bccbbc72e 100644 --- a/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/TextQuery.java +++ b/spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/TextQuery.java @@ -16,9 +16,9 @@ package org.springframework.data.mongodb.core.query; import java.util.Locale; +import java.util.Map.Entry; import org.bson.Document; - import org.springframework.data.mongodb.util.BsonUtils; import org.springframework.lang.Nullable; @@ -37,6 +37,7 @@ public class TextQuery extends Query { private String scoreFieldName = DEFAULT_SCORE_FIELD_FIELDNAME; private boolean includeScore = false; private boolean sortByScore = false; + private int sortByScoreIndex = 0; /** * Creates new {@link TextQuery} using the the given {@code wordsAndPhrases} with {@link TextCriteria} @@ -101,6 +102,7 @@ public class TextQuery extends Query { */ public TextQuery sortByScore() { + this.sortByScoreIndex = getSortObject().size(); this.includeScore(); this.sortByScore = true; return this; @@ -173,15 +175,35 @@ public class TextQuery extends Query { public Document getSortObject() { if (this.sortByScore) { - Document sort = new Document(); - sort.put(getScoreFieldName(), META_TEXT_SCORE); - sort.putAll(super.getSortObject()); - return sort; + if (sortByScoreIndex == 0) { + Document sort = new Document(); + sort.put(getScoreFieldName(), META_TEXT_SCORE); + sort.putAll(super.getSortObject()); + return sort; + } + return fitInSortByScoreAtPosition(super.getSortObject()); } return super.getSortObject(); } + private Document fitInSortByScoreAtPosition(Document source) { + + Document target = new Document(); + int i = 0; + for (Entry entry : source.entrySet()) { + if (i == sortByScoreIndex) { + target.put(getScoreFieldName(), META_TEXT_SCORE); + } + target.put(entry.getKey(), entry.getValue()); + i++; + } + if (i == sortByScoreIndex) { + target.put(getScoreFieldName(), META_TEXT_SCORE); + } + return target; + } + /* * (non-Javadoc) * @see org.springframework.data.mongodb.core.query.Query#isSorted() diff --git a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/TextQueryUnitTests.java b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/TextQueryUnitTests.java index 78616e792..ffdeaf39b 100644 --- a/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/TextQueryUnitTests.java +++ b/spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/TextQueryUnitTests.java @@ -17,8 +17,9 @@ package org.springframework.data.mongodb.core.query; import static org.springframework.data.mongodb.test.util.Assertions.*; -import org.junit.jupiter.api.Test; +import java.util.Map.Entry; +import org.junit.jupiter.api.Test; import org.springframework.data.domain.Sort; import org.springframework.data.domain.Sort.Direction; @@ -94,4 +95,28 @@ public class TextQueryUnitTests { assertThat(query.getSortObject()).containsKey("customFieldForScore"); } + @Test // GH-3896 + public void retainsSortOrderWhenUsingScore() { + + TextQuery query = new TextQuery(QUERY); + query.with(Sort.by(Direction.DESC, "one")); + query.sortByScore(); + query.with(Sort.by(Direction.DESC, "two")); + + assertThat(query.getSortObject().entrySet().stream().map(Entry::getKey)).containsExactly("one", "score", "two"); + + query = new TextQuery(QUERY); + query.with(Sort.by(Direction.DESC, "one")); + query.sortByScore(); + + assertThat(query.getSortObject().entrySet().stream().map(Entry::getKey)).containsExactly("one", "score"); + + query = new TextQuery(QUERY); + query.sortByScore(); + query.with(Sort.by(Direction.DESC, "one")); + query.with(Sort.by(Direction.DESC, "two")); + + assertThat(query.getSortObject().entrySet().stream().map(Entry::getKey)).containsExactly("score", "one", "two"); + } + }