# This demo program solves the equations of static
# linear elasticity for a 2D square

__author__ = "Jake Ostien (jtostie@sandia.gov)"
__date__ = "2008-10-13"
__copyright__ = "Copyright (C) 2008 Jake Ostien"
__license__  = "GNU LGPL Version 2.1"

from dolfin import *

# Load mesh and create finite element
mesh = UnitSquare(1,1)
element = VectorElement("DG", "triangle", 1)

# Sub domain for x = 0
class BCX0(SubDomain):
    def inside(self, x, on_boundary):
        return bool(x[0] < DOLFIN_EPS)

# Sub domain for x = 1
class BCX1(SubDomain):
    def inside(self, x, on_boundary):
        return bool(abs(x[0]-1.0) < DOLFIN_EPS)

# Sub domain for y = 0
class BCY0(SubDomain):
    def inside(self, x, on_boundary):
        return bool(x[1] < DOLFIN_EPS)


# Dirichlet boundary condition for tension
pull = Function(mesh, 0.1)

# Initialise source function
f = Function(element, mesh, 2, 0.0)

# zero function
zero = Function(mesh, 0.0)

# Define variational problem
# Test and trial functions
v = TestFunction(element)
u = TrialFunction(element)

n = FacetNormal("triangle", mesh)
h = AvgMeshSize("triangle", mesh)

E  = 10.0
nu = 0.3

alpha = 4.0

mu    = E / (2*(1 + nu))
lmbda = E*nu / ((1 + nu)*(1 - 2*nu))

def epsilon(v):
    return 0.5*(grad(v) + transp(grad(v)))

def sigma(v):
    return 2.0*mu*epsilon(v) + lmbda*mult(trace(epsilon(v)), Identity(len(v)))

a = dot(grad(v), sigma(u))*dx \
    - dot(jump(v),avg(mult(sigma(u),n)))*dS \
    - dot(jump(u),avg(mult(sigma(v),n)))*dS \
    + alpha*E/h('+')*dot(jump(v),jump(u))*dS

L = dot(v, f)*dx

# subsystems
X = SubSystem(0)
Y = SubSystem(1)

# Set up boundary conditions
b1 = BCX0()
dbc1 = DirichletBC(zero, mesh, b1, X, geometric)
b2 = BCY0()
dbc2 = DirichletBC(zero, mesh, b2, Y, geometric)
b3 = BCX1()
dbc3 = DirichletBC(pull, mesh, b3, X, geometric)

# Set up boundary conditions
#bcs = [dbc1, dbc2, dbc3]

# Set up PDE and solve
#pde = LinearPDE(a, L, mesh, bcs)
A = assemble(a, mesh)
b = assemble(L, mesh)
dbc1.apply(A, b, a)
dbc2.apply(A, b, a)
dbc3.apply(A, b, a)

#u   = pde.solve()
x = Vector()

solve(A, x, b)

u = Function(element, mesh, x)

# project solution
P1 = VectorElement("CG", "triangle", 1)
u_proj = project(u, P1)

# Save solution to VTK format
vtk_file = File("elasticity.pvd")
vtk_file << u_proj

# Save solution to XML format
xml_file = File("elasticity.xml")
xml_file << u_proj


# Plot solution
plot(u_proj, mode="displacement")
interactive()
