Here's a first take on having sabridge use the systemd-native event
library. The current, full diff is also visible on GitHub [1].

Obviously, this work still needs considerable cleanup and tightening.
I like how we're currently hammering out the basics, like the event
library to use and where the multiprocess/multithreaded logic should
go in the longer-run.

I'm open to better ideas for the data structures. Right now, the
priority is to hammer everything into symmetric structures so the
bi-directionality of the proxy gets abstracted away from the transfer
function. This is useful for ensuring we have consistent support for
server-first (MySQL) and client-first (HTTP) protocols.

[1] https://github.com/systemd/systemd/pull/5/files
/*-*- Mode: C; c-basic-offset: 8; indent-tabs-mode: nil -*-*/

/***
  This file is part of systemd.

  Copyright 2013 David Strauss

  systemd is free software; you can redistribute it and/or modify it
  under the terms of the GNU Lesser General Public License as published by
  the Free Software Foundation; either version 2.1 of the License, or
  (at your option) any later version.

  systemd is distributed in the hope that it will be useful, but
  WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU
  Lesser General Public License for more details.

  You should have received a copy of the GNU Lesser General Public License
  along with systemd; If not, see <http://www.gnu.org/licenses/>.
 ***/

#define __STDC_FORMAT_MACROS
#include <errno.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <netdb.h>
#include <sys/fcntl.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <unistd.h>

#include "log.h"
#include "sd-daemon.h"
#include "sd-event.h"

#define BUFFER_SIZE 1024

unsigned int total_clients = 0;

struct proxy {
    int listen_fd;
    bool remote_is_inet;
    const char *remote_host;
    const char *remote_service;
};

struct connection {
    int origin_fd;
    int destination_fd;
    sd_event_source *w_destination;
    struct connection *c_destination;
};

static int transfer_data_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
    struct connection *connection = (struct connection *) userdata;

    char *buffer = malloc(BUFFER_SIZE);
    ssize_t buffer_len;

    assert(revents & EPOLLIN);
    assert(fd == connection->origin_fd);

    log_info("About to transfer up to %u bytes from %d to %d.", BUFFER_SIZE, connection->origin_fd, connection->destination_fd);

    buffer_len = recv(connection->origin_fd, buffer, BUFFER_SIZE, 0);
    if (buffer_len == 0) {
        log_info("Clean disconnection.");
        sd_event_source_unref(connection->w_destination);
        sd_event_source_unref(s);
        close(connection->origin_fd);
        close(connection->destination_fd);
        free(connection->c_destination);
        free(connection);
        goto finish;
    }
    else if (buffer_len == -1) {
        log_error("Error %d in recv from fd=%d: %s", errno, connection->origin_fd, strerror(errno));
        exit(EXIT_FAILURE);
    }

    if (send(connection->destination_fd, buffer, buffer_len, 0) < 0) {
        log_error("Error %d in send to fd=%d: %s", errno, connection->destination_fd, strerror(errno));
        exit(EXIT_FAILURE);
    }

finish:
    free(buffer);
    return 0;
}

static int connected_to_server_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
    struct connection *c_server_to_client = (struct connection *) userdata;
    struct sd_event *e = sd_event_get(s);

    log_info("Connected to server. Initializing watchers for sending data.");

    // Start listening for data sent by the client.
    sd_event_add_io(e, c_server_to_client->destination_fd, EPOLLIN, transfer_data_cb, c_server_to_client->c_destination, &c_server_to_client->w_destination);

    // Cancel the write watcher for the server.
    sd_event_source_unref(s);

    // Start listening for data sent by the server.
    sd_event_add_io(e, c_server_to_client->origin_fd, EPOLLIN, transfer_data_cb, c_server_to_client, &c_server_to_client->c_destination->w_destination);

    return 0;
}


static int set_nonblock(int fd) {
    int flags;
    flags = fcntl(fd, F_GETFL);
    flags |= O_NONBLOCK;
    return fcntl(fd, F_SETFL, flags);
}

static int get_server_connection_fd(const struct proxy *proxy) {
    int server_fd;
    int len;

    if (proxy->remote_is_inet) {
        struct addrinfo hints;
        struct addrinfo *result;
        int s;

        memset(&hints, 0, sizeof(struct addrinfo));
        hints.ai_family = AF_UNSPEC; /* IPv4 or IPv6 */
        hints.ai_socktype = SOCK_STREAM;  /* TCP */
        hints.ai_flags = AI_PASSIVE; /* Any IP address */

        //log_error("Looking up address info for %s:%s", proxy->remote_host, proxy->remote_service);
        s = getaddrinfo(proxy->remote_host, proxy->remote_service, &hints, &result);
        if (s != 0) {
            log_error("getaddrinfo error (%d): %s", s, gai_strerror(s));
            exit(EXIT_FAILURE);
        }

        if (result == NULL) {
            log_error("getaddrinfo: no result");
            exit(EXIT_FAILURE);
        }

        // @TODO: Try connecting to all results instead of just the first.
        server_fd = socket(result->ai_family, result->ai_socktype, result->ai_protocol);

        if (-1 == set_nonblock(server_fd)) {
            log_error("Error setting socket to non-blocking.");
            exit(EXIT_FAILURE);
        }

        if (!connect(server_fd, result->ai_addr, result->ai_addrlen)) {
            log_error("Could not connect to socket: %s:%s", proxy->remote_host, proxy->remote_service);
            freeaddrinfo(result);
            exit(EXIT_FAILURE);
        }

        freeaddrinfo(result);
    }
    else {
        struct sockaddr_un remote;

        server_fd = socket(AF_UNIX, SOCK_STREAM, 0);
        if (-1 == server_fd) {
            log_error("Error %d while initializing socket socket: %s", errno, strerror(errno));
            exit(EXIT_FAILURE);
        }

        if (-1 == set_nonblock(server_fd)) {
            log_error("Error setting socket to non-blocking.");
            exit(EXIT_FAILURE);
        }

        remote.sun_family = AF_UNIX;
        strncpy(remote.sun_path, proxy->remote_host, sizeof(remote.sun_path));
        len = strlen(remote.sun_path) + sizeof(remote.sun_family);
        if (-1 == connect(server_fd, (struct sockaddr *) &remote, len)) {
            log_error("Could not connect to Unix domain socket: %s", proxy->remote_host);
            exit(EXIT_FAILURE);
        }
    }

    log_info("Server connection is fd=%d", server_fd);

    return server_fd;
}

static int accept_cb(sd_event_source *s, int fd, uint32_t revents, void *userdata) {
    struct proxy *proxy = (struct proxy *) userdata;
    struct connection *c_server_to_client = malloc(sizeof(struct connection));
    struct connection *c_client_to_server = malloc(sizeof(struct connection));
    int client_fd, server_fd;
    sd_event_source *w_server = NULL;

    // @TODO: Remove assumption of IPv4.
    struct sockaddr_in client_addr;
    socklen_t client_len = sizeof(client_addr);

    assert(revents & EPOLLIN);

    server_fd = get_server_connection_fd(proxy);
    if (server_fd < 0) {
      log_error("Error initiating server connection.");
      goto fail;
    }

    client_fd = accept(fd, (struct sockaddr *) &client_addr, &client_len);
    if (client_fd < 0) {
      log_error("Error accepting client connection.");
      goto fail;
    }

    log_info("Client connection accepted with fd=%d", client_fd);

    total_clients++;
    log_info("Client successfully connected. Total clients: %u", total_clients);

    // Wait for the server socket to be writable before initializing
    // read events for the client socket.
    sd_event_add_io(sd_event_get(s), server_fd, EPOLLOUT, connected_to_server_cb, c_server_to_client, &w_server);

    c_server_to_client->origin_fd = server_fd;
    c_server_to_client->destination_fd = client_fd;
    c_server_to_client->w_destination = NULL;
    c_server_to_client->c_destination = c_client_to_server;

    c_client_to_server->origin_fd = client_fd;
    c_client_to_server->destination_fd = server_fd;
    c_client_to_server->w_destination = w_server;
    c_client_to_server->c_destination = c_server_to_client;

    goto finish;

fail:
    free(c_client_to_server);
    free(c_server_to_client);
    sd_event_source_unref(w_server);

finish:
    return 0;
}

static int run_main_loop(struct proxy *proxy) {
    int r = EXIT_SUCCESS;
    struct sd_event *e = NULL;
    sd_event_source *w_accept = NULL;

    r = sd_event_new(&e);
    if (r < 0)
        goto finish;

    r = set_nonblock(proxy->listen_fd);
    if (r < 0)
        goto finish;

    log_info("Initializing main listener fd=%d", proxy->listen_fd);

    sd_event_add_io(e, proxy->listen_fd, EPOLLIN, accept_cb, proxy, &w_accept);

    log_info("Initialized main listener. Entering loop.");

    sd_event_loop(e);

finish:
    sd_event_source_unref(w_accept);
    sd_event_unref(e);

    return r;
}

int main(int argc, char *argv[]) {
    struct proxy proxy;
    int n, r;

    log_info("Starting up.");

    if (argc != 3) {
        fprintf(stderr, "usage: %s hostname service-or-port\n", argv[0]);
        exit(1);
    }

    proxy.listen_fd = SD_LISTEN_FDS_START;
    proxy.remote_host = argv[1];
    proxy.remote_service = argv[2];
    proxy.remote_is_inet = true;  // @TODO: Support Unix domain sockets.

    assert(proxy.remote_host);
    assert(proxy.remote_service);

    n = sd_listen_fds(1);
    if (n < 0) {
        log_error("Failed to determine passed sockets: %s", strerror(-n));
        exit(EXIT_FAILURE);
    } else if (n > 1) {
        log_error("Can't listen on more than one socket.");
        exit(EXIT_FAILURE);
    }

    log_info("Initializing main loop.");

    r = run_main_loop(&proxy);

    log_info("Exiting with status %d.", r);
    return r;
}
_______________________________________________
systemd-devel mailing list
systemd-devel@lists.freedesktop.org
http://lists.freedesktop.org/mailman/listinfo/systemd-devel

Reply via email to