Skip to content

Commit

Permalink
Allow FieldRef in fold post-filtering operations. (#618)
Browse files Browse the repository at this point in the history
This lays the groundwork for `@fold @Transform(op: "count")` followed by more `@transform` operations before a subsequent `@filter`.
  • Loading branch information
obi1kenobi committed Jun 13, 2024
1 parent d8a7fdc commit 46b6ac2
Show file tree
Hide file tree
Showing 70 changed files with 463 additions and 109 deletions.
2 changes: 1 addition & 1 deletion trustfall_core/src/frontend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1141,7 +1141,7 @@ where
component_path,
tags,
starting_vid,
fold_specific_field.kind,
field_ref.clone(),
filter_directive,
) {
Ok(filter) => post_filters.push(filter),
Expand Down
95 changes: 68 additions & 27 deletions trustfall_core/src/interpreter/execution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -270,26 +270,42 @@ fn usize_from_field_value(field_value: &FieldValue) -> Option<usize> {
}
}

/// If this IRFold has a filter on the folded element count, and that filter imposes
/// If this [`IRFold`] has a filter on the folded element count, and that filter imposes
/// a max size that can be statically determined, return that max size so it can
/// be used for further optimizations. Otherwise, return None.
/// be used for further optimizations. Otherwise, return `None`.
fn get_max_fold_count_limit(carrier: &mut QueryCarrier, fold: &IRFold) -> Option<usize> {
let mut result: Option<usize> = None;

let query_arguments = &carrier.query.as_ref().expect("query was not returned").arguments;
for post_fold_filter in fold.post_filters.iter() {
let left = post_fold_filter.left();
if !left
.refers_to_fold_specific_field()
.is_some_and(|f| f.kind == FoldSpecificFieldKind::Count)
{
// This filter is not using the count of the fold.
continue;
}

if !matches!(left, FieldRef::FoldSpecificField(f) if f.kind == FoldSpecificFieldKind::Count)
{
// The filter expression is doing something more complex than we can currently analyze.
// Conservatively return `None` to disable optimizations here.
//
// TODO: Once `@transform` may be applied to property-like values, update the analysis
// here to be able to optimize in such cases as well.
return None;
}

let next_limit = match post_fold_filter {
Operation::Equals(FoldSpecificFieldKind::Count, Argument::Variable(var_ref))
| Operation::LessThanOrEqual(
FoldSpecificFieldKind::Count,
Argument::Variable(var_ref),
) => {
Operation::Equals(_, Argument::Variable(var_ref))
| Operation::LessThanOrEqual(_, Argument::Variable(var_ref)) => {
let variable_value =
usize_from_field_value(&query_arguments[&var_ref.variable_name])
.expect("for field value to be coercible to usize");
Some(variable_value)
}
Operation::LessThan(FoldSpecificFieldKind::Count, Argument::Variable(var_ref)) => {
Operation::LessThan(_, Argument::Variable(var_ref)) => {
let variable_value =
usize_from_field_value(&query_arguments[&var_ref.variable_name])
.expect("for field value to be coercible to usize");
Expand All @@ -300,7 +316,7 @@ fn get_max_fold_count_limit(carrier: &mut QueryCarrier, fold: &IRFold) -> Option
// The later full application of filters ensures correctness.
Some(variable_value.saturating_sub(1))
}
Operation::OneOf(FoldSpecificFieldKind::Count, Argument::Variable(var_ref)) => {
Operation::OneOf(_, Argument::Variable(var_ref)) => {
match &query_arguments[&var_ref.variable_name] {
FieldValue::List(v) => v
.iter()
Expand All @@ -325,25 +341,41 @@ fn get_max_fold_count_limit(carrier: &mut QueryCarrier, fold: &IRFold) -> Option
result
}

/// If this IRFold has a filter on the folded element count, and that filter imposes
/// If this [`IRFold`] has a filter on the folded element count, and that filter imposes
/// a min size that can be statically determined, return that min size so it can
/// be used for further optimizations. Otherwise, return None.
/// be used for further optimizations. Otherwise, return `None`.
fn get_min_fold_count_limit(carrier: &mut QueryCarrier, fold: &IRFold) -> Option<usize> {
let mut result: Option<usize> = None;

let query_arguments = &carrier.query.as_ref().expect("query was not returned").arguments;
for post_fold_filter in fold.post_filters.iter() {
let left = post_fold_filter.left();
if !left
.refers_to_fold_specific_field()
.is_some_and(|f| f.kind == FoldSpecificFieldKind::Count)
{
// This filter is not using the count of the fold.
continue;
}

if !matches!(left, FieldRef::FoldSpecificField(f) if f.kind == FoldSpecificFieldKind::Count)
{
// The filter expression is doing something more complex than we can currently analyze.
// Conservatively return `None` to disable optimizations here.
//
// TODO: Once `@transform` may be applied to property-like values, update the analysis
// here to be able to optimize in such cases as well.
return None;
}

let next_limit = match post_fold_filter {
Operation::GreaterThanOrEqual(
FoldSpecificFieldKind::Count,
Argument::Variable(var_ref),
) => {
Operation::GreaterThanOrEqual(_, Argument::Variable(var_ref)) => {
let variable_value =
usize_from_field_value(&query_arguments[&var_ref.variable_name])
.expect("for field value to be coercible to usize");
Some(variable_value)
}
Operation::GreaterThan(FoldSpecificFieldKind::Count, Argument::Variable(var_ref)) => {
Operation::GreaterThan(_, Argument::Variable(var_ref)) => {
let variable_value =
usize_from_field_value(&query_arguments[&var_ref.variable_name])
.expect("for field value to be coercible to usize");
Expand Down Expand Up @@ -587,15 +619,24 @@ fn compute_fold<'query, AdapterT: Adapter<'query> + 'query>(
let mut post_filtered_iterator: ContextIterator<'query, AdapterT::Vertex> =
Box::new(folded_iterator);
for post_fold_filter in fold.post_filters.iter() {
post_filtered_iterator = apply_fold_specific_filter(
adapter.as_ref(),
carrier,
parent_component,
fold.as_ref(),
expanding_from.vid,
post_fold_filter,
post_filtered_iterator,
);
let left = post_fold_filter.left();
match left {
FieldRef::ContextField(_) => {
unreachable!("unexpectedly found a fold post-filtering step that references a ContextField: {fold:#?}");
}
FieldRef::FoldSpecificField(fold_specific_field) => {
let remapped_operation = post_fold_filter.map(|_| fold_specific_field.kind, |x| x);
post_filtered_iterator = apply_fold_specific_filter(
adapter.as_ref(),
carrier,
parent_component,
fold.as_ref(),
expanding_from.vid,
&remapped_operation,
post_filtered_iterator,
);
}
}
}

// Compute the outputs from this fold.
Expand Down Expand Up @@ -772,7 +813,7 @@ fn apply_fold_specific_filter<'query, AdapterT: Adapter<'query>>(
component: &IRQueryComponent,
fold: &IRFold,
current_vid: Vid,
filter: &Operation<FoldSpecificFieldKind, Argument>,
filter: &Operation<FoldSpecificFieldKind, &Argument>,
iterator: ContextIterator<'query, AdapterT::Vertex>,
) -> ContextIterator<'query, AdapterT::Vertex> {
let fold_specific_field = filter.left();
Expand All @@ -792,7 +833,7 @@ fn apply_fold_specific_filter<'query, AdapterT: Adapter<'query>>(
carrier,
component,
current_vid,
&filter.map(|_| (), |r| r),
&filter.map(|_| (), |r| *r),
field_iterator,
)
}
Expand Down
9 changes: 6 additions & 3 deletions trustfall_core/src/interpreter/hints/filters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::{borrow::Cow, collections::BTreeMap, fmt::Debug, ops::Bound, sync::Arc}

use itertools::Itertools;

use crate::ir::{Argument, FieldValue, FoldSpecificFieldKind, IRFold, Operation};
use crate::ir::{Argument, FieldRef, FieldValue, FoldSpecificFieldKind, IRFold, Operation};

use super::{candidates::NullableValue, CandidateValue, Range};

Expand Down Expand Up @@ -114,8 +114,11 @@ pub(super) fn fold_requires_at_least_one_element(
query_variables: &BTreeMap<Arc<str>, FieldValue>,
fold: &IRFold,
) -> bool {
let relevant_filters =
fold.post_filters.iter().filter(|op| matches!(op.left(), FoldSpecificFieldKind::Count));
// TODO: When we support applying `@transform` to property-like values, we can update this logic
// to be smarter and less conservative.
let relevant_filters = fold.post_filters.iter().filter(|op| {
matches!(op.left(), FieldRef::FoldSpecificField(f) if f.kind == FoldSpecificFieldKind::Count)
});
let is_subject_field_nullable = false; // the "count" value can't be null
candidate_from_statically_evaluated_filters(
relevant_filters,
Expand Down
10 changes: 8 additions & 2 deletions trustfall_core/src/ir/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,18 @@ pub struct IRFold {
/// Outputs from this fold that are derived from fold-specific fields.
///
/// All [`FieldRef`] values in the map are guaranteed to have
/// `[FieldRef].refers_to_fold_specific_field().is_some() == true`.
/// `FieldRef.refers_to_fold_specific_field().is_some() == true`.
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub fold_specific_outputs: BTreeMap<Arc<str>, FieldRef>,

/// Filters that are applied on the fold as a whole.
///
/// For example, as in `@fold @transform(op: "count") @filter(op: "=", value: ["$zero"])`.
///
/// All [`FieldRef`] values inside each [`Operation`] within the `Vec` are guaranteed to have
/// `FieldRef.refers_to_fold_specific_field().is_some() == true`.
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub post_filters: Vec<Operation<FoldSpecificFieldKind, Argument>>,
pub post_filters: Vec<Operation<FieldRef, Argument>>,
}

#[non_exhaustive]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ Ok(TestIRQuery(
},
),
post_filters: [
Equals(Count, Variable(VariableRef(
Equals(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "two",
variable_type: "Int!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -463,7 +463,11 @@ TestInterpreterOutputTrace(
},
),
post_filters: [
Equals(Count, Variable(VariableRef(
Equals(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "two",
variable_type: "Int!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ Ok(TestIRQuery(
},
),
post_filters: [
Equals(Count, Variable(VariableRef(
Equals(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "one",
variable_type: "Int!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,11 @@ TestInterpreterOutputTrace(
},
),
post_filters: [
Equals(Count, Variable(VariableRef(
Equals(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "one",
variable_type: "Int!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ Ok(TestIRQuery(
},
),
post_filters: [
LessThan(Count, Variable(VariableRef(
LessThan(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "two",
variable_type: "Int!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,11 @@ TestInterpreterOutputTrace(
},
),
post_filters: [
LessThan(Count, Variable(VariableRef(
LessThan(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "two",
variable_type: "Int!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ Ok(TestIRQuery(
},
),
post_filters: [
LessThanOrEqual(Count, Variable(VariableRef(
LessThanOrEqual(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "one",
variable_type: "Int!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,11 @@ TestInterpreterOutputTrace(
},
),
post_filters: [
LessThanOrEqual(Count, Variable(VariableRef(
LessThanOrEqual(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "one",
variable_type: "Int!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ Ok(TestIRQuery(
},
),
post_filters: [
OneOf(Count, Variable(VariableRef(
OneOf(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "counts",
variable_type: "[Int!]!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -421,7 +421,11 @@ TestInterpreterOutputTrace(
},
),
post_filters: [
OneOf(Count, Variable(VariableRef(
OneOf(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "counts",
variable_type: "[Int!]!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ Ok(TestIRQuery(
},
),
post_filters: [
Equals(Count, Variable(VariableRef(
Equals(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "neg_two",
variable_type: "Int!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,11 @@ TestInterpreterOutputTrace(
},
),
post_filters: [
Equals(Count, Variable(VariableRef(
Equals(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "neg_two",
variable_type: "Int!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,11 @@ Ok(TestIRQuery(
},
),
post_filters: [
GreaterThan(Count, Variable(VariableRef(
GreaterThan(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "neg_two",
variable_type: "Int!",
))),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -581,7 +581,11 @@ TestInterpreterOutputTrace(
},
),
post_filters: [
GreaterThan(Count, Variable(VariableRef(
GreaterThan(FoldSpecificField(FoldSpecificField(
fold_eid: Eid(1),
fold_root_vid: Vid(2),
kind: Count,
)), Variable(VariableRef(
variable_name: "neg_two",
variable_type: "Int!",
))),
Expand Down
Loading

0 comments on commit 46b6ac2

Please sign in to comment.