forked from google/struct2tensor
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtf_configure.bzl
223 lines (194 loc) · 6.99 KB
/
tf_configure.bzl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
# Copyright 2019 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Setup TensorFlow as external dependency"""
_TF_HEADER_DIR = "TF_HEADER_DIR"
_TF_SHARED_LIBRARY_DIR = "TF_SHARED_LIBRARY_DIR"
_TF_SHARED_LIBRARY_NAME = "TF_SHARED_LIBRARY_NAME"
def _tpl(repository_ctx, tpl, substitutions = {}, out = None):
if not out:
out = tpl
repository_ctx.template(
out,
Label("//tf:%s.tpl" % tpl),
substitutions,
)
def _fail(msg):
"""Output failure message when auto configuration fails."""
red = "\033[0;31m"
no_color = "\033[0m"
fail("%sPython Configuration Error:%s %s\n" % (red, no_color, msg))
def _is_windows(repository_ctx):
"""Returns true if the host operating system is windows."""
os_name = repository_ctx.os.name.lower()
if os_name.find("windows") != -1:
return True
return False
def _execute(
repository_ctx,
cmdline,
error_msg = None,
error_details = None,
empty_stdout_fine = False):
"""Executes an arbitrary shell command.
Helper for executes an arbitrary shell command.
Args:
repository_ctx: the repository_ctx object.
cmdline: list of strings, the command to execute.
error_msg: string, a summary of the error if the command fails.
error_details: string, details about the error or steps to fix it.
empty_stdout_fine: bool, if True, an empty stdout result is fine, otherwise
it's an error.
Returns:
The result of repository_ctx.execute(cmdline).
"""
result = repository_ctx.execute(cmdline)
if result.stderr or not (empty_stdout_fine or result.stdout):
_fail("\n".join([
error_msg.strip() if error_msg else "Repository command failed",
result.stderr.strip(),
error_details if error_details else "",
]))
return result
def _read_dir(repository_ctx, src_dir):
"""Returns a string with all files in a directory.
Finds all files inside a directory, traversing subfolders and following
symlinks. The returned string contains the full path of all files
separated by line breaks.
Args:
repository_ctx: the repository_ctx object.
src_dir: directory to find files from.
Returns:
A string of all files inside the given dir.
"""
if _is_windows(repository_ctx):
src_dir = src_dir.replace("/", "\\")
find_result = _execute(
repository_ctx,
["cmd.exe", "/c", "dir", src_dir, "/b", "/s", "/a-d"],
empty_stdout_fine = True,
)
# src_files will be used in genrule.outs where the paths must
# use forward slashes.
result = find_result.stdout.replace("\\", "/")
else:
find_result = _execute(
repository_ctx,
["find", src_dir, "-follow", "-type", "f"],
empty_stdout_fine = True,
)
result = find_result.stdout
return result
def _genrule(genrule_name, command, outs):
"""Returns a string with a genrule.
Genrule executes the given command and produces the given outputs.
Args:
genrule_name: A unique name for genrule target.
command: The command to run.
outs: A list of files generated by this rule.
Returns:
A genrule target.
"""
return (
"genrule(\n" +
' name = "' +
genrule_name + '",\n' +
" outs = [\n" +
outs +
"\n ],\n" +
' cmd = """\n' +
command +
'\n """,\n' +
")\n"
)
def _norm_path(path):
"""Returns a path with '/' and remove the trailing slash."""
path = path.replace("\\", "/")
if path[-1] == "/":
path = path[:-1]
return path
def _symlink_genrule_for_dir(
repository_ctx,
src_dir,
dest_dir,
genrule_name,
src_files = [],
dest_files = []):
"""Returns a genrule to symlink(or copy if on Windows) a set of files.
If src_dir is passed, files will be read from the given directory; otherwise
we assume files are in src_files and dest_files.
Args:
repository_ctx: the repository_ctx object.
src_dir: source directory.
dest_dir: directory to create symlink in.
genrule_name: genrule name.
src_files: list of source files instead of src_dir.
dest_files: list of corresonding destination files.
Returns:
genrule target that creates the symlinks.
"""
if src_dir != None:
src_dir = _norm_path(src_dir)
dest_dir = _norm_path(dest_dir)
files = "\n".join(sorted(_read_dir(repository_ctx, src_dir).splitlines()))
# Create a list with the src_dir stripped to use for outputs.
dest_files = files.replace(src_dir, "").splitlines()
src_files = files.splitlines()
command = []
outs = []
for i in range(len(dest_files)):
if dest_files[i] != "":
# If we have only one file to link we do not want to use the dest_dir, as
# $(@D) will include the full path to the file.
dest = "$(@D)/" + dest_dir + dest_files[i] if len(dest_files) != 1 else "$(@D)/" + dest_files[i]
# Copy the headers to create a sandboxable setup.
cmd = "cp -f"
command.append(cmd + ' "%s" "%s"' % (src_files[i], dest))
outs.append(' "' + dest_dir + dest_files[i] + '",')
genrule = _genrule(
genrule_name,
" && ".join(command),
"\n".join(outs),
)
return genrule
def _tf_pip_impl(repository_ctx):
tf_header_dir = repository_ctx.os.environ[_TF_HEADER_DIR]
tf_header_rule = _symlink_genrule_for_dir(
repository_ctx,
tf_header_dir,
"include",
"tf_header_include",
)
tf_shared_library_dir = repository_ctx.os.environ[_TF_SHARED_LIBRARY_DIR]
tf_shared_library_name = repository_ctx.os.environ[_TF_SHARED_LIBRARY_NAME]
tf_shared_library_path = "%s/%s" % (tf_shared_library_dir, tf_shared_library_name)
tf_shared_library_rule = _symlink_genrule_for_dir(
repository_ctx,
None,
"",
"libtensorflow_framework.so",
[tf_shared_library_path],
["libtensorflow_framework.so"],
)
_tpl(repository_ctx, "BUILD", {
"%{TF_HEADER_GENRULE}": tf_header_rule,
"%{TF_SHARED_LIBRARY_GENRULE}": tf_shared_library_rule,
})
tf_configure = repository_rule(
implementation = _tf_pip_impl,
environ = [
_TF_HEADER_DIR,
_TF_SHARED_LIBRARY_DIR,
_TF_SHARED_LIBRARY_NAME,
],
)