Skip to content

Instantly share code, notes, and snippets.

@chottokun
Last active January 25, 2026 05:19
Show Gist options
  • Select an option

  • Save chottokun/af798278c56d46109c16d27aa64383ca to your computer and use it in GitHub Desktop.

Select an option

Save chottokun/af798278c56d46109c16d27aa64383ca to your computer and use it in GitHub Desktop.
ModernBERT_JP_CAUSE-EFFECT.ipynb
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"name": "ModernBERT_JP_CAUSE-EFFECT.ipynb",
"authorship_tag": "ABX9TyOwl6lTU9stRHzbPBEXnsF4",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/chottokun/af798278c56d46109c16d27aa64383ca/cause-effect.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "f1c542d9"
},
"source": [
"# コンテキスト指向 NER パイプライン\n",
"\n",
"このノートブックでは、`cl-nagoya/ruri-v3-pt-30m` モデルを使用して、文脈に基づいた固有表現抽出(Context NER)を行う一連の流れを実装しています。具体的には、以下の3つのステップで構成されています。\n",
"\n",
"1. **データセット準備 (Dataset Preparation)**:\n",
" 「原因 (CAUSE)」と「結果 (EFFECT)」の関係を含む日本語のトイデータを作成し、BIOタグ形式でアノテーションを行います。\n",
"\n",
"2. **ファインチューニング (Fine-tuning)**:\n",
" 作成したデータセットを用いて Ruri-v3 モデルを学習させます。データ量が少ないため、パターンを確実に学習(過学習)させるために20エポック回します。\n",
"\n",
"3. **推論 (Inference)**:\n",
" 学習済みモデルを使用して、新しい文章(例:「システム障害によりサービスが停止した」)から原因と結果を抽出できるか検証します。"
]
},
{
"cell_type": "code",
"metadata": {
"id": "123acae6"
},
"source": [
"# --- 0. 環境構築 (Setup) ---\n",
"# 必要なライブラリをインストール(既にインストール済みの場合はスキップされます)\n",
"!pip install -U \"transformers>=4.48.0\" accelerate datasets seqeval\n",
"\n",
"import os\n",
"# W&B (Weights & Biases) を無効化\n",
"os.environ[\"WANDB_DISABLED\"] = \"true\"\n",
"\n",
"import logging\n",
"import warnings\n",
"import torch\n",
"from datasets import Dataset\n",
"from transformers import (\n",
" AutoTokenizer,\n",
" AutoModelForTokenClassification,\n",
" TrainingArguments,\n",
" Trainer,\n",
" DataCollatorForTokenClassification,\n",
" pipeline,\n",
" logging as transformers_logging\n",
")\n",
"\n",
"# ログ出力を整理して見やすく設定\n",
"warnings.filterwarnings(\"ignore\")\n",
"transformers_logging.set_verbosity_error()\n",
"logging.getLogger(\"huggingface_hub\").setLevel(logging.ERROR)\n",
"\n",
"# --- 1. データセットの準備 (Data Preparation) ---\n",
"print(\"1. データセットを準備しています...\")\n",
"\n",
"# ラベルの定義 (0:その他, 1:原因開始, 2:原因継続, 3:結果開始, 4:結果継続)\n",
"id2label = {0: 'O', 1: 'B-CAUSE', 2: 'I-CAUSE', 3: 'B-EFFECT', 4: 'I-EFFECT'}\n",
"label2id = {v: k for k, v in id2label.items()}\n",
"\n",
"# サンプルデータ(ここを独自のデータに差し替えることで再学習可能)\n",
"japanese_toy_data = [\n",
" {\"tokens\": [\"大雨\", \"により\", \"\", \"\", \"氾濫\", \"した\", \"\"], \"ner_tags\": [1, 0, 3, 4, 4, 4, 0]},\n",
" {\"tokens\": [\"システム\", \"障害\", \"\", \"ため\", \"\", \"サービス\", \"\", \"停止\", \"している\", \"\"], \"ner_tags\": [1, 2, 0, 0, 0, 3, 4, 4, 4, 0]},\n",
" {\"tokens\": [\"寝不足\", \"\", \"集中力\", \"低下\", \"\", \"招く\", \"\"], \"ner_tags\": [1, 0, 3, 4, 0, 0, 0]},\n",
" {\"tokens\": [\"インフレ\", \"によって\", \"物価\", \"\", \"上昇\", \"した\", \"\"], \"ner_tags\": [1, 0, 3, 4, 4, 4, 0]},\n",
" {\"tokens\": [\"渋滞\", \"\", \"せいで\", \"到着\", \"\", \"遅れた\", \"\"], \"ner_tags\": [1, 0, 0, 3, 4, 4, 0]},\n",
" {\"tokens\": [\"火花\", \"\", \"火災\", \"\", \"原因\", \"\", \"なった\", \"\"], \"ner_tags\": [1, 0, 3, 4, 4, 0, 0, 0]},\n",
" {\"tokens\": [\"過労\", \"\", \"健康\", \"悪化\", \"\", \"つながる\", \"\"], \"ner_tags\": [1, 0, 3, 4, 0, 0, 0]},\n",
" {\"tokens\": [\"メモリ\", \"リーク\", \"により\", \"アプリ\", \"\", \"クラッシュ\", \"した\", \"\"], \"ner_tags\": [1, 2, 0, 3, 4, 4, 4, 0]},\n",
" {\"tokens\": [\"地球\", \"温暖化\", \"\", \"海面\", \"上昇\", \"\", \"引き起こす\", \"\"], \"ner_tags\": [1, 2, 0, 3, 4, 0, 0, 0]},\n",
" {\"tokens\": [\"ストレス\", \"\", \"\", \"\", \"痛く\", \"なる\", \"\"], \"ner_tags\": [1, 0, 3, 4, 4, 4, 0]},\n",
" {\"tokens\": [\"構造\", \"\", \"欠陥\", \"により\", \"\", \"\", \"崩落\", \"した\", \"\"], \"ner_tags\": [1, 2, 2, 0, 3, 4, 4, 4, 0]},\n",
" {\"tokens\": [\"猛暑\", \"のため\", \"作物\", \"\", \"枯れた\", \"\"], \"ner_tags\": [1, 0, 3, 4, 4, 0]},\n",
" {\"tokens\": [\"停電\", \"\", \"落雷\", \"\", \"原因\", \"だった\", \"\"], \"ner_tags\": [3, 0, 1, 0, 0, 0, 0]},\n",
" {\"tokens\": [\"栄養\", \"不足\", \"\", \"免疫力\", \"低下\", \"\", \"招く\", \"\"], \"ner_tags\": [1, 2, 0, 3, 4, 0, 0, 0]},\n",
" {\"tokens\": [\"古い\", \"ドライバ\", \"\", \"使う\", \"\", \"動作\", \"\", \"不安定\", \"\", \"なる\", \"\"], \"ner_tags\": [1, 2, 0, 0, 0, 3, 4, 4, 0, 0, 0]},\n",
" {\"tokens\": [\"政治\", \"\", \"不安\", \"\", \"株価\", \"下落\", \"\", \"招いた\", \"\"], \"ner_tags\": [1, 2, 2, 0, 3, 4, 0, 0, 0]},\n",
" {\"tokens\": [\"ウイルス\", \"感染\", \"により\", \"データ\", \"\", \"消失\", \"した\", \"\"], \"ner_tags\": [1, 2, 0, 3, 4, 4, 4, 0]},\n",
" {\"tokens\": [\"需要\", \"急増\", \"\", \"供給\", \"不足\", \"\", \"つながった\", \"\"], \"ner_tags\": [1, 2, 0, 3, 4, 0, 0, 0]},\n",
" {\"tokens\": [\"政策\", \"変更\", \"\", \"抗議\", \"デモ\", \"\", \"引き起こした\", \"\"], \"ner_tags\": [1, 2, 0, 3, 4, 0, 0, 0]},\n",
" {\"tokens\": [\"大きな\", \"\", \"\", \"\", \"\", \"驚いた\", \"\"], \"ner_tags\": [1, 2, 0, 3, 4, 4, 0]}\n",
"]\n",
"\n",
"dataset = Dataset.from_list(japanese_toy_data)\n",
"\n",
"# --- 2. モデルとトークナイザーの準備 (Model Loading) ---\n",
"model_id = \"cl-nagoya/ruri-v3-pt-30m\"\n",
"print(f\"{model_id} をロード中...\")\n",
"\n",
"tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)\n",
"model = AutoModelForTokenClassification.from_pretrained(\n",
" model_id,\n",
" num_labels=len(id2label),\n",
" id2label=id2label,\n",
" label2id=label2id,\n",
" trust_remote_code=True\n",
")\n",
"\n",
"# データの前処理(トークン化とラベル位置合わせ)\n",
"def tokenize_and_align_labels(examples):\n",
" tokenized_inputs = tokenizer(\n",
" examples[\"tokens\"],\n",
" truncation=True,\n",
" is_split_into_words=True,\n",
" max_length=8192 # Ruri-v3の長文脈対応能力\n",
" )\n",
" labels = []\n",
" for i, label in enumerate(examples[\"ner_tags\"]):\n",
" word_ids = tokenized_inputs.word_ids(batch_index=i)\n",
" previous_word_idx = None\n",
" label_ids = []\n",
" for word_idx in word_ids:\n",
" if word_idx is None:\n",
" label_ids.append(-100) # 特殊トークンは無視\n",
" elif word_idx != previous_word_idx:\n",
" label_ids.append(label[word_idx]) # 単語の先頭\n",
" else:\n",
" label_ids.append(-100) # 単語の途中\n",
" previous_word_idx = word_idx\n",
" labels.append(label_ids)\n",
" tokenized_inputs[\"labels\"] = labels\n",
" return tokenized_inputs\n",
"\n",
"tokenized_dataset = dataset.map(tokenize_and_align_labels, batched=True)\n",
"\n",
"# --- 3. ファインチューニング (Training) ---\n",
"print(\"\\n2. モデルの学習を開始します...\")\n",
"\n",
"training_args = TrainingArguments(\n",
" output_dir=\"ruri_final_demo\",\n",
" learning_rate=5e-5,\n",
" per_device_train_batch_size=4,\n",
" num_train_epochs=20, # データが少ないため多めに設定(本番データでは3-5程度でOK)\n",
" weight_decay=0.01,\n",
" eval_strategy=\"epoch\",\n",
" save_strategy=\"epoch\",\n",
" logging_steps=10,\n",
" report_to=\"none\" # WandBを無効化\n",
")\n",
"\n",
"trainer = Trainer(\n",
" model=model,\n",
" args=training_args,\n",
" train_dataset=tokenized_dataset,\n",
" eval_dataset=tokenized_dataset,\n",
" processing_class=tokenizer,\n",
" data_collator=DataCollatorForTokenClassification(tokenizer)\n",
")\n",
"\n",
"trainer.train()\n",
"trainer.save_model(\"./ruri_final_demo\")\n",
"print(\"学習完了。モデルを保存しました。\")\n",
"\n",
"# --- 4. 推論の検証 (Inference) ---\n",
"print(\"\\n3. 推論テストを実行します...\")\n",
"\n",
"device = 0 if torch.cuda.is_available() else -1\n",
"ner_pipeline = pipeline(\n",
" \"ner\",\n",
" model=\"./ruri_final_demo\",\n",
" tokenizer=\"./ruri_final_demo\",\n",
" aggregation_strategy=\"simple\",\n",
" trust_remote_code=True,\n",
" device=device\n",
")\n",
"\n",
"# テスト用入力文\n",
"text = \"システム障害によりサービスが停止した。\"\n",
"results = ner_pipeline(text)\n",
"\n",
"print(f\"\\n入力文: {text}\")\n",
"print(\"抽出結果:\")\n",
"if not results:\n",
" print(\"エンティティが検出されませんでした。\")\n",
"else:\n",
" for entity in results:\n",
" print(f\" - [{entity['entity_group']}] {entity['word']} (信頼度: {entity['score']:.4f})\")"
],
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment