package de.fraunhofer.sit.c2x.pki.ca.validator.pseudonym;

import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.TimeZone;
import java.util.TreeSet;

import org.apache.log4j.Logger;
import org.joda.time.DateTime;
import org.joda.time.DateTimeZone;

import de.fraunhofer.sit.c2x.pki.ca.utils.DateUtils;
import de.fraunhofer.sit.c2x.pki.ca.validator.pseudonym.intervaltree.Interval;
import de.fraunhofer.sit.c2x.pki.ca.validator.pseudonym.intervaltree.IntervalTree;

/**
 * @author Daniel Quanz (daniel.quanz@sit.fraunhofer.de)
 */
public class PPEUtils {

	public static class IntervalDataPair {
		private final long oldD;
		private final long newD;

		public IntervalDataPair(long oldD, long newD) {
			this.oldD = oldD;
			this.newD = newD;
		}

		public long getOldD() {
			return oldD;
		}

		public long getNewD() {
			return newD;
		}
	}

	/**
	 * 
	 * @param ppn
	 * @param plt
	 * @param ppt
	 * @param lst
	 * @param reqeuestedInterval
	 * @param logger 
	 * @return
	 */
	public static PPEResult computeIntervals(int ppn, long plt, long ppt,
			List<Interval<IntervalDataPair>> lst, Interval<Integer> reqeuestedInterval, Logger logger) {

		// compute valid interval
		Interval<IntervalDataPair> reqInterval = reqInterval(reqeuestedInterval.getStart(),
				reqeuestedInterval.getEnd(), reqeuestedInterval.getData(), plt, ppt, logger);
		lst.add(reqInterval);

		// compute number to fill
		Integer requestedCerts = reqeuestedInterval.getData();
		long slots = ((reqInterval.getEnd() - reqInterval.getStart()) / plt);
//		System.out.println("end: " + reqInterval.getEnd());
//		System.out.println("start: " + reqInterval.getStart());
//		System.out.println("end-start :" + (reqInterval.getEnd() - reqInterval.getStart()));
//		System.out.println("plt= " + plt);
//		System.out.println("requested: " + requestedCerts);
//		System.out.println("slots: " + slots);
		if (slots == 0 || requestedCerts == 0) {
			return new PPEResult(0, plt, new TreeSet<Interval<Integer>>());
		}

		long min = requestedCerts / slots;
		int rest = (int) (requestedCerts % slots);
//		System.out.println("min: " + min);
//		System.out.println("rest: " + rest);

		List<Interval<IntervalDataPair>> newIntervals = new ArrayList<>();
		rest = fillMinimum(ppn, plt, min, rest, lst, newIntervals);
//		System.out.println(newIntervals.size());
		fillRest(ppn, rest, newIntervals);

		return joinIntervals(newIntervals, plt);
	}

	/**
	 * 
	 * @param lst
	 *            list of intervals
	 * @return
	 */
	private static IntervalTree<IntervalDataPair> buildIntervalTree(List<Interval<IntervalDataPair>> lst) {
		IntervalTree<IntervalDataPair> tree = new IntervalTree<>();
		tree.addAll(lst);
		tree.build();
		return tree;
	}

	/**
	 * 
	 * @param ap
	 * @return
	 */
	private static PPEResult joinIntervals(List<Interval<IntervalDataPair>> ap, long plt) {
		Set<Interval<Integer>> a = new TreeSet<>(new Comparator<Interval<Integer>>() {
			@Override
			public int compare(Interval<Integer> o1, Interval<Integer> o2) {
				return Long.compare(o1.getStart(), o2.getStart());
			}
		});
		Interval<IntervalDataPair> last = ap.get(0);
		int certs = computeNumberOfCerts(ap.get(0), plt);
//		System.out.println("-" + certs);
		for (int i = 1; i < ap.size(); i++) {
			Interval<IntervalDataPair> current = ap.get(i);
			if (current.getData().getNewD() != last.getData().getNewD()) {
				if (last.getData().getNewD() > 0)
					a.add(new Interval<Integer>(last.getStart(), last.getEnd(), (int) last.getData()
							.getNewD()));
				last = current;
			} else {
				last = new Interval<IntervalDataPair>(last.getStart(), current.getEnd(),
						new IntervalDataPair(0, last.getData().getNewD()));
			}
			certs += computeNumberOfCerts(ap.get(i), plt);
		}
//		System.out.println("--" + last.getData().getNewD());
		if (last.getData().getNewD() > 0)
			a.add(new Interval<Integer>(last.getStart(), last.getEnd(), (int) last.getData().getNewD()));
//		System.out.println("-->" + a.size());
		return new PPEResult(certs, plt, a);
	}

	private static int computeNumberOfCerts(Interval<IntervalDataPair> interval, long plt) {
		return (int) ((interval.getEnd() - interval.getStart()+1) / plt) * (int) interval.getData().getNewD();
	}

	/**
	 * 
	 * @param ppn
	 * @param rest
	 * @param ap
	 */
	private static void fillRest(int ppn, int rest, List<Interval<IntervalDataPair>> ap) {
		int i = 0;
		while (i < ap.size() && rest > 0) {
			IntervalDataPair before = ap.get(i).getData();
			int a = ppn - ((int) before.getOldD() + (int) before.getNewD());
			int b = (rest > a) ? a : rest;
			if (a > 0) {
				ap.get(i).setData(new IntervalDataPair(before.getOldD(), before.getNewD() + b));
				rest -= b;
			}
			i++;
		}
	}

	/**
	 * 
	 * @param ppn
	 *            parallel pseudonym number
	 * @param plt
	 *            pseudonym lifetime time
	 * @param minNumberOfCerts
	 *            computed min. number of certificate(s) to add to interval
	 * @param restNumberOfCerts
	 *            computed number of the rest
	 * @param tree
	 * @param bounds
	 * @param ap
	 * @param intervals
	 * @return
	 */
	private static int fillMinimum(int ppn, long plt, long minNumberOfCerts, int restNumberOfCerts,
			List<Interval<IntervalDataPair>> intervals, List<Interval<IntervalDataPair>> ap) {

		Long[] bounds = getBounds(intervals);
		IntervalTree<IntervalDataPair> tree = buildIntervalTree(intervals);
		for (int i = 0; i < bounds.length - 1; i++) {
			List<IntervalDataPair> overlappingData = tree.get(bounds[i], bounds[i + 1]);
			int sum = 0;
			for (IntervalDataPair integer : overlappingData) {
				sum += integer.getOldD();
			}

			long add2;
			if (ppn - sum < minNumberOfCerts) {
				add2 = ppn - sum;
				restNumberOfCerts += (minNumberOfCerts - (ppn - sum)) * ((bounds[i + 1] - bounds[i]) / plt);
			} else {
				add2 = minNumberOfCerts;
			}

			long start = bounds[i];
			while (start < bounds[i + 1]) {
				ap.add(new Interval<IntervalDataPair>(start, start + plt, new IntervalDataPair(sum,
						(int) add2)));
				start += plt;
			}

			// ap.add(new Interval<IntervalDataPair>(bounds[i], bounds[i + 1],
			// new
			// IntervalDataPair(sum, (int) add2)));
			//
			// System.out.println("STEP1: " + ap.get(i).getStart() + " - " +
			// ap.get(i).getEnd() + ", old="
			// + ap.get(i).getData().getOldD() + ", new=" +
			// ap.get(i).getData().getNewD() + ", r=" + r);
		}
		return restNumberOfCerts;
	}

	
	private static Object debug(Logger logger, String message) {
		
		if(logger != null && logger.isDebugEnabled())
			logger.debug(message);
		
		return message;
	}
	
	/**
	 * 
	 * @param start
	 * @param end
	 * @param reqCerts
	 * @param plt
	 * @param ppt
	 * @return
	 */
	public static Interval<IntervalDataPair> reqInterval(long start, long end, int reqCerts, long plt,
			long ppt, Logger logger) {

		
		debug(logger, "Start to adjust requested start- and expirationTime");
		debug(logger, "Requested startTime: "+new DateTime(start));
		debug(logger, "Requested expirationTime: "+new DateTime(end));
		TimeZone.setDefault(TimeZone.getTimeZone("UTC"));
		DateTime nowplusppt = new DateTime(new DateTime(DateTimeZone.UTC).getMillis()+ppt, DateTimeZone.UTC);
		start = DateUtils.floorDay(new DateTime(start, DateTimeZone.UTC)).getMillis();
		end = DateUtils.floorDay(new DateTime(end, DateTimeZone.UTC)).getMillis();
		debug(logger, "Rounded (down) startTime: "+new DateTime(start));
		debug(logger, "Rounded (down) expirationTime: "+new DateTime(end));
		
		debug(logger, "Check if expirationTime is less or equal NOW()+PPP: "+nowplusppt);
		if (end > nowplusppt.getMillis()) {
			end = DateUtils.floorDay(nowplusppt).getMillis();
			debug(logger, "Adjust expirationTime based on NOW()+PPP: "+new DateTime(end));
		}
		long slots = ((end - start) / plt);
		debug(logger, "Number of intervals: "+slots);
		if(slots<=0) {
			slots = 1;
			debug(logger, "Adjust number of intervals: "+slots);
		}
		debug(logger, "Number of requested certs: "+reqCerts);
		
		long expiration = start;
		if(reqCerts> slots) {
			expiration += (slots * plt);
			debug(logger, "Adjust expirationTime based on number of intervals: "+new DateTime(expiration));
		}else{
			expiration += (reqCerts * plt);
			debug(logger, "Adjust expirationTime based on number of reqested certs: "+new DateTime(expiration));
		}

		debug(logger, "New startTime: "+new DateTime(start));
		debug(logger, "New expirationTime: "+new DateTime(expiration));
		
		return new Interval<IntervalDataPair>(start, expiration, new IntervalDataPair(0, 0));

	}

	/**
	 * 
	 * @param intervals
	 * @return
	 */
	private static Long[] getBounds(List<Interval<IntervalDataPair>> intervals) {
		Set<Long> set = new HashSet<Long>();
		long start = intervals.get(intervals.size() - 1).getStart();
		long end = intervals.get(intervals.size() - 1).getEnd();
		for (Interval<IntervalDataPair> interval : intervals) {
			if (interval.getStart() >= start && interval.getStart() <= end) {
				set.add(interval.getStart());
			}
			if (interval.getEnd() >= start && interval.getEnd() <= end) {
				set.add(interval.getEnd());
			}
		}
		set = new TreeSet<>(set);
		return set.toArray(new Long[set.size()]);

	}

}
