Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 71e52c27 authored by Ard Biesheuvel's avatar Ard Biesheuvel Committed by Herbert Xu
Browse files

crypto: arm64/aes-ce-gcm - operate on two input blocks at a time



Update the core AES/GCM transform and the associated plumbing to operate
on 2 AES/GHASH blocks at a time. By itself, this is not expected to
result in a noticeable speedup, but it paves the way for reimplementing
the GHASH component using 2-way aggregation.

Signed-off-by: default avatarArd Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 3465893d
Loading
Loading
Loading
Loading
+97 −30
Original line number Diff line number Diff line
@@ -286,9 +286,10 @@ ENTRY(pmull_ghash_update_p8)
	__pmull_ghash	p8
ENDPROC(pmull_ghash_update_p8)

	KS		.req	v8
	CTR		.req	v9
	INP		.req	v10
	KS0		.req	v8
	KS1		.req	v9
	INP0		.req	v10
	INP1		.req	v11

	.macro		load_round_keys, rounds, rk
	cmp		\rounds, #12
@@ -336,84 +337,146 @@ CPU_LE( rev x8, x8 )

	.if		\enc == 1
	ldr		x10, [sp]
	ld1		{KS.16b}, [x10]
	ld1		{KS0.16b-KS1.16b}, [x10]
	.endif

0:	ld1		{CTR.8b}, [x5]			// load upper counter
	ld1		{INP.16b}, [x3], #16
0:	ld1		{INP0.16b-INP1.16b}, [x3], #32

	rev		x9, x8
	add		x8, x8, #1
	sub		w0, w0, #1
	ins		CTR.d[1], x9			// set lower counter
	add		x11, x8, #1
	add		x8, x8, #2

	.if		\enc == 1
	eor		INP.16b, INP.16b, KS.16b	// encrypt input
	st1		{INP.16b}, [x2], #16
	eor		INP0.16b, INP0.16b, KS0.16b	// encrypt input
	eor		INP1.16b, INP1.16b, KS1.16b
	.endif

	rev64		T1.16b, INP.16b
	ld1		{KS0.8b}, [x5]			// load upper counter
	rev		x11, x11
	sub		w0, w0, #2
	mov		KS1.8b, KS0.8b
	ins		KS0.d[1], x9			// set lower counter
	ins		KS1.d[1], x11

	rev64		T1.16b, INP0.16b

	cmp		w7, #12
	b.ge		2f				// AES-192/256?

1:	enc_round	CTR, v21
1:	enc_round	KS0, v21

	ext		T2.16b, XL.16b, XL.16b, #8
	ext		IN1.16b, T1.16b, T1.16b, #8

	enc_round	CTR, v22
	enc_round	KS1, v21

	eor		T1.16b, T1.16b, T2.16b
	eor		XL.16b, XL.16b, IN1.16b

	enc_round	CTR, v23
	enc_round	KS0, v22

	pmull2		XH.1q, SHASH.2d, XL.2d		// a1 * b1
	eor		T1.16b, T1.16b, XL.16b

	enc_round	CTR, v24
	enc_round	KS1, v22

	pmull		XL.1q, SHASH.1d, XL.1d		// a0 * b0
	pmull		XM.1q, SHASH2.1d, T1.1d		// (a1 + a0)(b1 + b0)

	enc_round	CTR, v25
	enc_round	KS0, v23

	ext		T1.16b, XL.16b, XH.16b, #8
	eor		T2.16b, XL.16b, XH.16b
	eor		XM.16b, XM.16b, T1.16b

	enc_round	CTR, v26
	enc_round	KS1, v23

	eor		XM.16b, XM.16b, T2.16b
	pmull		T2.1q, XL.1d, MASK.1d

	enc_round	CTR, v27
	enc_round	KS0, v24

	mov		XH.d[0], XM.d[1]
	mov		XM.d[1], XL.d[0]

	enc_round	CTR, v28
	enc_round	KS1, v24

	eor		XL.16b, XM.16b, T2.16b

	enc_round	CTR, v29
	enc_round	KS0, v25

	ext		T2.16b, XL.16b, XL.16b, #8

	aese		CTR.16b, v30.16b
	enc_round	KS1, v25

	pmull		XL.1q, XL.1d, MASK.1d
	eor		T2.16b, T2.16b, XH.16b

	eor		KS.16b, CTR.16b, v31.16b
	enc_round	KS0, v26

	eor		XL.16b, XL.16b, T2.16b
	rev64		T1.16b, INP1.16b

	enc_round	KS1, v26

	ext		T2.16b, XL.16b, XL.16b, #8
	ext		IN1.16b, T1.16b, T1.16b, #8

	enc_round	KS0, v27

	eor		T1.16b, T1.16b, T2.16b
	eor		XL.16b, XL.16b, IN1.16b

	enc_round	KS1, v27

	pmull2		XH.1q, SHASH.2d, XL.2d		// a1 * b1
	eor		T1.16b, T1.16b, XL.16b

	enc_round	KS0, v28

	pmull		XL.1q, SHASH.1d, XL.1d		// a0 * b0
	pmull		XM.1q, SHASH2.1d, T1.1d		// (a1 + a0)(b1 + b0)

	enc_round	KS1, v28

	ext		T1.16b, XL.16b, XH.16b, #8
	eor		T2.16b, XL.16b, XH.16b
	eor		XM.16b, XM.16b, T1.16b

	enc_round	KS0, v29

	eor		XM.16b, XM.16b, T2.16b
	pmull		T2.1q, XL.1d, MASK.1d

	enc_round	KS1, v29

	mov		XH.d[0], XM.d[1]
	mov		XM.d[1], XL.d[0]

	aese		KS0.16b, v30.16b

	eor		XL.16b, XM.16b, T2.16b

	aese		KS1.16b, v30.16b

	ext		T2.16b, XL.16b, XL.16b, #8

	eor		KS0.16b, KS0.16b, v31.16b

	pmull		XL.1q, XL.1d, MASK.1d
	eor		T2.16b, T2.16b, XH.16b

	eor		KS1.16b, KS1.16b, v31.16b

	eor		XL.16b, XL.16b, T2.16b

	.if		\enc == 0
	eor		INP.16b, INP.16b, KS.16b
	st1		{INP.16b}, [x2], #16
	eor		INP0.16b, INP0.16b, KS0.16b
	eor		INP1.16b, INP1.16b, KS1.16b
	.endif

	st1		{INP0.16b-INP1.16b}, [x2], #32

	cbnz		w0, 0b

CPU_LE(	rev		x8, x8		)
@@ -421,16 +484,20 @@ CPU_LE( rev x8, x8 )
	str		x8, [x5, #8]			// store lower counter

	.if		\enc == 1
	st1		{KS.16b}, [x10]
	st1		{KS0.16b-KS1.16b}, [x10]
	.endif

	ret

2:	b.eq		3f				// AES-192?
	enc_round	CTR, v17
	enc_round	CTR, v18
3:	enc_round	CTR, v19
	enc_round	CTR, v20
	enc_round	KS0, v17
	enc_round	KS1, v17
	enc_round	KS0, v18
	enc_round	KS1, v18
3:	enc_round	KS0, v19
	enc_round	KS1, v19
	enc_round	KS0, v20
	enc_round	KS1, v20
	b		1b
	.endm

+64 −39
Original line number Diff line number Diff line
@@ -348,9 +348,10 @@ static int gcm_encrypt(struct aead_request *req)
	struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
	struct skcipher_walk walk;
	u8 iv[AES_BLOCK_SIZE];
	u8 ks[AES_BLOCK_SIZE];
	u8 ks[2 * AES_BLOCK_SIZE];
	u8 tag[AES_BLOCK_SIZE];
	u64 dg[2] = {};
	int nrounds = num_rounds(&ctx->aes_key);
	int err;

	if (req->assoclen)
@@ -362,32 +363,31 @@ static int gcm_encrypt(struct aead_request *req)
	if (likely(may_use_simd())) {
		kernel_neon_begin();

		pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc,
					num_rounds(&ctx->aes_key));
		pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
		put_unaligned_be32(2, iv + GCM_IV_SIZE);
		pmull_gcm_encrypt_block(ks, iv, NULL,
					num_rounds(&ctx->aes_key));
		pmull_gcm_encrypt_block(ks, iv, NULL, nrounds);
		put_unaligned_be32(3, iv + GCM_IV_SIZE);
		pmull_gcm_encrypt_block(ks + AES_BLOCK_SIZE, iv, NULL, nrounds);
		put_unaligned_be32(4, iv + GCM_IV_SIZE);
		kernel_neon_end();

		err = skcipher_walk_aead_encrypt(&walk, req, false);

		while (walk.nbytes >= AES_BLOCK_SIZE) {
			int blocks = walk.nbytes / AES_BLOCK_SIZE;
		while (walk.nbytes >= 2 * AES_BLOCK_SIZE) {
			int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;

			kernel_neon_begin();
			pmull_gcm_encrypt(blocks, dg, walk.dst.virt.addr,
					  walk.src.virt.addr, &ctx->ghash_key,
					  iv, ctx->aes_key.key_enc,
					  num_rounds(&ctx->aes_key), ks);
					  iv, ctx->aes_key.key_enc, nrounds,
					  ks);
			kernel_neon_end();

			err = skcipher_walk_done(&walk,
						 walk.nbytes % AES_BLOCK_SIZE);
					walk.nbytes % (2 * AES_BLOCK_SIZE));
		}
	} else {
		__aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv,
				    num_rounds(&ctx->aes_key));
		__aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv, nrounds);
		put_unaligned_be32(2, iv + GCM_IV_SIZE);

		err = skcipher_walk_aead_encrypt(&walk, req, false);
@@ -399,8 +399,7 @@ static int gcm_encrypt(struct aead_request *req)

			do {
				__aes_arm64_encrypt(ctx->aes_key.key_enc,
						    ks, iv,
						    num_rounds(&ctx->aes_key));
						    ks, iv, nrounds);
				crypto_xor_cpy(dst, src, ks, AES_BLOCK_SIZE);
				crypto_inc(iv, AES_BLOCK_SIZE);

@@ -417,19 +416,28 @@ static int gcm_encrypt(struct aead_request *req)
		}
		if (walk.nbytes)
			__aes_arm64_encrypt(ctx->aes_key.key_enc, ks, iv,
					    num_rounds(&ctx->aes_key));
					    nrounds);
	}

	/* handle the tail */
	if (walk.nbytes) {
		u8 buf[GHASH_BLOCK_SIZE];
		unsigned int nbytes = walk.nbytes;
		u8 *dst = walk.dst.virt.addr;
		u8 *head = NULL;

		crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, ks,
			       walk.nbytes);

		memcpy(buf, walk.dst.virt.addr, walk.nbytes);
		memset(buf + walk.nbytes, 0, GHASH_BLOCK_SIZE - walk.nbytes);
		ghash_do_update(1, dg, buf, &ctx->ghash_key, NULL);
		if (walk.nbytes > GHASH_BLOCK_SIZE) {
			head = dst;
			dst += GHASH_BLOCK_SIZE;
			nbytes %= GHASH_BLOCK_SIZE;
		}

		memcpy(buf, dst, nbytes);
		memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes);
		ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head);

		err = skcipher_walk_done(&walk, 0);
	}
@@ -452,10 +460,11 @@ static int gcm_decrypt(struct aead_request *req)
	struct gcm_aes_ctx *ctx = crypto_aead_ctx(aead);
	unsigned int authsize = crypto_aead_authsize(aead);
	struct skcipher_walk walk;
	u8 iv[AES_BLOCK_SIZE];
	u8 iv[2 * AES_BLOCK_SIZE];
	u8 tag[AES_BLOCK_SIZE];
	u8 buf[GHASH_BLOCK_SIZE];
	u8 buf[2 * GHASH_BLOCK_SIZE];
	u64 dg[2] = {};
	int nrounds = num_rounds(&ctx->aes_key);
	int err;

	if (req->assoclen)
@@ -466,37 +475,44 @@ static int gcm_decrypt(struct aead_request *req)

	if (likely(may_use_simd())) {
		kernel_neon_begin();

		pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc,
					num_rounds(&ctx->aes_key));
		pmull_gcm_encrypt_block(tag, iv, ctx->aes_key.key_enc, nrounds);
		put_unaligned_be32(2, iv + GCM_IV_SIZE);
		kernel_neon_end();

		err = skcipher_walk_aead_decrypt(&walk, req, false);

		while (walk.nbytes >= AES_BLOCK_SIZE) {
			int blocks = walk.nbytes / AES_BLOCK_SIZE;
		while (walk.nbytes >= 2 * AES_BLOCK_SIZE) {
			int blocks = walk.nbytes / (2 * AES_BLOCK_SIZE) * 2;

			kernel_neon_begin();
			pmull_gcm_decrypt(blocks, dg, walk.dst.virt.addr,
					  walk.src.virt.addr, &ctx->ghash_key,
					  iv, ctx->aes_key.key_enc,
					  num_rounds(&ctx->aes_key));
					  iv, ctx->aes_key.key_enc, nrounds);
			kernel_neon_end();

			err = skcipher_walk_done(&walk,
						 walk.nbytes % AES_BLOCK_SIZE);
					walk.nbytes % (2 * AES_BLOCK_SIZE));
		}

		if (walk.nbytes) {
			u8 *iv2 = iv + AES_BLOCK_SIZE;

			if (walk.nbytes > AES_BLOCK_SIZE) {
				memcpy(iv2, iv, AES_BLOCK_SIZE);
				crypto_inc(iv2, AES_BLOCK_SIZE);
			}

			kernel_neon_begin();
			pmull_gcm_encrypt_block(iv, iv, ctx->aes_key.key_enc,
						num_rounds(&ctx->aes_key));
						nrounds);

			if (walk.nbytes > AES_BLOCK_SIZE)
				pmull_gcm_encrypt_block(iv2, iv2, NULL,
							nrounds);
			kernel_neon_end();
		}

	} else {
		__aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv,
				    num_rounds(&ctx->aes_key));
		__aes_arm64_encrypt(ctx->aes_key.key_enc, tag, iv, nrounds);
		put_unaligned_be32(2, iv + GCM_IV_SIZE);

		err = skcipher_walk_aead_decrypt(&walk, req, false);
@@ -511,8 +527,7 @@ static int gcm_decrypt(struct aead_request *req)

			do {
				__aes_arm64_encrypt(ctx->aes_key.key_enc,
						    buf, iv,
						    num_rounds(&ctx->aes_key));
						    buf, iv, nrounds);
				crypto_xor_cpy(dst, src, buf, AES_BLOCK_SIZE);
				crypto_inc(iv, AES_BLOCK_SIZE);

@@ -525,14 +540,24 @@ static int gcm_decrypt(struct aead_request *req)
		}
		if (walk.nbytes)
			__aes_arm64_encrypt(ctx->aes_key.key_enc, iv, iv,
					    num_rounds(&ctx->aes_key));
					    nrounds);
	}

	/* handle the tail */
	if (walk.nbytes) {
		memcpy(buf, walk.src.virt.addr, walk.nbytes);
		memset(buf + walk.nbytes, 0, GHASH_BLOCK_SIZE - walk.nbytes);
		ghash_do_update(1, dg, buf, &ctx->ghash_key, NULL);
		const u8 *src = walk.src.virt.addr;
		const u8 *head = NULL;
		unsigned int nbytes = walk.nbytes;

		if (walk.nbytes > GHASH_BLOCK_SIZE) {
			head = src;
			src += GHASH_BLOCK_SIZE;
			nbytes %= GHASH_BLOCK_SIZE;
		}

		memcpy(buf, src, nbytes);
		memset(buf + nbytes, 0, GHASH_BLOCK_SIZE - nbytes);
		ghash_do_update(!!nbytes, dg, buf, &ctx->ghash_key, head);

		crypto_xor_cpy(walk.dst.virt.addr, walk.src.virt.addr, iv,
			       walk.nbytes);
@@ -557,7 +582,7 @@ static int gcm_decrypt(struct aead_request *req)

static struct aead_alg gcm_aes_alg = {
	.ivsize			= GCM_IV_SIZE,
	.chunksize		= AES_BLOCK_SIZE,
	.chunksize		= 2 * AES_BLOCK_SIZE,
	.maxauthsize		= AES_BLOCK_SIZE,
	.setkey			= gcm_setkey,
	.setauthsize		= gcm_setauthsize,