Commit eca3a04f authored by Linus Torvalds's avatar Linus Torvalds
Browse files
Pull dlm updates from David Teigland:
 "This fixes some races in the lowcomms startup and shutdown code that
  were found by targeted stress testing that quickly and repeatedly
  joins and leaves lockspaces"

* tag 'dlm-6.3' of git://git.kernel.org/pub/scm/linux/kernel/git/teigland/linux-dlm:
  fs: dlm: remove unnecessary waker_up() calls
  fs: dlm: move state change into else branch
  fs: dlm: remove newline in log_print
  fs: dlm: reduce the shutdown timeout to 5 secs
  fs: dlm: make dlm sequence id more robust
  fs: dlm: wait until all midcomms nodes detect version
  fs: dlm: ignore unexpected non dlm opts msgs
  fs: dlm: bring back previous shutdown handling
  fs: dlm: send FIN ack back in right cases
  fs: dlm: move sending fin message into state change handling
  fs: dlm: don't set stop rx flag after node reset
  fs: dlm: fix race setting stop tx flag
  fs: dlm: be sure to call dlm_send_queue_flush()
  fs: dlm: fix use after free in midcomms commit
  fs: dlm: start midcomms before scand
  fs/dlm: Remove "select SRCU"
  fs: dlm: fix return value check in dlm_memory_init()
parents 885ce487 723b197b
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -4,7 +4,6 @@ menuconfig DLM
	depends on INET
	depends on SYSFS && CONFIGFS_FS && (IPV6 || IPV6=n)
	select IP_SCTP
	select SRCU
	help
	A general purpose distributed lock manager for kernel or userspace
	applications.
+12 −9
Original line number Diff line number Diff line
@@ -381,23 +381,23 @@ static int threads_start(void)
{
	int error;

	error = dlm_scand_start();
	/* Thread for sending/receiving messages for all lockspace's */
	error = dlm_midcomms_start();
	if (error) {
		log_print("cannot start dlm_scand thread %d", error);
		log_print("cannot start dlm midcomms %d", error);
		goto fail;
	}

	/* Thread for sending/receiving messages for all lockspace's */
	error = dlm_midcomms_start();
	error = dlm_scand_start();
	if (error) {
		log_print("cannot start dlm midcomms %d", error);
		goto scand_fail;
		log_print("cannot start dlm_scand thread %d", error);
		goto midcomms_fail;
	}

	return 0;

 scand_fail:
	dlm_scand_stop();
 midcomms_fail:
	dlm_midcomms_stop();
 fail:
	return error;
}
@@ -572,7 +572,7 @@ static int new_lockspace(const char *name, const char *cluster,
	spin_lock_init(&ls->ls_rcom_spin);
	get_random_bytes(&ls->ls_rcom_seq, sizeof(uint64_t));
	ls->ls_recover_status = 0;
	ls->ls_recover_seq = 0;
	ls->ls_recover_seq = get_random_u64();
	ls->ls_recover_args = NULL;
	init_rwsem(&ls->ls_in_recovery);
	init_rwsem(&ls->ls_recv_active);
@@ -820,6 +820,9 @@ static int release_lockspace(struct dlm_ls *ls, int force)
		return rv;
	}

	if (ls_count == 1)
		dlm_midcomms_version_wait();

	dlm_device_deregister(ls);

	if (force < 3 && dlm_user_daemon_available())
+56 −21
Original line number Diff line number Diff line
@@ -61,6 +61,7 @@
#include "memory.h"
#include "config.h"

#define DLM_SHUTDOWN_WAIT_TIMEOUT msecs_to_jiffies(5000)
#define NEEDED_RMEM (4*1024*1024)

struct connection {
@@ -99,6 +100,7 @@ struct connection {
	struct connection *othercon;
	struct work_struct rwork; /* receive worker */
	struct work_struct swork; /* send worker */
	wait_queue_head_t shutdown_wait;
	unsigned char rx_leftover_buf[DLM_MAX_SOCKET_BUFSIZE];
	int rx_leftover;
	int mark;
@@ -282,6 +284,7 @@ static void dlm_con_init(struct connection *con, int nodeid)
	INIT_WORK(&con->swork, process_send_sockets);
	INIT_WORK(&con->rwork, process_recv_sockets);
	spin_lock_init(&con->addrs_lock);
	init_waitqueue_head(&con->shutdown_wait);
}

/*
@@ -790,6 +793,43 @@ static void close_connection(struct connection *con, bool and_other)
	up_write(&con->sock_lock);
}

static void shutdown_connection(struct connection *con, bool and_other)
{
	int ret;

	if (con->othercon && and_other)
		shutdown_connection(con->othercon, false);

	flush_workqueue(io_workqueue);
	down_read(&con->sock_lock);
	/* nothing to shutdown */
	if (!con->sock) {
		up_read(&con->sock_lock);
		return;
	}

	ret = kernel_sock_shutdown(con->sock, SHUT_WR);
	up_read(&con->sock_lock);
	if (ret) {
		log_print("Connection %p failed to shutdown: %d will force close",
			  con, ret);
		goto force_close;
	} else {
		ret = wait_event_timeout(con->shutdown_wait, !con->sock,
					 DLM_SHUTDOWN_WAIT_TIMEOUT);
		if (ret == 0) {
			log_print("Connection %p shutdown timed out, will force close",
				  con);
			goto force_close;
		}
	}

	return;

force_close:
	close_connection(con, false);
}

static struct processqueue_entry *new_processqueue_entry(int nodeid,
							 int buflen)
{
@@ -1488,6 +1528,7 @@ static void process_recv_sockets(struct work_struct *work)
		break;
	case DLM_IO_EOF:
		close_connection(con, false);
		wake_up(&con->shutdown_wait);
		/* CF_RECV_PENDING cleared */
		break;
	case DLM_IO_RESCHED:
@@ -1695,6 +1736,9 @@ static int work_start(void)

void dlm_lowcomms_shutdown(void)
{
	struct connection *con;
	int i, idx;

	/* stop lowcomms_listen_data_ready calls */
	lock_sock(listen_con.sock->sk);
	listen_con.sock->sk->sk_data_ready = listen_sock.sk_data_ready;
@@ -1703,29 +1747,20 @@ void dlm_lowcomms_shutdown(void)
	cancel_work_sync(&listen_con.rwork);
	dlm_close_sock(&listen_con.sock);

	flush_workqueue(process_workqueue);
}

void dlm_lowcomms_shutdown_node(int nodeid, bool force)
{
	struct connection *con;
	int idx;

	idx = srcu_read_lock(&connections_srcu);
	con = nodeid2con(nodeid, 0);
	if (WARN_ON_ONCE(!con)) {
		srcu_read_unlock(&connections_srcu, idx);
		return;
	}

	flush_work(&con->swork);
	for (i = 0; i < CONN_HASH_SIZE; i++) {
		hlist_for_each_entry_rcu(con, &connection_hash[i], list) {
			shutdown_connection(con, true);
			stop_connection_io(con);
	WARN_ON_ONCE(!force && !list_empty(&con->writequeue));
			flush_workqueue(process_workqueue);
			close_connection(con, true);

			clean_one_writequeue(con);
			if (con->othercon)
				clean_one_writequeue(con->othercon);
			allow_connection_io(con);
		}
	}
	srcu_read_unlock(&connections_srcu, idx);
}

+1 −1
Original line number Diff line number Diff line
@@ -51,7 +51,7 @@ int __init dlm_memory_init(void)
	cb_cache = kmem_cache_create("dlm_cb", sizeof(struct dlm_callback),
				     __alignof__(struct dlm_callback), 0,
				     NULL);
	if (!rsb_cache)
	if (!cb_cache)
		goto cb;

	return 0;
+66 −65
Original line number Diff line number Diff line
@@ -146,8 +146,8 @@

/* init value for sequence numbers for testing purpose only e.g. overflows */
#define DLM_SEQ_INIT		0
/* 3 minutes wait to sync ending of dlm */
#define DLM_SHUTDOWN_TIMEOUT	msecs_to_jiffies(3 * 60 * 1000)
/* 5 seconds wait to sync ending of dlm */
#define DLM_SHUTDOWN_TIMEOUT	msecs_to_jiffies(5000)
#define DLM_VERSION_NOT_SET	0

struct midcomms_node {
@@ -375,7 +375,7 @@ static int dlm_send_ack(int nodeid, uint32_t seq)
	struct dlm_msg *msg;
	char *ppc;

	msg = dlm_lowcomms_new_msg(nodeid, mb_len, GFP_NOFS, &ppc,
	msg = dlm_lowcomms_new_msg(nodeid, mb_len, GFP_ATOMIC, &ppc,
				   NULL, NULL);
	if (!msg)
		return -ENOMEM;
@@ -402,10 +402,11 @@ static int dlm_send_fin(struct midcomms_node *node,
	struct dlm_mhandle *mh;
	char *ppc;

	mh = dlm_midcomms_get_mhandle(node->nodeid, mb_len, GFP_NOFS, &ppc);
	mh = dlm_midcomms_get_mhandle(node->nodeid, mb_len, GFP_ATOMIC, &ppc);
	if (!mh)
		return -ENOMEM;

	set_bit(DLM_NODE_FLAG_STOP_TX, &node->flags);
	mh->ack_rcv = ack_rcv;

	m_header = (struct dlm_header *)ppc;
@@ -417,7 +418,6 @@ static int dlm_send_fin(struct midcomms_node *node,

	pr_debug("sending fin msg to node %d\n", node->nodeid);
	dlm_midcomms_commit_mhandle(mh, NULL, 0);
	set_bit(DLM_NODE_FLAG_STOP_TX, &node->flags);

	return 0;
}
@@ -467,7 +467,7 @@ static void dlm_pas_fin_ack_rcv(struct midcomms_node *node)
		break;
	default:
		spin_unlock(&node->state_lock);
		log_print("%s: unexpected state: %d\n",
		log_print("%s: unexpected state: %d",
			  __func__, node->state);
		WARN_ON_ONCE(1);
		return;
@@ -498,18 +498,14 @@ static void dlm_midcomms_receive_buffer(union dlm_packet *p,

		switch (p->header.h_cmd) {
		case DLM_FIN:
			/* send ack before fin */
			dlm_send_ack(node->nodeid, node->seq_next);

			spin_lock(&node->state_lock);
			pr_debug("receive fin msg from node %d with state %s\n",
				 node->nodeid, dlm_state_str(node->state));

			switch (node->state) {
			case DLM_ESTABLISHED:
				node->state = DLM_CLOSE_WAIT;
				pr_debug("switch node %d to state %s\n",
					 node->nodeid, dlm_state_str(node->state));
				dlm_send_ack(node->nodeid, node->seq_next);

				/* passive shutdown DLM_LAST_ACK case 1
				 * additional we check if the node is used by
				 * cluster manager events at all.
@@ -518,34 +514,38 @@ static void dlm_midcomms_receive_buffer(union dlm_packet *p,
					node->state = DLM_LAST_ACK;
					pr_debug("switch node %d to state %s case 1\n",
						 node->nodeid, dlm_state_str(node->state));
					spin_unlock(&node->state_lock);
					goto send_fin;
					set_bit(DLM_NODE_FLAG_STOP_RX, &node->flags);
					dlm_send_fin(node, dlm_pas_fin_ack_rcv);
				} else {
					node->state = DLM_CLOSE_WAIT;
					pr_debug("switch node %d to state %s\n",
						 node->nodeid, dlm_state_str(node->state));
				}
				break;
			case DLM_FIN_WAIT1:
				dlm_send_ack(node->nodeid, node->seq_next);
				node->state = DLM_CLOSING;
				set_bit(DLM_NODE_FLAG_STOP_RX, &node->flags);
				pr_debug("switch node %d to state %s\n",
					 node->nodeid, dlm_state_str(node->state));
				break;
			case DLM_FIN_WAIT2:
				dlm_send_ack(node->nodeid, node->seq_next);
				midcomms_node_reset(node);
				pr_debug("switch node %d to state %s\n",
					 node->nodeid, dlm_state_str(node->state));
				wake_up(&node->shutdown_wait);
				break;
			case DLM_LAST_ACK:
				/* probably remove_member caught it, do nothing */
				break;
			default:
				spin_unlock(&node->state_lock);
				log_print("%s: unexpected state: %d\n",
				log_print("%s: unexpected state: %d",
					  __func__, node->state);
				WARN_ON_ONCE(1);
				return;
			}
			spin_unlock(&node->state_lock);

			set_bit(DLM_NODE_FLAG_STOP_RX, &node->flags);
			break;
		default:
			WARN_ON_ONCE(test_bit(DLM_NODE_FLAG_STOP_RX, &node->flags));
@@ -564,12 +564,6 @@ static void dlm_midcomms_receive_buffer(union dlm_packet *p,
		log_print_ratelimited("ignore dlm msg because seq mismatch, seq: %u, expected: %u, nodeid: %d",
				      seq, node->seq_next, node->nodeid);
	}

	return;

send_fin:
	set_bit(DLM_NODE_FLAG_STOP_RX, &node->flags);
	dlm_send_fin(node, dlm_pas_fin_ack_rcv);
}

static struct midcomms_node *
@@ -612,16 +606,8 @@ dlm_midcomms_recv_node_lookup(int nodeid, const union dlm_packet *p,
				case DLM_ESTABLISHED:
					break;
				default:
					/* some invalid state passive shutdown
					 * was failed, we try to reset and
					 * hope it will go on.
					 */
					log_print("reset node %d because shutdown stuck",
						  node->nodeid);

					midcomms_node_reset(node);
					node->state = DLM_ESTABLISHED;
					break;
					spin_unlock(&node->state_lock);
					return NULL;
				}
				spin_unlock(&node->state_lock);
			}
@@ -671,6 +657,7 @@ static int dlm_midcomms_version_check_3_2(struct midcomms_node *node)
	switch (node->version) {
	case DLM_VERSION_NOT_SET:
		node->version = DLM_VERSION_3_2;
		wake_up(&node->shutdown_wait);
		log_print("version 0x%08x for node %d detected", DLM_VERSION_3_2,
			  node->nodeid);
		break;
@@ -840,6 +827,7 @@ static int dlm_midcomms_version_check_3_1(struct midcomms_node *node)
	switch (node->version) {
	case DLM_VERSION_NOT_SET:
		node->version = DLM_VERSION_3_1;
		wake_up(&node->shutdown_wait);
		log_print("version 0x%08x for node %d detected", DLM_VERSION_3_1,
			  node->nodeid);
		break;
@@ -1214,8 +1202,15 @@ void dlm_midcomms_commit_mhandle(struct dlm_mhandle *mh,
		dlm_free_mhandle(mh);
		break;
	case DLM_VERSION_3_2:
		/* held rcu read lock here, because we sending the
		 * dlm message out, when we do that we could receive
		 * an ack back which releases the mhandle and we
		 * get a use after free.
		 */
		rcu_read_lock();
		dlm_midcomms_commit_msg_3_2(mh, name, namelen);
		srcu_read_unlock(&nodes_srcu, mh->idx);
		rcu_read_unlock();
		break;
	default:
		srcu_read_unlock(&nodes_srcu, mh->idx);
@@ -1266,7 +1261,6 @@ static void dlm_act_fin_ack_rcv(struct midcomms_node *node)
		midcomms_node_reset(node);
		pr_debug("switch node %d to state %s\n",
			 node->nodeid, dlm_state_str(node->state));
		wake_up(&node->shutdown_wait);
		break;
	case DLM_CLOSED:
		/* not valid but somehow we got what we want */
@@ -1274,7 +1268,7 @@ static void dlm_act_fin_ack_rcv(struct midcomms_node *node)
		break;
	default:
		spin_unlock(&node->state_lock);
		log_print("%s: unexpected state: %d\n",
		log_print("%s: unexpected state: %d",
			  __func__, node->state);
		WARN_ON_ONCE(1);
		return;
@@ -1362,11 +1356,11 @@ void dlm_midcomms_remove_member(int nodeid)
		case DLM_CLOSE_WAIT:
			/* passive shutdown DLM_LAST_ACK case 2 */
			node->state = DLM_LAST_ACK;
			spin_unlock(&node->state_lock);

			pr_debug("switch node %d to state %s case 2\n",
				 node->nodeid, dlm_state_str(node->state));
			goto send_fin;
			set_bit(DLM_NODE_FLAG_STOP_RX, &node->flags);
			dlm_send_fin(node, dlm_pas_fin_ack_rcv);
			break;
		case DLM_LAST_ACK:
			/* probably receive fin caught it, do nothing */
			break;
@@ -1374,7 +1368,7 @@ void dlm_midcomms_remove_member(int nodeid)
			/* already gone, do nothing */
			break;
		default:
			log_print("%s: unexpected state: %d\n",
			log_print("%s: unexpected state: %d",
				  __func__, node->state);
			break;
		}
@@ -1382,12 +1376,6 @@ void dlm_midcomms_remove_member(int nodeid)
	spin_unlock(&node->state_lock);

	srcu_read_unlock(&nodes_srcu, idx);
	return;

send_fin:
	set_bit(DLM_NODE_FLAG_STOP_RX, &node->flags);
	dlm_send_fin(node, dlm_pas_fin_ack_rcv);
	srcu_read_unlock(&nodes_srcu, idx);
}

static void midcomms_node_release(struct rcu_head *rcu)
@@ -1395,9 +1383,31 @@ static void midcomms_node_release(struct rcu_head *rcu)
	struct midcomms_node *node = container_of(rcu, struct midcomms_node, rcu);

	WARN_ON_ONCE(atomic_read(&node->send_queue_cnt));
	dlm_send_queue_flush(node);
	kfree(node);
}

void dlm_midcomms_version_wait(void)
{
	struct midcomms_node *node;
	int i, idx, ret;

	idx = srcu_read_lock(&nodes_srcu);
	for (i = 0; i < CONN_HASH_SIZE; i++) {
		hlist_for_each_entry_rcu(node, &node_hash[i], hlist) {
			ret = wait_event_timeout(node->shutdown_wait,
						 node->version != DLM_VERSION_NOT_SET ||
						 node->state == DLM_CLOSED ||
						 test_bit(DLM_NODE_FLAG_CLOSE, &node->flags),
						 DLM_SHUTDOWN_TIMEOUT);
			if (!ret || test_bit(DLM_NODE_FLAG_CLOSE, &node->flags))
				pr_debug("version wait timed out for node %d with state %s\n",
					 node->nodeid, dlm_state_str(node->state));
		}
	}
	srcu_read_unlock(&nodes_srcu, idx);
}

static void midcomms_shutdown(struct midcomms_node *node)
{
	int ret;
@@ -1418,11 +1428,11 @@ static void midcomms_shutdown(struct midcomms_node *node)
		node->state = DLM_FIN_WAIT1;
		pr_debug("switch node %d to state %s case 2\n",
			 node->nodeid, dlm_state_str(node->state));
		dlm_send_fin(node, dlm_act_fin_ack_rcv);
		break;
	case DLM_CLOSED:
		/* we have what we want */
		spin_unlock(&node->state_lock);
		return;
		break;
	default:
		/* busy to enter DLM_FIN_WAIT1, wait until passive
		 * done in shutdown_wait to enter DLM_CLOSED.
@@ -1431,29 +1441,20 @@ static void midcomms_shutdown(struct midcomms_node *node)
	}
	spin_unlock(&node->state_lock);

	if (node->state == DLM_FIN_WAIT1) {
		dlm_send_fin(node, dlm_act_fin_ack_rcv);

	if (DLM_DEBUG_FENCE_TERMINATION)
		msleep(5000);
	}

	/* wait for other side dlm + fin */
	ret = wait_event_timeout(node->shutdown_wait,
				 node->state == DLM_CLOSED ||
				 test_bit(DLM_NODE_FLAG_CLOSE, &node->flags),
				 DLM_SHUTDOWN_TIMEOUT);
	if (!ret || test_bit(DLM_NODE_FLAG_CLOSE, &node->flags)) {
	if (!ret || test_bit(DLM_NODE_FLAG_CLOSE, &node->flags))
		pr_debug("active shutdown timed out for node %d with state %s\n",
			 node->nodeid, dlm_state_str(node->state));
		midcomms_node_reset(node);
		dlm_lowcomms_shutdown_node(node->nodeid, true);
		return;
	}

	else
		pr_debug("active shutdown done for node %d with state %s\n",
			 node->nodeid, dlm_state_str(node->state));
	dlm_lowcomms_shutdown_node(node->nodeid, false);
}

void dlm_midcomms_shutdown(void)
@@ -1461,8 +1462,6 @@ void dlm_midcomms_shutdown(void)
	struct midcomms_node *node;
	int i, idx;

	dlm_lowcomms_shutdown();

	mutex_lock(&close_lock);
	idx = srcu_read_lock(&nodes_srcu);
	for (i = 0; i < CONN_HASH_SIZE; i++) {
@@ -1480,6 +1479,8 @@ void dlm_midcomms_shutdown(void)
	}
	srcu_read_unlock(&nodes_srcu, idx);
	mutex_unlock(&close_lock);

	dlm_lowcomms_shutdown();
}

int dlm_midcomms_close(int nodeid)
Loading