mirror of
https://github.com/slackhq/nebula.git
synced 2025-01-11 20:08:12 +00:00
aba42f9fa6
* enforce the use of goimports Instead of enforcing `gofmt`, enforce `goimports`, which also asserts a separate section for non-builtin packages. * run `goimports` everywhere * exclude generated .pb.go files
139 lines
3.5 KiB
Go
139 lines
3.5 KiB
Go
package nebula
|
|
|
|
import (
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/stretchr/testify/assert"
|
|
)
|
|
|
|
func TestNewTimerWheel(t *testing.T) {
|
|
// Make sure we get an object we expect
|
|
tw := NewTimerWheel(time.Second, time.Second*10)
|
|
assert.Equal(t, 11, tw.wheelLen)
|
|
assert.Equal(t, 0, tw.current)
|
|
assert.Nil(t, tw.lastTick)
|
|
assert.Equal(t, time.Second*1, tw.tickDuration)
|
|
assert.Equal(t, time.Second*10, tw.wheelDuration)
|
|
assert.Len(t, tw.wheel, 11)
|
|
|
|
// Assert the math is correct
|
|
tw = NewTimerWheel(time.Second*3, time.Second*10)
|
|
assert.Equal(t, 4, tw.wheelLen)
|
|
|
|
tw = NewTimerWheel(time.Second*120, time.Minute*10)
|
|
assert.Equal(t, 6, tw.wheelLen)
|
|
}
|
|
|
|
func TestTimerWheel_findWheel(t *testing.T) {
|
|
tw := NewTimerWheel(time.Second, time.Second*10)
|
|
assert.Len(t, tw.wheel, 11)
|
|
|
|
// Current + tick + 1 since we don't know how far into current we are
|
|
assert.Equal(t, 2, tw.findWheel(time.Second*1))
|
|
|
|
// Scale up to min duration
|
|
assert.Equal(t, 2, tw.findWheel(time.Millisecond*1))
|
|
|
|
// Make sure we hit that last index
|
|
assert.Equal(t, 0, tw.findWheel(time.Second*10))
|
|
|
|
// Scale down to max duration
|
|
assert.Equal(t, 0, tw.findWheel(time.Second*11))
|
|
|
|
tw.current = 1
|
|
// Make sure we account for the current position properly
|
|
assert.Equal(t, 3, tw.findWheel(time.Second*1))
|
|
assert.Equal(t, 1, tw.findWheel(time.Second*10))
|
|
}
|
|
|
|
func TestTimerWheel_Add(t *testing.T) {
|
|
tw := NewTimerWheel(time.Second, time.Second*10)
|
|
|
|
fp1 := FirewallPacket{}
|
|
tw.Add(fp1, time.Second*1)
|
|
|
|
// Make sure we set head and tail properly
|
|
assert.NotNil(t, tw.wheel[2])
|
|
assert.Equal(t, fp1, tw.wheel[2].Head.Packet)
|
|
assert.Nil(t, tw.wheel[2].Head.Next)
|
|
assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
|
|
assert.Nil(t, tw.wheel[2].Tail.Next)
|
|
|
|
// Make sure we only modify head
|
|
fp2 := FirewallPacket{}
|
|
tw.Add(fp2, time.Second*1)
|
|
assert.Equal(t, fp2, tw.wheel[2].Head.Packet)
|
|
assert.Equal(t, fp1, tw.wheel[2].Head.Next.Packet)
|
|
assert.Equal(t, fp1, tw.wheel[2].Tail.Packet)
|
|
assert.Nil(t, tw.wheel[2].Tail.Next)
|
|
|
|
// Make sure we use free'd items first
|
|
tw.itemCache = &TimeoutItem{}
|
|
tw.itemsCached = 1
|
|
tw.Add(fp2, time.Second*1)
|
|
assert.Nil(t, tw.itemCache)
|
|
assert.Equal(t, 0, tw.itemsCached)
|
|
}
|
|
|
|
func TestTimerWheel_Purge(t *testing.T) {
|
|
// First advance should set the lastTick and do nothing else
|
|
tw := NewTimerWheel(time.Second, time.Second*10)
|
|
assert.Nil(t, tw.lastTick)
|
|
tw.advance(time.Now())
|
|
assert.NotNil(t, tw.lastTick)
|
|
assert.Equal(t, 0, tw.current)
|
|
|
|
fps := []FirewallPacket{
|
|
{LocalIP: 1},
|
|
{LocalIP: 2},
|
|
{LocalIP: 3},
|
|
{LocalIP: 4},
|
|
}
|
|
|
|
tw.Add(fps[0], time.Second*1)
|
|
tw.Add(fps[1], time.Second*1)
|
|
tw.Add(fps[2], time.Second*2)
|
|
tw.Add(fps[3], time.Second*2)
|
|
|
|
ta := time.Now().Add(time.Second * 3)
|
|
lastTick := *tw.lastTick
|
|
tw.advance(ta)
|
|
assert.Equal(t, 3, tw.current)
|
|
assert.True(t, tw.lastTick.After(lastTick))
|
|
|
|
// Make sure we get all 4 packets back
|
|
for i := 0; i < 4; i++ {
|
|
p, has := tw.Purge()
|
|
assert.True(t, has)
|
|
assert.Equal(t, fps[i], p)
|
|
}
|
|
|
|
// Make sure there aren't any leftover
|
|
_, ok := tw.Purge()
|
|
assert.False(t, ok)
|
|
assert.Nil(t, tw.expired.Head)
|
|
assert.Nil(t, tw.expired.Tail)
|
|
|
|
// Make sure we cached the free'd items
|
|
assert.Equal(t, 4, tw.itemsCached)
|
|
ci := tw.itemCache
|
|
for i := 0; i < 4; i++ {
|
|
assert.NotNil(t, ci)
|
|
ci = ci.Next
|
|
}
|
|
assert.Nil(t, ci)
|
|
|
|
// Lets make sure we roll over properly
|
|
ta = ta.Add(time.Second * 5)
|
|
tw.advance(ta)
|
|
assert.Equal(t, 8, tw.current)
|
|
|
|
ta = ta.Add(time.Second * 2)
|
|
tw.advance(ta)
|
|
assert.Equal(t, 10, tw.current)
|
|
|
|
ta = ta.Add(time.Second * 1)
|
|
tw.advance(ta)
|
|
assert.Equal(t, 0, tw.current)
|
|
}
|