Created
July 18, 2021 07:19
-
-
Save MaLiN2223/810b902df5cdd18bc5155bc4f6cb1780 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sqlite3 | |
from argparse import ArgumentError | |
from sqlite3 import Connection | |
from sqlite3.dbapi2 import NotSupportedError | |
from typing import Any, List, Optional, Tuple | |
DB_PATH = "data_out/database.db" | |
class SqlConnection: | |
def __init__(self, override_db_path: Optional[str] = DB_PATH): | |
self.connection = None | |
if override_db_path is None: | |
raise ArgumentError(override_db_path, "Db path should be a valid string") | |
self.db_path = override_db_path | |
self.execute_results = None | |
def __enter__(self): | |
self.connection = self.__new_connection() | |
return self | |
def fetchall(self, query: str): | |
c = self.connection.cursor() | |
c.execute(query) | |
return c.fetchall() | |
def fetchall_with_column_names(self, query: str) -> Tuple[List[Any], List[str]]: | |
c = self.connection.cursor() | |
cursor = c.execute(query) | |
column_names = [x[0] for x in cursor.description] | |
return c.fetchall(), column_names | |
def fetchmany(self, query: str, num_fetch: int): | |
if not self.execute_results: | |
self.c = self.connection.cursor() | |
self.execute_results = self.c.execute(query) | |
return self.c.fetchmany(num_fetch) | |
def fetchone(self, query: str): | |
c = self.connection.cursor() | |
c.execute(query) | |
return c.fetchone() | |
def fetchone_with_column_names(self, query: str) -> Tuple[List[Any], List[str]]: | |
c = self.connection.cursor() | |
cursor = c.execute(query) | |
column_names = [x[0] for x in cursor.description] | |
return c.fetchone(), column_names | |
def execute(self, query: str): | |
c = self.connection.cursor() | |
return c.execute(query) | |
def executemany(self, query: str, data: Any): | |
self.connection.executemany(query, data) | |
def commit(self): | |
self.connection.commit() | |
def __exit__(self, type, value, traceback): | |
self.connection.close() | |
def __new_connection(self) -> Connection: | |
print("Opening connection to", self.db_path) | |
conn = sqlite3.connect(self.db_path, timeout=8.0) | |
conn.execute("pragma journal_mode=wal") | |
return conn |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment