diff --git a/kosokoso.py b/kosokoso.py index ba97608..a67f254 100644 --- a/kosokoso.py +++ b/kosokoso.py @@ -5,6 +5,8 @@ from sqlalchemy.ext.associationproxy import association_proxy import itertools import logging import pprint +import sys +# import traceback Base = declarative_base() @@ -20,23 +22,39 @@ class Tag(Base): def __init__(self): self.data = [] - def append(self, item): + @sa.orm.collections.collection.internally_instrumented + def append(self, item, _sa_initiator=None): + """This method is marked internally_instrumented to ensure + that sqlalchemy only notices the add when a TagAssociation + is used. If a class is added directly, we generate a + TagAssociation and re-call this method with the + association, which then triggers the append event.""" + if isinstance(item, TagAssociation): try: - cls = self._get_class_by_tablename(item.target_table) - thingy = cls() - thingy.db_id = item.target_id - self.data.append(thingy) + logging.debug("adding assoc %s w/ %s" % (item, + _sa_initiator)) + # this has to be done first. + adapter = sa.orm.collections.collection_adapter(self) + item = adapter.fire_append_event(item, _sa_initiator) + if item.tracked_obj: + self.data.append(item) + else: + cls = self._get_class_by_tablename(item.target_table) + thingy = cls() + thingy.db_id = item.target_id + self.data.append(thingy) + # pprint.pprint(self.data) except TypeError: raise else: - # this is complete madness - print("got item") + logging.debug("adding %s" % item) + # this seems kind of redundant. assoc = TagAssociation() assoc.target_id = item.db_id assoc.target_table = item.__tablename__ - self.append(assoc) - # self.data.append(item) + assoc.tracked_obj = item + self.append(assoc, _sa_initiator) def remove(self, item): pass @@ -63,12 +81,9 @@ class Tag(Base): text = sa.Column(sa.Unicode(255, convert_unicode=False), unique=True) # this & collection should be read-only? - # associations = sa.orm.relationship("TagAssociation", - # back_populates="tag_obj") collection = sa.orm.relationship("TagAssociation", collection_class=TaggedObjectCollection, - enable_typechecks=False, - # back_populates="tag_obj", + enable_typechecks=True, primaryjoin="TagAssociation.tag_id==Tag.db_id" ) __tablename__ = "kk_tags" @@ -122,6 +137,7 @@ class TagAssociation(Base): # tag_obj = sa.orm.relationship("Tag", back_populates="collection") tag = association_proxy("tag_obj", "text") + tracked_obj = None def __eq__(self, other): if not isinstance(other, TagAssociation): @@ -145,8 +161,11 @@ class Taggable(Base): table = cls.__tablename__ def cls_init(self, tag=None): + # traceback.print_stack() + # print 'cls_init called, type %s' % table super(TagAssociation, self).__init__() self.target_table = table + # is this necessary? how does this get called? if tag: if isinstance(tag, Tag): self.tag_obj = tag diff --git a/tests/test_basic.py b/tests/test_basic.py index a0f6349..54a7d56 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -22,6 +22,9 @@ class Foo(DummyBase, kk.Taggable): def __repr__(self): return "foo: id %s" % self.db_id + def __eq__(self, other): + return self.db_id == other.db_id + class Bar(DummyBase, kk.Taggable): __tablename__ = 'bars' db_id = sa.Column(sa.Integer, primary_key=True) @@ -125,15 +128,14 @@ class ks_basic(unittest.TestCase): self.assertEqual(3, len(t.collection)) self.assertEqual(2, len(t2.collection)) self.assertEqual(1, len(t3.collection)) - # self.assertEqual(3, len(t.collection)) l = list(t.collection) # TODO do we need to test these? ta1 = self.session.query(kk.TagAssociation).get(1) ta2 = self.session.query(kk.TagAssociation).get(2) - self.assertEqual(l[0].db_id, ta1.db_id) - self.assertNotEqual(l[1].db_id, ta1.db_id) - self.assertEqual(l[1].db_id, ta2.db_id) + self.assertEqual(l[0].db_id, ta1.target_id) + self.assertNotEqual(l[1].db_id, ta1.target_id) + self.assertEqual(l[1].db_id, ta2.target_id) self.assertTrue(isinstance(l[0], Foo)) self.assertTrue(isinstance(l[1], Foo)) @@ -146,7 +148,6 @@ class ks_basic(unittest.TestCase): l3 = list(t3.collection) self.assertTrue(isinstance(l3[0], Bar)) - @unittest.skip("doesn't work.") def test_writable_collection(self): """Test adding objects to a tag's collection member""" f1 = Foo() @@ -160,6 +161,7 @@ class ks_basic(unittest.TestCase): self.session.add(t) self.session.commit() + # t.collection.append(f1) t.collection.append(f1) t.collection.append(f2) t.collection.append(b1) @@ -170,24 +172,12 @@ class ks_basic(unittest.TestCase): # print("collection:") # for i in t.collection: # print(i) - # print("collection:") - # for i in t.collection: - # print(i) self.assertEqual(3, len(t.collection)) - # self.assertEqual(3, len(t.collection)) l = list(t.collection) - # TODO do we need to test these? - # these are the db_ids of the tag associations. - # ta1 = self.session.query(kk.TagAssociation).get(1) - # ta2 = self.session.query(kk.TagAssociation).get(2) - # self.assertEqual(l[0], ta1) - # self.assertNotEqual(l[1], ta1) - # self.assertEqual(l[1], ta2) - self.assertTrue(isinstance(l[0], Foo)) - self.assertTrue(isinstance(l[1], Foo)) + self.assertEqual(l[0], f1) self.assertTrue(isinstance(l[2], Bar)) def test_association(self):