/*
Copyright (c) 2013, Felix J. Ogris
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met: 

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer. 
2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution. 

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/*
 * tcpoptd - modify arbitrary tcp header options on systems offering
 *           divert(4) sockets, e.g. FreeBSD
 */

#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <errno.h>
#include <signal.h>
#include <pwd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/syslog.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <netinet/ip6.h>
#include <netinet/tcp.h>

#define LOGX(level, msg, params ...) do { \
  if (log_to_syslog) \
    syslog(level, "[%s:%03i] " msg "\n", __FILE__, __LINE__, ## params); \
  else \
    fprintf(stderr, "[%s:%03i] " msg "\n", __FILE__, __LINE__, ## params); \
} while (0)

#define ERRDIEX(msg, params ...) do { \
  LOGX(LOG_ERR, msg, ## params); \
  exit(1); \
} while (0)
#define LOGWARNX(msg, params ...) LOGX(LOG_WARNING, msg, ## params)
#define LOGINFOX(msg, params ...) LOGX(LOG_INFO, msg, ## params)
#define LOGDEBUGX(msg, params ...) do { if (debug >= 2) \
  LOGX(LOG_DEBUG, msg, ## params); \
} while (0)

#define ERRDIE(msg, params ...) do { \
  LOGX(LOG_ERR, msg ": %s", ## params, strerror(errno)); \
  exit(1); \
} while (0)
#define LOGWARN(msg, params ...) LOGX(LOG_WARNING, msg ": %s", \
                                      ## params, strerror(errno))

int debug = 0;
int log_to_syslog = 0;
int running = 1;

void kill_handler(int sig __attribute__((unused)))
{
  running = 0;
  LOGINFOX("received SIGKILL, going down...");
}

/* tcp/ip checksum */
void chksum(uint8_t *chksum_field, unsigned int num_fields,
            struct iovec *fields)
{
  uint16_t tmp;
  unsigned int i;
  size_t pos;
  uint32_t chksum = 0;
  int have_carry = 0;

  *((uint8_t*) &tmp) = chksum_field[0];
  *((uint8_t*) &tmp + 1) = chksum_field[1];
  LOGDEBUGX("old chksum=0x%hX", tmp);

  /* to calculate the checksum set checksum field to zero */
  chksum_field[0] = 0;
  chksum_field[1] = 0;

  /* avoid a pseudo header, use a scatter/gather approach instead */
  for (i = 0; i < num_fields; i++) {
    if (fields[i].iov_len == 0)
      continue;

    if (have_carry) {
      /* previous loop had odd data length */
      *((uint8_t*) &tmp) = *((uint8_t*) fields[i].iov_base);
      chksum += tmp;
      pos = 1;
    } else {
      pos = 0;
    }

    for (; pos < fields[i].iov_len - 1; pos += 2) {
      *((uint8_t*) &tmp) = *((uint8_t*) (fields[i].iov_base + pos));
      *((uint8_t*) &tmp + 1) = *((uint8_t*) (fields[i].iov_base + pos + 1));
      chksum += tmp;
    }

    if (pos == fields[i].iov_len) {
      have_carry = 0;
    } else {
      /* remember carry */
      *((uint8_t*) &tmp) = *((uint8_t*) (fields[i].iov_base + pos));
      have_carry = 1;
    }
  }

  if (have_carry) {
    *((uint8_t*) &tmp + 1) = 0;
    chksum += tmp;
  }

  /* fold checksum */
  while (chksum > 0xffff)
    chksum = (chksum & 0xffff) + (chksum >> 16);

  chksum = ~chksum;
  chksum_field[0] = *((uint8_t*) &chksum);
  chksum_field[1] = *((uint8_t*) &chksum + 1);

  LOGDEBUGX("new chksum=0x%hX", chksum);
}

/* removes + adds tcp header options */
int build_opts(uint8_t *opts, unsigned int opts_len, uint8_t *new_opts,
               unsigned int *new_opts_len, unsigned int new_opts_size,
               char *remove_opts, uint8_t *add_opts, uint add_opts_len)
{
  unsigned int i;
  uint8_t type;
  uint8_t len;

  *new_opts_len = 0;

  for (i = 0; i < opts_len; ) {
    type = opts[i];
    LOGDEBUGX("opt=%u remove=%i", type, remove_opts[type]);

    if (type == 0) /* end-of-options, padding */
      break;

    if (type == 1) {
      /* no-op */
      i++;
      continue;
    }

    if (i + 1 >= opts_len) /* no room for length */
      return -1;

    len = opts[i + 1];
    LOGDEBUGX("opt_len=%u", len);

    if (len < 2) /* too short */
      return -1;
    if (i + len > opts_len) /* too large to fit into opts */
      return -1;

    /* copy option length and data unless type is subject to removal */
    if (remove_opts[type]) i += len;
    else {
      if (*new_opts_len + len + (len % 2) > new_opts_size)
        return -1;

      /* most options are aligned to 2 bytes */
      if (len % 2)
        new_opts[(*new_opts_len)++] = 1;

      while (len-- > 0)
        new_opts[(*new_opts_len)++] = opts[i++];
    }
  }

  /* new_opts and add_opts exceed 40 bytes */
  if (*new_opts_len + add_opts_len > new_opts_size)
   return -1; 

  /* copy add_opts */
  for (i = 0; i < add_opts_len; i++)
    new_opts[(*new_opts_len)++] = add_opts[i];

  /* pad to 4 bytes */
  while (*new_opts_len % 4 != 0)
    new_opts[(*new_opts_len)++] = 0;

  return 0;
}

/* parse ipv4 packet */
void parse_v4(void *buf, unsigned int *len, unsigned int size,
              char *remove_opts, uint8_t *add_opts, int add_opts_len)
{
  struct ip *ip = (struct ip*) buf;
  struct tcphdr *tcp;
  uint8_t *opts;
  unsigned int opts_len;
  uint8_t new_opts[40];
  unsigned int new_opts_len;
  unsigned int i;
  unsigned int new_len;
  uint8_t tcp_pseudo[4];
  uint16_t tcp_len;
  struct iovec fields[3];

  /* ip sanity checks */
  if (*len < sizeof(struct ip))
    return;

  LOGDEBUGX("ip_p=%i ip_hl=%i", ip->ip_p, ip->ip_hl * 4);

  if (ip->ip_p != IPPROTO_TCP)
    return;

  if ((unsigned) ip->ip_hl * 4 < sizeof(struct ip))
    return;

  if (*len < (unsigned) ip->ip_hl * 4 + sizeof(struct tcphdr))
    return;

  /* tcp sanity checks */
  tcp = (struct tcphdr*) (buf + ip->ip_hl * 4);
  LOGDEBUGX("th_off=%i", tcp->th_off * 4);

  if ((unsigned) tcp->th_off * 4 < sizeof(struct tcphdr))
    return;

  if (*len < (unsigned) ip->ip_hl * 4 + (unsigned) tcp->th_off * 4)
    return;

  opts = (uint8_t*) (tcp + 1);
  opts_len = tcp->th_off * 4 - sizeof(struct tcphdr);

  LOGDEBUGX("opts_len=%u", opts_len);

  if (build_opts(opts, opts_len, new_opts, &new_opts_len,
                 sizeof(new_opts)/sizeof(new_opts[0]), remove_opts,
                 add_opts, add_opts_len) < 0)
    return;

  LOGDEBUGX("new_opts_len=%u", new_opts_len);

  new_len = *len + new_opts_len - opts_len;
  if (new_len != *len) {
    /* new options exceed buffer */
    if (new_len > size)
      return;

    /* move away payload */
    memmove(opts + new_opts_len, opts + opts_len,
            *len - ip->ip_hl * 4 - tcp->th_off * 4);
  }

  /* copy in the new options */
  for (i = 0; i < new_opts_len; i++)
    opts[i] = new_opts[i];

  tcp->th_off = (sizeof(struct tcphdr) + new_opts_len) / 4;
  tcp_len = new_len - ip->ip_hl * 4;
  tcp_pseudo[0] = 0;
  tcp_pseudo[1] = ip->ip_p;
  tcp_pseudo[2] = tcp_len >> 8;
  tcp_pseudo[3] = tcp_len & 0xff;

  /* tcp header checksum */
  fields[0].iov_base = &ip->ip_src;
  fields[0].iov_len = 8;
  fields[1].iov_base = &tcp_pseudo;
  fields[1].iov_len = 4;
  fields[2].iov_base = tcp;
  fields[2].iov_len = tcp_len;

  LOGDEBUGX("tcp_len=%u", tcp_len);
  chksum((uint8_t*) &tcp->th_sum, 3, fields);

  if (new_len != *len) {
    /* ip header checksum */
    *((uint8_t*) &ip->ip_len) = (new_len >> 8) & 0xFF;
    *((uint8_t*) &ip->ip_len + 1) = new_len & 0xFF;
    fields[0].iov_base = buf;
    fields[0].iov_len = ip->ip_hl * 4;
    chksum((uint8_t*) &ip->ip_sum, 1, fields);
    *len = new_len;
  }
}

/* parse commandline option */
int parse_byte_list(char *argv, unsigned int *pos, uint8_t *val)
{
  int res;
  int len = 0;
  int tmpval;

  /* argv: 2,5,0xA,6,0x1d,9
     pos:      ^
     => val = 10 */
  if (*pos >= strlen(argv))
    return -1;

  res = sscanf(argv + *pos, "%i,%n", &tmpval, &len);
  if (res != 1)
    return -1;

  *val = (uint8_t) tmpval;
  if (len < 2)
    return 1;

  *pos += len;
  return 0;
}

/* parse commandline option */
void parse_remove_opts(char *argv, char *remove_opts)
{
  int res;
  unsigned int pos = 0;
  uint8_t opt;

  /* set remove_opts[x]=1 if x is given in argv */
  do {
    res = parse_byte_list(argv, &pos, &opt);
    if (res >= 0) {
      LOGDEBUGX("remove_opt=%u", opt);
      remove_opts[opt] = 1;
    }
  } while (res == 0);
}

/* parse commandline option */
unsigned int parse_add_opts(char *argv, uint8_t *add_opts,
                            unsigned int add_opts_size)
{
  int res;
  unsigned int pos = 0;
  unsigned int add_opts_len = 0;

  /* add numeric values of argv to add_opts */
  do {
    if (add_opts_len >= add_opts_size)
      ERRDIEX("too many options to add (max. %u)", add_opts_size);

    res = parse_byte_list(argv, &pos, add_opts + add_opts_len);
    if (res >= 0) {
      LOGDEBUGX("add_opt=%u", *(add_opts + add_opts_len));
      add_opts_len++;
    }
  } while (res == 0);

  return add_opts_len;
}

void print_help(FILE *f)
{
  fputs(
"tcpoptd V0.1 - modifies tcp header options\n\n"
"usage: tcpoptd [-a <opts>] [-r <opts>] [-p <port>]\n"
"               [-d [-d]] [-i <pidfile>] [-u <user>]\n"
"       tcpoptd -h\n\n"
"  -a <opts>      comma-separated values that are inserted into every tcp\n"
"                 packet; decimal and hexadecimal values are accepted, e.g.\n"
"                 to insert tcp option 0x4e (78) comprising two values 23 "
                                                                       "and\n"
"                 42:\n"
"                 0x4e,4,23,42\n"
"                  |   | |  +-- option value\n"
"                  |   | +----- option value\n"
"                  |   +------- total length incl. type and length\n"
"                  +----------- option type\n"
"                 if you are inserting an odd number of bytes, then you "
                                                                    "should\n"
"                 prefix these values with a nop opcode (1), e.g. instead of\n"
"                 45,3,4 use 1,45,3,4\n\n"
"  -r <opts>      comma-separated list of tcp header options that get remove\n"
"                 from every tcp packet; decimal and hexadecimal values are\n"
"                 accepted, e.g. to remove any occurrences of option type\n"
"                 0xf (15) and 11:\n"
"                 0xf,11\n"
"                 Note that option type 0 (end-of-options, padding) and\n"
"                 1 (nop) are always removed and added as needed at the end\n"
"                 of the new tcp header options\n\n"
"  -p <port>      divert port number to listen on; defaults to 10003, e.g.\n"
"                 to redirect all outgoing tcp packets with syn flag set and\n"
"                 destined to port 25 through tcpoptd, add this to your ipfw\n"
"                 ruleset:\n"
"                 add divert 10003 tcp from me to any 25 out setup\n\n"
"  -d             run in foreground, don't write pidfile\n"
"  -d -d          run in foreground, don't write pidfile, "
                                                      "debug log to stderr,\n"
"                 don't drop privileges\n"
"  -i <pidfile>   write pid to <pidfile> instead of /var/run/tcpoptd.pid\n"
"  -u <user>      run as user <user> instead of uid 65535 and gid 65535\n"
"  -h             show this help ;-)\n\n", f);
}

int main(int argc, char **argv)
{
  char ch;
  char *pidfile = "/var/run/tcpoptd.pid";
  gid_t gid = 65535;
  uid_t uid = 65535;
  struct passwd *pw;
  pid_t pid;
  struct stat st;
  FILE *f;
  char buf[65536];
  struct sigaction sa;
  struct sockaddr_in sin;
  int port = 10003;
  int sock;
  ssize_t recvlen;
  ssize_t sendlen;
  unsigned int totalsendlen;
  unsigned int buflen;
  socklen_t sinsize;
  struct ip *ip;
  char remove_opts[256];
  uint8_t add_opts[40];
  unsigned int add_opts_len = 0;

  bzero(remove_opts, sizeof(remove_opts));
  bzero(add_opts, sizeof(add_opts));

  while ((ch = getopt(argc, argv, "a:r:p:dhi:u:")) != -1) {
    switch (ch) {
      case 'a':
        /* array of bytes which gets added to every tcp header option field */
        add_opts_len = parse_add_opts(optarg, add_opts,
                                      sizeof(add_opts)/sizeof(add_opts[0]));
        break;
      case 'r':
        /* list of tcp header options which get removed from every packet */
        parse_remove_opts(optarg, remove_opts);
        break;
      case 'p':
        /* divert port */
        port = atoi(optarg);
        break;
      case 'd':
        debug++;
        break;
      case 'i':
        pidfile = optarg;
        break;
      case 'u':
        if ((pw = getpwnam(optarg)) == NULL)
          ERRDIEX("no such user: %s", optarg);
        gid = pw->pw_gid;
        uid = pw->pw_uid;
        break;
      default:
        print_help(ch != 'h' ? stderr : stdout);
        return (ch != 'h');
    }
  }

  close(0);
  close(1);

  umask(0022);

  if (debug <= 0) {
    /* check pidfile */
    if (lstat(pidfile, &st) == 0) {
      if (!S_ISREG(st.st_mode))
        ERRDIEX("%s exists, but is not a regular file", pidfile);
      if ((f = fopen(pidfile, "r")) == NULL) ERRDIE("fopen(%s, r)", pidfile);
      if (fgets(buf, sizeof(buf), f) == NULL) ERRDIE("fgets(%s)", pidfile);
      pid = atoi(buf);
      fclose(f);
      if (kill(pid, 0) == 0) ERRDIEX("still running (pid=%li)", (long)pid);
    } else if (errno != ENOENT) ERRDIE("%s", pidfile);

    /* daemonize */
    pid = fork();
    if (pid < 0) ERRDIE("fork()");
    if (pid > 0) return 0;
    if (setsid() == -1) ERRDIE("setsid()");

    /* save pid */
    if ((f = fopen(pidfile, "w")) == NULL) ERRDIE("fopen(%s, w)", pidfile);
    fprintf(f, "%li\n", (long)getpid());
    if (fclose(f) != 0) ERRDIE("fclose(%s)", pidfile);;
  }

  /* create divert socket */
  if ((sock = socket(PF_INET, SOCK_RAW, IPPROTO_DIVERT)) < 0)
    ERRDIE("socket()");
  
  memset(&sin, 0, sizeof(struct sockaddr_in));
  sin.sin_family = AF_INET;
  sin.sin_port = htons(port);

  if (bind(sock, (struct sockaddr*) &sin, sizeof(struct sockaddr_in)) < 0)
    ERRDIE("bind()");

  if (debug <= 1) {
    /* drop privileges */
    if (setgid(gid) != 0) ERRDIE("setgid(%lu)", (long unsigned)gid);
    if (setuid(uid) != 0) ERRDIE("setuid(%lu)", (long unsigned)uid);
    if (chdir("/") != 0) ERRDIE("chdir(/)");
  }

  /* set up signal handlers */
  memset(&sa, 0, sizeof(sa));
  sa.sa_handler = kill_handler;
  if (sigaction(SIGTERM, &sa, NULL) < 0)
    ERRDIE("sigaction(SIGTERM)");

  if (debug <= 1) {
    /* open syslog */
    close(2);
    openlog("tcpoptd", LOG_PID, LOG_DAEMON);
    log_to_syslog = 1;
  }

  LOGINFOX("starting...");

  while (running) {
    sinsize = sizeof(sin);
    recvlen = recvfrom(sock, buf, sizeof(buf), 0, (struct sockaddr*) &sin,
                       &sinsize);
    LOGDEBUGX("recvlen=%zi", recvlen);

    if (recvlen < 0) {
      /* recvfrom() failed */
      if (errno != EINTR) ERRDIE("recvfrom()");
      continue;
    }
    buflen = recvlen;
    if (buflen > 0) {
      /* determine ip version */
      ip = (struct ip*) buf;
      LOGDEBUGX("ip_v=%i", ip->ip_v);
      switch (ip->ip_v) {
        case 4:
          parse_v4(buf, &buflen, sizeof(buf), remove_opts, add_opts,
                   add_opts_len);
          break;
#if 0
        case 6:
          parse_v6(buf, &buflen, sizeof(buf), remove_opts, add_opts,
                   add_opts_len);
          break;
#endif
        default: break;
      }
    }

    totalsendlen = 0;
    do {
      /* flush buffer */
      sendlen = sendto(sock, buf + totalsendlen, buflen - totalsendlen, 0,
                       (struct sockaddr*) &sin, sinsize);
      LOGDEBUGX("sendlen=%zi totalsendlen=%u", sendlen, totalsendlen);
      if (sendlen >= 0) totalsendlen += sendlen;
      else if (errno != EINTR) ERRDIE("sendto()");
    } while (running && (totalsendlen < buflen));
  }

  /* bye bye */
  LOGINFOX("exiting...");
  if (log_to_syslog) closelog();

  return 0;
}
