from contextlib import contextmanager
from sqlalchemy import create_engine, Column, Integer, Table, ForeignKey
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import sessionmaker, relationship
from sqlalchemy.sql import select, func
from sqlalchemy.sql.ddl import DDLElement



class CreateView(DDLElement):
    def __init__(self, name, selectable):
        self.name = name
        self.selectable = selectable


class DropView(DDLElement):
    def __init__(self, name, if_exists=False):
        self.name = name
        self.if_exists = if_exists


@compiles(CreateView)
def visit_create_view(element, compiler, **kw):
    return ('CREATE VIEW {} AS {}'
            .format(element.name,
                    compiler.sql_compiler.process(element.selectable)))



# Database setup

Base = declarative_base()
engine = create_engine('sqlite://')#, echo=True)
Session = sessionmaker(bind=engine)

@contextmanager
def session_scope():
    session = Session()
    try:
        yield session
        session.commit()
    except:
        session.rollback()
        raise
    finally:
        session.close()


Base.metadata.bind = engine


# Define and create database tables
class Group(Base):
    __tablename__ = 'groups'

    id = Column(Integer, primary_key=True)
    items = relationship('ItemCount', backref='group')


class ItemCount(Base):
    __tablename__ = 'item_counts'

    item_id = Column(Integer, primary_key=True)
    group_id = Column(Integer, ForeignKey(Group.id), primary_key=True)
    value = Column(Integer, primary_key=True)
    count = Column(Integer)


Base.metadata.create_all()


# Fill database with data
with session_scope() as session:
    group_a = Group()
    session.add(group_a)
    group_a.items.append(ItemCount(item_id=1, value=1, count=5))
    group_a.items.append(ItemCount(item_id=1, value=2, count=2))
    group_a.items.append(ItemCount(item_id=1, value=3, count=3))

    group_a.items.append(ItemCount(item_id=2, value=1, count=1))
    group_a.items.append(ItemCount(item_id=2, value=2, count=10))
    group_a.items.append(ItemCount(item_id=2, value=3, count=100))

    group_b = Group()
    session.add(group_b)
    group_b.items.append(ItemCount(item_id=3, value=1, count=50))
    group_b.items.append(ItemCount(item_id=3, value=2, count=20))
    group_b.items.append(ItemCount(item_id=3, value=3, count=30))

    group_b.items.append(ItemCount(item_id=4, value=1, count=2))
    group_b.items.append(ItemCount(item_id=4, value=2, count=20))
    group_b.items.append(ItemCount(item_id=4, value=3, count=200))


# with session_scope() as session:
#     for row in session.query(ItemCount):
#         print(row.group_id, row.value, row.count)


# Create a view that sums counts for a particular group and value
view_name = 'group_counts'
view_select = (select([ItemCount.group_id.label('group_id'),
                       ItemCount.value.label('value'),
                       func.sum(ItemCount.count).label('count')])
               .select_from(ItemCount.__table__)
               .group_by(ItemCount.group_id, ItemCount.value))


CreateView(view_name, view_select).execute(bind=engine)


view_table = Table(view_name, Base.metadata,
                   Column('group_id', ForeignKey(Group.id), primary_key=True),
                   autoload=True)


class GroupCount(Base):
    __table__ = view_table


# The following queries should return the same set of results, I think
with session_scope() as session:
    # prints 3 rows (one for each value)
    print('Raw select statement:')
    for group_count in session.execute('SELECT * FROM group_counts '
                                       'WHERE group_id = 1'):
        print(group_count)
    print

    # only the first row is returned
    print('Using the GroupCount class:')
    for group_count in session.query(GroupCount).filter_by(group_id=1):
        print(group_count.group_id, group_count.value, group_count.count)


# Group.counts = relationship(GroupCount, backref='group',
#                             cascade='all, delete, delete-orphan',
#                             order_by=GroupCount.value)
#
#
# with session_scope() as session:
#     for group in session.query(Group):
#         print(group.counts)
