Created
July 19, 2024 13:30
-
-
Save Keshav13142/5e0c17c6ef4200e1af9982e057c12ce7 to your computer and use it in GitHub Desktop.
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import sqlite3 | |
from datetime import datetime, timedelta | |
import requests | |
from datasets import DatasetDict, load_dataset | |
from dotenv import load_dotenv | |
from flask import Flask, jsonify, redirect, render_template, request | |
from transformers import ( | |
DataCollatorForLanguageModeling, | |
GPT2LMHeadModel, | |
GPT2Tokenizer, | |
Trainer, | |
TrainingArguments, | |
) | |
load_dotenv() | |
PRIVATE_TOKEN = os.getenv("PRIVATE_TOKEN") | |
PROJECT_ID = os.getenv("PROJECT_ID") | |
SERVER_URL = os.getenv("SERVER_URL") | |
GITLAB_APP_ID = os.getenv("GITLAB_APP_ID") | |
GITLAB_OAUTH_SCOPES = os.getenv("GITLAB_OAUTH_SCOPES") | |
app = Flask(__name__) | |
OAUTH_ACCESS_TOKEN = None | |
PROJECT_ACCESS_TOKEN = None | |
@app.route("/webhook", methods=["POST"]) | |
def webhook(): | |
if request.method == "POST": | |
data = request.json | |
print(f"Received webhook data: {data}") | |
# Process the webhook payload here | |
return jsonify({"status": "success"}), 200 | |
else: | |
return jsonify({"status": "method not allowed"}), 405 | |
@app.route("/", methods=["GET"]) | |
def index_page(): | |
if request.method == "GET": | |
return render_template("index.html") | |
@app.route("/connect", methods=["GET"]) | |
def connect_page(): | |
if request.method == "GET": | |
if OAUTH_ACCESS_TOKEN != None: | |
return redirect("/success") | |
return redirect( | |
f"https://gitlab.com/oauth/authorize?client_id={GITLAB_APP_ID}&redirect_uri={SERVER_URL}/oauth/callback&response_type=code&state=random&scope={GITLAB_OAUTH_SCOPES}" | |
) | |
return "Hello" | |
@app.route("/oauth/callback", methods=["GET"]) | |
def get_oauth_token(): | |
global OAUTH_ACCESS_TOKEN | |
if request.method == "GET": | |
OAUTH_ACCESS_TOKEN = request.args["code"] | |
if OAUTH_ACCESS_TOKEN != None: | |
return redirect("/success") | |
return redirect("/") | |
@app.route("/success", methods=["GET"]) | |
def success_page(): | |
global OAUTH_ACCESS_TOKEN | |
if request.method == "GET": | |
# if access_token != None: | |
return render_template("success.html") | |
# else: | |
# return redirect("/") | |
@app.route("/add-project", methods=["GET", "POST"]) | |
def add_project(): | |
if request.method == "GET": | |
if OAUTH_ACCESS_TOKEN == None: | |
return redirect("/") | |
return render_template("add-project.html") | |
elif request.method == "POST": | |
project_id = request.form["project_id"] | |
print(project_id) | |
if project_id == None: | |
return redirect("/add-project") | |
success = add_webhook(project_id) | |
if success: | |
return redirect("/add-project") | |
else: | |
return "Failed to add project!" | |
def create_project_access_token(project_id): | |
global PROJECT_ACCESS_TOKEN | |
print(f"https://gitlab.com/api/v4/projects/{project_id}/access_tokens") | |
print(f"BEARER {OAUTH_ACCESS_TOKEN}") | |
print( | |
{ | |
"name": "flask access token", | |
"scopes": ["api"], | |
"expires_at": (datetime.now() + timedelta(days=30)).isoformat(), | |
} | |
) | |
response = requests.post( | |
f"https://gitlab.com/api/v4/projects/{project_id}/access_tokens", | |
json={ | |
"name": "flask access token", | |
"scopes": ["api"], | |
"expires_at": (datetime.now() + timedelta(days=30)).isoformat(), | |
}, | |
headers={ | |
"Authorization": f"BEARER {OAUTH_ACCESS_TOKEN}", | |
"Content-Type": "application/json", | |
}, | |
) | |
PROJECT_ACCESS_TOKEN = response.json().get("token") | |
print("Got project access token ", PROJECT_ACCESS_TOKEN) | |
return True if response.status_code == 201 else False | |
def add_webhook(project_id): | |
print("ACCESS_TOKEN ", OAUTH_ACCESS_TOKEN) | |
create_project_access_token(project_id) | |
response = requests.post( | |
f"https://gitlab.com/api/v4/projects/{project_id}/hooks", | |
json={ | |
"url": f"{SERVER_URL}/webhook", | |
"name": "IntelliOps Webhook", | |
"description": "Custom webhook to notify the server about new issues", | |
"issues_events": True, | |
}, | |
headers={ | |
"Authorization": f"BEARER {PROJECT_ACCESS_TOKEN}", | |
# "Authorization": f"BEARER glpat-P-3zvVPxeg2GZVxrpA64", | |
"Content-Type": "application/json", | |
}, | |
) | |
print(response) | |
return True if response.status_code == 201 else False | |
# @app.route('/webhook', methods=['POST']) | |
# def webhook_handler(): | |
# if request.headers['Content-Type'] == 'application/json': | |
# payload = request.json | |
# # Send a quick response indicating the webhook was received | |
# response_data = {'message': 'Webhook received. Starting training...'} | |
# status_code = 200 | |
# # Start processing the payload asynchronously | |
# import threading | |
# threading.Thread(target=main, args=(payload,)).start() | |
# return jsonify(response_data), status_code | |
def get_gitlab_issues(): | |
url = f"https://git.virtusa.com/api/v4/projects/{PROJECT_ID}/issues" | |
headers = {"PRIVATE-TOKEN": PRIVATE_TOKEN} | |
response = requests.get(url, headers=headers) | |
if response.status_code == 200: | |
return response.json() | |
else: | |
print(f"Failed to fetch issues, status code: {response.status_code}") | |
return None | |
def initialize_database(): | |
conn = sqlite3.connect("issues.db") | |
c = conn.cursor() | |
c.execute( | |
""" | |
CREATE TABLE IF NOT EXISTS issues ( | |
issue_id INTEGER PRIMARY KEY, | |
title TEXT NOT NULL, | |
description TEXT, | |
solution TEXT); | |
""" | |
) | |
conn.commit() | |
conn.close() | |
def update_issues_txt(): | |
conn = sqlite3.connect("issues.db") | |
c = conn.cursor() | |
c.execute("SELECT * FROM issues") | |
issues = c.fetchall() | |
conn.close() | |
with open("issues.txt", "a", encoding="utf-8") as f: # Append mode | |
for issue in issues: | |
issue_id, title, description, solution = issue | |
f.write(f"Issue ID: {issue_id}\n") | |
f.write(f"Title: {title}\n") | |
f.write(f"Description: {description}\n") | |
f.write(f"Solution: {solution}\n\n") | |
def load_and_tokenize_dataset(file_path, tokenizer): | |
if not os.path.exists(file_path) or os.path.getsize(file_path) == 0: | |
raise ValueError("The dataset file is empty or does not exist.") | |
raw_dataset = load_dataset("text", data_files={"train": file_path}, split="train") | |
def tokenize_function(examples): | |
return tokenizer( | |
examples["text"], truncation=True, padding="max_length", max_length=512 | |
) | |
tokenized_dataset = raw_dataset.map(tokenize_function, batched=True) | |
return DatasetDict({"train": tokenized_dataset}) | |
def fine_tune_model(dataset, model, tokenizer): | |
training_args = TrainingArguments( | |
output_dir="./fine_tuned_model", | |
overwrite_output_dir=True, | |
num_train_epochs=1, | |
per_device_train_batch_size=1, | |
save_steps=10, | |
save_total_limit=1, | |
) | |
data_collator = DataCollatorForLanguageModeling( | |
tokenizer=tokenizer, | |
mlm=False, | |
) | |
trainer = Trainer( | |
model=model, | |
args=training_args, | |
data_collator=data_collator, | |
train_dataset=dataset["train"], | |
) | |
trainer.train() | |
def generate_solution(issue_description, model, tokenizer): | |
if not issue_description.strip(): | |
return "No input text provided." | |
input_ids = tokenizer.encode(issue_description, return_tensors="pt") | |
output = model.generate( | |
input_ids, | |
max_length=150, | |
num_return_sequences=1, | |
pad_token_id=tokenizer.eos_token_id, | |
) | |
generated_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return generated_text | |
def main(payload=None): | |
initialize_database() | |
if not payload: | |
issues = get_gitlab_issues() | |
else: | |
issues = payload.get("issues") | |
if issues: | |
model_name = "distilgpt2" | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
tokenizer.pad_token = tokenizer.eos_token | |
for issue in issues: | |
process_issue(issue, model, tokenizer) | |
dataset_path = "issues.txt" | |
dataset = load_and_tokenize_dataset(dataset_path, tokenizer) | |
fine_tune_model(dataset, model, tokenizer) | |
output_dir = "./fine_tuned_model" | |
os.makedirs(output_dir, exist_ok=True) | |
model.save_pretrained(output_dir) | |
tokenizer.save_pretrained(output_dir) | |
else: | |
print("No issues fetched from GitLab.") | |
def process_issue(issue, model, tokenizer): | |
issue_id = issue["id"] | |
title = issue["title"] | |
description = issue.get("description", "No description provided") | |
conn = sqlite3.connect("issues.db") | |
c = conn.cursor() | |
c.execute("SELECT * FROM issues WHERE issue_id=?", (issue_id,)) | |
existing_issue = c.fetchone() | |
if existing_issue: | |
print(f"Issue '{title}' already exists in the database. Skipping.") | |
return | |
generated_solution = generate_solution(description, model, tokenizer) | |
c.execute( | |
"INSERT INTO issues (issue_id, title, description, solution) VALUES (?, ?, ?, ?)", | |
(issue_id, title, description, generated_solution), | |
) | |
conn.commit() | |
conn.close() | |
with open("issues.txt", "a", encoding="utf-8") as f: | |
f.write(f"Issue ID: {issue_id}\n") | |
f.write(f"Title: {title}\n") | |
f.write(f"Description: {description}\n") | |
f.write(f"Solution: {generated_solution}\n\n") | |
if __name__ == "__main__": | |
app.run(debug=True, host="0.0.0.0", port=5000) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment