Skip to content

Commit

Permalink
feat: add support for loading F8_E5M2 weights (#460)
Browse files Browse the repository at this point in the history
  • Loading branch information
LostRuins authored Nov 23, 2024
1 parent 0758544 commit 8f94efa
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 1 deletion.
63 changes: 63 additions & 0 deletions model.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,48 @@ uint16_t f8_e4m3_to_f16(uint8_t f8) {
return ggml_fp32_to_fp16(*reinterpret_cast<const float*>(&result));
}


uint16_t f8_e5m2_to_f16(uint8_t fp8) {
uint8_t sign = (fp8 >> 7) & 0x1;
uint8_t exponent = (fp8 >> 2) & 0x1F;
uint8_t mantissa = fp8 & 0x3;

uint16_t fp16_sign = sign << 15;
uint16_t fp16_exponent;
uint16_t fp16_mantissa;

if (exponent == 0 && mantissa == 0) { //zero
return fp16_sign;
}

if (exponent == 0x1F) { //NAN and INF
fp16_exponent = 0x1F;
fp16_mantissa = mantissa ? (mantissa << 8) : 0;
return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
}

if (exponent == 0) { //subnormal numbers
fp16_exponent = 0;
fp16_mantissa = (mantissa << 8);
return fp16_sign | fp16_mantissa;
}

//normal numbers
int16_t true_exponent = (int16_t)exponent - 15 + 15;
if (true_exponent <= 0) {
fp16_exponent = 0;
fp16_mantissa = (mantissa << 8);
} else if (true_exponent >= 0x1F) {
fp16_exponent = 0x1F;
fp16_mantissa = 0;
} else {
fp16_exponent = (uint16_t)true_exponent;
fp16_mantissa = mantissa << 8;
}

return fp16_sign | (fp16_exponent << 10) | fp16_mantissa;
}

void bf16_to_f32_vec(uint16_t* src, float* dst, int64_t n) {
// support inplace op
for (int64_t i = n - 1; i >= 0; i--) {
Expand All @@ -627,6 +669,12 @@ void f8_e4m3_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
dst[i] = f8_e4m3_to_f16(src[i]);
}
}
void f8_e5m2_to_f16_vec(uint8_t* src, uint16_t* dst, int64_t n) {
// support inplace op
for (int64_t i = n - 1; i >= 0; i--) {
dst[i] = f8_e5m2_to_f16(src[i]);
}
}

void convert_tensor(void* src,
ggml_type src_type,
Expand Down Expand Up @@ -863,6 +911,8 @@ ggml_type str_to_ggml_type(const std::string& dtype) {
ttype = GGML_TYPE_F32;
} else if (dtype == "F8_E4M3") {
ttype = GGML_TYPE_F16;
} else if (dtype == "F8_E5M2") {
ttype = GGML_TYPE_F16;
}
return ttype;
}
Expand Down Expand Up @@ -976,6 +1026,10 @@ bool ModelLoader::init_from_safetensors_file(const std::string& file_path, const
tensor_storage.is_f8_e4m3 = true;
// f8 -> f16
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
} else if (dtype == "F8_E5M2") {
tensor_storage.is_f8_e5m2 = true;
// f8 -> f16
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size * 2);
} else {
GGML_ASSERT(tensor_storage.nbytes() == tensor_data_size);
}
Expand Down Expand Up @@ -1644,6 +1698,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
} else if (tensor_storage.is_f8_e5m2) {
// inplace op
f8_e5m2_to_f16_vec((uint8_t*)dst_tensor->data, (uint16_t*)dst_tensor->data, tensor_storage.nelements());
}
} else {
read_buffer.resize(tensor_storage.nbytes());
Expand All @@ -1655,6 +1712,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e5m2) {
// inplace op
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
}

convert_tensor((void*)read_buffer.data(), tensor_storage.type, dst_tensor->data,
Expand All @@ -1670,6 +1730,9 @@ bool ModelLoader::load_tensors(on_new_tensor_cb_t on_new_tensor_cb, ggml_backend
} else if (tensor_storage.is_f8_e4m3) {
// inplace op
f8_e4m3_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
} else if (tensor_storage.is_f8_e5m2) {
// inplace op
f8_e5m2_to_f16_vec((uint8_t*)read_buffer.data(), (uint16_t*)read_buffer.data(), tensor_storage.nelements());
}

if (tensor_storage.type == dst_tensor->type) {
Expand Down
5 changes: 4 additions & 1 deletion model.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ struct TensorStorage {
ggml_type type = GGML_TYPE_F32;
bool is_bf16 = false;
bool is_f8_e4m3 = false;
bool is_f8_e5m2 = false;
int64_t ne[SD_MAX_DIMS] = {1, 1, 1, 1, 1};
int n_dims = 0;

Expand Down Expand Up @@ -65,7 +66,7 @@ struct TensorStorage {
}

int64_t nbytes_to_read() const {
if (is_bf16 || is_f8_e4m3) {
if (is_bf16 || is_f8_e4m3 || is_f8_e5m2) {
return nbytes() / 2;
} else {
return nbytes();
Expand Down Expand Up @@ -115,6 +116,8 @@ struct TensorStorage {
type_name = "bf16";
} else if (is_f8_e4m3) {
type_name = "f8_e4m3";
} else if (is_f8_e5m2) {
type_name = "f8_e5m2";
}
ss << name << " | " << type_name << " | ";
ss << n_dims << " [";
Expand Down

0 comments on commit 8f94efa

Please sign in to comment.