Commit e8c8fd9b authored by Alexei Starovoitov's avatar Alexei Starovoitov
Browse files

Merge branch 'bpf, sockmap: Fix infinite recursion in sock_map_close'

Jakub Sitnicki says:

====================

This patch set addresses the syzbot report in [1].

Patch #1 has been suggested by Eric [2]. I extended it to cover the rest of
sock_map proto callbacks. Otherwise we would still overflow the stack.

Patch #2 contains the actual fix and bug analysis.
Patches #3 & #4 add coverage to selftests to trigger the bug.

[1] https://lore.kernel.org/all/00000000000073b14905ef2e7401@google.com/
[2] https://lore.kernel.org/all/CANn89iK2UN1FmdUcH12fv_xiZkv2G+Nskvmq7fG6aA_6VKRf6g@mail.gmail.com/
---
v1 -> v2:
v1: https://lore.kernel.org/r/20230113-sockmap-fix-v1-0-d3cad092ee10@cloudflare.com


[v1 didn't hit bpf@ ML by mistake]

 * pull in Eric's patch to protect against recursion loop bugs (Eric)
 * add a macro helper to check if pointer is inside a memory range (Eric)
====================

Signed-off-by: default avatarAlexei Starovoitov <ast@kernel.org>
parents 74bc3a5a c88ea16a
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -38,4 +38,16 @@
 */
#define find_closest_descending(x, a, as) __find_closest(x, a, as, >=)

/**
 * is_insidevar - check if the @ptr points inside the @var memory range.
 * @ptr:	the pointer to a memory address.
 * @var:	the variable which address and size identify the memory range.
 *
 * Evaluates to true if the address in @ptr lies within the memory
 * range allocated to @var.
 */
#define is_insidevar(ptr, var)						\
	((uintptr_t)(ptr) >= (uintptr_t)(var) &&			\
	 (uintptr_t)(ptr) <  (uintptr_t)(var) + sizeof(var))

#endif
+34 −27
Original line number Diff line number Diff line
@@ -1569,14 +1569,15 @@ void sock_map_unhash(struct sock *sk)
	psock = sk_psock(sk);
	if (unlikely(!psock)) {
		rcu_read_unlock();
		if (sk->sk_prot->unhash)
			sk->sk_prot->unhash(sk);
		return;
	}

		saved_unhash = READ_ONCE(sk->sk_prot)->unhash;
	} else {
		saved_unhash = psock->saved_unhash;
		sock_map_remove_links(sk, psock);
		rcu_read_unlock();
	}
	if (WARN_ON_ONCE(saved_unhash == sock_map_unhash))
		return;
	if (saved_unhash)
		saved_unhash(sk);
}
EXPORT_SYMBOL_GPL(sock_map_unhash);
@@ -1590,16 +1591,17 @@ void sock_map_destroy(struct sock *sk)
	psock = sk_psock_get(sk);
	if (unlikely(!psock)) {
		rcu_read_unlock();
		if (sk->sk_prot->destroy)
			sk->sk_prot->destroy(sk);
		return;
	}

		saved_destroy = READ_ONCE(sk->sk_prot)->destroy;
	} else {
		saved_destroy = psock->saved_destroy;
		sock_map_remove_links(sk, psock);
		rcu_read_unlock();
		sk_psock_stop(psock);
		sk_psock_put(sk, psock);
	}
	if (WARN_ON_ONCE(saved_destroy == sock_map_destroy))
		return;
	if (saved_destroy)
		saved_destroy(sk);
}
EXPORT_SYMBOL_GPL(sock_map_destroy);
@@ -1615,9 +1617,8 @@ void sock_map_close(struct sock *sk, long timeout)
	if (unlikely(!psock)) {
		rcu_read_unlock();
		release_sock(sk);
		return sk->sk_prot->close(sk, timeout);
	}

		saved_close = READ_ONCE(sk->sk_prot)->close;
	} else {
		saved_close = psock->saved_close;
		sock_map_remove_links(sk, psock);
		rcu_read_unlock();
@@ -1625,6 +1626,12 @@ void sock_map_close(struct sock *sk, long timeout)
		release_sock(sk);
		cancel_work_sync(&psock->work);
		sk_psock_put(sk, psock);
	}
	/* Make sure we do not recurse. This is a bug.
	 * Leak the socket instead of crashing on a stack overflow.
	 */
	if (WARN_ON_ONCE(saved_close == sock_map_close))
		return;
	saved_close(sk, timeout);
}
EXPORT_SYMBOL_GPL(sock_map_close);
+2 −2
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@
#include <linux/bpf.h>
#include <linux/init.h>
#include <linux/wait.h>
#include <linux/util_macros.h>

#include <net/inet_common.h>
#include <net/tls.h>
@@ -639,10 +640,9 @@ EXPORT_SYMBOL_GPL(tcp_bpf_update_proto);
 */
void tcp_bpf_clone(const struct sock *sk, struct sock *newsk)
{
	int family = sk->sk_family == AF_INET6 ? TCP_BPF_IPV6 : TCP_BPF_IPV4;
	struct proto *prot = newsk->sk_prot;

	if (prot == &tcp_bpf_prots[family][TCP_BPF_BASE])
	if (is_insidevar(prot, tcp_bpf_prots))
		newsk->sk_prot = sk->sk_prot_creator;
}
#endif /* CONFIG_BPF_SYSCALL */
+63 −18
Original line number Diff line number Diff line
@@ -30,6 +30,8 @@
#define MAX_STRERR_LEN 256
#define MAX_TEST_NAME 80

#define __always_unused	__attribute__((__unused__))

#define _FAIL(errnum, fmt...)                                                  \
	({                                                                     \
		error_at_line(0, (errnum), __func__, __LINE__, fmt);           \
@@ -321,7 +323,8 @@ static int socket_loopback(int family, int sotype)
	return socket_loopback_reuseport(family, sotype, -1);
}

static void test_insert_invalid(int family, int sotype, int mapfd)
static void test_insert_invalid(struct test_sockmap_listen *skel __always_unused,
				int family, int sotype, int mapfd)
{
	u32 key = 0;
	u64 value;
@@ -338,7 +341,8 @@ static void test_insert_invalid(int family, int sotype, int mapfd)
		FAIL_ERRNO("map_update: expected EBADF");
}

static void test_insert_opened(int family, int sotype, int mapfd)
static void test_insert_opened(struct test_sockmap_listen *skel __always_unused,
			       int family, int sotype, int mapfd)
{
	u32 key = 0;
	u64 value;
@@ -359,7 +363,8 @@ static void test_insert_opened(int family, int sotype, int mapfd)
	xclose(s);
}

static void test_insert_bound(int family, int sotype, int mapfd)
static void test_insert_bound(struct test_sockmap_listen *skel __always_unused,
			      int family, int sotype, int mapfd)
{
	struct sockaddr_storage addr;
	socklen_t len;
@@ -386,7 +391,8 @@ static void test_insert_bound(int family, int sotype, int mapfd)
	xclose(s);
}

static void test_insert(int family, int sotype, int mapfd)
static void test_insert(struct test_sockmap_listen *skel __always_unused,
			int family, int sotype, int mapfd)
{
	u64 value;
	u32 key;
@@ -402,7 +408,8 @@ static void test_insert(int family, int sotype, int mapfd)
	xclose(s);
}

static void test_delete_after_insert(int family, int sotype, int mapfd)
static void test_delete_after_insert(struct test_sockmap_listen *skel __always_unused,
				     int family, int sotype, int mapfd)
{
	u64 value;
	u32 key;
@@ -419,7 +426,8 @@ static void test_delete_after_insert(int family, int sotype, int mapfd)
	xclose(s);
}

static void test_delete_after_close(int family, int sotype, int mapfd)
static void test_delete_after_close(struct test_sockmap_listen *skel __always_unused,
				    int family, int sotype, int mapfd)
{
	int err, s;
	u64 value;
@@ -442,7 +450,8 @@ static void test_delete_after_close(int family, int sotype, int mapfd)
		FAIL_ERRNO("map_delete: expected EINVAL/EINVAL");
}

static void test_lookup_after_insert(int family, int sotype, int mapfd)
static void test_lookup_after_insert(struct test_sockmap_listen *skel __always_unused,
				     int family, int sotype, int mapfd)
{
	u64 cookie, value;
	socklen_t len;
@@ -470,7 +479,8 @@ static void test_lookup_after_insert(int family, int sotype, int mapfd)
	xclose(s);
}

static void test_lookup_after_delete(int family, int sotype, int mapfd)
static void test_lookup_after_delete(struct test_sockmap_listen *skel __always_unused,
				     int family, int sotype, int mapfd)
{
	int err, s;
	u64 value;
@@ -493,7 +503,8 @@ static void test_lookup_after_delete(int family, int sotype, int mapfd)
	xclose(s);
}

static void test_lookup_32_bit_value(int family, int sotype, int mapfd)
static void test_lookup_32_bit_value(struct test_sockmap_listen *skel __always_unused,
				     int family, int sotype, int mapfd)
{
	u32 key, value32;
	int err, s;
@@ -523,7 +534,8 @@ static void test_lookup_32_bit_value(int family, int sotype, int mapfd)
	xclose(s);
}

static void test_update_existing(int family, int sotype, int mapfd)
static void test_update_existing(struct test_sockmap_listen *skel __always_unused,
				 int family, int sotype, int mapfd)
{
	int s1, s2;
	u64 value;
@@ -551,7 +563,7 @@ static void test_update_existing(int family, int sotype, int mapfd)
/* Exercise the code path where we destroy child sockets that never
 * got accept()'ed, aka orphans, when parent socket gets closed.
 */
static void test_destroy_orphan_child(int family, int sotype, int mapfd)
static void do_destroy_orphan_child(int family, int sotype, int mapfd)
{
	struct sockaddr_storage addr;
	socklen_t len;
@@ -582,10 +594,38 @@ static void test_destroy_orphan_child(int family, int sotype, int mapfd)
	xclose(s);
}

static void test_destroy_orphan_child(struct test_sockmap_listen *skel,
				      int family, int sotype, int mapfd)
{
	int msg_verdict = bpf_program__fd(skel->progs.prog_msg_verdict);
	int skb_verdict = bpf_program__fd(skel->progs.prog_skb_verdict);
	const struct test {
		int progfd;
		enum bpf_attach_type atype;
	} tests[] = {
		{ -1, -1 },
		{ msg_verdict, BPF_SK_MSG_VERDICT },
		{ skb_verdict, BPF_SK_SKB_VERDICT },
	};
	const struct test *t;

	for (t = tests; t < tests + ARRAY_SIZE(tests); t++) {
		if (t->progfd != -1 &&
		    xbpf_prog_attach(t->progfd, mapfd, t->atype, 0) != 0)
			return;

		do_destroy_orphan_child(family, sotype, mapfd);

		if (t->progfd != -1)
			xbpf_prog_detach2(t->progfd, mapfd, t->atype);
	}
}

/* Perform a passive open after removing listening socket from SOCKMAP
 * to ensure that callbacks get restored properly.
 */
static void test_clone_after_delete(int family, int sotype, int mapfd)
static void test_clone_after_delete(struct test_sockmap_listen *skel __always_unused,
				    int family, int sotype, int mapfd)
{
	struct sockaddr_storage addr;
	socklen_t len;
@@ -621,7 +661,8 @@ static void test_clone_after_delete(int family, int sotype, int mapfd)
 * SOCKMAP, but got accept()'ed only after the parent has been removed
 * from SOCKMAP, gets cloned without parent psock state or callbacks.
 */
static void test_accept_after_delete(int family, int sotype, int mapfd)
static void test_accept_after_delete(struct test_sockmap_listen *skel __always_unused,
				     int family, int sotype, int mapfd)
{
	struct sockaddr_storage addr;
	const u32 zero = 0;
@@ -675,7 +716,8 @@ static void test_accept_after_delete(int family, int sotype, int mapfd)
/* Check that child socket that got created and accepted while parent
 * was in a SOCKMAP is cloned without parent psock state or callbacks.
 */
static void test_accept_before_delete(int family, int sotype, int mapfd)
static void test_accept_before_delete(struct test_sockmap_listen *skel __always_unused,
				      int family, int sotype, int mapfd)
{
	struct sockaddr_storage addr;
	const u32 zero = 0, one = 1;
@@ -784,7 +826,8 @@ static void *connect_accept_thread(void *arg)
	return NULL;
}

static void test_syn_recv_insert_delete(int family, int sotype, int mapfd)
static void test_syn_recv_insert_delete(struct test_sockmap_listen *skel __always_unused,
					int family, int sotype, int mapfd)
{
	struct connect_accept_ctx ctx = { 0 };
	struct sockaddr_storage addr;
@@ -847,7 +890,8 @@ static void *listen_thread(void *arg)
	return NULL;
}

static void test_race_insert_listen(int family, int socktype, int mapfd)
static void test_race_insert_listen(struct test_sockmap_listen *skel __always_unused,
				    int family, int socktype, int mapfd)
{
	struct connect_accept_ctx ctx = { 0 };
	const u32 zero = 0;
@@ -1473,7 +1517,8 @@ static void test_ops(struct test_sockmap_listen *skel, struct bpf_map *map,
		     int family, int sotype)
{
	const struct op_test {
		void (*fn)(int family, int sotype, int mapfd);
		void (*fn)(struct test_sockmap_listen *skel,
			   int family, int sotype, int mapfd);
		const char *name;
		int sotype;
	} tests[] = {
@@ -1520,7 +1565,7 @@ static void test_ops(struct test_sockmap_listen *skel, struct bpf_map *map,
		if (!test__start_subtest(s))
			continue;

		t->fn(family, sotype, map_fd);
		t->fn(skel, family, sotype, map_fd);
		test_ops_cleanup(map);
	}
}