/*
 * Copyright 2013 Red Hat Inc.
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2 of
 * the License, or (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 *
 * Authors: Jérôme Glisse <jglisse@redhat.com>
 */
#define _GNU_SOURCE

#include <assert.h>
#include <time.h>
#include <errno.h>
#include <fcntl.h>
#include <stdio.h>
#include <stdint.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <setjmp.h>
#include <strings.h>
#include <sys/mman.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <sys/ioctl.h>
#include <pthread.h>
#include <sys/time.h>
#include <sys/types.h>
#include <sys/syscall.h>

#include "hmm_test_framework.h"

#define gettid() ((pid_t)syscall(SYS_gettid))

#define MULTITHREADED_MIGRATE 0
#define MULTITHREADED_READ 0

#define MAX_RETRY 16

#define ALIGN(x, a) (((x) + (a - 1)) & (~((a) - 1)))

static int _hmm_exit_ok = 0;
static jmp_buf _hmm_exit_env;
static unsigned page_size = 0;
static unsigned page_shift;
static unsigned long page_mask;

static struct hmm_ctx *g_ctx;

static inline void hmm_init_page_info(void)
{
    if (page_size) {
        return;
    }
    page_size = sysconf(_SC_PAGE_SIZE);
    page_shift = ffs(page_size) - 1;
    page_mask = ~((unsigned long)(page_size - 1));
}

static void hmm_exit(void)
{
    if (_hmm_exit_ok) {
        longjmp(_hmm_exit_env, -1);
    }
    exit(-1);
}

static int hmm_dummy_ctx_register(struct hmm_ctx *ctx)
{
    hmm_init_page_info();
    ctx->pid = getpid();
    return 0;
}

int hmm_ctx_init(struct hmm_ctx *ctx)
{
    char pathname[32];

    hmm_init_page_info();

    if (setjmp(_hmm_exit_env)) {
        return -1;
    }
    _hmm_exit_ok = 1;

    snprintf(pathname, sizeof(pathname), "/dev/hmm_dmirror");
    ctx->fd = open(pathname, O_RDWR, 0);
    if (ctx->fd < 0) {
        fprintf(stderr, "could not open hmm dummy driver (%s)\n", pathname);
        return -1;
    }

    return hmm_dummy_ctx_register(ctx);
}

void hmm_ctx_fini(struct hmm_ctx *ctx)
{
    close(ctx->fd);
    ctx->fd = -1;
}

unsigned long hmm_buffer_nbytes(struct hmm_buffer *buffer)
{
    return buffer->npages << page_shift;
}

struct hmm_buffer *hmm_buffer_new_anon(const char *name, unsigned long nbytes)
{
    struct hmm_buffer *buffer;
    unsigned long npages;

    hmm_init_page_info();

    if (!nbytes) {
        fprintf(stderr, "(EE) %s(%s).nbytes -> %ld\n", __func__, name, nbytes);
        hmm_exit();
    }

    npages = ALIGN(nbytes, page_size) >> page_shift;

    buffer = malloc(sizeof(*buffer));
    if (buffer == NULL) {
        fprintf(stderr, "(EE) %s(%s).malloc(struct)\n", __func__, name);
        hmm_exit();
    }

    buffer->fd = -1;
    buffer->name = name;
    buffer->npages = npages;

    // buffer->mirror = mmap(0, npages << page_shift,
    //                       PROT_READ | PROT_WRITE,
    //                       MAP_PRIVATE | MAP_ANONYMOUS,
    //                       -1, 0);
    // if (buffer->mirror == MAP_FAILED) {
    //     fprintf(stderr, "(EE) %s(%s).mmap(%ld)\n", __func__, name, npages);
    //     free(buffer);
    //     hmm_exit();
    // }

    // buffer->ptr = mmap(0, npages << page_shift,
    //                    PROT_READ | PROT_WRITE,
    //                    MAP_PRIVATE | MAP_ANONYMOUS,
    //                    -1, 0);
    // if (buffer->ptr == MAP_FAILED) {
    //     fprintf(stderr, "(EE) %s(%s).mmap(%ld)\n", __func__, name, npages);
    //     free(buffer);
    //     hmm_exit();
    // }

    buffer->mirror = malloc((npages + 1) * page_size);
    assert(buffer->mirror);
    buffer->mirror = (char*)(((unsigned long long)buffer->mirror + (page_size-1)) & ~((unsigned long long)page_size-1));  
//    memset(buffer->mirror, 0, npages * page_size);

    buffer->ptr = malloc((npages + 1) * page_size);
    assert(buffer->ptr);
    buffer->ptr = (char*)(((unsigned long long)buffer->ptr + (page_size-1)) & ~((unsigned long long)page_size-1));
//    memset(buffer->ptr, 0, npages * page_size);

    return buffer;
}

static void *read_thread_body(struct hmm_dmirror_read *p_read)
{
    // unsigned long i, size;
    // long ret;

    // do {
    //     ret = ioctl(g_ctx->fd, HMM_DMIRROR_READ, p_read);
    // } while (ret && (errno == EINTR));

    // if (ret) {     
    //     return (void*)ret;
    // }

    volatile long *x = (volatile long*)p_read->addr;

//    printf("[thread %d] %s:%d: addr 0x%p, value %ld\n", gettid(), __FUNCTION__, __LINE__, x, *x);

    // for (i = 0, ptr = p_read->ptr; i < size/sizeof(int); ++i) {
    //     if (ptr[i] != i) {
    //         fprintf(stderr, "(EE:%4d) invalid value [%ld] got %d expected %ld\n",
    //                 __LINE__, i, ptr[i], i);
    //         ret = -1;
    //         break;
    //     }
    // }

    return (void*)(*x);
}

static inline unsigned long MIN(unsigned long x, unsigned long y)
{
    return x<y?x:y;
}

int hmm_buffer_mirror_read(struct hmm_ctx *ctx,
                           struct hmm_buffer *buffer,
                           unsigned long spage,
                           unsigned long npages,
                           struct stats *stats)
{
    pthread_t threads[MAX_THREADS];
    struct hmm_dmirror_read reads[MAX_THREADS];
    int i = 0;
    unsigned long pages_per_thread = npages / MAX_THREADS;
    int ret;

//    printf("%s:%d [%s]: npages: %lu, threads: %d, pages_per_thread: %lu\n", __FUNCTION__, __LINE__, 
//        MULTITHREADED_READ? "parallel" : "serialized", npages, MAX_THREADS, pages_per_thread);

    // if (npages > MAX_THREADS) {
    //     printf("%s:%d: too many pages %lu\n", __FUNCTION__, __LINE__, buffer->npages);
    //     return -1;
    // }

    memset(threads, 0, sizeof(threads));

    g_ctx = ctx;

    for (i = 0; i < MAX_THREADS; i++) {
        reads[i].addr = (uintptr_t)buffer->ptr + (spage + i * pages_per_thread) * 4096;
        reads[i].ptr = (uintptr_t)buffer->mirror + (spage + i * pages_per_thread) * 4096;
        reads[i].cpages = 0;

        if (i * pages_per_thread < npages) {
            reads[i].npages = MIN(pages_per_thread, npages - i * pages_per_thread);
        }
        else {
            reads[i].npages = 0;
        }

        if (MULTITHREADED_READ) {
            pthread_create(&threads[i], NULL, (void*(*)(void*))&read_thread_body, &reads[i]);
        }
        else {
            ret = ioctl(ctx->fd, HMM_DMIRROR_READ, &reads[i]);

            if (ret && errno != EINTR) {
                // printf("%s:%d: HMM_DMIRROR_READ ret %d errno %d\n", __FUNCTION__, __LINE__, ret, errno);
                return ret;
            }

            if (reads[i].cpages != reads[i].npages) {
                printf("%s:%d: failed to read page at 0x%lx\n", __FUNCTION__, __LINE__, reads[i].addr);
                return ret;
            }
        }
    }

    if (MULTITHREADED_READ) {
        for (i = 0; i < MAX_THREADS; i++) {
            void *thr_ret;

            if (threads[i] == 0)
                continue;

            pthread_join(threads[i], &thr_ret);
            ret = (int)(long)thr_ret;

            // if (ret) {     
            //     printf("%s:%d: HMM_DMIRROR_READ error %d\n", __FUNCTION__, __LINE__, ret);  
            //     return ret;
            // }

            // if (reads[i].cpages != reads[i].npages) {
            //     printf("%s:%d: failed to read page at 0x%lx\n", __FUNCTION__, __LINE__, reads[i].addr);
            //     return ret;
            // }
        }
    }

    return 0;
}

static void *migrate_thread_body(struct hmm_dmirror_migrate *p_migrate)
{
    long ret;

    do {
        ret = ioctl(g_ctx->fd, HMM_DMIRROR_MIGRATE, p_migrate);
    } while (ret && (errno == EINTR));

    if (ret) {     
        printf("%s:%d: HMM_DMIRROR_MIGRATE error %d\n", __FUNCTION__, __LINE__, (int)ret);
        return (void*)ret;
    }

    return (void*)0;
}

int hmm_buffer_mirror_migrate_to(struct hmm_ctx *ctx,
                                 struct hmm_buffer *buffer,
                                 struct stats *stats)
{
    pthread_t threads[MAX_THREADS];
    struct hmm_dmirror_migrate migrates[MAX_THREADS];
    int i = 0;
    unsigned long pages_per_thread = buffer->npages / MAX_THREADS;
    int ret;

    // if (buffer->npages > MAX_THREADS) {
    //     printf("%s:%d: too many pages %lu\n", __FUNCTION__, __LINE__, buffer->npages);
    //     return -1;
    // }

    // printf("%s:%d [%s]: npages: %lu, threads: %d, pages_per_thread: %lu\n", __FUNCTION__, __LINE__,
    //     MULTITHREADED_MIGRATE? "parallel" : "serialized", buffer->npages, MAX_THREADS, pages_per_thread);

    memset(threads, 0, sizeof(threads));

    g_ctx = ctx;

    for (i = 0; i < MAX_THREADS; i++) {
        size_t npages;

        if (i * pages_per_thread < buffer->npages) {
            npages = MIN(pages_per_thread, buffer->npages - i * pages_per_thread);
        }
        else {
            npages = 0;
        }

        migrates[i].npages = npages;
        migrates[i].addr = (uintptr_t)buffer->ptr;
        migrates[i].addr += 4096 * (i * pages_per_thread);

        if (MULTITHREADED_MIGRATE) {
            pthread_create(&threads[i], NULL, (void*(*)(void*))&migrate_thread_body, &migrates[i]);
        }
        else {
            do {
                ret = ioctl(ctx->fd, HMM_DMIRROR_MIGRATE, &migrates[i]);
            } while (ret && (errno == EINTR));

            if (ret) {     
                printf("%s:%d: HMM_DMIRROR_MIGRATE error %d\n", __FUNCTION__, __LINE__, ret);  
                return ret;
            }

            if (migrates[i].npages != npages) {
                // printf("%s:%d: failed to read pages at 0x%lx (migrates[%d].npages: %lu != npages: %lu)\n", __FUNCTION__, __LINE__,
                //     migrates[i].addr, i, migrates[i].npages, npages);
                return ret;
            }
        }
    }
 
    if (MULTITHREADED_MIGRATE) {
        for (i = 0; i < MAX_THREADS; i++) {
            void *thr_ret;
            size_t npages;

            if (i * pages_per_thread < buffer->npages) {
                npages = MIN(pages_per_thread, buffer->npages - i * pages_per_thread);
            }
            else {
                npages = 0;
            }

            if (threads[i] == 0)
                continue;

            pthread_join(threads[i], &thr_ret);
            ret = (int)(long)thr_ret;

            if (ret) {     
                printf("%s:%d: HMM_DMIRROR_MIGRATE error %d\n", __FUNCTION__, __LINE__, ret);  
                return ret;
            }

            if (migrates[i].npages != npages) {
                printf("%s:%d: failed to read pages at 0x%lx (migrates[%d].npages: %lu != npages: %lu)\n", __FUNCTION__, __LINE__,
                    migrates[i].addr, i, migrates[i].npages, npages);
                return ret;
            }
        }
    }

    return ret;
}

void hmm_buffer_free(struct hmm_buffer *buffer)
{
    if (buffer == NULL) {
        return;
    }

    // munmap(buffer->mirror, buffer->npages << page_shift);
    // munmap(buffer->ptr, buffer->npages << page_shift);
    free(buffer);
}
