-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpr_dataset.py
104 lines (78 loc) · 2.92 KB
/
pr_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
#!/usr/bin/env python3
import json
import os
import re
from collections import defaultdict
from pathlib import Path
from typing import Any
from get_tokens_from_directory import get_tokens_from_file, should_process_file
from tokenizer import Tokenizer
from torch.utils.data import Dataset
DIFF_REGEX = re.compile(r"^diff --git a/(?P<a>.+) b/(?P<b>.+)$")
LINENO_REGEX = re.compile(r"@@ -(?P<begin>\d+),\d+ \+(?P<end>\d+),\d+ @@")
class PyTorchPRDataset(Dataset):
def __init__(self, pull_requests_file: str, pull_requests_dir: str):
self.tokenizer = Tokenizer()
with open(pull_requests_file) as f:
self.pull_requests = json.load(f)
self.dir = pull_requests_dir
def __len__(self):
return len(self.pull_requests)
def __getitem__(self, idx):
patch = self.pull_requests[idx]["patch"]
pr_number = self.pull_requests[idx]["number"]
current_file = None
selected_lines = []
all_tokens = defaultdict(list)
for line in patch.split("\n"):
mf = DIFF_REGEX.match(line)
if mf:
if current_file:
tokens = get_tokens_from_file(
Path(current_file),
repo_dir=None,
tests_only=False,
selected_lines=selected_lines,
)
all_tokens.update(tokens)
# Reset for the next file
current_file = None
selected_lines = []
filepath = mf["a"]
current_file = f"{self.dir}/{pr_number}/{filepath}"
if not should_process_file(
str(os.path.basename(current_file)),
str(os.path.dirname(current_file)),
"",
):
# Not interested in this file, i.e. cpp
current_file = None
continue
ml = LINENO_REGEX.match(line)
if not ml:
continue
begin_lineno = int(ml["begin"])
end_lineno = int(ml["end"])
selected_lines.append((begin_lineno, end_lineno))
if current_file:
tokens = get_tokens_from_file(
Path(current_file),
repo_dir=None,
tests_only=False,
selected_lines=selected_lines,
)
all_tokens.update(tokens)
return all_tokens
def parse_args() -> Any:
from argparse import ArgumentParser
parser = ArgumentParser("GitHub PR tokenization")
parser.add_argument("--input", type=str, help="the input JSON file")
parser.add_argument("--pr-dir", type=str, help="the PR data directory")
return parser.parse_args()
def main() -> None:
args = parse_args()
data = PyTorchPRDataset(args.input, args.pr_dir)
for r in data:
print(r)
if __name__ == "__main__":
main()