Skip to content

Instantly share code, notes, and snippets.

@regexyl
Created March 18, 2026 20:57
Show Gist options
  • Select an option

  • Save regexyl/8a5ec296f00df6f2caf1f0c8dd41d2f9 to your computer and use it in GitHub Desktop.

Select an option

Save regexyl/8a5ec296f00df6f2caf1f0c8dd41d2f9 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
from __future__ import annotations
import argparse
import json
import os
import sys
import time
from dataclasses import asdict, dataclass
from pathlib import Path
from typing import Any, Optional
from urllib import error, parse, request
DEFAULT_LIMIT = 10
DEFAULT_RETRIES = 2
DEFAULT_RETRY_DELAY_SECONDS = 1.0
DEFAULT_TIMEOUT_SECONDS = 20.0
TOKEN_URL = "https://www.reddit.com/api/v1/access_token"
OAUTH_BASE_URL = "https://oauth.reddit.com"
SKILL_DIR = Path(__file__).resolve().parents[1]
ENV_FILE = SKILL_DIR / ".env"
class RedditSearchError(RuntimeError):
"""Base error for the Reddit search CLI."""
class RedditConfigurationError(RedditSearchError):
"""Raised when required credentials or parameters are missing."""
class RedditRequestError(RedditSearchError):
"""Raised when Reddit returns an unsuccessful response."""
class RedditResponseError(RedditSearchError):
"""Raised when Reddit returns an unexpected payload."""
@dataclass(frozen=True)
class Credentials:
client_id: str
client_secret: str
user_agent: str
@dataclass(frozen=True)
class SearchRequest:
resource: str
query: str
subreddit: Optional[str]
path: Optional[str]
sort: Optional[str]
time_window: Optional[str]
limit: int
after: Optional[str]
before: Optional[str]
include_over_18: bool
extra_params: tuple[tuple[str, str], ...]
retries: int
retry_delay_seconds: float
timeout_seconds: float
@dataclass(frozen=True)
class SearchResult:
kind: str
id: str
name: str
title: Optional[str]
text: Optional[str]
subreddit: Optional[str]
author: Optional[str]
score: Optional[int]
comment_count: Optional[int]
created_utc: Optional[float]
permalink: Optional[str]
url: Optional[str]
over_18: Optional[bool]
@dataclass(frozen=True)
class SearchOutput:
resource: str
query: str
path: str
params: dict[str, str]
returned_count: int
after: Optional[str]
before: Optional[str]
results: list[SearchResult]
def build_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description="Search Reddit posts, comments, subreddits, or users with OAuth.",
)
parser.add_argument(
"--query",
required=True,
help="Reddit search query string.",
)
parser.add_argument(
"--resource",
required=True,
choices=("posts", "comments", "subreddits", "users", "raw"),
help="Reddit search resource to query.",
)
parser.add_argument(
"--subreddit",
help="Optional subreddit name when searching posts or comments inside one subreddit.",
)
parser.add_argument(
"--path",
help="OAuth API path for --resource raw, for example /search or /subreddits/search.",
)
parser.add_argument(
"--sort",
help="Reddit sort value, for example relevance, hot, top, new, or comments.",
)
parser.add_argument(
"--time",
help="Reddit time window, for example hour, day, week, month, year, or all.",
)
parser.add_argument(
"--limit",
type=int,
help=f"Maximum results to return. Defaults to {DEFAULT_LIMIT}.",
)
parser.add_argument(
"--after",
help="Pagination cursor for the next page.",
)
parser.add_argument(
"--before",
help="Pagination cursor for the previous page.",
)
parser.add_argument(
"--include-over-18",
action="store_true",
help="Pass include_over_18=on to Reddit.",
)
parser.add_argument(
"--param",
action="append",
help="Additional query parameter in key=value form. Repeat as needed.",
)
parser.add_argument(
"--retries",
type=int,
help=f"Retry count for 429, 5xx, and transient connection failures. Defaults to {DEFAULT_RETRIES}.",
)
parser.add_argument(
"--retry-delay-seconds",
type=float,
help=f"Base delay between retries. Defaults to {DEFAULT_RETRY_DELAY_SECONDS}.",
)
parser.add_argument(
"--timeout-seconds",
type=float,
help=f"HTTP timeout per request. Defaults to {DEFAULT_TIMEOUT_SECONDS}.",
)
parser.add_argument(
"--json",
action="store_true",
help="Emit machine-readable JSON instead of text output.",
)
return parser
def parse_args(argv: list[str]) -> tuple[argparse.Namespace, SearchRequest]:
parser = build_parser()
parsed = parser.parse_args(argv)
limit = parsed.limit if parsed.limit is not None else DEFAULT_LIMIT
retries = parsed.retries if parsed.retries is not None else DEFAULT_RETRIES
retry_delay_seconds = (
parsed.retry_delay_seconds
if parsed.retry_delay_seconds is not None
else DEFAULT_RETRY_DELAY_SECONDS
)
timeout_seconds = (
parsed.timeout_seconds
if parsed.timeout_seconds is not None
else DEFAULT_TIMEOUT_SECONDS
)
if limit <= 0:
raise RedditConfigurationError("--limit must be greater than 0.")
if retries < 0:
raise RedditConfigurationError("--retries must be 0 or greater.")
if retry_delay_seconds < 0:
raise RedditConfigurationError("--retry-delay-seconds must be 0 or greater.")
if timeout_seconds <= 0:
raise RedditConfigurationError("--timeout-seconds must be greater than 0.")
if parsed.resource == "raw" and parsed.path is None:
raise RedditConfigurationError("--path is required when --resource raw is used.")
if parsed.resource in ("subreddits", "users") and parsed.subreddit is not None:
raise RedditConfigurationError("--subreddit is only valid for posts or comments.")
extra_params = parse_extra_params(parsed.param)
search_request = SearchRequest(
resource=parsed.resource,
query=parsed.query,
subreddit=parsed.subreddit,
path=parsed.path,
sort=parsed.sort,
time_window=parsed.time,
limit=limit,
after=parsed.after,
before=parsed.before,
include_over_18=parsed.include_over_18,
extra_params=extra_params,
retries=retries,
retry_delay_seconds=retry_delay_seconds,
timeout_seconds=timeout_seconds,
)
return parsed, search_request
def parse_extra_params(raw_params: Optional[list[str]]) -> tuple[tuple[str, str], ...]:
if raw_params is None:
return ()
parsed_params: list[tuple[str, str]] = []
for raw_param in raw_params:
if "=" not in raw_param:
raise RedditConfigurationError(
f"Invalid --param value '{raw_param}'. Expected key=value.",
)
key, value = raw_param.split("=", 1)
stripped_key = key.strip()
if stripped_key == "":
raise RedditConfigurationError(
f"Invalid --param value '{raw_param}'. Parameter name cannot be empty.",
)
parsed_params.append((stripped_key, value))
return tuple(parsed_params)
def load_dotenv(env_file: Path) -> dict[str, str]:
if not env_file.exists():
return {}
values: dict[str, str] = {}
for raw_line in env_file.read_text(encoding="utf-8").splitlines():
line = raw_line.strip()
if line == "" or line.startswith("#"):
continue
if "=" not in line:
raise RedditConfigurationError(
f"Invalid line in {env_file}: '{raw_line}'. Expected KEY=VALUE.",
)
key, value = line.split("=", 1)
normalized_key = key.strip()
if normalized_key == "":
raise RedditConfigurationError(
f"Invalid line in {env_file}: '{raw_line}'. Key cannot be empty.",
)
values[normalized_key] = value.strip()
return values
def get_credentials() -> Credentials:
file_values = load_dotenv(ENV_FILE)
client_id = os.environ.get("REDDIT_CLIENT_ID", file_values.get("REDDIT_CLIENT_ID"))
client_secret = os.environ.get(
"REDDIT_CLIENT_SECRET",
file_values.get("REDDIT_CLIENT_SECRET"),
)
user_agent = os.environ.get("REDDIT_USER_AGENT", file_values.get("REDDIT_USER_AGENT"))
missing_keys = [
key
for key, value in (
("REDDIT_CLIENT_ID", client_id),
("REDDIT_CLIENT_SECRET", client_secret),
("REDDIT_USER_AGENT", user_agent),
)
if value is None or value == ""
]
if missing_keys:
missing_list = ", ".join(missing_keys)
raise RedditConfigurationError(
"Missing Reddit credentials. "
f"Set {missing_list} in the environment or in {ENV_FILE}.",
)
return Credentials(
client_id=client_id,
client_secret=client_secret,
user_agent=user_agent,
)
def resolve_path(search_request: SearchRequest) -> str:
if search_request.resource == "raw":
assert search_request.path is not None
return normalize_path(search_request.path)
if search_request.resource == "subreddits":
return "/subreddits/search"
if search_request.resource == "users":
return "/users/search"
if search_request.subreddit is not None:
return f"/r/{search_request.subreddit}/search"
return "/search"
def normalize_path(raw_path: str) -> str:
stripped_path = raw_path.strip()
if stripped_path == "":
raise RedditConfigurationError("--path cannot be empty.")
if stripped_path.startswith("/"):
return stripped_path
return f"/{stripped_path}"
def build_query_params(search_request: SearchRequest) -> dict[str, str]:
params: dict[str, str] = {
"q": search_request.query,
"limit": str(search_request.limit),
}
if search_request.sort is not None:
params["sort"] = search_request.sort
if search_request.time_window is not None:
params["t"] = search_request.time_window
if search_request.after is not None:
params["after"] = search_request.after
if search_request.before is not None:
params["before"] = search_request.before
if search_request.include_over_18:
params["include_over_18"] = "on"
if search_request.resource == "comments":
params["type"] = "comment"
if search_request.resource == "posts":
params["type"] = "link"
for key, value in search_request.extra_params:
params[key] = value
return params
def fetch_access_token(
credentials: Credentials,
retries: int,
retry_delay_seconds: float,
timeout_seconds: float,
) -> str:
request_body = parse.urlencode({"grant_type": "client_credentials"}).encode("utf-8")
basic_auth = request.HTTPBasicAuthHandler()
basic_auth.add_password(
realm=None,
uri=TOKEN_URL,
user=credentials.client_id,
passwd=credentials.client_secret,
)
opener = request.build_opener(basic_auth)
last_error: Optional[BaseException] = None
attempt_index = 0
while attempt_index <= retries:
token_request = request.Request(
TOKEN_URL,
data=request_body,
headers={
"User-Agent": credentials.user_agent,
"Content-Type": "application/x-www-form-urlencoded",
},
method="POST",
)
try:
with opener.open(token_request, timeout=timeout_seconds) as response:
body = response.read().decode("utf-8", errors="replace")
if response.status < 200 or response.status >= 300:
raise RedditRequestError(
"Token request failed. "
f"url={TOKEN_URL} status={response.status} body={truncate_text(body)}",
)
payload = json.loads(body)
access_token = payload.get("access_token")
if not isinstance(access_token, str) or access_token == "":
raise RedditResponseError(
"Token response did not include a usable access_token. "
f"url={TOKEN_URL} body={truncate_text(body)}",
)
return access_token
except error.HTTPError as exc:
response_body = exc.read().decode("utf-8", errors="replace")
last_error = RedditRequestError(
"Token request failed. "
f"url={TOKEN_URL} status={exc.code} body={truncate_text(response_body)}",
)
if not should_retry_http(exc.code, attempt_index, retries):
raise last_error
except error.URLError as exc:
last_error = RedditRequestError(
"Token request failed. "
f"url={TOKEN_URL} reason={exc.reason}",
)
if attempt_index >= retries:
raise last_error
except json.JSONDecodeError as exc:
raise RedditResponseError(
f"Token response was not valid JSON. url={TOKEN_URL} error={exc}",
) from exc
attempt_index += 1
emit_retry_warning(
"token",
attempt_index,
retries,
retry_delay_seconds,
last_error,
)
time.sleep(retry_delay_seconds * attempt_index)
raise RedditRequestError("Token request failed without a captured error.")
def perform_search(
credentials: Credentials,
search_request: SearchRequest,
) -> SearchOutput:
path = resolve_path(search_request)
params = build_query_params(search_request)
access_token = fetch_access_token(
credentials=credentials,
retries=search_request.retries,
retry_delay_seconds=search_request.retry_delay_seconds,
timeout_seconds=search_request.timeout_seconds,
)
url = f"{OAUTH_BASE_URL}{path}?{parse.urlencode(params)}"
last_error: Optional[BaseException] = None
attempt_index = 0
while attempt_index <= search_request.retries:
api_request = request.Request(
url,
headers={
"Authorization": f"Bearer {access_token}",
"User-Agent": credentials.user_agent,
"Accept": "application/json",
},
method="GET",
)
try:
with request.urlopen(api_request, timeout=search_request.timeout_seconds) as response:
body = response.read().decode("utf-8", errors="replace")
if response.status < 200 or response.status >= 300:
raise RedditRequestError(
"Reddit search request failed. "
f"path={path} params={json.dumps(params, sort_keys=True)} "
f"status={response.status} body={truncate_text(body)}",
)
payload = json.loads(body)
return parse_search_output(
resource=search_request.resource,
query=search_request.query,
path=path,
params=params,
payload=payload,
)
except error.HTTPError as exc:
response_body = exc.read().decode("utf-8", errors="replace")
last_error = RedditRequestError(
"Reddit search request failed. "
f"path={path} params={json.dumps(params, sort_keys=True)} "
f"status={exc.code} body={truncate_text(response_body)}",
)
if not should_retry_http(exc.code, attempt_index, search_request.retries):
raise last_error
except error.URLError as exc:
last_error = RedditRequestError(
"Reddit search request failed. "
f"path={path} params={json.dumps(params, sort_keys=True)} reason={exc.reason}",
)
if attempt_index >= search_request.retries:
raise last_error
except json.JSONDecodeError as exc:
raise RedditResponseError(
"Reddit search response was not valid JSON. "
f"path={path} params={json.dumps(params, sort_keys=True)} error={exc}",
) from exc
attempt_index += 1
emit_retry_warning(
"search",
attempt_index,
search_request.retries,
search_request.retry_delay_seconds,
last_error,
)
time.sleep(search_request.retry_delay_seconds * attempt_index)
raise RedditRequestError("Search request failed without a captured error.")
def should_retry_http(status_code: int, attempt_index: int, retries: int) -> bool:
retryable_status = status_code == 429 or 500 <= status_code < 600
return retryable_status and attempt_index < retries
def emit_retry_warning(
stage: str,
attempt_number: int,
retries: int,
retry_delay_seconds: float,
last_error: Optional[BaseException],
) -> None:
if attempt_number > retries:
return
print(
(
f"warning: reddit {stage} request failed; "
f"retry {attempt_number} of {retries} in {retry_delay_seconds * attempt_number:.1f}s; "
f"error={last_error}"
),
file=sys.stderr,
)
def parse_search_output(
resource: str,
query: str,
path: str,
params: dict[str, str],
payload: dict[str, Any],
) -> SearchOutput:
data = payload.get("data")
if not isinstance(data, dict):
raise RedditResponseError(
"Reddit search response did not include a listing data object. "
f"path={path} params={json.dumps(params, sort_keys=True)} body={truncate_text(json.dumps(payload))}",
)
children = data.get("children")
if not isinstance(children, list):
raise RedditResponseError(
"Reddit search response did not include listing children. "
f"path={path} params={json.dumps(params, sort_keys=True)} body={truncate_text(json.dumps(payload))}",
)
results = [parse_result(child) for child in children]
after_value = data.get("after")
before_value = data.get("before")
return SearchOutput(
resource=resource,
query=query,
path=path,
params=params,
returned_count=len(results),
after=after_value if isinstance(after_value, str) else None,
before=before_value if isinstance(before_value, str) else None,
results=results,
)
def parse_result(child: Any) -> SearchResult:
if not isinstance(child, dict):
raise RedditResponseError(f"Unexpected Reddit child item: {child!r}")
kind = child.get("kind")
data = child.get("data")
if not isinstance(kind, str) or not isinstance(data, dict):
raise RedditResponseError(f"Unexpected Reddit child shape: {child!r}")
permalink = data.get("permalink")
if isinstance(permalink, str):
absolute_permalink = f"https://www.reddit.com{permalink}"
else:
absolute_permalink = None
created_utc_value = data.get("created_utc")
if isinstance(created_utc_value, (int, float)):
created_utc = float(created_utc_value)
else:
created_utc = None
score_value = data.get("score")
score = score_value if isinstance(score_value, int) else None
comment_count_value = data.get("num_comments")
comment_count = comment_count_value if isinstance(comment_count_value, int) else None
return SearchResult(
kind=kind,
id=str(data.get("id", "")),
name=str(data.get("name", "")),
title=string_or_none(data.get("title")),
text=first_non_empty_string(
string_or_none(data.get("selftext")),
string_or_none(data.get("body")),
string_or_none(data.get("public_description")),
string_or_none(data.get("description")),
),
subreddit=string_or_none(data.get("subreddit")),
author=string_or_none(data.get("author")),
score=score,
comment_count=comment_count,
created_utc=created_utc,
permalink=absolute_permalink,
url=string_or_none(data.get("url")),
over_18=bool_or_none(data.get("over_18")),
)
def string_or_none(value: Any) -> Optional[str]:
if isinstance(value, str) and value != "":
return value
return None
def bool_or_none(value: Any) -> Optional[bool]:
if isinstance(value, bool):
return value
return None
def first_non_empty_string(*values: Optional[str]) -> Optional[str]:
for value in values:
if value is not None and value != "":
return value
return None
def truncate_text(value: str) -> str:
limit = 500
if len(value) <= limit:
return value
return f"{value[:limit]}...[truncated]"
def render_text(output: SearchOutput) -> str:
lines: list[str] = [
f"resource: {output.resource}",
f"query: {output.query}",
f"path: {output.path}",
f"params: {json.dumps(output.params, sort_keys=True)}",
f"returned_count: {output.returned_count}",
f"after: {output.after}",
f"before: {output.before}",
]
for index, result in enumerate(output.results, start=1):
lines.append("")
lines.append(f"{index}. [{result.kind}] {result.title or result.name or result.id}")
if result.subreddit is not None:
lines.append(f" subreddit: {result.subreddit}")
if result.author is not None:
lines.append(f" author: {result.author}")
if result.score is not None:
lines.append(f" score: {result.score}")
if result.comment_count is not None:
lines.append(f" comment_count: {result.comment_count}")
if result.permalink is not None:
lines.append(f" permalink: {result.permalink}")
if result.url is not None and result.url != result.permalink:
lines.append(f" url: {result.url}")
if result.text is not None:
lines.append(f" text: {truncate_text(result.text)}")
return "\n".join(lines)
def main(argv: list[str]) -> int:
parsed, search_request = parse_args(argv)
credentials = get_credentials()
output = perform_search(
credentials=credentials,
search_request=search_request,
)
if parsed.json:
print(json.dumps(asdict(output), indent=2))
else:
print(render_text(output))
return 0
if __name__ == "__main__":
try:
raise SystemExit(main(sys.argv[1:]))
except RedditSearchError as exc:
print(f"error: {exc}", file=sys.stderr)
raise SystemExit(1) from exc
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment