Attaching slightly modified sample which reproduces the problem (previous
one did not work sometimes).

Can be built as
g++ -o dtlstest main.cpp -std=c++11 -lssl -lcrypto -lpthread -g


On Wed, Dec 18, 2013 at 3:06 PM, Dmitry Sobinov via RT <r...@openssl.org>wrote:

> Got some more info on this bug. It's a memory use after free.
>
> There's a problem with ssl_st::write_hash. It's cached
> in dtls1_buffer_message() function for each handshake message and got freed
> and replaced by new hash context when forming Change Cipher Spec message
> (in ssl_replace_hash(), see stack trace below). So, when we want to resend
> lost packet, old (already freed) hash context is used to create MACs for
> messages we want to resend, leading to crash on Win32 and undefined
> behavior on Linux.
>
> It's not a problem for initial handshake, because s->write_hash is zero and
> nothing gets destroyed when forming  Change Cipher Spec.
>
> Any ideas how to workaround/fix this are appreciated.
>
> Here is the report from clang AddressSanitizer:
>
> ==32258==ERROR: AddressSanitizer: heap-use-after-free on address
> 0x60400000b890 at pc 0xa541d0 bp 0x7f9225188c50 sp 0x7f9225188c48
> READ of size 8 at 0x60400000b890 thread T1
>     #0 0xa541cf in EVP_MD_CTX_md evp_lib.c:285
>     #1 0x66e790 in dtls1_do_write d1_both.c:275
>     #2 0x682231 in dtls1_retransmit_message d1_both.c:1290
>     #3 0x67fe96 in dtls1_retransmit_buffered_messages d1_both.c:1145
>     #4 0x64f934 in dtls1_handle_timeout d1_lib.c:451
>     #5 0x65b6d4 in dtls1_read_bytes d1_pkt.c:832
>     #6 0x67640a in dtls1_get_message_fragment d1_both.c:789
>     #7 0x674459 in dtls1_get_message d1_both.c:437
>     #8 0x79835b in ssl3_get_new_session_ticket s3_clnt.c:2040
>     #9 0x63a5c1 in dtls1_connect d1_clnt.c:641
>     #10 0x6bef29 in SSL_do_handshake ssl_lib.c:2564
>     #11 0x4e11aa in _ZN17DtlsSrtpTransport18handshakeIterationEv
> test.cpp:348
>     #12 0x5070b2 in _ZN17DtlsSrtpTransport19receiveTimerExpiredEv
> test.cpp:423
>     #13 0x505128 in operator() test.cpp:392
>     #14 0x504951 in
>
> _ZNSt3__18__invokeIRZN17DtlsSrtpTransport18handshakeIterationEvEUlvE_JEEEDTclclsr3std3__1E7forwardIT_Efp_Espclsr3std3__1E7forwardIT0_Efp0_EEEOS4_DpOS5_
> __functional_base:341
>     #15 0x6069e9 in _ZNKSt3__18functionIFvvEEclEv functional:1436
>     #16 0x603c64 in _ZN15AsyncDispatcher3runEv test.cpp:170
>     #17 0x603928 in operator() test.cpp:88
>     #18 0x6021c5 in
>
> _ZNSt3__17forwardIZN15AsyncDispatcherC1EvEUlvE_EEOT_RNS_16remove_referenceIS3_E4typeE
> type_traits:1341
>     #19 0x460083 in _ZN6__asan10AsanThread11ThreadStartEm ??:?
>     #20 0x7f922847e0a1 in start_thread pthread_create.c:?
>     #21 0x7f922788b49c in __clone ??:?
> 0x60400000b890 is located 0 bytes inside of 48-byte region
> [0x60400000b890,0x60400000b8c0)
> freed by thread T1 here:
>     #0 0x459a04 in __interceptor_free ??:?
>     #1 0x897bdf in CRYPTO_free mem.c:397
>     #2 0x9f97cf in EVP_MD_CTX_destroy digest.c:370
>     #3 0x693983 in ssl_clear_hash_ctx ssl_lib.c:3244
>     #4 0x6ca125 in ssl_replace_hash ssl_lib.c:3236
>     #5 0x83e213 in tls1_change_cipher_state t1_enc.c:425
>     #6 0x639971 in dtls1_connect d1_clnt.c:560
>     #7 0x6bef29 in SSL_do_handshake ssl_lib.c:2564
>     #8 0x4e11aa in _ZN17DtlsSrtpTransport18handshakeIterationEv
> test.cpp:348
>     #9 0x4d7d98 in operator() test.cpp:219
>     #10 0x4d52c1 in
>
> _ZNSt3__18__invokeIRZN17DtlsSrtpTransport18handleIncomingDataERKNS_6vectorIhNS_9allocatorIhEEEEEUlvE_JEEEDTclclsr3std3__1E7forwardIT_Efp_Espclsr3std3__1E7forwardIT0_Efp0_EEEOSA_DpOSB_
> __functional_base:341
>     #11 0x6069e9 in _ZNKSt3__18functionIFvvEEclEv functional:1436
>     #12 0x603c64 in _ZN15AsyncDispatcher3runEv test.cpp:170
>     #13 0x603928 in operator() test.cpp:88
>     #14 0x6021c5 in
>
> _ZNSt3__17forwardIZN15AsyncDispatcherC1EvEUlvE_EEOT_RNS_16remove_referenceIS3_E4typeE
> type_traits:1341
>     #15 0x460083 in _ZN6__asan10AsanThread11ThreadStartEm ??:?
> previously allocated by thread T1 here:
>     #0 0x459ae4 in __interceptor_malloc ??:?
>     #1 0x89228c in default_malloc_ex mem.c:79
>     #2 0x895a5c in CRYPTO_malloc mem.c:308
>     #3 0x9f4515 in EVP_MD_CTX_create digest.c:131
>     #4 0x6ca12a in ssl_replace_hash ssl_lib.c:3237
>     #5 0x83e213 in tls1_change_cipher_state t1_enc.c:425
>     #6 0x639971 in dtls1_connect d1_clnt.c:560
>     #7 0x6bef29 in SSL_do_handshake ssl_lib.c:2564
>     #8 0x4e11aa in _ZN17DtlsSrtpTransport18handshakeIterationEv
> test.cpp:348
>     #9 0x4d7d98 in operator() test.cpp:219
>     #10 0x4d52c1 in
>
> _ZNSt3__18__invokeIRZN17DtlsSrtpTransport18handleIncomingDataERKNS_6vectorIhNS_9allocatorIhEEEEEUlvE_JEEEDTclclsr3std3__1E7forwardIT_Efp_Espclsr3std3__1E7forwardIT0_Efp0_EEEOSA_DpOSB_
> __functional_base:341
>     #11 0x6069e9 in _ZNKSt3__18functionIFvvEEclEv functional:1436
>     #12 0x603c64 in _ZN15AsyncDispatcher3runEv test.cpp:170
>     #13 0x603928 in operator() test.cpp:88
>     #14 0x6021c5 in
>
> _ZNSt3__17forwardIZN15AsyncDispatcherC1EvEUlvE_EEOT_RNS_16remove_referenceIS3_E4typeE
> type_traits:1341
>     #15 0x460083 in _ZN6__asan10AsanThread11ThreadStartEm ??:?
> Thread T1 created by T0 here:
>     #0 0x455a50 in __interceptor_pthread_create ??:?
>     #1 0x5ff729 in thread<<lambda at test.cpp:88:31>, , void> thread:355
>     #2 0x5fd7b9 in thread<<lambda at test.cpp:88:31>, , void> thread:360
>     #3 0x5fd068 in AsyncDispatcher test.cpp:88
>     #4 0x4a3e3c in AsyncDispatcher test.cpp:89
>     #5 0x4688f7 in main test.cpp:516
>     #6 0x7f92277c7bc4 in __libc_start_main ??:?
>
>
>
> On Fri, Dec 13, 2013 at 5:55 PM, Dmitry Sobinov via RT <r...@openssl.org
> >wrote:
>
> > Hello
> >
> > While testing renegotiations for DTLS-SRTP, found a crash on Windows.
> > OpenSSL version is 1.0.1e, also tested on the latest 1.0.1 snapshot.
> There
> > were 2 possible stack traces:
> >
> >   AddLiveService.dll!EVP_MD_size(const env_md_st * md) Line 273 C
> > > AddLiveService.dll!dtls1_do_write(ssl_st * s, int type) Line 275 C
> >   AddLiveService.dll!dtls1_retransmit_message(ssl_st * s, unsigned short
> > seq, unsigned long frag_off, int * found) Line 1293 C
> >   AddLiveService.dll!dtls1_retransmit_buffered_messages(ssl_st * s) Line
> > 1145 C
> >   AddLiveService.dll!dtls1_handle_timeout(ssl_st * s) Line 450 C
> >   AddLiveService.dll!dtls1_read_bytes(ssl_st * s, int type, unsigned
> char *
> > buf, int len, int peek) Line 832 C
> >   AddLiveService.dll!dtls1_get_message_fragment(ssl_st * s, int st1, int
> > stn, long max, int * ok) Line 789 C
> >   AddLiveService.dll!dtls1_get_message(ssl_st * s, int st1, int stn, int
> > mt, long max, int * ok) Line 436 C
> >   AddLiveService.dll!ssl3_get_new_session_ticket(ssl_st * s) Line 2046 C
> >   AddLiveService.dll!dtls1_connect(ssl_st * s) Line 631 C
> >   AddLiveService.dll!SSL_do_handshake(ssl_st * s) Line 2562 C
> >
> > and
> >
> >   msvcr120d.dll!memcpy(unsigned char * dst, unsigned char * src, unsigned
> > long count) Line 188 Unknown
> > > dtls_test.exe!dtls1_get_message_fragment(ssl_st * s, int st1, int stn,
> > long max, int * ok) Line 789 C
> >   dtls_test.exe!dtls1_get_message(ssl_st * s, int st1, int stn, int mt,
> > long max, int * ok) Line 436 C
> >   dtls_test.exe!ssl3_get_new_session_ticket(ssl_st * s) Line 2046 C
> >   dtls_test.exe!dtls1_connect(ssl_st * s) Line 631 C
> >   dtls_test.exe!SSL_do_handshake(ssl_st * s) Line 2562 C
> >
> > Both are segfaults (access violations). On linux rehandshake doesn't
> finish
> > at all (failure after 1-2 minutes on timeout).
> >
> > You can find sample c++11 source file to reproduce this issue. In-memory
> > BIO pair is used, client and server in the same process. When no flights
> > are dropped, everything is fine.
> >
> > The sample can be compiled by MSVC 2013 on Windows and g++ 4.7+ (g++ -o
> > dtlstest main.cpp -std=c++11 -lssl -lcrypto -lpthread -g) or clang 3.2+.
> >
> >
> > ---
> > Dmitry Sobinov
> > AddLive.com
> > Live video and voice for your application
> >
> >
>
>
> ---
> Dmitry Sobinov
> AddLive.com
> Live video and voice for your application
>
> ______________________________________________________________________
> OpenSSL Project                                 http://www.openssl.org
> Development Mailing List                       openssl-dev@openssl.org
> Automated List Manager                           majord...@openssl.org
>



-- 
---
Dmitry Sobinov
AddLive.com
Live video and voice for your application
#include <iostream>
#include <string>
#include <mutex>
#include <thread>
#include <condition_variable>
#include <future>
#include <memory>
#include <vector>
#include <deque>
#include <chrono>
#include <algorithm>
#include <functional>
#include <stdint.h>
#include <assert.h>

#include <openssl/ssl.h>
#include <openssl/bio.h>
#include <openssl/err.h>
#include <openssl/x509.h>

// Can be built in MSVC 2013;
// with gcc:
// g++ -o dtlstest dtlstest.cpp -std=c++11 -lssl -lcrypto -lpthread -g -D_DEBUG
// clang with libc++:
// clang++ -o dtlstest dtlstest.cpp -std=c++11 -D_DEBUG -lssl -lcrypto -lpthread -stdlib=libc++ -lc++abi -g


std::chrono::steady_clock::time_point logStartingTime = std::chrono::steady_clock::now();

#define MLOG_D(x) std::cout << std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::steady_clock::now() - logStartingTime).count() << _label << x << std::endl;
#define LOG_E(x) std::cout << "[ERROR] " << x << std::endl;
#define MLOG_E(x) LOG_E(_label << x)

#ifdef X509_NAME
#undef X509_NAME // disable macro from wincrypt.h (included from dtls1.h/winsock.h)
#endif


