Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save chottokun/797453e27359664b5e7af364a41a0dcd to your computer and use it in GitHub Desktop.
Save chottokun/797453e27359664b5e7af364a41a0dcd to your computer and use it in GitHub Desktop.
Shuu12121/CodeSearch-ModernBERT-Crow-Plus
Display the source blob
Display the rendered blob
Raw
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4",
"name": "Shuu12121/CodeSearch-ModernBERT-Crow-Plus",
"authorship_tag": "ABX9TyP21U8nbqC7+N9+bX5lhYdW",
"include_colab_link": true
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "markdown",
"metadata": {
"id": "view-in-github",
"colab_type": "text"
},
"source": [
"<a href=\"https://colab.research.google.com/gist/chottokun/797453e27359664b5e7af364a41a0dcd/shuu12121-codeembed-modernbert-owl-preview.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "TbQW8vlIQaWf",
"outputId": "9bed0b7a-8d59-4acb-9388-9c64694dee8b"
},
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m1.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m58.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m34.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m45.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m664.8/664.8 MB\u001b[0m \u001b[31m2.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m211.5/211.5 MB\u001b[0m \u001b[31m6.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m56.3/56.3 MB\u001b[0m \u001b[31m13.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m127.9/127.9 MB\u001b[0m \u001b[31m11.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m207.5/207.5 MB\u001b[0m \u001b[31m6.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m21.1/21.1 MB\u001b[0m \u001b[31m80.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[?25h"
]
}
],
"source": [
"!pip install -U -q sentence-transformers"
]
},
{
"cell_type": "markdown",
"source": [
"# cosine類似度計算"
],
"metadata": {
"id": "o-e8ZA4WpN9Q"
}
},
{
"cell_type": "code",
"source": [
"from sentence_transformers import SentenceTransformer, util\n",
"\n",
"# モデルのロード\n",
"model_id = \"Shuu12121/CodeSearch-ModernBERT-Crow-Plus\"\n",
"model = SentenceTransformer(model_id)\n",
"\n",
"# コードスニペットの定義\n",
"code_snippet_1 = \"\"\"\n",
"def add(a, b):\n",
" return a + b\n",
"\"\"\"\n",
"\n",
"code_snippet_2 = \"\"\"\n",
"def sum_numbers(x, y):\n",
" return x + y\n",
"\"\"\"\n",
"\n",
"# コードスニペットをベクトルにエンコード\n",
"embedding_1 = model.encode(code_snippet_1, convert_to_tensor=True)\n",
"embedding_2 = model.encode(code_snippet_2, convert_to_tensor=True)\n",
"\n",
"# コサイン類似度の計算\n",
"cosine_similarity = util.pytorch_cos_sim(embedding_1, embedding_2)\n",
"\n",
"print(f\"Cosine Similarity: {cosine_similarity.item():.4f}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "B6_iV2rYZ3mp",
"outputId": "40db7121-9780-4209-ca9a-c056f7a28946"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cosine Similarity: 0.6156\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# コードスニペットの定義\n",
"code_snippet_1 = \"\"\"\n",
"func (v *View) Render(m ...string) trees.Markup {\n",
"if len(m) <= 0 {\n",
"m = []string{\".\"}\n",
"}\n",
"\"\"\"\n",
"\n",
"code_snippet_2 = \"\"\"\n",
"def sum_numbers(x, y):\n",
" return x + y\n",
"\"\"\"\n",
"\n",
"# コードスニペットをベクトルにエンコード\n",
"embedding_1 = model.encode(code_snippet_1, convert_to_tensor=True)\n",
"embedding_2 = model.encode(code_snippet_2, convert_to_tensor=True)\n",
"\n",
"# コサイン類似度の計算\n",
"cosine_similarity = util.pytorch_cos_sim(embedding_1, embedding_2)\n",
"\n",
"print(f\"Cosine Similarity: {cosine_similarity.item():.4f}\")"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "rw8ULhbnaHpO",
"outputId": "fa51947d-acce-40e5-8d77-abeea6a31d6c"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"Cosine Similarity: 0.0063\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"# 類似コードのテスト2"
],
"metadata": {
"id": "u_PqMlVZl4y4"
}
},
{
"cell_type": "code",
"source": [
"from sentence_transformers import SentenceTransformer, util\n",
"\n",
"# モデルのロード(embeddingモデルとして利用)\n",
"model_id = \"Shuu12121/CodeEmbed-ModernBERT-Owl-Preview\"\n",
"model = SentenceTransformer(model_id)\n",
"\n",
"# サンプルのコードスニペットとそのコメント(要約)のデータセット\n",
"dataset = [\n",
" {\n",
" \"code\": \"\"\"\n",
"def add(a, b):\n",
" return a + b\n",
"\"\"\",\n",
" \"comment\": \"2つの数値を加算して結果を返す関数です。\"\n",
" },\n",
" {\n",
" \"code\": \"\"\"\n",
"def factorial(n):\n",
" if n == 0:\n",
" return 1\n",
" else:\n",
" return n * factorial(n - 1)\n",
"\"\"\",\n",
" \"comment\": \"再帰的にnの階乗を計算する関数です。\"\n",
" },\n",
" {\n",
" \"code\": \"\"\"\n",
"def is_even(n):\n",
" return n % 2 == 0\n",
"\"\"\",\n",
" \"comment\": \"与えられた数が偶数かどうかを判定する関数です。\"\n",
" }\n",
"]\n",
"\n",
"# データセット中のコードの埋め込みを事前計算\n",
"dataset_codes = [item[\"code\"] for item in dataset]\n",
"dataset_comments = [item[\"comment\"] for item in dataset]\n",
"dataset_embeddings = model.encode(dataset_codes, convert_to_tensor=True)\n",
"\n",
"def generate_comment_for_code(input_code, model, dataset_embeddings, dataset_comments):\n",
" # 入力コードの埋め込みを取得\n",
" input_embedding = model.encode(input_code, convert_to_tensor=True)\n",
" # 各データセットコードとのコサイン類似度を計算\n",
" cosine_scores = util.pytorch_cos_sim(input_embedding, dataset_embeddings)\n",
" # 類似度が最も高いコードのインデックスを取得\n",
" best_idx = cosine_scores.argmax()\n",
" return dataset_comments[best_idx], cosine_scores[0][best_idx].item()\n",
"\n",
"# テスト用の入力コードスニペット\n",
"input_code = \"\"\"\n",
"def sum_numbers(x, y):\n",
" return x + y\n",
"\"\"\"\n",
"\n",
"# コメント生成(類似コードに基づく要約)\n",
"generated_comment, score = generate_comment_for_code(input_code, model, dataset_embeddings, dataset_comments)\n",
"print(\"生成されたコメント:\", generated_comment)\n",
"print(\"類似度スコア:\", score)\n"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "tzmVSpkDly4H",
"outputId": "c33af416-82be-407f-a42b-425ae34f2776"
},
"execution_count": null,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"生成されたコメント: 2つの数値を加算して結果を返す関数です。\n",
"類似度スコア: 0.6156468987464905\n"
]
}
]
},
{
"cell_type": "code",
"source": [
"# from google.colab import runtime\n",
"# runtime.unassign()"
],
"metadata": {
"id": "CjQfniWHlNuj"
},
"execution_count": null,
"outputs": []
}
]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment