Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement tanh activation function #55

Open
wants to merge 1 commit into
base: develop
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions nngen/onnx/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
'Relu': act_func.Relu,
'LeakyRelu': act_func.LeakyRelu,
'Sigmoid': act_func.Sigmoid,
'Tanh': act_func.Tanh,
'BatchNormalization': batchnormalization.BatchNormalization,
'Shape': shape.Shape,
'Reshape': reshape.Reshape,
Expand Down
3 changes: 3 additions & 0 deletions nngen/onnx/act_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,6 @@ def LeakyRelu(visitor, node):

def Sigmoid(visitor, node):
return _act_func(operator.sigmoid, visitor, node)

def Tanh(visitor, node):
return _act_func(operator.tanh, visitor, node)
1 change: 1 addition & 0 deletions nngen/operator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from .relu import relu, relu6
from .leaky_relu import leaky_relu, get_leaky_relu_op, leaky_relu_base
from .sigmoid import sigmoid
from .tanh import tanh
from .matmul import matmul
from .conv2d import conv2d
from .log_weight_conv2d import log_weight_conv2d
Expand Down
118 changes: 118 additions & 0 deletions nngen/operator/tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import functools
import math
import numpy as np
from collections import OrderedDict

import nngen.basic_types as bt
from nngen.quantizer import util

class tanh(bt._ActFuncOperator):

def __init__(self, features,
lut_addrwidth=8, lut_clip=6.0, range_rate=0.95,
dtype=None, name=None, par=1):

shape = None
if features.dtype is not None and features.dtype.width < 8:
lut_addrwidth = features.dtype.width

self.lut_addrwidth = lut_addrwidth
self.lut_clip = lut_clip
self.range_rate = range_rate
bt._ActFuncOperator.__init__(self, features,
dtype=dtype, shape=shape, name=name, par=par)

def _get_expected_scale_factor(self):
return (2 ** (self.lut_addrwidth - 1)) / self.lut_clip

def _get_features_scale_shamt(self):
expected_scale_factor = self._get_expected_scale_factor()

features_scale = np.array([expected_scale_factor / self.args[0].scale_factor])
q_features_scale, scale_factor = util.quantize_linear_scale(features_scale, 32)
q_features_scale = int(q_features_scale[0])
q_features_shamt = round(math.log(scale_factor, 2))
return q_features_scale, q_features_shamt

def get_local_control_param_values(self):
q_features_scale, q_features_shamt = self._get_features_scale_shamt()
return OrderedDict([('features_scale_cparam', q_features_scale),
('features_shamt_cparam', q_features_shamt)])

def get_stream_hash(self):
base = bt._ActFuncOperator.get_stream_hash(self)
return (base, self.lut_addrwidth, self.lut_clip, self.range_rate)

def op(self, strm, *args, **kwargs):
features_signed = self.args[0].get_signed()

features_scale = strm.ReinterpretCast(self.features_scale_cparam,
width=self.features_scale_cparam.width + 1,
signed=features_signed)
mul = strm.Times(args[0], features_scale)
mul.width = mul.width + features_scale.width

features_shamt = strm.ReinterpretCast(self.features_shamt_cparam,
width=self.features_shamt_cparam.width,
signed=False)
sra = strm.Sra(mul, features_shamt)
lut_addr = strm.Slice(sra, self.lut_addrwidth - 1, 0)

out_width = self.dtype.width
out_point = self.dtype.point
out_signed = self.dtype.signed
if out_signed:
out_scale = round((2 ** (out_width - 1)) * self.range_rate)
else:
out_scale = round((2 ** out_width) * self.range_rate)

def _tanh(x):
return int((np.tanh(x) * out_scale).astype(np.int64))

addr_scale = 1 / self._get_expected_scale_factor()
patterns_p = [_tanh(i * addr_scale)
for i in range(2 ** (self.lut_addrwidth - 1))]
patterns_n = [_tanh((-i - 1) * addr_scale)
for i in range(2 ** (self.lut_addrwidth - 1))]
patterns_n.reverse()

patterns = patterns_p + patterns_n

lut = strm.LUT(lut_addr, patterns, out_width, out_point, out_signed)

p_th = 2 ** (self.lut_addrwidth - 1) - 1
n_th = -1 * p_th

if out_point == 0:
th_scale = out_scale
elif out_point > 0:
th_scale = out_scale >> out_point
else:
th_scale = out_scale << (-1 * out_point)

p = strm.Mux(sra > p_th, th_scale, lut)
n = strm.Mux(sra < n_th, 0, lut)
out = strm.Mux(sra >= 0, p, n)

return out

def get_eval_method(self):
import nngen.verify as verify

name = self.__class__.__name__
method = getattr(verify, name, None)

features_scale, features_shamt = self._get_features_scale_shamt()

method = functools.partial(method,
lut_addrwidth=self.lut_addrwidth,
lut_clip=self.lut_clip,
range_rate=self.range_rate,
features_dtype=self.args[0].dtype,
features_scale=features_scale,
features_shamt=features_shamt)
return method
2 changes: 2 additions & 0 deletions nngen/quantizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from . import matmul
from . import normalize
from . import sigmoid
from . import tanh
from . import exp
from . import reduce

Expand All @@ -26,6 +27,7 @@
'scaled_multiply': normalize.scaled_multiply,
'scaled_div': normalize.scaled_div,
'sigmoid': sigmoid.sigmoid,
'tanh': tanh.tanh,
'exp': exp.exp,
'argmax': reduce.argmax,
'argmin': reduce.argmin,
Expand Down
8 changes: 8 additions & 0 deletions nngen/quantizer/tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

from . import sigmoid

def tanh(visitor, node):
sigmoid(visitor, node)
51 changes: 51 additions & 0 deletions nngen/verify/tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from __future__ import absolute_import
from __future__ import print_function
from __future__ import division

import numpy as np


def tanh(features,
lut_addrwidth=8, lut_clip=6.0, range_rate=0.95,
dtype=None, name=None, par=1,
features_dtype=None, features_scale=1, features_shamt=0):

features_point = 0 if features_dtype is None else features_dtype.point
out_point = 0 if dtype is None else dtype.point
out_shift = out_point - features_point

mul = features * features_scale
sra = mul >> features_shamt

if dtype is None:
raise ValueError('tanh requires dtype to determine the value range.')

out_width = dtype.width
out_point = dtype.point
out_signed = dtype.signed
if out_signed:
out_scale = round((2 ** (out_width - 1)) * range_rate)
else:
out_scale = round((2 ** out_width) * range_rate)

def _tanh(x):
return (np.tanh(x) * out_scale).astype(np.int64)

addr_scale = lut_clip / (2 ** (lut_addrwidth - 1))
lut = _tanh(sra * addr_scale)

p_th = 2 ** (lut_addrwidth - 1) - 1
n_th = -1 * p_th

if out_point == 0:
th_scale = out_scale
elif out_point > 0:
th_scale = out_scale >> out_point
else:
th_scale = out_scale << (-1 * out_point)

p = np.where(sra > p_th, th_scale, lut)
n = np.where(sra < n_th, 0, lut)
out = np.where(sra >= 0, p, n)

return out
88 changes: 88 additions & 0 deletions tests/matrix_conv2d/test_matrix_conv2d_int16_3x3_stride1_tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import absolute_import
from __future__ import print_function

import os
import sys

# the next line can be removed after installation
sys.path.insert(0, os.path.dirname(os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))))

import nngen as ng
import veriloggen

import matrix_conv2d


act_shape = (1, 7, 7, 15)
weight_shape = (7, 3, 3, 15)
bias_shape = None
scale_shape = None
act_dtype = ng.int16
weight_dtype = ng.int16
bias_dtype = ng.int32
scale_dtype = ng.int16
out_dtype = ng.int16
stride = (1, 1, 1, 1)
rshift_mul = None
rshift_sum = None
rshift_out = None
act_func = ng.tanh
par_ich = 1
par_och = 1
par_col = 1
par_row = 1
concur_och = None
stationary = 'filter'
input_ram_size = None
filter_ram_size = None
bias_ram_size = None
scale_ram_size = None
out_ram_size = None
axi_datawidth = 32


