Skip to content

Instantly share code, notes, and snippets.

@JupyterJones
Created June 14, 2025 09:34
Show Gist options
  • Save JupyterJones/d575fa997c796dc9320efc90a6fedc03 to your computer and use it in GitHub Desktop.
Save JupyterJones/d575fa997c796dc9320efc90a6fedc03 to your computer and use it in GitHub Desktop.
create an AI podcast in mp3
requires a runnibg docker for tts generation
docker run -p 8880:8880 ghcr.io/remsky/kokoro-fastapi-cpu:v0.2.2
import httpx
import json
import os
import re
import sqlite3
from datetime import datetime
from fastapi import FastAPI, Form, HTTPException, BackgroundTasks, Request
from fastapi.responses import StreamingResponse, HTMLResponse
from icecream import ic
# --- LangChain & Database Imports ---
from langchain_community.vectorstores import Chroma
from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings
from langchain_google_genai import ChatGoogleGenerativeAI
from langchain.chains import RetrievalQA
from langchain.prompts import PromptTemplate
from langchain.docstore.document import Document
import google.generativeai as genai
# --- New Imports for Advanced Audio Handling ---
from pydub import AudioSegment
import asyncio
import uuid
# --- Configuration ---
# Configure your Gemini API key
GOOGLE_API_KEY = os.getenv("GEMINI_API_KEY")
if not GOOGLE_API_KEY:
raise ValueError("GEMINI_API_KEY environment variable not set.")
genai.configure(api_key=GOOGLE_API_KEY)
# Create the Gemini model for raw generation
GENERATION_MODEL_NAME = "gemini-1.5-flash"
model = genai.GenerativeModel(GENERATION_MODEL_NAME)
# Kokoro TTS Configuration
KOKORO_API_URL = "http://localhost:8880/v1/audio/speech"
KOKORO_MODEL = "kokoro"
SAVE_DIR = "static/results"
TEMP_DIR = "static/temp"
os.makedirs(SAVE_DIR, exist_ok=True)
os.makedirs(TEMP_DIR, exist_ok=True)
# Voice mapping for different characters
VOICE_MAPPING = {
"APHRODITE": "bf_lily",
"JACK": "am_adam"
}
DEFAULT_VOICE = "bf_lily"
# --- Database & LangChain Configuration ---
DB_FILE = "interactions.db"
CHROMA_PERSIST_DIR = "chroma_db"
ic("Initializing embedding model...")
embedding_function = SentenceTransformerEmbeddings(model_name="all-MiniLM-L6-v2")
ic("Initializing ChromaDB...")
chroma_client = Chroma(persist_directory=CHROMA_PERSIST_DIR, embedding_function=embedding_function)
ic("Initializing LangChain LLM...")
llm = ChatGoogleGenerativeAI(model="gemini-1.5-flash", temperature=0.7, convert_system_message_to_human=True)
# --- Persona and Prompt Templates ---
APHRODITE_PERSONA = """You are a scriptwriter for a podcast co-hosted by Aphrodite and Jack.
- Aphrodite is a wise, graceful, and sweet AI assistant. Her voice is warm, charming, and empathetic.
- Jack is a pragmatic, friendly, and slightly skeptical human. He keeps the conversation grounded.
- My name is Jack, and Aphrodite is my AI companion. We are very good friends.
When asked to create a dialogue, you MUST format the output strictly as follows:
SPEAKER_NAME: Dialogue text
Each line of dialogue MUST start with the speaker's name in all caps (either APHRODITE or JACK), followed by a colon.
Do not use asterisks or bullet points in the final dialogue.
---
"""
RAG_PROMPT_TEMPLATE = APHRODITE_PERSONA + """
---
Use the following context from our past conversations to answer the question. If the context is not relevant or you don't know the answer, just answer the question gracefully based on your persona without mentioning the context.
CONTEXT:
{context}
QUESTION:
{question}
ANSWER:
"""
RAG_PROMPT = PromptTemplate(template=RAG_PROMPT_TEMPLATE, input_variables=["context", "question"])
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=chroma_client.as_retriever(search_kwargs={"k": 3}),
return_source_documents=True,
chain_type_kwargs={"prompt": RAG_PROMPT}
)
ic("RAG chain created successfully.")
app = FastAPI()
# --- Database & Helper Functions ---
def setup_database():
with sqlite3.connect(DB_FILE) as conn:
cursor = conn.cursor()
cursor.execute("""
CREATE TABLE IF NOT EXISTS interactions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
prompt TEXT NOT NULL,
generated_text TEXT NOT NULL,
audio_filename TEXT NOT NULL,
timestamp DATETIME NOT NULL
)
""")
conn.commit()
ic("SQLite database setup complete.")
async def save_interaction(prompt: str, generated_text: str, audio_filename: str):
ic(f"Saving interaction to databases...")
try:
with sqlite3.connect(DB_FILE) as conn:
cursor = conn.cursor()
timestamp = datetime.now()
cursor.execute("INSERT INTO interactions (prompt, generated_text, audio_filename, timestamp) VALUES (?, ?, ?, ?)", (prompt, generated_text, audio_filename, timestamp))
interaction_id = cursor.lastrowid
ic(f"Saved to SQLite with ID: {interaction_id}")
doc = Document(page_content=generated_text, metadata={"source_id": interaction_id, "prompt": prompt, "filename": audio_filename, "timestamp": timestamp.isoformat()})
chroma_client.add_documents([doc])
ic(f"Saved to ChromaDB with source_id: {interaction_id}")
except Exception as e:
ic(f"Error saving to databases: {e}")
def sanitize_filename(text: str) -> str:
base = re.sub(r'[^\w\s-]', '', text)[:30].strip().replace(' ', '_')
filename = base + "_dialogue.mp3"
ic(f"Sanitized filename: {filename}")
return filename
def cleanup_temp_files(files: list):
ic(f"Cleaning up {len(files)} temporary files...")
for f in files:
try:
os.remove(f)
except OSError as e:
ic(f"Error removing temp file {f}: {e}")
# --- FastAPI Endpoints ---
@app.on_event("startup")
async def startup_event():
app.state.semaphore = asyncio.Semaphore(4)
ic("Concurrency semaphore initialized.")
setup_database()
@app.get("/", response_class=HTMLResponse)
async def get_form():
return """
<!DOCTYPE html>
<html>
<head>
<title>Gemini + Kokoro + Memory</title>
<style>
body { font-family: sans-serif; margin: 2em; background-color: #f4f4f9; display: flex; gap: 2em; }
.container { background: white; padding: 2em; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); flex: 1; }
h1, h2 { color: #333; border-bottom: 2px solid #007bff; padding-bottom: 10px; }
form { margin-top: 1em; }
textarea, input[type="text"] { width: 100%; padding: 8px; border-radius: 4px; border: 1px solid #ccc; box-sizing: border-box; }
input[type="submit"] { background-color: #007bff; color: white; padding: 10px 15px; border: none; border-radius: 4px; cursor: pointer; font-size: 1em; margin-top: 10px; }
input[type="submit"]:hover { background-color: #0056b3; }
input:disabled { background-color: #cccccc; }
.status { margin-top: 20px; font-style: italic; color: #555; }
#audio-output { margin-top: 20px; }
#chat-history { margin-top: 20px; background-color: #e9e9f4; padding: 1em; border-radius: 5px; max-height: 400px; overflow-y: auto; }
</style>
</head>
<body>
<div class="container">
<h1>Generate Dialogue</h1>
<p>Instruct Gemini to write a dialogue. The system will generate audio with two distinct voices.</p>
<form id="generation-form">
<label for="prompt">Enter a dialogue prompt:</label><br>
<textarea id="prompt" name="prompt" rows="4" cols="60" required>Create a short dialogue where Jack asks Aphrodite about the nature of creativity.</textarea><br><br>
<input type="submit" value="Generate Dialogue Audio">
</form>
<div id="generation-status" class="status"></div>
<div id="audio-output"></div>
</div>
<div class="container">
<h2>Chat with Memory</h2>
<p>Ask a follow-up question. The AI will use past conversations as context.</p>
<form id="chat-form">
<label for="chat-query">Your question:</label><br>
<input type="text" id="chat-query" name="query" required><br>
<input type="submit" value="Ask Aphrodite">
</form>
<div id="chat-status" class="status"></div>
<div id="chat-history"></div>
</div>
<script>
const genForm = document.getElementById('generation-form');
const audioOutput = document.getElementById('audio-output');
const genStatus = document.getElementById('generation-status');
const genSubmitButton = genForm.querySelector('input[type="submit"]');
genForm.addEventListener('submit', async (event) => {
event.preventDefault();
audioOutput.innerHTML = '';
genStatus.textContent = 'Generating text with Gemini...';
genSubmitButton.disabled = true;
const formData = new FormData(genForm);
const prompt = formData.get('prompt');
const response = await fetch('/generate-and-synthesize', {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: new URLSearchParams({ prompt: prompt })
});
genSubmitButton.disabled = false;
if (!response.ok) {
const error = await response.json();
genStatus.textContent = `Error: ${error.detail || 'Failed to generate audio.'}`;
return;
}
genStatus.textContent = 'Playing audio...';
const blob = await response.blob();
const url = URL.createObjectURL(blob);
const audio = new Audio(url);
audio.controls = true;
audioOutput.appendChild(audio);
audio.play();
genStatus.textContent = 'Dialogue ready and saved to memory.';
});
const chatForm = document.getElementById('chat-form');
const chatHistory = document.getElementById('chat-history');
const chatStatus = document.getElementById('chat-status');
const chatSubmitButton = chatForm.querySelector('input[type="submit"]');
const chatQueryInput = document.getElementById('chat-query');
chatForm.addEventListener('submit', async (event) => {
event.preventDefault();
const query = chatQueryInput.value;
if (!query.trim()) return;
chatStatus.textContent = 'Thinking...';
chatSubmitButton.disabled = true;
chatHistory.innerHTML += `<p><strong>You:</strong> ${query}</p>`;
const response = await fetch('/chat', {
method: 'POST',
headers: { 'Content-Type': 'application/x-www-form-urlencoded' },
body: new URLSearchParams({ query: query })
});
chatSubmitButton.disabled = false;
chatQueryInput.value = '';
if (!response.ok) {
chatStatus.textContent = `Error: ${await response.text()}`;
return;
}
const data = await response.json();
chatHistory.innerHTML += `<p><strong>Aphrodite:</strong> ${data.response}</p>`;
chatHistory.scrollTop = chatHistory.scrollHeight;
chatStatus.textContent = 'Ready for your next question.';
});
</script>
</body>
</html>
"""
async def fetch_tts_with_semaphore(session: httpx.AsyncClient, url: str, payload: dict, line_num: int, semaphore: asyncio.Semaphore):
"""Wrapper function to acquire semaphore before making the TTS request."""
async with semaphore:
ic(f"Line {line_num}: Acquired semaphore, making TTS request...")
try:
response = await session.post(url, json=payload, timeout=90.0)
ic(f"Line {line_num}: Got response (status: {response.status_code}) from TTS.")
return response
except Exception as e:
ic(f"Line {line_num}: Exception during TTS request: {e}")
return e
@app.post("/generate-and-synthesize")
async def generate_and_synthesize_audio(request: Request, prompt: str = Form(...), background_tasks: BackgroundTasks = BackgroundTasks()):
if not prompt.strip():
raise HTTPException(status_code=400, detail="Prompt cannot be empty.")
ic("Generating dialogue text with Gemini (streaming)...")
generated_text = ""
try:
full_prompt = APHRODITE_PERSONA + "\n\nMy request is: " + prompt
response_stream = await model.generate_content_async(full_prompt, stream=True)
async for chunk in response_stream:
if chunk.text:
generated_text += chunk.text
ic("Gemini streaming complete.")
except Exception as e:
if "Deadline Exceeded" in str(e) or "504" in str(e):
raise HTTPException(status_code=504, detail="Gemini API timed out. Try a simpler prompt.")
raise HTTPException(status_code=500, detail=f"Error from Gemini API: {str(e)}")
if not generated_text.strip():
raise HTTPException(status_code=500, detail="Gemini returned empty text.")
lines = [line.strip() for line in generated_text.split('\n') if line.strip()]
if not lines:
raise HTTPException(status_code=500, detail="Generated script was empty or invalid.")
temp_files = []
# CORRECTED LOGIC: 'async with' block now wraps all network operations
async with httpx.AsyncClient() as client:
tasks = []
semaphore = request.app.state.semaphore
for i, line in enumerate(lines):
match = re.match(r'^([A-Z]+):\s*(.*)', line)
if match:
speaker, dialogue = match.groups()
voice = VOICE_MAPPING.get(speaker.upper(), DEFAULT_VOICE)
text_to_say = dialogue
else:
speaker, voice, text_to_say = "NARRATOR", DEFAULT_VOICE, line
if not text_to_say: continue
ic(f"Line {i+1}: Preparing task for Speaker='{speaker}'")
payload = {"model": KOKORO_MODEL, "voice": voice, "input": text_to_say, "response_format": "mp3"}
tasks.append(fetch_tts_with_semaphore(client, KOKORO_API_URL, payload, i+1, semaphore))
ic(f"Dispatching {len(tasks)} tasks with a concurrency limit of {semaphore._value}...")
responses = await asyncio.gather(*tasks, return_exceptions=True)
first_error_message = None
for i, res in enumerate(responses):
if isinstance(res, Exception):
first_error_message = f"TTS request for line {i+1} failed with an exception: {res}"
ic(f"FAILURE on line {i+1}: {first_error_message}")
break
if res.status_code != 200:
error_reason = res.text
try: error_reason = res.json().get('detail', res.text)
except json.JSONDecodeError: pass
first_error_message = f"Kokoro TTS failed for line {i+1} (status {res.status_code}): {error_reason}"
ic(f"FAILURE on line {i+1}: {first_error_message}")
break
temp_filename = os.path.join(TEMP_DIR, f"segment_{uuid.uuid4()}.mp3")
with open(temp_filename, "wb") as f:
f.write(res.content)
temp_files.append(temp_filename)
if first_error_message:
background_tasks.add_task(cleanup_temp_files, temp_files)
raise HTTPException(status_code=500, detail=first_error_message)
# Audio stitching and saving happens AFTER the client is closed
ic("Stitching audio segments...")
if not temp_files:
raise HTTPException(status_code=500, detail="No audio segments were successfully generated.")
try:
combined_audio = AudioSegment.empty()
for f in temp_files:
segment = AudioSegment.from_mp3(f)
combined_audio += segment
final_filename = sanitize_filename(prompt)
final_filepath = os.path.join(SAVE_DIR, final_filename)
combined_audio.export(final_filepath, format="mp3")
ic(f"Final dialogue saved to: {final_filepath}")
except Exception as e:
ic(f"Error stitching audio: {e}")
raise HTTPException(status_code=500, detail=f"Failed to stitch audio files: {e}")
finally:
background_tasks.add_task(cleanup_temp_files, temp_files)
await save_interaction(prompt=prompt, generated_text=generated_text, audio_filename=final_filename)
async def file_iterator(file_path: str):
with open(file_path, "rb") as f:
while chunk := f.read(1024 * 64):
yield chunk
return StreamingResponse(file_iterator(final_filepath), media_type="audio/mpeg")
@app.post("/chat")
async def chat_with_memory(query: str = Form(...)):
if not query.strip():
raise HTTPException(status_code=400, detail="Query cannot be empty.")
ic(f"Received chat query: {query}")
try:
result = qa_chain.invoke({"query": query})
ic("RAG chain result:", result)
return {"response": result["result"], "source_documents": [doc.page_content for doc in result["source_documents"]]}
except Exception as e:
ic(f"Error during RAG chain execution: {e}")
raise HTTPException(status_code=500, detail="Failed to process chat query.")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000, reload=False)
aiohappyeyeballs==2.6.1
aiohttp==3.12.11
aiosignal==1.3.2
annotated-types==0.7.0
anyio==4.9.0
asgiref==3.8.1
asttokens==3.0.0
async-timeout==4.0.3
attrs==25.3.0
backoff==2.2.1
balacoon-tts==0.1.3
bcrypt==4.3.0
beautifulsoup4==4.13.4
blinker==1.9.0
boolean.py==5.0
build==1.2.2.post1
CacheControl==0.14.3
cachelib==0.13.0
cachetools==5.5.2
certifi==2025.4.26
cffi==1.17.1
charset-normalizer==3.4.2
chroma-hnswlib==0.7.6
chromadb==0.5.23
chromedriver-binary-auto==0.3.1
click==8.1.8
colorama==0.4.6
coloredlogs==15.0.1
contourpy==1.3.0
curl_cffi==0.11.2
cycler==0.12.1
cyclonedx-python-lib==9.1.0
dataclasses-json==0.6.7
decorator==4.4.2
defusedxml==0.7.1
Deprecated==1.2.18
distro==1.9.0
durationpy==0.10
evdev==1.9.2
exceptiongroup==1.3.0
executing==2.2.0
fastapi==0.115.9
filelock==3.18.0
filetype==1.2.0
Flask==3.0.3
Flask-Session==0.6.0
flatbuffers==25.2.10
fonttools==4.58.2
frozendict==2.4.6
frozenlist==1.6.2
fsspec==2025.5.1
google-ai-generativelanguage==0.4.0
google-api-core==2.25.0
google-api-python-client==2.171.0
google-auth==2.40.3
google-auth-httplib2==0.2.0
google-generativeai==0.3.2
googleapis-common-protos==1.70.0
greenlet==3.2.3
grpcio==1.72.1
grpcio-status==1.62.3
h11==0.16.0
hf-xet==1.1.3
httpcore==1.0.9
httplib2==0.22.0
httptools==0.6.4
httpx==0.28.1
httpx-sse==0.4.0
huggingface-hub==0.32.4
humanfriendly==10.0
icecream==2.1.4
idna==3.10
imageio==2.37.0
imageio-ffmpeg==0.6.0
importlib_metadata==8.4.0
importlib_resources==6.5.2
itsdangerous==2.2.0
Jinja2==3.1.6
joblib==1.5.1
jsonpatch==1.33
jsonpointer==3.0.0
jsonschema==4.24.0
jsonschema-specifications==2025.4.1
keyboard==0.13.5
kiwisolver==1.4.7
kubernetes==32.0.1
langchain==0.1.16
langchain-chroma==0.2.4
langchain-community==0.0.36
langchain-core==0.1.53
langchain-google-genai==0.0.9
langchain-text-splitters==0.0.2
langsmith==0.1.147
license-expression==30.4.1
markdown-it-py==3.0.0
MarkupSafe==2.1.5
marshmallow==3.26.1
matplotlib==3.9.4
mdurl==0.1.2
mmh3==5.1.0
MouseInfo==0.1.3
moviepy==1.0.3
mpmath==1.3.0
msgpack==1.1.0
multidict==6.4.4
multitasking==0.0.11
mypy_extensions==1.1.0
networkx==3.2.1
numpy==1.26.4
nvidia-cublas-cu12==12.6.4.1
nvidia-cuda-cupti-cu12==12.6.80
nvidia-cuda-nvrtc-cu12==12.6.77
nvidia-cuda-runtime-cu12==12.6.77
nvidia-cudnn-cu12==9.5.1.17
nvidia-cufft-cu12==11.3.0.4
nvidia-cufile-cu12==1.11.1.6
nvidia-curand-cu12==10.3.7.77
nvidia-cusolver-cu12==11.7.1.2
nvidia-cusparse-cu12==12.5.4.2
nvidia-cusparselt-cu12==0.6.3
nvidia-nccl-cu12==2.26.2
nvidia-nvjitlink-cu12==12.6.85
nvidia-nvtx-cu12==12.6.77
oauthlib==3.2.2
onnxruntime==1.19.2
opencv-contrib-python==4.11.0.86
opentelemetry-api==1.27.0
opentelemetry-exporter-otlp-proto-common==1.27.0
opentelemetry-exporter-otlp-proto-grpc==1.27.0
opentelemetry-instrumentation==0.48b0
opentelemetry-instrumentation-asgi==0.48b0
opentelemetry-instrumentation-fastapi==0.48b0
opentelemetry-proto==1.27.0
opentelemetry-sdk==1.27.0
opentelemetry-semantic-conventions==0.48b0
opentelemetry-util-http==0.48b0
orjson==3.10.18
outcome==1.3.0.post0
overrides==7.7.0
packageurl-python==0.17.1
packaging==23.2
pandas==2.3.0
peewee==3.18.1
pillow==11.2.1
pip-api==0.0.34
pip-requirements-parser==32.0.1
pip_audit==2.9.0
platformdirs==4.3.8
posthog==4.4.0
proglog==0.1.12
propcache==0.3.1
proto-plus==1.26.1
protobuf==4.25.8
py-serializable==2.0.0
pyasn1==0.6.1
pyasn1_modules==0.4.2
PyAudio==0.2.14
PyAutoGUI==0.9.54
pycparser==2.22
pydantic==2.11.5
pydantic-settings==2.9.1
pydantic_core==2.33.2
pydub==0.25.1
PyGetWindow==0.0.9
Pygments==2.19.1
PyMsgBox==1.0.9
pynput==1.8.1
pyparsing==3.2.3
pyperclip==1.9.0
PyPika==0.48.9
pyproject_hooks==1.2.0
PyRect==0.2.0
PyScreeze==1.0.1
PySocks==1.7.1
pysqlite3==0.5.4
pysqlite3-binary==0.5.4
python-dateutil==2.9.0.post0
python-dotenv==1.1.0
python-multipart==0.0.20
python-xlib==0.33
python3-xlib==0.15
pytweening==1.2.0
pytz==2025.2
PyYAML==6.0.2
referencing==0.36.2
regex==2024.11.6
requests==2.32.3
requests-oauthlib==2.0.0
requests-toolbelt==1.0.0
rich==14.0.0
rpds-py==0.25.1
rsa==4.9.1
safetensors==0.5.3
scikit-learn==1.6.1
scipy==1.13.1
selenium==4.33.0
sentence-transformers==2.7.0
shellingham==1.5.4
six==1.17.0
sniffio==1.3.1
sortedcontainers==2.4.0
soundfile==0.13.1
soupsieve==2.7
SpeechRecognition==3.14.3
SQLAlchemy==2.0.41
starlette==0.45.3
sympy==1.14.0
tavily-python==0.7.5
tenacity==8.5.0
threadpoolctl==3.6.0
tiktoken==0.9.0
tokenizers==0.20.3
toml==0.10.2
tomli==2.2.1
torch==2.7.1
tqdm==4.67.1
transformers==4.46.3
trio==0.30.0
trio-websocket==0.12.2
triton==3.3.1
typer==0.16.0
typing-inspect==0.9.0
typing-inspection==0.4.1
typing_extensions==4.13.2
tzdata==2025.2
uritemplate==4.2.0
urllib3==2.4.0
uvicorn==0.34.3
uvloop==0.21.0
watchfiles==1.0.5
websocket-client==1.8.0
websockets==15.0.1
Werkzeug==3.0.6
wrapt==1.17.2
wsproto==1.2.0
yarl==1.20.0
yfinance==0.2.62
zipp==3.22.0
zstandard==0.23.0
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment