import sqlalchemy as sa from sqlalchemy.orm import declared_attr from sqlalchemy.ext.associationproxy import association_proxy # from sqlalchemy import text import itertools import logging import pprint import sys # this will get overridden by init_base Base = None class TagBase(object): def __init__(self, tag_text=None): super(Tag, self).__init__() self.text = tag_text def __repr__(self): return "kk_tag %s: %s" % (self.db_id, self.text) def __eq__(self, other): return self.text == other.text db_id = sa.Column(sa.Integer, primary_key=True) text = sa.Column(sa.Unicode(255), unique=True) __tablename__ = "kk_tags" # we disable batch to ensure that parents are inserted before # children. __mapper_args__ = { # "primary_key": [db_id], "batch": False } class TaggedObjectCollection(object): __emulates__ = list def __init__(self): self.data = [] @sa.orm.collections.collection.internally_instrumented def append(self, item, _sa_initiator=None): """This method is marked internally_instrumented to ensure that sqlalchemy only notices the add when a TagAssociation is used. If a class is added directly, we generate a TagAssociation and re-call this method with the association, which then triggers the append event. hang on, shouldn't this all be handled by the associationproxy?""" if isinstance(item, TagAssociation): try: # logging.debug("adding assoc %s w/ %s" % (item, # _sa_initiator)) # this has to be done first. adapter = sa.orm.collections.collection_adapter(self) item = adapter.fire_append_event(item, _sa_initiator) if item.tracked_obj: self.data.append(item) else: cls = self._get_class_by_tablename(item.target_table) thingy = cls() thingy.db_id = item.target_id self.data.append(thingy) # pprint.pprint(self.data) except TypeError: raise else: logging.debug("adding %s" % item) # this seems kind of redundant. assoc = TagAssociation() assoc.target_id = item.db_id assoc.target_table = item.__tablename__ assoc.tracked_obj = item self.append(assoc, _sa_initiator) def remove(self, item): pass # 'extend' is apparently taken? def merge(self, iterable): logging.debug('merge called') for i in iterable: for j in self.data: if (i.target_table == j.target_table and i.target_id == j.target_id): logging.info(f'item {i} already in collection') break else: logging.info(f'adding {i}') self.append(i) continue def __iter__(self): return iter(self.data) def __len__(self): return len(self.data) def __eq__(self, other): pass # if not isinstance(other, TaggedObjectCollection): # return False # if len(self.data) ne len(other.data): # return False # s1 = sorted(self.data) # s2 = sorted(other.data) # for i in range(0, len(s1)): # if s1[i] ne s2[i]: # return False # return True def __repr__(self): out = 'TaggedObjectCollection: [' for i in self.data: out += str(i) + '\n' out += ']' return out def _get_class_by_tablename(self, tablename): """Return class reference mapped to table. :param tablename: String with name of table. :return: Class reference or None. """ for c in Base.registry._class_registry.values(): if getattr(c, '__tablename__', None) == tablename: return c # for c in Base._decl_class_registry.values(): # if (hasattr(c, '__tablename__') # and c.__tablename__ == tablename): # return c collection = sa.orm.relationship("TagAssociation", collection_class=TaggedObjectCollection, enable_typechecks=True, primaryjoin="TagAssociation.tag_id==Tag.db_id", back_populates="tag_obj" ) class TagAssociationBase(object): __tablename__ = "kk_tag_associations" db_id = sa.Column(sa.Integer, primary_key=True) tag_id = sa.Column(sa.Integer, sa.ForeignKey("kk_tags.db_id")) # target_table = sa.Column(sa.Unicode(255, convert_unicode=False)) target_table = sa.Column(sa.Unicode(255)) target_id = sa.Column(sa.Integer) __mapper_args__ = { "polymorphic_on": "target_table" } # at import time, Tag doesn't exist, so we use declared_attr to # defer. @declared_attr def tag_obj(cls): return sa.orm.relationship(Tag) tag = association_proxy("tag_obj", "text") tracked_obj = None def __eq__(self, other): if not isinstance(other, TagAssociation): return False return self.db_id == other.db_id def __lt__(self, other): if not isinstance(other, TagAssociation): return False k1 = str(self.tag_obj.text) + self.target_table + str(self.target_id) k2 = str(other.tag_obj.text) + other.target_table + str(other.target_id) return k1 < k2 # @property # def parent(self): # return getattr(self, "%s_parent" % self.target_table) def __repr__(self): return "tag assoc %s: [%s] %s:%s" % (self.db_id, self.tag_obj.text, self.target_table, self.target_id) class TaggableBase(object): __abstract__ = True @declared_attr def kk_tag_associations(cls): name = cls.__name__ table = cls.__tablename__ def cls_init(self, tag=None): # traceback.print_stack() # logging.debug("cls_init called, type %s".format(table)) super(TagAssociation, self).__init__() self.target_table = table # is this necessary? how does this get called? if tag: if isinstance(tag, Tag): self.tag_obj = tag else: # logging.debug("cls init called with text %s" % tag) self.tag_obj = Tag(tag) assoc_cls = type( f"{name}TagAssociation", (TagAssociation, ), dict( __mapper_args__ = { "polymorphic_identity": table, "polymorphic_on": "target_table", }, __init__ = cls_init, ) ) cls.tags = association_proxy("kk_tag_associations", "tag") return sa.orm.relationship( assoc_cls, primaryjoin=( "and_(TagAssociation.target_table=='%s', " "foreign(TagAssociation.target_id)==%s.db_id)" %(table, name)), ) def init_base(new_base): """Set up classes based on new_base.""" # based on https://stackoverflow.com/a/41927212 def init_cls(name, base): current = globals()[name + 'Base'] new_ns = current.__dict__.copy() del new_ns['__dict__'] new_ns = type(name, (new_base,), new_ns) globals()[name] = new_ns return {name: new_ns} ret = {} for c in ['Tag', 'TagAssociation', 'Taggable']: ret.update(init_cls(c, new_base)) globals()['Base'] = new_base return ret @sa.event.listens_for(sa.orm.Session, 'before_flush') def merge_text_tags(session, ctx, instances): texts = {} for i in session.new: if isinstance(i, Tag): t = session.query(Tag).filter_by(text=i.text).first() if t: logging.info(f'got t, merging') i.db_id = t.db_id t.collection.merge(i.collection) # t = session.merge(i) session.expunge(i) elif i.text in texts: logging.info('have text, not sure how to proceed.') session.expunge(i) else: texts[i.text] = i # @sa.event.listens_for(sa.orm.Session, "before_attach") # def intercept(session, instance): # if not isinstance(instance, Tag): # logging.info(f'got non-tag {instance}') # return # # # i guess this only returns stuff that's persistent? # logging.info(f'have tag {instance}') # t = session.query(Tag).filter_by(text=instance.text).first() # if t: # logging.info(f'have tag {instance}') # instance.db_id = t.db_id # session.merge(instance) # return # i'm not sure if this is necessary -- its purpose is to set the id on # our new instance to ensure it matches the old instance. if this # isn't done, it seems like the id might change, breaking old links. # @sa.event.listens_for(sa.orm.Session, "before_flush") # def enforce_unique_text(session, flush_context, instances): # texts = [] # for i in session.new: # if isinstance(i, Tag): # t = session.query(Tag).filter_by(text=i.text).first() # if t: # logging.info(f'got t, merging') # i.db_id = t.db_id # session.merge(t) # elif i.text in texts: # logging.info('have text, not sure how to proceed.') # else: # texts.append(i.text) # FIXME we could probably do this more effectively if we listen for # the object event, not session. # @sa.event.listens_for(sa.orm.Session, "before_flush") # def delete_before_insert(session, context, instances): # tags = [] # logging.info('calling delete_before_insert') # for instance in session.dirty: # if isinstance(instance, Tag): # logging.info('have dirty tag') # tags.append(instance.text) # for instance in session.new: # if not isinstance(instance, Tag): # logging.info('got non-tag') # continue # # if not session.is_modified(instance): # # continue # # logging.info(f'have tag {instance}') # t = session.query(Tag).filter_by(text=instance.text).first() # # conn = session.connection() # # r = conn.execute(text("SELECT db_id FROM kk_tags WHERE text='%s'" % # # instance.text)) # # apparently this is -1 if the query matches no rows # # if r.rowcount > 0: # if t: # session.expunge(instance) # else: # logging.info('no row') # # # handle the case where we have multiples of the same tag in # # new # logging.info(f'{tags}') # if instance.text in tags: # session.expunge(instance) # else: # tags.append(instance.text)