Commit cf2f225e authored by Kuniyuki Iwashima's avatar Kuniyuki Iwashima Committed by David S. Miller
Browse files

af_unix: Put a socket into a per-netns hash table.



This commit replaces the global hash table with a per-netns one and removes
the global one.

We now link a socket in each netns's hash table so we can save some netns
comparisons when iterating through a hash bucket.

Signed-off-by: default avatarKuniyuki Iwashima <kuniyu@amazon.com>
Signed-off-by: default avatarDavid S. Miller <davem@davemloft.net>
parent 79b05bea
Loading
Loading
Loading
Loading
+0 −1
Original line number Original line Diff line number Diff line
@@ -22,7 +22,6 @@ struct sock *unix_peer_get(struct sock *sk);


extern unsigned int unix_tot_inflight;
extern unsigned int unix_tot_inflight;
extern spinlock_t unix_table_locks[UNIX_HASH_SIZE];
extern spinlock_t unix_table_locks[UNIX_HASH_SIZE];
extern struct hlist_head unix_socket_table[UNIX_HASH_SIZE];


struct unix_address {
struct unix_address {
	refcount_t	refcnt;
	refcount_t	refcnt;
+20 −30
Original line number Original line Diff line number Diff line
@@ -120,8 +120,6 @@


spinlock_t unix_table_locks[UNIX_HASH_SIZE];
spinlock_t unix_table_locks[UNIX_HASH_SIZE];
EXPORT_SYMBOL_GPL(unix_table_locks);
EXPORT_SYMBOL_GPL(unix_table_locks);
struct hlist_head unix_socket_table[UNIX_HASH_SIZE];
EXPORT_SYMBOL_GPL(unix_socket_table);
static atomic_long_t unix_nr_socks;
static atomic_long_t unix_nr_socks;


/* SMP locking strategy:
/* SMP locking strategy:
@@ -308,20 +306,20 @@ static void __unix_remove_socket(struct sock *sk)
	sk_del_node_init(sk);
	sk_del_node_init(sk);
}
}


static void __unix_insert_socket(struct sock *sk)
static void __unix_insert_socket(struct net *net, struct sock *sk)
{
{
	DEBUG_NET_WARN_ON_ONCE(!sk_unhashed(sk));
	DEBUG_NET_WARN_ON_ONCE(!sk_unhashed(sk));
	sk_add_node(sk, &unix_socket_table[sk->sk_hash]);
	sk_add_node(sk, &net->unx.table.buckets[sk->sk_hash]);
}
}


static void __unix_set_addr_hash(struct sock *sk, struct unix_address *addr,
static void __unix_set_addr_hash(struct net *net, struct sock *sk,
				 unsigned int hash)
				 struct unix_address *addr, unsigned int hash)
{
{
	__unix_remove_socket(sk);
	__unix_remove_socket(sk);
	smp_store_release(&unix_sk(sk)->addr, addr);
	smp_store_release(&unix_sk(sk)->addr, addr);


	sk->sk_hash = hash;
	sk->sk_hash = hash;
	__unix_insert_socket(sk);
	__unix_insert_socket(net, sk);
}
}


static void unix_remove_socket(struct net *net, struct sock *sk)
static void unix_remove_socket(struct net *net, struct sock *sk)
@@ -337,7 +335,7 @@ static void unix_insert_unbound_socket(struct net *net, struct sock *sk)
{
{
	spin_lock(&unix_table_locks[sk->sk_hash]);
	spin_lock(&unix_table_locks[sk->sk_hash]);
	spin_lock(&net->unx.table.locks[sk->sk_hash]);
	spin_lock(&net->unx.table.locks[sk->sk_hash]);
	__unix_insert_socket(sk);
	__unix_insert_socket(net, sk);
	spin_unlock(&net->unx.table.locks[sk->sk_hash]);
	spin_unlock(&net->unx.table.locks[sk->sk_hash]);
	spin_unlock(&unix_table_locks[sk->sk_hash]);
	spin_unlock(&unix_table_locks[sk->sk_hash]);
}
}
@@ -348,12 +346,9 @@ static struct sock *__unix_find_socket_byname(struct net *net,
{
{
	struct sock *s;
	struct sock *s;


	sk_for_each(s, &unix_socket_table[hash]) {
	sk_for_each(s, &net->unx.table.buckets[hash]) {
		struct unix_sock *u = unix_sk(s);
		struct unix_sock *u = unix_sk(s);


		if (!net_eq(sock_net(s), net))
			continue;

		if (u->addr->len == len &&
		if (u->addr->len == len &&
		    !memcmp(u->addr->name, sunname, len))
		    !memcmp(u->addr->name, sunname, len))
			return s;
			return s;
@@ -384,7 +379,7 @@ static struct sock *unix_find_socket_byinode(struct net *net, struct inode *i)


	spin_lock(&unix_table_locks[hash]);
	spin_lock(&unix_table_locks[hash]);
	spin_lock(&net->unx.table.locks[hash]);
	spin_lock(&net->unx.table.locks[hash]);
	sk_for_each(s, &unix_socket_table[hash]) {
	sk_for_each(s, &net->unx.table.buckets[hash]) {
		struct dentry *dentry = unix_sk(s)->path.dentry;
		struct dentry *dentry = unix_sk(s)->path.dentry;


		if (dentry && d_backing_inode(dentry) == i) {
		if (dentry && d_backing_inode(dentry) == i) {
@@ -1140,7 +1135,7 @@ static int unix_autobind(struct sock *sk)
		goto retry;
		goto retry;
	}
	}


	__unix_set_addr_hash(sk, addr, new_hash);
	__unix_set_addr_hash(net, sk, addr, new_hash);
	unix_table_double_unlock(net, old_hash, new_hash);
	unix_table_double_unlock(net, old_hash, new_hash);
	err = 0;
	err = 0;


@@ -1199,7 +1194,7 @@ static int unix_bind_bsd(struct sock *sk, struct sockaddr_un *sunaddr,
	unix_table_double_lock(net, old_hash, new_hash);
	unix_table_double_lock(net, old_hash, new_hash);
	u->path.mnt = mntget(parent.mnt);
	u->path.mnt = mntget(parent.mnt);
	u->path.dentry = dget(dentry);
	u->path.dentry = dget(dentry);
	__unix_set_addr_hash(sk, addr, new_hash);
	__unix_set_addr_hash(net, sk, addr, new_hash);
	unix_table_double_unlock(net, old_hash, new_hash);
	unix_table_double_unlock(net, old_hash, new_hash);
	mutex_unlock(&u->bindlock);
	mutex_unlock(&u->bindlock);
	done_path_create(&parent, dentry);
	done_path_create(&parent, dentry);
@@ -1246,7 +1241,7 @@ static int unix_bind_abstract(struct sock *sk, struct sockaddr_un *sunaddr,
	if (__unix_find_socket_byname(net, addr->name, addr->len, new_hash))
	if (__unix_find_socket_byname(net, addr->name, addr->len, new_hash))
		goto out_spin;
		goto out_spin;


	__unix_set_addr_hash(sk, addr, new_hash);
	__unix_set_addr_hash(net, sk, addr, new_hash);
	unix_table_double_unlock(net, old_hash, new_hash);
	unix_table_double_unlock(net, old_hash, new_hash);
	mutex_unlock(&u->bindlock);
	mutex_unlock(&u->bindlock);
	return 0;
	return 0;
@@ -3239,12 +3234,11 @@ static struct sock *unix_from_bucket(struct seq_file *seq, loff_t *pos)
{
{
	unsigned long offset = get_offset(*pos);
	unsigned long offset = get_offset(*pos);
	unsigned long bucket = get_bucket(*pos);
	unsigned long bucket = get_bucket(*pos);
	struct sock *sk;
	unsigned long count = 0;
	unsigned long count = 0;
	struct sock *sk;


	for (sk = sk_head(&unix_socket_table[bucket]); sk; sk = sk_next(sk)) {
	for (sk = sk_head(&seq_file_net(seq)->unx.table.buckets[bucket]);
		if (sock_net(sk) != seq_file_net(seq))
	     sk; sk = sk_next(sk)) {
			continue;
		if (++count == offset)
		if (++count == offset)
			break;
			break;
	}
	}
@@ -3279,13 +3273,13 @@ static struct sock *unix_get_next(struct seq_file *seq, struct sock *sk,
				  loff_t *pos)
				  loff_t *pos)
{
{
	unsigned long bucket = get_bucket(*pos);
	unsigned long bucket = get_bucket(*pos);
	struct net *net = seq_file_net(seq);


	for (sk = sk_next(sk); sk; sk = sk_next(sk))
	sk = sk_next(sk);
		if (sock_net(sk) == net)
	if (sk)
		return sk;
		return sk;


	spin_unlock(&net->unx.table.locks[bucket]);

	spin_unlock(&seq_file_net(seq)->unx.table.locks[bucket]);
	spin_unlock(&unix_table_locks[bucket]);
	spin_unlock(&unix_table_locks[bucket]);


	*pos = set_bucket_offset(++bucket, 1);
	*pos = set_bucket_offset(++bucket, 1);
@@ -3406,7 +3400,6 @@ static int bpf_iter_unix_hold_batch(struct seq_file *seq, struct sock *start_sk)


{
{
	struct bpf_unix_iter_state *iter = seq->private;
	struct bpf_unix_iter_state *iter = seq->private;
	struct net *net = seq_file_net(seq);
	unsigned int expected = 1;
	unsigned int expected = 1;
	struct sock *sk;
	struct sock *sk;


@@ -3414,9 +3407,6 @@ static int bpf_iter_unix_hold_batch(struct seq_file *seq, struct sock *start_sk)
	iter->batch[iter->end_sk++] = start_sk;
	iter->batch[iter->end_sk++] = start_sk;


	for (sk = sk_next(start_sk); sk; sk = sk_next(sk)) {
	for (sk = sk_next(start_sk); sk; sk = sk_next(sk)) {
		if (sock_net(sk) != net)
			continue;

		if (iter->end_sk < iter->max_sk) {
		if (iter->end_sk < iter->max_sk) {
			sock_hold(sk);
			sock_hold(sk);
			iter->batch[iter->end_sk++] = sk;
			iter->batch[iter->end_sk++] = sk;
@@ -3425,7 +3415,7 @@ static int bpf_iter_unix_hold_batch(struct seq_file *seq, struct sock *start_sk)
		expected++;
		expected++;
	}
	}


	spin_unlock(&net->unx.table.locks[start_sk->sk_hash]);
	spin_unlock(&seq_file_net(seq)->unx.table.locks[start_sk->sk_hash]);
	spin_unlock(&unix_table_locks[start_sk->sk_hash]);
	spin_unlock(&unix_table_locks[start_sk->sk_hash]);


	return expected;
	return expected;
+3 −6
Original line number Original line Diff line number Diff line
@@ -210,9 +210,7 @@ static int unix_diag_dump(struct sk_buff *skb, struct netlink_callback *cb)
		num = 0;
		num = 0;
		spin_lock(&unix_table_locks[slot]);
		spin_lock(&unix_table_locks[slot]);
		spin_lock(&net->unx.table.locks[slot]);
		spin_lock(&net->unx.table.locks[slot]);
		sk_for_each(sk, &unix_socket_table[slot]) {
		sk_for_each(sk, &net->unx.table.buckets[slot]) {
			if (!net_eq(sock_net(sk), net))
				continue;
			if (num < s_num)
			if (num < s_num)
				goto next;
				goto next;
			if (!(req->udiag_states & (1 << sk->sk_state)))
			if (!(req->udiag_states & (1 << sk->sk_state)))
@@ -246,13 +244,14 @@ static struct sock *unix_lookup_by_ino(struct net *net, unsigned int ino)
	for (i = 0; i < UNIX_HASH_SIZE; i++) {
	for (i = 0; i < UNIX_HASH_SIZE; i++) {
		spin_lock(&unix_table_locks[i]);
		spin_lock(&unix_table_locks[i]);
		spin_lock(&net->unx.table.locks[i]);
		spin_lock(&net->unx.table.locks[i]);
		sk_for_each(sk, &unix_socket_table[i])
		sk_for_each(sk, &net->unx.table.buckets[i]) {
			if (ino == sock_i_ino(sk)) {
			if (ino == sock_i_ino(sk)) {
				sock_hold(sk);
				sock_hold(sk);
				spin_unlock(&net->unx.table.locks[i]);
				spin_unlock(&net->unx.table.locks[i]);
				spin_unlock(&unix_table_locks[i]);
				spin_unlock(&unix_table_locks[i]);
				return sk;
				return sk;
			}
			}
		}
		spin_unlock(&net->unx.table.locks[i]);
		spin_unlock(&net->unx.table.locks[i]);
		spin_unlock(&unix_table_locks[i]);
		spin_unlock(&unix_table_locks[i]);
	}
	}
@@ -277,8 +276,6 @@ static int unix_diag_get_exact(struct sk_buff *in_skb,
	err = -ENOENT;
	err = -ENOENT;
	if (sk == NULL)
	if (sk == NULL)
		goto out_nosk;
		goto out_nosk;
	if (!net_eq(sock_net(sk), net))
		goto out;


	err = sock_diag_check_cookie(sk, req->udiag_cookie);
	err = sock_diag_check_cookie(sk, req->udiag_cookie);
	if (err)
	if (err)