Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit dd597fb3 authored by Ard Biesheuvel's avatar Ard Biesheuvel Committed by Herbert Xu
Browse files

crypto: arm64/aes-blk - add support for CTS-CBC mode



Currently, we rely on the generic CTS chaining mode wrapper to
instantiate the cts(cbc(aes)) skcipher. Due to the high performance
of the ARMv8 Crypto Extensions AES instructions (~1 cycles per byte),
any overhead in the chaining mode layers is amplified, and so it pays
off considerably to fold the CTS handling into the SIMD routines.

On Cortex-A53, this results in a ~50% speedup for smaller input sizes.

Signed-off-by: default avatarArd Biesheuvel <ard.biesheuvel@linaro.org>
Signed-off-by: default avatarHerbert Xu <herbert@gondor.apana.org.au>
parent 6e7de6af
Loading
Loading
Loading
Loading
+165 −0
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@
#include <crypto/internal/hash.h>
#include <crypto/internal/simd.h>
#include <crypto/internal/skcipher.h>
#include <crypto/scatterwalk.h>
#include <linux/module.h>
#include <linux/cpufeature.h>
#include <crypto/xts.h>
@@ -31,6 +32,8 @@
#define aes_ecb_decrypt		ce_aes_ecb_decrypt
#define aes_cbc_encrypt		ce_aes_cbc_encrypt
#define aes_cbc_decrypt		ce_aes_cbc_decrypt
#define aes_cbc_cts_encrypt	ce_aes_cbc_cts_encrypt
#define aes_cbc_cts_decrypt	ce_aes_cbc_cts_decrypt
#define aes_ctr_encrypt		ce_aes_ctr_encrypt
#define aes_xts_encrypt		ce_aes_xts_encrypt
#define aes_xts_decrypt		ce_aes_xts_decrypt
@@ -45,6 +48,8 @@ MODULE_DESCRIPTION("AES-ECB/CBC/CTR/XTS using ARMv8 Crypto Extensions");
#define aes_ecb_decrypt		neon_aes_ecb_decrypt
#define aes_cbc_encrypt		neon_aes_cbc_encrypt
#define aes_cbc_decrypt		neon_aes_cbc_decrypt
#define aes_cbc_cts_encrypt	neon_aes_cbc_cts_encrypt
#define aes_cbc_cts_decrypt	neon_aes_cbc_cts_decrypt
#define aes_ctr_encrypt		neon_aes_ctr_encrypt
#define aes_xts_encrypt		neon_aes_xts_encrypt
#define aes_xts_decrypt		neon_aes_xts_decrypt
@@ -73,6 +78,11 @@ asmlinkage void aes_cbc_encrypt(u8 out[], u8 const in[], u32 const rk[],
asmlinkage void aes_cbc_decrypt(u8 out[], u8 const in[], u32 const rk[],
				int rounds, int blocks, u8 iv[]);

asmlinkage void aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
				int rounds, int bytes, u8 const iv[]);
asmlinkage void aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
				int rounds, int bytes, u8 const iv[]);

asmlinkage void aes_ctr_encrypt(u8 out[], u8 const in[], u32 const rk[],
				int rounds, int blocks, u8 ctr[]);

@@ -87,6 +97,12 @@ asmlinkage void aes_mac_update(u8 const in[], u32 const rk[], int rounds,
			       int blocks, u8 dg[], int enc_before,
			       int enc_after);

struct cts_cbc_req_ctx {
	struct scatterlist sg_src[2];
	struct scatterlist sg_dst[2];
	struct skcipher_request subreq;
};

struct crypto_aes_xts_ctx {
	struct crypto_aes_ctx key1;
	struct crypto_aes_ctx __aligned(8) key2;
@@ -209,6 +225,136 @@ static int cbc_decrypt(struct skcipher_request *req)
	return err;
}

static int cts_cbc_init_tfm(struct crypto_skcipher *tfm)
{
	crypto_skcipher_set_reqsize(tfm, sizeof(struct cts_cbc_req_ctx));
	return 0;
}

static int cts_cbc_encrypt(struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
	struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
	int err, rounds = 6 + ctx->key_length / 4;
	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
	struct scatterlist *src = req->src, *dst = req->dst;
	struct skcipher_walk walk;

	skcipher_request_set_tfm(&rctx->subreq, tfm);

	if (req->cryptlen == AES_BLOCK_SIZE)
		cbc_blocks = 1;

	if (cbc_blocks > 0) {
		unsigned int blocks;

		skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
					   cbc_blocks * AES_BLOCK_SIZE,
					   req->iv);

		err = skcipher_walk_virt(&walk, &rctx->subreq, false);

		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
			kernel_neon_begin();
			aes_cbc_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
					ctx->key_enc, rounds, blocks, walk.iv);
			kernel_neon_end();
			err = skcipher_walk_done(&walk,
						 walk.nbytes % AES_BLOCK_SIZE);
		}
		if (err)
			return err;

		if (req->cryptlen == AES_BLOCK_SIZE)
			return 0;

		dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
					     rctx->subreq.cryptlen);
		if (req->dst != req->src)
			dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
					       rctx->subreq.cryptlen);
	}

	/* handle ciphertext stealing */
	skcipher_request_set_crypt(&rctx->subreq, src, dst,
				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
				   req->iv);

	err = skcipher_walk_virt(&walk, &rctx->subreq, false);
	if (err)
		return err;

	kernel_neon_begin();
	aes_cbc_cts_encrypt(walk.dst.virt.addr, walk.src.virt.addr,
			    ctx->key_enc, rounds, walk.nbytes, walk.iv);
	kernel_neon_end();

	return skcipher_walk_done(&walk, 0);
}

static int cts_cbc_decrypt(struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
	struct crypto_aes_ctx *ctx = crypto_skcipher_ctx(tfm);
	struct cts_cbc_req_ctx *rctx = skcipher_request_ctx(req);
	int err, rounds = 6 + ctx->key_length / 4;
	int cbc_blocks = DIV_ROUND_UP(req->cryptlen, AES_BLOCK_SIZE) - 2;
	struct scatterlist *src = req->src, *dst = req->dst;
	struct skcipher_walk walk;

	skcipher_request_set_tfm(&rctx->subreq, tfm);

	if (req->cryptlen == AES_BLOCK_SIZE)
		cbc_blocks = 1;

	if (cbc_blocks > 0) {
		unsigned int blocks;

		skcipher_request_set_crypt(&rctx->subreq, req->src, req->dst,
					   cbc_blocks * AES_BLOCK_SIZE,
					   req->iv);

		err = skcipher_walk_virt(&walk, &rctx->subreq, false);

		while ((blocks = (walk.nbytes / AES_BLOCK_SIZE))) {
			kernel_neon_begin();
			aes_cbc_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
					ctx->key_dec, rounds, blocks, walk.iv);
			kernel_neon_end();
			err = skcipher_walk_done(&walk,
						 walk.nbytes % AES_BLOCK_SIZE);
		}
		if (err)
			return err;

		if (req->cryptlen == AES_BLOCK_SIZE)
			return 0;

		dst = src = scatterwalk_ffwd(rctx->sg_src, req->src,
					     rctx->subreq.cryptlen);
		if (req->dst != req->src)
			dst = scatterwalk_ffwd(rctx->sg_dst, req->dst,
					       rctx->subreq.cryptlen);
	}

	/* handle ciphertext stealing */
	skcipher_request_set_crypt(&rctx->subreq, src, dst,
				   req->cryptlen - cbc_blocks * AES_BLOCK_SIZE,
				   req->iv);

	err = skcipher_walk_virt(&walk, &rctx->subreq, false);
	if (err)
		return err;

	kernel_neon_begin();
	aes_cbc_cts_decrypt(walk.dst.virt.addr, walk.src.virt.addr,
			    ctx->key_dec, rounds, walk.nbytes, walk.iv);
	kernel_neon_end();

	return skcipher_walk_done(&walk, 0);
}

static int ctr_encrypt(struct skcipher_request *req)
{
	struct crypto_skcipher *tfm = crypto_skcipher_reqtfm(req);
@@ -334,6 +480,25 @@ static struct skcipher_alg aes_algs[] = { {
	.setkey		= skcipher_aes_setkey,
	.encrypt	= cbc_encrypt,
	.decrypt	= cbc_decrypt,
}, {
	.base = {
		.cra_name		= "__cts(cbc(aes))",
		.cra_driver_name	= "__cts-cbc-aes-" MODE,
		.cra_priority		= PRIO,
		.cra_flags		= CRYPTO_ALG_INTERNAL,
		.cra_blocksize		= 1,
		.cra_ctxsize		= sizeof(struct crypto_aes_ctx),
		.cra_module		= THIS_MODULE,
	},
	.min_keysize	= AES_MIN_KEY_SIZE,
	.max_keysize	= AES_MAX_KEY_SIZE,
	.ivsize		= AES_BLOCK_SIZE,
	.chunksize	= AES_BLOCK_SIZE,
	.walksize	= 2 * AES_BLOCK_SIZE,
	.setkey		= skcipher_aes_setkey,
	.encrypt	= cts_cbc_encrypt,
	.decrypt	= cts_cbc_decrypt,
	.init		= cts_cbc_init_tfm,
}, {
	.base = {
		.cra_name		= "__ctr(aes)",
+78 −1
Original line number Diff line number Diff line
@@ -170,6 +170,84 @@ AES_ENTRY(aes_cbc_decrypt)
AES_ENDPROC(aes_cbc_decrypt)


	/*
	 * aes_cbc_cts_encrypt(u8 out[], u8 const in[], u32 const rk[],
	 *		       int rounds, int bytes, u8 const iv[])
	 * aes_cbc_cts_decrypt(u8 out[], u8 const in[], u32 const rk[],
	 *		       int rounds, int bytes, u8 const iv[])
	 */

AES_ENTRY(aes_cbc_cts_encrypt)
	adr_l		x8, .Lcts_permute_table
	sub		x4, x4, #16
	add		x9, x8, #32
	add		x8, x8, x4
	sub		x9, x9, x4
	ld1		{v3.16b}, [x8]
	ld1		{v4.16b}, [x9]

	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
	ld1		{v1.16b}, [x1]

	ld1		{v5.16b}, [x5]			/* get iv */
	enc_prepare	w3, x2, x6

	eor		v0.16b, v0.16b, v5.16b		/* xor with iv */
	tbl		v1.16b, {v1.16b}, v4.16b
	encrypt_block	v0, w3, x2, x6, w7

	eor		v1.16b, v1.16b, v0.16b
	tbl		v0.16b, {v0.16b}, v3.16b
	encrypt_block	v1, w3, x2, x6, w7

	add		x4, x0, x4
	st1		{v0.16b}, [x4]			/* overlapping stores */
	st1		{v1.16b}, [x0]
	ret
AES_ENDPROC(aes_cbc_cts_encrypt)

AES_ENTRY(aes_cbc_cts_decrypt)
	adr_l		x8, .Lcts_permute_table
	sub		x4, x4, #16
	add		x9, x8, #32
	add		x8, x8, x4
	sub		x9, x9, x4
	ld1		{v3.16b}, [x8]
	ld1		{v4.16b}, [x9]

	ld1		{v0.16b}, [x1], x4		/* overlapping loads */
	ld1		{v1.16b}, [x1]

	ld1		{v5.16b}, [x5]			/* get iv */
	dec_prepare	w3, x2, x6

	tbl		v2.16b, {v1.16b}, v4.16b
	decrypt_block	v0, w3, x2, x6, w7
	eor		v2.16b, v2.16b, v0.16b

	tbx		v0.16b, {v1.16b}, v4.16b
	tbl		v2.16b, {v2.16b}, v3.16b
	decrypt_block	v0, w3, x2, x6, w7
	eor		v0.16b, v0.16b, v5.16b		/* xor with iv */

	add		x4, x0, x4
	st1		{v2.16b}, [x4]			/* overlapping stores */
	st1		{v0.16b}, [x0]
	ret
AES_ENDPROC(aes_cbc_cts_decrypt)

	.section	".rodata", "a"
	.align		6
.Lcts_permute_table:
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		 0x0,  0x1,  0x2,  0x3,  0x4,  0x5,  0x6,  0x7
	.byte		 0x8,  0x9,  0xa,  0xb,  0xc,  0xd,  0xe,  0xf
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.byte		0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff
	.previous


	/*
	 * aes_ctr_encrypt(u8 out[], u8 const in[], u8 const rk[], int rounds,
	 *		   int blocks, u8 ctr[])
@@ -253,7 +331,6 @@ AES_ENTRY(aes_ctr_encrypt)
	ins		v4.d[0], x7
	b		.Lctrcarrydone
AES_ENDPROC(aes_ctr_encrypt)
	.ltorg


	/*