Skip to content

Commit

Permalink
Merge pull request google#383 from alanvgreen/gen2g
Browse files Browse the repository at this point in the history
  • Loading branch information
alanvgreen authored Dec 22, 2021
2 parents 89e7e94 + db66f7e commit fd2124b
Show file tree
Hide file tree
Showing 21 changed files with 10,992 additions and 112 deletions.
2 changes: 2 additions & 0 deletions common/src/tflite.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ static uint8_t tensor_arena[kTensorArenaSize];
#endif
} // anonymous namespace

uint8_t *tflite_tensor_arena = tensor_arena;

static void tflite_init() {
static bool initialized = false;
if (initialized) {
Expand Down
2 changes: 2 additions & 0 deletions common/src/tflite.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,4 +45,6 @@ void tflite_classify();
int8_t* tflite_get_output();
float* tflite_get_output_float();

// The arena
extern uint8_t *tflite_tensor_arena;
#endif // _TFLITE_H
2 changes: 1 addition & 1 deletion proj/hps_accel/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ export DEFINES :=

# Define which generation gateware to use = gen1 or gen2
#GATEWARE_GEN ?= 1
GATEWARE_GEN = 2
GATEWARE_GEN := 1
DEFINES += GATEWARE_GEN=$(GATEWARE_GEN)

# Uncomment this line to use the custom accelerated convolution operation
Expand Down
19 changes: 13 additions & 6 deletions proj/hps_accel/gateware/gen2/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,16 +131,19 @@ def build_filter_store(self, m):
]
return store.values_out

def build_input_fetcher(self, m):
def build_input_fetcher(self, m, stop):
m.submodules['fetcher'] = fetcher = InputFetcher()
# We reset the fetcher when finished calculating to avoid
# spurious first and last signals that might corrupt the next
# accelerator reset.
m.d.comb += [
fetcher.reset.eq(self.reset),
fetcher.reset.eq(self.reset | stop),
fetcher.start.eq(self.start),
fetcher.base_addr.eq(self.config.input_base_addr),
fetcher.num_pixels_x.eq(self.config.num_pixels_x),
fetcher.pixel_advance_x.eq(self.config.pixel_advance_x),
fetcher.pixel_advance_y.eq(self.config.pixel_advance_y),
fetcher.depth.eq(self.config.output_channel_depth >> 4), # divide by 16
fetcher.depth.eq(self.config.output_channel_depth >> 4),
fetcher.num_repeats.eq(self.config.num_repeats),
]
for i in range(4):
Expand Down Expand Up @@ -195,12 +198,13 @@ def build_accumulator_reader(self, m, accumulators, accumulator_news):
m.d.comb += connect(ar.output, al.stream_in)
m.d.comb += al.num_allowed.eq(self.config.num_output_values)
m.d.comb += al.start.eq(self.start)
return al.stream_out
return al.stream_out, al.finished

def elab(self, m):
# Create filter store and input fetcher
filter_values = self.build_filter_store(m)
first, last, activations = self.build_input_fetcher(m)
stop_input = Signal()
first, last, activations = self.build_input_fetcher(m, stop_input)

# Plumb in sysarray and its inputs
m.submodules['sysarray'] = sa = SystolicArray()
Expand All @@ -217,10 +221,13 @@ def elab(self, m):
m.d.comb += sa.last.eq(last)

# Get pipeline inputs from systolic array and parameters
accumulator_stream = self.build_accumulator_reader(
accumulator_stream, finished = self.build_accumulator_reader(
m, sa.accumulator, sa.accumulator_new)
param_stream = self.build_param_store(m)

# When last accumulator read, stop input
m.d.comb += stop_input.eq(finished)

# Plumb in pipeline
m.submodules['ppp'] = ppp = PostProcessPipeline()

Expand Down
3 changes: 3 additions & 0 deletions proj/hps_accel/gateware/gen2/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ class Constants:
# Gets the output value
REG_OUTPUT_WORD = 19

# Number of items in FIFO
REG_FIFO_ITEMS = 20

# Maximum number of 8-bit channels per pixel
MAX_CHANNEL_DEPTH = 512

Expand Down
30 changes: 19 additions & 11 deletions proj/hps_accel/gateware/gen2/hps_cfu.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ class GetInstruction(InstructionBase):
Attributes
----------
reg_fifo_items_value: Signal(32), in
The value to return for the fifo item count register.
reg_verify_value: Signal(32), in
The value to return for the verify register.
Expand All @@ -57,30 +59,35 @@ class GetInstruction(InstructionBase):

def __init__(self):
super().__init__()
self.reg_fifo_items_value = Signal(32)
self.reg_verify_value = Signal(32)
self.output_words = Endpoint(unsigned(32))

def elab(self, m):
m.d.sync += self.done.eq(0)

def get_output():
m.d.comb += self.output_words.ready.eq(1)
with m.If(self.output_words.is_transferring()):
m.d.sync += self.output.eq(self.output_words.payload)
m.d.sync += self.done.eq(1)
m.next = "WAIT_START"
with m.Else():
m.next = "WAIT_OUTPUT"

with m.FSM():
with m.State("WAIT_START"):
with m.If(self.start):
with m.If(self.funct7 == Constants.REG_VERIFY):
m.d.sync += self.output.eq(self.reg_verify_value)
m.d.sync += self.done.eq(1)
with m.Elif(self.funct7 == Constants.REG_FIFO_ITEMS):
m.d.sync += self.output.eq(self.reg_fifo_items_value)
m.d.sync += self.done.eq(1)
with m.Elif(self.funct7 == Constants.REG_OUTPUT_WORD):
m.d.comb += self.output_words.ready.eq(1)
with m.If(self.output_words.is_transferring()):
m.d.sync += self.output.eq(self.output_words.payload)
m.d.sync += self.done.eq(1)
with m.Else():
m.next = "WAIT_OUTPUT"
get_output()
with m.State("WAIT_OUTPUT"):
m.d.comb += self.output_words.ready.eq(1)
with m.If(self.output_words.is_transferring()):
m.d.sync += self.output.eq(self.output_words.payload)
m.d.sync += self.done.eq(1)
m.next = "WAIT_START"
get_output()


class SetInstruction(InstructionBase):
Expand Down Expand Up @@ -208,6 +215,7 @@ def elab_instructions(self, m):
core.start.eq(set_.accelerator_start),
core.reset.eq(set_.accelerator_reset),
core.config.eq(set_.config),
get.reg_fifo_items_value.eq(fifo.r_level),
]
m.d.comb += connect(set_.filter_output, core.write_filter_input)
m.d.comb += connect(set_.post_process_params, core.post_process_params)
Expand Down
7 changes: 7 additions & 0 deletions proj/hps_accel/gateware/gen2/post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -500,6 +500,9 @@ class StreamLimiter(SimpleElaboratable):
running: Signal(), out
Indicates that items are being allowed to pass
finished: Signal(), out
Indicates that last item has been handled.
"""

def __init__(self, payload_type=signed(32)):
Expand All @@ -508,8 +511,10 @@ def __init__(self, payload_type=signed(32)):
self.num_allowed = Signal(18)
self.start = Signal()
self.running = Signal()
self.finished = Signal()

def elab(self, m):
m.d.sync += self.finished.eq(0)
m.d.comb += [
self.stream_in.ready.eq(self.running),
self.stream_out.valid.eq(self.stream_in.is_transferring()),
Expand All @@ -520,6 +525,8 @@ def elab(self, m):
m.d.sync += counter.eq(self.num_allowed)
with m.If(self.stream_in.is_transferring()):
m.d.sync += counter.eq(counter - 1)
with m.If(counter == 1):
m.d.sync += self.finished.eq(1)
m.d.comb += self.running.eq(counter != 0)


Expand Down
9 changes: 5 additions & 4 deletions proj/hps_accel/gateware/gen2/test_hps_cfu.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def load_filter_data(self):
filter_data = self.data.filter_data
num_filter_words_per_output = dims[1] * dims[2] * dims[3] // 4
num_output_channels = dims[0]
store = 0
for chan in range(num_output_channels):
store = chan & 1
chan_start = chan * num_filter_words_per_output
Expand Down Expand Up @@ -72,6 +71,11 @@ def configure(self):
output_depth = data.output_dims[3]
num_filter_values = reduce(lambda a, b: a * b, data.filter_dims, 1)
filter_words_per_store = num_filter_values // 4 // 2

# Toggle reset
yield do_set(C.REG_ACCELERATOR_RESET, 0)

# Set simple values
yield do_set(C.REG_INPUT_OFFSET, data.input_offset)
yield do_set(C.REG_NUM_FILTER_WORDS, filter_words_per_store)
yield do_set(C.REG_OUTPUT_OFFSET, data.output_offset)
Expand All @@ -87,9 +91,6 @@ def configure(self):
yield do_set(C.REG_NUM_OUTPUT_VALUES,
self.NUM_OUTPUT_PIXELS * output_depth)

# Toggle reset
yield do_set(Constants.REG_ACCELERATOR_RESET, 0)

# load post process parameters
for i in range(output_depth):
yield do_set(C.REG_POST_PROCESS_BIAS, data.output_biases[i])
Expand Down
39 changes: 20 additions & 19 deletions proj/hps_accel/gateware/gen2/test_post_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,34 +527,33 @@ def create_dut(self):
def test_simple_case(self):
dut = self.dut

# send 15, check it limits to 10
data = [
# (num_allowed, start, valid, running)
(3, 0, 1, 0),
(3, 0, 1, 0),
# (num_allowed, start, valid, running, finished)
(3, 0, 1, 0, 0),
(3, 0, 1, 0, 0),
# Pass 3 items
(3, 1, 1, 0),
(2, 0, 1, 1),
(2, 0, 1, 1),
(2, 0, 1, 1),
(3, 1, 1, 0, 0),
(2, 0, 1, 1, 0),
(2, 0, 1, 1, 0),
(2, 0, 1, 1, 0),

# Do not allow next few
(2, 0, 1, 0),
(2, 0, 0, 0),
(2, 0, 0, 0),
(2, 0, 1, 0, 1),
(2, 0, 0, 0, 0),
(2, 0, 0, 0, 0),

# Start running again, but do not pass on every cycle
(2, 1, 0, 0),
(2, 0, 0, 1),
(2, 0, 1, 1),
(2, 0, 0, 1),
(2, 0, 0, 1),
(2, 0, 1, 1),
(2, 0, 0, 0),
(2, 1, 0, 0, 0),
(2, 0, 0, 1, 0),
(2, 0, 1, 1, 0),
(2, 0, 0, 1, 0),
(2, 0, 0, 1, 0),
(2, 0, 1, 1, 0),
(2, 0, 0, 0, 1),
]

def process():
for num_allowed, start, input_valid, running in data:
for num_allowed, start, input_valid, running, finished in data:
# Set inputs
payload = random.randrange(256)
yield dut.num_allowed.eq(num_allowed)
Expand All @@ -571,6 +570,8 @@ def process():
self.assertEqual(payload, (yield dut.stream_out.payload))
self.assertEqual(running and input_valid,
(yield dut.stream_out.valid))

self.assertEqual(finished, (yield dut.finished))
yield

self.run_sim(process, False)
Expand Down
4 changes: 4 additions & 0 deletions proj/hps_accel/gateware/stream/fifo.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,16 @@ class StreamFifo(SimpleElaboratable):
Incoming stream of items
output: Endpoint(type), out
Outgoing stream of items
r_level: Signal(depth.bit_length())
Number of items available for reading
"""

def __init__(self, *, type, depth):
self.depth = depth
self.input = Endpoint(type)
self.output = Endpoint(type)
self.r_level = Signal(depth.bit_length())

def elab(self, m: Module):
m.submodules.wrapped = fifo = SyncFIFOBuffered(
Expand All @@ -60,4 +63,5 @@ def elab(self, m: Module):
self.output.valid.eq(fifo.r_rdy),
self.output.payload.eq(fifo.r_data),
fifo.r_en.eq(self.output.ready),
self.r_level.eq(fifo.r_level),
]
Loading

0 comments on commit fd2124b

Please sign in to comment.