diff --git a/drivers/pci/msi.c b/drivers/pci/msi.c
index 79e56906c7f1cb81da62c798a1fe2c0d2e3316ea..944e45e4a84f369636aad2f537c4f04c3855ccde 100644
--- a/drivers/pci/msi.c
+++ b/drivers/pci/msi.c
@@ -313,22 +313,22 @@ static void __pci_restore_msix_state(struct pci_dev *dev)
 
 	if (!dev->msix_enabled)
 		return;
+	BUG_ON(list_empty(&dev->msi_list));
+	entry = list_entry(dev->msi_list.next, struct msi_desc, list);
+	pos = entry->msi_attrib.pos;
+	pci_read_config_word(dev, pos + PCI_MSIX_FLAGS, &control);
 
 	/* route the table */
 	pci_intx_for_msi(dev, 0);
-	msix_set_enable(dev, 0);
+	control |= PCI_MSIX_FLAGS_ENABLE | PCI_MSIX_FLAGS_MASKALL;
+	pci_write_config_word(dev, pos + PCI_MSIX_FLAGS, control);
 
 	list_for_each_entry(entry, &dev->msi_list, list) {
 		write_msi_msg(entry->irq, &entry->msg);
 		msix_mask_irq(entry, entry->masked);
 	}
 
-	BUG_ON(list_empty(&dev->msi_list));
-	entry = list_entry(dev->msi_list.next, struct msi_desc, list);
-	pos = entry->msi_attrib.pos;
-	pci_read_config_word(dev, pos + PCI_MSIX_FLAGS, &control);
 	control &= ~PCI_MSIX_FLAGS_MASKALL;
-	control |= PCI_MSIX_FLAGS_ENABLE;
 	pci_write_config_word(dev, pos + PCI_MSIX_FLAGS, control);
 }
 
@@ -419,11 +419,14 @@ static int msix_capability_init(struct pci_dev *dev,
 	u8 bir;
 	void __iomem *base;
 
-	msix_set_enable(dev, 0);/* Ensure msix is disabled as I set it up */
-
    	pos = pci_find_capability(dev, PCI_CAP_ID_MSIX);
+	pci_read_config_word(dev, pos + PCI_MSIX_FLAGS, &control);
+
+	/* Ensure MSI-X is disabled while it is set up */
+	control &= ~PCI_MSIX_FLAGS_ENABLE;
+	pci_write_config_word(dev, pos + PCI_MSIX_FLAGS, control);
+
 	/* Request & Map MSI-X table region */
- 	pci_read_config_word(dev, msi_control_reg(pos), &control);
 	nr_entries = multi_msix_capable(control);
 
  	pci_read_config_dword(dev, msix_table_offset_reg(pos), &table_offset);
@@ -434,7 +437,6 @@ static int msix_capability_init(struct pci_dev *dev,
 	if (base == NULL)
 		return -ENOMEM;
 
-	/* MSI-X Table Initialization */
 	for (i = 0; i < nvec; i++) {
 		entry = alloc_msi_entry(dev);
 		if (!entry)
@@ -447,7 +449,6 @@ static int msix_capability_init(struct pci_dev *dev,
 		entry->msi_attrib.default_irq = dev->irq;
 		entry->msi_attrib.pos = pos;
 		entry->mask_base = base;
-		msix_mask_irq(entry, 1);
 
 		list_add_tail(&entry->list, &dev->msi_list);
 	}
@@ -472,22 +473,31 @@ static int msix_capability_init(struct pci_dev *dev,
 		return ret;
 	}
 
+	/*
+	 * Some devices require MSI-X to be enabled before we can touch the
+	 * MSI-X registers.  We need to mask all the vectors to prevent
+	 * interrupts coming in before they're fully set up.
+	 */
+	control |= PCI_MSIX_FLAGS_MASKALL | PCI_MSIX_FLAGS_ENABLE;
+	pci_write_config_word(dev, pos + PCI_MSIX_FLAGS, control);
+
 	i = 0;
 	list_for_each_entry(entry, &dev->msi_list, list) {
 		entries[i].vector = entry->irq;
 		set_irq_msi(entry->irq, entry);
+		j = entries[i].entry;
+		entry->masked = readl(base + j * PCI_MSIX_ENTRY_SIZE +
+					PCI_MSIX_ENTRY_VECTOR_CTRL_OFFSET);
+		msix_mask_irq(entry, 1);
 		i++;
 	}
-	/* Set MSI-X enabled bits */
+
+	/* Set MSI-X enabled bits and unmask the function */
 	pci_intx_for_msi(dev, 0);
-	msix_set_enable(dev, 1);
 	dev->msix_enabled = 1;
 
-	list_for_each_entry(entry, &dev->msi_list, list) {
-		int vector = entry->msi_attrib.entry_nr;
-		entry->masked = readl(base + vector * PCI_MSIX_ENTRY_SIZE +
-					PCI_MSIX_ENTRY_VECTOR_CTRL_OFFSET);
-	}
+	control &= ~PCI_MSIX_FLAGS_MASKALL;
+	pci_write_config_word(dev, pos + PCI_MSIX_FLAGS, control);
 
 	return 0;
 }