343 lines
11 KiB
Python
343 lines
11 KiB
Python
import sqlalchemy as sa
|
|
from sqlalchemy.orm import declared_attr
|
|
from sqlalchemy.ext.associationproxy import association_proxy
|
|
# from sqlalchemy import text
|
|
|
|
import itertools
|
|
import logging
|
|
import pprint
|
|
import sys
|
|
|
|
# this will get overridden by init_base
|
|
Base = None
|
|
|
|
|
|
class TagBase(object):
|
|
|
|
def __init__(self, tag_text=None):
|
|
super(Tag, self).__init__()
|
|
self.text = tag_text
|
|
|
|
def __repr__(self):
|
|
return "kk_tag %s: %s" % (self.db_id, self.text)
|
|
|
|
def __eq__(self, other):
|
|
return self.text == other.text
|
|
|
|
db_id = sa.Column(sa.Integer, primary_key=True)
|
|
text = sa.Column(sa.Unicode(255), unique=True)
|
|
__tablename__ = "kk_tags"
|
|
# we disable batch to ensure that parents are inserted before
|
|
# children.
|
|
__mapper_args__ = {
|
|
# "primary_key": [db_id],
|
|
"batch": False
|
|
}
|
|
|
|
class TaggedObjectCollection(object):
|
|
__emulates__ = list
|
|
|
|
def __init__(self):
|
|
self.data = []
|
|
|
|
@sa.orm.collections.collection.internally_instrumented
|
|
def append(self, item, _sa_initiator=None):
|
|
"""This method is marked internally_instrumented to ensure
|
|
that sqlalchemy only notices the add when a TagAssociation
|
|
is used. If a class is added directly, we generate a
|
|
TagAssociation and re-call this method with the
|
|
association, which then triggers the append event.
|
|
|
|
hang on, shouldn't this all be handled by the
|
|
associationproxy?"""
|
|
|
|
if isinstance(item, TagAssociation):
|
|
try:
|
|
# logging.debug("adding assoc %s w/ %s" % (item,
|
|
# _sa_initiator))
|
|
# this has to be done first.
|
|
adapter = sa.orm.collections.collection_adapter(self)
|
|
item = adapter.fire_append_event(item, _sa_initiator)
|
|
if item.tracked_obj:
|
|
self.data.append(item)
|
|
else:
|
|
cls = self._get_class_by_tablename(item.target_table)
|
|
thingy = cls()
|
|
thingy.db_id = item.target_id
|
|
self.data.append(thingy)
|
|
# pprint.pprint(self.data)
|
|
except TypeError:
|
|
raise
|
|
else:
|
|
logging.debug("adding %s" % item)
|
|
# this seems kind of redundant.
|
|
assoc = TagAssociation()
|
|
assoc.target_id = item.db_id
|
|
assoc.target_table = item.__tablename__
|
|
assoc.tracked_obj = item
|
|
self.append(assoc, _sa_initiator)
|
|
|
|
def remove(self, item):
|
|
pass
|
|
|
|
# 'extend' is apparently taken?
|
|
def merge(self, iterable):
|
|
logging.debug('merge called')
|
|
for i in iterable:
|
|
for j in self.data:
|
|
if (i.target_table == j.target_table and
|
|
i.target_id == j.target_id):
|
|
logging.info(f'item {i} already in collection')
|
|
break
|
|
else:
|
|
logging.info(f'adding {i}')
|
|
self.append(i)
|
|
continue
|
|
|
|
def __iter__(self):
|
|
return iter(self.data)
|
|
def __len__(self):
|
|
return len(self.data)
|
|
|
|
def __eq__(self, other):
|
|
pass
|
|
# if not isinstance(other, TaggedObjectCollection):
|
|
# return False
|
|
# if len(self.data) ne len(other.data):
|
|
# return False
|
|
|
|
# s1 = sorted(self.data)
|
|
# s2 = sorted(other.data)
|
|
# for i in range(0, len(s1)):
|
|
# if s1[i] ne s2[i]:
|
|
# return False
|
|
# return True
|
|
|
|
def __repr__(self):
|
|
out = 'TaggedObjectCollection: ['
|
|
for i in self.data:
|
|
out += str(i) + '\n'
|
|
out += ']'
|
|
return out
|
|
|
|
def _get_class_by_tablename(self, tablename):
|
|
"""Return class reference mapped to table.
|
|
|
|
:param tablename: String with name of table.
|
|
:return: Class reference or None.
|
|
"""
|
|
for c in Base.registry._class_registry.values():
|
|
if getattr(c, '__tablename__', None) == tablename:
|
|
return c
|
|
# for c in Base._decl_class_registry.values():
|
|
# if (hasattr(c, '__tablename__')
|
|
# and c.__tablename__ == tablename):
|
|
# return c
|
|
|
|
collection = sa.orm.relationship("TagAssociation",
|
|
collection_class=TaggedObjectCollection,
|
|
enable_typechecks=True,
|
|
primaryjoin="TagAssociation.tag_id==Tag.db_id",
|
|
back_populates="tag_obj"
|
|
)
|
|
|
|
|
|
class TagAssociationBase(object):
|
|
__tablename__ = "kk_tag_associations"
|
|
db_id = sa.Column(sa.Integer, primary_key=True)
|
|
tag_id = sa.Column(sa.Integer, sa.ForeignKey("kk_tags.db_id"))
|
|
# target_table = sa.Column(sa.Unicode(255, convert_unicode=False))
|
|
target_table = sa.Column(sa.Unicode(255))
|
|
target_id = sa.Column(sa.Integer)
|
|
__mapper_args__ = {
|
|
"polymorphic_on": "target_table"
|
|
}
|
|
|
|
# at import time, Tag doesn't exist, so we use declared_attr to
|
|
# defer.
|
|
@declared_attr
|
|
def tag_obj(cls):
|
|
return sa.orm.relationship(Tag)
|
|
|
|
tag = association_proxy("tag_obj", "text")
|
|
|
|
tracked_obj = None
|
|
|
|
def __eq__(self, other):
|
|
if not isinstance(other, TagAssociation):
|
|
return False
|
|
return self.db_id == other.db_id
|
|
|
|
def __lt__(self, other):
|
|
if not isinstance(other, TagAssociation):
|
|
return False
|
|
k1 = str(self.tag_obj.text) + self.target_table + str(self.target_id)
|
|
k2 = str(other.tag_obj.text) + other.target_table + str(other.target_id)
|
|
return k1 < k2
|
|
|
|
# @property
|
|
# def parent(self):
|
|
# return getattr(self, "%s_parent" % self.target_table)
|
|
|
|
def __repr__(self):
|
|
return "tag assoc %s: [%s] %s:%s" % (self.db_id, self.tag_obj.text,
|
|
self.target_table, self.target_id)
|
|
|
|
|
|
class TaggableBase(object):
|
|
__abstract__ = True
|
|
|
|
@declared_attr
|
|
def kk_tag_associations(cls):
|
|
name = cls.__name__
|
|
table = cls.__tablename__
|
|
|
|
def cls_init(self, tag=None):
|
|
# traceback.print_stack()
|
|
# logging.debug("cls_init called, type %s".format(table))
|
|
super(TagAssociation, self).__init__()
|
|
self.target_table = table
|
|
# is this necessary? how does this get called?
|
|
if tag:
|
|
if isinstance(tag, Tag):
|
|
self.tag_obj = tag
|
|
else:
|
|
# logging.debug("cls init called with text %s" % tag)
|
|
self.tag_obj = Tag(tag)
|
|
|
|
|
|
assoc_cls = type(
|
|
f"{name}TagAssociation",
|
|
(TagAssociation, ),
|
|
dict(
|
|
__mapper_args__ = {
|
|
"polymorphic_identity": table,
|
|
"polymorphic_on": "target_table",
|
|
},
|
|
__init__ = cls_init,
|
|
)
|
|
)
|
|
cls.tags = association_proxy("kk_tag_associations", "tag")
|
|
|
|
return sa.orm.relationship(
|
|
assoc_cls,
|
|
primaryjoin=(
|
|
"and_(TagAssociation.target_table=='%s', "
|
|
"foreign(TagAssociation.target_id)==%s.db_id)"
|
|
%(table, name)),
|
|
)
|
|
|
|
|
|
|
|
def init_base(new_base):
|
|
"""Set up classes based on new_base."""
|
|
# based on https://stackoverflow.com/a/41927212
|
|
def init_cls(name, base):
|
|
current = globals()[name + 'Base']
|
|
new_ns = current.__dict__.copy()
|
|
del new_ns['__dict__']
|
|
new_ns = type(name, (new_base,), new_ns)
|
|
globals()[name] = new_ns
|
|
return {name: new_ns}
|
|
|
|
ret = {}
|
|
for c in ['Tag', 'TagAssociation', 'Taggable']:
|
|
ret.update(init_cls(c, new_base))
|
|
|
|
globals()['Base'] = new_base
|
|
|
|
return ret
|
|
|
|
|
|
@sa.event.listens_for(sa.orm.Session, 'before_flush')
|
|
def merge_text_tags(session, ctx, instances):
|
|
texts = {}
|
|
for i in session.new:
|
|
if isinstance(i, Tag):
|
|
t = session.query(Tag).filter_by(text=i.text).first()
|
|
if t:
|
|
logging.info(f'got t, merging')
|
|
i.db_id = t.db_id
|
|
t.collection.merge(i.collection)
|
|
|
|
# t = session.merge(i)
|
|
session.expunge(i)
|
|
elif i.text in texts:
|
|
logging.info('have text, not sure how to proceed.')
|
|
session.expunge(i)
|
|
else:
|
|
texts[i.text] = i
|
|
|
|
# @sa.event.listens_for(sa.orm.Session, "before_attach")
|
|
# def intercept(session, instance):
|
|
# if not isinstance(instance, Tag):
|
|
# logging.info(f'got non-tag {instance}')
|
|
# return
|
|
#
|
|
# # i guess this only returns stuff that's persistent?
|
|
# logging.info(f'have tag {instance}')
|
|
# t = session.query(Tag).filter_by(text=instance.text).first()
|
|
# if t:
|
|
# logging.info(f'have tag {instance}')
|
|
# instance.db_id = t.db_id
|
|
# session.merge(instance)
|
|
# return
|
|
|
|
|
|
# i'm not sure if this is necessary -- its purpose is to set the id on
|
|
# our new instance to ensure it matches the old instance. if this
|
|
# isn't done, it seems like the id might change, breaking old links.
|
|
# @sa.event.listens_for(sa.orm.Session, "before_flush")
|
|
# def enforce_unique_text(session, flush_context, instances):
|
|
# texts = []
|
|
# for i in session.new:
|
|
# if isinstance(i, Tag):
|
|
# t = session.query(Tag).filter_by(text=i.text).first()
|
|
# if t:
|
|
# logging.info(f'got t, merging')
|
|
# i.db_id = t.db_id
|
|
# session.merge(t)
|
|
# elif i.text in texts:
|
|
# logging.info('have text, not sure how to proceed.')
|
|
# else:
|
|
# texts.append(i.text)
|
|
|
|
|
|
# FIXME we could probably do this more effectively if we listen for
|
|
# the object event, not session.
|
|
# @sa.event.listens_for(sa.orm.Session, "before_flush")
|
|
# def delete_before_insert(session, context, instances):
|
|
# tags = []
|
|
# logging.info('calling delete_before_insert')
|
|
# for instance in session.dirty:
|
|
# if isinstance(instance, Tag):
|
|
# logging.info('have dirty tag')
|
|
# tags.append(instance.text)
|
|
# for instance in session.new:
|
|
# if not isinstance(instance, Tag):
|
|
# logging.info('got non-tag')
|
|
# continue
|
|
# # if not session.is_modified(instance):
|
|
# # continue
|
|
#
|
|
# logging.info(f'have tag {instance}')
|
|
# t = session.query(Tag).filter_by(text=instance.text).first()
|
|
# # conn = session.connection()
|
|
# # r = conn.execute(text("SELECT db_id FROM kk_tags WHERE text='%s'" %
|
|
# # instance.text))
|
|
# # apparently this is -1 if the query matches no rows
|
|
# # if r.rowcount > 0:
|
|
# if t:
|
|
# session.expunge(instance)
|
|
# else:
|
|
# logging.info('no row')
|
|
#
|
|
# # handle the case where we have multiples of the same tag in
|
|
# # new
|
|
# logging.info(f'{tags}')
|
|
# if instance.text in tags:
|
|
# session.expunge(instance)
|
|
# else:
|
|
# tags.append(instance.text)
|
|
|