This is an automated email from the ASF dual-hosted git repository.

hoslo pushed a commit to branch add-ascii-auth
in repository https://gitbox.apache.org/repos/asf/opendal.git

commit d7b5d73ac45d923dfe20e68faa5c9bc3d9fc2bcc
Author: hoslo <[email protected]>
AuthorDate: Thu Feb 22 17:08:29 2024 +1000

    feat(services/memcached): add binary protocal support
---
 core/Cargo.toml                        |   1 +
 core/src/services/memcached/backend.rs |  23 ++-
 core/src/services/memcached/binary.rs  | 270 +++++++++++++++++++++++++++++++++
 core/src/services/memcached/mod.rs     |   1 +
 4 files changed, 292 insertions(+), 3 deletions(-)

diff --git a/core/Cargo.toml b/core/Cargo.toml
index 3796268c93..e62360df2b 100644
--- a/core/Cargo.toml
+++ b/core/Cargo.toml
@@ -56,6 +56,7 @@ default = [
   "services-webdav",
   "services-webhdfs",
   "services-azfile",
+  "services-memcached",
 ]
 
 # Build test utils or not.
diff --git a/core/src/services/memcached/backend.rs 
b/core/src/services/memcached/backend.rs
index f0d48e4a91..d7d0629f69 100644
--- a/core/src/services/memcached/backend.rs
+++ b/core/src/services/memcached/backend.rs
@@ -24,7 +24,7 @@ use serde::Deserialize;
 use tokio::net::TcpStream;
 use tokio::sync::OnceCell;
 
-use super::ascii;
+use super::{ascii, binary};
 use crate::raw::adapters::kv;
 use crate::raw::*;
 use crate::*;
@@ -42,6 +42,10 @@ pub struct MemcachedConfig {
     ///
     /// default is "/"
     root: Option<String>,
+    /// Memcached username, optional.
+    username: Option<String>,
+    /// Memcached password, optional.
+    password: Option<String>,
     /// The default ttl for put operations.
     default_ttl: Option<Duration>,
 }
@@ -74,6 +78,18 @@ impl MemcachedBuilder {
         self
     }
 
+    /// set the username.
+    pub fn username(&mut self, username: &str) -> &mut Self {
+        self.config.username = Some(username.to_string());
+        self
+    }
+
+    /// set the password.
+    pub fn password(&mut self, password: &str) -> &mut Self {
+        self.config.password = Some(password.to_string());
+        self
+    }
+
     /// Set the default ttl for memcached services.
     pub fn default_ttl(&mut self, ttl: Duration) -> &mut Self {
         self.config.default_ttl = Some(ttl);
@@ -249,7 +265,7 @@ impl MemcacheConnectionManager {
 
 #[async_trait]
 impl bb8::ManageConnection for MemcacheConnectionManager {
-    type Connection = ascii::Connection;
+    type Connection = binary::Connection;
     type Error = Error;
 
     /// TODO: Implement unix stream support.
@@ -257,10 +273,11 @@ impl bb8::ManageConnection for MemcacheConnectionManager {
         let conn = TcpStream::connect(&self.address)
             .await
             .map_err(new_std_io_error)?;
-        Ok(ascii::Connection::new(conn))
+        Ok(binary::Connection::new(conn))
     }
 
     async fn is_valid(&self, conn: &mut Self::Connection) -> 
std::result::Result<(), Self::Error> {
+        conn.auth("test", "test").await?;
         conn.version().await.map(|_| ())
     }
 
diff --git a/core/src/services/memcached/binary.rs 
b/core/src/services/memcached/binary.rs
new file mode 100644
index 0000000000..dd5ae1ca0c
--- /dev/null
+++ b/core/src/services/memcached/binary.rs
@@ -0,0 +1,270 @@
+use futures::TryFutureExt;
+use log::debug;
+use tokio::io::{self, AsyncReadExt, AsyncWriteExt, BufReader};
+use tokio::net::TcpStream;
+
+use crate::raw::*;
+use crate::*;
+
+const OK_STATUS: u16 = 0x0;
+const KEY_NOT_FOUND: u16 = 0x1;
+
+pub enum Opcode {
+    Get = 0x00,
+    Set = 0x01,
+    Delete = 0x04,
+    Version = 0x0b,
+    StartAuth = 0x21,
+}
+
+pub enum Magic {
+    Request = 0x80,
+    Response = 0x81,
+}
+
+#[derive(Debug)]
+pub struct StoreExtras {
+    pub flags: u32,
+    pub expiration: u32,
+}
+
+#[derive(Debug, Default)]
+pub struct PacketHeader {
+    pub magic: u8,
+    pub opcode: u8,
+    pub key_length: u16,
+    pub extras_length: u8,
+    pub data_type: u8,
+    pub vbucket_id_or_status: u16,
+    pub total_body_length: u32,
+    pub opaque: u32,
+    pub cas: u64,
+}
+
+impl PacketHeader {
+    pub async fn write(self, writer: &mut TcpStream) -> io::Result<()> {
+        writer.write_u8(self.magic).await?;
+        writer.write_u8(self.opcode).await?;
+        writer.write_u16(self.key_length).await?;
+        writer.write_u8(self.extras_length).await?;
+        writer.write_u8(self.data_type).await?;
+        writer.write_u16(self.vbucket_id_or_status).await?;
+        writer.write_u32(self.total_body_length).await?;
+        writer.write_u32(self.opaque).await?;
+        writer.write_u64(self.cas).await?;
+        return Ok(());
+    }
+
+    pub async fn read(reader: &mut TcpStream) -> 
std::result::Result<PacketHeader, io::Error> {
+        let header = PacketHeader {
+            magic: reader.read_u8().await?,
+            opcode: reader.read_u8().await?,
+            key_length: reader.read_u16().await?,
+            extras_length: reader.read_u8().await?,
+            data_type: reader.read_u8().await?,
+            vbucket_id_or_status: reader.read_u16().await?,
+            total_body_length: reader.read_u32().await?,
+            opaque: reader.read_u32().await?,
+            cas: reader.read_u64().await?,
+        };
+        return Ok(header);
+    }
+}
+
+pub struct Response {
+    header: PacketHeader,
+    key: Vec<u8>,
+    extras: Vec<u8>,
+    value: Vec<u8>,
+}
+
+pub struct Connection {
+    io: BufReader<TcpStream>,
+    buf: Vec<u8>,
+}
+
+impl Connection {
+    pub fn new(io: TcpStream) -> Self {
+        Self {
+            io: BufReader::new(io),
+            buf: Vec::new(),
+        }
+    }
+
+    pub async fn auth(&mut self, username: &str, password: &str) -> Result<()> 
{
+        let writer = self.io.get_mut();
+        let key = "PLAIN";
+        let request_header = PacketHeader {
+            magic: Magic::Request as u8,
+            opcode: Opcode::StartAuth as u8,
+            key_length: key.len() as u16,
+            total_body_length: (key.len() + username.len() + password.len() + 
2) as u32,
+            ..Default::default()
+        };
+        request_header
+            .write(writer)
+            .await
+            .map_err(new_std_io_error)?;
+        writer
+            .write_all(key.as_bytes())
+            .await
+            .map_err(new_std_io_error)?;
+        writer
+            .write_all(format!("\x00{}\x00{}", username, password).as_bytes())
+            .await
+            .map_err(new_std_io_error)?;
+        writer.flush().await.map_err(new_std_io_error)?;
+        parse_response(writer).await?;
+        Ok(())
+    }
+
+    pub async fn version(&mut self) -> Result<String> {
+        let writer = self.io.get_mut();
+        let request_header = PacketHeader {
+            magic: Magic::Request as u8,
+            opcode: Opcode::Version as u8,
+            ..Default::default()
+        };
+        request_header
+            .write(writer)
+            .await
+            .map_err(new_std_io_error)?;
+        writer.flush().await.map_err(new_std_io_error)?;
+        let response = parse_response(writer).await?;
+        let version = String::from_utf8(response.value);
+        match version {
+            Ok(version) => Ok(version),
+            Err(e) => {
+                Err(Error::new(ErrorKind::Unexpected, "unexpected data 
received").set_source(e))
+            }
+        }
+    }
+
+    pub async fn get(&mut self, key: &str) -> Result<Option<Vec<u8>>> {
+        let writer = self.io.get_mut();
+        let request_header = PacketHeader {
+            magic: Magic::Request as u8,
+            opcode: Opcode::Get as u8,
+            key_length: key.len() as u16,
+            total_body_length: key.len() as u32,
+            ..Default::default()
+        };
+        request_header
+            .write(writer)
+            .await
+            .map_err(new_std_io_error)?;
+        writer
+            .write_all(key.as_bytes())
+            .await
+            .map_err(new_std_io_error)?;
+        writer.flush().await.map_err(new_std_io_error)?;
+        match parse_response(writer).await {
+            Ok(response) => {
+                if response.header.vbucket_id_or_status == 0x1 {
+                    return Ok(None);
+                }
+                Ok(Some(response.value))
+            }
+            Err(e) => Err(e),
+        }
+    }
+
+    pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> 
Result<()> {
+        let writer = self.io.get_mut();
+        let request_header = PacketHeader {
+            magic: Magic::Request as u8,
+            opcode: Opcode::Set as u8,
+            key_length: key.len() as u16,
+            extras_length: 8,
+            total_body_length: (8 + key.len() + val.len()) as u32,
+            ..Default::default()
+        };
+        let extras = StoreExtras {
+            flags: 0,
+            expiration,
+        };
+        request_header
+            .write(writer)
+            .await
+            .map_err(new_std_io_error)?;
+        writer
+            .write_u32(extras.flags)
+            .await
+            .map_err(new_std_io_error)?;
+        writer
+            .write_u32(extras.expiration)
+            .await
+            .map_err(new_std_io_error)?;
+        writer
+            .write_all(key.as_bytes())
+            .await
+            .map_err(new_std_io_error)?;
+        writer.write_all(val).await.map_err(new_std_io_error)?;
+        writer.flush().await.map_err(new_std_io_error)?;
+
+        parse_response(writer).await?;
+        Ok(())
+    }
+
+    pub async fn delete(&mut self, key: &str) -> Result<()> {
+        let writer = self.io.get_mut();
+        let request_header = PacketHeader {
+            magic: Magic::Request as u8,
+            opcode: Opcode::Delete as u8,
+            key_length: key.len() as u16,
+            total_body_length: key.len() as u32,
+            ..Default::default()
+        };
+        request_header
+            .write(writer)
+            .await
+            .map_err(new_std_io_error)?;
+        writer
+            .write_all(key.as_bytes())
+            .await
+            .map_err(new_std_io_error)?;
+        writer.flush().await.map_err(new_std_io_error)?;
+        parse_response(writer).await?;
+        Ok(())
+    }
+}
+
+pub async fn parse_response(reader: &mut TcpStream) -> Result<Response> {
+    let header = PacketHeader::read(reader).await.map_err(new_std_io_error)?;
+
+    if header.vbucket_id_or_status != OK_STATUS && header.vbucket_id_or_status 
!= KEY_NOT_FOUND {
+        return Err(
+            Error::new(ErrorKind::Unexpected, "unexpected status received")
+                .with_context("message", format!("{}", 
header.vbucket_id_or_status)),
+        );
+    }
+
+    let mut extras = vec![0x0; header.extras_length as usize];
+    reader
+        .read_exact(extras.as_mut_slice())
+        .await
+        .map_err(new_std_io_error)?;
+
+    let mut key = vec![0x0; header.key_length as usize];
+    reader
+        .read_exact(key.as_mut_slice())
+        .await
+        .map_err(new_std_io_error)?;
+
+    let mut value = vec![
+        0x0;
+        (header.total_body_length - u32::from(header.key_length) - 
u32::from(header.extras_length))
+            as usize
+    ];
+    reader
+        .read_exact(value.as_mut_slice())
+        .await
+        .map_err(new_std_io_error)?;
+
+    Ok(Response {
+        header,
+        key,
+        extras,
+        value,
+    })
+}
diff --git a/core/src/services/memcached/mod.rs 
b/core/src/services/memcached/mod.rs
index bbe45219a1..72788d1e71 100644
--- a/core/src/services/memcached/mod.rs
+++ b/core/src/services/memcached/mod.rs
@@ -19,3 +19,4 @@ mod backend;
 pub use backend::MemcachedBuilder as Memcached;
 pub use backend::MemcachedConfig;
 mod ascii;
+mod binary;

Reply via email to