Skip to content

Commit

Permalink
Support named arguments (#304)
Browse files Browse the repository at this point in the history
Co-authored-by: Aljaž Mur Eržen <aljaz@edgedb.com>
  • Loading branch information
MrFoxPro and aljazerzen authored Apr 12, 2024
1 parent cf5043a commit 783438e
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 27 deletions.
40 changes: 22 additions & 18 deletions edgedb-protocol/src/codec.rs
Original file line number Diff line number Diff line change
Expand Up @@ -785,24 +785,28 @@ impl Codec for ArrayAdapter {
impl<'a> From<&'a [descriptors::ShapeElement]> for ObjectShape {
fn from(shape: &'a [descriptors::ShapeElement]) -> ObjectShape {
ObjectShape(Arc::new(ObjectShapeInfo {
elements: shape.iter().map(|e| {
let descriptors::ShapeElement {
flag_implicit,
flag_link_property,
flag_link,
cardinality,
name,
type_pos: _,
} = e;
ShapeElement {
flag_implicit: *flag_implicit,
flag_link_property: *flag_link_property,
flag_link: *flag_link,
cardinality: *cardinality,
name: name.clone(),
}
}).collect(),
}))
elements: shape.iter().map(ShapeElement::from).collect(),
}))
}
}

impl<'a> From<&'a descriptors::ShapeElement> for ShapeElement {
fn from(e: &'a descriptors::ShapeElement) -> ShapeElement {
let descriptors::ShapeElement {
flag_implicit,
flag_link_property,
flag_link,
cardinality,
name,
type_pos: _,
} = e;
ShapeElement {
flag_implicit: *flag_implicit,
flag_link_property: *flag_link_property,
flag_link: *flag_link,
cardinality: *cardinality,
name: name.clone(),
}
}
}

