When the device offers DEVICE_INIT_ON_INFLATE (bit 7), the device
initializes inflated pages and returns a per-page bitmap indicating
which pages were successfully initialized.

The driver appends a device-writable bitmap buffer to each inflate
descriptor chain via virtqueue_add_sgs. After the host acknowledges,
the driver checks bitmap bits (bounded by used_len) and marks pages
with SetPageZeroed.

tell_host() returns used_len from virtqueue_get_buf(). Bitmap reads
are bounded: fill_balloon() and virtballoon_migratepage() only trust
bits within the used_len range.

On deflate, release_pages_balloon checks PageZeroed per page and
uses put_page_zeroed for pages the host initialized, propagating
the zeroed hint to the buddy allocator.

If inflate_vq has fewer than 2 descriptors, the feature is
cleared at probe time. If PAGE_POISON is negotiated with non-zero
poison_val, the feature is cleared in validate().

See the virtio spec change:
https://lore.kernel.org/all/9c69b992c3dd83dfef3db92cd86b2fd8a0730d48.1777731396.git....@redhat.com

Signed-off-by: Michael S. Tsirkin <[email protected]>
Assisted-by: Claude:claude-opus-4-6
Assisted-by: cursor-agent:GPT-5.4-xhigh
---
 drivers/virtio/virtio_balloon.c     | 97 +++++++++++++++++++++++++----
 include/uapi/linux/virtio_balloon.h |  1 +
 2 files changed, 87 insertions(+), 11 deletions(-)

diff --git a/drivers/virtio/virtio_balloon.c b/drivers/virtio/virtio_balloon.c
index e99ffbbdd2bd..708b0c344ae9 100644
--- a/drivers/virtio/virtio_balloon.c
+++ b/drivers/virtio/virtio_balloon.c
@@ -120,6 +120,9 @@ struct virtio_balloon {
        struct virtqueue *reporting_vq;
        struct page_reporting_dev_info pr_dev_info;
 
+       /* Bitmap returned by host for DEVICE_INIT_ON_INFLATE */
+       DECLARE_BITMAP(inflate_bitmap, VIRTIO_BALLOON_ARRAY_PFNS_MAX);
+
        /* State for keeping the wakeup_source active while adjusting the 
balloon */
        spinlock_t wakeup_lock;
        bool processing_wakeup_event;
@@ -180,20 +183,30 @@ static void balloon_ack(struct virtqueue *vq)
        wake_up(&vb->acked);
 }
 
-static void tell_host(struct virtio_balloon *vb, struct virtqueue *vq)
+static unsigned int tell_host(struct virtio_balloon *vb, struct virtqueue *vq)
 {
-       struct scatterlist sg;
+       struct scatterlist sg_out, sg_in;
+       struct scatterlist *sgs[] = { &sg_out, &sg_in };
        unsigned int len;
 
-       sg_init_one(&sg, vb->pfns, sizeof(vb->pfns[0]) * vb->num_pfns);
+       sg_init_one(&sg_out, vb->pfns, sizeof(vb->pfns[0]) * vb->num_pfns);
 
-       /* We should always be able to add one buffer to an empty queue. */
-       virtqueue_add_outbuf(vq, &sg, 1, vb, GFP_KERNEL);
+       if (vq == vb->inflate_vq &&
+           virtio_has_feature(vb->vdev,
+                              VIRTIO_BALLOON_F_DEVICE_INIT_ON_INFLATE)) {
+               unsigned int bitmap_bytes;
+
+               bitmap_bytes = DIV_ROUND_UP(vb->num_pfns, 8);
+               bitmap_zero(vb->inflate_bitmap, vb->num_pfns);
+               sg_init_one(&sg_in, vb->inflate_bitmap, bitmap_bytes);
+               virtqueue_add_sgs(vq, sgs, 1, 1, vb, GFP_KERNEL);
+       } else {
+               virtqueue_add_outbuf(vq, &sg_out, 1, vb, GFP_KERNEL);
+       }
        virtqueue_kick(vq);
 
-       /* When host has read buffer, this completes via balloon_ack */
        wait_event(vb->acked, virtqueue_get_buf(vq, &len));
-
+       return len;
 }
 
 static int virtballoon_free_page_report(struct page_reporting_dev_info 
*pr_dev_info,
@@ -290,8 +303,37 @@ static unsigned int fill_balloon(struct virtio_balloon 
*vb, size_t num)
 
        num_allocated_pages = vb->num_pfns;
        /* Did we get any? */
-       if (vb->num_pfns != 0)
-               tell_host(vb, vb->inflate_vq);
+       if (vb->num_pfns != 0) {
+               unsigned int used_len = tell_host(vb, vb->inflate_vq);
+
+               if (virtio_has_feature(vb->vdev,
+                                      
VIRTIO_BALLOON_F_DEVICE_INIT_ON_INFLATE)) {
+                       unsigned int i;
+                       unsigned int valid_bits = used_len * 8;
+
+                       for (i = 0; i < vb->num_pfns;
+                            i += VIRTIO_BALLOON_PAGES_PER_PAGE) {
+                               unsigned int pfn, j;
+                               bool zeroed = true;
+
+                               if (i + VIRTIO_BALLOON_PAGES_PER_PAGE > 
valid_bits)
+                                       break;
+                               for (j = 0; j < VIRTIO_BALLOON_PAGES_PER_PAGE; 
j++) {
+                                       if (!test_bit(i + j, 
vb->inflate_bitmap)) {
+                                               zeroed = false;
+                                               break;
+                                       }
+                               }
+                               if (zeroed) {
+                                       pfn = virtio32_to_cpu(vb->vdev,
+                                                             vb->pfns[i]);
+                                       __SetPageZeroed(pfn_to_page(pfn >>
+                                               (PAGE_SHIFT -
+                                                VIRTIO_BALLOON_PFN_SHIFT)));
+                               }
+                       }
+               }
+       }
        mutex_unlock(&vb->balloon_lock);
 
        return num_allocated_pages;
@@ -304,7 +346,12 @@ static void release_pages_balloon(struct virtio_balloon 
*vb,
 
        list_for_each_entry_safe(page, next, pages, lru) {
                list_del(&page->lru);
-               put_page(page); /* balloon reference */
+               if (PageZeroed(page)) {
+                       __ClearPageZeroed(page);
+                       put_page_zeroed(page);
+               } else {
+                       put_page(page);
+               }
        }
 }
 
@@ -851,7 +898,25 @@ static int virtballoon_migratepage(struct balloon_dev_info 
*vb_dev_info,
        /* balloon's page migration 1st step  -- inflate "newpage" */
        vb->num_pfns = VIRTIO_BALLOON_PAGES_PER_PAGE;
        set_page_pfns(vb, vb->pfns, newpage);
-       tell_host(vb, vb->inflate_vq);
+       {
+               unsigned int used_len = tell_host(vb, vb->inflate_vq);
+
+               if (virtio_has_feature(vb->vdev,
+                                      VIRTIO_BALLOON_F_DEVICE_INIT_ON_INFLATE) 
&&
+                   used_len >= DIV_ROUND_UP(VIRTIO_BALLOON_PAGES_PER_PAGE, 8)) 
{
+                       unsigned int j;
+                       bool zeroed = true;
+
+                       for (j = 0; j < VIRTIO_BALLOON_PAGES_PER_PAGE; j++) {
+                               if (!test_bit(j, vb->inflate_bitmap)) {
+                                       zeroed = false;
+                                       break;
+                               }
+                       }
+                       if (zeroed)
+                               __SetPageZeroed(newpage);
+               }
+       }
 
        /* balloon's page migration 2nd step -- deflate "page" */
        vb->num_pfns = VIRTIO_BALLOON_PAGES_PER_PAGE;
@@ -956,6 +1021,12 @@ static int virtballoon_probe(struct virtio_device *vdev)
        if (err)
                goto out_free_vb;
 
+       if (virtio_has_feature(vdev, VIRTIO_BALLOON_F_DEVICE_INIT_ON_INFLATE) &&
+           virtqueue_get_vring_size(vb->inflate_vq) < 2) {
+               err = -ENOSPC;
+               goto out_del_vqs;
+       }
+
        if (!virtio_has_feature(vb->vdev, VIRTIO_BALLOON_F_DEFLATE_ON_OOM))
                vb->vb_dev_info.adjust_managed_page_count = true;
 #ifdef CONFIG_BALLOON_MIGRATION
@@ -1171,6 +1242,9 @@ static int virtballoon_validate(struct virtio_device 
*vdev)
        else if (!virtio_has_feature(vdev, VIRTIO_BALLOON_F_PAGE_POISON))
                __virtio_clear_bit(vdev, VIRTIO_BALLOON_F_REPORTING);
 
+       if (virtio_has_feature(vdev, VIRTIO_BALLOON_F_PAGE_POISON) &&
+           !want_init_on_free())
+               __virtio_clear_bit(vdev, 
VIRTIO_BALLOON_F_DEVICE_INIT_ON_INFLATE);
        __virtio_clear_bit(vdev, VIRTIO_F_ACCESS_PLATFORM);
        return 0;
 }
@@ -1182,6 +1256,7 @@ static unsigned int features[] = {
        VIRTIO_BALLOON_F_FREE_PAGE_HINT,
        VIRTIO_BALLOON_F_PAGE_POISON,
        VIRTIO_BALLOON_F_REPORTING,
+       VIRTIO_BALLOON_F_DEVICE_INIT_ON_INFLATE,
 };
 
 static struct virtio_driver virtio_balloon_driver = {
diff --git a/include/uapi/linux/virtio_balloon.h 
b/include/uapi/linux/virtio_balloon.h
index ee35a372805d..d129736cc3a8 100644
--- a/include/uapi/linux/virtio_balloon.h
+++ b/include/uapi/linux/virtio_balloon.h
@@ -37,6 +37,7 @@
 #define VIRTIO_BALLOON_F_FREE_PAGE_HINT        3 /* VQ to report free pages */
 #define VIRTIO_BALLOON_F_PAGE_POISON   4 /* Guest is using page poisoning */
 #define VIRTIO_BALLOON_F_REPORTING     5 /* Page reporting virtqueue */
+#define VIRTIO_BALLOON_F_DEVICE_INIT_ON_INFLATE        6 /* Device initializes 
pages on inflate */
 
 /* Size of a PFN in the balloon interface. */
 #define VIRTIO_BALLOON_PFN_SHIFT 12
-- 
MST


Reply via email to