#include <unistd.h>
#include <stdlib.h>
#include <string.h>
#include <strings.h>
#include <err.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <net/if.h>
#include <netinet/ip_fw.h>
#include "fw.h"

void *fw_init ()
{
  int *ipfw_sock;

  ipfw_sock = (int*) malloc(sizeof(*ipfw_sock));
  if (ipfw_sock == NULL) return NULL;

  *ipfw_sock = socket(AF_INET, SOCK_RAW, IPPROTO_RAW);
  if (*ipfw_sock != -1) return ipfw_sock;

  free(ipfw_sock);
  return NULL;
}

int fw_close (void *data)
{
  int ipfw_sock;

  ipfw_sock = *((int*) data);
  free(data);
  return (close(ipfw_sock) == -1 ? -1 : 0);
}

int fw_add_rule (void *data, uint32_t rulenum, struct rule *rule)
{
  int ipfw_sock;
  uint32_t rulebuf[255];
  struct ip_fw *rule;
  ipfw_insn *cmd;
  ipfw_insn_u32 *ip4_cmd;
  ipfw_insn_ip6 *ip6_cmd;
  ipfw_insn_u16 *port_cmd;
  socklen_t rulesize;

  rule = (struct ip_fw*) rulebuf;
  cmd = rule->cmd;

  bzero(rulebuf, sizeof(rulebuf));
  rule->rulenum = rulenum;

  cmd->opcode = O_PROBE_STATE;
  cmd->len = 1;
  cmd++;

  cmd->opcode = O_PROTO;
  cmd->len = 1;
  cmd->arg1 = IPPROTO_TCP;
  cmd++;

  if (is_ipv6) {
    ip6_cmd = (ipfw_insn_ip6*) cmd;
    ip6_cmd->o.opcode = O_IP6_SRC;
    ip6_cmd->o.len = 5;
    memcpy(&ip6_cmd->addr6, rule->src_ip, 16);
    cmd += 5;
  } else {
    ip4_cmd = (ipfw_insn_u32*) cmd;
    ip4_cmd->o.opcode = O_IP_SRC;
    ip4_cmd->o.len = 2;
    memcpy(ip4_cmd->d, rule->src_ip, 4);
    cmd += 2;
  }

  port_cmd = (ipfw_insn_u16*) cmd;
  port_cmd->o.opcode = O_IP_SRCPORT;
  port_cmd->o.len = 2;
  port_cmd->ports[0] = rule->src_port_low;
  port_cmd->ports[1] = rule->src_port_high;
  cmd += 2;

  if (is_ipv6) {
    ip6_cmd = (ipfw_insn_ip6*) cmd;
    ip6_cmd->o.opcode = O_IP6_DST;
    ip6_cmd->o.len = 5;
    memcpy(&ip6_cmd->addr6, rule->dst_ip, 16);
    cmd += 5;
  } else {
    ip4_cmd = (ipfw_insn_u32*) cmd;
    ip4_cmd->o.opcode = O_IP_DST;
    ip4_cmd->o.len = 2;
    memcpy(ip4_cmd->d, rule->dst_ip, 4);
    cmd += 2;
  }

  port_cmd = (ipfw_insn_u16*) cmd;
  port_cmd->o.opcode = O_IP_DSTPORT;
  port_cmd->o.len = 2;
  port_cmd->ports[0] = rule->dst_port_low;
  port_cmd->ports[1] = rule->dst_port_high;
  cmd += 2;

  cmd->opcode = O_KEEP_STATE;
  cmd->len = 1;
  cmd++;

  rule->act_ofs = (uint32_t*) cmd - (uint32_t*) rule->cmd;

  cmd->opcode = O_ACCEPT;
  cmd->len = 1;
  cmd++;

  rule->cmd_len = (uint32_t*) cmd - (uint32_t*) rule->cmd;
  rulesize = RULESIZE(rule);

  ipfw_sock = *((int*) data);
  if (getsockopt(ipfw_sock, IPPROTO_IP, IP_FW_ADD, rule, &rulesize) == -1)
    return -1;

  return 0;
}

int fw_del_rule (void *data, uint32_t rulenum)
{
  int ipfw_sock, ret;

  ipfw_sock = *((int*) data);
  ret = setsockopt(ipfw_sock, IPPROTO_IP, IP_FW_DEL, &rulenum,
                   sizeof(rulenum));
  return (ret == -1 ? -1 : 0);
}

#if 0
int main()
{
  void *fw_data;
  uint8_t testip1[4] = { 192, 168, 0, 10 };
  uint8_t testip2[4] = { 192, 168, 0, 15 };

  fw_data = fw_init();
  if (fw_data == NULL) err(-1, "fw_init");
//  if (fw_del_rule(fw_data, 1111)) err(-1, "fw_del_rule");

  if (fw_add_rule(fw_data, 1111, 0, testip1, testip2, 0, 65535, 22, 22))
    err(-1, "fw_add_rule");

  if (fw_close(fw_data)) err(-1, "fw_close");
  return 0;
}
#endif
