Skip to content

Commit

Permalink
Tree component (intel#155)
Browse files Browse the repository at this point in the history
* general  reduction tree
  • Loading branch information
desmonddak authored Feb 1, 2025
1 parent 1df633f commit 09b8a65
Show file tree
Hide file tree
Showing 6 changed files with 364 additions and 2 deletions.
1 change: 1 addition & 0 deletions doc/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ Some in-development items will have opened issues, as well. Feel free to create
- NoC's
- Coherent
- Non-Coherent
- [Reduction Tree](./components/reduction_tree.md)
- Memory
- [Register File](./components/memory.md#register-files)
- [Masking](./components/memory.md#masks)
Expand Down
2 changes: 1 addition & 1 deletion doc/components/multiplier_components.md
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ row slice mult

A few things to note: first, that we are negating by ones' complement (so we need a -0) and second, these rows do not add up to (18: 10010). For Booth encoded rows to add up properly, they need to be in twos' complement form, and they need to be sign-extended.

Here is the matrix with a crude sign extension `brute` (the table formatting is available from our [Partial Product Generator](https://intel.github.io/rohd-hcl/rohd_hcl/PartialProductGeneratorBase-class.html) component followed by [Brute Force Sign Extension]). With twos' complementation, and sign bits folded in (note the LSB of each row has a sign term from the previous row), these addends are correctly formed and add to (18: 10010).
Here is the matrix with a crude sign extension `brute` (the table formatting is available from our [PartialProductGenerator](https://intel.github.io/rohd-hcl/rohd_hcl/PartialProductGeneratorBase-class.html) component). With twos' complementation, and sign bits folded in (note the LSB of each row has a sign term from the previous row), these addends are correctly formed and add to (18: 10010).

```text
7 6 5 4 3 2 1 0
Expand Down
54 changes: 54 additions & 0 deletions doc/components/reduction_tree.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# Reduction Tree

The `ReductionTree` component is a general tree generator that allows for arbitrary radix or tree-branching factor in the computation. It takes a sequence of `Logic` values and performs a specified operation at each node of the tree, taking in 'radix' inputs and producing one output. If the operation widens the output (say in addition), then the `ReductionTree` will widen values using either sign-extension or zero-extension as specified.

The input sequence is provided in the form 'List\<Logic\>'. The operation must be provided in the form:

```dart
Logic Function(List<Logic> operands, {String name})
```

This function should support operand lengths between $[2,radix]$ if the tree is to support an arbitrary length sequence: Note that the `ReductionTree` itself does not require the sequence length to be a power of the radix; it will use shorter operations to balance the tree when the sequence length is not a power of the radix.

The resulting tree can be pipelined by specifying the depth of nodes before a pipestage is added. Since the input can be of arbitrary length, paths in the tree may not be balanced, and extra pipestages will be added in shorter sections of the tree to align the computation.

Here is an example radix-4 computation tree using native addition on 79 13-bit inputs, pipelining every 2 operations deep, and producing a single 13-bit result.

```dart
Logic addReduce(List<Logic> inputs, {String name = 'native'}) {
final a = inputs.reduce((v, e) => v + e);
return a;
}
/// Tree reduction using addReduce
const width = 13;
const length = 79;
final vec = <Logic>[];
final reductionTree = ReductionTree(
vec, radix: 4, addReduce, clk: clk, depthToFlop; 2);
```

Here is the same example radix-4 computation tree but using prefix adders on 79 13-bit inputs, pipelining every 2 operations deep, and producing a single 21-bit result, due to width-extension of the prefix adder, adding 1 bit for each addition in 7 levels of the tree.

```dart
Logic addReduceAdders(List<Logic> inputs, {String name = 'prefix'}) {
if (inputs.length < 4) {
return inputs.reduce((v, e) => v + e);
} else {
final add0 =
ParallelPrefixAdder(inputs[0], inputs[1], name: '${name}_add0');
final add1 =
ParallelPrefixAdder(inputs[2], inputs[3], name: '${name}_add1');
final addf = ParallelPrefixAdder(add0.sum, add1.sum, name: '${name}_addf');
return addf.sum;
}
}
/// Tree reduction using addReduceWithAdders
const width = 13;
const length = 79;
final vec = <Logic>[];
final reductionTree = ReductionTree(
vec, radix: 4, addReduceWithAdders, clk: clk, depthToFlop; 2, signExtend: true);
```
3 changes: 2 additions & 1 deletion lib/rohd_hcl.dart
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
// Copyright (C) 2023-2024 Intel Corporation
// Copyright (C) 2023-2025 Intel Corporation
// SPDX-License-Identifier: BSD-3-Clause

export 'src/arbiters/arbiters.dart';
Expand All @@ -17,6 +17,7 @@ export 'src/find.dart';
export 'src/interfaces/interfaces.dart';
export 'src/memory/memories.dart';
export 'src/models/models.dart';
export 'src/reduction_tree.dart';
export 'src/rotate.dart';
export 'src/serialization/serialization.dart';
export 'src/shift_register.dart';
Expand Down
147 changes: 147 additions & 0 deletions lib/src/reduction_tree.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,147 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: BSD-3-Clause
//
// reduction_tree.dart
// A generator for creating tree reduction computations.
//
// 2025 January 10
// Author: Desmond A Kirkpatrick <[email protected]

import 'dart:math';

import 'package:meta/meta.dart';
import 'package:rohd/rohd.dart';
import 'package:rohd_hcl/rohd_hcl.dart';

/// A generator which constructs a tree of radix-input / 1-output modules.
class ReductionTree extends Module {
/// The radix-sized input operation to be performed at each node.
@protected
final Logic Function(List<Logic> inputs, {String name}) operation;

/// Specified width of input to each reduction node (e.g., binary: radix=2)
@protected
late final int radix;

/// When [signExtend] is true, use sign-extension on values,
/// otherwise use zero-extension.
@protected
late final bool signExtend;

/// Specified depth of nodes at which to flop (requires [clk]).
@protected
late final int? depthToFlop;

/// Optional [clk] input to create pipeline.
@protected
late final Logic? clk;

/// Optional [reset] input to reset pipeline.
@protected
late final Logic? reset;

/// Optional [enable] input to enable pipeline.
@protected
late final Logic? enable;

/// The final output of the tree computation.
Logic get out => output('out');

/// The combinational depth since the last flop. The total compute depth of
/// the tree is: depth + flopDepth * depthToflop;
int get depth => _computed.depth;

/// The flop depth of the tree from the output to the leaves.
int get latency => _computed.flopDepth;

/// Capture the record of compute: the final value, its depth (from last
/// flop or input), and its flopDepth if pipelined.
late final ({Logic value, int depth, int flopDepth}) _computed;

/// Generate a tree based on dividing the input [sequence] of a node into
/// segments, recursively constructing [radix] child nodes to operate
/// on each segment.
/// - [sequence] is the input sequence to be reduced using the tree of
/// operations.
/// - Logic Function(List<Logic> inputs, {String name}) [operation]
/// is the operation to be
/// performed at each node. Note that [operation] can widen the output. The
/// logic function must support the operation for 2 to radix inputs.
/// - [radix] is the width of reduction at each node in the tree (e.g.,
/// binary: radix=2).
/// - [signExtend] if true, use sign-extension to widen [Logic] values as
/// needed in the tree, otherwise use zero-extension (default).
///
/// Optional parameters to be used for creating a pipelined computation tree:
/// - [clk], [reset], [enable] are optionally provided to allow for flopping.
/// - [depthToFlop] specifies how many nodes deep separate flops.
ReductionTree(List<Logic> sequence, this.operation,
{this.radix = 2,
this.signExtend = false,
this.depthToFlop,
Logic? clk,
Logic? enable,
Logic? reset,
super.name = 'reduction_tree'})
: super(definitionName: 'ReductionTree_R${radix}_L${sequence.length}}') {
if (sequence.isEmpty) {
throw RohdHclException("Don't use ReductionTree "
'with an empty sequence');
}
sequence = [
for (var i = 0; i < sequence.length; i++)
addInput('seq$i', sequence[i], width: sequence[i].width)
];
this.clk = (clk != null) ? addInput('clk', clk) : null;
this.enable = (enable != null) ? addInput('enable', enable) : null;
this.reset = (reset != null) ? addInput('reset', reset) : null;

_computed = reductionTreeRecurse(sequence);
addOutput('out', width: _computed.value.width) <= _computed.value;
}

/// Local conditional flop using module reset/enable
Logic localFlop(Logic d, {bool doFlop = false}) =>
condFlop(doFlop ? clk : null, reset: reset, en: enable, d);

/// Recursively construct the computation tree
({Logic value, int depth, int flopDepth}) reductionTreeRecurse(
List<Logic> seq) {
if (seq.length < radix) {
return (value: operation(seq), depth: 0, flopDepth: 0);
} else {
final results = <({Logic value, int depth, int flopDepth})>[];
final segment = seq.length ~/ radix;
var pos = 0;
for (var i = 0; i < radix; i++) {
final c = reductionTreeRecurse(seq
.getRange(pos, (i < radix - 1) ? pos + segment : seq.length)
.toList());
results.add(c);
pos += segment;
}
final flopDepth = results.map((c) => c.flopDepth).reduce(max);
final treeDepth = results.map((c) => c.depth).reduce(max);

final alignedResults = results
.map((c) => localFlop(c.value, doFlop: c.flopDepth < flopDepth));

final depthFlop = (depthToFlop != null) &&
(treeDepth > 0) & (treeDepth % depthToFlop! == 0);
final resultsFlop =
alignedResults.map((r) => localFlop(r, doFlop: depthFlop));

final alignWidth = results.map((c) => c.value.width).reduce(max);
final resultsExtend = resultsFlop.map((r) =>
signExtend ? r.signExtend(alignWidth) : r.zeroExtend(alignWidth));

final computed = operation(resultsExtend.toList(),
name: 'reduce_d${(treeDepth + 1) + flopDepth * (depthToFlop ?? 0)}');
return (
value: computed,
depth: depthFlop ? 0 : treeDepth + 1,
flopDepth: flopDepth + (depthFlop ? 1 : 0)
);
}
}
}
159 changes: 159 additions & 0 deletions test/reduction_tree_test.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
// Copyright (C) 2025 Intel Corporation
// SPDX-License-Identifier: BSD-3-Clause
//
// reduction_tree_test.dart
// Tests of the ReductionTreeNode generator.
//
// 2025 January 8
// Author: Desmond A Kirkpatrick <[email protected]

import 'dart:async';
import 'package:rohd/rohd.dart';
import 'package:rohd_hcl/rohd_hcl.dart';
import 'package:rohd_vf/rohd_vf.dart';
import 'package:test/test.dart';

Logic addReduceAdders(List<Logic> inputs, {String name = 'prefix'}) {
if (inputs.length < 4) {
return inputs.reduce((v, e) => v + e);
} else {
final add0 =
ParallelPrefixAdder(inputs[0], inputs[1], name: '${name}_add0');
final add1 =
ParallelPrefixAdder(inputs[2], inputs[3], name: '${name}_add1');
final addf = ParallelPrefixAdder(add0.sum, add1.sum, name: '${name}_addf');
return addf.sum;
}
}

void main() {
tearDown(() async {
await Simulator.reset();
});
Logic addReduce(List<Logic> inputs, {String name = ''}) {
final a = inputs.reduce((v, e) => v + e);
return a;
}

test('reduction tree of add operations -- quick test', () async {
const width = 13;
const length = 79;
final vec = <Logic>[];
// First sum will be length *(length-1) /2
var count = 0;
for (var i = 0; i < length; i++) {
vec.add(Const(i, width: width));
count = count + i;
}
for (var radix = 2; radix < length; radix++) {
final prefixAdd = ReductionTree(vec, radix: radix, addReduce);
expect(prefixAdd.out.value.toInt(), equals(count));
}
});
test('reduction tree of adders -- large', () async {
final clk = SimpleClockGenerator(10).clk;

const width = 17;
const length = 290;
final vec = <Logic>[];
// First sum will be length *(length-1) /2
var count = 0;
for (var i = 0; i < length; i++) {
vec.add(Const(i, width: width));
count = count + i;
}
for (var radix = 2; radix < length; radix++) {
final prefixAdd = ReductionTree(vec, radix: radix, addReduce, clk: clk);
expect(prefixAdd.out.value.toInt(), equals(count));
}
});

test('reduction tree of adders -- large, pipelined', () async {
final clk = SimpleClockGenerator(10).clk;

const width = 17;
const length = 290;
final vec = <Logic>[];
// First sum will be length *(length-1) /2
for (var i = 0; i < length; i++) {
vec.add(Const(i, width: width));
}
const radix = 4;
final prefixAdd =
ReductionTree(vec, radix: radix, addReduce, clk: clk, depthToFlop: 1);

await prefixAdd.build();
unawaited(Simulator.run());
var cycles = 0;
await clk.nextNegedge;
cycles++;
// second sum will be length
for (var i = 0; i < length; i++) {
vec[i].inject(1);
}
await clk.nextNegedge;
cycles++;
// third sum will be length *2
for (var i = 0; i < length; i++) {
vec[i].inject(2);
}
if (prefixAdd.latency > cycles) {
await clk.waitCycles(prefixAdd.latency - cycles);
await clk.nextNegedge;
}
expect(prefixAdd.out.value.toInt(), equals(length * (length - 1) / 2));
await clk.nextNegedge;
expect(prefixAdd.out.value.toInt(), equals(length));
await clk.nextNegedge;
expect(prefixAdd.out.value.toInt(), equals(length * 2));
await clk.nextNegedge;
await clk.nextNegedge;
await clk.nextNegedge;
await Simulator.endSimulation();
});

test('reduction tree of prefix adders -- large, pipelined, radix 4',
() async {
final clk = SimpleClockGenerator(10).clk;

const width = 17;
const length = 290;
final vec = <Logic>[];
// First sum will be length *(length-1) /2
for (var i = 0; i < length; i++) {
vec.add(Const(i, width: width));
}
const reduce = 4;
final prefixAdd = ReductionTree(
vec, radix: reduce, addReduceAdders, clk: clk, depthToFlop: 1);

await prefixAdd.build();
unawaited(Simulator.run());
var cycles = 0;
await clk.nextNegedge;
cycles++;
// second sum will be length
for (var i = 0; i < length; i++) {
vec[i].inject(1);
}
await clk.nextNegedge;
cycles++;
// third sum will be length *2
for (var i = 0; i < length; i++) {
vec[i].inject(2);
}
if (prefixAdd.latency > cycles) {
await clk.waitCycles(prefixAdd.latency - cycles);
await clk.nextNegedge;
}
expect(prefixAdd.out.value.toInt(), equals(length * (length - 1) / 2));
await clk.nextNegedge;
expect(prefixAdd.out.value.toInt(), equals(length));
await clk.nextNegedge;
expect(prefixAdd.out.value.toInt(), equals(length * 2));
await clk.nextNegedge;
await clk.nextNegedge;
await clk.nextNegedge;
await Simulator.endSimulation();
});
}

0 comments on commit 09b8a65

Please sign in to comment.