Github user sohami commented on a diff in the pull request:

    https://github.com/apache/drill/pull/950#discussion_r141230974
  
    --- Diff: contrib/native/client/src/clientlib/channel.cpp ---
    @@ -0,0 +1,452 @@
    +/*
    + * Licensed to the Apache Software Foundation (ASF) under one
    + * or more contributor license agreements.  See the NOTICE file
    + * distributed with this work for additional information
    + * regarding copyright ownership.  The ASF licenses this file
    + * to you under the Apache License, Version 2.0 (the
    + * "License"); you may not use this file except in compliance
    + * with the License.  You may obtain a copy of the License at
    + *
    + * http://www.apache.org/licenses/LICENSE-2.0
    + *
    + * Unless required by applicable law or agreed to in writing, software
    + * distributed under the License is distributed on an "AS IS" BASIS,
    + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    + * See the License for the specific language governing permissions and
    + * limitations under the License.
    + */
    +
    +#include <boost/lexical_cast.hpp>
    +#include <boost/regex.hpp>
    +
    +#include "drill/drillConfig.hpp"
    +#include "drill/drillError.hpp"
    +#include "drill/userProperties.hpp"
    +#include "channel.hpp"
    +#include "errmsgs.hpp"
    +#include "logger.hpp"
    +#include "utils.hpp"
    +#include "zookeeperClient.hpp"
    +
    +#include "GeneralRPC.pb.h"
    +
    +namespace Drill{
    +
    +ConnectionEndpoint::ConnectionEndpoint(const char* connStr){
    +    m_connectString=connStr;
    +    m_pError=NULL;
    +}
    +
    +ConnectionEndpoint::ConnectionEndpoint(const char* host, const char* port){
    +    m_host=host;
    +    m_port=port;
    +    m_protocol="drillbit"; // direct connection
    +    m_pError=NULL;
    +}
    +
    +ConnectionEndpoint::~ConnectionEndpoint(){
    +    if(m_pError!=NULL){
    +        delete m_pError; m_pError=NULL;
    +    }
    +}
    +
    +connectionStatus_t ConnectionEndpoint::getDrillbitEndpoint(){
    +    connectionStatus_t ret=CONN_SUCCESS;
    +    if(!m_connectString.empty()){
    +        parseConnectString();
    +        if(m_protocol.empty()){
    +            return handleError(CONN_INVALID_INPUT, 
getMessage(ERR_CONN_UNKPROTO, "<invalid_string>"));
    +        }
    +        if(isZookeeperConnection()){
    +            if((ret=getDrillbitEndpointFromZk())!=CONN_SUCCESS){
    +                return ret;
    +            }
    +        }else if(!this->isDirectConnection()){
    +            return handleError(CONN_INVALID_INPUT, 
getMessage(ERR_CONN_UNKPROTO, this->getProtocol().c_str()));
    +        }
    +    }else{
    +        if(m_host.empty() || m_port.empty()){
    +            return handleError(CONN_INVALID_INPUT, 
getMessage(ERR_CONN_NOCONNSTR));
    +        }
    +    }
    +    return ret;
    +}
    +
    +void ConnectionEndpoint::parseConnectString(){
    +    boost::regex connStrExpr("(.*)=(.*):([0-9]+)(?:/(.+))?");
    +    boost::cmatch matched;
    +
    +    if(boost::regex_match(m_connectString.c_str(), matched, connStrExpr)){
    +        m_protocol.assign(matched[1].first, matched[1].second);
    +        std::string host, port;
    +        host.assign(matched[2].first, matched[2].second);
    +        port.assign(matched[3].first, matched[3].second);
    +        if(isDirectConnection()){
    +            // if the connection is to a zookeeper, 
    +            // we will get the host and the port only after connecting to 
the Zookeeper
    +            m_host=host;
    +            m_port=port;
    +        }
    +        m_hostPortStr=host+std::string(":")+port;
    +        std::string pathToDrill;
    +        if(matched.size()==5){
    +            pathToDrill.assign(matched[4].first, matched[4].second);
    +            if(!pathToDrill.empty()){
    +                m_pathToDrill=std::string("/")+pathToDrill;
    +            }
    +        }
    +        DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) 
    +                << "Conn str: "<< m_connectString 
    +                << ";  protocol: " << m_protocol 
    +                << ";  host: " << host 
    +                << "; port: " << port 
    +                << ";  path to drill: " << m_pathToDrill 
    +                << std::endl;)
    +    } else {
    +        DRILL_MT_LOG(DRILL_LOG(LOG_DEBUG) << "Invalid connect string. 
Regexp did not match" << std::endl;)
    +    }
    +
    +    return;
    +}
    +
    +bool ConnectionEndpoint::isDirectConnection(){
    +    assert(!m_protocol.empty());
    +    return (!strcmp(m_protocol.c_str(), "local") || 
!strcmp(m_protocol.c_str(), "drillbit"));
    +}
    +
    +bool ConnectionEndpoint::isZookeeperConnection(){
    +    assert(!m_protocol.empty());
    +    return (!strcmp(m_protocol.c_str(), "zk"));
    +}
    +
    +connectionStatus_t ConnectionEndpoint::getDrillbitEndpointFromZk(){
    +    ZookeeperClient zook(m_pathToDrill);
    +    assert(!m_hostPortStr.empty());
    +    std::vector<std::string> drillbits;
    +    if(zook.getAllDrillbits(m_hostPortStr.c_str(), drillbits)!=0){
    +        return handleError(CONN_ZOOKEEPER_ERROR, 
getMessage(ERR_CONN_ZOOKEEPER, zook.getError().c_str()));
    +    }
    +    if (drillbits.empty()){
    +        return handleError(CONN_FAILURE, getMessage(ERR_CONN_ZKNODBIT));
    +    }
    +    Utils::shuffle(drillbits);
    +    exec::DrillbitEndpoint endpoint;
    +    int err = zook.getEndPoint(drillbits[drillbits.size() -1], 
endpoint);// get the last one in the list
    +    if(!err){
    +        m_host=boost::lexical_cast<std::string>(endpoint.address());
    +        m_port=boost::lexical_cast<std::string>(endpoint.user_port());
    +    }
    +    DRILL_MT_LOG(DRILL_LOG(LOG_TRACE) << "Choosing drillbit <" << 
(drillbits.size() - 1)  << ">. Selected " << endpoint.DebugString() << 
std::endl;)
    +        zook.close();
    +    return CONN_SUCCESS;
    +}
    +
    +connectionStatus_t ConnectionEndpoint::handleError(connectionStatus_t 
status, std::string msg){
    +    DrillClientError* pErr = new DrillClientError(status, 
DrillClientError::CONN_ERROR_START+status, msg);
    +    if(m_pError!=NULL){ delete m_pError; m_pError=NULL;}
    +    m_pError=pErr;
    +    return status;
    +}
    +
    +/****************************
    + * Channel Context Factory
    + ****************************/
    +ChannelContext* ChannelContextFactory::getChannelContext(channelType_t t, 
DrillUserProperties* props){
    +    ChannelContext* pChannelContext=NULL;
    +    switch(t){
    +        case CHANNEL_TYPE_SOCKET:
    +            pChannelContext=new ChannelContext(props);
    +            break;
    +#if defined(IS_SSL_ENABLED)
    +        case CHANNEL_TYPE_SSLSTREAM: {
    +
    +            std::string protocol;
    +            props->getProp(USERPROP_TLSPROTOCOL, protocol);
    +            boost::asio::ssl::context::method tlsVersion = 
SSLChannelContext::getTlsVersion(protocol);
    +
    +            std::string noVerifyCert;
    +            props->getProp(USERPROP_DISABLE_CERTVERIFICATION, 
noVerifyCert);
    +            boost::asio::ssl::context::verify_mode verifyMode = 
boost::asio::ssl::context::verify_peer;
    +            if (noVerifyCert == "true") {
    +                verifyMode = boost::asio::ssl::context::verify_none;
    +            }
    +
    +            pChannelContext = new SSLChannelContext(props, tlsVersion, 
verifyMode);
    +        }
    +            break;
    +#endif
    +        default:
    +            DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not 
supported." << std::endl;
    +            break;
    +    }
    +    return pChannelContext;
    +} 
    +
    +/*******************
    + *  ChannelFactory
    + * *****************/
    +Channel* ChannelFactory::getChannel(channelType_t t, const char* connStr){
    +    Channel* pChannel=NULL;
    +    switch(t){
    +        case CHANNEL_TYPE_SOCKET:
    +            pChannel=new SocketChannel(connStr);
    +            break;
    +#if defined(IS_SSL_ENABLED)
    +        case CHANNEL_TYPE_SSLSTREAM:
    +            pChannel=new SSLStreamChannel(connStr);
    +            break;
    +#endif
    +        default:
    +            DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not 
supported." << std::endl;
    +            break;
    +    }
    +    return pChannel;
    +}
    +
    +Channel* ChannelFactory::getChannel(channelType_t t, const char* host, 
const char* port){
    +    Channel* pChannel=NULL;
    +    switch(t){
    +        case CHANNEL_TYPE_SOCKET:
    +            pChannel=new SocketChannel(host, port);
    +            break;
    +#if defined(IS_SSL_ENABLED)
    +        case CHANNEL_TYPE_SSLSTREAM:
    +            pChannel=new SSLStreamChannel(host, port);
    +            break;
    +#endif
    +        default:
    +            DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not 
supported." << std::endl;
    +            break;
    +    }
    +    return pChannel;
    +}
    +
    +Channel* ChannelFactory::getChannel(channelType_t t, 
boost::asio::io_service& ioService, const char* connStr){
    +    Channel* pChannel=NULL;
    +    switch(t){
    +        case CHANNEL_TYPE_SOCKET:
    +            pChannel=new SocketChannel(ioService, connStr);
    +            break;
    +#if defined(IS_SSL_ENABLED)
    +        case CHANNEL_TYPE_SSLSTREAM:
    +            pChannel=new SSLStreamChannel(ioService, connStr);
    +            break;
    +#endif
    +        default:
    +            DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not 
supported." << std::endl;
    +            break;
    +    }
    +    return pChannel;
    +}
    +
    +Channel* ChannelFactory::getChannel(channelType_t t, 
boost::asio::io_service& ioService, const char* host, const char* port){
    +    Channel* pChannel=NULL;
    +    switch(t){
    +        case CHANNEL_TYPE_SOCKET:
    +            pChannel=new SocketChannel(ioService, host, port);
    +            break;
    +#if defined(IS_SSL_ENABLED)
    +        case CHANNEL_TYPE_SSLSTREAM:
    +            pChannel=new SSLStreamChannel(ioService, host, port);
    +            break;
    +#endif
    +        default:
    +            DRILL_LOG(LOG_ERROR) << "Channel type " << t << " is not 
supported." << std::endl;
    +            break;
    +    }
    +    return pChannel;
    +}
    +
    +/*******************
    + *  Channel
    + * *****************/
    +
    +Channel::Channel(const char* connStr) : m_ioService(m_ioServiceFallback){
    +    m_pEndpoint=new ConnectionEndpoint(connStr);
    +    m_ownIoService = true;
    +    m_pSocket=NULL;
    +    m_state=CHANNEL_UNINITIALIZED;
    +    m_pError=NULL;
    +}
    +
    +Channel::Channel(const char* host, const char* port) : 
