Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add OpenAI integration with CLI for generating Mermaid diagrams. #8

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,9 @@ NEXT_PUBLIC_API_DEV_URL=http://localhost:8000
ANTHROPIC_API_KEY=

# OPTIONAL: providing your own GitHub PAT increases rate limits from 60/hr to 5000/hr to the GitHub API
GITHUB_PAT=
GITHUB_PAT=

# OpenAI API configuration for CLI usage
OPENAI_API_KEY=""
OPENAI_BASE_URL="https://api.openai.com/v1"
OPENAI_MODEL="gpt-4o-mini"
2 changes: 1 addition & 1 deletion .github/workflows/deploy.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ on:
jobs:
deploy:
runs-on: ubuntu-latest

if: github.repository_owner == 'ahmedkhaleel2004'
# Add concurrency to prevent multiple deployments running at once
concurrency:
group: production
Expand Down
65 changes: 65 additions & 0 deletions backend/app/services/openai_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import os
from dotenv import load_dotenv
import openai

load_dotenv()


class OpenAIService:
def __init__(self):
self.api_key = os.getenv("OPENAI_API_KEY")
self.base_url = os.getenv("OPENAI_BASE_URL")
self.model = os.getenv("OPENAI_MODEL")
self.client = openai.OpenAI(
api_key=self.api_key,
base_url=self.base_url,
)

def call_openai_api(self, system_prompt: str, data: dict) -> str:
"""
Makes an API call to OpenAI and returns the response.

Args:
system_prompt (str): The instruction/system prompt
data (dict): Dictionary of variables to format into the user message

Returns:
str: OpenAI's response text
"""
# Format the user message
user_message = self._format_user_message(data)

messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_message}
]

try:
response = self.client.chat.completions.create(
model=self.model,
messages=messages,
max_tokens=4096,
temperature=0,
)
return response.choices[0].message.content
except Exception as e:
raise Exception(f"API call failed: {str(e)}")

def _format_user_message(self, data: dict[str, str]) -> str:
"""Helper method to format the data into a user message"""
parts = []
for key, value in data.items():
if key == 'file_tree':
parts.append(f"<file_tree>\n{value}\n</file_tree>")
elif key == 'readme':
parts.append(f"<readme>\n{value}\n</readme>")
elif key == 'explanation':
parts.append(f"<explanation>\n{value}\n</explanation>")
elif key == 'component_mapping':
parts.append(
f"<component_mapping>\n{value}\n</component_mapping>")
elif key == 'instructions' and value != "":
parts.append(f"<instructions>\n{value}\n</instructions>")
elif key == 'diagram':
parts.append(f"<diagram>\n{value}\n</diagram>")
return "\n\n".join(parts)
1 change: 1 addition & 0 deletions backend/cli/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .local_git import *
128 changes: 128 additions & 0 deletions backend/cli/generate.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
import os
import argparse
from app.services.openai_service import OpenAIService
from app.prompts import SYSTEM_FIRST_PROMPT, SYSTEM_SECOND_PROMPT, SYSTEM_THIRD_PROMPT, ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
import sys
from cli import build_file_tree, get_readme, print_stat


def main():
parser = argparse.ArgumentParser(
description="Generate Mermaid diagrams from local Git repositories.")
parser.add_argument("repo_path", help="Path to the local Git repository")
parser.add_argument(
"--instructions", help="Instructions for diagram generation", default=None)
parser.add_argument(
"--output", help="Output file for the Mermaid diagram", default="diagram.mmd")
parser.add_argument(
"--stat",
help="Only outputs the file list and statistics",
action="store_true"
)

args = parser.parse_args()

repo_path = args.repo_path
instructions = args.instructions
output_file = args.output

if not os.path.isdir(repo_path):
print(f"Error: The path '{repo_path}' is not a valid directory.")
sys.exit(1)

openai_service = OpenAIService()

if (args.stat):
print_stat(repo_path)
return

# Build file tree and get README
file_tree = build_file_tree(repo_path)
readme = get_readme(repo_path)

if not file_tree and not readme:
print("Error: The repository is empty or unreadable.")
sys.exit(1)

# Prepare system prompts with instructions if provided
first_system_prompt = SYSTEM_FIRST_PROMPT
third_system_prompt = SYSTEM_THIRD_PROMPT
if instructions:
first_system_prompt += "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
third_system_prompt += "\n" + ADDITIONAL_SYSTEM_INSTRUCTIONS_PROMPT
else:
instructions = ""

# Call OpenAI API to get explanation
try:
explanation = openai_service.call_openai_api(
system_prompt=first_system_prompt,
data={
"file_tree": file_tree,
"readme": readme,
"instructions": instructions
},
)
except Exception as e:
print(f"Error generating explanation: {e}")
sys.exit(1)

if "BAD_INSTRUCTIONS" in explanation:
print("Error: Invalid or unclear instructions provided.")
sys.exit(1)

# Call API to get component mapping
try:
full_second_response = openai_service.call_openai_api(
system_prompt=SYSTEM_SECOND_PROMPT,
data={
"explanation": explanation,
"file_tree": file_tree
}
)
except Exception as e:
print(f"Error generating component mapping: {e}")
sys.exit(1)

# Extract component mapping from the response
start_tag = "<component_mapping>"
end_tag = "</component_mapping>"
try:
component_mapping_text = full_second_response[
full_second_response.find(start_tag):
full_second_response.find(end_tag) + len(end_tag)
]
except Exception:
print("Error extracting component mapping.")
sys.exit(1)

# Call API to get Mermaid diagram
try:
mermaid_code = openai_service.call_openai_api(
system_prompt=third_system_prompt,
data={
"explanation": explanation,
"component_mapping": component_mapping_text,
"instructions": instructions
}
)
except Exception as e:
print(f"Error generating Mermaid diagram: {e}")
sys.exit(1)

if "BAD_INSTRUCTIONS" in mermaid_code:
print("Error: Invalid or unclear instructions provided.")
sys.exit(1)

# Save the diagram to the output file
try:
with open(output_file, 'w', encoding='utf-8') as f:
f.write(mermaid_code)
print(f"Mermaid diagram generated and saved to '{output_file}'.")
except Exception as e:
print(f"Error saving diagram: {e}")
sys.exit(1)


if __name__ == "__main__":
main()
76 changes: 76 additions & 0 deletions backend/cli/local_git.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
import os
from collections import Counter


def build_file_tree(repo_path):
"""
Traverse the local repository and build a file tree list.
"""
excluded_patterns = [
'node_modules', 'vendor', 'venv',
'.min.', '.pyc', '.pyo', '.pyd', '.so', '.dll', '.class', ".o",
'.jpg', '.jpeg', '.png', '.gif', '.ico', '.svg', '.ttf', '.woff', '.webp',
'.pdf', '.xml', '.wav', '.doc', '.docx', '.xls', '.xlsx', '.ppt', '.pptx', ".txt", ".log",
'__pycache__', '.cache', '.tmp',
'yarn.lock', 'poetry.lock',
'.vscode', '.idea', '.git', "test", "activate"
]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be an idea to respect the local .gitignore contents as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the feedback! 👍 I’ve gone ahead and updated the code to respect the local .gitignore contents.


file_paths = []
for root, dirs, files in os.walk(repo_path):
# Modify dirs in-place to skip excluded directories
dirs[:] = [d for d in dirs if not any(
excl in d.lower() for excl in excluded_patterns)]
for file in files:
file_path = os.path.join(root, file)
relative_path = os.path.relpath(file_path, repo_path)
if not any(excl in relative_path.lower() for excl in excluded_patterns):
# For Windows compatibility
file_paths.append(relative_path.replace("\\", "/"))

return file_paths # Return as list instead of string for easier processing


def get_readme(repo_path):
"""
Fetch the README content from the local repository.
"""
readme_path = os.path.join(repo_path, "README.md")
if os.path.exists(readme_path):
with open(readme_path, 'r', encoding='utf-8') as f:
return f.read()
return ""


def analyze_extension_percentage(file_paths):
"""
Analyze the percentage distribution of file extensions in the provided file list.

Args:
file_paths (list): List of file paths.

Returns:
dict: Dictionary mapping file extensions to their percentage occurrence.
"""
extensions = [os.path.splitext(file)[1].lower()
for file in file_paths if os.path.splitext(file)[1]]
total = len(extensions)
if total == 0:
return {}
counts = Counter(extensions)
percentages = {ext: (count / total) * 100 for ext, count in counts.items()}

sorted_percentages = dict(
sorted(percentages.items(), key=lambda item: item[1], reverse=True))
return sorted_percentages


def print_stat(repo_path):
file_list = build_file_tree(repo_path)
for f in file_list:
print(f)
extension_percentages = analyze_extension_percentage(file_list)

print("File Extension Percentage Distribution:")
for ext, percent in extension_percentages.items():
print(f"{ext or 'No Extension'}: {percent:.2f}%")
1 change: 1 addition & 0 deletions backend/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,4 @@ uvloop==0.21.0
watchfiles==1.0.3
websockets==14.1
wrapt==1.17.0
openai==1.58.1