Skip to content

Instantly share code, notes, and snippets.

@Keshav13142
Created July 23, 2024 07:22
Show Gist options
  • Save Keshav13142/445536b2b43eda637f037fa40c694534 to your computer and use it in GitHub Desktop.
Save Keshav13142/445536b2b43eda637f037fa40c694534 to your computer and use it in GitHub Desktop.

app.py

import os
import pickle
import sqlite3

import google.generativeai as genai
import markdown
import numpy as np
import requests
import torch
import urllib3
from flask import Flask, redirect, render_template, request, session, url_for
from transformers import (
    AutoModel,
    AutoTokenizer,
    T5ForConditionalGeneration,
    T5Tokenizer,
)

# Suppress insecure request warnings for development
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)

app = Flask(__name__)

# API Configurations

GITLAB_API_URL = "https://gitlab.com/api/v4"
GENAI_API_KEY = "key"
app.secret_key = "tech_hackers"

# OAuth Configuration
CLIENT_SECRET = os.getenv("GITLAB_APP_SECRET")
CLIENT_ID = os.getenv("GITLAB_APP_ID")
SERVER_URL = os.getenv("SERVER_URL")

REDIRECT_URI = f"{SERVER_URL}/oauth/callback"
GITLAB_AUTHORIZE_URL = "https://gitlab.com/oauth/authorize"
GITLAB_TOKEN_URL = "https://gitlab.com/oauth/token"
GITLAB_USER_API_URL = "https://gitlab.com/api/v4/user"

OAUTH_ACCESS_TOKEN = None
issues = []


class MemoryVectorStore:
    def __init__(self):
        self.documents = []
        self.embeddings = []

    def add_documents(self, documents, embeddings):
        self.documents.extend(documents)
        self.embeddings.extend(embeddings)

    def similarity_search(self, query_embedding, top_k=3):
        query_embedding = torch.tensor(query_embedding, dtype=torch.float32)
        self.embeddings = torch.tensor(self.embeddings, dtype=torch.float32)

        # Ensure that embeddings are of the same shape
        if query_embedding.ndim == 2:
            query_embedding = query_embedding.squeeze(0)
        similarities = torch.cosine_similarity(
            query_embedding.unsqueeze(0), self.embeddings, dim=1
        )
        top_k_indices = torch.topk(similarities, top_k).indices
        return [(self.documents[i], similarities[i].item()) for i in top_k_indices]

    def save(self, docs_file_path, embeddings_file_path):
        with open(docs_file_path, "wb") as f:
            pickle.dump(self.documents, f)
        np.save(embeddings_file_path, self.embeddings)

    def load(self, docs_file_path, embeddings_file_path):
        with open(docs_file_path, "rb") as f:
            self.documents = pickle.load(f)
        self.embeddings = np.load(embeddings_file_path)


EMBEDDING_MODEL_NAME = "bert-base-uncased"


@app.route("/", methods=["GET"])
def index_page():
    if request.method == "GET":
        return render_template("index.html")


@app.route("/connect", methods=["GET"])
def connect_page():
    global OAUTH_ACCESS_TOKEN
    if OAUTH_ACCESS_TOKEN:
        return redirect("/form")
    return redirect(
        f"{GITLAB_AUTHORIZE_URL}?client_id={CLIENT_ID}&redirect_uri={REDIRECT_URI}&response_type=code&state=random&scope=api"
    )


@app.route("/oauth/callback", methods=["GET"])
def get_oauth_token():
    global OAUTH_ACCESS_TOKEN
    code = request.args.get("code")
    if not code:
        return redirect("/")

    token_response = requests.post(
        GITLAB_TOKEN_URL,
        data={
            "client_id": CLIENT_ID,
            "client_secret": CLIENT_SECRET,
            "code": code,
            "grant_type": "authorization_code",
            "redirect_uri": REDIRECT_URI,
        },
        headers={"Content-Type": "application/x-www-form-urlencoded"},
        verify=False,
    )

    if token_response.status_code != 200:
        return "Error: Unable to fetch access token", 400

    token_data = token_response.json()
    OAUTH_ACCESS_TOKEN = token_data["access_token"]

    user_response = requests.get(
        GITLAB_USER_API_URL,
        headers={"Authorization": f"Bearer {OAUTH_ACCESS_TOKEN}"},
        verify=False,
    )

    if user_response.status_code != 200:
        return "Error: Unable to fetch user information", 400

    user_data = user_response.json()
    session["user"] = user_data
    return redirect("/form")


@app.route("/form", methods=["GET"])
def form():
    if "user" not in session:
        return redirect("/connect")
    return render_template("ProjectForm.html")


def load_fine_tuned_model(model_dir):
    if model_dir is None:
        raise ValueError("Model directory must be specified.")

    model = T5ForConditionalGeneration.from_pretrained(model_dir)
    tokenizer = T5Tokenizer.from_pretrained(model_dir)

    return model, tokenizer


def interact_with_llm_model(input_text, model, tokenizer):
    if not input_text.strip():
        return "No input text provided."

    input_ids = tokenizer.encode(input_text, return_tensors="pt")

    output = model.generate(
        input_ids,
        max_length=150,
        num_return_sequences=1,
        pad_token_id=tokenizer.eos_token_id,
        temperature=1.0,
        repetition_penalty=1.2,
        top_k=50,
        top_p=0.95,
        do_sample=True,
    )

    return tokenizer.decode(output[0], skip_special_tokens=True)


def interact_with_gemini_model(input_text):
    try:
        vector_store = MemoryVectorStore()
        vector_store.load("documents.pkl", "code_embeddings.npy")

        # Example usage of similarity search

        tokenizer = AutoTokenizer.from_pretrained(EMBEDDING_MODEL_NAME)
        model = AutoModel.from_pretrained(EMBEDDING_MODEL_NAME)

        inputs = tokenizer(
            input_text,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=512,
        )
        with torch.no_grad():
            query_embedding = model(**inputs).last_hidden_state.mean(dim=1).numpy()

        # Ensure query_embedding is 2D
        if query_embedding.ndim == 1:
            query_embedding = query_embedding.reshape(1, -1)
        elif query_embedding.ndim == 3:
            query_embedding = query_embedding.squeeze(axis=0)

        results = vector_store.similarity_search(query_embedding)
        genai.configure(api_key=GENAI_API_KEY)

        model = genai.GenerativeModel("gemini-1.5-flash")
        prompt = f"solve the GitLab project issue: {input_text} with context: {results}.If No context satisfying give response from your knowledge base"
        # , and don't give the response in markdown format, I want it in plain text

        response = model.generate_content(prompt)

        return response if response else None

    except Exception as e:
        print(f"Error generating content with Gemini model: {e}")
        return None


@app.route("/issueDetails", methods=["POST", "GET"])
def collect_info():
    get_gitlab_issues()

    if request.method == "POST":
        return render_template("issueTitles.html", issues_info=issues)

    if request.method == "GET":  # Get issues from session
        return render_template("issueTitles.html", issues_info=issues)

    return redirect(url_for("form"))


@app.route("/issue/<int:issue_id>", methods=["GET", "POST"])
def issue_details(issue_id):
    try:
        issue = get_issue_by_id(issue_id)

        if issue is None:
            return "Issue not found", 404

        # Load the model and tokenizer
        model_dir = "./fine_tuned_model8"
        model, tokenizer = load_fine_tuned_model(model_dir)

        if request.method == "POST":
            action = request.form.get("action")

            if action == "confirm_solution":
                solution = request.form.get("solution", "")
                print(f"Confirming solution: {solution}")  # Debugging
                if solution:
                    update_solution(issue_id, solution)
                    print(f"Solution updated for issue_id {issue_id}")  # Debugging
                    # Update the session data
                    issues = session.get("issues", [])
                    for idx, issue_info in enumerate(issues):
                        if issue_info["issue_id"] == issue_id:
                            issue_info["solution"] = solution
                            break

                    session["issues"] = issues  # Ensure session is updated
                    session.modified = (
                        True  # Indicate that the session has been modified
                    )
                    print(
                        f"Session data after update: {session.get('issues')}"
                    )  # Debugging
                return redirect(url_for("collect_info"))

            elif action == "generate_llm":
                title = issue.get("title", "")
                description = issue.get("description", "")
                generated_solution = interact_with_llm_model(
                    title + " " + description, model, tokenizer
                )
                return render_template(
                    "generatedSolution.html",
                    issue=issue,
                    generated_solution=generated_solution,
                    issue_id=issue_id,
                    model_type="llm",
                )

            elif action == "generate_gemini":
                title = issue.get("title", "")
                description = issue.get("description", "")
                gemini_solution = interact_with_gemini_model(title + " " + description)
                text_content = str(
                    gemini_solution._result.candidates[0].content.parts[0].text
                )
                return render_template(
                    "generatedSolution.html",
                    issue=issue,
                    generated_solution=markdown.markdown(text_content),
                    issue_id=issue_id,
                    model_type="gemini",
                )

            elif action == "user_provided_solution":
                user_solution = request.form.get("user_solution", "")
                insert_solution(
                    issue_id, issue["title"], issue["description"], user_solution
                )
                return redirect(url_for("collect_info"))

    except Exception as e:
        print(f"An error occurred: {e}")

        raise
    return render_template("issueDetails.html", issue=issue, issue_id=issue_id)


@app.route("/feedback")
def feedback():
    feedback_message = "Your feedback has been recorded."
    return render_template("feedback.html", feedback_message=feedback_message)


# Database functions
def initialize_database():
    conn = sqlite3.connect("issues.db")
    c = conn.cursor()
    c.execute(
        """
        CREATE TABLE IF NOT EXISTS issues (
            issue_id INTEGER PRIMARY KEY,
            title TEXT NOT NULL,
            description TEXT,
            solution TEXT
        );
    """
    )
    conn.commit()
    conn.close()


def check_duplicate_issue(title):
    conn = sqlite3.connect("issues.db")
    c = conn.cursor()
    c.execute("SELECT * FROM issues WHERE title=?", (title,))
    result = c.fetchone()
    conn.close()
    return result


