From 5ec97e7cd0b137b27a5f4c7503bff518518e0c6a Mon Sep 17 00:00:00 2001 From: chris t Date: Thu, 16 Mar 2023 03:51:46 -0700 Subject: [PATCH] tests pass with sqlalchemy 2.0 --- __init__.py | 3 ++- kosokoso.py | 34 +++++++++++++++++----------------- tests/test_basic.py | 33 ++++++++++++++++++++++++--------- 3 files changed, 43 insertions(+), 27 deletions(-) diff --git a/__init__.py b/__init__.py index 321df17..e0650f8 100644 --- a/__init__.py +++ b/__init__.py @@ -1 +1,2 @@ -from kosokoso import * +# from kosokoso import * +from kosokoso import init_base diff --git a/kosokoso.py b/kosokoso.py index e875728..44e9fe5 100644 --- a/kosokoso.py +++ b/kosokoso.py @@ -1,6 +1,7 @@ import sqlalchemy as sa -from sqlalchemy.ext.declarative import declared_attr +from sqlalchemy.orm import declared_attr from sqlalchemy.ext.associationproxy import association_proxy +from sqlalchemy import text import itertools import logging @@ -71,16 +72,19 @@ class TagBase(object): :param tablename: String with name of table. :return: Class reference or None. """ - for c in Base._decl_class_registry.values(): - if (hasattr(c, '__tablename__') - and c.__tablename__ == tablename): + for c in Base.registry._class_registry.values(): + if getattr(c, '__tablename__', None) == tablename: return c + # for c in Base._decl_class_registry.values(): + # if (hasattr(c, '__tablename__') + # and c.__tablename__ == tablename): + # return c db_id = sa.Column(sa.Integer, primary_key=True) - text = sa.Column(sa.Unicode(255, convert_unicode=False), - unique=True) - # this & collection should be read-only? + # text = sa.Column(sa.Unicode(255, convert_unicode=False), + # unique=True) + text = sa.Column(sa.Unicode(255), unique=True) collection = sa.orm.relationship("TagAssociation", collection_class=TaggedObjectCollection, enable_typechecks=True, @@ -110,22 +114,17 @@ class TagAssociationBase(object): __tablename__ = "kk_tag_associations" db_id = sa.Column(sa.Integer, primary_key=True) tag_id = sa.Column(sa.Integer, sa.ForeignKey("kk_tags.db_id")) - target_table = sa.Column(sa.Unicode(255, convert_unicode=False)) + # target_table = sa.Column(sa.Unicode(255, convert_unicode=False)) + target_table = sa.Column(sa.Unicode(255)) target_id = sa.Column(sa.Integer) __mapper_args__ = { "polymorphic_on": "target_table" } - # m = sa.orm.mapper(Tag, Tag.__table__, non_primary=True, - # primary_key=[Tag.text]) - # tag_obj = sa.orm.relationship(m) @declared_attr def tag_obj(self): return sa.orm.relationship(Tag) - # tag_obj = sa.orm.relationship(Tag) - # FIXME doesn't work, may not be needed. - # tag_obj = sa.orm.relationship("Tag", back_populates="collection") tag = association_proxy("tag_obj", "text") tracked_obj = None @@ -220,10 +219,11 @@ def init_base(new_base): def delete_before_insert(mapper, conn, target): # TODO figure out where exactly transactions happen. # TODO can we just upsert? Or, for that matter, skip the insert? - r = conn.execute("SELECT db_id FROM kk_tags WHERE text='%s'" % - target.text) + r = conn.execute(text("SELECT db_id FROM kk_tags WHERE text='%s'" % + target.text)) if r: - conn.execute("DELETE FROM kk_tags WHERE text='%s'" % target.text) + conn.execute(text("DELETE FROM kk_tags WHERE text='%s'" % + target.text)) # i'm not sure if this is necessary -- its purpose is to set the id on # our new instance to ensure it matches the old instance. if this diff --git a/tests/test_basic.py b/tests/test_basic.py index d6c64f7..e6a5f07 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -6,13 +6,18 @@ import mock import pandas as pd import pprint import sqlalchemy as sa +import sqlalchemy.orm as orm import sqlite3 import textwrap import unittest +from sqlalchemy import text + +# from . import kosokoso as kk import libkosokoso as kk -DummyBase = sa.ext.declarative.declarative_base() +# DummyBase = sa.orm.declarative_base() +DummyBase = orm.declarative_base() kk.__dict__.update(kk.init_base(DummyBase)) # TODO make it so DummyBase doesn't have to come after the mixin. @@ -83,7 +88,7 @@ class ks_basic(unittest.TestCase): self.session.add(t1) self.session.add(t2) - self.session.commit() + # self.session.commit() # del t1, t2 ts = self.session.query(kk.Tag).all() @@ -286,6 +291,7 @@ class ks_basic(unittest.TestCase): a.tags.extend(l2) self.assertEqual(l, a.tags) + # @unittest.skip def test_pretty(self): """verify that the example in the readme works.""" a = Foo() @@ -311,12 +317,13 @@ class ks_basic(unittest.TestCase): ) actual = [] - actual.append(str(pd.read_sql_query("SELECT * FROM foos", - self.engine))) - actual.append(str(pd.read_sql_query("SELECT * FROM kk_tag_associations", - self.engine))) - actual.append(str(pd.read_sql_query("SELECT * FROM kk_tags", - self.engine))) + with self.engine.begin() as conn: + actual.append(str(pd.read_sql_query(text("SELECT * FROM foos"), + conn))) + actual.append(str(pd.read_sql_query( + text("SELECT * FROM kk_tag_associations"), conn))) + actual.append(str(pd.read_sql_query(text("SELECT * FROM kk_tags"), + conn))) for i in a.tags: actual.append(i) actual = '\n'.join(actual) @@ -376,5 +383,13 @@ class kk_mocktest(unittest.TestCase): @mock.patch('libkosokoso.Tag', autospec=True) def test_tag_mock(self, tag_mock): - t1 = kk.Tag() + t1 = kk.Tag('test') self.assertIsInstance(t1, mock.NonCallableMagicMock) + + f1 = Foo() + f1.tags.append(t1) + self.assertIsInstance(f1.kk_tag_associations[0].tag_obj, + mock.NonCallableMagicMock) + + # FIXME bypasses mock + f1.tags.append('test2')