Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 3143edd3 authored by Sage Weil's avatar Sage Weil
Browse files

ceph: clean up statfs



Avoid unnecessary msgpool.  Preallocate reply.  Fix use-after-free race.

Signed-off-by: default avatarSage Weil <sage@newdream.net>
parent 6f46cb29
Loading
Loading
Loading
Loading
+94 −55
Original line number Diff line number Diff line
@@ -393,16 +393,64 @@ static void __insert_statfs(struct ceph_mon_client *monc,
	rb_insert_color(&new->node, &monc->statfs_request_tree);
}

static void release_statfs_request(struct kref *kref)
{
	struct ceph_mon_statfs_request *req =
		container_of(kref, struct ceph_mon_statfs_request, kref);

	if (req->reply)
		ceph_msg_put(req->reply);
	if (req->request)
		ceph_msg_put(req->request);
}

static void put_statfs_request(struct ceph_mon_statfs_request *req)
{
	kref_put(&req->kref, release_statfs_request);
}

static void get_statfs_request(struct ceph_mon_statfs_request *req)
{
	kref_get(&req->kref);
}

static struct ceph_msg *get_statfs_reply(struct ceph_connection *con,
					 struct ceph_msg_header *hdr,
					 int *skip)
{
	struct ceph_mon_client *monc = con->private;
	struct ceph_mon_statfs_request *req;
	u64 tid = le64_to_cpu(hdr->tid);
	struct ceph_msg *m;

	mutex_lock(&monc->mutex);
	req = __lookup_statfs(monc, tid);
	if (!req) {
		dout("get_statfs_reply %lld dne\n", tid);
		*skip = 1;
		m = NULL;
	} else {
		dout("get_statfs_reply %lld got %p\n", tid, req->reply);
		m = ceph_msg_get(req->reply);
		/*
		 * we don't need to track the connection reading into
		 * this reply because we only have one open connection
		 * at a time, ever.
		 */
	}
	mutex_unlock(&monc->mutex);
	return m;
}

static void handle_statfs_reply(struct ceph_mon_client *monc,
				struct ceph_msg *msg)
{
	struct ceph_mon_statfs_request *req;
	struct ceph_mon_statfs_reply *reply = msg->front.iov_base;
	u64 tid;
	u64 tid = le64_to_cpu(msg->hdr.tid);

	if (msg->front.iov_len != sizeof(*reply))
		goto bad;
	tid = le64_to_cpu(msg->hdr.tid);
	dout("handle_statfs_reply %p tid %llu\n", msg, tid);

	mutex_lock(&monc->mutex);
@@ -410,10 +458,13 @@ static void handle_statfs_reply(struct ceph_mon_client *monc,
	if (req) {
		*req->buf = reply->st;
		req->result = 0;
		get_statfs_request(req);
	}
	mutex_unlock(&monc->mutex);
	if (req)
	if (req) {
		complete(&req->completion);
		put_statfs_request(req);
	}
	return;

bad:
@@ -422,67 +473,63 @@ static void handle_statfs_reply(struct ceph_mon_client *monc,
}

/*
 * (re)send a statfs request
 * Do a synchronous statfs().
 */
static int send_statfs(struct ceph_mon_client *monc,
		       struct ceph_mon_statfs_request *req)
int ceph_monc_do_statfs(struct ceph_mon_client *monc, struct ceph_statfs *buf)
{
	struct ceph_msg *msg;
	struct ceph_mon_statfs_request *req;
	struct ceph_mon_statfs *h;
	int err;

	req = kmalloc(sizeof(*req), GFP_NOFS);
	if (!req)
		return -ENOMEM;

	memset(req, 0, sizeof(*req));
	kref_init(&req->kref);
	req->buf = buf;
	init_completion(&req->completion);

	dout("send_statfs tid %llu\n", req->tid);
	msg = ceph_msg_new(CEPH_MSG_STATFS, sizeof(*h), 0, 0, NULL);
	if (IS_ERR(msg))
		return PTR_ERR(msg);
	req->request = msg;
	msg->hdr.tid = cpu_to_le64(req->tid);
	h = msg->front.iov_base;
	req->request = ceph_msg_new(CEPH_MSG_STATFS, sizeof(*h), 0, 0, NULL);
	if (IS_ERR(req->request)) {
		err = PTR_ERR(req->request);
		goto out;
	}
	req->reply = ceph_msg_new(CEPH_MSG_STATFS_REPLY, 1024, 0, 0, NULL);
	if (IS_ERR(req->reply)) {
		err = PTR_ERR(req->reply);
		goto out;
	}

	/* fill out request */
	h = req->request->front.iov_base;
	h->monhdr.have_version = 0;
	h->monhdr.session_mon = cpu_to_le16(-1);
	h->monhdr.session_mon_tid = 0;
	h->fsid = monc->monmap->fsid;
	ceph_con_send(monc->con, msg);
	return 0;
}

/*
 * Do a synchronous statfs().
 */
int ceph_monc_do_statfs(struct ceph_mon_client *monc, struct ceph_statfs *buf)
{
	struct ceph_mon_statfs_request req;
	int err;

	req.buf = buf;
	init_completion(&req.completion);

	/* allocate memory for reply */
	err = ceph_msgpool_resv(&monc->msgpool_statfs_reply, 1);
	if (err)
		return err;

	/* register request */
	mutex_lock(&monc->mutex);
	req.tid = ++monc->last_tid;
	req.last_attempt = jiffies;
	req.delay = BASE_DELAY_INTERVAL;
	__insert_statfs(monc, &req);
	req->tid = ++monc->last_tid;
	req->request->hdr.tid = cpu_to_le64(req->tid);
	__insert_statfs(monc, req);
	monc->num_statfs_requests++;
	mutex_unlock(&monc->mutex);

	/* send request and wait */
	err = send_statfs(monc, &req);
	if (!err)
		err = wait_for_completion_interruptible(&req.completion);
	ceph_con_send(monc->con, ceph_msg_get(req->request));
	err = wait_for_completion_interruptible(&req->completion);

	mutex_lock(&monc->mutex);
	rb_erase(&req.node, &monc->statfs_request_tree);
	rb_erase(&req->node, &monc->statfs_request_tree);
	monc->num_statfs_requests--;
	ceph_msgpool_resv(&monc->msgpool_statfs_reply, -1);
	mutex_unlock(&monc->mutex);

	if (!err)
		err = req.result;
		err = req->result;

out:
	kref_put(&req->kref, release_statfs_request);
	return err;
}

@@ -496,7 +543,7 @@ static void __resend_statfs(struct ceph_mon_client *monc)

	for (p = rb_first(&monc->statfs_request_tree); p; p = rb_next(p)) {
		req = rb_entry(p, struct ceph_mon_statfs_request, node);
		send_statfs(monc, req);
		ceph_con_send(monc->con, ceph_msg_get(req->request));
	}
}

@@ -591,13 +638,9 @@ int ceph_monc_init(struct ceph_mon_client *monc, struct ceph_client *cl)
			       sizeof(struct ceph_mon_subscribe_ack), 1, false);
	if (err < 0)
		goto out_monmap;
	err = ceph_msgpool_init(&monc->msgpool_statfs_reply,
				sizeof(struct ceph_mon_statfs_reply), 0, false);
	if (err < 0)
		goto out_pool1;
	err = ceph_msgpool_init(&monc->msgpool_auth_reply, 4096, 1, false);
	if (err < 0)
		goto out_pool2;
		goto out_pool;

	monc->m_auth = ceph_msg_new(CEPH_MSG_AUTH, 4096, 0, 0, NULL);
	monc->pending_auth = 0;
@@ -624,10 +667,8 @@ int ceph_monc_init(struct ceph_mon_client *monc, struct ceph_client *cl)

out_pool3:
	ceph_msgpool_destroy(&monc->msgpool_auth_reply);
out_pool2:
out_pool:
	ceph_msgpool_destroy(&monc->msgpool_subscribe_ack);
out_pool1:
	ceph_msgpool_destroy(&monc->msgpool_statfs_reply);
out_monmap:
	kfree(monc->monmap);
out:
@@ -652,7 +693,6 @@ void ceph_monc_stop(struct ceph_mon_client *monc)

	ceph_msg_put(monc->m_auth);
	ceph_msgpool_destroy(&monc->msgpool_subscribe_ack);
	ceph_msgpool_destroy(&monc->msgpool_statfs_reply);
	ceph_msgpool_destroy(&monc->msgpool_auth_reply);

	kfree(monc->monmap);
@@ -773,8 +813,7 @@ static struct ceph_msg *mon_alloc_msg(struct ceph_connection *con,
		m = ceph_msgpool_get(&monc->msgpool_subscribe_ack, front_len);
		break;
	case CEPH_MSG_STATFS_REPLY:
		m = ceph_msgpool_get(&monc->msgpool_statfs_reply, front_len);
		break;
		return get_statfs_reply(con, hdr, skip);
	case CEPH_MSG_AUTH_REPLY:
		m = ceph_msgpool_get(&monc->msgpool_auth_reply, front_len);
		break;
+3 −2
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@
#define _FS_CEPH_MON_CLIENT_H

#include <linux/completion.h>
#include <linux/kref.h>
#include <linux/rbtree.h>

#include "messenger.h"
@@ -44,13 +45,14 @@ struct ceph_mon_request {
 * to the caller
 */
struct ceph_mon_statfs_request {
	struct kref kref;
	u64 tid;
	struct rb_node node;
	int result;
	struct ceph_statfs *buf;
	struct completion completion;
	unsigned long last_attempt, delay; /* jiffies */
	struct ceph_msg *request;  /* original request */
	struct ceph_msg *reply;    /* and reply */
};

struct ceph_mon_client {
@@ -72,7 +74,6 @@ struct ceph_mon_client {

	/* msg pools */
	struct ceph_msgpool msgpool_subscribe_ack;
	struct ceph_msgpool msgpool_statfs_reply;
	struct ceph_msgpool msgpool_auth_reply;

	/* pending statfs requests */