diff --git a/src/query.rs b/src/query.rs index 7d02a60..a54164c 100644 --- a/src/query.rs +++ b/src/query.rs @@ -10,7 +10,7 @@ use crate::{ request_body::RequestBody, response::Response, row::Row, - sql::{Bind, SqlBuilder}, + sql::{Bind, SqlBuilder, ser}, Client, }; @@ -196,9 +196,9 @@ impl Query { self } - pub fn with_param(self, name: &str, value: impl Bind + Serialize) -> Result { + pub fn with_param(self, name: &str, value: T) -> Result where T: Serialize { let mut param = String::from(""); - Bind::write(&value, &mut param)?; + ser::write_param(&mut param, &value)?; Ok(self.with_option(format!("param_{name}"), param)) } } diff --git a/src/sql/mod.rs b/src/sql/mod.rs index 7417be7..66330f6 100644 --- a/src/sql/mod.rs +++ b/src/sql/mod.rs @@ -9,7 +9,7 @@ pub use bind::{Bind, Identifier}; mod bind; pub(crate) mod escape; -mod ser; +pub(crate) mod ser; #[derive(Debug, Clone)] pub(crate) enum SqlBuilder { diff --git a/src/sql/ser.rs b/src/sql/ser.rs index 00ea606..46541e5 100644 --- a/src/sql/ser.rs +++ b/src/sql/ser.rs @@ -8,23 +8,23 @@ use thiserror::Error; use super::escape; -// === SqlSerializerError === +// === SerializerError === #[derive(Debug, Error)] -enum SqlSerializerError { +enum SerializerError { #[error("{0} is unsupported")] Unsupported(&'static str), #[error("{0}")] Custom(String), } -impl ser::Error for SqlSerializerError { +impl ser::Error for SerializerError { fn custom(msg: T) -> Self { Self::Custom(msg.to_string()) } } -impl From for SqlSerializerError { +impl From for SerializerError { fn from(err: fmt::Error) -> Self { Self::Custom(err.to_string()) } @@ -32,8 +32,8 @@ impl From for SqlSerializerError { // === SqlSerializer === -type Result = std::result::Result; -type Impossible = ser::Impossible<(), SqlSerializerError>; +type Result = std::result::Result; +type Impossible = ser::Impossible<(), SerializerError>; struct SqlSerializer<'a, W> { writer: &'a mut W, @@ -43,7 +43,7 @@ macro_rules! unsupported { ($ser_method:ident($ty:ty) -> $ret:ty, $($other:tt)*) => { #[inline] fn $ser_method(self, _v: $ty) -> $ret { - Err(SqlSerializerError::Unsupported(stringify!($ser_method))) + Err(SerializerError::Unsupported(stringify!($ser_method))) } unsupported!($($other)*); }; @@ -53,7 +53,7 @@ macro_rules! unsupported { ($ser_method:ident, $($other:tt)*) => { #[inline] fn $ser_method(self) -> Result { - Err(SqlSerializerError::Unsupported(stringify!($ser_method))) + Err(SerializerError::Unsupported(stringify!($ser_method))) } unsupported!($($other)*); }; @@ -73,7 +73,7 @@ macro_rules! forward_to_display { } impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { - type Error = SqlSerializerError; + type Error = SerializerError; type Ok = (); type SerializeMap = Impossible; type SerializeSeq = SqlListSerializer<'a, W>; @@ -177,12 +177,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { _variant: &'static str, _value: &T, ) -> Result { - Err(SqlSerializerError::Unsupported("serialize_newtype_variant")) + Err(SerializerError::Unsupported("serialize_newtype_variant")) } #[inline] fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { - Err(SqlSerializerError::Unsupported("serialize_tuple_struct")) + Err(SerializerError::Unsupported("serialize_tuple_struct")) } #[inline] @@ -193,12 +193,12 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { _variant: &'static str, _len: usize, ) -> Result { - Err(SqlSerializerError::Unsupported("serialize_tuple_variant")) + Err(SerializerError::Unsupported("serialize_tuple_variant")) } #[inline] fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { - Err(SqlSerializerError::Unsupported("serialize_struct")) + Err(SerializerError::Unsupported("serialize_struct")) } #[inline] @@ -209,7 +209,7 @@ impl<'a, W: Write> Serializer for SqlSerializer<'a, W> { _variant: &'static str, _len: usize, ) -> Result { - Err(SqlSerializerError::Unsupported("serialize_struct_variant")) + Err(SerializerError::Unsupported("serialize_struct_variant")) } #[inline] @@ -227,7 +227,7 @@ struct SqlListSerializer<'a, W> { } impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> { - type Error = SqlSerializerError; + type Error = SerializerError; type Ok = (); #[inline] @@ -254,7 +254,7 @@ impl<'a, W: Write> SerializeSeq for SqlListSerializer<'a, W> { } impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> { - type Error = SqlSerializerError; + type Error = SerializerError; type Ok = (); #[inline] @@ -271,6 +271,160 @@ impl<'a, W: Write> SerializeTuple for SqlListSerializer<'a, W> { } } +// === ParamSerializer === + +struct ParamSerializer<'a, W> { + writer: &'a mut W, +} + +impl<'a, W: Write> Serializer for ParamSerializer<'a, W> { + type Error = SerializerError; + type Ok = (); + type SerializeMap = Impossible; + type SerializeSeq = SqlListSerializer<'a, W>; + type SerializeStruct = Impossible; + type SerializeStructVariant = Impossible; + type SerializeTuple = SqlListSerializer<'a, W>; + type SerializeTupleStruct = Impossible; + type SerializeTupleVariant = Impossible; + + unsupported!( + serialize_map(Option) -> Result, + serialize_bytes(&[u8]), + serialize_unit, + serialize_unit_struct(&'static str), + ); + + forward_to_display!( + serialize_i8(i8), + serialize_i16(i16), + serialize_i32(i32), + serialize_i64(i64), + serialize_i128(i128), + serialize_u8(u8), + serialize_u16(u16), + serialize_u32(u32), + serialize_u64(u64), + serialize_u128(u128), + serialize_f32(f32), + serialize_f64(f64), + serialize_bool(bool), + ); + + #[inline] + fn serialize_char(self, value: char) -> Result { + let mut tmp = [0u8; 4]; + self.serialize_str(value.encode_utf8(&mut tmp)) + } + + #[inline] + fn serialize_str(self, value: &str) -> Result { + // ClickHouse expects strings in params to be unquoted until inside a nested type + // nested types go through serialize_seq which'll quote strings + self.writer.write_str(value)?; + Ok(()) + } + + #[inline] + fn serialize_seq(self, _len: Option) -> Result> { + self.writer.write_char('[')?; + Ok(SqlListSerializer { + writer: self.writer, + has_items: false, + closing_char: ']', + }) + } + + #[inline] + fn serialize_tuple(self, _len: usize) -> Result> { + self.writer.write_char('(')?; + Ok(SqlListSerializer { + writer: self.writer, + has_items: false, + closing_char: ')', + }) + } + + #[inline] + fn serialize_some(self, _value: &T) -> Result { + _value.serialize(self) + } + + #[inline] + fn serialize_none(self) -> std::result::Result { + self.writer.write_str("NULL")?; + Ok(()) + } + + #[inline] + fn serialize_unit_variant( + self, + _name: &'static str, + _variant_index: u32, + variant: &'static str, + ) -> Result { + escape::string(variant, self.writer)?; + Ok(()) + } + + #[inline] + fn serialize_newtype_struct( + self, + _name: &'static str, + value: &T, + ) -> Result { + value.serialize(self) + } + + #[inline] + fn serialize_newtype_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _value: &T, + ) -> Result { + Err(SerializerError::Unsupported("serialize_newtype_variant")) + } + + #[inline] + fn serialize_tuple_struct(self, _name: &'static str, _len: usize) -> Result { + Err(SerializerError::Unsupported("serialize_tuple_struct")) + } + + #[inline] + fn serialize_tuple_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(SerializerError::Unsupported("serialize_tuple_variant")) + } + + #[inline] + fn serialize_struct(self, _name: &'static str, _len: usize) -> Result { + Err(SerializerError::Unsupported("serialize_struct")) + } + + #[inline] + fn serialize_struct_variant( + self, + _name: &'static str, + _variant_index: u32, + _variant: &'static str, + _len: usize, + ) -> Result { + Err(SerializerError::Unsupported("serialize_struct_variant")) + } + + #[inline] + fn is_human_readable(&self) -> bool { + true + } +} + // === Public API === pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> { @@ -279,6 +433,12 @@ pub(crate) fn write_arg(writer: &mut impl Write, value: &impl Serialize) -> Resu .map_err(|err| err.to_string()) } +pub(crate) fn write_param(writer: &mut impl Write, value: &impl Serialize) -> Result<(), String> { + value + .serialize(ParamSerializer { writer }) + .map_err(|err| err.to_string()) +} + #[cfg(test)] mod tests { use super::*;