Consolidate region type detection and initialization into
mshv_region_create() to simplify the region creation flow. Move type
determination logic (MMIO/pinned/movable) earlier in the process and
initialize type-specific fields during creation rather than after.

This eliminates the need for mshv_region_movable_init/fini() by
handling MMU interval notifier setup directly in the constructor and
teardown in the destructor. Region mapping is also unified through a
single mshv_map_region() dispatcher that routes to the appropriate
type-specific handler.

Changes improve code organization by:
- Reducing API surface (4 fewer exported functions)
- Centralizing type determination and validation
- Making region lifecycle more explicit and easier to follow
- Removing post-construction initialization steps

The refactoring maintains existing functionality while making the
codebase more maintainable and less error-prone.

Additionally, movable region initialization now fails explicitly
if mmu_interval_notifier_insert() returns an error, rather than
silently falling back to pinned memory. This fail-fast approach
makes configuration issues more visible.

Signed-off-by: Stanislav Kinsburskii <[email protected]>
---
 drivers/hv/mshv_regions.c   |   81 ++++++++++++++++++++++++++++---------------
 drivers/hv/mshv_root.h      |   14 +++----
 drivers/hv/mshv_root_main.c |   61 +++++++++++++-------------------
 3 files changed, 83 insertions(+), 73 deletions(-)

diff --git a/drivers/hv/mshv_regions.c b/drivers/hv/mshv_regions.c
index 6b703b269a4f..a85d18e2c279 100644
--- a/drivers/hv/mshv_regions.c
+++ b/drivers/hv/mshv_regions.c
@@ -20,6 +20,8 @@
 #define MSHV_MAP_FAULT_IN_PAGES                                PTRS_PER_PMD
 #define MSHV_INVALID_PFN                               ULONG_MAX
 
+static const struct mmu_interval_notifier_ops mshv_region_mni_ops;
+
 /**
  * mshv_chunk_stride - Compute stride for mapping guest memory
  * @page      : The page to check for huge page backing
@@ -241,16 +243,39 @@ static int mshv_region_process_range(struct 
mshv_mem_region *region,
        return 0;
 }
 
-struct mshv_mem_region *mshv_region_create(u64 guest_pfn, u64 nr_pfns,
-                                          u64 uaddr, u32 flags)
+struct mshv_mem_region *mshv_region_create(enum mshv_region_type type,
+                                          u64 guest_pfn, u64 nr_pfns,
+                                          u64 uaddr, u32 flags,
+                                          ulong mmio_pfn)
 {
        struct mshv_mem_region *region;
+       int ret = 0;
        u64 i;
 
        region = vzalloc(sizeof(*region) + sizeof(unsigned long) * nr_pfns);
        if (!region)
                return ERR_PTR(-ENOMEM);
 
+       switch (type) {
+       case MSHV_REGION_TYPE_MEM_MOVABLE:
+               ret = mmu_interval_notifier_insert(&region->mreg_mni,
+                                                  current->mm, uaddr,
+                                                  nr_pfns << HV_HYP_PAGE_SHIFT,
+                                                  &mshv_region_mni_ops);
+               break;
+       case MSHV_REGION_TYPE_MEM_PINNED:
+               break;
+       case MSHV_REGION_TYPE_MMIO:
+               region->mreg_mmio_pfn = mmio_pfn;
+               break;
+       default:
+               ret = -EINVAL;
+       }
+
+       if (ret)
+               goto free_region;
+
+       region->mreg_type = type;
        region->nr_pfns = nr_pfns;
        region->start_gfn = guest_pfn;
        region->start_uaddr = uaddr;
@@ -263,9 +288,14 @@ struct mshv_mem_region *mshv_region_create(u64 guest_pfn, 
u64 nr_pfns,
        for (i = 0; i < nr_pfns; i++)
                region->mreg_pfns[i] = MSHV_INVALID_PFN;
 
+       mutex_init(&region->mreg_mutex);
        kref_init(&region->mreg_refcount);
 
        return region;
+
+free_region:
+       vfree(region);
+       return ERR_PTR(ret);
 }
 
 static int mshv_region_chunk_share(struct mshv_mem_region *region,
@@ -462,7 +492,7 @@ static void mshv_region_destroy(struct kref *ref)
        int ret;
 
        if (region->mreg_type == MSHV_REGION_TYPE_MEM_MOVABLE)
-               mshv_region_movable_fini(region);
+               mmu_interval_notifier_remove(&region->mreg_mni);
 
        if (mshv_partition_encrypted(partition)) {
                ret = mshv_region_share(region);
@@ -736,27 +766,6 @@ static const struct mmu_interval_notifier_ops 
mshv_region_mni_ops = {
        .invalidate = mshv_region_interval_invalidate,
 };
 
-void mshv_region_movable_fini(struct mshv_mem_region *region)
-{
-       mmu_interval_notifier_remove(&region->mreg_mni);
-}
-
-bool mshv_region_movable_init(struct mshv_mem_region *region)
-{
-       int ret;
-
-       ret = mmu_interval_notifier_insert(&region->mreg_mni, current->mm,
-                                          region->start_uaddr,
-                                          region->nr_pfns << HV_HYP_PAGE_SHIFT,
-                                          &mshv_region_mni_ops);
-       if (ret)
-               return false;
-
-       mutex_init(&region->mreg_mutex);
-
-       return true;
-}
-
 /**
  * mshv_map_pinned_region - Pin and map memory regions
  * @region: Pointer to the memory region structure
@@ -770,7 +779,7 @@ bool mshv_region_movable_init(struct mshv_mem_region 
*region)
  *
  * Return: 0 on success, negative error code on failure.
  */
-int mshv_map_pinned_region(struct mshv_mem_region *region)
+static int mshv_map_pinned_region(struct mshv_mem_region *region)
 {
        struct mshv_partition *partition = region->partition;
        int ret;
@@ -826,17 +835,31 @@ int mshv_map_pinned_region(struct mshv_mem_region *region)
        return ret;
 }
 
-int mshv_map_movable_region(struct mshv_mem_region *region)
+static int mshv_map_movable_region(struct mshv_mem_region *region)
 {
        return mshv_region_collect_and_map(region, 0, region->nr_pfns,
                                           false);
 }
 
-int mshv_map_mmio_region(struct mshv_mem_region *region,
-                        unsigned long mmio_pfn)
+static int mshv_map_mmio_region(struct mshv_mem_region *region)
 {
        struct mshv_partition *partition = region->partition;
 
        return hv_call_map_mmio_pfns(partition->pt_id, region->start_gfn,
-                                    mmio_pfn, region->nr_pfns);
+                                    region->mreg_mmio_pfn,
+                                    region->nr_pfns);
+}
+
+int mshv_map_region(struct mshv_mem_region *region)
+{
+       switch (region->mreg_type) {
+       case MSHV_REGION_TYPE_MEM_PINNED:
+               return mshv_map_pinned_region(region);
+       case MSHV_REGION_TYPE_MEM_MOVABLE:
+               return mshv_map_movable_region(region);
+       case MSHV_REGION_TYPE_MMIO:
+               return mshv_map_mmio_region(region);
+       }
+
+       return -EINVAL;
 }
diff --git a/drivers/hv/mshv_root.h b/drivers/hv/mshv_root.h
index 1f92b9f85b60..2bcdfa070517 100644
--- a/drivers/hv/mshv_root.h
+++ b/drivers/hv/mshv_root.h
@@ -92,6 +92,7 @@ struct mshv_mem_region {
        enum mshv_region_type mreg_type;
        struct mmu_interval_notifier mreg_mni;
        struct mutex mreg_mutex;        /* protects region PFNs remapping */
+       u64 mreg_mmio_pfn;
        unsigned long mreg_pfns[];
 };
 
@@ -366,16 +367,13 @@ extern struct mshv_root mshv_root;
 extern enum hv_scheduler_type hv_scheduler_type;
 extern u8 * __percpu *hv_synic_eventring_tail;
 
-struct mshv_mem_region *mshv_region_create(u64 guest_pfn, u64 nr_pages,
-                                          u64 uaddr, u32 flags);
+struct mshv_mem_region *mshv_region_create(enum mshv_region_type type,
+                                          u64 guest_pfn, u64 nr_pfns,
+                                          u64 uaddr, u32 flags,
+                                          ulong mmio_pfn);
 void mshv_region_put(struct mshv_mem_region *region);
 int mshv_region_get(struct mshv_mem_region *region);
 bool mshv_region_handle_gfn_fault(struct mshv_mem_region *region, u64 gfn);
-void mshv_region_movable_fini(struct mshv_mem_region *region);
-bool mshv_region_movable_init(struct mshv_mem_region *region);
-int mshv_map_pinned_region(struct mshv_mem_region *region);
-int mshv_map_movable_region(struct mshv_mem_region *region);
-int mshv_map_mmio_region(struct mshv_mem_region *region,
-                        unsigned long mmio_pfn);
+int mshv_map_region(struct mshv_mem_region *region);
 
 #endif /* _MSHV_ROOT_H_ */
diff --git a/drivers/hv/mshv_root_main.c b/drivers/hv/mshv_root_main.c
index adb09350205a..3bfa9e9c575f 100644
--- a/drivers/hv/mshv_root_main.c
+++ b/drivers/hv/mshv_root_main.c
@@ -1217,11 +1217,14 @@ static void mshv_async_hvcall_handler(void *data, u64 
*status)
  */
 static int mshv_partition_create_region(struct mshv_partition *partition,
                                        struct mshv_user_mem_region *mem,
-                                       struct mshv_mem_region **regionpp,
-                                       bool is_mmio)
+                                       struct mshv_mem_region **regionpp)
 {
        struct mshv_mem_region *rg;
+       enum mshv_region_type type;
        u64 nr_pfns = HVPFN_DOWN(mem->size);
+       struct vm_area_struct *vma;
+       ulong mmio_pfn;
+       bool is_mmio;
 
        /* Reject overlapping regions */
        spin_lock(&partition->pt_mem_regions_lock);
@@ -1234,18 +1237,27 @@ static int mshv_partition_create_region(struct 
mshv_partition *partition,
        }
        spin_unlock(&partition->pt_mem_regions_lock);
 
-       rg = mshv_region_create(mem->guest_pfn, nr_pfns,
-                               mem->userspace_addr, mem->flags);
-       if (IS_ERR(rg))
-               return PTR_ERR(rg);
+       mmap_read_lock(current->mm);
+       vma = vma_lookup(current->mm, mem->userspace_addr);
+       is_mmio = vma ? !!(vma->vm_flags & (VM_IO | VM_PFNMAP)) : 0;
+       mmio_pfn = is_mmio ? vma->vm_pgoff : 0;
+       mmap_read_unlock(current->mm);
+
+       if (!vma)
+               return -EINVAL;
 
        if (is_mmio)
-               rg->mreg_type = MSHV_REGION_TYPE_MMIO;
-       else if (mshv_partition_encrypted(partition) ||
-                !mshv_region_movable_init(rg))
-               rg->mreg_type = MSHV_REGION_TYPE_MEM_PINNED;
+               type = MSHV_REGION_TYPE_MMIO;
+       else if (mshv_partition_encrypted(partition))
+               type = MSHV_REGION_TYPE_MEM_PINNED;
        else
-               rg->mreg_type = MSHV_REGION_TYPE_MEM_MOVABLE;
+               type = MSHV_REGION_TYPE_MEM_MOVABLE;
+
+       rg = mshv_region_create(type, mem->guest_pfn, nr_pfns,
+                               mem->userspace_addr, mem->flags,
+                               mmio_pfn);
+       if (IS_ERR(rg))
+               return PTR_ERR(rg);
 
        rg->partition = partition;
 
@@ -1271,40 +1283,17 @@ mshv_map_user_memory(struct mshv_partition *partition,
                     struct mshv_user_mem_region mem)
 {
        struct mshv_mem_region *region;
-       struct vm_area_struct *vma;
-       bool is_mmio;
-       ulong mmio_pfn;
        long ret;
 
        if (mem.flags & BIT(MSHV_SET_MEM_BIT_UNMAP) ||
            !access_ok((const void __user *)mem.userspace_addr, mem.size))
                return -EINVAL;
 
-       mmap_read_lock(current->mm);
-       vma = vma_lookup(current->mm, mem.userspace_addr);
-       is_mmio = vma ? !!(vma->vm_flags & (VM_IO | VM_PFNMAP)) : 0;
-       mmio_pfn = is_mmio ? vma->vm_pgoff : 0;
-       mmap_read_unlock(current->mm);
-
-       if (!vma)
-               return -EINVAL;
-
-       ret = mshv_partition_create_region(partition, &mem, &region,
-                                          is_mmio);
+       ret = mshv_partition_create_region(partition, &mem, &region);
        if (ret)
                return ret;
 
-       switch (region->mreg_type) {
-       case MSHV_REGION_TYPE_MEM_PINNED:
-               ret = mshv_map_pinned_region(region);
-               break;
-       case MSHV_REGION_TYPE_MEM_MOVABLE:
-               ret = mshv_map_movable_region(region);
-               break;
-       case MSHV_REGION_TYPE_MMIO:
-               ret = mshv_map_mmio_region(region, mmio_pfn);
-               break;
-       }
+       ret = mshv_map_region(region);
 
        trace_mshv_map_user_memory(partition->pt_id, region->start_uaddr,
                                   region->start_gfn, region->nr_pfns,



Reply via email to