#!/usr/bin/env python

from threading import Thread
import threading
import inspect
from dummy_ber_simulation import *


class SimulationThread(Thread):

    def __init__(self, kwargs):
        Thread.__init__(self, group=None, target=None, name=kwargs['name'], args=[], kwargs={}, verbose=False)
        self.tb = None
        self.kwargs = kwargs.copy()  # make sure to have a local copy. Otherwise setup will alter it all the time.
        self.kwargs['total_time'] = 0.0
        self.setup_start = None
        self.setup_fg(kwargs)

    def setup_fg(self, kwargs):
        print('DEFAULT FLOWGRAPH. Using DUMMY')
        self.setup_dummy_fg(kwargs)
        self.kwargs['code_rate'] = self.get_code_rate()
        self.kwargs['codeword_size'] = self.get_codeword_size()

    def setup_dummy_fg(self, kwargs):
        reqs = get_required_kwargs(get_class_arg_list(dummy_ber_simulation), kwargs)
        self.tb = dummy_ber_simulation(**reqs)

    def get_code_rate(self):
        return self.tb.code_rate

    def get_codeword_size(self):
        return self.tb.block_size

    def result(self):
        return [self.total_errors(), self.total_items()]

    def runtime(self):
        if self.kwargs['total_time'] > 0.0:
            return self.kwargs['total_time']
        elif self.setup_start:
            t = self.kwargs['total_time']
            if not t:
                t = time.time()
            return t - self.setup_start
        return 0.0

    def common_static_info_string(self):
        info_string = 'chan: {}, Eb/N0={:.2f}dB, N={:<5d}, R={:.2f}'
        return info_string.format(self.kwargs['channel'], self.kwargs['ebno'], self.kwargs['codeword_size'], self.kwargs['code_rate'])

    def static_info_string(self):
        info_string = self.common_static_info_string()
        return info_string

    def progress_status(self):
        bc = self.total_errors()
        bc = min(self.kwargs['min_error_count'], bc)
        return 1. * bc / self.kwargs['min_error_count']

    def total_errors(self):
        if not self.tb:
            return self.kwargs['total_errors']
        return self.tb.ber_counter.total_errors()

    def total_items(self):
        if not self.tb:
            return self.kwargs['total_items']
        if self.tb.ber_counter.nitems_read(0) > 0:  # did not produce any output.
            return self.tb.ber_counter.nitems_read(0)
        else:
            return 1

    def ber(self):
        return 1. * self.total_errors() / self.total_items()

    def save_results(self):
        self.kwargs['total_errors'] = self.total_errors()
        self.kwargs['total_items'] = self.total_items()
        self.kwargs['ber'] = self.ber()
        filename = self.kwargs['result_dir'] + 'SIMRES_' + self.kwargs['fg_type']
        filename += '_' + str(self.kwargs['codeword_size'])
        filename += '_' + str(self.kwargs['ebno'])
        filename += '_' + str(time.time())  # make sure previous results are not overwritten.
        filename += '.npy'
        np.save(filename, self.kwargs)

    def progress_info_string(self):
        info_string = 'SIM [' + self.static_info_string() + '] {:.2f}s BER={:.4e} ({}/{})'
        info_string = info_string.format(self.runtime(), self.ber(), self.total_errors(), self.total_items())
        return info_string

    def run(self):
        self.setup_start = time.time()
        print('SIM [' + self.static_info_string() + '] minErrs={} START'.format(self.kwargs['min_error_count']))
        self.tb.run()
        self.kwargs['total_time'] = time.time() - self.setup_start
        self.save_results()
        print('\n' + self.progress_info_string())
        self.tb = None  # try to free resources early!


def get_common_kwargs(ebno, min_error_count=2 ** 14, noise_seed=0, bit_seed=0, sig_ampl=1.0,
                      name=None, fg_type='DUMMY', channel=None, res_dir=None):
    #  Those are simulation values common to all simulations
    if not res_dir:
        res_dir = 'sim_results/'
    if not channel:
        channel = fg_type
    kwargs = {
        'name': name,
        'ebno': ebno,
        'min_error_count': min_error_count,
        'noise_seed': noise_seed,
        'bit_seed': bit_seed,
        'sig_ampl': sig_ampl,
        'result_dir': res_dir,
        'channel': channel,
        'fg_type': fg_type,
    }
    return kwargs
    

def get_dummy_kwargs(ebno,
                     min_error_count=2 ** 14, noise_seed=0, bit_seed=0, sig_ampl=1.0, name=None):
    if not name:
        name = 'ST_dummy_' + str(2048) + '_' + str(ebno)
    kwargs = get_common_kwargs(ebno, min_error_count, noise_seed, bit_seed, sig_ampl, name)
    kwargs['block_size'] = 2048
    return kwargs


def start_ber_curve_simulation(ebnos, kwargs):
    simulations = []
    for v in ebnos:
        kwargs['ebno'] = v
        s = SimulationThread(kwargs)
        simulations.append(s)
        s.start()
        # time.sleep(1.0)
    return simulations


def get_ber_curve_results(simulations):
    res = []
    for s in simulations:
        r = s.ber()
        res.append(r)
    return res


def flatten_simulation_list(simulations):
    sims = []
    for line in simulations:
        for s in line:
            sims.append(s)
    return sims


def print_simulation_summary(sims, name):
    print('\nSimulation Summary: ' + name)
    for s in sims:
        if s.is_alive():
            print(s.progress_info_string() + ' {:.2f}'.format(100. * s.progress_status()))


def print_runtime_summary(sims, start_time):
    prog = 0.0
    longest_ber = 1.0
    cnt = 0
    for s in sims:
        if s.is_alive():
            prog += s.progress_status()
            cnt += 1
        if longest_ber > s.ber():
            longest_ber = s.ber()
    prog /= 1. * cnt
    nthreads = threading.active_count()
    nsims = (nthreads - 1) // 2  # -1 for main thread. // 2 because every simulation spawns 2 threads.
    runtime = time.time() - start_time
    print('RUNTIME {:.1f}s progress {:.2f}, lowest-BER = {:.4e}, active sims={} [nthreads={}]'.format(
        runtime, 100. * prog, longest_ber, nsims, nthreads))


def monitor_simulation(simulations, start_time=time.time(), name=''):
    sims = flatten_simulation_list(simulations)
    summary_rate = 50
    thread_timeout = 1.0
    summary_counter = summary_rate + 1

    while threading.active_count() - 1:  # don't attempt to join main thread
        for s in sims:
            if s.is_alive():
                print_runtime_summary(sims, start_time)
                s.join(thread_timeout)
                summary_counter += 1
                if summary_counter > summary_rate:
                    summary_counter = 0
                    print_simulation_summary(sims, name)


def get_required_kwargs(reqs, kwargs):
    d = {}
    for r in reqs:
        d[r] = kwargs[r]  # this assumes dict entries have the same name. Seems legit.
    if d['channel'] == 'TVBSC':
        d['channel'] = 'BSC'
        d['design_snr'] += 30.0
    return d


def get_class_arg_list(ins_class):
    return get_function_arg_list(ins_class.__init__)[1:]


def get_function_arg_list(func):
    return inspect.getargspec(func).args


def get_common_elements(l, d):
    c = []
    for i in d:
        if i in l:
            c.append(i)
    return c


def main():
    t = SimulationThread(get_dummy_kwargs(0.0))
    print(t.get_code_rate())


if __name__ == '__main__':
    main()
