/* Program usage: mpiexec -n 1 toy[-help] [all TAO options] */

/* ----------------------------------------------------------------------
---------------------------------------------------------------------- */

#include <petsctao.h>

static  char help[]="";

/*
   User-defined application context - contains data needed by the
   application-provided call-back routines, FormFunction(),
   FormGradient(), and FormHessian().
*/

/*
   x,d in R^n
   f in R
   bin in R^mi
   beq in R^me
   Aeq in R^(me x n)
   Ain in R^(mi x n)
   H in R^(n x n)
   min f=(1/2)*x'*H*x + d'*x
   s.t.  Aeq*x == beq
         Ain*x >= bin
*/
typedef struct {
  PetscInt n; /* Length x */
  PetscInt ne; /* number of equality constraints */
  PetscInt ni; /* number of inequality constraints */
  Vec      x,xl,xu;
  Vec      ce,ci,bl,bu;
  Mat      Ae,Ai,H;
} AppCtx;

/* -------- User-defined Routines --------- */

PetscErrorCode InitializeProblem(AppCtx *);
PetscErrorCode DestroyProblem(AppCtx *);
PetscErrorCode FormFunctionGradient(Tao,Vec,PetscReal *,Vec,void *);
PetscErrorCode FormHessian(Tao,Vec,Mat,Mat, void*);
PetscErrorCode FormInequalityConstraints(Tao,Vec,Vec,void*);
PetscErrorCode FormEqualityConstraints(Tao,Vec,Vec,void*);
PetscErrorCode FormInequalityJacobian(Tao,Vec,Mat,Mat, void*);
PetscErrorCode FormEqualityJacobian(Tao,Vec,Mat,Mat, void*);



PetscErrorCode main(int argc,char **argv)
{
  PetscErrorCode     ierr;                /* used to check for functions returning nonzeros */
  Tao                tao;
  KSP                ksp;
  PC                 pc;
  AppCtx             user;                /* application context */

  ierr = PetscInitialize(&argc,&argv,(char *)0,help);if (ierr) return ierr;
  ierr = PetscPrintf(PETSC_COMM_WORLD,"\n---- TOY Problem -----\n");CHKERRQ(ierr);
  ierr = PetscPrintf(PETSC_COMM_WORLD,"Solution should be f(1,1)=-2\n");CHKERRQ(ierr);
  ierr = InitializeProblem(&user);CHKERRQ(ierr);
  ierr = TaoCreate(PETSC_COMM_WORLD,&tao);CHKERRQ(ierr);
  ierr = TaoSetType(tao,TAOIPM);CHKERRQ(ierr);
  ierr = TaoSetInitialVector(tao,user.x);CHKERRQ(ierr);
  ierr = TaoSetVariableBounds(tao,user.xl,user.xu);CHKERRQ(ierr);
  ierr = TaoSetObjectiveAndGradientRoutine(tao,FormFunctionGradient,(void*)&user);CHKERRQ(ierr);

  ierr = TaoSetEqualityConstraintsRoutine(tao,user.ce,FormEqualityConstraints,(void*)&user);CHKERRQ(ierr);
  ierr = TaoSetInequalityConstraintsRoutine(tao,user.ci,FormInequalityConstraints,(void*)&user);CHKERRQ(ierr);

  ierr = TaoSetJacobianEqualityRoutine(tao,user.Ae,user.Ae,FormEqualityJacobian,(void*)&user);CHKERRQ(ierr);
  ierr = TaoSetJacobianInequalityRoutine(tao,user.Ai,user.Ai,FormInequalityJacobian,(void*)&user);CHKERRQ(ierr);
  ierr = TaoSetHessianRoutine(tao,user.H,user.H,FormHessian,(void*)&user);CHKERRQ(ierr);
  /* ierr = TaoSetTolerances(tao,0,0,0);CHKERRQ(ierr); */

  ierr = TaoSetFromOptions(tao);CHKERRQ(ierr);

  ierr = TaoGetKSP(tao,&ksp);CHKERRQ(ierr);
  ierr = KSPGetPC(ksp,&pc);CHKERRQ(ierr);
  ierr = PCSetType(pc,PCLU);CHKERRQ(ierr);
  /*
      This algorithm produces matrices with zeros along the diagonal therefore we need to use
    SuperLU which does partial pivoting
  */
  ierr = PCFactorSetMatSolverType(pc,MATSOLVERMUMPS);CHKERRQ(ierr);
  ierr = KSPSetType(ksp,KSPPREONLY);CHKERRQ(ierr);
  ierr = KSPSetFromOptions(ksp);CHKERRQ(ierr);

  /* ierr = TaoSetTolerances(tao,0,0,0);CHKERRQ(ierr); */
  ierr = TaoSolve(tao);CHKERRQ(ierr);
  ierr = VecView(user.x, PETSC_VIEWER_STDOUT_WORLD);

  ierr = DestroyProblem(&user);CHKERRQ(ierr);
  ierr = TaoDestroy(&tao);CHKERRQ(ierr);
  ierr = PetscFinalize();
  return ierr;
}

PetscErrorCode InitializeProblem(AppCtx *user)
{
  PetscErrorCode ierr;

  PetscFunctionBegin;
  user->n = 4;
  ierr = VecCreateSeq(PETSC_COMM_SELF,user->n,&user->x);CHKERRQ(ierr);
  ierr = VecDuplicate(user->x,&user->xl);CHKERRQ(ierr);
  ierr = VecDuplicate(user->x,&user->xu);CHKERRQ(ierr);
  ierr = VecSet(user->x,1.0);CHKERRQ(ierr);
  ierr = VecSet(user->xl,1.0);CHKERRQ(ierr);
  ierr = VecSet(user->xu,5.0);CHKERRQ(ierr);

  user->ne = 1;
  ierr = VecCreateSeq(PETSC_COMM_SELF,user->ne,&user->ce);CHKERRQ(ierr);

  user->ni = 1;
  ierr = VecCreateSeq(PETSC_COMM_SELF,user->ni,&user->ci);CHKERRQ(ierr);

  ierr = MatCreateSeqAIJ(PETSC_COMM_SELF,user->ne,user->n,user->n,NULL,&user->Ae);CHKERRQ(ierr);
  ierr = MatCreateSeqAIJ(PETSC_COMM_SELF,user->ni,user->n,user->n,NULL,&user->Ai);CHKERRQ(ierr);
  ierr = MatSetFromOptions(user->Ae);CHKERRQ(ierr);
  ierr = MatSetFromOptions(user->Ai);CHKERRQ(ierr);


  ierr = MatCreateSeqAIJ(PETSC_COMM_SELF,user->n,user->n,user->n,NULL,&user->H);CHKERRQ(ierr);
  ierr = MatSetFromOptions(user->H);CHKERRQ(ierr);CHKERRQ(ierr);

  PetscFunctionReturn(0);
}

