This is an automated email from the ASF dual-hosted git repository. xuanwo pushed a commit to branch refactor-memcached in repository https://gitbox.apache.org/repos/asf/incubator-opendal.git
commit 5868fc656118853bac6749627145415b8df9dda8 Author: Xuanwo <[email protected]> AuthorDate: Thu May 4 18:30:38 2023 +0800 refactor(services/memcached): Rewrite memecached connection entirely Signed-off-by: Xuanwo <[email protected]> --- core/src/services/memcached/MIT-ascii.txt | 20 ---- core/src/services/memcached/ascii.rs | 171 +++++++++++++++++------------- core/src/services/memcached/backend.rs | 54 +++------- 3 files changed, 113 insertions(+), 132 deletions(-) diff --git a/core/src/services/memcached/MIT-ascii.txt b/core/src/services/memcached/MIT-ascii.txt deleted file mode 100644 index c176da35..00000000 --- a/core/src/services/memcached/MIT-ascii.txt +++ /dev/null @@ -1,20 +0,0 @@ -The MIT License (MIT) - -Copyright (c) 2017 An Long - -Permission is hereby granted, free of charge, to any person obtaining a copy of -this software and associated documentation files (the "Software"), to deal in -the Software without restriction, including without limitation the rights to -use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of -the Software, and to permit persons to whom the Software is furnished to do so, -subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS -FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR -COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER -IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN -CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/core/src/services/memcached/ascii.rs b/core/src/services/memcached/ascii.rs index 85c74178..d5d8aa6f 100644 --- a/core/src/services/memcached/ascii.rs +++ b/core/src/services/memcached/ascii.rs @@ -1,55 +1,62 @@ -// Copyright 2017 vavrusa <[email protected]> +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at // -// Licensed under the MIT License (see MIT-ascii.txt); - -use core::fmt::Display; -use std::io::Error; -use std::io::ErrorKind; -use std::marker::Unpin; - -use futures::io::AsyncBufReadExt; -use futures::io::AsyncRead; -use futures::io::AsyncReadExt; -use futures::io::AsyncWrite; -use futures::io::AsyncWriteExt; -use futures::io::BufReader; - -/// Memcache ASCII protocol implementation. -pub struct Protocol<S> { - io: BufReader<S>, +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use crate::*; + +use super::backend::parse_io_error; +use tokio::io::AsyncBufReadExt; +use tokio::io::AsyncReadExt; +use tokio::io::AsyncWriteExt; +use tokio::io::BufReader; +use tokio::net::TcpStream; + +pub struct Connection { + io: BufReader<TcpStream>, buf: Vec<u8>, } -impl<S> Protocol<S> -where - S: AsyncRead + AsyncWrite + Unpin, -{ - /// Creates the ASCII protocol on a stream. - pub fn new(io: S) -> Self { +impl Connection { + pub fn new(io: TcpStream) -> Self { Self { io: BufReader::new(io), buf: Vec::new(), } } - /// Returns the value for given key as bytes. If the value doesn't exist, [`ErrorKind::NotFound`] is returned. - pub async fn get<K: AsRef<[u8]>>(&mut self, key: K) -> Result<Vec<u8>, Error> { + pub async fn get(&mut self, key: &str) -> Result<Option<Vec<u8>>> { // Send command let writer = self.io.get_mut(); writer - .write_all(&[b"get ", key.as_ref(), b"\r\n"].concat()) - .await?; - writer.flush().await?; + .write_all(&[b"get ", key.as_bytes(), b"\r\n"].concat()) + .await + .map_err(parse_io_error)?; + writer.flush().await.map_err(parse_io_error)?; // Read response header - let header = self.read_line().await?; - let header = std::str::from_utf8(header).map_err(|_| ErrorKind::InvalidData)?; + let header = self.read_header().await?; // Check response header and parse value length if header.contains("ERROR") { - return Err(Error::new(ErrorKind::Other, header)); + return Err( + Error::new(ErrorKind::Unexpected, "unexpected data received") + .with_context("message", header), + ); } else if header.starts_with("END") { - return Err(ErrorKind::NotFound.into()); + return Ok(None); } // VALUE <key> <flags> <bytes> [<cas unique>]\r\n @@ -57,89 +64,109 @@ where .split(' ') .nth(3) .and_then(|len| len.trim_end().parse().ok()) - .ok_or(ErrorKind::InvalidData)?; + .ok_or_else(|| Error::new(ErrorKind::Unexpected, "invalid data received"))?; // Read value let mut buffer: Vec<u8> = vec![0; length]; - self.io.read_exact(&mut buffer).await?; + self.io + .read_exact(&mut buffer) + .await + .map_err(parse_io_error)?; // Read the trailing header self.read_line().await?; // \r\n self.read_line().await?; // END\r\n - Ok(buffer) + Ok(Some(buffer)) } - /// Set key to given value and don't wait for response. - pub async fn set<K: Display>( - &mut self, - key: K, - val: &[u8], - expiration: u32, - ) -> Result<(), Error> { + pub async fn set(&mut self, key: &str, val: &[u8], expiration: u32) -> Result<()> { let header = format!("set {} 0 {} {}\r\n", key, expiration, val.len()); - self.io.write_all(header.as_bytes()).await?; - self.io.write_all(val).await?; - self.io.write_all(b"\r\n").await?; - self.io.flush().await?; + self.io + .write_all(header.as_bytes()) + .await + .map_err(parse_io_error)?; + self.io.write_all(val).await.map_err(parse_io_error)?; + self.io.write_all(b"\r\n").await.map_err(parse_io_error)?; + self.io.flush().await.map_err(parse_io_error)?; // Read response header - let header = self.read_line().await?; - let header = std::str::from_utf8(header).map_err(|_| ErrorKind::InvalidData)?; + let header = self.read_header().await?; + // Check response header and make sure we got a `STORED` if header.contains("STORED") { return Ok(()); } else if header.contains("ERROR") { - return Err(Error::new(ErrorKind::Other, header)); + return Err( + Error::new(ErrorKind::Unexpected, "unexpected data received") + .with_context("message", header), + ); } Ok(()) } - /// Delete a key and don't wait for response. - pub async fn delete<K: Display>(&mut self, key: K) -> Result<(), Error> { + pub async fn delete(&mut self, key: &str) -> Result<()> { let header = format!("delete {}\r\n", key); - self.io.write_all(header.as_bytes()).await?; - self.io.flush().await?; + self.io + .write_all(header.as_bytes()) + .await + .map_err(parse_io_error)?; + self.io.flush().await.map_err(parse_io_error)?; // Read response header - let header = self.read_line().await?; - let header = std::str::from_utf8(header).map_err(|_| ErrorKind::InvalidData)?; + let header = self.read_header().await?; + // Check response header and parse value length - if header.contains("NOT_FOUND") { + if header.contains("NOT_FOUND") || header.starts_with("END") { return Ok(()); - } else if header.starts_with("END") { - return Err(ErrorKind::NotFound.into()); } else if header.contains("ERROR") || !header.contains("DELETED") { - return Err(Error::new(ErrorKind::Other, header)); + return Err( + Error::new(ErrorKind::Unexpected, "unexpected data received") + .with_context("message", header), + ); } Ok(()) } - /// Return the version of the remote server. - pub async fn version(&mut self) -> Result<String, Error> { - self.io.write_all(b"version\r\n").await?; - self.io.flush().await?; + pub async fn version(&mut self) -> Result<String> { + self.io + .write_all(b"version\r\n") + .await + .map_err(parse_io_error)?; + self.io.flush().await.map_err(parse_io_error)?; // Read response header - let header = { - let buf = self.read_line().await?; - std::str::from_utf8(buf).map_err(|_| Error::from(ErrorKind::InvalidData))? - }; + let header = self.read_header().await?; if !header.starts_with("VERSION") { - return Err(Error::new(ErrorKind::Other, header)); + return Err( + Error::new(ErrorKind::Unexpected, "unexpected data received") + .with_context("message", header), + ); } let version = header.trim_start_matches("VERSION ").trim_end(); Ok(version.to_string()) } - async fn read_line(&mut self) -> Result<&[u8], Error> { + async fn read_line(&mut self) -> Result<&[u8]> { let Self { io, buf } = self; buf.clear(); - io.read_until(b'\n', buf).await?; + io.read_until(b'\n', buf).await.map_err(parse_io_error)?; if buf.last().copied() != Some(b'\n') { - return Err(ErrorKind::UnexpectedEof.into()); + return Err(Error::new( + ErrorKind::ContentIncomplete, + "unexpected eof, the response must be incomplete", + )); } Ok(&buf[..]) } + + async fn read_header(&mut self) -> Result<&str> { + let header = self.read_line().await?; + let header = std::str::from_utf8(header).map_err(|err| { + Error::new(ErrorKind::Unexpected, "invalid data received").set_source(err) + })?; + + Ok(header) + } } diff --git a/core/src/services/memcached/backend.rs b/core/src/services/memcached/backend.rs index 3920b239..ee6b30e7 100644 --- a/core/src/services/memcached/backend.rs +++ b/core/src/services/memcached/backend.rs @@ -18,7 +18,6 @@ use std::collections::HashMap; use std::time::Duration; -use async_compat::Compat; use async_trait::async_trait; use bb8::RunError; use tokio::net::TcpStream; @@ -220,7 +219,7 @@ impl Adapter { RunError::TimedOut => { Error::new(ErrorKind::Unexpected, "get connection from pool failed").set_temporary() } - RunError::User(err) => parse_io_error(err), + RunError::User(err) => err, }) } } @@ -243,12 +242,8 @@ impl kv::Adapter for Adapter { async fn get(&self, key: &str) -> Result<Option<Vec<u8>>> { let mut conn = self.conn().await?; - // TODO: memcache-async have `Sized` limit on key, can we remove it? - match conn.get(&percent_encode_path(key)).await { - Ok(bs) => Ok(Some(bs)), - Err(err) if err.kind() == std::io::ErrorKind::NotFound => Ok(None), - Err(err) => Err(parse_io_error(err)), - } + + conn.get(&percent_encode_path(key)).await } async fn set(&self, key: &str, value: &[u8]) -> Result<()> { @@ -263,40 +258,13 @@ impl kv::Adapter for Adapter { .unwrap_or_default(), ) .await - .map_err(parse_io_error)?; - - Ok(()) } async fn delete(&self, key: &str) -> Result<()> { let mut conn = self.conn().await?; - let _: () = conn - .delete(&percent_encode_path(key)) - .await - .map_err(parse_io_error)?; - Ok(()) - } -} - -fn parse_io_error(err: std::io::Error) -> Error { - use std::io::ErrorKind::*; - - let (kind, retryable) = match err.kind() { - NotFound => (ErrorKind::NotFound, false), - AlreadyExists => (ErrorKind::NotFound, false), - PermissionDenied => (ErrorKind::PermissionDenied, false), - Interrupted | UnexpectedEof | TimedOut | WouldBlock => (ErrorKind::Unexpected, true), - _ => (ErrorKind::Unexpected, true), - }; - - let mut err = Error::new(kind, &err.kind().to_string()).set_source(err); - - if retryable { - err = err.set_temporary(); + conn.delete(&percent_encode_path(key)).await } - - err } /// A `bb8::ManageConnection` for `memcache_async::ascii::Protocol`. @@ -317,13 +285,15 @@ impl MemcacheConnectionManager { #[async_trait] impl bb8::ManageConnection for MemcacheConnectionManager { - type Connection = ascii::Protocol<Compat<TcpStream>>; - type Error = std::io::Error; + type Connection = ascii::Connection; + type Error = Error; /// TODO: Implement unix stream support. async fn connect(&self) -> std::result::Result<Self::Connection, Self::Error> { - let sock = TcpStream::connect(&self.address).await?; - Ok(ascii::Protocol::new(Compat::new(sock))) + let conn = TcpStream::connect(&self.address) + .await + .map_err(parse_io_error)?; + Ok(ascii::Connection::new(conn)) } async fn is_valid(&self, conn: &mut Self::Connection) -> std::result::Result<(), Self::Error> { @@ -334,3 +304,7 @@ impl bb8::ManageConnection for MemcacheConnectionManager { false } } + +pub fn parse_io_error(err: std::io::Error) -> Error { + Error::new(ErrorKind::Unexpected, &err.kind().to_string()).set_source(err) +}
