-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathutils.py
132 lines (109 loc) · 4.41 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
import base64
import httpx
from dotenv import load_dotenv
import jwt
import os
import time
load_dotenv()
APP_ID = os.environ.get("APP_ID")
PRIVATE_KEY = os.environ.get("PRIVATE_KEY", "")
# with open('private-key.pem', 'r') as f:
# PRIVATE_KEY = f.read()
def generate_jwt():
payload = {
"iat": int(time.time()),
"exp": int(time.time()) + (10 * 60),
"iss": APP_ID,
}
if PRIVATE_KEY:
jwt_token = jwt.encode(payload, PRIVATE_KEY, algorithm="RS256")
return jwt_token
raise ValueError("PRIVATE_KEY not found.")
async def get_installation_access_token(jwt, installation_id):
url = f"https://api.github.com/app/installations/{installation_id}/access_tokens"
headers = {
"Authorization": f"Bearer {jwt}",
"Accept": "application/vnd.github.v3+json",
}
async with httpx.AsyncClient() as client:
response = await client.post(url, headers=headers)
return response.json()["token"]
def get_diff_url(pr):
"""GitHub 302s to this URL."""
original_url = pr.get("url")
parts = original_url.split("/")
owner, repo, pr_number = parts[-4], parts[-3], parts[-1]
return f"https://patch-diff.githubusercontent.com/raw/{owner}/{repo}/pull/{pr_number}.diff"
async def get_branch_files(pr, branch, headers):
original_url = pr.get("url")
parts = original_url.split("/")
owner, repo = parts[-4], parts[-3]
url = f"https://api.github.com/repos/{owner}/{repo}/git/trees/{branch}?recursive=1"
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers)
tree = response.json().get('tree', [])
files = {}
for item in tree:
if item['type'] == 'blob':
file_url = item['url']
print(file_url)
file_response = await client.get(file_url, headers=headers)
content = file_response.json().get('content', '')
# Decode the base64 content
decoded_content = base64.b64decode(content).decode('utf-8')
files[item['path']] = decoded_content
return files
async def get_pr_head_branch(pr, headers):
original_url = pr.get("url")
parts = original_url.split("/")
owner, repo, pr_number = parts[-4], parts[-3], parts[-1]
url = f"https://api.github.com/repos/{owner}/{repo}/pulls/{pr_number}"
async with httpx.AsyncClient() as client:
response = await client.get(url, headers=headers)
# Check if the response is successful
if response.status_code != 200:
print(f"Error: Received status code {response.status_code}")
print("Response body:", response.text)
return ''
# Safely get the 'ref'
data = response.json()
head_data = data.get('head', {})
ref = head_data.get('ref', '')
return ref
def files_to_diff_dict(diff):
files_with_diff = {}
current_file = None
for line in diff.split("\n"):
if line.startswith("diff --git"):
current_file = line.split(" ")[2][2:]
files_with_diff[current_file] = {"text": []}
elif line.startswith("+") and not line.startswith("+++"):
files_with_diff[current_file]["text"].append(line[1:])
return files_with_diff
def parse_diff_to_line_numbers(diff):
files_with_line_numbers = {}
current_file = None
line_number = 0
for line in diff.split("\n"):
if line.startswith("diff --git"):
current_file = line.split(" ")[2][2:]
files_with_line_numbers[current_file] = []
line_number = 0
elif line.startswith("@@"):
line_number = int(line.split(" ")[2].split(",")[0][1:]) - 1
elif line.startswith("+") and not line.startswith("+++"):
files_with_line_numbers[current_file].append(line_number)
line_number += 1
elif not line.startswith("-"):
line_number += 1
return files_with_line_numbers
def get_context_from_files(files, files_with_line_numbers, context_lines=2):
context_data = {}
for file, lines in files_with_line_numbers.items():
file_content = files[file].split("\n")
context_data[file] = []
for line in lines:
start = max(line - context_lines, 0)
end = min(line + context_lines + 1, len(file_content))
context_data[file].append('\n'.join(file_content[start:end]))
return context_data