#include <petsc.h>
static char help[] = "dmplex";

// Ghost cells are in the point SF
static PetscErrorCode GetLeafCells(DM dm, PetscInt *cEndInterior)
{
    PetscSF sf;
    const PetscInt *leaves;
    PetscInt Nl, cStart, cEnd;
    PetscMPIInt rank;

    PetscFunctionBeginUser;
    PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank));
    PetscCall(DMPlexGetHeightStratum(dm, 0, &cStart, &cEnd));
    *cEndInterior = cEnd;
    PetscCall(DMGetPointSF(dm, &sf));
    PetscCall(PetscSFGetGraph(sf, NULL, &Nl, &leaves, NULL));
    for (PetscInt l = 0; l < Nl; ++l)
    {
        const PetscInt leaf = leaves[l];

        if (leaf >= cStart && leaf < cEnd)
        {
            *cEndInterior = PetscMin(leaf, *cEndInterior);
            PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] ghost cell %" PetscInt_FMT "\n", rank, leaf));
        }
    }
    PetscFunctionReturn(0);
}

int main(int argc, char **argv)
{
    Mat J;
    DM dm, dm_dist;
    PetscSection section;
    PetscInt cStart, cEndInterior, cEnd, rank, m, n;
    PetscInt nField = 1, nDof = 3, field = 0;

    PetscCall(PetscInitialize(&argc, &argv, NULL, help));
    PetscCallMPI(MPI_Comm_rank(PETSC_COMM_WORLD, &rank));

    PetscCall(DMCreate(PETSC_COMM_WORLD, &dm));
    PetscCall(DMSetType(dm, DMPLEX));
    PetscCall(DMPlexDistributeSetDefault(dm, PETSC_FALSE));
    PetscCall(DMSetFromOptions(dm));

    PetscCall(GetLeafCells(dm, &cEndInterior));

    PetscCall(DMPlexGetHeightStratum(dm, 0, &cStart, &cEnd));
    PetscCall(PetscPrintf(PETSC_COMM_SELF, "Before Distribution Rank: %d, cStart: %d, cEndInterior: %d, cEnd: %d\n",
                          rank, cStart, cEndInterior, cEnd));

    PetscCall(DMPlexDistribute(dm, 1, NULL, &dm_dist));
    if (dm_dist)
    {
        PetscCall(DMDestroy(&dm));
        dm = dm_dist;
    }

    PetscCall(DMPlexGetHeightStratum(dm, 0, &cStart, &cEnd));
    PetscCall(PetscSectionCreate(PETSC_COMM_WORLD, &section));
    PetscCall(PetscSectionSetNumFields(section, nField));
    PetscCall(PetscSectionSetChart(section, cStart, cEnd));
    for (PetscInt p = cStart; p < cEnd; p++) {
        PetscCall(PetscSectionSetFieldDof(section, p, field, nDof));
        PetscCall(PetscSectionSetDof(section, p, nDof));
    }
    PetscCall(PetscSectionSetUp(section));
    PetscCall(DMSetLocalSection(dm, section));
    PetscCall(PetscSectionView(section, NULL));
    PetscCall(PetscSectionDestroy(&section));
    PetscCall(DMPlexGetHeightStratum(dm, 0, &cStart, &cEnd));
    PetscCall(GetLeafCells(dm, &cEndInterior));
    PetscCall(PetscPrintf(PETSC_COMM_SELF, "After Distribution Rank: %d, cStart: %d, cEndInterior: %d, cEnd: %d\n",
                          rank, cStart, cEndInterior, cEnd));

    PetscCall(DMSetAdjacency(dm, field, PETSC_TRUE, PETSC_FALSE));
    PetscCall(DMCreateMatrix(dm, &J));
    PetscCall(MatGetLocalSize(J, &m, &n));
    PetscCall(PetscPrintf(PETSC_COMM_SELF, "[%d] m: %d n: %d\n", rank, m, n));
    PetscCall(DMDestroy(&dm));
    PetscCall(MatDestroy(&J));
    PetscCall(PetscFinalize());
    return 0;
}

/*TEST

  test:
    args: -dm_plex_dim 3 -dm_plex_simplex 0 -dm_plex_box_faces 3,3,3

TEST*/
