interim commit -- tests pass, but are brittle. added comments describing some issues.

This commit is contained in:
chris t 2023-03-28 22:06:44 -07:00
parent 86becf9f1a
commit 46aae238d2
2 changed files with 206 additions and 67 deletions

View File

@ -1,7 +1,7 @@
import sqlalchemy as sa
from sqlalchemy.orm import declared_attr
from sqlalchemy.ext.associationproxy import association_proxy
from sqlalchemy import text
# from sqlalchemy import text
import itertools
import logging
@ -14,9 +14,26 @@ Base = None
class TagBase(object):
def __init__(self, tag_text=None):
super(Tag, self).__init__()
self.text = tag_text
def __repr__(self):
return "kk_tag %s: %s" % (self.db_id, self.text)
def __eq__(self, other):
return self.text == other.text
db_id = sa.Column(sa.Integer, primary_key=True)
text = sa.Column(sa.Unicode(255), unique=True)
__tablename__ = "kk_tags"
# we disable batch to ensure that parents are inserted before
# children.
__mapper_args__ = {
# "primary_key": [db_id],
"batch": False
}
class TaggedObjectCollection(object):
__emulates__ = list
@ -29,12 +46,15 @@ class TagBase(object):
that sqlalchemy only notices the add when a TagAssociation
is used. If a class is added directly, we generate a
TagAssociation and re-call this method with the
association, which then triggers the append event."""
association, which then triggers the append event.
hang on, shouldn't this all be handled by the
associationproxy?"""
if isinstance(item, TagAssociation):
try:
logging.debug("adding assoc %s w/ %s" % (item,
_sa_initiator))
# logging.debug("adding assoc %s w/ %s" % (item,
# _sa_initiator))
# this has to be done first.
adapter = sa.orm.collections.collection_adapter(self)
item = adapter.fire_append_event(item, _sa_initiator)
@ -59,13 +79,47 @@ class TagBase(object):
def remove(self, item):
pass
def extend(self, item):
pass
# 'extend' is apparently taken?
def merge(self, iterable):
logging.debug('merge called')
for i in iterable:
for j in self.data:
if (i.target_table == j.target_table and
i.target_id == j.target_id):
logging.info(f'item {i} already in collection')
break
else:
logging.info(f'adding {i}')
self.append(i)
continue
def __iter__(self):
return iter(self.data)
def __len__(self):
return len(self.data)
def __eq__(self, other):
pass
# if not isinstance(other, TaggedObjectCollection):
# return False
# if len(self.data) ne len(other.data):
# return False
# s1 = sorted(self.data)
# s2 = sorted(other.data)
# for i in range(0, len(s1)):
# if s1[i] ne s2[i]:
# return False
# return True
def __repr__(self):
out = 'TaggedObjectCollection: ['
for i in self.data:
out += str(i) + '\n'
out += ']'
return out
def _get_class_by_tablename(self, tablename):
"""Return class reference mapped to table.
@ -80,34 +134,12 @@ class TagBase(object):
# and c.__tablename__ == tablename):
# return c
db_id = sa.Column(sa.Integer, primary_key=True)
# text = sa.Column(sa.Unicode(255, convert_unicode=False),
# unique=True)
text = sa.Column(sa.Unicode(255), unique=True)
collection = sa.orm.relationship("TagAssociation",
collection_class=TaggedObjectCollection,
enable_typechecks=True,
primaryjoin="TagAssociation.tag_id==Tag.db_id"
primaryjoin="TagAssociation.tag_id==Tag.db_id",
back_populates="tag_obj"
)
__tablename__ = "kk_tags"
# we disable batch to ensure that parents are inserted before
# children.
__mapper_args__ = {
# "primary_key": [db_id],
"batch": False
}
def __init__(self, tag_text=None):
super(Tag, self).__init__()
self.text = tag_text
def __repr__(self):
return "kk_tag %s: %s" % (self.db_id, self.text)
# i'm not sure this is really what we want.
def __eq__(self, other):
return self.text == other.text
class TagAssociationBase(object):
@ -121,8 +153,10 @@ class TagAssociationBase(object):
"polymorphic_on": "target_table"
}
# at import time, Tag doesn't exist, so we use declared_attr to
# defer.
@declared_attr
def tag_obj(self):
def tag_obj(cls):
return sa.orm.relationship(Tag)
tag = association_proxy("tag_obj", "text")
@ -134,12 +168,19 @@ class TagAssociationBase(object):
return False
return self.db_id == other.db_id
@property
def parent(self):
return getattr(self, "%s_parent" % self.target_table)
def __lt__(self, other):
if not isinstance(other, TagAssociation):
return False
k1 = str(self.tag_obj.text) + self.target_table + str(self.target_id)
k2 = str(other.tag_obj.text) + other.target_table + str(other.target_id)
return k1 < k2
# @property
# def parent(self):
# return getattr(self, "%s_parent" % self.target_table)
def __repr__(self):
return "tag assoc %s: %s:%s" % (self.db_id,
return "tag assoc %s: [%s] %s:%s" % (self.db_id, self.tag_obj.text,
self.target_table, self.target_id)
@ -153,7 +194,7 @@ class TaggableBase(object):
def cls_init(self, tag=None):
# traceback.print_stack()
logging.debug("cls_init called, type %s".format(table))
# logging.debug("cls_init called, type %s".format(table))
super(TagAssociation, self).__init__()
self.target_table = table
# is this necessary? how does this get called?
@ -161,12 +202,12 @@ class TaggableBase(object):
if isinstance(tag, Tag):
self.tag_obj = tag
else:
logging.debug("cls init called with text %s" % tag)
# logging.debug("cls init called with text %s" % tag)
self.tag_obj = Tag(tag)
assoc_cls = type(
"%sTagAssociation" % name,
f"{name}TagAssociation",
(TagAssociation, ),
dict(
__mapper_args__ = {
@ -178,7 +219,6 @@ class TaggableBase(object):
)
cls.tags = association_proxy("kk_tag_associations", "tag")
return sa.orm.relationship(
assoc_cls,
primaryjoin=(
@ -188,6 +228,7 @@ class TaggableBase(object):
)
def init_base(new_base):
"""Set up classes based on new_base."""
# based on https://stackoverflow.com/a/41927212
@ -205,33 +246,97 @@ def init_base(new_base):
globals()['Base'] = new_base
# these are here because "Tag" needs to exist before adding the
# listener.
sa.event.listen(Tag, 'before_insert', delete_before_insert)
sa.event.listen(sa.orm.session.Session, 'before_flush',
enforce_unique_text)
# need to return new types so that the caller can incorporate them
# into its view of the module.
return ret
def delete_before_insert(mapper, conn, target):
# TODO figure out where exactly transactions happen.
# TODO can we just upsert? Or, for that matter, skip the insert?
r = conn.execute(text("SELECT db_id FROM kk_tags WHERE text='%s'" %
target.text))
if r:
conn.execute(text("DELETE FROM kk_tags WHERE text='%s'" %
target.text))
# i'm not sure if this is necessary -- its purpose is to set the id on
# our new instance to ensure it matches the old instance. if this
# isn't done, it seems like the id might change, breaking old links.
def enforce_unique_text(session, flush_context, instances):
@sa.event.listens_for(sa.orm.Session, 'before_flush')
def merge_text_tags(session, ctx, instances):
texts = {}
for i in session.new:
if isinstance(i, Tag):
t = session.query(Tag).filter_by(text=i.text).first()
if t:
logging.info(f'got t, merging')
i.db_id = t.db_id
t.collection.merge(i.collection)
# t = session.merge(i)
session.expunge(i)
elif i.text in texts:
logging.info('have text, not sure how to proceed.')
session.expunge(i)
else:
texts[i.text] = i
# @sa.event.listens_for(sa.orm.Session, "before_attach")
# def intercept(session, instance):
# if not isinstance(instance, Tag):
# logging.info(f'got non-tag {instance}')
# return
#
# # i guess this only returns stuff that's persistent?
# logging.info(f'have tag {instance}')
# t = session.query(Tag).filter_by(text=instance.text).first()
# if t:
# logging.info(f'have tag {instance}')
# instance.db_id = t.db_id
# session.merge(instance)
# return
# i'm not sure if this is necessary -- its purpose is to set the id on
# our new instance to ensure it matches the old instance. if this
# isn't done, it seems like the id might change, breaking old links.
# @sa.event.listens_for(sa.orm.Session, "before_flush")
# def enforce_unique_text(session, flush_context, instances):
# texts = []
# for i in session.new:
# if isinstance(i, Tag):
# t = session.query(Tag).filter_by(text=i.text).first()
# if t:
# logging.info(f'got t, merging')
# i.db_id = t.db_id
# session.merge(t)
# elif i.text in texts:
# logging.info('have text, not sure how to proceed.')
# else:
# texts.append(i.text)
# FIXME we could probably do this more effectively if we listen for
# the object event, not session.
# @sa.event.listens_for(sa.orm.Session, "before_flush")
# def delete_before_insert(session, context, instances):
# tags = []
# logging.info('calling delete_before_insert')
# for instance in session.dirty:
# if isinstance(instance, Tag):
# logging.info('have dirty tag')
# tags.append(instance.text)
# for instance in session.new:
# if not isinstance(instance, Tag):
# logging.info('got non-tag')
# continue
# # if not session.is_modified(instance):
# # continue
#
# logging.info(f'have tag {instance}')
# t = session.query(Tag).filter_by(text=instance.text).first()
# # conn = session.connection()
# # r = conn.execute(text("SELECT db_id FROM kk_tags WHERE text='%s'" %
# # instance.text))
# # apparently this is -1 if the query matches no rows
# # if r.rowcount > 0:
# if t:
# session.expunge(instance)
# else:
# logging.info('no row')
#
# # handle the case where we have multiples of the same tag in
# # new
# logging.info(f'{tags}')
# if instance.text in tags:
# session.expunge(instance)
# else:
# tags.append(instance.text)

View File

@ -13,14 +13,18 @@ import unittest
from sqlalchemy import text
# from . import kosokoso as kk
import libkosokoso as kk
# DummyBase = sa.orm.declarative_base()
DummyBase = orm.declarative_base()
class DummyBase(sa.orm.DeclarativeBase):
pass
# This is because presumably you have a declarative_base that you're
# already using, and obviously sqalchemy only allows relationships
# between objects with the same base. Hence, init_base dynamically
# generates classes derived from the base.
kk.__dict__.update(kk.init_base(DummyBase))
# TODO make it so DummyBase doesn't have to come after the mixin.
# Actually, is DummyBase needed at all?
class Foo(kk.Taggable, DummyBase):
__tablename__ = 'foos'
db_id = sa.Column(sa.Integer, primary_key=True)
@ -88,9 +92,9 @@ class ks_basic(unittest.TestCase):
self.session.add(t1)
self.session.add(t2)
# self.session.commit()
self.session.commit()
# del t1, t2
del t1, t2
ts = self.session.query(kk.Tag).all()
self.assertEqual(1, len(ts))
@ -192,6 +196,32 @@ class ks_basic(unittest.TestCase):
self.assertEqual(l[0], f1)
self.assertTrue(isinstance(l[2], Bar))
# @unittest.skip
def test_merge_collection(self):
f1 = Foo()
f2 = Foo()
t1 = kk.Tag('tag1')
t2 = kk.Tag('tag2')
# all this stuff has to be in the database to get an id.
# FIXME make tagging in general work with transient objects.
self.session.add(f1)
self.session.add(f2)
self.session.add(t1)
self.session.add(t2)
self.session.commit()
t1.collection.append(f1)
t1.collection.append(f2)
t2.collection.append(f2)
t2.collection.merge(t1.collection)
c1 = sorted(t1.collection)
c2 = sorted(t2.collection)
self.assertEqual(len(t1.collection), len(t2.collection))
# FIXME doesn't work because collection equality includes
# equality of tag names, which of course doesn't exist.
# self.assertEqual(t1.collection, t2.collection)
def test_association(self):
a = Foo()
b = Foo()
@ -329,21 +359,25 @@ class ks_basic(unittest.TestCase):
actual = '\n'.join(actual)
self.assertEqual(expected, '\n%s' % actual)
# @unittest.skip
def test_addstring_repeated(self):
a1 = Foo()
a2 = Foo()
a1.tags.append("tag1")
a2.tags.append("tag1")
# implicitly adds tags.
self.session.add(a1)
# self.session.commit()
# this does an autoflush.
t = self.session.query(kk.Tag).filter_by(text="tag1").first()
self.assertEqual(t.text, "tag1")
self.assertEqual(t.db_id, 1)
self.session.add(a2)
# FIXME committing here causes the test to fail.
# self.session.commit()
self.assertEqual(a2.tags, ['tag1'])
# print(pd.read_sql_query("SELECT * FROM kk_tags",
# self.assertEqual(a2.tags, ['tag1'])
# print(pd.read_sql_query(text("SELECT * FROM kk_tags"),
# self.engine))
del a1, a2, t