# -*- coding: utf-8 -*-
"""
Created on Fri Aug 05 15:32:49 2016

@author: korevma
"""

from fipy import (Grid1D, CellVariable, TransientTerm, DiffusionTerm,
                  ImplicitSourceTerm, UpwindConvectionTerm, PowerLawConvectionTerm, Viewer,
                  MatplotlibViewer, VanLeerConvectionTerm)
import fipy
import numpy as np

import matplotlib.pyplot as plt
import matplotlib as mpl
import matplotlib.mlab as mlab


class DiscreteSlider(mpl.widgets.Slider):
    """A matplotlib slider widget with discrete steps."""
    def __init__(self, *args, **kwargs):
        """
        Identical to Slider.__init__, except for the new keyword 'allowed_vals'.
        This keyword specifies the allowed positions of the slider
        """
        self.allowed_vals = kwargs.pop('allowed_vals', None)
        self.previous_val = kwargs['valinit']
        mpl.widgets.Slider.__init__(self, *args, **kwargs)
        if self.allowed_vals is None:
            self.allowed_vals = [self.valmin, self.valmax]

    def set_val(self, val):
        discrete_val = self.allowed_vals[abs(val - self.allowed_vals).argmin()]
        xy = self.poly.xy
        xy[2] = discrete_val, 1
        xy[3] = discrete_val, 0
        self.poly.xy = xy
        self.valtext.set_text(self.valfmt % discrete_val)
        if self.drawon:
            self.ax.figure.canvas.draw()
        self.val = val
        if self.previous_val != discrete_val:
            self.previous_val = discrete_val
            if not self.eventson:
                return
            for cid, func in self.observers.iteritems():
                func(discrete_val)

class ChangingPlot(object):
    def __init__(self, time_steps, xdata, ydata, labels):
        """ timestepts to plot (basically the indices in ydata to plot)
            xdata of size 1xN
            ydata size n_data x t_n x N or t_n x N
            with n_data the number of datasets to plot (number of lines) and
            t_n the number of timesteps
            N the number of sample poitns (e.g. coordinates """

        self.time_steps = time_steps
        self.xdata = xdata
        if len(np.shape(ydata)) == 1:
            raise Exception('ydata should at least have to dimensions')
        elif len(np.shape(ydata)) == 2:
            ydata = [ydata]

        if len(ydata) != len(labels):
            raise Exception('Number of labels should equal number of datasets')
        self.ydata = ydata

        self.ndim = np.shape(ydata)[0]
        self.fig, self.ax = plt.subplots()
        for ii in range(self.ndim):
            # loop over the different datasets
            self.ax.plot(xdata, ydata[ii][0], label=labels[ii])
#            self.lines[ii] = line_h

        (min_x, max_x) = min(xdata), max(xdata)

        width = max_x - min_x
        new_min = min_x - 0.1 * width
        new_max = max_x + 0.1 * width
        self.ax.set_xlim(new_min, new_max)

        self.lines = self.ax.lines

        self.sliderax     = self.fig.add_axes([0.2, 0.02, 0.6, 0.03],axisbg='cyan')
        self.slider       = DiscreteSlider(self.sliderax, 'Time step', 0, len(time_steps),\
                                           allowed_vals=time_steps, valinit=time_steps[0])

        self.slider.on_changed(self.update)

    def update(self, value):
        min_y = np.min(self.ydata[0][value][:])
        max_y = np.max(self.ydata[0][value][:])
        for ii in range(self.ndim):
            # loop over the different datasets
            ydata = self.ydata[ii][value]
            self.lines[ii].set_ydata(ydata)
            min_y, max_y = min(min_y, min(ydata)), max(max_y, max(ydata))

        height = max_y - min_y
        (new_min, new_max) = self.ax.get_ylim()
        if min_y < self.ax.get_ylim()[0] + 0.1 * height:
            new_min = min_y - 0.1 * height
        if max_y > self.ax.get_ylim()[1] - 0.1 * height:
            new_max = max_y + 0.1 * height
        self.ax.set_ylim(new_min, new_max)
        self.fig.show()

    def show(self):
        plt.show()


A1 = 0.36
A2 = 0.48
A3 = 0.0148
A4 = 1.e-8
u = 1.6e-3
K_Fr = 386
n_Fr = 0.33

L = 2.
nx = 100

t_tot = 20 * 60
dt = 0.5

c_in = 1.79e-3
# Create mesh and variable of the equations
mesh = Grid1D(nx=nx, Lx=L)

c_b = CellVariable(mesh=mesh, hasOld=True,
                                  value=1.79e-12, name=r"$bar{c}_" + r"$")
c_f = CellVariable(mesh=mesh, hasOld=True,
                                 value=1.79e-12, name="$c_{f" + r"}$")
q_b = CellVariable(mesh=mesh, hasOld=True,
                                 value=0., name=r"$bar{q}_" + r"$")

# linearizing the Freundlich equation
q_s = K_Fr * c_f ** n_Fr
q_s1 = (K_Fr / n_Fr) * c_f ** (n_Fr - 1.)
q_s0 = q_s - q_s1 * c_f

# the equation for the bulk:
eq_bulk = (TransientTerm(var=c_b) ==
                             - PowerLawConvectionTerm(coeff=[u], var=c_b)
        #                     - VanLeerConvectionTerm(coeff=[u], var=c_b) # This does not help
                             - (
                              A1 * c_b -
                              A1* c_f
                             ) )

# equation of the film
eq_film = (TransientTerm(var=c_f) ==
                               (
                                ImplicitSourceTerm(coeff=A2, var=c_b) -
                                ImplicitSourceTerm(coeff=A2, var=c_f)
                               ) -
                               (
                                (A3 * q_s0+
                                 ImplicitSourceTerm(coeff=A3 * q_s1, var=c_f)
                                ) -
                                ImplicitSourceTerm(coeff=A3, var=q_b)
                               ) )

# equation of the solid
eq_surf_sol = (TransientTerm(var=q_b) ==
                                (A4 * q_s0 +
                                 ImplicitSourceTerm(coeff=A4 * q_s1,
                                                    var=c_f)
                                ) -
                                ImplicitSourceTerm(coeff=A4, var=q_b)
                            )
# Set bound. Cond.
# - Dirichlet at inlet only for bulk concentration
c_b.constrain(c_in, mesh.facesLeft)
# - Neuman for all components at outlet
c_b.faceGrad.constrain(0., mesh.facesRight)
c_f.faceGrad.constrain(0., mesh.facesRight)
q_b.faceGrad.constrain(0., mesh.facesRight)

# Then couple them
eq = eq_bulk & eq_film & eq_surf_sol

# select solver
solver = fipy.solvers.LinearLUSolver()  #  I tried several, this
# start time loop
t = 0.
max_residual_nonlin = 1e-6
n_max_allowed_iter = 50
restart = False
plt.close('all')
fig_h, ax_h = plt.subplots()
coord = mesh.cellCenters.value[0]

while (t < t_tot):

    if restart:
        dt /= 2.
        # Copy solution of previous timestep as estimator for this timestep
        for var in [c_b, c_f, q_b]:
            var.value = var.old.value
        restart = False

    n_iter_nonlin = 0
    residual = 999.
    # Solve iteratively with sweep
    while (residual > max_residual_nonlin) & (n_iter_nonlin < n_max_allowed_iter):
        residual = eq.sweep(dt=dt, solver=solver)
        n_iter_nonlin += 1

#        if any(np.hstack([c_b.value, c_f.value, q_s.value, q_b.value] ) < 0.):
#            print "********************************************************************"
#            print "Negative concentrations, restart this timestep with smaller dt = " + str(dt / 2.)
#            print "********************************************************************"
#            restart = True
#            break

    if restart:
        continue
#
    if residual > max_residual_nonlin:
        print "********************************************************************"
        print "Nonlinear did not converge, restart this timestep with smaller dt = " + str(dt / 2.)
        print "********************************************************************"
        restart = True
        continue
    else:
        restart = False

    if (n_iter_nonlin <= 2):
#        if dt <= 0.5:
#            dt = 2. * dt
#        else:
#            dt = dt + 0.01
        dt = 1.1 * dt
    t += dt

    plt.cla()
    line_h = []
    line_h.append(plt.plot(coord, c_b.value, label=r'$\bar{c}$')[0])#r'$\bar{c}_b$'))
    line_h.append(plt.plot(coord, c_f.value, label=r'$c_{f}$')[0])
#    line_h.append(plt.plot(coord, q_s.value, label=r'$q_{s}$')[0])
#    line_h.append(plt.plot(coord, q_f.value, label=r'$\bar{q}$')[0])

    plt.legend(handles=line_h)
    plt.ylim(0., c_in / 2.)
    plt.draw()
    plt.pause(0.015)

    c_b.updateOld()
    c_f.updateOld()
    q_b.updateOld()

    print "time =" + str(t) + ' with dt is ' + str(dt) + ' number of iterations is ' + str(n_iter_nonlin)
