Last active
January 25, 2026 05:19
-
-
Save chottokun/af798278c56d46109c16d27aa64383ca to your computer and use it in GitHub Desktop.
ModernBERT_JP_CAUSE-EFFECT.ipynb
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| { | |
| "nbformat": 4, | |
| "nbformat_minor": 0, | |
| "metadata": { | |
| "colab": { | |
| "provenance": [], | |
| "gpuType": "T4", | |
| "name": "ModernBERT_JP_CAUSE-EFFECT.ipynb", | |
| "authorship_tag": "ABX9TyOwl6lTU9stRHzbPBEXnsF4", | |
| "include_colab_link": true | |
| }, | |
| "kernelspec": { | |
| "name": "python3", | |
| "display_name": "Python 3" | |
| }, | |
| "language_info": { | |
| "name": "python" | |
| }, | |
| "accelerator": "GPU" | |
| }, | |
| "cells": [ | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "view-in-github", | |
| "colab_type": "text" | |
| }, | |
| "source": [ | |
| "<a href=\"https://colab.research.google.com/gist/chottokun/af798278c56d46109c16d27aa64383ca/cause-effect.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
| ] | |
| }, | |
| { | |
| "cell_type": "markdown", | |
| "metadata": { | |
| "id": "f1c542d9" | |
| }, | |
| "source": [ | |
| "# コンテキスト指向 NER パイプライン\n", | |
| "\n", | |
| "このノートブックでは、`cl-nagoya/ruri-v3-pt-30m` モデルを使用して、文脈に基づいた固有表現抽出(Context NER)を行う一連の流れを実装しています。具体的には、以下の3つのステップで構成されています。\n", | |
| "\n", | |
| "1. **データセット準備 (Dataset Preparation)**:\n", | |
| " 「原因 (CAUSE)」と「結果 (EFFECT)」の関係を含む日本語のトイデータを作成し、BIOタグ形式でアノテーションを行います。\n", | |
| "\n", | |
| "2. **ファインチューニング (Fine-tuning)**:\n", | |
| " 作成したデータセットを用いて Ruri-v3 モデルを学習させます。データ量が少ないため、パターンを確実に学習(過学習)させるために20エポック回します。\n", | |
| "\n", | |
| "3. **推論 (Inference)**:\n", | |
| " 学習済みモデルを使用して、新しい文章(例:「システム障害によりサービスが停止した」)から原因と結果を抽出できるか検証します。" | |
| ] | |
| }, | |
| { | |
| "cell_type": "code", | |
| "metadata": { | |
| "id": "123acae6" | |
| }, | |
| "source": [ | |
| "# --- 0. 環境構築 (Setup) ---\n", | |
| "# 必要なライブラリをインストール(既にインストール済みの場合はスキップされます)\n", | |
| "!pip install -U \"transformers>=4.48.0\" accelerate datasets seqeval\n", | |
| "\n", | |
| "import os\n", | |
| "# W&B (Weights & Biases) を無効化\n", | |
| "os.environ[\"WANDB_DISABLED\"] = \"true\"\n", | |
| "\n", | |
| "import logging\n", | |
| "import warnings\n", | |
| "import torch\n", | |
| "from datasets import Dataset\n", | |
| "from transformers import (\n", | |
| " AutoTokenizer,\n", | |
| " AutoModelForTokenClassification,\n", | |
| " TrainingArguments,\n", | |
| " Trainer,\n", | |
| " DataCollatorForTokenClassification,\n", | |
| " pipeline,\n", | |
| " logging as transformers_logging\n", | |
| ")\n", | |
| "\n", | |
| "# ログ出力を整理して見やすく設定\n", | |
| "warnings.filterwarnings(\"ignore\")\n", | |
| "transformers_logging.set_verbosity_error()\n", | |
| "logging.getLogger(\"huggingface_hub\").setLevel(logging.ERROR)\n", | |
| "\n", | |
| "# --- 1. データセットの準備 (Data Preparation) ---\n", | |
| "print(\"1. データセットを準備しています...\")\n", | |
| "\n", | |
| "# ラベルの定義 (0:その他, 1:原因開始, 2:原因継続, 3:結果開始, 4:結果継続)\n", | |
| "id2label = {0: 'O', 1: 'B-CAUSE', 2: 'I-CAUSE', 3: 'B-EFFECT', 4: 'I-EFFECT'}\n", | |
| "label2id = {v: k for k, v in id2label.items()}\n", | |
| "\n", | |
| "# サンプルデータ(ここを独自のデータに差し替えることで再学習可能)\n", | |
| "japanese_toy_data = [\n", | |
| " {\"tokens\": [\"大雨\", \"により\", \"川\", \"が\", \"氾濫\", \"した\", \"。\"], \"ner_tags\": [1, 0, 3, 4, 4, 4, 0]},\n", | |
| " {\"tokens\": [\"システム\", \"障害\", \"の\", \"ため\", \"、\", \"サービス\", \"が\", \"停止\", \"している\", \"。\"], \"ner_tags\": [1, 2, 0, 0, 0, 3, 4, 4, 4, 0]},\n", | |
| " {\"tokens\": [\"寝不足\", \"は\", \"集中力\", \"低下\", \"を\", \"招く\", \"。\"], \"ner_tags\": [1, 0, 3, 4, 0, 0, 0]},\n", | |
| " {\"tokens\": [\"インフレ\", \"によって\", \"物価\", \"が\", \"上昇\", \"した\", \"。\"], \"ner_tags\": [1, 0, 3, 4, 4, 4, 0]},\n", | |
| " {\"tokens\": [\"渋滞\", \"の\", \"せいで\", \"到着\", \"が\", \"遅れた\", \"。\"], \"ner_tags\": [1, 0, 0, 3, 4, 4, 0]},\n", | |
| " {\"tokens\": [\"火花\", \"が\", \"火災\", \"の\", \"原因\", \"と\", \"なった\", \"。\"], \"ner_tags\": [1, 0, 3, 4, 4, 0, 0, 0]},\n", | |
| " {\"tokens\": [\"過労\", \"が\", \"健康\", \"悪化\", \"に\", \"つながる\", \"。\"], \"ner_tags\": [1, 0, 3, 4, 0, 0, 0]},\n", | |
| " {\"tokens\": [\"メモリ\", \"リーク\", \"により\", \"アプリ\", \"が\", \"クラッシュ\", \"した\", \"。\"], \"ner_tags\": [1, 2, 0, 3, 4, 4, 4, 0]},\n", | |
| " {\"tokens\": [\"地球\", \"温暖化\", \"は\", \"海面\", \"上昇\", \"を\", \"引き起こす\", \"。\"], \"ner_tags\": [1, 2, 0, 3, 4, 0, 0, 0]},\n", | |
| " {\"tokens\": [\"ストレス\", \"で\", \"胃\", \"が\", \"痛く\", \"なる\", \"。\"], \"ner_tags\": [1, 0, 3, 4, 4, 4, 0]},\n", | |
| " {\"tokens\": [\"構造\", \"的\", \"欠陥\", \"により\", \"橋\", \"が\", \"崩落\", \"した\", \"。\"], \"ner_tags\": [1, 2, 2, 0, 3, 4, 4, 4, 0]},\n", | |
| " {\"tokens\": [\"猛暑\", \"のため\", \"作物\", \"が\", \"枯れた\", \"。\"], \"ner_tags\": [1, 0, 3, 4, 4, 0]},\n", | |
| " {\"tokens\": [\"停電\", \"は\", \"落雷\", \"が\", \"原因\", \"だった\", \"。\"], \"ner_tags\": [3, 0, 1, 0, 0, 0, 0]},\n", | |
| " {\"tokens\": [\"栄養\", \"不足\", \"は\", \"免疫力\", \"低下\", \"を\", \"招く\", \"。\"], \"ner_tags\": [1, 2, 0, 3, 4, 0, 0, 0]},\n", | |
| " {\"tokens\": [\"古い\", \"ドライバ\", \"を\", \"使う\", \"と\", \"動作\", \"が\", \"不安定\", \"に\", \"なる\", \"。\"], \"ner_tags\": [1, 2, 0, 0, 0, 3, 4, 4, 0, 0, 0]},\n", | |
| " {\"tokens\": [\"政治\", \"的\", \"不安\", \"が\", \"株価\", \"下落\", \"を\", \"招いた\", \"。\"], \"ner_tags\": [1, 2, 2, 0, 3, 4, 0, 0, 0]},\n", | |
| " {\"tokens\": [\"ウイルス\", \"感染\", \"により\", \"データ\", \"が\", \"消失\", \"した\", \"。\"], \"ner_tags\": [1, 2, 0, 3, 4, 4, 4, 0]},\n", | |
| " {\"tokens\": [\"需要\", \"急増\", \"が\", \"供給\", \"不足\", \"に\", \"つながった\", \"。\"], \"ner_tags\": [1, 2, 0, 3, 4, 0, 0, 0]},\n", | |
| " {\"tokens\": [\"政策\", \"変更\", \"は\", \"抗議\", \"デモ\", \"を\", \"引き起こした\", \"。\"], \"ner_tags\": [1, 2, 0, 3, 4, 0, 0, 0]},\n", | |
| " {\"tokens\": [\"大きな\", \"音\", \"に\", \"猫\", \"が\", \"驚いた\", \"。\"], \"ner_tags\": [1, 2, 0, 3, 4, 4, 0]}\n", | |
| "]\n", | |
| "\n", | |
| "dataset = Dataset.from_list(japanese_toy_data)\n", | |
| "\n", | |
| "# --- 2. モデルとトークナイザーの準備 (Model Loading) ---\n", | |
| "model_id = \"cl-nagoya/ruri-v3-pt-30m\"\n", | |
| "print(f\"{model_id} をロード中...\")\n", | |
| "\n", | |
| "tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n", | |
| "model = AutoModelForTokenClassification.from_pretrained(\n", | |
| " model_id,\n", | |
| " num_labels=len(id2label),\n", | |
| " id2label=id2label,\n", | |
| " label2id=label2id,\n", | |
| " trust_remote_code=True\n", | |
| ")\n", | |
| "\n", | |
| "# データの前処理(トークン化とラベル位置合わせ)\n", | |
| "def tokenize_and_align_labels(examples):\n", | |
| " tokenized_inputs = tokenizer(\n", | |
| " examples[\"tokens\"],\n", | |
| " truncation=True,\n", | |
| " is_split_into_words=True,\n", | |
| " max_length=8192 # Ruri-v3の長文脈対応能力\n", | |
| " )\n", | |
| " labels = []\n", | |
| " for i, label in enumerate(examples[\"ner_tags\"]):\n", | |
| " word_ids = tokenized_inputs.word_ids(batch_index=i)\n", | |
| " previous_word_idx = None\n", | |
| " label_ids = []\n", | |
| " for word_idx in word_ids:\n", | |
| " if word_idx is None:\n", | |
| " label_ids.append(-100) # 特殊トークンは無視\n", | |
| " elif word_idx != previous_word_idx:\n", | |
| " label_ids.append(label[word_idx]) # 単語の先頭\n", | |
| " else:\n", | |
| " label_ids.append(-100) # 単語の途中\n", | |
| " previous_word_idx = word_idx\n", | |
| " labels.append(label_ids)\n", | |
| " tokenized_inputs[\"labels\"] = labels\n", | |
| " return tokenized_inputs\n", | |
| "\n", | |
| "tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)\n", | |
| "\n", | |
| "# --- 3. ファインチューニング (Training) ---\n", | |
| "print(\"\\n2. モデルの学習を開始します...\")\n", | |
| "\n", | |
| "training_args = TrainingArguments(\n", | |
| " output_dir=\"ruri_final_demo\",\n", | |
| " learning_rate=5e-5,\n", | |
| " per_device_train_batch_size=4,\n", | |
| " num_train_epochs=20, # データが少ないため多めに設定(本番データでは3-5程度でOK)\n", | |
| " weight_decay=0.01,\n", | |
| " eval_strategy=\"epoch\",\n", | |
| " save_strategy=\"epoch\",\n", | |
| " logging_steps=10,\n", | |
| " report_to=\"none\" # WandBを無効化\n", | |
| ")\n", | |
| "\n", | |
| "trainer = Trainer(\n", | |
| " model=model,\n", | |
| " args=training_args,\n", | |
| " train_dataset=tokenized_dataset,\n", | |
| " eval_dataset=tokenized_dataset,\n", | |
| " processing_class=tokenizer,\n", | |
| " data_collator=DataCollatorForTokenClassification(tokenizer)\n", | |
| ")\n", | |
| "\n", | |
| "trainer.train()\n", | |
| "trainer.save_model(\"./ruri_final_demo\")\n", | |
| "print(\"学習完了。モデルを保存しました。\")\n", | |
| "\n", | |
| "# --- 4. 推論の検証 (Inference) ---\n", | |
| "print(\"\\n3. 推論テストを実行します...\")\n", | |
| "\n", | |
| "device = 0 if torch.cuda.is_available() else -1\n", | |
| "ner_pipeline = pipeline(\n", | |
| " \"ner\",\n", | |
| " model=\"./ruri_final_demo\",\n", | |
| " tokenizer=\"./ruri_final_demo\",\n", | |
| " aggregation_strategy=\"simple\",\n", | |
| " trust_remote_code=True,\n", | |
| " device=device\n", | |
| ")\n", | |
| "\n", | |
| "# テスト用入力文\n", | |
| "text = \"システム障害によりサービスが停止した。\"\n", | |
| "results = ner_pipeline(text)\n", | |
| "\n", | |
| "print(f\"\\n入力文: {text}\")\n", | |
| "print(\"抽出結果:\")\n", | |
| "if not results:\n", | |
| " print(\"エンティティが検出されませんでした。\")\n", | |
| "else:\n", | |
| " for entity in results:\n", | |
| " print(f\" - [{entity['entity_group']}] {entity['word']} (信頼度: {entity['score']:.4f})\")" | |
| ], | |
| "execution_count": null, | |
| "outputs": [] | |
| } | |
| ] | |
| } |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment