import junit.framework.TestCase;

// import org.apache.log4j.Logger;

import org.apache.ojb.broker.PBKey;
import org.apache.ojb.broker.PersistenceBroker;
import org.apache.ojb.broker.PersistenceBrokerFactory;
import org.apache.ojb.broker.metadata.*;

import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.ResultSet;
import java.sql.SQLException;

import java.util.*;


public abstract class TestOJBSchemaSanity extends TestCase
{
    // private static final Logger LOG = Logger.getLogger(TestOJBSchemaSanity.class);

    /**
     * <String (table name)> : <Table>
     */
    private Map tables;

    /**
     * <String (Class name)> : <ClassDescriptor>
     */
    private Map descriptors;
    private DatabaseMetaData meta;

    /**
     * Subclasses must provide a PBKey whose mappins will be used to verify appropriate sanity =)
     */
    public abstract PBKey getPBKey();

    public void testSanity() throws Exception
    {
        final PBKey key = getPBKey();
        final PersistenceBroker broker = PersistenceBrokerFactory.createPersistenceBroker(key);
        final Connection conn = broker.serviceConnectionManager().getConnection();

        meta = conn.getMetaData();
        tables = buildTableMetadata(meta);

        final MetadataManager mm = MetadataManager.getInstance();
        final DescriptorRepository dr = mm.getRepository();
        descriptors = dr.getDescriptorTable();
        for (Iterator iterator = descriptors.entrySet().iterator(); iterator.hasNext();)
        {
            final Map.Entry entry = (Map.Entry)iterator.next();
            final String class_name = (String)entry.getKey();
            final ClassDescriptor cd = (ClassDescriptor)entry.getValue();
            if (class_name.startsWith("org.apache.ojb") || (cd.getFullTableName() == null))
            {
                // Skip ojb internal mappings and unmapped model instances (extent only)
                continue;
            }

            final String table_name = cd.getFullTableName().toLowerCase();

            verifyTableExists(cd, tables, table_name);
            verifyColumnsExist(cd, (Table)tables.get(table_name));
            verifyObjectReferenceForeignKeys(cd);
            verifyCollectionForeignKeys(cd);
        }

        broker.close();
    }

    private void verifyCollectionForeignKeys(final ClassDescriptor classDescriptor)
        throws SQLException
    {
        final List collection_descriptors = classDescriptor.getCollectionDescriptors();
        for (int i = 0; i < collection_descriptors.size(); i++)
        {
            final CollectionDescriptor col_desc =
                (CollectionDescriptor)collection_descriptors.get(i);
            if (col_desc.isMtoNRelation())
            {
                checkMtoNRelation(classDescriptor, col_desc);
            }
            else
            {
                checkOneToManyRelation(classDescriptor, col_desc);
            }
        }
    }

    private void checkMtoNRelation(final ClassDescriptor refering_cd,
        final CollectionDescriptor col_desc) throws SQLException
    {
        final String indirection_table = col_desc.getIndirectionTable();
        final ResultSet rs =
            meta.getTables(null, null, indirection_table.toLowerCase(), new String[] { "TABLE" });
        assertTrue("Indirection table [" + indirection_table + "] does not exist and "
            + "is referenced in M:N relation from [" + refering_cd.getClassNameOfObject() + "] "
            + "in the attribute [" + col_desc.getAttributeName() + "]", rs.next()); // only one row
        assertFalse(rs.next());

        final String[] fks = col_desc.getFksToThisClass();
        for (int i = 0; i < fks.length; i++)
        {
            final String fk = fks[i];
            final ResultSet column_rs =
                meta.getColumns(null, null, indirection_table.toLowerCase(), fk.toLowerCase());
            assertTrue("Foriegn key column [" + fk + "] in table [" + indirection_table + "] "
                + "referenced within class [" + refering_cd.getClassNameOfObject() + "] "
                + "in attribute [" + col_desc.getAttributeName() + "] does not exist!",
                column_rs.next());
            assertFalse(column_rs.next());
        }
    }

    private void checkOneToManyRelation(final ClassDescriptor refering_cd,
        final CollectionDescriptor col_desc)
    {
        // FK's are on other side from refering_cd
        final String item_class_name = col_desc.getItemClassName();
        final ClassDescriptor item_cd = (ClassDescriptor)descriptors.get(item_class_name);
        final List fk_field_names = col_desc.getForeignKeyFields();
        for (int i = 0; i < fk_field_names.size(); i++)
        {
            final String fk_field_name = (String)fk_field_names.get(i);
            checkFieldExistsIn(refering_cd.getClassNameOfObject(), item_cd, fk_field_name);
        }
    }

    public void checkFieldExistsIn(final String parent_name, final ClassDescriptor this_cd,
        final String field_name)
    {
        final boolean is_mapped = this_cd.getFullTableName() != null;
        if (is_mapped)
        {
            final FieldDescriptor fd = this_cd.getFieldDescriptorByName(field_name);
            assertNotNull("Field referenced as fk from  [" + parent_name + "] " + "in field ["
                + field_name + "] " + "does not exist in [" + this_cd.getClassNameOfObject() + "]",
                fd);

            final String column_name = fd.getColumnName().toLowerCase();
            final Table table = (Table)tables.get(this_cd.getFullTableName().toLowerCase());
            assertTrue("Field referenced as fk from  [" + parent_name + "] " + "in field ["
                + field_name + "] " + "does not actually map to column [" + column_name + "]",
                table.containsColumn(column_name));
        }
        if (this_cd.isExtent())
        {
            final List extent_class_names = this_cd.getExtentClassNames();
            for (int i = 0; i < extent_class_names.size(); i++)
            {
                final String extent_class_name = (String)extent_class_names.get(i);
                final ClassDescriptor extent_cd =
                    (ClassDescriptor)descriptors.get(extent_class_name);
                checkFieldExistsIn(parent_name, extent_cd, field_name);
            }
        }
    }

    private void verifyObjectReferenceForeignKeys(final ClassDescriptor cd)
        throws NoSuchFieldException
    {
        final List reference_descriptors = cd.getObjectReferenceDescriptors();
        for (int i = 0; i < reference_descriptors.size(); i++)
        {
            final ObjectReferenceDescriptor rd =
                (ObjectReferenceDescriptor)reference_descriptors.get(i);
            final String item_class_name = rd.getItemClassName();
            final ClassDescriptor item_descriptor =
                (ClassDescriptor)descriptors.get(item_class_name);
            checkTargetObjectTypeInReferenceIsMapped(item_descriptor, rd);
        }
    }

    private void checkTargetObjectTypeInReferenceIsMapped(final ClassDescriptor cd,
        final ObjectReferenceDescriptor rd) throws NoSuchFieldException
    {
        final boolean is_mapped = cd.getFullTableName() != null;

        if (is_mapped)
        {
            final List field_names_in_refering_class = rd.getForeignKeyFields();
            for (int i = 0; i < field_names_in_refering_class.size(); i++)
            {
                final String item_class_name = rd.getItemClassName();
                final ClassDescriptor item_class_descriptor =
                    (ClassDescriptor)descriptors.get(item_class_name);
                assertNotNull(item_class_descriptor);
            }
        }
        if (cd.isExtent())
        {
            final List extents_class_names = cd.getExtentClassNames();
            for (int i = 0; i < extents_class_names.size(); i++)
            {
                final String class_name = (String)extents_class_names.get(i);
                checkTargetObjectTypeInReferenceIsMapped((ClassDescriptor)descriptors.get(
                        class_name), rd);
            }
        }
    }

    private void verifyColumnsExist(final ClassDescriptor cd, final Table table)
    {
        final FieldDescriptor[] field_descriptors = cd.getFieldDescriptions();
        for (int i = 0; i < field_descriptors.length; i++)
        {
            final FieldDescriptor fd = field_descriptors[i];
            assertTrue("Class [" + cd.getClassNameOfObject() + "] Field [" + fd.getAttributeName()
                + "]" + " references column [" + fd.getColumnName() + "] which doesn't exist!",
                table.containsColumn(fd.getColumnName()));
        }
    }

    private void verifyTableExists(final ClassDescriptor cd, final Map tables, final String table)
    {
        if (! (tables.containsKey(table.toLowerCase())))
        {
            fail("Table [" + table + "] doesn't exist and is used by [" + cd.getClassNameOfObject()
                + "]");
        }
    }

    public Map buildTableMetadata(final DatabaseMetaData meta)
        throws SQLException
    {
        final HashMap tables = new HashMap();
        final ResultSet rs = meta.getTables(null, null, "%", new String[] { "TABLE" });

        while (rs.next())
        {
            final String table_name = rs.getString("TABLE_NAME");
            final Table table = new Table(table_name.toLowerCase());

            // if (LOG.isDebugEnabled())
            //{
            //    LOG.debug("Working on table: " + table_name);
            //}

            final ResultSet column_rs = meta.getColumns(null, null, table_name, "%");
            while (column_rs.next())
            {
                final String column_name = column_rs.getString("COLUMN_NAME");
                final String column_type = column_rs.getString("DATA_TYPE");
                final String column_size = column_rs.getString("COLUMN_SIZE");

                /*if (LOG.isDebugEnabled()) {
                    LOG.debug("column_name: " + column_name);
                    LOG.debug("column_type: " + column_type);
                    LOG.debug("column_size: " + column_size);
                }*/
                final Column column =
                    new Column(column_name.toLowerCase(), column_type.toLowerCase(), column_size);
                table.getColumns().add(column);
            }
            tables.put(table_name, table);
        }

        return tables;
    }

    private static class Table
    {
        String name;
        private Set columns = new HashSet();

        Table(final String name)
        {
            this.name = name;
        }

        public boolean containsColumn(final String name)
        {
            for (Iterator iterator = columns.iterator(); iterator.hasNext();)
            {
                final Column column = (Column)iterator.next();
                if (name.toLowerCase().equals(column.getName()))
                {
                    return true;
                }
            }

            return false;
        }

        public String getName()
        {
            return name;
        }

        public Set getColumns()
        {
            return columns;
        }
    }

    private static class Column
    {
        String type;
        String size;
        String name;

        Column(final String name, final String type, final String size)
        {
            this.name = name.toLowerCase();
            this.type = type;
            this.size = size;
        }

        public String getName()
        {
            return name;
        }

        public String getType()
        {
            return type;
        }

        public String getSize()
        {
            return size;
        }

        public boolean equals(final Object o)
        {
            if (this == o)
            {
                return true;
            }
            if (! (o instanceof Column))
            {
                return false;
            }

            final Column column = (Column)o;

            if ((name != null) ? (! name.equals(column.name)) : (column.name != null))
            {
                return false;
            }

            return true;
        }

        public int hashCode()
        {
            return ((name != null) ? name.hashCode() : 0);
        }
    }
}
