Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 8f780044 authored by David S. Miller's avatar David S. Miller
Browse files

Merge branch 'net-tls-Combined-memory-allocation-for-decryption-request'



Vakul Garg says:

====================
net/tls: Combined memory allocation for decryption request

This patch does a combined memory allocation from heap for scatterlists,
aead_request, aad and iv for the tls record decryption path. In present
code, aead_request is allocated from heap, scatterlists on a conditional
basis are allocated on heap or on stack. This is inefficient as it may
requires multiple kmalloc/kfree.

The initialization vector passed in cryption request is allocated on
stack. This is a problem since the stack memory is not dma-able from
crypto accelerators.

Doing one combined memory allocation for each decryption request fixes
both the above issues. It also paves a way to be able to submit multiple
async decryption requests while the previous one is pending i.e. being
processed or queued.
====================

Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parents 78cbac64 0b243d00
Loading
Loading
Loading
Loading
+0 −4
Original line number Diff line number Diff line
@@ -124,10 +124,6 @@ struct tls_sw_context_rx {
	struct sk_buff *recv_pkt;
	u8 control;
	bool decrypted;

	char rx_aad_ciphertext[TLS_AAD_SPACE_SIZE];
	char rx_aad_plaintext[TLS_AAD_SPACE_SIZE];

};

struct tls_record_info {
+142 −96
Original line number Diff line number Diff line
@@ -48,19 +48,13 @@ static int tls_do_decryption(struct sock *sk,
			     struct scatterlist *sgout,
			     char *iv_recv,
			     size_t data_len,
			     struct sk_buff *skb,
			     gfp_t flags)
			     struct aead_request *aead_req)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct aead_request *aead_req;

	int ret;

	aead_req = aead_request_alloc(ctx->aead_recv, flags);
	if (!aead_req)
		return -ENOMEM;

	aead_request_set_tfm(aead_req, ctx->aead_recv);
	aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
	aead_request_set_crypt(aead_req, sgin, sgout,
			       data_len + tls_ctx->rx.tag_size,
@@ -69,8 +63,6 @@ static int tls_do_decryption(struct sock *sk,
				  crypto_req_done, &ctx->async_wait);

	ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);

	aead_request_free(aead_req);
	return ret;
}

@@ -657,8 +649,132 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
	return skb;
}

/* This function decrypts the input skb into either out_iov or in out_sg
 * or in skb buffers itself. The input parameter 'zc' indicates if
 * zero-copy mode needs to be tried or not. With zero-copy mode, either
 * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
 * NULL, then the decryption happens inside skb buffers itself, i.e.
 * zero-copy gets disabled and 'zc' is updated.
 */

static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
			    struct iov_iter *out_iov,
			    struct scatterlist *out_sg,
			    int *chunk, bool *zc)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	struct strp_msg *rxm = strp_msg(skb);
	int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
	struct aead_request *aead_req;
	struct sk_buff *unused;
	u8 *aad, *iv, *mem = NULL;
	struct scatterlist *sgin = NULL;
	struct scatterlist *sgout = NULL;
	const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;

	if (*zc && (out_iov || out_sg)) {
		if (out_iov)
			n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
		else
			n_sgout = sg_nents(out_sg);
	} else {
		n_sgout = 0;
		*zc = false;
	}

	n_sgin = skb_cow_data(skb, 0, &unused);
	if (n_sgin < 1)
		return -EBADMSG;

	/* Increment to accommodate AAD */
	n_sgin = n_sgin + 1;

	nsg = n_sgin + n_sgout;

	aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
	mem_size = aead_size + (nsg * sizeof(struct scatterlist));
	mem_size = mem_size + TLS_AAD_SPACE_SIZE;
	mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);

	/* Allocate a single block of memory which contains
	 * aead_req || sgin[] || sgout[] || aad || iv.
	 * This order achieves correct alignment for aead_req, sgin, sgout.
	 */
	mem = kmalloc(mem_size, sk->sk_allocation);
	if (!mem)
		return -ENOMEM;

	/* Segment the allocated memory */
	aead_req = (struct aead_request *)mem;
	sgin = (struct scatterlist *)(mem + aead_size);
	sgout = sgin + n_sgin;
	aad = (u8 *)(sgout + n_sgout);
	iv = aad + TLS_AAD_SPACE_SIZE;

	/* Prepare IV */
	err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
			    tls_ctx->rx.iv_size);
	if (err < 0) {
		kfree(mem);
		return err;
	}
	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);

	/* Prepare AAD */
	tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
		     tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
		     ctx->control);

	/* Prepare sgin */
	sg_init_table(sgin, n_sgin);
	sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
	err = skb_to_sgvec(skb, &sgin[1],
			   rxm->offset + tls_ctx->rx.prepend_size,
			   rxm->full_len - tls_ctx->rx.prepend_size);
	if (err < 0) {
		kfree(mem);
		return err;
	}

	if (n_sgout) {
		if (out_iov) {
			sg_init_table(sgout, n_sgout);
			sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);

			*chunk = 0;
			err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
						 chunk, &sgout[1],
						 (n_sgout - 1), false);
			if (err < 0)
				goto fallback_to_reg_recv;
		} else if (out_sg) {
			memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
		} else {
			goto fallback_to_reg_recv;
		}
	} else {
fallback_to_reg_recv:
		sgout = sgin;
		pages = 0;
		*chunk = 0;
		*zc = false;
	}

	/* Prepare and submit AEAD request */
	err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);

	/* Release the pages in case iov was mapped to pages */
	for (; pages > 0; pages--)
		put_page(sg_page(&sgout[pages]));

	kfree(mem);
	return err;
}

static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
			      struct scatterlist *sgout, bool *zc)
			      struct iov_iter *dest, int *chunk, bool *zc)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -671,7 +787,7 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
		return err;
#endif
	if (!ctx->decrypted) {
		err = decrypt_skb(sk, skb, sgout);
		err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
		if (err < 0)
			return err;
	} else {
@@ -690,54 +806,10 @@ static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
int decrypt_skb(struct sock *sk, struct sk_buff *skb,
		struct scatterlist *sgout)
{
	struct tls_context *tls_ctx = tls_get_ctx(sk);
	struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
	char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE];
	struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
	struct scatterlist *sgin = &sgin_arr[0];
	struct strp_msg *rxm = strp_msg(skb);
	int ret, nsg = ARRAY_SIZE(sgin_arr);
	struct sk_buff *unused;

	ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
			    iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
			    tls_ctx->rx.iv_size);
	if (ret < 0)
		return ret;

	memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
	if (!sgout) {
		nsg = skb_cow_data(skb, 0, &unused) + 1;
		sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation);
		sgout = sgin;
	}

	sg_init_table(sgin, nsg);
	sg_set_buf(&sgin[0], ctx->rx_aad_ciphertext, TLS_AAD_SPACE_SIZE);

	nsg = skb_to_sgvec(skb, &sgin[1],
			   rxm->offset + tls_ctx->rx.prepend_size,
			   rxm->full_len - tls_ctx->rx.prepend_size);
	if (nsg < 0) {
		ret = nsg;
		goto out;
	}

	tls_make_aad(ctx->rx_aad_ciphertext,
		     rxm->full_len - tls_ctx->rx.overhead_size,
		     tls_ctx->rx.rec_seq,
		     tls_ctx->rx.rec_seq_size,
		     ctx->control);

	ret = tls_do_decryption(sk, sgin, sgout, iv,
				rxm->full_len - tls_ctx->rx.overhead_size,
				skb, sk->sk_allocation);

out:
	if (sgin != &sgin_arr[0])
		kfree(sgin);
	bool zc = true;
	int chunk;

	return ret;
	return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
}

static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
@@ -816,44 +888,18 @@ int tls_sw_recvmsg(struct sock *sk,
		}

		if (!ctx->decrypted) {
			int page_count;
			int to_copy;

			page_count = iov_iter_npages(&msg->msg_iter,
						     MAX_SKB_FRAGS);
			to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
			if (!is_kvec && to_copy <= len && page_count < MAX_SKB_FRAGS &&
			    likely(!(flags & MSG_PEEK)))  {
				struct scatterlist sgin[MAX_SKB_FRAGS + 1];
				int pages = 0;
			int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;

			if (!is_kvec && to_copy <= len &&
			    likely(!(flags & MSG_PEEK)))
				zc = true;
				sg_init_table(sgin, MAX_SKB_FRAGS + 1);
				sg_set_buf(&sgin[0], ctx->rx_aad_plaintext,
					   TLS_AAD_SPACE_SIZE);

				err = zerocopy_from_iter(sk, &msg->msg_iter,
							 to_copy, &pages,
							 &chunk, &sgin[1],
							 MAX_SKB_FRAGS,	false);
				if (err < 0)
					goto fallback_to_reg_recv;

				err = decrypt_skb_update(sk, skb, sgin, &zc);
				for (; pages > 0; pages--)
					put_page(sg_page(&sgin[pages]));
				if (err < 0) {
					tls_err_abort(sk, EBADMSG);
					goto recv_end;
				}
			} else {
fallback_to_reg_recv:
				err = decrypt_skb_update(sk, skb, NULL, &zc);
			err = decrypt_skb_update(sk, skb, &msg->msg_iter,
						 &chunk, &zc);
			if (err < 0) {
				tls_err_abort(sk, EBADMSG);
				goto recv_end;
			}
			}
			ctx->decrypted = true;
		}

@@ -903,7 +949,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
	int err = 0;
	long timeo;
	int chunk;
	bool zc;
	bool zc = false;

	lock_sock(sk);

@@ -920,7 +966,7 @@ ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
	}

	if (!ctx->decrypted) {
		err = decrypt_skb_update(sk, skb, NULL, &zc);
		err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);

		if (err < 0) {
			tls_err_abort(sk, EBADMSG);