Last active
December 14, 2020 13:10
-
-
Save peolic/398c43fd673a254e58edc2ce4bda9a9a to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
"""Manually submit fingerprints to StashDB.org""" | |
import argparse | |
import hashlib | |
import json | |
import math | |
import os | |
import struct | |
import subprocess | |
import sys | |
from functools import partial | |
from pathlib import Path | |
from textwrap import dedent | |
from typing import Any, Callable, Dict, List, Literal, Optional | |
import requests | |
import yaml | |
STASH_APP_PATH: Path = Path.home() / '.stash' | |
CONFIG_PATH = STASH_APP_PATH / 'config.yml' | |
AnyHashAlgorithm = Literal['OSHASH', 'MD5'] | |
ALGORITHMS = ('OSHASH', 'MD5') | |
# ALGORITHMS = ('OSHASH',) | |
class Config: | |
def __init__(self, cfg: Dict[str, Any]) -> None: | |
self.database: str = cfg.get('database') or str(STASH_APP_PATH / 'stash-go.sqlite') | |
self.stash_boxes: List[Dict[str, str]] = cfg.get('stash_boxes', []) | |
@classmethod | |
def read(cls, path: Path): | |
with path.open() as fh: | |
return cls(yaml.safe_load(fh)) | |
def md5(fname: str) -> str: | |
hash_md5 = hashlib.md5() | |
with open(fname, 'rb') as f: | |
# Note: Upped from 8192 to 1024 * 1024 * 10 | |
for chunk in iter(partial(f.read, 10_485_760), b''): | |
hash_md5.update(chunk) | |
return hash_md5.hexdigest() | |
def oshash(fname: str) -> str: | |
long_long_format = '<q' # little-endian long long | |
byte_size = struct.calcsize(long_long_format) | |
filesize = os.path.getsize(fname) | |
hash_value = filesize | |
if filesize < 65536 * 2: | |
raise Exception('SizeError') | |
chunk_size = 65536 // byte_size | |
with open(fname, 'rb') as fh: | |
for x in range(chunk_size): | |
buffer = fh.read(byte_size) | |
(l_value, ) = struct.unpack(long_long_format, buffer) | |
hash_value += l_value | |
hash_value = hash_value & 0xFFFFFFFFFFFFFFFF # to remain as 64bit number | |
fh.seek(max(0, filesize - 65536), 0) | |
for x in range(chunk_size): | |
buffer = fh.read(byte_size) | |
(l_value, ) = struct.unpack(long_long_format, buffer) | |
hash_value += l_value | |
hash_value = hash_value & 0xFFFFFFFFFFFFFFFF | |
return '%016x' % hash_value | |
class Video: | |
HASH_FUNCTIONS: Dict[AnyHashAlgorithm, Callable[[str], str]] = { | |
'MD5': md5, | |
'OSHASH': oshash, | |
} | |
def __init__(self, file_path: Path) -> None: | |
self.file_path: Path = file_path | |
self._hash_cache: Dict[AnyHashAlgorithm, Optional[str]] = dict.fromkeys(self.HASH_FUNCTIONS.keys()) | |
self._duration = None | |
def hash(self, algorithm: AnyHashAlgorithm): | |
cached = self._hash_cache[algorithm] | |
if cached: | |
return cached | |
func = self.HASH_FUNCTIONS[algorithm] | |
result = func(str(self.file_path)) | |
self._hash_cache[algorithm] = result | |
return result | |
@property | |
def duration(self) -> int: | |
if not self._duration: | |
probe_result = subprocess.check_output([ | |
'ffprobe', | |
'-v', 'quiet', | |
'-print_format', 'json', | |
'-show_format', | |
'-show_streams', | |
'-show_error', | |
self.file_path, | |
]) | |
data = json.loads(probe_result) | |
fmt = data['format'] | |
duration = round(float(fmt['duration']) * 100) / 100 | |
if duration >= 0.0: | |
self._duration = math.floor(duration + 0.5) | |
else: | |
self._duration = math.ceil(duration - 0.5) | |
return self._duration | |
def as_fingerprint_submission(self, scene_id: str, algorithm: AnyHashAlgorithm) -> 'FingerprintSubmission': | |
hash_value = self.hash(algorithm) | |
if not hash_value: | |
raise ValueError('Empty hash_value') | |
duration = self.duration | |
if not duration: | |
raise ValueError('Empty duration') | |
return FingerprintSubmission( | |
scene_id=scene_id, | |
fingerprint=FingerprintInput( | |
algorithm=algorithm, | |
hash=hash_value, | |
duration=duration, | |
) | |
) | |
class FingerprintInput(dict): | |
"""Type""" | |
def __init__(self, algorithm: AnyHashAlgorithm, hash: str, duration: int) -> None: | |
self['algorithm'] = self.algorithm = algorithm | |
self['hash'] = self.hash = hash | |
self['duration'] = self.duration = duration | |
class FingerprintSubmission(dict): | |
"""Type""" | |
def __init__(self, scene_id: str, fingerprint: FingerprintInput) -> None: | |
self['scene_id'] = self.scene_id = scene_id | |
self['fingerprint'] = self.fingerprint = fingerprint | |
class SubmitFingerprint(): | |
"""Mutation""" | |
def __init__(self) -> None: | |
self.operation_name: str = 'SubmitFingerprint' | |
self.mutation_name: str = 'submitFingerprint' | |
@staticmethod | |
def variables(input: FingerprintSubmission): | |
return {'input': input} | |
def __str__(self) -> str: | |
return dedent("""\ | |
mutation """ + self.operation_name + """($input: FingerprintSubmission!) { | |
""" + self.mutation_name + """(input: $input) | |
}\ | |
""") | |
class QueryScene(): | |
"""Query""" | |
def __init__(self) -> None: | |
self.operation_name: str = 'Scene' | |
self.mutation_name: str = 'findScene' | |
@staticmethod | |
def variables(scene_id: str): | |
return {'id': scene_id} | |
def __str__(self) -> str: | |
# https://github.com/stashapp/stash-box/blob/35979fa4365436b0669e66e4dfe6478e2e83ca0c/frontend/src/queries/Scene.gql | |
return dedent("""\ | |
query """ + self.operation_name + """($id: ID!) { | |
""" + self.mutation_name + """(id: $id) { | |
id | |
date | |
title, | |
fingerprints { | |
hash | |
algorithm | |
duration | |
} | |
studio { | |
name | |
} | |
performers { | |
performer { | |
name | |
disambiguation | |
id | |
gender | |
aliases | |
} | |
} | |
urls { | |
url, | |
type | |
} | |
images { | |
url | |
} | |
} | |
}\ | |
""") | |
class StashDB: | |
def __init__(self, endpoint: str, apikey: str, name: str = ''): | |
self.endpoint = endpoint | |
self.name = name | |
self.session = requests.Session() | |
self.session.headers.update({ | |
'Content-Type': 'application/json', | |
'ApiKey': apikey, | |
}) | |
def _call(self, query: str, operation_name: str, variables: Optional[Dict[str, Any]] = None): | |
json_data = { | |
'operationName': operation_name, | |
'query': query, | |
'variables': variables or {}, | |
} | |
response = self.session.post(self.endpoint, json=json_data) | |
result = response.json() | |
if 'errors' in result: | |
print('GraphQL Errors:') | |
for error in result['errors']: | |
if 'locations' in error: | |
for location in error['locations']: | |
print(f"At line {location['line']} column {location['column']}") | |
print(f" [{error['extensions']['code']}] {error['message']}") | |
elif 'path' in error: | |
print(f"At path /{'/'.join(map(str, error['path']))}: {error['message']}") | |
else: | |
print(error) | |
return None | |
return result['data'] | |
def find_scene_by_id(self, scene_id: str): | |
if not scene_id: | |
raise ValueError('Empty scene_id') | |
query = QueryScene() | |
results = self._call( | |
query=str(query), | |
operation_name=query.operation_name, | |
variables=query.variables(scene_id), | |
) | |
if not results: | |
print('Not found.') | |
return | |
return results.get(query.mutation_name, None) | |
def submit_fingerprints(self, fingerprints: List[FingerprintSubmission]): | |
count = len(fingerprints) | |
mutation = SubmitFingerprint() | |
query = str(mutation) | |
def make_kwargs(fingerprint: FingerprintSubmission): | |
return dict( | |
query=query, | |
operation_name=mutation.operation_name, | |
variables=mutation.variables(fingerprint), | |
) | |
for idx, fingerprint in enumerate(fingerprints, 1): | |
fp_data = fingerprint.fingerprint | |
current = ' | '.join([fp_data.algorithm, fp_data.hash, str(fp_data.duration)]) | |
print(f'[{idx: 2}/{count: 2}] Current: {fingerprint.scene_id} [{current}]') | |
try: | |
answer = input('Submit? [Y/n] ').strip().lower() | |
except (KeyboardInterrupt, SystemExit): | |
print('Canceled.') | |
return | |
if answer == 'n': | |
print('Skipped.') | |
continue | |
results = self._call(**make_kwargs(fingerprint)) | |
if results: | |
result = results.get(mutation.mutation_name, None) | |
print(f'Result: {result}') | |
else: | |
print('Failed.') | |
class Arguments(argparse.Namespace): | |
video_path: str | |
stash_scene_id: Optional[str] | |
data: bool | |
def run(args: Arguments): | |
cfg = Config.read(CONFIG_PATH) | |
video = Video( | |
file_path=Path(args.video_path).resolve(), | |
) | |
print('File: '.ljust(10) + str(video.file_path.name), file=sys.stderr) | |
for algorithm in ALGORITHMS: | |
print(f'{algorithm}: '.ljust(10) + video.hash(algorithm), file=sys.stderr) | |
print('Duration: '.ljust(10) + str(video.duration), file=sys.stderr) | |
print(','.join( | |
[ | |
f'"{video.file_path.name}"' if ' ' in video.file_path.name else video.file_path.name, | |
str(video.duration), | |
] + [ | |
video.hash(algorithm) | |
for algorithm in ALGORITHMS | |
] | |
)) | |
stash_id = args.stash_scene_id | |
if not stash_id: | |
return | |
print() | |
stashdb = next( | |
sbc for sbc in cfg.stash_boxes | |
if 'stashdb.org/' in sbc['endpoint'] | |
) | |
client = StashDB(**stashdb) | |
remote_data = client.find_scene_by_id(stash_id) | |
if remote_data: | |
print('StashDB.org Data:') | |
scene_url = stashdb['endpoint'].replace('/graphql', f'/scenes/{remote_data["id"]}') | |
print(f' {"URL":<13}: {scene_url}') | |
print(f' {"Title":<13}: {remote_data["title"]}') | |
print(f' {"Date":<13}: {remote_data["date"]}') | |
print(f' {"Studio":<13}: {remote_data["studio"]["name"]}') | |
performers = [p['performer'] for p in remote_data['performers']] | |
if performers: | |
print(f' {"Performers":<13}:') | |
for pi, p in enumerate(performers, 1): | |
name = p['name'] | |
if p['disambiguation']: | |
name += ' (' + p['disambiguation'] + ')' | |
if p['aliases']: | |
name += ' [' + ', '.join(p['aliases']) + ']' | |
print(f' {pi}. [{p["gender"][0]}] {name}') | |
if remote_data['images']: | |
print(f' {"Images":<13}:') | |
for ii, i in enumerate(remote_data['images'], 1): | |
print(f' {ii}. {i["url"]}') | |
if remote_data['fingerprints']: | |
print(f' {"Fingerprints":<13}:') | |
for fpi, fp in enumerate(remote_data['fingerprints'], 1): | |
print(f' {fpi}. [{fp["algorithm"]:<6}] {fp["hash"]} | Duration: {fp["duration"]}') | |
# from pprint import pprint | |
# print() | |
# pprint(remote_data) | |
# return | |
else: | |
print('StashDB.org scene not found!') | |
return | |
print() | |
fingerprints = [ | |
video.as_fingerprint_submission(stash_id, algorithm) | |
for algorithm in ALGORITHMS | |
] | |
remote_fingerprints = { | |
algorithm: {} | |
for algorithm in ALGORITHMS | |
} | |
for fp in remote_data['fingerprints']: | |
remote_fingerprints[fp['algorithm']][fp['hash']] = fp['duration'] | |
to_remove = [] | |
for fp_submission in fingerprints: | |
fp = fp_submission.fingerprint | |
exists = fp.hash in remote_fingerprints[fp.algorithm] | |
same_duration = lambda: remote_fingerprints[fp.algorithm][fp.hash] == fp.duration # noqa: E731 | |
if exists and same_duration(): | |
print(f'WARNING: [{fp.algorithm} | {fp.hash} | {fp.duration}] already exists') | |
to_remove.append(fp_submission) | |
for item in to_remove: | |
fingerprints.remove(item) | |
client.submit_fingerprints(fingerprints) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument('video_path', | |
help='Video file') | |
parser.add_argument('stash_scene_id', nargs='?', | |
help='Stash scene ID') | |
args = parser.parse_args(namespace=Arguments()) | |
run(args) | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment