0
0
Fork 0
mirror of https://github.com/slackhq/nebula.git synced 2025-01-27 10:19:04 +00:00
slackhq_nebula/hostmap.go
Wade Simmons b37a91cfbc
add meta packet statistics (#230)
This change add more metrics around "meta" (non "message" type packets).
For lighthouse packets, we also record statistics around the specific
lighthouse meta type.

We don't keep statistics for the "message" type so that we don't slow
down the fast path (and you can just look at metrics on the tun
interface to find that information).
2020-06-26 13:45:48 -04:00

808 lines
20 KiB
Go

package nebula
import (
"encoding/json"
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/cert"
)
//const ProbeLen = 100
const PromoteEvery = 1000
const MaxRemotes = 10
// How long we should prevent roaming back to the previous IP.
// This helps prevent flapping due to packets already in flight
const RoamingSupressSeconds = 2
type HostMap struct {
sync.RWMutex //Because we concurrently read and write to our maps
name string
Indexes map[uint32]*HostInfo
Hosts map[uint32]*HostInfo
preferredRanges []*net.IPNet
vpnCIDR *net.IPNet
defaultRoute uint32
unsafeRoutes *CIDRTree
metricsEnabled bool
}
type HostInfo struct {
remote *udpAddr
Remotes []*HostInfoDest
promoteCounter uint32
ConnectionState *ConnectionState
handshakeStart time.Time
HandshakeReady bool
HandshakeCounter int
HandshakeComplete bool
HandshakePacket map[uint8][]byte
packetStore []*cachedPacket
remoteIndexId uint32
localIndexId uint32
hostId uint32
recvError int
remoteCidr *CIDRTree
lastRoam time.Time
lastRoamRemote *udpAddr
}
type cachedPacket struct {
messageType NebulaMessageType
messageSubType NebulaMessageSubType
callback packetCallback
packet []byte
}
type packetCallback func(t NebulaMessageType, st NebulaMessageSubType, h *HostInfo, p, nb, out []byte)
type HostInfoDest struct {
active bool
addr *udpAddr
//probes [ProbeLen]bool
probeCounter int
}
type Probe struct {
Addr *net.UDPAddr
Counter int
}
func NewHostMap(name string, vpnCIDR *net.IPNet, preferredRanges []*net.IPNet) *HostMap {
h := map[uint32]*HostInfo{}
i := map[uint32]*HostInfo{}
m := HostMap{
name: name,
Indexes: i,
Hosts: h,
preferredRanges: preferredRanges,
vpnCIDR: vpnCIDR,
defaultRoute: 0,
unsafeRoutes: NewCIDRTree(),
}
return &m
}
// UpdateStats takes a name and reports host and index counts to the stats collection system
func (hm *HostMap) EmitStats(name string) {
hm.RLock()
hostLen := len(hm.Hosts)
indexLen := len(hm.Indexes)
hm.RUnlock()
metrics.GetOrRegisterGauge("hostmap."+name+".hosts", nil).Update(int64(hostLen))
metrics.GetOrRegisterGauge("hostmap."+name+".indexes", nil).Update(int64(indexLen))
}
func (hm *HostMap) GetIndexByVpnIP(vpnIP uint32) (uint32, error) {
hm.RLock()
if i, ok := hm.Hosts[vpnIP]; ok {
index := i.localIndexId
hm.RUnlock()
return index, nil
}
hm.RUnlock()
return 0, errors.New("vpn IP not found")
}
func (hm *HostMap) GetVpnIPByIndex(index uint32) (uint32, error) {
hm.RLock()
if i, ok := hm.Indexes[index]; ok {
vpnIP := i.hostId
hm.RUnlock()
return vpnIP, nil
}
hm.RUnlock()
return 0, errors.New("vpn IP not found")
}
func (hm *HostMap) Add(ip uint32, hostinfo *HostInfo) {
hm.Lock()
hm.Hosts[ip] = hostinfo
hm.Unlock()
}
func (hm *HostMap) AddVpnIP(vpnIP uint32) *HostInfo {
h := &HostInfo{}
hm.RLock()
if _, ok := hm.Hosts[vpnIP]; !ok {
hm.RUnlock()
h = &HostInfo{
Remotes: []*HostInfoDest{},
promoteCounter: 0,
hostId: vpnIP,
HandshakePacket: make(map[uint8][]byte, 0),
}
hm.Lock()
hm.Hosts[vpnIP] = h
hm.Unlock()
return h
} else {
h = hm.Hosts[vpnIP]
hm.RUnlock()
return h
}
}
func (hm *HostMap) DeleteVpnIP(vpnIP uint32) {
hm.Lock()
delete(hm.Hosts, vpnIP)
if len(hm.Hosts) == 0 {
hm.Hosts = map[uint32]*HostInfo{}
}
hm.Unlock()
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap vpnIp deleted")
}
}
func (hm *HostMap) AddIndex(index uint32, ci *ConnectionState) (*HostInfo, error) {
hm.Lock()
if _, ok := hm.Indexes[index]; !ok {
h := &HostInfo{
ConnectionState: ci,
Remotes: []*HostInfoDest{},
localIndexId: index,
HandshakePacket: make(map[uint8][]byte, 0),
}
hm.Indexes[index] = h
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": false, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap index added")
hm.Unlock()
return h, nil
}
hm.Unlock()
return nil, fmt.Errorf("refusing to overwrite existing index: %d", index)
}
func (hm *HostMap) AddIndexHostInfo(index uint32, h *HostInfo) {
hm.Lock()
h.localIndexId = index
hm.Indexes[index] = h
hm.Unlock()
if l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap index added")
}
}
func (hm *HostMap) AddVpnIPHostInfo(vpnIP uint32, h *HostInfo) {
hm.Lock()
h.hostId = vpnIP
hm.Hosts[vpnIP] = h
hm.Unlock()
if l.Level > logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIP), "mapTotalSize": len(hm.Hosts),
"hostinfo": m{"existing": true, "localIndexId": h.localIndexId, "hostId": IntIp(h.hostId)}}).
Debug("Hostmap vpnIp added")
}
}
func (hm *HostMap) DeleteIndex(index uint32) {
hm.Lock()
delete(hm.Indexes, index)
if len(hm.Indexes) == 0 {
hm.Indexes = map[uint32]*HostInfo{}
}
hm.Unlock()
if l.Level >= logrus.DebugLevel {
l.WithField("hostMap", m{"mapName": hm.name, "indexNumber": index, "mapTotalSize": len(hm.Indexes)}).
Debug("Hostmap index deleted")
}
}
func (hm *HostMap) QueryIndex(index uint32) (*HostInfo, error) {
//TODO: we probably just want ot return bool instead of error, or at least a static error
hm.RLock()
if h, ok := hm.Indexes[index]; ok {
hm.RUnlock()
return h, nil
} else {
hm.RUnlock()
return nil, errors.New("unable to find index")
}
}
// This function needs to range because we don't keep a map of remote indexes.
func (hm *HostMap) QueryReverseIndex(index uint32) (*HostInfo, error) {
hm.RLock()
for _, h := range hm.Indexes {
if h.ConnectionState != nil && h.remoteIndexId == index {
hm.RUnlock()
return h, nil
}
}
for _, h := range hm.Hosts {
if h.ConnectionState != nil && h.remoteIndexId == index {
hm.RUnlock()
return h, nil
}
}
hm.RUnlock()
return nil, fmt.Errorf("unable to find reverse index or connectionstate nil in %s hostmap", hm.name)
}
func (hm *HostMap) AddRemote(vpnIp uint32, remote *udpAddr) *HostInfo {
hm.Lock()
i, v := hm.Hosts[vpnIp]
if v {
i.AddRemote(*remote)
} else {
i = &HostInfo{
Remotes: []*HostInfoDest{NewHostInfoDest(remote)},
promoteCounter: 0,
hostId: vpnIp,
HandshakePacket: make(map[uint8][]byte, 0),
}
i.remote = i.Remotes[0].addr
hm.Hosts[vpnIp] = i
l.WithField("hostMap", m{"mapName": hm.name, "vpnIp": IntIp(vpnIp), "udpAddr": remote, "mapTotalSize": len(hm.Hosts)}).
Debug("Hostmap remote ip added")
}
i.ForcePromoteBest(hm.preferredRanges)
hm.Unlock()
return i
}
func (hm *HostMap) QueryVpnIP(vpnIp uint32) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, nil)
}
// PromoteBestQueryVpnIP will attempt to lazily switch to the best remote every
// `PromoteEvery` calls to this function for a given host.
func (hm *HostMap) PromoteBestQueryVpnIP(vpnIp uint32, ifce *Interface) (*HostInfo, error) {
return hm.queryVpnIP(vpnIp, ifce)
}
func (hm *HostMap) queryVpnIP(vpnIp uint32, promoteIfce *Interface) (*HostInfo, error) {
hm.RLock()
if h, ok := hm.Hosts[vpnIp]; ok {
if promoteIfce != nil {
h.TryPromoteBest(hm.preferredRanges, promoteIfce)
}
//fmt.Println(h.remote)
hm.RUnlock()
return h, nil
} else {
//return &net.UDPAddr{}, nil, errors.New("Unable to find host")
hm.RUnlock()
/*
if lightHouse != nil {
lightHouse.Query(vpnIp)
return nil, errors.New("Unable to find host")
}
*/
return nil, errors.New("unable to find host")
}
}
func (hm *HostMap) queryUnsafeRoute(ip uint32) uint32 {
r := hm.unsafeRoutes.MostSpecificContains(ip)
if r != nil {
return r.(uint32)
} else {
return 0
}
}
func (hm *HostMap) CheckHandshakeCompleteIP(vpnIP uint32) bool {
hm.RLock()
if i, ok := hm.Hosts[vpnIP]; ok {
if i == nil {
hm.RUnlock()
return false
}
complete := i.HandshakeComplete
hm.RUnlock()
return complete
}
hm.RUnlock()
return false
}
func (hm *HostMap) CheckHandshakeCompleteIndex(index uint32) bool {
hm.RLock()
if i, ok := hm.Indexes[index]; ok {
if i == nil {
hm.RUnlock()
return false
}
complete := i.HandshakeComplete
hm.RUnlock()
return complete
}
hm.RUnlock()
return false
}
func (hm *HostMap) ClearRemotes(vpnIP uint32) {
hm.Lock()
i := hm.Hosts[vpnIP]
if i == nil {
hm.Unlock()
return
}
i.remote = nil
i.Remotes = nil
hm.Unlock()
}
func (hm *HostMap) SetDefaultRoute(ip uint32) {
hm.defaultRoute = ip
}
func (hm *HostMap) PunchList() []*udpAddr {
var list []*udpAddr
hm.RLock()
for _, v := range hm.Hosts {
for _, r := range v.Remotes {
list = append(list, r.addr)
}
// if h, ok := hm.Hosts[vpnIp]; ok {
// hm.Hosts[vpnIp].PromoteBest(hm.preferredRanges, false)
//fmt.Println(h.remote)
// }
}
hm.RUnlock()
return list
}
func (hm *HostMap) Punchy(conn *udpConn) {
var metricsTxPunchy metrics.Counter
if hm.metricsEnabled {
metricsTxPunchy = metrics.GetOrRegisterCounter("messages.tx.punchy", nil)
} else {
metricsTxPunchy = metrics.NilCounter{}
}
for {
for _, addr := range hm.PunchList() {
metricsTxPunchy.Inc(1)
conn.WriteTo([]byte{1}, addr)
}
time.Sleep(time.Second * 30)
}
}
func (hm *HostMap) addUnsafeRoutes(routes *[]route) {
for _, r := range *routes {
l.WithField("route", r.route).WithField("via", r.via).Warn("Adding UNSAFE Route")
hm.unsafeRoutes.AddCIDR(r.route, ip2int(*r.via))
}
}
func (i *HostInfo) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"remote": i.remote,
"remotes": i.Remotes,
"promote_counter": i.promoteCounter,
"connection_state": i.ConnectionState,
"handshake_start": i.handshakeStart,
"handshake_ready": i.HandshakeReady,
"handshake_counter": i.HandshakeCounter,
"handshake_complete": i.HandshakeComplete,
"handshake_packet": i.HandshakePacket,
"packet_store": i.packetStore,
"remote_index": i.remoteIndexId,
"local_index": i.localIndexId,
"host_id": int2ip(i.hostId),
"receive_errors": i.recvError,
"last_roam": i.lastRoam,
"last_roam_remote": i.lastRoamRemote,
})
}
func (i *HostInfo) BindConnectionState(cs *ConnectionState) {
i.ConnectionState = cs
}
func (i *HostInfo) TryPromoteBest(preferredRanges []*net.IPNet, ifce *Interface) {
if i.remote == nil {
i.ForcePromoteBest(preferredRanges)
return
}
i.promoteCounter++
if i.promoteCounter%PromoteEvery == 0 {
// return early if we are already on a preferred remote
rIP := udp2ip(i.remote)
for _, l := range preferredRanges {
if l.Contains(rIP) {
return
}
}
// We re-query the lighthouse periodically while sending packets, so
// check for new remotes in our local lighthouse cache
ips := ifce.lightHouse.QueryCache(i.hostId)
for _, ip := range ips {
i.AddRemote(ip)
}
best, preferred := i.getBestRemote(preferredRanges)
if preferred && !best.Equals(i.remote) {
// Try to send a test packet to that host, this should
// cause it to detect a roaming event and switch remotes
ifce.send(test, testRequest, i.ConnectionState, i, best, []byte(""), make([]byte, 12, 12), make([]byte, mtu))
}
}
}
func (i *HostInfo) ForcePromoteBest(preferredRanges []*net.IPNet) {
best, _ := i.getBestRemote(preferredRanges)
if best != nil {
i.remote = best
}
}
func (i *HostInfo) getBestRemote(preferredRanges []*net.IPNet) (best *udpAddr, preferred bool) {
if len(i.Remotes) > 0 {
for _, r := range i.Remotes {
rIP := udp2ip(r.addr)
for _, l := range preferredRanges {
if l.Contains(rIP) {
return r.addr, true
}
}
if best == nil || !PrivateIP(rIP) {
best = r.addr
}
/*
for _, r := range i.Remotes {
// Must have > 80% probe success to be considered.
//fmt.Println("GRADE:", r.addr.IP, r.Grade())
if r.Grade() > float64(.8) {
if localToMe.Contains(r.addr.IP) == true {
best = r.addr
break
//i.remote = i.Remotes[c].addr
} else {
//}
}
*/
}
return best, false
}
return nil, false
}
// rotateRemote will move remote to the next ip in the list of remote ips for this host
// This is different than PromoteBest in that what is algorithmically best may not actually work.
// Only known use case is when sending a stage 0 handshake.
// It may be better to just send stage 0 handshakes to all known ips and sort it out in the receiver.
func (i *HostInfo) rotateRemote() {
// We have 0, can't rotate
if len(i.Remotes) < 1 {
return
}
if i.remote == nil {
i.remote = i.Remotes[0].addr
return
}
// We want to look at all but the very last entry since that is handled at the end
for x := 0; x < len(i.Remotes)-1; x++ {
// Find our current position and move to the next one in the list
if i.Remotes[x].addr.Equals(i.remote) {
i.remote = i.Remotes[x+1].addr
return
}
}
// Our current position was likely the last in the list, start over at 0
i.remote = i.Remotes[0].addr
}
func (i *HostInfo) cachePacket(t NebulaMessageType, st NebulaMessageSubType, packet []byte, f packetCallback) {
//TODO: return the error so we can log with more context
if len(i.packetStore) < 100 {
tempPacket := make([]byte, len(packet))
copy(tempPacket, packet)
//l.WithField("trace", string(debug.Stack())).Error("Caching packet", tempPacket)
i.packetStore = append(i.packetStore, &cachedPacket{t, st, f, tempPacket})
i.logger().
WithField("length", len(i.packetStore)).
WithField("stored", true).
Debugf("Packet store")
} else if l.Level >= logrus.DebugLevel {
i.logger().
WithField("length", len(i.packetStore)).
WithField("stored", false).
Debugf("Packet store")
}
}
// handshakeComplete will set the connection as ready to communicate, as well as flush any stored packets
func (i *HostInfo) handshakeComplete() {
//TODO: I'm not certain the distinction between handshake complete and ConnectionState being ready matters because:
//TODO: HandshakeComplete means send stored packets and ConnectionState.ready means we are ready to send
//TODO: if the transition from HandhsakeComplete to ConnectionState.ready happens all within this function they are identical
i.ConnectionState.queueLock.Lock()
i.HandshakeComplete = true
//TODO: this should be managed by the handshake state machine to set it based on how many handshake were seen.
// Clamping it to 2 gets us out of the woods for now
*i.ConnectionState.messageCounter = 2
i.logger().Debugf("Sending %d stored packets", len(i.packetStore))
nb := make([]byte, 12, 12)
out := make([]byte, mtu)
for _, cp := range i.packetStore {
cp.callback(cp.messageType, cp.messageSubType, i, cp.packet, nb, out)
}
i.packetStore = make([]*cachedPacket, 0)
i.ConnectionState.ready = true
i.ConnectionState.queueLock.Unlock()
i.ConnectionState.certState = nil
}
func (i *HostInfo) RemoteUDPAddrs() []*udpAddr {
var addrs []*udpAddr
for _, r := range i.Remotes {
addrs = append(addrs, r.addr)
}
return addrs
}
func (i *HostInfo) GetCert() *cert.NebulaCertificate {
if i.ConnectionState != nil {
return i.ConnectionState.peerCert
}
return nil
}
func (i *HostInfo) AddRemote(r udpAddr) *udpAddr {
remote := &r
//add := true
for _, r := range i.Remotes {
if r.addr.Equals(remote) {
return r.addr
//add = false
}
}
// Trim this down if necessary
if len(i.Remotes) > MaxRemotes {
i.Remotes = i.Remotes[len(i.Remotes)-MaxRemotes:]
}
i.Remotes = append(i.Remotes, NewHostInfoDest(remote))
return remote
//l.Debugf("Added remote %s for vpn ip", remote)
}
func (i *HostInfo) SetRemote(remote udpAddr) {
i.remote = i.AddRemote(remote)
}
func (i *HostInfo) ClearRemotes() {
i.remote = nil
i.Remotes = []*HostInfoDest{}
}
func (i *HostInfo) ClearConnectionState() {
i.ConnectionState = nil
}
func (i *HostInfo) RecvErrorExceeded() bool {
if i.recvError < 3 {
i.recvError += 1
return false
}
return true
}
func (i *HostInfo) CreateRemoteCIDR(c *cert.NebulaCertificate) {
if len(c.Details.Ips) == 1 && len(c.Details.Subnets) == 0 {
// Simple case, no CIDRTree needed
return
}
remoteCidr := NewCIDRTree()
for _, ip := range c.Details.Ips {
remoteCidr.AddCIDR(&net.IPNet{IP: ip.IP, Mask: net.IPMask{255, 255, 255, 255}}, struct{}{})
}
for _, n := range c.Details.Subnets {
remoteCidr.AddCIDR(n, struct{}{})
}
i.remoteCidr = remoteCidr
}
func (i *HostInfo) logger() *logrus.Entry {
if i == nil {
return logrus.NewEntry(l)
}
li := l.WithField("vpnIp", IntIp(i.hostId))
if connState := i.ConnectionState; connState != nil {
if peerCert := connState.peerCert; peerCert != nil {
li = li.WithField("certName", peerCert.Details.Name)
}
}
return li
}
//########################
func NewHostInfoDest(addr *udpAddr) *HostInfoDest {
i := &HostInfoDest{
addr: addr,
}
return i
}
func (hid *HostInfoDest) MarshalJSON() ([]byte, error) {
return json.Marshal(m{
"active": hid.active,
"address": hid.addr,
"probe_count": hid.probeCounter,
})
}
/*
func (hm *HostMap) DebugRemotes(vpnIp uint32) string {
s := "\n"
for _, h := range hm.Hosts {
for _, r := range h.Remotes {
s += fmt.Sprintf("%s : %d ## %v\n", r.addr.IP.String(), r.addr.Port, r.probes)
}
}
return s
}
func (d *HostInfoDest) Grade() float64 {
c1 := ProbeLen
for n := len(d.probes) - 1; n >= 0; n-- {
if d.probes[n] == true {
c1 -= 1
}
}
return float64(c1) / float64(ProbeLen)
}
func (d *HostInfoDest) Grade() (float64, float64, float64) {
c1 := ProbeLen
c2 := ProbeLen / 2
c2c := ProbeLen - ProbeLen/2
c3 := ProbeLen / 5
c3c := ProbeLen - ProbeLen/5
for n := len(d.probes) - 1; n >= 0; n-- {
if d.probes[n] == true {
c1 -= 1
if n >= c2c {
c2 -= 1
if n >= c3c {
c3 -= 1
}
}
}
//if n >= d {
}
return float64(c3) / float64(ProbeLen/5), float64(c2) / float64(ProbeLen/2), float64(c1) / float64(ProbeLen)
//return float64(c1) / float64(ProbeLen), float64(c2) / float64(ProbeLen/2), float64(c3) / float64(ProbeLen/5)
}
func (i *HostInfo) HandleReply(addr *net.UDPAddr, counter int) {
for _, r := range i.Remotes {
if r.addr.IP.Equal(addr.IP) && r.addr.Port == addr.Port {
r.ProbeReceived(counter)
}
}
}
func (i *HostInfo) Probes() []*Probe {
p := []*Probe{}
for _, d := range i.Remotes {
p = append(p, &Probe{Addr: d.addr, Counter: d.Probe()})
}
return p
}
func (d *HostInfoDest) Probe() int {
//d.probes = append(d.probes, true)
d.probeCounter++
d.probes[d.probeCounter%ProbeLen] = true
return d.probeCounter
//return d.probeCounter
}
func (d *HostInfoDest) ProbeReceived(probeCount int) {
if probeCount >= (d.probeCounter - ProbeLen) {
//fmt.Println("PROBE WORKED", probeCount)
//fmt.Println(d.addr, d.Grade())
d.probes[probeCount%ProbeLen] = false
}
}
*/
// Utility functions
func localIps(allowList *AllowList) *[]net.IP {
//FIXME: This function is pretty garbage
var ips []net.IP
ifaces, _ := net.Interfaces()
for _, i := range ifaces {
allow := allowList.AllowName(i.Name)
l.WithField("interfaceName", i.Name).WithField("allow", allow).Debug("localAllowList.AllowName")
if !allow {
continue
}
addrs, _ := i.Addrs()
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
//continue
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip.To4() != nil && ip.IsLoopback() == false {
allow := allowList.Allow(ip2int(ip))
l.WithField("localIp", ip).WithField("allow", allow).Debug("localAllowList.Allow")
if !allow {
continue
}
ips = append(ips, ip)
}
}
}
return &ips
}
func PrivateIP(ip net.IP) bool {
private := false
_, private24BitBlock, _ := net.ParseCIDR("10.0.0.0/8")
_, private20BitBlock, _ := net.ParseCIDR("172.16.0.0/12")
_, private16BitBlock, _ := net.ParseCIDR("192.168.0.0/16")
private = private24BitBlock.Contains(ip) || private20BitBlock.Contains(ip) || private16BitBlock.Contains(ip)
return private
}