Skip to content

Commit

Permalink
Merge pull request #106 from sunng87/fix/client-type-inference
Browse files Browse the repository at this point in the history
feat!: allow server inferenced type for portal
  • Loading branch information
sunng87 authored Aug 4, 2023
2 parents b74d134 + 3ed58e6 commit 5843411
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,6 @@ jobs:
- uses: actions/checkout@v3
- uses: actions-rs/toolchain@v1
with:
toolchain: "1.65"
toolchain: "1.67"
override: true
- run: cargo build --all-features
2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ homepage = "https://github.com/sunng87/pgwire"
repository = "https://github.com/sunng87/pgwire"
documentation = "https://docs.rs/crate/pgwire/"
readme = "README.md"
rust-version = "1.65"
rust-version = "1.67"

[dependencies]
log = "0.4"
Expand Down
14 changes: 7 additions & 7 deletions examples/sqlite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,31 +147,31 @@ fn get_params(portal: &Portal<String>) -> Vec<Box<dyn ToSql>> {
// we only support a small amount of types for demo
match param_type {
&Type::BOOL => {
let param = portal.parameter::<bool>(i).unwrap();
let param = portal.parameter::<bool>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::INT2 => {
let param = portal.parameter::<i16>(i).unwrap();
let param = portal.parameter::<i16>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::INT4 => {
let param = portal.parameter::<i32>(i).unwrap();
let param = portal.parameter::<i32>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::INT8 => {
let param = portal.parameter::<i64>(i).unwrap();
let param = portal.parameter::<i64>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::TEXT | &Type::VARCHAR => {
let param = portal.parameter::<String>(i).unwrap();
let param = portal.parameter::<String>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::FLOAT4 => {
let param = portal.parameter::<f32>(i).unwrap();
let param = portal.parameter::<f32>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
&Type::FLOAT8 => {
let param = portal.parameter::<f64>(i).unwrap();
let param = portal.parameter::<f64>(i, param_type).unwrap();
results.push(Box::new(param) as Box<dyn ToSql>);
}
_ => {
Expand Down
38 changes: 24 additions & 14 deletions src/api/portal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ use bytes::Bytes;
use postgres_types::FromSqlOwned;

use crate::{
api::Type,
error::{PgWireError, PgWireResult},
messages::{data::FORMAT_CODE_BINARY, extendedquery::Bind},
};
Expand Down Expand Up @@ -101,33 +102,27 @@ impl<S: Clone> Portal<S> {

/// Attempt to get parameter at given index as type `T`.
///
pub fn parameter<T>(&self, idx: usize) -> PgWireResult<Option<T>>
pub fn parameter<T>(&self, idx: usize, pg_type: &Type) -> PgWireResult<Option<T>>
where
T: FromSqlOwned,
{
if !T::accepts(pg_type) {
return Err(PgWireError::InvalidRustTypeForParameter(
pg_type.name().to_owned(),
));
}

let param = self
.parameters()
.get(idx)
.ok_or_else(|| PgWireError::ParameterIndexOutOfBound(idx))?;

let _format = self.parameter_format().format_for(idx);

let ty = self
.statement
.parameter_types()
.get(idx)
.ok_or_else(|| PgWireError::ParameterTypeIndexOutOfBound(idx))?;

if !T::accepts(ty) {
return Err(PgWireError::InvalidRustTypeForParameter(
ty.name().to_owned(),
));
}

if let Some(ref param) = param {
// TODO: from_sql only works with binary format
// here we need to check format code first and seek to support text
T::from_sql(ty, param)
T::from_sql(pg_type, param)
.map(|v| Some(v))
.map_err(PgWireError::FailedToParseParameter)
} else {
Expand All @@ -136,3 +131,18 @@ impl<S: Clone> Portal<S> {
}
}
}

#[cfg(test)]
mod tests {
use postgres_types::FromSql;

use super::*;

#[test]
fn test_from_sql() {
assert_eq!(
"helloworld",
String::from_sql(&Type::UNKNOWN, "helloworld".as_bytes()).unwrap()
)
}
}
2 changes: 0 additions & 2 deletions src/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,6 @@ pub enum PgWireError {
UnknownTypeId(Oid),
#[error("Parameter index out of bound: {0:?}")]
ParameterIndexOutOfBound(usize),
#[error("Parameter type index out of bound: {0:?}")]
ParameterTypeIndexOutOfBound(usize),
#[error("Cannot convert postgre type {0:?} to given rust type")]
InvalidRustTypeForParameter(String),
#[error("Failed to parse parameter: {0:?}")]
Expand Down

0 comments on commit 5843411

Please sign in to comment.