Skip to content

Commit

Permalink
move gen_aten and gen_aten_hip into shared build structure
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#77751

This requires two changes to rule generation:
 * pulling the cpu static dispatch prediction into the rules
 * disabling the Bazel-style generated file aliases

Differential Revision: [D36481918](https://our.internmc.facebook.com/intern/diff/D36481918/)

**NOTE FOR REVIEWERS**: This PR has internal Facebook specific changes or comments, please review them on [Phabricator](https://our.internmc.facebook.com/intern/diff/D36481918/)!

Approved by: https://github.com/kit1980, https://github.com/seemethere
  • Loading branch information
Michael Andreas Dagitses authored and pytorchmergebot committed Jun 15, 2022
1 parent b9bb52d commit eb5751d
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 0 deletions.
61 changes: 61 additions & 0 deletions build.bzl
Original file line number Diff line number Diff line change
@@ -1,3 +1,10 @@
load(
":ufunc_defs.bzl",
"aten_ufunc_generated_cpu_kernel_sources",
"aten_ufunc_generated_cpu_sources",
"aten_ufunc_generated_cuda_sources",
)

def define_targets(rules):
rules.cc_library(
name = "caffe2_serialize",
Expand All @@ -22,6 +29,60 @@ def define_targets(rules):
],
)

#
# ATen generated code
# You need to keep this is sync with the files written out
# by gen.py (in the cmake build system, we track generated files
# via generated_cpp.txt and generated_cpp.txt-cuda
#
# Sure would be nice to use gen.py to create this list dynamically
# instead of hardcoding, no? Well, we can't, as discussed in this
# thread:
# https://fb.facebook.com/groups/askbuck/permalink/1924258337622772/

gen_aten_srcs = [
"aten/src/ATen/native/native_functions.yaml",
"aten/src/ATen/native/tags.yaml",
] + rules.glob(["aten/src/ATen/templates/*"])

gen_aten_cmd = " ".join([
"$(location //torchgen:gen)",
"--install_dir=$(RULEDIR)",
"--source-path aten/src/ATen",
] + (["--static_dispatch_backend CPU"] if rules.is_cpu_static_dispatch_build() else []))

gen_aten_outs_cuda = (
GENERATED_H_CUDA + GENERATED_CPP_CUDA +
aten_ufunc_generated_cuda_sources()
)

gen_aten_outs = (
GENERATED_H + GENERATED_H_CORE +
GENERATED_CPP + GENERATED_CPP_CORE +
aten_ufunc_generated_cpu_sources() +
aten_ufunc_generated_cpu_kernel_sources() + [
"Declarations.yaml",
] + gen_aten_outs_cuda
)

rules.genrule(
name = "gen_aten",
srcs = gen_aten_srcs,
tools = ["//torchgen:gen"],
outs = gen_aten_outs,
cmd = gen_aten_cmd,
)

rules.genrule(
name = "gen_aten_hip",
srcs = gen_aten_srcs,
tools = ["//torchgen:gen"],
outs = gen_aten_outs_cuda,
cmd = gen_aten_cmd + " --rocm",
features = ["-create_bazel_outputs"],
tags = ["-bazel"],
)

rules.genrule(
name = "generate-code",
srcs = [
Expand Down
4 changes: 4 additions & 0 deletions tools/bazel.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,9 @@ def _genrule(**kwds):
if _enabled(**kwds):
native.genrule(**kwds)

def _is_cpu_static_dispatch_build():
return False

def _py_library(name, **kwds):
deps = [dep for dep in kwds.pop("deps", []) if dep != None]
native.py_library(name = name, deps = deps, **kwds)
Expand All @@ -26,6 +29,7 @@ rules = struct(
genrule = _genrule,
glob = native.glob,
if_cuda = if_cuda,
is_cpu_static_dispatch_build = _is_cpu_static_dispatch_build,
py_binary = native.py_binary,
py_library = _py_library,
requirement = _requirement,
Expand Down

0 comments on commit eb5751d

Please sign in to comment.