Skip to content

Commit

Permalink
Pass XlaOp by value
Browse files Browse the repository at this point in the history
As recommended by Google style guide.
  • Loading branch information
asuhan authored and dlibenzi committed Dec 6, 2019
1 parent 7712819 commit 85c1dee
Show file tree
Hide file tree
Showing 46 changed files with 362 additions and 450 deletions.
27 changes: 11 additions & 16 deletions torch_xla/csrc/batch_norm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
namespace torch_xla {
namespace {

xla::XlaOp VarianceRecover(const xla::XlaOp& invstd, float eps_value) {
xla::XlaOp VarianceRecover(xla::XlaOp invstd, float eps_value) {
xla::XlaBuilder* builder = invstd.builder();
const xla::Shape& invstd_shape = XlaHelpers::ShapeOfXlaOp(invstd);
xla::XlaOp eps =
Expand All @@ -18,8 +18,7 @@ xla::XlaOp VarianceRecover(const xla::XlaOp& invstd, float eps_value) {

} // namespace

xla::XlaOp BatchNormVarianceInvert(const xla::XlaOp& variance,
float eps_value) {
xla::XlaOp BatchNormVarianceInvert(xla::XlaOp variance, float eps_value) {
xla::XlaBuilder* builder = variance.builder();
const xla::Shape& variance_shape = XlaHelpers::ShapeOfXlaOp(variance);
xla::XlaOp eps = XlaHelpers::ScalarValue(
Expand All @@ -29,10 +28,8 @@ xla::XlaOp BatchNormVarianceInvert(const xla::XlaOp& variance,
return one / xla::Sqrt(variance + eps);
}

BatchNormOutput BuildBatchNormTraining(const xla::XlaOp& input,
const xla::XlaOp& weight,
const xla::XlaOp& bias,
float eps_value) {
BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight,
xla::XlaOp bias, float eps_value) {
xla::XlaOp outputs =
xla::BatchNormTraining(input, weight, bias, eps_value, 1);
xla::XlaOp output = xla::GetTupleElement(outputs, 0);
Expand All @@ -41,19 +38,17 @@ BatchNormOutput BuildBatchNormTraining(const xla::XlaOp& input,
return {output, batch_mean, batch_variance};
}

xla::XlaOp BuildBatchNormInference(
const xla::XlaOp& input, const xla::XlaOp& weight, const xla::XlaOp& bias,
const xla::XlaOp& mean, const xla::XlaOp& variance, float eps_value) {
xla::XlaOp BuildBatchNormInference(xla::XlaOp input, xla::XlaOp weight,
xla::XlaOp bias, xla::XlaOp mean,
xla::XlaOp variance, float eps_value) {
return xla::BatchNormInference(input, weight, bias, mean, variance, eps_value,
1);
}

BatchNormGrads BuildBatchNormBackward(const xla::XlaOp& grad,
const xla::XlaOp& input,
const xla::XlaOp& weight,
const xla::XlaOp& save_mean,
const xla::XlaOp& save_invstd,
bool training, float eps_value) {
BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input,
xla::XlaOp weight, xla::XlaOp save_mean,
xla::XlaOp save_invstd, bool training,
float eps_value) {
xla::XlaOp grads = xla::BatchNormGrad(input, weight, save_mean,
VarianceRecover(save_invstd, eps_value),
grad, eps_value, 1);
Expand Down
31 changes: 13 additions & 18 deletions torch_xla/csrc/batch_norm.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,23 +16,18 @@ struct BatchNormGrads {
xla::XlaOp grad_bias;
};

xla::XlaOp BatchNormVarianceInvert(const xla::XlaOp& variance, float eps_value);

BatchNormOutput BuildBatchNormTraining(const xla::XlaOp& input,
const xla::XlaOp& weight,
const xla::XlaOp& bias, float eps_value);

xla::XlaOp BuildBatchNormInference(const xla::XlaOp& input,
const xla::XlaOp& weight,
const xla::XlaOp& bias,
const xla::XlaOp& mean,
const xla::XlaOp& variance, float eps_value);

BatchNormGrads BuildBatchNormBackward(const xla::XlaOp& grad,
const xla::XlaOp& input,
const xla::XlaOp& weight,
const xla::XlaOp& save_mean,
const xla::XlaOp& save_invstd,
bool training, float eps_value);
xla::XlaOp BatchNormVarianceInvert(xla::XlaOp variance, float eps_value);

BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight,
xla::XlaOp bias, float eps_value);

xla::XlaOp BuildBatchNormInference(xla::XlaOp input, xla::XlaOp weight,
xla::XlaOp bias, xla::XlaOp mean,
xla::XlaOp variance, float eps_value);

BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input,
xla::XlaOp weight, xla::XlaOp save_mean,
xla::XlaOp save_invstd, bool training,
float eps_value);

} // namespace torch_xla
11 changes: 5 additions & 6 deletions torch_xla/csrc/convert_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,14 @@
namespace torch_xla {
namespace {

xla::XlaOp ExplicitBooleanConvert(const xla::XlaOp& op,
xla::PrimitiveType from) {
xla::XlaOp ExplicitBooleanConvert(xla::XlaOp op, xla::PrimitiveType from) {
xla::XlaOp zero = xla::Zero(op.builder(), from);
return xla::Ne(op, zero);
}

} // namespace

xla::XlaOp ConvertTo(const xla::XlaOp& op, xla::PrimitiveType from,
xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType to, const Device* device) {
if (device == nullptr) {
device = GetDefaultDevice();
Expand Down Expand Up @@ -51,7 +50,7 @@ xla::XlaOp ConvertTo(const xla::XlaOp& op, xla::PrimitiveType from,
}
}

xla::XlaOp ConvertToNumeric(const xla::XlaOp& op, xla::PrimitiveType from) {
xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from) {
const Device* device = GetDefaultDevice();
return from != xla::PrimitiveType::PRED
? op
Expand All @@ -60,11 +59,11 @@ xla::XlaOp ConvertToNumeric(const xla::XlaOp& op, xla::PrimitiveType from) {
device);
}

xla::XlaOp ConvertToNumeric(const xla::XlaOp& op) {
xla::XlaOp ConvertToNumeric(xla::XlaOp op) {
return ConvertToNumeric(op, XlaHelpers::TypeOfXlaOp(op));
}

xla::XlaOp CastToScalarType(const xla::XlaOp& input,
xla::XlaOp CastToScalarType(xla::XlaOp input,
c10::optional<at::ScalarType> dtype) {
if (dtype) {
return ConvertTo(input, XlaHelpers::TypeOfXlaOp(input),
Expand Down
8 changes: 4 additions & 4 deletions torch_xla/csrc/convert_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,16 @@

namespace torch_xla {

xla::XlaOp ConvertTo(const xla::XlaOp& op, xla::PrimitiveType from,
xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType to, const Device* device);

xla::XlaOp ConvertToNumeric(const xla::XlaOp& op, xla::PrimitiveType from);
xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from);

xla::XlaOp ConvertToNumeric(const xla::XlaOp& op);
xla::XlaOp ConvertToNumeric(xla::XlaOp op);

// Cast the input to the given dtype. If dtype is null, no-op with the exception
// of predicates, which are converted to 8-bit unsigned integers.
xla::XlaOp CastToScalarType(const xla::XlaOp& input,
xla::XlaOp CastToScalarType(xla::XlaOp input,
c10::optional<at::ScalarType> dtype);

} // namespace torch_xla
18 changes: 7 additions & 11 deletions torch_xla/csrc/convolution.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,7 @@ std::vector<std::pair<xla::int64, xla::int64>> MakePadding(

// Computes the input gradient for a convolution.
xla::XlaOp BuildConvBackwardInput(
const xla::XlaOp& grad_output, const xla::XlaOp& kernel,
const xla::Shape& input_shape,
xla::XlaOp grad_output, xla::XlaOp kernel, const xla::Shape& input_shape,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_stride,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_padding,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_dilation,
Expand All @@ -209,8 +208,7 @@ xla::XlaOp BuildConvBackwardInput(

// Computes the kernel gradient for a convolution.
xla::XlaOp BuildConvBackwardWeight(
const xla::XlaOp& grad_output, const xla::XlaOp& input,
const xla::Shape& kernel_shape,
xla::XlaOp grad_output, xla::XlaOp input, const xla::Shape& kernel_shape,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_stride,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_padding,
tensorflow::gtl::ArraySlice<const xla::int64> spatial_dilation,
Expand Down Expand Up @@ -247,7 +245,7 @@ xla::XlaOp BuildGradBias(xla::XlaOp grad_output) {
}

xla::XlaOp BuildTransposedConvolution(
const xla::XlaOp& input, const xla::XlaOp& kernel,
xla::XlaOp input, xla::XlaOp kernel,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
tensorflow::gtl::ArraySlice<const xla::int64> dilation,
Expand Down Expand Up @@ -275,8 +273,7 @@ xla::XlaOp BuildTransposedConvolution(
}

ConvGrads BuildTransposedConvolutionBackward(
const xla::XlaOp& grad_output, const xla::XlaOp& input,
const xla::XlaOp& kernel,
xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp kernel,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
tensorflow::gtl::ArraySlice<const xla::int64> dilation,
Expand All @@ -295,7 +292,7 @@ ConvGrads BuildTransposedConvolutionBackward(
} // namespace

xla::XlaOp BuildConvolutionOverrideable(
const xla::XlaOp& input, const xla::XlaOp& kernel,
xla::XlaOp input, xla::XlaOp kernel,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
tensorflow::gtl::ArraySlice<const xla::int64> dilation, bool transposed,
Expand All @@ -320,7 +317,7 @@ xla::XlaOp BuildConvolutionOverrideable(
}

xla::XlaOp BuildConvolutionOverrideableBias(
const xla::XlaOp& input, const xla::XlaOp& kernel, const xla::XlaOp& bias,
xla::XlaOp input, xla::XlaOp kernel, xla::XlaOp bias,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
tensorflow::gtl::ArraySlice<const xla::int64> dilation, bool transposed,
Expand All @@ -340,8 +337,7 @@ xla::XlaOp BuildConvolutionOverrideableBias(
}

ConvGrads BuildConvolutionBackwardOverrideable(
const xla::XlaOp& grad_output, const xla::XlaOp& input,
const xla::XlaOp& kernel,
xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp kernel,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
tensorflow::gtl::ArraySlice<const xla::int64> dilation, bool transposed,
Expand Down
7 changes: 3 additions & 4 deletions torch_xla/csrc/convolution.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ namespace torch_xla {
// Computes the convolution of the given input and kernel with the given
// precision, with the given stride and padding.
xla::XlaOp BuildConvolutionOverrideable(
const xla::XlaOp& input, const xla::XlaOp& kernel,
xla::XlaOp input, xla::XlaOp kernel,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
tensorflow::gtl::ArraySlice<const xla::int64> dilation, bool transposed,
Expand All @@ -17,7 +17,7 @@ xla::XlaOp BuildConvolutionOverrideable(

// Same as above, then broadcasts the bias and adds it to the result.
xla::XlaOp BuildConvolutionOverrideableBias(
const xla::XlaOp& input, const xla::XlaOp& kernel, const xla::XlaOp& bias,
xla::XlaOp input, xla::XlaOp kernel, xla::XlaOp bias,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
tensorflow::gtl::ArraySlice<const xla::int64> dilation, bool transposed,
Expand All @@ -32,8 +32,7 @@ struct ConvGrads {

// Computes the gradients for a convolution with the given stride and padding.
ConvGrads BuildConvolutionBackwardOverrideable(
const xla::XlaOp& grad_output, const xla::XlaOp& input,
const xla::XlaOp& kernel,
xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp kernel,
tensorflow::gtl::ArraySlice<const xla::int64> stride,
tensorflow::gtl::ArraySlice<const xla::int64> padding,
tensorflow::gtl::ArraySlice<const xla::int64> dilation, bool transposed,
Expand Down
5 changes: 2 additions & 3 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,8 @@ xla::XlaComputation GetReduceComutation(AllReduceType reduce_type,

std::vector<xla::XlaOp> BuildAllReduce(
AllReduceType reduce_type,
tensorflow::gtl::ArraySlice<const xla::XlaOp> operands,
const xla::XlaOp& token, double scale,
const std::vector<std::vector<xla::int64>>& groups) {
tensorflow::gtl::ArraySlice<const xla::XlaOp> operands, xla::XlaOp token,
double scale, const std::vector<std::vector<xla::int64>>& groups) {
std::vector<xla::ReplicaGroup> reduce_groups;
for (auto& group : groups) {
xla::ReplicaGroup rgroup;
Expand Down
5 changes: 2 additions & 3 deletions torch_xla/csrc/cross_replica_reduces.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ enum class AllReduceType {

std::vector<xla::XlaOp> BuildAllReduce(
AllReduceType reduce_type,
tensorflow::gtl::ArraySlice<const xla::XlaOp> operands,
const xla::XlaOp& token, double scale,
const std::vector<std::vector<xla::int64>>& groups);
tensorflow::gtl::ArraySlice<const xla::XlaOp> operands, xla::XlaOp token,
double scale, const std::vector<std::vector<xla::int64>>& groups);

} // namespace torch_xla
Loading

0 comments on commit 85c1dee

Please sign in to comment.