libkosokoso/kosokoso.py

220 lines
7.3 KiB
Python

import sqlalchemy as sa
from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.ext.associationproxy import association_proxy
import itertools
import logging
import pprint
import sys
# import traceback
Base = declarative_base()
class Tag(Base):
def __eq__(self, other):
return self.text == other.text
class TaggedObjectCollection(object):
__emulates__ = list
def __init__(self):
self.data = []
@sa.orm.collections.collection.internally_instrumented
def append(self, item, _sa_initiator=None):
"""This method is marked internally_instrumented to ensure
that sqlalchemy only notices the add when a TagAssociation
is used. If a class is added directly, we generate a
TagAssociation and re-call this method with the
association, which then triggers the append event."""
if isinstance(item, TagAssociation):
try:
logging.debug("adding assoc %s w/ %s" % (item,
_sa_initiator))
# this has to be done first.
adapter = sa.orm.collections.collection_adapter(self)
item = adapter.fire_append_event(item, _sa_initiator)
if item.tracked_obj:
self.data.append(item)
else:
cls = self._get_class_by_tablename(item.target_table)
thingy = cls()
thingy.db_id = item.target_id
self.data.append(thingy)
# pprint.pprint(self.data)
except TypeError:
raise
else:
logging.debug("adding %s" % item)
# this seems kind of redundant.
assoc = TagAssociation()
assoc.target_id = item.db_id
assoc.target_table = item.__tablename__
assoc.tracked_obj = item
self.append(assoc, _sa_initiator)
def remove(self, item):
pass
def extend(self, item):
pass
def __iter__(self):
return iter(self.data)
def __len__(self):
return len(self.data)
def _get_class_by_tablename(self, tablename):
"""Return class reference mapped to table.
:param tablename: String with name of table.
:return: Class reference or None.
"""
for c in Base._decl_class_registry.values():
if (hasattr(c, '__tablename__')
and c.__tablename__ == tablename):
return c
db_id = sa.Column(sa.Integer, primary_key=True)
text = sa.Column(sa.Unicode(255, convert_unicode=False),
unique=True)
# this & collection should be read-only?
collection = sa.orm.relationship("TagAssociation",
collection_class=TaggedObjectCollection,
enable_typechecks=True,
primaryjoin="TagAssociation.tag_id==Tag.db_id"
)
__tablename__ = "kk_tags"
# we disable batch to ensure that parents are inserted before
# children.
__mapper_args__ = {
# "primary_key": [db_id],
"batch": False
}
def __init__(self, tag_text=None):
super(Tag, self).__init__()
self.text = tag_text
def __repr__(self):
return "kk_tag %s: %s" % (self.db_id, self.text)
# i'm not sure this is really what we want.
def __eq__(self, other):
return self.text == other.text
def _get_class_by_tablename(self, tablename):
"""Return class reference mapped to table.
:param tablename: String with name of table.
:return: Class reference or None.
"""
for c in Base._decl_class_registry.values():
if (hasattr(c, '__tablename__')
and c.__tablename__ == tablename):
return c
class TagAssociation(Base):
__tablename__ = "kk_tag_associations"
db_id = sa.Column(sa.Integer, primary_key=True)
tag_id = sa.Column(sa.Integer, sa.ForeignKey("kk_tags.db_id"))
target_table = sa.Column(sa.Unicode(255, convert_unicode=False))
target_id = sa.Column(sa.Integer)
__mapper_args__ = {
"polymorphic_on": "target_table"
}
# m = sa.orm.mapper(Tag, Tag.__table__, non_primary=True,
# primary_key=[Tag.text])
# tag_obj = sa.orm.relationship(m)
tag_obj = sa.orm.relationship(Tag)
# FIXME doesn't work, may not be needed.
# tag_obj = sa.orm.relationship("Tag", back_populates="collection")
tag = association_proxy("tag_obj", "text")
tracked_obj = None
def __eq__(self, other):
if not isinstance(other, TagAssociation):
return False
return self.db_id == other.db_id
@property
def parent(self):
return getattr(self, "%s_parent" % self.target_table)
def __repr__(self):
return "tag assoc %s: %s:%s" % (self.db_id,
self.target_table, self.target_id)
class Taggable(Base):
__abstract__ = True
@declared_attr
def kk_tag_associations(cls):
name = cls.__name__
table = cls.__tablename__
def cls_init(self, tag=None):
# traceback.print_stack()
# print 'cls_init called, type %s' % table
super(TagAssociation, self).__init__()
self.target_table = table
# is this necessary? how does this get called?
if tag:
if isinstance(tag, Tag):
self.tag_obj = tag
else:
logging.debug("cls init called with text %s" % tag)
self.tag_obj = Tag(tag)
assoc_cls = type(
"%sTagAssociation" % name,
(TagAssociation, ),
dict(
__mapper_args__ = {
"polymorphic_identity": table,
"polymorphic_on": "target_table",
},
__init__ = cls_init,
)
)
cls.tags = association_proxy("kk_tag_associations", "tag")
return sa.orm.relationship(
assoc_cls,
primaryjoin=(
"and_(TagAssociation.target_table=='%s', "
"foreign(TagAssociation.target_id)==%s.db_id)"
%(table, name)),
)
def delete_before_insert(mapper, conn, target):
# TODO figure out where exactly transactions happen.
r = conn.execute("SELECT db_id FROM kk_tags WHERE text='%s'" %
target.text)
if r:
conn.execute("DELETE FROM kk_tags WHERE text='%s'" % target.text)
# i'm not sure if this is necessary -- its purpose is to set the id on
# our new instance to ensure it matches the old instance. if this
# isn't done, it seems like the id might change, breaking old links.
def enforce_unique_text(session, flush_context, instances):
for i in session.new:
if isinstance(i, Tag):
t = session.query(Tag).filter_by(text=i.text).first()
if t:
i.db_id = t.db_id
sa.event.listen(Tag, 'before_insert', delete_before_insert)
sa.event.listen(sa.orm.session.Session, 'before_flush',
enforce_unique_text)