Commit 8329e2e7 authored by Andy Polyakov's avatar Andy Polyakov
Browse files

bn_exp.c: further optimizations using more ideas from

parent 3f66f204
Loading
Loading
Loading
Loading
+76 −7
Original line number Diff line number Diff line
@@ -607,7 +607,6 @@ $code.=<<___;
	add	$A[1],$N[1]		# np[j]*m1+ap[j]*bp[i]+tp[j]
	lea	4($j),$j		# j+=2
	adc	\$0,%rdx
	mov	$N[1],(%rsp)		# tp[j-1]
	mov	%rdx,$N[0]
	jmp	.Linner4x
.align	16
@@ -626,7 +625,7 @@ $code.=<<___;
	adc	\$0,%rdx
	add	$A[0],$N[0]
	adc	\$0,%rdx
	mov	$N[0],-24(%rsp,$j,8)	# tp[j-1]
	mov	$N[1],-32(%rsp,$j,8)	# tp[j-1]
	mov	%rdx,$N[1]

	mulq	$m0			# ap[j]*bp[i]
@@ -643,7 +642,7 @@ $code.=<<___;
	adc	\$0,%rdx
	add	$A[1],$N[1]
	adc	\$0,%rdx
	mov	$N[1],-16(%rsp,$j,8)	# tp[j-1]
	mov	$N[0],-24(%rsp,$j,8)	# tp[j-1]
	mov	%rdx,$N[0]

	mulq	$m0			# ap[j]*bp[i]
@@ -660,7 +659,7 @@ $code.=<<___;
	adc	\$0,%rdx
	add	$A[0],$N[0]
	adc	\$0,%rdx
	mov	$N[0],-8(%rsp,$j,8)	# tp[j-1]
	mov	$N[1],-16(%rsp,$j,8)	# tp[j-1]
	mov	%rdx,$N[1]

	mulq	$m0			# ap[j]*bp[i]
@@ -678,7 +677,7 @@ $code.=<<___;
	adc	\$0,%rdx
	add	$A[1],$N[1]
	adc	\$0,%rdx
	mov	$N[1],-32(%rsp,$j,8)	# tp[j-1]
	mov	$N[0],-40(%rsp,$j,8)	# tp[j-1]
	mov	%rdx,$N[0]
	cmp	$num,$j
	jl	.Linner4x
@@ -697,7 +696,7 @@ $code.=<<___;
	adc	\$0,%rdx
	add	$A[0],$N[0]
	adc	\$0,%rdx
	mov	$N[0],-24(%rsp,$j,8)	# tp[j-1]
	mov	$N[1],-32(%rsp,$j,8)	# tp[j-1]
	mov	%rdx,$N[1]

	mulq	$m0			# ap[j]*bp[i]
@@ -715,10 +714,11 @@ $code.=<<___;
	adc	\$0,%rdx
	add	$A[1],$N[1]
	adc	\$0,%rdx
	mov	$N[1],-16(%rsp,$j,8)	# tp[j-1]
	mov	$N[0],-24(%rsp,$j,8)	# tp[j-1]
	mov	%rdx,$N[0]

	movq	%xmm0,$m0		# bp[i+1]
	mov	$N[1],-16(%rsp,$j,8)	# tp[j-1]

	xor	$N[1],$N[1]
	add	$A[0],$N[0]
