From 630527170d6744cee3f8ed088dd6ab4b23e66a24 Mon Sep 17 00:00:00 2001 From: MaximFischuk Date: Fri, 12 Jan 2024 12:16:10 +0200 Subject: [PATCH 1/2] Add GetAllChildrenNumber functionality --- src/proto.rs | 24 ++++ src/zookeeper.rs | 33 +++++- tests/test_get_all_children_number.rs | 160 ++++++++++++++++++++++++++ 3 files changed, 214 insertions(+), 3 deletions(-) create mode 100644 tests/test_get_all_children_number.rs diff --git a/src/proto.rs b/src/proto.rs index 95e256ea8..bef53fa80 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -21,6 +21,8 @@ pub enum OpCode { Create2 = 15, CreateTtl = 21, + + GetAllChildrenNumber = 104, } pub type ByteBuf = Cursor>; @@ -433,6 +435,28 @@ impl ReadFrom for GetChildrenResponse { } } +pub struct GetAllChildrenNumberRequest { + pub path: String, +} + +impl WriteTo for GetAllChildrenNumberRequest { + fn write_to(&self, writer: &mut dyn Write) -> Result<()> { + self.path.write_to(writer) + } +} + +pub struct GetAllChildrenNumberResponse { + pub total_number: i32, +} + +impl ReadFrom for GetAllChildrenNumberResponse { + fn read_from(reader: &mut R) -> Result { + Ok(GetAllChildrenNumberResponse { + total_number: reader.read_i32::()?, + }) + } +} + pub type GetDataRequest = StringAndBoolRequest; pub struct GetDataResponse { diff --git a/src/zookeeper.rs b/src/zookeeper.rs index 62b937c0a..00c624db2 100644 --- a/src/zookeeper.rs +++ b/src/zookeeper.rs @@ -15,9 +15,10 @@ use crate::listeners::ListenerSet; use crate::proto::{ to_len_prefixed_buf, AuthRequest, ByteBuf, Create2Response, CreateRequest, CreateResponse, CreateTTLRequest, DeleteRequest, EmptyRequest, EmptyResponse, ExistsRequest, ExistsResponse, - GetAclRequest, GetAclResponse, GetChildrenRequest, GetChildrenResponse, GetDataRequest, - GetDataResponse, OpCode, ReadFrom, ReplyHeader, RequestHeader, SetAclRequest, SetAclResponse, - SetDataRequest, SetDataResponse, WriteTo, + GetAclRequest, GetAclResponse, GetAllChildrenNumberRequest, GetAllChildrenNumberResponse, + GetChildrenRequest, GetChildrenResponse, GetDataRequest, GetDataResponse, OpCode, ReadFrom, + ReplyHeader, RequestHeader, SetAclRequest, SetAclResponse, SetDataRequest, SetDataResponse, + WriteTo, }; use crate::watch::ZkWatch; use crate::{ @@ -585,6 +586,32 @@ impl ZooKeeper { Ok(response.children) } + /// Return the number of children of the node of the given `path`. + /// This operation returns the number of children recursively. + /// For example, given the following tree: + /// ```text + /// /test + /// /test/child1 + /// /test/child2/child21 + /// /test/child3/child31/child311 + /// ``` + /// The number of children of `/test` is 6. + /// + /// # Errors + /// If no node with the given path exists, `Err(ZkError::NoNode)` will be returned. + pub async fn get_all_children_number(&self, path: &str) -> ZkResult { + trace!("ZooKeeper::get_all_children_number"); + let req = GetAllChildrenNumberRequest { + path: self.path(path)?, + }; + + let response: GetAllChildrenNumberResponse = self + .request(OpCode::GetAllChildrenNumber, self.xid(), req, None) + .await?; + + Ok(response.total_number) + } + /// Return the data and the `Stat` of the node of the given path. /// /// If `watch` is `true` and the call is successful (no error is returned), a watch will be left diff --git a/tests/test_get_all_children_number.rs b/tests/test_get_all_children_number.rs new file mode 100644 index 000000000..80f824b65 --- /dev/null +++ b/tests/test_get_all_children_number.rs @@ -0,0 +1,160 @@ +use crate::test::ZkCluster; +use std::time::Duration; +use tracing::info; +use zookeeper_async::{Acl, CreateMode, WatchedEvent, Watcher, ZooKeeper}; + +mod test; + +struct LogWatcher; + +impl Watcher for LogWatcher { + fn handle(&self, event: WatchedEvent) { + info!("{:?}", event); + } +} + +async fn create_zk(connection_string: &str) -> ZooKeeper { + ZooKeeper::connect(connection_string, Duration::from_secs(10), LogWatcher) + .await + .unwrap() +} + +#[tokio::test] +async fn zk_get_all_children_number_test() { + // Create a test cluster + let mut cluster = ZkCluster::start(3); + + // Connect to the test cluster + let zk = create_zk(&cluster.connect_string).await; + + // Do the tests + let _ = zk + .create( + "/test", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .await; + // create few children of /test + let _ = zk + .create( + "/test/child1", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .await; + let _ = zk + .create( + "/test/child2", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .await; + let _ = zk + .create( + "/test/child3", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .await; + + let children = zk.get_all_children_number("/test").await; + + assert_eq!( + children, + Ok(3), + "get_all_children_number failed: {:?}", + children + ); + + cluster.kill_an_instance(); + + // After closing the client all operations return Err + zk.close().await.unwrap(); +} + +#[tokio::test] +async fn zk_get_all_children_number_with_subtree_test() { + // Create a test cluster + let mut cluster = ZkCluster::start(3); + + // Connect to the test cluster + let zk = create_zk(&cluster.connect_string).await; + + // Do the tests + let _ = zk + .create( + "/test", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .await; + // create few children of /test + let _ = zk + .create( + "/test/child1", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .await; + let _ = zk + .create( + "/test/child2", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .await; + let _ = zk + .create( + "/test/child2/child21", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .await; + let _ = zk + .create( + "/test/child3", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .await; + let _ = zk + .create( + "/test/child3/child31", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .await; + let _ = zk + .create( + "/test/child3/child31/child311", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .await; + + let children = zk.get_all_children_number("/test").await; + + assert_eq!( + children, + Ok(6), + "get_all_children_number failed: {:?}", + children + ); + + cluster.kill_an_instance(); + + // After closing the client all operations return Err + zk.close().await.unwrap(); +} From be2eae842a36a2d203a5a3534d43edeb165bc659 Mon Sep 17 00:00:00 2001 From: MaximFischuk Date: Fri, 31 May 2024 13:55:58 +0300 Subject: [PATCH 2/2] feat: Add implementation of multple writes (transactions) and atomic reads --- examples/transactions.rs | 106 ++++++++++++++++ src/consts.rs | 3 + src/lib.rs | 2 + src/multi_op.rs | 162 +++++++++++++++++++++++ src/proto.rs | 208 ++++++++++++++++++++++++++++++ src/zookeeper.rs | 125 +++++++++++++++++- tests/test_multi.rs | 268 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 870 insertions(+), 4 deletions(-) create mode 100644 examples/transactions.rs create mode 100644 src/multi_op.rs create mode 100644 tests/test_multi.rs diff --git a/examples/transactions.rs b/examples/transactions.rs new file mode 100644 index 000000000..80fa83ccc --- /dev/null +++ b/examples/transactions.rs @@ -0,0 +1,106 @@ +use std::{env, time::Duration}; + +use zookeeper_async::{WatchedEvent, Watcher, ZooKeeper}; + +struct LoggingWatcher; +impl Watcher for LoggingWatcher { + fn handle(&self, e: WatchedEvent) { + println!("{:?}", e) + } +} + +#[tokio::main] +async fn main() { + let zk_urls = zk_server_urls(); + println!("connecting to {}", zk_urls); + + let zk = ZooKeeper::connect(&zk_urls, Duration::from_secs(15), LoggingWatcher) + .await + .unwrap(); + + // Create transaction that creates a node and a child node + let results = zk + .transaction() + .create( + "/test", + vec![], + zookeeper_async::Acl::open_unsafe().clone(), + zookeeper_async::CreateMode::Persistent, + ) + .create( + "/test/child1", + vec![], + zookeeper_async::Acl::open_unsafe().clone(), + zookeeper_async::CreateMode::Persistent, + ) + // Check that the node exists + .check("/test", None) + .commit() + .await + .unwrap(); + + for result in results { + println!("{:?}", result); + } + + // Create transaction that sets data on a node + let results = zk + .transaction() + .create( + "/test2", + vec![], + zookeeper_async::Acl::open_unsafe().clone(), + zookeeper_async::CreateMode::Persistent, + ) + .set_data("/test2", vec![1, 2, 3], None) + .create( + "/test2/child1", + vec![], + zookeeper_async::Acl::open_unsafe().clone(), + zookeeper_async::CreateMode::Persistent, + ) + .set_data("/test2/child1", vec![4, 5, 6], None) + .commit() + .await + .unwrap(); + + for result in results { + println!("{:?}", result); + } + + // Read the data from the node + let results = zk + .read() + .get_data("/test2", false) + .get_data("/test2/child1", false) + .execute() + .await + .unwrap(); + + for result in results { + println!("{:?}", result); + } + + // Delete all nodes + let results = zk + .transaction() + .delete("/test/child1", None) + .delete("/test", None) + .delete("/test2/child1", None) + .delete("/test2", None) + .commit() + .await + .unwrap(); + + for result in results { + println!("{:?}", result); + } +} + +fn zk_server_urls() -> String { + let key = "ZOOKEEPER_SERVERS"; + match env::var(key) { + Ok(val) => val, + Err(_) => "localhost:2181".to_string(), + } +} diff --git a/src/consts.rs b/src/consts.rs index 4d1784b6b..0fdb1263c 100644 --- a/src/consts.rs +++ b/src/consts.rs @@ -7,6 +7,9 @@ use num_enum::*; )] #[repr(i32)] pub enum ZkError { + /// Operation completed successfully. + /// This code is used to indicate success in transaction operations. + Ok = 0, /// This code is never returned from the server. It should not be used other than to indicate a /// range. Specifically error codes greater than this value are API errors (while values less /// than this indicate a system error). diff --git a/src/lib.rs b/src/lib.rs index dc4cc961a..840eb679f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,6 +5,7 @@ mod consts; mod data; mod io; mod listeners; +mod multi_op; mod paths; mod proto; pub mod recipes; @@ -16,6 +17,7 @@ pub use self::zookeeper::{ZkResult, ZooKeeper}; pub use acl::*; pub use consts::*; pub use data::*; +pub use multi_op::*; pub use watch::{Watch, WatchType, WatchedEvent, Watcher}; pub use zookeeper_ext::ZooKeeperExt; diff --git a/src/multi_op.rs b/src/multi_op.rs new file mode 100644 index 000000000..949c218ec --- /dev/null +++ b/src/multi_op.rs @@ -0,0 +1,162 @@ +use std::time::Duration; + +use crate::{ + proto::{ + CheckRequest, CreateRequest, CreateTTLRequest, DeleteRequest, GetDataRequest, Op, + SetDataRequest, + }, + Acl, CreateMode, Stat, ZkResult, ZooKeeper, +}; + +#[derive(Debug)] +pub enum OperationResult { + Create(String), + Create2(String, Stat), + CreateTtl(String, Stat), + SetData(Stat), + Delete, + Check, +} + +#[derive(Debug)] +pub enum ReadOperationResult { + GetData(Vec, Stat), + GetChildren(Vec), +} + +pub struct Transaction<'a> { + zookeeper: &'a ZooKeeper, + operations: Vec, +} + +pub struct Read<'a> { + zookeeper: &'a ZooKeeper, + operations: Vec, +} + +impl<'a> Transaction<'a> { + pub fn new(zookeeper: &'a ZooKeeper) -> Self { + Self { + zookeeper, + operations: Vec::new(), + } + } + + /// See [ZooKeeper::create] + pub fn create(mut self, path: &str, data: Vec, acl: Vec, mode: CreateMode) -> Self { + self.operations.push(Op::Create(CreateRequest { + path: path.to_string(), + data, + acl, + flags: mode as i32, + })); + self + } + + /// See [ZooKeeper::create2] + pub fn create2(mut self, path: &str, data: Vec, acl: Vec, mode: CreateMode) -> Self { + self.operations.push(Op::Create2(CreateRequest { + path: path.to_string(), + data, + acl, + flags: mode as i32, + })); + self + } + + /// See [ZooKeeper::create_ttl] + pub fn create_ttl( + mut self, + path: &str, + data: Vec, + acl: Vec, + mode: CreateMode, + ttl: Duration, + ) -> Self { + self.operations.push(Op::CreateTtl(CreateTTLRequest { + path: path.to_string(), + data, + acl, + flags: mode as i32, + ttl: ttl.as_millis() as i64, + })); + self + } + + /// See [ZooKeeper::set_data] + pub fn set_data(mut self, path: &str, data: Vec, version: Option) -> Self { + self.operations.push(Op::SetData(SetDataRequest { + path: path.to_string(), + data, + version: version.unwrap_or(-1), + })); + self + } + + /// See [ZooKeeper::delete] + pub fn delete(mut self, path: &str, version: Option) -> Self { + self.operations.push(Op::Delete(DeleteRequest { + path: path.to_string(), + version: version.unwrap_or(-1), + })); + self + } + + /// Check if the path exists and the version matches. If the version is not provided, it will + /// check if the path exists. + pub fn check(mut self, path: &str, version: Option) -> Self { + self.operations.push(Op::Check(CheckRequest { + path: path.to_string(), + version: version.unwrap_or(-1), + })); + self + } + + /// Commit the transaction + /// + /// # Errors + /// + /// If any of the operations fail, the first error will be returned. + /// + /// See [ZooKeeper] for more information on errors. + /// See [crate::ZkError] for list of possible errrors. + pub async fn commit(self) -> ZkResult> { + self.zookeeper.multi(self.operations).await + } +} + +impl<'a> Read<'a> { + pub fn new(zookeeper: &'a ZooKeeper) -> Self { + Self { + zookeeper, + operations: Vec::new(), + } + } + /// See [ZooKeeper::get_data] + pub fn get_data(mut self, path: &str, watch: bool) -> Self { + self.operations.push(Op::GetData(GetDataRequest { + path: path.to_string(), + watch, + })); + self + } + + /// See [ZooKeeper::get_children] + pub fn get_children(mut self, path: &str, watch: bool) -> Self { + self.operations.push(Op::GetChildren(GetDataRequest { + path: path.to_string(), + watch, + })); + self + } + + /// # Errors + /// + /// If any of the operations fail, the first error will be returned. + /// + /// See [ZooKeeper] for more information on errors. + /// See [crate::ZkError] for list of possible errrors. + pub async fn execute(self) -> ZkResult> { + self.zookeeper.multi_read(self.operations).await + } +} diff --git a/src/proto.rs b/src/proto.rs index bef53fa80..2cae55dfe 100644 --- a/src/proto.rs +++ b/src/proto.rs @@ -19,10 +19,43 @@ pub enum OpCode { Ping = 11, CloseSession = -11, + Check = 13, + Multi = 14, Create2 = 15, CreateTtl = 21, + MultiRead = 22, GetAllChildrenNumber = 104, + + Error = -1, +} + +impl TryFrom for OpCode { + type Error = Error; + + fn try_from(value: i32) -> Result { + match value { + 1 => Ok(OpCode::Create), + 2 => Ok(OpCode::Delete), + 3 => Ok(OpCode::Exists), + 4 => Ok(OpCode::GetData), + 5 => Ok(OpCode::SetData), + 6 => Ok(OpCode::GetAcl), + 7 => Ok(OpCode::SetAcl), + 8 => Ok(OpCode::GetChildren), + 11 => Ok(OpCode::Ping), + -11 => Ok(OpCode::CloseSession), + 13 => Ok(OpCode::Check), + 14 => Ok(OpCode::Multi), + 15 => Ok(OpCode::Create2), + 21 => Ok(OpCode::CreateTtl), + 22 => Ok(OpCode::MultiRead), + 100 => Ok(OpCode::Auth), + 104 => Ok(OpCode::GetAllChildrenNumber), + -1 => Ok(OpCode::Error), + _ => Err(error("Invalid op code")), + } + } } pub type ByteBuf = Cursor>; @@ -418,6 +451,168 @@ impl WriteTo for SetDataRequest { pub type SetDataResponse = StatResponse; +pub enum Op { + Create(CreateRequest), + Create2(CreateRequest), + CreateTtl(CreateTTLRequest), + SetData(SetDataRequest), + Delete(DeleteRequest), + Check(CheckRequest), + + // Read operations + GetData(StringAndBoolRequest), + GetChildren(StringAndBoolRequest), +} + +struct MultiHeader { + op: i32, + done: bool, + err: i32, +} + +impl MultiHeader { + pub fn new(op: OpCode) -> MultiHeader { + MultiHeader { + op: op as i32, + done: false, + err: -1, + } + } + + pub fn done() -> MultiHeader { + MultiHeader { + op: -1, + done: true, + err: -1, + } + } + + pub fn is_done(&self) -> bool { + self.done + } +} + +impl WriteTo for MultiHeader { + fn write_to(&self, writer: &mut dyn Write) -> Result<()> { + writer.write_i32::(self.op)?; + writer.write_u8(self.done as u8)?; + writer.write_i32::(self.err)?; + Ok(()) + } +} + +impl ReadFrom for MultiHeader { + fn read_from(reader: &mut R) -> Result { + Ok(MultiHeader { + op: reader.read_i32::()?, + done: reader.read_u8()? != 0, + err: reader.read_i32::()?, + }) + } +} + +pub struct MultiRequest { + pub ops: Vec, +} + +impl WriteTo for MultiRequest { + fn write_to(&self, writer: &mut dyn Write) -> Result<()> { + for op in &self.ops { + match op { + Op::Create(req) => { + MultiHeader::new(OpCode::Create).write_to(writer)?; + req.write_to(writer)?; + } + Op::Create2(req) => { + MultiHeader::new(OpCode::Create2).write_to(writer)?; + req.write_to(writer)?; + } + Op::CreateTtl(req) => { + MultiHeader::new(OpCode::CreateTtl).write_to(writer)?; + req.write_to(writer)?; + } + Op::SetData(req) => { + MultiHeader::new(OpCode::SetData).write_to(writer)?; + req.write_to(writer)?; + } + Op::Delete(req) => { + MultiHeader::new(OpCode::Delete).write_to(writer)?; + req.write_to(writer)?; + } + Op::Check(req) => { + MultiHeader::new(OpCode::Check).write_to(writer)?; + req.write_to(writer)?; + } + Op::GetData(req) => { + MultiHeader::new(OpCode::GetData).write_to(writer)?; + req.write_to(writer)?; + } + Op::GetChildren(req) => { + MultiHeader::new(OpCode::GetChildren).write_to(writer)?; + req.write_to(writer)?; + } + } + } + + MultiHeader::done().write_to(writer)?; + Ok(()) + } +} + +pub enum OpResponse { + Create(CreateResponse), + Create2(Create2Response), + CreateTtl(Create2Response), + SetData(SetDataResponse), + Delete, + Check, + + // Read operations + GetData(GetDataResponse), + GetChildren(GetChildrenResponse), + + // Error + Error(i32), +} + +pub struct MultiResponse { + pub results: Vec, +} + +impl ReadFrom for MultiResponse { + fn read_from(reader: &mut R) -> Result { + let mut results = Vec::new(); + loop { + let header = MultiHeader::read_from(reader)?; + if header.is_done() { + break; + } + + let result = match OpCode::try_from(header.op) { + Ok(OpCode::Create) => OpResponse::Create(CreateResponse::read_from(reader)?), + Ok(OpCode::Create2) => OpResponse::Create2(Create2Response::read_from(reader)?), + Ok(OpCode::CreateTtl) => OpResponse::CreateTtl(Create2Response::read_from(reader)?), + Ok(OpCode::SetData) => OpResponse::SetData(SetDataResponse::read_from(reader)?), + Ok(OpCode::Delete) => OpResponse::Delete, + Ok(OpCode::Check) => OpResponse::Check, + Ok(OpCode::GetData) => OpResponse::GetData(GetDataResponse::read_from(reader)?), + Ok(OpCode::GetChildren) => { + OpResponse::GetChildren(GetChildrenResponse::read_from(reader)?) + } + Ok(OpCode::Error) => { + let err = reader.read_i32::()?; + OpResponse::Error(err) + } + _ => return Err(error("Invalid op code")), + }; + + results.push(result); + } + + Ok(MultiResponse { results }) + } +} + pub type GetChildrenRequest = StringAndBoolRequest; pub struct GetChildrenResponse { @@ -518,3 +713,16 @@ impl ReadFrom for WatchedEvent { }) } } + +pub struct CheckRequest { + pub path: String, + pub version: i32, +} + +impl WriteTo for CheckRequest { + fn write_to(&self, writer: &mut dyn Write) -> Result<()> { + self.path.write_to(writer)?; + writer.write_i32::(self.version)?; + Ok(()) + } +} diff --git a/src/zookeeper.rs b/src/zookeeper.rs index 00c624db2..85545b999 100644 --- a/src/zookeeper.rs +++ b/src/zookeeper.rs @@ -1,3 +1,4 @@ +use num_enum::FromPrimitive; use std::convert::From; use std::fmt::{Debug, Formatter, Result as FmtResult}; use std::net::{SocketAddr, ToSocketAddrs}; @@ -16,13 +17,14 @@ use crate::proto::{ to_len_prefixed_buf, AuthRequest, ByteBuf, Create2Response, CreateRequest, CreateResponse, CreateTTLRequest, DeleteRequest, EmptyRequest, EmptyResponse, ExistsRequest, ExistsResponse, GetAclRequest, GetAclResponse, GetAllChildrenNumberRequest, GetAllChildrenNumberResponse, - GetChildrenRequest, GetChildrenResponse, GetDataRequest, GetDataResponse, OpCode, ReadFrom, - ReplyHeader, RequestHeader, SetAclRequest, SetAclResponse, SetDataRequest, SetDataResponse, - WriteTo, + GetChildrenRequest, GetChildrenResponse, GetDataRequest, GetDataResponse, MultiRequest, + MultiResponse, Op, OpCode, OpResponse, ReadFrom, ReplyHeader, RequestHeader, SetAclRequest, + SetAclResponse, SetDataRequest, SetDataResponse, WriteTo, }; use crate::watch::ZkWatch; use crate::{ - Acl, CreateMode, Stat, Subscription, Watch, WatchType, WatchedEvent, Watcher, ZkError, ZkState, + Acl, CreateMode, OperationResult, Read, ReadOperationResult, Stat, Subscription, Transaction, + Watch, WatchType, WatchedEvent, Watcher, ZkError, ZkState, }; /// Value returned from potentially-error operations. @@ -442,6 +444,121 @@ impl ZooKeeper { Ok(()) } + /// Create a new transaction builder. A transaction is a set of operations that are executed + /// atomically. + pub fn transaction(&self) -> Transaction { + Transaction::new(self) + } + + pub(crate) async fn multi(&self, mut ops: Vec) -> ZkResult> { + trace!("ZooKeeper::multi_op"); + for operation in ops.iter_mut() { + match operation { + Op::Check(op) => { + op.path = self.path(op.path.as_str())?; + } + Op::Create(op) => { + op.path = self.path(op.path.as_str())?; + } + Op::Create2(op) => { + op.path = self.path(op.path.as_str())?; + } + Op::CreateTtl(op) => { + op.path = self.path(op.path.as_str())?; + } + Op::SetData(op) => { + op.path = self.path(op.path.as_str())?; + } + Op::Delete(op) => { + op.path = self.path(op.path.as_str())?; + } + _ => unreachable!(), + } + } + + let req = MultiRequest { ops }; + + let response: MultiResponse = self.request(OpCode::Multi, self.xid(), req, None).await?; + + let mut result = Vec::with_capacity(response.results.len()); + + for r in response.results { + let op_result = match r { + OpResponse::Create(r) => OperationResult::Create(self.cut_chroot(r.path)), + OpResponse::Create2(r) => OperationResult::Create2(self.cut_chroot(r.path), r.stat), + OpResponse::CreateTtl(r) => { + OperationResult::CreateTtl(self.cut_chroot(r.path), r.stat) + } + OpResponse::SetData(r) => OperationResult::SetData(r.stat), + OpResponse::Delete => OperationResult::Delete, + OpResponse::Check => OperationResult::Check, + + OpResponse::Error(0) => { + // ignore successful operations, because there is guarantee non zero error when some operation is failed + continue; + } + OpResponse::Error(err) => return Err(ZkError::from_primitive(err)), + + OpResponse::GetChildren(_) | OpResponse::GetData(_) => { + unreachable!() + } + }; + + result.push(op_result); + } + + Ok(result) + } + + /// Create a new read builder. This is a set of read operations that are executed atomically. + pub fn read(&self) -> Read { + Read::new(self) + } + + pub(crate) async fn multi_read(&self, mut ops: Vec) -> ZkResult> { + trace!("ZooKeeper::multi_read"); + for operation in ops.iter_mut() { + match operation { + Op::GetChildren(op) => { + op.path = self.path(op.path.as_str())?; + } + Op::GetData(op) => { + op.path = self.path(op.path.as_str())?; + } + _ => unreachable!(), + } + } + + let req = MultiRequest { ops }; + + let response: MultiResponse = self + .request(OpCode::MultiRead, self.xid(), req, None) + .await?; + + let mut result = Vec::with_capacity(response.results.len()); + + for r in response.results { + let op_result = match r { + OpResponse::GetData(r) => { + ReadOperationResult::GetData(r.data_stat.0, r.data_stat.1) + } + OpResponse::GetChildren(r) => ReadOperationResult::GetChildren(r.children), + + OpResponse::Error(0) => { + // ignore successful operations, because there is guarantee non zero error when some operation is failed + continue; + } + OpResponse::Error(err) => return Err(ZkError::from_primitive(err)), + + _ => unreachable!(), + }; + + result.push(op_result); + } + + Ok(result) + } + /// Return the `Stat` of the node of the given `path` or `None` if no such node exists. /// /// If the `watch` is `true` and the call is successful (no error is returned), a watch will be diff --git a/tests/test_multi.rs b/tests/test_multi.rs new file mode 100644 index 000000000..3428588bc --- /dev/null +++ b/tests/test_multi.rs @@ -0,0 +1,268 @@ +use crate::test::ZkCluster; +use std::time::Duration; +use zookeeper_async::{Acl, CreateMode, ZkError, ZooKeeper}; + +mod test; + +async fn create_zk(connection_string: &str) -> ZooKeeper { + ZooKeeper::connect(connection_string, Duration::from_secs(10), |_ev| {}) + .await + .unwrap() +} + +#[tokio::test] +async fn zk_multi() { + // Create a test cluster + let cluster = ZkCluster::start(3); + + // Connect to the test cluster + let zk = create_zk(&cluster.connect_string).await; + + let results = zk + .transaction() + .create( + "/test", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .create( + "/test/child1", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .check("/test", Some(0)) + .commit() + .await; + + assert!(results.is_ok()); + let results = results.unwrap(); + assert_eq!(results.len(), 3); + + assert!(zk.exists("/test", false).await.unwrap().is_some()); + assert!(zk.exists("/test/child1", false).await.unwrap().is_some()); + + // After closing the client all operations return Err + zk.close().await.unwrap(); +} + +#[tokio::test] +async fn zk_multi_w_set_data() { + // Create a test cluster + let cluster = ZkCluster::start(3); + + // Connect to the test cluster + let zk = create_zk(&cluster.connect_string).await; + + let results = zk + .transaction() + .create2( + "/test", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .create2( + "/test/child1", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .set_data("/test", vec![1, 2, 3], None) + .check("/test", Some(1)) + .commit() + .await; + + assert!(results.is_ok()); + let results = results.unwrap(); + assert_eq!(results.len(), 4); + + assert!(zk.exists("/test", false).await.unwrap().is_some()); + assert!(zk.exists("/test/child1", false).await.unwrap().is_some()); + + let data = zk.get_data("/test", false).await.unwrap(); + assert_eq!(data.0, vec![1, 2, 3]); + + // After closing the client all operations return Err + zk.close().await.unwrap(); +} + +#[tokio::test] +async fn zk_multi_w_delete() { + // Create a test cluster + let cluster = ZkCluster::start(3); + + // Connect to the test cluster + let zk = create_zk(&cluster.connect_string).await; + + let results = zk + .transaction() + .create( + "/test", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .create( + "/test/child1", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .delete("/test/child1", None) + .commit() + .await; + + assert!(results.is_ok()); + let results = results.unwrap(); + assert_eq!(results.len(), 3); + + assert!(zk.exists("/test", false).await.unwrap().is_some()); + assert!(zk.exists("/test/child1", false).await.unwrap().is_none()); + + // After closing the client all operations return Err + zk.close().await.unwrap(); +} + +#[tokio::test] +async fn zk_multi_error() { + // Create a test cluster + let cluster = ZkCluster::start(3); + + // Connect to the test cluster + let zk = create_zk(&cluster.connect_string).await; + + let results = zk + .transaction() + .create( + "/test", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .check("/test", Some(2)) // This should fail because the version is wrong + .commit() + .await; + + let Err(error) = results else { + panic!("Expected an error"); + }; + + assert_eq!(error, ZkError::BadVersion); + + let results = zk + .transaction() + .create( + "/test", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .check("/test-wrong", None) + .commit() + .await; + + let Err(error) = results else { + panic!("Expected an error"); + }; + + assert_eq!(error, ZkError::NoNode); + + let results = zk + .transaction() + .create( + "/test", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .create( + "/test", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .commit() + .await; + + let Err(error) = results else { + panic!("Expected an error"); + }; + + assert_eq!(error, ZkError::NodeExists); + + // Ensure that the transaction was not committed + + assert!(zk.exists("/test", false).await.unwrap().is_none()); + + // After closing the client all operations return Err + zk.close().await.unwrap(); +} + +#[tokio::test] +async fn zk_multi_read() { + // Create a test cluster + let cluster = ZkCluster::start(3); + + // Connect to the test cluster + let zk = create_zk(&cluster.connect_string).await; + + let results = zk + .transaction() + .create( + "/test", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .create( + "/test/child1", + vec![], + Acl::open_unsafe().clone(), + CreateMode::Persistent, + ) + .commit() + .await; + + assert!(results.is_ok()); + let results = results.unwrap(); + assert_eq!(results.len(), 2); + + let results = zk + .read() + .get_data("/test", false) + .get_children("/test", false) + .execute() + .await; + + // assert!(results.is_ok()); + let results = results.unwrap(); + assert_eq!(results.len(), 2); + + assert!(zk.exists("/test", false).await.unwrap().is_some()); + assert!(zk.exists("/test/child1", false).await.unwrap().is_some()); + + // After closing the client all operations return Err + zk.close().await.unwrap(); +} + +#[tokio::test] +async fn zk_multi_read_error() { + // Create a test cluster + let cluster = ZkCluster::start(3); + + // Connect to the test cluster + let zk = create_zk(&cluster.connect_string).await; + + let results = zk.read().get_data("/test", false).execute().await; + + let Err(error) = results else { + panic!("Expected an error"); + }; + + assert_eq!(error, ZkError::NoNode); + + // After closing the client all operations return Err + zk.close().await.unwrap(); +}