Skip to content

Instantly share code, notes, and snippets.

@Notookk
Created April 9, 2025 12:08
Show Gist options
  • Save Notookk/b1abe1f14917644d6f0ad9e55bbcfcac to your computer and use it in GitHub Desktop.
Save Notookk/b1abe1f14917644d6f0ad9e55bbcfcac to your computer and use it in GitHub Desktop.
import sqlite3
import aiosqlite
from config import DB_PATH, ALERT_CHANNEL_ID
from typing import List, Tuple, Dict, Optional, Union
import logging
from datetime import datetime
logger = logging.getLogger(__name__)
class Database:
"""Enhanced SQLite database handler with async support for NSFW bot"""
def __init__(self, db_path: str = DB_PATH):
self.db_path = db_path
self._init_sync_db()
def _init_sync_db(self) -> None:
"""Initialize database tables (sync)"""
try:
with sqlite3.connect(self.db_path) as conn:
# Enable foreign keys and WAL mode for better performance
conn.execute("PRAGMA foreign_keys = ON")
conn.execute("PRAGMA journal_mode = WAL")
# Create tables with improved schema
conn.executescript("""
-- Original tables (unchanged)
CREATE TABLE IF NOT EXISTS users (
user_id INTEGER PRIMARY KEY,
username TEXT,
first_name TEXT,
last_name TEXT,
started_bot BOOLEAN DEFAULT 0,
start_date TIMESTAMP,
last_active TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
violation_count INTEGER DEFAULT 0
);
CREATE TABLE IF NOT EXISTS approved_users (
user_id INTEGER PRIMARY KEY,
date_added TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
added_by INTEGER,
FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS user_violations (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
category TEXT NOT NULL,
count INTEGER DEFAULT 1,
last_updated TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE,
UNIQUE(user_id, category)
);
CREATE TABLE IF NOT EXISTS groups (
group_id INTEGER PRIMARY KEY,
title TEXT,
created_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_active TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT 1,
member_count INTEGER DEFAULT 0
);
CREATE TABLE IF NOT EXISTS alerts (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER,
category TEXT,
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
message TEXT,
FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS group_memberships (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
group_id INTEGER NOT NULL,
join_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
last_active TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
is_active BOOLEAN DEFAULT 1,
FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE,
FOREIGN KEY(group_id) REFERENCES groups(group_id) ON DELETE CASCADE,
UNIQUE(user_id, group_id)
);
CREATE TABLE IF NOT EXISTS bot_start_events (
id INTEGER PRIMARY KEY AUTOINCREMENT,
user_id INTEGER NOT NULL,
start_date TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
referral_source TEXT,
FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
-- New broadcast tables (added without modifying existing ones)
CREATE TABLE IF NOT EXISTS broadcast_messages (
id INTEGER PRIMARY KEY AUTOINCREMENT,
message TEXT NOT NULL,
creator_id INTEGER NOT NULL,
target TEXT NOT NULL CHECK(target IN ('all', 'approved', 'group')),
group_id INTEGER,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
completed_at TIMESTAMP,
status TEXT NOT NULL CHECK(status IN ('pending', 'processing', 'completed', 'failed')),
sent_count INTEGER DEFAULT 0,
failed_count INTEGER DEFAULT 0,
FOREIGN KEY(creator_id) REFERENCES users(user_id),
FOREIGN KEY(group_id) REFERENCES groups(group_id)
);
CREATE TABLE IF NOT EXISTS broadcast_deliveries (
id INTEGER PRIMARY KEY AUTOINCREMENT,
broadcast_id INTEGER NOT NULL,
user_id INTEGER NOT NULL,
status TEXT NOT NULL CHECK(status IN ('success', 'failed')),
error TEXT,
delivered_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(broadcast_id) REFERENCES broadcast_messages(id) ON DELETE CASCADE,
FOREIGN KEY(user_id) REFERENCES users(user_id) ON DELETE CASCADE
);
-- Original indexes (unchanged)
CREATE INDEX IF NOT EXISTS idx_user_violations_user_id ON user_violations(user_id);
CREATE INDEX IF NOT EXISTS idx_alerts_user_id ON alerts(user_id);
CREATE INDEX IF NOT EXISTS idx_alerts_timestamp ON alerts(timestamp);
CREATE INDEX IF NOT EXISTS idx_group_memberships_user ON group_memberships(user_id);
CREATE INDEX IF NOT EXISTS idx_group_memberships_group ON group_memberships(group_id);
CREATE INDEX IF NOT EXISTS idx_bot_start_events_user ON bot_start_events(user_id);
-- New indexes for broadcast system
CREATE INDEX IF NOT EXISTS idx_broadcast_status ON broadcast_messages(status);
CREATE INDEX IF NOT EXISTS idx_broadcast_deliveries ON broadcast_deliveries(broadcast_id);
""")
conn.commit()
except sqlite3.Error as e:
logger.error(f"Database initialization failed: {e}")
raise
async def init_db(self) -> None:
"""Async database initialization (wrapper for sync version)"""
self._init_sync_db()
async def _execute(self, query: str, params: tuple = (), commit: bool = False) -> Optional[List[Tuple]]:
"""Generic async execute helper with enhanced foreign key handling"""
try:
async with aiosqlite.connect(self.db_path) as db:
# Enable foreign keys and set deferrable for better constraint handling
await db.execute("PRAGMA foreign_keys = ON")
await db.execute("PRAGMA defer_foreign_keys = ON")
cursor = await db.execute(query, params)
if commit:
await db.commit()
return await cursor.fetchall()
except aiosqlite.Error as e:
logger.error(f"Database error: {e}\nQuery: {query}\nParams: {params}")
raise
# ----------- Original Core NSFW Functions (unchanged) -----------
async def is_approved(self, user_id: int) -> bool:
"""Check if user is approved with cache support"""
result = await self._execute(
"SELECT 1 FROM approved_users WHERE user_id = ? LIMIT 1",
(user_id,)
)
return bool(result)
async def update_violations(self, user_id: int, category: str) -> None:
"""Update violations with atomic increment and user existence check"""
try:
# First ensure the user exists in the users table
await self._execute(
"""INSERT OR IGNORE INTO users (user_id)
VALUES (?)""",
(user_id,),
commit=True
)
# Then update violations
await self._execute(
"""INSERT INTO user_violations (user_id, category)
VALUES (?, ?)
ON CONFLICT(user_id, category) DO UPDATE SET
count = count + 1,
last_updated = CURRENT_TIMESTAMP""",
(user_id, category),
commit=True
)
# Update user's violation count
await self._execute(
"UPDATE users SET violation_count = violation_count + 1 WHERE user_id = ?",
(user_id,),
commit=True
)
except Exception as e:
logger.error(f"Failed to update violations for user {user_id}: {e}")
raise
async def add_approved_user(self, user_id: int, added_by: Optional[int] = None) -> None:
"""Add user to approved list with admin tracking"""
await self._execute(
"""INSERT OR IGNORE INTO approved_users (user_id, added_by)
VALUES (?, ?)""",
(user_id, added_by),
commit=True
)
async def remove_approved_user(self, user_id: int) -> None:
"""Remove user from approved list"""
await self._execute(
"DELETE FROM approved_users WHERE user_id = ?",
(user_id,),
commit=True
)
async def get_user_violations(self, user_id: int) -> List[Tuple[str, int, datetime]]:
"""Get detailed violation history"""
return await self._execute(
"SELECT category, count, last_updated FROM user_violations WHERE user_id = ?",
(user_id,)
)
async def get_all_approved_users(self) -> List[Dict[str, Union[int, datetime]]]:
"""Get approved users with metadata"""
results = await self._execute(
"""SELECT user_id, date_added, added_by
FROM approved_users
ORDER BY date_added DESC"""
)
return [{'user_id': r[0], 'date_added': r[1], 'added_by': r[2]} for r in results]
# ----------- Original Enhanced User Management (unchanged) -----------
async def upsert_user(self, user_id: int, username: Optional[str] = None,
first_name: Optional[str] = None, last_name: Optional[str] = None) -> None:
"""Full user upsert with all fields"""
await self._execute(
"""INSERT INTO users (user_id, username, first_name, last_name)
VALUES (?, ?, ?, ?)
ON CONFLICT(user_id) DO UPDATE SET
username = excluded.username,
first_name = excluded.first_name,
last_name = excluded.last_name,
last_active = CURRENT_TIMESTAMP""",
(user_id, username, first_name, last_name),
commit=True
)
async def get_user_info(self, user_id: int) -> Optional[Dict[str, Union[int, str, bool]]]:
"""Get complete user info"""
result = await self._execute(
"""SELECT user_id, username, first_name, last_name,
started_bot, start_date, last_active, violation_count
FROM users WHERE user_id = ? LIMIT 1""",
(user_id,)
)
if result:
return {
'user_id': result[0][0],
'username': result[0][1],
'first_name': result[0][2],
'last_name': result[0][3],
'started_bot': bool(result[0][4]),
'start_date': result[0][5],
'last_active': result[0][6],
'violation_count': result[0][7]
}
return None
# ----------- Original Alert System (unchanged) -----------
async def log_alert(self, user_id: int, category: str, message: str) -> None:
"""Log NSFW alert to database"""
await self._execute(
"""INSERT INTO alerts (user_id, category, message)
VALUES (?, ?, ?)""",
(user_id, category, message),
commit=True
)
async def get_recent_alerts(self, limit: int = 10) -> List[Dict[str, Union[int, str, datetime]]]:
"""Get recent alerts with pagination"""
results = await self._execute(
"""SELECT a.id, a.user_id, u.username, a.category,
a.message, a.timestamp
FROM alerts a
LEFT JOIN users u ON a.user_id = u.user_id
ORDER BY a.timestamp DESC
LIMIT ?""",
(limit,)
)
return [{
'id': r[0],
'user_id': r[1],
'username': r[2],
'category': r[3],
'message': r[4],
'timestamp': r[5]
} for r in results]
# ----------- Original Tracking Methods (unchanged) -----------
async def record_bot_start(self, user_id: int, referral_source: Optional[str] = None) -> None:
"""Record when a user starts interacting with the bot"""
try:
await self._execute(
"""INSERT INTO bot_start_events (user_id, referral_source)
VALUES (?, ?)""",
(user_id, referral_source),
commit=True
)
await self._execute(
"""UPDATE users
SET started_bot = 1,
start_date = CURRENT_TIMESTAMP,
last_active = CURRENT_TIMESTAMP
WHERE user_id = ?""",
(user_id,),
commit=True
)
except Exception as e:
logger.error(f"Failed to record bot start for user {user_id}: {e}")
raise
async def record_group_join(self, user_id: int, group_id: int, group_title: str) -> None:
"""Record when a user joins a group"""
try:
# First ensure group exists
await self._execute(
"""INSERT OR IGNORE INTO groups (group_id, title)
VALUES (?, ?)""",
(group_id, group_title),
commit=True
)
# Record membership
await self._execute(
"""INSERT INTO group_memberships (user_id, group_id)
VALUES (?, ?)
ON CONFLICT(user_id, group_id) DO UPDATE SET
is_active = 1,
last_active = CURRENT_TIMESTAMP""",
(user_id, group_id),
commit=True
)
# Update group member count
await self._execute(
"""UPDATE groups
SET member_count = (
SELECT COUNT(*)
FROM group_memberships
WHERE group_id = ? AND is_active = 1
),
last_active = CURRENT_TIMESTAMP
WHERE group_id = ?""",
(group_id, group_id),
commit=True
)
except Exception as e:
logger.error(f"Failed to record group join for user {user_id}: {e}")
raise
async def record_group_leave(self, user_id: int, group_id: int) -> None:
"""Record when a user leaves a group"""
try:
await self._execute(
"""UPDATE group_memberships
SET is_active = 0,
last_active = CURRENT_TIMESTAMP
WHERE user_id = ? AND group_id = ?""",
(user_id, group_id),
commit=True
)
# Update group member count
await self._execute(
"""UPDATE groups
SET member_count = (
SELECT COUNT(*)
FROM group_memberships
WHERE group_id = ? AND is_active = 1
)
WHERE group_id = ?""",
(group_id, group_id),
commit=True
)
except Exception as e:
logger.error(f"Failed to record group leave for user {user_id}: {e}")
raise
async def get_user_groups(self, user_id: int) -> List[Dict[str, Union[int, str, datetime]]]:
"""Get all groups a user is active in"""
results = await self._execute(
"""SELECT g.group_id, g.title, gm.join_date, gm.last_active
FROM group_memberships gm
JOIN groups g ON gm.group_id = g.group_id
WHERE gm.user_id = ? AND gm.is_active = 1
ORDER BY gm.last_active DESC""",
(user_id,)
)
return [{
'group_id': r[0],
'title': r[1],
'join_date': r[2],
'last_active': r[3]
} for r in results]
async def get_group_members(self, group_id: int) -> List[Dict[str, Union[int, str, datetime]]]:
"""Get all active members of a group"""
results = await self._execute(
"""SELECT u.user_id, u.username, u.first_name, u.last_name,
gm.join_date, gm.last_active
FROM group_memberships gm
JOIN users u ON gm.user_id = u.user_id
WHERE gm.group_id = ? AND gm.is_active = 1
ORDER BY gm.last_active DESC""",
(group_id,)
)
return [{
'user_id': r[0],
'username': r[1],
'first_name': r[2],
'last_name': r[3],
'join_date': r[4],
'last_active': r[5]
} for r in results]
async def get_user_activity(self, user_id: int) -> Dict[str, Union[int, List[Dict]]]:
"""Get comprehensive user activity data"""
user_info = await self.get_user_info(user_id)
if not user_info:
return {}
start_events = await self._execute(
"SELECT start_date, referral_source FROM bot_start_events WHERE user_id = ?",
(user_id,)
)
groups = await self.get_user_groups(user_id)
violations = await self.get_user_violations(user_id)
return {
'user_info': user_info,
'start_events': [{
'start_date': e[0],
'referral_source': e[1]
} for e in start_events],
'groups': groups,
'violations': [{
'category': v[0],
'count': v[1],
'last_updated': v[2]
} for v in violations]
}
# ----------- Original Maintenance (unchanged) -----------
async def backup_database(self, backup_path: str) -> bool:
"""Create database backup"""
try:
async with aiosqlite.connect(self.db_path) as source:
async with aiosqlite.connect(backup_path) as target:
await source.backup(target)
return True
except aiosqlite.Error as e:
logger.error(f"Backup failed: {e}")
return False
# ----------- New Broadcast System Methods (added without modifying existing ones) -----------
async def add_broadcast_message(self, message: str, creator_id: int,
target: str = 'all', group_id: Optional[int] = None) -> int:
"""Add a broadcast message to the database"""
result = await self._execute(
"""INSERT INTO broadcast_messages
(message, creator_id, target, group_id, created_at, status)
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP, 'pending')
RETURNING id""",
(message, creator_id, target, group_id),
commit=True
)
return result[0][0] if result else None
async def get_pending_broadcasts(self, limit: int = 10) -> List[Dict]:
"""Get pending broadcast messages"""
results = await self._execute(
"""SELECT id, message, creator_id, target, group_id, created_at
FROM broadcast_messages
WHERE status = 'pending'
ORDER BY created_at ASC
LIMIT ?""",
(limit,)
)
return [{
'id': r[0],
'message': r[1],
'creator_id': r[2],
'target': r[3],
'group_id': r[4],
'created_at': r[5]
} for r in results]
async def update_broadcast_status(self, broadcast_id: int, status: str,
sent_count: int = 0, failed_count: int = 0) -> None:
"""Update broadcast status and statistics"""
await self._execute(
"""UPDATE broadcast_messages
SET status = ?,
sent_count = ?,
failed_count = ?,
completed_at = CASE WHEN ? = 'completed' THEN CURRENT_TIMESTAMP ELSE NULL END
WHERE id = ?""",
(status, sent_count, failed_count, status, broadcast_id),
commit=True
)
async def get_recipients_for_broadcast(self, target: str, group_id: Optional[int] = None) -> List[int]:
"""Get recipient user IDs based on broadcast target"""
if target == 'all':
results = await self._execute(
"SELECT user_id FROM users WHERE started_bot = 1"
)
elif target == 'approved':
results = await self._execute(
"SELECT user_id FROM approved_users"
)
elif target == 'group' and group_id:
results = await self._execute(
"""SELECT user_id FROM group_memberships
WHERE group_id = ? AND is_active = 1""",
(group_id,)
)
else:
return []
return [r[0] for r in results] if results else []
async def log_broadcast_delivery(self, broadcast_id: int, user_id: int,
status: str, error: Optional[str] = None) -> None:
"""Log individual delivery attempts"""
await self._execute(
"""INSERT INTO broadcast_deliveries
(broadcast_id, user_id, status, error, delivered_at)
VALUES (?, ?, ?, ?, CURRENT_TIMESTAMP)""",
(broadcast_id, user_id, status, error),
commit=True
)
async def get_broadcast_stats(self, broadcast_id: int) -> Dict:
"""Get statistics for a broadcast"""
result = await self._execute(
"""SELECT status, sent_count, failed_count, created_at, completed_at
FROM broadcast_messages
WHERE id = ?""",
(broadcast_id,)
)
if not result:
return {}
deliveries = await self._execute(
"""SELECT status, COUNT(*)
FROM broadcast_deliveries
WHERE broadcast_id = ?
GROUP BY status""",
(broadcast_id,)
)
return {
'status': result[0][0],
'sent_count': result[0][1],
'failed_count': result[0][2],
'created_at': result[0][3],
'completed_at': result[0][4],
'deliveries': {d[0]: d[1] for d in deliveries}
}
# Initialize database instance
db = Database()
# Maintain backwards compatibility
is_approved = db.is_approved
update_violations = db.update_violations
add_approved_user = db.add_approved_user
remove_approved_user = db.remove_approved_user
get_user_violations = db.get_user_violations
get_all_users = db.get_all_approved_users
__all__ = [
'Database',
'db',
'is_approved',
'update_violations',
'add_approved_user',
'remove_approved_user',
'get_user_violations',
'get_all_users',
]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment