diff --git a/kosokoso.py b/kosokoso.py index 44e9fe5..a01743f 100644 --- a/kosokoso.py +++ b/kosokoso.py @@ -1,7 +1,7 @@ import sqlalchemy as sa from sqlalchemy.orm import declared_attr from sqlalchemy.ext.associationproxy import association_proxy -from sqlalchemy import text +# from sqlalchemy import text import itertools import logging @@ -14,9 +14,26 @@ Base = None class TagBase(object): + def __init__(self, tag_text=None): + super(Tag, self).__init__() + self.text = tag_text + + def __repr__(self): + return "kk_tag %s: %s" % (self.db_id, self.text) + def __eq__(self, other): return self.text == other.text + db_id = sa.Column(sa.Integer, primary_key=True) + text = sa.Column(sa.Unicode(255), unique=True) + __tablename__ = "kk_tags" + # we disable batch to ensure that parents are inserted before + # children. + __mapper_args__ = { + # "primary_key": [db_id], + "batch": False + } + class TaggedObjectCollection(object): __emulates__ = list @@ -29,12 +46,15 @@ class TagBase(object): that sqlalchemy only notices the add when a TagAssociation is used. If a class is added directly, we generate a TagAssociation and re-call this method with the - association, which then triggers the append event.""" + association, which then triggers the append event. + + hang on, shouldn't this all be handled by the + associationproxy?""" if isinstance(item, TagAssociation): try: - logging.debug("adding assoc %s w/ %s" % (item, - _sa_initiator)) + # logging.debug("adding assoc %s w/ %s" % (item, + # _sa_initiator)) # this has to be done first. adapter = sa.orm.collections.collection_adapter(self) item = adapter.fire_append_event(item, _sa_initiator) @@ -59,13 +79,47 @@ class TagBase(object): def remove(self, item): pass - def extend(self, item): - pass + + # 'extend' is apparently taken? + def merge(self, iterable): + logging.debug('merge called') + for i in iterable: + for j in self.data: + if (i.target_table == j.target_table and + i.target_id == j.target_id): + logging.info(f'item {i} already in collection') + break + else: + logging.info(f'adding {i}') + self.append(i) + continue + def __iter__(self): return iter(self.data) def __len__(self): return len(self.data) + def __eq__(self, other): + pass + # if not isinstance(other, TaggedObjectCollection): + # return False + # if len(self.data) ne len(other.data): + # return False + + # s1 = sorted(self.data) + # s2 = sorted(other.data) + # for i in range(0, len(s1)): + # if s1[i] ne s2[i]: + # return False + # return True + + def __repr__(self): + out = 'TaggedObjectCollection: [' + for i in self.data: + out += str(i) + '\n' + out += ']' + return out + def _get_class_by_tablename(self, tablename): """Return class reference mapped to table. @@ -80,34 +134,12 @@ class TagBase(object): # and c.__tablename__ == tablename): # return c - - db_id = sa.Column(sa.Integer, primary_key=True) - # text = sa.Column(sa.Unicode(255, convert_unicode=False), - # unique=True) - text = sa.Column(sa.Unicode(255), unique=True) collection = sa.orm.relationship("TagAssociation", collection_class=TaggedObjectCollection, enable_typechecks=True, - primaryjoin="TagAssociation.tag_id==Tag.db_id" + primaryjoin="TagAssociation.tag_id==Tag.db_id", + back_populates="tag_obj" ) - __tablename__ = "kk_tags" - # we disable batch to ensure that parents are inserted before - # children. - __mapper_args__ = { - # "primary_key": [db_id], - "batch": False - } - - def __init__(self, tag_text=None): - super(Tag, self).__init__() - self.text = tag_text - - def __repr__(self): - return "kk_tag %s: %s" % (self.db_id, self.text) - - # i'm not sure this is really what we want. - def __eq__(self, other): - return self.text == other.text class TagAssociationBase(object): @@ -121,8 +153,10 @@ class TagAssociationBase(object): "polymorphic_on": "target_table" } + # at import time, Tag doesn't exist, so we use declared_attr to + # defer. @declared_attr - def tag_obj(self): + def tag_obj(cls): return sa.orm.relationship(Tag) tag = association_proxy("tag_obj", "text") @@ -134,12 +168,19 @@ class TagAssociationBase(object): return False return self.db_id == other.db_id - @property - def parent(self): - return getattr(self, "%s_parent" % self.target_table) + def __lt__(self, other): + if not isinstance(other, TagAssociation): + return False + k1 = str(self.tag_obj.text) + self.target_table + str(self.target_id) + k2 = str(other.tag_obj.text) + other.target_table + str(other.target_id) + return k1 < k2 + + # @property + # def parent(self): + # return getattr(self, "%s_parent" % self.target_table) def __repr__(self): - return "tag assoc %s: %s:%s" % (self.db_id, + return "tag assoc %s: [%s] %s:%s" % (self.db_id, self.tag_obj.text, self.target_table, self.target_id) @@ -153,7 +194,7 @@ class TaggableBase(object): def cls_init(self, tag=None): # traceback.print_stack() - logging.debug("cls_init called, type %s".format(table)) + # logging.debug("cls_init called, type %s".format(table)) super(TagAssociation, self).__init__() self.target_table = table # is this necessary? how does this get called? @@ -161,12 +202,12 @@ class TaggableBase(object): if isinstance(tag, Tag): self.tag_obj = tag else: - logging.debug("cls init called with text %s" % tag) + # logging.debug("cls init called with text %s" % tag) self.tag_obj = Tag(tag) assoc_cls = type( - "%sTagAssociation" % name, + f"{name}TagAssociation", (TagAssociation, ), dict( __mapper_args__ = { @@ -178,7 +219,6 @@ class TaggableBase(object): ) cls.tags = association_proxy("kk_tag_associations", "tag") - return sa.orm.relationship( assoc_cls, primaryjoin=( @@ -188,6 +228,7 @@ class TaggableBase(object): ) + def init_base(new_base): """Set up classes based on new_base.""" # based on https://stackoverflow.com/a/41927212 @@ -205,33 +246,97 @@ def init_base(new_base): globals()['Base'] = new_base - # these are here because "Tag" needs to exist before adding the - # listener. - sa.event.listen(Tag, 'before_insert', delete_before_insert) - sa.event.listen(sa.orm.session.Session, 'before_flush', - enforce_unique_text) - - # need to return new types so that the caller can incorporate them - # into its view of the module. return ret -def delete_before_insert(mapper, conn, target): - # TODO figure out where exactly transactions happen. - # TODO can we just upsert? Or, for that matter, skip the insert? - r = conn.execute(text("SELECT db_id FROM kk_tags WHERE text='%s'" % - target.text)) - if r: - conn.execute(text("DELETE FROM kk_tags WHERE text='%s'" % - target.text)) - -# i'm not sure if this is necessary -- its purpose is to set the id on -# our new instance to ensure it matches the old instance. if this -# isn't done, it seems like the id might change, breaking old links. -def enforce_unique_text(session, flush_context, instances): +@sa.event.listens_for(sa.orm.Session, 'before_flush') +def merge_text_tags(session, ctx, instances): + texts = {} for i in session.new: if isinstance(i, Tag): t = session.query(Tag).filter_by(text=i.text).first() if t: + logging.info(f'got t, merging') i.db_id = t.db_id + t.collection.merge(i.collection) + + # t = session.merge(i) + session.expunge(i) + elif i.text in texts: + logging.info('have text, not sure how to proceed.') + session.expunge(i) + else: + texts[i.text] = i + +# @sa.event.listens_for(sa.orm.Session, "before_attach") +# def intercept(session, instance): +# if not isinstance(instance, Tag): +# logging.info(f'got non-tag {instance}') +# return +# +# # i guess this only returns stuff that's persistent? +# logging.info(f'have tag {instance}') +# t = session.query(Tag).filter_by(text=instance.text).first() +# if t: +# logging.info(f'have tag {instance}') +# instance.db_id = t.db_id +# session.merge(instance) +# return + + +# i'm not sure if this is necessary -- its purpose is to set the id on +# our new instance to ensure it matches the old instance. if this +# isn't done, it seems like the id might change, breaking old links. +# @sa.event.listens_for(sa.orm.Session, "before_flush") +# def enforce_unique_text(session, flush_context, instances): +# texts = [] +# for i in session.new: +# if isinstance(i, Tag): +# t = session.query(Tag).filter_by(text=i.text).first() +# if t: +# logging.info(f'got t, merging') +# i.db_id = t.db_id +# session.merge(t) +# elif i.text in texts: +# logging.info('have text, not sure how to proceed.') +# else: +# texts.append(i.text) + + +# FIXME we could probably do this more effectively if we listen for +# the object event, not session. +# @sa.event.listens_for(sa.orm.Session, "before_flush") +# def delete_before_insert(session, context, instances): +# tags = [] +# logging.info('calling delete_before_insert') +# for instance in session.dirty: +# if isinstance(instance, Tag): +# logging.info('have dirty tag') +# tags.append(instance.text) +# for instance in session.new: +# if not isinstance(instance, Tag): +# logging.info('got non-tag') +# continue +# # if not session.is_modified(instance): +# # continue +# +# logging.info(f'have tag {instance}') +# t = session.query(Tag).filter_by(text=instance.text).first() +# # conn = session.connection() +# # r = conn.execute(text("SELECT db_id FROM kk_tags WHERE text='%s'" % +# # instance.text)) +# # apparently this is -1 if the query matches no rows +# # if r.rowcount > 0: +# if t: +# session.expunge(instance) +# else: +# logging.info('no row') +# +# # handle the case where we have multiples of the same tag in +# # new +# logging.info(f'{tags}') +# if instance.text in tags: +# session.expunge(instance) +# else: +# tags.append(instance.text) diff --git a/tests/test_basic.py b/tests/test_basic.py index e6a5f07..492c821 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -13,14 +13,18 @@ import unittest from sqlalchemy import text -# from . import kosokoso as kk import libkosokoso as kk -# DummyBase = sa.orm.declarative_base() -DummyBase = orm.declarative_base() +class DummyBase(sa.orm.DeclarativeBase): + pass +# This is because presumably you have a declarative_base that you're +# already using, and obviously sqalchemy only allows relationships +# between objects with the same base. Hence, init_base dynamically +# generates classes derived from the base. kk.__dict__.update(kk.init_base(DummyBase)) # TODO make it so DummyBase doesn't have to come after the mixin. +# Actually, is DummyBase needed at all? class Foo(kk.Taggable, DummyBase): __tablename__ = 'foos' db_id = sa.Column(sa.Integer, primary_key=True) @@ -88,9 +92,9 @@ class ks_basic(unittest.TestCase): self.session.add(t1) self.session.add(t2) - # self.session.commit() + self.session.commit() - # del t1, t2 + del t1, t2 ts = self.session.query(kk.Tag).all() self.assertEqual(1, len(ts)) @@ -192,6 +196,32 @@ class ks_basic(unittest.TestCase): self.assertEqual(l[0], f1) self.assertTrue(isinstance(l[2], Bar)) + # @unittest.skip + def test_merge_collection(self): + f1 = Foo() + f2 = Foo() + t1 = kk.Tag('tag1') + t2 = kk.Tag('tag2') + # all this stuff has to be in the database to get an id. + # FIXME make tagging in general work with transient objects. + self.session.add(f1) + self.session.add(f2) + self.session.add(t1) + self.session.add(t2) + self.session.commit() + + t1.collection.append(f1) + t1.collection.append(f2) + t2.collection.append(f2) + t2.collection.merge(t1.collection) + c1 = sorted(t1.collection) + c2 = sorted(t2.collection) + + self.assertEqual(len(t1.collection), len(t2.collection)) + # FIXME doesn't work because collection equality includes + # equality of tag names, which of course doesn't exist. + # self.assertEqual(t1.collection, t2.collection) + def test_association(self): a = Foo() b = Foo() @@ -329,21 +359,25 @@ class ks_basic(unittest.TestCase): actual = '\n'.join(actual) self.assertEqual(expected, '\n%s' % actual) + # @unittest.skip def test_addstring_repeated(self): a1 = Foo() a2 = Foo() a1.tags.append("tag1") a2.tags.append("tag1") + # implicitly adds tags. self.session.add(a1) # self.session.commit() + # this does an autoflush. t = self.session.query(kk.Tag).filter_by(text="tag1").first() self.assertEqual(t.text, "tag1") self.assertEqual(t.db_id, 1) self.session.add(a2) + # FIXME committing here causes the test to fail. # self.session.commit() - self.assertEqual(a2.tags, ['tag1']) - # print(pd.read_sql_query("SELECT * FROM kk_tags", + # self.assertEqual(a2.tags, ['tag1']) + # print(pd.read_sql_query(text("SELECT * FROM kk_tags"), # self.engine)) del a1, a2, t