Last active
April 26, 2025 11:56
-
-
Save chottokun/797453e27359664b5e7af364a41a0dcd to your computer and use it in GitHub Desktop.
Shuu12121/CodeSearch-ModernBERT-Crow-Plus
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"nbformat": 4, | |
"nbformat_minor": 0, | |
"metadata": { | |
"colab": { | |
"provenance": [], | |
"gpuType": "T4", | |
"name": "Shuu12121/CodeSearch-ModernBERT-Crow-Plus", | |
"authorship_tag": "ABX9TyP21U8nbqC7+N9+bX5lhYdW", | |
"include_colab_link": true | |
}, | |
"kernelspec": { | |
"name": "python3", | |
"display_name": "Python 3" | |
}, | |
"language_info": { | |
"name": "python" | |
}, | |
"accelerator": "GPU" | |
}, | |
"cells": [ | |
{ | |
"cell_type": "markdown", | |
"metadata": { | |
"id": "view-in-github", | |
"colab_type": "text" | |
}, | |
"source": [ | |
"<a href=\"https://colab.research.google.com/gist/chottokun/797453e27359664b5e7af364a41a0dcd/shuu12121-codeembed-modernbert-owl-preview.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "TbQW8vlIQaWf", | |
"outputId": "9bed0b7a-8d59-4acb-9388-9c64694dee8b" | |
}, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m58.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m34.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m45.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m80.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", | |
"\u001b[?25h" | |
] | |
} | |
], | |
"source": [ | |
"!pip install -U -q sentence-transformers" | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# cosine類似度計算" | |
], | |
"metadata": { | |
"id": "o-e8ZA4WpN9Q" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sentence_transformers import SentenceTransformer, util\n", | |
"\n", | |
"# モデルのロード\n", | |
"model_id = \"Shuu12121/CodeSearch-ModernBERT-Crow-Plus\"\n", | |
"model = SentenceTransformer(model_id)\n", | |
"\n", | |
"# コードスニペットの定義\n", | |
"code_snippet_1 = \"\"\"\n", | |
"def add(a, b):\n", | |
" return a + b\n", | |
"\"\"\"\n", | |
"\n", | |
"code_snippet_2 = \"\"\"\n", | |
"def sum_numbers(x, y):\n", | |
" return x + y\n", | |
"\"\"\"\n", | |
"\n", | |
"# コードスニペットをベクトルにエンコード\n", | |
"embedding_1 = model.encode(code_snippet_1, convert_to_tensor=True)\n", | |
"embedding_2 = model.encode(code_snippet_2, convert_to_tensor=True)\n", | |
"\n", | |
"# コサイン類似度の計算\n", | |
"cosine_similarity = util.pytorch_cos_sim(embedding_1, embedding_2)\n", | |
"\n", | |
"print(f\"Cosine Similarity: {cosine_similarity.item():.4f}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "B6_iV2rYZ3mp", | |
"outputId": "40db7121-9780-4209-ca9a-c056f7a28946" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Cosine Similarity: 0.6156\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# コードスニペットの定義\n", | |
"code_snippet_1 = \"\"\"\n", | |
"func (v *View) Render(m ...string) trees.Markup {\n", | |
"if len(m) <= 0 {\n", | |
"m = []string{\".\"}\n", | |
"}\n", | |
"\"\"\"\n", | |
"\n", | |
"code_snippet_2 = \"\"\"\n", | |
"def sum_numbers(x, y):\n", | |
" return x + y\n", | |
"\"\"\"\n", | |
"\n", | |
"# コードスニペットをベクトルにエンコード\n", | |
"embedding_1 = model.encode(code_snippet_1, convert_to_tensor=True)\n", | |
"embedding_2 = model.encode(code_snippet_2, convert_to_tensor=True)\n", | |
"\n", | |
"# コサイン類似度の計算\n", | |
"cosine_similarity = util.pytorch_cos_sim(embedding_1, embedding_2)\n", | |
"\n", | |
"print(f\"Cosine Similarity: {cosine_similarity.item():.4f}\")" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "rw8ULhbnaHpO", | |
"outputId": "fa51947d-acce-40e5-8d77-abeea6a31d6c" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"Cosine Similarity: 0.0063\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "markdown", | |
"source": [ | |
"# 類似コードのテスト2" | |
], | |
"metadata": { | |
"id": "u_PqMlVZl4y4" | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"from sentence_transformers import SentenceTransformer, util\n", | |
"\n", | |
"# モデルのロード(embeddingモデルとして利用)\n", | |
"model_id = \"Shuu12121/CodeEmbed-ModernBERT-Owl-Preview\"\n", | |
"model = SentenceTransformer(model_id)\n", | |
"\n", | |
"# サンプルのコードスニペットとそのコメント(要約)のデータセット\n", | |
"dataset = [\n", | |
" {\n", | |
" \"code\": \"\"\"\n", | |
"def add(a, b):\n", | |
" return a + b\n", | |
"\"\"\",\n", | |
" \"comment\": \"2つの数値を加算して結果を返す関数です。\"\n", | |
" },\n", | |
" {\n", | |
" \"code\": \"\"\"\n", | |
"def factorial(n):\n", | |
" if n == 0:\n", | |
" return 1\n", | |
" else:\n", | |
" return n * factorial(n - 1)\n", | |
"\"\"\",\n", | |
" \"comment\": \"再帰的にnの階乗を計算する関数です。\"\n", | |
" },\n", | |
" {\n", | |
" \"code\": \"\"\"\n", | |
"def is_even(n):\n", | |
" return n % 2 == 0\n", | |
"\"\"\",\n", | |
" \"comment\": \"与えられた数が偶数かどうかを判定する関数です。\"\n", | |
" }\n", | |
"]\n", | |
"\n", | |
"# データセット中のコードの埋め込みを事前計算\n", | |
"dataset_codes = [item[\"code\"] for item in dataset]\n", | |
"dataset_comments = [item[\"comment\"] for item in dataset]\n", | |
"dataset_embeddings = model.encode(dataset_codes, convert_to_tensor=True)\n", | |
"\n", | |
"def generate_comment_for_code(input_code, model, dataset_embeddings, dataset_comments):\n", | |
" # 入力コードの埋め込みを取得\n", | |
" input_embedding = model.encode(input_code, convert_to_tensor=True)\n", | |
" # 各データセットコードとのコサイン類似度を計算\n", | |
" cosine_scores = util.pytorch_cos_sim(input_embedding, dataset_embeddings)\n", | |
" # 類似度が最も高いコードのインデックスを取得\n", | |
" best_idx = cosine_scores.argmax()\n", | |
" return dataset_comments[best_idx], cosine_scores[0][best_idx].item()\n", | |
"\n", | |
"# テスト用の入力コードスニペット\n", | |
"input_code = \"\"\"\n", | |
"def sum_numbers(x, y):\n", | |
" return x + y\n", | |
"\"\"\"\n", | |
"\n", | |
"# コメント生成(類似コードに基づく要約)\n", | |
"generated_comment, score = generate_comment_for_code(input_code, model, dataset_embeddings, dataset_comments)\n", | |
"print(\"生成されたコメント:\", generated_comment)\n", | |
"print(\"類似度スコア:\", score)\n" | |
], | |
"metadata": { | |
"colab": { | |
"base_uri": "https://localhost:8080/" | |
}, | |
"id": "tzmVSpkDly4H", | |
"outputId": "c33af416-82be-407f-a42b-425ae34f2776" | |
}, | |
"execution_count": null, | |
"outputs": [ | |
{ | |
"output_type": "stream", | |
"name": "stdout", | |
"text": [ | |
"生成されたコメント: 2つの数値を加算して結果を返す関数です。\n", | |
"類似度スコア: 0.6156468987464905\n" | |
] | |
} | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"source": [ | |
"# from google.colab import runtime\n", | |
"# runtime.unassign()" | |
], | |
"metadata": { | |
"id": "CjQfniWHlNuj" | |
}, | |
"execution_count": null, | |
"outputs": [] | |
} | |
] | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment