small fixes.

This commit is contained in:
chris t 2018-12-11 00:42:49 -08:00
parent d6ffe96b30
commit 9b3b9810c0
3 changed files with 95 additions and 18 deletions

View File

@ -53,7 +53,8 @@ objects, iterate the kk_tag_associations member::
return [ta.tag_obj for ta in a.kk_tag_associations] return [ta.tag_obj for ta in a.kk_tag_associations]
This will bypass the association proxy and return a list of Tag This will bypass the association proxy and return a list of Tag
objects. objects. The tag object has a useful member, "collection", which
provides direct access to the objects with this tag.
It's also possible to add kk.Tag objects directly to the tags It's also possible to add kk.Tag objects directly to the tags
property on a taggable object, and they'll be handled correctly. property on a taggable object, and they'll be handled correctly.
@ -79,3 +80,4 @@ buzzing side channel to me.
(Originally, it was "guzuguzu", which is, like, slow. This took me (Originally, it was "guzuguzu", which is, like, slow. This took me
way longer to write than I was expecting.) way longer to write than I was expecting.)

View File

@ -1,9 +1,9 @@
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import event
from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.ext.declarative import declarative_base, declared_attr
from sqlalchemy.ext.associationproxy import association_proxy from sqlalchemy.ext.associationproxy import association_proxy
import itertools import itertools
import logging
import pprint import pprint
Base = declarative_base() Base = declarative_base()
@ -11,6 +11,9 @@ Base = declarative_base()
class Tag(Base): class Tag(Base):
def __eq__(self, other):
return self.text == other.text
class TaggedObjectCollection(object): class TaggedObjectCollection(object):
__emulates__ = list __emulates__ = list
@ -18,6 +21,7 @@ class Tag(Base):
self.data = [] self.data = []
def append(self, item): def append(self, item):
if isinstance(item, TagAssociation):
try: try:
cls = self._get_class_by_tablename(item.target_table) cls = self._get_class_by_tablename(item.target_table)
thingy = cls() thingy = cls()
@ -25,6 +29,14 @@ class Tag(Base):
self.data.append(thingy) self.data.append(thingy)
except TypeError: except TypeError:
raise raise
else:
# this is complete madness
print("got item")
assoc = TagAssociation()
assoc.target_id = item.db_id
assoc.target_table = item.__tablename__
self.append(assoc)
# self.data.append(item)
def remove(self, item): def remove(self, item):
pass pass
@ -35,9 +47,6 @@ class Tag(Base):
def __len__(self): def __len__(self):
return len(self.data) return len(self.data)
def __eq__(self, other):
return self.text == other.text
def _get_class_by_tablename(self, tablename): def _get_class_by_tablename(self, tablename):
"""Return class reference mapped to table. """Return class reference mapped to table.
@ -58,13 +67,15 @@ class Tag(Base):
# back_populates="tag_obj") # back_populates="tag_obj")
collection = sa.orm.relationship("TagAssociation", collection = sa.orm.relationship("TagAssociation",
collection_class=TaggedObjectCollection, collection_class=TaggedObjectCollection,
enable_typechecks = False, enable_typechecks=False,
# back_populates="tag_obj", # back_populates="tag_obj",
primaryjoin="TagAssociation.tag_id==Tag.db_id" primaryjoin="TagAssociation.tag_id==Tag.db_id"
) )
__tablename__ = "kk_tags" __tablename__ = "kk_tags"
# we disable batch to ensure that parents are inserted before
# children.
__mapper_args__ = { __mapper_args__ = {
"primary_key": [db_id], # "primary_key": [db_id],
"batch": False "batch": False
} }
@ -140,7 +151,7 @@ class Taggable(Base):
if isinstance(tag, Tag): if isinstance(tag, Tag):
self.tag_obj = tag self.tag_obj = tag
else: else:
# print "cls init called with text %s" % tag logging.debug("cls init called with text %s" % tag)
self.tag_obj = Tag(tag) self.tag_obj = Tag(tag)
@ -184,6 +195,6 @@ def enforce_unique_text(session, flush_context, instances):
if t: if t:
i.db_id = t.db_id i.db_id = t.db_id
event.listen(Tag, 'before_insert', delete_before_insert) sa.event.listen(Tag, 'before_insert', delete_before_insert)
event.listen(sa.orm.session.Session, 'before_flush', sa.event.listen(sa.orm.session.Session, 'before_flush',
enforce_unique_text) enforce_unique_text)

View File

@ -11,7 +11,10 @@ import unittest
import libkosokoso as kk import libkosokoso as kk
class DummyBase = sa.ext.declarative.declarative_base() # TODO get this working -- merge bases?
# DummyBase = sa.ext.declarative.declarative_base()
class DummyBase(object):
"""dummy"""
class Foo(DummyBase, kk.Taggable): class Foo(DummyBase, kk.Taggable):
__tablename__ = 'foos' __tablename__ = 'foos'
@ -96,9 +99,70 @@ class ks_basic(unittest.TestCase):
f2 = self.session.query(Foo).get(2) f2 = self.session.query(Foo).get(2)
b1 = self.session.query(Bar).get(1) b1 = self.session.query(Bar).get(1)
t = self.session.query(kk.Tag).get(1) t = self.session.query(kk.Tag).get(1)
t2 = kk.Tag('tag2')
t3 = kk.Tag('tag3')
self.session.add(t2)
self.session.add(t3)
f1.tags.append(t) f1.tags.append(t)
f2.tags.append(t) f2.tags.append(t)
b1.tags.append(t) b1.tags.append(t)
f2.tags.append(t2)
b1.tags.append(t3)
b1.tags.append(t2)
self.session.commit()
# print(pd.read_sql_query("SELECT * FROM kk_tag_associations",
# self.engine))
# print("collection:")
# for i in t.collection:
# print(i)
# print("collection:")
# for i in t.collection:
# print(i)
self.assertEqual(3, len(t.collection))
self.assertEqual(2, len(t2.collection))
self.assertEqual(1, len(t3.collection))
# self.assertEqual(3, len(t.collection))
l = list(t.collection)
# TODO do we need to test these?
ta1 = self.session.query(kk.TagAssociation).get(1)
ta2 = self.session.query(kk.TagAssociation).get(2)
self.assertEqual(l[0].db_id, ta1.db_id)
self.assertNotEqual(l[1].db_id, ta1.db_id)
self.assertEqual(l[1].db_id, ta2.db_id)
self.assertTrue(isinstance(l[0], Foo))
self.assertTrue(isinstance(l[1], Foo))
self.assertTrue(isinstance(l[2], Bar))
l2 = list(t2.collection)
self.assertTrue(isinstance(l2[0], Foo))
self.assertTrue(isinstance(l2[1], Bar))
l3 = list(t3.collection)
self.assertTrue(isinstance(l3[0], Bar))
@unittest.skip("doesn't work.")
def test_writable_collection(self):
"""Test adding objects to a tag's collection member"""
f1 = Foo()
f2 = Foo()
b1 = Bar()
t = kk.Tag('tag1')
self.session.add(f1)
self.session.add(f2)
self.session.add(b1)
self.session.add(t)
self.session.commit()
t.collection.append(f1)
t.collection.append(f2)
t.collection.append(b1)
self.session.commit() self.session.commit()
# print(pd.read_sql_query("SELECT * FROM kk_tag_associations", # print(pd.read_sql_query("SELECT * FROM kk_tag_associations",