diff --git a/torch_xla/csrc/batch_norm.cpp b/torch_xla/csrc/batch_norm.cpp index d594b4fa9845..6410d7599d26 100644 --- a/torch_xla/csrc/batch_norm.cpp +++ b/torch_xla/csrc/batch_norm.cpp @@ -5,7 +5,7 @@ namespace torch_xla { namespace { -xla::XlaOp VarianceRecover(const xla::XlaOp& invstd, float eps_value) { +xla::XlaOp VarianceRecover(xla::XlaOp invstd, float eps_value) { xla::XlaBuilder* builder = invstd.builder(); const xla::Shape& invstd_shape = XlaHelpers::ShapeOfXlaOp(invstd); xla::XlaOp eps = @@ -18,8 +18,7 @@ xla::XlaOp VarianceRecover(const xla::XlaOp& invstd, float eps_value) { } // namespace -xla::XlaOp BatchNormVarianceInvert(const xla::XlaOp& variance, - float eps_value) { +xla::XlaOp BatchNormVarianceInvert(xla::XlaOp variance, float eps_value) { xla::XlaBuilder* builder = variance.builder(); const xla::Shape& variance_shape = XlaHelpers::ShapeOfXlaOp(variance); xla::XlaOp eps = XlaHelpers::ScalarValue( @@ -29,10 +28,8 @@ xla::XlaOp BatchNormVarianceInvert(const xla::XlaOp& variance, return one / xla::Sqrt(variance + eps); } -BatchNormOutput BuildBatchNormTraining(const xla::XlaOp& input, - const xla::XlaOp& weight, - const xla::XlaOp& bias, - float eps_value) { +BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight, + xla::XlaOp bias, float eps_value) { xla::XlaOp outputs = xla::BatchNormTraining(input, weight, bias, eps_value, 1); xla::XlaOp output = xla::GetTupleElement(outputs, 0); @@ -41,19 +38,17 @@ BatchNormOutput BuildBatchNormTraining(const xla::XlaOp& input, return {output, batch_mean, batch_variance}; } -xla::XlaOp BuildBatchNormInference( - const xla::XlaOp& input, const xla::XlaOp& weight, const xla::XlaOp& bias, - const xla::XlaOp& mean, const xla::XlaOp& variance, float eps_value) { +xla::XlaOp BuildBatchNormInference(xla::XlaOp input, xla::XlaOp weight, + xla::XlaOp bias, xla::XlaOp mean, + xla::XlaOp variance, float eps_value) { return xla::BatchNormInference(input, weight, bias, mean, variance, eps_value, 1); } -BatchNormGrads BuildBatchNormBackward(const xla::XlaOp& grad, - const xla::XlaOp& input, - const xla::XlaOp& weight, - const xla::XlaOp& save_mean, - const xla::XlaOp& save_invstd, - bool training, float eps_value) { +BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input, + xla::XlaOp weight, xla::XlaOp save_mean, + xla::XlaOp save_invstd, bool training, + float eps_value) { xla::XlaOp grads = xla::BatchNormGrad(input, weight, save_mean, VarianceRecover(save_invstd, eps_value), grad, eps_value, 1); diff --git a/torch_xla/csrc/batch_norm.h b/torch_xla/csrc/batch_norm.h index 04f6f23a95c8..1e91602acf8c 100644 --- a/torch_xla/csrc/batch_norm.h +++ b/torch_xla/csrc/batch_norm.h @@ -16,23 +16,18 @@ struct BatchNormGrads { xla::XlaOp grad_bias; }; -xla::XlaOp BatchNormVarianceInvert(const xla::XlaOp& variance, float eps_value); - -BatchNormOutput BuildBatchNormTraining(const xla::XlaOp& input, - const xla::XlaOp& weight, - const xla::XlaOp& bias, float eps_value); - -xla::XlaOp BuildBatchNormInference(const xla::XlaOp& input, - const xla::XlaOp& weight, - const xla::XlaOp& bias, - const xla::XlaOp& mean, - const xla::XlaOp& variance, float eps_value); - -BatchNormGrads BuildBatchNormBackward(const xla::XlaOp& grad, - const xla::XlaOp& input, - const xla::XlaOp& weight, - const xla::XlaOp& save_mean, - const xla::XlaOp& save_invstd, - bool training, float eps_value); +xla::XlaOp BatchNormVarianceInvert(xla::XlaOp variance, float eps_value); + +BatchNormOutput BuildBatchNormTraining(xla::XlaOp input, xla::XlaOp weight, + xla::XlaOp bias, float eps_value); + +xla::XlaOp BuildBatchNormInference(xla::XlaOp input, xla::XlaOp weight, + xla::XlaOp bias, xla::XlaOp mean, + xla::XlaOp variance, float eps_value); + +BatchNormGrads BuildBatchNormBackward(xla::XlaOp grad, xla::XlaOp input, + xla::XlaOp weight, xla::XlaOp save_mean, + xla::XlaOp save_invstd, bool training, + float eps_value); } // namespace torch_xla diff --git a/torch_xla/csrc/convert_ops.cpp b/torch_xla/csrc/convert_ops.cpp index 07baee0a5bb0..149a941e575a 100644 --- a/torch_xla/csrc/convert_ops.cpp +++ b/torch_xla/csrc/convert_ops.cpp @@ -9,15 +9,14 @@ namespace torch_xla { namespace { -xla::XlaOp ExplicitBooleanConvert(const xla::XlaOp& op, - xla::PrimitiveType from) { +xla::XlaOp ExplicitBooleanConvert(xla::XlaOp op, xla::PrimitiveType from) { xla::XlaOp zero = xla::Zero(op.builder(), from); return xla::Ne(op, zero); } } // namespace -xla::XlaOp ConvertTo(const xla::XlaOp& op, xla::PrimitiveType from, +xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, xla::PrimitiveType to, const Device* device) { if (device == nullptr) { device = GetDefaultDevice(); @@ -51,7 +50,7 @@ xla::XlaOp ConvertTo(const xla::XlaOp& op, xla::PrimitiveType from, } } -xla::XlaOp ConvertToNumeric(const xla::XlaOp& op, xla::PrimitiveType from) { +xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from) { const Device* device = GetDefaultDevice(); return from != xla::PrimitiveType::PRED ? op @@ -60,11 +59,11 @@ xla::XlaOp ConvertToNumeric(const xla::XlaOp& op, xla::PrimitiveType from) { device); } -xla::XlaOp ConvertToNumeric(const xla::XlaOp& op) { +xla::XlaOp ConvertToNumeric(xla::XlaOp op) { return ConvertToNumeric(op, XlaHelpers::TypeOfXlaOp(op)); } -xla::XlaOp CastToScalarType(const xla::XlaOp& input, +xla::XlaOp CastToScalarType(xla::XlaOp input, c10::optional dtype) { if (dtype) { return ConvertTo(input, XlaHelpers::TypeOfXlaOp(input), diff --git a/torch_xla/csrc/convert_ops.h b/torch_xla/csrc/convert_ops.h index 4bc8c49672ab..15e5e447a4af 100644 --- a/torch_xla/csrc/convert_ops.h +++ b/torch_xla/csrc/convert_ops.h @@ -9,16 +9,16 @@ namespace torch_xla { -xla::XlaOp ConvertTo(const xla::XlaOp& op, xla::PrimitiveType from, +xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from, xla::PrimitiveType to, const Device* device); -xla::XlaOp ConvertToNumeric(const xla::XlaOp& op, xla::PrimitiveType from); +xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from); -xla::XlaOp ConvertToNumeric(const xla::XlaOp& op); +xla::XlaOp ConvertToNumeric(xla::XlaOp op); // Cast the input to the given dtype. If dtype is null, no-op with the exception // of predicates, which are converted to 8-bit unsigned integers. -xla::XlaOp CastToScalarType(const xla::XlaOp& input, +xla::XlaOp CastToScalarType(xla::XlaOp input, c10::optional dtype); } // namespace torch_xla diff --git a/torch_xla/csrc/convolution.cpp b/torch_xla/csrc/convolution.cpp index df8e37c587ce..0d8d14788f46 100644 --- a/torch_xla/csrc/convolution.cpp +++ b/torch_xla/csrc/convolution.cpp @@ -190,8 +190,7 @@ std::vector> MakePadding( // Computes the input gradient for a convolution. xla::XlaOp BuildConvBackwardInput( - const xla::XlaOp& grad_output, const xla::XlaOp& kernel, - const xla::Shape& input_shape, + xla::XlaOp grad_output, xla::XlaOp kernel, const xla::Shape& input_shape, tensorflow::gtl::ArraySlice spatial_stride, tensorflow::gtl::ArraySlice spatial_padding, tensorflow::gtl::ArraySlice spatial_dilation, @@ -209,8 +208,7 @@ xla::XlaOp BuildConvBackwardInput( // Computes the kernel gradient for a convolution. xla::XlaOp BuildConvBackwardWeight( - const xla::XlaOp& grad_output, const xla::XlaOp& input, - const xla::Shape& kernel_shape, + xla::XlaOp grad_output, xla::XlaOp input, const xla::Shape& kernel_shape, tensorflow::gtl::ArraySlice spatial_stride, tensorflow::gtl::ArraySlice spatial_padding, tensorflow::gtl::ArraySlice spatial_dilation, @@ -247,7 +245,7 @@ xla::XlaOp BuildGradBias(xla::XlaOp grad_output) { } xla::XlaOp BuildTransposedConvolution( - const xla::XlaOp& input, const xla::XlaOp& kernel, + xla::XlaOp input, xla::XlaOp kernel, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, tensorflow::gtl::ArraySlice dilation, @@ -275,8 +273,7 @@ xla::XlaOp BuildTransposedConvolution( } ConvGrads BuildTransposedConvolutionBackward( - const xla::XlaOp& grad_output, const xla::XlaOp& input, - const xla::XlaOp& kernel, + xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp kernel, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, tensorflow::gtl::ArraySlice dilation, @@ -295,7 +292,7 @@ ConvGrads BuildTransposedConvolutionBackward( } // namespace xla::XlaOp BuildConvolutionOverrideable( - const xla::XlaOp& input, const xla::XlaOp& kernel, + xla::XlaOp input, xla::XlaOp kernel, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, tensorflow::gtl::ArraySlice dilation, bool transposed, @@ -320,7 +317,7 @@ xla::XlaOp BuildConvolutionOverrideable( } xla::XlaOp BuildConvolutionOverrideableBias( - const xla::XlaOp& input, const xla::XlaOp& kernel, const xla::XlaOp& bias, + xla::XlaOp input, xla::XlaOp kernel, xla::XlaOp bias, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, tensorflow::gtl::ArraySlice dilation, bool transposed, @@ -340,8 +337,7 @@ xla::XlaOp BuildConvolutionOverrideableBias( } ConvGrads BuildConvolutionBackwardOverrideable( - const xla::XlaOp& grad_output, const xla::XlaOp& input, - const xla::XlaOp& kernel, + xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp kernel, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, tensorflow::gtl::ArraySlice dilation, bool transposed, diff --git a/torch_xla/csrc/convolution.h b/torch_xla/csrc/convolution.h index 7220e63817fd..23b0db657c3d 100644 --- a/torch_xla/csrc/convolution.h +++ b/torch_xla/csrc/convolution.h @@ -8,7 +8,7 @@ namespace torch_xla { // Computes the convolution of the given input and kernel with the given // precision, with the given stride and padding. xla::XlaOp BuildConvolutionOverrideable( - const xla::XlaOp& input, const xla::XlaOp& kernel, + xla::XlaOp input, xla::XlaOp kernel, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, tensorflow::gtl::ArraySlice dilation, bool transposed, @@ -17,7 +17,7 @@ xla::XlaOp BuildConvolutionOverrideable( // Same as above, then broadcasts the bias and adds it to the result. xla::XlaOp BuildConvolutionOverrideableBias( - const xla::XlaOp& input, const xla::XlaOp& kernel, const xla::XlaOp& bias, + xla::XlaOp input, xla::XlaOp kernel, xla::XlaOp bias, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, tensorflow::gtl::ArraySlice dilation, bool transposed, @@ -32,8 +32,7 @@ struct ConvGrads { // Computes the gradients for a convolution with the given stride and padding. ConvGrads BuildConvolutionBackwardOverrideable( - const xla::XlaOp& grad_output, const xla::XlaOp& input, - const xla::XlaOp& kernel, + xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp kernel, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, tensorflow::gtl::ArraySlice dilation, bool transposed, diff --git a/torch_xla/csrc/cross_replica_reduces.cpp b/torch_xla/csrc/cross_replica_reduces.cpp index 2a9f6588db64..1c48cfda178d 100644 --- a/torch_xla/csrc/cross_replica_reduces.cpp +++ b/torch_xla/csrc/cross_replica_reduces.cpp @@ -56,9 +56,8 @@ xla::XlaComputation GetReduceComutation(AllReduceType reduce_type, std::vector BuildAllReduce( AllReduceType reduce_type, - tensorflow::gtl::ArraySlice operands, - const xla::XlaOp& token, double scale, - const std::vector>& groups) { + tensorflow::gtl::ArraySlice operands, xla::XlaOp token, + double scale, const std::vector>& groups) { std::vector reduce_groups; for (auto& group : groups) { xla::ReplicaGroup rgroup; diff --git a/torch_xla/csrc/cross_replica_reduces.h b/torch_xla/csrc/cross_replica_reduces.h index ad463fa2b244..9f7fd40382a6 100644 --- a/torch_xla/csrc/cross_replica_reduces.h +++ b/torch_xla/csrc/cross_replica_reduces.h @@ -18,8 +18,7 @@ enum class AllReduceType { std::vector BuildAllReduce( AllReduceType reduce_type, - tensorflow::gtl::ArraySlice operands, - const xla::XlaOp& token, double scale, - const std::vector>& groups); + tensorflow::gtl::ArraySlice operands, xla::XlaOp token, + double scale, const std::vector>& groups); } // namespace torch_xla diff --git a/torch_xla/csrc/data_ops.cpp b/torch_xla/csrc/data_ops.cpp index f20bb8ce7af1..9a0b10caa252 100644 --- a/torch_xla/csrc/data_ops.cpp +++ b/torch_xla/csrc/data_ops.cpp @@ -41,8 +41,7 @@ std::vector GetReflectionPad2dSpatialDims(xla::int64 rank) { } // namespace -bool IsSparseGather(const xla::XlaOp& input, const xla::XlaOp& index, - xla::int64 dim) { +bool IsSparseGather(xla::XlaOp input, xla::XlaOp index, xla::int64 dim) { return IsSparseGather(XlaHelpers::ShapeOfXlaOp(input), XlaHelpers::ShapeOfXlaOp(index), dim); } @@ -105,7 +104,7 @@ absl::optional GetDynamicReshapeInfo( } xla::XlaOp BuildView( - const xla::XlaOp& input, + xla::XlaOp input, tensorflow::gtl::ArraySlice output_sizes) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); const auto complete_output_sizes = @@ -118,7 +117,7 @@ xla::XlaOp BuildView( return xla::Reshape(input, complete_output_sizes); } -xla::XlaOp SqueezeTrivialDimension(const xla::XlaOp& input, size_t dim) { +xla::XlaOp SqueezeTrivialDimension(xla::XlaOp input, size_t dim) { auto input_sizes = XlaHelpers::SizesOfXlaOp(input); XLA_CHECK_LT(dim, input_sizes.size()); if (input_sizes[dim] != 1) { @@ -128,7 +127,7 @@ xla::XlaOp SqueezeTrivialDimension(const xla::XlaOp& input, size_t dim) { return xla::Reshape(input, input_sizes); } -xla::XlaOp SqueezeAllTrivialDimensions(const xla::XlaOp& input) { +xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input) { auto input_sizes = XlaHelpers::SizesOfXlaOp(input); // Squeeze the trivial (of size 1) dimensions. std::vector non_singleton_dimensions; @@ -139,7 +138,7 @@ xla::XlaOp SqueezeAllTrivialDimensions(const xla::XlaOp& input) { } xla::XlaOp BuildExpand( - const xla::XlaOp& input, + xla::XlaOp input, tensorflow::gtl::ArraySlice output_sizes) { auto input_sizes = XlaHelpers::SizesOfXlaOp(input); // Adjust the rank of the input to match the rank of the output. @@ -159,7 +158,7 @@ std::vector BuildUnsqueezeDimensions( return unsqueeze_dimensions; } -xla::XlaOp BuildUnsqueeze(const xla::XlaOp& input, size_t dim) { +xla::XlaOp BuildUnsqueeze(xla::XlaOp input, size_t dim) { auto dimensions = BuildUnsqueezeDimensions(XlaHelpers::SizesOfXlaOp(input), dim); return xla::Reshape(input, dimensions); @@ -184,7 +183,7 @@ xla::XlaOp BuildCat(tensorflow::gtl::ArraySlice inputs, return xla::ConcatInDim(inputs[0].builder(), inputs, dim); } -xla::XlaOp BuildRepeat(const xla::XlaOp& input, +xla::XlaOp BuildRepeat(xla::XlaOp input, tensorflow::gtl::ArraySlice repeats) { const auto input_sizes = XlaHelpers::SizesOfXlaOp(input); XLA_CHECK_GE(repeats.size(), input_sizes.size()) @@ -220,8 +219,8 @@ size_t ComputeSplitCount( } std::vector BuildSplit( - const xla::XlaOp& input, - tensorflow::gtl::ArraySlice split_sizes, xla::int64 dim) { + xla::XlaOp input, tensorflow::gtl::ArraySlice split_sizes, + xla::int64 dim) { const auto input_sizes = XlaHelpers::SizesOfXlaOp(input); xla::int64 dim_size = input_sizes.at(dim); xla::int64 index = 0; @@ -237,7 +236,7 @@ std::vector BuildSplit( } xla::XlaOp BuildUpdateSlice( - const xla::XlaOp& input, const xla::XlaOp& source, + xla::XlaOp input, xla::XlaOp source, tensorflow::gtl::ArraySlice base_indices) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); const xla::Shape& source_shape = XlaHelpers::ShapeOfXlaOp(source); @@ -257,7 +256,7 @@ xla::XlaOp BuildUpdateSlice( } xla::XlaOp BuildSlice( - const xla::XlaOp& input, + xla::XlaOp input, tensorflow::gtl::ArraySlice base_indices, tensorflow::gtl::ArraySlice sizes) { XLA_CHECK_EQ(base_indices.size(), sizes.size()); @@ -269,14 +268,14 @@ xla::XlaOp BuildSlice( return xla::Slice(input, base_indices, limit_indices, strides); } -xla::XlaOp BoundIndices(const xla::XlaOp& index, const xla::XlaOp& max_index) { +xla::XlaOp BoundIndices(xla::XlaOp index, xla::XlaOp max_index) { const xla::Shape& index_shape = XlaHelpers::ShapeOfXlaOp(index); return xla::Select( xla::Ge(index, xla::Zero(index.builder(), index_shape.element_type())), index, index + max_index); } -xla::XlaOp BuildTake(const xla::XlaOp& input, const xla::XlaOp& index) { +xla::XlaOp BuildTake(xla::XlaOp input, xla::XlaOp index) { static const int take_dim = 0; xla::Shape input_shape; xla::XlaOp r1_input = XlaHelpers::Flatten(input, &input_shape); @@ -292,7 +291,7 @@ xla::XlaOp BuildTake(const xla::XlaOp& input, const xla::XlaOp& index) { return xla::Reshape(r1_result, index_shape.dimensions()); } -xla::XlaOp BuildResize(const xla::XlaOp& input, +xla::XlaOp BuildResize(xla::XlaOp input, tensorflow::gtl::ArraySlice size) { xla::Shape input_shape; xla::XlaOp r1_input = XlaHelpers::Flatten(input, &input_shape); @@ -313,9 +312,8 @@ xla::XlaOp BuildResize(const xla::XlaOp& input, return xla::Reshape(resized_input, size); } -xla::XlaOp BuildUnselect(const xla::XlaOp& target, const xla::XlaOp& source, - xla::int64 dim, xla::int64 start, xla::int64 end, - xla::int64 stride) { +xla::XlaOp BuildUnselect(xla::XlaOp target, xla::XlaOp source, xla::int64 dim, + xla::int64 start, xla::int64 end, xla::int64 stride) { const xla::Shape& target_shape = XlaHelpers::ShapeOfXlaOp(target); const xla::Shape& source_shape = XlaHelpers::ShapeOfXlaOp(source); if (target_shape.dimensions(dim) == source_shape.dimensions(dim)) { @@ -358,8 +356,7 @@ xla::XlaOp BuildUnselect(const xla::XlaOp& target, const xla::XlaOp& source, } xla::XlaOp BuildReflectionPad2d( - const xla::XlaOp& input, - tensorflow::gtl::ArraySlice padding) { + xla::XlaOp input, tensorflow::gtl::ArraySlice padding) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); std::vector spatial_dims = GetReflectionPad2dSpatialDims(input_shape.rank()); @@ -384,7 +381,7 @@ xla::XlaOp BuildReflectionPad2d( } xla::XlaOp BuildReflectionPad2dBackward( - const xla::XlaOp& grad_output, const xla::XlaOp& input, + xla::XlaOp grad_output, xla::XlaOp input, tensorflow::gtl::ArraySlice padding) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); const xla::Shape& grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output); @@ -423,7 +420,7 @@ xla::XlaOp BuildReflectionPad2dBackward( return grad; } -xla::XlaOp PadInDim(const xla::XlaOp& input, xla::int64 dim, xla::int64 pad_lo, +xla::XlaOp PadInDim(xla::XlaOp input, xla::int64 dim, xla::int64 pad_lo, xla::int64 pad_hi, const xla::XlaOp* pad_value) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp zero; diff --git a/torch_xla/csrc/data_ops.h b/torch_xla/csrc/data_ops.h index 76b08228ad8e..e9aaa5f0f38e 100644 --- a/torch_xla/csrc/data_ops.h +++ b/torch_xla/csrc/data_ops.h @@ -15,8 +15,7 @@ struct DynamicReshapeInfo { xla::int64 dynamic_dimension = -1; }; -bool IsSparseGather(const xla::XlaOp& input, const xla::XlaOp& index, - xla::int64 dim); +bool IsSparseGather(xla::XlaOp input, xla::XlaOp index, xla::int64 dim); // For input_sizes and a potentially incomplete output_sizes, return a complete // output shape. The complete output shape has same total number of elements as @@ -33,27 +32,27 @@ absl::optional GetDynamicReshapeInfo( // Creates a new tensor with the same data as the input tensor and the specified // output size. xla::XlaOp BuildView( - const xla::XlaOp& input, + xla::XlaOp input, tensorflow::gtl::ArraySlice output_sizes); // Squeezes the given dimension if trivial (size 1), returns the unchanged input // otherwise. -xla::XlaOp SqueezeTrivialDimension(const xla::XlaOp& input, size_t dim); +xla::XlaOp SqueezeTrivialDimension(xla::XlaOp input, size_t dim); // Squeezes out the trivial (size 1) dimensions of the input. -xla::XlaOp SqueezeAllTrivialDimensions(const xla::XlaOp& input); +xla::XlaOp SqueezeAllTrivialDimensions(xla::XlaOp input); // Creates a new tensor with the singleton dimensions expanded to the specified // output sizes. xla::XlaOp BuildExpand( - const xla::XlaOp& input, + xla::XlaOp input, tensorflow::gtl::ArraySlice output_sizes); std::vector BuildUnsqueezeDimensions( tensorflow::gtl::ArraySlice dimensions, size_t dim); // Insert a dimension of size one at the specified position. -xla::XlaOp BuildUnsqueeze(const xla::XlaOp& input, size_t dim); +xla::XlaOp BuildUnsqueeze(xla::XlaOp input, size_t dim); // Concatenates a list of tensors along a new dimension dim. xla::XlaOp BuildStack(tensorflow::gtl::ArraySlice inputs, @@ -65,7 +64,7 @@ xla::XlaOp BuildCat(tensorflow::gtl::ArraySlice inputs, xla::int64 dim); // Repeats the input tensor along each dimension by the given number of repeats. -xla::XlaOp BuildRepeat(const xla::XlaOp& input, +xla::XlaOp BuildRepeat(xla::XlaOp input, tensorflow::gtl::ArraySlice repeats); // Computes the number of splits with a dimension size and the split sizes. @@ -76,40 +75,38 @@ size_t ComputeSplitCount( // Splits a tensor into parts whose size is passed in split_sizes, along the dim // dimension. std::vector BuildSplit( - const xla::XlaOp& input, - tensorflow::gtl::ArraySlice split_sizes, xla::int64 dim); + xla::XlaOp input, tensorflow::gtl::ArraySlice split_sizes, + xla::int64 dim); // Creates an updated version of input, where, starting at base_indices, source // if overlapped with input. xla::XlaOp BuildUpdateSlice( - const xla::XlaOp& input, const xla::XlaOp& source, + xla::XlaOp input, xla::XlaOp source, tensorflow::gtl::ArraySlice base_indices); xla::XlaOp BuildSlice( - const xla::XlaOp& input, + xla::XlaOp input, tensorflow::gtl::ArraySlice base_indices, tensorflow::gtl::ArraySlice sizes); -xla::XlaOp BoundIndices(const xla::XlaOp& index, const xla::XlaOp& max_index); +xla::XlaOp BoundIndices(xla::XlaOp index, xla::XlaOp max_index); -xla::XlaOp BuildTake(const xla::XlaOp& input, const xla::XlaOp& index); +xla::XlaOp BuildTake(xla::XlaOp input, xla::XlaOp index); -xla::XlaOp BuildResize(const xla::XlaOp& input, +xla::XlaOp BuildResize(xla::XlaOp input, tensorflow::gtl::ArraySlice size); -xla::XlaOp BuildUnselect(const xla::XlaOp& target, const xla::XlaOp& source, - xla::int64 dim, xla::int64 start, xla::int64 end, - xla::int64 stride); +xla::XlaOp BuildUnselect(xla::XlaOp target, xla::XlaOp source, xla::int64 dim, + xla::int64 start, xla::int64 end, xla::int64 stride); xla::XlaOp BuildReflectionPad2d( - const xla::XlaOp& input, - tensorflow::gtl::ArraySlice padding); + xla::XlaOp input, tensorflow::gtl::ArraySlice padding); xla::XlaOp BuildReflectionPad2dBackward( - const xla::XlaOp& grad_output, const xla::XlaOp& input, + xla::XlaOp grad_output, xla::XlaOp input, tensorflow::gtl::ArraySlice padding); -xla::XlaOp PadInDim(const xla::XlaOp& input, xla::int64 dim, xla::int64 pad_lo, +xla::XlaOp PadInDim(xla::XlaOp input, xla::int64 dim, xla::int64 pad_lo, xla::int64 pad_hi, const xla::XlaOp* pad_value = nullptr); } // namespace torch_xla diff --git a/torch_xla/csrc/elementwise.cpp b/torch_xla/csrc/elementwise.cpp index f1b3e59f4803..5c1ad7d16f4a 100644 --- a/torch_xla/csrc/elementwise.cpp +++ b/torch_xla/csrc/elementwise.cpp @@ -7,8 +7,7 @@ namespace torch_xla { namespace { -xla::XlaOp Between(const xla::XlaOp& input, at::Scalar min_val, - at::Scalar max_val) { +xla::XlaOp Between(xla::XlaOp input, at::Scalar min_val, at::Scalar max_val) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); xla::PrimitiveType element_type = shape.element_type(); xla::XlaBuilder* builder = input.builder(); @@ -23,8 +22,8 @@ xla::XlaOp Between(const xla::XlaOp& input, at::Scalar min_val, } // namespace -xla::XlaOp BuildComparisonOp(c10::Symbol kind, const xla::XlaOp& input, - const xla::XlaOp& other) { +xla::XlaOp BuildComparisonOp(c10::Symbol kind, xla::XlaOp input, + xla::XlaOp other) { std::pair ops = XlaHelpers::Promote(input, other); switch (kind) { case at::aten::ne: @@ -45,7 +44,7 @@ xla::XlaOp BuildComparisonOp(c10::Symbol kind, const xla::XlaOp& input, } } -xla::XlaOp BuildThreshold(const xla::XlaOp& input, const xla::XlaOp& output, +xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output, const float threshold, const float value) { xla::XlaBuilder* builder = input.builder(); const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); @@ -58,20 +57,20 @@ xla::XlaOp BuildThreshold(const xla::XlaOp& input, const xla::XlaOp& output, xla::Broadcast(xla_value, input_shape.dimensions())); } -xla::XlaOp BuildRelu(const xla::XlaOp& input) { +xla::XlaOp BuildRelu(xla::XlaOp input) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); return xla::Max(input, XlaHelpers::ScalarValue( 0, input_shape.element_type(), input.builder())); } -xla::XlaOp BuildHardshrink(const xla::XlaOp& input, at::Scalar lambda) { +xla::XlaOp BuildHardshrink(xla::XlaOp input, at::Scalar lambda) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); return xla::Select(Between(input, -lambda, lambda), XlaHelpers::ScalarBroadcast(0, shape, input.builder()), input); } -xla::XlaOp BuildSoftshrink(const xla::XlaOp& input, at::Scalar lambda) { +xla::XlaOp BuildSoftshrink(xla::XlaOp input, at::Scalar lambda) { xla::XlaBuilder* builder = input.builder(); const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp zero = XlaHelpers::ScalarBroadcast(0, shape, builder); @@ -83,17 +82,16 @@ xla::XlaOp BuildSoftshrink(const xla::XlaOp& input, at::Scalar lambda) { input - xla_lambd); } -xla::XlaOp BuildShrinkBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, at::Scalar lambda) { +xla::XlaOp BuildShrinkBackward(xla::XlaOp grad_output, xla::XlaOp input, + at::Scalar lambda) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); return xla::Select(Between(input, -lambda, lambda), XlaHelpers::ScalarBroadcast(0, shape, input.builder()), grad_output); } -xla::XlaOp BuildHardtanhBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, at::Scalar min_val, - at::Scalar max_val) { +xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input, + at::Scalar min_val, at::Scalar max_val) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(grad_output); xla::XlaOp zero = xla::Broadcast( XlaHelpers::ScalarValue(0, shape.element_type(), grad_output.builder()), @@ -101,12 +99,11 @@ xla::XlaOp BuildHardtanhBackward(const xla::XlaOp& grad_output, return xla::Select(Between(input, min_val, max_val), grad_output, zero); } -xla::XlaOp BuildLeakyRelu(const xla::XlaOp& input, - double negative_slope_value) { +xla::XlaOp BuildLeakyRelu(xla::XlaOp input, double negative_slope_value) { return BuildLeakyReluBackward(input, input, negative_slope_value); } -std::vector BuildRrelu(const xla::XlaOp& input, at::Scalar lower, +std::vector BuildRrelu(xla::XlaOp input, at::Scalar lower, at::Scalar upper, bool training) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp zero = @@ -131,10 +128,9 @@ std::vector BuildRrelu(const xla::XlaOp& input, at::Scalar lower, return {output, noise}; } -xla::XlaOp BuildRreluBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, const xla::XlaOp& noise, - at::Scalar lower, at::Scalar upper, - bool training) { +xla::XlaOp BuildRreluBackward(xla::XlaOp grad_output, xla::XlaOp input, + xla::XlaOp noise, at::Scalar lower, + at::Scalar upper, bool training) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp zero = XlaHelpers::ScalarValue(0, input_shape.element_type(), input.builder()); @@ -151,8 +147,7 @@ xla::XlaOp BuildRreluBackward(const xla::XlaOp& grad_output, return grad_input; } -xla::XlaOp BuildLeakyReluBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, +xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input, double negative_slope_value) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp zero = XlaHelpers::ScalarValue( @@ -163,21 +158,21 @@ xla::XlaOp BuildLeakyReluBackward(const xla::XlaOp& grad_output, negative_slope * grad_output); } -xla::XlaOp BuildSigmoid(const xla::XlaOp& input) { +xla::XlaOp BuildSigmoid(xla::XlaOp input) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp half = XlaHelpers::ScalarValue(0.5, shape.element_type(), input.builder()); return half + half * xla::Tanh(half * input); } -xla::XlaOp BuildReciprocal(const xla::XlaOp& input) { +xla::XlaOp BuildReciprocal(xla::XlaOp input) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp one = XlaHelpers::ScalarValue(1., shape.element_type(), input.builder()); return xla::Div(one, input); } -xla::XlaOp BuildSign(const xla::XlaOp& input) { +xla::XlaOp BuildSign(xla::XlaOp input) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp zero = XlaHelpers::ScalarValue(0., shape.element_type(), input.builder()); @@ -189,7 +184,7 @@ xla::XlaOp BuildSign(const xla::XlaOp& input) { xla::Broadcast(zero, shape.dimensions()), sign); } -xla::XlaOp BuildAbs(const xla::XlaOp& input) { +xla::XlaOp BuildAbs(xla::XlaOp input) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); if (xla::primitive_util::IsUnsignedIntegralType(shape.element_type())) { return input; diff --git a/torch_xla/csrc/elementwise.h b/torch_xla/csrc/elementwise.h index c5c0c3dd3d3d..3e88fa659ab7 100644 --- a/torch_xla/csrc/elementwise.h +++ b/torch_xla/csrc/elementwise.h @@ -8,55 +8,52 @@ namespace torch_xla { // Computes binary comparison operations. -xla::XlaOp BuildComparisonOp(c10::Symbol kind, const xla::XlaOp& input, - const xla::XlaOp& other); +xla::XlaOp BuildComparisonOp(c10::Symbol kind, xla::XlaOp input, + xla::XlaOp other); // Computes the elementwise threshold of the input: if the value is below the // threshold, replace it with the provided value, otherwise leave it unchanged. -xla::XlaOp BuildThreshold(const xla::XlaOp& input, const xla::XlaOp& output, +xla::XlaOp BuildThreshold(xla::XlaOp input, xla::XlaOp output, const float threshold, const float value); // Computes the rectified linear unit (replace negative elements with 0). -xla::XlaOp BuildRelu(const xla::XlaOp& input); +xla::XlaOp BuildRelu(xla::XlaOp input); -std::vector BuildRrelu(const xla::XlaOp& input, at::Scalar lower, +std::vector BuildRrelu(xla::XlaOp input, at::Scalar lower, at::Scalar upper, bool training); -xla::XlaOp BuildRreluBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, const xla::XlaOp& noise, - at::Scalar lower, at::Scalar upper, - bool training); +xla::XlaOp BuildRreluBackward(xla::XlaOp grad_output, xla::XlaOp input, + xla::XlaOp noise, at::Scalar lower, + at::Scalar upper, bool training); -xla::XlaOp BuildHardshrink(const xla::XlaOp& input, at::Scalar lambda); -xla::XlaOp BuildSoftshrink(const xla::XlaOp& input, at::Scalar lambda); -xla::XlaOp BuildShrinkBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, at::Scalar lambda); +xla::XlaOp BuildHardshrink(xla::XlaOp input, at::Scalar lambda); +xla::XlaOp BuildSoftshrink(xla::XlaOp input, at::Scalar lambda); +xla::XlaOp BuildShrinkBackward(xla::XlaOp grad_output, xla::XlaOp input, + at::Scalar lambda); -xla::XlaOp BuildHardtanhBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, at::Scalar min_val, - at::Scalar max_val); +xla::XlaOp BuildHardtanhBackward(xla::XlaOp grad_output, xla::XlaOp input, + at::Scalar min_val, at::Scalar max_val); // Computes the leaky rectified linear unit: // LeakyReLU(x) = max(0, input) + negative_slope ∗ min(0, input). -xla::XlaOp BuildLeakyRelu(const xla::XlaOp& input, double negative_slope); +xla::XlaOp BuildLeakyRelu(xla::XlaOp input, double negative_slope); -xla::XlaOp BuildLeakyReluBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, +xla::XlaOp BuildLeakyReluBackward(xla::XlaOp grad_output, xla::XlaOp input, double negative_slope_value); // Computes the sigmoid function using Tanh // Sigmoid(x) = (tanh(x ∗ 0.5) + 1) ∗ 0.5 -xla::XlaOp BuildSigmoid(const xla::XlaOp& input); +xla::XlaOp BuildSigmoid(xla::XlaOp input); // Computes the reciprocal function. // Reciprocal(x) = 1 / x -xla::XlaOp BuildReciprocal(const xla::XlaOp& input); +xla::XlaOp BuildReciprocal(xla::XlaOp input); // Computes the sign of the input. // If x is NaN then 0, otherwise the actual sign -xla::XlaOp BuildSign(const xla::XlaOp& input); +xla::XlaOp BuildSign(xla::XlaOp input); // Computes the absolute value of the input. -xla::XlaOp BuildAbs(const xla::XlaOp& input); +xla::XlaOp BuildAbs(xla::XlaOp input); } // namespace torch_xla diff --git a/torch_xla/csrc/helpers.cpp b/torch_xla/csrc/helpers.cpp index 20020b3c86b0..2710098cb545 100644 --- a/torch_xla/csrc/helpers.cpp +++ b/torch_xla/csrc/helpers.cpp @@ -15,8 +15,8 @@ namespace torch_xla { namespace { -xla::XlaOp ConvertBinaryOpResult(const xla::XlaOp& op1, const xla::XlaOp& op2, - const xla::XlaOp& result) { +xla::XlaOp ConvertBinaryOpResult(xla::XlaOp op1, xla::XlaOp op2, + xla::XlaOp result) { xla::PrimitiveType type1 = XlaHelpers::TypeOfXlaOp(op1); xla::PrimitiveType type2 = XlaHelpers::TypeOfXlaOp(op2); xla::PrimitiveType result_type = XlaHelpers::TypeOfXlaOp(result); @@ -42,7 +42,7 @@ xla::PrecisionConfig XlaHelpers::BuildPrecisionConfig( xla::XlaComputation CreateComputation( const std::string& name, xla::PrimitiveType type, - const std::function& op) { + const std::function& op) { xla::XlaBuilder builder(name); xla::XlaOp x = xla::Parameter(&builder, 0, xla::ShapeUtil::MakeShape(type, {}), "x"); @@ -203,7 +203,7 @@ xla::PaddingConfig XlaHelpers::MakeXlaPaddingConfigFromNdPadding( xla::XlaComputation XlaHelpers::CreateAddComputation(xla::PrimitiveType type) { return CreateComputation( - "AddComputation", type, [&](const xla::XlaOp& x, const xla::XlaOp& y) { + "AddComputation", type, [&](xla::XlaOp x, xla::XlaOp y) { return type == xla::PrimitiveType::PRED ? xla::Or(x, y) : xla::Add(x, y); }); @@ -212,50 +212,49 @@ xla::XlaComputation XlaHelpers::CreateAddComputation(xla::PrimitiveType type) { xla::XlaComputation XlaHelpers::CreateMulComputation(xla::PrimitiveType type) { return CreateComputation( "MulComputation", type, - [&](const xla::XlaOp& x, const xla::XlaOp& y) { return xla::Mul(x, y); }); + [&](xla::XlaOp x, xla::XlaOp y) { return xla::Mul(x, y); }); } xla::XlaComputation XlaHelpers::CreateMaxComputation(xla::PrimitiveType type) { return CreateComputation( "MaxComputation", type, - [&](const xla::XlaOp& x, const xla::XlaOp& y) { return xla::Max(x, y); }); + [&](xla::XlaOp x, xla::XlaOp y) { return xla::Max(x, y); }); } xla::XlaComputation XlaHelpers::CreateMinComputation(xla::PrimitiveType type) { return CreateComputation( "MinComputation", type, - [&](const xla::XlaOp& x, const xla::XlaOp& y) { return xla::Min(x, y); }); + [&](xla::XlaOp x, xla::XlaOp y) { return xla::Min(x, y); }); } xla::XlaComputation XlaHelpers::CreateAndComputation(xla::PrimitiveType type) { return CreateComputation( "AndComputation", type, - [&](const xla::XlaOp& x, const xla::XlaOp& y) { return xla::And(x, y); }); + [&](xla::XlaOp x, xla::XlaOp y) { return xla::And(x, y); }); } xla::XlaComputation XlaHelpers::CreateOrComputation(xla::PrimitiveType type) { return CreateComputation( "OrComputation", type, - [&](const xla::XlaOp& x, const xla::XlaOp& y) { return xla::Or(x, y); }); + [&](xla::XlaOp x, xla::XlaOp y) { return xla::Or(x, y); }); } -const xla::Shape& XlaHelpers::ShapeOfXlaOp(const xla::XlaOp& op) { +const xla::Shape& XlaHelpers::ShapeOfXlaOp(xla::XlaOp op) { const xla::Shape* shape = ConsumeValue(op.builder()->GetShapePtr(op)); return *shape; } -std::vector XlaHelpers::SizesOfXlaOp(const xla::XlaOp& op) { +std::vector XlaHelpers::SizesOfXlaOp(xla::XlaOp op) { const xla::Shape& op_shape = ShapeOfXlaOp(op); return std::vector(op_shape.dimensions().begin(), op_shape.dimensions().end()); } -xla::PrimitiveType XlaHelpers::TypeOfXlaOp(const xla::XlaOp& op) { +xla::PrimitiveType XlaHelpers::TypeOfXlaOp(xla::XlaOp op) { return ShapeOfXlaOp(op).element_type(); } -xla::XlaOp XlaHelpers::ReshapeToRank(const xla::XlaOp& input, - xla::int64 expected_rank, +xla::XlaOp XlaHelpers::ReshapeToRank(xla::XlaOp input, xla::int64 expected_rank, xla::int64 offset) { const xla::Shape& shape = ShapeOfXlaOp(input); XLA_CHECK_LE(offset + shape.rank(), expected_rank); @@ -269,8 +268,7 @@ xla::XlaOp XlaHelpers::ReshapeToRank(const xla::XlaOp& input, return xla::Reshape(input, dimensions); } -xla::XlaOp XlaHelpers::Flatten(const xla::XlaOp& input, - xla::Shape* input_shape) { +xla::XlaOp XlaHelpers::Flatten(xla::XlaOp input, xla::Shape* input_shape) { xla::util::MaybePtr input_shape_tmp(input_shape); *input_shape_tmp = ShapeOfXlaOp(input); if (input_shape_tmp->rank() == 1) { @@ -290,8 +288,7 @@ std::vector XlaHelpers::MakeTransposePermutation(xla::int64 dim0, return permute_dims; } -xla::XlaOp XlaHelpers::LinearInterpolation(const xla::XlaOp& value0, - const xla::XlaOp& value1, +xla::XlaOp XlaHelpers::LinearInterpolation(xla::XlaOp value0, xla::XlaOp value1, double alpha) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(value0); xla::XlaOp one = ScalarValue(1.0, shape.element_type(), value0.builder()); @@ -300,8 +297,8 @@ xla::XlaOp XlaHelpers::LinearInterpolation(const xla::XlaOp& value0, return value0 * alpha_value + value1 * (one - alpha_value); } -std::pair XlaHelpers::PromoteValues( - const xla::XlaOp& op1, const xla::XlaOp& op2) { +std::pair XlaHelpers::PromoteValues(xla::XlaOp op1, + xla::XlaOp op2) { xla::PrimitiveType type1 = TypeOfXlaOp(op1); xla::PrimitiveType type2 = TypeOfXlaOp(op2); if (type1 == type2) { @@ -343,7 +340,7 @@ std::pair XlaHelpers::PromoteValues( } std::pair XlaHelpers::PromoteSecondValue( - const xla::XlaOp& op1, const xla::XlaOp& op2) { + xla::XlaOp op1, xla::XlaOp op2) { xla::PrimitiveType type1 = TypeOfXlaOp(op1); xla::PrimitiveType type2 = TypeOfXlaOp(op2); return type1 == type2 @@ -395,8 +392,8 @@ xla::Shape XlaHelpers::GetPromotedShape(const xla::Shape& shape1, GetPromotedShape(shape1.dimensions(), shape2.dimensions())); } -std::pair XlaHelpers::PromoteShapes( - const xla::XlaOp& op1, const xla::XlaOp& op2) { +std::pair XlaHelpers::PromoteShapes(xla::XlaOp op1, + xla::XlaOp op2) { const xla::Shape& shape1 = ShapeOfXlaOp(op1); const xla::Shape& shape2 = ShapeOfXlaOp(op2); if (xla::ShapeUtil::Compatible(shape1, shape2)) { @@ -412,19 +409,19 @@ std::pair XlaHelpers::PromoteShapes( ImplicitBroadcast(op2, shape2, shape)); } -std::pair XlaHelpers::Promote(const xla::XlaOp& op1, - const xla::XlaOp& op2) { +std::pair XlaHelpers::Promote(xla::XlaOp op1, + xla::XlaOp op2) { std::pair vops = PromoteValues(op1, op2); return PromoteShapes(vops.first, vops.second); } -std::pair XlaHelpers::PromoteSecond( - const xla::XlaOp& op1, const xla::XlaOp& op2) { +std::pair XlaHelpers::PromoteSecond(xla::XlaOp op1, + xla::XlaOp op2) { std::pair vops = PromoteSecondValue(op1, op2); return PromoteShapes(vops.first, vops.second); } -xla::XlaOp XlaHelpers::ImplicitBroadcast(const xla::XlaOp& op, +xla::XlaOp XlaHelpers::ImplicitBroadcast(xla::XlaOp op, const xla::Shape& op_shape, const xla::Shape& shape) { const auto& op_shape_dims = op_shape.dimensions(); @@ -463,9 +460,8 @@ xla::XlaOp XlaHelpers::ImplicitBroadcast(const xla::XlaOp& op, } xla::XlaOp XlaHelpers::PromotedBinaryOp( - const xla::XlaOp& op1, const xla::XlaOp& op2, - const std::function& - bin_op) { + xla::XlaOp op1, xla::XlaOp op2, + const std::function& bin_op) { xla::XlaOp numeric_op1 = ConvertToNumeric(op1); xla::XlaOp numeric_op2 = ConvertToNumeric(op2); std::pair vops = diff --git a/torch_xla/csrc/helpers.h b/torch_xla/csrc/helpers.h index 281f1fcbcf11..b0abd52b29de 100644 --- a/torch_xla/csrc/helpers.h +++ b/torch_xla/csrc/helpers.h @@ -77,17 +77,17 @@ class XlaHelpers { // Performa a linear interpolation between value0 and value1, by calculating: // result = value0 * alpha + value1 * (1 - alpha) - static xla::XlaOp LinearInterpolation(const xla::XlaOp& value0, - const xla::XlaOp& value1, double alpha); + static xla::XlaOp LinearInterpolation(xla::XlaOp value0, xla::XlaOp value1, + double alpha); // Returns the shape of the given XLA operation. - static const xla::Shape& ShapeOfXlaOp(const xla::XlaOp& op); + static const xla::Shape& ShapeOfXlaOp(xla::XlaOp op); // Returns the list of dimension sizes for the given XLA operation. - static std::vector SizesOfXlaOp(const xla::XlaOp& op); + static std::vector SizesOfXlaOp(xla::XlaOp op); // Returns the value type of given XLA operation. - static xla::PrimitiveType TypeOfXlaOp(const xla::XlaOp& op); + static xla::PrimitiveType TypeOfXlaOp(xla::XlaOp op); static std::vector GetAllDimensions(size_t rank) { return xla::util::Iota(rank); @@ -183,11 +183,10 @@ class XlaHelpers { // appending 1s to the major dimension. If offset is greater than zero, 1s // will be prepened to the minor dimension as well. // Expected condition: rank(input) + offset <= expected_rank - static xla::XlaOp ReshapeToRank(const xla::XlaOp& input, - xla::int64 expected_rank, + static xla::XlaOp ReshapeToRank(xla::XlaOp input, xla::int64 expected_rank, xla::int64 offset = 0); - static xla::XlaOp Flatten(const xla::XlaOp& input, + static xla::XlaOp Flatten(xla::XlaOp input, xla::Shape* input_shape = nullptr); // Gathers the input using the order specified by the permutation. For each i, @@ -213,30 +212,30 @@ class XlaHelpers { xla::int64 rank); // Performs type promotion to make sure both operations return the same type. - static std::pair PromoteValues(const xla::XlaOp& op1, - const xla::XlaOp& op2); + static std::pair PromoteValues(xla::XlaOp op1, + xla::XlaOp op2); // Performs type promotion, by casting the second operation to the type of the // first, if different. - static std::pair PromoteSecondValue( - const xla::XlaOp& op1, const xla::XlaOp& op2); + static std::pair PromoteSecondValue(xla::XlaOp op1, + xla::XlaOp op2); // Eventually performs a broadcast to make sure the shapes of the returned // xla::XlaOp values have the same shape. The first returned xla::XlaOp is op1 // or a broadcast of it, and the second returned xla::XlaOp is either op2 or a // broadcast ot it. - static std::pair PromoteShapes(const xla::XlaOp& op1, - const xla::XlaOp& op2); + static std::pair PromoteShapes(xla::XlaOp op1, + xla::XlaOp op2); // Combines PromoteValues() and PromoteShapes() returning two operations which // match in shape and types. - static std::pair Promote(const xla::XlaOp& op1, - const xla::XlaOp& op2); + static std::pair Promote(xla::XlaOp op1, + xla::XlaOp op2); // Combines PromoteSecondValue() and PromoteShapes() returning two operations // which match in shape and types. - static std::pair PromoteSecond(const xla::XlaOp& op1, - const xla::XlaOp& op2); + static std::pair PromoteSecond(xla::XlaOp op1, + xla::XlaOp op2); // Calculates the protomoted shape to which the input shapes should be // broadcasted for an elementwise operation. The size of the common dimensions @@ -258,40 +257,34 @@ class XlaHelpers { // one that op is broadcast-able to (usually the result of a // GetPromotedShape() call). If op_shape matches shape, the op itself is // returned. - static xla::XlaOp ImplicitBroadcast(const xla::XlaOp& op, - const xla::Shape& op_shape, + static xla::XlaOp ImplicitBroadcast(xla::XlaOp op, const xla::Shape& op_shape, const xla::Shape& shape); // Performs the bin_op binary operation by promoting types and shapes of the // two input operands. static xla::XlaOp PromotedBinaryOp( - const xla::XlaOp& op1, const xla::XlaOp& op2, - const std::function& - bin_op); + xla::XlaOp op1, xla::XlaOp op2, + const std::function& bin_op); // Basic promoted binary operation implementation follow. - static xla::XlaOp PromotedAdd(const xla::XlaOp& op1, const xla::XlaOp& op2) { + static xla::XlaOp PromotedAdd(xla::XlaOp op1, xla::XlaOp op2) { return PromotedBinaryOp( - op1, op2, - [](const xla::XlaOp& op1, const xla::XlaOp& op2) { return op1 + op2; }); + op1, op2, [](xla::XlaOp op1, xla::XlaOp op2) { return op1 + op2; }); } - static xla::XlaOp PromotedSub(const xla::XlaOp& op1, const xla::XlaOp& op2) { + static xla::XlaOp PromotedSub(xla::XlaOp op1, xla::XlaOp op2) { return PromotedBinaryOp( - op1, op2, - [](const xla::XlaOp& op1, const xla::XlaOp& op2) { return op1 - op2; }); + op1, op2, [](xla::XlaOp op1, xla::XlaOp op2) { return op1 - op2; }); } - static xla::XlaOp PromotedMul(const xla::XlaOp& op1, const xla::XlaOp& op2) { + static xla::XlaOp PromotedMul(xla::XlaOp op1, xla::XlaOp op2) { return PromotedBinaryOp( - op1, op2, - [](const xla::XlaOp& op1, const xla::XlaOp& op2) { return op1 * op2; }); + op1, op2, [](xla::XlaOp op1, xla::XlaOp op2) { return op1 * op2; }); } - static xla::XlaOp PromotedDiv(const xla::XlaOp& op1, const xla::XlaOp& op2) { + static xla::XlaOp PromotedDiv(xla::XlaOp op1, xla::XlaOp op2) { return PromotedBinaryOp( - op1, op2, - [](const xla::XlaOp& op1, const xla::XlaOp& op2) { return op1 / op2; }); + op1, op2, [](xla::XlaOp op1, xla::XlaOp op2) { return op1 / op2; }); } template diff --git a/torch_xla/csrc/lowering_context.cpp b/torch_xla/csrc/lowering_context.cpp index a64277210acf..877556c91671 100644 --- a/torch_xla/csrc/lowering_context.cpp +++ b/torch_xla/csrc/lowering_context.cpp @@ -86,8 +86,7 @@ xla::StatusOr LoweringContext::Build() { return builder()->Build(); } -xla::StatusOr LoweringContext::Build( - const xla::XlaOp& root) { +xla::StatusOr LoweringContext::Build(xla::XlaOp root) { XLA_CHECK(root_tuple_.empty()); return builder()->Build(root); } diff --git a/torch_xla/csrc/lowering_context.h b/torch_xla/csrc/lowering_context.h index b0b8ba2badaa..3a1010edb05f 100644 --- a/torch_xla/csrc/lowering_context.h +++ b/torch_xla/csrc/lowering_context.h @@ -57,7 +57,7 @@ class LoweringContext { // embedded XLA builder (returned by the builder() API). // Uses root as return value forthe computation. It is an error to use this // API after having called the AddResult() API. - xla::StatusOr Build(const xla::XlaOp& root); + xla::StatusOr Build(xla::XlaOp root); // Lowers a single IR node. All the inputs to the node must have a lowering // before calling this API. Returns the generated XLA operations. diff --git a/torch_xla/csrc/matrix.cpp b/torch_xla/csrc/matrix.cpp index 96dfff32c80c..c9fa25850de8 100644 --- a/torch_xla/csrc/matrix.cpp +++ b/torch_xla/csrc/matrix.cpp @@ -43,7 +43,7 @@ xla::PaddingConfig CreateDiagonalPaddingConfig(const xla::Shape& target_shape, return padding_config; } -DiagonalMask CreateDiagonalMask(const xla::XlaOp& input, +DiagonalMask CreateDiagonalMask(xla::XlaOp input, const xla::Shape& target_shape, xla::int64 offset) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); @@ -80,18 +80,18 @@ std::vector GetDiagonalPermutation(xla::int64 rank, xla::int64 dim1, } // namespace -xla::XlaOp BuildTriu(const xla::XlaOp& input, xla::int64 diagonal) { +xla::XlaOp BuildTriu(xla::XlaOp input, xla::int64 diagonal) { return xla::Select(xla::TriangleMask(input, diagonal - 1), xla::ZerosLike(input), input); } -xla::XlaOp BuildTril(const xla::XlaOp& input, xla::int64 diagonal) { +xla::XlaOp BuildTril(xla::XlaOp input, xla::int64 diagonal) { return xla::Select(xla::TriangleMask(input, diagonal), input, xla::ZerosLike(input)); } -xla::XlaOp BuildDiagonal(const xla::XlaOp& input, xla::int64 offset, - xla::int64 dim1, xla::int64 dim2) { +xla::XlaOp BuildDiagonal(xla::XlaOp input, xla::int64 offset, xla::int64 dim1, + xla::int64 dim2) { xla::XlaOp diag_input = input; if (dim1 != 0 || dim2 != 1) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); @@ -101,9 +101,9 @@ xla::XlaOp BuildDiagonal(const xla::XlaOp& input, xla::int64 offset, return xla::GetMatrixDiagonal(diag_input, offset); } -xla::XlaOp BuildDiagonalViewUpdate(const xla::XlaOp& target, - const xla::XlaOp& input, xla::int64 offset, - xla::int64 dim1, xla::int64 dim2) { +xla::XlaOp BuildDiagonalViewUpdate(xla::XlaOp target, xla::XlaOp input, + xla::int64 offset, xla::int64 dim1, + xla::int64 dim2) { const xla::Shape* target_shape = &XlaHelpers::ShapeOfXlaOp(target); const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp diag_input = input; diff --git a/torch_xla/csrc/matrix.h b/torch_xla/csrc/matrix.h index be50b58f806d..537c71c04971 100644 --- a/torch_xla/csrc/matrix.h +++ b/torch_xla/csrc/matrix.h @@ -4,15 +4,15 @@ namespace torch_xla { -xla::XlaOp BuildTriu(const xla::XlaOp& input, xla::int64 diagonal); +xla::XlaOp BuildTriu(xla::XlaOp input, xla::int64 diagonal); -xla::XlaOp BuildTril(const xla::XlaOp& input, xla::int64 diagonal); +xla::XlaOp BuildTril(xla::XlaOp input, xla::int64 diagonal); -xla::XlaOp BuildDiagonal(const xla::XlaOp& input, xla::int64 offset, - xla::int64 dim1, xla::int64 dim2); +xla::XlaOp BuildDiagonal(xla::XlaOp input, xla::int64 offset, xla::int64 dim1, + xla::int64 dim2); -xla::XlaOp BuildDiagonalViewUpdate(const xla::XlaOp& target, - const xla::XlaOp& input, xla::int64 offset, - xla::int64 dim1, xla::int64 dim2); +xla::XlaOp BuildDiagonalViewUpdate(xla::XlaOp target, xla::XlaOp input, + xla::int64 offset, xla::int64 dim1, + xla::int64 dim2); } // namespace torch_xla diff --git a/torch_xla/csrc/nll_loss.cpp b/torch_xla/csrc/nll_loss.cpp index 4398d3de19b6..24337e57005b 100644 --- a/torch_xla/csrc/nll_loss.cpp +++ b/torch_xla/csrc/nll_loss.cpp @@ -45,8 +45,8 @@ xla::XlaOp OneHotIota(xla::XlaBuilder* builder, xla::int64 depth, int axis, // positions, respectively. If "ignore_index" is a valid class, it'll be // considered off. xla::XlaOp LabelsToOneHot(xla::XlaBuilder* builder, xla::int64 depth, int axis, - const xla::XlaOp& indices, const xla::XlaOp& on_value, - const xla::XlaOp& off_value, int ignore_index) { + xla::XlaOp indices, xla::XlaOp on_value, + xla::XlaOp off_value, int ignore_index) { const xla::Shape& indices_shape = XlaHelpers::ShapeOfXlaOp(indices); // Expand the labels with a depth dimension for the classes. @@ -71,10 +71,8 @@ xla::XlaOp LabelsToOneHot(xla::XlaBuilder* builder, xla::int64 depth, int axis, } WeightScale GetMaskedWeight(const absl::optional& weight, - const xla::Shape& logits_shape, - const xla::XlaOp& labels, - const xla::XlaOp& one_hot_labels, - int ignore_index) { + const xla::Shape& logits_shape, xla::XlaOp labels, + xla::XlaOp one_hot_labels, int ignore_index) { const xla::Shape& labels_shape = XlaHelpers::ShapeOfXlaOp(labels); xla::XlaOp valid_bitmap = xla::Ne( labels, XlaHelpers::ScalarValue( @@ -105,7 +103,7 @@ WeightScale GetMaskedWeight(const absl::optional& weight, } // namespace // Builds the NLLLoss for log-probabilities "logits" and class indices "labels". -xla::XlaOp BuildNllLoss(const xla::XlaOp& logits, const xla::XlaOp& labels, +xla::XlaOp BuildNllLoss(xla::XlaOp logits, xla::XlaOp labels, const absl::optional& weight, int ignore_index, ReductionMode reduction_mode) { const int classes_axis = 1; @@ -138,9 +136,8 @@ xla::XlaOp BuildNllLoss(const xla::XlaOp& logits, const xla::XlaOp& labels, // Builds the NLLLoss gradient for log-probabilities "logits" and class indices // "labels". -xla::XlaOp BuildNllLossBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& logits, - const xla::XlaOp& labels, +xla::XlaOp BuildNllLossBackward(xla::XlaOp grad_output, xla::XlaOp logits, + xla::XlaOp labels, const absl::optional& weight, const absl::optional& total_weight, int ignore_index, diff --git a/torch_xla/csrc/nll_loss.h b/torch_xla/csrc/nll_loss.h index 3cdd25549106..a8340d6ec3c3 100644 --- a/torch_xla/csrc/nll_loss.h +++ b/torch_xla/csrc/nll_loss.h @@ -7,15 +7,14 @@ namespace torch_xla { // Builds the NLLLoss for log-probabilities "logits" and class indices "labels". -xla::XlaOp BuildNllLoss(const xla::XlaOp& logits, const xla::XlaOp& labels, +xla::XlaOp BuildNllLoss(xla::XlaOp logits, xla::XlaOp labels, const absl::optional& weight, int ignore_index, ReductionMode reduction_mode); // Builds the NLLLoss gradient for log-probabilities "logits" and class indices // "labels". -xla::XlaOp BuildNllLossBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& logits, - const xla::XlaOp& labels, +xla::XlaOp BuildNllLossBackward(xla::XlaOp grad_output, xla::XlaOp logits, + xla::XlaOp labels, const absl::optional& weight, const absl::optional& total_weight, int ignore_index, ReductionMode reduction_mode); diff --git a/torch_xla/csrc/ops/as_strided.cpp b/torch_xla/csrc/ops/as_strided.cpp index 05dd2b04dbd9..b7e2ed8dce85 100644 --- a/torch_xla/csrc/ops/as_strided.cpp +++ b/torch_xla/csrc/ops/as_strided.cpp @@ -11,7 +11,7 @@ namespace ir { namespace ops { namespace { -xla::XlaOp LowerAsStrided(const xla::XlaOp& input, +xla::XlaOp LowerAsStrided(xla::XlaOp input, tensorflow::gtl::ArraySlice size, xla::int64 storage_offset) { xla::int64 input_element_count = diff --git a/torch_xla/csrc/ops/as_strided_view_update.cpp b/torch_xla/csrc/ops/as_strided_view_update.cpp index 61e016632f42..5c23e0f4bc98 100644 --- a/torch_xla/csrc/ops/as_strided_view_update.cpp +++ b/torch_xla/csrc/ops/as_strided_view_update.cpp @@ -14,7 +14,7 @@ namespace ops { namespace { xla::XlaOp LowerAsStridedViewUpdate( - const xla::XlaOp& target, const xla::XlaOp& input, + xla::XlaOp target, xla::XlaOp input, tensorflow::gtl::ArraySlice size, xla::int64 storage_offset) { xla::int64 input_element_count = xla::ShapeUtil::ElementsIn(XlaHelpers::ShapeOfXlaOp(input)); diff --git a/torch_xla/csrc/ops/cumprod.cpp b/torch_xla/csrc/ops/cumprod.cpp index 3adaef0a20b3..19967afa2f01 100644 --- a/torch_xla/csrc/ops/cumprod.cpp +++ b/torch_xla/csrc/ops/cumprod.cpp @@ -13,7 +13,7 @@ namespace ir { namespace ops { namespace { -xla::XlaOp LowerCumProd(const xla::XlaOp& input, xla::int64 dim, +xla::XlaOp LowerCumProd(xla::XlaOp input, xla::int64 dim, c10::optional dtype) { xla::XlaOp casted_input = CastToScalarType(input, dtype); const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(casted_input); diff --git a/torch_xla/csrc/ops/cumsum.cpp b/torch_xla/csrc/ops/cumsum.cpp index 65ea1f070525..850cc3284042 100644 --- a/torch_xla/csrc/ops/cumsum.cpp +++ b/torch_xla/csrc/ops/cumsum.cpp @@ -13,7 +13,7 @@ namespace ir { namespace ops { namespace { -xla::XlaOp LowerCumSum(const xla::XlaOp& input, xla::int64 dim, +xla::XlaOp LowerCumSum(xla::XlaOp input, xla::int64 dim, c10::optional dtype) { xla::XlaOp casted_input = CastToScalarType(input, dtype); const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(casted_input); diff --git a/torch_xla/csrc/ops/index_put.cpp b/torch_xla/csrc/ops/index_put.cpp index a281c7801182..eb852cda8bd7 100644 --- a/torch_xla/csrc/ops/index_put.cpp +++ b/torch_xla/csrc/ops/index_put.cpp @@ -30,9 +30,7 @@ NodePtr IndexPut::Clone(OpList operands) const { XlaOpVector IndexPut::Lower(LoweringContext* loctx) const { std::function add_scatter_combiner = - [](const xla::XlaOp& x, const xla::XlaOp& y) -> xla::XlaOp { - return x + y; - }; + [](xla::XlaOp x, xla::XlaOp y) -> xla::XlaOp { return x + y; }; xla::XlaOp base = loctx->GetOutputOp(operand(0)); xla::XlaOp indices = loctx->GetOutputOp(operand(1)); diff --git a/torch_xla/csrc/ops/log_softmax.cpp b/torch_xla/csrc/ops/log_softmax.cpp index 6efbfc29cd37..4d49c1348acf 100644 --- a/torch_xla/csrc/ops/log_softmax.cpp +++ b/torch_xla/csrc/ops/log_softmax.cpp @@ -12,7 +12,7 @@ namespace ir { namespace ops { namespace { -xla::XlaOp LowerLogSoftmax(const xla::XlaOp& input, xla::int64 dim, +xla::XlaOp LowerLogSoftmax(xla::XlaOp input, xla::int64 dim, const c10::optional& dtype) { xla::XlaOp result = BuildLogSoftmax(input, dim); return CastToScalarType(result, dtype); diff --git a/torch_xla/csrc/ops/mean.cpp b/torch_xla/csrc/ops/mean.cpp index d8a1362377d3..8ae956e109e5 100644 --- a/torch_xla/csrc/ops/mean.cpp +++ b/torch_xla/csrc/ops/mean.cpp @@ -14,7 +14,7 @@ namespace ir { namespace ops { namespace { -xla::XlaOp LowerMean(const xla::XlaOp& input, +xla::XlaOp LowerMean(xla::XlaOp input, const std::vector& dimensions, bool keep_reduced_dimensions, const c10::optional& dtype) { diff --git a/torch_xla/csrc/ops/native_batch_norm_forward.cpp b/torch_xla/csrc/ops/native_batch_norm_forward.cpp index 776178e92d54..917f84c2d1b0 100644 --- a/torch_xla/csrc/ops/native_batch_norm_forward.cpp +++ b/torch_xla/csrc/ops/native_batch_norm_forward.cpp @@ -11,12 +11,10 @@ namespace ir { namespace ops { namespace { -std::vector LowerBatchNorm(const xla::XlaOp& input, - const xla::XlaOp& weight, - const xla::XlaOp& bias, - const xla::XlaOp& running_mean, - const xla::XlaOp& running_var, - bool training, double eps) { +std::vector LowerBatchNorm(xla::XlaOp input, xla::XlaOp weight, + xla::XlaOp bias, xla::XlaOp running_mean, + xla::XlaOp running_var, bool training, + double eps) { std::vector values; if (training) { BatchNormOutput batch_norm_output = diff --git a/torch_xla/csrc/ops/prod.cpp b/torch_xla/csrc/ops/prod.cpp index c5cddd8c345f..9badaae3486a 100644 --- a/torch_xla/csrc/ops/prod.cpp +++ b/torch_xla/csrc/ops/prod.cpp @@ -15,7 +15,7 @@ namespace ir { namespace ops { namespace { -xla::XlaOp LowerProd(const xla::XlaOp& input, +xla::XlaOp LowerProd(xla::XlaOp input, const std::vector& dimensions, bool keep_reduced_dimensions, c10::optional dtype) { diff --git a/torch_xla/csrc/ops/qr.cpp b/torch_xla/csrc/ops/qr.cpp index 396f6f7ec974..60712bc68315 100644 --- a/torch_xla/csrc/ops/qr.cpp +++ b/torch_xla/csrc/ops/qr.cpp @@ -13,7 +13,7 @@ namespace ir { namespace ops { namespace { -std::vector LowerQR(const xla::XlaOp& input, bool some) { +std::vector LowerQR(xla::XlaOp input, bool some) { xla::QRDecompositionResult qr_result = xla::QRDecomposition(input, /*full_matrices=*/!some, /*block_size=*/128, XlaHelpers::mat_mul_precision()) diff --git a/torch_xla/csrc/ops/softmax.cpp b/torch_xla/csrc/ops/softmax.cpp index e0f9f5e111c9..50d07ea2e465 100644 --- a/torch_xla/csrc/ops/softmax.cpp +++ b/torch_xla/csrc/ops/softmax.cpp @@ -12,7 +12,7 @@ namespace ir { namespace ops { namespace { -xla::XlaOp LowerSoftmax(const xla::XlaOp& input, xla::int64 dim, +xla::XlaOp LowerSoftmax(xla::XlaOp input, xla::int64 dim, const c10::optional& dtype) { xla::XlaOp result = BuildSoftmax(input, dim); return CastToScalarType(result, dtype); diff --git a/torch_xla/csrc/ops/squeeze.cpp b/torch_xla/csrc/ops/squeeze.cpp index c1d4f3c7c256..38f6534f6a3e 100644 --- a/torch_xla/csrc/ops/squeeze.cpp +++ b/torch_xla/csrc/ops/squeeze.cpp @@ -11,7 +11,7 @@ namespace ir { namespace ops { namespace { -xla::XlaOp LowerSqueeze(const xla::XlaOp& input, int dim) { +xla::XlaOp LowerSqueeze(xla::XlaOp input, int dim) { if (dim == -1) { return SqueezeAllTrivialDimensions(input); } diff --git a/torch_xla/csrc/ops/sum.cpp b/torch_xla/csrc/ops/sum.cpp index 584751a557be..31c7a7b89a0e 100644 --- a/torch_xla/csrc/ops/sum.cpp +++ b/torch_xla/csrc/ops/sum.cpp @@ -15,7 +15,7 @@ namespace ir { namespace ops { namespace { -xla::XlaOp LowerSum(const xla::XlaOp& input, +xla::XlaOp LowerSum(xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions, c10::optional dtype) { diff --git a/torch_xla/csrc/ops/svd.cpp b/torch_xla/csrc/ops/svd.cpp index b19bbc1df35a..849d2b98bcfd 100644 --- a/torch_xla/csrc/ops/svd.cpp +++ b/torch_xla/csrc/ops/svd.cpp @@ -13,8 +13,7 @@ namespace ir { namespace ops { namespace { -std::vector LowerSVD(const xla::XlaOp& input, bool some, - bool compute_uv) { +std::vector LowerSVD(xla::XlaOp input, bool some, bool compute_uv) { xla::SVDResult svd_result = xla::SVD(input, /*max_iter=*/100, /*epsilon=*/1e-6, XlaHelpers::mat_mul_precision()); diff --git a/torch_xla/csrc/ops/symeig.cpp b/torch_xla/csrc/ops/symeig.cpp index f8032abc1e7e..6a6721bed72d 100644 --- a/torch_xla/csrc/ops/symeig.cpp +++ b/torch_xla/csrc/ops/symeig.cpp @@ -12,7 +12,7 @@ namespace ir { namespace ops { namespace { -std::vector LowerSymEig(const xla::XlaOp& input, bool eigenvectors, +std::vector LowerSymEig(xla::XlaOp input, bool eigenvectors, bool lower) { xla::SelfAdjointEigResult self_adj_eig_result = xla::SelfAdjointEig(input, /*lower=*/lower, /*max_iter=*/100, diff --git a/torch_xla/csrc/ops/triangular_solve.cpp b/torch_xla/csrc/ops/triangular_solve.cpp index ba9287530a3a..c58da7eba751 100644 --- a/torch_xla/csrc/ops/triangular_solve.cpp +++ b/torch_xla/csrc/ops/triangular_solve.cpp @@ -45,8 +45,7 @@ std::pair InferTriangularSolveShape( lhs_batch_promoted_shape); } -std::vector LowerTriangularSolve(const xla::XlaOp& rhs, - const xla::XlaOp& lhs, +std::vector LowerTriangularSolve(xla::XlaOp rhs, xla::XlaOp lhs, bool left_side, bool lower, bool transpose, bool unit_diagonal) { diff --git a/torch_xla/csrc/pooling.cpp b/torch_xla/csrc/pooling.cpp index 1b40ef50cdac..4bc5a1cc10ec 100644 --- a/torch_xla/csrc/pooling.cpp +++ b/torch_xla/csrc/pooling.cpp @@ -81,8 +81,7 @@ struct BatchInput { // Adds a batch dimension of size 1 if the input tensor doesn't have a batch // dimension. -BatchInput CreateBatchInput(const xla::XlaOp& input, - xla::int64 spatial_dim_count) { +BatchInput CreateBatchInput(xla::XlaOp input, xla::int64 spatial_dim_count) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::int64 rank = input_shape.rank(); XLA_CHECK(rank == spatial_dim_count + 1 || rank == spatial_dim_count + 2) @@ -94,7 +93,7 @@ BatchInput CreateBatchInput(const xla::XlaOp& input, return {input, rank}; } -xla::XlaOp RemoveTrivialBatch(const xla::XlaOp& batch, xla::int64 original_rank, +xla::XlaOp RemoveTrivialBatch(xla::XlaOp batch, xla::int64 original_rank, xla::int64 spatial_dim_count) { if (original_rank == spatial_dim_count + 1) { return SqueezeTrivialDimension(batch, 0); @@ -162,7 +161,7 @@ bool IsSupportedAdaptiveAvgPool2d( } xla::XlaOp BuildMaxPoolNd( - const xla::XlaOp& input, xla::int64 spatial_dim_count, + xla::XlaOp input, xla::int64 spatial_dim_count, tensorflow::gtl::ArraySlice kernel_size, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, bool ceil_mode) { @@ -192,8 +191,7 @@ xla::XlaOp BuildMaxPoolNd( } xla::XlaOp BuildMaxPoolNdBackward( - const xla::XlaOp& out_backprop, const xla::XlaOp& input, - xla::int64 spatial_dim_count, + xla::XlaOp out_backprop, xla::XlaOp input, xla::int64 spatial_dim_count, tensorflow::gtl::ArraySlice kernel_size, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, bool ceil_mode) { @@ -232,7 +230,7 @@ xla::XlaOp BuildMaxPoolNdBackward( } xla::XlaOp BuildAvgPoolNd( - const xla::XlaOp& input, xla::int64 spatial_dim_count, + xla::XlaOp input, xla::int64 spatial_dim_count, tensorflow::gtl::ArraySlice kernel_size, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, bool ceil_mode, @@ -258,8 +256,7 @@ xla::XlaOp BuildAvgPoolNd( } xla::XlaOp BuildAvgPoolNdBackward( - const xla::XlaOp& out_backprop, const xla::XlaOp& input, - xla::int64 spatial_dim_count, + xla::XlaOp out_backprop, xla::XlaOp input, xla::int64 spatial_dim_count, tensorflow::gtl::ArraySlice kernel_size, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, bool ceil_mode, @@ -288,7 +285,7 @@ xla::XlaOp BuildAvgPoolNdBackward( } xla::XlaOp BuildAdaptiveAvgPool2d( - const xla::XlaOp& input, + xla::XlaOp input, tensorflow::gtl::ArraySlice output_size) { XLA_CHECK_EQ(output_size.size(), 2) << "Invalid output size rank"; const auto input_size = XlaHelpers::SizesOfXlaOp(input); @@ -310,8 +307,8 @@ xla::XlaOp BuildAdaptiveAvgPool2d( /*spatial_dim_count=*/2); } -xla::XlaOp BuildAdaptiveAvgPool2dBackward(const xla::XlaOp& out_backprop, - const xla::XlaOp& input) { +xla::XlaOp BuildAdaptiveAvgPool2dBackward(xla::XlaOp out_backprop, + xla::XlaOp input) { BatchInput batch_out_backprop_info = CreateBatchInput(/*input=*/out_backprop, /*spatial_dim_count=*/2); const auto out_backprop_size = diff --git a/torch_xla/csrc/pooling.h b/torch_xla/csrc/pooling.h index 00f67e874870..42c7f3e27a6a 100644 --- a/torch_xla/csrc/pooling.h +++ b/torch_xla/csrc/pooling.h @@ -7,22 +7,21 @@ namespace torch_xla { // Computes max pooling for the given input. xla::XlaOp BuildMaxPoolNd( - const xla::XlaOp& input, xla::int64 spatial_dim_count, + xla::XlaOp input, xla::int64 spatial_dim_count, tensorflow::gtl::ArraySlice kernel_size, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, bool ceil_mode); // Computes the gradient for max pooling. xla::XlaOp BuildMaxPoolNdBackward( - const xla::XlaOp& out_backprop, const xla::XlaOp& input, - xla::int64 spatial_dim_count, + xla::XlaOp out_backprop, xla::XlaOp input, xla::int64 spatial_dim_count, tensorflow::gtl::ArraySlice kernel_size, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, bool ceil_mode); // Computes average pooling for the given input. xla::XlaOp BuildAvgPoolNd( - const xla::XlaOp& input, xla::int64 spatial_dim_count, + xla::XlaOp input, xla::int64 spatial_dim_count, tensorflow::gtl::ArraySlice kernel_size, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, bool ceil_mode, @@ -30,8 +29,7 @@ xla::XlaOp BuildAvgPoolNd( // Computes the gradient for average pooling. xla::XlaOp BuildAvgPoolNdBackward( - const xla::XlaOp& out_backprop, const xla::XlaOp& input, - xla::int64 spatial_dim_count, + xla::XlaOp out_backprop, xla::XlaOp input, xla::int64 spatial_dim_count, tensorflow::gtl::ArraySlice kernel_size, tensorflow::gtl::ArraySlice stride, tensorflow::gtl::ArraySlice padding, bool ceil_mode, @@ -39,12 +37,12 @@ xla::XlaOp BuildAvgPoolNdBackward( // Computes adaptive average pooling for the given input and output size. xla::XlaOp BuildAdaptiveAvgPool2d( - const xla::XlaOp& input, + xla::XlaOp input, tensorflow::gtl::ArraySlice output_size); // Computes the gradient for adaptive average pooling. -xla::XlaOp BuildAdaptiveAvgPool2dBackward(const xla::XlaOp& out_backprop, - const xla::XlaOp& input); +xla::XlaOp BuildAdaptiveAvgPool2dBackward(xla::XlaOp out_backprop, + xla::XlaOp input); // Returns true if XLA lowering is supported for the given input and output size // combination. diff --git a/torch_xla/csrc/reduction.cpp b/torch_xla/csrc/reduction.cpp index 6630ba8eff3e..697bd9af6f05 100644 --- a/torch_xla/csrc/reduction.cpp +++ b/torch_xla/csrc/reduction.cpp @@ -69,8 +69,7 @@ xla::XlaComputation CreateAnyComputation(xla::PrimitiveType type) { } SummationResult CreateSummation( - const xla::XlaOp& input, - tensorflow::gtl::ArraySlice dimensions, + xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions, bool scale) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp init_value = @@ -95,8 +94,7 @@ SummationResult CreateSummation( } xla::XlaOp CreateProduct( - const xla::XlaOp& input, - tensorflow::gtl::ArraySlice dimensions, + xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp init_value = @@ -127,8 +125,7 @@ xla::XlaOp AverageValue(const xla::XlaOp& input, const xla::XlaOp& reduced) { } // namespace -xla::XlaOp BuildBinaryCrossEntropy(const xla::XlaOp& input, - const xla::XlaOp& target, +xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target, const absl::optional& weight, ReductionMode reduction) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); @@ -156,9 +153,8 @@ xla::XlaOp BuildBinaryCrossEntropy(const xla::XlaOp& input, } xla::XlaOp BuildBinaryCrossEntropyBackward( - const xla::XlaOp& grad_output, const xla::XlaOp& input, - const xla::XlaOp& target, const absl::optional& weight, - ReductionMode reduction) { + xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp target, + const absl::optional& weight, ReductionMode reduction) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp xweight; if (weight) { @@ -180,7 +176,7 @@ xla::XlaOp BuildBinaryCrossEntropyBackward( return result; } -xla::XlaOp BuildL1Loss(const xla::XlaOp& input, const xla::XlaOp& target, +xla::XlaOp BuildL1Loss(xla::XlaOp input, xla::XlaOp target, ReductionMode reduction) { xla::XlaOp result = xla::Abs(input - target); if (reduction == ReductionMode::kNone) { @@ -196,10 +192,8 @@ xla::XlaOp BuildL1Loss(const xla::XlaOp& input, const xla::XlaOp& target, return result; } -xla::XlaOp BuildL1LossBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, - const xla::XlaOp& target, - ReductionMode reduction) { +xla::XlaOp BuildL1LossBackward(xla::XlaOp grad_output, xla::XlaOp input, + xla::XlaOp target, ReductionMode reduction) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); if (reduction == ReductionMode::kNone) { xla::XlaOp one = xla::One(input.builder(), input_shape.element_type()); @@ -213,7 +207,7 @@ xla::XlaOp BuildL1LossBackward(const xla::XlaOp& grad_output, return xla::Select(xla::Ge(input, target), grad_value, -grad_value); } -xla::XlaOp BuildMseLoss(const xla::XlaOp& input, const xla::XlaOp& target, +xla::XlaOp BuildMseLoss(xla::XlaOp input, xla::XlaOp target, ReductionMode reduction) { xla::XlaOp diff = input - target; xla::XlaOp result = diff * diff; @@ -238,10 +232,8 @@ xla::XlaOp BuildMseLoss(const xla::XlaOp& input, const xla::XlaOp& target, return result; } -xla::XlaOp BuildMseLossBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, - const xla::XlaOp& target, - ReductionMode reduction) { +xla::XlaOp BuildMseLossBackward(xla::XlaOp grad_output, xla::XlaOp input, + xla::XlaOp target, ReductionMode reduction) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp two = XlaHelpers::ScalarValue( 2, input_shape.element_type(), input.builder()); @@ -260,9 +252,9 @@ xla::XlaOp BuildMseLossBackward(const xla::XlaOp& grad_output, return d_input * grad_value; } -xla::XlaOp BuildCumulativeComputation(const xla::XlaOp& input, xla::int64 dim, +xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, xla::int64 dim, const xla::XlaComputation& reducer, - const xla::XlaOp& init) { + xla::XlaOp init) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); std::vector window_strides(input_shape.rank(), 1); std::vector window_dims(input_shape.rank(), 1); @@ -274,7 +266,7 @@ xla::XlaOp BuildCumulativeComputation(const xla::XlaOp& input, xla::int64 dim, /*base_dilations=*/{}, /*window_dilations=*/{}, padding); } -xla::XlaOp BuildMean(const xla::XlaOp& input, +xla::XlaOp BuildMean(xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions) { return CreateSummation(input, dimensions, keep_reduced_dimensions, @@ -283,8 +275,7 @@ xla::XlaOp BuildMean(const xla::XlaOp& input, } xla::XlaOp BuildStdDeviation( - const xla::XlaOp& input, - tensorflow::gtl::ArraySlice dimensions, + xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions, bool unbiased) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp mean = @@ -316,7 +307,7 @@ xla::XlaOp BuildStdDeviation( return xla::Sqrt(squared_result); } -xla::XlaOp BuildSum(const xla::XlaOp& input, +xla::XlaOp BuildSum(xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions) { return CreateSummation(input, dimensions, keep_reduced_dimensions, @@ -324,13 +315,13 @@ xla::XlaOp BuildSum(const xla::XlaOp& input, .result; } -xla::XlaOp BuildProd(const xla::XlaOp& input, +xla::XlaOp BuildProd(xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions) { return CreateProduct(input, dimensions, keep_reduced_dimensions); } -xla::XlaOp BuildMaxInDim(const xla::XlaOp& input, xla::int64 dim, +xla::XlaOp BuildMaxInDim(xla::XlaOp input, xla::int64 dim, bool keep_reduced_dimensions) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(shape.element_type()); @@ -347,7 +338,7 @@ xla::XlaOp BuildMaxInDim(const xla::XlaOp& input, xla::int64 dim, return result; } -xla::XlaOp BuildMinInDim(const xla::XlaOp& input, xla::int64 dim, +xla::XlaOp BuildMinInDim(xla::XlaOp input, xla::int64 dim, bool keep_reduced_dimensions) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); XlaHelpers::MinMax min_max = XlaHelpers::MinMaxValues(shape.element_type()); @@ -364,7 +355,7 @@ xla::XlaOp BuildMinInDim(const xla::XlaOp& input, xla::int64 dim, return result; } -xla::XlaOp BuildArgMax(const xla::XlaOp& input, xla::int64 dim, bool keepdim) { +xla::XlaOp BuildArgMax(xla::XlaOp input, xla::int64 dim, bool keepdim) { const xla::Shape* shape = &XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp operand = input; if (dim < 0) { @@ -383,7 +374,7 @@ xla::XlaOp BuildArgMax(const xla::XlaOp& input, xla::int64 dim, bool keepdim) { return result; } -xla::XlaOp BuildArgMin(const xla::XlaOp& input, xla::int64 dim, bool keepdim) { +xla::XlaOp BuildArgMin(xla::XlaOp input, xla::int64 dim, bool keepdim) { const xla::Shape* shape = &XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp operand = input; if (dim < 0) { @@ -402,7 +393,7 @@ xla::XlaOp BuildArgMin(const xla::XlaOp& input, xla::int64 dim, bool keepdim) { return result; } -xla::XlaOp BuildAll(const xla::XlaOp& input, +xla::XlaOp BuildAll(xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); @@ -419,7 +410,7 @@ xla::XlaOp BuildAll(const xla::XlaOp& input, return result; } -xla::XlaOp BuildAny(const xla::XlaOp& input, +xla::XlaOp BuildAny(xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); diff --git a/torch_xla/csrc/reduction.h b/torch_xla/csrc/reduction.h index 9aa853798a9d..5ad83e36d726 100644 --- a/torch_xla/csrc/reduction.h +++ b/torch_xla/csrc/reduction.h @@ -11,87 +11,80 @@ enum class ReductionMode { kSum, }; -xla::XlaOp BuildBinaryCrossEntropy(const xla::XlaOp& input, - const xla::XlaOp& target, +xla::XlaOp BuildBinaryCrossEntropy(xla::XlaOp input, xla::XlaOp target, const absl::optional& weight, ReductionMode reduction); xla::XlaOp BuildBinaryCrossEntropyBackward( - const xla::XlaOp& grad_output, const xla::XlaOp& input, - const xla::XlaOp& target, const absl::optional& weight, - ReductionMode reduction); + xla::XlaOp grad_output, xla::XlaOp input, xla::XlaOp target, + const absl::optional& weight, ReductionMode reduction); -xla::XlaOp BuildL1Loss(const xla::XlaOp& input, const xla::XlaOp& target, +xla::XlaOp BuildL1Loss(xla::XlaOp input, xla::XlaOp target, ReductionMode reduction); -xla::XlaOp BuildL1LossBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, - const xla::XlaOp& target, - ReductionMode reduction); +xla::XlaOp BuildL1LossBackward(xla::XlaOp grad_output, xla::XlaOp input, + xla::XlaOp target, ReductionMode reduction); -xla::XlaOp BuildMseLoss(const xla::XlaOp& input, const xla::XlaOp& target, +xla::XlaOp BuildMseLoss(xla::XlaOp input, xla::XlaOp target, ReductionMode reduction); -xla::XlaOp BuildMseLossBackward(const xla::XlaOp& grad_output, - const xla::XlaOp& input, - const xla::XlaOp& target, - ReductionMode reduction); +xla::XlaOp BuildMseLossBackward(xla::XlaOp grad_output, xla::XlaOp input, + xla::XlaOp target, ReductionMode reduction); // Builds a mean by reducing all the dimensions listed in dimensions. If // keep_reduced_dimensions is true, the reduced dimensions will be retained, // with value 1. -xla::XlaOp BuildMean(const xla::XlaOp& input, +xla::XlaOp BuildMean(xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions); xla::XlaOp BuildStdDeviation( - const xla::XlaOp& input, - tensorflow::gtl::ArraySlice dimensions, + xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions, bool unbiased); // Builds the sum of all values by reducing all the dimensions listed in // dimensions. If keep_reduced_dimensions is true, the reduced dimensions will // be retained, with value 1. -xla::XlaOp BuildSum(const xla::XlaOp& input, +xla::XlaOp BuildSum(xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions); // Builds the max of all values by reducing in the given dimension. If // keep_reduced_dimensions is true, the reduced dimension will be retained, with // value 1. -xla::XlaOp BuildMaxInDim(const xla::XlaOp& input, xla::int64 dim, +xla::XlaOp BuildMaxInDim(xla::XlaOp input, xla::int64 dim, bool keep_reduced_dimensions); // Builds the min of all values by reducing in the given dimension. If // keep_reduced_dimensions is true, the reduced dimension will be retained, with // value 1. -xla::XlaOp BuildMinInDim(const xla::XlaOp& input, xla::int64 dim, +xla::XlaOp BuildMinInDim(xla::XlaOp input, xla::int64 dim, bool keep_reduced_dimensions); // Compute the indices of the maximum values of a tensor across a dimension. -xla::XlaOp BuildArgMax(const xla::XlaOp& input, xla::int64 dim, bool keepdim); +xla::XlaOp BuildArgMax(xla::XlaOp input, xla::int64 dim, bool keepdim); // Compute the indices of the minimum values of a tensor across a dimension. -xla::XlaOp BuildArgMin(const xla::XlaOp& input, xla::int64 dim, bool keepdim); +xla::XlaOp BuildArgMin(xla::XlaOp input, xla::int64 dim, bool keepdim); // Builds the product of all values by reducing all the dimensions listed in // dimensions. If keep_reduced_dimensions is true, the reduced dimensions will // be retained, with value 1. -xla::XlaOp BuildProd(const xla::XlaOp& input, +xla::XlaOp BuildProd(xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions); // Compute the cumulative computation specified by "reducer" and "init" in the // given dimension "dim". -xla::XlaOp BuildCumulativeComputation(const xla::XlaOp& input, xla::int64 dim, +xla::XlaOp BuildCumulativeComputation(xla::XlaOp input, xla::int64 dim, const xla::XlaComputation& reducer, - const xla::XlaOp& init); + xla::XlaOp init); -xla::XlaOp BuildAll(const xla::XlaOp& input, +xla::XlaOp BuildAll(xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions); -xla::XlaOp BuildAny(const xla::XlaOp& input, +xla::XlaOp BuildAny(xla::XlaOp input, tensorflow::gtl::ArraySlice dimensions, bool keep_reduced_dimensions); diff --git a/torch_xla/csrc/resize_ops.cpp b/torch_xla/csrc/resize_ops.cpp index b351b1d71296..57b5ae1b2bab 100644 --- a/torch_xla/csrc/resize_ops.cpp +++ b/torch_xla/csrc/resize_ops.cpp @@ -41,7 +41,7 @@ xla::Shape GetBackwardOutputShape2d( return xla::ShapeUtil::MakeShape(input_shape.element_type(), input_size); } -xla::XlaOp LowerForward2d(const std::string& target, const xla::XlaOp& input, +xla::XlaOp LowerForward2d(const std::string& target, xla::XlaOp input, const xla::Shape& output_shape, bool align_corners, bool half_pixel_centers) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); @@ -65,7 +65,7 @@ xla::XlaOp LowerForward2d(const std::string& target, const xla::XlaOp& input, return xla::Transpose(resised, inv_transpose_permute); } -xla::XlaOp LowerBackward2d(const std::string& target, const xla::XlaOp& input, +xla::XlaOp LowerBackward2d(const std::string& target, xla::XlaOp input, const xla::Shape& output_shape, bool align_corners, bool half_pixel_centers) { static double resiple_split_factor = diff --git a/torch_xla/csrc/resize_ops.h b/torch_xla/csrc/resize_ops.h index acc49662397d..f82f3d695030 100644 --- a/torch_xla/csrc/resize_ops.h +++ b/torch_xla/csrc/resize_ops.h @@ -16,11 +16,11 @@ xla::Shape GetBackwardOutputShape2d( const xla::Shape& input_shape, tensorflow::gtl::ArraySlice input_size); -xla::XlaOp LowerForward2d(const std::string& target, const xla::XlaOp& input, +xla::XlaOp LowerForward2d(const std::string& target, xla::XlaOp input, const xla::Shape& output_shape, bool align_corners, bool half_pixel_centers); -xla::XlaOp LowerBackward2d(const std::string& target, const xla::XlaOp& input, +xla::XlaOp LowerBackward2d(const std::string& target, xla::XlaOp input, const xla::Shape& output_shape, bool align_corners, bool half_pixel_centers); diff --git a/torch_xla/csrc/softmax_builder.cpp b/torch_xla/csrc/softmax_builder.cpp index d50dd4f1cc01..40387bfb944e 100644 --- a/torch_xla/csrc/softmax_builder.cpp +++ b/torch_xla/csrc/softmax_builder.cpp @@ -25,7 +25,7 @@ std::vector BroadcastDimensions(xla::int64 dims, return result_dims; } -SoftMaxPartials LogSoftmaxPartials(const xla::XlaOp& logits, xla::int64 dim) { +SoftMaxPartials LogSoftmaxPartials(xla::XlaOp logits, xla::int64 dim) { const xla::Shape& logits_shape = XlaHelpers::ShapeOfXlaOp(logits); std::vector broadcast_dimensions = BroadcastDimensions(logits_shape.rank(), dim); @@ -47,7 +47,7 @@ SoftMaxPartials LogSoftmaxPartials(const xla::XlaOp& logits, xla::int64 dim) { return {std::move(broadcast_dimensions), shifted_logits, exp_shifted, reduce}; } -xla::XlaOp SoftmaxSumOfGrad(const xla::XlaOp& grad_output, xla::int64 dim) { +xla::XlaOp SoftmaxSumOfGrad(xla::XlaOp grad_output, xla::int64 dim) { const xla::Shape& grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output); auto broadcast_dimensions = BroadcastDimensions(grad_output_shape.rank(), dim); @@ -61,14 +61,14 @@ xla::XlaOp SoftmaxSumOfGrad(const xla::XlaOp& grad_output, xla::int64 dim) { } // namespace -xla::XlaOp BuildLogSoftmax(const xla::XlaOp& logits, xla::int64 dim) { +xla::XlaOp BuildLogSoftmax(xla::XlaOp logits, xla::int64 dim) { SoftMaxPartials parts = LogSoftmaxPartials(logits, dim); return xla::Sub(parts.shifted_logits, xla::Log(parts.reduce), parts.broadcast_dimensions); } -xla::XlaOp BuildLogSoftmaxGrad(const xla::XlaOp& grad_output, - const xla::XlaOp& output, xla::int64 dim) { +xla::XlaOp BuildLogSoftmaxGrad(xla::XlaOp grad_output, xla::XlaOp output, + xla::int64 dim) { // Inspired from tf2xla. xla::XlaOp sum = SoftmaxSumOfGrad(grad_output, dim); const xla::Shape& grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output); @@ -78,13 +78,13 @@ xla::XlaOp BuildLogSoftmaxGrad(const xla::XlaOp& grad_output, xla::Mul(xla::Exp(output), sum, broadcast_dimensions)); } -xla::XlaOp BuildSoftmax(const xla::XlaOp& logits, xla::int64 dim) { +xla::XlaOp BuildSoftmax(xla::XlaOp logits, xla::int64 dim) { SoftMaxPartials parts = LogSoftmaxPartials(logits, dim); return xla::Div(parts.exp_shifted, parts.reduce, parts.broadcast_dimensions); } -xla::XlaOp BuildSoftmaxGrad(const xla::XlaOp& grad_output, - const xla::XlaOp& output, xla::int64 dim) { +xla::XlaOp BuildSoftmaxGrad(xla::XlaOp grad_output, xla::XlaOp output, + xla::int64 dim) { xla::XlaOp sum = SoftmaxSumOfGrad(xla::Mul(grad_output, output), dim); const xla::Shape& grad_output_shape = XlaHelpers::ShapeOfXlaOp(grad_output); auto broadcast_dimensions = diff --git a/torch_xla/csrc/softmax_builder.h b/torch_xla/csrc/softmax_builder.h index d417c80120a6..f7c4d0ee2693 100644 --- a/torch_xla/csrc/softmax_builder.h +++ b/torch_xla/csrc/softmax_builder.h @@ -5,15 +5,15 @@ namespace torch_xla { // Computes log(softmax(logits)) along the dimension specified by "dim". -xla::XlaOp BuildLogSoftmax(const xla::XlaOp& logits, xla::int64 dim); +xla::XlaOp BuildLogSoftmax(xla::XlaOp logits, xla::int64 dim); // Computes the gradient of the input of the LogSoftmax function. -xla::XlaOp BuildLogSoftmaxGrad(const xla::XlaOp& grad_output, - const xla::XlaOp& output, xla::int64 dim); +xla::XlaOp BuildLogSoftmaxGrad(xla::XlaOp grad_output, xla::XlaOp output, + xla::int64 dim); -xla::XlaOp BuildSoftmax(const xla::XlaOp& logits, xla::int64 dim); +xla::XlaOp BuildSoftmax(xla::XlaOp logits, xla::int64 dim); -xla::XlaOp BuildSoftmaxGrad(const xla::XlaOp& grad_output, - const xla::XlaOp& output, xla::int64 dim); +xla::XlaOp BuildSoftmaxGrad(xla::XlaOp grad_output, xla::XlaOp output, + xla::int64 dim); } // namespace torch_xla diff --git a/torch_xla/csrc/xla_lower_util.cpp b/torch_xla/csrc/xla_lower_util.cpp index bc6b1f7f33ac..8ce20152cb1c 100644 --- a/torch_xla/csrc/xla_lower_util.cpp +++ b/torch_xla/csrc/xla_lower_util.cpp @@ -27,7 +27,7 @@ struct ConditionMaskData { xla::XlaOp length; }; -ConditionMaskData CreateConditionMaskData(const xla::XlaOp& condition) { +ConditionMaskData CreateConditionMaskData(xla::XlaOp condition) { xla::Shape iota_shape = XlaHelpers::ShapeOfXlaOp(condition); iota_shape.set_element_type(xla::PrimitiveType::S32); @@ -48,7 +48,7 @@ ConditionMaskData CreateConditionMaskData(const xla::XlaOp& condition) { length}; } -std::pair DotExpand(const xla::XlaOp& op, +std::pair DotExpand(xla::XlaOp op, const xla::Shape& op_shape, const xla::Shape& to_shape) { xla::int64 rank_delta = to_shape.rank() - op_shape.rank(); @@ -71,9 +71,9 @@ std::pair DotExpand(const xla::XlaOp& op, xla::ShapeUtil::MakeShape(op_shape.element_type(), broadcasted_sizes)); } -std::pair DotBroadcast(const xla::XlaOp& lhs, +std::pair DotBroadcast(xla::XlaOp lhs, const xla::Shape& lhs_shape, - const xla::XlaOp& rhs, + xla::XlaOp rhs, const xla::Shape& rhs_shape) { auto lhs_dimensions = xla::util::ToVector(lhs_shape.dimensions()); auto rhs_dimensions = xla::util::ToVector(rhs_shape.dimensions()); @@ -122,10 +122,9 @@ xla::XlaComputation MakeScatterComputation( } xla::XlaOp CreateIndexAlongDim( - const xla::XlaOp& buffer, xla::int64 dim, const xla::XlaOp& index, - const xla::XlaOp& value, bool broadcast_value_to_index, - const std::function& - combiner) { + xla::XlaOp buffer, xla::int64 dim, xla::XlaOp index, xla::XlaOp value, + bool broadcast_value_to_index, + const std::function& combiner) { const xla::Shape& buffer_shape = XlaHelpers::ShapeOfXlaOp(buffer); xla::ScatterDimensionNumbers dim_numbers; dim_numbers.set_index_vector_dim(1); @@ -174,10 +173,8 @@ bool ScatterRequiresPadding(const xla::Shape& input_shape, } xla::XlaOp XlaDenseScatter( - const xla::XlaOp& input, const xla::XlaOp& index, const xla::XlaOp& src, - xla::int64 dim, - const std::function& - combiner) { + xla::XlaOp input, xla::XlaOp index, xla::XlaOp src, xla::int64 dim, + const std::function& combiner) { // Contribute back this code to xla::TorchScatterDense() once this has reached // a stable implementation. xla::XlaBuilder* builder = input.builder(); @@ -240,7 +237,7 @@ xla::XlaOp XlaDenseScatter( }); } -std::vector BuildConditionIndices(const xla::XlaOp& condition) { +std::vector BuildConditionIndices(xla::XlaOp condition) { ConditionMaskData cmd = CreateConditionMaskData(condition); std::vector to_sort = {cmd.reshaped_condition_int}; std::vector types_to_sort = {xla::PrimitiveType::S32}; @@ -270,7 +267,7 @@ std::vector BuildConditionIndices(const xla::XlaOp& condition) { } // namespace -xla::XlaOp PadToSize(const xla::XlaOp& input, const xla::XlaOp& pad_value, +xla::XlaOp PadToSize(xla::XlaOp input, xla::XlaOp pad_value, tensorflow::gtl::ArraySlice size) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); XLA_CHECK_EQ(input_shape.rank(), size.size()); @@ -286,7 +283,7 @@ xla::XlaOp PadToSize(const xla::XlaOp& input, const xla::XlaOp& pad_value, return xla::Pad(input, pad_value, padding_config); } -std::vector CreateKthValue(const xla::XlaOp& input, xla::int64 k, +std::vector CreateKthValue(xla::XlaOp input, xla::int64 k, xla::int64 dim, bool keepdim) { // Here 'k' is 1 based (1...). const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); @@ -322,7 +319,7 @@ std::vector CreateKthValue(const xla::XlaOp& input, xla::int64 k, /*device=*/nullptr))}; } -std::vector CreateTopK(const xla::XlaOp& input, xla::int64 k, +std::vector CreateTopK(xla::XlaOp input, xla::int64 k, xla::int64 dim, bool largest, bool /* sorted */) { // Here 'k' is 1 based (1...). @@ -356,7 +353,7 @@ std::vector CreateTopK(const xla::XlaOp& input, xla::int64 k, /*device=*/nullptr))}; } -xla::XlaOp CreateMatMul(const xla::XlaOp& lhs, const xla::XlaOp& rhs) { +xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs) { const auto precision_level = XlaHelpers::mat_mul_precision(); xla::PrecisionConfig precision_config = XlaHelpers::BuildPrecisionConfig(precision_level); @@ -402,8 +399,7 @@ xla::XlaOp CreateMatMul(const xla::XlaOp& lhs, const xla::XlaOp& rhs) { << rhs_shape << ")"; } -xla::XlaOp BuildBernoulli(const xla::XlaOp& probability, - const xla::Shape& shape) { +xla::XlaOp BuildBernoulli(xla::XlaOp probability, const xla::Shape& shape) { const xla::Shape& probability_shape = XlaHelpers::ShapeOfXlaOp(probability); xla::XlaOp zero = XlaHelpers::ScalarValue( 0, probability_shape.element_type(), probability.builder()); @@ -414,7 +410,7 @@ xla::XlaOp BuildBernoulli(const xla::XlaOp& probability, shape.element_type()); } -xla::XlaOp BuildDropout(const xla::XlaOp& input, float probability) { +xla::XlaOp BuildDropout(xla::XlaOp input, float probability) { const xla::Shape& shape = XlaHelpers::ShapeOfXlaOp(input); xla::XlaOp prob = XlaHelpers::ScalarBroadcast(probability, shape, input.builder()); @@ -467,7 +463,7 @@ std::vector CreateBroadcastTensors( return result; } -xla::XlaOp CreateIndex(const xla::XlaOp& input, const xla::XlaOp& indices, +xla::XlaOp CreateIndex(xla::XlaOp input, xla::XlaOp indices, xla::int64 start_dim) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); const xla::Shape& indices_shape = XlaHelpers::ShapeOfXlaOp(indices); @@ -499,10 +495,9 @@ xla::XlaOp CreateIndex(const xla::XlaOp& input, const xla::XlaOp& indices, } xla::XlaOp CreateIndexUpdate( - const xla::XlaOp& buffer, const xla::XlaOp& indices, xla::int64 start_dim, - const xla::XlaOp& values, - const std::function& - combiner) { + xla::XlaOp buffer, xla::XlaOp indices, xla::int64 start_dim, + xla::XlaOp values, + const std::function& combiner) { const xla::Shape& buffer_shape = XlaHelpers::ShapeOfXlaOp(buffer); const xla::Shape& indices_shape = XlaHelpers::ShapeOfXlaOp(indices); const xla::Shape& values_shape = XlaHelpers::ShapeOfXlaOp(values); @@ -556,10 +551,9 @@ xla::XlaOp CreateIndexUpdate( dim_numbers); } -xla::XlaOp CreateIndexAdd(const xla::XlaOp& buffer, xla::int64 dim, - const xla::XlaOp& index, const xla::XlaOp& value) { - auto add_scatter_combiner = [](const xla::XlaOp& x, - const xla::XlaOp& y) -> xla::XlaOp { +xla::XlaOp CreateIndexAdd(xla::XlaOp buffer, xla::int64 dim, xla::XlaOp index, + xla::XlaOp value) { + auto add_scatter_combiner = [](xla::XlaOp x, xla::XlaOp y) -> xla::XlaOp { return x + y; }; return CreateIndexAlongDim(buffer, dim, index, value, @@ -567,20 +561,20 @@ xla::XlaOp CreateIndexAdd(const xla::XlaOp& buffer, xla::int64 dim, add_scatter_combiner); } -xla::XlaOp CreateIndexCopy(const xla::XlaOp& buffer, xla::int64 dim, - const xla::XlaOp& index, const xla::XlaOp& value) { +xla::XlaOp CreateIndexCopy(xla::XlaOp buffer, xla::int64 dim, xla::XlaOp index, + xla::XlaOp value) { return CreateIndexAlongDim(buffer, dim, index, value, /*broadcast_value_to_index=*/false, nullptr); } -xla::XlaOp CreateIndexFill(const xla::XlaOp& buffer, xla::int64 dim, - const xla::XlaOp& index, const xla::XlaOp& value) { +xla::XlaOp CreateIndexFill(xla::XlaOp buffer, xla::int64 dim, xla::XlaOp index, + xla::XlaOp value) { return CreateIndexAlongDim(buffer, dim, index, value, /*broadcast_value_to_index=*/true, nullptr); } XlaOpCombiner NumericAddCombiner() { - return [](const xla::XlaOp& x, const xla::XlaOp& y) -> xla::XlaOp { + return [](xla::XlaOp x, xla::XlaOp y) -> xla::XlaOp { xla::XlaOp numeric_x = ConvertToNumeric(x); xla::XlaOp numeric_y = ConvertToNumeric(y); xla::XlaOp numeric_sum = numeric_x + numeric_y; @@ -590,9 +584,8 @@ XlaOpCombiner NumericAddCombiner() { }; } -xla::XlaOp CreateScatter(const xla::XlaOp& input, const xla::XlaOp& index, - const xla::XlaOp& source, xla::int64 dim, - const XlaOpCombiner& combiner) { +xla::XlaOp CreateScatter(xla::XlaOp input, xla::XlaOp index, xla::XlaOp source, + xla::int64 dim, const XlaOpCombiner& combiner) { static int dense_scatter_factor = xla::sys_util::GetEnvInt("XLA_DENSE_SCATTER_FACTOR", 100); const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); @@ -635,8 +628,8 @@ xla::XlaOp CreateScatter(const xla::XlaOp& input, const xla::XlaOp& index, scatter_dnums); } -xla::XlaOp CreatePut(const xla::XlaOp& input, const xla::XlaOp& index, - const xla::XlaOp& source, bool accumulate) { +xla::XlaOp CreatePut(xla::XlaOp input, xla::XlaOp index, xla::XlaOp source, + bool accumulate) { xla::Shape input_shape; xla::XlaOp r1_input = XlaHelpers::Flatten(input, &input_shape); xla::Shape index_shape; @@ -655,14 +648,13 @@ xla::XlaOp CreatePut(const xla::XlaOp& input, const xla::XlaOp& index, return xla::Reshape(r1_scatter, input_shape.dimensions()); } -std::vector BuildNonZero(const xla::XlaOp& input) { +std::vector BuildNonZero(xla::XlaOp input) { const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input); return BuildConditionIndices( xla::Ne(input, xla::Zero(input.builder(), input_shape.element_type()))); } -std::vector BuildMaskedSelect(const xla::XlaOp& input, - const xla::XlaOp& mask) { +std::vector BuildMaskedSelect(xla::XlaOp input, xla::XlaOp mask) { xla::Shape input_shape; xla::XlaOp r1_input = XlaHelpers::Flatten(input, &input_shape); const xla::Shape& mask_shape = XlaHelpers::ShapeOfXlaOp(mask); diff --git a/torch_xla/csrc/xla_lower_util.h b/torch_xla/csrc/xla_lower_util.h index 6c342f96c90e..6c459094ff9c 100644 --- a/torch_xla/csrc/xla_lower_util.h +++ b/torch_xla/csrc/xla_lower_util.h @@ -7,21 +7,20 @@ namespace torch_xla { -xla::XlaOp PadToSize(const xla::XlaOp& input, const xla::XlaOp& pad_value, +xla::XlaOp PadToSize(xla::XlaOp input, xla::XlaOp pad_value, tensorflow::gtl::ArraySlice size); -std::vector CreateKthValue(const xla::XlaOp& input, xla::int64 k, +std::vector CreateKthValue(xla::XlaOp input, xla::int64 k, xla::int64 dim, bool keepdim); -std::vector CreateTopK(const xla::XlaOp& input, xla::int64 k, +std::vector CreateTopK(xla::XlaOp input, xla::int64 k, xla::int64 dim, bool largest, bool sorted); -xla::XlaOp CreateMatMul(const xla::XlaOp& lhs, const xla::XlaOp& rhs); +xla::XlaOp CreateMatMul(xla::XlaOp lhs, xla::XlaOp rhs); -xla::XlaOp BuildBernoulli(const xla::XlaOp& probability, - const xla::Shape& shape); +xla::XlaOp BuildBernoulli(xla::XlaOp probability, const xla::Shape& shape); -xla::XlaOp BuildDropout(const xla::XlaOp& input, float probability); +xla::XlaOp BuildDropout(xla::XlaOp input, float probability); xla::XlaOp BuildRandperm(xla::int64 n, xla::PrimitiveType element_type, xla::XlaBuilder* builder); @@ -30,41 +29,37 @@ std::vector CreateBroadcastTensors( tensorflow::gtl::ArraySlice operands); // Similar to tf.gather_nd, used to implement advanced indexing. -xla::XlaOp CreateIndex(const xla::XlaOp& input, const xla::XlaOp& indices, +xla::XlaOp CreateIndex(xla::XlaOp input, xla::XlaOp indices, xla::int64 start_dim); // Similar to tf.scatter_nd, used to implement advanced indexing updates. xla::XlaOp CreateIndexUpdate( - const xla::XlaOp& buffer, const xla::XlaOp& indices, xla::int64 start_dim, - const xla::XlaOp& updates, - const std::function& - combiner); + xla::XlaOp buffer, xla::XlaOp indices, xla::int64 start_dim, + xla::XlaOp updates, + const std::function& combiner); -xla::XlaOp CreateIndexAdd(const xla::XlaOp& buffer, xla::int64 dim, - const xla::XlaOp& index, const xla::XlaOp& value); +xla::XlaOp CreateIndexAdd(xla::XlaOp buffer, xla::int64 dim, xla::XlaOp index, + xla::XlaOp value); -xla::XlaOp CreateIndexCopy(const xla::XlaOp& buffer, xla::int64 dim, - const xla::XlaOp& index, const xla::XlaOp& value); +xla::XlaOp CreateIndexCopy(xla::XlaOp buffer, xla::int64 dim, xla::XlaOp index, + xla::XlaOp value); -xla::XlaOp CreateIndexFill(const xla::XlaOp& buffer, xla::int64 dim, - const xla::XlaOp& index, const xla::XlaOp& values); +xla::XlaOp CreateIndexFill(xla::XlaOp buffer, xla::int64 dim, xla::XlaOp index, + xla::XlaOp values); -using XlaOpCombiner = - std::function; +using XlaOpCombiner = std::function; XlaOpCombiner NumericAddCombiner(); // Used to lower scatter and scatter_add. -xla::XlaOp CreateScatter(const xla::XlaOp& input, const xla::XlaOp& index, - const xla::XlaOp& source, xla::int64 dim, - const XlaOpCombiner& combiner); +xla::XlaOp CreateScatter(xla::XlaOp input, xla::XlaOp index, xla::XlaOp source, + xla::int64 dim, const XlaOpCombiner& combiner); -xla::XlaOp CreatePut(const xla::XlaOp& input, const xla::XlaOp& index, - const xla::XlaOp& source, bool accumulate); +xla::XlaOp CreatePut(xla::XlaOp input, xla::XlaOp index, xla::XlaOp source, + bool accumulate); -std::vector BuildNonZero(const xla::XlaOp& input); +std::vector BuildNonZero(xla::XlaOp input); -std::vector BuildMaskedSelect(const xla::XlaOp& input, - const xla::XlaOp& mask); +std::vector BuildMaskedSelect(xla::XlaOp input, xla::XlaOp mask); } // namespace torch_xla