Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Order shape elements before decoding #380

Merged
merged 2 commits into from
Feb 4, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 73 additions & 26 deletions gel-derive/src/shape.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ pub fn derive_struct(s: &syn::ItemStruct) -> syn::Result<TokenStream> {
let buf = syn::Ident::new("buf", Span::mixed_site());
let nfields = syn::Ident::new("nfields", Span::mixed_site());
let elements = syn::Ident::new("elements", Span::mixed_site());
let order = syn::Ident::new("order", Span::mixed_site());
let sub_args = syn::Ident::new("sub_args", Span::mixed_site());
let (impl_generics, ty_generics, _) = s.generics.split_for_impl();
let fields = match &s.fields {
syn::Fields::Named(named) => {
Expand Down Expand Up @@ -82,51 +84,87 @@ pub fn derive_struct(s: &syn::ItemStruct) -> syn::Result<TokenStream> {
});
let field_decoders = fields
.iter()
.map(|field| {
.enumerate()
.map(|(index, field)| {
let fieldname = &field.name;

let index_lit = syn::LitInt::new(&index.to_string(), Span::mixed_site());
let sub_arg = quote! { &#sub_args.#index_lit };
let buf = quote! { fields[#order[#index]].as_deref() };

if field.attrs.json {
quote! {
let #fieldname: ::gel_protocol::model::Json =
<::gel_protocol::model::Json as
::gel_protocol::queryable::Queryable>
::decode_optional(#decoder, &(), #elements.read()?)?;
::decode_optional(#decoder, #sub_arg, #buf)?;
let #fieldname = ::serde_json::from_str(#fieldname.as_ref())
.map_err(::gel_protocol::errors::decode_error)?;
}
} else {
quote! {
let #fieldname =
::gel_protocol::queryable::Queryable
::decode_optional(#decoder, &(), #elements.read()?)?;
::decode_optional(#decoder, #sub_arg, #buf)?;
}
}
})
.collect::<TokenStream>();
let field_checks = fields
.iter()
.map(|field| {
.enumerate()
.map(|(field_index, field)| {
let name_str = &field.str_name;
let mut result = quote! {
let el = &shape.elements[idx];
if(el.name != #name_str) {
return ::std::result::Result::Err(ctx.wrong_field(#name_str, &el.name));
}
idx += 1;
let description_str = syn::LitStr::new(
&format!("field {}", field.str_name.value()),
field.str_name.span(),
);
let get_element = quote! {
let ::std::option::Option::Some((position, el)) = elements.get(#name_str) else {
return ::std::result::Result::Err(ctx.expected(#description_str));
};
order.push(*position);
};

let fieldtype = &field.ty;
if field.attrs.json {
result.extend(quote! {
let check_descriptor = if field.attrs.json {
quote! {
<::gel_protocol::model::Json as
::gel_protocol::queryable::Queryable>
::check_descriptor(ctx, el.type_pos)?;
});
::check_descriptor(ctx, el.type_pos)?
}
} else {
result.extend(quote! {
quote! {
<#fieldtype as ::gel_protocol::queryable::Queryable>
::check_descriptor(ctx, el.type_pos)?;
});
::check_descriptor(ctx, el.type_pos)?
}
};

let arg_ident = quote::format_ident!("arg_{field_index}");

quote! {
#get_element
let #arg_ident = #check_descriptor;
}
})
.collect::<TokenStream>();
let construct_sub_args = fields
.iter()
.enumerate()
.map(|(field_index, _)| {
let arg_ident = quote::format_ident!("arg_{field_index}");
quote! { #arg_ident, }
})
.collect::<TokenStream>();
let args_ty = fields
.iter()
.map(|field| {
if field.attrs.json {
quote! { (), }
} else {
let ty = &field.ty;
quote! { <#ty as ::gel_protocol::queryable::Queryable>::Args, }
}
result
})
.collect::<TokenStream>();

Expand All @@ -135,11 +173,13 @@ pub fn derive_struct(s: &syn::ItemStruct) -> syn::Result<TokenStream> {
let expanded = quote! {
impl #impl_generics ::gel_protocol::queryable::Queryable
for #name #ty_generics {
type Args = ();
type Args = (::std::vec::Vec<usize>, (#args_ty));

fn decode(#decoder: &::gel_protocol::queryable::Decoder, _args: &(), #buf: &[u8])
-> ::std::result::Result<Self, ::gel_protocol::errors::DecodeError>
{
fn decode(
#decoder: &::gel_protocol::queryable::Decoder,
(#order, #sub_args): &Self::Args,
#buf: &[u8]
) -> ::std::result::Result<Self, ::gel_protocol::errors::DecodeError> {
let #nfields = #base_fields
+ if #decoder.has_implicit_id { 1 } else { 0 }
+ if #decoder.has_implicit_tid { 1 } else { 0 }
Expand All @@ -151,6 +191,7 @@ pub fn derive_struct(s: &syn::ItemStruct) -> syn::Result<TokenStream> {
#type_id_block
#type_name_block
#id_block
let fields = #elements.read_n(#field_count)?;
#field_decoders
::std::result::Result::Ok(#name {
#(
Expand All @@ -160,8 +201,8 @@ pub fn derive_struct(s: &syn::ItemStruct) -> syn::Result<TokenStream> {
}
fn check_descriptor(
ctx: &::gel_protocol::queryable::DescriptorContext,
type_pos: ::gel_protocol::descriptors::TypePos)
-> ::std::result::Result<(), ::gel_protocol::queryable::DescriptorMismatch>
type_pos: ::gel_protocol::descriptors::TypePos
) -> ::std::result::Result<Self::Args, ::gel_protocol::queryable::DescriptorMismatch>
{
use ::gel_protocol::descriptors::Descriptor::ObjectShape;
let desc = ctx.get(type_pos)?;
Expand All @@ -174,7 +215,6 @@ pub fn derive_struct(s: &syn::ItemStruct) -> syn::Result<TokenStream> {

// TODO(tailhook) cache shape.id somewhere
let mut idx = 0;

#type_id_check
#type_name_check
#id_check
Expand All @@ -183,8 +223,15 @@ pub fn derive_struct(s: &syn::ItemStruct) -> syn::Result<TokenStream> {
#field_count, shape.elements.len())
);
}

let mut elements = ::std::collections::HashMap::with_capacity(shape.elements.len());
use ::std::iter::Iterator;
for (position, element) in shape.elements.iter().enumerate() {
elements.insert(element.name.as_str(), (position, element));
}
let mut order = ::std::vec::Vec::with_capacity(shape.elements.len());
#field_checks
::std::result::Result::Ok(())
::std::result::Result::Ok((order, (#construct_sub_args)))
}
}
};
Expand Down
3 changes: 2 additions & 1 deletion gel-derive/tests/json.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ fn json_field() {
\0\0\x0b\x86\0\0\0\x10\xf2\xe6F9\xd7\x04\x11\xea\
\xa0<\x83\x9f\xd9\xbd\x88\x94\0\0\0\x19\
\0\0\0\x02id\0\0\x0e\xda\0\0\0\x10\x01{\"field1\": 123}";
let res = ShapeWithJson::decode(&old_decoder(), &(), data);
let order = (vec![0_usize, 1], ((), ()));
let res = ShapeWithJson::decode(&old_decoder(), &order, data);
assert_eq!(
res.unwrap(),
ShapeWithJson {
Expand Down
6 changes: 4 additions & 2 deletions gel-derive/tests/list_scalar_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ fn decode_new() {
let data = b"\0\0\0\x03\0\0\0\x19\0\0\0\x0fcal::local_date\
\0\0\0\x19\0\0\0 std::anyscalar, std::anydiscrete\
\0\0\0\x19\0\0\0\x06normal";
let res = ScalarType::decode(&Decoder::default(), &(), data);
let order = (vec![0, 1, 2], ((), (), ()));
let res = ScalarType::decode(&Decoder::default(), &order, data);
assert_eq!(
res.unwrap(),
ScalarType {
Expand All @@ -38,7 +39,8 @@ fn decode_old() {
\xee\xfc\xb6\x12\0\0\x0b\x86\0\0\0\x10\0\0\0\0\0\0\0\0\0\0\0\0\0\0\
\x01\x0c\0\0\0\x19\0\0\0\x0fcal::local_date\
\0\0\0\x19\0\0\0\x0estd::anyscalar\0\0\0\x19\0\0\0\x06normal";
let res = ScalarType::decode(&old_decoder(), &(), data);
let order = (vec![0, 1, 2], ((), (), ()));
let res = ScalarType::decode(&old_decoder(), &order, data);
assert_eq!(
res.unwrap(),
ScalarType {
Expand Down
3 changes: 2 additions & 1 deletion gel-derive/tests/varnames.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ fn decode() {
let data = b"\0\0\0\x04\0\0\0\x14\0\0\0\x08\0\0\0\0\0\0\x03\0\0\0\
\0\x19\0\0\0\0\0\0\0\x19\0\0\0\x0bSomeDecoder\
\0\0\0\x14\0\0\0\x08\0\0\0\0\0\0\0{";
let res = WeirdStruct::decode(&Decoder::default(), &(), data);
let order = (vec![0, 1, 2, 3], ((), (), (), ()));
let res = WeirdStruct::decode(&Decoder::default(), &order, data);
assert_eq!(
res.unwrap(),
WeirdStruct {
Expand Down
10 changes: 9 additions & 1 deletion gel-protocol/src/serialization/decode/raw_composite.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,18 @@ impl<'t> DecodeTupleLike<'t> {
Ok(elements)
}

pub fn read(&mut self) -> Result<Option<&[u8]>, DecodeError> {
pub fn read(&mut self) -> Result<Option<&'t [u8]>, DecodeError> {
self.inner.read_object_element()
}

pub fn read_n(&mut self, n: usize) -> Result<Vec<Option<&'t [u8]>>, DecodeError> {
let mut bufs = Vec::with_capacity(n);
for _ in 0..n {
bufs.push(self.read()?);
}
Ok(bufs)
}

pub fn skip_element(&mut self) -> Result<(), DecodeError> {
self.read()?;
Ok(())
Expand Down
1 change: 1 addition & 0 deletions gel-tokio/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ miette = { version = "7.2.0", features = ["fancy"] }
gel-errors = { path = "../gel-errors", features = ["miette"] }
test-utils = { git = "https://github.com/edgedb/test-utils.git" }
tempfile = "3.13.0"
tokio = { version = "1.15", features = ["rt"] }

[features]
default = ["derive", "env"]
Expand Down
57 changes: 53 additions & 4 deletions gel-tokio/tests/func/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,7 @@ async fn wrong_field_number() -> anyhow::Result<()> {
.query_required_single::<Thing, _>("select { a := 'hello', c := 'world' }", &())
.await
.unwrap_err();
assert_eq!(
format!("{err:#}"),
"DescriptorMismatch: unexpected field c, expected b"
);
assert_eq!(format!("{err:#}"), "DescriptorMismatch: expected field b");

Ok(())
}
Expand Down Expand Up @@ -356,3 +353,55 @@ async fn vector() -> anyhow::Result<()> {

Ok(())
}

#[tokio::test]
async fn props_in_wrong_order() -> anyhow::Result<()> {
let client = Client::new(&SERVER.config);
client.ensure_connected().await?;

#[derive(Debug, PartialEq, Queryable)]
struct Foo {
hello: String,
world: i64,
}

let res = client
.query_required_single::<Foo, _>("select { world := 42, hello := 'hello' }", &())
.await
.unwrap();

assert_eq!(
res,
Foo {
hello: "hello".into(),
world: 42
}
);

#[derive(Debug, PartialEq, Queryable)]
struct Bar {
foo: Foo,
baz: i64,
}

let res = client
.query_required_single::<Bar, _>(
"select { baz := 3, foo := { world := 42, hello := 'hello' } }",
&(),
)
.await
.unwrap();

assert_eq!(
res,
Bar {
foo: Foo {
hello: "hello".into(),
world: 42
},
baz: 3
}
);

Ok(())
}