This is an automated email from the ASF dual-hosted git repository. hoslo pushed a commit to branch add-ascii-auth in repository https://gitbox.apache.org/repos/asf/opendal.git
commit d7b5d73ac45d923dfe20e68faa5c9bc3d9fc2bcc Author: hoslo <[email protected]> AuthorDate: Thu Feb 22 17:08:29 2024 +1000 feat(services/memcached): add binary protocal support --- core/Cargo.toml | 1 + core/src/services/memcached/backend.rs | 23 ++- core/src/services/memcached/binary.rs | 270 +++++++++++++++++++++++++++++++++ core/src/services/memcached/mod.rs | 1 + 4 files changed, 292 insertions(+), 3 deletions(-) diff --git a/core/Cargo.toml b/core/Cargo.toml index 3796268c93..e62360df2b 100644 --- a/core/Cargo.toml +++ b/core/Cargo.toml @@ -56,6 +56,7 @@ default = [ "services-webdav", "services-webhdfs", "services-azfile", + "services-memcached", ] # Build test utils or not. diff --git a/core/src/services/memcached/backend.rs b/core/src/services/memcached/backend.rs index f0d48e4a91..d7d0629f69 100644 --- a/core/src/services/memcached/backend.rs +++ b/core/src/services/memcached/backend.rs @@ -24,7 +24,7 @@ use serde::Deserialize; use tokio::net::TcpStream; use tokio::sync::OnceCell; -use super::ascii; +use super::{ascii, binary}; use crate::raw::adapters::kv; use crate::raw::*; use crate::*; @@ -42,6 +42,10 @@ pub struct MemcachedConfig { /// /// default is "/" root: Option<String>, + /// Memcached username, optional. + username: Option<String>, + /// Memcached password, optional. + password: Option<String>, /// The default ttl for put operations. default_ttl: Option<Duration>, } @@ -74,6 +78,18 @@ impl MemcachedBuilder { self } + /// set the username. + pub fn username(&mut self, username: &str) -> &mut Self { + self.config.username = Some(username.to_string()); + self + } + + /// set the password. + pub fn password(&mut self, password: &str) -> &mut Self { + self.config.password = Some(password.to_string()); + self + } + /// Set the default ttl for memcached services. pub fn default_ttl(&mut self, ttl: Duration) -> &mut Self { self.config.default_ttl = Some(ttl); @@ -249,7 +265,7 @@ impl MemcacheConnectionManager { #[async_trait] impl bb8::ManageConnection for MemcacheConnectionManager { - type Connection = ascii::Connection; + type Connection = binary::Connection; type Error = Error; /// TODO: Implement unix stream support. @@ -257,10 +273,11 @@ impl bb8::ManageConnection for MemcacheConnectionManager { let conn = TcpStream::connect(&self.address) .await .map_err(new_std_io_error)?; - Ok(ascii::Connection::new(conn)) + Ok(binary::Connection::new(conn)) } async fn is_valid(&self, conn: &mut Self::Connection) -> std::result::Result<(), Self::Error> { + conn.auth("test", "test").await?; conn.version().await.map(|_| ()) } diff --git a/core/src/services/memcached/binary.rs b/core/src/services/memcached/binary.rs new file mode 100644 index 0000000000..dd5ae1ca0c --- /dev/null +++ b/core/src/services/memcached/binary.rs @@ -0,0 +1,270 @@ +use futures::TryFutureExt; +use log::debug; +use tokio::io::{self, AsyncReadExt, AsyncWriteExt, BufReader}; +use tokio::net::TcpStream; + +use crate::raw::*; +use crate::*; + +const OK_STATUS: u16 = 0x0; +const KEY_NOT_FOUND: u16 = 0x1; + +pub enum Opcode { + Get = 0x00, + Set = 0x01, + Delete = 0x04, + Version = 0x0b, + StartAuth = 0x21, +} + +pub enum Magic { + Request = 0x80, + Response = 0x81, +} + +#[derive(Debug)] +pub struct StoreExtras { + pub flags: u32, + pub expiration: u32, +} + +#[derive(Debug, Default)] +pub struct PacketHeader { + pub magic: u8, + pub opcode: u8, + pub key_length: u16, + pub extras_length: u8, + pub data_type: u8, + pub vbucket_id_or_status: u16, + pub total_body_length: u32, + pub opaque: u32, + pub cas: u64, +} + +impl PacketHeader { + pub async fn write(self, writer: &mut TcpStream) -> io::Result<()> { + writer.write_u8(self.magic).await?; + writer.write_u8(self.opcode).await?; + writer.write_u16(self.key_length).await?; + writer.write_u8(self.extras_length).await?; + writer.write_u8(self.data_type).await?; + writer.write_u16(self.vbucket_id_or_status).await?; + writer.write_u32(self.total_body_length).await?; + writer.write_u32(self.opaque).await?; + writer.write_u64(self.cas).await?; + return Ok(()); + } + + pub async fn read(reader: &mut TcpStream) -> std::result::Result<PacketHeader, io::Error> { + let header = PacketHeader { + magic: reader.read_u8().await?, + opcode: reader.read_u8().await?, + key_length: reader.read_u16().await?, + extras_length: reader.read_u8().await?, + data_type: reader.read_u8().await?, + vbucket_id_or_status: reader.read_u16().await?, + total_body_length: reader.read_u32().await?, + opaque: reader.read_u32().await?, + cas: reader.read_u64().await?, + }; + return Ok(header); + } +} + +pub struct Response { + header: PacketHeader, + key: Vec<u8>, + extras: Vec<u8>, + value: Vec<u8>, +} + +pub struct Connection { + io: BufReader<TcpStream>, + buf: Vec<u8>, +} + +impl Connection { + pub fn new(io: TcpStream) -> Self { + Self { + io: BufReader::new(io), + buf: Vec::new(), + } + } + + pub async fn auth(&mut self, username: &str, password: &str) -> Result<()> { + let writer = self.io.get_mut(); + let key = "PLAIN"; + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::StartAuth as u8, + key_length: key.len() as u16, + total_body_length: (key.len() + username.len() + password.len() + 2) as u32, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer + .write_all(format!("\x00{}\x00{}", username, password).as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + parse_response(writer).await?; + Ok(()) + } + + pub async fn version(&mut self) -> Result<String> { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Version as u8, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + let response = parse_response(writer).await?; + let version = String::from_utf8(response.value); + match version { + Ok(version) => Ok(version), + Err(e) => { + Err(Error::new(ErrorKind::Unexpected, "unexpected data received").set_source(e)) + } + } + } + + pub async fn get(&mut self, key: &str) -> Result<Option<Vec<u8>>> { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Get as u8, + key_length: key.len() as u16, + total_body_length: key.len() as u32, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + match parse_response(writer).await { + Ok(response) => { + if response.header.vbucket_id_or_status == 0x1 { + return Ok(None); + } + Ok(Some(response.value)) + } + Err(e) => Err(e), + } + } + + pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Set as u8, + key_length: key.len() as u16, + extras_length: 8, + total_body_length: (8 + key.len() + val.len()) as u32, + ..Default::default() + }; + let extras = StoreExtras { + flags: 0, + expiration, + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_u32(extras.flags) + .await + .map_err(new_std_io_error)?; + writer + .write_u32(extras.expiration) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.write_all(val).await.map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + + parse_response(writer).await?; + Ok(()) + } + + pub async fn delete(&mut self, key: &str) -> Result<()> { + let writer = self.io.get_mut(); + let request_header = PacketHeader { + magic: Magic::Request as u8, + opcode: Opcode::Delete as u8, + key_length: key.len() as u16, + total_body_length: key.len() as u32, + ..Default::default() + }; + request_header + .write(writer) + .await + .map_err(new_std_io_error)?; + writer + .write_all(key.as_bytes()) + .await + .map_err(new_std_io_error)?; + writer.flush().await.map_err(new_std_io_error)?; + parse_response(writer).await?; + Ok(()) + } +} + +pub async fn parse_response(reader: &mut TcpStream) -> Result<Response> { + let header = PacketHeader::read(reader).await.map_err(new_std_io_error)?; + + if header.vbucket_id_or_status != OK_STATUS && header.vbucket_id_or_status != KEY_NOT_FOUND { + return Err( + Error::new(ErrorKind::Unexpected, "unexpected status received") + .with_context("message", format!("{}", header.vbucket_id_or_status)), + ); + } + + let mut extras = vec![0x0; header.extras_length as usize]; + reader + .read_exact(extras.as_mut_slice()) + .await + .map_err(new_std_io_error)?; + + let mut key = vec![0x0; header.key_length as usize]; + reader + .read_exact(key.as_mut_slice()) + .await + .map_err(new_std_io_error)?; + + let mut value = vec![ + 0x0; + (header.total_body_length - u32::from(header.key_length) - u32::from(header.extras_length)) + as usize + ]; + reader + .read_exact(value.as_mut_slice()) + .await + .map_err(new_std_io_error)?; + + Ok(Response { + header, + key, + extras, + value, + }) +} diff --git a/core/src/services/memcached/mod.rs b/core/src/services/memcached/mod.rs index bbe45219a1..72788d1e71 100644 --- a/core/src/services/memcached/mod.rs +++ b/core/src/services/memcached/mod.rs @@ -19,3 +19,4 @@ mod backend; pub use backend::MemcachedBuilder as Memcached; pub use backend::MemcachedConfig; mod ascii; +mod binary;
