Skip to content

Commit

Permalink
reuse the innerHit weight for the performance improvement (#581)
Browse files Browse the repository at this point in the history
  • Loading branch information
waziqi89 authored Jun 28, 2023
1 parent e0d8bc3 commit 574ec46
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ public Query getQuery(com.yelp.nrtsearch.server.grpc.Query query, IndexState sta
}

public Query applyQueryNestedPath(Query query, IndexState indexState, String path) {
if (path == null || path.isEmpty()) {
return query;
}
BooleanQuery.Builder builder = new BooleanQuery.Builder();
builder.add(getNestedPathQuery(indexState, path), BooleanClause.Occur.FILTER);
builder.add(query, BooleanClause.Occur.MUST);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,6 @@
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.TopFieldCollector;
import org.apache.lucene.search.TopScoreDocCollector;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.QueryBitSetProducer;

/**
* Object to store all necessary context information for {@link InnerHitFetchTask} to search and
Expand All @@ -52,17 +50,10 @@ public class InnerHitContext implements FieldFetchContext {

private static final int DEFAULT_INNER_HIT_TOP_HITS = 3;

public BitSetProducer getParentFilter() {
return parentFilter;
}

public String getInnerHitName() {
return innerHitName;
}

private final String innerHitName;
private final BitSetProducer parentFilter;
private final Query parentFilterQuery;
private final String queryNestedPath;
private final Query childFilterQuery;
private final Query query;
private final IndexState indexState;
private final ShardState shardState;
Expand All @@ -85,12 +76,11 @@ private InnerHitContext(InnerHitContextBuilder builder, boolean needValidation)
this.indexState = builder.indexState;
this.shardState = builder.shardState;
this.searcherAndTaxonomy = builder.searcherAndTaxonomy;
this.parentFilter =
new QueryBitSetProducer(
QueryNodeMapper.getInstance()
.getNestedPathQuery(indexState, builder.parentQueryNestedPath));
// rewrite the query in advance so that it won't be rewritten per hit.
this.query = searcherAndTaxonomy.searcher.rewrite(builder.query);
this.parentFilterQuery =
QueryNodeMapper.getInstance().getNestedPathQuery(indexState, builder.parentQueryNestedPath);
this.childFilterQuery =
QueryNodeMapper.getInstance().getNestedPathQuery(indexState, queryNestedPath);
this.query = builder.query;
this.startHit = builder.startHit;
// TODO: implement the totalCountCollector in case (topHits == 0 || startHit >= topHits).
// Currently, return DEFAULT_INNER_HIT_TOP_HITS results in case of 0.
Expand Down Expand Up @@ -156,13 +146,28 @@ private void validate() {
}
}

/** Get parent filter query. */
public Query getParentFilterQuery() {
return parentFilterQuery;
}

/** Get the name of the innerHit task. */
public String getInnerHitName() {
return innerHitName;
}

/**
* Get the nested path for the innerHit query. This path is the field name of the nested object.
*/
public String getQueryNestedPath() {
return queryNestedPath;
}

/** Get child filter query. */
public Query getChildFilterQuery() {
return childFilterQuery;
}

/**
* Get the query for the innerHit. Should assume this query is directly searched against the child
* documents only. Omitted this field to retrieve all children for each hit.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,17 +27,28 @@
import com.yelp.nrtsearch.server.luceneserver.search.SortParser;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.concurrent.atomic.DoubleAdder;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.CollectionTerminatedException;
import org.apache.lucene.search.ConjunctionDISI;
import org.apache.lucene.search.DocIdSetIterator;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.LeafCollector;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.ScoreMode;
import org.apache.lucene.search.Scorer;
import org.apache.lucene.search.ScorerSupplier;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopDocsCollector;
import org.apache.lucene.search.Weight;
import org.apache.lucene.search.join.BitSetProducer;
import org.apache.lucene.search.join.ParentChildrenBlockJoinQuery;
import org.apache.lucene.search.join.QueryBitSetProducer;

/**
* InnerHit fetch task does a mini-scale search per hit against all child documents for this hit.
Expand All @@ -51,27 +62,56 @@ public InnerHitContext getInnerHitContext() {
}

private final InnerHitContext innerHitContext;
private final IndexSearcher searcher;
private final BitSetProducer parentFilter;
private final Weight innerHitWeight;

private final DoubleAdder getFieldsTimeMs = new DoubleAdder();
private final DoubleAdder firstPassSearchTimeMs = new DoubleAdder();

public InnerHitFetchTask(InnerHitContext innerHitContext) {
public InnerHitFetchTask(InnerHitContext innerHitContext) throws IOException {
this.innerHitContext = innerHitContext;
this.searcher = innerHitContext.getSearcherAndTaxonomy().searcher;
boolean needScore =
innerHitContext.getTopHits() >= innerHitContext.getStartHit()
&& (innerHitContext.getSort() == null || innerHitContext.getSort().needsScores());
// We support TopDocsCollector only, so top_scores is good enough
this.innerHitWeight =
searcher
.rewrite(innerHitContext.getQuery())
.createWeight(
searcher, needScore ? ScoreMode.TOP_SCORES : ScoreMode.COMPLETE_NO_SCORES, 1f);
this.parentFilter =
new QueryBitSetProducer(searcher.rewrite(innerHitContext.getParentFilterQuery()));
}

/**
* Collect all inner hits for each parent hit. Normally, {@link IndexSearcher} will create weight
* each time we search. But for innerHit, child query weight is reusable. Therefore, for max
* efficiency, we will not use search from the {@link IndexSearcher}. Instead, we create two
* weights separately - the non-reusable {@link ParentChildrenBlockJoinQuery}'s weight and the
* reusable innerHitQuery weight, and then intersect the {@link DocIdSetIterator}s created from
* them.
*/
public void processHit(
SearchContext searchContext, LeafReaderContext hitLeaf, SearchResponse.Hit.Builder hit)
throws IOException {
long startTime = System.nanoTime();
IndexSearcher searcher = innerHitContext.getSearcherAndTaxonomy().searcher;

// This is just a children selection query for each parent hit. And score is not needed for this
// filter query.
ParentChildrenBlockJoinQuery parentChildrenBlockJoinQuery =
new ParentChildrenBlockJoinQuery(
innerHitContext.getParentFilter(), innerHitContext.getQuery(), hit.getLuceneDocId());
parentFilter, innerHitContext.getChildFilterQuery(), hit.getLuceneDocId());
Weight filterWeight =
parentChildrenBlockJoinQuery.createWeight(searcher, ScoreMode.COMPLETE_NO_SCORES, 1f);
// All child documents are guaranteed to be stored in the same leaf as the parent document.
// Therefore, a single collector without reduce is sufficient to collect all.
TopDocsCollector topDocsCollector = innerHitContext.getTopDocsCollectorManager().newCollector();
searcher.search(parentChildrenBlockJoinQuery, topDocsCollector);

intersectWeights(filterWeight, innerHitWeight, topDocsCollector, hitLeaf);
TopDocs topDocs = topDocsCollector.topDocs();

if (innerHitContext.getStartHit() > 0) {
topDocs =
SearchHandler.getHitsFromOffset(
Expand Down Expand Up @@ -117,6 +157,44 @@ public void processHit(
getFieldsTimeMs.add(((System.nanoTime() - startTime) / NS_PER_MS));
}

private void intersectWeights(
Weight filterWeight,
Weight innerHitWeight,
TopDocsCollector topDocsCollector,
LeafReaderContext hitLeaf)
throws IOException {
ScorerSupplier filterScorerSupplier = filterWeight.scorerSupplier(hitLeaf);
if (filterScorerSupplier == null) {
return;
}
Scorer filterScorer = filterScorerSupplier.get(0);

ScorerSupplier innerHitScorerSupplier = innerHitWeight.scorerSupplier(hitLeaf);
if (innerHitScorerSupplier == null) {
return;
}
Scorer innerHitScorer = innerHitScorerSupplier.get(0);

DocIdSetIterator iterator =
ConjunctionDISI.intersectIterators(
Arrays.asList(filterScorer.iterator(), innerHitScorer.iterator()));

LeafCollector leafCollector = topDocsCollector.getLeafCollector(hitLeaf);
// filterWeight is always COMPLETE_NO_SCORES
try {
leafCollector.setScorer(innerHitScorer);
} catch (CollectionTerminatedException exception) {
// Same as the indexSearcher, innerHit shall swallow this exception. No doc to collect in this
// case.
return;
}

int docId;
while ((docId = iterator.nextDoc()) != DocIdSetIterator.NO_MORE_DOCS) {
leafCollector.collect(docId);
}
}

public SearchResponse.Diagnostics getDiagnostic() {
Builder builder =
Diagnostics.newBuilder()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,20 +165,19 @@ public static SearchContext buildContextForRequest(

List<InnerHitFetchTask> innerHitFetchTasks = null;
if (searchRequest.getInnerHitsCount() > 0) {
innerHitFetchTasks =
searchRequest.getInnerHitsMap().keySet().stream()
.map(
innerHitName ->
buildInnerHitContext(
indexState,
shardState,
queryFields,
searcherAndTaxonomy,
rootQueryNestedPath,
innerHitName,
searchRequest.getInnerHitsOrThrow(innerHitName)))
.map(InnerHitFetchTask::new)
.collect(Collectors.toList());
innerHitFetchTasks = new ArrayList<>(searchRequest.getInnerHitsCount());
for (Entry<String, InnerHit> entry : searchRequest.getInnerHitsMap().entrySet()) {
innerHitFetchTasks.add(
new InnerHitFetchTask(
buildInnerHitContext(
indexState,
shardState,
queryFields,
searcherAndTaxonomy,
rootQueryNestedPath,
entry.getKey(),
entry.getValue())));
}
}

contextBuilder.setFetchTasks(
Expand Down Expand Up @@ -455,8 +454,8 @@ private static InnerHitContext buildInnerHitContext(
String parentQueryNestedPath,
String innerHitName,
InnerHit innerHit) {
Query childQuery =
extractQuery(indexState, "", innerHit.getInnerQuery(), innerHit.getQueryNestedPath());
// Do not apply nestedPath here. This is query is used to create a shared weight.
Query childQuery = extractQuery(indexState, "", innerHit.getInnerQuery(), null);
return InnerHitContextBuilder.Builder()
.withInnerHitName(innerHitName)
.withQuery(childQuery)
Expand Down

0 comments on commit 574ec46

Please sign in to comment.