From 387c1ff856c184506710a6d9e9102eedb8b6d4f1 Mon Sep 17 00:00:00 2001 From: Predrag Gruevski <2348618+obi1kenobi@users.noreply.github.com> Date: Sat, 15 Jun 2024 13:43:05 -0400 Subject: [PATCH] Implement execution for `@transform` applied to filters' left-hand operand. (#626) --- trustfall_core/src/interpreter/execution.rs | 163 ++++++++++++---- trustfall_core/src/interpreter/filtering.rs | 71 ++----- trustfall_core/src/interpreter/mod.rs | 2 + trustfall_core/src/interpreter/tags.rs | 72 +++++++ .../src/interpreter/transformation.rs | 180 ++++++++++++++++++ trustfall_core/src/ir/mod.rs | 2 +- 6 files changed, 394 insertions(+), 96 deletions(-) create mode 100644 trustfall_core/src/interpreter/tags.rs create mode 100644 trustfall_core/src/interpreter/transformation.rs diff --git a/trustfall_core/src/interpreter/execution.rs b/trustfall_core/src/interpreter/execution.rs index 483dd63f..8835e807 100644 --- a/trustfall_core/src/interpreter/execution.rs +++ b/trustfall_core/src/interpreter/execution.rs @@ -8,15 +8,17 @@ use crate::{ ir::{ Argument, ContextField, EdgeParameters, Eid, FieldRef, FieldValue, FoldSpecificFieldKind, IREdge, IRFold, IRQueryComponent, IRVertex, IndexedQuery, LocalField, Operation, - OperationSubject, Recursive, Vid, + OperationSubject, Recursive, TransformBase, Vid, }, util::BTreeMapTryInsertExt, }; use super::{ - error::QueryArgumentsError, filtering::apply_filter, Adapter, AsVertex, ContextIterator, - ContextOutcomeIterator, DataContext, InterpretedQuery, ResolveEdgeInfo, ResolveInfo, - TaggedValue, ValueOrVec, VertexIterator, + error::QueryArgumentsError, + filtering::apply_filter, + transformation::{apply_transforms, push_transform_argument_tag_values_onto_stack}, + Adapter, AsVertex, ContextIterator, ContextOutcomeIterator, DataContext, InterpretedQuery, + ResolveEdgeInfo, ResolveInfo, TaggedValue, ValueOrVec, VertexIterator, }; #[derive(Debug, Clone)] @@ -621,25 +623,15 @@ fn compute_fold<'query, AdapterT: Adapter<'query> + 'query>( let mut post_filtered_iterator: ContextIterator<'query, AdapterT::Vertex> = Box::new(folded_iterator); for post_fold_filter in fold.post_filters.iter() { - let left = post_fold_filter.left(); - match left { - OperationSubject::FoldSpecificField(fold_specific_field) => { - let remapped_operation = post_fold_filter.map(|_| fold_specific_field.kind, |x| x); - post_filtered_iterator = apply_fold_specific_filter( - adapter.as_ref(), - carrier, - parent_component, - fold.as_ref(), - expanding_from.vid, - &remapped_operation, - post_filtered_iterator, - ); - } - OperationSubject::TransformedField(_) => todo!(), - OperationSubject::LocalField(_) => { - unreachable!("unexpectedly found a fold post-filtering step that references a LocalField: {fold:#?}"); - } - } + post_filtered_iterator = apply_fold_specific_filter( + adapter.as_ref(), + carrier, + parent_component, + fold.as_ref(), + expanding_from.vid, + post_fold_filter, + post_filtered_iterator, + ); } // Compute the outputs from this fold. @@ -809,7 +801,63 @@ fn apply_filter_with_non_folded_field_subject<'query, AdapterT: Adapter<'query>> filter.map_left(|_| field), iterator, ), - OperationSubject::TransformedField(_) => todo!(), + OperationSubject::TransformedField(transformed) => { + let prepped_iterator = push_transform_argument_tag_values_onto_stack( + adapter, + carrier, + component, + current_vid, + &transformed.value.transforms, + iterator, + ); + + let query_variables = + Arc::clone(&carrier.query.as_ref().expect("query was not returned").arguments); + let transform_data = Arc::clone(&transformed.value); + + match &transformed.value.base { + TransformBase::ContextField(field) => { + assert_eq!(current_vid, field.vertex_id, "filter left-hand side was a transformed field from a different vertex: {current_vid:?} {filter:?}"); + let local_field = LocalField { + field_name: field.field_name.clone(), + field_type: field.field_type.clone(), + }; + + let filter_input_iterator = Box::new( + compute_local_field_with_separate_value( + adapter, + carrier, + component, + current_vid, + &local_field, + prepped_iterator, + ) + .map(move |(mut ctx, mut value)| { + value = apply_transforms( + &transform_data, + &query_variables, + &mut ctx.values, + value, + ); + ctx.values.push(value); + ctx + }), + ); + + apply_filter( + adapter, + carrier, + component, + current_vid, + &filter.map(|_| (), |r| r), + filter_input_iterator, + ) + } + TransformBase::FoldSpecificField(..) => unreachable!( + "illegal filter over fold-specific field passed to this function: {filter:?}" + ), + } + } OperationSubject::FoldSpecificField(..) => unreachable!( "illegal filter over fold-specific field passed to this function: {filter:?}" ), @@ -844,27 +892,70 @@ fn apply_fold_specific_filter<'query, AdapterT: Adapter<'query>>( component: &IRQueryComponent, fold: &IRFold, current_vid: Vid, - filter: &Operation, + filter: &Operation, iterator: ContextIterator<'query, AdapterT::Vertex>, ) -> ContextIterator<'query, AdapterT::Vertex> { - let fold_specific_field = filter.left(); - let field_iterator = Box::new(compute_fold_specific_field_with_separate_value(fold.eid, fold_specific_field, iterator).map(|(mut ctx, tagged_value)| { - let value = match tagged_value { - TaggedValue::Some(value) => value, - TaggedValue::NonexistentOptional => { - unreachable!("while applying fold-specific filter, the @fold turned out to not exist: {ctx:?}") + let left = filter.left(); + let (fold_specific_field, transform_data) = match left { + OperationSubject::FoldSpecificField(field) => (field, None), + OperationSubject::TransformedField(transformed) => match &transformed.value.base { + TransformBase::FoldSpecificField(field) => (field, Some(&transformed.value)), + TransformBase::ContextField(_) => { + unreachable!("post-fold filter does not refer to a fold-specific field: {left:?}") } - }; - ctx.values.push(value); - ctx - })); + }, + OperationSubject::LocalField(_) => { + unreachable!("post-fold filter does not refer to a fold-specific field: {left:?}") + } + }; + + let field_iterator: ContextIterator<'query, AdapterT::Vertex> = if let Some(transform_data) = + transform_data + { + let prepped_iterator = push_transform_argument_tag_values_onto_stack( + adapter, + carrier, + component, + current_vid, + &transform_data.transforms, + iterator, + ); + + let query_variables = + Arc::clone(&carrier.query.as_ref().expect("query was not returned").arguments); + let transform_data = Arc::clone(transform_data); + Box::new(compute_fold_specific_field_with_separate_value(fold.eid, &fold_specific_field.kind, prepped_iterator).map(move |(mut ctx, tagged_value)| { + let mut value = match tagged_value { + TaggedValue::Some(value) => value, + TaggedValue::NonexistentOptional => { + unreachable!("while applying fold-specific filter, the @fold turned out to not exist: {ctx:?}") + } + }; + + value = apply_transforms(&transform_data, &query_variables, &mut ctx.values, value); + + ctx.values.push(value); + ctx + })) + } else { + Box::new(compute_fold_specific_field_with_separate_value(fold.eid, &fold_specific_field.kind, iterator).map(|(mut ctx, tagged_value)| { + let value = match tagged_value { + TaggedValue::Some(value) => value, + TaggedValue::NonexistentOptional => { + unreachable!("while applying fold-specific filter, the @fold turned out to not exist: {ctx:?}") + } + }; + ctx.values.push(value); + ctx + })) + }; apply_filter( adapter, carrier, component, current_vid, - &filter.map(|_| (), |r| *r), + &filter.map(|_| (), |r| r), field_iterator, ) } diff --git a/trustfall_core/src/interpreter/filtering.rs b/trustfall_core/src/interpreter/filtering.rs index 11b1e058..3c38996a 100644 --- a/trustfall_core/src/interpreter/filtering.rs +++ b/trustfall_core/src/interpreter/filtering.rs @@ -2,14 +2,11 @@ use std::{fmt::Debug, mem}; use regex::Regex; -use crate::ir::{Argument, FieldRef, FieldValue, IRQueryComponent, LocalField, Operation, Vid}; +use crate::ir::{Argument, FieldValue, IRQueryComponent, Operation, Vid}; use super::{ - execution::{ - compute_context_field_with_separate_value, compute_fold_specific_field_with_separate_value, - compute_local_field_with_separate_value, QueryCarrier, - }, - Adapter, ContextIterator, ContextOutcomeIterator, TaggedValue, + execution::QueryCarrier, tags::compute_tag_with_separate_value, Adapter, ContextIterator, + ContextOutcomeIterator, TaggedValue, }; #[inline(always)] @@ -281,61 +278,17 @@ pub(super) fn apply_filter<'query, AdapterT: Adapter<'query>>( let right_value = query_arguments[var.variable_name.as_ref()].to_owned(); apply_filter_with_static_argument_value(filter, right_value, iterator) } - Some(Argument::Tag(FieldRef::ContextField(context_field))) => { - // TODO: Benchmark if it would be faster to duplicate the filtering code to special-case - // the situation when the tag is always known to exist, so we don't have to unwrap - // a TaggedValue enum, because we know it would be TaggedValue::Some. - let argument_value_iterator = if context_field.vertex_id == current_vid { - // This tag is from the vertex we're currently filtering. That means the field - // whose value we want to get is actually local, so there's no need to compute it - // using the more expensive approach we use for non-local fields. - let local_equivalent_field = LocalField { - field_name: context_field.field_name.clone(), - field_type: context_field.field_type.clone(), - }; - Box::new( - compute_local_field_with_separate_value( - adapter, - carrier, - component, - current_vid, - &local_equivalent_field, - iterator, - ) - .map(|(ctx, value)| (ctx, TaggedValue::Some(value))), - ) - } else { - compute_context_field_with_separate_value( - adapter, - carrier, - component, - context_field, - iterator, - ) - }; - apply_filter_with_tagged_argument_value(filter, argument_value_iterator) - } - Some(Argument::Tag(field_ref @ FieldRef::FoldSpecificField(fold_field))) => { - let argument_value_iterator = if component.folds.contains_key(&fold_field.fold_eid) { - compute_fold_specific_field_with_separate_value( - fold_field.fold_eid, - &fold_field.kind, - iterator, - ) - } else { - // This value represents an imported tag value from an outer component. - // Grab its value from the context itself. - let cloned_ref = field_ref.clone(); - Box::new(iterator.map(move |ctx| { - let right_value = ctx.imported_tags[&cloned_ref].clone(); - (ctx, right_value) - })) - }; + Some(Argument::Tag(field_ref)) => { + let argument_value_iterator = compute_tag_with_separate_value( + adapter, + carrier, + component, + current_vid, + field_ref, + iterator, + ); apply_filter_with_tagged_argument_value(filter, argument_value_iterator) } - Some(Argument::Tag(FieldRef::TransformedField(_))) => { - todo!() - } None => unreachable!( "no argument present for filter, but not handled in unary filters fn: {filter:?}" ), diff --git a/trustfall_core/src/interpreter/mod.rs b/trustfall_core/src/interpreter/mod.rs index 0679af3c..7b4c94f3 100644 --- a/trustfall_core/src/interpreter/mod.rs +++ b/trustfall_core/src/interpreter/mod.rs @@ -17,7 +17,9 @@ mod filtering; pub mod helpers; mod hints; pub mod replay; +mod tags; pub mod trace; +mod transformation; pub use hints::{ CandidateValue, DynamicallyResolvedValue, EdgeInfo, NeighborInfo, QueryInfo, Range, diff --git a/trustfall_core/src/interpreter/tags.rs b/trustfall_core/src/interpreter/tags.rs new file mode 100644 index 00000000..dec68995 --- /dev/null +++ b/trustfall_core/src/interpreter/tags.rs @@ -0,0 +1,72 @@ +use crate::ir::{FieldRef, IRQueryComponent, LocalField, Vid}; + +use super::{ + execution::{ + compute_context_field_with_separate_value, compute_fold_specific_field_with_separate_value, + compute_local_field_with_separate_value, QueryCarrier, + }, + Adapter, ContextIterator, DataContext, TaggedValue, +}; + +pub(super) fn compute_tag_with_separate_value<'query, AdapterT: Adapter<'query>>( + adapter: &AdapterT, + carrier: &mut QueryCarrier, + component: &IRQueryComponent, + current_vid: Vid, + field_ref: &FieldRef, + iterator: ContextIterator<'query, AdapterT::Vertex>, +) -> Box, TaggedValue)> + 'query> { + match field_ref { + FieldRef::ContextField(context_field) => { + // TODO: Benchmark if it would be faster to duplicate the code to special-case + // the situation when the tag is always known to exist, so we don't have to unwrap + // a TaggedValue enum, because we know it would be TaggedValue::Some. + if context_field.vertex_id == current_vid { + // This tag is from the vertex we're currently evaluating. That means the field + // whose value we want to get is actually local, so there's no need to compute it + // using the more expensive approach we use for non-local fields. + let local_equivalent_field = LocalField { + field_name: context_field.field_name.clone(), + field_type: context_field.field_type.clone(), + }; + Box::new( + compute_local_field_with_separate_value( + adapter, + carrier, + component, + current_vid, + &local_equivalent_field, + iterator, + ) + .map(|(ctx, value)| (ctx, TaggedValue::Some(value))), + ) + } else { + compute_context_field_with_separate_value( + adapter, + carrier, + component, + context_field, + iterator, + ) + } + } + FieldRef::FoldSpecificField(fold_field) => { + if component.folds.contains_key(&fold_field.fold_eid) { + compute_fold_specific_field_with_separate_value( + fold_field.fold_eid, + &fold_field.kind, + iterator, + ) + } else { + // This value represents an imported tag value from an outer component. + // Grab its value from the context itself. + let cloned_ref = field_ref.clone(); + Box::new(iterator.map(move |ctx| { + let right_value = ctx.imported_tags[&cloned_ref].clone(); + (ctx, right_value) + })) + } + } + FieldRef::TransformedField(_) => todo!(), + } +} diff --git a/trustfall_core/src/interpreter/transformation.rs b/trustfall_core/src/interpreter/transformation.rs new file mode 100644 index 00000000..e11de1eb --- /dev/null +++ b/trustfall_core/src/interpreter/transformation.rs @@ -0,0 +1,180 @@ +use std::{collections::BTreeMap, sync::Arc}; + +use crate::ir::{Argument, FieldValue, IRQueryComponent, Transform, TransformedValue, Vid}; + +use super::{ + execution::QueryCarrier, tags::compute_tag_with_separate_value, Adapter, ContextIterator, + TaggedValue, +}; + +pub(super) fn push_transform_argument_tag_values_onto_stack<'query, AdapterT: Adapter<'query>>( + adapter: &AdapterT, + carrier: &mut QueryCarrier, + component: &IRQueryComponent, + current_vid: Vid, + transforms: &[Transform], + mut iterator: ContextIterator<'query, AdapterT::Vertex>, +) -> ContextIterator<'query, AdapterT::Vertex> { + // Ensure any non-immediate operands (like values coming from tags) are pushed + // onto the each context's stack before we evaluate the transform. + // We push them on the stack in reverse order, since the stack is LIFO. + for transform in transforms.iter().rev() { + match transform { + Transform::Add(op) => match op { + Argument::Tag(tag) => { + iterator = Box::new( + compute_tag_with_separate_value( + adapter, + carrier, + component, + current_vid, + tag, + iterator, + ) + .map(|(mut ctx, tag_value)| { + let value = match tag_value { + TaggedValue::NonexistentOptional => FieldValue::Null, + TaggedValue::Some(value) => value, + }; + ctx.values.push(value); + ctx + }), + ); + } + Argument::Variable(..) => {} + }, + Transform::Len | Transform::Abs => { + // No tag arguments here! + } + } + } + + iterator +} + +pub(super) fn apply_transforms( + transformed_value: &TransformedValue, + variables: &BTreeMap, FieldValue>, + stack: &mut Vec, + mut value: FieldValue, +) -> FieldValue { + for transform in &transformed_value.transforms { + value = apply_one_transform(transform, variables, stack, &value); + } + + value +} + +#[inline] +fn apply_one_transform( + transform: &Transform, + variables: &BTreeMap, FieldValue>, + stack: &mut Vec, + value: &FieldValue, +) -> FieldValue { + match transform { + Transform::Len => apply_len_transform(value), + Transform::Abs => apply_abs_transform(value), + Transform::Add(argument) => match argument { + Argument::Variable(var) => { + let operand = &variables[&var.variable_name]; + apply_add_transform(value, operand) + } + Argument::Tag(_) => { + let operand = stack.pop().expect( + "empty stack while attempting to resolve transform operand: {transform:?}", + ); + apply_add_transform(value, &operand) + } + }, + } +} + +#[inline] +fn apply_len_transform(value: &FieldValue) -> FieldValue { + match value { + FieldValue::Null => FieldValue::Null, + FieldValue::List(l) => FieldValue::Int64(l.len() as i64), + _ => unreachable!("{value:?}"), + } +} + +#[inline] +fn apply_abs_transform(value: &FieldValue) -> FieldValue { + match value { + FieldValue::Null => FieldValue::Null, + FieldValue::Int64(x) => FieldValue::Uint64(x.unsigned_abs()), + FieldValue::Uint64(x) => FieldValue::Uint64(*x), + FieldValue::Float64(x) => FieldValue::Float64(x.abs()), + _ => unreachable!("{value:?}"), + } +} + +#[inline] +fn apply_add_transform(value: &FieldValue, operand: &FieldValue) -> FieldValue { + match (value, operand) { + (FieldValue::Null, _) => FieldValue::Null, + (_, FieldValue::Null) => FieldValue::Null, + (FieldValue::Int64(x), FieldValue::Int64(y)) => FieldValue::Int64(x.saturating_add(*y)), + (FieldValue::Uint64(x), FieldValue::Uint64(y)) => FieldValue::Uint64(x.saturating_add(*y)), + (FieldValue::Int64(signed), FieldValue::Uint64(unsigned)) + | (FieldValue::Uint64(unsigned), FieldValue::Int64(signed)) => { + add_unlike_signedness_integers(*signed, *unsigned) + } + (FieldValue::Float64(x), FieldValue::Float64(y)) => FieldValue::Float64(x + y), + (FieldValue::Float64(x), FieldValue::Int64(y)) + | (FieldValue::Int64(y), FieldValue::Float64(x)) => FieldValue::Float64(x + (*y as f64)), + (FieldValue::Float64(x), FieldValue::Uint64(y)) + | (FieldValue::Uint64(y), FieldValue::Float64(x)) => FieldValue::Float64(x + (*y as f64)), + _ => unreachable!("{value:?} {operand:?}"), + } +} + +#[inline] +fn add_unlike_signedness_integers(signed: i64, unsigned: u64) -> FieldValue { + if (unsigned > i64::MAX as u64) || !signed.is_negative() { + return FieldValue::Uint64(unsigned.saturating_add_signed(signed)); + } + + FieldValue::Int64(signed.saturating_add_unsigned(unsigned)) +} + +#[cfg(test)] +mod tests { + use crate::ir::FieldValue; + + use super::add_unlike_signedness_integers; + + #[test] + fn test_add_unlike_signedness_integers() { + let test_data = [ + // Adding two non-negative numbers results in a u64. + (123i64, 456u64, FieldValue::Uint64(579)), + (i64::MAX, 0, FieldValue::Uint64(i64::MAX as u64)), + (i64::MAX, 1, FieldValue::Uint64(i64::MAX as u64 + 1)), + // Adding a negative and positive number far from the numeric bounds results in i64. + (-123, 122, FieldValue::Int64(-1)), + (-123, 123, FieldValue::Int64(0)), + (-123, 124, FieldValue::Int64(1)), + // Adding a small negative number to a u64 above the i64 numeric bound results in u64. + (-1, u64::MAX, FieldValue::Uint64(u64::MAX - 1)), + // Addition right up to the numeric bounds. + (i64::MAX, u64::MAX - (i64::MAX as u64), FieldValue::Uint64(u64::MAX)), + (i64::MIN, 0, FieldValue::Int64(i64::MIN)), + // Saturation at the numeric bounds instead of overflow or underflow. + (i64::MAX, u64::MAX, FieldValue::Uint64(u64::MAX)), + ]; + + for (signed, unsigned, expected) in test_data { + let actual = add_unlike_signedness_integers(signed, unsigned); + assert_eq!( + expected, actual, + "{signed} + {unsigned} => {actual:?} but expected {expected:?}" + ); + assert!( + expected.structural_eq(&actual), + "values compare equal but are structurally different: {expected:?} {actual:?}" + ); + } + } +} diff --git a/trustfall_core/src/ir/mod.rs b/trustfall_core/src/ir/mod.rs index fc3debd5..04344409 100644 --- a/trustfall_core/src/ir/mod.rs +++ b/trustfall_core/src/ir/mod.rs @@ -811,7 +811,7 @@ pub enum TransformBase { pub enum Transform { Len, Abs, - Add(FieldRef), + Add(Argument), } impl Transform {