def insert_solution(issue_id, title, description, solution="No solution provided"):
    conn = sqlite3.connect("issues.db")
    c = conn.cursor()
    c.execute("SELECT * FROM issues WHERE issue_id=?", (issue_id,))
    if c.fetchone():
        c.execute(
            "UPDATE issues SET title=?, description=?, solution=? WHERE issue_id=?",
            (title, description, solution, issue_id),
        )
    else:
        c.execute(
            "INSERT INTO issues (issue_id, title, description, solution) VALUES (?, ?, ?, ?)",
            (issue_id, title, description, solution),
        )
    conn.commit()
    conn.close()


def get_issue_by_id(issue_id):
    try:
        with sqlite3.connect("issues.db") as conn:
            c = conn.cursor()
            print(f"Querying for issue_id: {issue_id}")  # Debugging info
            c.execute("SELECT * FROM issues WHERE issue_id = ?", (issue_id,))
            issue = c.fetchone()

            if issue:
                print(f"Issue found: {issue}")  # Debugging info
                return {
                    "issue_id": issue[0],
                    "title": issue[1],
                    "description": issue[2],
                    "solution": issue[3],
                }
            else:
                print("No issue found")  # Debugging info
                return None

    except sqlite3.Error as e:
        print(f"An error occurred: {e}")
        raise


def update_solution(issue_id, solution):
    issue = get_issue_by_id(issue_id)
    if issue:
        insert_solution(issue_id, issue["title"], issue["description"], solution)


def get_gitlab_issues():
    global issues
    global OAUTH_ACCESS_TOKEN
    project_id = request.form.get("id")

    if issues:
        print("Issues already populated")
        return

    if not project_id:
        return "Project ID is required", 400

    print("Fetching issues from gitlab")

    response = requests.get(
        f"{GITLAB_API_URL}/projects/{project_id}/issues",
        headers={"Authorization": f"Bearer {OAUTH_ACCESS_TOKEN}"},
        verify=False,
    )

    if response.status_code != 200:
        return "Error: Unable to fetch issues", 400

    response = response.json()
    if response:
        for issue in response:
            # Check if the issue is open
            if issue.get("state") == "opened":
                state = "opened"
            else:
                state = "closed"

            issue_id = issue["id"]
            title = issue["title"]
            issue_description = issue.get("description", "No description provided")

            existing_issue = check_duplicate_issue(title)
            if existing_issue:
                existing_solution = existing_issue[3]
                issues.append(
                    {
                        "issue_id": issue_id,
                        "title": title,
                        "description": issue_description,
                        "solution": existing_solution,
                        "status": "existing",
                        "update_prompt": True,
                        "state": state,  # Add state information
                    }
                )
            else:
                insert_solution(issue_id, title, issue_description)
                issues.append(
                    {
                        "issue_id": issue_id,
                        "title": title,
                        "description": issue_description,
                        "solution": "No Solution Provided",
                        "status": "new",
                        "update_prompt": False,
                        "state": state,  # Add state information
                    }
                )

                insert_solution(
                    issue_id, title, issue_description, "No Solution Provided"
                )


if __name__ == "__main__":
    app.run(debug=True)
    initialize_database()

generatedSolution.html

<!DOCTYPE html>
<html lang="en">
  <head>
    <meta charset="UTF-8" />
    <meta
      name="viewport"
      content="width=device-width, initial-scale=1, shrink-to-fit=no"
    />
    <title>Generated Solution</title>
    <link
      href="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/css/bootstrap.min.css"
      rel="stylesheet"
    />
    <link
      rel="stylesheet"
      href="https://cdnjs.cloudflare.com/ajax/libs/github-markdown-css/4.0.0/github-markdown.min.css"
    />
  </head>
  <body>
    <div class="container my-4">
      <h1>Generated Solution</h1>
      <h2>{{ issue.title }}</h2>
      <p>{{ issue.description }}</p>
      <strong>Generated Solution:</strong>
      <article class="markdown-body">{{ generated_solution|safe }}</article>
      <br />

      <form
        action="{{ url_for('issue_details', issue_id=issue_id) }}"
        method="POST"
        class="mb-3"
      >
        <input type="hidden" name="action" value="confirm_solution" />
        <input type="hidden" name="solution" value="{{ generated_solution }}" />
        <button type="submit" class="btn btn-success">Solution Works</button>
      </form>
      <form
        action="{{ url_for('issue_details', issue_id=issue_id) }}"
        method="GET"
        class="mb-3"
      >
        <button type="submit" class="btn btn-secondary">Go Back</button>
      </form>
    </div>

    <script src="https://code.jquery.com/jquery-3.5.1.slim.min.js"></script>
    <script src="https://cdn.jsdelivr.net/npm/@popperjs/[email protected]/dist/umd/popper.min.js"></script>
    <script src="https://stackpath.bootstrapcdn.com/bootstrap/4.5.2/js/bootstrap.min.js"></script>
  </body>
</html>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment