

from sqlalchemy import *


global_connect('sqlite://')


element_table = Table('elements',
    Column('element_id', Integer, primary_key=True),
    Column('name', String(128))).create()

detail_table = Table('details',
    Column('element_id', Integer, ForeignKey('elements.element_id'),
        primary_key=True),
    Column('material', String(128), default=''),
    Column('weight', Float, default=0)).create()

assembly_table = Table('assemblies',
    Column('element_id', Integer, ForeignKey('elements.element_id'),
        primary_key=True)).create()

specification_table = Table('specification',
    Column('spec_line_id', Integer, primary_key=True),
    Column('master_id', Integer, ForeignKey("elements.element_id"),
        nullable=False),
    Column('slave_id', Integer, ForeignKey("elements.element_id"),
        nullable=True),
    Column('quantity', Float, default=1)).create()




class Element(object):
    def __repr__(self):
        return '<%s %s>' % (self.__class__.__name__, self.name)

class Detail(Element): pass

class Assembly(Element): pass

class SpecLine(object):
    def __repr__(self):
        return '<%s %.01f %s>' % (self.__class__.__name__,
            self.quantity, getattr(self.slave, 'name', None))




assign_mapper(Element, element_table,
    properties=dict(name=element_table.c.name))

assign_mapper(Assembly, assembly_table, inherits=Element.mapper)

assign_mapper(Detail, detail_table, inherits=Element.mapper)





detail_join = select([element_table,
    detail_table.c.material,
    detail_table.c.weight,
    column("'detail'").label('type')],
    element_table.c.element_id==detail_table.c.element_id)

assembly_join = select([element_table,
    null().label('material'),
    null().label('weight'),
    column("'assembly'").label('type')],
    element_table.c.element_id==assembly_table.c.element_id)

element_type_join = detail_join.union_all(assembly_join).alias('pjoin')






class Extension(MapperExtension):
    def create_instance(ext, mapper, row, imap, class_):
        if row['pjoin_type'] == 'detail':
            return Detail()
        elif row['pjoin_type'] == 'assembly':
            return Assembly()
        else:
            return Element()

element_type_mapper = mapper(Element, element_type_join, extension=Extension())


assign_mapper(SpecLine, specification_table,
    properties=dict(
        master=relation(element_type_mapper,
            primaryjoin=specification_table.c.master_id==element_type_mapper.c.element_id,
            foreignkey=specification_table.c.master_id,
            lazy=True, uselist=False),
        slave=relation(element_type_mapper,
            primaryjoin=specification_table.c.slave_id==element_type_mapper.c.element_id,
            foreignkey=specification_table.c.slave_id,
            lazy=False, uselist=False),
        quantity=specification_table.c.quantity))

Element.mapper.add_property('specification',
    relation(SpecLine,
        primaryjoin=specification_table.c.master_id==element_type_mapper.c.element_id,
        foreignkey=specification_table.c.master_id,
        lazy=True, private=True, backref='master'))


a1 = Assembly(name='a1')
a1.specification.append(Detail(name='d1'))
objectstore.commit()
