import sqlalchemy as sa from sqlalchemy.ext.declarative import declared_attr from sqlalchemy.ext.associationproxy import association_proxy import itertools import logging import pprint import sys # this will get overridden by init_base Base = None class TagBase(object): def __eq__(self, other): return self.text == other.text class TaggedObjectCollection(object): __emulates__ = list def __init__(self): self.data = [] @sa.orm.collections.collection.internally_instrumented def append(self, item, _sa_initiator=None): """This method is marked internally_instrumented to ensure that sqlalchemy only notices the add when a TagAssociation is used. If a class is added directly, we generate a TagAssociation and re-call this method with the association, which then triggers the append event.""" if isinstance(item, TagAssociation): try: logging.debug("adding assoc %s w/ %s" % (item, _sa_initiator)) # this has to be done first. adapter = sa.orm.collections.collection_adapter(self) item = adapter.fire_append_event(item, _sa_initiator) if item.tracked_obj: self.data.append(item) else: cls = self._get_class_by_tablename(item.target_table) thingy = cls() thingy.db_id = item.target_id self.data.append(thingy) # pprint.pprint(self.data) except TypeError: raise else: logging.debug("adding %s" % item) # this seems kind of redundant. assoc = TagAssociation() assoc.target_id = item.db_id assoc.target_table = item.__tablename__ assoc.tracked_obj = item self.append(assoc, _sa_initiator) def remove(self, item): pass def extend(self, item): pass def __iter__(self): return iter(self.data) def __len__(self): return len(self.data) def _get_class_by_tablename(self, tablename): """Return class reference mapped to table. :param tablename: String with name of table. :return: Class reference or None. """ for c in Base._decl_class_registry.values(): if (hasattr(c, '__tablename__') and c.__tablename__ == tablename): return c db_id = sa.Column(sa.Integer, primary_key=True) text = sa.Column(sa.Unicode(255, convert_unicode=False), unique=True) # this & collection should be read-only? collection = sa.orm.relationship("TagAssociation", collection_class=TaggedObjectCollection, enable_typechecks=True, primaryjoin="TagAssociation.tag_id==Tag.db_id" ) __tablename__ = "kk_tags" # we disable batch to ensure that parents are inserted before # children. __mapper_args__ = { # "primary_key": [db_id], "batch": False } def __init__(self, tag_text=None): super(Tag, self).__init__() self.text = tag_text def __repr__(self): return "kk_tag %s: %s" % (self.db_id, self.text) # i'm not sure this is really what we want. def __eq__(self, other): return self.text == other.text class TagAssociationBase(object): __tablename__ = "kk_tag_associations" db_id = sa.Column(sa.Integer, primary_key=True) tag_id = sa.Column(sa.Integer, sa.ForeignKey("kk_tags.db_id")) target_table = sa.Column(sa.Unicode(255, convert_unicode=False)) target_id = sa.Column(sa.Integer) __mapper_args__ = { "polymorphic_on": "target_table" } # m = sa.orm.mapper(Tag, Tag.__table__, non_primary=True, # primary_key=[Tag.text]) # tag_obj = sa.orm.relationship(m) @declared_attr def tag_obj(self): return sa.orm.relationship(Tag) # tag_obj = sa.orm.relationship(Tag) # FIXME doesn't work, may not be needed. # tag_obj = sa.orm.relationship("Tag", back_populates="collection") tag = association_proxy("tag_obj", "text") tracked_obj = None def __eq__(self, other): if not isinstance(other, TagAssociation): return False return self.db_id == other.db_id @property def parent(self): return getattr(self, "%s_parent" % self.target_table) def __repr__(self): return "tag assoc %s: %s:%s" % (self.db_id, self.target_table, self.target_id) class TaggableBase(object): __abstract__ = True @declared_attr def kk_tag_associations(cls): name = cls.__name__ table = cls.__tablename__ def cls_init(self, tag=None): # traceback.print_stack() logging.debug("cls_init called, type %s".format(table)) super(TagAssociation, self).__init__() self.target_table = table # is this necessary? how does this get called? if tag: if isinstance(tag, Tag): self.tag_obj = tag else: logging.debug("cls init called with text %s" % tag) self.tag_obj = Tag(tag) assoc_cls = type( "%sTagAssociation" % name, (TagAssociation, ), dict( __mapper_args__ = { "polymorphic_identity": table, "polymorphic_on": "target_table", }, __init__ = cls_init, ) ) cls.tags = association_proxy("kk_tag_associations", "tag") return sa.orm.relationship( assoc_cls, primaryjoin=( "and_(TagAssociation.target_table=='%s', " "foreign(TagAssociation.target_id)==%s.db_id)" %(table, name)), ) def init_base(new_base): """Set up classes based on new_base.""" # based on https://stackoverflow.com/a/41927212 def init_cls(name, base): current = globals()[name + 'Base'] new_ns = current.__dict__.copy() del new_ns['__dict__'] new_ns = type(name, (new_base,), new_ns) globals()[name] = new_ns for c in ['Tag', 'TagAssociation', 'Taggable']: init_cls(c, new_base) globals()['Base'] = new_base # these are here because "Tag" needs to exist before adding the # listener. sa.event.listen(Tag, 'before_insert', delete_before_insert) sa.event.listen(sa.orm.session.Session, 'before_flush', enforce_unique_text) # need to return new types so that the caller can incorporate them # into its view of the module. return {'Tag': globals()['Tag'], 'TagAssociation': globals()['TagAssociation'], 'Taggable': globals()['Taggable']} def delete_before_insert(mapper, conn, target): # TODO figure out where exactly transactions happen. # TODO can we just upsert? Or, for that matter, skip the insert? r = conn.execute("SELECT db_id FROM kk_tags WHERE text='%s'" % target.text) if r: conn.execute("DELETE FROM kk_tags WHERE text='%s'" % target.text) # i'm not sure if this is necessary -- its purpose is to set the id on # our new instance to ensure it matches the old instance. if this # isn't done, it seems like the id might change, breaking old links. def enforce_unique_text(session, flush_context, instances): for i in session.new: if isinstance(i, Tag): t = session.query(Tag).filter_by(text=i.text).first() if t: i.db_id = t.db_id