From f4e138bcb96fb7d68b099ea26fdf363c60e85d74 Mon Sep 17 00:00:00 2001 From: Marcelo Cantos <13160581+anzdaddy@users.noreply.github.com> Date: Fri, 29 Nov 2024 09:08:15 +1100 Subject: [PATCH] Faster add, etc. (#85) - Make Add faster. - Add checks that most operations never allocate. - Remove `-count=10` from profiler make rules. - Rename `cmp` to `cmp64` to avoid collision with standard package. - Rework test suite ops as a map of functions. - Remove `errorf` helper (no value-add). --- Makefile | 11 ++-- decimal64.go | 12 +--- decimal64_test.go | 10 ---- decimal64decParts.go | 54 ++++++++++++++---- decimal64math.go | 33 +++++++---- decimal64math_test.go | 22 +++++++ decimalSuite_test.go | 129 ++++++++++++++++++------------------------ uint128.go | 17 ++++-- uint128_test.go | 10 ++++ util_test.go | 26 ++++----- 10 files changed, 184 insertions(+), 140 deletions(-) diff --git a/Makefile b/Makefile index 2df5399..2a4ea8e 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,5 @@ .PHONY: all -all: test-all build-linux lint +all: test-all build-linux lint no-allocs .PHONY: ci ci: test-all no-allocs @@ -8,10 +8,13 @@ ci: test-all no-allocs test-all: test test-32 .PHONY: test -test: - go test $(GOTESTFLAGS) +test: test-release go test $(GOTESTFLAGS) -tags=decimal_debug +.PHONY: test-release +test-release: + go test $(GOTESTFLAGS) + .PHONY: test-32 test-32: if [ "$(shell go env GOOS)" = "linux" ]; then \ @@ -59,7 +62,7 @@ lint: build-linux .INTERMEDIATE: %.prof %.prof: $(wildcard *.go) - go test -$*profile $@ -count=10 $(GOPROFILEFLAGS) + go test -$*profile $@ $(GOPROFILEFLAGS) .PHONY: bench bench: bench.txt diff --git a/decimal64.go b/decimal64.go index 3641b9e..58f8204 100644 --- a/decimal64.go +++ b/decimal64.go @@ -222,8 +222,7 @@ func renormalize(exp int16, significand uint64) (int16, uint64) { } // roundStatus gives info about the truncated part of the significand that can't be fully stored in 16 decimal digits. -func roundStatus(significand uint64, exp int16, targetExp int16) discardedDigit { - expDiff := targetExp - exp +func roundStatus(significand uint64, expDiff int16) discardedDigit { if expDiff > 19 && significand != 0 { return lt5 } @@ -530,15 +529,6 @@ func (d Decimal64) Class() string { return "+Normal-Normal"[7*dp.sign : 7*(dp.sign+1)] } -// numDecimalDigitsU64 returns the magnitude (number of digits) of a uint64. -func numDecimalDigitsU64(n uint64) int16 { - numDigits := int16(bits.Len64(n) * 77 / 256) // ~ 3/10 - if n >= tenToThe[uint(numDigits)%uint(len(tenToThe))] { - numDigits++ - } - return numDigits -} - func checkNan(d, e Decimal64, dp, ep *decParts) (Decimal64, bool) { dp.fl = d.flavor() ep.fl = e.flavor() diff --git a/decimal64_test.go b/decimal64_test.go index 05efae2..fdb8023 100644 --- a/decimal64_test.go +++ b/decimal64_test.go @@ -315,16 +315,6 @@ func TestDecimal64isZero(t *testing.T) { check(t, !One64.IsZero()) } -func TestNumDecimalDigits(t *testing.T) { - t.Parallel() - - for i, num := range tenToThe { - for j := uint64(1); j < 10 && i < 19; j++ { - equal(t, i+1, int(numDecimalDigitsU64(num*j))) - } - } -} - func TestIsSubnormal(t *testing.T) { t.Parallel() diff --git a/decimal64decParts.go b/decimal64decParts.go index e722972..0f6f168 100644 --- a/decimal64decParts.go +++ b/decimal64decParts.go @@ -38,6 +38,38 @@ func (ans *decParts) add128(dp, ep *decParts) { } } +// add64 adds the low 64 bits of two decParts +func (ans *decParts) add64(dp, ep *decParts) { + ans.exp = dp.exp + switch { + case dp.sign == ep.sign: + ans.sign = dp.sign + ans.significand.lo = dp.significand.lo + ep.significand.lo + case dp.significand.lt(&ep.significand): + ans.sign = ep.sign + ans.significand.lo = ep.significand.lo - dp.significand.lo + case ep.significand.lt(&dp.significand): + ans.sign = dp.sign + ans.significand.lo = dp.significand.lo - ep.significand.lo + } +} + +// add128 adds two decParts with full precision in 128 bits of significand +func (ans *decParts) add128V2(dp, ep *decParts) { + ans.exp = dp.exp + switch { + case dp.sign == ep.sign: + ans.sign = dp.sign + ans.significand.add(&dp.significand, &ep.significand) + case dp.significand.lt(&ep.significand): + ans.sign = ep.sign + ans.significand.sub(&ep.significand, &dp.significand) + case ep.significand.lt(&dp.significand): + ans.sign = dp.sign + ans.significand.sub(&dp.significand, &ep.significand) + } +} + func (dp *decParts) matchScales128(ep *decParts) { expDiff := ep.exp - dp.exp if (ep.significand != uint128T{0, 0}) { @@ -56,10 +88,10 @@ func (dp *decParts) roundToLo() discardedDigit { if ds := &dp.significand; ds.hi > 0 || ds.lo >= 10*decimal64Base { var remainder uint64 - expDiff := ds.numDecimalDigits() - 16 + expDiff := int16(ds.numDecimalDigits()) - 16 dp.exp += expDiff remainder = ds.divrem64(ds, tenToThe[expDiff]) - rndStatus = roundStatus(remainder, 0, expDiff) + rndStatus = roundStatus(remainder, expDiff) } return rndStatus } @@ -74,7 +106,9 @@ func (dp *decParts) isSubnormal() bool { // separation gets the separation in decimal places of the MSD's of two decimal 64s func (dp *decParts) separation(ep *decParts) int16 { - return dp.significand.numDecimalDigits() + dp.exp - ep.significand.numDecimalDigits() - ep.exp + sep := int16(dp.significand.numDecimalDigits()) + dp.exp + sep -= int16(ep.significand.numDecimalDigits()) + ep.exp + return sep } // removeZeros removes zeros and increments the exponent to match. @@ -115,18 +149,17 @@ func (dp *decParts) isinf() bool { return dp.fl == flInf } -func (dp *decParts) rescale(targetExp int16) (rndStatus discardedDigit) { +func (dp *decParts) rescale(targetExp int16) discardedDigit { expDiff := targetExp - dp.exp - mag := dp.significand.numDecimalDigits() - rndStatus = roundStatus(dp.significand.lo, dp.exp, targetExp) - if expDiff > mag { + rndStatus := roundStatus(dp.significand.lo, expDiff) + if expDiff > int16(dp.significand.numDecimalDigits()) { dp.significand.lo, dp.exp = 0, targetExp - return + return rndStatus } divisor := tenToThe[expDiff] dp.significand.lo = dp.significand.lo / divisor dp.exp = targetExp - return + return rndStatus } func (dp *decParts) unpack(d Decimal64) { @@ -142,9 +175,6 @@ func (dp *decParts) unpackV2(d Decimal64) { // EE ∈ {00, 01, 10} dp.exp = int16((d.bits>>(63-10))&(1<<10-1)) - expOffset dp.significand.lo = d.bits & (1<<53 - 1) - if dp.significand.lo == 0 { - dp.exp = 0 - } case flNormal51: // s 11EEeeeeeeee (100)t tttttttttt tttttttttt tttttttttt tttttttttt tttttttttt // EE ∈ {00, 01, 10} diff --git a/decimal64math.go b/decimal64math.go index f00e516..b3342d7 100644 --- a/decimal64math.go +++ b/decimal64math.go @@ -61,7 +61,7 @@ func (d Decimal64) Cmp(e Decimal64) int { if _, nan := checkNan(d, e, &dp, &ep); nan { return -2 } - return cmp(d, e, &dp, &ep) + return cmp64(d, e, &dp, &ep) } // Cmp64 returns the same output as Cmp as a Decimal64, unless d or e is NaN, in @@ -71,7 +71,7 @@ func (d Decimal64) Cmp64(e Decimal64) Decimal64 { if nan, is := checkNan(d, e, &dp, &ep); is { return nan } - switch cmp(d, e, &dp, &ep) { + switch cmp64(d, e, &dp, &ep) { case -1: return NegOne64 case 1: @@ -81,9 +81,9 @@ func (d Decimal64) Cmp64(e Decimal64) Decimal64 { } } -func cmp(d, e Decimal64, dp, ep *decParts) int { +func cmp64(d, e Decimal64, dp, ep *decParts) int { switch { - case dp.isZero() && ep.isZero(), d == e: + case d == e, dp.isZero() && ep.isZero(): return 0 default: diff := d.Sub(e) @@ -112,7 +112,7 @@ func (d Decimal64) min(e Decimal64, sign int) Decimal64 { switch { case !dnan && !enan: // Fast path for non-NaNs. - if sign*cmp(d, e, &dp, &ep) < 0 { + if sign*cmp64(d, e, &dp, &ep) < 0 { return d } return e @@ -152,7 +152,7 @@ func (d Decimal64) minMag(e Decimal64, sign int) Decimal64 { switch { case !dnan && !enan: // Fast path for non-NaNs. - switch sign * cmp(da, ea, &dp, &ep) { + switch sign * cmp64(da, ea, &dp, &ep) { case -1: return d case 1: @@ -334,7 +334,7 @@ func (ctx Context64) add(d, e Decimal64, dp, ep *decParts) Decimal64 { } else if ep.significand.lo == 0 { return d } - sep := dp.separation(ep) + sep := dp.exp - ep.exp if sep < -17 { return e @@ -348,13 +348,26 @@ func (ctx Context64) add(d, e Decimal64, dp, ep *decParts) Decimal64 { } var rndStatus discardedDigit var ans decParts - ans.add128(dp, ep) + switch { + case sep == 0: + ans.add64(dp, ep) + case sep < 4: + dp.significand.lo *= tenToThe[sep] + dp.exp -= sep + ans.add64(dp, ep) + default: + dp.significand.mul64(&dp.significand, tenToThe[17]) + dp.exp -= 17 + ep.significand.mul64(&ep.significand, tenToThe[17-sep]) + ep.exp -= 17 - sep + ans.add128V2(dp, ep) + } rndStatus = ans.roundToLo() if ans.exp < -expOffset { rndStatus = ans.rescale(-expOffset) } ans.significand.lo = ctx.Rounding.round(ans.significand.lo, rndStatus) - if ans.exp >= -expOffset && ans.significand.lo != 0 { + if ans.significand.lo != 0 { ans.exp, ans.significand.lo = renormalize(ans.exp, ans.significand.lo) } if ans.exp > expMax || ans.significand.lo > maxSig { @@ -365,7 +378,7 @@ func (ctx Context64) add(d, e Decimal64, dp, ep *decParts) Decimal64 { // Add computes d + e func (ctx Context64) Sub(d, e Decimal64) Decimal64 { - return d.Add(e.Neg()) + return d.Add(new64(neg64 ^ e.bits)) } // FMA computes d*e + f diff --git a/decimal64math_test.go b/decimal64math_test.go index 6f6e25e..498f4db 100644 --- a/decimal64math_test.go +++ b/decimal64math_test.go @@ -48,6 +48,28 @@ func TestDecimal64Add(t *testing.T) { func(a, b int64) int64 { return a + b }, func(a, b Decimal64) Decimal64 { return a.Add(b) }, ) + + add := func(a, b, expected string, ctx *Context64) func(*testing.T) { + return func(*testing.T) { + t.Helper() + + e := MustParse64(expected) + x := MustParse64(a) + y := MustParse64(b) + if ctx == nil { + ctx = &DefaultContext64 + } + replayOnFail(t, func() { + z := ctx.Add(x, y) + equalD64(t, e, z) + }) + } + } + + t.Run("tiny-neg", add("1E-383", "-1E-398", "9.99999999999999E-384", nil)) + + he := Context64{Rounding: HalfEven} + t.Run("round-even", add("12345678", "0.123456785", "12345678.12345678", &he)) } func TestDecimal64AddNaN(t *testing.T) { diff --git a/decimalSuite_test.go b/decimalSuite_test.go index d61c37b..462ab2b 100644 --- a/decimalSuite_test.go +++ b/decimalSuite_test.go @@ -265,88 +265,69 @@ func convertToDec64(testvals *testCase) (opResult, error) { } // runTest completes the tests and compares actual and expected results. -func runTest(t *testing.T, context Context64, expected opResult, testValStrings *testCase) bool { - actual := execOp(context, expected.val1, expected.val2, expected.val3, testValStrings.function) - - if actual.text != "" { - if testValStrings.function == "compare" && actual.text == "-2" && expected.result.IsNaN() { - return true - } - if actual.text != testValStrings.expectedResult { - t.Errorf("test:\n%s\ncalculated text: %s", testValStrings, actual.text) - return false - } - return true - } - if actual.result.IsNaN() || expected.result.IsNaN() { - e := expected.result.String() - a := actual.result.String() - if e != a { +func runTest(t *testing.T, context Context64, expected opResult, testValStrings *testCase) pass { + return replayOnFail(t, func() { + actual := execOp(context, expected.val1, expected.val2, expected.val3, testValStrings.function) + switch { + case actual.text != "": + if testValStrings.function == "compare" && actual.text == "-2" && expected.result.IsNaN() { + return + } + if actual.text != testValStrings.expectedResult { + t.Errorf("test:\n%s\ncalculated text: %s", testValStrings, actual.text) + } + case actual.result.IsNaN() || expected.result.IsNaN(): + e := expected.result.String() + a := actual.result.String() + if e != a { + t.Errorf("test:\n%s\ncalculated result: %v", testValStrings, actual.result) + } + case expected.result.Cmp(actual.result) != 0: t.Errorf("test:\n%s\ncalculated result: %v", testValStrings, actual.result) - return false } - return true - } - if expected.result.Cmp(actual.result) != 0 { - t.Errorf("test:\n%s\ncalculated result: %v", testValStrings, actual.result) - return false - } - return true + }) } var textResults = set{"class": {}} +var ops = map[string]func(ctx Context64, a, b, c Decimal64) any{ + "add": func(ctx Context64, a, b, c Decimal64) any { return ctx.Add(a, b) }, + "abs": func(ctx Context64, a, b, c Decimal64) any { return a.Abs() }, + "class": func(ctx Context64, a, b, c Decimal64) any { return a.Class() }, + "compare": func(ctx Context64, a, b, c Decimal64) any { return a.Cmp64(b) }, + "copysign": func(ctx Context64, a, b, c Decimal64) any { return a.CopySign(b) }, + "divide": func(ctx Context64, a, b, c Decimal64) any { return ctx.Quo(a, b) }, + "fma": func(ctx Context64, a, b, c Decimal64) any { return ctx.FMA(a, b, c) }, + "logb": func(ctx Context64, a, b, c Decimal64) any { return a.Logb() }, + "max": func(ctx Context64, a, b, c Decimal64) any { return a.Max(b) }, + "maxmag": func(ctx Context64, a, b, c Decimal64) any { return a.MaxMag(b) }, + "min": func(ctx Context64, a, b, c Decimal64) any { return a.Min(b) }, + "minmag": func(ctx Context64, a, b, c Decimal64) any { return a.MinMag(b) }, + "minus": func(ctx Context64, a, b, c Decimal64) any { return a.Neg() }, + "multiply": func(ctx Context64, a, b, c Decimal64) any { return ctx.Mul(a, b) }, + "nextminus": func(ctx Context64, a, b, c Decimal64) any { return a.NextMinus() }, + "nextplus": func(ctx Context64, a, b, c Decimal64) any { return a.NextPlus() }, + "plus": func(ctx Context64, a, b, c Decimal64) any { return a }, + "scaleb": func(ctx Context64, a, b, c Decimal64) any { return a.ScaleB(b) }, + "round": func(ctx Context64, a, b, c Decimal64) any { return ctx.Round(a, b) }, + "tointegralx": func(ctx Context64, a, b, c Decimal64) any { return ctx.ToIntegral(a) }, + "subtract": func(ctx Context64, a, b, c Decimal64) any { return ctx.Add(a, b.Neg()) }, + "squareroot": func(ctx Context64, a, b, c Decimal64) any { return a.Sqrt() }, + // "quantize": func(ctx Context64, a, b, c Decimal64) any { return ctx.Quantize(a, b) }, +} + // TODO: get runTest to run more functions such as FMA. // execOp returns the calculated answer to the operation as Decimal64. func execOp(ctx Context64, a, b, c Decimal64, op string) opResult { - switch op { - case "add": - return opResult{result: ctx.Add(a, b)} - case "abs": - return opResult{result: a.Abs()} - case "class": - return opResult{text: a.Class()} - case "compare": - return opResult{result: a.Cmp64(b)} - case "copysign": - return opResult{result: a.CopySign(b)} - case "divide": - return opResult{result: ctx.Quo(a, b)} - case "fma": - return opResult{result: ctx.FMA(a, b, c)} - case "logb": - return opResult{result: a.Logb()} - case "max": - return opResult{result: a.Max(b)} - case "maxmag": - return opResult{result: a.MaxMag(b)} - case "min": - return opResult{result: a.Min(b)} - case "minmag": - return opResult{result: a.MinMag(b)} - case "minus": - return opResult{result: a.Neg()} - case "multiply": - return opResult{result: ctx.Mul(a, b)} - case "nextminus": - return opResult{result: a.NextMinus()} - case "nextplus": - return opResult{result: a.NextPlus()} - case "plus": - return opResult{result: a} - // case "quantize": - // return opResult{result: ctx.Quantize(a, b)} - case "scaleb": - return opResult{result: a.ScaleB(b)} - case "round": - return opResult{result: ctx.Round(a, b)} - case "tointegralx": - return opResult{result: ctx.ToIntegral(a)} - case "subtract": - return opResult{result: ctx.Add(a, b.Neg())} - case "squareroot": - return opResult{result: a.Sqrt()} - default: - panic(fmt.Errorf("unhandled op: %s", op)) + if f, has := ops[op]; has { + switch a := f(ctx, a, b, c).(type) { + case string: + return opResult{text: a} + case Decimal64: + return opResult{result: a} + default: + panic("wat?") + } } + panic(fmt.Errorf("unhandled op: %s", op)) } diff --git a/uint128.go b/uint128.go index 1798641..47b73b5 100644 --- a/uint128.go +++ b/uint128.go @@ -10,16 +10,25 @@ type uint128T struct { lo, hi uint64 } -func (a *uint128T) numDecimalDigits() int16 { +func (a *uint128T) numDecimalDigits() int { if a.hi == 0 { return numDecimalDigitsU64(a.lo) } bitSize := 65 + bits.Len64(a.hi) - numDigitsEst := int16(bitSize * 77 / 256) - if !a.lt(&tenToThe128[uint(numDigitsEst)%uint(len(tenToThe128))]) { + numDigitsEst := uint(bitSize) * 77 / 256 + if !a.lt(&tenToThe128[numDigitsEst%uint(len(tenToThe128))]) { numDigitsEst++ } - return numDigitsEst + return int(numDigitsEst) +} + +// numDecimalDigitsU64 returns the magnitude (number of digits) of a uint64. +func numDecimalDigitsU64(n uint64) int { + numDigits := uint(bits.Len64(n)) * 77 / 256 // ~ 3/10 + if n >= tenToThe[numDigits%uint(len(tenToThe))] { + numDigits++ + } + return int(numDigits) } var tenToThe128 = func() [64]uint128T { diff --git a/uint128_test.go b/uint128_test.go index 2397a27..f6c6525 100644 --- a/uint128_test.go +++ b/uint128_test.go @@ -22,6 +22,16 @@ func TestUint128Shl(t *testing.T) { test(uint128T{0, 3}, uint128T{3, 42}, 64) } +func TestNumDecimalDigits(t *testing.T) { + t.Parallel() + + for i, num := range tenToThe { + for j := uint64(1); j < 10 && i < 19; j++ { + equal(t, i+1, int(numDecimalDigitsU64(num*j))) + } + } +} + func TestSqrtu64(t *testing.T) { t.Parallel() diff --git a/util_test.go b/util_test.go index ac2feb5..9d23652 100644 --- a/util_test.go +++ b/util_test.go @@ -2,24 +2,20 @@ package decimal import "testing" -func errorf(t *testing.T, format string, args ...any) { - t.Helper() - t.Errorf(format, args...) -} - -func replayOnFail(t *testing.T, f func()) { +func replayOnFail(t *testing.T, f func()) pass { t.Helper() alreadyFailed := t.Failed() defer func() { t.Helper() if r := recover(); r != nil { - errorf(t, "panic: %+v", r) + t.Errorf("panic: %+v", r) } if !alreadyFailed && t.Failed() { f() // Set a breakpoint here to replay the first failed test. } }() f() + return !pass(alreadyFailed && t.Failed()) } type pass bool @@ -33,7 +29,7 @@ func (p pass) Or(f func()) { func check(t *testing.T, ok bool) pass { t.Helper() if !ok { - errorf(t, "expected true") + t.Errorf("expected true") return false } return true @@ -42,7 +38,7 @@ func check(t *testing.T, ok bool) pass { func epsilon(t *testing.T, a, b float64) pass { t.Helper() if a/b-1 > 0.00000001 { - errorf(t, "%f and %f too dissimilar", a, b) + t.Errorf("%f and %f too dissimilar", a, b) return false } return true @@ -51,7 +47,7 @@ func epsilon(t *testing.T, a, b float64) pass { func equal[T comparable](t *testing.T, a, b T) pass { t.Helper() if a != b { - errorf(t, "expected %+v, got %+v", a, b) + t.Errorf("expected %+v, got %+v", a, b) return false } return true @@ -65,7 +61,7 @@ func equalD64(t *testing.T, expected, actual Decimal64) pass { func isnil(t *testing.T, a any) pass { t.Helper() if a != nil { - errorf(t, "expected nil, got %+v", a) + t.Errorf("expected nil, got %+v", a) return false } return true @@ -76,7 +72,7 @@ func nopanic(t *testing.T, f func()) (b pass) { defer func() { t.Helper() if r := recover(); r != nil { - errorf(t, "panic: %+v", r) + t.Errorf("panic: %+v", r) b = false } }() @@ -87,7 +83,7 @@ func nopanic(t *testing.T, f func()) (b pass) { func notequal[T comparable](t *testing.T, a, b T) pass { t.Helper() if a == b { - errorf(t, "equal values %+v", a) + t.Errorf("equal values %+v", a) return false } return true @@ -96,7 +92,7 @@ func notequal[T comparable](t *testing.T, a, b T) pass { func notnil(t *testing.T, a any) pass { t.Helper() if a == nil { - errorf(t, "expected non-nil") + t.Errorf("expected non-nil") return false } return true @@ -107,7 +103,7 @@ func panics(t *testing.T, f func()) (b pass) { defer func() { t.Helper() if r := recover(); r == nil { - errorf(t, "expected panic") + t.Errorf("expected panic") b = false } }()