PetscErrorCode DestroyProblem(AppCtx *user)
{
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = MatDestroy(&user->Ae);CHKERRQ(ierr);
  ierr = MatDestroy(&user->Ai);CHKERRQ(ierr);
  ierr = MatDestroy(&user->H);CHKERRQ(ierr);

  ierr = VecDestroy(&user->x);CHKERRQ(ierr);
  ierr = VecDestroy(&user->ce);CHKERRQ(ierr);
  ierr = VecDestroy(&user->ci);CHKERRQ(ierr);
  ierr = VecDestroy(&user->xl);CHKERRQ(ierr);
  ierr = VecDestroy(&user->xu);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

PetscErrorCode FormFunctionGradient(Tao tao, Vec X, PetscReal *f, Vec G, void *ctx)
{
  PetscScalar       *g;
  const PetscScalar *x;
  PetscErrorCode    ierr;

  PetscFunctionBegin;
  ierr = VecGetArrayRead(X,&x);CHKERRQ(ierr);
  ierr = VecGetArray(G,&g);CHKERRQ(ierr);
  *f = x[0]*x[3]*(x[0] + x[1] + x[2]) + x[2];
  g[0] = 2.*x[0]*x[3] + x[1]*x[3] + x[2]*x[3];
  g[1] = x[0]*x[3];
  g[2] = x[0]*x[3] + 1.;
  g[3] = x[0]*x[0] + x[0]*x[1] + x[0]*x[2];
  ierr = VecRestoreArrayRead(X,&x);CHKERRQ(ierr);
  ierr = VecRestoreArray(G,&g);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

PetscErrorCode FormHessian(Tao tao, Vec x, Mat H, Mat Hpre, void *ctx)
{
  Vec               DE,DI;
  const PetscScalar *ptr, *de, *di;
  PetscInt          rows[4];
  PetscScalar       vals[16];
  PetscErrorCode    ierr;

  PetscFunctionBegin;
  ierr = TaoGetDualVariables(tao,&DE,&DI);CHKERRQ(ierr);
  rows[0] = 0;
  rows[1] = 1;
  rows[2] = 2;
  rows[3] = 3;

  ierr = VecGetArrayRead(x,&ptr);CHKERRQ(ierr);
        vals[0 * 4 + 0]  = 2*ptr[3];
        vals[0 * 4 + 1]  = ptr[3];
        vals[0 * 4 + 2]  = ptr[3];
        vals[0 * 4 + 3]  = 2*ptr[0]+ptr[1]+ptr[2];

        vals[1 * 4 + 0]  = ptr[3];
        vals[1 * 4 + 3]  = ptr[0];

        vals[2 * 4 + 0]  = ptr[3];
        vals[2 * 4 + 3]  = ptr[0];

        vals[3 * 4 + 0]  = (2*ptr[0] + ptr[1] + ptr[2]);
        vals[3 * 4 + 1]  = ptr[0];
        vals[3 * 4 + 2]  = ptr[0];

        if(DE) {
  ierr = VecGetArrayRead(DE,&de);CHKERRQ(ierr);
        vals[0 * 4 + 0] += 2 * de[0];
        vals[1 * 4 + 1] += 2 * de[0];
        vals[2 * 4 + 2] += 2 * de[0];
        vals[3 * 4 + 3] += 2 * de[0];
  ierr = VecRestoreArrayRead(DE,&de);CHKERRQ(ierr);
        }

        /*
  ierr = VecView(DI, PETSC_VIEWER_STDOUT_WORLD);
    */
    if(DI) {
        ierr = VecGetArrayRead(DI,&di);CHKERRQ(ierr);
        vals[0 * 4 + 1] += di[0]*(ptr[2] * ptr[3]);
        vals[0 * 4 + 2] += di[0]*(ptr[1] * ptr[3]);
        vals[0 * 4 + 3] += di[0]*(ptr[1] * ptr[2]);

        vals[1 * 4 + 0] += di[0]*(ptr[2] * ptr[3]);
        vals[1 * 4 + 2] += di[0]*(ptr[0] * ptr[3]);
        vals[1 * 4 + 3] += di[0]*(ptr[0] * ptr[2]);

        vals[2 * 4 + 0] += di[0]*(ptr[1] * ptr[3]);
        vals[2 * 4 + 1] += di[0]*(ptr[0] * ptr[3]);
        vals[2 * 4 + 3] += di[0]*(ptr[0] * ptr[1]);

        vals[3 * 4 + 0] += di[0]*(ptr[1] * ptr[2]);
        vals[3 * 4 + 1] += di[0]*(ptr[0] * ptr[2]);
        vals[3 * 4 + 2] += di[0]*(ptr[0] * ptr[1]);
        ierr = VecRestoreArrayRead(DI,&di);CHKERRQ(ierr);
    }
  ierr = VecRestoreArrayRead(x,&ptr);CHKERRQ(ierr);

  ierr = MatSetValues(H,4,rows,4,rows,vals,INSERT_VALUES);CHKERRQ(ierr);

  ierr = MatAssemblyBegin(H,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
  ierr = MatAssemblyEnd(H,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

PetscErrorCode FormInequalityConstraints(Tao tao, Vec X, Vec CI, void *ctx)
{
  const PetscScalar *x;
  PetscScalar       *c;
  PetscErrorCode    ierr;

  PetscFunctionBegin;
  ierr = VecGetArrayRead(X,&x);CHKERRQ(ierr);
  ierr = VecGetArray(CI,&c);CHKERRQ(ierr);
  c[0] = x[0]*x[1]*x[2]*x[3]-25;
  ierr = VecRestoreArrayRead(X,&x);CHKERRQ(ierr);
  ierr = VecRestoreArray(CI,&c);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

PetscErrorCode FormEqualityConstraints(Tao tao, Vec X, Vec CE,void *ctx)
{
  PetscScalar    *x,*c;
  PetscErrorCode ierr;

  PetscFunctionBegin;
  ierr = VecGetArray(X,&x);CHKERRQ(ierr);
  ierr = VecGetArray(CE,&c);CHKERRQ(ierr);
  c[0] = x[0]*x[0] + x[1]*x[1] + x[2]*x[2] + x[3]*x[3] - 40;
  ierr = VecRestoreArray(X,&x);CHKERRQ(ierr);
  ierr = VecRestoreArray(CE,&c);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}

PetscErrorCode FormInequalityJacobian(Tao tao, Vec X, Mat JI, Mat JIpre,  void *ctx)
{
  PetscInt          rows[1];
  PetscInt          cols[4];
  PetscScalar       vals[4];
  const PetscScalar *x;
  PetscErrorCode    ierr;

  PetscFunctionBegin;
  ierr = VecGetArrayRead(X,&x);CHKERRQ(ierr);
  rows[0] = 0;       cols[1] = 1;
  cols[0] = 0;       cols[2] = 2;
  cols[3] = 3;
  vals[0] = x[1]*x[2]*x[3];
  vals[1] = x[0]*x[2]*x[3];
  vals[2] = x[0]*x[1]*x[3];
  vals[3] = x[0]*x[1]*x[2];
  ierr = VecRestoreArrayRead(X,&x);CHKERRQ(ierr);
  ierr = MatSetValues(JI,1,rows,4,cols,vals,INSERT_VALUES);CHKERRQ(ierr);
  ierr = MatAssemblyBegin(JI,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
  ierr = MatAssemblyEnd(JI,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);

  PetscFunctionReturn(0);
}

PetscErrorCode FormEqualityJacobian(Tao tao, Vec X, Mat JE, Mat JEpre, void *ctx)
{
  PetscInt          rows[4];
  PetscScalar       vals[4];
  const PetscScalar *x;
  PetscErrorCode    ierr;

  PetscFunctionBegin;
  ierr = VecGetArrayRead(X,&x);CHKERRQ(ierr);
  rows[0] = 0;       rows[1] = 1;
  rows[2] = 2;       rows[3] = 3;
  vals[0] = 2*x[0];  vals[1] = 2*x[1];
  vals[2] = 2*x[2];  vals[3] = 2*x[3];
  ierr = VecRestoreArrayRead(X,&x);CHKERRQ(ierr);
  ierr = MatSetValues(JE,1,rows,4,rows,vals,INSERT_VALUES);CHKERRQ(ierr);
  ierr = MatAssemblyBegin(JE,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
  ierr = MatAssemblyEnd(JE,MAT_FINAL_ASSEMBLY);CHKERRQ(ierr);
  PetscFunctionReturn(0);
}


/*TEST

   build:
      requires: !complex !define(PETSC_USE_CXX)

   test:
      requires: superlu
      args: -tao_smonitor -tao_view -tao_gatol 1.e-5

TEST*/
