Skip to content

Commit

Permalink
Enforce a layout for AllReduce and use Current instead of Defau… (pyt…
Browse files Browse the repository at this point in the history
  • Loading branch information
dlibenzi authored and ailzhang committed Dec 10, 2019
1 parent 9e56738 commit cf3adeb
Show file tree
Hide file tree
Showing 9 changed files with 93 additions and 51 deletions.
12 changes: 11 additions & 1 deletion torch_xla/csrc/aten_xla_bridge.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ Device AtenDeviceToXlaDevice(const c10::Device& device) {
}
}
if (ordinal < 0) {
return *GetDefaultDevice();
return GetCurrentDevice();
}
return AtenXlaDeviceMapper::Get()->GetDeviceFromOrdinal(ordinal);
}
Expand All @@ -193,6 +193,16 @@ c10::Device AtenDefaultDevice() {
return XlaDeviceToAtenDevice(*GetDefaultDevice());
}

c10::Device SetCurrentDevice(const c10::Device& device) {
c10::Device prev_device = XLATensorImpl::SetCurrentAtenDevice(device);
SetCurrentDevice(AtenDeviceToXlaDevice(device));
return prev_device;
}

c10::Device GetCurrentAtenDevice() {
return XLATensorImpl::GetCurrentAtenDevice();
}

at::Tensor XlaToAtenTensor(XLATensor xla_tensor,
const at::TensorOptions& tensor_options) {
if (tensor_options.has_device()) {
Expand Down
4 changes: 4 additions & 0 deletions torch_xla/csrc/aten_xla_bridge.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,10 @@ std::string ToXlaString(const c10::Device& device);

c10::Device AtenDefaultDevice();

c10::Device SetCurrentDevice(const c10::Device& device);

c10::Device GetCurrentAtenDevice();

at::Tensor XlaToAtenTensor(XLATensor xla_tensor,
const at::TensorOptions& tensor_options);

Expand Down
2 changes: 1 addition & 1 deletion torch_xla/csrc/aten_xla_type.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ struct XlaOptions {
}
}

Device get_device() const { return device ? *device : *GetDefaultDevice(); }
Device get_device() const { return device ? *device : GetCurrentDevice(); }

at::ScalarType get_scalar_type(
at::ScalarType defval = at::ScalarType::Float) const {
Expand Down
14 changes: 6 additions & 8 deletions torch_xla/csrc/convert_ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,7 @@ xla::XlaOp ExplicitBooleanConvert(xla::XlaOp op, xla::PrimitiveType from) {

xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from,
xla::PrimitiveType to, const Device* device) {
if (device == nullptr) {
device = GetDefaultDevice();
}
if (device->hw_type != DeviceType::TPU) {
if (GetDeviceOrCurrent(device).hw_type != DeviceType::TPU) {
return xla::ConvertElementType(op, to);
}
switch (from) {
Expand Down Expand Up @@ -51,12 +48,13 @@ xla::XlaOp ConvertTo(xla::XlaOp op, xla::PrimitiveType from,
}

xla::XlaOp ConvertToNumeric(xla::XlaOp op, xla::PrimitiveType from) {
const Device* device = GetDefaultDevice();
Device xla_device = GetCurrentDevice();
return from != xla::PrimitiveType::PRED
? op
: ConvertTo(op, from,
GetDevicePrimitiveType(xla::PrimitiveType::U8, device),
device);
: ConvertTo(
op, from,
GetDevicePrimitiveType(xla::PrimitiveType::U8, &xla_device),
&xla_device);
}

xla::XlaOp ConvertToNumeric(xla::XlaOp op) {
Expand Down
37 changes: 29 additions & 8 deletions torch_xla/csrc/cross_replica_reduces.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,32 +2,47 @@

#include <map>

#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/xla_client/debug_macros.h"
#include "tensorflow/compiler/xla/xla_client/util.h"
#include "torch_xla/csrc/device.h"
#include "torch_xla/csrc/helpers.h"
#include "torch_xla/csrc/layout_manager.h"

namespace torch_xla {
namespace {

struct PerTypeContext {
std::vector<xla::XlaOp> ops;
std::vector<size_t> indices;
std::vector<xla::Shape> operand_shapes;
};

struct ReduceContext {
std::map<xla::PrimitiveType, PerTypeContext> contexts;
std::vector<xla::Shape> operand_shapes;
};

xla::Shape MakeReduceShape(
tensorflow::gtl::ArraySlice<const xla::Shape> operand_shapes) {
Device xla_device = GetCurrentDevice();
std::vector<xla::Shape> shapes_and_layouts;
shapes_and_layouts.reserve(operand_shapes.size());
for (auto& shape : operand_shapes) {
shapes_and_layouts.push_back(MakeArrayShapeFromDimensions(
shape.dimensions(), shape.element_type(), xla_device.hw_type));
}
return xla::ShapeUtil::MakeTupleShape(shapes_and_layouts);
}

ReduceContext GetReduceContext(
tensorflow::gtl::ArraySlice<const xla::XlaOp> operands) {
ReduceContext redux;
for (size_t i = 0; i < operands.size(); ++i) {
redux.operand_shapes.push_back(XlaHelpers::ShapeOfXlaOp(operands[i]));
PerTypeContext& ctx =
redux.contexts[redux.operand_shapes.back().element_type()];
xla::Shape operand_shape = XlaHelpers::ShapeOfXlaOp(operands[i]);
PerTypeContext& ctx = redux.contexts[operand_shape.element_type()];
ctx.ops.push_back(operands[i]);
ctx.indices.push_back(i);
ctx.operand_shapes.push_back(std::move(operand_shape));
}
return redux;
}
Expand Down Expand Up @@ -73,18 +88,24 @@ std::vector<xla::XlaOp> BuildAllReduce(
ReduceContext redux = GetReduceContext(operands);
std::vector<xla::XlaOp> result(operands.size());
for (auto& type_ctx : redux.contexts) {
type_ctx.second.ops.push_back(
xla::ConvertElementType(chained_token, type_ctx.first));
xla::XlaOp token_op =
xla::ConvertElementType(chained_token, type_ctx.first);
type_ctx.second.ops.push_back(token_op);
type_ctx.second.operand_shapes.push_back(
XlaHelpers::ShapeOfXlaOp(token_op));

xla::XlaOp reduce = xla::AllReduce(
xla::Tuple(operands[0].builder(), type_ctx.second.ops),
GetReduceComutation(reduce_type, type_ctx.first), reduce_groups);
GetReduceComutation(reduce_type, type_ctx.first), reduce_groups,
/*channel_id=*/absl::nullopt,
MakeReduceShape(type_ctx.second.operand_shapes));
for (size_t i = 0; i < type_ctx.second.indices.size(); ++i) {
size_t op_idx = type_ctx.second.indices[i];
xla::XlaOp gte = xla::GetTupleElement(reduce, i);
if (scale != 1.0) {
xla::XlaOp scaling_value = XlaHelpers::ScalarValue<float>(
scale, redux.operand_shapes[op_idx].element_type(), gte.builder());
scale, type_ctx.second.operand_shapes[i].element_type(),
gte.builder());
gte = gte * scaling_value;
}
result[op_idx] = gte;
Expand Down
10 changes: 10 additions & 0 deletions torch_xla/csrc/device.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
namespace torch_xla {
namespace {

thread_local Device g_current_device;

std::string DeviceTypeToString(DeviceType hw_type) {
switch (hw_type) {
case DeviceType::CPU:
Expand Down Expand Up @@ -66,4 +68,12 @@ const Device* GetDefaultDevice() {
return default_device;
}

Device GetCurrentDevice() { return g_current_device; }

Device SetCurrentDevice(const Device& device) {
Device current = g_current_device;
g_current_device = device;
return current;
}

} // namespace torch_xla
8 changes: 6 additions & 2 deletions torch_xla/csrc/device.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,12 @@ struct Device {

const Device* GetDefaultDevice();

static inline const Device& GetDeviceOrDefault(const Device* device) {
return device != nullptr ? *device : *GetDefaultDevice();
Device GetCurrentDevice();

Device SetCurrentDevice(const Device& device);

static inline Device GetDeviceOrCurrent(const Device* device) {
return device != nullptr ? *device : GetCurrentDevice();
}

} // namespace torch_xla
16 changes: 8 additions & 8 deletions torch_xla/csrc/init_python_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,16 @@ std::string GetTensorsDump(
return coverter(nodes);
}

std::string SetCurrentDevice(const std::string& device_str) {
c10::Device prev_device =
XLATensorImpl::SetCurrentAtenDevice(c10::Device(device_str));
std::string SetCurrentThreadDevice(const std::string& device_str) {
c10::Device prev_device = bridge::SetCurrentDevice(c10::Device(device_str));
std::stringstream ss;
ss << prev_device;
return ss.str();
}

std::string GetCurrentDevice() {
std::string GetCurrentThreadDevice() {
std::stringstream ss;
ss << XLATensorImpl::GetCurrentAtenDevice();
ss << bridge::GetCurrentAtenDevice();
return ss.str();
}

Expand Down Expand Up @@ -417,9 +416,10 @@ void InitXlaModuleBindings(py::module m) {
}
return new_token;
});
m.def("_xla_set_default_device",
[](const std::string& device) { return SetCurrentDevice(device); });
m.def("_xla_get_default_device", []() { return GetCurrentDevice(); });
m.def("_xla_set_default_device", [](const std::string& device) {
return SetCurrentThreadDevice(device);
});
m.def("_xla_get_default_device", []() { return GetCurrentThreadDevice(); });
m.def("_xla_sync_multi",
[](const std::vector<at::Tensor>& tensors,
const std::vector<std::string>& devices, bool wait,
Expand Down
41 changes: 18 additions & 23 deletions torch_xla/csrc/tensor_util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -546,19 +546,18 @@ std::vector<xla::ComputationClient::DataPtr> CreateTensorsData(

xla::Literal GetTensorLiteral(const at::Tensor& tensor, const xla::Shape* shape,
const Device* device) {
if (device == nullptr) {
device = GetDefaultDevice();
}
Device xla_device = GetDeviceOrCurrent(device);
xla::Shape computed_shape;
if (shape == nullptr) {
auto dimensions = XlaHelpers::I64List(tensor.sizes());
computed_shape = MakeTorchTensorLayout(
dimensions, XlaTypeFromTensorType(tensor.type().scalarType(), *device));
dimensions,
XlaTypeFromTensorType(tensor.type().scalarType(), xla_device));
shape = &computed_shape;
}
xla::Literal literal(*shape);
PopulateTensorBuffer(tensor, *shape, literal.untyped_data(),
literal.size_bytes(), *device);
literal.size_bytes(), xla_device);
return literal;
}

Expand Down Expand Up @@ -617,13 +616,11 @@ xla::Shape MakeShapeWithDeviceLayout(const xla::Shape& shape,

xla::Shape CreateComputationShapeFromTensor(const at::Tensor& tensor,
const Device* device) {
if (device == nullptr) {
device = GetDefaultDevice();
}
Device xla_device = GetDeviceOrCurrent(device);
return MakeArrayShapeFromDimensions(
XlaHelpers::I64List(tensor.sizes()),
MakeXlaPrimitiveType(tensor.type().scalarType(), device),
device->hw_type);
MakeXlaPrimitiveType(tensor.type().scalarType(), &xla_device),
xla_device.hw_type);
}

at::ScalarType TensorTypeFromXlaType(xla::PrimitiveType xla_type) {
Expand Down Expand Up @@ -655,32 +652,30 @@ at::ScalarType TensorTypeFromXlaType(xla::PrimitiveType xla_type) {

xla::PrimitiveType GetDevicePrimitiveType(xla::PrimitiveType type,
const Device* device) {
if (device == nullptr) {
device = GetDefaultDevice();
}
Device xla_device = GetDeviceOrCurrent(device);
switch (type) {
case xla::PrimitiveType::F64:
if (UseBF16()) {
return xla::PrimitiveType::BF16;
}
return device->hw_type != DeviceType::TPU ? xla::PrimitiveType::F64
: xla::PrimitiveType::F32;
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::F64
: xla::PrimitiveType::F32;
case xla::PrimitiveType::F32:
// When PyTorch will support native BF16 type, the global configuration
// can be replaced (or augmented) with the proper mapping.
return UseBF16() ? xla::PrimitiveType::BF16 : xla::PrimitiveType::F32;
case xla::PrimitiveType::U8:
return device->hw_type != DeviceType::TPU ? xla::PrimitiveType::U8
: xla::PrimitiveType::S32;
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::U8
: xla::PrimitiveType::S32;
case xla::PrimitiveType::S8:
return device->hw_type != DeviceType::TPU ? xla::PrimitiveType::S8
: xla::PrimitiveType::S32;
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::S8
: xla::PrimitiveType::S32;
case xla::PrimitiveType::U16:
return device->hw_type != DeviceType::TPU ? xla::PrimitiveType::U16
: xla::PrimitiveType::S32;
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::U16
: xla::PrimitiveType::S32;
case xla::PrimitiveType::S16:
return device->hw_type != DeviceType::TPU ? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
return xla_device.hw_type != DeviceType::TPU ? xla::PrimitiveType::S16
: xla::PrimitiveType::S32;
case xla::PrimitiveType::S64:
return Use32BitLong() ? xla::PrimitiveType::S32 : xla::PrimitiveType::S64;
default:
Expand Down

0 comments on commit cf3adeb

Please sign in to comment.