mirror of
https://github.com/slackhq/nebula.git
synced 2025-05-13 13:11:55 +00:00
Generic timerwheel (#804)
This commit is contained in:
parent
c177126ed0
commit
5278b6f926
8 changed files with 116 additions and 431 deletions
|
@ -19,12 +19,12 @@ type connectionManager struct {
|
|||
inLock *sync.RWMutex
|
||||
out map[iputil.VpnIp]struct{}
|
||||
outLock *sync.RWMutex
|
||||
TrafficTimer *SystemTimerWheel
|
||||
TrafficTimer *LockingTimerWheel[iputil.VpnIp]
|
||||
intf *Interface
|
||||
|
||||
pendingDeletion map[iputil.VpnIp]int
|
||||
pendingDeletionLock *sync.RWMutex
|
||||
pendingDeletionTimer *SystemTimerWheel
|
||||
pendingDeletionTimer *LockingTimerWheel[iputil.VpnIp]
|
||||
|
||||
checkInterval int
|
||||
pendingDeletionInterval int
|
||||
|
@ -40,11 +40,11 @@ func newConnectionManager(ctx context.Context, l *logrus.Logger, intf *Interface
|
|||
inLock: &sync.RWMutex{},
|
||||
out: make(map[iputil.VpnIp]struct{}),
|
||||
outLock: &sync.RWMutex{},
|
||||
TrafficTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
|
||||
TrafficTimer: NewLockingTimerWheel[iputil.VpnIp](time.Millisecond*500, time.Second*60),
|
||||
intf: intf,
|
||||
pendingDeletion: make(map[iputil.VpnIp]int),
|
||||
pendingDeletionLock: &sync.RWMutex{},
|
||||
pendingDeletionTimer: NewSystemTimerWheel(time.Millisecond*500, time.Second*60),
|
||||
pendingDeletionTimer: NewLockingTimerWheel[iputil.VpnIp](time.Millisecond*500, time.Second*60),
|
||||
checkInterval: checkInterval,
|
||||
pendingDeletionInterval: pendingDeletionInterval,
|
||||
l: l,
|
||||
|
@ -160,15 +160,13 @@ func (n *connectionManager) Run(ctx context.Context) {
|
|||
}
|
||||
|
||||
func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte) {
|
||||
n.TrafficTimer.advance(now)
|
||||
n.TrafficTimer.Advance(now)
|
||||
for {
|
||||
ep := n.TrafficTimer.Purge()
|
||||
if ep == nil {
|
||||
vpnIp, has := n.TrafficTimer.Purge()
|
||||
if !has {
|
||||
break
|
||||
}
|
||||
|
||||
vpnIp := ep.(iputil.VpnIp)
|
||||
|
||||
// Check for traffic coming back in from this host.
|
||||
traf := n.CheckIn(vpnIp)
|
||||
|
||||
|
@ -214,15 +212,13 @@ func (n *connectionManager) HandleMonitorTick(now time.Time, p, nb, out []byte)
|
|||
}
|
||||
|
||||
func (n *connectionManager) HandleDeletionTick(now time.Time) {
|
||||
n.pendingDeletionTimer.advance(now)
|
||||
n.pendingDeletionTimer.Advance(now)
|
||||
for {
|
||||
ep := n.pendingDeletionTimer.Purge()
|
||||
if ep == nil {
|
||||
vpnIp, has := n.pendingDeletionTimer.Purge()
|
||||
if !has {
|
||||
break
|
||||
}
|
||||
|
||||
vpnIp := ep.(iputil.VpnIp)
|
||||
|
||||
hostinfo, err := n.hostMap.QueryVpnIp(vpnIp)
|
||||
if err != nil {
|
||||
n.l.Debugf("Not found in hostmap: %s", vpnIp)
|
||||
|
|
|
@ -77,7 +77,7 @@ type FirewallConntrack struct {
|
|||
sync.Mutex
|
||||
|
||||
Conns map[firewall.Packet]*conn
|
||||
TimerWheel *TimerWheel
|
||||
TimerWheel *TimerWheel[firewall.Packet]
|
||||
}
|
||||
|
||||
type FirewallTable struct {
|
||||
|
@ -145,7 +145,7 @@ func NewFirewall(l *logrus.Logger, tcpTimeout, UDPTimeout, defaultTimeout time.D
|
|||
return &Firewall{
|
||||
Conntrack: &FirewallConntrack{
|
||||
Conns: make(map[firewall.Packet]*conn),
|
||||
TimerWheel: NewTimerWheel(min, max),
|
||||
TimerWheel: NewTimerWheel[firewall.Packet](min, max),
|
||||
},
|
||||
InRules: newFirewallTable(),
|
||||
OutRules: newFirewallTable(),
|
||||
|
@ -510,6 +510,7 @@ func (f *Firewall) addConn(packet []byte, fp firewall.Packet, incoming bool) {
|
|||
conntrack := f.Conntrack
|
||||
conntrack.Lock()
|
||||
if _, ok := conntrack.Conns[fp]; !ok {
|
||||
conntrack.TimerWheel.Advance(time.Now())
|
||||
conntrack.TimerWheel.Add(fp, timeout)
|
||||
}
|
||||
|
||||
|
@ -537,6 +538,7 @@ func (f *Firewall) evict(p firewall.Packet) {
|
|||
|
||||
// Timeout is in the future, re-add the timer
|
||||
if newT > 0 {
|
||||
conntrack.TimerWheel.Advance(time.Now())
|
||||
conntrack.TimerWheel.Add(p, newT)
|
||||
return
|
||||
}
|
||||
|
|
|
@ -47,7 +47,7 @@ type HandshakeManager struct {
|
|||
lightHouse *LightHouse
|
||||
outside *udp.Conn
|
||||
config HandshakeConfig
|
||||
OutboundHandshakeTimer *SystemTimerWheel
|
||||
OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
|
||||
messageMetrics *MessageMetrics
|
||||
metricInitiated metrics.Counter
|
||||
metricTimedOut metrics.Counter
|
||||
|
@ -65,7 +65,7 @@ func NewHandshakeManager(l *logrus.Logger, tunCidr *net.IPNet, preferredRanges [
|
|||
outside: outside,
|
||||
config: config,
|
||||
trigger: make(chan iputil.VpnIp, config.triggerBuffer),
|
||||
OutboundHandshakeTimer: NewSystemTimerWheel(config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
|
||||
OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
|
||||
messageMetrics: config.messageMetrics,
|
||||
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
|
||||
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
|
||||
|
@ -90,13 +90,12 @@ func (c *HandshakeManager) Run(ctx context.Context, f udp.EncWriter) {
|
|||
}
|
||||
|
||||
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time, f udp.EncWriter) {
|
||||
c.OutboundHandshakeTimer.advance(now)
|
||||
c.OutboundHandshakeTimer.Advance(now)
|
||||
for {
|
||||
ep := c.OutboundHandshakeTimer.Purge()
|
||||
if ep == nil {
|
||||
vpnIp, has := c.OutboundHandshakeTimer.Purge()
|
||||
if !has {
|
||||
break
|
||||
}
|
||||
vpnIp := ep.(iputil.VpnIp)
|
||||
c.handleOutbound(vpnIp, f, false)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -106,8 +106,8 @@ func Test_NewHandshakeManagerTrigger(t *testing.T) {
|
|||
assert.Equal(t, 1, testCountTimerWheelEntries(blah.OutboundHandshakeTimer))
|
||||
}
|
||||
|
||||
func testCountTimerWheelEntries(tw *SystemTimerWheel) (c int) {
|
||||
for _, i := range tw.wheel {
|
||||
func testCountTimerWheelEntries(tw *LockingTimerWheel[iputil.VpnIp]) (c int) {
|
||||
for _, i := range tw.t.wheel {
|
||||
n := i.Head
|
||||
for n != nil {
|
||||
c++
|
||||
|
|
95
timeout.go
95
timeout.go
|
@ -1,17 +1,14 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
)
|
||||
|
||||
// How many timer objects should be cached
|
||||
const timerCacheMax = 50000
|
||||
|
||||
var emptyFWPacket = firewall.Packet{}
|
||||
|
||||
type TimerWheel struct {
|
||||
type TimerWheel[T any] struct {
|
||||
// Current tick
|
||||
current int
|
||||
|
||||
|
@ -26,31 +23,38 @@ type TimerWheel struct {
|
|||
wheelDuration time.Duration
|
||||
|
||||
// The actual wheel which is just a set of singly linked lists, head/tail pointers
|
||||
wheel []*TimeoutList
|
||||
wheel []*TimeoutList[T]
|
||||
|
||||
// Singly linked list of items that have timed out of the wheel
|
||||
expired *TimeoutList
|
||||
expired *TimeoutList[T]
|
||||
|
||||
// Item cache to avoid garbage collect
|
||||
itemCache *TimeoutItem
|
||||
itemCache *TimeoutItem[T]
|
||||
itemsCached int
|
||||
}
|
||||
|
||||
type LockingTimerWheel[T any] struct {
|
||||
m sync.Mutex
|
||||
t *TimerWheel[T]
|
||||
}
|
||||
|
||||
// TimeoutList Represents a tick in the wheel
|
||||
type TimeoutList struct {
|
||||
Head *TimeoutItem
|
||||
Tail *TimeoutItem
|
||||
type TimeoutList[T any] struct {
|
||||
Head *TimeoutItem[T]
|
||||
Tail *TimeoutItem[T]
|
||||
}
|
||||
|
||||
// TimeoutItem Represents an item within a tick
|
||||
type TimeoutItem struct {
|
||||
Packet firewall.Packet
|
||||
Next *TimeoutItem
|
||||
type TimeoutItem[T any] struct {
|
||||
Item T
|
||||
Next *TimeoutItem[T]
|
||||
}
|
||||
|
||||
// NewTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
|
||||
// Purge must be called once per entry to actually remove anything
|
||||
func NewTimerWheel(min, max time.Duration) *TimerWheel {
|
||||
// The TimerWheel does not handle concurrency on its own.
|
||||
// Locks around access to it must be used if multiple routines are manipulating it.
|
||||
func NewTimerWheel[T any](min, max time.Duration) *TimerWheel[T] {
|
||||
//TODO provide an error
|
||||
//if min >= max {
|
||||
// return nil
|
||||
|
@ -61,26 +65,31 @@ func NewTimerWheel(min, max time.Duration) *TimerWheel {
|
|||
// timeout
|
||||
wLen := int((max / min) + 2)
|
||||
|
||||
tw := TimerWheel{
|
||||
tw := TimerWheel[T]{
|
||||
wheelLen: wLen,
|
||||
wheel: make([]*TimeoutList, wLen),
|
||||
wheel: make([]*TimeoutList[T], wLen),
|
||||
tickDuration: min,
|
||||
wheelDuration: max,
|
||||
expired: &TimeoutList{},
|
||||
expired: &TimeoutList[T]{},
|
||||
}
|
||||
|
||||
for i := range tw.wheel {
|
||||
tw.wheel[i] = &TimeoutList{}
|
||||
tw.wheel[i] = &TimeoutList[T]{}
|
||||
}
|
||||
|
||||
return &tw
|
||||
}
|
||||
|
||||
// Add will add a firewall.Packet to the wheel in it's proper timeout
|
||||
func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem {
|
||||
// Check and see if we should progress the tick
|
||||
tw.advance(time.Now())
|
||||
// NewLockingTimerWheel is version of TimerWheel that is safe for concurrent use with a small performance penalty
|
||||
func NewLockingTimerWheel[T any](min, max time.Duration) *LockingTimerWheel[T] {
|
||||
return &LockingTimerWheel[T]{
|
||||
t: NewTimerWheel[T](min, max),
|
||||
}
|
||||
}
|
||||
|
||||
// Add will add an item to the wheel in its proper timeout.
|
||||
// Caller should Advance the wheel prior to ensure the proper slot is used.
|
||||
func (tw *TimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] {
|
||||
i := tw.findWheel(timeout)
|
||||
|
||||
// Try to fetch off the cache
|
||||
|
@ -90,11 +99,11 @@ func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem
|
|||
tw.itemsCached--
|
||||
ti.Next = nil
|
||||
} else {
|
||||
ti = &TimeoutItem{}
|
||||
ti = &TimeoutItem[T]{}
|
||||
}
|
||||
|
||||
// Relink and return
|
||||
ti.Packet = v
|
||||
ti.Item = v
|
||||
if tw.wheel[i].Tail == nil {
|
||||
tw.wheel[i].Head = ti
|
||||
tw.wheel[i].Tail = ti
|
||||
|
@ -106,9 +115,12 @@ func (tw *TimerWheel) Add(v firewall.Packet, timeout time.Duration) *TimeoutItem
|
|||
return ti
|
||||
}
|
||||
|
||||
func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
|
||||
// Purge removes and returns the first available expired item from the wheel and the 2nd argument is true.
|
||||
// If no item is available then an empty T is returned and the 2nd argument is false.
|
||||
func (tw *TimerWheel[T]) Purge() (T, bool) {
|
||||
if tw.expired.Head == nil {
|
||||
return emptyFWPacket, false
|
||||
var na T
|
||||
return na, false
|
||||
}
|
||||
|
||||
ti := tw.expired.Head
|
||||
|
@ -128,11 +140,11 @@ func (tw *TimerWheel) Purge() (firewall.Packet, bool) {
|
|||
tw.itemsCached++
|
||||
}
|
||||
|
||||
return ti.Packet, true
|
||||
return ti.Item, true
|
||||
}
|
||||
|
||||
// advance will move the wheel forward by proper number of ticks. The caller _should_ lock the wheel before calling this
|
||||
func (tw *TimerWheel) findWheel(timeout time.Duration) (i int) {
|
||||
// findWheel find the next position in the wheel for the provided timeout given the current tick
|
||||
func (tw *TimerWheel[T]) findWheel(timeout time.Duration) (i int) {
|
||||
if timeout < tw.tickDuration {
|
||||
// Can't track anything below the set resolution
|
||||
timeout = tw.tickDuration
|
||||
|
@ -154,8 +166,9 @@ func (tw *TimerWheel) findWheel(timeout time.Duration) (i int) {
|
|||
return tick
|
||||
}
|
||||
|
||||
// advance will lock and move the wheel forward by proper number of ticks.
|
||||
func (tw *TimerWheel) advance(now time.Time) {
|
||||
// Advance will move the wheel forward by the appropriate number of ticks for the provided time and all items
|
||||
// passed over will be moved to the expired list. Calling Purge is necessary to remove them entirely.
|
||||
func (tw *TimerWheel[T]) Advance(now time.Time) {
|
||||
if tw.lastTick == nil {
|
||||
tw.lastTick = &now
|
||||
}
|
||||
|
@ -192,3 +205,21 @@ func (tw *TimerWheel) advance(now time.Time) {
|
|||
newTick := tw.lastTick.Add(tw.tickDuration * time.Duration(adv))
|
||||
tw.lastTick = &newTick
|
||||
}
|
||||
|
||||
func (lw *LockingTimerWheel[T]) Add(v T, timeout time.Duration) *TimeoutItem[T] {
|
||||
lw.m.Lock()
|
||||
defer lw.m.Unlock()
|
||||
return lw.t.Add(v, timeout)
|
||||
}
|
||||
|
||||
func (lw *LockingTimerWheel[T]) Purge() (T, bool) {
|
||||
lw.m.Lock()
|
||||
defer lw.m.Unlock()
|
||||
return lw.t.Purge()
|
||||
}
|
||||
|
||||
func (lw *LockingTimerWheel[T]) Advance(now time.Time) {
|
||||
lw.m.Lock()
|
||||
defer lw.m.Unlock()
|
||||
lw.t.Advance(now)
|
||||
}
|
||||
|
|
|
@ -1,199 +0,0 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
)
|
||||
|
||||
// How many timer objects should be cached
|
||||
const systemTimerCacheMax = 50000
|
||||
|
||||
type SystemTimerWheel struct {
|
||||
// Current tick
|
||||
current int
|
||||
|
||||
// Cheat on finding the length of the wheel
|
||||
wheelLen int
|
||||
|
||||
// Last time we ticked, since we are lazy ticking
|
||||
lastTick *time.Time
|
||||
|
||||
// Durations of a tick and the entire wheel
|
||||
tickDuration time.Duration
|
||||
wheelDuration time.Duration
|
||||
|
||||
// The actual wheel which is just a set of singly linked lists, head/tail pointers
|
||||
wheel []*SystemTimeoutList
|
||||
|
||||
// Singly linked list of items that have timed out of the wheel
|
||||
expired *SystemTimeoutList
|
||||
|
||||
// Item cache to avoid garbage collect
|
||||
itemCache *SystemTimeoutItem
|
||||
itemsCached int
|
||||
|
||||
lock sync.Mutex
|
||||
}
|
||||
|
||||
// SystemTimeoutList Represents a tick in the wheel
|
||||
type SystemTimeoutList struct {
|
||||
Head *SystemTimeoutItem
|
||||
Tail *SystemTimeoutItem
|
||||
}
|
||||
|
||||
// SystemTimeoutItem Represents an item within a tick
|
||||
type SystemTimeoutItem struct {
|
||||
Item iputil.VpnIp
|
||||
Next *SystemTimeoutItem
|
||||
}
|
||||
|
||||
// NewSystemTimerWheel Builds a timer wheel and identifies the tick duration and wheel duration from the provided values
|
||||
// Purge must be called once per entry to actually remove anything
|
||||
func NewSystemTimerWheel(min, max time.Duration) *SystemTimerWheel {
|
||||
//TODO provide an error
|
||||
//if min >= max {
|
||||
// return nil
|
||||
//}
|
||||
|
||||
// Round down and add 2 so we can have the smallest # of ticks in the wheel and still account for a full
|
||||
// max duration, even if our current tick is at the maximum position and the next item to be added is at maximum
|
||||
// timeout
|
||||
wLen := int((max / min) + 2)
|
||||
|
||||
tw := SystemTimerWheel{
|
||||
wheelLen: wLen,
|
||||
wheel: make([]*SystemTimeoutList, wLen),
|
||||
tickDuration: min,
|
||||
wheelDuration: max,
|
||||
expired: &SystemTimeoutList{},
|
||||
}
|
||||
|
||||
for i := range tw.wheel {
|
||||
tw.wheel[i] = &SystemTimeoutList{}
|
||||
}
|
||||
|
||||
return &tw
|
||||
}
|
||||
|
||||
func (tw *SystemTimerWheel) Add(v iputil.VpnIp, timeout time.Duration) *SystemTimeoutItem {
|
||||
tw.lock.Lock()
|
||||
defer tw.lock.Unlock()
|
||||
|
||||
// Check and see if we should progress the tick
|
||||
//tw.advance(time.Now())
|
||||
|
||||
i := tw.findWheel(timeout)
|
||||
|
||||
// Try to fetch off the cache
|
||||
ti := tw.itemCache
|
||||
if ti != nil {
|
||||
tw.itemCache = ti.Next
|
||||
ti.Next = nil
|
||||
tw.itemsCached--
|
||||
} else {
|
||||
ti = &SystemTimeoutItem{}
|
||||
}
|
||||
|
||||
// Relink and return
|
||||
ti.Item = v
|
||||
ti.Next = tw.wheel[i].Head
|
||||
tw.wheel[i].Head = ti
|
||||
|
||||
if tw.wheel[i].Tail == nil {
|
||||
tw.wheel[i].Tail = ti
|
||||
}
|
||||
|
||||
return ti
|
||||
}
|
||||
|
||||
func (tw *SystemTimerWheel) Purge() interface{} {
|
||||
tw.lock.Lock()
|
||||
defer tw.lock.Unlock()
|
||||
|
||||
if tw.expired.Head == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
ti := tw.expired.Head
|
||||
tw.expired.Head = ti.Next
|
||||
|
||||
if tw.expired.Head == nil {
|
||||
tw.expired.Tail = nil
|
||||
}
|
||||
|
||||
p := ti.Item
|
||||
|
||||
// Clear out the items references
|
||||
ti.Item = 0
|
||||
ti.Next = nil
|
||||
|
||||
// Maybe cache it for later
|
||||
if tw.itemsCached < systemTimerCacheMax {
|
||||
ti.Next = tw.itemCache
|
||||
tw.itemCache = ti
|
||||
tw.itemsCached++
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func (tw *SystemTimerWheel) findWheel(timeout time.Duration) (i int) {
|
||||
if timeout < tw.tickDuration {
|
||||
// Can't track anything below the set resolution
|
||||
timeout = tw.tickDuration
|
||||
} else if timeout > tw.wheelDuration {
|
||||
// We aren't handling timeouts greater than the wheels duration
|
||||
timeout = tw.wheelDuration
|
||||
}
|
||||
|
||||
// Find the next highest, rounding up
|
||||
tick := int(((timeout - 1) / tw.tickDuration) + 1)
|
||||
|
||||
// Add another tick since the current tick may almost be over then map it to the wheel from our
|
||||
// current position
|
||||
tick += tw.current + 1
|
||||
if tick >= tw.wheelLen {
|
||||
tick -= tw.wheelLen
|
||||
}
|
||||
|
||||
return tick
|
||||
}
|
||||
|
||||
func (tw *SystemTimerWheel) advance(now time.Time) {
|
||||
tw.lock.Lock()
|
||||
defer tw.lock.Unlock()
|
||||
|
||||
if tw.lastTick == nil {
|
||||
tw.lastTick = &now
|
||||
}
|
||||
|
||||
// We want to round down
|
||||
ticks := int(now.Sub(*tw.lastTick) / tw.tickDuration)
|
||||
//l.Infoln("Ticks: ", ticks)
|
||||
for i := 0; i < ticks; i++ {
|
||||
tw.current++
|
||||
//l.Infoln("Tick: ", tw.current)
|
||||
if tw.current >= tw.wheelLen {
|
||||
tw.current = 0
|
||||
}
|
||||
|
||||
// We need to append the expired items as to not starve evicting the oldest ones
|
||||
if tw.expired.Tail == nil {
|
||||
tw.expired.Head = tw.wheel[tw.current].Head
|
||||
tw.expired.Tail = tw.wheel[tw.current].Tail
|
||||
} else {
|
||||
tw.expired.Tail.Next = tw.wheel[tw.current].Head
|
||||
if tw.wheel[tw.current].Tail != nil {
|
||||
tw.expired.Tail = tw.wheel[tw.current].Tail
|
||||
}
|
||||
}
|
||||
|
||||
//l.Infoln("Head: ", tw.expired.Head, "Tail: ", tw.expired.Tail)
|
||||
tw.wheel[tw.current].Head = nil
|
||||
tw.wheel[tw.current].Tail = nil
|
||||
|
||||
tw.lastTick = &now
|
||||
}
|
||||
}
|
|
@ -1,156 +0,0 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewSystemTimerWheel(t *testing.T) {
|
||||
// Make sure we get an object we expect
|
||||
tw := NewSystemTimerWheel(time.Second, time.Second*10)
|
||||
assert.Equal(t, 12, tw.wheelLen)
|
||||
assert.Equal(t, 0, tw.current)
|
||||
assert.Nil(t, tw.lastTick)
|
||||
assert.Equal(t, time.Second*1, tw.tickDuration)
|
||||
assert.Equal(t, time.Second*10, tw.wheelDuration)
|
||||
assert.Len(t, tw.wheel, 12)
|
||||
|
||||
// Assert the math is correct
|
||||
tw = NewSystemTimerWheel(time.Second*3, time.Second*10)
|
||||
assert.Equal(t, 5, tw.wheelLen)
|
||||
|
||||
tw = NewSystemTimerWheel(time.Second*120, time.Minute*10)
|
||||
assert.Equal(t, 7, tw.wheelLen)
|
||||
}
|
||||
|
||||
func TestSystemTimerWheel_findWheel(t *testing.T) {
|
||||
tw := NewSystemTimerWheel(time.Second, time.Second*10)
|
||||
assert.Len(t, tw.wheel, 12)
|
||||
|
||||
// Current + tick + 1 since we don't know how far into current we are
|
||||
assert.Equal(t, 2, tw.findWheel(time.Second*1))
|
||||
|
||||
// Scale up to min duration
|
||||
assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
|
||||
|
||||
// Make sure we hit that last index
|
||||
assert.Equal(t, 11, tw.findWheel(time.Second*10))
|
||||
|
||||
// Scale down to max duration
|
||||
assert.Equal(t, 11, tw.findWheel(time.Second*11))
|
||||
|
||||
tw.current = 1
|
||||
// Make sure we account for the current position properly
|
||||
assert.Equal(t, 3, tw.findWheel(time.Second*1))
|
||||
assert.Equal(t, 0, tw.findWheel(time.Second*10))
|
||||
|
||||
// Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
|
||||
for min := time.Duration(1); min < 100; min++ {
|
||||
for max := min; max < 100; max++ {
|
||||
tw = NewSystemTimerWheel(min, max)
|
||||
|
||||
for current := 0; current < tw.wheelLen; current++ {
|
||||
tw.current = current
|
||||
for timeout := time.Duration(0); timeout <= tw.wheelDuration; timeout++ {
|
||||
tick := tw.findWheel(timeout)
|
||||
if tick >= tw.wheelLen {
|
||||
t.Errorf("Min: %v; Max: %v; Wheel len: %v; Current Tick: %v; Insert timeout: %v; Calc tick: %v", min, max, tw.wheelLen, current, timeout, tick)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSystemTimerWheel_Add(t *testing.T) {
|
||||
tw := NewSystemTimerWheel(time.Second, time.Second*10)
|
||||
|
||||
fp1 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
|
||||
tw.Add(fp1, time.Second*1)
|
||||
|
||||
// Make sure we set head and tail properly
|
||||
assert.NotNil(t, tw.wheel[2])
|
||||
assert.Equal(t, fp1, tw.wheel[2].Head.Item)
|
||||
assert.Nil(t, tw.wheel[2].Head.Next)
|
||||
assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
|
||||
assert.Nil(t, tw.wheel[2].Tail.Next)
|
||||
|
||||
// Make sure we only modify head
|
||||
fp2 := iputil.Ip2VpnIp(net.ParseIP("1.2.3.4"))
|
||||
tw.Add(fp2, time.Second*1)
|
||||
assert.Equal(t, fp2, tw.wheel[2].Head.Item)
|
||||
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
|
||||
assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
|
||||
assert.Nil(t, tw.wheel[2].Tail.Next)
|
||||
|
||||
// Make sure we use free'd items first
|
||||
tw.itemCache = &SystemTimeoutItem{}
|
||||
tw.itemsCached = 1
|
||||
tw.Add(fp2, time.Second*1)
|
||||
assert.Nil(t, tw.itemCache)
|
||||
assert.Equal(t, 0, tw.itemsCached)
|
||||
}
|
||||
|
||||
func TestSystemTimerWheel_Purge(t *testing.T) {
|
||||
// First advance should set the lastTick and do nothing else
|
||||
tw := NewSystemTimerWheel(time.Second, time.Second*10)
|
||||
assert.Nil(t, tw.lastTick)
|
||||
tw.advance(time.Now())
|
||||
assert.NotNil(t, tw.lastTick)
|
||||
assert.Equal(t, 0, tw.current)
|
||||
|
||||
fps := []iputil.VpnIp{9, 10, 11, 12}
|
||||
|
||||
//fp1 := ip2int(net.ParseIP("1.2.3.4"))
|
||||
|
||||
tw.Add(fps[0], time.Second*1)
|
||||
tw.Add(fps[1], time.Second*1)
|
||||
tw.Add(fps[2], time.Second*2)
|
||||
tw.Add(fps[3], time.Second*2)
|
||||
|
||||
ta := time.Now().Add(time.Second * 3)
|
||||
lastTick := *tw.lastTick
|
||||
tw.advance(ta)
|
||||
assert.Equal(t, 3, tw.current)
|
||||
assert.True(t, tw.lastTick.After(lastTick))
|
||||
|
||||
// Make sure we get all 4 packets back
|
||||
for i := 0; i < 4; i++ {
|
||||
assert.Contains(t, fps, tw.Purge())
|
||||
}
|
||||
|
||||
// Make sure there aren't any leftover
|
||||
assert.Nil(t, tw.Purge())
|
||||
assert.Nil(t, tw.expired.Head)
|
||||
assert.Nil(t, tw.expired.Tail)
|
||||
|
||||
// Make sure we cached the free'd items
|
||||
assert.Equal(t, 4, tw.itemsCached)
|
||||
ci := tw.itemCache
|
||||
for i := 0; i < 4; i++ {
|
||||
assert.NotNil(t, ci)
|
||||
ci = ci.Next
|
||||
}
|
||||
assert.Nil(t, ci)
|
||||
|
||||
// Lets make sure we roll over properly
|
||||
ta = ta.Add(time.Second * 5)
|
||||
tw.advance(ta)
|
||||
assert.Equal(t, 8, tw.current)
|
||||
|
||||
ta = ta.Add(time.Second * 2)
|
||||
tw.advance(ta)
|
||||
assert.Equal(t, 10, tw.current)
|
||||
|
||||
ta = ta.Add(time.Second * 1)
|
||||
tw.advance(ta)
|
||||
assert.Equal(t, 11, tw.current)
|
||||
|
||||
ta = ta.Add(time.Second * 1)
|
||||
tw.advance(ta)
|
||||
assert.Equal(t, 0, tw.current)
|
||||
}
|
|
@ -10,7 +10,7 @@ import (
|
|||
|
||||
func TestNewTimerWheel(t *testing.T) {
|
||||
// Make sure we get an object we expect
|
||||
tw := NewTimerWheel(time.Second, time.Second*10)
|
||||
tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
|
||||
assert.Equal(t, 12, tw.wheelLen)
|
||||
assert.Equal(t, 0, tw.current)
|
||||
assert.Nil(t, tw.lastTick)
|
||||
|
@ -19,15 +19,27 @@ func TestNewTimerWheel(t *testing.T) {
|
|||
assert.Len(t, tw.wheel, 12)
|
||||
|
||||
// Assert the math is correct
|
||||
tw = NewTimerWheel(time.Second*3, time.Second*10)
|
||||
tw = NewTimerWheel[firewall.Packet](time.Second*3, time.Second*10)
|
||||
assert.Equal(t, 5, tw.wheelLen)
|
||||
|
||||
tw = NewTimerWheel(time.Second*120, time.Minute*10)
|
||||
tw = NewTimerWheel[firewall.Packet](time.Second*120, time.Minute*10)
|
||||
assert.Equal(t, 7, tw.wheelLen)
|
||||
|
||||
// Test empty purge of non nil items
|
||||
i, ok := tw.Purge()
|
||||
assert.Equal(t, firewall.Packet{}, i)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Test empty purges of nil items
|
||||
tw2 := NewTimerWheel[*int](time.Second, time.Second*10)
|
||||
i2, ok := tw2.Purge()
|
||||
assert.Nil(t, i2)
|
||||
assert.False(t, ok)
|
||||
|
||||
}
|
||||
|
||||
func TestTimerWheel_findWheel(t *testing.T) {
|
||||
tw := NewTimerWheel(time.Second, time.Second*10)
|
||||
tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
|
||||
assert.Len(t, tw.wheel, 12)
|
||||
|
||||
// Current + tick + 1 since we don't know how far into current we are
|
||||
|
@ -49,28 +61,28 @@ func TestTimerWheel_findWheel(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestTimerWheel_Add(t *testing.T) {
|
||||
tw := NewTimerWheel(time.Second, time.Second*10)
|
||||
tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
|
||||
|
||||
fp1 := firewall.Packet{}
|
||||
tw.Add(fp1, time.Second*1)
|
||||
|
||||
// Make sure we set head and tail properly
|
||||
assert.NotNil(t, tw.wheel[2])
|
||||
assert.Equal(t, fp1, tw.wheel[2].Head.Packet)
|
||||
assert.Equal(t, fp1, tw.wheel[2].Head.Item)
|
||||
assert.Nil(t, tw.wheel[2].Head.Next)
|
||||
assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
|
||||
assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
|
||||
assert.Nil(t, tw.wheel[2].Tail.Next)
|
||||
|
||||
// Make sure we only modify head
|
||||
fp2 := firewall.Packet{}
|
||||
tw.Add(fp2, time.Second*1)
|
||||
assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
|
||||
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
|
||||
assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
|
||||
assert.Equal(t, fp2, tw.wheel[2].Head.Item)
|
||||
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Item)
|
||||
assert.Equal(t, fp1, tw.wheel[2].Tail.Item)
|
||||
assert.Nil(t, tw.wheel[2].Tail.Next)
|
||||
|
||||
// Make sure we use free'd items first
|
||||
tw.itemCache = &TimeoutItem{}
|
||||
tw.itemCache = &TimeoutItem[firewall.Packet]{}
|
||||
tw.itemsCached = 1
|
||||
tw.Add(fp2, time.Second*1)
|
||||
assert.Nil(t, tw.itemCache)
|
||||
|
@ -79,7 +91,7 @@ func TestTimerWheel_Add(t *testing.T) {
|
|||
// Ensure that all configurations of a wheel does not result in calculating an overflow of the wheel
|
||||
for min := time.Duration(1); min < 100; min++ {
|
||||
for max := min; max < 100; max++ {
|
||||
tw = NewTimerWheel(min, max)
|
||||
tw = NewTimerWheel[firewall.Packet](min, max)
|
||||
|
||||
for current := 0; current < tw.wheelLen; current++ {
|
||||
tw.current = current
|
||||
|
@ -96,9 +108,9 @@ func TestTimerWheel_Add(t *testing.T) {
|
|||
|
||||
func TestTimerWheel_Purge(t *testing.T) {
|
||||
// First advance should set the lastTick and do nothing else
|
||||
tw := NewTimerWheel(time.Second, time.Second*10)
|
||||
tw := NewTimerWheel[firewall.Packet](time.Second, time.Second*10)
|
||||
assert.Nil(t, tw.lastTick)
|
||||
tw.advance(time.Now())
|
||||
tw.Advance(time.Now())
|
||||
assert.NotNil(t, tw.lastTick)
|
||||
assert.Equal(t, 0, tw.current)
|
||||
|
||||
|
@ -116,7 +128,7 @@ func TestTimerWheel_Purge(t *testing.T) {
|
|||
|
||||
ta := time.Now().Add(time.Second * 3)
|
||||
lastTick := *tw.lastTick
|
||||
tw.advance(ta)
|
||||
tw.Advance(ta)
|
||||
assert.Equal(t, 3, tw.current)
|
||||
assert.True(t, tw.lastTick.After(lastTick))
|
||||
|
||||
|
@ -142,20 +154,20 @@ func TestTimerWheel_Purge(t *testing.T) {
|
|||
}
|
||||
assert.Nil(t, ci)
|
||||
|
||||
// Lets make sure we roll over properly
|
||||
// Let's make sure we roll over properly
|
||||
ta = ta.Add(time.Second * 5)
|
||||
tw.advance(ta)
|
||||
tw.Advance(ta)
|
||||
assert.Equal(t, 8, tw.current)
|
||||
|
||||
ta = ta.Add(time.Second * 2)
|
||||
tw.advance(ta)
|
||||
tw.Advance(ta)
|
||||
assert.Equal(t, 10, tw.current)
|
||||
|
||||
ta = ta.Add(time.Second * 1)
|
||||
tw.advance(ta)
|
||||
tw.Advance(ta)
|
||||
assert.Equal(t, 11, tw.current)
|
||||
|
||||
ta = ta.Add(time.Second * 1)
|
||||
tw.advance(ta)
|
||||
tw.Advance(ta)
|
||||
assert.Equal(t, 0, tw.current)
|
||||
}
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue