From fcda73933284169083ff435334f43d833956b168 Mon Sep 17 00:00:00 2001 From: chris t Date: Mon, 9 Apr 2018 02:49:45 -0700 Subject: [PATCH] fix identical rows bug. --- kosokoso.py | 72 +++++++++++++++++++++++------------- tests/test_basic.py | 89 +++++++++++++++++++++++++++++++++++++-------- 2 files changed, 121 insertions(+), 40 deletions(-) diff --git a/kosokoso.py b/kosokoso.py index 0dfc25a..10a3ecf 100644 --- a/kosokoso.py +++ b/kosokoso.py @@ -1,7 +1,9 @@ import sqlalchemy as sa +from sqlalchemy import event from sqlalchemy.ext.declarative import declarative_base, declared_attr from sqlalchemy.ext.associationproxy import association_proxy +import itertools import pprint Base = declarative_base() @@ -33,6 +35,9 @@ class Tag(Base): def __len__(self): return len(self.data) + def __eq__(self, other): + return self.text == other.text + def _get_class_by_tablename(self, tablename): """Return class reference mapped to table. @@ -45,37 +50,30 @@ class Tag(Base): return c - __tablename__ = "kk_tags" db_id = sa.Column(sa.Integer, primary_key=True) text = sa.Column(sa.Unicode(255, convert_unicode=False), unique=True) # this & collection should be read-only? + # associations = sa.orm.relationship("TagAssociation", + # back_populates="tag_obj") collection = sa.orm.relationship("TagAssociation", collection_class=TaggedObjectCollection, enable_typechecks = False, # back_populates="tag_obj", primaryjoin="TagAssociation.tag_id==Tag.db_id" ) + __tablename__ = "kk_tags" + __mapper_args__ = { + "primary_key": [db_id], + "batch": False + } def __init__(self, tag_text=None): super(Tag, self).__init__() self.text = tag_text - # if tag_text: - # select_expr = sa.select([self.__table__]).where( - # self.__table__.c.text == tag_text) - # session = sa.inspect(self).session - # extant = session.execute(select_expr) - - # if extant: - # self.db_id = extant[0].db_id - # self.text = extant[0].text - # else: - # self.text = tag_text - - def enforce_unique_text(self, key, value): - pass - # i have a feeling this may require a flush. - + + def __repr__(self): + return "kk_tag %s: %s" % (self.db_id, self.text) def _get_class_by_tablename(self, tablename): """Return class reference mapped to table. @@ -96,8 +94,15 @@ class TagAssociation(Base): tag_id = sa.Column(sa.Integer, sa.ForeignKey("kk_tags.db_id")) target_table = sa.Column(sa.Unicode(255, convert_unicode=False)) target_id = sa.Column(sa.Integer) + __mapper_args__ = { + "polymorphic_on": "target_table" + } + + # m = sa.orm.mapper(Tag, Tag.__table__, non_primary=True, + # primary_key=[Tag.text]) + # tag_obj = sa.orm.relationship(m) + tag_obj = sa.orm.relationship(Tag) - tag_obj = sa.orm.relationship("Tag") # FIXME doesn't work, may not be needed. # tag_obj = sa.orm.relationship("Tag", back_populates="collection") tag = association_proxy("tag_obj", "text") @@ -126,14 +131,13 @@ class Taggable(Base): def cls_init(self, tag=None): super(TagAssociation, self).__init__() - if self.args and self.args.get("target_table"): - target_table = self.args["target_table"] - self.target_table = u"%s" % target_table + self.target_table = table if tag: if isinstance(tag, Tag): self.tag_obj = tag else: - self.tag = tag + # print "cls init called with text %s" % tag + self.tag_obj = Tag(tag) assoc_cls = type( @@ -145,9 +149,6 @@ class Taggable(Base): "polymorphic_on": "target_table", }, __init__ = cls_init, - args = { - "target_table": table - }, ) ) cls.tags = association_proxy("kk_tag_associations", "tag") @@ -160,3 +161,24 @@ class Taggable(Base): "foreign(TagAssociation.target_id)==%s.db_id)" %(table, name)), ) + + +def delete_before_insert(mapper, conn, target): + r = conn.execute("SELECT db_id FROM kk_tags WHERE text='%s'" % + target.text) + if r: + conn.execute("DELETE FROM kk_tags WHERE text='%s'" % target.text) + +# i'm not sure if this is necessary -- its purpose is to set the id on +# our new instance to ensure it matches the old instance. if this +# isn't done, it seems like the id might change, breaking old links. +def enforce_unique_text(session, flush_context, instances): + for i in session.new: + if isinstance(i, Tag): + t = session.query(Tag).filter_by(text=i.text).first() + if t: + i.db_id = t.db_id + +event.listen(Tag, 'before_insert', delete_before_insert) +event.listen(sa.orm.session.Session, 'before_flush', + enforce_unique_text) diff --git a/tests/test_basic.py b/tests/test_basic.py index fadd6d8..2b220b0 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -28,6 +28,9 @@ class ks_basic(unittest.TestCase): kk.Base.metadata.create_all(self.engine) self.session = sa.orm.Session(self.engine) + def tearDown(self): + del self.engine + def test_add_as_object(self): a = Foo() self.session.add(a) @@ -52,11 +55,22 @@ class ks_basic(unittest.TestCase): a_id = a.db_id t_id = t.db_id + t2_id = t2.db_id del a, t, t2 a = self.session.query(Foo).get(a_id) t = self.session.query(kk.Tag).get(t_id) + t2 = self.session.query(kk.Tag).get(t2_id) self.assertEqual(['test_tag', 'test2'], a.tags) self.assertEqual(u'test_tag', t.text) + self.assertEqual(u'test2', t2.text) + + def test_unique(self): + t1 = kk.Tag('ttt') + t2 = kk.Tag('ttt') + + self.session.add(t1) + self.session.add(t2) + self.session.commit() def test_collection(self): """Test access to the collection of objects associated with a @@ -78,8 +92,6 @@ class ks_basic(unittest.TestCase): f2 = self.session.query(Foo).get(2) b1 = self.session.query(Bar).get(1) t = self.session.query(kk.Tag).get(1) - # FIXME there's a bug where adding via text doesn't check if a - # row with the same text already exists. f1.tags.append(t) f2.tags.append(t) b1.tags.append(t) @@ -98,6 +110,7 @@ class ks_basic(unittest.TestCase): # self.assertEqual(3, len(t.collection)) l = list(t.collection) + # TODO do we need to test these? # these are the db_ids of the tag associations. # ta1 = self.session.query(kk.TagAssociation).get(1) # ta2 = self.session.query(kk.TagAssociation).get(2) @@ -109,6 +122,26 @@ class ks_basic(unittest.TestCase): self.assertTrue(isinstance(l[1], Foo)) self.assertTrue(isinstance(l[2], Bar)) + def test_association(self): + a = Foo() + b = Foo() + c = Bar() + t1 = kk.Tag("bleh") + t2 = kk.Tag("blah") + + for i in [a,b]: + i.tags.append(t1) + for i in [b,c]: + i.tags.append(t2) + for i in [a,b,c,t1,t2]: + self.session.add(i) + self.session.commit() + + l = a.kk_tag_associations[0] + self.assertEqual(l.tag_id, t1.db_id) + self.assertEqual(l.target_table, "foos") + self.assertEqual(l.target_id, a.db_id) + def test_addstring(self): a = Foo() a.tags.append('testtag') @@ -136,7 +169,7 @@ class ks_basic(unittest.TestCase): b = Bar() b.tags.append('tagtest') self.session.add(b) - self.session.commit() + # self.session.commit() c = Foo() c.tags.append('thirdtag') @@ -158,19 +191,18 @@ class ks_basic(unittest.TestCase): self.assertEqual(['testtag'], a.tags) self.assertEqual(['tagtest'], b.tags) self.assertEqual(['thirdtag'], c.tags) - b.tags.append('tagtest2') + b.tags.append('testtag') # TODO figure out autocommit behavior. # self.session.commit() del b b = self.session.query(Bar).get(b_id) - self.assertEqual(['tagtest', 'tagtest2'], b.tags) - - # TODO test class inheritance. should work in simple cases. . - # . somewhat difficult, possibly infeasible in complex cases. + self.assertEqual(['tagtest', 'testtag'], b.tags) # print(pd.read_sql_query("SELECT * FROM kk_tag_associations", # self.engine)) + # print(pd.read_sql_query("SELECT * FROM kk_tags", + # self.engine)) def test_list(self): a = Foo() @@ -189,16 +221,43 @@ class ks_basic(unittest.TestCase): a.tags.extend(l2) self.assertEqual(l, a.tags) - # def test_addstring_repeated(self): - # a1 = Foo() - # a2 = Foo() + def test_addstring_repeated(self): + a1 = Foo() + a2 = Foo() - # a1.tags.append("tag1") - # a1.tags.append("tag1") - # self.session.add(a1, a2) - # self.session.commit() + a1.tags.append("tag1") + a2.tags.append("tag1") + self.session.add(a1) + # self.session.commit() + t = self.session.query(kk.Tag).filter_by(text="tag1").first() + self.assertEqual(t.text, "tag1") + self.assertEqual(t.db_id, 1) + self.session.add(a2) + # self.session.commit() + self.assertEqual(a2.tags, ['tag1']) + # print(pd.read_sql_query("SELECT * FROM kk_tags", + # self.engine)) + del a1, a2, t + (a1, a2) = tuple(self.session.query(Foo)) + t = self.session.query(kk.Tag).first() + self.assertEqual(len(list(self.session.query(kk.Tag))), 1) + self.assertEqual(a1.tags, ['tag1']) + self.assertEqual(a2.tags, ['tag1']) + self.assertEqual(t.db_id, 1) + + # print(pd.read_sql_query("SELECT * FROM kk_tag_associations", + # self.engine)) + # print(pd.read_sql_query("SELECT * FROM kk_tags", + # self.engine)) # TODO test access to Tag objects + # TODO test concurrent setting the same tag from different + # processes. + # TODO delete unused tags + + # TODO test class inheritance. should work in simple cases. . + # . somewhat difficult, possibly infeasible in complex cases. +