On Wed, 24 Mar 2021 21:05:52 -0400
Daniel Jordan <daniel.m.jor...@oracle.com> wrote:

> When vfio_pin_pages_remote() returns with a partial batch consisting of
> a single VM_PFNMAP pfn, a subsequent call will unfortunately try
> restoring it from batch->pages, resulting in vfio mapping the wrong page
> and unbalancing the page refcount.
> 
> Prevent the function from returning with this kind of partial batch to
> avoid the issue.  There's no explicit check for a VM_PFNMAP pfn because
> it's awkward to do so, so infer it from characteristics of the batch
> instead.  This may result in occasional false positives but keeps the
> code simpler.
> 
> Fixes: 4d83de6da265 ("vfio/type1: Batch page pinning")
> Link: https://lkml.kernel.org/r/20210323133254.33ed9...@omen.home.shazbot.org/
> Reported-by: Alex Williamson <alex.william...@redhat.com>
> Suggested-by: Alex Williamson <alex.william...@redhat.com>
> Signed-off-by: Daniel Jordan <daniel.m.jor...@oracle.com>
> ---
> 
> Alex, I couldn't immediately find a way to trigger this bug, but I can
> run your test case if you like.
> 
> This is the minimal fix, but it should still protect all calls of
> vfio_batch_unpin() from this problem.

Thanks, applied to my for-linus branch for v5.12.  The attached unit
test triggers the issue, I don't have any real world examples and was
only just experimenting with this for another series earlier this week.
Thanks,

Alex
/*
 * Alternate pages of device memory and anonymous memory within a single DMA
 * mapping.
 *
 * Run with argv[1] as a fully specified PCI device already bound to vfio-pci.
 * ex. "alternate-pfnmap 0000:01:00.0"
 */
#include <errno.h>
#include <libgen.h>
#include <fcntl.h>
#include <signal.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/eventfd.h>
#include <sys/ioctl.h>
#include <sys/mman.h>
#include <sys/stat.h>
#include <sys/time.h>
#include <sys/types.h>

#include <linux/ioctl.h>
#include <linux/vfio.h>
#include <linux/pci_regs.h>

void *vaddr = (void *)0x100000000;
size_t map_size = 0;

int get_container(void)
{
	int container = open("/dev/vfio/vfio", O_RDWR);

	if (container < 0)
		fprintf(stderr, "Failed to open /dev/vfio/vfio, %d (%s)\n",
		       container, strerror(errno));

	return container;
}

int get_group(char *name)
{
	int seg, bus, slot, func;
	int ret, group, groupid;
	char path[50], iommu_group_path[50], *group_name;
	struct stat st;
	ssize_t len;
	struct vfio_group_status group_status = {
		.argsz = sizeof(group_status)
	};

	ret = sscanf(name, "%04x:%02x:%02x.%d", &seg, &bus, &slot, &func);
	if (ret != 4) {
		fprintf(stderr, "Invalid device\n");
		return -EINVAL;
	}

	snprintf(path, sizeof(path),
		 "/sys/bus/pci/devices/%04x:%02x:%02x.%01x/",
		 seg, bus, slot, func);

	ret = stat(path, &st);
	if (ret < 0) {
		fprintf(stderr, "No such device\n");
		return ret;
	}

	strncat(path, "iommu_group", sizeof(path) - strlen(path) - 1);

	len = readlink(path, iommu_group_path, sizeof(iommu_group_path));
	if (len <= 0) {
		fprintf(stderr, "No iommu_group for device\n");
		return -EINVAL;
	}

	iommu_group_path[len] = 0;
	group_name = basename(iommu_group_path);

	if (sscanf(group_name, "%d", &groupid) != 1) {
		fprintf(stderr, "Unknown group\n");
		return -EINVAL;
	}

	snprintf(path, sizeof(path), "/dev/vfio/%d", groupid);
	group = open(path, O_RDWR);
	if (group < 0) {
		fprintf(stderr, "Failed to open %s, %d (%s)\n",
		       path, group, strerror(errno));
		return group;
	}

	ret = ioctl(group, VFIO_GROUP_GET_STATUS, &group_status);
	if (ret) {
		fprintf(stderr, "ioctl(VFIO_GROUP_GET_STATUS) failed\n");
		return ret;
	}

	if (!(group_status.flags & VFIO_GROUP_FLAGS_VIABLE)) {
		fprintf(stderr,
			"Group not viable, all devices attached to vfio?\n");
		return -1;
	}

	return group;
}

int group_set_container(int group, int container)
{
	int ret = ioctl(group, VFIO_GROUP_SET_CONTAINER, &container);

	if (ret)
		fprintf(stderr, "Failed to set group container\n");

	return ret;
}

int container_set_iommu(int container)
{
	int ret = ioctl(container, VFIO_SET_IOMMU, VFIO_TYPE1_IOMMU);

	if (ret)
		fprintf(stderr, "Failed to set IOMMU\n");

	return ret;
}

int group_get_device(int group, char *name)
{
	int device = ioctl(group, VFIO_GROUP_GET_DEVICE_FD, name);

	if (device < 0)
		fprintf(stderr, "Failed to get device\n");

	return device;
}

void *mmap_device_page(int device, int prot)
{
	struct vfio_region_info config_info = {
		.argsz = sizeof(config_info),
		.index = VFIO_PCI_CONFIG_REGION_INDEX
	};
	struct vfio_region_info region_info = {
		.argsz = sizeof(region_info)
	};
	void *map = MAP_FAILED;
	unsigned int bar;
	int i, ret;

	ret = ioctl(device, VFIO_DEVICE_GET_REGION_INFO, &config_info);
	if (ret) {
		fprintf(stderr, "Failed to get config space region info\n");
		return map;
	}

	for (i = 0; i < 6; i++) {
		if (pread(device, &bar, sizeof(bar), config_info.offset +
			  PCI_BASE_ADDRESS_0 + (4 * i)) != sizeof(bar)) {
			fprintf(stderr, "Error reading BAR%d\n", i);
			return map;
		}

		if (!(bar & PCI_BASE_ADDRESS_SPACE)) {
			break;
tryagain:
			if (bar & PCI_BASE_ADDRESS_MEM_TYPE_64)
				i++;
		}
	}

	if (i >= 6) {
		fprintf(stderr, "No memory BARs found\n");
		return map;
	}

	region_info.index = VFIO_PCI_BAR0_REGION_INDEX + i;
	ret = ioctl(device, VFIO_DEVICE_GET_REGION_INFO, &region_info);
	if (ret) {
		fprintf(stderr, "Failed to get BAR%d region info\n", i);
		return map;
	}
  
	if (!(region_info.flags & VFIO_REGION_INFO_FLAG_MMAP)) {
		printf("No mmap support, try next\n");
		goto tryagain;
	}

	if (region_info.size < getpagesize()) {
		printf("Too small for mmap, try next\n");
		goto tryagain;
	}

	map = mmap(vaddr + map_size, getpagesize(), prot, 
		   MAP_SHARED, device, region_info.offset);
	if (map == MAP_FAILED) {
		fprintf(stderr, "Error mmap'ing BAR: %m\n");
		goto tryagain;
	}

	fprintf(stderr, "\t\tmmap_device_page @0x%016lx\n",
						(unsigned long long)map);
	if (!vaddr) {
		vaddr = map;
	} else if (map != vaddr + map_size) {
		fprintf(stderr, "Did not get contiguous mmap\n");
		munmap(map, getpagesize());
		return MAP_FAILED;
	}

	map_size += getpagesize();

	return map;
}

void *mmap_mem_page(int prot)
{
	void *map = mmap(vaddr + map_size, getpagesize(), prot,
			 MAP_PRIVATE | MAP_ANONYMOUS, 0, 0);

	if (map == MAP_FAILED) {
		fprintf(stderr, "Map anonymous page failed: %m\n");
		return map;
	}

	fprintf(stderr, "\t\tmmap_mem_page @0x%016lx\n",
						(unsigned long long)map);
	if (!vaddr) {
		vaddr = map;
	} else if (map != vaddr + map_size) {
		fprintf(stderr, "Did not get contiguous mmap\n");
		munmap(map, getpagesize());
		return MAP_FAILED;
	}

	map_size += getpagesize();

	return map;
}

int dma_map(int container, void *map, int size, unsigned long iova)
{
	struct vfio_iommu_type1_dma_map dma_map = {
		.argsz = sizeof(dma_map),
		.size = size,
		.vaddr = (__u64)map,
		.iova = iova,
		.flags = VFIO_DMA_MAP_FLAG_READ | VFIO_DMA_MAP_FLAG_WRITE
	};
	int ret;

	ret = ioctl(container, VFIO_IOMMU_MAP_DMA, &dma_map);
	if (ret)
		fprintf(stderr, "Failed to DMA map: %m\n");

	return ret;
}

int dma_unmap(int container, int size, unsigned long iova)
{
	struct vfio_iommu_type1_dma_unmap dma_unmap = {
		.argsz = sizeof(dma_unmap),
		.iova = iova,
		.size = size,
	};
	int ret;

	ret = ioctl(container, VFIO_IOMMU_UNMAP_DMA, &dma_unmap);
	if (ret)
		fprintf(stderr, "Failed to DMA unmap: %m\n");

	return dma_unmap.size;
}

int main(int argc, char **argv)
{
	int container1;
	int group1;
	int device1;
	int ret;
	void *map, *map_base;

	group1 = get_group(argv[1]);
	if (group1 < 0) {
		fprintf(stderr, "Failed to get group for %s\n", argv[1]);
		return group1;
	}

	fprintf(stderr, "\tGot group for %s\n", argv[1]);

	container1 = get_container();

	if (container1 < 0) {
		fprintf(stderr, "Failed to get container\n");
		return -EFAULT;
	}

	fprintf(stderr, "\tGot container\n");

	if (group_set_container(group1, container1)) {
		fprintf(stderr, "Failed to set container\n");
		return -EFAULT;
	}

	fprintf(stderr, "\tAttached group to container\n");

	if (container_set_iommu(container1)) {
		fprintf(stderr, "Failed to set iommu\n");
		return -EFAULT;
	}

	fprintf(stderr, "\tSet IOMMU model for container\n");

	device1 = group_get_device(group1, argv[1]);

	if (device1 < 0) {
		fprintf(stderr, "Failed to get devices\n");
		return -EFAULT;
	}

	fprintf(stderr, "\tGot device file descriptors\n");

	map = mmap_device_page(device1, PROT_READ | PROT_WRITE);
	if (map == MAP_FAILED) {
		fprintf(stderr, "Failed to mmap device page\n");
		return -EFAULT;
	}

	fprintf(stderr, "\tGot mmap to device %s\n", argv[1]);

	map_base = map;
	
	map = mmap_mem_page(PROT_READ | PROT_WRITE);
	if (map == MAP_FAILED) {
		fprintf(stderr, "Failed to mmap memory page\n");
		return -EFAULT;
	}

	fprintf(stderr, "\tGot memory page\n");

	map = mmap_device_page(device1, PROT_READ | PROT_WRITE);
	if (map == MAP_FAILED) {
		fprintf(stderr, "Failed to mmap device page\n");
		return -EFAULT;
	}

	fprintf(stderr, "\tGot mmap to device %s\n", argv[1]);

	map = mmap_mem_page(PROT_READ | PROT_WRITE);
	if (map == MAP_FAILED) {
		fprintf(stderr, "Failed to mmap memory page\n");
		return -EFAULT;
	}

	fprintf(stderr, "\tGot memory page\n");

	if (dma_map(container1, map_base, getpagesize() * 4,
						1024 * 1024 * 1024)) {
		fprintf(stderr, "Failed to DMA map pages\n");
		return -EFAULT;
	}

	fprintf(stderr, "\tDMA mapped pages into container for device %s\n",
		argv[1]);

	return 0;
}

Reply via email to