#include "common.h"
#include <stdbool.h>            
#include <time.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <netinet/in.h>
#include <netinet/tcp.h>
#include <sys/wait.h>
#include <unistd.h>
#include "cJSON.h"

#define CA_LIST "ca.pem"
#define KEYFILE "server.pem"
#define PASSWORD "password"
#define DHFILE "dh1024.pem"
#define BUFSIZZ 20000
#define DISABLE_NAGLE 0


#ifdef DEBUG
#define DEBUG_PRINT(...) do{ fprintf( stdout, __VA_ARGS__ ); } while( false )
#else
#define DEBUG_PRINT(...) do{ } while ( false )
#endif

char* handler_command = NULL;

// -------------------------------------------------------------
// Load parameters from "dh1024.pem"
// -------------------------------------------------------------
void load_dh_params(SSL_CTX *ctx, char *file)
{
	DH *ret=0;
	BIO *bio;

    if ((bio=BIO_new_file(file,"r")) == NULL)
    {
        printf ("[SERVER ERROR] Couldn't open DH file %s\n", file);
        exit(0);
	}

	ret = PEM_read_bio_DHparams(bio, NULL, NULL, NULL);
	BIO_free(bio);
	if(SSL_CTX_set_tmp_dh(ctx,ret) < 0)
    {
        printf ("[SERVER ERROR] Couldn't set DH parameters\n");
		exit(0);
	}
}

// -------------------------------------------------------------
// Listen TCP socket
// -------------------------------------------------------------
int tcp_listen(){
    
	int sock;
	struct sockaddr_in sin;
	int val = 1;

	// Create socket, allocate memory and set sock options
	if((sock=socket(AF_INET,SOCK_STREAM,0)) < 0)
    {
		printf("[SERVER] Couldn't make socket");
        exit(-1);
	}

    memset(&sin, 0, sizeof(sin));
    sin.sin_addr.s_addr = INADDR_ANY;
    sin.sin_family = AF_INET;
    sin.sin_port = htons(DEFAULT_SERVER_PORT);
    setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &val,sizeof(val));

/*
    if (DISABLE_NAGLE == 1)
    	set_nagle(sock, 1); 
*/

	// Bind to socket    
	if(bind(sock,(struct sockaddr *)&sin, sizeof(sin))<0)
    {
		printf ("[SERVER] Could not bind to port %d", sin.sin_port);
	}

	// Listen to socket
    listen(sock,5);  

	// Return socket descriptor
    return(sock);
}

// -------------------------------------------------------------
// Check for SSL_write error (just write at this point)
// -- TO DO: check behavior per slice 
// -------------------------------------------------------------
void check_SSL_write_error(SSL *ssl, int r, int request_len)
{
    int errorCode = SSL_get_error(ssl, r);
    switch(errorCode)
    {
        case SSL_ERROR_NONE:
            if(request_len != r)
            {
                printf("[SERVER ERROR] Incomplete write!");
                exit(1);
            }
            break;

        default:
        {
            printf("[SERVER ERROR] SSL error %d", errorCode);
            exit(1);
        }
    }
}



cJSON* createErrorResponse (cJSON* recvData)
{
    cJSON* root = cJSON_CreateObject();
    cJSON* sliceArray = cJSON_AddArrayToObject(root, "slices");

    cJSON* inputSlices = cJSON_GetObjectItem(recvData, "slices");
    cJSON* inputSlice = NULL;
    cJSON_ArrayForEach(inputSlice, inputSlices)
    {
        cJSON* outputSlice = cJSON_CreateObject();
        cJSON_AddNumberToObject(outputSlice, "slice", cJSON_GetObjectItem(inputSlice, "slice")->valueint);
        cJSON_AddStringToObject(outputSlice, "data", "Error");
        cJSON_AddItemToArray(sliceArray, outputSlice);     
    }
    return root;
}


cJSON* call_response_handler(char* inputJsonFile, cJSON* recvData)
{
    DEBUG_PRINT("[SERVER DEBUG] Calling response handler with file %s\n", inputJsonFile);

    char commandBuf[256];
    sprintf(commandBuf, "%s %s", handler_command, inputJsonFile);

    int bufLen = 0;
    char* response = COMMON_CallExternalProcess(commandBuf, &bufLen);

    DEBUG_PRINT ("[SERVER DEBUG] Received %lu bytes from handler\n", strlen(response));
    cJSON* returnJson = cJSON_Parse(response);
    if (returnJson) return returnJson;

    printf ("[SERVER ERROR] Could not parse json from command handler %s  - received { %s }\n", handler_command, response);
    return createErrorResponse(recvData);
}


// -------------------------------------------------------------
// Serve requests in browser-like mode 
// -------------------------------------------------------------
int send_response(SSL *ssl, int s, char *proto)
{
	DEBUG_PRINT("[SERVER DEBUG] Send_response() called\n"); 
    
    SPP_SLICE *slice;       
    SPP_CTX *ctx;     
    char buf[BUFSIZZ];      
    int errorCode = SSL_ERROR_NONE;
    int expectedSlices = 0;
    int slicesSeen = 0;
    cJSON* recvData = cJSON_CreateObject();
    cJSON* sliceDataArray = cJSON_AddArrayToObject(recvData, "slices");

    while ( (errorCode == SSL_ERROR_NONE) && ( expectedSlices == 0 || expectedSlices > slicesSeen))
    {
        if (expectedSlices == 0)
        {
            DEBUG_PRINT("[SERVER DEBUG] Entering read loop for first time\n");
        }
        else
        {
            DEBUG_PRINT("[SERVER DEBUG] Entering read loop, seen %d / %d slices\n", slicesSeen, expectedSlices);
        }

    	DEBUG_PRINT("[SERVER DEBUG] Reading SPP record\n");
        int r = SPP_read_record(ssl, buf, BUFSIZZ, &slice, &ctx);
        errorCode = SSL_get_error(ssl, r);
        if (errorCode != SSL_ERROR_NONE)
        {
            printf("[SERVER ERROR] SSL read error code %d\n", errorCode);
            break;
        }

        if (expectedSlices == 0)
        {
            expectedSlices = ssl->slices_len;
            DEBUG_PRINT("[SERVER DEBUG] Expecting %lu slices\n", ssl->slices_len);
        }

        DEBUG_PRINT("[SERVER DEBUG] Read %d bytes from slice %d (%s)\n", r, slice->slice_id, slice->purpose);
        DEBUG_PRINT("[SERVER DEBUG] { %s }\n", buf);

        // Print proxies
        for (int i = 0; i < ssl->proxies_len; i++)
        {
            DEBUG_PRINT("   [SERVER DEBUG] Proxy: %s\n", ssl->proxies[i]->address); 
        }
        for (int i = 0; i < ssl->slices_len; i++)
        {
            DEBUG_PRINT("   [SERVER DEBUG] Slice with ID %d and purpose %s\n", ssl->slices[i]->slice_id, ssl->slices[i]->purpose); 
        }
        slicesSeen++;

        cJSON* sliceData = cJSON_CreateObject();
        cJSON_AddNumberToObject(sliceData, "slice", slice->slice_id);
        cJSON_AddStringToObject(sliceData, "slicePurpose", slice->purpose);
        cJSON_AddStringToObject(sliceData, "data", COMMON_MakeNullTerminatedCopy(buf, r));
        cJSON_AddItemToArray(sliceDataArray, sliceData);        
    }

    char* sliceDataString = cJSON_Print(recvData);
    printf("[SERVER] Received\n%s\n", sliceDataString);
    free(sliceDataString);

    char* filename = COMMON_WriteJSONFile(recvData, "Server");
    DEBUG_PRINT("[SERVER DEBUG] Data written to %s\n", filename);
    cJSON* responseData = call_response_handler(filename, recvData);

    char* responseDataStr = cJSON_Print(responseData);
    printf("[SERVER] Returning following data\n");
    printf("%s\n", responseDataStr);

    cJSON* responseSlices = cJSON_GetObjectItem(responseData, "slices");
    cJSON* outputSlice = NULL;
    cJSON_ArrayForEach(outputSlice, responseSlices)
    {
        int sliceIndex = cJSON_GetObjectItem(outputSlice, "slice")->valueint;  
        char* sliceData = cJSON_GetObjectItem(outputSlice, "data")->valuestring;

        DEBUG_PRINT("[SERVER DEBUG] Writing %lu bytes as record to slice %d\n", strlen(sliceData), sliceIndex);
        int r = SPP_write_record(ssl, sliceData, strlen(sliceData), SPP_get_slice_by_id(ssl, sliceIndex));
        check_SSL_write_error(ssl, r, strlen(sliceData));        
    }

    free(filename);
    cJSON_Delete(recvData);
    cJSON_Delete(responseData);

    DEBUG_PRINT ("[SERVER DEBUG] Shutting down SSL\n");
    int r = SSL_shutdown(ssl);
    if( !r )
    {
        shutdown(s, 1);
        r = SSL_shutdown(ssl);
    }

    // Verify that shutdown was good 
    switch(r){  
        case 1:
            break; // Success
        case 0:
        case -1:
        default: // Error 
            printf ("[SERVER ERROR] Shutdown failed with code %d\n", r); 
            exit(1);
    }
    // free SSL 
    SSL_free(ssl);

	return 0; 
}

// Main function  
int main(int argc, char **argv)
{
	BIO *sbio;
	SSL *ssl;
   	SSL_CTX *ctx;
    char* proto = "spp";
    int sock, newsock;
    int pid;
    int status;

    if (argc == 1)
    {
        // If no handler is specified, simply echo the data received from the client
        handler_command = "cat";
    }
    else
    {
        char tempBuf[1024];
        memset(tempBuf, 0, 1024);
        int written = 0;
        for (int i = 1; i < argc; i++)
        {
            written += snprintf(tempBuf + written, 1024 - written, "%s ", argv[i]);
        }
        handler_command = strdup(tempBuf);
    }


    // Print out configuration
    DEBUG_PRINT("[SERVER DEBUG] Server starting\n");
    DEBUG_PRINT("[SERVER DEBUG] Configuration:\n");
    DEBUG_PRINT("[SERVER DEBUG]   KeyFile: %s\n", KEYFILE);
    DEBUG_PRINT("[SERVER DEBUG]   DHFile:  %s\n", DHFILE);
    DEBUG_PRINT("[SERVER DEBUG]   Proto:   %s\n", proto);
    DEBUG_PRINT("[SERVER DEBUG]   Handler command { %s }\n", handler_command);

	// Build SSL context
    DEBUG_PRINT("[SERVER DEBUG] Initialising ctx\n");
	status = COMMON_InitializeSSLCtx(&ctx, KEYFILE, PASSWORD, CA_LIST, ID_SERVER);
    DEBUG_PRINT("[SERVER DEBUG] Loading DH parameters\n");
	load_dh_params(ctx, DHFILE);

	// Listen on socket
    DEBUG_PRINT("[SERVER DEBUG] Listening on TCP port %d\n", DEFAULT_SERVER_PORT);
	sock = tcp_listen();

	int nConn = 0; 
	bool report = true; 
	while (1)
    {		
		DEBUG_PRINT("[SERVER DEBUG] Waiting on TCP accept...\n");
        newsock = accept(sock, 0, 0);
		if (newsock < 0)
        {
			printf("[SERVER ERROR] Socket accept failed with code %d\n", newsock);
            exit(1);
		}
        else
        {
			DEBUG_PRINT("[SERVER DEBUG] Accepted new connection %d\n", sock); 
		}

		// Keep track of number of connections
		nConn++;
		DEBUG_PRINT("[SERVER DEBUG] %d connections\n", nConn); 

		// Fork a new process
		signal(SIGCHLD, SIG_IGN); 
		pid = fork(); 
		if (pid == 0)
        {
			// In child process
			if (pid == -1) 
            {
				printf ("[SERVER ERROR] Forking process for new connection failed");
				exit(1);
           	}
			
			DEBUG_PRINT("[SERVER DEBUG] child process close old socket and operate on new one\n");
			close(sock);
            sbio = BIO_new_socket(newsock, BIO_NOCLOSE);
            ssl = SSL_new(ctx);
            SSL_set_bio(ssl, sbio, sbio);

            // Wait on SSL Accept 
            int sslAcceptResult = SSL_accept(ssl);

            X509* server_cert = SSL_get_certificate(ssl);
            printf ("[SERVER] Server certificate:\n");
            COMMON_PrintCertificateDetails(server_cert);
            X509_free(server_cert);

            X509* client_cert = SSL_get_peer_certificate(ssl);
            printf ("[Server] Client certificate:\n");
            COMMON_PrintCertificateDetails(client_cert);
            X509_free(client_cert);

            if (sslAcceptResult <= 0)
            {
                printf("[SERVER ERROR] SSL accept error\n");
                exit(0);
            } 
            else 
            {
                DEBUG_PRINT("[SERVER DEBUG] SPP accept OK\n"); 
            }

            // Send a response back
            DEBUG_PRINT("[SERVER DEBUG] Sending response back\n");
            send_response(ssl, newsock, proto);
  			
			// Correctly end child process
			DEBUG_PRINT("[SERVER DEBUG] Ending child process\n");
            shutdown(newsock, SHUT_RDWR);
			close(newsock); 
			exit(0);  
		}
        else
        {
			DEBUG_PRINT("[SERVER DEBUG] Parent process closing new socket\n");
			close(newsock); 
		}
	}

    // Wait for forked process to complete
	wait(&status);
	
    // Clean context
	COMMON_DestroyCtx(ctx);
	
	// Correctly end parent process
	exit(0); 
}