struct DtlsIdentity
{
    EVP_PKEY* key;
    X509* certificate;
};

namespace
{
    /**
     * Helper functions (defined in the bottom of the file)
     */

    unsigned long idFunction();

    void opensslLockingFunc(int mode, int n,
        const char* /*file*/, int /*line*/);

    void opensslInit();

    void opensslCleanup();

    EVP_PKEY* generateRsaKeyPair();

    X509* generateCertificate(EVP_PKEY* pkey, const char* commonName);

    DtlsIdentity generateIdentity();

    void logOpenSslErrors(const std::string& prefix);
}

typedef std::function<void()> DispatcherTask;

/**
 * Helper class to serialize all requests and data transmissions in
 * queue in one separate thread (implementation of ActiveObject pattern).
 */
class AsyncDispatcher
{
    struct TimedTask
    {
        DispatcherTask task;
        std::chrono::steady_clock::time_point timeToFire;
        int id;
    };

public:

    AsyncDispatcher()
    {
        _thread = std::thread([this](){ run(); });
    }

    int push(const DispatcherTask& task,
        std::chrono::milliseconds delay = std::chrono::milliseconds(0))
    {
        std::unique_lock<std::mutex> lk(_queueMutex);
        int id = _idCounter++;
        _queue.push_back({ task, std::chrono::steady_clock::now() + delay, id });
        std::stable_sort(_queue.begin(), _queue.end(),
            [](const TimedTask& t1, const TimedTask& t2) -> bool { return t1.timeToFire < t2.timeToFire; });
        lk.unlock();
        _condVar.notify_one();
        return id;
    }

