import argparse
import libvirt
import logging
from multiprocessing import Process
import os
import random
import threading
import time

def manageVM(endSignal, workerNumber, statefile, loggingLevel):
    logging.basicConfig(level = loggingLevel)
    logging.debug("Worker " + str(workerNumber) + ": begins work")
    domainName = os.path.basename(statefile).split(".")[0]
    domainID = 0
    domain = None
    connection = libvirt.open(None)
    if connection == None:
        logging.error('Failed to open connection to hypervisor')
        exit(1)
    logging.debug("Worker " + str(workerNumber) + ": connected to libvirt")

    while not endSignal.wait(3):
        sleepTime = random.randint(20,30)
        logging.debug("Worker " + str(workerNumber) + ": starts VM from statefile " + str(statefile))
        try:
            id = connection.restore(statefile)
            if id < 0:
                logging.error('Unable to restore guest from ' + statefile)
                exit(1)
        except libvirt.libvirtError as err:
            logging.error("Worker " + str(workerNumber) + ": Error when restoring domain " + domainName + " " + str(err))
        logging.debug("Worker " + str(workerNumber) + ": Restored domain. Check if it succeeded.")
        try:
            domain = connection.lookupByName(domainName)
            if domain == None:
                logging.error('Unable to find guest with name ' + domainName)
                exit(1)
            domainID = domain.ID()
        except libvirt.libvirtError as err:
            logging.error("Worker " + str(workerNumber) + ": Error when looking up domain " + domainName + " " + str(err))
        logging.debug("Worker " + str(workerNumber) + ": VM " + domainName + " started (id " + str (domainID) + "), waiting " + str(sleepTime))
        time.sleep(sleepTime)
        logging.debug("Worker " + str(workerNumber) + ": sleep time up. Destroying VM")
        try:
            if domain:
                domain.destroy()
        except libvirt.libvirtError as err:
            logging.error("Worker " + str(workerNumber) + ": Error when managing domain " + domainName + " " + str(err))
        logging.debug("Worker " + str(workerNumber) + ": VM destroyed")
    logging.debug("Worker " + str(workerNumber) + ": ending")

def startTestRun(stateFiles, loggingLevel):
    activeThreads = []
    logging.debug("Starting worker threads")
    endSignal = threading.Event()
    for i in range (len(stateFiles)):
        vmWorkerThread = threading.Thread(target = manageVM, args=(endSignal, i, stateFiles[i], loggingLevel))
        activeThreads.append(vmWorkerThread)
        vmWorkerThread.start()
        time.sleep(5)
    logging.debug("Waiting for program termination")
    try:
        while 1:
            time.sleep(.1)
    except KeyboardInterrupt:
        logging.debug("Got Interrupt")
        endSignal.set()
    for thread in activeThreads:
        thread.join()
    logging.debug("All workers finished")

def inputSanityCheck(args):
    numeric_level = getattr(logging, args.loggingLevel.upper(), None)
    if not isinstance(numeric_level, int):
        raise ValueError('Invalid log level: %s' % args.loggingLevel)
    logging.basicConfig(level = args.loggingLevel)

    if not args.stateFiles:
        raise ValueError("No state files given. Aborting")

    for file in args.stateFiles:
         if not os.path.exists(file):
             raise ValueError("State file " + file + " does not exist. Aborting")

def main():
    parser = argparse.ArgumentParser(description='Restore and destroy VMs from  statefile based on virsh.')
    parser.add_argument('-FILE', action='append', dest='stateFiles',
                        default=[],
                        help='List of state files from which to restore the VMs')
    parser.add_argument('--loggingLevel' ,
                        default = 'WARNING',
                        help='Set the logging level. Valid levels are: DEBUG, INFO, WARNING (default), ERROR, CRITICAL')
    args = parser.parse_args()
    print (args)

    inputSanityCheck(args)
    stateFiles = args.stateFiles

    loggingLevel = args.loggingLevel

    startTestRun(stateFiles, loggingLevel)

if __name__ == '__main__':
    main()