def test(request, silent=True):
veriloggen.reset()

simtype = request.config.getoption('--sim')

rslt = matrix_conv2d.run(act_shape, weight_shape,
bias_shape, scale_shape,
act_dtype, weight_dtype,
bias_dtype, scale_dtype,
out_dtype,
stride,
rshift_mul, rshift_sum, rshift_out,
act_func,
par_ich, par_och, par_col, par_row,
concur_och, stationary,
input_ram_size, filter_ram_size,
bias_ram_size, scale_ram_size,
out_ram_size,
axi_datawidth, silent,
filename=None, simtype=simtype,
outputfile=os.path.splitext(os.path.basename(__file__))[0] + '.out')

verify_rslt = rslt.splitlines()[-1]
assert(verify_rslt == '# verify: PASSED')


if __name__ == '__main__':
rslt = matrix_conv2d.run(act_shape, weight_shape,
bias_shape, scale_shape,
act_dtype, weight_dtype,
bias_dtype, scale_dtype,
out_dtype,
stride,
rshift_mul, rshift_sum, rshift_out,
act_func,
par_ich, par_och, par_col, par_row,
concur_och, stationary,
input_ram_size, filter_ram_size,
bias_ram_size, scale_ram_size,
out_ram_size,
axi_datawidth, silent=False,
filename='tmp.v',
outputfile=os.path.splitext(os.path.basename(__file__))[0] + '.out')
print(rslt)
88 changes: 88 additions & 0 deletions tests/matrix_conv2d/test_matrix_conv2d_int32_3x3_stride1_tanh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
from __future__ import absolute_import
from __future__ import print_function

import os
import sys

# the next line can be removed after installation
sys.path.insert(0, os.path.dirname(os.path.dirname(
os.path.dirname(os.path.abspath(__file__)))))

import nngen as ng
import veriloggen

import matrix_conv2d


act_shape = (1, 7, 7, 15)
weight_shape = (7, 3, 3, 15)
bias_shape = None
scale_shape = None
act_dtype = ng.int32
weight_dtype = ng.int32
bias_dtype = ng.int32
scale_dtype = ng.int32
out_dtype = ng.int32
stride = (1, 1, 1, 1)
rshift_mul = None
rshift_sum = None
rshift_out = None
act_func = ng.tanh
par_ich = 1
par_och = 1
par_col = 1
par_row = 1
concur_och = None
stationary = 'filter'
input_ram_size = None
filter_ram_size = None
bias_ram_size = None
scale_ram_size = None
out_ram_size = None
axi_datawidth = 32


def test(request, silent=True):
veriloggen.reset()

simtype = request.config.getoption('--sim')

rslt = matrix_conv2d.run(act_shape, weight_shape,
bias_shape, scale_shape,
act_dtype, weight_dtype,
bias_dtype, scale_dtype,
out_dtype,
stride,
rshift_mul, rshift_sum, rshift_out,
act_func,
par_ich, par_och, par_col, par_row,
concur_och, stationary,
input_ram_size, filter_ram_size,
bias_ram_size, scale_ram_size,
out_ram_size,
axi_datawidth, silent,
filename=None, simtype=simtype,
outputfile=os.path.splitext(os.path.basename(__file__))[0] + '.out')

verify_rslt = rslt.splitlines()[-1]
assert(verify_rslt == '# verify: PASSED')


if __name__ == '__main__':
rslt = matrix_conv2d.run(act_shape, weight_shape,
bias_shape, scale_shape,
act_dtype, weight_dtype,
bias_dtype, scale_dtype,
out_dtype,
stride,
rshift_mul, rshift_sum, rshift_out,
act_func,
par_ich, par_och, par_col, par_row,
concur_och, stationary,
input_ram_size, filter_ram_size,
bias_ram_size, scale_ram_size,
out_ram_size,
axi_datawidth, silent=False,
filename='tmp.v',
outputfile=os.path.splitext(os.path.basename(__file__))[0] + '.out')
print(rslt)
Loading