    void stop()
    {
        std::unique_lock<std::mutex> lk(_queueMutex);
        _active = false;
        _queue.clear();
        lk.unlock();
        _condVar.notify_one();

        _thread.join();
    }

    void cancelTimedTask(int id)
    {
        std::unique_lock<std::mutex> lk(_queueMutex);
        _queue.erase(std::remove_if(_queue.begin(), _queue.end(),
            [=](const TimedTask& elem){ return elem.id == id; }),
            _queue.end());
        lk.unlock();
        _condVar.notify_one();
    }

private:

    bool waitAndPop(TimedTask& poppedValue)
    {
        std::unique_lock<std::mutex> lk(_queueMutex);

        while (true)
        {
            if (_queue.empty())
            {
                // queue empty, new handlers are not allowed to add => exiting
                if (!_active)
                    return false;

                // wait for pushed element if no elements to process right away
                _condVar.wait(lk);
                continue;
            }

            auto nearestExpireTime = _queue.front().timeToFire;
            if (nearestExpireTime >= std::chrono::steady_clock::now())
            {
                _condVar.wait_until(lk, nearestExpireTime);
                continue;
            }

            // have some expired callbacks
            break;
        }

        poppedValue = _queue.front();
        _queue.pop_front();

        return true;
    }

    void run()
    {
        bool hasPendingWork = true;
        
        while (true)
        {
            TimedTask rec;
            if (!waitAndPop(rec))
                return;
            rec.task();
        } while (hasPendingWork);
    }

    std::thread _thread;
    std::mutex _queueMutex;
    std::condition_variable _condVar;
    std::deque<TimedTask> _queue;

    bool _active = true;
    int _idCounter = 0;
};

enum DtlsRole
{
    DTLS_CLIENT = 0,
    DTLS_SERVER = 1
};

typedef std::function<void(const std::vector<uint8_t>&)> DtlsSendFunc;
typedef std::function<void(bool)> DtlsConnectResultHandler;

class DtlsSrtpTransport
{
public:

    DtlsSrtpTransport(DtlsRole role, const std::string& label,
        AsyncDispatcher& dispatcher) :
        _role(role),
        _label(label),
        _dispatcher(dispatcher)
    {
        auto identity = generateIdentity();
        _certificate = identity.certificate;
        _pkey = identity.key;

        MLOG_D("Starting DTLS");
        _sslCtx = createSslContext();
        assert(_sslCtx);
        _ssl = ::SSL_new(_sslCtx);
        assert(_ssl);

        _inBio = BIO_new(BIO_s_mem());
        _outBio = BIO_new(BIO_s_mem());

        SSL_set_app_data(_ssl, this);

        if (_role == DTLS_CLIENT)
            ::SSL_set_connect_state(_ssl);
        else
            ::SSL_set_accept_state(_ssl);

        ::SSL_set_bio(_ssl, _inBio, _outBio);  //< the SSL object owns the bio now

        MLOG_D("DTLS context initialization finished");
    }

    ~DtlsSrtpTransport()
    {
        stopInternal();
    }

    void handleIncomingData(const std::vector<uint8_t>& data)
    {
        // handle data in dispatcher thread
        _dispatcher.push([this, data]()
        {
            MLOG_D("INCOMING DATA of size " << data.size());

            (void)BIO_reset(_inBio);
            (void)BIO_reset(_outBio);
            ::BIO_write(_inBio, &data[0], data.size());
            handshakeIteration();
        });
    }

    void setResultHandler(const DtlsConnectResultHandler& h)
    {
        _resultHandler = h;
    }

    void setSendFunction(const DtlsSendFunc& sendFunc)
    {
        _sendFunc = sendFunc;
    }

    void start()
    {
        // perform on dispatcher thread
        _dispatcher.push([this](){ handshakeIteration(); });
    }

    void renegotiate()
    {
        // perform on dispatcher thread
        _dispatcher.push([this]()
        {
            MLOG_D("<<<<Renegotiation requested>>>>");
            assert(_handshakeCompleted);
            assert(!_activeRenegotiation);

            _activeRenegotiation = true;
            (void)BIO_reset(_inBio);
            (void)BIO_reset(_outBio);
            //SSL_renegotiate_abbreviated(_ssl);
            SSL_renegotiate(_ssl);
            handshakeIteration();
        });
    }

private:

    SSL_CTX* createSslContext()
    {
        SSL_CTX *ctx = (_role == DTLS_CLIENT) ?
            ::SSL_CTX_new(DTLSv1_client_method()) :
            ::SSL_CTX_new(DTLSv1_server_method());

        assert(ctx);
        assert(_certificate);
        assert(_pkey);

        ::SSL_CTX_use_certificate(ctx, _certificate);
        ::SSL_CTX_use_PrivateKey(ctx, _pkey);

        if (_role == DTLS_SERVER)
        {
            SSL_CTX_set_session_cache_mode(ctx, SSL_SESS_CACHE_OFF);
        }

        ::SSL_CTX_set_info_callback(ctx, &DtlsSrtpTransport::sslInfoCallback);

        ::SSL_CTX_set_verify(ctx, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT,
            &DtlsSrtpTransport::sslVerifyCallback);
        ::SSL_CTX_set_verify_depth(ctx, 1);
        ::SSL_CTX_set_cipher_list(ctx, "ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH");

        //::SSL_CTX_set_tlsext_use_srtp(ctx, "SRTP_AES128_CM_SHA1_32:SRTP_AES128_CM_SHA1_80");
        SSL_CTX_set_read_ahead(ctx, 1);

        // "bad decompression" error fix for some linuxes:
        SSL_CTX_set_options(ctx, SSL_OP_NO_COMPRESSION);
        // disable tickets for simplicity
        SSL_CTX_set_options(ctx, SSL_OP_NO_TICKET);

        return ctx;
    }

    void stopInternal()
    {
        if (_ssl)
        {
            ::SSL_shutdown(_ssl);
            ::SSL_free(_ssl);
            _ssl = NULL;
        }
        if (_sslCtx)
        {
            ::SSL_CTX_free(_sslCtx);
            _sslCtx = NULL;
        }
    }

    void handshakeIteration()
    {
        // we don't actually read data, but need this for SSL_read:
        uint8_t buf[4096];

        // SSL_read after initial negotiation, SSL_do_handshake on client side
        // when renegotiation requested
        int res = (_handshakeCompleted && !_activeRenegotiation) ?
            ::SSL_read(_ssl, buf, sizeof(buf)) : ::SSL_do_handshake(_ssl);

        // get pointer to data written by handshake
        int outBioLen = 0;
        uint8_t *outBioData;
        outBioLen = BIO_get_mem_data(_outBio, &outBioData);

        int err = ::SSL_get_error(_ssl, res);
        struct timeval timeout;

        // check if remote side requested renegotiation
        if (!_activeRenegotiation && _handshakeCompleted && SSL_renegotiate_pending(_ssl) == 1)
        {
            MLOG_D("Remote renegotiation detected");
            _activeRenegotiation = true;
        }

        // check if renegotiation finished
        bool renegotiationFinished = _activeRenegotiation && SSL_renegotiate_pending(_ssl) == 0;

        // handle handshake errors
        switch (err)
        {
        case SSL_ERROR_NONE:
            if (!_handshakeCompleted || renegotiationFinished)
            {
                _handshakeCompleted = true;
                _activeRenegotiation = false;
                reportSuccess();
                _dispatcher.cancelTimedTask(_currentTimerId);
            }
            break;

        case SSL_ERROR_WANT_READ:
            if (renegotiationFinished)
            {
                _activeRenegotiation = false;
                _dispatcher.cancelTimedTask(_currentTimerId);
                reportSuccess();
            }
            else if (DTLSv1_get_timeout(_ssl, &timeout))
            {
                int delay = timeout.tv_sec * 1000 + timeout.tv_usec / 1000;
                MLOG_D("WANT_READ: setting new timer for " << delay << "ms");
                _currentTimerId = _dispatcher.push([this](){ receiveTimerExpired(); },
                    std::chrono::milliseconds(delay));
            }
            break;

        default:
            MLOG_E("Unexpected error while processing DTLS: " << err);
            logOpenSslErrors("SSL reading");
            _resultHandler(false);
            _dispatcher.cancelTimedTask(_currentTimerId);
            assert(false);
            // don't write any data, just return:
            return;
        }

        if (outBioLen)
        {
            MLOG_D("Sending handshake data for DTLS; size " << outBioLen);
            std::vector<uint8_t> data(outBioData, outBioData + outBioLen);
            _sendFunc(data);
        }
    }

    void receiveTimerExpired()
    {
        MLOG_D("DTLS timer expired. Asking OpenSSL to repeat operations");

        (void)BIO_reset(_inBio);
        (void)BIO_reset(_outBio);

        handshakeIteration();
    }

    void reportSuccess()
    {
        MLOG_D("Negotiation success");
        _resultHandler(true);
    }

    static int sslVerifyCallback(int ok, X509_STORE_CTX* store)
    {
        // we don't verify certificate here for simplicity
        return 1;
    }

