Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use PyTuple for more idiomatic PyO3 code. #673

Merged
merged 1 commit into from
Sep 18, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
105 changes: 62 additions & 43 deletions pytrustfall/src/shim.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
use std::{collections::BTreeMap, sync::Arc};

use pyo3::{exceptions::PyStopIteration, prelude::*, types::PyIterator, wrap_pyfunction};
use pyo3::{
exceptions::PyStopIteration, prelude::*, types::PyIterator, types::PyTuple, wrap_pyfunction,
};
use trustfall_core::{
frontend::{error::FrontendError, parse},
interpreter::{
Expand Down Expand Up @@ -359,24 +361,27 @@ impl Iterator for PythonResolvePropertyIterator {

fn next(&mut self) -> Option<Self::Item> {
Python::with_gil(|py| {
match self.underlying.call_method_bound(py, "__next__", (), None) {
match self.underlying.call_method0(py, "__next__") {
Ok(output) => {
// value is a (context, property_value) tuple here
let context: Opaque = output
.call_method_bound(py, "__getitem__", (0i64,), None)
.unwrap()
.extract(py)
.unwrap();

// TODO: if this panics, we got an unrepresentable FieldValue,
// which should be a proper error
let value: FieldValue = output
.call_method_bound(py, "__getitem__", (1i64,), None)
.unwrap()
.extract(py)
.unwrap();

Some((context, value))
// `output` must be a (context, property_value) tuple here, or else we panic.
let tuple = output.downcast_bound(py).expect(
"resolve_property() did not yield a `(context, property_value)` tuple",
);

let tuple_size_error: &'static str =
"resolve_property() yielded a tuple that did not have exactly 2 elements";

let property_value: FieldValue =
tuple.get_borrowed_item(1).expect(tuple_size_error).extract().expect(
"resolve_property() tuple element at index 1 is not a property value",
);

let context: Opaque = tuple.get_borrowed_item(0)
.expect(tuple_size_error)
.extract()
.expect("resolve_property() tuple element at index 0 is not a context (Opaque) value");

Some((context, property_value))
}
Err(e) => {
if e.is_instance_of::<PyStopIteration>(py) {
Expand Down Expand Up @@ -407,21 +412,27 @@ impl Iterator for PythonResolveNeighborsIterator {

fn next(&mut self) -> Option<Self::Item> {
Python::with_gil(|py| {
match self.underlying.call_method_bound(py, "__next__", (), None) {
match self.underlying.call_method0(py, "__next__") {
Ok(output) => {
// value is a (context, neighbor_iterator) tuple here
let context: Opaque = output
.call_method_bound(py, "__getitem__", (0i64,), None)
.unwrap()
.extract(py)
.unwrap();
let neighbors_iterable =
output.call_method_bound(py, "__getitem__", (1i64,), None).unwrap();

// Allow returning iterables (e.g. []), not just iterators.
// Iterators return self when __iter__() is called.
// `output` must be a (context, neighbor_iterator) tuple here, or else we panic.
let tuple: &Bound<'_, PyTuple> = output.downcast_bound(py).expect(
"resolve_neighbors() did not yield a `(context, neighbor_iterator)` tuple",
);

let tuple_size_error: &'static str =
"resolve_neighbors() yielded a tuple that did not have exactly 2 elements";

let neighbors_iterable = tuple.get_borrowed_item(1).expect(tuple_size_error);

let context: Opaque = tuple.get_borrowed_item(0)
.expect(tuple_size_error)
.extract()
.expect("resolve_neighbors() tuple element at index 0 is not a context (Opaque) value");

// Support returning iterables (e.g. []), not just iterators.
// Iterators return self when `__iter__()` is called.
let neighbors_iter = make_iterator(
neighbors_iterable.bind(py),
&neighbors_iterable,
"resolve_neighbors() yielded tuple's second element",
);

Expand Down Expand Up @@ -458,19 +469,27 @@ impl Iterator for PythonResolveCoercionIterator {

fn next(&mut self) -> Option<Self::Item> {
Python::with_gil(|py| {
match self.underlying.call_method_bound(py, "__next__", (), None) {
match self.underlying.call_method0(py, "__next__") {
Ok(output) => {
// value is a (context, can_coerce) tuple here
let context: Opaque = output
.call_method_bound(py, "__getitem__", (0i64,), None)
.unwrap()
.extract(py)
.unwrap();
let can_coerce: bool = output
.call_method_bound(py, "__getitem__", (1i64,), None)
.unwrap()
.extract::<bool>(py)
.unwrap();
// `output` must be a (context, can_coerce) tuple here, or else we panic.
let tuple = output
.downcast_bound(py)
.expect("resolve_coercion() did not yield a `(context, can_coerce)` tuple");

let tuple_size_error: &'static str =
"resolve_coercion() yielded a tuple that did not have exactly 2 elements";

let can_coerce: bool = tuple
.get_borrowed_item(1)
.expect(tuple_size_error)
.extract()
.expect("resolve_coercion() tuple element at index 1 is not a bool");

let context: Opaque = tuple.get_borrowed_item(0)
.expect(tuple_size_error)
.extract()
.expect("resolve_coercion() tuple element at index 0 is not a context (Opaque) value");

Some((context, can_coerce))
}
Err(e) => {
Expand Down
Loading