Skip to content

Instantly share code, notes, and snippets.

@laksjdjf
Created March 23, 2025 01:37
Show Gist options
  • Save laksjdjf/52d01161412478a602dd6c8540335ad3 to your computer and use it in GitHub Desktop.
Save laksjdjf/52d01161412478a602dd6c8540335ad3 to your computer and use it in GitHub Desktop.
import gradio as gr
import pandas as pd
import random
query_general_cache = None
query_character_cache = None
df = pd.read_csv("https://huggingface.co/datasets/furusu/aesthetic_score_danbooru2024/resolve/main/part/aes6_5.csv")
#df = pd.read_csv("aes6_5.csv")
df[["tags", "characters"]] = df[["tags", "characters"]].astype(str)
target_df = None
answers = None
total = 0
correct = 0
tags = []
wrong_tags = []
def sample_unique_items(list_a, list_b, sample_size=5):
global answers
unique_a = [item for item in list_a if item not in list_b]
unique_b = [item for item in list_b if item not in list_a]
unique_items = [(item, 'left') for item in unique_a] + [(item, 'right') for item in unique_b]
sampled = random.sample(unique_items, min(sample_size, len(unique_items)))
items = [item for item, _ in sampled] + [None] * (sample_size - len(sampled))
answers = [answer for _, answer in sampled] + [None] * (sample_size - len(sampled))
return items
def has_all_required_tags(tag_string, required_tags):
tag_set = set(tag_string.split(','))
return required_tags.issubset(tag_set)
def get_images(query_general, query_character):
global query_general_cache, query_character_cache, df, target_df, tags
if query_general != query_general_cache or query_character != query_character_cache:
query_general_cache = set(tag.strip() for tag in query_general.split(',')) - set([""])
query_character_cache = set(tag.strip() for tag in query_character.split(',')) - set([""])
if len(query_general_cache) > 0:
target_df = df[df["tags"].apply(lambda x: has_all_required_tags(x, query_general_cache))]
else:
target_df = df
if len(query_character_cache) > 0:
target_df = target_df[target_df["characters"].apply(lambda x: has_all_required_tags(x, query_character_cache))]
sample = target_df.sample(2)
images = [(row["file_url"], str(row["id"])) for _, row in sample.iterrows()]
left_tags = sample["tags"].iloc[0].split(",")
right_tags = sample["tags"].iloc[1].split(",")
sampled_tags = sample_unique_items(left_tags, right_tags)
tags = sampled_tags
return (
images,
*sampled_tags,
*[gr.update(value=None, interactive=True) for i in range(5)],
*["" for i in range(5)]
)
def submit(*player_answers):
global answers, total, correct, wrong_tags
output = []
for player_answer, correct_answer, tag in zip(player_answers, answers, tags):
if correct_answer is not None:
total += 1
if player_answer == correct_answer:
output.append(f"**正解**だよ!おめでとう! 知ってると思うけど⇒[{tag}](https://danbooru.donmai.us/wiki_pages/{tag}?)")
correct += 1
elif player_answer is None:
output.append(f"ちゃんと答えてね! [{tag}](https://danbooru.donmai.us/wiki_pages/{tag}?) ")
total -= 1
else:
output.append(f"ばーか!**不正解**だよ!ちゃんと勉強しろ⇒ [{tag}](https://danbooru.donmai.us/wiki_pages/{tag}?)")
wrong_tags.append(tag)
else:
output.append("")
info = f"**{total}** 問中 **{correct}** 問正解! 正解率は **{correct / total * 100:.2f}%** だよ!"
if len(wrong_tags) > 0:
info += f"\n間違えたタグは:\n {'\n'.join(wrong_tags)} だよ!"
return (*output, info)
def reset():
global total, correct, wrong_tags
total = 0
correct = 0
wrong_tags = []
return ""
with gr.Blocks() as demo:
with gr.Row():
query_general = gr.Textbox("1girl, solo", label="query_general")
query_character = gr.Textbox("", label="query_character")
images = gr.Gallery(columns=2)
radios = []
tags = []
results = []
for i in range(5):
with gr.Row():
radio = gr.Radio(["left", "right"], label=f"radio_{i}", scale=0.3)
tag = gr.Textbox("", label=f"tag_{i}", scale=0.7)
radios.append(radio)
tags.append(tag)
result = gr.Markdown("")
results.append(result)
with gr.Row():
new_button = gr.Button("New", variant="primary")
submit_button = gr.Button("Submit", variant="secondary")
reset_button = gr.Button("Reset")
info = gr.Markdown("")
new_button.click(
get_images,
inputs=[query_general, query_character],
outputs=[images] + tags + radios + results
)
submit_button.click(
submit,
inputs=radios,
outputs=results + [info]
)
reset_button.click(
reset,
outputs=[info]
)
demo.launch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment