/*
Copyright (c) 2014, 2019, Felix J. Ogris
All rights reserved.

Redistribution and use in source and binary forms, with or without
modification, are permitted provided that the following conditions are met: 

1. Redistributions of source code must retain the above copyright notice, this
   list of conditions and the following disclaimer. 
2. Redistributions in binary form must reproduce the above copyright notice,
   this list of conditions and the following disclaimer in the documentation
   and/or other materials provided with the distribution. 

THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/

/*
 * ftp6proxy - IPv6 FTP firewall helper for systems offering
 *             pf and ipfw packet filters, e.g. FreeBSD
 */

#include <stdlib.h>
#include <stdio.h>
#include <unistd.h>
#include <string.h>
#include <errno.h>
#include <signal.h>
#include <pwd.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/syslog.h>
#include <sys/socket.h>
#include <sys/event.h>
#include <sys/time.h>
#include <netinet/in.h>
#include <netinet/ip.h>
#include <netinet/ip6.h>
#include <netinet/tcp.h>
#include "fw.h"

#define LOGX(level, msg, params ...) do { \
  if (log_to_syslog) \
    syslog(level, "[%s:%03i] " msg "\n", __FILE__, __LINE__, ## params); \
  else \
    fprintf(stderr, "[%s:%03i] " msg "\n", __FILE__, __LINE__, ## params); \
} while (0)

#define ERRDIEX(msg, params ...) do { \
  LOGX(LOG_ERR, msg, ## params); \
  exit(1); \
} while (0)
#define LOGWARNX(msg, params ...) LOGX(LOG_WARNING, msg, ## params)
#define LOGINFOX(msg, params ...) LOGX(LOG_INFO, msg, ## params)
#define LOGDEBUGX(msg, params ...) do { if (debug >= 2) \
  LOGX(LOG_DEBUG, msg, ## params); \
} while (0)

#define ERRDIE(msg, params ...) do { \
  LOGX(LOG_ERR, msg ": %s", ## params, strerror(errno)); \
  exit(1); \
} while (0)
#define LOGERR(msg, params ...) LOGX(LOG_ERR, msg ": %s", \
                                     ## params, strerror(errno))
#define LOGWARN(msg, params ...) LOGX(LOG_WARNING, msg ": %s", \
                                      ## params, strerror(errno))

int debug = 0;
int log_to_syslog = 0;
int running = 1;

void kill_handler(int sig __attribute__((unused)))
{
  running = 0;
  LOGINFOX("received SIGKILL, going down...");
}

int parse_ftp_line(void *buf, char *code, char separator, char *format,
                   unsigned int *port1, unsigned int *port2)
{
  char *left, *right;

  if (strncasecmp(buf, code, strlen(code))) return 0;
  if ((right = strrchr(buf, separator)) == NULL) return 0;
  *right = '\0';
  left = strrchr(buf, separator);
  *right = separator;
  if (left == NULL) return 0;
  return sscanf(left + 1, format, port1, port2);
}

void parse_active_ftp(void *buf, struct rule *rule)
{
  unsigned int port, high, low;

  if (parse_ftp_line(buf, "EPRT", '|', "%u", &port, NULL) == 1) {
    LOGDEBUGX("EPRT port=%u", port);
  } else if ((parse_ftp_line(buf, "PORT", ',', "%u,%u", &high, &low) == 2) ||
             (parse_ftp_line(buf, "LPRT", ',', "%u,%u", &high, &low) == 2)) {
    port = high * 256 + low;
    LOGDEBUGX("PORT/LPRT high_port=%u low_port=%u port=%u", high, low, port);
  } else {
    return;
  }

  rule->src_port_low = 20;
  rule->src_port_high = 65535; /* actually, 20, and 1024-65535 */
  rule->dst_port_low = port;
  rule->dst_port_high = port;
  rule->action = ADD;
}

void parse_passive_ftp(void *buf, struct rule *rule)
{
  unsigned int port, high, low;

  if (parse_ftp_line(buf, "229", '|', "%u", &port, NULL) == 1) {
    LOGDEBUGX("EPSV port=%u", port);
  } else if ((parse_ftp_line(buf, "227", ',', "%u,%u", &high, &low) == 2) ||
             (parse_ftp_line(buf, "228", ',', "%u,%u", &high, &low) == 2)) {
    port = high * 256 + low;
    LOGDEBUGX("PASV/LPSV high_port=%u low_port=%u port=%u", high, low, port);
  } else {
    return;
  }

  rule->src_port_low = 1024;
  rule->src_port_high = 65535;
  rule->dst_port_low = port;
  rule->dst_port_high = port;
  rule->action = ADD;
}

void parse_tcp(void *buf, unsigned int len, struct rule *rule)
{
  struct tcphdr *tcp = (struct tcphdr*) buf;
  unsigned int th_off;
  char *right, save;

  /* tcp sanity checks */
  if (len < sizeof(struct tcphdr))
    return;

  th_off = tcp->th_off * 4;
  LOGDEBUGX("th_off=%i", th_off);

  if (th_off < sizeof(struct tcphdr))
    return;

  if (len < th_off)
    return;

  /* we need some data */
  len -= th_off;
  if (len <= 0)
    return;

  buf += th_off;
  right = buf + len - 1;

  if ((*right == '\n') || (*right == '\r')) {
    save = *right;
    *right = '\0';

    if (tcp->th_sport == htons(21)) parse_passive_ftp(buf, rule);
    if (tcp->th_dport == htons(21)) parse_active_ftp(buf, rule);

    *right = save;
  }
}

/* parse ipv4 packet */
void parse_v4(void *buf, unsigned int len, struct rule *rule)
{
  struct ip *ip = (struct ip*) buf;
  unsigned int ip_hl;

  /* ip sanity checks */
  if (len < sizeof(struct ip))
    return;

  ip_hl = ip->ip_hl * 4;
  LOGDEBUGX("ip_p=%i ip_hl=%i", ip->ip_p, ip_hl);

  if (ip->ip_p != IPPROTO_TCP)
    return;

  if (ip_hl < sizeof(struct ip))
    return;

  if (len < ip_hl)
    return;

  parse_tcp(buf + ip_hl, len - ip_hl, rule);

  rule->is_ipv6 = 0;
  rule->src_ip = &ip->ip_dst;
  rule->dst_ip = &ip->ip_src;
}

/* parse ipv6 packet */
void parse_v6(void *buf, unsigned int len, struct rule *rule)
{
  struct ip6_hdr *ip6 = (struct ip6_hdr*) buf;
  unsigned int ip_hl;
  struct ip6_ext *ip6_ext;
  unsigned int next;
  unsigned int next_len;

  /* ip6 sanity checks */
  if (len < sizeof(struct ip6_hdr))
    return;

  /* skip ipv6 extension headers */
  ip_hl = sizeof(struct ip6_hdr);
  ip6_ext = (struct ip6_ext*) (ip6 + 1);
  next = ip6->ip6_nxt;

  while (next != IPPROTO_TCP) {
    if (len < ip_hl + sizeof(struct ip6_ext))
      return;

    next_len = 8;
    switch (next) {
      case 0:   /* hop by hop */
      case 43:  /* routing */
      case 51:  /* ah */
      case 60:  /* destination options */
      case 135: /* mobility */
        next_len += ip6_ext->ip6e_len * 8;
      case 44:  /* fragment - length field is reserved - yuck yuck */
        break;

      case 50:  /* esp */
      case 59:  /* no next header */
      default:
        return;
    }

    next = ip6_ext->ip6e_nxt;
    ip_hl += next_len;
    ip6_ext = ((void*) ip6_ext) + next_len;
  }

  parse_tcp(buf + ip_hl, len - ip_hl, rule);

  rule->is_ipv6 = 1;
  rule->src_ip = &ip6->ip6_dst;
  rule->dst_ip = &ip6->ip6_src;
}

void print_help(FILE *f)
{
  fputs(
"ftp6proxy V0.2 - IPv6 FTP firewall helper\n\n"
"usage: ftp6proxy [-b <base>] [-c <count>] [-p <port>] [-t <timeout>]\n"
"                 [-q <queue>] [-d [-d]] [-i <pidfile>] [-u <user>]\n"
"       ftp6proxy -h\n\n"
"  -b <base>      punch holes into firewall starting at this rule number;\n"
"                 default: 64000\n"
"  -c <count>     use up to <count> firewall rules; default: 100\n"
"  -p <port>      divert port number to listen on; defaults to 10004\n"
"  -t <timeout>   wait up to <timeout> seconds for data connections;\n"
"                 defaults to 2 seconds\n"
"  -q <queue>     place data connections into this QoS class\n"
"  -d             run in foreground, don't write pidfile\n"
"  -d -d          run in foreground, don't write pidfile, "
                                                      "debug log to stderr,\n"
"                 don't drop privileges\n"
"  -i <pidfile>   write pid to <pidfile> instead of /var/run/ftp6proxy.pid\n"
"  -u <user>      run as user <user> instead of proxy\n"
"  -h             show this help ;-)\n\n", f);
}

int main(int argc, char **argv)
{
  char ch;
  char *pidfile = "/var/run/ftp6proxy.pid";
  char *user = "proxy";
  gid_t gid;
  uid_t uid;
  struct passwd *pw;
  pid_t pid;
  struct stat st;
  FILE *f;
  char buf[65536];
  struct sigaction sa;
  struct sockaddr_in sin;
  int port = 10004;
  int sock;
  ssize_t recvlen;
  ssize_t sentlen;
  socklen_t sinsize;
  struct ip *ip;
  time_t *rules; /* ring buffer */
  int rule_base = 64000;
  int rule_count = 100;
  int rule_start = 0;
  int rule_end = 0;
  struct rule rule;
  int timeout = 2;
  void *fw;
  int kq;
  struct kevent kev;
  struct timespec ktimeout;
  int numkevents;

  rule.queue = NULL;

  while ((ch = getopt(argc, argv, "b:c:p:t:q:dhi:u:")) != -1) {
    switch (ch) {
      case 'b':
        rule_base = atoi(optarg);
        if (rule_base < 0) ERRDIEX("rule base must be 0 or greater");
        break;
      case 'c':
        rule_count = atoi(optarg);
        if (rule_count <= 0) ERRDIEX("rule count must be 1 or greater");
        break;
      case 'p':
        /* divert port */
        port = atoi(optarg);
        if (port < 0) ERRDIEX("port must be 0 or greater");
        break;
      case 't':
        timeout = atoi(optarg);
        if (timeout < 0) ERRDIEX("timeout must be 0 or greater");
        break;
      case 'q':
        rule.queue = optarg;
        break;
      case 'd':
        debug++;
        break;
      case 'i':
        pidfile = optarg;
        break;
      case 'u':
        user = optarg;
        break;
      default:
        print_help(ch != 'h' ? stderr : stdout);
        return (ch != 'h');
    }
  }

  if ((pw = getpwnam(user)) == NULL)
    ERRDIEX("no such user: %s", user);
  gid = pw->pw_gid;
  uid = pw->pw_uid;

  close(0);
  close(1);

  umask(0022);

  if (debug <= 0) {
    /* check pidfile */
    if (lstat(pidfile, &st) == 0) {
      if (!S_ISREG(st.st_mode))
        ERRDIEX("%s exists, but is not a regular file", pidfile);
      if ((f = fopen(pidfile, "r")) == NULL) ERRDIE("fopen(%s, r)", pidfile);
      if (fgets(buf, sizeof(buf), f) == NULL) ERRDIE("fgets(%s)", pidfile);
      pid = atoi(buf);
      fclose(f);
      if (kill(pid, 0) == 0) ERRDIEX("still running (pid=%li)", (long)pid);
    } else if (errno != ENOENT) ERRDIE("%s", pidfile);

    /* daemonize */
    pid = fork();
    if (pid < 0) ERRDIE("fork()");
    if (pid > 0) return 0;
    if (setsid() == -1) ERRDIE("setsid()");

    /* save pid */
    if ((f = fopen(pidfile, "w")) == NULL) ERRDIE("fopen(%s, w)", pidfile);
    fprintf(f, "%li\n", (long)getpid());
    if (fclose(f) != 0) ERRDIE("fclose(%s)", pidfile);;
  }

  fw = fw_init();
  if (fw == NULL) ERRDIE("cannot initialize firewall");

  /* create divert socket */
  if ((sock = socket(PF_INET, SOCK_RAW, IPPROTO_DIVERT)) < 0)
    ERRDIE("socket()");
  
  memset(&sin, 0, sizeof(struct sockaddr_in));
  sin.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
  sin.sin_family = AF_INET;
  sin.sin_port = htons(port);

  if (bind(sock, (struct sockaddr*) &sin, sizeof(struct sockaddr_in)) < 0)
    ERRDIE("bind()");

  if (debug <= 1) {
    /* drop privileges */
    if (setgid(gid) != 0) ERRDIE("setgid(%lu)", (long unsigned)gid);
    if (setuid(uid) != 0) ERRDIE("setuid(%lu)", (long unsigned)uid);
    if (chdir("/") != 0) ERRDIE("chdir(/)");
  }

  /* allocate ring buffer */
  rules = malloc(rule_count * sizeof(time_t));
  if (rules == NULL) ERRDIE("cannot allocate rules buffer");

  /* we can't block on the socket in order to remove temporary rules */
  kq = kqueue();
  if (kq == -1) ERRDIE("kqueue()");

  EV_SET(&kev, sock, EVFILT_READ, EV_ADD, 0, 0, NULL);
  if (kevent(kq, &kev, 1, NULL, 0, NULL) == -1)
    ERRDIE("cannot add socket to kqueue");

  /* set up signal handlers */
  memset(&sa, 0, sizeof(sa));
  sa.sa_handler = kill_handler;
  if (sigaction(SIGTERM, &sa, NULL) < 0)
    ERRDIE("sigaction(SIGTERM)");

  if (debug <= 1) {
    /* open syslog */
    close(2);
    openlog("ftp6proxy", LOG_PID, LOG_DAEMON);
    log_to_syslog = 1;
  }

  LOGINFOX("starting...");

  while (running) {
    while (rule_start != rule_end) {
      /* delete expired rules */
      if (rules[rule_start] + timeout > time(NULL)) break;
      if (fw_del_rule(fw, rule_base + rule_start) == -1)
        LOGWARN("cannot delete fw rule #%u", rule_base + rule_start);
      else
        LOGDEBUGX("deleted fw rule #%u ts=%li",
                  rule_base + rule_start, rules[rule_start]);
      rule_start++;
      if (rule_start >= rule_count) rule_start = 0;
    }

    ktimeout.tv_sec = 1;
    ktimeout.tv_nsec = 0;
    numkevents = kevent(kq, NULL, 0, &kev, 1, &ktimeout);
    if (numkevents == 0) continue;
    if (numkevents != 1) {
      /* kevent() failed */
      /* check if we're stil running to avoid log message on SIGKILL */
      if (running) LOGERR("kevent()");
      if (errno != EINTR) running = 0;
      continue;
    }

    sinsize = sizeof(sin);
    recvlen = recvfrom(sock, buf, sizeof(buf), 0, (struct sockaddr*) &sin,
                       &sinsize);
    LOGDEBUGX("recvlen=%zi", recvlen);

    if (recvlen < 0) {
      /* recvfrom() failed */
      if (running) LOGERR("recvfrom()");
      if (errno != EINTR) running = 0;
      continue;
    }

    rule.action = NONE;

    if (recvlen > 0) {
      /* determine ip version */
      ip = (struct ip*) buf;
      LOGDEBUGX("ip_v=%i", ip->ip_v);
      switch (ip->ip_v) {
        case 4:
          parse_v4(buf, recvlen, &rule);
          break;
        case 6:
          parse_v6(buf, recvlen, &rule);
          break;
        default: break;
      }
    }

    if (rule.action == ADD) {
      /* add rule and remember current time */
      rules[rule_end] = time(NULL);
      if (fw_add_rule(fw, rule_base + rule_end, &rule) == -1)
        LOGWARN("cannot add fw rule #%u", rule_base + rule_end);
      else
        LOGDEBUGX("added fw rule #%u ts=%li",
                  rule_base + rule_end, rules[rule_end]);
      if (rule_end + 1 == rule_start)
        LOGWARNX("too many rules, try increasing the rule count");
      else {
        rule_end++;
        if (rule_end >= rule_count) rule_end = 0;
      }
    }

    /* flush buffer */
    sentlen = sendto(sock, buf, recvlen, 0, (struct sockaddr*) &sin, sinsize);
    LOGDEBUGX("sentlen=%zi", sentlen);
    if ((sentlen >= 0) && (sentlen != recvlen))
      LOGWARNX("sent %zi of %zi bytes", sentlen, recvlen);
  }

  while (rule_start != rule_end) {
    /* delete all rules */
    if (fw_del_rule(fw, rule_base + rule_start) == -1)
      LOGWARN("cannot delete fw rule #%u", rule_base + rule_start);
    else
      LOGDEBUGX("deleted fw rule #%u ts=%li",
                rule_base + rule_start, rules[rule_start]);
    rule_start++;
    if (rule_start >= rule_count) rule_start = 0;
  }

  fw_close(fw);
  free(rules);

  /* bye bye */
  LOGINFOX("exiting...");
  if (log_to_syslog) closelog();

  return 0;
}
