From aec97601075052d9e3603a4e96d90e746db2c3a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Alja=C5=BE=20Mur=20Er=C5=BEen?= Date: Tue, 9 Apr 2024 11:52:31 +0200 Subject: [PATCH] Fix lints suggested by clippy (#309) --- .gitignore | 7 + .vscode/settings.json | 3 + edgedb-derive/src/json.rs | 4 +- edgedb-derive/src/lib.rs | 8 +- edgedb-derive/src/shape.rs | 6 +- edgedb-derive/tests/json.rs | 2 +- edgedb-derive/tests/list_scalar_types.rs | 2 +- edgedb-errors/src/bin/edgedb_gen_errors.rs | 14 +- edgedb-errors/src/error.rs | 13 +- edgedb-errors/src/traits.rs | 3 +- edgedb-protocol/src/client_message.rs | 4 +- edgedb-protocol/src/codec.rs | 28 ++-- edgedb-protocol/src/common.rs | 4 +- edgedb-protocol/src/descriptors.rs | 2 +- edgedb-protocol/src/encoding.rs | 8 +- edgedb-protocol/src/error_response.rs | 10 +- edgedb-protocol/src/errors.rs | 1 - edgedb-protocol/src/model/bignum.rs | 48 +++--- edgedb-protocol/src/model/json.rs | 6 +- edgedb-protocol/src/model/time.rs | 148 +++++++++--------- edgedb-protocol/src/query_arg.rs | 4 +- edgedb-protocol/src/query_result.rs | 2 +- edgedb-protocol/src/queryable.rs | 11 +- .../src/serialization/decode/raw_composite.rs | 4 +- .../src/serialization/decode/raw_scalar.rs | 49 +++--- .../src/serialization/test_scalars.rs | 2 +- edgedb-protocol/src/server_message.rs | 22 +-- edgedb-protocol/src/value.rs | 4 +- edgedb-protocol/tests/client_messages.rs | 8 +- edgedb-protocol/tests/codecs.rs | 8 +- edgedb-protocol/tests/datetime_chrono.rs | 10 +- edgedb-protocol/tests/decode.rs | 2 - edgedb-tokio/examples/transaction_errors.rs | 4 +- edgedb-tokio/src/builder.rs | 114 +++++++------- edgedb-tokio/src/client.rs | 12 +- edgedb-tokio/src/credentials.rs | 3 +- edgedb-tokio/src/errors.rs | 3 +- edgedb-tokio/src/lib.rs | 2 +- edgedb-tokio/src/options.rs | 9 +- edgedb-tokio/src/raw/connection.rs | 61 ++++---- edgedb-tokio/src/raw/dumps.rs | 45 +++--- edgedb-tokio/src/raw/mod.rs | 24 ++- edgedb-tokio/src/raw/queries.rs | 114 +++++++------- edgedb-tokio/src/raw/response.rs | 4 +- edgedb-tokio/src/raw/state.rs | 16 +- edgedb-tokio/src/server_params.rs | 3 - edgedb-tokio/src/tls.rs | 2 +- edgedb-tokio/src/transaction.rs | 24 +-- edgedb-tokio/tests/func/globals.rs | 27 ++-- edgedb-tokio/tests/func/main.rs | 2 +- edgedb-tokio/tests/func/raw.rs | 21 ++- edgedb-tokio/tests/func/server.rs | 75 ++++----- 52 files changed, 482 insertions(+), 530 deletions(-) create mode 100644 .vscode/settings.json diff --git a/.gitignore b/.gitignore index 9b9c5aa2..6b4fb3fc 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,10 @@ __pycache__ /Cargo.lock /.idea +/.vscode + +# nix stuff +/flake.nix +/flake.lock +/.envrc +/.direnv diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 00000000..bfd35bb2 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,3 @@ +{ + "rust-analyzer.cargo.features": ["unstable", "chrono"] +} \ No newline at end of file diff --git a/edgedb-derive/src/json.rs b/edgedb-derive/src/json.rs index 0e2b0476..36a0d19a 100644 --- a/edgedb-derive/src/json.rs +++ b/edgedb-derive/src/json.rs @@ -26,8 +26,8 @@ pub fn derive(item: &syn::Item) -> syn::Result { { let json: ::edgedb_protocol::model::Json = ::edgedb_protocol::queryable::Queryable::decode(decoder, buf)?; - Ok(::serde_json::from_str(json.as_ref()) - .map_err(::edgedb_protocol::errors::decode_error)?) + ::serde_json::from_str(json.as_ref()) + .map_err(::edgedb_protocol::errors::decode_error) } fn check_descriptor( ctx: &::edgedb_protocol::queryable::DescriptorContext, diff --git a/edgedb-derive/src/lib.rs b/edgedb-derive/src/lib.rs index 87989896..618a34a1 100644 --- a/edgedb-derive/src/lib.rs +++ b/edgedb-derive/src/lib.rs @@ -78,7 +78,7 @@ let query_res: Vec = client.query(query, &()).await?; extern crate proc_macro; use proc_macro::TokenStream; -use syn::{self, parse_macro_input}; +use syn::parse_macro_input; mod attrib; mod enums; @@ -105,7 +105,7 @@ fn derive(item: &syn::Item) -> syn::Result { )); } }; - let attrs = attrib::ContainerAttrs::from_syn(&attrs)?; + let attrs = attrib::ContainerAttrs::from_syn(attrs)?; if attrs.json { json::derive(item) } else { @@ -113,10 +113,10 @@ fn derive(item: &syn::Item) -> syn::Result { syn::Item::Struct(s) => shape::derive_struct(s), syn::Item::Enum(s) => enums::derive_enum(s), _ => { - return Err(syn::Error::new_spanned(item, + Err(syn::Error::new_spanned(item, "can only derive Queryable for a struct and enum \ in non-JSON mode" - )); + )) } } } diff --git a/edgedb-derive/src/shape.rs b/edgedb-derive/src/shape.rs index f2d658e9..79106147 100644 --- a/edgedb-derive/src/shape.rs +++ b/edgedb-derive/src/shape.rs @@ -80,7 +80,7 @@ pub fn derive_struct(s: &syn::ItemStruct) -> syn::Result { } }); let field_decoders = fields.iter().map(|field| { - let ref fieldname = field.name; + let fieldname = &field.name; if field.attrs.json { quote!{ let #fieldname: ::edgedb_protocol::model::Json = @@ -99,7 +99,7 @@ pub fn derive_struct(s: &syn::ItemStruct) -> syn::Result { } }).collect::(); let field_checks = fields.iter().map(|field| { - let ref name_str = field.str_name; + let name_str = &field.str_name; let mut result = quote!{ let el = &shape.elements[idx]; if(el.name != #name_str) { @@ -107,7 +107,7 @@ pub fn derive_struct(s: &syn::ItemStruct) -> syn::Result { } idx += 1; }; - let ref fieldtype = field.ty; + let fieldtype = &field.ty; if field.attrs.json { result.extend(quote!{ <::edgedb_protocol::model::Json as diff --git a/edgedb-derive/tests/json.rs b/edgedb-derive/tests/json.rs index de8ca579..72a07ac0 100644 --- a/edgedb-derive/tests/json.rs +++ b/edgedb-derive/tests/json.rs @@ -25,7 +25,7 @@ fn old_decoder() -> Decoder { let mut dec = Decoder::default(); dec.has_implicit_id = true; dec.has_implicit_tid = true; - return dec; + dec } #[test] diff --git a/edgedb-derive/tests/list_scalar_types.rs b/edgedb-derive/tests/list_scalar_types.rs index f5690285..49a924d8 100644 --- a/edgedb-derive/tests/list_scalar_types.rs +++ b/edgedb-derive/tests/list_scalar_types.rs @@ -12,7 +12,7 @@ fn old_decoder() -> Decoder { let mut dec = Decoder::default(); dec.has_implicit_id = true; dec.has_implicit_tid = true; - return dec; + dec } #[test] diff --git a/edgedb-errors/src/bin/edgedb_gen_errors.rs b/edgedb-errors/src/bin/edgedb_gen_errors.rs index 4f5f234d..b8496207 100644 --- a/edgedb-errors/src/bin/edgedb_gen_errors.rs +++ b/edgedb-errors/src/bin/edgedb_gen_errors.rs @@ -6,8 +6,8 @@ use std::env::args; fn find_tag<'x>(template: &'x str, tag: &str) -> (usize, usize, &'x str) { let tag_line = format!("// <{}>\n", tag); let pos = template.find(&tag_line) - .expect(&format!("missing tag <{}>", tag)); - let indent = template[..pos].rfind("\n").unwrap_or(0) + 1; + .unwrap_or_else(|| panic!("missing tag <{}>", tag)); + let indent = template[..pos].rfind('\n').unwrap_or(0) + 1; (pos, pos + tag_line.len(), &template[indent..pos]) } @@ -15,9 +15,9 @@ fn find_macro<'x>(template: &'x str, name: &str) -> &'x str { let macro_line = format!("macro_rules! {} {{", name); let pos = template.find(¯o_line) .map(|pos| pos + macro_line.len()) - .expect(&format!("missing macro {}", name)); + .unwrap_or_else(|| panic!("missing macro {}", name)); let body = template[pos..] - .find("{").map(|x| pos + x + 1) + .find('{').map(|x| pos + x + 1) .and_then(|open| { let mut level = 0; for (idx, c) in template[open..].char_indices() { @@ -32,11 +32,11 @@ fn find_macro<'x>(template: &'x str, name: &str) -> &'x str { }) .map(|(begin, end)| template[begin..end].trim()) .expect("invalid macro"); - return body; + body } fn main() -> Result<(), Box> { - let filename = args().skip(1).next().expect("single argument"); + let filename = args().nth(1).expect("single argument"); let mut all_errors = Vec::new(); let mut all_tags = BTreeSet::<&str>::new(); let data = fs::read_to_string(filename)?; @@ -49,7 +49,7 @@ fn main() -> Result<(), Box> { let code = u32::from_str_radix( &parts.next().expect("code always specified") .strip_prefix("0x").expect("code contains 0x") - .replace("_", ""), + .replace('_', ""), 16 ).expect("code is valid hex"); let name = parts.next().expect("name always specified"); diff --git a/edgedb-errors/src/error.rs b/edgedb-errors/src/error.rs index b863f18b..edd6ac85 100644 --- a/edgedb-errors/src/error.rs +++ b/edgedb-errors/src/error.rs @@ -6,7 +6,7 @@ use std::fmt; use std::str; use crate::kinds::{tag_check, error_name}; -use crate::kinds::{UserError}; +use crate::kinds::UserError; use crate::traits::{ErrorKind, Field}; @@ -53,9 +53,6 @@ pub(crate) struct Inner { pub fields: HashMap<(&'static str, TypeId), Box>, } -trait Assert: Send + Sync + 'static {} -impl Assert for Error {} - impl Error { pub fn is(&self) -> bool { T::is_superclass_of(self.0.code) @@ -182,12 +179,10 @@ impl fmt::Display for Error { } } + } else if let Some(last) = self.0.messages.last() { + write!(f, "{}: {}", kind, last)?; } else { - if let Some(last) = self.0.messages.last() { - write!(f, "{}: {}", kind, last)?; - } else { - write!(f, "{}", kind)?; - } + write!(f, "{}", kind)?; } if let Some((line, col)) = self.line().zip(self.column()) { write!(f, " (on line {}, column {})", line, col)?; diff --git a/edgedb-errors/src/traits.rs b/edgedb-errors/src/traits.rs index 8668c999..d622f37c 100644 --- a/edgedb-errors/src/traits.rs +++ b/edgedb-errors/src/traits.rs @@ -85,8 +85,7 @@ pub trait Sealed { const TAGS: u32; // TODO(tailhook) use uuids of errors instead fn is_superclass_of(code: u32) -> bool { - let mask = 0xFFFFFFFF_u32 - << (Self::CODE.trailing_zeros() / 8)*8; + let mask = 0xFFFFFFFF_u32 << ((Self::CODE.trailing_zeros() / 8)*8); code & mask == Self::CODE } fn has_tag(bit: u32) -> bool { diff --git a/edgedb-protocol/src/client_message.rs b/edgedb-protocol/src/client_message.rs index ab9d8e17..04645895 100644 --- a/edgedb-protocol/src/client_message.rs +++ b/edgedb-protocol/src/client_message.rs @@ -738,7 +738,7 @@ impl Decode for Restore { let jobs = buf.get_u16(); let data = buf.copy_to_bytes(buf.remaining()); - return Ok(Restore { jobs, headers, data }) + Ok(Restore { jobs, headers, data }) } } @@ -754,7 +754,7 @@ impl Encode for RestoreBlock { impl Decode for RestoreBlock { fn decode(buf: &mut Input) -> Result { let data = buf.copy_to_bytes(buf.remaining()); - return Ok(RestoreBlock { data }) + Ok(RestoreBlock { data }) } } diff --git a/edgedb-protocol/src/codec.rs b/edgedb-protocol/src/codec.rs index d47396a6..473fd305 100644 --- a/edgedb-protocol/src/codec.rs +++ b/edgedb-protocol/src/codec.rs @@ -247,14 +247,14 @@ impl ObjectShape { impl Deref for ObjectShape { type Target = ObjectShapeInfo; fn deref(&self) -> &ObjectShapeInfo { - &*self.0 + &self.0 } } impl Deref for NamedTupleShape { type Target = NamedTupleShapeInfo; fn deref(&self) -> &NamedTupleShapeInfo { - &*self.0 + &self.0 } } @@ -293,7 +293,7 @@ impl<'a> CodecBuilder<'a> { D::TypeAnnotation(..) => unreachable!(), } } else { - return errors::UnexpectedTypePos { position: pos.0 }.fail()?; + errors::UnexpectedTypePos { position: pos.0 }.fail()? } } } @@ -332,7 +332,7 @@ pub fn scalar_codec(uuid: &UuidVal) -> Result, CodecError> { STD_BIGINT => Ok(Arc::new(BigInt {})), CFG_MEMORY => Ok(Arc::new(ConfigMemory {})), PGVECTOR_VECTOR => Ok(Arc::new(Vector {})), - _ => return errors::UndefinedBaseScalar { uuid: uuid.clone() }.fail()?, + _ => errors::UndefinedBaseScalar { uuid: uuid.clone() }.fail()?, } } @@ -607,7 +607,7 @@ impl Tuple { fn build(d: &descriptors::TupleTypeDescriptor, dec: &CodecBuilder) -> Result { - return Ok(Tuple { + Ok(Tuple { elements: d.element_types.iter() .map(|&t| dec.build(t)) .collect::>()?, @@ -628,14 +628,14 @@ impl NamedTuple { } } -fn decode_tuple<'t>(mut elements:DecodeTupleLike, codecs:&Vec>) -> Result, DecodeError>{ +fn decode_tuple(mut elements:DecodeTupleLike, codecs: &[Arc]) -> Result, DecodeError>{ codecs .iter() .map(|codec| codec.decode(elements.read()?.ok_or_else(|| errors::MissingRequiredElement.build())?)) .collect::, DecodeError>>() } -fn decode_array_like<'t>(elements: DecodeArrayLike<'t>, codec:&dyn Codec) -> Result, DecodeError>{ +fn decode_array_like(elements: DecodeArrayLike<'_>, codec:&dyn Codec) -> Result, DecodeError>{ elements .map(|element| codec.decode(element?)) .collect::, DecodeError>>() @@ -763,7 +763,7 @@ impl Codec for ArrayAdapter { let len = buf.get_i32() as usize; ensure!(buf.remaining() >= len, errors::Underflow); ensure!(buf.remaining() <= len, errors::ExtraData); - return self.0.decode(buf); + self.0.decode(buf) } fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> @@ -831,7 +831,7 @@ impl From<&str> for EnumValue { impl std::ops::Deref for EnumValue { type Target = str; fn deref(&self) -> &str { - &*self.0 + &self.0 } } @@ -1068,7 +1068,7 @@ pub(crate) fn encode_local_time(buf: &mut BytesMut, val: &model::LocalTime) impl Codec for Json { fn decode(&self, buf: &[u8]) -> Result { - RawCodec::decode(buf).map(|json: model::Json| Value::Json(json.into())) + RawCodec::decode(buf).map(|json: model::Json| Value::Json(json)) } fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> @@ -1099,7 +1099,7 @@ impl Codec for Tuple { fn decode(&self, buf: &[u8]) -> Result { let elements = DecodeTupleLike::new_object(buf, self.elements.len())?; let items = decode_tuple(elements, &self.elements)?; - return Ok(Value::Tuple(items)) + Ok(Value::Tuple(items)) } fn encode(&self, buf: &mut BytesMut, val: &Value) -> Result<(), EncodeError> @@ -1132,7 +1132,7 @@ impl Codec for NamedTuple { fn decode(&self, buf: &[u8]) -> Result { let elements = DecodeTupleLike::new_tuple(buf, self.codecs.len())?; let fields = decode_tuple(elements, &self.codecs)?; - return Ok(Value::NamedTuple { + Ok(Value::NamedTuple { shape: self.shape.clone(), fields, }) @@ -1296,7 +1296,7 @@ impl Codec for Range { let pos = buf.len(); buf.reserve(4); buf.put_u32(0); // replaced after serializing a value - self.element.encode(buf, &lower)?; + self.element.encode(buf, lower)?; let len = buf.len()-pos-4; buf[pos..pos+4].copy_from_slice( &u32::try_from(len) @@ -1308,7 +1308,7 @@ impl Codec for Range { let pos = buf.len(); buf.reserve(4); buf.put_u32(0); // replaced after serializing a value - self.element.encode(buf, &upper)?; + self.element.encode(buf, upper)?; let len = buf.len()-pos-4; buf[pos..pos+4].copy_from_slice( &u32::try_from(len) diff --git a/edgedb-protocol/src/common.rs b/edgedb-protocol/src/common.rs index 0b0b6c05..2f2215e6 100644 --- a/edgedb-protocol/src/common.rs +++ b/edgedb-protocol/src/common.rs @@ -77,7 +77,7 @@ impl RawTypedesc { } } pub fn decode(&self) -> Result { - let ref mut cur = Input::new( + let cur = &mut Input::new( self.proto.clone(), self.data.clone(), ); @@ -134,6 +134,6 @@ impl CompilationOptions { cflags |= CompilationFlags::INJECT_OUTPUT_TYPE_IDS; } // TODO(tailhook) object ids - return cflags; + cflags } } diff --git a/edgedb-protocol/src/descriptors.rs b/edgedb-protocol/src/descriptors.rs index bd9581a7..9fc1783a 100644 --- a/edgedb-protocol/src/descriptors.rs +++ b/edgedb-protocol/src/descriptors.rs @@ -436,7 +436,7 @@ fn serialize_variables(enc: &mut Encoder, variables: &BTreeMap, let mut serialized = 0; for (idx, el) in desc.elements.iter().enumerate() { if let Some(value) = variables.get(&el.name) { - value.check_descriptor(&enc.ctx, el.type_pos)?; + value.check_descriptor(enc.ctx, el.type_pos)?; serialized += 1; enc.buf.reserve(8); enc.buf.put_u32(idx as u32); diff --git a/edgedb-protocol/src/encoding.rs b/edgedb-protocol/src/encoding.rs index 6fe6409f..7f983613 100644 --- a/edgedb-protocol/src/encoding.rs +++ b/edgedb-protocol/src/encoding.rs @@ -99,7 +99,7 @@ impl Output<'_> { } } pub fn proto(&self) -> &ProtocolVersion { - &self.proto + self.proto } pub fn reserve(&mut self, size: usize) { self.bytes.reserve(size) @@ -170,10 +170,10 @@ impl Decode for String { ensure!(buf.remaining() >= len, errors::Underflow); let mut data = vec![0u8; len]; buf.copy_to_slice(&mut data[..]); - let result = String::from_utf8(data) + + String::from_utf8(data) .map_err(|e| e.utf8_error()) - .context(errors::InvalidUtf8); - return result; + .context(errors::InvalidUtf8) } } diff --git a/edgedb-protocol/src/error_response.rs b/edgedb-protocol/src/error_response.rs index 314bc480..d4359336 100644 --- a/edgedb-protocol/src/error_response.rs +++ b/edgedb-protocol/src/error_response.rs @@ -3,10 +3,10 @@ use edgedb_errors::Error; use crate::server_message::ErrorResponse; -impl Into for ErrorResponse { - fn into(self) -> Error { - Error::from_code(self.code) - .context(self.message) - .with_headers(self.attributes) +impl From for Error { + fn from(val: ErrorResponse) -> Self { + Error::from_code(val.code) + .context(val.message) + .with_headers(val.attributes) } } diff --git a/edgedb-protocol/src/errors.rs b/edgedb-protocol/src/errors.rs index 87386222..831d5601 100644 --- a/edgedb-protocol/src/errors.rs +++ b/edgedb-protocol/src/errors.rs @@ -2,7 +2,6 @@ use std::str; use std::error::Error; use snafu::{Snafu, Backtrace, IntoError}; -use uuid; use crate::value::Value; diff --git a/edgedb-protocol/src/model/bignum.rs b/edgedb-protocol/src/model/bignum.rs index 47cc536e..5b30409f 100644 --- a/edgedb-protocol/src/model/bignum.rs +++ b/edgedb-protocol/src/model/bignum.rs @@ -32,7 +32,7 @@ impl BigInt { self.digits.remove(0); self.weight -= 1; } - return self + self } fn trailing_zero_groups(&self) -> i16 { @@ -64,17 +64,17 @@ impl std::fmt::Display for BigInt { impl From for BigInt { fn from(v: u64) -> BigInt { - return BigInt { + BigInt { negative: false, weight: 4, digits: vec![ - (v / 10000_0000_0000_0000 % 10000) as u16, - (v / 10000_0000_0000 % 10000) as u16, - (v / 10000_0000 % 10000) as u16, + (v / 10_000_000_000_000_000 % 10000) as u16, + (v / 1_000_000_000_000 % 10000) as u16, + (v / 100_000_000 % 10000) as u16, (v / 10000 % 10000) as u16, (v % 10000) as u16, ], - }.normalize(); + }.normalize() } } @@ -85,31 +85,31 @@ impl From for BigInt { } else { (v as u64, false) }; - return BigInt { + BigInt { negative, weight: 4, digits: vec![ - (abs / 10000_0000_0000_0000 % 10000) as u16, - (abs / 10000_0000_0000 % 10000) as u16, - (abs / 10000_0000 % 10000) as u16, + (abs / 10_000_000_000_000_000 % 10000) as u16, + (abs / 1_000_000_000_000 % 10000) as u16, + (abs / 100_000_000 % 10000) as u16, (abs / 10000 % 10000) as u16, (abs % 10000) as u16, ], - }.normalize(); + }.normalize() } } impl From for BigInt { fn from(v: u32) -> BigInt { - return BigInt { + BigInt { negative: false, weight: 2, digits: vec![ - (v / 10000_0000) as u16, + (v / 100_000_000) as u16, (v / 10000 % 10000) as u16, (v % 10000) as u16, ], - }.normalize(); + }.normalize() } } @@ -120,15 +120,15 @@ impl From for BigInt { } else { (v as u32, false) }; - return BigInt { + BigInt { negative, weight: 2, digits: vec![ - (abs / 10000_0000) as u16, + (abs / 100_000_000) as u16, (abs / 10000 % 10000) as u16, (abs % 10000) as u16, ], - }.normalize(); + }.normalize() } } @@ -142,25 +142,25 @@ impl Decimal { self.digits.remove(0); self.weight -= 1; } - return self + self } } #[cfg(test)] #[allow(dead_code)] // used by optional tests -pub(self) mod test_helpers{ + mod test_helpers{ use rand::Rng; pub fn gen_u64(rng: &mut T) -> u64 { // change distribution to generate different length more frequently let max = 10_u64.pow(rng.gen_range(0..20)); - return rng.gen_range(0..max); + rng.gen_range(0..max) } pub fn gen_i64(rng: &mut T) -> i64 { // change distribution to generate different length more frequently let max = 10_i64.pow(rng.gen_range(0..19)); - return rng.gen_range(-max..max); + rng.gen_range(-max..max) } } @@ -209,7 +209,7 @@ mod test { assert_eq!(BigInt::from(u64::MAX).weight, 4); assert_eq!( BigInt::from(u64::MAX).digits, - &[1844, 6744, 0737, 0955, 1615] + &[1844, 6744, 737, 955, 1615] ); assert_eq!(BigInt::from(125i64).weight, 0); @@ -221,7 +221,7 @@ mod test { assert_eq!(BigInt::from(i64::MAX).weight, 4); assert_eq!( BigInt::from(i64::MAX).digits, - &[922, 3372, 0368, 5477, 5807] + &[922, 3372, 368, 5477, 5807] ); assert_eq!(BigInt::from(-125i64).weight, 0); @@ -233,7 +233,7 @@ mod test { assert_eq!(BigInt::from(i64::MIN).weight, 4); assert_eq!( BigInt::from(i64::MIN).digits, - &[922, 3372, 0368, 5477, 5808] + &[922, 3372, 368, 5477, 5808] ); } diff --git a/edgedb-protocol/src/model/json.rs b/edgedb-protocol/src/model/json.rs index cd534255..5e624522 100644 --- a/edgedb-protocol/src/model/json.rs +++ b/edgedb-protocol/src/model/json.rs @@ -37,8 +37,8 @@ impl std::ops::Deref for Json { } } -impl Into for Json { - fn into(self) -> String { - self.0 +impl From for String { + fn from(val: Json) -> Self { + val.0 } } diff --git a/edgedb-protocol/src/model/time.rs b/edgedb-protocol/src/model/time.rs index 05be23a3..6de1bd7f 100644 --- a/edgedb-protocol/src/model/time.rs +++ b/edgedb-protocol/src/model/time.rs @@ -112,15 +112,15 @@ impl Duration { // Note: `std::time::Duration` can't be negative pub fn abs_duration(&self) -> std::time::Duration { if self.micros.is_negative() { - return std::time::Duration::from_micros( - u64::MAX - self.micros as u64 + 1); + std::time::Duration::from_micros( + u64::MAX - self.micros as u64 + 1) } else { - return std::time::Duration::from_micros(self.micros as u64); + std::time::Duration::from_micros(self.micros as u64) } } fn try_from_pg_simple_format(input: &str) -> Result { - let mut split = input.trim_end().splitn(3, ":"); + let mut split = input.trim_end().splitn(3, ':'); let mut value: i64 = 0; let negative; let mut pos: usize = 0; @@ -128,7 +128,7 @@ impl Duration { { let hour_str = split .next() - .filter(|s| s.len() > 0) + .filter(|s| !s.is_empty()) .ok_or_else(|| ParseDurationError::new( "EOF met, expecting `+`, `-` or int") .not_final() @@ -137,13 +137,13 @@ impl Duration { pos += hour_str.len() - 1; let hour_str = hour_str.trim_start(); let hour = hour_str - .strip_prefix("-") + .strip_prefix('-') .unwrap_or(hour_str) .parse::() .map_err(|e| ParseDurationError::from(e).not_final().pos(pos) )?; - negative = hour_str.starts_with("-"); + negative = hour_str.starts_with('-'); value += (hour.abs() as i64) * MICROS_PER_HOUR; } @@ -156,7 +156,7 @@ impl Duration { .not_final() .pos(pos) )?; - if minute_str.len() > 0 { + if !minute_str.is_empty() { pos += minute_str.len(); let minute = minute_str .parse::() @@ -177,7 +177,7 @@ impl Duration { if let Some(remaining) = split.last() { pos += 1; - let mut sec_split = remaining.splitn(2, "."); + let mut sec_split = remaining.splitn(2, '.'); { let second_str = sec_split.next().unwrap(); @@ -233,7 +233,7 @@ impl Duration { let mut current = parts.next(); if let Some(part) = current { - if let Some(hour_str) = part.strip_suffix("H") { + if let Some(hour_str) = part.strip_suffix('H') { let hour = hour_str .parse::() .map_err(|e| ParseDurationError::from(e) @@ -246,7 +246,7 @@ impl Duration { } if let Some(part) = current { - if let Some(minute_str) = part.strip_suffix("M") { + if let Some(minute_str) = part.strip_suffix('M') { let minute = minute_str .parse::() .map_err(|e| ParseDurationError::from(e) @@ -259,11 +259,11 @@ impl Duration { } if let Some(part) = current { - if let Some(second_str) = part.strip_suffix("S") { + if let Some(second_str) = part.strip_suffix('S') { let (second_str, subsec_str) = second_str .split_once('.') .map(|(sec, sub)| - (sec, sub.get(..6).or_else(||Some(sub)))) + (sec, sub.get(..6).or(Some(sub)))) .unwrap_or_else(|| (second_str, None)); let second = second_str @@ -433,7 +433,7 @@ impl FromStr for Duration { if let Ok(seconds) = input.trim().parse::() { seconds .checked_mul(MICROS_PER_SECOND) - .map(|micros| Self::from_micros(micros)) + .map(Self::from_micros) .ok_or_else(|| Self::Err::new("seconds value out of range") .pos(input.len() - 1)) } else { @@ -465,7 +465,7 @@ impl LocalDatetime { pub(crate) fn from_postgres_micros(micros: i64) -> Result { - if micros < Self::MIN.micros || micros > Self::MAX.micros { + if !(Self::MIN.micros..=Self::MAX.micros).contains(µs) { return Err(OutOfRangeError); } Ok(LocalDatetime { micros }) @@ -476,8 +476,7 @@ impl LocalDatetime { note="use Datetime::try_from_unix_micros(v).into() instead", )] pub fn from_micros(micros: i64) -> LocalDatetime { - Self::from_postgres_micros(micros).expect(&format!( - "LocalDatetime::from_micros({}) is outside the valid datetime range", + Self::from_postgres_micros(micros).unwrap_or_else(|_| panic!("LocalDatetime::from_micros({}) is outside the valid datetime range", micros)) } @@ -510,7 +509,7 @@ impl LocalDatetime { impl From for LocalDatetime { fn from(d: Datetime) -> LocalDatetime { - return LocalDatetime { micros: d.micros } + LocalDatetime { micros: d.micros } } } @@ -533,14 +532,14 @@ impl LocalTime { pub(crate) fn try_from_micros(micros: u64) -> Result { if micros < MICROS_PER_DAY { - Ok(LocalTime { micros: micros }) + Ok(LocalTime { micros }) } else { Err(OutOfRangeError) } } pub fn from_micros(micros: u64) -> LocalTime { - Self::try_from_micros(micros).ok().expect("LocalTime is out of range") + Self::try_from_micros(micros).expect("LocalTime is out of range") } pub fn to_micros(self) -> u64 { @@ -568,14 +567,14 @@ impl LocalTime { #[cfg(test)] // currently only used by tests, will be used by parsing later fn from_hmsu(hour: u8, minute: u8, second:u8, microsecond: u32) -> LocalTime { - assert!(microsecond < 1000_000); + assert!(microsecond < 1_000_000); assert!(second < 60); assert!(minute < 60); assert!(hour < 24); let micros = microsecond as u64 - + 1000_000 * (second as u64 + + 1_000_000 * (second as u64 + 60 * (minute as u64 + 60 * (hour as u64))); LocalTime::from_micros(micros) @@ -610,7 +609,7 @@ impl LocalDate { pub const UNIX_EPOCH : LocalDate = LocalDate { days: -(30 * 365 + 7) }; // 1970-01-01 fn try_from_days(days: i32) -> Result { - if days < Self::MIN.days || days > Self::MAX.days { + if !(Self::MIN.days..=Self::MAX.days).contains(&days) { return Err(OutOfRangeError); } Ok(LocalDate { days }) @@ -618,7 +617,7 @@ impl LocalDate { pub fn from_days(days: i32) -> LocalDate { Self::try_from_days(days) - .expect(&format!("LocalDate::from_days({}) is outside the valid date range", days)) + .unwrap_or_else(|_| panic!("LocalDate::from_days({}) is outside the valid date range", days)) } pub fn to_days(self) -> i32 { @@ -626,19 +625,18 @@ impl LocalDate { } pub fn from_ymd(year:i32, month: u8, day:u8) -> LocalDate { - Self::try_from_ymd(year, month, day).expect(&format!( - "invalid date {:04}-{:02}-{:02}", + Self::try_from_ymd(year, month, day).unwrap_or_else(|_| panic!("invalid date {:04}-{:02}-{:02}", year, month, day)) } fn try_from_ymd(year:i32, month: u8, day:u8) -> Result { - if day < 1 || day > 31 { + if !(1..=31).contains(&day) { return Err(OutOfRangeError); } - if month < 1 || month > 12 { + if !(1..=12).contains(&month) { return Err(OutOfRangeError); } - if year < MIN_YEAR || year > MAX_YEAR { + if !(MIN_YEAR..=MAX_YEAR).contains(&year) { return Err(OutOfRangeError); } @@ -753,7 +751,7 @@ impl Datetime { pub(crate) fn from_postgres_micros(micros: i64) -> Result { - if micros < Self::MIN.micros || micros > Self::MAX.micros { + if !(Self::MIN.micros..=Self::MAX.micros).contains(µs) { return Err(OutOfRangeError); } Ok(Datetime { micros }) @@ -761,7 +759,7 @@ impl Datetime { fn _from_micros(micros: i64) -> Option { let micros = micros.checked_add(Self::UNIX_EPOCH.micros)?; - if micros < Self::MIN.micros || micros > Self::MAX.micros { + if !(Self::MIN.micros..=Self::MAX.micros).contains(µs) { return None; } Some(Datetime { micros }) @@ -772,8 +770,7 @@ impl Datetime { note="use from_unix_micros instead", )] pub fn from_micros(micros: i64) -> Datetime { - Self::from_postgres_micros(micros).expect(&format!( - "Datetime::from_micros({}) is outside the valid datetime range", + Self::from_postgres_micros(micros).unwrap_or_else(|_| panic!("Datetime::from_micros({}) is outside the valid datetime range", micros)) } @@ -930,17 +927,18 @@ impl std::ops::Add<&'_ std::time::Duration> for Datetime { return Datetime::MAX; }; if let Some(micros) = self.micros.checked_add(duration.micros) { - return Datetime { micros }; + Datetime { micros } } else { debug_assert!(false, "duration is out of range"); - return Datetime::MAX; + Datetime::MAX } } } impl std::ops::Add for Datetime { type Output = Datetime; + #[allow(clippy::op_ref)] fn add(self, other: std::time::Duration) -> Datetime { self + &other } @@ -1008,7 +1006,7 @@ mod test { assert_eq!(0, LocalDate::from_ymd(2000, 1, 1).to_days()); assert_eq!(-365, LocalDate::from_ymd(1999, 1, 1).to_days()); assert_eq!(366, LocalDate::from_ymd(2001, 1, 1).to_days()); - assert_eq!(-730119, LocalDate::from_ymd(0001, 1, 1).to_days()); + assert_eq!(-730119, LocalDate::from_ymd(1, 1, 1).to_days()); assert_eq!(2921575, LocalDate::from_ymd(9999, 1, 1).to_days()); assert_eq!(Err(OutOfRangeError), LocalDate::try_from_ymd(2001, 1, 32)); @@ -1046,7 +1044,7 @@ mod test { let days_in_current_month = DAYS_IN_MONTH_LEAP[month - 1]; total_days += days_in_current_month as i32; - let end_of_month = LocalDate::from_ymd(2001, month as u8, days_in_current_month as u8).to_days(); + let end_of_month = LocalDate::from_ymd(2001, month as u8, days_in_current_month).to_days(); assert_eq!(total_days - 1, end_of_month - start_of_year); } assert_eq!(365, total_days); @@ -1114,9 +1112,9 @@ mod test { 0, 10, 10_020, - 12345 * 1000_000, - 12345 * 1001_000, - 12345 * 1001_001, + 12345 * 1_000_000, + 12345 * 1_001_000, + 12345 * 1_001_001, MICROS_PER_DAY - 1, ]; TIMES.iter().copied() @@ -1206,8 +1204,8 @@ mod test { assert_eq!(dur_str(1_000_000), "0:00:01"); assert_eq!(dur_str(1), "0:00:00.000001"); assert_eq!(dur_str(7_015_000), "0:00:07.015"); - assert_eq!(dur_str(10_000_000__015_000), "2777:46:40.015"); - assert_eq!(dur_str(12_345_678__000_000), "3429:21:18"); + assert_eq!(dur_str(10_000_000_015_000), "2777:46:40.015"); + assert_eq!(dur_str(12_345_678_000_000), "3429:21:18"); } #[test] @@ -1218,19 +1216,19 @@ mod test { assert_eq!(micros(" 100 "), 100_000_000); assert_eq!(micros("123"), 123_000_000); assert_eq!(micros("-123"), -123_000_000); - assert_eq!(micros(" 20 mins 1hr "), 4800_000_000); - assert_eq!(micros(" 20 mins -1hr "), -2400_000_000); - assert_eq!(micros(" 20us 1h 20 "), 3620_000_020); - assert_eq!(micros(" -20us 1h 20 "), 3619_999_980); - assert_eq!(micros(" -20US 1H 20 "), 3619_999_980); - assert_eq!(micros("1 hour 20 minutes 30 seconds 40 milliseconds 50 microseconds"), 4830_040_050); - assert_eq!(micros("1 hour 20 minutes +30seconds 40 milliseconds -50microseconds"), 4830_039_950); - assert_eq!(micros("1 houR 20 minutes 30SECOND 40 milliseconds 50 us"), 4830_040_050); - assert_eq!(micros(" 20 us 1H 20 minutes "), 4800_000_020); - assert_eq!(micros("-1h"), -3600_000_000); - assert_eq!(micros("100h"), 3600_000_000_00); - let h12 = 12 * 3600_000_000 as i64; - let m12 = 12 * 60_000_000 as i64; + assert_eq!(micros(" 20 mins 1hr "), 4_800_000_000); + assert_eq!(micros(" 20 mins -1hr "), -2_400_000_000); + assert_eq!(micros(" 20us 1h 20 "), 3_620_000_020); + assert_eq!(micros(" -20us 1h 20 "), 3_619_999_980); + assert_eq!(micros(" -20US 1H 20 "), 3_619_999_980); + assert_eq!(micros("1 hour 20 minutes 30 seconds 40 milliseconds 50 microseconds"), 4_830_040_050); + assert_eq!(micros("1 hour 20 minutes +30seconds 40 milliseconds -50microseconds"), 4_830_039_950); + assert_eq!(micros("1 houR 20 minutes 30SECOND 40 milliseconds 50 us"), 4_830_040_050); + assert_eq!(micros(" 20 us 1H 20 minutes "), 4_800_000_020); + assert_eq!(micros("-1h"), -3_600_000_000); + assert_eq!(micros("100h"), 360_000_000_000); + let h12 = 12 * 3_600_000_000_i64; + let m12 = 12 * 60_000_000_i64; assert_eq!(micros(" 12:12:12.2131 "), h12 + m12 + 12_213_100); assert_eq!(micros("-12:12:12.21313"), -(h12 + m12 + 12_213_130)); assert_eq!(micros("-12:12:12.213134"), -(h12 + m12 + 12_213_134)); @@ -1257,18 +1255,18 @@ mod test { assert_eq!(micros(" +00005"), 5_000_000); assert_eq!(micros(" -00005"), -5_000_000); assert_eq!(micros("PT"), 0); - assert_eq!(micros("PT1H1M1S"), 3661_000_000); + assert_eq!(micros("PT1H1M1S"), 3_661_000_000); assert_eq!(micros("PT1M1S"), 61_000_000); assert_eq!(micros("PT1S"), 1_000_000); - assert_eq!(micros("PT1H1S"), 3601_000_000); - assert_eq!(micros("PT1H1M1.1S"), 3661_100_000); - assert_eq!(micros("PT1H1M1.01S"), 3661_010_000); - assert_eq!(micros("PT1H1M1.10S"), 3661_100_000); - assert_eq!(micros("PT1H1M1.1234567S"), 3661_123_456); - assert_eq!(micros("PT1H1M1.1234564S"), 3661_123_456); - assert_eq!(micros("PT-1H1M1.1S"), -3538_900_000); - assert_eq!(micros("PT+1H-1M1.1S"), 3541_100_000); - assert_eq!(micros("PT1H+1M-1.1S"), 3658_900_000); + assert_eq!(micros("PT1H1S"), 3_601_000_000); + assert_eq!(micros("PT1H1M1.1S"), 3_661_100_000); + assert_eq!(micros("PT1H1M1.01S"), 3_661_010_000); + assert_eq!(micros("PT1H1M1.10S"), 3_661_100_000); + assert_eq!(micros("PT1H1M1.1234567S"), 3_661_123_456); + assert_eq!(micros("PT1H1M1.1234564S"), 3_661_123_456); + assert_eq!(micros("PT-1H1M1.1S"), -3_538_900_000); + assert_eq!(micros("PT+1H-1M1.1S"), 3_541_100_000); + assert_eq!(micros("PT1H+1M-1.1S"), 3_658_900_000); fn assert_error(input: &str, expected_pos: usize, pat: &str) { let ParseDurationError { @@ -1347,7 +1345,7 @@ impl RelativeDuration { -> Result { Ok(RelativeDuration { - months: months, + months, days: 0, micros: 0, }) @@ -1565,7 +1563,7 @@ impl DateDuration { -> Result { Ok(DateDuration { - months: months, + months, days: 0, }) } @@ -1643,14 +1641,14 @@ fn nanos_to_micros(nanos: i64) -> i64 { if remainder == 500 && micros % 2 == 1 || remainder > 500 { micros += 1; } - return micros; + micros } #[cfg(feature = "chrono")] mod chrono_interop { use super::*; - use chrono::naive::{NaiveDate, NaiveDateTime, NaiveTime }; - use std::convert::{From, Into, TryFrom}; + use chrono::naive::{NaiveDate, NaiveDateTime, NaiveTime}; + use chrono::DateTime; type ChronoDatetime = chrono::DateTime; @@ -1658,8 +1656,9 @@ mod chrono_interop { fn from(value: &LocalDatetime) -> NaiveDateTime { let timestamp_seconds = value.micros.wrapping_div_euclid(1000_000) - (Datetime::UNIX_EPOCH.micros / 1000_000); let timestamp_nanos = (value.micros.wrapping_rem_euclid(1000_000) * 1000) as u32; - NaiveDateTime::from_timestamp_opt(timestamp_seconds, timestamp_nanos) + DateTime::from_timestamp(timestamp_seconds, timestamp_nanos) .expect("NaiveDateTime range is bigger than LocalDatetime") + .naive_utc() } } @@ -1668,8 +1667,8 @@ mod chrono_interop { fn try_from(d: &NaiveDateTime) -> Result { - let secs = d.timestamp(); - let subsec_nanos = d.timestamp_subsec_nanos(); + let secs = d.and_utc().timestamp(); + let subsec_nanos = d.and_utc().timestamp_subsec_nanos(); let subsec_micros = nanos_to_micros(subsec_nanos.into()); let micros = secs.checked_mul(1_000_000) .and_then(|x| x.checked_add(subsec_micros)) @@ -1824,9 +1823,6 @@ mod chrono_interop { mod test { use super::*; use crate::model::time::test::{ test_times, valid_test_dates, to_debug, CHRONO_MAX_YEAR}; - use std::convert::{TryFrom, TryInto}; - use std::str::FromStr; - use std::fmt::{ Display, Debug }; #[test] fn chrono_roundtrips() -> Result<(), Box> { diff --git a/edgedb-protocol/src/query_arg.rs b/edgedb-protocol/src/query_arg.rs index 2cf317c5..2b618423 100644 --- a/edgedb-protocol/src/query_arg.rs +++ b/edgedb-protocol/src/query_arg.rs @@ -229,7 +229,7 @@ impl QueryArg for Value { } else { let members = { let mut members = members - .into_iter() + .iter() .map(|c| format!("'{c}'")) .collect::>(); members.sort_unstable(); @@ -253,7 +253,7 @@ impl QueryArgs for Value { fn encode(&self, enc: &mut Encoder) -> Result<(), Error> { let codec = enc.ctx.build_codec()?; codec - .encode(&mut enc.buf, self) + .encode(enc.buf, self) .map_err(ClientEncodingError::with_source) } } diff --git a/edgedb-protocol/src/query_result.rs b/edgedb-protocol/src/query_result.rs index f44c5a32..91be7964 100644 --- a/edgedb-protocol/src/query_result.rs +++ b/edgedb-protocol/src/query_result.rs @@ -51,7 +51,7 @@ impl QueryResult for T { fn decode(decoder: &mut Decoder, msg: &Bytes) -> Result { - Queryable::decode(&decoder, msg) + Queryable::decode(decoder, msg) .map_err(ProtocolEncodingError::with_source) } } diff --git a/edgedb-protocol/src/queryable.rs b/edgedb-protocol/src/queryable.rs index bcd0e92c..caf5398e 100644 --- a/edgedb-protocol/src/queryable.rs +++ b/edgedb-protocol/src/queryable.rs @@ -12,21 +12,14 @@ use crate::descriptors::{Descriptor, TypePos}; #[non_exhaustive] +#[derive(Default)] pub struct Decoder { pub has_implicit_id: bool, pub has_implicit_tid: bool, pub has_implicit_tname: bool, } -impl Default for Decoder { - fn default() -> Decoder { - Decoder { - has_implicit_id: false, - has_implicit_tid: false, - has_implicit_tname: false, - } - } -} + pub trait Queryable: Sized { fn decode(decoder: &Decoder, buf: &[u8]) diff --git a/edgedb-protocol/src/serialization/decode/raw_composite.rs b/edgedb-protocol/src/serialization/decode/raw_composite.rs index 1de7afd7..6e278f19 100644 --- a/edgedb-protocol/src/serialization/decode/raw_composite.rs +++ b/edgedb-protocol/src/serialization/decode/raw_composite.rs @@ -140,7 +140,7 @@ mod inner { ensure!(self.raw.len() >= position, self.underflow()); let result = &self.raw[..position]; self.raw.advance(position); - ensure!(self.count > 0 || self.raw.len() == 0, errors::ExtraData); + ensure!(self.count > 0 || self.raw.is_empty(), errors::ExtraData); Ok(result) } @@ -164,7 +164,7 @@ mod inner { pub fn read_array_like_element(&mut self) -> Result<&'t [u8], DecodeError> { ensure!(self.raw.remaining() >= 4, self.underflow()); let len = self.raw.get_i32() as usize; - Ok(self.read_element(len)?) + self.read_element(len) } pub fn read_tuple_like_header(mut buf:&'t [u8]) -> Result { diff --git a/edgedb-protocol/src/serialization/decode/raw_scalar.rs b/edgedb-protocol/src/serialization/decode/raw_scalar.rs index c21ebe3e..efbd37e8 100644 --- a/edgedb-protocol/src/serialization/decode/raw_scalar.rs +++ b/edgedb-protocol/src/serialization/decode/raw_scalar.rs @@ -194,7 +194,7 @@ impl ScalarArg for bool { impl<'t> RawCodec<'t> for i16 { fn decode(mut buf: &[u8]) -> Result { ensure_exact_size(buf, size_of::())?; - return Ok(buf.get_i16()); + Ok(buf.get_i16()) } } @@ -219,7 +219,7 @@ impl ScalarArg for i16 { impl<'t> RawCodec<'t> for i32 { fn decode(mut buf: &[u8]) -> Result { ensure_exact_size(buf, size_of::())?; - return Ok(buf.get_i32()); + Ok(buf.get_i32()) } } @@ -244,14 +244,14 @@ impl ScalarArg for i32 { impl<'t> RawCodec<'t> for i64 { fn decode(mut buf: &[u8]) -> Result { ensure_exact_size(buf, size_of::())?; - return Ok(buf.get_i64()); + Ok(buf.get_i64()) } } impl<'t> RawCodec<'t> for ConfigMemory { fn decode(mut buf: &[u8]) -> Result { ensure_exact_size(buf, size_of::())?; - return Ok(ConfigMemory(buf.get_i64())); + Ok(ConfigMemory(buf.get_i64())) } } @@ -276,7 +276,7 @@ impl ScalarArg for i64 { impl<'t> RawCodec<'t> for f32 { fn decode(mut buf: &[u8]) -> Result { ensure_exact_size(buf, size_of::())?; - return Ok(buf.get_f32()); + Ok(buf.get_f32()) } } @@ -301,7 +301,7 @@ impl ScalarArg for f32 { impl<'t> RawCodec<'t> for f64 { fn decode(mut buf: &[u8]) -> Result { ensure_exact_size(buf, size_of::())?; - return Ok(buf.get_f64()); + Ok(buf.get_f64()) } } @@ -425,7 +425,7 @@ impl ScalarArg for Decimal { -> Result<(), Error> { codec::encode_decimal(encoder.buf, self) - .map_err(|e| ClientEncodingError::with_source(e)) + .map_err(ClientEncodingError::with_source) } fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> @@ -499,7 +499,7 @@ impl ScalarArg for BigInt { -> Result<(), Error> { codec::encode_big_int(encoder.buf, self) - .map_err(|e| ClientEncodingError::with_source(e)) + .map_err(ClientEncodingError::with_source) } fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> @@ -558,7 +558,7 @@ impl ScalarArg for Duration { -> Result<(), Error> { codec::encode_duration(encoder.buf, self) - .map_err(|e| ClientEncodingError::with_source(e)) + .map_err(ClientEncodingError::with_source) } fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> @@ -596,7 +596,7 @@ impl ScalarArg for RelativeDuration { -> Result<(), Error> { codec::encode_relative_duration(encoder.buf, self) - .map_err(|e| ClientEncodingError::with_source(e)) + .map_err(ClientEncodingError::with_source) } fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> @@ -623,7 +623,7 @@ impl ScalarArg for SystemTime { .map_err(|e| ClientEncodingError::with_source(e) .context("cannot serialize SystemTime value"))?; codec::encode_datetime(encoder.buf, &val) - .map_err(|e| ClientEncodingError::with_source(e)) + .map_err(ClientEncodingError::with_source) } fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> @@ -641,8 +641,8 @@ impl ScalarArg for SystemTime { impl<'t> RawCodec<'t> for Datetime { fn decode(buf: &[u8]) -> Result { let micros = i64::decode(buf)?; - Ok(Datetime::from_postgres_micros(micros) - .map_err(|_| errors::InvalidDate.build())?) + Datetime::from_postgres_micros(micros) + .map_err(|_| errors::InvalidDate.build()) } } @@ -651,7 +651,7 @@ impl ScalarArg for Datetime { -> Result<(), Error> { codec::encode_datetime(encoder.buf, self) - .map_err(|e| ClientEncodingError::with_source(e)) + .map_err(ClientEncodingError::with_source) } fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> @@ -676,7 +676,7 @@ impl ScalarArg for LocalDatetime { -> Result<(), Error> { codec::encode_local_datetime(encoder.buf, self) - .map_err(|e| ClientEncodingError::with_source(e)) + .map_err(ClientEncodingError::with_source) } fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> @@ -700,7 +700,7 @@ impl ScalarArg for LocalDate { -> Result<(), Error> { codec::encode_local_date(encoder.buf, self) - .map_err(|e| ClientEncodingError::with_source(e)) + .map_err(ClientEncodingError::with_source) } fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> @@ -715,7 +715,7 @@ impl ScalarArg for LocalDate { impl<'t> RawCodec<'t> for LocalTime { fn decode(buf: &[u8]) -> Result { let micros = i64::decode(buf)?; - ensure!(micros >= 0 && micros < 86_400 * 1_000_000, errors::InvalidDate); + ensure!((0..86_400 * 1_000_000).contains(µs), errors::InvalidDate); Ok(LocalTime { micros: micros as u64 }) } } @@ -725,7 +725,7 @@ impl ScalarArg for DateDuration { -> Result<(), Error> { codec::encode_date_duration(encoder.buf, self) - .map_err(|e| ClientEncodingError::with_source(e)) + .map_err(ClientEncodingError::with_source) } fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> @@ -742,7 +742,7 @@ impl ScalarArg for LocalTime { -> Result<(), Error> { codec::encode_local_time(encoder.buf, self) - .map_err(|e| ClientEncodingError::with_source(e)) + .map_err(ClientEncodingError::with_source) } fn check_descriptor(ctx: &DescriptorContext, pos: TypePos) -> Result<(), Error> @@ -767,13 +767,10 @@ impl ScalarArg for EnumValue { use crate::descriptors::Descriptor::Enumeration; let desc = ctx.get(pos)?; - match desc { - Enumeration(_) => { - // Should we check enum members? - // Should we override `QueryArg` check descriptor for that? - // Or maybe implement just `QueryArg` for enum? - } - _ => {} + if let Enumeration(_) = desc { + // Should we check enum members? + // Should we override `QueryArg` check descriptor for that? + // Or maybe implement just `QueryArg` for enum? } Err(ctx.wrong_type(desc, "enum")) } diff --git a/edgedb-protocol/src/serialization/test_scalars.rs b/edgedb-protocol/src/serialization/test_scalars.rs index c9313986..df75e3f0 100644 --- a/edgedb-protocol/src/serialization/test_scalars.rs +++ b/edgedb-protocol/src/serialization/test_scalars.rs @@ -17,7 +17,7 @@ fn encode(val: impl ScalarArg) -> Bytes { let mut buf = BytesMut::new(); let mut encoder = Encoder::new(&ctx, &mut buf); ScalarArg::encode(&val, &mut encoder).expect("encoded"); - return buf.freeze(); + buf.freeze() } fn decode<'x, T: RawCodec<'x>>(bytes: &'x [u8]) -> T { diff --git a/edgedb-protocol/src/server_message.rs b/edgedb-protocol/src/server_message.rs index d4fd5799..34bc6171 100644 --- a/edgedb-protocol/src/server_message.rs +++ b/edgedb-protocol/src/server_message.rs @@ -269,18 +269,18 @@ impl StateDataDescription { impl ParameterStatus { pub fn parse_system_config(self) -> Result<(Typedesc, Bytes), DecodeError> { - let ref mut cur = Input::new( + let cur = &mut Input::new( self.proto.clone(), self.value, ); let typedesc_data = Bytes::decode(cur)?; let data = Bytes::decode(cur)?; - let ref mut typedesc_buf = Input::new( + let typedesc_buf = &mut Input::new( self.proto, typedesc_data, ); - let typedesc_id = Uuid::decode(typedesc_buf)?.into(); + let typedesc_id = Uuid::decode(typedesc_buf)?; let typedesc = Typedesc::decode_with_id(typedesc_id, typedesc_buf)?; Ok((typedesc, data)) } @@ -320,7 +320,7 @@ impl ServerMessage { /// in the buffer or if extra data is present. pub fn decode(buf: &mut Input) -> Result { use self::ServerMessage as M; - let ref mut data = buf.slice(5..); + let data = &mut buf.slice(5..); let result = match buf[0] { 0x76 => ServerHandshake::decode(data).map(M::ServerHandshake)?, 0x45 => ErrorResponse::decode(data).map(M::ErrorResponse)?, @@ -442,7 +442,7 @@ impl Decode for ErrorResponse { ensure!(buf.remaining() >= 4, errors::Underflow); attributes.insert(buf.get_u16(), Bytes::decode(buf)?); } - return Ok(ErrorResponse { + Ok(ErrorResponse { severity, code, message, attributes, }) } @@ -481,7 +481,7 @@ impl Decode for LogMessage { ensure!(buf.remaining() >= 4, errors::Underflow); attributes.insert(buf.get_u16(), Bytes::decode(buf)?); } - return Ok(LogMessage { + Ok(LogMessage { severity, code, text, attributes, }) } @@ -616,9 +616,9 @@ impl MessageSeverity { _ => Unknown(code), } } - fn to_u8(&self) -> u8 { + fn to_u8(self) -> u8 { use MessageSeverity::*; - match *self { + match self { Debug => 20, Info => 40, Notice => 60, @@ -946,7 +946,7 @@ impl Decode for Data { for _ in 0..num_chunks { data.push(Bytes::decode(buf)?); } - return Ok(Data { data }) + Ok(Data { data }) } } @@ -979,7 +979,7 @@ impl Decode for RestoreReady { } ensure!(buf.remaining() >= 2, errors::Underflow); let jobs = buf.get_u16(); - return Ok(RestoreReady { jobs, headers }) + Ok(RestoreReady { jobs, headers }) } } @@ -994,7 +994,7 @@ impl Encode for RawPacket { impl Decode for RawPacket { fn decode(buf: &mut Input) -> Result { - return Ok(RawPacket { data: buf.copy_to_bytes(buf.remaining()) }) + Ok(RawPacket { data: buf.copy_to_bytes(buf.remaining()) }) } } diff --git a/edgedb-protocol/src/value.rs b/edgedb-protocol/src/value.rs index 5ec01120..2a2d0ee4 100644 --- a/edgedb-protocol/src/value.rs +++ b/edgedb-protocol/src/value.rs @@ -1,8 +1,6 @@ /*! Contains the [Value](crate::value::Value) enum. */ -use std::iter::IntoIterator; - use bytes::Bytes; use crate::codec::{NamedTupleShape, ObjectShape, ShapeElement}; @@ -154,7 +152,7 @@ impl PartialEq for SparseObject { } } let other_num = other.fields.iter().filter(|e| e.is_some()).count(); - return num == other_num; + num == other_num } } diff --git a/edgedb-protocol/tests/client_messages.rs b/edgedb-protocol/tests/client_messages.rs index 607e5244..e0f93489 100644 --- a/edgedb-protocol/tests/client_messages.rs +++ b/edgedb-protocol/tests/client_messages.rs @@ -11,9 +11,9 @@ use edgedb_protocol::client_message::{ClientMessage, ClientHandshake}; use edgedb_protocol::client_message::{ExecuteScript, Execute0, Execute1}; use edgedb_protocol::client_message::{Parse, Prepare, IoFormat, Cardinality}; use edgedb_protocol::client_message::{DescribeStatement, DescribeAspect}; -use edgedb_protocol::client_message::{SaslInitialResponse}; -use edgedb_protocol::client_message::{SaslResponse}; -use edgedb_protocol::client_message::{OptimisticExecute}; +use edgedb_protocol::client_message::SaslInitialResponse; +use edgedb_protocol::client_message::SaslResponse; +use edgedb_protocol::client_message::OptimisticExecute; use edgedb_protocol::client_message::Restore; mod base; @@ -149,7 +149,7 @@ fn optimistic_execute() -> Result<(), Box> { input_typedesc_id: Uuid::from_u128(0xFF), output_typedesc_id: Uuid::from_u128(0x0), arguments: Bytes::new(), - }), b"O\0\0\06\0\0bo\0\0\0\x06COMMIT\ + }), b"O\0\0\x006\0\0bo\0\0\0\x06COMMIT\ \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\xff\ \0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0\0"); Ok(()) diff --git a/edgedb-protocol/tests/codecs.rs b/edgedb-protocol/tests/codecs.rs index 9d036732..8f654ab6 100644 --- a/edgedb-protocol/tests/codecs.rs +++ b/edgedb-protocol/tests/codecs.rs @@ -221,7 +221,7 @@ fn uuid() -> Result<(), Box> { ] )?; encoding_eq!(&codec, b"I(\xcc\x1e e\x11\xea\x88H{S\xa6\xad\xb3\x83", - Value::Uuid("4928cc1e-2065-11ea-8848-7b53a6adb383".parse::()?.into())); + Value::Uuid("4928cc1e-2065-11ea-8848-7b53a6adb383".parse::()?)); Ok(()) } @@ -237,10 +237,10 @@ fn duration() -> Result<(), Box> { // SELECT '2019-11-29T00:00:00Z'-'2000-01-01T00:00:00Z' encoding_eq!(&codec, b"\0\x02;o\xad\xff\0\0\0\0\0\0\0\0\0\0", - Value::Duration(Duration::from_micros(7272*86400*1000_000))); + Value::Duration(Duration::from_micros(7272*86400*1_000_000))); // SELECT '2019-11-29T00:00:00Z'-'2019-11-28T01:00:00Z' encoding_eq!(&codec, b"\0\0\0\x13GC\xbc\0\0\0\0\0\0\0\0\0", - Value::Duration(Duration::from_micros(82800*1000_000))); + Value::Duration(Duration::from_micros(82800*1_000_000))); encoding_eq!(&codec, b"\xff\xff\xff\xff\xd3,\xba\xe0\0\0\0\0\0\0\0\0", Value::Duration(Duration::from_micros(-752043296))); @@ -374,7 +374,7 @@ fn input_codec() -> Result<(), Box> { }; let out_desc = sdd.parse()?; let codec = build_codec(Some(TypePos(16)), - &out_desc.descriptors(), + out_desc.descriptors(), )?; encoding_eq!(&codec, b"\0\0\0\x03\0\0\0\0\0\0\0\x07default\0\0\0\x02\0\0\0\x1c\ diff --git a/edgedb-protocol/tests/datetime_chrono.rs b/edgedb-protocol/tests/datetime_chrono.rs index d66fcd0c..1172b328 100644 --- a/edgedb-protocol/tests/datetime_chrono.rs +++ b/edgedb-protocol/tests/datetime_chrono.rs @@ -2,7 +2,7 @@ mod chrono { use std::str::FromStr; - use std::convert::{TryInto, TryFrom}; + use std::convert::TryInto; use bytes::{BytesMut, Buf}; use edgedb_protocol::codec::{self, Codec}; @@ -164,7 +164,7 @@ mod chrono { assert_eq!(serialized_micros, micros); - let rev = chrono::DateTime::::try_from(edgedb).unwrap(); + let rev = chrono::DateTime::::from(edgedb); assert_eq!(format!("{:?}", rev), formatted); } @@ -322,7 +322,7 @@ mod chrono { assert_eq!(serialized_micros, micros); - let rev = chrono::NaiveDateTime::try_from(edgedb).unwrap(); + let rev = chrono::NaiveDateTime::from(edgedb); assert_eq!(format!("{:?}", rev), formatted); } @@ -386,7 +386,7 @@ mod chrono { )] fn local_time(input: &str, micros: i64, formatted: &str) { let chrono = chrono::NaiveTime::from_str(input).unwrap(); - let edgedb: LocalTime = chrono.try_into().unwrap(); + let edgedb: LocalTime = chrono.into(); assert_eq!(format!("{:?}", edgedb), formatted); let mut buf = BytesMut::new(); @@ -396,7 +396,7 @@ mod chrono { assert_eq!(serialized_micros, micros); - let rev = chrono::NaiveTime::try_from(edgedb).unwrap(); + let rev = chrono::NaiveTime::from(edgedb); assert_eq!(format!("{:?}", rev), formatted); } } diff --git a/edgedb-protocol/tests/decode.rs b/edgedb-protocol/tests/decode.rs index 7a36f191..3b179b1d 100644 --- a/edgedb-protocol/tests/decode.rs +++ b/edgedb-protocol/tests/decode.rs @@ -1,5 +1,3 @@ -use std::default::Default; - use edgedb_protocol::queryable::Queryable; use edgedb_protocol::model::Vector; diff --git a/edgedb-tokio/examples/transaction_errors.rs b/edgedb-tokio/examples/transaction_errors.rs index 5bc8c22d..3d6ba615 100644 --- a/edgedb-tokio/examples/transaction_errors.rs +++ b/edgedb-tokio/examples/transaction_errors.rs @@ -10,7 +10,7 @@ struct CounterError; fn check_val0(val: i64) -> anyhow::Result<()> { if val % 3 == 0 { if thread_rng().gen_bool(0.9) { - return Err(CounterError)?; + Err(CounterError)?; } } Ok(()) @@ -19,7 +19,7 @@ fn check_val0(val: i64) -> anyhow::Result<()> { fn check_val1(val: i64) -> Result<(), CounterError> { if val % 3 == 1 { if thread_rng().gen_bool(0.1) { - return Err(CounterError)?; + Err(CounterError)?; } } Ok(()) diff --git a/edgedb-tokio/src/builder.rs b/edgedb-tokio/src/builder.rs index 1aad8ec9..6975aa2a 100644 --- a/edgedb-tokio/src/builder.rs +++ b/edgedb-tokio/src/builder.rs @@ -198,11 +198,11 @@ pub async fn search_dir(base: &Path) -> Result, Error> { let mut path = base; if fs::metadata(path.join("edgedb.toml")).await.is_ok() { - return Ok(Some(path.into())); + return Ok(Some(path)); } while let Some(parent) = path.parent() { if fs::metadata(parent.join("edgedb.toml")).await.is_ok() { - return Ok(Some(parent.into())); + return Ok(Some(parent)); } path = parent; } @@ -210,7 +210,7 @@ pub async fn search_dir(base: &Path) -> Result, Error> } #[cfg(unix)] -fn path_bytes<'x>(path: &'x Path) -> &'x [u8] { +fn path_bytes(path: &Path) -> &'_ [u8] { use std::os::unix::ffi::OsStrExt; path.as_os_str().as_bytes() } @@ -230,7 +230,7 @@ fn stash_name(path: &Path) -> OsString { let mut base = base.to_os_string(); base.push("-"); base.push(&hash); - return base; + base } fn config_dir() -> Result { @@ -293,7 +293,7 @@ fn is_valid_local_instance_name(name: &str) -> bool { was_dash = false; } } - return !was_dash; + !was_dash } fn is_valid_cloud_name(name: &str) -> bool { @@ -321,7 +321,7 @@ fn is_valid_cloud_name(name: &str) -> bool { was_dash = false; } } - return !was_dash; + !was_dash } impl fmt::Display for InstanceName { @@ -396,15 +396,21 @@ impl fmt::Display for DisplayAddr<'_> { impl<'a> DsnHelper<'a> { fn from_url(url: &'a url::Url) -> Result { + use std::collections::hash_map::Entry::*; + let admin = url.scheme() == "edgedbadmin"; let mut query = HashMap::new(); for (k, v) in url.query_pairs() { - if query.contains_key(&k) { - return Err(ClientError::with_message(format!( - "{k:?} is defined multiple times in the DSN query" - )).context("invalid DSN")); - } else { - query.insert(k, v); + match query.entry(k) { + Vacant(e) => { + e.insert(v); + }, + Occupied(e) => { + return Err(ClientError::with_message(format!( + "{:?} is defined multiple times in the DSN query", + e.key() + )).context("invalid DSN")); + } } } Ok(Self { url, admin, query }) @@ -498,7 +504,7 @@ impl<'a> DsnHelper<'a> { } async fn retrieve_tls_server_name(&mut self) -> Result, Error> { - self.retrieve_value("tls_server_name", None, |s| Ok(s)).await + self.retrieve_value("tls_server_name", None, Ok).await } async fn retrieve_port(&mut self) -> Result, Error> { @@ -522,11 +528,11 @@ impl<'a> DsnHelper<'a> { async fn retrieve_password(&mut self) -> Result, Error> { let v = self.url.password().map(|s| s.to_owned()); - self.retrieve_value("password", v, |s| Ok(s)).await + self.retrieve_value("password", v, Ok).await } async fn retrieve_database(&mut self) -> Result, Error> { - let v = self.url.path().strip_prefix("/").and_then(|s| { + let v = self.url.path().strip_prefix('/').and_then(|s| { if s.is_empty() { None } else { @@ -534,14 +540,14 @@ impl<'a> DsnHelper<'a> { } }); self.retrieve_value("database", v, |s| { - let s = s.strip_prefix("/").unwrap_or(&s); + let s = s.strip_prefix('/').unwrap_or(&s); validate_database(&s)?; Ok(s.to_owned()) }).await } async fn retrieve_branch(&mut self) -> Result, Error> { - let v = self.url.path().strip_prefix("/").and_then(|s| { + let v = self.url.path().strip_prefix('/').and_then(|s| { if s.is_empty() { None } else { @@ -549,18 +555,18 @@ impl<'a> DsnHelper<'a> { } }); self.retrieve_value("branch", v, |s| { - let s = s.strip_prefix("/").unwrap_or(&s); + let s = s.strip_prefix('/').unwrap_or(&s); validate_branch(&s)?; Ok(s.to_owned()) }).await } async fn retrieve_secret_key(&mut self) -> Result, Error> { - self.retrieve_value("secret_key", None, |s| Ok(s)).await + self.retrieve_value("secret_key", None, Ok).await } async fn retrieve_tls_ca_file(&mut self) -> Result, Error> { - self.retrieve_value("tls_ca_file", None, |s| Ok(s)).await + self.retrieve_value("tls_ca_file", None, Ok).await } async fn retrieve_tls_security(&mut self) -> Result, Error> { @@ -623,8 +629,7 @@ impl Builder { -> Result<&mut Self, Error> { if let Some(cert_data) = &credentials.tls_ca { - validate_certs(&cert_data) - .context("invalid certificates in `tls_ca`")?; + validate_certs(cert_data).context("invalid certificates in `tls_ca`")?; } self.credentials = Some(credentials.clone()); Ok(self) @@ -834,15 +839,15 @@ impl Builder { .or_else(|| creds.map(|c| c.user.clone())) .unwrap_or_else(|| "edgedb".into()), password: self.password.clone() - .or_else(|| creds.map(|c| c.password.clone()).flatten()), + .or_else(|| creds.and_then(|c| c.password.clone())), secret_key: self.secret_key.clone(), cloud_profile: self.cloud_profile.clone(), cloud_certs: None, database: self.database.clone() - .or_else(|| creds.map(|c| c.database.clone()).flatten()) + .or_else(|| creds.and_then(|c| c.database.clone())) .unwrap_or_else(|| "edgedb".into()), branch: self.branch.clone() - .or_else(|| creds.map(|c| c.branch.clone()).flatten()) + .or_else(|| creds.and_then(|c| c.branch.clone())) .unwrap_or_else(|| "__default__".into()), instance_name: None, wait: self.wait_until_available.unwrap_or(DEFAULT_WAIT), @@ -851,7 +856,7 @@ impl Builder { extra_dsn_query_args: HashMap::new(), creds_file_outdated: false, pem_certificates: self.pem_certificates.clone() - .or_else(|| creds.map(|c| c.tls_ca.clone()).flatten()), + .or_else(|| creds.and_then(|c| c.tls_ca.clone())), // Pool configuration max_concurrency: self.max_concurrency, @@ -940,9 +945,8 @@ impl Builder { "port argument conflicts with {}", conflict ))); } - match &mut cfg.address { - Address::Tcp((_, ref mut portref)) => *portref = *port, - _ => {}, + if let Address::Tcp((_, ref mut portref)) = &mut cfg.address { + *portref = *port } } if let Some(unix_path) = &self.unix_path { @@ -1076,23 +1080,21 @@ impl Builder { if let Some(host) = str_env("EDGEDB_HOST", errors) { match validate_host(&host) { Ok(_) => { - cfg.address = Address::Tcp(( - host.into(), - DEFAULT_PORT, - )); + cfg.address = Address::Tcp((host, DEFAULT_PORT)); } Err(e) => errors.push(e.context("EDGEDB_HOST is invalid")), } } if let Some(port_str) = str_env("EDGEDB_PORT", errors) { let port = port_str.parse() - .map_err(|e| ClientError::with_source(e)) + .map_err(ClientError::with_source) .and_then(validate_port) .context("EDGEDB_PORT is invalid"); match port { - Ok(port) => match &mut cfg.address { - Address::Tcp((_, ref mut portref)) => *portref = port, - _ => {}, + Ok(port) => { + if let Address::Tcp((_, ref mut portref)) = &mut cfg.address { + *portref = port + } }, Err(e) => { if port_str.starts_with("tcp://") { @@ -1216,7 +1218,7 @@ impl Builder { async fn read_dsn(&self, cfg: &mut ConfigInner, url: &url::Url, errors: &mut Vec) { - let mut dsn = match DsnHelper::from_url(&url) { + let mut dsn = match DsnHelper::from_url(url) { Ok(dsn) => dsn, Err(e) => { errors.push(e); @@ -1506,7 +1508,7 @@ impl Builder { .unwrap_or(TlsSecurity::Strict); cfg.verifier = cfg.make_verifier(tls_security); - return (complete, Config(Arc::new(cfg)), errors); + (complete, Config(Arc::new(cfg)), errors) } } @@ -1527,7 +1529,7 @@ fn resolve_unix(path: impl AsRef, port: u16, admin: bool) -> PathBuf { }; path.as_ref().join(socket_name) }; - return path; + path } async fn read_instance(cfg: &mut ConfigInner, name: &InstanceName) @@ -1547,7 +1549,7 @@ async fn read_instance(cfg: &mut ConfigInner, name: &InstanceName) secret_key.clone() } else { let profile = cfg.cloud_profile.as_deref().unwrap_or("default"); - let path = cloud_config_file(&profile)?; + let path = cloud_config_file(profile)?; let data = match fs::read(path).await { Ok(data) => data, Err(e) if e.kind() == io::ErrorKind::NotFound => { @@ -1574,12 +1576,10 @@ async fn read_instance(cfg: &mut ConfigInner, name: &InstanceName) config.secret_key }; let claims_b64 = secret_key - .splitn(3, ".") - .skip(1) - .next() + .split('.').nth(1) .ok_or(ClientError::with_message("Illegal JWT token"))?; let claims = base64::engine::general_purpose::URL_SAFE_NO_PAD - .decode(&claims_b64) + .decode(claims_b64) .map_err(ClientError::with_source)?; let claims: Claims = from_slice(&claims) .map_err(ClientError::with_source)?; @@ -1635,7 +1635,7 @@ fn set_credentials(cfg: &mut ConfigInner, creds: &Credentials) -> Result<(), Error> { if let Some(cert_data) = &creds.tls_ca { - validate_certs(&cert_data) + validate_certs(cert_data) .context("invalid certificates in `tls_ca`")?; cfg.pem_certificates = Some(cert_data.into()); } @@ -1654,7 +1654,7 @@ fn set_credentials(cfg: &mut ConfigInner, creds: &Credentials) fn validate_certs(data: &str) -> Result<(), Error> { let root_store = tls::read_root_cert_pem(data) - .map_err(|e| ClientError::with_source_ref(e))?; + .map_err(ClientError::with_source_ref)?; if root_store.is_empty() { return Err(ClientError::with_message( "PEM data contains no certificate")); @@ -1667,7 +1667,7 @@ fn validate_host>(host: T) -> Result { return Err(InvalidArgumentError::with_message( "invalid host: empty string" )); - } else if host.as_ref().contains(",") { + } else if host.as_ref().contains(',') { return Err(InvalidArgumentError::with_message( "invalid host: multiple hosts" )); @@ -1713,7 +1713,7 @@ fn validate_user>(user: T) -> Result { impl Config { /// A displayable form for an address this builder will connect to - pub fn display_addr<'x>(&'x self) -> impl fmt::Display + 'x { + pub fn display_addr(&self) -> impl fmt::Display + '_ { DisplayAddr(Some(&self.0.address)) } @@ -1848,6 +1848,7 @@ impl Config { Ok(self) } + /// Return the same config with changed database branch pub fn with_branch(mut self, branch: &str) -> Result { if branch.is_empty() { return Err(InvalidArgumentError::with_message( @@ -2154,7 +2155,7 @@ fn test_instance_name() { Ok(InstanceName::Local(name)) => assert_eq!(name, inst_name), Ok(InstanceName::Cloud { org_slug, name }) => { let (o, i) = inst_name - .split_once("/") + .split_once('/') .expect("test case must have one slash"); assert_eq!(org_slug, o); assert_eq!(name, i); @@ -2190,21 +2191,20 @@ pub async fn get_project_dir(override_dir: Option<&Path>, search_parents: bool) None => { Cow::Owned(env::current_dir() .map_err(|e| ClientError::with_source(e) - .context("failed to get current directory"))? - .into()) + .context("failed to get current directory"))?) } }; if search_parents { if let Some(ancestor) = search_dir(&dir).await? { - return Ok(Some(ancestor.to_path_buf())); + Ok(Some(ancestor.to_path_buf())) } else { - return Ok(None); + Ok(None) } } else { - if !fs::metadata(dir.join("edgedb.toml")).await.is_ok() { + if fs::metadata(dir.join("edgedb.toml")).await.is_err() { return Ok(None) } - return Ok(Some(dir.to_path_buf())) - }; + Ok(Some(dir.to_path_buf())) + } } diff --git a/edgedb-tokio/src/client.rs b/edgedb-tokio/src/client.rs index 915c09e3..6a5d3dcc 100644 --- a/edgedb-tokio/src/client.rs +++ b/edgedb-tokio/src/client.rs @@ -219,9 +219,8 @@ impl Client { io_format: IoFormat::Json, expected_cardinality: Cardinality::Many, }; - let desc; - match conn.parse(&flags, query, &self.options.state).await { - Ok(parsed) => desc = parsed, + let desc = match conn.parse(&flags, query, &self.options.state).await { + Ok(parsed) => parsed, Err(e) => { if e.has_tag(SHOULD_RETRY) { let rule = self.options.retry.get_rule(&e); @@ -329,9 +328,8 @@ impl Client { io_format: IoFormat::Json, expected_cardinality: Cardinality::AtMostOne, }; - let desc; - match conn.parse(&flags, query, &self.options.state).await { - Ok(parsed) => desc = parsed, + let desc = match conn.parse(&flags, query, &self.options.state).await { + Ok(parsed) => parsed, Err(e) => { if e.has_tag(SHOULD_RETRY) { let rule = self.options.retry.get_rule(&e); @@ -434,7 +432,7 @@ impl Client { let state = &self.options.state; let caps = Capabilities::MODIFICATIONS | Capabilities::DDL; match conn.execute(query, arguments, state, caps).await { - Ok(resp) => return Ok(resp.data), + Ok(_) => return Ok(()), Err(e) => { let allow_retry = match e.get::() { // Error from a weird source, or just a bug diff --git a/edgedb-tokio/src/credentials.rs b/edgedb-tokio/src/credentials.rs index 07628a07..7d8655b9 100644 --- a/edgedb-tokio/src/credentials.rs +++ b/edgedb-tokio/src/credentials.rs @@ -1,5 +1,4 @@ //! Credentials file handling routines -use std::default::Default; use std::fmt; use std::str::FromStr; @@ -149,7 +148,7 @@ impl Serialize for Credentials { }, }; - return CredentialsCompat::serialize(&creds, serializer); + CredentialsCompat::serialize(&creds, serializer) } } diff --git a/edgedb-tokio/src/errors.rs b/edgedb-tokio/src/errors.rs index af84081e..7a6b6031 100644 --- a/edgedb-tokio/src/errors.rs +++ b/edgedb-tokio/src/errors.rs @@ -1,3 +1,2 @@ //! Errors that can be returned by a client -pub use edgedb_errors::{Error, Tag, ErrorKind, ResultExt, kinds::*}; -pub use edgedb_errors::display::*; +pub use edgedb_errors::{Error, ErrorKind, ResultExt, kinds::*}; diff --git a/edgedb-tokio/src/lib.rs b/edgedb-tokio/src/lib.rs index 95349d7f..b3484ab2 100644 --- a/edgedb-tokio/src/lib.rs +++ b/edgedb-tokio/src/lib.rs @@ -68,7 +68,7 @@ pub use client::Client; pub use errors::Error; pub use options::{TransactionOptions, RetryOptions, RetryCondition}; pub use state::{GlobalsDelta, ConfigDelta}; -pub use transaction::{Transaction}; +pub use transaction::Transaction; #[cfg(feature="unstable")] pub use builder::get_project_dir; diff --git a/edgedb-tokio/src/options.rs b/edgedb-tokio/src/options.rs index fdffd3fc..1e31aebf 100644 --- a/edgedb-tokio/src/options.rs +++ b/edgedb-tokio/src/options.rs @@ -1,5 +1,4 @@ use std::collections::HashMap; -use std::default::Default; use std::fmt; use std::sync::Arc; use std::time::Duration; @@ -9,10 +8,6 @@ use once_cell::sync::Lazy; use crate::errors::{Error, IdleSessionTimeoutError}; -trait Assert: Send + Sync + 'static {} -impl Assert for RetryOptions {} -impl Assert for TransactionOptions {} - /// Single immediate retry on idle is fine /// @@ -157,7 +152,7 @@ impl RetryOptions { use RetryCondition::*; if err.is::() { - return &*IDLE_TIMEOUT_RULE; + &IDLE_TIMEOUT_RULE } else if err.is::() { self.0.overrides.get(&TransactionConflict) .unwrap_or(&self.0.default) @@ -165,7 +160,7 @@ impl RetryOptions { self.0.overrides.get(&NetworkError).unwrap_or(&self.0.default) } else { &self.0.default - } + } } } diff --git a/edgedb-tokio/src/raw/connection.rs b/edgedb-tokio/src/raw/connection.rs index 88787355..1b3aab61 100644 --- a/edgedb-tokio/src/raw/connection.rs +++ b/edgedb-tokio/src/raw/connection.rs @@ -5,13 +5,13 @@ use std::error::Error as _; use std::future::{self, Future}; use std::io; use std::str; -use std::time::{Duration}; +use std::time::Duration; use bytes::{Bytes, BytesMut}; use rand::{thread_rng, Rng}; use scram::ScramClient; use tls_api::{TlsConnector, TlsConnectorBox, TlsStream, TlsStreamDyn}; -use tls_api::{TlsConnectorBuilder}; +use tls_api::TlsConnectorBuilder; use tls_api_not_tls::TlsConnector as PlainConnector; use tokio::io::{AsyncRead, AsyncReadExt}; use tokio::io::{AsyncWrite, AsyncWriteExt}; @@ -23,20 +23,17 @@ use rustls::pki_types::DnsName; use edgedb_protocol::client_message::{ClientMessage, ClientHandshake}; use edgedb_protocol::encoding::{Input, Output}; use edgedb_protocol::features::ProtocolVersion; -use edgedb_protocol::server_message::{ParameterStatus, RawTypedesc}; -use edgedb_protocol::server_message::{ServerHandshake}; -use edgedb_protocol::server_message::{Authentication, ErrorResponse, ServerMessage}; -use edgedb_protocol::server_message::{TransactionState, MessageSeverity}; +use edgedb_protocol::server_message::{ParameterStatus, RawTypedesc, TransactionState, MessageSeverity}; +use edgedb_protocol::server_message::{ServerHandshake, Authentication, ErrorResponse, ServerMessage}; use edgedb_protocol::value::Value; use crate::builder::{Config, Address}; -use crate::errors::{AuthenticationError, PasswordRequired}; -use crate::errors::{ClientConnectionError, ClientConnectionFailedError}; -use crate::errors::{ClientConnectionFailedTemporarilyError, ProtocolTlsError}; -use crate::errors::{ClientEncodingError, ClientConnectionEosError}; -use crate::errors::{Error, ClientError, ErrorKind}; -use crate::errors::{IdleSessionTimeoutError}; -use crate::errors::{ProtocolEncodingError, ProtocolError}; +use crate::errors::{ + AuthenticationError, PasswordRequired, ClientConnectionError, ClientConnectionFailedError, + ClientConnectionFailedTemporarilyError, ProtocolTlsError, ClientEncodingError, + ClientConnectionEosError, Error, ClientError, ErrorKind, IdleSessionTimeoutError, + ProtocolEncodingError, ProtocolError +}; use crate::raw::{Connection, PingInterval}; use crate::raw::queries::Guard; use crate::server_params::{ServerParams, ServerParam, SystemConfig}; @@ -269,7 +266,7 @@ async fn connect(cfg: &Config) -> Result { let start = Instant::now(); let wait = cfg.0.wait; - let ref mut warned = false; + let warned = &mut false; let conn = loop { match connect_timeout(cfg, connect2(cfg, &tls, warned)).await { Err(e) if is_temporary(&e) => { @@ -340,14 +337,14 @@ async fn connect3(cfg: &Config, tls: &TlsConnectorBox) Cow::from(server_name) }, None => { - if !DnsName::try_from(host.clone()).is_ok() { + if DnsName::try_from(host.clone()).is_err() { // FIXME: https://github.com/rustls/rustls/issues/184 // If self.host is neither an IP address nor a valid DNS // name, the hacks below won't make it valid anyways. let host = format!("{}.host-for-ip.edgedb.net", host); // for ipv6addr - let host = host.replace(":", "-").replace("%", "-"); - if host.starts_with("-") { + let host = host.replace([':', '%'], "-"); + if host.starts_with('-') { Cow::from(format!("i{}", host)) } else { Cow::from(host) @@ -459,8 +456,7 @@ async fn connect4(cfg: &Config, mut stream: TlsStream) b"pgaddr" => { use crate::server_params::PostgresAddress; - let pgaddr: PostgresAddress; - pgaddr = match serde_json::from_slice(&par.value[..]) { + let pgaddr: PostgresAddress = match serde_json::from_slice(&par.value[..]) { Ok(a) => a, Err(e) => { log::warn!("Can't decode param {:?}: {}", @@ -515,17 +511,17 @@ async fn scram( use edgedb_protocol::client_message::SaslInitialResponse; use edgedb_protocol::client_message::SaslResponse; - let scram = ScramClient::new(&user, &password, None); + let scram = ScramClient::new(user, password, None); let (scram, first) = scram.client_first(); - send_messages(stream, out_buf, &proto, &[ + send_messages(stream, out_buf, proto, &[ ClientMessage::AuthenticationSaslInitialResponse( SaslInitialResponse { method: "SCRAM-SHA-256".into(), data: Bytes::copy_from_slice(first.as_bytes()), }), ]).await?; - let msg = wait_message(stream, in_buf, &proto).await?; + let msg = wait_message(stream, in_buf, proto).await?; let data = match msg { ServerMessage::Authentication( Authentication::SaslContinue { data } @@ -541,16 +537,16 @@ async fn scram( let data = str::from_utf8(&data[..]) .map_err(|e| ProtocolError::with_source(e).context( "invalid utf-8 in SCRAM-SHA-256 auth"))?; - let scram = scram.handle_server_first(&data) + let scram = scram.handle_server_first(data) .map_err(AuthenticationError::with_source)?; let (scram, data) = scram.client_final(); - send_messages(stream, out_buf, &proto, &[ + send_messages(stream, out_buf, proto, &[ ClientMessage::AuthenticationSaslResponse( SaslResponse { data: Bytes::copy_from_slice(data.as_bytes()), }), ]).await?; - let msg = wait_message(stream, in_buf, &proto).await?; + let msg = wait_message(stream, in_buf, proto).await?; let data = match msg { ServerMessage::Authentication(Authentication::SaslFinal { data }) => data, @@ -565,11 +561,11 @@ async fn scram( let data = str::from_utf8(&data[..]) .map_err(|_| ProtocolError::with_message( "invalid utf-8 in SCRAM-SHA-256 auth"))?; - scram.handle_server_final(&data) + scram.handle_server_final(data) .map_err(|e| AuthenticationError::with_message(format!( "Authentication error: {}", e)))?; loop { - let msg = wait_message(stream, in_buf, &proto).await?; + let msg = wait_message(stream, in_buf, proto).await?; match msg { ServerMessage::Authentication(Authentication::Ok) => break, ServerMessage::ErrorResponse(ErrorResponse { @@ -732,7 +728,7 @@ async fn _wait_message<'x>(stream: &mut (impl AsyncRead + Unpin), log::debug!(target: "edgedb::incoming::frame", "Frame Contents: {:#?}", result); - return Ok(result) + Ok(result) } fn connect_sleep() -> Duration { @@ -753,9 +749,10 @@ async fn connect_timeout(cfg: &Config, f: F) -> Result } fn is_temporary(e: &Error) -> bool { - use io::ErrorKind::{ConnectionRefused, TimedOut, NotFound}; - use io::ErrorKind::{ConnectionAborted, ConnectionReset, UnexpectedEof}; - use io::ErrorKind::{AddrNotAvailable}; + use io::ErrorKind::{ + AddrNotAvailable, ConnectionAborted, ConnectionRefused, ConnectionReset, NotFound, + TimedOut, UnexpectedEof, + }; if e.is::() { return true; @@ -780,7 +777,7 @@ fn is_temporary(e: &Error) -> bool { } } } - return false; + false } fn tls_fail(e: anyhow::Error) -> Error { diff --git a/edgedb-tokio/src/raw/dumps.rs b/edgedb-tokio/src/raw/dumps.rs index 34c19740..6797af33 100644 --- a/edgedb-tokio/src/raw/dumps.rs +++ b/edgedb-tokio/src/raw/dumps.rs @@ -45,30 +45,26 @@ impl Connection { data: header, }), ]).await?; - loop { - let msg = self.message().await?; - match msg { - ServerMessage::RestoreReady(_) => { - log::info!("Schema applied in {:?}", - start_headers.elapsed()); - break; - } - ServerMessage::ErrorResponse(err) => { - self.send_messages(&[ClientMessage::Sync]).await?; - self.expect_ready_or_eos(guard).await - .map_err(|e| log::warn!( - "Error waiting for Ready after error: {e:#}")) - .ok(); - return Err(Into::::into(err) - .context("error initiating restore protocol") - .into()); - } - _ => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "unsolicited message {:?}", msg)))?; - } + + match self.message().await? { + ServerMessage::RestoreReady(_) => { + log::info!("Schema applied in {:?}", start_headers.elapsed()); + } + ServerMessage::ErrorResponse(err) => { + self.send_messages(&[ClientMessage::Sync]).await?; + self.expect_ready_or_eos(guard).await + .map_err(|e| log::warn!( + "Error waiting for Ready after error: {e:#}")) + .ok(); + return Err(Into::::into(err) + .context("error initiating restore protocol")); + } + msg => { + return Err(ProtocolOutOfOrderError::with_message(format!( + "unsolicited message {:?}", msg)))?; } } + let start_blocks = Instant::now(); let mut num_blocks = 0; @@ -164,7 +160,7 @@ impl Connection { self.send_messages(&[ ClientMessage::Dump(Dump { - headers: headers, + headers, }), ClientMessage::Sync, ]).await?; @@ -177,8 +173,7 @@ impl Connection { "Error waiting for Ready after error: {e:#}")) .ok(); return Err(Into::::into(err) - .context("error receiving dump header") - .into()); + .context("error receiving dump header")); } _ => { return Err(ProtocolOutOfOrderError::with_message(format!( diff --git a/edgedb-tokio/src/raw/mod.rs b/edgedb-tokio/src/raw/mod.rs index 96a8475c..9fc0e9a2 100644 --- a/edgedb-tokio/src/raw/mod.rs +++ b/edgedb-tokio/src/raw/mod.rs @@ -5,30 +5,32 @@ mod options; mod queries; mod response; pub mod state; -#[cfg(feature="unstable")] mod dumps; +#[cfg(feature="unstable")] +mod dumps; use std::collections::VecDeque; use std::sync::{Arc, Mutex as BlockingMutex}; use std::time::Duration; use bytes::{Bytes, BytesMut}; -use tls_api::{TlsStream}; +use tls_api::TlsStream; use tokio::sync::{self, Semaphore}; use edgedb_protocol::features::ProtocolVersion; use edgedb_protocol::common::{RawTypedesc, Capabilities}; -use edgedb_protocol::server_message::{TransactionState}; -use edgedb_protocol::server_message::{CommandDataDescription1}; +use edgedb_protocol::server_message::TransactionState; +use edgedb_protocol::server_message::CommandDataDescription1; use crate::errors::{Error, ErrorKind, ClientError}; use crate::builder::Config; use crate::server_params::ServerParams; pub use options::Options; -pub use response::{ResponseStream}; +pub use response::ResponseStream; pub use state::{State, PoolState}; -#[cfg(feature="unstable")] pub use dumps::{DumpStream}; +#[cfg(feature="unstable")] +pub use dumps::DumpStream; #[derive(Clone, Debug)] pub struct Pool(Arc); @@ -82,12 +84,6 @@ pub(crate) enum PingInterval { Interval(Duration), } -trait AssertConn: Send + 'static {} -impl AssertConn for PoolConnection {} -impl AssertConn for Connection {} - -trait AssertPool: Send + Sync + 'static {} -impl AssertPool for Pool {} impl edgedb_errors::Field for QueryCapabilities { const NAME: &'static str = "capabilities"; @@ -142,11 +138,11 @@ impl PoolInner { // Make sure that connection is wrapped before we commit, // so that connection is returned into a pool if we fail // to commit because of async stuff - return Ok(PoolConnection { + Ok(PoolConnection { inner: Some(conn), permit, pool: self.clone(), - }); + }) } } diff --git a/edgedb-tokio/src/raw/queries.rs b/edgedb-tokio/src/raw/queries.rs index bfd57f30..2d4a1fd5 100644 --- a/edgedb-tokio/src/raw/queries.rs +++ b/edgedb-tokio/src/raw/queries.rs @@ -8,8 +8,8 @@ use edgedb_protocol::QueryResult; use edgedb_protocol::client_message::{ClientMessage, Parse, Prepare}; use edgedb_protocol::client_message::{DescribeStatement, DescribeAspect}; use edgedb_protocol::client_message::{Execute0, Execute1}; -use edgedb_protocol::client_message::{OptimisticExecute}; -use edgedb_protocol::common::{CompilationOptions}; +use edgedb_protocol::client_message::OptimisticExecute; +use edgedb_protocol::common::CompilationOptions; use edgedb_protocol::common::{IoFormat, Cardinality, Capabilities}; use edgedb_protocol::descriptors::Typedesc; use edgedb_protocol::features::ProtocolVersion; @@ -42,8 +42,7 @@ impl Connection { ::with_message("interrupted ping")), } } - pub(crate) fn end_request(&mut self, guard: Guard) { - drop(guard); + pub(crate) fn end_request(&mut self, _guard: Guard) { self.mode = Mode::Normal { idle_since: Instant::now() }; @@ -53,15 +52,14 @@ impl Connection { { loop { let msg = self.message().await?; - match msg { - ServerMessage::ReadyForCommand(ready) => { - self.transaction_state = ready.transaction_state; - self.end_request(guard); - return Ok(()) - } - // TODO(tailhook) should we react on messages somehow? - // At least parse LogMessage's? - _ => {}, + + // TODO(tailhook) should we react on messages somehow? + // At least parse LogMessage's? + + if let ServerMessage::ReadyForCommand(ready) = msg { + self.transaction_state = ready.transaction_state; + self.end_request(guard); + return Ok(()) } } } @@ -137,25 +135,22 @@ impl Connection { ClientMessage::Prepare(Prepare::new(flags, query)), ClientMessage::Sync, ]).await?; - - loop { - let msg = self.message().await?; - match msg { - ServerMessage::PrepareComplete(data) => { - self.expect_ready(guard).await?; - return Ok(data); - } - ServerMessage::ErrorResponse(err) => { - self.expect_ready_or_eos(guard).await - .map_err(|e| log::warn!( - "Error waiting for Ready after error: {e:#}")) - .ok(); - return Err(err.into()); - } - _ => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", msg))); - } + + match self.message().await? { + ServerMessage::PrepareComplete(data) => { + self.expect_ready(guard).await?; + Ok(data) + } + ServerMessage::ErrorResponse(err) => { + self.expect_ready_or_eos(guard).await + .map_err(|e| log::warn!( + "Error waiting for Ready after error: {e:#}")) + .ok(); + Err(err.into()) + } + msg => { + Err(ProtocolOutOfOrderError::with_message(format!( + "Unsolicited message {:?}", msg))) } } } @@ -172,24 +167,21 @@ impl Connection { ClientMessage::Sync, ]).await?; - let desc = loop { - let msg = self.message().await?; - match msg { - ServerMessage::CommandDataDescription0(data_desc) => { - self.expect_ready(guard).await?; - break data_desc; - } - ServerMessage::ErrorResponse(err) => { - self.expect_ready_or_eos(guard).await - .map_err(|e| log::warn!( - "Error waiting for Ready after error: {e:#}")) - .ok(); - return Err(err.into()); - } - _ => { - return Err(ProtocolOutOfOrderError::with_message(format!( - "Unsolicited message {:?}", msg))); - } + let desc = match self.message().await? { + ServerMessage::CommandDataDescription0(data_desc) => { + self.expect_ready(guard).await?; + data_desc + } + ServerMessage::ErrorResponse(err) => { + self.expect_ready_or_eos(guard).await + .map_err(|e| log::warn!( + "Error waiting for Ready after error: {e:#}")) + .ok(); + return Err(err.into()); + } + msg => { + return Err(ProtocolOutOfOrderError::with_message(format!( + "Unsolicited message {:?}", msg))); } }; // normalize CommandDataDescription0 into Parse (proto 1.x) output @@ -371,7 +363,7 @@ impl Connection { let out_desc = desc.output() .map_err(ProtocolEncodingError::with_source)?; - return ResponseStream::new(self, &out_desc, guard).await; + ResponseStream::new(self, &out_desc, guard).await } pub async fn try_execute_stream(&mut self, opts: &CompilationOptions, query: &str, @@ -414,7 +406,7 @@ impl Connection { ]).await?; } - return ResponseStream::new(self, output, guard).await; + ResponseStream::new(self, output, guard).await } pub async fn statement(&mut self, flags: &CompilationOptions, query: &str, state: &dyn State) @@ -556,12 +548,12 @@ impl Connection { .map(|chunk| R::decode(&mut state, &chunk)) .collect::>() })?; - return Ok(rows) + Ok(rows) } - None => return Err(NoResultExpected::build()), + None => Err(NoResultExpected::build()), } }.await; - return result.map_err(|e| e.set::(caps)); + result.map_err(|e| e.set::(caps)) } pub async fn query_single(&mut self, query: &str, arguments: &A, @@ -604,7 +596,7 @@ impl Connection { Some(root_pos) => { let ctx = out_desc.as_queryable_context(); let mut state = R::prepare(&ctx, root_pos)?; - return response.map(|data| { + response.map(|data| { let bytes = data.into_iter().next() .and_then(|chunk| chunk.data.into_iter().next()); if let Some(bytes) = bytes { @@ -612,12 +604,12 @@ impl Connection { } else { Ok(None) } - }); + }) } - None => return Err(NoResultExpected::build()), + None => Err(NoResultExpected::build()), } }.await; - return result.map_err(|e| e.set::(caps)); + result.map_err(|e| e.set::(caps)) } pub async fn query_required_single( @@ -666,9 +658,9 @@ impl Connection { let res = self._execute( &flags, query, state, &desc, &arg_buf.freeze(), ).await?; - Ok(res.map(|_| Ok::<_, Error>(()))?) + res.map(|_| Ok::<_, Error>(())) }.await; - return result.map_err(|e| e.set::(caps)); + result.map_err(|e| e.set::(caps)) } } diff --git a/edgedb-tokio/src/raw/response.rs b/edgedb-tokio/src/raw/response.rs index b8bfb097..77dfb14b 100644 --- a/edgedb-tokio/src/raw/response.rs +++ b/edgedb-tokio/src/raw/response.rs @@ -236,7 +236,7 @@ impl<'a, T: QueryResult> ResponseStream<'a, T> .map_err(|e| log::warn!( "Error waiting for Ready after error: {e:#}")) .ok(); - self.buffer = ErrorResponse(err.into()); + self.buffer = ErrorResponse(err); return None; } Ok(msg) => { @@ -271,7 +271,7 @@ impl<'a, T: QueryResult> ResponseStream<'a, T> if let Some(desc) = self.description.take() { err = err.set::(desc); } - Err(err.into()) + Err(err) } Reset => panic!("process_complete() called twice"), } diff --git a/edgedb-tokio/src/raw/state.rs b/edgedb-tokio/src/raw/state.rs index 02b0fdf2..a23534d6 100644 --- a/edgedb-tokio/src/raw/state.rs +++ b/edgedb-tokio/src/raw/state.rs @@ -1,11 +1,10 @@ //! Connection state modification utilities use std::collections::{BTreeMap, HashMap}; -use std::default::Default; use std::sync::Arc; use arc_swap::ArcSwapOption; -use edgedb_protocol::client_message::{State as EncodedState}; +use edgedb_protocol::client_message::State as EncodedState; use edgedb_protocol::descriptors::{RawTypedesc,StateBorrow}; use edgedb_protocol::query_arg::QueryArg; use edgedb_protocol::value::Value; @@ -403,7 +402,7 @@ impl PoolState { globals: &self.raw_state.globals, })?; self.cache.store(Some(Arc::new(result.clone()))); - return Ok(result); + Ok(result) } } @@ -416,10 +415,9 @@ impl SealedState for &PoolState { } impl State for &PoolState {} impl SealedState for Arc { - fn encode(&self, desc: &RawTypedesc) - -> Result + fn encode(&self, desc: &RawTypedesc) -> Result { - (&**self).encode(desc) + PoolState::encode(self, desc) } } impl State for Arc {} @@ -433,9 +431,9 @@ impl SealedState for EncodedState { { return Ok((*self).clone()); } - return Err(ClientError::with_message( + Err(ClientError::with_message( "state doesn't match state descriptor" - )); + )) } } impl State for EncodedState {} @@ -443,7 +441,7 @@ impl SealedState for Arc { fn encode(&self, desc: &RawTypedesc) -> Result { - (&**self).encode(desc) + (**self).encode(desc) } } impl State for Arc {} diff --git a/edgedb-tokio/src/server_params.rs b/edgedb-tokio/src/server_params.rs index 45368457..3738e83a 100644 --- a/edgedb-tokio/src/server_params.rs +++ b/edgedb-tokio/src/server_params.rs @@ -11,9 +11,6 @@ use crate::sealed::SealedParam; #[derive(Debug)] pub(crate) struct ServerParams(HashMap>); -trait AssertParams: Send + Sync + 'static {} -impl AssertParams for ServerParams {} - /// Address of the underlying postgres, available only in dev mode. #[derive(Deserialize, Debug, Serialize)] pub struct PostgresAddress { diff --git a/edgedb-tokio/src/tls.rs b/edgedb-tokio/src/tls.rs index 1fda22fb..dafcbe1b 100644 --- a/edgedb-tokio/src/tls.rs +++ b/edgedb-tokio/src/tls.rs @@ -26,7 +26,7 @@ pub struct NoHostnameVerifier { impl NoHostnameVerifier { pub fn new(roots: Arc) -> Self { NoHostnameVerifier { - roots: roots, + roots, supported: ring::default_provider().signature_verification_algorithms, } } diff --git a/edgedb-tokio/src/transaction.rs b/edgedb-tokio/src/transaction.rs index 7a352d32..977446e7 100644 --- a/edgedb-tokio/src/transaction.rs +++ b/edgedb-tokio/src/transaction.rs @@ -10,7 +10,7 @@ use edgedb_protocol::query_arg::{QueryArgs, Encoder}; use tokio::sync::oneshot; use tokio::time::sleep; -use crate::errors::{ClientError}; +use crate::errors::ClientError; use crate::errors::{Error, ErrorKind, SHOULD_RETRY}; use crate::errors::{ProtocolEncodingError, NoResultExpected, NoDataError}; use crate::raw::{Pool, PoolConnection, Options, PoolState}; @@ -43,9 +43,6 @@ pub struct Inner { return_conn: oneshot::Sender, } -trait Assert: Send {} -impl Assert for Transaction {} - impl Drop for Transaction { fn drop(&mut self) { self.inner.take().map(|Inner { started, conn, return_conn }| { @@ -178,7 +175,7 @@ impl Transaction { expected_cardinality: Cardinality::Many, }; let state = self.state.clone(); // TODO: optimize, by careful borrow - let ref mut conn = self.inner().conn; + let conn = &mut self.inner().conn; let desc = conn.parse(&flags, query, &state).await?; let inp_desc = desc.input() .map_err(ProtocolEncodingError::with_source)?; @@ -241,7 +238,7 @@ impl Transaction { expected_cardinality: Cardinality::AtMostOne, }; let state = self.state.clone(); // TODO optimize using careful borrow - let ref mut conn = self.inner().conn; + let conn = &mut self.inner().conn; let desc = conn.parse(&flags, query, &state).await?; let inp_desc = desc.input() .map_err(ProtocolEncodingError::with_source)?; @@ -324,7 +321,7 @@ impl Transaction { expected_cardinality: Cardinality::Many, }; let state = self.state.clone(); // TODO optimize using careful borrow - let ref mut conn = self.inner().conn; + let conn = &mut self.inner().conn; let desc = conn.parse(&flags, query, &state).await?; let inp_desc = desc.input() .map_err(ProtocolEncodingError::with_source)?; @@ -382,7 +379,7 @@ impl Transaction { expected_cardinality: Cardinality::AtMostOne, }; let state = self.state.clone(); // TODO optimize using careful borrow - let ref mut conn = self.inner().conn; + let conn = &mut self.inner().conn; let desc = conn.parse(&flags, query, &state).await?; let inp_desc = desc.input() .map_err(ProtocolEncodingError::with_source)?; @@ -454,7 +451,7 @@ impl Transaction { expected_cardinality: Cardinality::Many, }; let state = self.state.clone(); // TODO: optimize, by careful borrow - let ref mut conn = self.inner().conn; + let conn = &mut self.inner().conn; let desc = conn.parse(&flags, query, &state).await?; let inp_desc = desc.input() .map_err(ProtocolEncodingError::with_source)?; @@ -470,12 +467,3 @@ impl Transaction { } } -#[allow(dead_code, unreachable_code)] -fn _transaction_assertions() { - let _cli: crate::Client = unimplemented!(); - assert_send( - _cli.transaction(|mut tx| async move { tx.query_json("SELECT 'hello'", &()).await }), - ); -} - -fn assert_send(_: T) {} diff --git a/edgedb-tokio/tests/func/globals.rs b/edgedb-tokio/tests/func/globals.rs index a4abde72..3ce7f2e5 100644 --- a/edgedb-tokio/tests/func/globals.rs +++ b/edgedb-tokio/tests/func/globals.rs @@ -10,27 +10,28 @@ async fn global_fn() -> anyhow::Result<()> { let value = client .with_default_module(Some("test")) .with_globals_fn(|m| m.set("str_val", "hello")) - .query::("SELECT (global str_val)", &()).await?; + .query::("SELECT (global str_val)", &()) + .await?; assert_eq!(value, vec![String::from("hello")]); let value = client .with_default_module(Some("test")) .with_globals_fn(|m| m.set("int_val", 127)) - .query::("SELECT (global int_val)", &()).await?; + .query::("SELECT (global int_val)", &()) + .await?; assert_eq!(value, vec![127]); Ok(()) } -#[cfg(feature="derive")] +#[derive(edgedb_derive::GlobalsDelta)] +struct Globals { + str_val: &'static str, + int_val: i32, +} + +#[cfg(feature = "derive")] #[tokio::test] async fn global_struct() -> anyhow::Result<()> { - - #[derive(edgedb_derive::GlobalsDelta)] - struct Globals { - str_val: &'static str, - int_val: i32, - } - let client = Client::new(&SERVER.config); client.ensure_connected().await?; @@ -40,7 +41,8 @@ async fn global_struct() -> anyhow::Result<()> { str_val: "value1", int_val: 345, }) - .query::("SELECT (global str_val)", &()).await?; + .query::("SELECT (global str_val)", &()) + .await?; assert_eq!(value, vec![String::from("value1")]); let value = client @@ -49,7 +51,8 @@ async fn global_struct() -> anyhow::Result<()> { str_val: "value2", int_val: 678, }) - .query::("SELECT (global int_val)", &()).await?; + .query::("SELECT (global int_val)", &()) + .await?; assert_eq!(value, vec![678]); Ok(()) } diff --git a/edgedb-tokio/tests/func/main.rs b/edgedb-tokio/tests/func/main.rs index bd21325a..282fdcad 100644 --- a/edgedb-tokio/tests/func/main.rs +++ b/edgedb-tokio/tests/func/main.rs @@ -1,7 +1,7 @@ #[cfg(not(windows))] mod server; -#[cfg(all(not(windows), features="unstable"))] +#[cfg(all(not(windows), feature="unstable"))] mod raw; #[cfg(not(windows))] diff --git a/edgedb-tokio/tests/func/raw.rs b/edgedb-tokio/tests/func/raw.rs index 684fbcbc..8d2f914c 100644 --- a/edgedb-tokio/tests/func/raw.rs +++ b/edgedb-tokio/tests/func/raw.rs @@ -1,8 +1,10 @@ +use std::sync::Arc; + use bytes::Bytes; -use edgedb_tokio::raw::Pool; -use edgedb_protocol::common::{CompliationOptions, IoFormat, Cardinality}; -use edgedb_protocol::common::{Capabilities}; +use edgedb_protocol::common::Capabilities; +use edgedb_protocol::common::{Cardinality, CompilationOptions, IoFormat}; +use edgedb_tokio::raw::{Pool, PoolState}; use crate::server::SERVER; @@ -11,7 +13,9 @@ async fn poll_connect() -> anyhow::Result<()> { let pool = Pool::new(&SERVER.config); let mut conn = pool.acquire().await?; assert!(conn.is_consistent()); - let _prepare = conn.prepare(&CompliationOptions { + + let state = Arc::new(PoolState::default()); + let options = CompilationOptions { implicit_limit: None, implicit_typenames: false, implicit_typeids: false, @@ -19,11 +23,12 @@ async fn poll_connect() -> anyhow::Result<()> { explicit_objectids: true, io_format: IoFormat::Binary, expected_cardinality: Cardinality::Many, - }, "SELECT 7*8").await; - assert!(conn.is_consistent()); - let _descr = conn.describe_data().await?; + }; + + let desc = conn.parse(&options, "SELECT 7*8", &state).await?; assert!(conn.is_consistent()); - let data = conn.execute(&Bytes::new()).await?; + + let data = conn.execute(&options, "SELECT 7*8", &state, &desc, &Bytes::new()).await?; assert!(conn.is_consistent()); assert_eq!(data.len(), 1); assert_eq!(data[0].data.len(), 1); diff --git a/edgedb-tokio/tests/func/server.rs b/edgedb-tokio/tests/func/server.rs index f7bf857d..cad82632 100644 --- a/edgedb-tokio/tests/func/server.rs +++ b/edgedb-tokio/tests/func/server.rs @@ -1,19 +1,17 @@ use std::env; use std::fs::{self, File}; -use std::io::{BufReader, BufRead}; +use std::io::{BufRead, BufReader}; use std::os::unix::io::FromRawFd; use std::process; use std::sync::Mutex; use command_fds::{CommandFdExt, FdMapping}; use once_cell::sync::Lazy; -use shutdown_hooks; use edgedb_tokio::{Builder, Config}; -pub static SHUTDOWN_INFO: Lazy>> = - Lazy::new(|| Mutex::new(Vec::new())); -pub static SERVER: Lazy = Lazy::new(|| ServerGuard::start()); +pub static SHUTDOWN_INFO: Lazy>> = Lazy::new(|| Mutex::new(Vec::new())); +pub static SERVER: Lazy = Lazy::new(ServerGuard::start); pub struct ShutdownInfo { process: process::Child, @@ -29,12 +27,8 @@ pub struct ServerInfo { tls_cert_file: String, } - impl ServerGuard { fn start() -> ServerGuard { - ServerGuard::_start().expect("can run server") - } - fn _start() -> anyhow::Result { use std::process::Command; let bin_name = if let Ok(ver) = env::var("EDGEDB_MAJOR_VERSION") { @@ -42,7 +36,7 @@ impl ServerGuard { } else { "edgedb-server".to_string() }; - let (pipe_read, pipe_write) = nix::unistd::pipe()?; + let (pipe_read, pipe_write) = nix::unistd::pipe().unwrap(); let mut cmd = Command::new(&bin_name); cmd.env("EDGEDB_SERVER_SECURITY", "insecure_dev_mode"); cmd.arg("--temp-dir"); @@ -50,9 +44,11 @@ impl ServerGuard { cmd.arg("--emit-server-status=fd://3"); cmd.arg("--port=auto"); cmd.arg("--tls-cert-mode=generate_self_signed"); - cmd.fd_mappings(vec![ - FdMapping { parent_fd: pipe_write, child_fd: 3 } - ])?; + cmd.fd_mappings(vec![FdMapping { + parent_fd: pipe_write, + child_fd: 3, + }]) + .unwrap(); if nix::unistd::Uid::effective().as_raw() == 0 { use std::os::unix::process::CommandExt; @@ -61,16 +57,15 @@ impl ServerGuard { cmd.uid(1); } - let process = cmd.spawn() - .expect(&format!("Can run {}", bin_name)); + let process = cmd.spawn().unwrap_or_else(|_| panic!("Can run {}", bin_name)); let pipe = BufReader::new(unsafe { File::from_raw_fd(pipe_read) }); let mut result = Err(anyhow::anyhow!("no server info emitted")); for line in pipe.lines() { match line { Ok(line) => { if let Some(data) = line.strip_prefix("READY=") { - let data: ServerInfo = serde_json::from_str(data) - .expect("valid server data"); + let data: ServerInfo = + serde_json::from_str(data).expect("valid server data"); println!("Server data {:?}", data); result = Ok(data); break; @@ -89,36 +84,48 @@ impl ServerGuard { shutdown_hooks::add_shutdown_hook(stop_processes); } sinfo.push(ShutdownInfo { process }); - let info = result?; + let info = result.unwrap(); fs::remove_file("tests/func/dbschema/migrations/00001.edgeql").ok(); assert!(Command::new("edgedb") .current_dir("./tests/func") - .arg("--tls-security").arg("insecure") - .arg("--port").arg(info.port.to_string()) + .arg("--tls-security") + .arg("insecure") + .arg("--port") + .arg(info.port.to_string()) .arg("migration") .arg("create") .arg("--non-interactive") - .status()?.success()); + .status() + .expect("cannot run edgedb-cli to create a migration") + .success()); + dbg!("2"); assert!(Command::new("edgedb") .current_dir("./tests/func") - .arg("--tls-security").arg("insecure") - .arg("--port").arg(info.port.to_string()) + .arg("--tls-security") + .arg("insecure") + .arg("--port") + .arg(info.port.to_string()) .arg("migration") .arg("apply") - .status()?.success()); + .status() + .expect("cannot run edgedb-cli to apply a migration") + .success()); - let cert_data = std::fs::read_to_string(&info.tls_cert_file) - .expect("cert file should be readable"); + let cert_data = + std::fs::read_to_string(&info.tls_cert_file).expect("cert file should be readable"); let config = Builder::new() - .port(info.port)? - .pem_certificates(&cert_data)? - .constrained_build()?; - Ok(ServerGuard { config }) + .port(info.port) + .unwrap() + .pem_certificates(&cert_data) + .unwrap() + .constrained_build() + .unwrap(); + ServerGuard { config } } } -extern fn stop_processes() { +extern "C" fn stop_processes() { let mut items = SHUTDOWN_INFO.lock().expect("shutdown mutex works"); for item in items.iter_mut() { term_process(&mut item.process); @@ -129,12 +136,10 @@ extern fn stop_processes() { } fn term_process(proc: &mut process::Child) { - use nix::unistd::Pid; use nix::sys::signal::{self, Signal}; + use nix::unistd::Pid; - if let Err(e) = signal::kill( - Pid::from_raw(proc.id() as i32), Signal::SIGTERM - ) { + if let Err(e) = signal::kill(Pid::from_raw(proc.id() as i32), Signal::SIGTERM) { eprintln!("could not send SIGTERM to edgedb-server: {:?}", e); }; }