Skip to content

Instantly share code, notes, and snippets.

@pessom
Forked from KurimuzonAkuma/storages.py
Last active January 22, 2025 04:50
Show Gist options
  • Save pessom/06d0078a2a9df0860b049621babf60ba to your computer and use it in GitHub Desktop.
Save pessom/06d0078a2a9df0860b049621babf60ba to your computer and use it in GitHub Desktop.
MySQL and SQLite storages for aiogram framework
import pickle
from asyncio import Lock
from typing import Any, Dict, Optional
import aiomysql
import aiosqlite
from aiogram.fsm.state import State
from aiogram.fsm.storage.base import BaseStorage, StateType, StorageKey
class MySQLStorage(BaseStorage):
"""
MySQL storage backend for FSM.\n
If database and table does not exist, it will be created automatically.
"""
def __init__(
self,
host: str,
user: str,
password: str,
database: str,
) -> None:
self.host = host
self.user = user
self.password = password
self.database = database
self.connection = None
self.cursor = None
self._lock = Lock()
async def __execute(self, query: str, values: tuple = None, commit: bool = False):
async with self._lock:
await self.connect()
await self.cursor.execute(query, values)
if commit:
await self.connection.commit()
async def __create_tables(self) -> None:
await self.cursor.execute(
f"CREATE DATABASE IF NOT EXISTS `{self.database}`;\n"
f"USE `{self.database}`;\n"
"CREATE TABLE IF NOT EXISTS `aiogram_fsm_states` ("
"`chat_id` BIGINT NOT NULL,"
"`user_id` BIGINT NOT NULL,"
"`state` TEXT,"
"PRIMARY KEY (`chat_id`)"
");\n"
"CREATE TABLE IF NOT EXISTS `aiogram_fsm_data` ("
"`chat_id` BIGINT NOT NULL,"
"`user_id` BIGINT NOT NULL,"
"`data` BLOB,"
"PRIMARY KEY (`chat_id`)"
");"
)
async def connect(self) -> aiomysql.Connection:
if self.connection is None:
self.connection = await aiomysql.connect(
host=self.host,
user=self.user,
password=self.password,
db=self.database,
)
self.cursor = await self.connection.cursor(aiomysql.DictCursor)
await self.__create_tables()
return self.connection
async def close(self) -> None:
if isinstance(self.connection, aiomysql.Connection):
await self.cursor.close()
self.connection.close()
self.connection = None
self.cursor = None
async def set_state(self, key: StorageKey, state: StateType = None) -> None:
_state = state.state if isinstance(state, State) else state
if _state is None:
await self.__execute(
"DELETE FROM `aiogram_fsm_states` WHERE `chat_id` = %s AND `user_id` = %s;",
(key.chat_id, key.user_id),
commit=True,
)
return
await self.__execute(
"INSERT INTO `aiogram_fsm_states` (`chat_id`, `user_id`, `state`) VALUES (%s, %s, %s) "
"ON DUPLICATE KEY UPDATE `state` = %s;",
(key.chat_id, key.user_id, _state, _state),
commit=True,
)
async def get_state(self, key: StorageKey) -> Optional[str]:
await self.__execute(
"SELECT `state` FROM `aiogram_fsm_states` WHERE `chat_id` = %s AND `user_id` = %s;",
(key.chat_id, key.user_id),
)
result = await self.cursor.fetchone()
return result["state"] if result else None
async def set_data(self, key: StorageKey, data: Dict[str, Any]) -> None:
if not data:
await self.__execute(
"DELETE FROM `aiogram_fsm_data` WHERE `chat_id` = %s AND `user_id` = %s;",
(key.chat_id, key.user_id),
commit=True,
)
return
serialized_data = pickle.dumps(data)
await self.__execute(
"INSERT INTO `aiogram_fsm_data` (`chat_id`, `user_id`, `data`) VALUES (%s, %s, %s) "
"ON DUPLICATE KEY UPDATE `data` = %s;",
(key.chat_id, key.user_id, serialized_data, serialized_data),
commit=True,
)
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
await self.__execute(
"SELECT `data` FROM `aiogram_fsm_data` WHERE `chat_id` = %s AND `user_id` = %s;",
(key.chat_id, key.user_id),
)
result = await self.cursor.fetchone()
return pickle.loads(result["data"]) if result else {}
class SQLiteStorage(BaseStorage):
"""
SQLite storage backend for FSM.\n
Tables will be created automatically.
"""
def __init__(self, path: str) -> None:
self.path = path
self.connection = None
self.cursor = None
self._lock = Lock()
async def __execute(self, query: str, values: tuple = None, commit: bool = False):
async with self._lock:
await self.connect()
self.cursor.execute(query, values)
if commit:
await self.connection.commit()
async def __create_tables(self) -> None:
await self.cursor.execute(
"CREATE TABLE IF NOT EXISTS `aiogram_fsm_states` ("
"`chat_id` BIGINT NOT NULL,"
"`user_id` BIGINT NOT NULL,"
"`state` TEXT,"
"PRIMARY KEY (`chat_id`, `user_id`)"
");\n"
"CREATE TABLE IF NOT EXISTS `aiogram_fsm_data` ("
"`chat_id` BIGINT NOT NULL,"
"`user_id` BIGINT NOT NULL,"
"`data` BLOB,"
"PRIMARY KEY (`chat_id`, `user_id`)"
");"
)
async def connect(self) -> aiosqlite.Connection:
if self.connection is None:
self.connection = await aiosqlite.connect(self.path)
self.connection.row_factory = aiosqlite.Row
self.cursor = await self.connection.cursor()
await self.__create_tables()
return self.connection
async def close(self) -> None:
if isinstance(self.connection, aiosqlite.Connection):
await self.connection.close()
await self.cursor.close()
self.connection = None
self.cursor = None
async def set_state(self, key: StorageKey, state: StateType = None) -> None:
_state = state.state if isinstance(state, State) else state
if _state is None:
await self.__execute(
"DELETE FROM `aiogram_fsm_states` WHERE `chat_id` = ? AND `user_id` = ?;",
(key.chat_id, key.user_id),
commit=True,
)
return
await self.__execute(
"INSERT INTO `aiogram_fsm_states` (`chat_id`, `user_id`, `state`) VALUES (?, ?, ?) "
"ON CONFLICT(`chat_id`, `user_id`) DO UPDATE SET `state` = ?;",
(key.chat_id, key.user_id, _state, _state),
commit=True,
)
async def get_state(self, key: StorageKey) -> Optional[str]:
await self.__execute(
"SELECT `state` FROM `aiogram_fsm_states` WHERE `chat_id` = ? AND `user_id` = ?;",
(key.chat_id, key.user_id),
)
result = await self.cursor.fetchone()
return result["state"] if result else None
async def set_data(self, key: StorageKey, data: Dict[str, Any]) -> None:
if not data:
await self.__execute(
"DELETE FROM `aiogram_fsm_data` WHERE `chat_id` = ? AND `user_id` = ?;",
(key.chat_id, key.user_id),
commit=True,
)
return
serialized_data = pickle.dumps(data)
await self.__execute(
"INSERT INTO `aiogram_fsm_data` (`chat_id`, `user_id`, `data`) VALUES (?, ?, ?) "
"ON CONFLICT(`chat_id`, `user_id`) DO UPDATE SET `data` = ?;",
(key.chat_id, key.user_id, serialized_data, serialized_data),
commit=True,
)
async def get_data(self, key: StorageKey) -> Dict[str, Any]:
await self.__execute(
"SELECT `data` FROM `aiogram_fsm_data` WHERE `chat_id` = ? AND `user_id` = ?;",
(key.chat_id, key.user_id),
)
result = await self.cursor.fetchone()
return pickle.loads(result["data"]) if result else {}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment