From e18ec5cc5456301a4fad1f14fb4cf98052fa7a8f Mon Sep 17 00:00:00 2001 From: Florent Pigout Date: Thu, 13 Oct 2011 18:52:32 +0000 Subject: add simple db management to replace old dict way to manage story keys --- diff --git a/atoidejouer/db/__init__.py b/atoidejouer/db/__init__.py new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/atoidejouer/db/__init__.py diff --git a/atoidejouer/db/story.py b/atoidejouer/db/story.py new file mode 100644 index 0000000..032dff6 --- /dev/null +++ b/atoidejouer/db/story.py @@ -0,0 +1,171 @@ +# python import +import logging + +# get application logger +logger = logging.getLogger('atoidejouer') + +# sqlite import +import sqlite3; + +# atoidejouer import +from atoidejouer.tools import storage + + +class Key(object): + + def __init__(self, id=None, name=None, type=None, time=-1, layer=-1, media=None): + self.id, self.name, self.type, self.media = id, name, type, media + self.time = 0 if time is None else time + self.layer = 0 if layer is None else layer + + def __repr__(self): + return "name=%s|type=%s|time=%s|layer=%s|media=%s"\ + % (self.name, self.type, self.time, self.layer, self.media) + + def __cmp__(self, other): + return cmp( + (self.name, self.type, self.time, self.layer, self.media), + (other.name, other.type, other.time, other.layer, other.media) + ) + + def create(self): + return "create table story("\ + "id integer primary key autoincrement not null,"\ + "name text,"\ + "type text,"\ + "time integer,"\ + "layer integer,"\ + "media text"\ + ")" + + def insert(self): + columns = list() + values = list() + for c in ['name', 'type', 'time', 'layer', 'media']: + v = getattr(self, c) + if v and v != -1: + columns.append(c) + values.append(str(v) if c in ['time', 'layer'] else "'%s'" % v) + return "insert into story (%s) values (%s)" % ( + ",".join(columns), + ",".join(values) + ) + + def _params(self, crit, joiner=" "): + values = list() + for c in ['name', 'type', 'time', 'layer', 'media']: + v = getattr(self, c) + if v and v != -1: + v = v if v in ['time', 'layer'] else "'%s'" % v + values.append("%s=%s" % (c, v)) + return "%s %s" % (crit, joiner.join(values)) + + def where(self): + """Prepares simple where query according OO parameters. + """ + return self._params("where") + + def set(self): + """Prepares simple where query according OO parameters. + """ + return self._params("set", joiner=",") + + def select(self): + """Prepares simple select query according OO parameters. + """ + return "select * from story %s" % self.where() + + def update(self): + return "update story %s where id=%s" % (self.set(), self.id) + + def delete(self, all=False): + """Prepares simple delete query according OO parameters. + """ + q = "delete from story" + if all is True: + return q + else: + return "%s %s" % (q, self.where()) + + +class DB(object): + + class __Singleton: + + def __init__(self, config=None, name="story", obj=Key): + self.name, self.obj = name, obj + db_path = storage.get_db_path('default') + self.con = sqlite3.connect(db_path, + detect_types=sqlite3.PARSE_DECLTYPES) + self.con.row_factory = sqlite3.Row + self.__check() + + def __check(self): + cur = self.con.cursor() + cur.execute( + "select count(*) from sqlite_master where name=?", + (self.name,) + ) + # remove all first + if cur.fetchone(): + cur.execute("drop table %s" % self.name) + # create fresh db + cur.execute(self.obj().create()) + # and close + cur.close() + + def add(self, obj): + cur = self.con.cursor() + cur.execute(obj.insert()) + count = cur.rowcount + cur.close() + return count + + def _fetch(self, cur): + row = cur.fetchone() + while(row): + yield self.obj(**row) + row = cur.fetchone() + + def all(self): + cur = self.con.cursor() + cur.execute("select * from story") + for obj in self._fetch(cur): + yield obj + cur.close() + + def get(self, obj): + cur = self.con.cursor() + cur.execute(obj.select()) + for obj in self._fetch(cur): + yield obj + cur.close() + + def update(self, obj): + cur = self.con.cursor() + cur.execute(obj.update()) + rowcount = cur.rowcount + cur.close() + return rowcount + + def _del(self, obj=None, all=False): + cur = self.con.cursor() + obj = self.obj() if obj is None else obj + cur.execute(obj.delete(all=all)) + rowcount = cur.rowcount + cur.close() + return rowcount + + # singleton instance + instance = None + + def __new__(c, force=False): + """Singleton new init. + """ + # if doesn't already initialized + if not DB.instance\ + or force is True: + # create a new instance + DB.instance = DB.__Singleton() + # return the manager object + return DB.instance diff --git a/tests/test_db_story.py b/tests/test_db_story.py new file mode 100644 index 0000000..08a8e71 --- /dev/null +++ b/tests/test_db_story.py @@ -0,0 +1,71 @@ +# python import - http://docs.python.org/library/unittest.html +import os, sys, unittest + +# add lib path to current python path +sys.path.append(os.path.join(os.path.dirname(__file__), '..')) + +# atoidejouer import +from atoidejouer.db import story +from atoidejouer.tools import storage + + +class TestDBStory(unittest.TestCase): + + def setUp(self): + story.DB() + + def tearDown(self): + story.DB()._del(all=True) + + def test_all(self): + all = [r for r in story.DB().all()] + self.assertEqual(len(all), 0, + "should not have row! found: %s" % len(all)) + + def test_add(self): + # second row + key = story.Key(None, 'helo', 'image', 0, 1, 'helo.png') + story.DB().add(key) + all = [r for r in story.DB().all()] + self.assertEqual(len(all), 1, + "should have 1 row! found: %s" % len(all)) + self.assertEqual(all[0], key, "not the same row: %s" % all[0]) + # second row + key = story.Key(None, 'hola', 'image', 0, 2, 'hola.png') + story.DB().add(key) + all = [r for r in story.DB().all()] + self.assertEqual(len(all), 2, + "should have 1 row! found: %s" % len(all)) + self.assertEqual(all[1], key, "not the same row: %s" % all[1]) + + def test_get(self): + key = story.Key(None, 'helo', 'image', 0, 1, 'helo.png') + story.DB().add(key) + all = [r for r in story.DB().get(story.Key(name='helo'))] + self.assertEqual(len(all), 1, + "should have 1 row! found: %s" % len(all)) + self.assertEqual(all[0], key, "not the same row: %s" % all[0]) + + def test_update(self): + key = story.Key(None, 'helo', 'image', 0, 1, 'helo.png') + story.DB().add(key) + all = [r for r in story.DB().all()] + key = story.Key(id=all[0].id, name='hola', layer=2, media='hola.png') + story.DB().update(key) + all = [r for r in story.DB().all()] + self.assertEqual(len(all), 1, + "should have 1 row! found: %s" % len(all)) + exp_key = story.Key(None, 'hola', 'image', 0, 2, 'hola.png') + self.assertEqual(all[0], exp_key, "not the same row: %s" % all[0]) + +def suite(): + suite = unittest.TestSuite() + suite.addTest(TestDBStory('test_all')) + suite.addTest(TestDBStory('test_add')) + suite.addTest(TestDBStory('test_get')) + suite.addTest(TestDBStory('test_update')) + return suite + + +if __name__ == '__main__': + unittest.TextTestRunner(verbosity=1).run(suite()) -- cgit v0.9.1