0
0
Fork 0
mirror of https://github.com/slackhq/nebula.git synced 2025-01-11 03:48:12 +00:00
slackhq_nebula/handshake_manager.go
Wade Simmons 4eb1da0958
remove deadlock in GetOrHandshake (#1151)
We had a rare deadlock in GetOrHandshake because we kept the hostmap
lock when we do the call to StartHandshake. StartHandshake can block
while sending to the lighthouse query worker channel, and that worker
needs to be able to grab the hostmap lock to do its work. Other calls
for StartHandshake don't hold the hostmap lock so we should be able to
drop it here.

This lock was originally added with: https://github.com/slackhq/nebula/pull/954
2024-05-29 12:52:52 -04:00

661 lines
21 KiB
Go

package nebula
import (
"bytes"
"context"
"crypto/rand"
"encoding/binary"
"errors"
"net"
"sync"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/udp"
)
const (
DefaultHandshakeTryInterval = time.Millisecond * 100
DefaultHandshakeRetries = 10
DefaultHandshakeTriggerBuffer = 64
DefaultUseRelays = true
)
var (
defaultHandshakeConfig = HandshakeConfig{
tryInterval: DefaultHandshakeTryInterval,
retries: DefaultHandshakeRetries,
triggerBuffer: DefaultHandshakeTriggerBuffer,
useRelays: DefaultUseRelays,
}
)
type HandshakeConfig struct {
tryInterval time.Duration
retries int
triggerBuffer int
useRelays bool
messageMetrics *MessageMetrics
}
type HandshakeManager struct {
// Mutex for interacting with the vpnIps and indexes maps
sync.RWMutex
vpnIps map[iputil.VpnIp]*HandshakeHostInfo
indexes map[uint32]*HandshakeHostInfo
mainHostMap *HostMap
lightHouse *LightHouse
outside udp.Conn
config HandshakeConfig
OutboundHandshakeTimer *LockingTimerWheel[iputil.VpnIp]
messageMetrics *MessageMetrics
metricInitiated metrics.Counter
metricTimedOut metrics.Counter
f *Interface
l *logrus.Logger
// can be used to trigger outbound handshake for the given vpnIp
trigger chan iputil.VpnIp
}
type HandshakeHostInfo struct {
sync.Mutex
startTime time.Time // Time that we first started trying with this handshake
ready bool // Is the handshake ready
counter int // How many attempts have we made so far
lastRemotes []*udp.Addr // Remotes that we sent to during the previous attempt
packetStore []*cachedPacket // A set of packets to be transmitted once the handshake completes
hostinfo *HostInfo
}
func (hh *HandshakeHostInfo) cachePacket(l *logrus.Logger, t header.MessageType, st header.MessageSubType, packet []byte, f packetCallback, m *cachedPacketMetrics) {
if len(hh.packetStore) < 100 {
tempPacket := make([]byte, len(packet))
copy(tempPacket, packet)
hh.packetStore = append(hh.packetStore, &cachedPacket{t, st, f, tempPacket})
if l.Level >= logrus.DebugLevel {
hh.hostinfo.logger(l).
WithField("length", len(hh.packetStore)).
WithField("stored", true).
Debugf("Packet store")
}
} else {
m.dropped.Inc(1)
if l.Level >= logrus.DebugLevel {
hh.hostinfo.logger(l).
WithField("length", len(hh.packetStore)).
WithField("stored", false).
Debugf("Packet store")
}
}
}
func NewHandshakeManager(l *logrus.Logger, mainHostMap *HostMap, lightHouse *LightHouse, outside udp.Conn, config HandshakeConfig) *HandshakeManager {
return &HandshakeManager{
vpnIps: map[iputil.VpnIp]*HandshakeHostInfo{},
indexes: map[uint32]*HandshakeHostInfo{},
mainHostMap: mainHostMap,
lightHouse: lightHouse,
outside: outside,
config: config,
trigger: make(chan iputil.VpnIp, config.triggerBuffer),
OutboundHandshakeTimer: NewLockingTimerWheel[iputil.VpnIp](config.tryInterval, hsTimeout(config.retries, config.tryInterval)),
messageMetrics: config.messageMetrics,
metricInitiated: metrics.GetOrRegisterCounter("handshake_manager.initiated", nil),
metricTimedOut: metrics.GetOrRegisterCounter("handshake_manager.timed_out", nil),
l: l,
}
}
func (c *HandshakeManager) Run(ctx context.Context) {
clockSource := time.NewTicker(c.config.tryInterval)
defer clockSource.Stop()
for {
select {
case <-ctx.Done():
return
case vpnIP := <-c.trigger:
c.handleOutbound(vpnIP, true)
case now := <-clockSource.C:
c.NextOutboundHandshakeTimerTick(now)
}
}
}
func (hm *HandshakeManager) HandleIncoming(addr *udp.Addr, via *ViaSender, packet []byte, h *header.H) {
// First remote allow list check before we know the vpnIp
if addr != nil {
if !hm.lightHouse.GetRemoteAllowList().AllowUnknownVpnIp(addr.IP) {
hm.l.WithField("udpAddr", addr).Debug("lighthouse.remote_allow_list denied incoming handshake")
return
}
}
switch h.Subtype {
case header.HandshakeIXPSK0:
switch h.MessageCounter {
case 1:
ixHandshakeStage1(hm.f, addr, via, packet, h)
case 2:
newHostinfo := hm.queryIndex(h.RemoteIndex)
tearDown := ixHandshakeStage2(hm.f, addr, via, newHostinfo, packet, h)
if tearDown && newHostinfo != nil {
hm.DeleteHostInfo(newHostinfo.hostinfo)
}
}
}
}
func (c *HandshakeManager) NextOutboundHandshakeTimerTick(now time.Time) {
c.OutboundHandshakeTimer.Advance(now)
for {
vpnIp, has := c.OutboundHandshakeTimer.Purge()
if !has {
break
}
c.handleOutbound(vpnIp, false)
}
}
func (hm *HandshakeManager) handleOutbound(vpnIp iputil.VpnIp, lighthouseTriggered bool) {
hh := hm.queryVpnIp(vpnIp)
if hh == nil {
return
}
hh.Lock()
defer hh.Unlock()
hostinfo := hh.hostinfo
// If we are out of time, clean up
if hh.counter >= hm.config.retries {
hh.hostinfo.logger(hm.l).WithField("udpAddrs", hh.hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())).
WithField("initiatorIndex", hh.hostinfo.localIndexId).
WithField("remoteIndex", hh.hostinfo.remoteIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithField("durationNs", time.Since(hh.startTime).Nanoseconds()).
Info("Handshake timed out")
hm.metricTimedOut.Inc(1)
hm.DeleteHostInfo(hostinfo)
return
}
// Increment the counter to increase our delay, linear backoff
hh.counter++
// Check if we have a handshake packet to transmit yet
if !hh.ready {
if !ixHandshakeStage0(hm.f, hh) {
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter))
return
}
}
// Get a remotes object if we don't already have one.
// This is mainly to protect us as this should never be the case
// NB ^ This comment doesn't jive. It's how the thing gets initialized.
// It's the common path. Should it update every time, in case a future LH query/queries give us more info?
if hostinfo.remotes == nil {
hostinfo.remotes = hm.lightHouse.QueryCache(vpnIp)
}
remotes := hostinfo.remotes.CopyAddrs(hm.mainHostMap.GetPreferredRanges())
remotesHaveChanged := !udp.AddrSlice(remotes).Equal(hh.lastRemotes)
// We only care about a lighthouse trigger if we have new remotes to send to.
// This is a very specific optimization for a fast lighthouse reply.
if lighthouseTriggered && !remotesHaveChanged {
// If we didn't return here a lighthouse could cause us to aggressively send handshakes
return
}
hh.lastRemotes = remotes
// TODO: this will generate a load of queries for hosts with only 1 ip
// (such as ones registered to the lighthouse with only a private IP)
// So we only do it one time after attempting 5 handshakes already.
if len(remotes) <= 1 && hh.counter == 5 {
// If we only have 1 remote it is highly likely our query raced with the other host registered within the lighthouse
// Our vpnIp here has a tunnel with a lighthouse but has yet to send a host update packet there so we only know about
// the learned public ip for them. Query again to short circuit the promotion counter
hm.lightHouse.QueryServer(vpnIp)
}
// Send the handshake to all known ips, stage 2 takes care of assigning the hostinfo.remote based on the first to reply
var sentTo []*udp.Addr
hostinfo.remotes.ForEach(hm.mainHostMap.GetPreferredRanges(), func(addr *udp.Addr, _ bool) {
hm.messageMetrics.Tx(header.Handshake, header.MessageSubType(hostinfo.HandshakePacket[0][1]), 1)
err := hm.outside.WriteTo(hostinfo.HandshakePacket[0], addr)
if err != nil {
hostinfo.logger(hm.l).WithField("udpAddr", addr).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
WithError(err).Error("Failed to send handshake message")
} else {
sentTo = append(sentTo, addr)
}
})
// Don't be too noisy or confusing if we fail to send a handshake - if we don't get through we'll eventually log a timeout,
// so only log when the list of remotes has changed
if remotesHaveChanged {
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Info("Handshake message sent")
} else if hm.l.IsLevelEnabled(logrus.DebugLevel) {
hostinfo.logger(hm.l).WithField("udpAddrs", sentTo).
WithField("initiatorIndex", hostinfo.localIndexId).
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
Debug("Handshake message sent")
}
if hm.config.useRelays && len(hostinfo.remotes.relays) > 0 {
hostinfo.logger(hm.l).WithField("relays", hostinfo.remotes.relays).Info("Attempt to relay through hosts")
// Send a RelayRequest to all known Relay IP's
for _, relay := range hostinfo.remotes.relays {
// Don't relay to myself, and don't relay through the host I'm trying to connect to
if *relay == vpnIp || *relay == hm.lightHouse.myVpnIp {
continue
}
relayHostInfo := hm.mainHostMap.QueryVpnIp(*relay)
if relayHostInfo == nil || relayHostInfo.remote == nil {
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Establish tunnel to relay target")
hm.f.Handshake(*relay)
continue
}
// Check the relay HostInfo to see if we already established a relay through it
if existingRelay, ok := relayHostInfo.relayState.QueryRelayForByIp(vpnIp); ok {
switch existingRelay.State {
case Established:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Send handshake via relay")
hm.f.SendVia(relayHostInfo, existingRelay, hostinfo.HandshakePacket[0], make([]byte, 12), make([]byte, mtu), false)
case Requested:
hostinfo.logger(hm.l).WithField("relay", relay.String()).Info("Re-send CreateRelay request")
// Re-send the CreateRelay request, in case the previous one was lost.
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: existingRelay.LocalIndex,
RelayFromIp: uint32(hm.lightHouse.myVpnIp),
RelayToIp: uint32(vpnIp),
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
} else {
// This must send over the hostinfo, not over hm.Hosts[ip]
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.lightHouse.myVpnIp,
"relayTo": vpnIp,
"initiatorRelayIndex": existingRelay.LocalIndex,
"relay": *relay}).
Info("send CreateRelayRequest")
}
default:
hostinfo.logger(hm.l).
WithField("vpnIp", vpnIp).
WithField("state", existingRelay.State).
WithField("relay", relayHostInfo.vpnIp).
Errorf("Relay unexpected state")
}
} else {
// No relays exist or requested yet.
if relayHostInfo.remote != nil {
idx, err := AddRelay(hm.l, relayHostInfo, hm.mainHostMap, vpnIp, nil, TerminalType, Requested)
if err != nil {
hostinfo.logger(hm.l).WithField("relay", relay.String()).WithError(err).Info("Failed to add relay to hostmap")
}
m := NebulaControl{
Type: NebulaControl_CreateRelayRequest,
InitiatorRelayIndex: idx,
RelayFromIp: uint32(hm.lightHouse.myVpnIp),
RelayToIp: uint32(vpnIp),
}
msg, err := m.Marshal()
if err != nil {
hostinfo.logger(hm.l).
WithError(err).
Error("Failed to marshal Control message to create relay")
} else {
hm.f.SendMessageToHostInfo(header.Control, 0, relayHostInfo, msg, make([]byte, 12), make([]byte, mtu))
hm.l.WithFields(logrus.Fields{
"relayFrom": hm.lightHouse.myVpnIp,
"relayTo": vpnIp,
"initiatorRelayIndex": idx,
"relay": *relay}).
Info("send CreateRelayRequest")
}
}
}
}
}
// If a lighthouse triggered this attempt then we are still in the timer wheel and do not need to re-add
if !lighthouseTriggered {
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval*time.Duration(hh.counter))
}
}
// GetOrHandshake will try to find a hostinfo with a fully formed tunnel or start a new handshake if one is not present
// The 2nd argument will be true if the hostinfo is ready to transmit traffic
func (hm *HandshakeManager) GetOrHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) (*HostInfo, bool) {
hm.mainHostMap.RLock()
h, ok := hm.mainHostMap.Hosts[vpnIp]
hm.mainHostMap.RUnlock()
if ok {
// Do not attempt promotion if you are a lighthouse
if !hm.lightHouse.amLighthouse {
h.TryPromoteBest(hm.mainHostMap.GetPreferredRanges(), hm.f)
}
return h, true
}
return hm.StartHandshake(vpnIp, cacheCb), false
}
// StartHandshake will ensure a handshake is currently being attempted for the provided vpn ip
func (hm *HandshakeManager) StartHandshake(vpnIp iputil.VpnIp, cacheCb func(*HandshakeHostInfo)) *HostInfo {
hm.Lock()
if hh, ok := hm.vpnIps[vpnIp]; ok {
// We are already trying to handshake with this vpn ip
if cacheCb != nil {
cacheCb(hh)
}
hm.Unlock()
return hh.hostinfo
}
hostinfo := &HostInfo{
vpnIp: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
relayState: RelayState{
relays: map[iputil.VpnIp]struct{}{},
relayForByIp: map[iputil.VpnIp]*Relay{},
relayForByIdx: map[uint32]*Relay{},
},
}
hh := &HandshakeHostInfo{
hostinfo: hostinfo,
startTime: time.Now(),
}
hm.vpnIps[vpnIp] = hh
hm.metricInitiated.Inc(1)
hm.OutboundHandshakeTimer.Add(vpnIp, hm.config.tryInterval)
if cacheCb != nil {
cacheCb(hh)
}
// If this is a static host, we don't need to wait for the HostQueryReply
// We can trigger the handshake right now
_, doTrigger := hm.lightHouse.GetStaticHostList()[vpnIp]
if !doTrigger {
// Add any calculated remotes, and trigger early handshake if one found
doTrigger = hm.lightHouse.addCalculatedRemotes(vpnIp)
}
if doTrigger {
select {
case hm.trigger <- vpnIp:
default:
}
}
hm.Unlock()
hm.lightHouse.QueryServer(vpnIp)
return hostinfo
}
var (
ErrExistingHostInfo = errors.New("existing hostinfo")
ErrAlreadySeen = errors.New("already seen")
ErrLocalIndexCollision = errors.New("local index collision")
)
// CheckAndComplete checks for any conflicts in the main and pending hostmap
// before adding hostinfo to main. If err is nil, it was added. Otherwise err will be:
//
// ErrAlreadySeen if we already have an entry in the hostmap that has seen the
// exact same handshake packet
//
// ErrExistingHostInfo if we already have an entry in the hostmap for this
// VpnIp and the new handshake was older than the one we currently have
//
// ErrLocalIndexCollision if we already have an entry in the main or pending
// hostmap for the hostinfo.localIndexId.
func (c *HandshakeManager) CheckAndComplete(hostinfo *HostInfo, handshakePacket uint8, f *Interface) (*HostInfo, error) {
c.mainHostMap.Lock()
defer c.mainHostMap.Unlock()
c.Lock()
defer c.Unlock()
// Check if we already have a tunnel with this vpn ip
existingHostInfo, found := c.mainHostMap.Hosts[hostinfo.vpnIp]
if found && existingHostInfo != nil {
testHostInfo := existingHostInfo
for testHostInfo != nil {
// Is it just a delayed handshake packet?
if bytes.Equal(hostinfo.HandshakePacket[handshakePacket], testHostInfo.HandshakePacket[handshakePacket]) {
return testHostInfo, ErrAlreadySeen
}
testHostInfo = testHostInfo.next
}
// Is this a newer handshake?
if existingHostInfo.lastHandshakeTime >= hostinfo.lastHandshakeTime && !existingHostInfo.ConnectionState.initiator {
return existingHostInfo, ErrExistingHostInfo
}
existingHostInfo.logger(c.l).Info("Taking new handshake")
}
existingIndex, found := c.mainHostMap.Indexes[hostinfo.localIndexId]
if found {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
}
existingPendingIndex, found := c.indexes[hostinfo.localIndexId]
if found && existingPendingIndex.hostinfo != hostinfo {
// We have a collision, but for a different hostinfo
return existingIndex, ErrLocalIndexCollision
}
existingRemoteIndex, found := c.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil && existingRemoteIndex.vpnIp != hostinfo.vpnIp {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(c.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
Info("New host shadows existing host remoteIndex")
}
c.mainHostMap.unlockedAddHostInfo(hostinfo, f)
return existingHostInfo, nil
}
// Complete is a simpler version of CheckAndComplete when we already know we
// won't have a localIndexId collision because we already have an entry in the
// pendingHostMap. An existing hostinfo is returned if there was one.
func (hm *HandshakeManager) Complete(hostinfo *HostInfo, f *Interface) {
hm.mainHostMap.Lock()
defer hm.mainHostMap.Unlock()
hm.Lock()
defer hm.Unlock()
existingRemoteIndex, found := hm.mainHostMap.RemoteIndexes[hostinfo.remoteIndexId]
if found && existingRemoteIndex != nil {
// We have a collision, but this can happen since we can't control
// the remote ID. Just log about the situation as a note.
hostinfo.logger(hm.l).
WithField("remoteIndex", hostinfo.remoteIndexId).WithField("collision", existingRemoteIndex.vpnIp).
Info("New host shadows existing host remoteIndex")
}
// We need to remove from the pending hostmap first to avoid undoing work when after to the main hostmap.
hm.unlockedDeleteHostInfo(hostinfo)
hm.mainHostMap.unlockedAddHostInfo(hostinfo, f)
}
// allocateIndex generates a unique localIndexId for this HostInfo
// and adds it to the pendingHostMap. Will error if we are unable to generate
// a unique localIndexId
func (hm *HandshakeManager) allocateIndex(hh *HandshakeHostInfo) error {
hm.mainHostMap.RLock()
defer hm.mainHostMap.RUnlock()
hm.Lock()
defer hm.Unlock()
for i := 0; i < 32; i++ {
index, err := generateIndex(hm.l)
if err != nil {
return err
}
_, inPending := hm.indexes[index]
_, inMain := hm.mainHostMap.Indexes[index]
if !inMain && !inPending {
hh.hostinfo.localIndexId = index
hm.indexes[index] = hh
return nil
}
}
return errors.New("failed to generate unique localIndexId")
}
func (c *HandshakeManager) DeleteHostInfo(hostinfo *HostInfo) {
c.Lock()
defer c.Unlock()
c.unlockedDeleteHostInfo(hostinfo)
}
func (c *HandshakeManager) unlockedDeleteHostInfo(hostinfo *HostInfo) {
delete(c.vpnIps, hostinfo.vpnIp)
if len(c.vpnIps) == 0 {
c.vpnIps = map[iputil.VpnIp]*HandshakeHostInfo{}
}
delete(c.indexes, hostinfo.localIndexId)
if len(c.vpnIps) == 0 {
c.indexes = map[uint32]*HandshakeHostInfo{}
}
if c.l.Level >= logrus.DebugLevel {
c.l.WithField("hostMap", m{"mapTotalSize": len(c.vpnIps),
"vpnIp": hostinfo.vpnIp, "indexNumber": hostinfo.localIndexId, "remoteIndexNumber": hostinfo.remoteIndexId}).
Debug("Pending hostmap hostInfo deleted")
}
}
func (hm *HandshakeManager) QueryVpnIp(vpnIp iputil.VpnIp) *HostInfo {
hh := hm.queryVpnIp(vpnIp)
if hh != nil {
return hh.hostinfo
}
return nil
}
func (hm *HandshakeManager) queryVpnIp(vpnIp iputil.VpnIp) *HandshakeHostInfo {
hm.RLock()
defer hm.RUnlock()
return hm.vpnIps[vpnIp]
}
func (hm *HandshakeManager) QueryIndex(index uint32) *HostInfo {
hh := hm.queryIndex(index)
if hh != nil {
return hh.hostinfo
}
return nil
}
func (hm *HandshakeManager) queryIndex(index uint32) *HandshakeHostInfo {
hm.RLock()
defer hm.RUnlock()
return hm.indexes[index]
}
func (c *HandshakeManager) GetPreferredRanges() []*net.IPNet {
return c.mainHostMap.GetPreferredRanges()
}
func (c *HandshakeManager) ForEachVpnIp(f controlEach) {
c.RLock()
defer c.RUnlock()
for _, v := range c.vpnIps {
f(v.hostinfo)
}
}
func (c *HandshakeManager) ForEachIndex(f controlEach) {
c.RLock()
defer c.RUnlock()
for _, v := range c.indexes {
f(v.hostinfo)
}
}
func (c *HandshakeManager) EmitStats() {
c.RLock()
hostLen := len(c.vpnIps)
indexLen := len(c.indexes)
c.RUnlock()
metrics.GetOrRegisterGauge("hostmap.pending.hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap.pending.indexes", nil).Update(int64(indexLen))
c.mainHostMap.EmitStats()
}
// Utility functions below
func generateIndex(l *logrus.Logger) (uint32, error) {
b := make([]byte, 4)
// Let zero mean we don't know the ID, so don't generate zero
var index uint32
for index == 0 {
_, err := rand.Read(b)
if err != nil {
l.Errorln(err)
return 0, err
}
index = binary.BigEndian.Uint32(b)
}
if l.Level >= logrus.DebugLevel {
l.WithField("index", index).
Debug("Generated index")
}
return index, nil
}
func hsTimeout(tries int, interval time.Duration) time.Duration {
return time.Duration(tries / 2 * ((2 * int(interval)) + (tries-1)*int(interval)))
}