190 lines
5.8 KiB
Python
190 lines
5.8 KiB
Python
import sqlalchemy as sa
|
|
from sqlalchemy import event
|
|
from sqlalchemy.ext.declarative import declarative_base, declared_attr
|
|
from sqlalchemy.ext.associationproxy import association_proxy
|
|
|
|
import itertools
|
|
import pprint
|
|
|
|
Base = declarative_base()
|
|
|
|
|
|
class Tag(Base):
|
|
|
|
class TaggedObjectCollection(object):
|
|
__emulates__ = list
|
|
|
|
def __init__(self):
|
|
self.data = []
|
|
|
|
def append(self, item):
|
|
try:
|
|
cls = self._get_class_by_tablename(item.target_table)
|
|
thingy = cls()
|
|
thingy.db_id = item.target_id
|
|
self.data.append(thingy)
|
|
except TypeError:
|
|
raise
|
|
|
|
def remove(self, item):
|
|
pass
|
|
def extend(self, item):
|
|
pass
|
|
def __iter__(self):
|
|
return iter(self.data)
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __eq__(self, other):
|
|
return self.text == other.text
|
|
|
|
def _get_class_by_tablename(self, tablename):
|
|
"""Return class reference mapped to table.
|
|
|
|
:param tablename: String with name of table.
|
|
:return: Class reference or None.
|
|
"""
|
|
for c in Base._decl_class_registry.values():
|
|
if (hasattr(c, '__tablename__')
|
|
and c.__tablename__ == tablename):
|
|
return c
|
|
|
|
|
|
db_id = sa.Column(sa.Integer, primary_key=True)
|
|
text = sa.Column(sa.Unicode(255, convert_unicode=False),
|
|
unique=True)
|
|
# this & collection should be read-only?
|
|
# associations = sa.orm.relationship("TagAssociation",
|
|
# back_populates="tag_obj")
|
|
collection = sa.orm.relationship("TagAssociation",
|
|
collection_class=TaggedObjectCollection,
|
|
enable_typechecks = False,
|
|
# back_populates="tag_obj",
|
|
primaryjoin="TagAssociation.tag_id==Tag.db_id"
|
|
)
|
|
__tablename__ = "kk_tags"
|
|
__mapper_args__ = {
|
|
"primary_key": [db_id],
|
|
"batch": False
|
|
}
|
|
|
|
def __init__(self, tag_text=None):
|
|
super(Tag, self).__init__()
|
|
self.text = tag_text
|
|
|
|
def __repr__(self):
|
|
return "kk_tag %s: %s" % (self.db_id, self.text)
|
|
|
|
# i'm not sure this is really what we want.
|
|
def __eq__(self, other):
|
|
return self.text == other.text
|
|
|
|
def _get_class_by_tablename(self, tablename):
|
|
"""Return class reference mapped to table.
|
|
|
|
:param tablename: String with name of table.
|
|
:return: Class reference or None.
|
|
"""
|
|
for c in Base._decl_class_registry.values():
|
|
if (hasattr(c, '__tablename__')
|
|
and c.__tablename__ == tablename):
|
|
return c
|
|
|
|
|
|
|
|
class TagAssociation(Base):
|
|
__tablename__ = "kk_tag_associations"
|
|
db_id = sa.Column(sa.Integer, primary_key=True)
|
|
tag_id = sa.Column(sa.Integer, sa.ForeignKey("kk_tags.db_id"))
|
|
target_table = sa.Column(sa.Unicode(255, convert_unicode=False))
|
|
target_id = sa.Column(sa.Integer)
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "target_table"
|
|
}
|
|
|
|
# m = sa.orm.mapper(Tag, Tag.__table__, non_primary=True,
|
|
# primary_key=[Tag.text])
|
|
# tag_obj = sa.orm.relationship(m)
|
|
tag_obj = sa.orm.relationship(Tag)
|
|
|
|
# FIXME doesn't work, may not be needed.
|
|
# tag_obj = sa.orm.relationship("Tag", back_populates="collection")
|
|
tag = association_proxy("tag_obj", "text")
|
|
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, TagAssociation):
|
|
return False
|
|
return self.db_id == other.db_id
|
|
|
|
@property
|
|
def parent(self):
|
|
return getattr(self, "%s_parent" % self.target_table)
|
|
|
|
def __repr__(self):
|
|
return "tag assoc %s: %s:%s" % (self.db_id,
|
|
self.target_table, self.target_id)
|
|
|
|
class Taggable(Base):
|
|
__abstract__ = True
|
|
|
|
@declared_attr
|
|
def kk_tag_associations(cls):
|
|
name = cls.__name__
|
|
table = cls.__tablename__
|
|
|
|
def cls_init(self, tag=None):
|
|
super(TagAssociation, self).__init__()
|
|
self.target_table = table
|
|
if tag:
|
|
if isinstance(tag, Tag):
|
|
self.tag_obj = tag
|
|
else:
|
|
# print "cls init called with text %s" % tag
|
|
self.tag_obj = Tag(tag)
|
|
|
|
|
|
assoc_cls = type(
|
|
"%sTagAssociation" % name,
|
|
(TagAssociation, ),
|
|
dict(
|
|
__mapper_args__ = {
|
|
"polymorphic_identity": table,
|
|
"polymorphic_on": "target_table",
|
|
},
|
|
__init__ = cls_init,
|
|
)
|
|
)
|
|
cls.tags = association_proxy("kk_tag_associations", "tag")
|
|
|
|
|
|
return sa.orm.relationship(
|
|
assoc_cls,
|
|
primaryjoin=(
|
|
"and_(TagAssociation.target_table=='%s', "
|
|
"foreign(TagAssociation.target_id)==%s.db_id)"
|
|
%(table, name)),
|
|
)
|
|
|
|
|
|
def delete_before_insert(mapper, conn, target):
|
|
# TODO figure out where exactly transactions happen.
|
|
r = conn.execute("SELECT db_id FROM kk_tags WHERE text='%s'" %
|
|
target.text)
|
|
if r:
|
|
conn.execute("DELETE FROM kk_tags WHERE text='%s'" % target.text)
|
|
|
|
# i'm not sure if this is necessary -- its purpose is to set the id on
|
|
# our new instance to ensure it matches the old instance. if this
|
|
# isn't done, it seems like the id might change, breaking old links.
|
|
def enforce_unique_text(session, flush_context, instances):
|
|
for i in session.new:
|
|
if isinstance(i, Tag):
|
|
t = session.query(Tag).filter_by(text=i.text).first()
|
|
if t:
|
|
i.db_id = t.db_id
|
|
|
|
event.listen(Tag, 'before_insert', delete_before_insert)
|
|
event.listen(sa.orm.session.Session, 'before_flush',
|
|
enforce_unique_text)
|