/* A simple TCP proxy, similar to simpleproxy by Vadim Zaliva et al. */
/* I'm writing this in order to learn how to use libevent 1.4. */

/* The basic outline of the program is as follows. Initially, it opens
 * a listening socket, attaches an event handler to it, and goes into
 * an infinite event loop. So initially there’s a single event
 * handler, for new connections on the socket. When a connection is
 * accepted, an outbound connection is created and a “pipe” object is
 * created to associate it with the inbound connection; the “pipe”
 * creates a couple of bufferevents, which will register read events
 * and sometimes write events, which function merely to pass data back
 * and forth.  When the connection is lost, the pipe deletes itself.
 *
 * Why that takes >300 lines of code is anybody’s fucking guess.
 */

typedef unsigned char u_char;   /* for event.h */
#include <sys/types.h>

#include <assert.h>
#include <errno.h>
#include <event.h>
#include <fcntl.h>
#include <getopt.h>
#include <netdb.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/socket.h>
#include <sys/time.h>
#include <time.h>
#include <unistd.h>
#include <arpa/inet.h>


/* XXX comment stuff */
/* XXX make the error handling a little saner */

/* Basic logging and error reporting */
/*-----------------------------------*/

char *argv0;

#ifdef __GNUC__
#define attribute(x) __attribute__(x)
#else
#define attribute(x)
#endif

static void
panic(char *format, ...)
    attribute ((format (printf, 1, 2)));

static void
panic(char *format, ...)
{
    va_list ap;

    fprintf(stderr, "%s: PANIC: ", argv0);

    va_start(ap, format);
    vfprintf(stderr, format, ap);
    va_end(ap);

    fprintf(stderr, "\n");
    exit(1);
}

static void
log_msg(char *format, ...)
    attribute ((format (printf, 1, 2)));

enum { time_str_max = 30 };

static void
log_msg(char *format, ...)
{
    va_list ap;
    time_t now;
    char time_str[time_str_max];
    struct tm *now_tm;

    time(&now);
    now_tm = localtime(&now);
    if (!strftime(time_str, time_str_max, "%Y-%m-%d %H:%M:%S", now_tm)) {
        panic("strftime ran out of space in log_msg");
    }

    fprintf(stderr, "%s %s: ", time_str, argv0);

    va_start(ap, format);
    vfprintf(stderr, format, ap);
    va_end(ap);

    fprintf(stderr, "\n");
}


/* Command-line parsing stuff */
/*----------------------------*/

static int
parse_ushort(const char *number, unsigned short *dest)
{
    char *endptr;
    long result = strtol(number, &endptr, 10);

    if (endptr == number) return 0; /* no digits */
    if (*endptr != '\0') return 0;
    if (result != (unsigned short)result) return 0;

    *dest = result;
    return 1;
}

enum { max_hostname_size = 256 };

static int
parse_host_port(const char *host_port, struct sockaddr_in *sa)
{
    char *colon = strchr(host_port, ':');
    int hostname_size = colon - host_port + 1;
    char hostname[max_hostname_size];
    unsigned short port;
    struct hostent *host;

    if (!colon) return 0;
    if (hostname_size > max_hostname_size) return 0;
    memcpy(hostname, host_port, hostname_size);
    hostname[hostname_size - 1] = '\0';

    if (!parse_ushort(colon + 1, &port)) return 0;

    host = gethostbyname(hostname); /* XXX use getaddrinfo */
    if (!host) return 0;

    sa->sin_family = AF_INET;
    sa->sin_addr = *((struct in_addr **)host->h_addr_list)[0];
    sa->sin_port = htons(port);
    return 1;
}


/* Basic networking stuff */
/*------------------------*/

static int
tcp_socket()
{
    return socket(PF_INET, SOCK_STREAM, 0);
}

static int
connecting_socket(struct sockaddr_in remote)
{
    int fd = tcp_socket();
    int rv;
    if (-1 == fd) {
        log_msg("failed to allocate socket: %s", strerror(errno));
        return -1;
    }

    if (-1 == fcntl(fd, F_SETFL, O_NONBLOCK)) {
        log_msg("failed to set nonblocking mode: %s", strerror(errno));
        goto error;
    }

    rv = connect(fd, (struct sockaddr*)&remote, sizeof(remote));
    if (rv == -1 && errno != EINPROGRESS) {
        log_msg("failed to connect(): %s", strerror(errno));
        goto error;
    }

    return fd;
 error:
    close(fd);
    return -1;
}


/* A “pipe” connects two sockets one way. This program exists to make pipes. */
/*---------------------------------------------------------------------------*/

struct pipe {
    struct pipe *next;
    struct pipe *partner;
    struct bufferevent *buf;
    int marked_for_deletion;
    int fd;
};

/* I’m not sure how safe it is to deallocate bufferevents and args
 * from inside their callbacks, so I link pipes to be deleted into a
 * list here and deallocate them at the top level of the event loop.
 */
struct pipe *pipes_to_delete = NULL;

static void
mark_for_deletion(struct pipe *p)
{
    if (p->marked_for_deletion) return;

    assert(!p->next);
    p->next = p->partner;

    assert(!p->partner->next);
    p->partner->next = pipes_to_delete;

    pipes_to_delete = p;

    p->marked_for_deletion = 1;
}

static void
delete_pipe(struct pipe *p)
{
    log_msg("closing connection %d", p->fd);
    bufferevent_free(p->buf);
    close(p->fd);
    free(p);
}

static void
delete_pipes_to_delete()
{
    while (pipes_to_delete) {
        struct pipe *next = pipes_to_delete->next;
        delete_pipe(pipes_to_delete);
        pipes_to_delete = next;
    }
}

enum { net_buf_size = 8 };

static size_t
copy_data(struct bufferevent *from, struct bufferevent *to)
{
    char buf[net_buf_size];
    size_t total_bytes = 0;
    for (;;) {
        int bytes = bufferevent_read(from, buf, net_buf_size);
        int y = bufferevent_write(to, buf, bytes);
        /* XXX y could be -1 and then we should drop the conn */
        total_bytes += bytes;
        if (bytes == 0) return total_bytes;
    }
}

/* read callback from remote end of the pipe */
static void
read_callback(struct bufferevent *bufev, void *arg)
{
    struct pipe *p = arg;
    if (!copy_data(p->buf, p->partner->buf)) {
        log_msg("closed connection %d", p->fd);
        mark_for_deletion(p);
    }
}

/* error callback from remote end of the pipe */
/* XXX EOF is reported as an error EVBUFFER_EOF */
/* XXX what about half-open sockets? */
static void
error_callback(struct bufferevent *bufev, short what, void *arg)
{
    struct pipe *p = arg;
    log_msg("error %d on conn %d", what, p->fd); /* WTF is "remote error 17"? */
    mark_for_deletion(p);
}

/* Initialize and enable a pipe structure; returns NULL on failure. */
/* Leaves field `partner` uninitialized. */
static struct pipe *
new_pipe(int fd)
{
    struct pipe *p = malloc(sizeof(*p));
    if (!p) {
        log_msg("malloc failed (%s)", strerror(errno));
        return NULL;
    }

    p->next = NULL;
    p->marked_for_deletion = 0;
    p->fd = fd;

    p->buf = bufferevent_new(fd, read_callback, NULL, error_callback, p);
    if (!p->buf) {
        log_msg("bufferevent_new failed");
        goto error;
    }

    if (-1 == bufferevent_enable(p->buf, EV_READ | EV_WRITE)) {
        log_msg("bufferevent_enable failed");
        goto error2;
    }

    return p;

 error2:
    bufferevent_free(p->buf);

 error:
    free(p);
    return NULL;
}

/* Creates two pipes; returns 0 on failure. */
/* Opens a connection, sets up event handlers, acquires all resources. */
static int
open_pipes(int fd, struct sockaddr_in *remote_sockaddr)
{
    struct pipe *p = new_pipe(fd);
    struct pipe *partner;
    int outgoing_conn;
    if (!p) return 0;

    errno = 0;
    outgoing_conn = connecting_socket(*remote_sockaddr);
    if (outgoing_conn == -1) goto error;

    partner = new_pipe(outgoing_conn);
    if (!partner) goto error2;

    p->partner = partner;
    partner->partner = p;

    return 1;

 error2:
    close(outgoing_conn);

 error:
    bufferevent_free(p->buf);
    free(p);
    return 0;
}


/* Handling of the listening socket. */
/*-----------------------------------*/

static void
accept_conn(int fd, short event_type, void *arg)
{
    struct sockaddr_in peer_addr;
    socklen_t pa_s = sizeof(peer_addr);
    int conn = accept(fd, (struct sockaddr *)&peer_addr, &pa_s);
    struct sockaddr_in *remote_sockaddr = arg;

    if (-1 == conn) {
        log_msg("failed to accept a connection: %s", strerror(errno));
        return;
    }

    log_msg("accepted connection %d from %s:%d", conn, 
            inet_ntoa(peer_addr.sin_addr),
            ntohs(peer_addr.sin_port));

    if (!open_pipes(conn, remote_sockaddr)) close(conn);
}

static int
do_listen(struct event *listen_event, short port, 
          struct sockaddr_in *remote_sockaddr)
{
    int fd = tcp_socket();
    struct sockaddr_in local_sa;
    int so_reuseaddr_value = 1;

    if (-1 == fd) return 0;

    if (-1 == setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, 
                         &so_reuseaddr_value, sizeof(so_reuseaddr_value))) {
        goto error;
    }

    local_sa.sin_family = AF_INET;
    local_sa.sin_addr.s_addr = INADDR_ANY;
    local_sa.sin_port = htons(port);
    if (-1 == bind(fd, (struct sockaddr*)&local_sa, sizeof(local_sa))) {
        goto error;
    }
    if (-1 == listen(fd, 5)) goto error;

    event_set(listen_event, fd, EV_READ | EV_PERSIST, 
              accept_conn, remote_sockaddr);
    event_add(listen_event, NULL);

    return 1;
 error:
    close(fd);
    return 0;
}


/* Top-level control: event loop and CLI UI */
/*------------------------------------------*/

static void
run_simple_proxy(short local_port_number, struct sockaddr_in remote_sockaddr)
{
    struct event listen_event;
    event_init();

    if (!do_listen(&listen_event, local_port_number, &remote_sockaddr)) {
        panic("couldn't listen: %s", strerror(errno));
    }
    log_msg("listening on port %d", local_port_number);

    for (;;) {
        event_loop(EVLOOP_ONCE);
        log_msg("went through the event loop");
        delete_pipes_to_delete();
    }
}

int main(int argc, char **argv)
{
    char *local_port = 0, *remote_host_and_port = 0;
    int opt;
    unsigned short local_port_number = 0;
    struct sockaddr_in remote_sockaddr;

    argv0 = argv[0];

    while ((opt = getopt(argc, argv, "L:R:")) != -1) {
        switch (opt) {
        case 'L':
            local_port = optarg;
            break;
        case 'R':
            remote_host_and_port = optarg;
            break;
        default:
            panic("getopt %d", opt);
            break;
        }
    }

    if (!local_port) panic("no local port given with -L");
    if (!remote_host_and_port) panic("no remote host:port given with -R");

    log_msg("local port is %s; remote host and port is %s", 
            local_port, remote_host_and_port);

    if (!parse_ushort(local_port, &local_port_number)) {
        panic("failed to convert local port: %s", local_port);
    }

    if (!parse_host_port(remote_host_and_port, &remote_sockaddr)) {
        panic("couldn't parse remote host and port: %s", remote_host_and_port);
    }
    log_msg("remote host resolved to %s (port %d)", 
            inet_ntoa(remote_sockaddr.sin_addr),
            ntohs(remote_sockaddr.sin_port));

    run_simple_proxy(local_port_number, remote_sockaddr);

    return 0;
}

/*
 * Local Variables:
 * compile-command: "make -k simpleproxy LDLIBS=-levent && ./simpleproxy -L 8080 -R www.google.com:80"
 * c-basic-offset: 4
 * End:
 */