    static void sslInfoCallback(const SSL* s, int where, int ret)
    {
        auto this_ = reinterpret_cast<DtlsSrtpTransport*>(SSL_get_app_data(s));
        this_->sslInfoCallbackInternal(s, where, ret);
    }

    void sslInfoCallbackInternal(const SSL* s, int where, int ret)
    {
        std::string method = "undefined";
        int w = where & ~SSL_ST_MASK;
        if (w & SSL_ST_CONNECT)
        {
            method = "SSL_connect";
        }
        else if (w & SSL_ST_ACCEPT)
        {
            method = "SSL_accept";
        }


        if (where & SSL_CB_LOOP)
        {
            MLOG_D(method << ": " << SSL_state_string_long(s));
        }
        else if (where & SSL_CB_ALERT)
        {
            const char* direction = (where & SSL_CB_READ) ? "read" : "write";
            MLOG_D("SSL3 alert " << direction
                << ":" << SSL_alert_type_string_long(ret)
                << ":" << SSL_alert_desc_string_long(ret));
        }
        else if (where & SSL_CB_EXIT)
        {
            if (ret == 0)
            {
                MLOG_D(method << " failed in " << SSL_state_string_long(s));
            }
            else if (ret < 0)
            {
                MLOG_D(method << " error in " << SSL_state_string_long(s));
            }
        }
    }


    SSL* _ssl = nullptr;
    SSL_CTX* _sslCtx = nullptr;
    BIO* _inBio = nullptr;
    BIO* _outBio = nullptr;
    X509* _certificate = nullptr;
    EVP_PKEY* _pkey = nullptr;

    bool _handshakeCompleted = false;
    DtlsConnectResultHandler _resultHandler;
    DtlsSendFunc _sendFunc;


    DtlsRole _role;
    std::string _label;

    bool _activeRenegotiation = false;
    int _currentTimerId = 0;

    AsyncDispatcher& _dispatcher;
};



int main()
{
    opensslInit();

    AsyncDispatcher dispatcher;
    
    DtlsSrtpTransport client(DTLS_CLIENT, " [C] ", dispatcher);
    DtlsSrtpTransport server(DTLS_SERVER, " [S] ", dispatcher);

    std::promise<bool> clientResultPromise;
    std::promise<bool> serverResultPromise;

    // on negotiation set promises so we can go on
    client.setResultHandler([&](bool result){ clientResultPromise.set_value(result); });
    server.setResultHandler([&](bool result){ serverResultPromise.set_value(result); });

    int clientCounter = 0;
    client.setSendFunction([&](const std::vector<uint8_t>& data)
    {
        clientCounter++;
        if (clientCounter == 4) //< drop specific flight
            return;
        
        server.handleIncomingData(data);
    });

    server.setSendFunction([&](const std::vector<uint8_t>& data)
    {
        client.handleIncomingData(data);
    });

    client.start();
    server.start();

    // block until get results in promises
    auto clientResult = clientResultPromise.get_future().get();
    auto serverResult = serverResultPromise.get_future().get();

    assert(clientResult);
    assert(serverResult);

    /// renegotiation

    // reset promises
    clientResultPromise = std::promise<bool>();
    serverResultPromise = std::promise<bool>();

    // ask for renegotiation
    client.renegotiate();

    // block until get results
    clientResult = clientResultPromise.get_future().get();
    serverResult = serverResultPromise.get_future().get();

    assert(clientResult);
    assert(serverResult);

    dispatcher.stop();
    opensslCleanup();
}

// Helper functions implementation

namespace
{
    std::vector<std::shared_ptr<std::mutex>> opensslMutexes;


    unsigned long idFunction()
    {
        return std::hash<std::thread::id>()(std::this_thread::get_id());
    }

    void opensslLockingFunc(int mode, int n,
        const char* /*file*/, int /*line*/)
    {
        if (mode & CRYPTO_LOCK)
            opensslMutexes[n]->lock();
        else
            opensslMutexes[n]->unlock();
    }


