// ---------------------------------------------------------------------
//
// Copyright (C) 2013 - 2015 by the deal.II authors
//
// This file is part of the deal.II library.
//
// The deal.II library is free software; you can use it, redistribute
// it, and/or modify it under the terms of the GNU Lesser General
// Public License as published by the Free Software Foundation; either
// version 2.1 of the License, or (at your option) any later version.
// The full text of the license can be found in the file LICENSE at
// the top level of the deal.II distribution.
//
// ---------------------------------------------------------------------

#include <deal.II/grid/tria.h>
#include <deal.II/grid/tria_iterator.h>
#include <deal.II/grid/tria_accessor.h>
#include <deal.II/base/tensor.h>
#include <deal.II/lac/vector.h>
#include <cmath>

#include "manifold.h"

template <int dim, int spacedim>
RotatedSphericalManifold<dim,spacedim>::RotatedSphericalManifold(const Point<spacedim> center):
ChartManifold<dim,spacedim,spacedim>(RotatedSphericalManifold<dim,spacedim>::get_periodicity()),
center(center)
{}



template <int dim, int spacedim>
Tensor<1,spacedim>
RotatedSphericalManifold<dim,spacedim>::get_periodicity()
{
   Tensor<1,spacedim> periodicity;
   // In two dimensions, theta is periodic.
   // In three dimensions things are a little more complicated, since the only variable
   // that is truly periodic is phi, while theta should be bounded between
   // 0 and pi. There is currently no way to enforce this, so here we only fix
   // periodicity for the last variable, corresponding to theta in 2d and phi in 3d.
   periodicity[spacedim-1] = 2*numbers::PI;
   return periodicity;
}


template <int dim, int spacedim>
Point<spacedim>
RotatedSphericalManifold<dim,spacedim>::get_new_point(const Quadrature<spacedim> &quad) const
{
   if (spacedim == 2)
      return ChartManifold<dim,spacedim,spacedim>::get_new_point(quad);
   else
   {
      double rho_average = 0;
      Point<spacedim> mid_point;
      for (unsigned int i=0; i<quad.size(); ++i)
      {
         rho_average += quad.weight(i)*(quad.point(i)-center).norm();
         mid_point += quad.weight(i)*quad.point(i);
      }
      // Project the mid_point back to the right location
      Tensor<1,spacedim> R = mid_point-center;
      // Scale it to have radius rho_average
      R *= rho_average/R.norm();
      // And return it.
      return center+R;
   }
}



template <int dim, int spacedim>
Point<spacedim>
RotatedSphericalManifold<dim,spacedim>::push_forward(const Point<spacedim> &spherical_point) const
{
   Assert(spherical_point[0] >=0.0,
          ExcMessage("Negative radius for given point."));
   const double rho = spherical_point[0];
   const double theta = spherical_point[1];
   
   Point<spacedim> p;
   if (rho > 1e-10)
      switch (spacedim)
   {
      case 3:
      {
         const double phi= spherical_point[2];
         Point<spacedim> p0;
         p0[0] = rho*sin(theta)*cos(phi);
         p0[1] = rho*sin(theta)*sin(phi);
         p0[2] = rho*cos(theta);
         
         // rotate p0 by -90 deg about x axis
         p[0] =  p0[0];
         p[1] = -p0[2];
         p[2] =  p0[1];
         break;
      }
      default:
         Assert(false, ExcNotImplemented());
   }
   return p+center;
}

template <int dim, int spacedim>
Point<spacedim>
RotatedSphericalManifold<dim,spacedim>::pull_back(const Point<spacedim> &space_point) const
{
   const Tensor<1,spacedim> R0 = space_point-center;
   const double rho = R0.norm();
   
   // rotate R0 by +90 deg about x axis
   const Point<spacedim> R (R0[0], R0[2], -R0[1]);
   
   Point<spacedim> p;
   p[0] = rho;
   
   switch (spacedim)
   {
      case 3:
      {
         const double z = R[2];
         p[2] = atan2(R[1],R[0]); // phi
         if (p[2] < 0)
            p[2] += 2*numbers::PI; // phi is periodic
         p[1] = atan2(sqrt(R[0]*R[0]+R[1]*R[1]),z);  // theta
      }
         break;
      default:
         Assert(false, ExcInternalError());
   }
   return p;
}


template <int dim, int spacedim>
DerivativeForm<1,spacedim,spacedim>
RotatedSphericalManifold<dim,spacedim>::push_forward_gradient(const Point<spacedim> &spherical_point) const
{
   Assert(spherical_point[0] >=0.0,
          ExcMessage("Negative radius for given point."));
   const double rho = spherical_point[0];
   const double theta = spherical_point[1];
   
   DerivativeForm<1,spacedim,spacedim> DX;
   if (rho > 1e-10)
      switch (spacedim)
   {
      case 3:
      {
         const double phi= spherical_point[2];
         DX[0][0] =      sin(theta)*cos(phi);
         DX[0][1] =  rho*cos(theta)*cos(phi);
         DX[0][2] = -rho*sin(theta)*sin(phi);

         DX[1][0] =    -cos(theta);
         DX[1][1] = rho*sin(theta);
         DX[1][2] = 0;
         
         DX[2][0] =     sin(theta)*sin(phi);
         DX[2][1] = rho*cos(theta)*sin(phi);
         DX[2][2] = rho*sin(theta)*cos(phi);

         break;
      }
      default:
         Assert(false, ExcInternalError());
   }
   return DX;
}


template class RotatedSphericalManifold<2,3>;