m_ioService(m_ioServiceFallback){
    +    m_pEndpoint=new ConnectionEndpoint(host, port);
    +    m_ownIoService = true;
    +    m_pSocket=NULL;
    +    m_state=CHANNEL_UNINITIALIZED;
    +    m_pError=NULL;
    +}
    +
    +Channel::Channel(boost::asio::io_service& ioService, const char* 
connStr):m_ioService(ioService){
    +    m_pEndpoint=new ConnectionEndpoint(connStr);
    +    m_ownIoService = false;
    +    m_pSocket=NULL;
    +    m_state=CHANNEL_UNINITIALIZED;
    +    m_pError=NULL;
    +}
    +
    +Channel::Channel(boost::asio::io_service& ioService, const char* host, 
const char* port) : m_ioService(ioService){
    +    m_pEndpoint=new ConnectionEndpoint(host, port);
    +    m_ownIoService = true;
    +    m_pSocket=NULL;
    +    m_state=CHANNEL_UNINITIALIZED;
    +    m_pError=NULL;
    +}
    +
    +Channel::~Channel(){
    +    if(m_pEndpoint!=NULL){
    +        delete m_pEndpoint; m_pEndpoint=NULL;
    +    }
    +    if(m_pSocket!=NULL){
    +        delete m_pSocket; m_pSocket=NULL;
    +    }
    +    if(m_pError!=NULL){
    +        delete m_pError; m_pError=NULL;
    +    }
    +}
    +
    +template <typename SettableSocketOption> void 
