
#include <petscdmplex.h>
#include <mpi.h>
#include <cassert>
#include <iostream>
#include <vector>
#include <algorithm>
#include <sstream>

void debugDM(DM dm, int rank){
    // get local coordinates
    Vec coordinates;
    DMGetCoordinatesLocal(dm, &coordinates);
    PetscScalar *coords;
    PetscInt coordSize;
    VecGetSize(coordinates, &coordSize);
    int start, m;
    VecGetArray(coordinates, &coords);
    PetscSection cs;
    DMGetCoordinateSection(dm, &cs);

    // iterate points
    std::vector<std::string> res;
    PetscInt from,to,dof,off;
    DMPlexGetHeightStratum(dm, 0,&from, &to);
    for (int cellIndex=from;cellIndex<to;cellIndex++){
        const PetscInt *edges;
        PetscInt       numEdges;
        DMPlexGetConeSize(dm, cellIndex, &numEdges);
        DMPlexGetCone(dm, cellIndex, &edges);
        for (int e = 0;e<numEdges;e++) {
            int edgeIndex = edges[e];
            res.push_back(std::to_string(cellIndex)+" --> "+std::to_string(edgeIndex));
            const PetscInt *vertices;
            PetscInt       numVertices;
            DMPlexGetConeSize(dm, edgeIndex, &numVertices);
            DMPlexGetCone(dm, edgeIndex, &vertices);
            for (int v = 0;v<numVertices;v++){
                int vertexIndex = vertices[v];
                PetscSectionGetDof(cs, vertexIndex, &dof);
                PetscSectionGetOffset(cs, vertexIndex, &off);
                std::string coordinatesStr = std::string(" (");
                for (int j=0;j<2;j++){
                    coordinatesStr += std::to_string(coords[off+j])+", ";
                }
                coordinatesStr += ")";
                res.push_back(std::to_string(edges[e])+" --> "+std::to_string(vertexIndex)+ coordinatesStr) ;
            }
        }
    }
    VecRestoreArray(coordinates, &coords);

    // sort and print
    std::sort(res.begin(),res.end());
    std::stringstream ss;
    for (auto s : res){
        ss <<s << std::endl;
    }
    auto fullStr = ss.str();
    PetscSynchronizedPrintf(PETSC_COMM_WORLD,"Rank %d hasse diagram:\n%s\n", rank, fullStr.c_str());
}

void twoTrianglesQuestion(){
    int rank;
    MPI_Comm_rank (PETSC_COMM_WORLD, &rank);	// get current process id


    // from tutorial ( http://www.mcs.anl.gov/petsc/documentation/tutorials/ParisTutorial.pdf )
    //
    //         8
    //       / | \
    //      2  |  5
    //     /   |   \
    //    7  0 4 1  10
    //     \   |   /
    //      3  |  6
    //       \ | /
    //         9
    PetscInt depth= 2;
    PetscInt numPoints[3] =         {4, 5, 2};
    PetscInt coneSize[11] =         {3,3,2,2,2,2,2,0,0,0,0};
    PetscInt cones[16] =            {2,3,4,  5,4,6,  8,7,  7,9,  9,8,  10,8,  9,10};
    PetscInt coneOrientations[16] = {0,0,0,  0,-2,0,  0,0,  0,0,  0,0,  0,0,  0,0};
    PetscScalar vertexCoords[] =    {-1,0,  0,1,  0,-1,  1,0};

    DM dm;
    DMPlexCreate(PETSC_COMM_WORLD, &dm);
    DMSetType(dm, DMPLEX);

    DMSetDimension(dm, 2);
    DMPlexCreateFromDAG(dm, depth, numPoints, coneSize, cones, coneOrientations, vertexCoords);

    PetscSynchronizedPrintf(PETSC_COMM_WORLD,"Non distributed\n");
    debugDM(dm,rank);

    DM dmDist = nullptr;
    int overlap = 0;
    std::string err = "";
    DMPlexDistribute(dm, overlap, nullptr, &dmDist);
    if (dmDist) {
        DMDestroy(&dm);
        dm = dmDist;
    } else {
        err = "Error";
    }
    PetscSynchronizedPrintf(PETSC_COMM_WORLD,"%s\n",err.c_str());

    PetscSynchronizedPrintf(PETSC_COMM_WORLD,"Distributed\n");
    debugDM(dm,rank);
}

int main(int argc, char **argv) {
    PetscErrorCode ierr;
    static char help[] = "Please";
    ierr = PetscInitialize(&argc, &argv, NULL, help);
    CHKERRQ(ierr);

    twoTrianglesQuestion();
    PetscSynchronizedFlush(PETSC_COMM_WORLD,PETSC_STDOUT);
    PetscFinalize();

    return 0;
}