import dolfin
import numpy
import sympy


def mi_factorial( alpha ):
    if len( alpha ) != 0:
        return reduce( lambda a,b : a*b , [ sympy.factorial(a) for a in alpha ] )
    else:
        raise Exception, "barf"
def mi_sum( alpha , beta ):
    return [ alpha[i]+beta[i] for i in range( len( alpha ) ) ]

# let's get the local-global mapping on a given dolfin mesh
# with Bernstein polynomials of a given degree

def berntuples( k , n ):
    if n ==0 :
        return [ [ k ] ]
    elif n == 1:
        btsi0 = [ berntuples(i,0) for i in range(0,k+1) ]
        bts = [ [ [k-i] + btsi0[i][j] \
                      for j in range( 0, len(btsi0[i] ) ) ] \
                    for i in range(0,k+1) ]
        return reduce( lambda a,b:a+b , bts )
    else:
        btsnm1i = [ berntuples(i,n-1) for i in range(0,k+1) ]
        bts = [ [ [k-i] + btsnm1i[i][j] \
                    for j in range( 0 , len(btsnm1i[i] ) ) ] \
                      for i in range(0,k+1) ]
        return reduce( lambda a,b:a+b , bts )

def bern_tuples( k , n ):
    return map( tuple , berntuples( k , n ) )

def idx_tuples( bts ):
    """Given a list of tuples, returns the dictionary
    mapping each tuple to its index in the list"""
    res = {}
    for i in range( len( bts ) ):
        res[bts[i]] = i

    return res

def nnz( t ):
    n = 0
    for ti in t:
        if ti:
            n += 1
    return n

def first_nonzero_index( t ):
    for i in range( len( t ) ):
        if t[i]:
            return i

def first_zero_index( t ):
    for i in range( len( t ) ):
        if not t[i]:
            return i

def topologize_tuples( bts ):
    ordering = idx_tuples( bts )

    # for each tuple, it belongs to a vertex
    # if two entries are zero

    dofmap = {0:{0:[],1:[],2:[]},1:{0:[],1:[],2:[]},2:{0:[]}}

    for t in bts:
        if nnz( t ) == 1:
            dofmap[0][first_nonzero_index(t)].append( ordering[t] )
        elif nnz( t ) == 2:
            dofmap[1][first_zero_index(t)].append( ordering[t] )
        else:
            dofmap[2][0].append( ordering[t] )
    
    return dofmap
        
def bern_mass_exact( d , m , n ):
    mfact = sympy.factorial( m )
    nfact = sympy.factorial( n )
    mndfact = sympy.factorial( m+n+d )
    
    alphas = bern_tuples( m , d )
    betas = bern_tuples( n , d )

    def entry( alpha , beta ):
        return mfact * nfact * mi_factorial( mi_sum( alpha , beta ) ) \
            / mi_factorial( alpha ) / mi_factorial( beta ) / mndfact

    return sympy.Matrix( [ [ entry( alpha , beta ) for beta in betas ] \
                 for alpha in alphas ] )

def bern_mass_numeric( d , m , n = -1 ):
    if n == -1:
        n = m
    foo = numpy.array( 1.0 * bern_mass_exact( d , m , n ) )
    bar = numpy.zeros( foo.shape , dtype=numpy.float64 )
    bar[:,:] = foo
    return bar

def poly_dim( d , k ):
    dim = sympy.Number( 1 )
    for i in range( 1, d+1 ):
        dim *= (k+i)
    return int( dim / sympy.Factorial( d ) )

def ltg2d( m , k ):
    """Generates the local-global mapping for Bernstein
    polynomials of degree k on a mesh m.  Returned as 
    a 2d numpy array ltg[e,i] as the global node number of the
    i:th local node on cell e"""
    ltg_array = numpy.zeros( (m.numCells(),(k+1)*(k+2)/2), "i" )
    
    # make sure I have the edges, etc ordered
    m.init()
    if not  m.ordered():
        m.order()

    bts = bern_tuples( k , 2 )

    localdof = topologize_tuples( bts )

    dof_per_vertex = len( localdof[0][0] )
    dof_per_edge = len( localdof[1][0] )
    dof_per_cell = len( localdof[2][0] )

    nv = m.numVertices()
    ne = m.numEdges()
    nc = m.numCells()

    for c in dolfin.cells( m ):
        vcur = 0
        for v in dolfin.vertices( c ):
            ltg_array[c.index(),localdof[0][vcur][0]] = dof_per_vertex * v.index()
            vcur += 1

        ecur = 0
        # compare to v loop, but also need loop over internal dof
        for e in dolfin.edges( c ):
            ldofcur = 0
            for ldof in localdof[1][ecur]:
                ltg_array[c.index(),localdof[1][ecur][ldofcur]] = \
                          dof_per_vertex*nv + dof_per_edge * e.index() + ldofcur
                ldofcur += 1
            ecur += 1
    

        ldofcur = 0
        for ldof in localdof[2][0]:
            ltg_array[c.index(),localdof[2][0][ldofcur]] = \
                dof_per_vertex*nv+dof_per_edge*ne \
                + dof_per_cell*c.index()+ldofcur
            ldofcur += 1

    return ltg_array            
    
class BernMass:
    def __init__( self , d , m ):
        self.d,self.m = d,m
        self.dfact = float( sympy.factorial( self.d ) )
        self.Mref = bern_mass_numeric( d , m )
    def element_mass( self , c ):
        return self.Mref * self.dfact * c.volume()
    def dimension( self , mesh ):
        dim_cur = mesh.numVertices()
        if self.m > 1:
            dim_cur += mesh.numEdges() * ( self.m - 1 )
        if self.m > 2:
            dim_cur += mesh.numFaces() * (self.m-2)*(self.m-1)/2
        return dim_cur
    def assemble_matrix( self , mesh ):
        dim = self.dimension( mesh )

        A = dolfin.uBLASSparseMatrix( dim , dim )
        ltg = ltg2d( mesh, self.m )


        for c in dolfin.cells( mesh ):
            elmass = self.element_mass( c )

            ltgcur = ltg[c.index(),:]
            for i in range( len( ltgcur ) ):
                for j in range( len( ltgcur ) ):
                    A[ltgcur[i],ltgcur[j]] = A[ltgcur[i],ltgcur[j]]+ elmass[i,j]

        A.apply()

        return A

def apply_elevate_1d_transpose( n , x ):
    y = numpy.zeros( (x.shape[0]-1,) , numpy.float64 )
    for j in range(n):
        y[j] = ( (n-j) * x[j] + ( j+1 ) * x[j+1] ) / n
    return y


