From a9c468160e489dfb3546296b75dbf0854d13fb8e Mon Sep 17 00:00:00 2001 From: Predrag Gruevski Date: Thu, 19 Sep 2024 14:48:42 +0000 Subject: [PATCH] Better errors for illegal PyO3 `FieldValue` values. --- pytrustfall/src/shim.rs | 5 +- pytrustfall/src/value.rs | 95 ++++++++++++++----- pytrustfall/trustfall/tests/test_execution.py | 4 +- 3 files changed, 78 insertions(+), 26 deletions(-) diff --git a/pytrustfall/src/shim.rs b/pytrustfall/src/shim.rs index b25bd9b5..58800044 100644 --- a/pytrustfall/src/shim.rs +++ b/pytrustfall/src/shim.rs @@ -1,7 +1,10 @@ use std::{collections::BTreeMap, sync::Arc}; use pyo3::{ - exceptions::PyStopIteration, prelude::*, types::PyIterator, types::PyTuple, wrap_pyfunction, + exceptions::PyStopIteration, + prelude::*, + types::{PyIterator, PyTuple}, + wrap_pyfunction, }; use trustfall_core::{ frontend::{error::FrontendError, parse}, diff --git a/pytrustfall/src/value.rs b/pytrustfall/src/value.rs index 911a897d..1a73a516 100644 --- a/pytrustfall/src/value.rs +++ b/pytrustfall/src/value.rs @@ -1,10 +1,7 @@ -use std::sync::Arc; +use std::{fmt::Display, sync::Arc}; -use pyo3::{exceptions::PyTypeError, prelude::*, types::PyList}; +use pyo3::{exceptions::PyValueError, prelude::*, types::PyList}; -use crate::errors::QueryArgumentsError; - -// TODO: apply https://pyo3.rs/v0.22.3/conversions/traits#deriving-frompyobject-for-enums #[derive(Debug, Clone)] pub(crate) enum FieldValue { Null, @@ -18,6 +15,38 @@ pub(crate) enum FieldValue { List(Vec), } +impl FieldValue { + #[inline] + pub(crate) fn is_null(&self) -> bool { + matches!(self, FieldValue::Null) + } +} + +impl Display for FieldValue { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + FieldValue::Null => write!(f, "null"), + FieldValue::Int64(v) => write!(f, "{v}"), + FieldValue::Uint64(v) => write!(f, "{v}"), + FieldValue::Float64(v) => write!(f, "{v}"), + FieldValue::String(v) => write!(f, "\"{v}\""), + FieldValue::Boolean(v) => write!(f, "{v}"), + FieldValue::Enum(v) => write!(f, "{v}"), + FieldValue::List(v) => { + write!(f, "[")?; + let mut iter = v.iter(); + if let Some(next) = iter.next() { + write!(f, "{next}")?; + } + for elem in iter { + write!(f, ", {elem}")?; + } + write!(f, "]") + } + } + } +} + impl IntoPy> for FieldValue { fn into_py(self, py: Python<'_>) -> Py { match self { @@ -46,29 +75,47 @@ impl<'py> pyo3::FromPyObject<'py> for FieldValue { } else if let Ok(inner) = value.extract::() { Ok(FieldValue::Uint64(inner)) } else if let Ok(inner) = value.extract::() { - // TODO: disallow and error on nan and infinite values - Ok(FieldValue::Float64(inner)) + if inner.is_finite() { + Ok(FieldValue::Float64(inner)) + } else { + Err(PyValueError::new_err(format!( + "{inner} is not a valid query argument value: \ + float values may not be NaN or infinity" + ))) + } } else if let Ok(inner) = value.extract::() { Ok(FieldValue::String(inner.into())) } else if let Ok(list) = value.downcast::() { - let converted = list.iter().map(|element| element.extract::()).try_fold( - vec![], - |mut acc, item| { - if let Ok(value) = item { - acc.push(value); - Some(acc) - } else { - None - } - }, - ); + let mut converted = Vec::with_capacity(list.len()); + for element in list.iter() { + let value = element.extract::()?; + converted.push(value); + } - // TODO: handle conversion errors properly - if let Some(inner_values) = converted { - Ok(FieldValue::List(inner_values)) - } else { - Err(PyErr::new::("first")) + // Ensure all non-null items in the list are of the same type. + let mut iter = converted.iter(); + let first_non_null = loop { + let Some(next) = iter.next() else { break None }; + if !next.is_null() { + break Some(next); + } + }; + if let Some(first) = first_non_null { + let expected = std::mem::discriminant(first); + for other in iter { + if !other.is_null() { + let next_discriminant = std::mem::discriminant(other); + if expected != next_discriminant { + return Err(PyValueError::new_err(format!( + "Found elements of different (non-null) types in the same list, \ + which is not allowed: {first} {other}" + ))); + } + } + } } + + Ok(FieldValue::List(converted)) } else { let repr = value.repr(); let display = repr @@ -76,7 +123,7 @@ impl<'py> pyo3::FromPyObject<'py> for FieldValue { .map_err(|_| ()) .and_then(|x| x.to_str().map_err(|_| ())) .unwrap_or(""); - Err(QueryArgumentsError::new_err(format!( + Err(PyValueError::new_err(format!( "Value {display} of type {} is not supported by Trustfall", value.get_type() ))) diff --git a/pytrustfall/trustfall/tests/test_execution.py b/pytrustfall/trustfall/tests/test_execution.py index ebea3a1a..b0001844 100644 --- a/pytrustfall/trustfall/tests/test_execution.py +++ b/pytrustfall/trustfall/tests/test_execution.py @@ -202,7 +202,9 @@ def test_unrepresentable_field_value(self) -> None: "required": object(), } - self.assertRaises(QueryArgumentsError, execute_query, NumbersAdapter(), SCHEMA, query, args) + self.assertRaises( + ValueError, execute_query, NumbersAdapter(), SCHEMA, query, args + ) def test_bad_query_input_type(self) -> None: query = 123