Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Add a bf16 datatype to sfnp #914

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 104 additions & 0 deletions shortfin/python/array_host_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,96 @@
#include "xtensor/xsort.hpp"
#include "xtl/xhalf_float.hpp"


#ifndef BFLOAT16_HPP
#define BFLOAT16_HPP

#include <cstdint>
#include <bit>
#include <limits>
#include <type_traits>

struct bfloat16_t {
uint16_t value;

constexpr bfloat16_t() noexcept : value(0) {}

explicit constexpr bfloat16_t(float f) noexcept {
uint32_t temp = std::bit_cast<uint32_t>(f);
value = static_cast<uint16_t>(temp >> 16);
}

template <typename T, typename = std::enable_if_t<std::is_arithmetic_v<T> &&
!std::is_same_v<T, float>>>
constexpr bfloat16_t(T value) noexcept : bfloat16_t(static_cast<float>(value)) {}

constexpr operator float() const noexcept {
uint32_t temp = static_cast<uint32_t>(value) << 16;
return std::bit_cast<float>(temp);
}

// Arithmetic operators (implemented via conversion to float)
constexpr bfloat16_t operator+(const bfloat16_t& other) const noexcept {
return bfloat16_t(float(*this) + float(other));
}
constexpr bfloat16_t operator-(const bfloat16_t& other) const noexcept {
return bfloat16_t(float(*this) - float(other));
}
constexpr bfloat16_t operator*(const bfloat16_t& other) const noexcept {
return bfloat16_t(float(*this) * float(other));
}
constexpr bfloat16_t operator/(const bfloat16_t& other) const noexcept {
return bfloat16_t(float(*this) / float(other));
}

constexpr bfloat16_t& operator+=(const bfloat16_t& other) noexcept {
*this = *this + other;
return *this;
}
constexpr bfloat16_t& operator-=(const bfloat16_t& other) noexcept {
*this = *this - other;
return *this;
}
constexpr bfloat16_t& operator*=(const bfloat16_t& other) noexcept {
*this = *this * other;
return *this;
}
constexpr bfloat16_t& operator/=(const bfloat16_t& other) noexcept {
*this = *this / other;
return *this;
}

// Comparison operators (using conversion to float)
constexpr bool operator==(const bfloat16_t& other) const noexcept {
return float(*this) == float(other);
}
constexpr bool operator!=(const bfloat16_t& other) const noexcept {
return !(*this == other);
}
constexpr bool operator<(const bfloat16_t& other) const noexcept {
return float(*this) < float(other);
}
constexpr bool operator<=(const bfloat16_t& other) const noexcept {
return float(*this) <= float(other);
}
constexpr bool operator>(const bfloat16_t& other) const noexcept {
return float(*this) > float(other);
}
constexpr bool operator>=(const bfloat16_t& other) const noexcept {
return float(*this) >= float(other);
}
};

// Mark bfloat16_t as a trivial, standard-layout type so that xtensor can use it.
namespace std {
template<> struct is_trivial<bfloat16_t> : std::true_type {};
template<> struct is_standard_layout<bfloat16_t> : std::true_type {};
template<> struct is_trivially_copyable<bfloat16_t> : std::true_type {};
}

#endif // BFLOAT16_HPP


using namespace shortfin::array;

namespace shortfin::python {
Expand Down Expand Up @@ -191,6 +281,7 @@ struct ConvertFunctor {
}
switch (dtype) {
SF_STORE_CASE(float16, half_float::half);
SF_STORE_CASE(bfloat16, bfloat16_t);
SF_STORE_CASE(float32, float);
SF_STORE_CASE(float64, double);
SF_STORE_CASE(uint8, uint8_t);
Expand All @@ -210,6 +301,7 @@ struct ConvertFunctor {

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t);
SF_UNARY_THUNK_CASE(float32, float);
SF_UNARY_THUNK_CASE(float64, double);
SF_UNARY_THUNK_CASE(uint8, uint8_t);
Expand Down Expand Up @@ -264,6 +356,7 @@ struct ConvertRoundFunctor {

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
Expand Down Expand Up @@ -308,6 +401,7 @@ struct ConvertCeilFunctor {

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
Expand Down Expand Up @@ -352,6 +446,7 @@ struct ConvertFloorFunctor {

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
Expand Down Expand Up @@ -396,6 +491,7 @@ struct ConvertTruncFunctor {

switch (input.dtype()) {
SF_UNARY_THUNK_CASE(float16, half_float::half);
SF_UNARY_THUNK_CASE(bfloat16, bfloat16_t);
SF_UNARY_THUNK_CASE(float32, float);
default:
throw std::invalid_argument(fmt::format(
Expand Down Expand Up @@ -525,6 +621,11 @@ half_float::half ConvertPyToEltTy(py::handle py_value, half_float::half zero) {
return static_cast<half_float::half>(py::cast<double>(py_value));
}

bfloat16_t ConvertPyToEltTy(py::handle py_value, bfloat16_t zero) {
// Python can't cast directly to half so first go to double.
return static_cast<bfloat16_t>(py::cast<double>(py_value));
}

struct AddFunctor {
template <typename Lhs, typename Rhs>
static auto Invoke(Lhs &&lhs, Rhs &&rhs) {
Expand Down Expand Up @@ -610,6 +711,7 @@ device_array ElementwiseOperation(py::handle lhs, py::handle rhs,

switch (dtype) {
SF_UNARY_FUNCTION_CASE(float16, half_float::half);
SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t);
SF_UNARY_FUNCTION_CASE(float32, float);
SF_UNARY_FUNCTION_CASE(float64, double);
SF_UNARY_FUNCTION_CASE(uint8, uint8_t);
Expand Down Expand Up @@ -661,6 +763,7 @@ void BindArrayHostOps(py::module_ &m) {

switch (input.dtype()) {
SF_UNARY_FUNCTION_CASE(float16, half_float::half);
SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t);
SF_UNARY_FUNCTION_CASE(float32, float);
default:
throw std::invalid_argument(
Expand Down Expand Up @@ -690,6 +793,7 @@ void BindArrayHostOps(py::module_ &m) {

switch (out.dtype()) {
SF_UNARY_FUNCTION_CASE(float16, half_float::half);
SF_UNARY_FUNCTION_CASE(bfloat16, bfloat16_t);
SF_UNARY_FUNCTION_CASE(float32, float);
default:
throw std::invalid_argument(
Expand Down
Loading