#include "udp_layer_factory.hh"

#include "loggers.hh"

udp_layer::udp_layer(const std::string & p_type, const std::string & param) : layer(p_type), _params() {
  loggers::get_instance().log(">>> udp_layer::udp_layer: %s, %s", to_string().c_str(), param.c_str());
  // Setup parameters
  params::convert(_params, param);
  params::const_iterator it = _params.find("src_ip");
  if (it == _params.cend()) {
    _params.insert(std::pair<std::string, std::string>(std::string("src_ip"), "127.0.0.1"));
  }
  it = _params.find("src_port");
  if (it == _params.cend()) {
    _params.insert(std::pair<std::string, std::string>(std::string("src_port"), "12345"));
  }
  it = _params.find("dst_ip");
  if (it == _params.cend()) {
    _params.insert(std::pair<std::string, std::string>(std::string("dst_ip"), "127.0.0.1"));
  }
  it = _params.find("dst_port");
  if (it == _params.cend()) {
    _params.insert(std::pair<std::string, std::string>(std::string("dst_port"), "12346"));
  }

  //_params.log();
}

void udp_layer::send_data(OCTETSTRING& data, params& params) {
  loggers::get_instance().log_msg(">>> udp_layer::send_data: ", data);
  
  // Create IP/UDP packet
  unsigned int len = sizeof(struct iphdr) + sizeof(struct udphdr) + data.lengthof();
  unsigned char *buffer = new unsigned char[len];
  // Set ip header
  _iphdr = (struct iphdr *)buffer;
  _daddr.sin_family = AF_INET;
  _saddr.sin_family = AF_INET;
  params::const_iterator it = _params.find("dst_port");
  _daddr.sin_port = htons(std::strtoul(it->second.c_str(), NULL, 10));
  it = _params.find("src_port");
  _saddr.sin_port = htons(std::strtoul(it->second.c_str(), NULL, 10));
  it = _params.find("dst_ip");
  inet_pton(AF_INET, it->second.c_str(), (struct in_addr *)&_daddr.sin_addr.s_addr);
  it = _params.find("src_ip");
  inet_pton(AF_INET, it->second.c_str(), (struct in_addr *)&_saddr.sin_addr.s_addr);
  _iphdr->ihl = 5;
  _iphdr->version = 4;
  _iphdr->tos = IPTOS_LOWDELAY;
  _iphdr->id = 0;
  _iphdr->frag_off = htons(0x4000); /* Don't fragment */
  _iphdr->ttl = 64;
  _iphdr->tot_len = htons(sizeof(struct iphdr) + sizeof(struct udphdr) + data.lengthof());
  _iphdr->protocol = IPPROTO_UDP;
  _iphdr->saddr = _saddr.sin_addr.s_addr;
  _iphdr->daddr = _daddr.sin_addr.s_addr;
  _iphdr->check = 0;
  _iphdr->check = inet_check_sum((const void *)_iphdr, sizeof(struct iphdr));
  // Set udp header
  _udphdr = (struct udphdr *)(buffer + sizeof(struct iphdr));
  _udphdr->source = _saddr.sin_port;
  _udphdr->dest = _daddr.sin_port;
  _udphdr->len = htons(sizeof(struct udphdr) + data.lengthof());
  _udphdr->check = 0;
  // Set payload
  unsigned char *payload = buffer + sizeof(struct iphdr) + sizeof(struct udphdr);
  memcpy(payload, static_cast<const unsigned char *>(data), data.lengthof());
  // Calculate UDP checksum
  _udphdr->check = inet_check_sum(
                                  (const void *)_udphdr,
                                  sizeof(struct udphdr),
                                  inet_check_sum(
                                                 static_cast<const unsigned char*>(data),
                                                 data.lengthof(),
                                                 inet_check_sum(
                                                                (const unsigned char*)(&(_iphdr->saddr)),
                                                                2 * sizeof(_iphdr->saddr),
                                                                IPPROTO_UDP + static_cast<unsigned int>(ntohs(_udphdr->len))
                                                                )
                                                 )
                                  );
  // Send data lower layers
  OCTETSTRING udp(len, buffer);
  send_to_all_layers(udp, params);
  // Free buffer
  delete [] buffer;
}

void udp_layer::receive_data(OCTETSTRING& data, params& params) {
  loggers::get_instance().log_msg(">>> udp_layer::receive_data: ", data);

  // Decode UDP packet
  const unsigned char* buffer = static_cast<const unsigned char *>(data);
  _iphdr = (struct iphdr*)buffer;
  _udphdr = (struct udphdr*)(buffer + sizeof(struct iphdr));
  loggers::get_instance().log("udp_layer::receive_data: src_port = %d, payload length = %d", ntohs(_udphdr->source), ntohs(_udphdr->len));
  // TODO To be refined
  data = OCTETSTRING(ntohs(_udphdr->len) - sizeof(struct udphdr), (unsigned char*)(buffer + sizeof(struct iphdr) + sizeof(struct udphdr)));
  //loggers::get_instance().log_msg("udp_layer::receive_data: message payload", data);
  
  receive_to_all_layers(data, params);
}

unsigned short udp_layer::inet_check_sum(const void *buf, size_t len, const unsigned short p_initial_sum) {
  unsigned long sum = p_initial_sum;
  unsigned int i;
  
  // Checksum all the pairs of bytes first...
  for (i = 0; i < (len & ~1U); i += 2) {
    sum += (u_int16_t)ntohs(*((u_int16_t *)((unsigned char*)buf + i)));
    if (sum > 0xFFFF)
      sum -= 0xFFFF;
  } // End of 'for' statement
  // If there's a single byte left over, checksum it, too
  if (i < len) {
    sum += *((unsigned char*)buf + i) << 8;
    if (sum > 0xFFFF) {
      sum -= 0xFFFF;
    }
  }
  
  return htons((~sum & 0xffff));
}

udp_layer_factory udp_layer_factory::_f;
