package de.fraunhofer.sit.c2x.pki.ca.utils;

import java.io.ByteArrayOutputStream;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.lang.reflect.Array;
import java.lang.reflect.Constructor;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;

import de.fraunhofer.sit.c2x.pki.ca.core.interfaces.WaveType;

/**
 * Useful utility functions regarding WAVE certificates and certificate
 * revocation lists which are useful on server as well as on client side.
 */
public class WaveUtils {

	// public static UInt8 getFieldSize2(PKAlgorithm algorithm) {
	// switch (algorithm) {
	// case ECIES_NISTP256:
	// case ECDSA_NISTP256_WITH_SHA_256:
	// return FactoryInstance.getFactory().createUInt8(32);
	// case ECDSA_NISTP224_WITH_SHA224:
	// return FactoryInstance.getFactory().createUInt8(28);
	// default:
	// throw new IllegalArgumentException("Invalid PKAlgorithm");
	// }
	// }

	public static int getVariableLengthFieldLength(int length) {
		if (length >= 0 && length < 128) {
			return 1;
		} else if (length >= 0 && length < 16384) {
			return 2;
		} else if (length >= 0 && length < 2097152) {
			return 3;
		} else if (length >= 0 && length < 268435456) {
			return 4;
		}
		return 0;
	}

	/**
	 * Read flags from DataInputStream and remove the length bits in front of
	 * the flags. See IEEE 1609.2 v2 section 5.1.13
	 * 
	 * @param in
	 * @return
	 * @throws IOException
	 */
	public static byte[] readFlags(DataInputStream in) throws IOException {
		byte firstByte = in.readByte();
		if ((firstByte >>> 7 & 0x000000FF) == 0) {
			// flags bits = 0xxx xxxx => zero additional bytes for flags
			byte[] flags = new byte[1];
			flags[0] = firstByte;
			return flags;
		} else if ((firstByte >>> 6 & 0x00000003) == 2) {
			// flags bits = 10xx xxxx xxxx xxxx => one additional byte for flags
			byte[] flags = new byte[2];
			flags[0] = firstByte;
			flags[0] &= 0x3F; // set first 2 bits to 0
			in.read(flags, 1, 1);
			return flags;
		} else if ((firstByte >>> 5 & 0x00000007) == 6) {
			// flags bits = 110x xxxx xxxx xxxx xxxx xxxx => two additional
			// bytes for flags
			byte[] flags = new byte[3];
			flags[0] = firstByte;
			flags[0] &= 0x1F; // set first 3 bits to 0
			in.read(flags, 1, 2);
			return flags;
		} else if ((firstByte >>> 4 & 0x0000000F) == 14) {
			// flags bits = 1110 xxxx xxxx xxxx xxxx xxxx xxxx xxxx => three
			// additional bytes for flags
			byte[] flags = new byte[4];
			flags[0] = firstByte;
			flags[0] &= 0x1F; // set first 4 bits to 0
			in.read(flags, 1, 3);
			return flags;
		}
		return new byte[0];
	}

	public static int readVariableLengthField(byte[] in) {
		int length = 0;
		if (in.length > 0) {
			byte firstByte = in[0];
			if ((firstByte >>> 7 & 0x000000FF) == 0) {
				// lengthOfLength bits = 0xxx xxxx => zero additional bytes for
				// length field
				length = firstByte;
			} else if ((firstByte >>> 6 & 0x00000003) == 2) {
				// lengthOfLength bits = 10xx xxxx xxxx xxxx => one additional
				// byte for length field
				if (in.length > 1) {
					byte[] lengthBytes = new byte[2];
					lengthBytes[0] = firstByte;
					lengthBytes[1] = in[1];
					// in.read(lengthBytes, 1, 1);
					int pos = 0;
					length |= ((long) lengthBytes[pos++] & 0x3F) << 8;
					length |= ((long) lengthBytes[pos++] & 0xFF) << 0;
				}
			} else if ((firstByte >>> 5 & 0x00000007) == 6) {
				// lengthOfLength bits = 110x xxxx xxxx xxxx xxxx xxxx => two
				// additional bytes for length field
				if (in.length > 2) {
					byte[] lengthBytes = new byte[3];
					lengthBytes[0] = firstByte;
					lengthBytes[1] = in[1];
					lengthBytes[2] = in[2];
					// in.read(lengthBytes, 1, 2);
					int pos = 0;
					length |= ((long) lengthBytes[pos++] & 0x1F) << 16;
					length |= ((long) lengthBytes[pos++] & 0xFF) << 8;
					length |= ((long) lengthBytes[pos++] & 0xFF) << 0;
				}
			} else if ((firstByte >>> 4 & 0x0000000F) == 14) {
				// lengthOfLength bits = 1110 xxxx xxxx xxxx xxxx xxxx xxxx xxxx
				// => three additional bytes for length field
				if (in.length > 3) {
					byte[] lengthBytes = new byte[4];
					lengthBytes[0] = firstByte;
					lengthBytes[1] = in[1];
					lengthBytes[2] = in[2];
					lengthBytes[3] = in[3];
					// in.read(lengthBytes, 1, 3);
					int pos = 0;
					length |= ((long) lengthBytes[pos++] & 0x0F) << 24;
					length |= ((long) lengthBytes[pos++] & 0xFF) << 16;
					length |= ((long) lengthBytes[pos++] & 0xFF) << 8;
					length |= ((long) lengthBytes[pos++] & 0xFF) << 0;
				}
			}
		}
		return length;
	}

	/**
	 * Reads variable length fields from DataInputStream. See IEEE 1609.2 v2
	 * section 5.1.6.2
	 * 
	 * @param in
	 * @return content of length field without length of length
	 * @throws IOException
	 */
	public static int readVariableLengthField(DataInputStream in)
			throws IOException {
		byte firstByte = in.readByte();
		int length = 0;
		if ((firstByte >>> 7 & 0x000000FF) == 0) {
			// lengthOfLength bits = 0xxx xxxx => zero additional bytes for
			// length field
			length = firstByte;
		} else if ((firstByte >>> 6 & 0x00000003) == 2) {
			// lengthOfLength bits = 10xx xxxx xxxx xxxx => one additional byte
			// for length field
			byte[] lengthBytes = new byte[2];
			lengthBytes[0] = firstByte;
			in.read(lengthBytes, 1, 1);
			int pos = 0;
			length |= ((long) lengthBytes[pos++] & 0x3F) << 8;
			length |= ((long) lengthBytes[pos++] & 0xFF) << 0;
		} else if ((firstByte >>> 5 & 0x00000007) == 6) {
			// lengthOfLength bits = 110x xxxx xxxx xxxx xxxx xxxx => two
			// additional bytes for length field
			byte[] lengthBytes = new byte[3];
			lengthBytes[0] = firstByte;
			in.read(lengthBytes, 1, 2);
			int pos = 0;
			length |= ((long) lengthBytes[pos++] & 0x1F) << 16;
			length |= ((long) lengthBytes[pos++] & 0xFF) << 8;
			length |= ((long) lengthBytes[pos++] & 0xFF) << 0;
		} else if ((firstByte >>> 4 & 0x0000000F) == 14) {
			// lengthOfLength bits = 1110 xxxx xxxx xxxx xxxx xxxx xxxx xxxx =>
			// three additional bytes for length field
			byte[] lengthBytes = new byte[4];
			lengthBytes[0] = firstByte;
			in.read(lengthBytes, 1, 3);
			int pos = 0;
			length |= ((long) lengthBytes[pos++] & 0x0F) << 24;
			length |= ((long) lengthBytes[pos++] & 0xFF) << 16;
			length |= ((long) lengthBytes[pos++] & 0xFF) << 8;
			length |= ((long) lengthBytes[pos++] & 0xFF) << 0;
		}
		return length;
	}

	/**
	 * Write flags into DataOutputStream and add bits for the length in front of
	 * the flags. See IEEE 1609.2 v2 section 5.1.13
	 * 
	 * @param flags
	 * @param out
	 * @return
	 * @throws IOException
	 */
	public static int writeFlags(byte[] flags, DataOutputStream out)
			throws IOException {

		if (flags.length <= 0) {
			throw new IllegalArgumentException("Empty flags array not allowed");
		}

		if (flags.length > 4) {
			throw new IllegalArgumentException(
					"Length of flag field bigger than allowed " + flags.length
							+ " bytes > 4 bytes");
		}

		if (flags.length == 1) {
			if (flags[0] >> 8 == 1) {
				// add one byte: 1000 0000 xxxx xxxx
				byte[] bytes = new byte[2];
				bytes[0] = (byte) 0x80;
				bytes[1] = flags[0];
				out.write(flags);
				return 2;
			} else {
				out.write(flags);
				return 1;
			}
		} else if (flags.length == 2) {
			flags[0] |= 0x80; // set 10xx xxxx xxxx xxxx
			out.write(flags);
			return 2;
		} else if (flags.length == 3) {
			flags[0] |= 0xC0; // set 110x xxxx xxxx xxxx xxxx xxxx
			out.write(flags);
			return 3;
		} else {
			flags[0] |= 0xE0; // set 1110 xxxx xxxx xxxx xxxx xxxx xxxx xxxx
			out.write(flags);
			return 4;
		}
	}

	/**
	 * Write variable length into DataOutputStream and add bits for length of
	 * length field. See IEEE 1609.2 v2 section 5.1.6.2
	 * 
	 * @param length
	 * @param out
	 * @return
	 * @throws IOException
	 */
	public static int writeVariableLengthField(int length, DataOutputStream out)
			throws IOException {

		if (length < 0) {
			throw new IllegalArgumentException("Length value " + length
					+ " smaller than 0");
		}

		if (length > Math.pow(2, 28) - 1) {
			throw new IllegalArgumentException("Length value " + length
					+ " bigger than " + (Math.pow(2, 28) - 1)
					+ " are not accepted");
		}

		if (length < Math.pow(2, 7)) {
			out.writeByte(length);
			return 1;
		} else if (length < Math.pow(2, 14)) {
			byte[] bytes = new byte[2];
			long l = length + 32768; // set 10xx xxxx xxxx xxxx
			bytes[0] = (byte) (l >>> 8);
			bytes[1] = (byte) l;
			out.write(bytes);
			return 2;
		} else if (length < Math.pow(2, 21)) {
			byte[] bytes = new byte[3];
			long l = length + 12582912; // set 110x xxxx xxxx xxxx xxxx xxxx
			bytes[0] = (byte) (l >>> 16);
			bytes[1] = (byte) (l >>> 8);
			bytes[2] = (byte) l;
			out.write(bytes);
			return 3;
		} else {
			byte[] bytes = new byte[4];
			long l = length + 3758096384L; // set 1110 xxxx xxxx xxxx xxxx xxxx
			// xxxx xxxx
			bytes[0] = (byte) (l >>> 24);
			bytes[1] = (byte) (l >>> 16);
			bytes[2] = (byte) (l >>> 8);
			bytes[3] = (byte) l;
			out.write(bytes);
			return 4;
		}
	}

	public static int writeIntX(long length, DataOutputStream out)
			throws IOException {

		if (length < 0) {
			throw new IllegalArgumentException("Length value " + length
					+ " smaller than 0");
		}

		if (length > Math.pow(2, 28) - 1) {
			throw new IllegalArgumentException("Length value " + length
					+ " bigger than " + (Math.pow(2, 56) - 1)
					+ " are not accepted");
		}

		for (int i = 1; i <= 8; i++) {
			// System.out.println("i="+i);
			if (length < (long) Math.pow(2, i * 7)) {
				// System.out.println("length < "+(long) Math.pow(2, i * 7));
				byte[] bytes = new byte[i];
				long l = length
						+ (long) ((long) (Math.pow(2, i) - 2) << ((i) * 7));
				// System.out.println((long) Math.pow(2, i)-2 +
				// " << "+((i)*7)+", "+l);
				for (int j = 0; j < bytes.length; j++) {
					bytes[j] = (byte) (l >>> (bytes.length - 1 - j) * 8);
				}
				out.write(bytes);
				return i;
			}
		}
		return 0;
	}

	public static <T extends WaveType> T[] getArrayFromStream(
			DataInputStream in, Class<T> returnType) throws IOException {
		return getArrayFromStream(in, returnType, new Object[0]);
	}

	public static <T extends WaveType> ArrayList<T> getArrayListFromStream(
			DataInputStream in, Class<T> returnType) throws IOException {

		return new ArrayList<T>(Arrays.asList(getArrayFromStream(in,
				returnType, new Object[0])));
	}

	@SuppressWarnings("unchecked")
	public static <T extends WaveType> T[] getArrayFromStream(
			DataInputStream in, Class<T> returnType, Object[] args)
			throws IOException {

		List<T> lst = new ArrayList<T>();
		int numberOfOctets = readVariableLengthField(in);

		Class<?>[] params = new Class<?>[args.length + 1];
		Object[] input = new Object[args.length + 1];
		params[0] = DataInputStream.class;
		input[0] = in;
		for (int i = 0; i < args.length; i++) {
			params[i + 1] = args[i].getClass();
			input[i + 1] = args[i];
		}
		while (numberOfOctets > 0) {
			T wave = null;
			Constructor<T> constructor;
			try {
				constructor = returnType.getConstructor(params);
				wave = constructor.newInstance(input);
				lst.add(wave);
			} catch (Exception e) {
				e.printStackTrace();
			}
			numberOfOctets -= WaveUtils.getBytesFromWaveType(wave).length;
		}

		T[] arr = ((T[]) Array.newInstance(returnType, lst.size()));
		return (T[]) lst.toArray(arr);
	}

	public static <T extends WaveType> T getElementFromStream(
			DataInputStream in, Class<T> returnType, Object... args)
			throws IOException {

		Class<?>[] params = new Class<?>[args.length + 1];
		Object[] input = new Object[args.length + 1];
		params[0] = DataInputStream.class;
		input[0] = in;
		for (int i = 0; i < args.length; i++) {
			params[i + 1] = args[i].getClass();
			input[i + 1] = args[i];
		}
		T wave = null;
		Constructor<T> constructor;
		try {
			constructor = returnType.getConstructor(params);
			wave = constructor.newInstance(input);
		} catch (Exception e) {
			throw new IOException("Can not parse stream to class");
		}
		return wave;

	}

	public static <T extends WaveType> T getElementFromBytes(byte[] bytes,
			Class<T> returnType, Object... args) throws IOException {

		return getElementFromStream(ByteUtils.bytesAsStream(bytes), returnType,
				args);

	}

	public static <T extends WaveType> T getElementFromBytes(byte[] bytes,
			Class<T> returnType) throws IOException {

		return getElementFromStream(ByteUtils.bytesAsStream(bytes), returnType,
				new Object[0]);

	}

	public static <T extends WaveType, I extends T> ArrayList<T> getArrayListFromStream(
			DataInputStream in, Class<T> returnType, Class<I> ImplType,
			Object[] args) throws IOException {

		return new ArrayList<T>(Arrays.asList(getArrayFromStream(in, ImplType,
				args)));
	}

	public static <T extends WaveType> int writeArrayToStream(
			DataOutputStream out, T[] elements) throws IOException {

		int written = 0;
		int length = 0;
		for (T t : elements) {
			length += WaveUtils.getBytesFromWaveType(t).length;
		}
		written += writeVariableLengthField(length, out);
		for (T t : elements) {
			written += t.writeData(out);
		}

		return written;
	}

	public static <T extends WaveType> int writeArrayToStream(
			DataOutputStream out, List<T> elements) throws IOException {

		int written = 0;
		int length = 0;
		for (T t : elements) {
			length += WaveUtils.getBytesFromWaveType(t).length;
		}
		written += writeVariableLengthField(length, out);
		for (T t : elements) {
			written += t.writeData(out);
		}

		return written;
	}

	public static int writeWave(DataOutputStream out, WaveType... waves)
			throws IOException {

		int written = 0;
		for (WaveType waveType : waves) {
			if (waveType == null)
				throw new IllegalArgumentException("args may not be null");
			written += waveType.writeData(out);
		}
		return written;
	}

	public static byte[] getBytesFromWaveType(WaveType waveType) {
		ByteArrayOutputStream bos = new ByteArrayOutputStream();
		DataOutputStream dos = new DataOutputStream(bos);
		try {
			waveType.writeData(dos);
			dos.flush();
		} catch (IOException e) {
			e.printStackTrace();
		}
		return bos.toByteArray();
	}

	public static byte[] concatWaveTypes(WaveType... waveTypes) {

		byte[][] bytes = new byte[waveTypes.length][];
		int size = 0;
		for (int i = 0; i < waveTypes.length; i++) {
			bytes[i] = getBytesFromWaveType(waveTypes[i]);
			size += bytes[i].length;
		}
		byte[] concatBytes = new byte[size];
		int pointer = 0;
		for (int i = 0; i < bytes.length; i++) {
			System.arraycopy(bytes[i], 0, concatBytes, pointer,
					bytes[i].length);
			pointer += bytes[i].length;
		}
		return concatBytes;
	}

	public static int countOneBits(byte[] bytes) {

		int n = 0;
		for (byte b : bytes) {
			boolean finish = false;
			for (int i = 7; i >= 0; i--) {
				if (((b >> i) & 0x01) == 1) {
					n++;
				} else {
					finish = true;
					break;
				}
			}
			if (finish)
				break;
		}

		return n;
	}

	public static long variableLength(byte[] input) {
		// System.out.println("-----");
		int cbytes = countOneBits(input);
		// System.out.println(cbytes+", "+((int)cbytes/8)+", "+(((long)
		// Math.pow(2, 8-cbytes%8-1))-1));

		long res = ((long) (((long) input[((int) cbytes / 8)]) & (((long) Math
				.pow(2, 8 - cbytes % 8 - 1)) - 1)) << (cbytes - ((int) cbytes / 8)) * 8);
		// System.out.println("init-res "+res+
		// ", "+((long)input[((int)cbytes/8)]) +", "+(((long) Math.pow(2,
		// 8-cbytes%8-1))-1)+", "+(cbytes-((int)cbytes/8))*8);
		for (int i = ((int) cbytes / 8); i < cbytes; i++) {
			// System.out.println(((0xff & input[i+1])+" << "+(cbytes-i-1)*8));
			res += ((long) ((0xff & (long) input[i + 1]) << (cbytes - i - 1) * 8));
			// System.out.println(("+"+((long)((0xff & (long)input[i+1]) <<
			// (cbytes-i-1)*8))));
			// System.out.println("-res"+res);
		}

		return res;
	}

	public static long readIntX(DataInputStream in) throws IOException {

		int n = 0;
		boolean finish = false;
		byte b = 0x00;
		while (!finish && in.available() > 0) {
			b = in.readByte();
			for (int i = 7; i >= 0; i--) {
				if (((b >> i) & 0x01) == 1) {
					n++;
				} else {
					finish = true;
					break;
				}
			}
		}

		long res = ((long) (((long) b & (((long) Math.pow(2, 8 - n % 8 - 1)) - 1)) << (n - ((int) n / 8)) * 8));
		// System.out.println("init-res "+res+
		// ", "+((long)input[((int)cbytes/8)]) +", "+(((long) Math.pow(2,
		// 8-cbytes%8-1))-1)+", "+(cbytes-((int)cbytes/8))*8);

		for (int i = 0; i < n; i++) {
			// System.out.println(((0xff & input[i+1])+" << "+(cbytes-i-1)*8));
			res += ((long) ((0xff & (long) in.readByte()) << (n - i - 1) * 8));
			// System.out.println(("+"+((long)((0xff & (long)input[i+1]) <<
			// (cbytes-i-1)*8))));
			// System.out.println("-res"+res);
		}

		return res;
	}
}