    void opensslInit()
    {
        ::SSL_library_init();
        ::SSL_load_error_strings();
        ::OpenSSL_add_all_algorithms();

        opensslMutexes.resize(::CRYPTO_num_locks());
        for (auto& mutex : opensslMutexes)
            mutex.reset(new std::mutex());
        ::CRYPTO_set_locking_callback(&opensslLockingFunc);
        ::CRYPTO_set_id_callback(&idFunction);
    }

    void opensslCleanup()
    {
        ::CRYPTO_set_id_callback(0);
        ::CRYPTO_set_locking_callback(0);
        ::ERR_free_strings();
        ::ERR_remove_state(0);
        ::EVP_cleanup();
        ::CRYPTO_cleanup_all_ex_data();
    }


    const int gKeyLength = 1024;

    // number of random bits for certificate serial number
    const int gRandomBitsNum = 64;

    // one year certificate validity
    const int gCertificateLifetime = 60 * 60 * 24 * 365;

    // to compensate for slightly incorrect system clocks
    const int gCertificateValidationWindow = -60 * 60 * 24;

    EVP_PKEY* generateRsaKeyPair()
    {
        EVP_PKEY* pkey = EVP_PKEY_new();
        BIGNUM* exponent = BN_new();
        RSA* rsa = RSA_new();
        if (!pkey || !exponent || !rsa ||
            !BN_set_word(exponent, 0x10001) ||
            !RSA_generate_key_ex(rsa, gKeyLength, exponent, NULL) ||
            !EVP_PKEY_assign_RSA(pkey, rsa))
        {
            EVP_PKEY_free(pkey);
            BN_free(exponent);
            RSA_free(rsa);
            return NULL;
        }

        BN_free(exponent);
        return pkey;
    }

    X509* generateCertificate(EVP_PKEY* pkey, const char* commonName)
    {
        X509* x509 = NULL;
        BIGNUM* serialNumber = NULL;
        X509_NAME* name = NULL;

        if ((x509 = X509_new()) == NULL)
            goto error;

        if (!X509_set_pubkey(x509, pkey))
            goto error;

        ASN1_INTEGER* asn1SerialNumber;
        if ((serialNumber = BN_new()) == NULL ||
            !BN_pseudo_rand(serialNumber, gRandomBitsNum, 0, 0) ||
            (asn1SerialNumber = X509_get_serialNumber(x509)) == NULL ||
            !BN_to_ASN1_INTEGER(serialNumber, asn1SerialNumber))
            goto error;

        if (!X509_set_version(x509, 0L))
            goto error;

        if ((name = X509_NAME_new()) == NULL ||
            !X509_NAME_add_entry_by_NID(name, NID_commonName, MBSTRING_UTF8,
            (unsigned char*)commonName, -1, -1, 0) ||
            !X509_set_subject_name(x509, name) ||
            !X509_set_issuer_name(x509, name))
            goto error;

        if (!X509_gmtime_adj(X509_get_notBefore(x509), gCertificateValidationWindow) ||
            !X509_gmtime_adj(X509_get_notAfter(x509), gCertificateLifetime))
            goto error;

        if (!X509_sign(x509, pkey, EVP_sha256()))
            goto error;

        BN_free(serialNumber);
        X509_NAME_free(name);
        return x509;

    error:
        BN_free(serialNumber);
        X509_NAME_free(name);
        X509_free(x509);
        return NULL;
    }

    DtlsIdentity generateIdentity()
    {
        DtlsIdentity id;
        id.key = generateRsaKeyPair();
        id.certificate = generateCertificate(id.key, "TestCompany Inc");
        return id;
    }

    void logOpenSslErrors(const std::string& prefix)
    {
        char errorBuf[200];
        unsigned long err;

        while ((err = ERR_get_error()) != 0)
        {
            ERR_error_string_n(err, errorBuf, sizeof(errorBuf));
            LOG_E(prefix << ": " << errorBuf);
        }
    }
}

Reply via email to