diff --git a/breaker.go b/breaker.go index 1309f6d..7184453 100644 --- a/breaker.go +++ b/breaker.go @@ -27,13 +27,13 @@ const ( ) // State represents the internal state of CB. -type State string +type State[T any] string // State constants. const ( - StateClosed State = "closed" - StateOpen State = "open" - StateHalfOpen State = "half-open" + StateClosed State[any] = "closed" + StateOpen State[any] = "open" + StateHalfOpen State[any] = "half-open" ) // DefaultOpenBackOff returns defaultly used BackOff. @@ -77,7 +77,7 @@ func (c *Counters) incrementFailures() { } // StateChangeHook is a function which will be invoked when the state is changed. -type StateChangeHook func(oldState, newState State) +type StateChangeHook[T any] func(oldState, newState State[T]) // TripFunc is a function to determine if CircuitBreaker should open (trip) or // not. TripFunc is called when cb.Fail was called and the state was @@ -149,7 +149,7 @@ func MarkAsSuccess(err error) error { } // Options holds CircuitBreaker configuration options. -type options struct { +type options[T any] struct { // Clock to be used by CircuitBreaker. If nil, real-time clock is // used. clock clock.Clock @@ -191,7 +191,7 @@ type options struct { shouldTrip TripFunc // OnStateChange is a function which will be invoked when the state is changed. - onStateChange StateChangeHook + onStateChange StateChangeHook[T] // FailOnContextCancel controls if CircuitBreaker mark an error when the // passed context.Done() is context.Canceled as a fail. @@ -203,97 +203,97 @@ type options struct { } // CircuitBreaker provides circuit breaker pattern. -type CircuitBreaker struct { +type CircuitBreaker[T any] struct { clock clock.Clock interval time.Duration halfOpenMaxSuccesses int64 openBackOff backoff.BackOff shouldTrip TripFunc - onStateChange StateChangeHook + onStateChange StateChangeHook[T] failOnContextCancel bool failOnContextDeadline bool mu sync.RWMutex - state state + state state[T] cnt Counters } -type fnApplyOptions func(*options) +type fnApplyOptions[T any] func(*options[T]) // BreakerOption interface for applying configuration in the constructor -type BreakerOption interface { - apply(*options) +type BreakerOption[T any] interface { + apply(*options[T]) } -func (f fnApplyOptions) apply(options *options) { +func (f fnApplyOptions[T]) apply(options *options[T]) { f(options) } // WithTripFunc Set the function for counter -func WithTripFunc(tripFunc TripFunc) BreakerOption { - return fnApplyOptions(func(options *options) { +func WithTripFunc[T any](tripFunc TripFunc) BreakerOption[T] { + return fnApplyOptions[T](func(options *options[T]) { options.shouldTrip = tripFunc }) } // WithClock Set the clock -func WithClock(clock clock.Clock) BreakerOption { - return fnApplyOptions(func(options *options) { +func WithClock[T any](clock clock.Clock) BreakerOption[T] { + return fnApplyOptions[T](func(options *options[T]) { options.clock = clock }) } // WithOpenTimeoutBackOff Set the time backoff -func WithOpenTimeoutBackOff(backoff backoff.BackOff) BreakerOption { - return fnApplyOptions(func(options *options) { +func WithOpenTimeoutBackOff[T any](backoff backoff.BackOff) BreakerOption[T] { + return fnApplyOptions[T](func(options *options[T]) { options.openBackOff = backoff }) } // WithOpenTimeout Set the timeout of the circuit breaker -func WithOpenTimeout(timeout time.Duration) BreakerOption { - return fnApplyOptions(func(options *options) { +func WithOpenTimeout[T any](timeout time.Duration) BreakerOption[T] { + return fnApplyOptions[T](func(options *options[T]) { options.openTimeout = timeout }) } // WithHalfOpenMaxSuccesses Set the number of half open successes -func WithHalfOpenMaxSuccesses(maxSuccesses int64) BreakerOption { - return fnApplyOptions(func(options *options) { +func WithHalfOpenMaxSuccesses[T any](maxSuccesses int64) BreakerOption[T] { + return fnApplyOptions[T](func(options *options[T]) { options.halfOpenMaxSuccesses = maxSuccesses }) } // WithCounterResetInterval Set the interval of the circuit breaker, which is the cyclic time period to reset the internal counters -func WithCounterResetInterval(interval time.Duration) BreakerOption { - return fnApplyOptions(func(options *options) { +func WithCounterResetInterval[T any](interval time.Duration) BreakerOption[T] { + return fnApplyOptions[T](func(options *options[T]) { options.interval = interval }) } // WithFailOnContextCancel Set if the context should fail on cancel -func WithFailOnContextCancel(failOnContextCancel bool) BreakerOption { - return fnApplyOptions(func(options *options) { +func WithFailOnContextCancel[T any](failOnContextCancel bool) BreakerOption[T] { + return fnApplyOptions[T](func(options *options[T]) { options.failOnContextCancel = failOnContextCancel }) } // WithFailOnContextDeadline Set if the context should fail on deadline -func WithFailOnContextDeadline(failOnContextDeadline bool) BreakerOption { - return fnApplyOptions(func(options *options) { +func WithFailOnContextDeadline[T any](failOnContextDeadline bool) BreakerOption[T] { + return fnApplyOptions[T](func(options *options[T]) { options.failOnContextDeadline = failOnContextDeadline }) } // WithOnStateChangeHookFn set a hook function that trigger if the condition of the StateChangeHook is true -func WithOnStateChangeHookFn(hookFn StateChangeHook) BreakerOption { - return fnApplyOptions(func(options *options) { +func WithOnStateChangeHookFn[T any](hookFn StateChangeHook[T]) BreakerOption[T] { + return fnApplyOptions[T](func(options *options[T]) { options.onStateChange = hookFn }) } -func defaultOptions() *options { - return &options{ +func defaultOptions[T any]() *options[T] { + return &options[T]{ shouldTrip: DefaultTripFunc, clock: clock.New(), openBackOff: DefaultOpenBackOff(), @@ -324,8 +324,8 @@ func defaultOptions() *options { // ) // // The default options are described in the defaultOptions function -func New(opts ...BreakerOption) *CircuitBreaker { - cbOptions := defaultOptions() +func New[T any](opts ...BreakerOption[T]) *CircuitBreaker[T] { + cbOptions := defaultOptions[T]() for _, opt := range opts { opt.apply(cbOptions) @@ -335,7 +335,7 @@ func New(opts ...BreakerOption) *CircuitBreaker { cbOptions.openBackOff = backoff.NewConstantBackOff(cbOptions.openTimeout) } - cb := &CircuitBreaker{ + cb := &CircuitBreaker[T]{ shouldTrip: cbOptions.shouldTrip, onStateChange: cbOptions.onStateChange, clock: cbOptions.clock, @@ -345,12 +345,12 @@ func New(opts ...BreakerOption) *CircuitBreaker { failOnContextCancel: cbOptions.failOnContextCancel, failOnContextDeadline: cbOptions.failOnContextDeadline, } - cb.setState(&stateClosed{}) + cb.setState(&stateClosed[T]{}) return cb } // An Operation is executed by Do(). -type Operation func() (interface{}, error) +type Operation[T any] func() (T, error) // Do executes the Operation o and returns the return values if // cb.Ready() is true. If not ready, cb doesn't execute f and returns @@ -372,17 +372,18 @@ type Operation func() (interface{}, error) // If given Options' FailOnContextDeadline is false (default), cb.Do // doesn't mark the Operation's error as a failure if ctx.Err() returns // context.DeadlineExceeded. -func (cb *CircuitBreaker) Do(ctx context.Context, o Operation) (interface{}, error) { +func (cb *CircuitBreaker[T]) Do(ctx context.Context, o Operation[T]) (T, error) { + var ret T if !cb.Ready() { - return nil, ErrOpen + return ret, ErrOpen } - result, err := o() - return result, cb.Done(ctx, err) + ret, err := o() + return ret, cb.Done(ctx, err) } // Ready reports if cb is ready to execute an operation. Ready does not give // any change to cb. -func (cb *CircuitBreaker) Ready() bool { +func (cb *CircuitBreaker[T]) Ready() bool { cb.mu.RLock() defer cb.mu.RUnlock() return cb.state.ready(cb) @@ -390,7 +391,7 @@ func (cb *CircuitBreaker) Ready() bool { // Success signals that an execution of operation has been completed // successfully to cb. -func (cb *CircuitBreaker) Success() { +func (cb *CircuitBreaker[T]) Success() { cb.mu.Lock() defer cb.mu.Unlock() cb.cnt.incrementSuccesses() @@ -398,7 +399,7 @@ func (cb *CircuitBreaker) Success() { } // Fail signals that an execution of operation has been failed to cb. -func (cb *CircuitBreaker) Fail() { +func (cb *CircuitBreaker[T]) Fail() { cb.mu.Lock() defer cb.mu.Unlock() cb.cnt.incrementFailures() @@ -409,7 +410,7 @@ func (cb *CircuitBreaker) Fail() { // and ctx is done with context.Canceled error, no Fail() called. Similarly, if // FailOnContextDeadline is false and ctx is done with context.DeadlineExceeded // error, no Fail() called. -func (cb *CircuitBreaker) FailWithContext(ctx context.Context) { +func (cb *CircuitBreaker[T]) FailWithContext(ctx context.Context) { if ctxErr := ctx.Err(); ctxErr != nil { if ctxErr == context.Canceled && !cb.failOnContextCancel { return @@ -425,7 +426,7 @@ func (cb *CircuitBreaker) FailWithContext(ctx context.Context) { // Done calls Success and returns nil. If err is a SuccessMarkableError or // IgnorableError, Done returns wrapped error. Otherwise, Done calls // FailWithContext internally. -func (cb *CircuitBreaker) Done(ctx context.Context, err error) error { +func (cb *CircuitBreaker[T]) Done(ctx context.Context, err error) error { if err == nil { cb.Success() return nil @@ -445,7 +446,7 @@ func (cb *CircuitBreaker) Done(ctx context.Context, err error) error { } // State reports the curent State of cb. -func (cb *CircuitBreaker) State() State { +func (cb *CircuitBreaker[T]) State() State[T] { cb.mu.Lock() defer cb.mu.Unlock() return cb.state.State() @@ -453,41 +454,41 @@ func (cb *CircuitBreaker) State() State { // Counters returns internal counters. If current status is not // StateClosed, returns zero value. -func (cb *CircuitBreaker) Counters() Counters { +func (cb *CircuitBreaker[T]) Counters() Counters { cb.mu.Lock() defer cb.mu.Unlock() return cb.cnt } // Reset resets cb's state with StateClosed. -func (cb *CircuitBreaker) Reset() { +func (cb *CircuitBreaker[T]) Reset() { cb.mu.Lock() defer cb.mu.Unlock() cb.cnt.reset() - cb.setState(&stateClosed{}) + cb.setState(&stateClosed[T]{}) } // SetState set state of cb to st. -func (cb *CircuitBreaker) SetState(st State) { +func (cb *CircuitBreaker[T]) SetState(st State[T]) { switch st { - case StateClosed: - cb.setStateWithLock(&stateClosed{}) - case StateOpen: - cb.setStateWithLock(&stateOpen{}) - case StateHalfOpen: - cb.setStateWithLock(&stateHalfOpen{}) + case State[T](StateClosed): + cb.setStateWithLock(&stateClosed[T]{}) + case State[T](StateOpen): + cb.setStateWithLock(&stateOpen[T]{}) + case State[T](StateHalfOpen): + cb.setStateWithLock(&stateHalfOpen[T]{}) default: panic("undefined state") } } -func (cb *CircuitBreaker) setStateWithLock(s state) { +func (cb *CircuitBreaker[T]) setStateWithLock(s state[T]) { cb.mu.Lock() defer cb.mu.Unlock() cb.setState(s) } -func (cb *CircuitBreaker) setState(s state) { +func (cb *CircuitBreaker[T]) setState(s state[T]) { if cb.state != nil { cb.state.onExit(cb) } @@ -497,7 +498,7 @@ func (cb *CircuitBreaker) setState(s state) { cb.handleOnStateChange(from, s) } -func (cb *CircuitBreaker) handleOnStateChange(from, to state) { +func (cb *CircuitBreaker[T]) handleOnStateChange(from, to state[T]) { if from == nil || cb.onStateChange == nil { return } diff --git a/breaker_test.go b/breaker_test.go index 012f901..bd87826 100644 --- a/breaker_test.go +++ b/breaker_test.go @@ -9,7 +9,7 @@ import ( "time" "github.com/benbjohnson/clock" - "github.com/mercari/go-circuitbreaker" + "github.com/mercari/go-circuitbreaker/v2" "github.com/stretchr/testify/assert" ) @@ -23,7 +23,7 @@ func fetchUserInfo(ctx context.Context, name string) (*user, error) { } func ExampleCircuitBreaker() { - cb := circuitbreaker.New(nil) + cb := circuitbreaker.New[any](nil) ctx := context.Background() data, err := cb.Do(context.Background(), func() (interface{}, error) { @@ -44,40 +44,40 @@ func ExampleCircuitBreaker() { func TestDo(t *testing.T) { t.Run("success", func(t *testing.T) { - cb := circuitbreaker.New() - got, err := cb.Do(context.Background(), func() (interface{}, error) { + cb := circuitbreaker.New[string]() + got, err := cb.Do(context.Background(), func() (string, error) { return "data", nil }) assert.NoError(t, err) - assert.Equal(t, "data", got.(string)) + assert.Equal(t, "data", got) assert.Equal(t, int64(0), cb.Counters().Failures) }) t.Run("error", func(t *testing.T) { - cb := circuitbreaker.New() + cb := circuitbreaker.New[string]() wantErr := errors.New("something happens") - got, err := cb.Do(context.Background(), func() (interface{}, error) { + got, err := cb.Do(context.Background(), func() (string, error) { return "data", wantErr }) assert.Equal(t, err, wantErr) - assert.Equal(t, "data", got.(string)) + assert.Equal(t, "data", got) assert.Equal(t, int64(1), cb.Counters().Failures) }) t.Run("ignore", func(t *testing.T) { - cb := circuitbreaker.New() + cb := circuitbreaker.New[string]() wantErr := errors.New("something happens") - got, err := cb.Do(context.Background(), func() (interface{}, error) { return "data", circuitbreaker.Ignore(wantErr) }) + got, err := cb.Do(context.Background(), func() (string, error) { return "data", circuitbreaker.Ignore(wantErr) }) assert.Equal(t, err, wantErr) - assert.Equal(t, "data", got.(string)) + assert.Equal(t, "data", got) assert.Equal(t, int64(0), cb.Counters().Failures) }) t.Run("markassuccess", func(t *testing.T) { - cb := circuitbreaker.New() + cb := circuitbreaker.New[string]() wantErr := errors.New("something happens") - got, err := cb.Do(context.Background(), func() (interface{}, error) { return "data", circuitbreaker.MarkAsSuccess(wantErr) }) + got, err := cb.Do(context.Background(), func() (string, error) { return "data", circuitbreaker.MarkAsSuccess(wantErr) }) assert.Equal(t, err, wantErr) - assert.Equal(t, "data", got.(string)) + assert.Equal(t, "data", got) assert.Equal(t, int64(0), cb.Counters().Failures) }) @@ -92,15 +92,15 @@ func TestDo(t *testing.T) { for _, test := range tests { cancelErr := errors.New("context's Done channel closed.") t.Run(fmt.Sprintf("FailOnContextCanceled=%t", test.FailOnContextCancel), func(t *testing.T) { - cb := circuitbreaker.New(circuitbreaker.WithFailOnContextCancel(test.FailOnContextCancel)) + cb := circuitbreaker.New[string](circuitbreaker.WithFailOnContextCancel[string](test.FailOnContextCancel)) ctx, cancel := context.WithCancel(context.Background()) cancel() - got, err := cb.Do(ctx, func() (interface{}, error) { + got, err := cb.Do(ctx, func() (string, error) { <-ctx.Done() return "", cancelErr }) assert.Equal(t, err, cancelErr) - assert.Equal(t, "", got.(string)) + assert.Equal(t, "", got) assert.Equal(t, test.ExpectedFailures, cb.Counters().Failures) }) } @@ -117,15 +117,15 @@ func TestDo(t *testing.T) { for _, test := range tests { timeoutErr := errors.New("context's Done channel closed") t.Run(fmt.Sprintf("FailOnContextDeadline=%t", test.FailOnContextDeadline), func(t *testing.T) { - cb := circuitbreaker.New(circuitbreaker.WithFailOnContextDeadline(test.FailOnContextDeadline)) + cb := circuitbreaker.New[string](circuitbreaker.WithFailOnContextDeadline[string](test.FailOnContextDeadline)) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) defer cancel() - got, err := cb.Do(ctx, func() (interface{}, error) { + got, err := cb.Do(ctx, func() (string, error) { <-ctx.Done() return "", timeoutErr }) assert.Equal(t, err, timeoutErr) - assert.Equal(t, "", got.(string)) + assert.Equal(t, "", got) assert.Equal(t, test.ExpectedFailures, cb.Counters().Failures) }) } @@ -133,10 +133,10 @@ func TestDo(t *testing.T) { t.Run("cyclic-state-transition", func(t *testing.T) { clkMock := clock.NewMock() - cb := circuitbreaker.New(circuitbreaker.WithTripFunc(circuitbreaker.NewTripFuncThreshold(3)), - circuitbreaker.WithClock(clkMock), - circuitbreaker.WithOpenTimeout(1000*time.Millisecond), - circuitbreaker.WithHalfOpenMaxSuccesses(4)) + cb := circuitbreaker.New[any](circuitbreaker.WithTripFunc[any](circuitbreaker.NewTripFuncThreshold(3)), + circuitbreaker.WithClock[any](clkMock), + circuitbreaker.WithOpenTimeout[any](1000*time.Millisecond), + circuitbreaker.WithHalfOpenMaxSuccesses[any](4)) wantErr := errors.New("something happens") @@ -148,7 +148,7 @@ func TestDo(t *testing.T) { assert.Equal(t, circuitbreaker.StateClosed, cb.State()) got, err := cb.Do(context.Background(), func() (interface{}, error) { return "data", wantErr }) assert.Equal(t, err, wantErr) - assert.Equal(t, "data", got.(string)) + assert.Equal(t, "data", got) } // State: Closed => Open. Should return nil and ErrOpen error. @@ -164,7 +164,7 @@ func TestDo(t *testing.T) { // State: HalfOpen => Open. got, err = cb.Do(context.Background(), func() (interface{}, error) { return "data", wantErr }) assert.Equal(t, err, wantErr) - assert.Equal(t, "data", got.(string)) + assert.Equal(t, "data", got) assert.Equal(t, circuitbreaker.StateOpen, cb.State()) // State: Open => HalfOpen. @@ -175,7 +175,7 @@ func TestDo(t *testing.T) { assert.Equal(t, circuitbreaker.StateHalfOpen, cb.State()) got, err = cb.Do(context.Background(), func() (interface{}, error) { return "data", nil }) assert.NoError(t, err) - assert.Equal(t, "data", got.(string)) + assert.Equal(t, "data", got) } assert.Equal(t, circuitbreaker.StateClosed, cb.State()) } @@ -240,7 +240,7 @@ func TestMarkAsSuccess(t *testing.T) { } func TestSuccess(t *testing.T) { - cb := circuitbreaker.New() + cb := circuitbreaker.New[string]() cb.Success() assert.Equal(t, circuitbreaker.Counters{Successes: 1, Failures: 0, ConsecutiveSuccesses: 1, ConsecutiveFailures: 0}, cb.Counters()) @@ -252,7 +252,7 @@ func TestSuccess(t *testing.T) { } func TestFail(t *testing.T) { - cb := circuitbreaker.New() + cb := circuitbreaker.New[string]() cb.Fail() assert.Equal(t, circuitbreaker.Counters{Successes: 0, Failures: 1, ConsecutiveSuccesses: 0, ConsecutiveFailures: 1}, cb.Counters()) @@ -264,7 +264,7 @@ func TestFail(t *testing.T) { // TestReset tests if Reset resets all counters. func TestReset(t *testing.T) { - cb := circuitbreaker.New() + cb := circuitbreaker.New[string]() cb.Success() cb.Reset() assert.Equal(t, circuitbreaker.Counters{}, cb.Counters()) @@ -276,7 +276,7 @@ func TestReset(t *testing.T) { func TestReportFunctions(t *testing.T) { t.Run("Failed if ctx.Err() == nil", func(t *testing.T) { - cb := circuitbreaker.New() + cb := circuitbreaker.New[string]() cb.FailWithContext(context.Background()) assert.Equal(t, int64(1), cb.Counters().Failures) }) @@ -284,22 +284,22 @@ func TestReportFunctions(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) cancel() - cb := circuitbreaker.New() + cb := circuitbreaker.New[string]() cb.FailWithContext(ctx) assert.Equal(t, int64(0), cb.Counters().Failures) - cb = circuitbreaker.New(circuitbreaker.WithFailOnContextCancel(true)) + cb = circuitbreaker.New[string](circuitbreaker.WithFailOnContextCancel[string](true)) cb.FailWithContext(ctx) assert.Equal(t, int64(1), cb.Counters().Failures) }) t.Run("ctx.Err() == context.DeadlineExceeded", func(t *testing.T) { ctx, cancel := context.WithDeadline(context.Background(), time.Time{}) defer cancel() - cb := circuitbreaker.New() + cb := circuitbreaker.New[string]() cb.FailWithContext(ctx) assert.Equal(t, int64(0), cb.Counters().Failures) - cb = circuitbreaker.New(circuitbreaker.WithFailOnContextDeadline(true)) + cb = circuitbreaker.New[string](circuitbreaker.WithFailOnContextDeadline[string](true)) cb.FailWithContext(ctx) assert.Equal(t, int64(1), cb.Counters().Failures) }) diff --git a/go.mod b/go.mod index 01e5e5a..5b2918e 100644 --- a/go.mod +++ b/go.mod @@ -1,9 +1,15 @@ -module github.com/mercari/go-circuitbreaker +module github.com/mercari/go-circuitbreaker/v2 -go 1.15 +go 1.18 require ( github.com/benbjohnson/clock v1.3.0 github.com/cenkalti/backoff/v3 v3.1.1 github.com/stretchr/testify v1.4.0 ) + +require ( + github.com/davecgh/go-spew v1.1.0 // indirect + github.com/pmezard/go-difflib v1.0.0 // indirect + gopkg.in/yaml.v2 v2.2.2 // indirect +) diff --git a/state.go b/state.go index b084561..93b58b3 100644 --- a/state.go +++ b/state.go @@ -8,13 +8,13 @@ import ( // each implementations of state represents State of circuit breaker. // // ref: https://docs.microsoft.com/en-us/azure/architecture/patterns/circuit-breaker -type state interface { - State() State - onEntry(cb *CircuitBreaker) - onExit(cb *CircuitBreaker) - ready(cb *CircuitBreaker) bool - onSuccess(cb *CircuitBreaker) - onFail(cb *CircuitBreaker) +type state[T any] interface { + State() State[T] + onEntry(cb *CircuitBreaker[T]) + onExit(cb *CircuitBreaker[T]) + ready(cb *CircuitBreaker[T]) bool + onSuccess(cb *CircuitBreaker[T]) + onFail(cb *CircuitBreaker[T]) } // [Closed state] @@ -30,13 +30,13 @@ type state interface { // - reset counters. // /onExit // - stop ticker. -type stateClosed struct { +type stateClosed[T any] struct { ticker *clock.Ticker done chan struct{} } -func (st *stateClosed) State() State { return StateClosed } -func (st *stateClosed) onEntry(cb *CircuitBreaker) { +func (st *stateClosed[T]) State() State[T] { return State[T](StateClosed) } +func (st *stateClosed[T]) onEntry(cb *CircuitBreaker[T]) { cb.cnt.resetFailures() cb.openBackOff.Reset() if cb.interval > 0 { @@ -56,23 +56,23 @@ func (st *stateClosed) onEntry(cb *CircuitBreaker) { } } -func (st *stateClosed) onExit(cb *CircuitBreaker) { +func (st *stateClosed[T]) onExit(cb *CircuitBreaker[T]) { if st.done != nil { close(st.done) } } -func (st *stateClosed) onTicker(cb *CircuitBreaker) { +func (st *stateClosed[T]) onTicker(cb *CircuitBreaker[T]) { cb.mu.Lock() defer cb.mu.Unlock() cb.cnt.reset() } -func (st *stateClosed) ready(cb *CircuitBreaker) bool { return true } -func (st *stateClosed) onSuccess(cb *CircuitBreaker) {} -func (st *stateClosed) onFail(cb *CircuitBreaker) { +func (st *stateClosed[T]) ready(cb *CircuitBreaker[T]) bool { return true } +func (st *stateClosed[T]) onSuccess(cb *CircuitBreaker[T]) {} +func (st *stateClosed[T]) onFail(cb *CircuitBreaker[T]) { if cb.shouldTrip(&cb.cnt) { - cb.setState(&stateOpen{}) + cb.setState(&stateOpen[T]{}) } } @@ -85,23 +85,23 @@ func (st *stateClosed) onFail(cb *CircuitBreaker) { // - Change state to [HalfOpen]. // /onExit // - Stop timer. -type stateOpen struct { +type stateOpen[T any] struct { timer *clock.Timer } -func (st *stateOpen) State() State { return StateOpen } -func (st *stateOpen) onEntry(cb *CircuitBreaker) { +func (st *stateOpen[T]) State() State[T] { return State[T](StateOpen) } +func (st *stateOpen[T]) onEntry(cb *CircuitBreaker[T]) { timeout := cb.openBackOff.NextBackOff() if timeout != backoff.Stop { st.timer = cb.clock.AfterFunc(timeout, func() { st.onTimer(cb) }) } } -func (st *stateOpen) onTimer(cb *CircuitBreaker) { cb.setStateWithLock(&stateHalfOpen{}) } -func (st *stateOpen) onExit(cb *CircuitBreaker) { st.timer.Stop() } -func (st *stateOpen) ready(cb *CircuitBreaker) bool { return false } -func (st *stateOpen) onSuccess(cb *CircuitBreaker) {} -func (st *stateOpen) onFail(cb *CircuitBreaker) {} +func (st *stateOpen[T]) onTimer(cb *CircuitBreaker[T]) { cb.setStateWithLock(&stateHalfOpen[T]{}) } +func (st *stateOpen[T]) onExit(cb *CircuitBreaker[T]) { st.timer.Stop() } +func (st *stateOpen[T]) ready(cb *CircuitBreaker[T]) bool { return false } +func (st *stateOpen[T]) onSuccess(cb *CircuitBreaker[T]) {} +func (st *stateOpen[T]) onFail(cb *CircuitBreaker[T]) {} // [HalfOpen state] // /ready @@ -111,17 +111,17 @@ func (st *stateOpen) onFail(cb *CircuitBreaker) {} // -> If threshold reached, change state to [Closed]. // /onFail // -> change state to [Open]. -type stateHalfOpen struct{} +type stateHalfOpen[T any] struct{} -func (st *stateHalfOpen) State() State { return StateHalfOpen } -func (st *stateHalfOpen) onEntry(cb *CircuitBreaker) { cb.cnt.resetSuccesses() } -func (st *stateHalfOpen) onExit(cb *CircuitBreaker) {} -func (st *stateHalfOpen) ready(cb *CircuitBreaker) bool { return true } -func (st *stateHalfOpen) onSuccess(cb *CircuitBreaker) { +func (st *stateHalfOpen[T]) State() State[T] { return State[T](StateHalfOpen) } +func (st *stateHalfOpen[T]) onEntry(cb *CircuitBreaker[T]) { cb.cnt.resetSuccesses() } +func (st *stateHalfOpen[T]) onExit(cb *CircuitBreaker[T]) {} +func (st *stateHalfOpen[T]) ready(cb *CircuitBreaker[T]) bool { return true } +func (st *stateHalfOpen[T]) onSuccess(cb *CircuitBreaker[T]) { if cb.cnt.Successes >= cb.halfOpenMaxSuccesses { - cb.setState(&stateClosed{}) + cb.setState(&stateClosed[T]{}) } } -func (st *stateHalfOpen) onFail(cb *CircuitBreaker) { - cb.setState(&stateOpen{}) +func (st *stateHalfOpen[T]) onFail(cb *CircuitBreaker[T]) { + cb.setState(&stateOpen[T]{}) } diff --git a/state_test.go b/state_test.go index 9ca2547..d4b5055 100644 --- a/state_test.go +++ b/state_test.go @@ -9,16 +9,16 @@ import ( "github.com/benbjohnson/clock" "github.com/cenkalti/backoff/v3" - "github.com/mercari/go-circuitbreaker" + "github.com/mercari/go-circuitbreaker/v2" "github.com/stretchr/testify/assert" ) func TestCircuitBreakerStateTransitions(t *testing.T) { clk := clock.NewMock() - cb := circuitbreaker.New(circuitbreaker.WithTripFunc(circuitbreaker.NewTripFuncThreshold(3)), - circuitbreaker.WithClock(clk), - circuitbreaker.WithOpenTimeout(1000*time.Millisecond), - circuitbreaker.WithHalfOpenMaxSuccesses(4)) + cb := circuitbreaker.New(circuitbreaker.WithTripFunc[any](circuitbreaker.NewTripFuncThreshold(3)), + circuitbreaker.WithClock[any](clk), + circuitbreaker.WithOpenTimeout[any](1000*time.Millisecond), + circuitbreaker.WithHalfOpenMaxSuccesses[any](4)) for i := 0; i < 10; i++ { // Scenario: 3 Fails. State changes to -> StateOpen. @@ -53,8 +53,8 @@ func TestCircuitBreakerStateTransitions(t *testing.T) { func TestCircuitBreakerOnStateChange(t *testing.T) { type stateChange struct { - from circuitbreaker.State - to circuitbreaker.State + from circuitbreaker.State[any] + to circuitbreaker.State[any] } expectedStateChanges := []stateChange{ @@ -82,12 +82,12 @@ func TestCircuitBreakerOnStateChange(t *testing.T) { var actualStateChanges []stateChange clock := clock.NewMock() - cb := circuitbreaker.New( - circuitbreaker.WithTripFunc(circuitbreaker.NewTripFuncThreshold(3)), - circuitbreaker.WithClock(clock), - circuitbreaker.WithOpenTimeout(1000*time.Millisecond), - circuitbreaker.WithHalfOpenMaxSuccesses(4), - circuitbreaker.WithOnStateChangeHookFn(func(from, to circuitbreaker.State) { + cb := circuitbreaker.New[any]( + circuitbreaker.WithTripFunc[any](circuitbreaker.NewTripFuncThreshold(3)), + circuitbreaker.WithClock[any](clock), + circuitbreaker.WithOpenTimeout[any](1000*time.Millisecond), + circuitbreaker.WithHalfOpenMaxSuccesses[any](4), + circuitbreaker.WithOnStateChangeHookFn[any](func(from, to circuitbreaker.State[any]) { actualStateChanges = append(actualStateChanges, stateChange{ from: from, to: to, @@ -124,9 +124,9 @@ func TestCircuitBreakerOnStateChange(t *testing.T) { // - Interval ticker reset the internal counter.. func TestStateClosed(t *testing.T) { clk := clock.NewMock() - cb := circuitbreaker.New(circuitbreaker.WithTripFunc(circuitbreaker.NewTripFuncThreshold(3)), - circuitbreaker.WithClock(clk), - circuitbreaker.WithCounterResetInterval(1000*time.Millisecond)) + cb := circuitbreaker.New[any](circuitbreaker.WithTripFunc[any](circuitbreaker.NewTripFuncThreshold(3)), + circuitbreaker.WithClock[any](clk), + circuitbreaker.WithCounterResetInterval[any](1000*time.Millisecond)) t.Run("Ready", func(t *testing.T) { assert.True(t, cb.Ready()) @@ -157,9 +157,9 @@ func TestStateClosed(t *testing.T) { // - Change state to StateHalfOpen after timer. func TestStateOpen(t *testing.T) { clk := clock.NewMock() - cb := circuitbreaker.New(circuitbreaker.WithTripFunc(circuitbreaker.NewTripFuncThreshold(3)), - circuitbreaker.WithClock(clk), - circuitbreaker.WithOpenTimeout(500*time.Millisecond)) + cb := circuitbreaker.New[any](circuitbreaker.WithTripFunc[any](circuitbreaker.NewTripFuncThreshold(3)), + circuitbreaker.WithClock[any](clk), + circuitbreaker.WithOpenTimeout[any](500*time.Millisecond)) t.Run("Ready", func(t *testing.T) { cb.SetState(circuitbreaker.StateOpen) assert.False(t, cb.Ready()) @@ -186,10 +186,10 @@ func TestStateOpen(t *testing.T) { MaxElapsedTime: 0, Clock: clkMock, } - cb := circuitbreaker.New(circuitbreaker.WithTripFunc(circuitbreaker.NewTripFuncThreshold(1)), - circuitbreaker.WithHalfOpenMaxSuccesses(1), - circuitbreaker.WithClock(clkMock), - circuitbreaker.WithOpenTimeoutBackOff(backoffTest)) + cb := circuitbreaker.New[any](circuitbreaker.WithTripFunc[any](circuitbreaker.NewTripFuncThreshold(1)), + circuitbreaker.WithHalfOpenMaxSuccesses[any](1), + circuitbreaker.WithClock[any](clkMock), + circuitbreaker.WithOpenTimeoutBackOff[any](backoffTest)) backoffTest.Reset() tests := []struct { @@ -223,10 +223,10 @@ func TestStateOpen(t *testing.T) { MaxElapsedTime: 0, Clock: clkMock, } - cb := circuitbreaker.New(circuitbreaker.WithTripFunc(circuitbreaker.NewTripFuncThreshold(1)), - circuitbreaker.WithHalfOpenMaxSuccesses(1), - circuitbreaker.WithClock(clkMock), - circuitbreaker.WithOpenTimeoutBackOff(backoffTest)) + cb := circuitbreaker.New[any](circuitbreaker.WithTripFunc[any](circuitbreaker.NewTripFuncThreshold(1)), + circuitbreaker.WithHalfOpenMaxSuccesses[any](1), + circuitbreaker.WithClock[any](clkMock), + circuitbreaker.WithOpenTimeoutBackOff[any](backoffTest)) backoffTest.Reset() tests := []struct { @@ -246,7 +246,7 @@ func TestStateOpen(t *testing.T) { }) } -func assertChangeStateToHalfOpenAfter(t *testing.T, cb *circuitbreaker.CircuitBreaker, clock *clock.Mock, after time.Duration) { +func assertChangeStateToHalfOpenAfter[T any](t *testing.T, cb *circuitbreaker.CircuitBreaker[T], clock *clock.Mock, after time.Duration) { clock.Add(after - 1) assert.Equal(t, circuitbreaker.StateOpen, cb.State()) @@ -260,9 +260,9 @@ func assertChangeStateToHalfOpenAfter(t *testing.T, cb *circuitbreaker.CircuitBr // - If get a success, the state changes to Closed. func TestHalfOpen(t *testing.T) { clkMock := clock.NewMock() - cb := circuitbreaker.New(circuitbreaker.WithTripFunc(circuitbreaker.NewTripFuncThreshold(3)), - circuitbreaker.WithClock(clkMock), - circuitbreaker.WithHalfOpenMaxSuccesses(4)) + cb := circuitbreaker.New[any](circuitbreaker.WithTripFunc[any](circuitbreaker.NewTripFuncThreshold(3)), + circuitbreaker.WithClock[any](clkMock), + circuitbreaker.WithHalfOpenMaxSuccesses[any](4)) t.Run("Ready", func(t *testing.T) { cb.Reset() cb.SetState(circuitbreaker.StateHalfOpen) @@ -302,10 +302,10 @@ func run(wg *sync.WaitGroup, f func()) { func TestRace(t *testing.T) { clock := clock.NewMock() - cb := circuitbreaker.New( - circuitbreaker.WithTripFunc(func(_ *circuitbreaker.Counters) bool { return true }), - circuitbreaker.WithClock(clock), - circuitbreaker.WithCounterResetInterval(1000*time.Millisecond), + cb := circuitbreaker.New[any]( + circuitbreaker.WithTripFunc[any](func(_ *circuitbreaker.Counters) bool { return true }), + circuitbreaker.WithClock[any](clock), + circuitbreaker.WithCounterResetInterval[any](1000*time.Millisecond), ) wg := &sync.WaitGroup{} run(wg, func() {