Skip to content

Instantly share code, notes, and snippets.

@crustymonkey
Created March 23, 2024 21:41
Show Gist options
  • Save crustymonkey/6f66025f5d25fce1f4bd382167519424 to your computer and use it in GitHub Desktop.
Save crustymonkey/6f66025f5d25fce1f4bd382167519424 to your computer and use it in GitHub Desktop.
#!/usr/bin/env python3
import logging
import os
import re
import subprocess as sp
import sys
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter
from http.client import HTTPResponse
from tempfile import NamedTemporaryFile
from typing import Union, Any
from urllib.request import Request, build_opener, urlopen, HTTPErrorProcessor
# Used if there's no state file
DEF_VERS = '0.0.46'
class NoOpRedirHandler(HTTPErrorProcessor):
def https_response(self, request: Request, response: HTTPResponse) -> Any:
if response.status == 302:
return response
return super().http_response(request, response)
def get_args():
p = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
p.add_argument('-a', '--ua-string', default='curl/7.81.0',
help='Override the default user agent string with this')
p.add_argument('-u', '--base-url',
default='https://discord.com/api/download?platform=linux&format=deb',
help='The URL to fetch the location of the current download')
p.add_argument('-f', '--state-file', default='/var/tmp/discord.state',
help='The path to the file that contains the current version of '
'discord')
p.add_argument('-D', '--debug', action='store_true', default=False,
help='Add debug output')
args = p.parse_args()
return args
def setup_logging(args):
level = logging.DEBUG if args.debug else logging.INFO
logging.basicConfig(
format=(
'%(asctime)s - %(levelname)s - '
'%(filename)s:%(lineno)d %(funcName)s - %(message)s'
),
level=level,
)
def write_vers(cur_vers: str, args) -> None:
logging.debug(f'Writing vers to {args.state_file}: {cur_vers}')
with open(args.state_file, 'w') as fh:
fh.write(cur_vers.strip())
def get_cur_version(args) -> str:
cur_vers = DEF_VERS
if not os.path.isfile(args.state_file):
logging.debug(f'State file {args.state_file} not found, writing '
f'default version to file: {DEF_VERS}')
write_vers(cur_vers, args)
return cur_vers
with open(args.state_file) as fh:
cur_vers = fh.read().strip()
logging.debug(
f'Found the current version in the state file: {cur_vers}'
)
return cur_vers
def get_vers_from_loc(url: str) -> Union[str, None]:
reg = re.compile(r'/discord-([\d.]+?).deb')
if m := reg.search(url):
logging.debug(f'Remote version: {m.group(1)}')
return m.group(1)
return None
def download(url: str, args) -> str:
"""
Download the new file into a temporary file
"""
bytes_wrtn = 0
with NamedTemporaryFile(delete=False) as tmp:
req = Request(url, headers={'User-Agent': args.ua_string})
res = urlopen(req)
if res.status > 299:
raise RuntimeError(
f'Failed to open "{url}": {res.status}/{res.reason}'
)
while chunk := res.read(4096):
bytes_wrtn += len(chunk)
tmp.write(chunk)
logging.debug(f'Wrote {bytes_wrtn} to {tmp.name}')
return tmp.name
def install(fname: str) -> None:
cmd = ['/usr/bin/sudo', '/usr/bin/apt', 'install', fname]
logging.debug(f'Running cmd: {" ".join(cmd)}')
p = sp.run(cmd, capture_output=True)
if p.returncode != 0:
logging.error(f'Failed to run cmd: {" ".join(cmd)}')
logging.error(f'{p.stderr}')
raise RuntimeError('Failed to install package')
def chk_version(cur_vers: str, args) -> Union[str, None]:
"""
Returns either a string for the url to download a new version or None if
it matches the current version
"""
req = Request(
args.base_url,
headers={'User-Agent': args.ua_string},
method='HEAD',
)
# We don't want to actually follow the redirect here. We just want to
# parse the location.
opener = build_opener(NoOpRedirHandler)
try:
res = opener.open(req)
except Exception as e:
logging.exception(e)
return None
loc = res.getheader('location')
if loc is None:
logging.debug(f'Result headers: {res.getheaders()}')
raise RuntimeError(f'Failed to get latest version, no location header')
rem_vers = get_vers_from_loc(loc)
if rem_vers > cur_vers:
return loc
def rename(tmp_name: str, vers: str) -> str:
fname = f'discord-{vers}.deb'
fpath = os.path.join(os.path.dirname(tmp_name), fname)
logging.debug(f'Moving {tmp_name} -> {fpath}')
os.rename(tmp_name, fpath)
return fpath
def main() -> int:
args = get_args()
setup_logging(args)
cur_vers = get_cur_version(args)
url = chk_version(cur_vers, args)
if url is None:
logging.debug(f'No version newer than {cur_vers}')
return 0
new_vers = get_vers_from_loc(url)
try:
tmp_path = download(url, args)
except Exception as e:
logging.exception(e)
return 1
fpath = rename(tmp_path, new_vers)
try:
install(fpath)
except Exception as e:
logging.exception(e)
return 1
finally:
os.unlink(fpath)
write_vers(new_vers, args)
return 0
if __name__ == '__main__':
try:
sys.exit(main())
except KeyboardInterrupt:
sys.exit(0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment