Commit 294d1e36 authored by Emilia Kasper's avatar Emilia Kasper
Browse files

RT3066: rewrite RSA padding checks to be slightly more constant time.



Also tweak s3_cbc.c to use new constant-time methods.
Also fix memory leaks from internal errors in RSA_padding_check_PKCS1_OAEP_mgf1

This patch is based on the original RT submission by Adam Langley <agl@chromium.org>,
as well as code from BoringSSL and OpenSSL.

Reviewed-by: default avatarKurt Roeckx <kurt@openssl.org>
parent 51b7be8d
Loading
Loading
Loading
Loading
+34 −2
Original line number Diff line number Diff line
@@ -54,7 +54,7 @@ extern "C" {
#endif

/*
 * The following methods return a bitmask of all ones (0xff...f) for true
 * The boolean methods return a bitmask of all ones (0xff...f) for true
 * and 0 for false. This is useful for choosing a value based on the result
 * of a conditional in constant time. For example,
 *
@@ -67,7 +67,7 @@ extern "C" {
 * can be written as
 *
 * unsigned int lt = constant_time_lt(a, b);
 * c = a & lt | b & ~lt;
 * c = constant_time_select(lt, a, b);
 */

/*
@@ -107,6 +107,21 @@ static inline unsigned int constant_time_eq(unsigned int a, unsigned int b);
/* Convenience method for getting an 8-bit mask. */
static inline unsigned char constant_time_eq_8(unsigned int a, unsigned int b);

/*
 * Returns (mask & a) | (~mask & b).
 *
 * When |mask| is all 1s or all 0s (as returned by the methods above),
 * the select methods return either |a| (if |mask| is nonzero) or |b|
 * (if |mask| is zero).
 */
static inline unsigned int constant_time_select(unsigned int mask,
	unsigned int a, unsigned int b);
/* Convenience method for unsigned chars. */
static inline unsigned char constant_time_select_8(unsigned char mask,
	unsigned char a, unsigned char b);
/* Convenience method for signed integers. */
static inline int constant_time_select_int(unsigned int mask, int a, int b);

static inline unsigned int constant_time_msb(unsigned int a)
	{
	return (unsigned int)((int)(a) >> (sizeof(int) * 8 - 1));
@@ -162,6 +177,23 @@ static inline unsigned char constant_time_eq_8(unsigned int a, unsigned int b)
	return (unsigned char)(constant_time_eq(a, b));
	}

static inline unsigned int constant_time_select(unsigned int mask,
	unsigned int a, unsigned int b)
	{
	return (mask & a) | (~mask & b);
	}

static inline unsigned char constant_time_select_8(unsigned char mask,
	unsigned char a, unsigned char b)
	{
	return (unsigned char)(constant_time_select(mask, a, b));
	}

inline int constant_time_select_int(unsigned int mask, int a, int b)
	{
	return (int)(constant_time_select(mask, (unsigned)(a), (unsigned)(b)));
	}

#ifdef __cplusplus
}
#endif
+101 −17
Original line number Diff line number Diff line
@@ -50,9 +50,9 @@
#include <stdio.h>
#include <stdlib.h>

static const unsigned int CONSTTIME_TRUE = ~0;
static const unsigned int CONSTTIME_TRUE = (unsigned)(~0);
static const unsigned int CONSTTIME_FALSE = 0;
static const unsigned char CONSTTIME_TRUE_8 = ~0;
static const unsigned char CONSTTIME_TRUE_8 = 0xff;
static const unsigned char CONSTTIME_FALSE_8 = 0;

static int test_binary_op(unsigned int (*op)(unsigned int a, unsigned int b),
@@ -133,13 +133,86 @@ static int test_is_zero_8(unsigned int a)
        return 0;
	}

static int test_select(unsigned int a, unsigned int b)
	{
	unsigned int selected = constant_time_select(CONSTTIME_TRUE, a, b);
	if (selected != a)
		{
		fprintf(stderr, "Test failed for constant_time_select(%du, %du,"
			"%du): expected %du(first value), got %du\n",
			CONSTTIME_TRUE, a, b, a, selected);
		return 1;
		}
	selected = constant_time_select(CONSTTIME_FALSE, a, b);
	if (selected != b)
		{
		fprintf(stderr, "Test failed for constant_time_select(%du, %du,"
			"%du): expected %du(second value), got %du\n",
			CONSTTIME_FALSE, a, b, b, selected);
		return 1;
		}
	return 0;
	}

static int test_select_8(unsigned char a, unsigned char b)
	{
	unsigned char selected = constant_time_select_8(CONSTTIME_TRUE_8, a, b);
	if (selected != a)
		{
		fprintf(stderr, "Test failed for constant_time_select(%u, %u,"
			"%u): expected %u(first value), got %u\n",
			CONSTTIME_TRUE, a, b, a, selected);
		return 1;
		}
	selected = constant_time_select_8(CONSTTIME_FALSE_8, a, b);
	if (selected != b)
		{
		fprintf(stderr, "Test failed for constant_time_select(%u, %u,"
			"%u): expected %u(second value), got %u\n",
			CONSTTIME_FALSE, a, b, b, selected);
		return 1;
		}
	return 0;
	}

static int test_select_int(int a, int b)
	{
	int selected = constant_time_select_int(CONSTTIME_TRUE, a, b);
	if (selected != a)
		{
		fprintf(stderr, "Test failed for constant_time_select(%du, %d,"
			"%d): expected %d(first value), got %d\n",
			CONSTTIME_TRUE, a, b, a, selected);
		return 1;
		}
	selected = constant_time_select_int(CONSTTIME_FALSE, a, b);
	if (selected != b)
		{
		fprintf(stderr, "Test failed for constant_time_select(%du, %d,"
			"%d): expected %d(second value), got %d\n",
			CONSTTIME_FALSE, a, b, b, selected);
		return 1;
		}
	return 0;
	}


static unsigned int test_values[] = {0, 1, 1024, 12345, 32000, UINT_MAX/2-1,
                                     UINT_MAX/2, UINT_MAX/2+1, UINT_MAX-1,
                                     UINT_MAX};

static unsigned char test_values_8[] = {0, 1, 2, 20, 32, 127, 128, 129, 255};

static int signed_test_values[] = {0, 1, -1, 1024, -1024, 12345, -12345,
				   32000, -32000, INT_MAX, INT_MIN, INT_MAX-1,
				   INT_MIN+1};


int main(int argc, char *argv[])
	{
	unsigned int a, b, i, j;
	int c, d;
	unsigned char e, f;
	int num_failed = 0, num_all = 0;
	fprintf(stdout, "Testing constant time operations...\n");

@@ -148,20 +221,8 @@ int main(int argc, char *argv[])
		a = test_values[i];
		num_failed += test_is_zero(a);
		num_failed += test_is_zero_8(a);
		num_failed += test_binary_op(&constant_time_lt,
			"constant_time_lt", a, a, 0);
		num_failed += test_binary_op_8(&constant_time_lt_8,
			"constant_time_lt_8", a, a, 0);
		num_failed += test_binary_op(&constant_time_ge,
			"constant_time_ge", a, a, 1);
		num_failed += test_binary_op_8(&constant_time_ge_8,
			"constant_time_ge_8", a, a, 1);
		num_failed += test_binary_op(&constant_time_eq,
			"constant_time_eq", a, a, 1);
		num_failed += test_binary_op_8(&constant_time_eq_8,
			"constant_time_eq_8", a, a, 1);
		num_all += 8;
		for (j = i + 1; j < sizeof(test_values)/sizeof(int); ++j)
		num_all += 2;
		for (j = 0; j < sizeof(test_values)/sizeof(int); ++j)
			{
			b = test_values[j];
			num_failed += test_binary_op(&constant_time_lt,
@@ -188,7 +249,30 @@ int main(int argc, char *argv[])
				"constant_time_eq", b, a, b == a);
			num_failed += test_binary_op_8(&constant_time_eq_8,
				"constant_time_eq_8", b, a, b == a);
			num_all += 12;
			num_failed += test_select(a, b);
			num_all += 13;
			}
		}

	for (i = 0; i < sizeof(signed_test_values)/sizeof(int); ++i)
		{
		c = signed_test_values[i];
		for (j = 0; j < sizeof(signed_test_values)/sizeof(int); ++j)
			{
			d = signed_test_values[j];
			num_failed += test_select_int(c, d);
			num_all += 1;
			}
		}

	for (i = 0; i < sizeof(test_values_8); ++i)
		{
		e = test_values_8[i];
		for (j = 0; j < sizeof(test_values_8); ++j)
			{
			f = test_values_8[j];
			num_failed += test_select_8(e, f);
			num_all += 1;
			}
		}

+3 −2
Original line number Diff line number Diff line
@@ -206,7 +206,7 @@ rsa_oaep.o: ../../include/openssl/opensslv.h ../../include/openssl/ossl_typ.h
rsa_oaep.o: ../../include/openssl/rand.h ../../include/openssl/rsa.h
rsa_oaep.o: ../../include/openssl/safestack.h ../../include/openssl/sha.h
rsa_oaep.o: ../../include/openssl/stack.h ../../include/openssl/symhacks.h
rsa_oaep.o: ../cryptlib.h rsa_oaep.c
rsa_oaep.o: ../constant_time_locl.h ../cryptlib.h rsa_oaep.c
rsa_pk1.o: ../../e_os.h ../../include/openssl/asn1.h
rsa_pk1.o: ../../include/openssl/bio.h ../../include/openssl/bn.h
rsa_pk1.o: ../../include/openssl/buffer.h ../../include/openssl/crypto.h
@@ -215,7 +215,8 @@ rsa_pk1.o: ../../include/openssl/lhash.h ../../include/openssl/opensslconf.h
rsa_pk1.o: ../../include/openssl/opensslv.h ../../include/openssl/ossl_typ.h
rsa_pk1.o: ../../include/openssl/rand.h ../../include/openssl/rsa.h
rsa_pk1.o: ../../include/openssl/safestack.h ../../include/openssl/stack.h
rsa_pk1.o: ../../include/openssl/symhacks.h ../cryptlib.h rsa_pk1.c
rsa_pk1.o: ../../include/openssl/symhacks.h ../constant_time_locl.h
rsa_pk1.o: ../cryptlib.h rsa_pk1.c
rsa_pmeth.o: ../../e_os.h ../../include/openssl/asn1.h
rsa_pmeth.o: ../../include/openssl/asn1t.h ../../include/openssl/bio.h
rsa_pmeth.o: ../../include/openssl/bn.h ../../include/openssl/buffer.h
+2 −1
Original line number Diff line number Diff line
@@ -616,6 +616,7 @@ void ERR_load_RSA_strings(void);
#define RSA_R_OAEP_DECODING_ERROR			 121
#define RSA_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE	 148
#define RSA_R_PADDING_CHECK_FAILED			 114
#define RSA_R_PKCS_DECODING_ERROR			 159
#define RSA_R_P_NOT_PRIME				 128
#define RSA_R_Q_NOT_PRIME				 129
#define RSA_R_RSA_OPERATIONS_NOT_SUPPORTED		 130
@@ -624,7 +625,7 @@ void ERR_load_RSA_strings(void);
#define RSA_R_SSLV3_ROLLBACK_ATTACK			 115
#define RSA_R_THE_ASN1_OBJECT_IDENTIFIER_IS_NOT_KNOWN_FOR_THIS_MD 116
#define RSA_R_UNKNOWN_ALGORITHM_TYPE			 117
#define RSA_R_UNKNOWN_DIGEST				 159
#define RSA_R_UNKNOWN_DIGEST				 166
#define RSA_R_UNKNOWN_MASK_DIGEST			 151
#define RSA_R_UNKNOWN_PADDING_TYPE			 118
#define RSA_R_UNKNOWN_PSS_DIGEST			 152
+1 −0
Original line number Diff line number Diff line
@@ -181,6 +181,7 @@ static ERR_STRING_DATA RSA_str_reasons[]=
{ERR_REASON(RSA_R_OAEP_DECODING_ERROR)   ,"oaep decoding error"},
{ERR_REASON(RSA_R_OPERATION_NOT_SUPPORTED_FOR_THIS_KEYTYPE),"operation not supported for this keytype"},
{ERR_REASON(RSA_R_PADDING_CHECK_FAILED)  ,"padding check failed"},
{ERR_REASON(RSA_R_PKCS_DECODING_ERROR)   ,"pkcs decoding error"},
{ERR_REASON(RSA_R_P_NOT_PRIME)           ,"p not prime"},
{ERR_REASON(RSA_R_Q_NOT_PRIME)           ,"q not prime"},
{ERR_REASON(RSA_R_RSA_OPERATIONS_NOT_SUPPORTED),"rsa operations not supported"},
Loading