implement base merging

This commit is contained in:
chris t 2020-07-19 15:02:58 -07:00
parent f683d07793
commit f4a7441e86
2 changed files with 40 additions and 17 deletions

View File

@ -8,10 +8,11 @@ import pprint
import sys import sys
# import traceback # import traceback
Base = declarative_base() # this will get overridden by init_base
Base = None
class Tag(Base): class TagBase(object):
def __eq__(self, other): def __eq__(self, other):
return self.text == other.text return self.text == other.text
@ -117,8 +118,7 @@ class Tag(Base):
return c return c
class TagAssociationBase(object):
class TagAssociation(Base):
__tablename__ = "kk_tag_associations" __tablename__ = "kk_tag_associations"
db_id = sa.Column(sa.Integer, primary_key=True) db_id = sa.Column(sa.Integer, primary_key=True)
tag_id = sa.Column(sa.Integer, sa.ForeignKey("kk_tags.db_id")) tag_id = sa.Column(sa.Integer, sa.ForeignKey("kk_tags.db_id"))
@ -131,7 +131,10 @@ class TagAssociation(Base):
# m = sa.orm.mapper(Tag, Tag.__table__, non_primary=True, # m = sa.orm.mapper(Tag, Tag.__table__, non_primary=True,
# primary_key=[Tag.text]) # primary_key=[Tag.text])
# tag_obj = sa.orm.relationship(m) # tag_obj = sa.orm.relationship(m)
tag_obj = sa.orm.relationship(Tag) @declared_attr
def tag_obj(self):
return sa.orm.relationship(Tag)
# tag_obj = sa.orm.relationship(Tag)
# FIXME doesn't work, may not be needed. # FIXME doesn't work, may not be needed.
# tag_obj = sa.orm.relationship("Tag", back_populates="collection") # tag_obj = sa.orm.relationship("Tag", back_populates="collection")
@ -152,7 +155,8 @@ class TagAssociation(Base):
return "tag assoc %s: %s:%s" % (self.db_id, return "tag assoc %s: %s:%s" % (self.db_id,
self.target_table, self.target_id) self.target_table, self.target_id)
class Taggable(Base):
class TaggableBase(object):
__abstract__ = True __abstract__ = True
@declared_attr @declared_attr
@ -197,6 +201,26 @@ class Taggable(Base):
) )
def init_base(new_base):
def init_cls(name, base):
current = globals()[name + 'Base']
new_ns = current.__dict__.copy()
del new_ns['__dict__']
new_ns = type(name, (new_base,), new_ns)
globals()[name] = new_ns
for c in ['Tag', 'TagAssociation', 'Taggable']:
init_cls(c, new_base)
globals()['Base'] = new_base
sa.event.listen(Tag, 'before_insert', delete_before_insert)
sa.event.listen(sa.orm.session.Session, 'before_flush',
enforce_unique_text)
return {'Tag': globals()['Tag'],
'TagAssociation': globals()['TagAssociation'],
'Taggable': globals()['Taggable']}
def delete_before_insert(mapper, conn, target): def delete_before_insert(mapper, conn, target):
# TODO figure out where exactly transactions happen. # TODO figure out where exactly transactions happen.
r = conn.execute("SELECT db_id FROM kk_tags WHERE text='%s'" % r = conn.execute("SELECT db_id FROM kk_tags WHERE text='%s'" %
@ -214,6 +238,3 @@ def enforce_unique_text(session, flush_context, instances):
if t: if t:
i.db_id = t.db_id i.db_id = t.db_id
sa.event.listen(Tag, 'before_insert', delete_before_insert)
sa.event.listen(sa.orm.session.Session, 'before_flush',
enforce_unique_text)

View File

@ -11,12 +11,11 @@ import unittest
import libkosokoso as kk import libkosokoso as kk
# TODO get this working -- merge bases? DummyBase = sa.ext.declarative.declarative_base()
# DummyBase = sa.ext.declarative.declarative_base() kk.__dict__.update(kk.init_base(DummyBase))
class DummyBase(object):
"""dummy"""
class Foo(DummyBase, kk.Taggable): # TODO make it so DummyBase doesn't have to come after the mixin.
class Foo(kk.Taggable, DummyBase):
__tablename__ = 'foos' __tablename__ = 'foos'
db_id = sa.Column(sa.Integer, primary_key=True) db_id = sa.Column(sa.Integer, primary_key=True)
def __repr__(self): def __repr__(self):
@ -25,7 +24,7 @@ class Foo(DummyBase, kk.Taggable):
def __eq__(self, other): def __eq__(self, other):
return self.db_id == other.db_id return self.db_id == other.db_id
class Bar(DummyBase, kk.Taggable): class Bar(kk.Taggable, DummyBase):
__tablename__ = 'bars' __tablename__ = 'bars'
db_id = sa.Column(sa.Integer, primary_key=True) db_id = sa.Column(sa.Integer, primary_key=True)
def __repr__(self): def __repr__(self):
@ -35,12 +34,15 @@ class ks_basic(unittest.TestCase):
def setUp(self): def setUp(self):
# self.engine = sa.create_engine('sqlite://', echo=True) # self.engine = sa.create_engine('sqlite://', echo=True)
self.engine = sa.create_engine('sqlite://', echo=False) self.engine = sa.create_engine('sqlite://', echo=False)
kk.Base.metadata.create_all(self.engine) DummyBase.metadata.create_all(self.engine)
self.session = sa.orm.Session(self.engine) self.session = sa.orm.Session(self.engine)
def tearDown(self): def tearDown(self):
del self.engine del self.engine
# def debug(self):
# pprint.pprint(kk.__dict__)
def test_add_as_object(self): def test_add_as_object(self):
a = Foo() a = Foo()
self.session.add(a) self.session.add(a)
@ -355,7 +357,7 @@ class ks_basic(unittest.TestCase):
self.assertEqual(t3, kk.Tag('tag3')) self.assertEqual(t3, kk.Tag('tag3'))
# TODO test concurrent setting the same tag from different # TODO test concurrently setting the same tag from different
# processes. # processes.
# TODO delete unused tags # TODO delete unused tags