enable writable collections
This commit is contained in:
parent
9b3b9810c0
commit
f683d07793
37
kosokoso.py
37
kosokoso.py
|
|
@ -5,6 +5,8 @@ from sqlalchemy.ext.associationproxy import association_proxy
|
||||||
import itertools
|
import itertools
|
||||||
import logging
|
import logging
|
||||||
import pprint
|
import pprint
|
||||||
|
import sys
|
||||||
|
# import traceback
|
||||||
|
|
||||||
Base = declarative_base()
|
Base = declarative_base()
|
||||||
|
|
||||||
|
|
@ -20,23 +22,39 @@ class Tag(Base):
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
self.data = []
|
self.data = []
|
||||||
|
|
||||||
def append(self, item):
|
@sa.orm.collections.collection.internally_instrumented
|
||||||
|
def append(self, item, _sa_initiator=None):
|
||||||
|
"""This method is marked internally_instrumented to ensure
|
||||||
|
that sqlalchemy only notices the add when a TagAssociation
|
||||||
|
is used. If a class is added directly, we generate a
|
||||||
|
TagAssociation and re-call this method with the
|
||||||
|
association, which then triggers the append event."""
|
||||||
|
|
||||||
if isinstance(item, TagAssociation):
|
if isinstance(item, TagAssociation):
|
||||||
try:
|
try:
|
||||||
|
logging.debug("adding assoc %s w/ %s" % (item,
|
||||||
|
_sa_initiator))
|
||||||
|
# this has to be done first.
|
||||||
|
adapter = sa.orm.collections.collection_adapter(self)
|
||||||
|
item = adapter.fire_append_event(item, _sa_initiator)
|
||||||
|
if item.tracked_obj:
|
||||||
|
self.data.append(item)
|
||||||
|
else:
|
||||||
cls = self._get_class_by_tablename(item.target_table)
|
cls = self._get_class_by_tablename(item.target_table)
|
||||||
thingy = cls()
|
thingy = cls()
|
||||||
thingy.db_id = item.target_id
|
thingy.db_id = item.target_id
|
||||||
self.data.append(thingy)
|
self.data.append(thingy)
|
||||||
|
# pprint.pprint(self.data)
|
||||||
except TypeError:
|
except TypeError:
|
||||||
raise
|
raise
|
||||||
else:
|
else:
|
||||||
# this is complete madness
|
logging.debug("adding %s" % item)
|
||||||
print("got item")
|
# this seems kind of redundant.
|
||||||
assoc = TagAssociation()
|
assoc = TagAssociation()
|
||||||
assoc.target_id = item.db_id
|
assoc.target_id = item.db_id
|
||||||
assoc.target_table = item.__tablename__
|
assoc.target_table = item.__tablename__
|
||||||
self.append(assoc)
|
assoc.tracked_obj = item
|
||||||
# self.data.append(item)
|
self.append(assoc, _sa_initiator)
|
||||||
|
|
||||||
def remove(self, item):
|
def remove(self, item):
|
||||||
pass
|
pass
|
||||||
|
|
@ -63,12 +81,9 @@ class Tag(Base):
|
||||||
text = sa.Column(sa.Unicode(255, convert_unicode=False),
|
text = sa.Column(sa.Unicode(255, convert_unicode=False),
|
||||||
unique=True)
|
unique=True)
|
||||||
# this & collection should be read-only?
|
# this & collection should be read-only?
|
||||||
# associations = sa.orm.relationship("TagAssociation",
|
|
||||||
# back_populates="tag_obj")
|
|
||||||
collection = sa.orm.relationship("TagAssociation",
|
collection = sa.orm.relationship("TagAssociation",
|
||||||
collection_class=TaggedObjectCollection,
|
collection_class=TaggedObjectCollection,
|
||||||
enable_typechecks=False,
|
enable_typechecks=True,
|
||||||
# back_populates="tag_obj",
|
|
||||||
primaryjoin="TagAssociation.tag_id==Tag.db_id"
|
primaryjoin="TagAssociation.tag_id==Tag.db_id"
|
||||||
)
|
)
|
||||||
__tablename__ = "kk_tags"
|
__tablename__ = "kk_tags"
|
||||||
|
|
@ -122,6 +137,7 @@ class TagAssociation(Base):
|
||||||
# tag_obj = sa.orm.relationship("Tag", back_populates="collection")
|
# tag_obj = sa.orm.relationship("Tag", back_populates="collection")
|
||||||
tag = association_proxy("tag_obj", "text")
|
tag = association_proxy("tag_obj", "text")
|
||||||
|
|
||||||
|
tracked_obj = None
|
||||||
|
|
||||||
def __eq__(self, other):
|
def __eq__(self, other):
|
||||||
if not isinstance(other, TagAssociation):
|
if not isinstance(other, TagAssociation):
|
||||||
|
|
@ -145,8 +161,11 @@ class Taggable(Base):
|
||||||
table = cls.__tablename__
|
table = cls.__tablename__
|
||||||
|
|
||||||
def cls_init(self, tag=None):
|
def cls_init(self, tag=None):
|
||||||
|
# traceback.print_stack()
|
||||||
|
# print 'cls_init called, type %s' % table
|
||||||
super(TagAssociation, self).__init__()
|
super(TagAssociation, self).__init__()
|
||||||
self.target_table = table
|
self.target_table = table
|
||||||
|
# is this necessary? how does this get called?
|
||||||
if tag:
|
if tag:
|
||||||
if isinstance(tag, Tag):
|
if isinstance(tag, Tag):
|
||||||
self.tag_obj = tag
|
self.tag_obj = tag
|
||||||
|
|
|
||||||
|
|
@ -22,6 +22,9 @@ class Foo(DummyBase, kk.Taggable):
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return "foo: id %s" % self.db_id
|
return "foo: id %s" % self.db_id
|
||||||
|
|
||||||
|
def __eq__(self, other):
|
||||||
|
return self.db_id == other.db_id
|
||||||
|
|
||||||
class Bar(DummyBase, kk.Taggable):
|
class Bar(DummyBase, kk.Taggable):
|
||||||
__tablename__ = 'bars'
|
__tablename__ = 'bars'
|
||||||
db_id = sa.Column(sa.Integer, primary_key=True)
|
db_id = sa.Column(sa.Integer, primary_key=True)
|
||||||
|
|
@ -125,15 +128,14 @@ class ks_basic(unittest.TestCase):
|
||||||
self.assertEqual(3, len(t.collection))
|
self.assertEqual(3, len(t.collection))
|
||||||
self.assertEqual(2, len(t2.collection))
|
self.assertEqual(2, len(t2.collection))
|
||||||
self.assertEqual(1, len(t3.collection))
|
self.assertEqual(1, len(t3.collection))
|
||||||
# self.assertEqual(3, len(t.collection))
|
|
||||||
l = list(t.collection)
|
l = list(t.collection)
|
||||||
|
|
||||||
# TODO do we need to test these?
|
# TODO do we need to test these?
|
||||||
ta1 = self.session.query(kk.TagAssociation).get(1)
|
ta1 = self.session.query(kk.TagAssociation).get(1)
|
||||||
ta2 = self.session.query(kk.TagAssociation).get(2)
|
ta2 = self.session.query(kk.TagAssociation).get(2)
|
||||||
self.assertEqual(l[0].db_id, ta1.db_id)
|
self.assertEqual(l[0].db_id, ta1.target_id)
|
||||||
self.assertNotEqual(l[1].db_id, ta1.db_id)
|
self.assertNotEqual(l[1].db_id, ta1.target_id)
|
||||||
self.assertEqual(l[1].db_id, ta2.db_id)
|
self.assertEqual(l[1].db_id, ta2.target_id)
|
||||||
|
|
||||||
self.assertTrue(isinstance(l[0], Foo))
|
self.assertTrue(isinstance(l[0], Foo))
|
||||||
self.assertTrue(isinstance(l[1], Foo))
|
self.assertTrue(isinstance(l[1], Foo))
|
||||||
|
|
@ -146,7 +148,6 @@ class ks_basic(unittest.TestCase):
|
||||||
l3 = list(t3.collection)
|
l3 = list(t3.collection)
|
||||||
self.assertTrue(isinstance(l3[0], Bar))
|
self.assertTrue(isinstance(l3[0], Bar))
|
||||||
|
|
||||||
@unittest.skip("doesn't work.")
|
|
||||||
def test_writable_collection(self):
|
def test_writable_collection(self):
|
||||||
"""Test adding objects to a tag's collection member"""
|
"""Test adding objects to a tag's collection member"""
|
||||||
f1 = Foo()
|
f1 = Foo()
|
||||||
|
|
@ -160,6 +161,7 @@ class ks_basic(unittest.TestCase):
|
||||||
self.session.add(t)
|
self.session.add(t)
|
||||||
self.session.commit()
|
self.session.commit()
|
||||||
|
|
||||||
|
# t.collection.append(f1)
|
||||||
t.collection.append(f1)
|
t.collection.append(f1)
|
||||||
t.collection.append(f2)
|
t.collection.append(f2)
|
||||||
t.collection.append(b1)
|
t.collection.append(b1)
|
||||||
|
|
@ -170,24 +172,12 @@ class ks_basic(unittest.TestCase):
|
||||||
# print("collection:")
|
# print("collection:")
|
||||||
# for i in t.collection:
|
# for i in t.collection:
|
||||||
# print(i)
|
# print(i)
|
||||||
# print("collection:")
|
|
||||||
# for i in t.collection:
|
|
||||||
# print(i)
|
|
||||||
|
|
||||||
self.assertEqual(3, len(t.collection))
|
self.assertEqual(3, len(t.collection))
|
||||||
# self.assertEqual(3, len(t.collection))
|
|
||||||
l = list(t.collection)
|
l = list(t.collection)
|
||||||
|
|
||||||
# TODO do we need to test these?
|
|
||||||
# these are the db_ids of the tag associations.
|
|
||||||
# ta1 = self.session.query(kk.TagAssociation).get(1)
|
|
||||||
# ta2 = self.session.query(kk.TagAssociation).get(2)
|
|
||||||
# self.assertEqual(l[0], ta1)
|
|
||||||
# self.assertNotEqual(l[1], ta1)
|
|
||||||
# self.assertEqual(l[1], ta2)
|
|
||||||
|
|
||||||
self.assertTrue(isinstance(l[0], Foo))
|
self.assertTrue(isinstance(l[0], Foo))
|
||||||
self.assertTrue(isinstance(l[1], Foo))
|
self.assertEqual(l[0], f1)
|
||||||
self.assertTrue(isinstance(l[2], Bar))
|
self.assertTrue(isinstance(l[2], Bar))
|
||||||
|
|
||||||
def test_association(self):
|
def test_association(self):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue