-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathheapify.py
139 lines (102 loc) · 3.69 KB
/
heapify.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import ast
from constants import BUILTIN_FUNCS
class FreeVarFinder(ast.NodeVisitor):
def __init__(self):
self.f_vars = set()
def visit_Lambda(self, lda):
self.generic_visit(lda)
visitor = LocalFreeVarFinder()
visitor.visit(lda)
self.f_vars |= visitor.free_vars()
def free_vars(self):
return self.f_vars
class LocalFreeVarFinder(ast.NodeVisitor):
def __init__(self):
self.f_vars = set()
self.l_vars = set()
self.p_vars = set()
def visit_Lambda(self, lda):
self.p_vars |= set((a.arg for a in lda.args.args))
if lda.args.vararg:
self.p_vars.add(lda.args.vararg.arg)
self.generic_visit(lda)
def visit_Name(self, name):
local_vars = self.l_vars | self.p_vars | set(BUILTIN_FUNCS.keys())
if isinstance(name.ctx, ast.Load):
if name.id not in local_vars:
self.f_vars.add(name.id)
elif isinstance(name.ctx, ast.Store):
self.l_vars.add(name.id)
self.f_vars.discard(name.id)
def free_vars(self):
return set(self.f_vars)
def param_vars(self):
return set(self.p_vars)
def local_vars(self):
return set(self.l_vars)
def free_vars(n):
visitor = FreeVarFinder()
visitor.visit(n)
return visitor.free_vars()
class Heapifier(ast.NodeTransformer):
def __init__(self, free_vars):
self.free_vars = free_vars
def visit_Lambda(self, lda):
visitor = LocalFreeVarFinder()
visitor.visit(lda)
lda = self.generic_visit(lda)
setattr(lda, 'free_vars', visitor.free_vars())
free_params = [
arg.arg for arg in lda.args.args if arg.arg in self.free_vars
]
for arg in lda.args.args:
if arg.arg in free_params:
arg.arg = self.heapified_param_name(arg.arg)
if lda.args.vararg and lda.args.vararg.arg in self.free_vars:
free_params.append(lda.args.vararg.arg)
lda.args.vararg.arg = self.heapified_param_name(lda.args.vararg.arg)
lda.body = [
ast.Assign(
[ast.Name(name, ast.Store())],
ast.List([
ast.Name(self.heapified_param_name(name), ast.Load())
], ast.Load())
)
for name in free_params
] + [
ast.Assign(
[ast.Name(name, ast.Store())],
ast.List([ast.Num(0)], ast.Load())
)
for name in visitor.local_vars() if name in self.free_vars
] + lda.body
# for arg in lda.args:
# if arg in self.free_vars:
# free_params.append(arg)
return lda
def visit_Name(self, name):
return name if (name.id not in self.free_vars) else \
ast.copy_location(ast.Subscript(
ast.Name(name.id, ast.Load()),
ast.Index(ast.Num(0)),
name.ctx
), name)
def heapified_param_name(self, pid):
return pid + '_h'
def heapify_free_vars(node):
return Heapifier(free_vars(node)).visit(node)
# def heapify(node):
# fname = 'heapify_{}'.format(node.__class__.__name__)
# func = globals.get(fname, heapify_error)
# vars_to_heapify = free_vars(node)
# return func(node, vars_to_heapify)
#
#
# def heapify_error(node):
# context = ""
# if hasattr(node, 'lineno') and hasattr(node, 'col_offset'):
# context = "@({}:{})".format(node.lineno, node.col_offset)
# detail = ast.dump(node) if isinstance(node, ast.AST) else str(node)
# raise NotImplementedError("could not heapify {}{}: {}".format(
# node.__class__, context, detail
# ))