class BernMassAction2D( dolfin.uBLASKrylovMatrix ):
    def __init__( self , mesh , degree ):
        self.degree = degree
        self.ltg = ltg2d( mesh , degree )
        numdof = mesh.numVertices()
        if degree > 1:
            numdof += (degree-1) * mesh.numEdges()
            if degree > 2:
                numdof += (degree-2)*(degree-1)/2 * mesh.numCells()
        self.numdof = numdof
        self.localnumdof = (degree+1)*(degree+2)/2
        
        # I should set up local storage to
        # hold the scattered degrees of freedom x
        self.xloc = numpy.zeros( (mesh.numCells(),self.localnumdof) , \
                                     numpy.float64 )

        self.bglob = numpy.zeros( (self.numdof,1), numpy.float64 )
        self.bloc = numpy.zeros( self.xloc.shape , numpy.float64 )

        self.Mref = bern_mass_numeric( 2 , degree ) 

        # get an array of 1d mass matrices
        self.mass1d = [ bern_mass_numeric( 1 , degree, degree - beta1 ) \
                       for beta1 in range( degree+1 ) ]

        # set up local store y , z for application algorithm
        self.y = numpy.zeros( (self.localnumdof,1) , numpy.float64 )
        self.zold = numpy.zeros( self.y.shape , numpy.float64 )
        self.znew = numpy.zeros( self.y.shape , numpy.float64 )
        self.elz = numpy.zeros( self.y.shape , numpy.float64 )

    def size( self , dim ):
        return self.numdof

    def multold( self, x , b ):
        # first, scatter x to elements
        for c in dolfin.cells( mesh ):
            cellnum = c.index()
            ltgcur = self.ltg[cellnum,:]
            for i in range( len( ltgcur ) ):
                self.xloc[cellnum,i] = x[ltgcur[i]]
        
        # now, apply element mass matrix to each cell,
        # storing result in self.bloc
        for c in dolfin.cells( mesh ):
            self.bloc[c.index(),:] = 2.0 \
                * numpy.dot( self.Mref , self.xloc[c.index(),:] ) \
                * c.volume()

        # now, gatther xloc elements into b
        self.bglob *= 0

        for c in dolfin.cells( mesh ):
            cellnum = c.index()
            ltgcur = self.ltg[cellnum,:]
            for i in range( len( ltgcur ) ):
                cellnum,i,ltgcur[i]
                self.bglob[ltgcur[i]] += self.bloc[cellnum,i]

        # put values into b
        b.set( self.bglob )

    def mult( self, x , b ):
        # first, scatter x to elements
        for c in dolfin.cells( mesh ):
            cellnum = c.index()
            ltgcur = self.ltg[cellnum,:]
            for i in range( len( ltgcur ) ):
                self.xloc[cellnum,i] = x[ltgcur[i]]
        
        # now, apply element mass matrix to each cell,
        # storing result in self.bloc
        for c in dolfin.cells( mesh ):
            cellnum = c.index()
            self.bloc[cellnum,:] = 2.0 * c.volume() \
                * self.applymasslocal( self.degree , self.degree , self.xloc[cellnum,:] )

        # now, gatther xloc elements into b
        self.bglob *= 0
        for c in dolfin.cells( mesh ):
            cellnum = c.index()
            ltgcur = self.ltg[cellnum,:]
            for i in range( len( ltgcur ) ):
                self.bglob[ltgcur[i]] += self.bloc[cellnum,i]

        # put values into b
        b.set( self.bglob )

    def applymasslocal( self , m , n , x ):
        y = numpy.zeros( (poly_dim(2 , m) , ) , numpy.float64 )
        z = numpy.zeros( (poly_dim(2 , m) , ) , numpy.float64 )

        # set up indexing into x as function of beta_1
        xstart = poly_dim( 2 , n ) - poly_dim( 1 , n )
        xstop =  poly_dim( 2 , n )

        kappa = 1.0/( m + n + 2.0 )

        for beta1 in range(n+1):
            # figure out start and stop of xbeta1
            xbeta1 = x[xstart:xstop]
        
            # update start and stop
            xstop = xstart
            xstart = xstart - poly_dim(1,n-beta1-1)
        
            # get lower-dimension mass matrix
            mlower = self.mass1d[beta1]

            znew = kappa * numpy.dot( mlower , xbeta1)
            kappa *= (n - beta1) / ( m + n + 1.0 - beta1)

            # set up y indexing
            ystart = poly_dim( 2 , m ) - poly_dim( 1 , m )
            ystop = poly_dim( 2 , m )

            y[ystart:ystop] += znew

            for alpha1 in range(m):
                ystop=ystart
                ystart-=poly_dim(1,m-alpha1-1)
                zold = znew

                znew = (alpha1+1+beta1)*(m-alpha1) \
                    * apply_elevate_1d_transpose( m-alpha1 , zold ) \
                    / (m+n+1.0-alpha1-beta1)/(alpha1+1.0)
                y[ystart:ystop] += znew

        return y


if __name__=="__main__":
    N = 2
    deg = 3
    mesh = dolfin.UnitSquare( N , N )
    mesh.init()
    mesh.order()

    BM = BernMass( 2 , deg )
    gdim = BM.dimension( mesh )

    BMA = BernMassAction2D( mesh , deg )

    A = BM.assemble_matrix( mesh )

    x = dolfin.uBLASVector( gdim )
    
    x.set( numpy.arange( gdim , dtype=numpy.float64 ) )

    y1 = dolfin.uBLASVector( gdim )
    y2 = dolfin.uBLASVector( gdim )
    y3 = dolfin.uBLASVector( gdim )

    A.mult( x , y1 )
    BMA.multold( x , y2 )

    print max(abs((y1-y2).array()))

    BMA.mult( x , y2 )

    print max(abs((y1-y2).array()))

    S = dolfin.uBLASKrylovSolver( )
    S.solve( BMA ,y1 , x )






    