Channel::setOption(SettableSocketOption& option){
    +    //May be useful some day. 
    +    //At the moment, we only need to set some well known options after we 
connect.
    +    assert(0); 
    +}
    +
    +connectionStatus_t Channel::init(ChannelContext_t* pContext){
    +    connectionStatus_t ret=CONN_SUCCESS;
    +    this->m_state=CHANNEL_INITIALIZED;
    +    this->m_pContext = pContext;
    +    return ret;
    +}
    +
    +connectionStatus_t Channel::connect(){
    +    connectionStatus_t ret=CONN_FAILURE;
    +    if(this->m_state==CHANNEL_INITIALIZED){
    +        ret=m_pEndpoint->getDrillbitEndpoint();
    +        if(ret==CONN_SUCCESS){
    +            DRILL_LOG(LOG_TRACE) << "Connecting to drillbit: " 
    +                << m_pEndpoint->getHost() 
    +                << ":" << m_pEndpoint->getPort() 
    +                << "." << std::endl;
    +            ret=this->connectInternal();
    +        }else{
    +            handleError(ret, m_pEndpoint->getError()->msg);
    +        }
    +    }
    +    this->m_state=(ret==CONN_SUCCESS)?CHANNEL_CONNECTED:this->m_state;
    +    return ret;
    +}
    +
    +connectionStatus_t Channel::handleError(connectionStatus_t status, 
std::string msg){
    +    DrillClientError* pErr = new DrillClientError(status, 
DrillClientError::CONN_ERROR_START+status, msg);
    +    if(m_pError!=NULL){ delete m_pError; m_pError=NULL;}
    +    m_pError=pErr;
    +    return status;
    +}
    +
    +connectionStatus_t Channel::connectInternal() {
    +    using boost::asio::ip::tcp;
    +    tcp::endpoint endpoint;
    +    const char *host = m_pEndpoint->getHost().c_str();
    +    const char *port = m_pEndpoint->getPort().c_str();
    +    try {
    +        tcp::resolver resolver(m_ioService);
    +        tcp::resolver::query query(tcp::v4(), host, port);
    +        tcp::resolver::iterator iter = resolver.resolve(query);
    +        tcp::resolver::iterator end;
    +        while (iter != end) {
    +            endpoint = *iter++;
    +            DRILL_LOG(LOG_TRACE) << endpoint << std::endl;
    +        }
    +        boost::system::error_code ec;
    +        m_pSocket->getInnerSocket().connect(endpoint, ec);
    +        if (ec) {
    +            return handleError(CONN_FAILURE, getMessage(ERR_CONN_FAILURE, 
host, port, ec.message().c_str()));
    +        }
    +    } catch (std::exception e) {
    +        // Handle case when the hostname cannot be resolved. "resolve" is 
hard-coded in boost asio resolver.resolve
    +        if (!strcmp(e.what(), "resolve")) {
    +            return handleError(CONN_HOSTNAME_RESOLUTION_ERROR, 
getMessage(ERR_CONN_EXCEPT, e.what()));
    +        }
    +        return handleError(CONN_FAILURE, getMessage(ERR_CONN_EXCEPT, 
e.what()));
    +    }
    +
    +    // set socket keep alive
    +    boost::asio::socket_base::keep_alive keepAlive(true);
    +    m_pSocket->getInnerSocket().set_option(keepAlive);
    +    // set no_delay
    +    boost::asio::ip::tcp::no_delay noDelay(true);
    +    m_pSocket->getInnerSocket().set_option(noDelay);
    +    // set reuse addr
    +    boost::asio::socket_base::reuse_address reuseAddr(true);
    +    m_pSocket->getInnerSocket().set_option(reuseAddr);
    +
    +    std::string useSystemTrustStore;
    +    m_pContext->getUserProperties()->getProp(USERPROP_USESYSTEMTRUSTSTORE, 
useSystemTrustStore);
    +
    +    return this->protocolHandshake(useSystemTrustStore=="true");
    +
    +}
    +
    +connectionStatus_t SocketChannel::init(ChannelContext_t* pContext){
    +    connectionStatus_t ret=CONN_SUCCESS;
    +    m_pSocket=new Socket(m_ioService);
    +    if(m_pSocket!=NULL){
    +        ret=Channel::init(pContext);
    +    }else{
    +        DRILL_LOG(LOG_ERROR) << "Channel initialization failure. " << 
std::endl;
    +        handleError(CONN_NOSOCKET, getMessage(ERR_CONN_NOSOCKET));
    +        ret=CONN_FAILURE;
    +    }
    +    return ret;
    +}
    +
    +#if defined(IS_SSL_ENABLED)
    +connectionStatus_t SSLStreamChannel::init(ChannelContext_t* pContext){
    +    connectionStatus_t ret=CONN_SUCCESS;
    +
    +    const DrillUserProperties* props = pContext->getUserProperties();
    +   std::string useSystemTrustStore;
    +   props->getProp(USERPROP_USESYSTEMTRUSTSTORE, useSystemTrustStore);
    +   if (useSystemTrustStore != "true"){
    +           std::string certFile;
    +           props->getProp(USERPROP_CERTFILEPATH, certFile);
    +           try{
    +                   
((SSLChannelContext_t*)pContext)->getSslContext().load_verify_file(certFile);
    +           }
    +           catch (boost::system::system_error e){
    +                   DRILL_LOG(LOG_ERROR) << "Channel initialization 
failure. Certificate file  "
    +                           << certFile
    +                           << " could not be loaded."
    +                           << std::endl;
    +                   handleError(CONN_SSLERROR, 
getMessage(ERR_CONN_SSLCERTFAIL, certFile.c_str(), e.what()));
    +                   ret = CONN_FAILURE;
    +           }
    +   }
    +
    +    std::string disableHostVerification;
    +    props->getProp(USERPROP_DISABLE_HOSTVERIFICATION, 
disableHostVerification);
    +    if (disableHostVerification != "true") {
    +        std::string hostPortStr = m_pEndpoint->getHost() + ":" + 
m_pEndpoint->getPort();
    +        ((SSLChannelContext_t *) 
pContext)->getSslContext().set_verify_callback(
    +                
boost::asio::ssl::rfc2818_verification(hostPortStr.c_str()));
    +    }
    +
    +    std::string disableCertificateVerification;
    +    props->getProp(USERPROP_DISABLE_CERTVERIFICATION, 
disableCertificateVerification);
    +    if (disableCertificateVerification == "true") {
    +        ((SSLChannelContext_t *) 
pContext)->getSslContext().set_verify_mode(boost::asio::ssl::context::verify_none);
    +    }
    --- End diff --
    
    setting the `verifyMode` is duplicated inside `getChannelContext(..)` call 
where we create a `SSLChannelContext` object. How about moving all these setup 
of sslContext inside `SSLChannelContext` constructor ? We have access to props 
object there.


---

Reply via email to