Created
March 23, 2025 01:37
-
-
Save laksjdjf/52d01161412478a602dd6c8540335ad3 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import gradio as gr | |
import pandas as pd | |
import random | |
query_general_cache = None | |
query_character_cache = None | |
df = pd.read_csv("https://huggingface.co/datasets/furusu/aesthetic_score_danbooru2024/resolve/main/part/aes6_5.csv") | |
#df = pd.read_csv("aes6_5.csv") | |
df[["tags", "characters"]] = df[["tags", "characters"]].astype(str) | |
target_df = None | |
answers = None | |
total = 0 | |
correct = 0 | |
tags = [] | |
wrong_tags = [] | |
def sample_unique_items(list_a, list_b, sample_size=5): | |
global answers | |
unique_a = [item for item in list_a if item not in list_b] | |
unique_b = [item for item in list_b if item not in list_a] | |
unique_items = [(item, 'left') for item in unique_a] + [(item, 'right') for item in unique_b] | |
sampled = random.sample(unique_items, min(sample_size, len(unique_items))) | |
items = [item for item, _ in sampled] + [None] * (sample_size - len(sampled)) | |
answers = [answer for _, answer in sampled] + [None] * (sample_size - len(sampled)) | |
return items | |
def has_all_required_tags(tag_string, required_tags): | |
tag_set = set(tag_string.split(',')) | |
return required_tags.issubset(tag_set) | |
def get_images(query_general, query_character): | |
global query_general_cache, query_character_cache, df, target_df, tags | |
if query_general != query_general_cache or query_character != query_character_cache: | |
query_general_cache = set(tag.strip() for tag in query_general.split(',')) - set([""]) | |
query_character_cache = set(tag.strip() for tag in query_character.split(',')) - set([""]) | |
if len(query_general_cache) > 0: | |
target_df = df[df["tags"].apply(lambda x: has_all_required_tags(x, query_general_cache))] | |
else: | |
target_df = df | |
if len(query_character_cache) > 0: | |
target_df = target_df[target_df["characters"].apply(lambda x: has_all_required_tags(x, query_character_cache))] | |
sample = target_df.sample(2) | |
images = [(row["file_url"], str(row["id"])) for _, row in sample.iterrows()] | |
left_tags = sample["tags"].iloc[0].split(",") | |
right_tags = sample["tags"].iloc[1].split(",") | |
sampled_tags = sample_unique_items(left_tags, right_tags) | |
tags = sampled_tags | |
return ( | |
images, | |
*sampled_tags, | |
*[gr.update(value=None, interactive=True) for i in range(5)], | |
*["" for i in range(5)] | |
) | |
def submit(*player_answers): | |
global answers, total, correct, wrong_tags | |
output = [] | |
for player_answer, correct_answer, tag in zip(player_answers, answers, tags): | |
if correct_answer is not None: | |
total += 1 | |
if player_answer == correct_answer: | |
output.append(f"**正解**だよ!おめでとう! 知ってると思うけど⇒[{tag}](https://danbooru.donmai.us/wiki_pages/{tag}?)") | |
correct += 1 | |
elif player_answer is None: | |
output.append(f"ちゃんと答えてね! [{tag}](https://danbooru.donmai.us/wiki_pages/{tag}?) ") | |
total -= 1 | |
else: | |
output.append(f"ばーか!**不正解**だよ!ちゃんと勉強しろ⇒ [{tag}](https://danbooru.donmai.us/wiki_pages/{tag}?)") | |
wrong_tags.append(tag) | |
else: | |
output.append("") | |
info = f"**{total}** 問中 **{correct}** 問正解! 正解率は **{correct / total * 100:.2f}%** だよ!" | |
if len(wrong_tags) > 0: | |
info += f"\n間違えたタグは:\n {'\n'.join(wrong_tags)} だよ!" | |
return (*output, info) | |
def reset(): | |
global total, correct, wrong_tags | |
total = 0 | |
correct = 0 | |
wrong_tags = [] | |
return "" | |
with gr.Blocks() as demo: | |
with gr.Row(): | |
query_general = gr.Textbox("1girl, solo", label="query_general") | |
query_character = gr.Textbox("", label="query_character") | |
images = gr.Gallery(columns=2) | |
radios = [] | |
tags = [] | |
results = [] | |
for i in range(5): | |
with gr.Row(): | |
radio = gr.Radio(["left", "right"], label=f"radio_{i}", scale=0.3) | |
tag = gr.Textbox("", label=f"tag_{i}", scale=0.7) | |
radios.append(radio) | |
tags.append(tag) | |
result = gr.Markdown("") | |
results.append(result) | |
with gr.Row(): | |
new_button = gr.Button("New", variant="primary") | |
submit_button = gr.Button("Submit", variant="secondary") | |
reset_button = gr.Button("Reset") | |
info = gr.Markdown("") | |
new_button.click( | |
get_images, | |
inputs=[query_general, query_character], | |
outputs=[images] + tags + radios + results | |
) | |
submit_button.click( | |
submit, | |
inputs=radios, | |
outputs=results + [info] | |
) | |
reset_button.click( | |
reset, | |
outputs=[info] | |
) | |
demo.launch() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment