#include "UdpLayer.hh"
#include "loggers.hh"

UdpLayer::UdpLayer(const std::string & p_type, const std::string & param) : Layer(p_type), _params() {
  loggers::get_instance().log(">>> UdpLayer::UdpLayer: %s, %s", to_string().c_str(), param.c_str());
  // Setup parameters
  Params::convert(_params, param);
  Params::const_iterator it = _params.find("src_ip");
  if (it == _params.cend()) {
    _params.insert(std::pair<std::string, std::string>(std::string("src_ip"), "127.0.0.1"));
  }
  it = _params.find("src_port");
  if (it == _params.cend()) {
    _params.insert(std::pair<std::string, std::string>(std::string("src_port"), "12345"));
  }
  it = _params.find("dst_ip");
  if (it == _params.cend()) {
    _params.insert(std::pair<std::string, std::string>(std::string("dst_ip"), "127.0.0.1"));
  }
  it = _params.find("dst_port");
  if (it == _params.cend()) {
    _params.insert(std::pair<std::string, std::string>(std::string("dst_port"), "12346"));
  }

  //_params.log();
}

void UdpLayer::sendData(OCTETSTRING& data, Params& params) {
  loggers::get_instance().log_msg(">>> UdpLayer::sendData: ", data);
  
  // Create IP/UDP packet
  unsigned int len = sizeof(struct iphdr) + sizeof(struct udphdr) + data.lengthof();
  unsigned char *buffer = new unsigned char[len];
  // Set ip header
  _iphdr = (struct iphdr *)buffer;
  _daddr.sin_family = AF_INET;
  _saddr.sin_family = AF_INET;
  Params::const_iterator it = _params.find("dst_port");
  _daddr.sin_port = htons(std::strtoul(it->second.c_str(), NULL, 10));
  it = _params.find("src_port");
  _saddr.sin_port = htons(std::strtoul(it->second.c_str(), NULL, 10));
  it = _params.find("dst_ip");
  inet_pton(AF_INET, it->second.c_str(), (struct in_addr *)&_daddr.sin_addr.s_addr);
  it = _params.find("src_ip");
  inet_pton(AF_INET, it->second.c_str(), (struct in_addr *)&_saddr.sin_addr.s_addr);
  _iphdr->ihl = 5;
  _iphdr->version = 4;
  _iphdr->tos = 0x0;
  _iphdr->id = 0;
  _iphdr->frag_off = htons(0x4000); /* Don't fragment */
  _iphdr->ttl = 64;
  _iphdr->tot_len = htons(sizeof(struct iphdr) + sizeof(struct udphdr) + data.lengthof());
  _iphdr->protocol = IPPROTO_UDP;
  _iphdr->saddr = _saddr.sin_addr.s_addr;
  _iphdr->daddr = _daddr.sin_addr.s_addr;
  _iphdr->check = 0;
  _iphdr->check = inetCsum((const void *)_iphdr, sizeof(struct iphdr));// The checksum should be calculated over the entire header with the checksum field set to 0, so that's what we do
  // Set udp header
  _udphdr = (struct udphdr *)(buffer + sizeof(struct iphdr));
  _udphdr->source = _saddr.sin_port;
  _udphdr->dest = _daddr.sin_port;
  _udphdr->len = htons(sizeof(struct udphdr) + data.lengthof());
  _udphdr->check = 0;
  // Set payload
  unsigned char *payload = buffer + sizeof(struct iphdr) + sizeof(struct udphdr);
  memcpy(payload, static_cast<const unsigned char *>(data), data.lengthof());
  // Send data lower layers
  OCTETSTRING udp(len, buffer);
  sendToAllLayers(udp, params);
  // Free buffer
  delete [] buffer;
}

void UdpLayer::receiveData(OCTETSTRING& data, Params& params) {
  loggers::get_instance().log_msg(">>> UdpLayer::receiveData: ", data);

  // Decode UDP packet
  const unsigned char* buffer = static_cast<const unsigned char *>(data);
  _iphdr = (struct iphdr *)buffer;
  _udphdr = (struct udphdr *)(buffer + sizeof(struct iphdr));
  loggers::get_instance().log("UdpLayer::receiveData: src_port = %d, payload length = %d", ntohs(_udphdr->source), ntohs(_udphdr->len) - sizeof(struct udphdr));
  // TODO To be refined
  data = OCTETSTRING(ntohs(_udphdr->len) - sizeof(struct udphdr), (unsigned char*)(buffer + sizeof(struct iphdr) + sizeof(struct udphdr)));
  loggers::get_instance().log_msg("UdpLayer::receiveData: message payload", data);
  
  receiveToAllLayers(data, params);
}

unsigned short UdpLayer::inetCsum(const void *buf, size_t hdr_len) {
  unsigned long sum = 0;
  const unsigned short *ip1;

  ip1 = (const unsigned short *) buf;
  while (hdr_len > 1)
  {
    sum += *ip1++;
    if (sum & 0x80000000)
      sum = (sum & 0xFFFF) + (sum >> 16);
    hdr_len -= 2;
  }

  while (sum >> 16)
    sum = (sum & 0xFFFF) + (sum >> 16);

  return(~sum);
}
class UdpFactory: public LayerFactory {
  static UdpFactory _f;
public:
  UdpFactory();
  virtual Layer * createLayer(const std::string & type,
			      const std::string & param);
};

UdpFactory::UdpFactory() {
  // register factory
  loggers::get_instance().log(">>> UdpFactory::UdpFactory");
  LayerStackBuilder::RegisterLayerFactory("UDP", this);
}

Layer * UdpFactory::createLayer(const std::string & type, const std::string & param) {
  return new UdpLayer(type, param);
}

UdpFactory UdpFactory::_f;
