commit dac0ab6263ee157c1199e206a6b3f12fcd35fd8e
Author: Cecylia Bocovich <[email protected]>
Date:   Wed Feb 19 14:36:56 2020 -0500

    Revert to using twisted for sqlite3 db
---
 README.md          |  2 --
 gettor/utils/db.py | 86 ++++++++++++++++++++++++++++--------------------------
 tests/conftests.py |  1 -
 tests/test_db.py   | 67 ------------------------------------------
 4 files changed, 45 insertions(+), 111 deletions(-)

diff --git a/README.md b/README.md
index 6ad9603..f31d113 100644
--- a/README.md
+++ b/README.md
@@ -82,7 +82,5 @@ GetTor includes PyTest unit tests. To run the tests, first 
install the dependenc
 
 
 ```
-$ python3 scripts/create_db -n -c -o -f tests/gettor.db
-$ python3 scripts/add_links_to_db -f tests/gettor.db
 $ pytest-3 tests/
 ```
diff --git a/gettor/utils/db.py b/gettor/utils/db.py
index 7c3853f..0ca11aa 100644
--- a/gettor/utils/db.py
+++ b/gettor/utils/db.py
@@ -9,10 +9,10 @@
 
 from __future__ import absolute_import
 
-import sqlite3
 from datetime import datetime
 
 from twisted.python import log
+from twisted.enterprise import adbapi
 
 class SQLite3(object):
        """
@@ -20,90 +20,94 @@ class SQLite3(object):
        """
        def __init__(self, dbname):
                """Constructor."""
-               self.conn = sqlite3.connect(dbname)
+               self.dbpool = adbapi.ConnectionPool(
+                       "sqlite3", dbname, check_same_thread=False
+               )
+
+       def query_callback(self, results=None):
+               """
+               Query callback
+               Log that the database query has been executed and return results
+               """
+               log.msg("Database query executed successfully.")
+               return results
+
+       def query_errback(self, error=None):
+               """
+        Query error callback
+               Logs database error
+               """
+               if error:
+                       log.msg("Database error: {}".format(error))
+               return None
 
        def new_request(self, id, command, service, platform, language, date, 
status):
                """
                Perform a new request to the database
                """
-               c = self.conn.cursor()
                query = "INSERT INTO requests VALUES(?, ?, ?, ?, ?, ?, ?)"
 
-               c.execute(query, (id, command, platform, language, service,
-                    date, status))
-               self.conn.commit()
-               return
+               return self.dbpool.runQuery(
+                       query, (id, command, platform, language, service, date, 
status)
+               
).addCallback(self.query_callback).addErrback(self.query_errback)
 
        def get_requests(self, status, command, service):
                """
                Perform a SELECT request to the database
                """
-               c = self.conn.cursor()
                query = "SELECT * FROM requests WHERE service=? AND command=? 
AND "\
                "status = ?"
 
-               c.execute(query, (service, command, status))
-
-               return c.fetchall()
+               return self.dbpool.runQuery(
+                       query, (service, command, status)
+               
).addCallback(self.query_callback).addErrback(self.query_errback)
 
        def get_num_requests(self, id, service):
                """
                Get number of requests for statistics
                """
-               c = self.conn.cursor()
-               query = "SELECT COUNT(rowid) FROM requests WHERE id=? AND "\
-               "service=?"
+               query = "SELECT COUNT(rowid) FROM requests WHERE id=? AND 
service=?"
 
-               c.execute(query, (id, service))
-               return c.fetchone()[0]
+               return self.dbpool.runQuery(
+                       query, (id, service)
+               
).addCallback(self.query_callback).addErrback(self.query_errback)
 
        def remove_request(self, id, service, date):
                """
                Removes completed request record from the database
                """
-               c = self.conn.cursor()
-               query = "DELETE FROM requests WHERE id=? AND service=? AND "\
-                "date=?"
+               query = "DELETE FROM requests WHERE id=? AND service=? AND 
date=?"
 
-               c.execute(query, (id, service, date))
-               self.conn.commit()
-               return
+               return self.dbpool.runQuery(
+                       query, (id, service, date)
+               
).addCallback(self.query_callback).addErrback(self.query_errback)
 
        def update_stats(self, command, service, platform=None, language='en'):
                """
                Update statistics to the database
                """
-               c = self.conn.cursor()
                now_str = datetime.now().strftime("%Y%m%d")
                query = "INSERT INTO stats(num_requests, platform, language, 
command, "\
-                       "service, date) VALUES (1, ?, ?, ?, ?, ?) ON "\
-                        "CONFLICT(platform, language, command, service, date) 
"\
-                        "DO UPDATE SET num_requests=num_requests+1"
+                       "service, date) VALUES (1, ?, ?, ?, ?, ?) ON 
CONFLICT(platform, "\
+                               "language, command, service, date) DO UPDATE 
SET num_requests=num_requests+1"
 
-               c.execute(query, (platform, language, command, service,
-                   now_str))
-               self.conn.commit()
-               return
+               return self.dbpool.runQuery(
+                       query, (platform, language, command, service, now_str)
+               
).addCallback(self.query_callback).addErrback(self.query_errback)
 
        def get_links(self, platform, language, status):
                """
                Get links from the database per platform
                """
-               c = self.conn.cursor()
                query = "SELECT * FROM links WHERE platform=? AND language=? 
AND status=?"
-               c.execute(query, (platform, language, status))
-
-               return c.fetchall()
+               return self.dbpool.runQuery(
+                       query, (platform, language, status)
+               
).addCallback(self.query_callback).addErrback(self.query_errback)
 
        def get_locales(self):
                """
                Get a list of the supported tor browser binary locales
                """
-               c = self.conn.cursor()
                query = "SELECT DISTINCT language FROM links"
-               c.execute(query)
-
-               locales = []
-               for locale in c.fetchall():
-                   locales.append(locale[0])
-               return locales
+               return self.dbpool.runQuery(query
+               
).addCallback(self.query_callback).addErrback(self.query_errback)
diff --git a/tests/conftests.py b/tests/conftests.py
index d509776..cbb4d28 100644
--- a/tests/conftests.py
+++ b/tests/conftests.py
@@ -5,7 +5,6 @@ from __future__ import unicode_literals
 from gettor.utils import options
 from gettor.utils import strings
 from gettor.utils import twitter
-from gettor.utils.db import SQLite3
 from gettor.services.email.sendmail import Sendmail
 from gettor.services.twitter import twitterdm
 from gettor.parse.email import EmailParser, AddressError, DKIMError
diff --git a/tests/test_db.py b/tests/test_db.py
deleted file mode 100644
index d663d89..0000000
--- a/tests/test_db.py
+++ /dev/null
@@ -1,67 +0,0 @@
-#!/usr/bin/env python3
-import pytest
-from datetime import datetime
-from twisted.trial import unittest
-
-from . import conftests
-
-class DatabaseTests(unittest.TestCase):
-
-    # Fail any tests which take longer than 15 seconds.
-    timeout = 15
-    def setUp(self):
-        self.settings = 
conftests.options.parse_settings("en","tests/test.conf.json")
-        print(self.settings.get("dbname"))
-        self.db = conftests.SQLite3(self.settings.get("dbname"))
-
-    def tearDown(self):
-        print("tearDown()")
-
-    def add_dummy_requests(self, num):
-        now_str = datetime.now().strftime("%Y%m%d")
-        for i in (0, num):
-            self.db.new_request(
-                id='testid',
-                command='links',
-                platform='linux',
-                language='en',
-                service='email',
-                date=now_str,
-                status="ONHOLD",
-            )
-
-    def test_stored_locales(self):
-        locales = self.db.get_locales()
-        self.assertIn('en-US', locales)
-
-    def test_requests(self):
-        now_str = datetime.now().strftime("%Y%m%d")
-        self.add_dummy_requests(2)
-        num = self.db.get_num_requests("testid", "email")
-        self.assertEqual(num, 2)
-
-        requests = self.db.get_requests("ONHOLD", "links", "email")
-        for request in requests:
-            print(request)
-            self.assertEqual(request[1], "links")
-            self.assertEqual(request[4], "email")
-            self.assertEqual(request[5], now_str)
-            self.assertEqual(request[6], "ONHOLD")
-        self.assertEqual(len(requests), 2)
-
-        self.db.remove_request("testid", "email", now_str)
-        num = self.db.get_num_requests("testid", "email")
-        self.assertEqual(num, 0)
-
-    def test_links(self):
-        links = self.db.get_links("linux", "en-US", "ACTIVE")
-        self.assertEqual(len(links), 2) # Right now we have github and gitlab
-
-        for link in links:
-            self.assertEqual(link[1], "linux")
-            self.assertEqual(link[2], "en-US")
-            self.assertEqual(link[6], "ACTIVE")
-            self.assertIn(link[5], ["github", "gitlab"])
-
-if __name__ == "__main__":
-    unittest.main()



_______________________________________________
tor-commits mailing list
[email protected]
https://lists.torproject.org/cgi-bin/mailman/listinfo/tor-commits

Reply via email to