Skip to content

Commit

Permalink
FloatingPointConverter (intel#161)
Browse files Browse the repository at this point in the history
* fpconverter initial component

* api and doc for FloatingPointConverter

* patch from main to pass link check and confapp expect due to naming changes
  • Loading branch information
desmonddak authored Feb 1, 2025
1 parent 6f0d448 commit 1df633f
Show file tree
Hide file tree
Showing 8 changed files with 553 additions and 27 deletions.
1 change: 1 addition & 0 deletions doc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ Some in-development items will have opened issues, as well. Feel free to create
- [Simple Floating-Point Adder](./components/floating_point.md#floatingpointadder)
- [Rounding Floating-Point Adder](./components/floating_point.md#floatingpointadder)
- [Simple Floating-Point Multiplier](./components/floating_point.md#floatingpointmultiplier)
- [Floating-Point Converter](./components/floating_point.md#floatingpointconverter)
- [Fixed point](./components/fixed_point.md)
- [FloatToFixed](./components/fixed_point.md#floattofixed)
- [FixedToFloat](./components/fixed_point.md#fixedtofloat)
Expand Down
18 changes: 18 additions & 0 deletions doc/components/floating_point.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,3 +77,21 @@ It has options to control its performance:
- `adderGen`: used to specify the kind of [Adder] used for key functions like the mantissa addition. Defaults to [NativeAdder], but you can select a [ParallelPrefixAdder] of your choice.
- `seGen`: type of sign extension routine used, base class is [PartialProductSignExtension].
- `ppTree`: used to specify the type of ['ParallelPrefix'](https://intel.github.io/rohd-hcl/rohd_hcl/ParallelPrefix-class.html) used in the other critical functions like leading-one detect.

## FloatingPointConverter

A [FloatingPointConverter] component translates arbitrary width floating-point logic structures from one size to another, including handling sub-normals, infinities, and performs RNE rounding.

Here is an example using the converter to translate from 32-bit single-precision floating point to 16-bit brain (bfloat16) floating-point format.

```dart
final fp32 = FloatingPoint32();
final bf16 = FloatingPointBF16();
final one = FloatingPoint32Value.getFloatingPointConstant(
FloatingPointConstants.one);
fp32.put(one);
FloatingPointConverter(fp32, bf16);
expect(bf16.floatingPointValue.toDouble(), equals(1.0));
```
2 changes: 2 additions & 0 deletions lib/src/arithmetic/floating_point/floating_point.dart
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,7 @@
export 'floating_point_adder.dart';
export 'floating_point_adder_round.dart';
export 'floating_point_adder_simple.dart';
export 'floating_point_converter.dart';
export 'floating_point_multiplier.dart';
export 'floating_point_multiplier_simple.dart';
export 'floating_point_rounding.dart';
36 changes: 24 additions & 12 deletions lib/src/arithmetic/floating_point/floating_point_adder_round.dart
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,6 @@ import 'dart:math';
import 'package:rohd/rohd.dart';
import 'package:rohd_hcl/rohd_hcl.dart';

// TODO(desmonddak): factor rounding into a utility by merging
// the near and far bits and creating a LGRS algorithm on that word.

// TODO(desmondak): investigate how to implement other forms of rounding.

/// An adder module for variable FloatingPoint type with rounding.
// This is a Seidel/Even adder, dual-path implementation.
class FloatingPointAdderRound extends FloatingPointAdder {
Expand Down Expand Up @@ -315,6 +310,15 @@ class FloatingPointAdderRound extends FloatingPointAdder {
~effectiveSubtractionFlopped)
.named('isR');

final inf = outputSum.inf(sign: largerSignFlopped);
final infExponent = inf.exponent;

final realIsInfRPath =
exponentRPath.eq(infExponent).named('realIsInfRPath');

final realIsInfNPath =
exponentNPath.eq(infExponent).named('realIsInfNPath');

Combinational([
If(isNaNFlopped, then: [
outputSum < outputSum.nan,
Expand All @@ -323,14 +327,22 @@ class FloatingPointAdderRound extends FloatingPointAdder {
outputSum < outputSum.inf(sign: largerSignFlopped),
], orElse: [
If(isR, then: [
outputSum.sign < largerSignFlopped,
outputSum.exponent < exponentRPath,
outputSum.mantissa <
mantissaRPath.slice(mantissaRPath.width - 2, 1),
If(realIsInfRPath, then: [
outputSum < inf,
], orElse: [
outputSum.sign < largerSignFlopped,
outputSum.exponent < exponentRPath,
outputSum.mantissa <
mantissaRPath.slice(mantissaRPath.width - 2, 1),
]),
], orElse: [
outputSum.sign < signNPath,
outputSum.exponent < exponentNPath,
outputSum.mantissa < finalSignificandNPath,
If(realIsInfNPath, then: [
outputSum < inf,
], orElse: [
outputSum.sign < signNPath,
outputSum.exponent < exponentNPath,
outputSum.mantissa < finalSignificandNPath,
])
])
])
])
Expand Down
216 changes: 216 additions & 0 deletions lib/src/arithmetic/floating_point/floating_point_converter.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: BSD-3-Clause
//
// floating_point_converter.dart
// A Floating-point to floating-point arbitrary width converter.
//
// 2025 January 28 2025
// Author: Desmond A. Kirkpatrick <[email protected]>

import 'dart:math';

import 'package:meta/meta.dart';
import 'package:rohd/rohd.dart';
import 'package:rohd_hcl/rohd_hcl.dart';

/// A converter module for FloatingPoint values
class FloatingPointConverter extends Module {
/// Source exponent width
final int sourceExponentWidth;

/// Source mantissa width
final int sourceMantissaWidth;

/// Destination exponent width
late final int destExponentWidth;

/// Destination mantissa width
late final int destMantissaWidth;

/// Output [FloatingPoint] computed
late final FloatingPoint destination = FloatingPoint(
exponentWidth: destExponentWidth,
mantissaWidth: destMantissaWidth,
name: 'dest')
..gets(output('destination'));

/// The result of [FloatingPoint] conversion
@protected
late final FloatingPoint _destination = FloatingPoint(
exponentWidth: destExponentWidth,
mantissaWidth: destMantissaWidth,
name: 'dest');

/// Convert a [FloatingPoint] logic structure from one format to another.
/// - [source] is the source format [FloatingPoint] logic structure.
/// - [destination] is the destination format [FloatingPoint] logic
/// structure.
/// - [ppTree] can be specified to use a specific [ParallelPrefix] tree
/// for the leading-1 detection during normalization.
/// - [adderGen] can specify the [Adder] to use for exponent calculations.
FloatingPointConverter(FloatingPoint source, FloatingPoint destination,
{ParallelPrefix Function(
List<Logic> inps, Logic Function(Logic term1, Logic term2) op)
ppTree = KoggeStone.new,
Adder Function(Logic a, Logic b, {Logic? carryIn}) adderGen =
NativeAdder.new,
super.name})
: sourceExponentWidth = source.exponent.width,
sourceMantissaWidth = source.mantissa.width {
destExponentWidth = destination.exponent.width;
destMantissaWidth = destination.mantissa.width;
source = source.clone(name: 'source')
..gets(addInput('source', source, width: source.width));
addOutput('destination', width: _destination.width) <= _destination;
destination <= output('destination');

// maxExpWidth: mantissa +2:
// 1 for the hidden jbit and 1 for going past with leadingOneDetect
final maxExpWidth = [
source.exponent.width,
destExponentWidth,
log2Ceil(source.mantissa.width + 2),
log2Ceil(destMantissaWidth + 2)
].reduce(max);
final sBias = source.bias.zeroExtend(maxExpWidth).named('sourceBias');
final dBias = Const(FloatingPointValue.computeBias(destExponentWidth),
width: maxExpWidth)
.named('destBias');
final se = source.exponent.zeroExtend(maxExpWidth).named('sourceExp');
final mantissa =
[source.isNormal, source.mantissa].swizzle().named('mantissa');

final nan = source.isNaN;
final Logic infinity;
final Logic destExponent;
final Logic destMantissa;
if (destExponentWidth >= source.exponent.width) {
// Narrow to Wide
infinity = source.isInfinity;
final biasDiff = (dBias - sBias).named('biasDiff');
final predictExp = (se + biasDiff).named('predictExp');

final leadOneValid = Logic(name: 'leadOne_valid');
final leadOnePre = ParallelPrefixPriorityEncoder(mantissa.reversed,
ppGen: ppTree, valid: leadOneValid, name: 'lead_one_encoder')
.out;
final leadOne =
mux(leadOneValid, leadOnePre.zeroExtend(biasDiff.width), biasDiff)
.named('leadOne');

final predictSub = mux(
biasDiff.gte(leadOne) & leadOneValid,
biasDiff - (leadOne - Const(1, width: leadOne.width)),
Const(0, width: biasDiff.width))
.named('predictSubExp');

final shift =
mux(biasDiff.gte(leadOne), leadOne, biasDiff).named('shift');

final newMantissa = (mantissa << shift).named('mantissaShift');

final Logic roundedMantissa;
final Logic roundIncExp;
if (destMantissaWidth < source.mantissa.width) {
final rounder =
RoundRNE(newMantissa, source.mantissa.width - destMantissaWidth);

final roundAdder = adderGen(
newMantissa.reversed.getRange(1, destMantissaWidth + 1).reversed,
rounder.doRound.zeroExtend(destMantissaWidth));
roundedMantissa = roundAdder.sum
.named('roundedMantissa')
.getRange(0, destMantissaWidth);
roundIncExp = roundAdder.sum[-1];
} else {
roundedMantissa = newMantissa;
roundIncExp = Const(0);
}

destMantissa = ((destMantissaWidth >= source.mantissa.width)
? [
newMantissa.slice(-2, 0),
Const(0, width: destMantissaWidth - newMantissa.width + 1)
].swizzle().named('clippedMantissa')
: roundedMantissa)
.named('destMantissa');

final preExponent =
mux(shift.gt(Const(0, width: shift.width)), predictSub, predictExp)
.named('unRndDestExponent') +
roundIncExp.zeroExtend(predictSub.width).named('rndIncExp');
destExponent =
preExponent.getRange(0, destExponentWidth).named('destExponent');
} else {
// Wide to Narrow exponent
final biasDiff = (sBias - dBias).named('biasDiff');
final predictE = mux(biasDiff.gte(se), Const(0, width: biasDiff.width),
(se - biasDiff).named('sourceRebiased'))
.named('predictExponent');

final shift = mux(
biasDiff.gte(se),
(source.isNormal.zeroExtend(biasDiff.width).named('srcIsNormal') +
(biasDiff - se).named('negSourceRebiased'))
.named('shiftSubnormal'),
Const(0, width: biasDiff.width));

final fullMantissa = [mantissa, Const(0, width: destMantissaWidth + 2)]
.swizzle()
.named('fullMantissa');

final shiftMantissa = (fullMantissa >>> shift).named('shiftMantissa');
final rounder =
RoundRNE(shiftMantissa, fullMantissa.width - destMantissaWidth - 1);

final postPredRndMantissa = shiftMantissa
.slice(-2, shiftMantissa.width - destMantissaWidth - 1)
.named('preRndMantissa');

final roundAdder = adderGen(
postPredRndMantissa,
[Const(0, width: destMantissaWidth - 1), rounder.doRound]
.swizzle()
.named('rndIncrement'));
final roundIncExp = roundAdder.sum[-1];
final roundedMantissa = roundAdder.sum.getRange(0, destMantissaWidth);

destExponent = (predictE + roundIncExp.zeroExtend(predictE.width))
.named('predictExpRounded')
.getRange(0, destExponentWidth);
destMantissa =
roundedMantissa.getRange(0, destMantissaWidth).named('destMantissa');

final maxDestExp = Const(
FloatingPointValue.computeMaxExponent(destExponentWidth) +
FloatingPointValue.computeBias(destExponentWidth),
width: maxExpWidth);

infinity = source.isInfinity |
(se.gt(biasDiff) & (se - biasDiff).gt(maxDestExp));
}
Combinational([
If.block([
Iff(nan, [
_destination <
FloatingPoint(
exponentWidth: destExponentWidth,
mantissaWidth: destMantissaWidth)
.nan,
]),
ElseIf(infinity, [
_destination <
FloatingPoint(
exponentWidth: destExponentWidth,
mantissaWidth: destMantissaWidth)
.inf(sign: source.sign),
]),
ElseIf(Const(1), [
_destination.sign < source.sign,
_destination.exponent < destExponent,
_destination.mantissa < destMantissa,
]),
]),
]);
}
}
31 changes: 31 additions & 0 deletions lib/src/arithmetic/floating_point/floating_point_rounding.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: BSD-3-Clause
//
// floating_point_rounding.dart
// Floating-point rounding support.
//
// 2025 January 28 2025
// Author: Desmond A. Kirkpatrick <[email protected]>

import 'package:rohd/rohd.dart';

/// A rounding class that performs rounding-nearest-even
class RoundRNE {
/// Return whether to round the input or not.
Logic get doRound => _doRound;

late final Logic _doRound;

/// Determine whether the input should be rounded up given
/// - [inp] the input bitvector to consider rounding
/// - [lsb] the bit position at which to consider rounding
RoundRNE(Logic inp, int lsb) {
final last = inp[lsb];
final guard = (lsb > 0) ? inp[lsb - 1] : Const(0);
final round = (lsb > 1) ? inp[lsb - 2] : Const(0);
final sticky = (lsb > 2) ? inp.getRange(0, lsb - 2).or() : Const(0);

_doRound = guard & (last | round | sticky);
}
}
// TODO(desmondak): investigate how to implement other forms of rounding.
Loading

0 comments on commit 1df633f

Please sign in to comment.