Expand Down
17 changes: 9 additions & 8 deletions edgedb-protocol/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,20 +59,21 @@ pub enum Value {

mod query_result; // sealed trait should remain non-public

pub mod encoding;
pub mod client_message;
pub mod codec;
pub mod common;
pub mod descriptors;
pub mod encoding;
pub mod error_response;
pub mod errors;
pub mod features;
pub mod queryable;
pub mod serialization;
pub mod client_message;
pub mod server_message;
pub mod errors;
pub mod error_response;
pub mod descriptors;
pub mod value;
pub mod codec;
pub mod queryable;
#[macro_use]
pub mod value_opt;
pub mod query_arg;
pub mod model;


pub use query_result::QueryResult;
2 changes: 1 addition & 1 deletion edgedb-protocol/src/query_result.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ use edgedb_errors::{ProtocolEncodingError, DescriptorMismatch};

use crate::codec::Codec;
use crate::queryable::{Queryable, Decoder, DescriptorContext};
use crate::descriptors::{TypePos};
use crate::descriptors::TypePos;
use crate::value::Value;

pub trait Sealed: Sized {}
Expand Down
125 changes: 125 additions & 0 deletions edgedb-protocol/src/value_opt.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
use std::collections::HashMap;

use edgedb_errors::{ClientEncodingError, Error, ErrorKind};

use crate::codec::{ObjectShape, ShapeElement};
use crate::descriptors::Descriptor;
use crate::query_arg::{Encoder, QueryArgs};
use crate::value::Value;

/// An optional [Value] that can be constructed from `impl Into<Value>`,
/// `Option<impl Into<Value>>`, `Vec<impl Into<Value>>` or
/// `Option<Vec<impl Into<Value>>>`.
/// Used by [eargs!] macro.
pub struct ValueOpt(Option<Value>);

impl<V: Into<Value>> From<V> for ValueOpt {
fn from(value: V) -> Self {
ValueOpt(Some(value.into()))
}
}
impl<V: Into<Value>> From<Option<V>> for ValueOpt
where
Value: From<V>,
{
fn from(value: Option<V>) -> Self {
ValueOpt(value.map(Value::from))
}
}
impl<V: Into<Value>> From<Vec<V>> for ValueOpt
where
Value: From<V>,
{
fn from(value: Vec<V>) -> Self {
ValueOpt(Some(Value::Array(
value.into_iter().map(Value::from).collect(),
)))
}
}
impl<V: Into<Value>> From<Option<Vec<V>>> for ValueOpt
where
Value: From<V>,
{
fn from(value: Option<Vec<V>>) -> Self {
let mapped = value.map(|value| Value::Array(value.into_iter().map(Value::from).collect()));
ValueOpt(mapped)
}
}
impl From<ValueOpt> for Option<Value> {
fn from(value: ValueOpt) -> Self {
value.0
}
}

impl QueryArgs for HashMap<&str, ValueOpt> {
fn encode(&self, encoder: &mut Encoder) -> Result<(), Error> {
if self.len() == 0 && encoder.ctx.root_pos.is_none() {
return Ok(());
}

let root_pos = encoder.ctx.root_pos.ok_or_else(|| {
ClientEncodingError::with_message(format!(
"provided {} named arguments, but no arguments were expected by the server",
self.len()
))
})?;

let Descriptor::ObjectShape(target_shape) = encoder.ctx.get(root_pos)? else {
return Err(ClientEncodingError::with_message(
"query didn't expect named arguments",
));
};

let mut shape_elements: Vec<ShapeElement> = Vec::new();
let mut fields: Vec<Option<Value>> = Vec::new();

for param_descriptor in target_shape.elements.iter() {
let value = self.get(param_descriptor.name.as_str());

let Some(value) = value else {
return Err(ClientEncodingError::with_message(format!(
"argument for ${} missing",
param_descriptor.name
)));
};

shape_elements.push(ShapeElement::from(param_descriptor));
fields.push(value.0.clone());
}

Value::Object {
shape: ObjectShape::new(shape_elements),
fields,
}
.encode(encoder)
}
}

/// Constructs named query arguments that implement [QueryArgs] so they can be passed
/// into any query method.
/// ```no_run
/// use edgedb_protocol::value::Value;
///
/// let query = "SELECT (<str>$my_str, <int64>$my_int)";
/// let args = edgedb_protocol::named_args! {
/// "my_str" => "Hello world!".to_string(),
/// "my_int" => Value::Int64(42),
/// };
/// ```
///
/// The value side of an argument must be `impl Into<ValueOpt>`.
/// The type of the returned object is `HashMap<&str, ValueOpt>`.
#[macro_export]
macro_rules! named_args {
($($key:expr => $value:expr,)+) => { $crate::named_args!($($key => $value),+) };
($($key:expr => $value:expr),*) => {
{
const CAP: usize = <[()]>::len(&[$({ stringify!($key); }),*]);
let mut map = ::std::collections::HashMap::<&str, $crate::value_opt::ValueOpt>::with_capacity(CAP);
$(
map.insert($key, $crate::value_opt::ValueOpt::from($value));
)*
map
}
};
}
17 changes: 17 additions & 0 deletions edgedb-tokio/tests/func/client.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use edgedb_protocol::named_args;
use edgedb_protocol::value::{EnumValue, Value};
use edgedb_tokio::Client;
use edgedb_errors::NoDataError;
Expand Down Expand Up @@ -70,6 +71,22 @@ async fn simple() -> anyhow::Result<()> {
true
);

// named args
let value = client.query_required_single::<String, _>(
"select (
std::array_join(<array<str>>$msg1, ' ')
++ (<optional str>$question ?? ' the ultimate question of life')
++ ': '
++ <str><int64>$answer
);",
&named_args! {
"msg1" => vec!["the".to_string(), "answer".to_string(), "to".to_string()],
"question" => None::<String>,
"answer" => 42 as i64,
}
).await.unwrap();
assert_eq!(value.as_str(), "the answer to the ultimate question of life: 42");

Ok(())
}

Expand Down

0 comments on commit 783438e

Please sign in to comment.