Donate to e Foundation | Murena handsets with /e/OS | Own a part of Murena! Learn more

Commit 6b18662e authored by Al Viro's avatar Al Viro
Browse files

9p connect fixes



* if we fail in p9_conn_create(), we shouldn't leak references to struct file.
  Logics in ->close() doesn't help - ->trans is already gone by the time it's
  called.
* sock_create_kern() can fail.
* use of sock_map_fd() is all fscked up; I'd fixed most of that, but the
  rest will have to wait for a bit more work in net/socket.c (we still are
  violating the basic rule of working with descriptor table: "once the reference
  is installed there, don't rely on finding it there again").

Signed-off-by: default avatarAl Viro <viro@zeniv.linux.org.uk>
parent 7cbe66b6
Loading
Loading
Loading
Loading
+46 −66
Original line number Original line Diff line number Diff line
@@ -42,6 +42,8 @@
#include <net/9p/client.h>
#include <net/9p/client.h>
#include <net/9p/transport.h>
#include <net/9p/transport.h>


#include <linux/syscalls.h> /* killme */

#define P9_PORT 564
#define P9_PORT 564
#define MAX_SOCK_BUF (64*1024)
#define MAX_SOCK_BUF (64*1024)
#define MAXPOLLWADDR	2
#define MAXPOLLWADDR	2
@@ -788,24 +790,41 @@ static int p9_fd_open(struct p9_client *client, int rfd, int wfd)


static int p9_socket_open(struct p9_client *client, struct socket *csocket)
static int p9_socket_open(struct p9_client *client, struct socket *csocket)
{
{
	int fd, ret;
	struct p9_trans_fd *p;
	int ret, fd;

	p = kmalloc(sizeof(struct p9_trans_fd), GFP_KERNEL);
	if (!p)
		return -ENOMEM;


	csocket->sk->sk_allocation = GFP_NOIO;
	csocket->sk->sk_allocation = GFP_NOIO;
	fd = sock_map_fd(csocket, 0);
	fd = sock_map_fd(csocket, 0);
	if (fd < 0) {
	if (fd < 0) {
		P9_EPRINTK(KERN_ERR, "p9_socket_open: failed to map fd\n");
		P9_EPRINTK(KERN_ERR, "p9_socket_open: failed to map fd\n");
		sock_release(csocket);
		kfree(p);
		return fd;
		return fd;
	}
	}


	ret = p9_fd_open(client, fd, fd);
	get_file(csocket->file);
	if (ret < 0) {
	get_file(csocket->file);
		P9_EPRINTK(KERN_ERR, "p9_socket_open: failed to open fd\n");
	p->wr = p->rd = csocket->file;
	client->trans = p;
	client->status = Connected;

	sys_close(fd);	/* still racy */

	p->rd->f_flags |= O_NONBLOCK;

	p->conn = p9_conn_create(client);
	if (IS_ERR(p->conn)) {
		ret = PTR_ERR(p->conn);
		p->conn = NULL;
		kfree(p);
		sockfd_put(csocket);
		sockfd_put(csocket);
		sockfd_put(csocket);
		return ret;
		return ret;
	}
	}

	((struct p9_trans_fd *)client->trans)->rd->f_flags |= O_NONBLOCK;

	return 0;
	return 0;
}
}


@@ -883,7 +902,6 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
	struct socket *csocket;
	struct socket *csocket;
	struct sockaddr_in sin_server;
	struct sockaddr_in sin_server;
	struct p9_fd_opts opts;
	struct p9_fd_opts opts;
	struct p9_trans_fd *p = NULL; /* this gets allocated in p9_fd_open */


	err = parse_opts(args, &opts);
	err = parse_opts(args, &opts);
	if (err < 0)
	if (err < 0)
@@ -897,12 +915,11 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
	sin_server.sin_family = AF_INET;
	sin_server.sin_family = AF_INET;
	sin_server.sin_addr.s_addr = in_aton(addr);
	sin_server.sin_addr.s_addr = in_aton(addr);
	sin_server.sin_port = htons(opts.port);
	sin_server.sin_port = htons(opts.port);
	sock_create_kern(PF_INET, SOCK_STREAM, IPPROTO_TCP, &csocket);
	err = sock_create_kern(PF_INET, SOCK_STREAM, IPPROTO_TCP, &csocket);


	if (!csocket) {
	if (err) {
		P9_EPRINTK(KERN_ERR, "p9_trans_tcp: problem creating socket\n");
		P9_EPRINTK(KERN_ERR, "p9_trans_tcp: problem creating socket\n");
		err = -EIO;
		return err;
		goto error;
	}
	}


	err = csocket->ops->connect(csocket,
	err = csocket->ops->connect(csocket,
@@ -912,89 +929,54 @@ p9_fd_create_tcp(struct p9_client *client, const char *addr, char *args)
		P9_EPRINTK(KERN_ERR,
		P9_EPRINTK(KERN_ERR,
			"p9_trans_tcp: problem connecting socket to %s\n",
			"p9_trans_tcp: problem connecting socket to %s\n",
			addr);
			addr);
		goto error;
	}

	err = p9_socket_open(client, csocket);
	if (err < 0)
		goto error;

	p = (struct p9_trans_fd *) client->trans;
	p->conn = p9_conn_create(client);
	if (IS_ERR(p->conn)) {
		err = PTR_ERR(p->conn);
		p->conn = NULL;
		goto error;
	}

	return 0;

error:
	if (csocket)
		sock_release(csocket);
		sock_release(csocket);

	kfree(p);

		return err;
		return err;
	}
	}


	return p9_socket_open(client, csocket);
}

static int
static int
p9_fd_create_unix(struct p9_client *client, const char *addr, char *args)
p9_fd_create_unix(struct p9_client *client, const char *addr, char *args)
{
{
	int err;
	int err;
	struct socket *csocket;
	struct socket *csocket;
	struct sockaddr_un sun_server;
	struct sockaddr_un sun_server;
	struct p9_trans_fd *p = NULL; /* this gets allocated in p9_fd_open */


	csocket = NULL;
	csocket = NULL;


	if (strlen(addr) > UNIX_PATH_MAX) {
	if (strlen(addr) > UNIX_PATH_MAX) {
		P9_EPRINTK(KERN_ERR, "p9_trans_unix: address too long: %s\n",
		P9_EPRINTK(KERN_ERR, "p9_trans_unix: address too long: %s\n",
			addr);
			addr);
		err = -ENAMETOOLONG;
		return -ENAMETOOLONG;
		goto error;
	}
	}


	sun_server.sun_family = PF_UNIX;
	sun_server.sun_family = PF_UNIX;
	strcpy(sun_server.sun_path, addr);
	strcpy(sun_server.sun_path, addr);
	sock_create_kern(PF_UNIX, SOCK_STREAM, 0, &csocket);
	err = sock_create_kern(PF_UNIX, SOCK_STREAM, 0, &csocket);
	if (err < 0) {
		P9_EPRINTK(KERN_ERR, "p9_trans_unix: problem creating socket\n");
		return err;
	}
	err = csocket->ops->connect(csocket, (struct sockaddr *)&sun_server,
	err = csocket->ops->connect(csocket, (struct sockaddr *)&sun_server,
			sizeof(struct sockaddr_un) - 1, 0);
			sizeof(struct sockaddr_un) - 1, 0);
	if (err < 0) {
	if (err < 0) {
		P9_EPRINTK(KERN_ERR,
		P9_EPRINTK(KERN_ERR,
			"p9_trans_unix: problem connecting socket: %s: %d\n",
			"p9_trans_unix: problem connecting socket: %s: %d\n",
			addr, err);
			addr, err);
		goto error;
	}

	err = p9_socket_open(client, csocket);
	if (err < 0)
		goto error;

	p = (struct p9_trans_fd *) client->trans;
	p->conn = p9_conn_create(client);
	if (IS_ERR(p->conn)) {
		err = PTR_ERR(p->conn);
		p->conn = NULL;
		goto error;
	}

	return 0;

error:
	if (csocket)
		sock_release(csocket);
		sock_release(csocket);

	kfree(p);
		return err;
		return err;
	}
	}


	return p9_socket_open(client, csocket);
}

static int
static int
p9_fd_create(struct p9_client *client, const char *addr, char *args)
p9_fd_create(struct p9_client *client, const char *addr, char *args)
{
{
	int err;
	int err;
	struct p9_fd_opts opts;
	struct p9_fd_opts opts;
	struct p9_trans_fd *p = NULL; /* this get allocated in p9_fd_open */
	struct p9_trans_fd *p;


	parse_opts(args, &opts);
	parse_opts(args, &opts);


@@ -1005,21 +987,19 @@ p9_fd_create(struct p9_client *client, const char *addr, char *args)


	err = p9_fd_open(client, opts.rfd, opts.wfd);
	err = p9_fd_open(client, opts.rfd, opts.wfd);
	if (err < 0)
	if (err < 0)
		goto error;
		return err;


	p = (struct p9_trans_fd *) client->trans;
	p = (struct p9_trans_fd *) client->trans;
	p->conn = p9_conn_create(client);
	p->conn = p9_conn_create(client);
	if (IS_ERR(p->conn)) {
	if (IS_ERR(p->conn)) {
		err = PTR_ERR(p->conn);
		err = PTR_ERR(p->conn);
		p->conn = NULL;
		p->conn = NULL;
		goto error;
		fput(p->rd);
		fput(p->wr);
		return err;
	}
	}


	return 0;
	return 0;

error:
	kfree(p);
	return err;
}
}


static struct p9_trans_module p9_tcp_trans = {
static struct p9_trans_module p9_tcp_trans = {