#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <strings.h>
#include <errno.h>
#include <err.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <net/if.h>
#include <netinet/ip_fw.h>
#include <capsicum_helpers.h>
#include "fw.h"

const char *fw_ident ()
{
  return "ipfw";
}

void *fw_init ()
{
  int save_errno;
  int *ipfw_sock;
  cap_rights_t rights;

  ipfw_sock = (int*) malloc(sizeof(*ipfw_sock));
  if (ipfw_sock == NULL)
    goto FW_INIT_ERR0;

  *ipfw_sock = socket(AF_INET, SOCK_RAW, IPPROTO_RAW);
  if (*ipfw_sock < 0)
    goto FW_INIT_ERR1;

  cap_rights_init(&rights, CAP_GETSOCKOPT, CAP_SETSOCKOPT);
  if (caph_rights_limit(*ipfw_sock, &rights) != 0)
    goto FW_INIT_ERR2;

  return ipfw_sock;

FW_INIT_ERR2:
  save_errno = errno;
  close(*ipfw_sock);
  errno = save_errno;

FW_INIT_ERR1:
  save_errno = errno;
  free(ipfw_sock);
  errno = save_errno;

FW_INIT_ERR0:
  return NULL;
}

int fw_close (void *data)
{
  int ipfw_sock;

  ipfw_sock = *((int*) data);
  free(data);
  return (close(ipfw_sock) == -1 ? -1 : 0);
}

int fw_add_rule (void *data, uint32_t rulenum, struct rule *rule)
{
  int ipfw_sock;
  uint32_t rulebuf[255];
  struct ip_fw *fwrule;
  ipfw_insn *cmd;
  ipfw_insn_u32 *ip4_cmd;
  ipfw_insn_ip6 *ip6_cmd;
  ipfw_insn_u16 *port_cmd;
  socklen_t rulesize;

  fwrule = (struct ip_fw*) rulebuf;
  cmd = fwrule->cmd;

  bzero(rulebuf, sizeof(rulebuf));
  fwrule->rulenum = rulenum;

  cmd->opcode = O_PROBE_STATE;
  cmd->len = 1;
  cmd++;

  cmd->opcode = O_PROTO;
  cmd->len = 1;
  cmd->arg1 = IPPROTO_TCP;
  cmd++;

  if (rule->is_ipv6) {
    ip6_cmd = (ipfw_insn_ip6*) cmd;
    ip6_cmd->o.opcode = O_IP6_SRC;
    ip6_cmd->o.len = 5;
    memcpy(&ip6_cmd->addr6, rule->src_ip, 16);
    cmd += 5;
  } else {
    ip4_cmd = (ipfw_insn_u32*) cmd;
    ip4_cmd->o.opcode = O_IP_SRC;
    ip4_cmd->o.len = 2;
    memcpy(ip4_cmd->d, rule->src_ip, 4);
    cmd += 2;
  }

  port_cmd = (ipfw_insn_u16*) cmd;
  port_cmd->o.opcode = O_IP_SRCPORT;
  port_cmd->o.len = 2;
  port_cmd->ports[0] = rule->src_port_low;
  port_cmd->ports[1] = rule->src_port_high;
  cmd += 2;

  if (rule->is_ipv6) {
    ip6_cmd = (ipfw_insn_ip6*) cmd;
    ip6_cmd->o.opcode = O_IP6_DST;
    ip6_cmd->o.len = 5;
    memcpy(&ip6_cmd->addr6, rule->dst_ip, 16);
    cmd += 5;
  } else {
    ip4_cmd = (ipfw_insn_u32*) cmd;
    ip4_cmd->o.opcode = O_IP_DST;
    ip4_cmd->o.len = 2;
    memcpy(ip4_cmd->d, rule->dst_ip, 4);
    cmd += 2;
  }

  port_cmd = (ipfw_insn_u16*) cmd;
  port_cmd->o.opcode = O_IP_DSTPORT;
  port_cmd->o.len = 2;
  port_cmd->ports[0] = rule->dst_port_low;
  port_cmd->ports[1] = rule->dst_port_high;
  cmd += 2;

  if (rule->ipfw_call) {
    fwrule->act_ofs = (uint32_t*) cmd - (uint32_t*) fwrule->cmd;

    cmd->opcode = O_CALLRETURN;
    cmd->len = 1;
    cmd->arg1 = rule->ipfw_call;
    cmd++;
  } else {
    cmd->opcode = O_KEEP_STATE;
    cmd->len = 1;
    cmd++;

    fwrule->act_ofs = (uint32_t*) cmd - (uint32_t*) fwrule->cmd;

    cmd->opcode = O_ACCEPT;
    cmd->len = 1;
    cmd++;
  }

  fwrule->cmd_len = (uint32_t*) cmd - (uint32_t*) fwrule->cmd;
  rulesize = RULESIZE(fwrule);

  ipfw_sock = *((int*) data);
  if (getsockopt(ipfw_sock, IPPROTO_IP, IP_FW_ADD, fwrule, &rulesize) == -1)
    return -1;

  return 0;
}

int fw_del_rule (void *data, uint32_t rulenum)
{
  int ipfw_sock, ret;

  ipfw_sock = *((int*) data);
  ret = setsockopt(ipfw_sock, IPPROTO_IP, IP_FW_DEL, &rulenum,
                   sizeof(rulenum));
  return (ret == -1 ? -1 : 0);
}
