Skip to content

Commit

Permalink
Make sure the real sizes are accounted when averaging reduces, to dea…
Browse files Browse the repository at this point in the history
…l with dynamic shapes.
  • Loading branch information
dlibenzi committed Dec 6, 2019
1 parent d70d9d5 commit 88b80f8
Show file tree
Hide file tree
Showing 12 changed files with 190 additions and 134 deletions.
19 changes: 19 additions & 0 deletions torch_xla/csrc/helpers.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <limits>

#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/client/lib/constants.h"
#include "tensorflow/compiler/xla/primitive_util.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
Expand Down Expand Up @@ -127,6 +128,24 @@ xla::int64 XlaHelpers::GetDynamicDimension(const xla::Shape& shape) {
return dynamic_dimension;
}

xla::XlaOp XlaHelpers::GetDimensionsSize(
tensorflow::gtl::ArraySlice<const xla::XlaOp> inputs,
tensorflow::gtl::ArraySlice<const xla::int64> dimensions) {
XLA_CHECK(!inputs.empty());
xla::XlaOp size;
for (auto& input : inputs) {
for (auto dim : dimensions) {
if (size.valid()) {
size = size * xla::GetDimensionSize(input, dim);
} else {
size = xla::GetDimensionSize(input, dim);
}
}
}
return size.valid() ? size
: xla::One(inputs[0].builder(), xla::PrimitiveType::S32);
}

XlaHelpers::MinMax XlaHelpers::MinMaxValues(xla::PrimitiveType type) {
switch (type) {
case xla::PrimitiveType::S8:
Expand Down
12 changes: 12 additions & 0 deletions torch_xla/csrc/helpers.h
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,14 @@ class XlaHelpers {
// Returns the value type of given XLA operation.
static xla::PrimitiveType TypeOfXlaOp(const xla::XlaOp& op);

static std::vector<xla::int64> GetAllDimensions(size_t rank) {
return xla::util::Iota<xla::int64>(rank);
}

static std::vector<xla::int64> GetAllDimensions(const xla::Shape& shape) {
return xla::util::Iota<xla::int64>(shape.rank());
}

static xla::XlaOp CreateReturnValue(xla::XlaBuilder* builder,
const std::vector<xla::XlaOp>& outputs);

Expand Down Expand Up @@ -150,6 +158,10 @@ class XlaHelpers {
// Retrieves the dynamic dimension of an input shape, or returns -1 if none.
static xla::int64 GetDynamicDimension(const xla::Shape& shape);

static xla::XlaOp GetDimensionsSize(
tensorflow::gtl::ArraySlice<const xla::XlaOp> inputs,
tensorflow::gtl::ArraySlice<const xla::int64> dimensions);

// Retrieves type's minimum and maximum values.
static MinMax MinMaxValues(xla::PrimitiveType type);

Expand Down
15 changes: 12 additions & 3 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@ std::vector<at::Tensor> GetXlaTensorsFromAten(
return xla_tensors;
}

at::Tensor GetXlaTensorDimensionSize(const at::Tensor& tensor, xla::int64 dim) {
XLATensor xtensor = bridge::GetXlaTensor(tensor);
return bridge::AtenFromXlaTensor(
XLATensor::get_dimensions_size(xtensor, {dim}));
}

py::object GetMetricData(const std::string& name) {
xla::metrics::MetricData* data = xla::metrics::GetMetric(name);
if (data == nullptr) {
Expand Down Expand Up @@ -327,9 +333,12 @@ void InitXlaModuleBindings(py::module m) {
m.def("_initialize_aten_bindings",
[]() { AtenXlaType::InitializeAtenBindings(); });
m.def("_get_git_revs", []() { return GetRevisions(); });
m.def("_get_xla_tensor", [](const at::Tensor& tensor) -> XLATensor {
return bridge::GetXlaTensor(tensor);
});
m.def("_get_xla_tensor",
[](const at::Tensor& tensor) { return bridge::GetXlaTensor(tensor); });
m.def("_get_xla_tensor_dimension_size",
[](const at::Tensor& tensor, int dim) {
return GetXlaTensorDimensionSize(tensor, dim);
});
m.def("_get_xla_tensors_dot",
[](const std::vector<at::Tensor>& tensors) -> std::string {
auto coverter =
Expand Down
36 changes: 0 additions & 36 deletions torch_xla/csrc/ops/get_dimension_size.cpp

This file was deleted.

40 changes: 40 additions & 0 deletions torch_xla/csrc/ops/get_dimensions_size.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
#include "torch_xla/csrc/ops/get_dimensions_size.h"

#include "absl/strings/str_join.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/lowering_context.h"
#include "torch_xla/csrc/ops/xla_ops.h"

namespace torch_xla {
namespace ir {
namespace ops {

GetDimensionsSize::GetDimensionsSize(const Value& input,
std::vector<xla::int64> dimensions)
: Node(xla_get_dimensions_size, {input},
xla::ShapeUtil::MakeShape(xla::PrimitiveType::S32, {}),
/*num_outputs=*/1, xla::util::MHash(dimensions)),
dimensions_(std::move(dimensions)) {}

NodePtr GetDimensionsSize::Clone(OpList operands) const {
return MakeNode<GetDimensionsSize>(operands.at(0), dimensions_);
}

XlaOpVector GetDimensionsSize::Lower(LoweringContext* loctx) const {
xla::XlaOp input = loctx->GetOutputOp(operand(0));
xla::XlaOp output = XlaHelpers::GetDimensionsSize({input}, dimensions_);
return ReturnOp(output, loctx);
}

std::string GetDimensionsSize::ToString() const {
std::stringstream ss;
ss << Node::ToString() << ", dimensions=(" << absl::StrJoin(dimensions_, ", ")
<< ")";
return ss.str();
}

} // namespace ops
} // namespace ir
} // namespace torch_xla
Original file line number Diff line number Diff line change
@@ -1,25 +1,27 @@
#pragma once

#include <vector>

#include "torch_xla/csrc/ir.h"

namespace torch_xla {
namespace ir {
namespace ops {

class GetDimensionSize : public Node {
class GetDimensionsSize : public Node {
public:
GetDimensionSize(const Value& input, xla::int64 dimension);
GetDimensionsSize(const Value& input, std::vector<xla::int64> dimensions);

NodePtr Clone(OpList operands) const override;

XlaOpVector Lower(LoweringContext* loctx) const override;

std::string ToString() const override;

xla::int64 dimension() const { return dimension_; }
const std::vector<xla::int64>& dimensions() const { return dimensions_; }

private:
xla::int64 dimension_;
std::vector<xla::int64> dimensions_;
};

} // namespace ops
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/xla_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ const OpKindWrapper xla_cross_replica_sum("xla::cross_replica_sum");
const OpKindWrapper xla_device_data("xla::device_data");
const OpKindWrapper xla_diagonal_view_update("xla::diagonal_view_update");
const OpKindWrapper xla_generic_slice("xla::generic_slice");
const OpKindWrapper xla_get_dimension_size("xla::xla_get_dimension_size");
const OpKindWrapper xla_get_dimensions_size("xla::xla_get_dimensions_size");
const OpKindWrapper xla_moving_average("xla::moving_average");
const OpKindWrapper xla_not_supported("xla::not_supported");
const OpKindWrapper xla_select("xla::select");
Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/ops/xla_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ extern const OpKindWrapper xla_cross_replica_sum;
extern const OpKindWrapper xla_device_data;
extern const OpKindWrapper xla_diagonal_view_update;
extern const OpKindWrapper xla_generic_slice;
extern const OpKindWrapper xla_get_dimension_size;
extern const OpKindWrapper xla_get_dimensions_size;
extern const OpKindWrapper xla_moving_average;
extern const OpKindWrapper xla_not_supported;
extern const OpKindWrapper xla_select;
Expand Down
56 changes: 21 additions & 35 deletions torch_xla/csrc/reduction.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,19 @@ xla::XlaOp CreateProduct(
return result;
}

xla::XlaOp AverageValue(const xla::XlaOp& input, const xla::XlaOp& reduced) {
const xla::Shape& input_shape = XlaHelpers::ShapeOfXlaOp(input);
xla::XlaOp num_elements = XlaHelpers::GetDimensionsSize(
{input}, XlaHelpers::GetAllDimensions(input_shape));
xla::XlaOp zero =
xla::One(input.builder(), XlaHelpers::TypeOfXlaOp(num_elements));
return xla::Select(
xla::Ne(num_elements, zero),
reduced /
xla::ConvertElementType(num_elements, input_shape.element_type()),
xla::NanValue(input.builder(), input_shape.element_type()));
}

} // namespace

xla::XlaOp BuildBinaryCrossEntropy(const xla::XlaOp& input,
Expand All @@ -133,21 +146,13 @@ xla::XlaOp BuildBinaryCrossEntropy(const xla::XlaOp& input,
if (reduction == ReductionMode::kNone) {
return result;
}
result = xla::ReduceAll(
xla::XlaOp reduced_result = xla::ReduceAll(
result, xla::Zero(input.builder(), input_shape.element_type()),
XlaHelpers::CreateAddComputation(input_shape.element_type()));
if (reduction == ReductionMode::kMean) {
xla::int64 num_elements = xla::ShapeUtil::ElementsIn(input_shape);
if (num_elements == 0) {
return xla::NanValue(input.builder(), input_shape.element_type());
} else {
xla::XlaOp scale_value = XlaHelpers::ScalarValue<double>(
1.0 / static_cast<double>(num_elements), input_shape.element_type(),
input.builder());
result = result * scale_value;
}
reduced_result = AverageValue(result, reduced_result);
}
return result;
return reduced_result;
}

xla::XlaOp BuildBinaryCrossEntropyBackward(
Expand All @@ -168,18 +173,11 @@ xla::XlaOp BuildBinaryCrossEntropyBackward(
if (reduction == ReductionMode::kNone) {
return result * grad_output;
}
result = result * grad_output;
if (reduction == ReductionMode::kMean) {
xla::int64 num_elements = xla::ShapeUtil::ElementsIn(input_shape);
if (num_elements == 0) {
return xla::NanValue(input.builder(), input_shape.element_type());
} else {
xla::XlaOp scale_value = XlaHelpers::ScalarValue<double>(
1.0 / static_cast<double>(num_elements), input_shape.element_type(),
input.builder());
result = result * scale_value;
}
result = AverageValue(input, result);
}
return result * grad_output;
return result;
}

xla::XlaOp BuildL1Loss(const xla::XlaOp& input, const xla::XlaOp& target,
Expand All @@ -193,15 +191,7 @@ xla::XlaOp BuildL1Loss(const xla::XlaOp& input, const xla::XlaOp& target,
result, xla::Zero(input.builder(), input_shape.element_type()),
XlaHelpers::CreateAddComputation(input_shape.element_type()));
if (reduction == ReductionMode::kMean) {
xla::int64 num_elements = xla::ShapeUtil::ElementsIn(input_shape);
if (num_elements == 0) {
return xla::NanValue(input.builder(), input_shape.element_type());
} else {
xla::XlaOp scale_value = XlaHelpers::ScalarValue<double>(
1.0 / static_cast<double>(num_elements), input_shape.element_type(),
input.builder());
result = result * scale_value;
}
result = AverageValue(input, result);
}
return result;
}
Expand All @@ -218,11 +208,7 @@ xla::XlaOp BuildL1LossBackward(const xla::XlaOp& grad_output,
}
xla::XlaOp grad_value = grad_output;
if (reduction == ReductionMode::kMean) {
xla::int64 num_elements = xla::ShapeUtil::ElementsIn(input_shape);
xla::XlaOp scale_value = XlaHelpers::ScalarValue<double>(
1.0 / static_cast<double>(num_elements), input_shape.element_type(),
input.builder());
grad_value = grad_output * scale_value;
grad_value = AverageValue(input, grad_value);
}
return xla::Select(xla::Ge(input, target), grad_value, -grad_value);
}
Expand Down
32 changes: 19 additions & 13 deletions torch_xla/csrc/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,25 @@ class XLATensor {
const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices);

//////////////////////////////////////////////////////////////////////////////
// XLA dedicated operators follows here, listed in alphabetical order.
//////////////////////////////////////////////////////////////////////////////
static std::pair<XLATensor, ir::Value> all_reduce(
const XLATensor& input, const ir::Value& token, AllReduceType reduce_type,
double scale, const std::vector<std::vector<xla::int64>>& groups);

static ir::Value all_reduce_(
XLATensor& input, const ir::Value& token, AllReduceType reduce_type,
double scale, const std::vector<std::vector<xla::int64>>& groups);

static ir::Value all_reduce(
std::vector<XLATensor>* inputs, const ir::Value& token,
AllReduceType reduce_type, double scale,
const std::vector<std::vector<xla::int64>>& groups);

static XLATensor get_dimensions_size(const XLATensor& input,
std::vector<xla::int64> dimensions);

//////////////////////////////////////////////////////////////////////////////
// ATEN operators follows here, listed in alphabetical order.
//////////////////////////////////////////////////////////////////////////////
Expand Down Expand Up @@ -223,19 +242,6 @@ class XLATensor {
std::vector<xla::int64> dimensions,
bool keep_reduced_dimensions);

static std::pair<XLATensor, ir::Value> all_reduce(
const XLATensor& input, const ir::Value& token, AllReduceType reduce_type,
double scale, const std::vector<std::vector<xla::int64>>& groups);

static ir::Value all_reduce_(
XLATensor& input, const ir::Value& token, AllReduceType reduce_type,
double scale, const std::vector<std::vector<xla::int64>>& groups);

static ir::Value all_reduce(
std::vector<XLATensor>* inputs, const ir::Value& token,
AllReduceType reduce_type, double scale,
const std::vector<std::vector<xla::int64>>& groups);

static XLATensor any(const XLATensor& input,
std::vector<xla::int64> dimensions,
bool keep_reduced_dimensions);
Expand Down
Loading

0 comments on commit 88b80f8

Please sign in to comment.