Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better errors for illegal PyO3 FieldValue values. #676

Merged
merged 1 commit into from
Sep 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion pytrustfall/src/shim.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
use std::{collections::BTreeMap, sync::Arc};

use pyo3::{
exceptions::PyStopIteration, prelude::*, types::PyIterator, types::PyTuple, wrap_pyfunction,
exceptions::PyStopIteration,
prelude::*,
types::{PyIterator, PyTuple},
wrap_pyfunction,
};
use trustfall_core::{
frontend::{error::FrontendError, parse},
Expand Down
95 changes: 71 additions & 24 deletions pytrustfall/src/value.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
use std::sync::Arc;
use std::{fmt::Display, sync::Arc};

use pyo3::{exceptions::PyTypeError, prelude::*, types::PyList};
use pyo3::{exceptions::PyValueError, prelude::*, types::PyList};

use crate::errors::QueryArgumentsError;

// TODO: apply https://pyo3.rs/v0.22.3/conversions/traits#deriving-frompyobject-for-enums
#[derive(Debug, Clone)]
pub(crate) enum FieldValue {
Null,
Expand All @@ -18,6 +15,38 @@ pub(crate) enum FieldValue {
List(Vec<FieldValue>),
}

impl FieldValue {
#[inline]
pub(crate) fn is_null(&self) -> bool {
matches!(self, FieldValue::Null)
}
}

impl Display for FieldValue {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
FieldValue::Null => write!(f, "null"),
FieldValue::Int64(v) => write!(f, "{v}"),
FieldValue::Uint64(v) => write!(f, "{v}"),
FieldValue::Float64(v) => write!(f, "{v}"),
FieldValue::String(v) => write!(f, "\"{v}\""),
FieldValue::Boolean(v) => write!(f, "{v}"),
FieldValue::Enum(v) => write!(f, "{v}"),
FieldValue::List(v) => {
write!(f, "[")?;
let mut iter = v.iter();
if let Some(next) = iter.next() {
write!(f, "{next}")?;
}
for elem in iter {
write!(f, ", {elem}")?;
}
write!(f, "]")
}
}
}
}

impl IntoPy<Py<PyAny>> for FieldValue {
fn into_py(self, py: Python<'_>) -> Py<PyAny> {
match self {
Expand Down Expand Up @@ -46,37 +75,55 @@ impl<'py> pyo3::FromPyObject<'py> for FieldValue {
} else if let Ok(inner) = value.extract::<u64>() {
Ok(FieldValue::Uint64(inner))
} else if let Ok(inner) = value.extract::<f64>() {
// TODO: disallow and error on nan and infinite values
Ok(FieldValue::Float64(inner))
if inner.is_finite() {
Ok(FieldValue::Float64(inner))
} else {
Err(PyValueError::new_err(format!(
"{inner} is not a valid query argument value: \
float values may not be NaN or infinity"
)))
}
} else if let Ok(inner) = value.extract::<String>() {
Ok(FieldValue::String(inner.into()))
} else if let Ok(list) = value.downcast::<PyList>() {
let converted = list.iter().map(|element| element.extract::<FieldValue>()).try_fold(
vec![],
|mut acc, item| {
if let Ok(value) = item {
acc.push(value);
Some(acc)
} else {
None
}
},
);
let mut converted = Vec::with_capacity(list.len());
for element in list.iter() {
let value = element.extract::<FieldValue>()?;
converted.push(value);
}

// TODO: handle conversion errors properly
if let Some(inner_values) = converted {
Ok(FieldValue::List(inner_values))
} else {
Err(PyErr::new::<PyTypeError, &str>("first"))
// Ensure all non-null items in the list are of the same type.
let mut iter = converted.iter();
let first_non_null = loop {
let Some(next) = iter.next() else { break None };
if !next.is_null() {
break Some(next);
}
};
if let Some(first) = first_non_null {
let expected = std::mem::discriminant(first);
for other in iter {
if !other.is_null() {
let next_discriminant = std::mem::discriminant(other);
if expected != next_discriminant {
return Err(PyValueError::new_err(format!(
"Found elements of different (non-null) types in the same list, \
which is not allowed: {first} {other}"
)));
}
}
}
}

Ok(FieldValue::List(converted))
} else {
let repr = value.repr();
let display = repr
.as_ref()
.map_err(|_| ())
.and_then(|x| x.to_str().map_err(|_| ()))
.unwrap_or("<repr unavailable>");
Err(QueryArgumentsError::new_err(format!(
Err(PyValueError::new_err(format!(
"Value {display} of type {} is not supported by Trustfall",
value.get_type()
)))
Expand Down
4 changes: 3 additions & 1 deletion pytrustfall/trustfall/tests/test_execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,7 +202,9 @@ def test_unrepresentable_field_value(self) -> None:
"required": object(),
}

self.assertRaises(QueryArgumentsError, execute_query, NumbersAdapter(), SCHEMA, query, args)
self.assertRaises(
ValueError, execute_query, NumbersAdapter(), SCHEMA, query, args
)

def test_bad_query_input_type(self) -> None:
query = 123
Expand Down
Loading