@@ -831,6 +831,10 @@ ___
{
my ($inp,$num,$tbl,$idx)=$win64?("%rcx","%rdx","%r8", "%r9") : # Win64 order
				("%rdi","%rsi","%rdx","%rcx"); # Unix order
my $out=$inp;
my $STRIDE=2**5*8;
my $N=$STRIDE/4;

$code.=<<___;
.globl	bn_scatter5
.type	bn_scatter5,\@abi-omnipotent
@@ -849,6 +853,61 @@ bn_scatter5:
.Lscatter_epilogue:
	ret
.size	bn_scatter5,.-bn_scatter5

.globl	bn_gather5
.type	bn_gather5,\@abi-omnipotent
.align	16
bn_gather5:
___
$code.=<<___ if ($win64);
.LSEH_begin_bn_gather5:
	# I can't trust assembler to use specific encoding:-(
	.byte	0x48,0x83,0xec,0x28		#sub	\$0x28,%rsp
	.byte	0x0f,0x29,0x34,0x24		#movaps	%xmm6,(%rsp)
	.byte	0x0f,0x29,0x7c,0x24,0x10	#movdqa	%xmm7,0x10(%rsp)
___
$code.=<<___;
	mov	$idx,%r11
	shr	\$`log($N/8)/log(2)`,$idx
	and	\$`$N/8-1`,%r11
	not	$idx
	lea	.Lmagic_masks(%rip),%rax
	and	\$`2**5/($N/8)-1`,$idx	# 5 is "window size"
	lea	96($tbl,%r11,8),$tbl	# pointer within 1st cache line
	movq	0(%rax,$idx,8),%xmm4	# set of masks denoting which
	movq	8(%rax,$idx,8),%xmm5	# cache line contains element
	movq	16(%rax,$idx,8),%xmm6	# denoted by 7th argument
	movq	24(%rax,$idx,8),%xmm7
	jmp	.Lgather
.align	16
.Lgather:
	movq	`0*$STRIDE/4-96`($tbl),%xmm0
	movq	`1*$STRIDE/4-96`($tbl),%xmm1
	pand	%xmm4,%xmm0
	movq	`2*$STRIDE/4-96`($tbl),%xmm2
	pand	%xmm5,%xmm1
	movq	`3*$STRIDE/4-96`($tbl),%xmm3
	pand	%xmm6,%xmm2
	por	%xmm1,%xmm0
	pand	%xmm7,%xmm3
	por	%xmm2,%xmm0
	lea	$STRIDE($tbl),$tbl
	por	%xmm3,%xmm0

	movq	%xmm0,($out)		# m0=bp[0]
	lea	8($out),$out
	sub	\$1,$num
	jnz	.Lgather
___
$code.=<<___ if ($win64);
	movaps	%xmm6,(%rsp)
	movaps	%xmm7,0x10(%rsp)
	lea	0x28(%rsp),%rsp
___
$code.=<<___;
	ret
.LSEH_end_bn_gather5:
.size	bn_gather5,.-bn_gather5
___
}
$code.=<<___;
@@ -980,6 +1039,10 @@ mul_handler:
	.rva	.LSEH_end_bn_mul4x_mont_gather5
	.rva	.LSEH_info_bn_mul4x_mont_gather5

	.rva	.LSEH_begin_bn_gather5
	.rva	.LSEH_end_bn_gather5
	.rva	.LSEH_info_bn_gather5

.section	.xdata
.align	8
.LSEH_info_bn_mul_mont_gather5:
@@ -992,6 +1055,12 @@ mul_handler:
	.rva	mul_handler
	.rva	.Lmul4x_alloca,.Lmul4x_body,.Lmul4x_epilogue	# HandlerData[]
.align	8
.LSEH_info_bn_gather5:
        .byte   0x01,0x0d,0x05,0x00
        .byte   0x0d,0x78,0x01,0x00	#movaps	0x10(rsp),xmm7
        .byte   0x08,0x68,0x00,0x00	#movaps	(rsp),xmm6
        .byte   0x04,0x42,0x00,0x00	#sub	rsp,0x28
.align	8
___
}

+81 −107
Original line number Diff line number Diff line
@@ -535,23 +535,17 @@ err:
 * as cache lines are concerned.  The following functions are used to transfer a BIGNUM
 * from/to that table. */

static int MOD_EXP_CTIME_COPY_TO_PREBUF(BIGNUM *b, int top, unsigned char *buf, int idx, int width)
static int MOD_EXP_CTIME_COPY_TO_PREBUF(const BIGNUM *b, int top, unsigned char *buf, int idx, int width)
	{
	size_t i, j;

	if (bn_wexpand(b, top) == NULL)
		return 0;
	while (b->top < top)
		{
		b->d[b->top++] = 0;
		}
	
	if (top > b->top)
		top = b->top; /* this works because 'buf' is explicitly zeroed */
	for (i = 0, j=idx; i < top * sizeof b->d[0]; i++, j+=width)
		{
		buf[j] = ((unsigned char*)b->d)[i];
		}

	bn_correct_top(b);
	return 1;
	}

@@ -587,14 +581,13 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
	{
	int i,bits,ret=0,window,wvalue;
	int top;
 	BIGNUM *r;
	BN_MONT_CTX *mont=NULL;

	int numPowers;
	unsigned char *powerbufFree=NULL;
	int powerbufLen = 0;
	unsigned char *powerbuf=NULL;
	BIGNUM computeTemp, *am=NULL;
	BIGNUM tmp, am;

	bn_check_top(a);
	bn_check_top(p);
@@ -614,10 +607,7 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
		return ret;
		}

 	/* Initialize BIGNUM context and allocate intermediate result */
	BN_CTX_start(ctx);
	r = BN_CTX_get(ctx);
	if (r == NULL) goto err;

	/* Allocate a montgomery context if it was not supplied by the caller.
	 * If this is not done, things will break in the montgomery part.
@@ -635,25 +625,13 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
#if defined(OPENSSL_BN_ASM_MONT5)
	if (window==6 && bits<=1024) window=5;	/* ~5% improvement of 2048-bit RSA sign */
#endif
 	/* Adjust the number of bits up to a multiple of the window size.
 	 * If the exponent length is not a multiple of the window size, then
 	 * this pads the most significant bits with zeros to normalize the
 	 * scanning loop to there's no special cases.
 	 *
 	 * * NOTE: Making the window size a power of two less than the native
	 * * word size ensures that the padded bits won't go past the last
 	 * * word in the internal BIGNUM structure. Going past the end will
 	 * * still produce the correct result, but causes a different branch
 	 * * to be taken in the BN_is_bit_set function.
 	 */
 	bits = ((bits+window-1)/window)*window;

	/* Allocate a buffer large enough to hold all of the pre-computed
	 * powers of a, plus computeTemp.
	 * powers of am, am itself and tmp.
	 */
	numPowers = 1 << window;
	powerbufLen = sizeof(m->d[0])*(top*numPowers +
				(top>numPowers?top:numPowers));
				((2*top)>numPowers?(2*top):numPowers));
#ifdef alloca
	if (powerbufLen < 3072)
		powerbufFree = alloca(powerbufLen+MOD_EXP_CTIME_MIN_CACHE_LINE_WIDTH);
@@ -670,28 +648,31 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
		powerbufFree = NULL;
#endif

	computeTemp.d = (BN_ULONG *)(powerbuf + sizeof(m->d[0])*top*numPowers);
	computeTemp.top = computeTemp.dmax = top;
	computeTemp.neg = 0;
	computeTemp.flags = BN_FLG_STATIC_DATA;

 	/* Initialize the intermediate result. Do this early to save double conversion,
	 * once each for a^0 and intermediate result.
	 */
 	if (!BN_to_montgomery(r,BN_value_one(),mont,ctx)) goto err;

	/* Initialize computeTemp as a^1 with montgomery precalcs */
	am = BN_CTX_get(ctx);
	if (am==NULL) goto err;
	/* lay down tmp and am right after powers table */
	tmp.d     = (BN_ULONG *)(powerbuf + sizeof(m->d[0])*top*numPowers);
	am.d      = tmp.d + top;
	tmp.top   = am.top  = 0;
	tmp.dmax  = am.dmax = top;
	tmp.neg   = am.neg  = 0;
	tmp.flags = am.flags = BN_FLG_STATIC_DATA;

	/* prepare a^0 in Montgomery domain */
#if 1
 	if (!BN_to_montgomery(&tmp,BN_value_one(),mont,ctx))	goto err;
#else
	tmp.d[0] = (0-m->d[0])&BN_MASK2;	/* 2^(top*BN_BITS2) - m */
	for (i=1;i<top;i++)
		tmp.d[i] = (~m->d[i])&BN_MASK2;
	tmp.top = top;
#endif

	/* prepare a^1 in Montgomery domain */
	if (a->neg || BN_ucmp(a,m) >= 0)
		{
		if (!BN_mod(am,a,m,ctx))		goto err;
		if (!BN_to_montgomery(am,am,mont,ctx))	goto err;
		if (!BN_mod(&am,a,m,ctx))			goto err;
		if (!BN_to_montgomery(&am,&am,mont,ctx))	goto err;
		}
	else	if (!BN_to_montgomery(am,a,mont,ctx))	goto err;

	if (!BN_copy(&computeTemp, am)) goto err;
	else	if (!BN_to_montgomery(&am,a,mont,ctx))		goto err;

#if defined(OPENSSL_BN_ASM_MONT5)
    /* This optimization uses ideas from http://eprint.iacr.org/2011/239,
@@ -707,95 +688,83 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
			const BN_ULONG *n0,int num,int power);
	void bn_scatter5(const BN_ULONG *inp,size_t num,
			void *table,size_t power);
	void bn_gather5(BN_ULONG *out,size_t num,
			void *table,size_t power);

	BN_ULONG *acc, *np=mont->N.d, *n0=mont->n0;
	BN_ULONG *np=mont->N.d, *n0=mont->n0;

	bn_scatter5(r->d,r->top,powerbuf,0);
	bn_scatter5(am->d,am->top,powerbuf,1);
	bn_scatter5(tmp.d,top,powerbuf,0);
	bn_scatter5(am.d,am.top,powerbuf,1);
	bn_mul_mont(tmp.d,am.d,am.d,np,n0,top);
	bn_scatter5(tmp.d,top,powerbuf,2);

	acc = computeTemp.d;
	/* bn_mul_mont() and bn_mul_mont_gather5() assume fixed length inputs.
	 * Pad the inputs with zeroes.
	 */
	if (bn_wexpand(am,top)==NULL ||	bn_wexpand(r,top)==NULL ||
	    bn_wexpand(&computeTemp,top)==NULL)
		goto err;
	for (i = am->top; i < top; ++i)
		{
		am->d[i] = 0;
		}
	for (i = computeTemp.top; i < top; ++i)
		{
		computeTemp.d[i] = 0;
		}
	for (i = r->top; i < top; ++i)
		{
		r->d[i] = 0;
		}
#if 0
	for (i=2; i<32; i++)
	for (i=3; i<32; i++)
		{
		bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
		bn_scatter5(acc,top,powerbuf,i);
		/* Calculate a^i = a^(i-1) * a */
		bn_mul_mont_gather5(tmp.d,am.d,powerbuf,np,n0,top,i-1);
		bn_scatter5(tmp.d,top,powerbuf,i);
		}
#else
	/* same as above, but uses squaring for 1/2 of operations */
	for (i=2; i<32; i*=2)
	for (i=4; i<32; i*=2)
		{
		bn_mul_mont(acc,acc,acc,np,n0,top);
		bn_scatter5(acc,top,powerbuf,i);
		bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
		bn_scatter5(tmp.d,top,powerbuf,i);
		}
	for (i=3; i<8; i+=2)
		{
		int j;
		bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
		bn_scatter5(acc,top,powerbuf,i);
		bn_mul_mont_gather5(tmp.d,am.d,powerbuf,np,n0,top,i-1);
		bn_scatter5(tmp.d,top,powerbuf,i);
		for (j=2*i; j<32; j*=2)
			{
			bn_mul_mont(acc,acc,acc,np,n0,top);
			bn_scatter5(acc,top,powerbuf,j);
			bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
			bn_scatter5(tmp.d,top,powerbuf,j);
			}
		}
	for (; i<16; i+=2)
		{
		bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
		bn_scatter5(acc,top,powerbuf,i);
		bn_mul_mont(acc,acc,acc,np,n0,top);
		bn_scatter5(acc,top,powerbuf,2*i);
		bn_mul_mont_gather5(tmp.d,am.d,powerbuf,np,n0,top,i-1);
		bn_scatter5(tmp.d,top,powerbuf,i);
		bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
		bn_scatter5(tmp.d,top,powerbuf,2*i);
		}
	for (; i<32; i+=2)
		{
		bn_mul_mont_gather5(acc,am->d,powerbuf,np,n0,top,i-1);
		bn_scatter5(acc,top,powerbuf,i);
		bn_mul_mont_gather5(tmp.d,am.d,powerbuf,np,n0,top,i-1);
		bn_scatter5(tmp.d,top,powerbuf,i);
		}
#endif
	acc = r->d;
	bits--;
	for (wvalue=0, i=bits%5; i>=0; i--,bits--)
		wvalue = (wvalue<<1)+BN_is_bit_set(p,bits);
	bn_gather5(tmp.d,top,powerbuf,wvalue);

	/* Scan the exponent one window at a time starting from the most
	 * significant bits.
	 */
	bits--;
	while (bits >= 0)
		{
		for (wvalue=0, i=0; i<5; i++,bits--)
			wvalue = (wvalue<<1)+BN_is_bit_set(p,bits);

		bn_mul_mont(acc,acc,acc,np,n0,top);
		bn_mul_mont(acc,acc,acc,np,n0,top);
		bn_mul_mont(acc,acc,acc,np,n0,top);
		bn_mul_mont(acc,acc,acc,np,n0,top);
		bn_mul_mont(acc,acc,acc,np,n0,top);
		bn_mul_mont_gather5(acc,acc,powerbuf,np,n0,top,wvalue);
		bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
		bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
		bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
		bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
		bn_mul_mont(tmp.d,tmp.d,tmp.d,np,n0,top);
		bn_mul_mont_gather5(tmp.d,tmp.d,powerbuf,np,n0,top,wvalue);
		}

	r->top=top;
	bn_correct_top(r);
	tmp.top=top;
	bn_correct_top(&tmp);
	}
    else
#endif
	{
	if (!MOD_EXP_CTIME_COPY_TO_PREBUF(r, top, powerbuf, 0, numPowers)) goto err;
	if (!MOD_EXP_CTIME_COPY_TO_PREBUF(am, top, powerbuf, 1, numPowers)) goto err;
	if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&tmp, top, powerbuf, 0, numPowers)) goto err;
	if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&am,  top, powerbuf, 1, numPowers)) goto err;

	/* If the window size is greater than 1, then calculate
	 * val[i=2..2^winsize-1]. Powers are computed as a*a^(i-1)
@@ -804,19 +773,25 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
	 */
	if (window > 1)
		{
		for (i=2; i<numPowers; i++)
		if (!BN_mod_mul_montgomery(&tmp,&am,&am,mont,ctx))	goto err;
		if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&tmp, top, powerbuf, 2, numPowers)) goto err;
		for (i=3; i<numPowers; i++)
			{
			/* Calculate a^i = a^(i-1) * a */
			if (!BN_mod_mul_montgomery(&computeTemp,am,&computeTemp,mont,ctx))
			if (!BN_mod_mul_montgomery(&tmp,&am,&tmp,mont,ctx))
				goto err;
			if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&computeTemp, top, powerbuf, i, numPowers)) goto err;
			if (!MOD_EXP_CTIME_COPY_TO_PREBUF(&tmp, top, powerbuf, i, numPowers)) goto err;
			}
		}

	bits--;
	for (wvalue=0, i=bits%window; i>=0; i--,bits--)
		wvalue = (wvalue<<1)+BN_is_bit_set(p,bits);
	if (!MOD_EXP_CTIME_COPY_FROM_PREBUF(&tmp,top,powerbuf,wvalue,numPowers)) goto err;
 
	/* Scan the exponent one window at a time starting from the most
	 * significant bits.
	 */
	bits--;
 	while (bits >= 0)
  		{
 		wvalue=0; /* The 'value' of the window */
@@ -824,20 +799,20 @@ int BN_mod_exp_mont_consttime(BIGNUM *rr, const BIGNUM *a, const BIGNUM *p,
 		/* Scan the window, squaring the result as we go */
 		for (i=0; i<window; i++,bits--)
 			{
			if (!BN_mod_mul_montgomery(r,r,r,mont,ctx))	goto err;
			if (!BN_mod_mul_montgomery(&tmp,&tmp,&tmp,mont,ctx))	goto err;
			wvalue = (wvalue<<1)+BN_is_bit_set(p,bits);
  			}
 		
		/* Fetch the appropriate pre-computed value from the pre-buf */
		if (!MOD_EXP_CTIME_COPY_FROM_PREBUF(&computeTemp, top, powerbuf, wvalue, numPowers)) goto err;
		if (!MOD_EXP_CTIME_COPY_FROM_PREBUF(&am, top, powerbuf, wvalue, numPowers)) goto err;

 		/* Multiply the result into the intermediate result */
 		if (!BN_mod_mul_montgomery(r,r,&computeTemp,mont,ctx)) goto err;
 		if (!BN_mod_mul_montgomery(&tmp,&tmp,&am,mont,ctx)) goto err;
  		}
	}

 	/* Convert the final result from montgomery to standard format */
	if (!BN_from_montgomery(rr,r,mont,ctx)) goto err;
	if (!BN_from_montgomery(rr,&tmp,mont,ctx)) goto err;
	ret=1;
err:
	if ((in_mont == NULL) && (mont != NULL)) BN_MONT_CTX_free(mont);
@@ -846,7 +821,6 @@ err:
		OPENSSL_cleanse(powerbuf,powerbufLen);
		if (powerbufFree) OPENSSL_free(powerbufFree);
		}
 	if (am!=NULL) BN_clear(am);
	BN_CTX_end(ctx);
	return(ret);
	}