Created
March 18, 2026 20:57
-
-
Save regexyl/8a5ec296f00df6f2caf1f0c8dd41d2f9 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| #!/usr/bin/env python3 | |
| from __future__ import annotations | |
| import argparse | |
| import json | |
| import os | |
| import sys | |
| import time | |
| from dataclasses import asdict, dataclass | |
| from pathlib import Path | |
| from typing import Any, Optional | |
| from urllib import error, parse, request | |
| DEFAULT_LIMIT = 10 | |
| DEFAULT_RETRIES = 2 | |
| DEFAULT_RETRY_DELAY_SECONDS = 1.0 | |
| DEFAULT_TIMEOUT_SECONDS = 20.0 | |
| TOKEN_URL = "https://www.reddit.com/api/v1/access_token" | |
| OAUTH_BASE_URL = "https://oauth.reddit.com" | |
| SKILL_DIR = Path(__file__).resolve().parents[1] | |
| ENV_FILE = SKILL_DIR / ".env" | |
| class RedditSearchError(RuntimeError): | |
| """Base error for the Reddit search CLI.""" | |
| class RedditConfigurationError(RedditSearchError): | |
| """Raised when required credentials or parameters are missing.""" | |
| class RedditRequestError(RedditSearchError): | |
| """Raised when Reddit returns an unsuccessful response.""" | |
| class RedditResponseError(RedditSearchError): | |
| """Raised when Reddit returns an unexpected payload.""" | |
| @dataclass(frozen=True) | |
| class Credentials: | |
| client_id: str | |
| client_secret: str | |
| user_agent: str | |
| @dataclass(frozen=True) | |
| class SearchRequest: | |
| resource: str | |
| query: str | |
| subreddit: Optional[str] | |
| path: Optional[str] | |
| sort: Optional[str] | |
| time_window: Optional[str] | |
| limit: int | |
| after: Optional[str] | |
| before: Optional[str] | |
| include_over_18: bool | |
| extra_params: tuple[tuple[str, str], ...] | |
| retries: int | |
| retry_delay_seconds: float | |
| timeout_seconds: float | |
| @dataclass(frozen=True) | |
| class SearchResult: | |
| kind: str | |
| id: str | |
| name: str | |
| title: Optional[str] | |
| text: Optional[str] | |
| subreddit: Optional[str] | |
| author: Optional[str] | |
| score: Optional[int] | |
| comment_count: Optional[int] | |
| created_utc: Optional[float] | |
| permalink: Optional[str] | |
| url: Optional[str] | |
| over_18: Optional[bool] | |
| @dataclass(frozen=True) | |
| class SearchOutput: | |
| resource: str | |
| query: str | |
| path: str | |
| params: dict[str, str] | |
| returned_count: int | |
| after: Optional[str] | |
| before: Optional[str] | |
| results: list[SearchResult] | |
| def build_parser() -> argparse.ArgumentParser: | |
| parser = argparse.ArgumentParser( | |
| description="Search Reddit posts, comments, subreddits, or users with OAuth.", | |
| ) | |
| parser.add_argument( | |
| "--query", | |
| required=True, | |
| help="Reddit search query string.", | |
| ) | |
| parser.add_argument( | |
| "--resource", | |
| required=True, | |
| choices=("posts", "comments", "subreddits", "users", "raw"), | |
| help="Reddit search resource to query.", | |
| ) | |
| parser.add_argument( | |
| "--subreddit", | |
| help="Optional subreddit name when searching posts or comments inside one subreddit.", | |
| ) | |
| parser.add_argument( | |
| "--path", | |
| help="OAuth API path for --resource raw, for example /search or /subreddits/search.", | |
| ) | |
| parser.add_argument( | |
| "--sort", | |
| help="Reddit sort value, for example relevance, hot, top, new, or comments.", | |
| ) | |
| parser.add_argument( | |
| "--time", | |
| help="Reddit time window, for example hour, day, week, month, year, or all.", | |
| ) | |
| parser.add_argument( | |
| "--limit", | |
| type=int, | |
| help=f"Maximum results to return. Defaults to {DEFAULT_LIMIT}.", | |
| ) | |
| parser.add_argument( | |
| "--after", | |
| help="Pagination cursor for the next page.", | |
| ) | |
| parser.add_argument( | |
| "--before", | |
| help="Pagination cursor for the previous page.", | |
| ) | |
| parser.add_argument( | |
| "--include-over-18", | |
| action="store_true", | |
| help="Pass include_over_18=on to Reddit.", | |
| ) | |
| parser.add_argument( | |
| "--param", | |
| action="append", | |
| help="Additional query parameter in key=value form. Repeat as needed.", | |
| ) | |
| parser.add_argument( | |
| "--retries", | |
| type=int, | |
| help=f"Retry count for 429, 5xx, and transient connection failures. Defaults to {DEFAULT_RETRIES}.", | |
| ) | |
| parser.add_argument( | |
| "--retry-delay-seconds", | |
| type=float, | |
| help=f"Base delay between retries. Defaults to {DEFAULT_RETRY_DELAY_SECONDS}.", | |
| ) | |
| parser.add_argument( | |
| "--timeout-seconds", | |
| type=float, | |
| help=f"HTTP timeout per request. Defaults to {DEFAULT_TIMEOUT_SECONDS}.", | |
| ) | |
| parser.add_argument( | |
| "--json", | |
| action="store_true", | |
| help="Emit machine-readable JSON instead of text output.", | |
| ) | |
| return parser | |
| def parse_args(argv: list[str]) -> tuple[argparse.Namespace, SearchRequest]: | |
| parser = build_parser() | |
| parsed = parser.parse_args(argv) | |
| limit = parsed.limit if parsed.limit is not None else DEFAULT_LIMIT | |
| retries = parsed.retries if parsed.retries is not None else DEFAULT_RETRIES | |
| retry_delay_seconds = ( | |
| parsed.retry_delay_seconds | |
| if parsed.retry_delay_seconds is not None | |
| else DEFAULT_RETRY_DELAY_SECONDS | |
| ) | |
| timeout_seconds = ( | |
| parsed.timeout_seconds | |
| if parsed.timeout_seconds is not None | |
| else DEFAULT_TIMEOUT_SECONDS | |
| ) | |
| if limit <= 0: | |
| raise RedditConfigurationError("--limit must be greater than 0.") | |
| if retries < 0: | |
| raise RedditConfigurationError("--retries must be 0 or greater.") | |
| if retry_delay_seconds < 0: | |
| raise RedditConfigurationError("--retry-delay-seconds must be 0 or greater.") | |
| if timeout_seconds <= 0: | |
| raise RedditConfigurationError("--timeout-seconds must be greater than 0.") | |
| if parsed.resource == "raw" and parsed.path is None: | |
| raise RedditConfigurationError("--path is required when --resource raw is used.") | |
| if parsed.resource in ("subreddits", "users") and parsed.subreddit is not None: | |
| raise RedditConfigurationError("--subreddit is only valid for posts or comments.") | |
| extra_params = parse_extra_params(parsed.param) | |
| search_request = SearchRequest( | |
| resource=parsed.resource, | |
| query=parsed.query, | |
| subreddit=parsed.subreddit, | |
| path=parsed.path, | |
| sort=parsed.sort, | |
| time_window=parsed.time, | |
| limit=limit, | |
| after=parsed.after, | |
| before=parsed.before, | |
| include_over_18=parsed.include_over_18, | |
| extra_params=extra_params, | |
| retries=retries, | |
| retry_delay_seconds=retry_delay_seconds, | |
| timeout_seconds=timeout_seconds, | |
| ) | |
| return parsed, search_request | |
| def parse_extra_params(raw_params: Optional[list[str]]) -> tuple[tuple[str, str], ...]: | |
| if raw_params is None: | |
| return () | |
| parsed_params: list[tuple[str, str]] = [] | |
| for raw_param in raw_params: | |
| if "=" not in raw_param: | |
| raise RedditConfigurationError( | |
| f"Invalid --param value '{raw_param}'. Expected key=value.", | |
| ) | |
| key, value = raw_param.split("=", 1) | |
| stripped_key = key.strip() | |
| if stripped_key == "": | |
| raise RedditConfigurationError( | |
| f"Invalid --param value '{raw_param}'. Parameter name cannot be empty.", | |
| ) | |
| parsed_params.append((stripped_key, value)) | |
| return tuple(parsed_params) | |
| def load_dotenv(env_file: Path) -> dict[str, str]: | |
| if not env_file.exists(): | |
| return {} | |
| values: dict[str, str] = {} | |
| for raw_line in env_file.read_text(encoding="utf-8").splitlines(): | |
| line = raw_line.strip() | |
| if line == "" or line.startswith("#"): | |
| continue | |
| if "=" not in line: | |
| raise RedditConfigurationError( | |
| f"Invalid line in {env_file}: '{raw_line}'. Expected KEY=VALUE.", | |
| ) | |
| key, value = line.split("=", 1) | |
| normalized_key = key.strip() | |
| if normalized_key == "": | |
| raise RedditConfigurationError( | |
| f"Invalid line in {env_file}: '{raw_line}'. Key cannot be empty.", | |
| ) | |
| values[normalized_key] = value.strip() | |
| return values | |
| def get_credentials() -> Credentials: | |
| file_values = load_dotenv(ENV_FILE) | |
| client_id = os.environ.get("REDDIT_CLIENT_ID", file_values.get("REDDIT_CLIENT_ID")) | |
| client_secret = os.environ.get( | |
| "REDDIT_CLIENT_SECRET", | |
| file_values.get("REDDIT_CLIENT_SECRET"), | |
| ) | |
| user_agent = os.environ.get("REDDIT_USER_AGENT", file_values.get("REDDIT_USER_AGENT")) | |
| missing_keys = [ | |
| key | |
| for key, value in ( | |
| ("REDDIT_CLIENT_ID", client_id), | |
| ("REDDIT_CLIENT_SECRET", client_secret), | |
| ("REDDIT_USER_AGENT", user_agent), | |
| ) | |
| if value is None or value == "" | |
| ] | |
| if missing_keys: | |
| missing_list = ", ".join(missing_keys) | |
| raise RedditConfigurationError( | |
| "Missing Reddit credentials. " | |
| f"Set {missing_list} in the environment or in {ENV_FILE}.", | |
| ) | |
| return Credentials( | |
| client_id=client_id, | |
| client_secret=client_secret, | |
| user_agent=user_agent, | |
| ) | |
| def resolve_path(search_request: SearchRequest) -> str: | |
| if search_request.resource == "raw": | |
| assert search_request.path is not None | |
| return normalize_path(search_request.path) | |
| if search_request.resource == "subreddits": | |
| return "/subreddits/search" | |
| if search_request.resource == "users": | |
| return "/users/search" | |
| if search_request.subreddit is not None: | |
| return f"/r/{search_request.subreddit}/search" | |
| return "/search" | |
| def normalize_path(raw_path: str) -> str: | |
| stripped_path = raw_path.strip() | |
| if stripped_path == "": | |
| raise RedditConfigurationError("--path cannot be empty.") | |
| if stripped_path.startswith("/"): | |
| return stripped_path | |
| return f"/{stripped_path}" | |
| def build_query_params(search_request: SearchRequest) -> dict[str, str]: | |
| params: dict[str, str] = { | |
| "q": search_request.query, | |
| "limit": str(search_request.limit), | |
| } | |
| if search_request.sort is not None: | |
| params["sort"] = search_request.sort | |
| if search_request.time_window is not None: | |
| params["t"] = search_request.time_window | |
| if search_request.after is not None: | |
| params["after"] = search_request.after | |
| if search_request.before is not None: | |
| params["before"] = search_request.before | |
| if search_request.include_over_18: | |
| params["include_over_18"] = "on" | |
| if search_request.resource == "comments": | |
| params["type"] = "comment" | |
| if search_request.resource == "posts": | |
| params["type"] = "link" | |
| for key, value in search_request.extra_params: | |
| params[key] = value | |
| return params | |
| def fetch_access_token( | |
| credentials: Credentials, | |
| retries: int, | |
| retry_delay_seconds: float, | |
| timeout_seconds: float, | |
| ) -> str: | |
| request_body = parse.urlencode({"grant_type": "client_credentials"}).encode("utf-8") | |
| basic_auth = request.HTTPBasicAuthHandler() | |
| basic_auth.add_password( | |
| realm=None, | |
| uri=TOKEN_URL, | |
| user=credentials.client_id, | |
| passwd=credentials.client_secret, | |
| ) | |
| opener = request.build_opener(basic_auth) | |
| last_error: Optional[BaseException] = None | |
| attempt_index = 0 | |
| while attempt_index <= retries: | |
| token_request = request.Request( | |
| TOKEN_URL, | |
| data=request_body, | |
| headers={ | |
| "User-Agent": credentials.user_agent, | |
| "Content-Type": "application/x-www-form-urlencoded", | |
| }, | |
| method="POST", | |
| ) | |
| try: | |
| with opener.open(token_request, timeout=timeout_seconds) as response: | |
| body = response.read().decode("utf-8", errors="replace") | |
| if response.status < 200 or response.status >= 300: | |
| raise RedditRequestError( | |
| "Token request failed. " | |
| f"url={TOKEN_URL} status={response.status} body={truncate_text(body)}", | |
| ) | |
| payload = json.loads(body) | |
| access_token = payload.get("access_token") | |
| if not isinstance(access_token, str) or access_token == "": | |
| raise RedditResponseError( | |
| "Token response did not include a usable access_token. " | |
| f"url={TOKEN_URL} body={truncate_text(body)}", | |
| ) | |
| return access_token | |
| except error.HTTPError as exc: | |
| response_body = exc.read().decode("utf-8", errors="replace") | |
| last_error = RedditRequestError( | |
| "Token request failed. " | |
| f"url={TOKEN_URL} status={exc.code} body={truncate_text(response_body)}", | |
| ) | |
| if not should_retry_http(exc.code, attempt_index, retries): | |
| raise last_error | |
| except error.URLError as exc: | |
| last_error = RedditRequestError( | |
| "Token request failed. " | |
| f"url={TOKEN_URL} reason={exc.reason}", | |
| ) | |
| if attempt_index >= retries: | |
| raise last_error | |
| except json.JSONDecodeError as exc: | |
| raise RedditResponseError( | |
| f"Token response was not valid JSON. url={TOKEN_URL} error={exc}", | |
| ) from exc | |
| attempt_index += 1 | |
| emit_retry_warning( | |
| "token", | |
| attempt_index, | |
| retries, | |
| retry_delay_seconds, | |
| last_error, | |
| ) | |
| time.sleep(retry_delay_seconds * attempt_index) | |
| raise RedditRequestError("Token request failed without a captured error.") | |
| def perform_search( | |
| credentials: Credentials, | |
| search_request: SearchRequest, | |
| ) -> SearchOutput: | |
| path = resolve_path(search_request) | |
| params = build_query_params(search_request) | |
| access_token = fetch_access_token( | |
| credentials=credentials, | |
| retries=search_request.retries, | |
| retry_delay_seconds=search_request.retry_delay_seconds, | |
| timeout_seconds=search_request.timeout_seconds, | |
| ) | |
| url = f"{OAUTH_BASE_URL}{path}?{parse.urlencode(params)}" | |
| last_error: Optional[BaseException] = None | |
| attempt_index = 0 | |
| while attempt_index <= search_request.retries: | |
| api_request = request.Request( | |
| url, | |
| headers={ | |
| "Authorization": f"Bearer {access_token}", | |
| "User-Agent": credentials.user_agent, | |
| "Accept": "application/json", | |
| }, | |
| method="GET", | |
| ) | |
| try: | |
| with request.urlopen(api_request, timeout=search_request.timeout_seconds) as response: | |
| body = response.read().decode("utf-8", errors="replace") | |
| if response.status < 200 or response.status >= 300: | |
| raise RedditRequestError( | |
| "Reddit search request failed. " | |
| f"path={path} params={json.dumps(params, sort_keys=True)} " | |
| f"status={response.status} body={truncate_text(body)}", | |
| ) | |
| payload = json.loads(body) | |
| return parse_search_output( | |
| resource=search_request.resource, | |
| query=search_request.query, | |
| path=path, | |
| params=params, | |
| payload=payload, | |
| ) | |
| except error.HTTPError as exc: | |
| response_body = exc.read().decode("utf-8", errors="replace") | |
| last_error = RedditRequestError( | |
| "Reddit search request failed. " | |
| f"path={path} params={json.dumps(params, sort_keys=True)} " | |
| f"status={exc.code} body={truncate_text(response_body)}", | |
| ) | |
| if not should_retry_http(exc.code, attempt_index, search_request.retries): | |
| raise last_error | |
| except error.URLError as exc: | |
| last_error = RedditRequestError( | |
| "Reddit search request failed. " | |
| f"path={path} params={json.dumps(params, sort_keys=True)} reason={exc.reason}", | |
| ) | |
| if attempt_index >= search_request.retries: | |
| raise last_error | |
| except json.JSONDecodeError as exc: | |
| raise RedditResponseError( | |
| "Reddit search response was not valid JSON. " | |
| f"path={path} params={json.dumps(params, sort_keys=True)} error={exc}", | |
| ) from exc | |
| attempt_index += 1 | |
| emit_retry_warning( | |
| "search", | |
| attempt_index, | |
| search_request.retries, | |
| search_request.retry_delay_seconds, | |
| last_error, | |
| ) | |
| time.sleep(search_request.retry_delay_seconds * attempt_index) | |
| raise RedditRequestError("Search request failed without a captured error.") | |
| def should_retry_http(status_code: int, attempt_index: int, retries: int) -> bool: | |
| retryable_status = status_code == 429 or 500 <= status_code < 600 | |
| return retryable_status and attempt_index < retries | |
| def emit_retry_warning( | |
| stage: str, | |
| attempt_number: int, | |
| retries: int, | |
| retry_delay_seconds: float, | |
| last_error: Optional[BaseException], | |
| ) -> None: | |
| if attempt_number > retries: | |
| return | |
| print( | |
| ( | |
| f"warning: reddit {stage} request failed; " | |
| f"retry {attempt_number} of {retries} in {retry_delay_seconds * attempt_number:.1f}s; " | |
| f"error={last_error}" | |
| ), | |
| file=sys.stderr, | |
| ) | |
| def parse_search_output( | |
| resource: str, | |
| query: str, | |
| path: str, | |
| params: dict[str, str], | |
| payload: dict[str, Any], | |
| ) -> SearchOutput: | |
| data = payload.get("data") | |
| if not isinstance(data, dict): | |
| raise RedditResponseError( | |
| "Reddit search response did not include a listing data object. " | |
| f"path={path} params={json.dumps(params, sort_keys=True)} body={truncate_text(json.dumps(payload))}", | |
| ) | |
| children = data.get("children") | |
| if not isinstance(children, list): | |
| raise RedditResponseError( | |
| "Reddit search response did not include listing children. " | |
| f"path={path} params={json.dumps(params, sort_keys=True)} body={truncate_text(json.dumps(payload))}", | |
| ) | |
| results = [parse_result(child) for child in children] | |
| after_value = data.get("after") | |
| before_value = data.get("before") | |
| return SearchOutput( | |
| resource=resource, | |
| query=query, | |
| path=path, | |
| params=params, | |
| returned_count=len(results), | |
| after=after_value if isinstance(after_value, str) else None, | |
| before=before_value if isinstance(before_value, str) else None, | |
| results=results, | |
| ) | |
| def parse_result(child: Any) -> SearchResult: | |
| if not isinstance(child, dict): | |
| raise RedditResponseError(f"Unexpected Reddit child item: {child!r}") | |
| kind = child.get("kind") | |
| data = child.get("data") | |
| if not isinstance(kind, str) or not isinstance(data, dict): | |
| raise RedditResponseError(f"Unexpected Reddit child shape: {child!r}") | |
| permalink = data.get("permalink") | |
| if isinstance(permalink, str): | |
| absolute_permalink = f"https://www.reddit.com{permalink}" | |
| else: | |
| absolute_permalink = None | |
| created_utc_value = data.get("created_utc") | |
| if isinstance(created_utc_value, (int, float)): | |
| created_utc = float(created_utc_value) | |
| else: | |
| created_utc = None | |
| score_value = data.get("score") | |
| score = score_value if isinstance(score_value, int) else None | |
| comment_count_value = data.get("num_comments") | |
| comment_count = comment_count_value if isinstance(comment_count_value, int) else None | |
| return SearchResult( | |
| kind=kind, | |
| id=str(data.get("id", "")), | |
| name=str(data.get("name", "")), | |
| title=string_or_none(data.get("title")), | |
| text=first_non_empty_string( | |
| string_or_none(data.get("selftext")), | |
| string_or_none(data.get("body")), | |
| string_or_none(data.get("public_description")), | |
| string_or_none(data.get("description")), | |
| ), | |
| subreddit=string_or_none(data.get("subreddit")), | |
| author=string_or_none(data.get("author")), | |
| score=score, | |
| comment_count=comment_count, | |
| created_utc=created_utc, | |
| permalink=absolute_permalink, | |
| url=string_or_none(data.get("url")), | |
| over_18=bool_or_none(data.get("over_18")), | |
| ) | |
| def string_or_none(value: Any) -> Optional[str]: | |
| if isinstance(value, str) and value != "": | |
| return value | |
| return None | |
| def bool_or_none(value: Any) -> Optional[bool]: | |
| if isinstance(value, bool): | |
| return value | |
| return None | |
| def first_non_empty_string(*values: Optional[str]) -> Optional[str]: | |
| for value in values: | |
| if value is not None and value != "": | |
| return value | |
| return None | |
| def truncate_text(value: str) -> str: | |
| limit = 500 | |
| if len(value) <= limit: | |
| return value | |
| return f"{value[:limit]}...[truncated]" | |
| def render_text(output: SearchOutput) -> str: | |
| lines: list[str] = [ | |
| f"resource: {output.resource}", | |
| f"query: {output.query}", | |
| f"path: {output.path}", | |
| f"params: {json.dumps(output.params, sort_keys=True)}", | |
| f"returned_count: {output.returned_count}", | |
| f"after: {output.after}", | |
| f"before: {output.before}", | |
| ] | |
| for index, result in enumerate(output.results, start=1): | |
| lines.append("") | |
| lines.append(f"{index}. [{result.kind}] {result.title or result.name or result.id}") | |
| if result.subreddit is not None: | |
| lines.append(f" subreddit: {result.subreddit}") | |
| if result.author is not None: | |
| lines.append(f" author: {result.author}") | |
| if result.score is not None: | |
| lines.append(f" score: {result.score}") | |
| if result.comment_count is not None: | |
| lines.append(f" comment_count: {result.comment_count}") | |
| if result.permalink is not None: | |
| lines.append(f" permalink: {result.permalink}") | |
| if result.url is not None and result.url != result.permalink: | |
| lines.append(f" url: {result.url}") | |
| if result.text is not None: | |
| lines.append(f" text: {truncate_text(result.text)}") | |
| return "\n".join(lines) | |
| def main(argv: list[str]) -> int: | |
| parsed, search_request = parse_args(argv) | |
| credentials = get_credentials() | |
| output = perform_search( | |
| credentials=credentials, | |
| search_request=search_request, | |
| ) | |
| if parsed.json: | |
| print(json.dumps(asdict(output), indent=2)) | |
| else: | |
| print(render_text(output)) | |
| return 0 | |
| if __name__ == "__main__": | |
| try: | |
| raise SystemExit(main(sys.argv[1:])) | |
| except RedditSearchError as exc: | |
| print(f"error: {exc}", file=sys.stderr) | |
| raise SystemExit(1) from exc |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment