Compare commits

..

No commits in common. "master" and "5ec97e7cd0b137b27a5f4c7503bff518518e0c6a" have entirely different histories.

3 changed files with 69 additions and 208 deletions

View File

@ -11,7 +11,7 @@ Usage
import libkosokoso as kk
import sqlalchemy as sa
base = sa.orm.declarative_base()
base = sa.ext.declarative.declarative_base()
kk.__dict__.update(kk.init_base(base))

View File

@ -1,7 +1,7 @@
import sqlalchemy as sa
from sqlalchemy.orm import declared_attr
from sqlalchemy.ext.associationproxy import association_proxy
# from sqlalchemy import text
from sqlalchemy import text
import itertools
import logging
@ -14,26 +14,9 @@ Base = None
class TagBase(object):
def __init__(self, tag_text=None):
super(Tag, self).__init__()
self.text = tag_text
def __repr__(self):
return "kk_tag %s: %s" % (self.db_id, self.text)
def __eq__(self, other):
return self.text == other.text
db_id = sa.Column(sa.Integer, primary_key=True)
text = sa.Column(sa.Unicode(255), unique=True)
__tablename__ = "kk_tags"
# we disable batch to ensure that parents are inserted before
# children.
__mapper_args__ = {
# "primary_key": [db_id],
"batch": False
}
class TaggedObjectCollection(object):
__emulates__ = list
@ -46,15 +29,12 @@ class TagBase(object):
that sqlalchemy only notices the add when a TagAssociation
is used. If a class is added directly, we generate a
TagAssociation and re-call this method with the
association, which then triggers the append event.
hang on, shouldn't this all be handled by the
associationproxy?"""
association, which then triggers the append event."""
if isinstance(item, TagAssociation):
try:
# logging.debug("adding assoc %s w/ %s" % (item,
# _sa_initiator))
logging.debug("adding assoc %s w/ %s" % (item,
_sa_initiator))
# this has to be done first.
adapter = sa.orm.collections.collection_adapter(self)
item = adapter.fire_append_event(item, _sa_initiator)
@ -79,47 +59,13 @@ class TagBase(object):
def remove(self, item):
pass
# 'extend' is apparently taken?
def merge(self, iterable):
logging.debug('merge called')
for i in iterable:
for j in self.data:
if (i.target_table == j.target_table and
i.target_id == j.target_id):
logging.info(f'item {i} already in collection')
break
else:
logging.info(f'adding {i}')
self.append(i)
continue
def extend(self, item):
pass
def __iter__(self):
return iter(self.data)
def __len__(self):
return len(self.data)
def __eq__(self, other):
pass
# if not isinstance(other, TaggedObjectCollection):
# return False
# if len(self.data) ne len(other.data):
# return False
# s1 = sorted(self.data)
# s2 = sorted(other.data)
# for i in range(0, len(s1)):
# if s1[i] ne s2[i]:
# return False
# return True
def __repr__(self):
out = 'TaggedObjectCollection: ['
for i in self.data:
out += str(i) + '\n'
out += ']'
return out
def _get_class_by_tablename(self, tablename):
"""Return class reference mapped to table.
@ -134,12 +80,34 @@ class TagBase(object):
# and c.__tablename__ == tablename):
# return c
db_id = sa.Column(sa.Integer, primary_key=True)
# text = sa.Column(sa.Unicode(255, convert_unicode=False),
# unique=True)
text = sa.Column(sa.Unicode(255), unique=True)
collection = sa.orm.relationship("TagAssociation",
collection_class=TaggedObjectCollection,
enable_typechecks=True,
primaryjoin="TagAssociation.tag_id==Tag.db_id",
back_populates="tag_obj"
primaryjoin="TagAssociation.tag_id==Tag.db_id"
)
__tablename__ = "kk_tags"
# we disable batch to ensure that parents are inserted before
# children.
__mapper_args__ = {
# "primary_key": [db_id],
"batch": False
}
def __init__(self, tag_text=None):
super(Tag, self).__init__()
self.text = tag_text
def __repr__(self):
return "kk_tag %s: %s" % (self.db_id, self.text)
# i'm not sure this is really what we want.
def __eq__(self, other):
return self.text == other.text
class TagAssociationBase(object):
@ -153,10 +121,8 @@ class TagAssociationBase(object):
"polymorphic_on": "target_table"
}
# at import time, Tag doesn't exist, so we use declared_attr to
# defer.
@declared_attr
def tag_obj(cls):
def tag_obj(self):
return sa.orm.relationship(Tag)
tag = association_proxy("tag_obj", "text")
@ -168,19 +134,12 @@ class TagAssociationBase(object):
return False
return self.db_id == other.db_id
def __lt__(self, other):
if not isinstance(other, TagAssociation):
return False
k1 = str(self.tag_obj.text) + self.target_table + str(self.target_id)
k2 = str(other.tag_obj.text) + other.target_table + str(other.target_id)
return k1 < k2
# @property
# def parent(self):
# return getattr(self, "%s_parent" % self.target_table)
@property
def parent(self):
return getattr(self, "%s_parent" % self.target_table)
def __repr__(self):
return "tag assoc %s: [%s] %s:%s" % (self.db_id, self.tag_obj.text,
return "tag assoc %s: %s:%s" % (self.db_id,
self.target_table, self.target_id)
@ -194,7 +153,7 @@ class TaggableBase(object):
def cls_init(self, tag=None):
# traceback.print_stack()
# logging.debug("cls_init called, type %s".format(table))
logging.debug("cls_init called, type %s".format(table))
super(TagAssociation, self).__init__()
self.target_table = table
# is this necessary? how does this get called?
@ -202,12 +161,12 @@ class TaggableBase(object):
if isinstance(tag, Tag):
self.tag_obj = tag
else:
# logging.debug("cls init called with text %s" % tag)
logging.debug("cls init called with text %s" % tag)
self.tag_obj = Tag(tag)
assoc_cls = type(
f"{name}TagAssociation",
"%sTagAssociation" % name,
(TagAssociation, ),
dict(
__mapper_args__ = {
@ -219,6 +178,7 @@ class TaggableBase(object):
)
cls.tags = association_proxy("kk_tag_associations", "tag")
return sa.orm.relationship(
assoc_cls,
primaryjoin=(
@ -228,7 +188,6 @@ class TaggableBase(object):
)
def init_base(new_base):
"""Set up classes based on new_base."""
# based on https://stackoverflow.com/a/41927212
@ -246,97 +205,33 @@ def init_base(new_base):
globals()['Base'] = new_base
# these are here because "Tag" needs to exist before adding the
# listener.
sa.event.listen(Tag, 'before_insert', delete_before_insert)
sa.event.listen(sa.orm.session.Session, 'before_flush',
enforce_unique_text)
# need to return new types so that the caller can incorporate them
# into its view of the module.
return ret
@sa.event.listens_for(sa.orm.Session, 'before_flush')
def merge_text_tags(session, ctx, instances):
texts = {}
for i in session.new:
if isinstance(i, Tag):
t = session.query(Tag).filter_by(text=i.text).first()
if t:
logging.info(f'got t, merging')
i.db_id = t.db_id
t.collection.merge(i.collection)
# t = session.merge(i)
session.expunge(i)
elif i.text in texts:
logging.info('have text, not sure how to proceed.')
session.expunge(i)
else:
texts[i.text] = i
# @sa.event.listens_for(sa.orm.Session, "before_attach")
# def intercept(session, instance):
# if not isinstance(instance, Tag):
# logging.info(f'got non-tag {instance}')
# return
#
# # i guess this only returns stuff that's persistent?
# logging.info(f'have tag {instance}')
# t = session.query(Tag).filter_by(text=instance.text).first()
# if t:
# logging.info(f'have tag {instance}')
# instance.db_id = t.db_id
# session.merge(instance)
# return
def delete_before_insert(mapper, conn, target):
# TODO figure out where exactly transactions happen.
# TODO can we just upsert? Or, for that matter, skip the insert?
r = conn.execute(text("SELECT db_id FROM kk_tags WHERE text='%s'" %
target.text))
if r:
conn.execute(text("DELETE FROM kk_tags WHERE text='%s'" %
target.text))
# i'm not sure if this is necessary -- its purpose is to set the id on
# our new instance to ensure it matches the old instance. if this
# isn't done, it seems like the id might change, breaking old links.
# @sa.event.listens_for(sa.orm.Session, "before_flush")
# def enforce_unique_text(session, flush_context, instances):
# texts = []
# for i in session.new:
# if isinstance(i, Tag):
# t = session.query(Tag).filter_by(text=i.text).first()
# if t:
# logging.info(f'got t, merging')
# i.db_id = t.db_id
# session.merge(t)
# elif i.text in texts:
# logging.info('have text, not sure how to proceed.')
# else:
# texts.append(i.text)
# FIXME we could probably do this more effectively if we listen for
# the object event, not session.
# @sa.event.listens_for(sa.orm.Session, "before_flush")
# def delete_before_insert(session, context, instances):
# tags = []
# logging.info('calling delete_before_insert')
# for instance in session.dirty:
# if isinstance(instance, Tag):
# logging.info('have dirty tag')
# tags.append(instance.text)
# for instance in session.new:
# if not isinstance(instance, Tag):
# logging.info('got non-tag')
# continue
# # if not session.is_modified(instance):
# # continue
#
# logging.info(f'have tag {instance}')
# t = session.query(Tag).filter_by(text=instance.text).first()
# # conn = session.connection()
# # r = conn.execute(text("SELECT db_id FROM kk_tags WHERE text='%s'" %
# # instance.text))
# # apparently this is -1 if the query matches no rows
# # if r.rowcount > 0:
# if t:
# session.expunge(instance)
# else:
# logging.info('no row')
#
# # handle the case where we have multiples of the same tag in
# # new
# logging.info(f'{tags}')
# if instance.text in tags:
# session.expunge(instance)
# else:
# tags.append(instance.text)
def enforce_unique_text(session, flush_context, instances):
for i in session.new:
if isinstance(i, Tag):
t = session.query(Tag).filter_by(text=i.text).first()
if t:
i.db_id = t.db_id

View File

@ -13,18 +13,14 @@ import unittest
from sqlalchemy import text
# from . import kosokoso as kk
import libkosokoso as kk
class DummyBase(sa.orm.DeclarativeBase):
pass
# This is because presumably you have a declarative_base that you're
# already using, and obviously sqalchemy only allows relationships
# between objects with the same base. Hence, init_base dynamically
# generates classes derived from the base.
# DummyBase = sa.orm.declarative_base()
DummyBase = orm.declarative_base()
kk.__dict__.update(kk.init_base(DummyBase))
# TODO make it so DummyBase doesn't have to come after the mixin.
# Actually, is DummyBase needed at all?
class Foo(kk.Taggable, DummyBase):
__tablename__ = 'foos'
db_id = sa.Column(sa.Integer, primary_key=True)
@ -92,9 +88,9 @@ class ks_basic(unittest.TestCase):
self.session.add(t1)
self.session.add(t2)
self.session.commit()
# self.session.commit()
del t1, t2
# del t1, t2
ts = self.session.query(kk.Tag).all()
self.assertEqual(1, len(ts))
@ -196,32 +192,6 @@ class ks_basic(unittest.TestCase):
self.assertEqual(l[0], f1)
self.assertTrue(isinstance(l[2], Bar))
# @unittest.skip
def test_merge_collection(self):
f1 = Foo()
f2 = Foo()
t1 = kk.Tag('tag1')
t2 = kk.Tag('tag2')
# all this stuff has to be in the database to get an id.
# FIXME make tagging in general work with transient objects.
self.session.add(f1)
self.session.add(f2)
self.session.add(t1)
self.session.add(t2)
self.session.commit()
t1.collection.append(f1)
t1.collection.append(f2)
t2.collection.append(f2)
t2.collection.merge(t1.collection)
c1 = sorted(t1.collection)
c2 = sorted(t2.collection)
self.assertEqual(len(t1.collection), len(t2.collection))
# FIXME doesn't work because collection equality includes
# equality of tag names, which of course doesn't exist.
# self.assertEqual(t1.collection, t2.collection)
def test_association(self):
a = Foo()
b = Foo()
@ -359,25 +329,21 @@ class ks_basic(unittest.TestCase):
actual = '\n'.join(actual)
self.assertEqual(expected, '\n%s' % actual)
# @unittest.skip
def test_addstring_repeated(self):
a1 = Foo()
a2 = Foo()
a1.tags.append("tag1")
a2.tags.append("tag1")
# implicitly adds tags.
self.session.add(a1)
# self.session.commit()
# this does an autoflush.
t = self.session.query(kk.Tag).filter_by(text="tag1").first()
self.assertEqual(t.text, "tag1")
self.assertEqual(t.db_id, 1)
self.session.add(a2)
# FIXME committing here causes the test to fail.
# self.session.commit()
# self.assertEqual(a2.tags, ['tag1'])
# print(pd.read_sql_query(text("SELECT * FROM kk_tags"),
self.assertEqual(a2.tags, ['tag1'])
# print(pd.read_sql_query("SELECT * FROM kk_tags",
# self.engine))
del a1, a2, t