On Wed, Nov 20, 2024 at 12:34:50PM +0530, Shijith Thotton wrote:
> @@ -63,44 +80,53 @@ static irqreturn_t octep_vdpa_intr_handler(int irq, void 
> *data)
>  static void octep_free_irqs(struct octep_hw *oct_hw)
>  {
>       struct pci_dev *pdev = oct_hw->pdev;
> +     int irq;
> +
> +     for (irq = 0; irq < oct_hw->nb_irqs && oct_hw->irqs; irq++) {
> +             if (oct_hw->irqs[irq] < 0)
> +                     continue;
>  
> -     if (oct_hw->irq != -1) {
> -             devm_free_irq(&pdev->dev, oct_hw->irq, oct_hw);
> -             oct_hw->irq = -1;
> +             devm_free_irq(&pdev->dev, oct_hw->irqs[irq], oct_hw);
>       }
> +
>       pci_free_irq_vectors(pdev);
> +     kfree(oct_hw->irqs);

You should add:

        oct_hw->nb_irqs = 0;
        oct_hw->irqs = NULL;

Otherwise if reset is called twice in a row, before re-initializing the IRQs it
results in a use after free.

>  }
>  
>  static int octep_request_irqs(struct octep_hw *oct_hw)
>  {
>       struct pci_dev *pdev = oct_hw->pdev;
> -     int ret, irq;
> +     int ret, irq, idx;
>  
> -     /* Currently HW device provisions one IRQ per VF, hence
> -      * allocate one IRQ for all virtqueues call interface.
> -      */
> -     ret = pci_alloc_irq_vectors(pdev, 1, 1, PCI_IRQ_MSIX);
> +     ret = pci_alloc_irq_vectors(pdev, 1, oct_hw->nb_irqs, PCI_IRQ_MSIX);
>       if (ret < 0) {
>               dev_err(&pdev->dev, "Failed to alloc msix vector");
>               return ret;
>       }
>  
> -     snprintf(oct_hw->vqs->msix_name, sizeof(oct_hw->vqs->msix_name),
> -              OCTEP_VDPA_DRIVER_NAME "-vf-%d", pci_iov_vf_id(pdev));
> +     oct_hw->irqs = kcalloc(oct_hw->nb_irqs, sizeof(int), GFP_KERNEL);

This isn't free on the ->release() path or whatever.  octep_free_irqs() is
called on reset() but we rely on devm_ to free the IRQs on ->release().  Use
devm_kcalloc() here as well, probably.

> +     if (!oct_hw->irqs) {
> +             ret = -ENOMEM;
> +             goto free_irqs;
> +     }
>  
> -     irq = pci_irq_vector(pdev, 0);
> -     ret = devm_request_irq(&pdev->dev, irq, octep_vdpa_intr_handler, 0,
> -                            oct_hw->vqs->msix_name, oct_hw);
> -     if (ret) {
> -             dev_err(&pdev->dev, "Failed to register interrupt handler\n");
> -             goto free_irq_vec;
> +     memset(oct_hw->irqs, -1, sizeof(oct_hw->irqs));

This works, but it would be more normal to just leave it zeroed and check for
zero instead of checking for negatives.  There is never a zero IRQ.  See my blog
for more details:
https://staticthinking.wordpress.com/2023/08/07/writing-a-check-for-zero-irq-error-codes/

regards,
dan carpenter

> +
> +     for (idx = 0; idx < oct_hw->nb_irqs; idx++) {
> +             irq = pci_irq_vector(pdev, idx);
> +             ret = devm_request_irq(&pdev->dev, irq, 
> octep_vdpa_intr_handler, 0,
> +                                    dev_name(&pdev->dev), oct_hw);
> +             if (ret) {
> +                     dev_err(&pdev->dev, "Failed to register interrupt 
> handler\n");
> +                     goto free_irqs;
> +             }
> +             oct_hw->irqs[idx] = irq;
>       }
> -     oct_hw->irq = irq;
>  
>       return 0;
>  
> -free_irq_vec:
> -     pci_free_irq_vectors(pdev);
> +free_irqs:
> +     octep_free_irqs(oct_hw);
>       return ret;
>  }
>  

regards,
dan carpenter

Reply via email to