0
0
Fork 0
mirror of https://github.com/slackhq/nebula.git synced 2025-04-11 08:11:18 +00:00

Implement ECMP for unsafe_routes ()

This commit is contained in:
dioss-Machiel 2025-03-24 23:15:59 +01:00 committed by GitHub
parent 3de36c99b6
commit f86953ca56
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
21 changed files with 690 additions and 82 deletions

View file

@ -239,7 +239,28 @@ tun:
# Unsafe routes allows you to route traffic over nebula to non-nebula nodes
# Unsafe routes should be avoided unless you have hosts/services that cannot run nebula
# NOTE: The nebula certificate of the "via" node *MUST* have the "route" defined as a subnet in its certificate
# Supports weighted ECMP if you define a list of gateways, this can be used for load balancing or redundancy to hosts outside of nebula
# NOTES:
# * You will only see a single gateway in the routing table if you are not on linux
# * If a gateway is not reachable through the overlay another gateway will be selected to send the traffic through, ignoring weights
#
# unsafe_routes:
# # Multiple gateways without defining a weight defaults to a weight of 1, this will balance traffic equally between the three gateways
# - route: 192.168.87.0/24
# via:
# - gateway: 10.0.0.1
# - gateway: 10.0.0.2
# - gateway: 10.0.0.3
# # Multiple gateways with a weight, this will balance traffic accordingly
# - route: 192.168.87.0/24
# via:
# - gateway: 10.0.0.1
# weight: 10
# - gateway: 10.0.0.2
# weight: 5
#
# NOTE: The nebula certificate of the "via" node(s) *MUST* have the "route" defined as a subnet in its certificate
# `via`: single node or list of gateways to use for this route
# `mtu`: will default to tun mtu if this option is not specified
# `metric`: will default to 0 if this option is not specified
# `install`: will default to true, controls whether this route is installed in the systems routing table.

View file

@ -8,6 +8,7 @@ import (
"github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/noiseutil"
"github.com/slackhq/nebula/routing"
)
func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet, nb, out []byte, q int, localCache firewall.ConntrackCache) {
@ -49,7 +50,7 @@ func (f *Interface) consumeInsidePacket(packet []byte, fwPacket *firewall.Packet
return
}
hostinfo, ready := f.getOrHandshake(fwPacket.RemoteAddr, func(hh *HandshakeHostInfo) {
hostinfo, ready := f.getOrHandshakeConsiderRouting(fwPacket, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, header.Message, 0, packet, f.sendMessageNow, f.cachedPacketMetrics)
})
@ -121,22 +122,94 @@ func (f *Interface) rejectOutside(packet []byte, ci *ConnectionState, hostinfo *
f.sendNoMetrics(header.Message, 0, ci, hostinfo, netip.AddrPort{}, out, nb, packet, q)
}
// Handshake will attempt to initiate a tunnel with the provided vpn address if it is within our vpn networks. This is a no-op if the tunnel is already established or being established
func (f *Interface) Handshake(vpnAddr netip.Addr) {
f.getOrHandshake(vpnAddr, nil)
f.getOrHandshakeNoRouting(vpnAddr, nil)
}
// getOrHandshake returns nil if the vpnAddr is not routable.
// getOrHandshakeNoRouting returns nil if the vpnAddr is not routable.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel
func (f *Interface) getOrHandshake(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
func (f *Interface) getOrHandshakeNoRouting(vpnAddr netip.Addr, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
_, found := f.myVpnNetworksTable.Lookup(vpnAddr)
if !found {
vpnAddr = f.inside.RouteFor(vpnAddr)
if !vpnAddr.IsValid() {
return nil, false
}
if found {
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
}
return nil, false
}
// getOrHandshakeConsiderRouting will try to find the HostInfo to handle this packet, starting a handshake if necessary.
// If the 2nd return var is false then the hostinfo is not ready to be used in a tunnel.
func (f *Interface) getOrHandshakeConsiderRouting(fwPacket *firewall.Packet, cacheCallback func(*HandshakeHostInfo)) (*HostInfo, bool) {
destinationAddr := fwPacket.RemoteAddr
hostinfo, ready := f.getOrHandshakeNoRouting(destinationAddr, cacheCallback)
// Host is inside the mesh, no routing required
if hostinfo != nil {
return hostinfo, ready
}
gateways := f.inside.RoutesFor(destinationAddr)
switch len(gateways) {
case 0:
return nil, false
case 1:
// Single gateway route
return f.handshakeManager.GetOrHandshake(gateways[0].Addr(), cacheCallback)
default:
// Multi gateway route, perform ECMP categorization
gatewayAddr, balancingOk := routing.BalancePacket(fwPacket, gateways)
if !balancingOk {
// This happens if the gateway buckets were not calculated, this _should_ never happen
f.l.Error("Gateway buckets not calculated, fallback from ECMP to random routing. Please report this bug.")
}
var handshakeInfoForChosenGateway *HandshakeHostInfo
var hhReceiver = func(hh *HandshakeHostInfo) {
handshakeInfoForChosenGateway = hh
}
// Store the handshakeHostInfo for later.
// If this node is not reachable we will attempt other nodes, if none are reachable we will
// cache the packet for this gateway.
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gatewayAddr, hhReceiver); ready {
return hostinfo, true
}
// It appears the selected gateway cannot be reached, find another gateway to fallback on.
// The current implementation breaks ECMP but that seems better than no connectivity.
// If ECMP is also required when a gateway is down then connectivity status
// for each gateway needs to be kept and the weights recalculated when they go up or down.
// This would also need to interact with unsafe_route updates through reloading the config or
// use of the use_system_route_table option
if f.l.Level >= logrus.DebugLevel {
f.l.WithField("destination", destinationAddr).
WithField("originalGateway", gatewayAddr).
Debugln("Calculated gateway for ECMP not available, attempting other gateways")
}
for i := range gateways {
// Skip the gateway that failed previously
if gateways[i].Addr() == gatewayAddr {
continue
}
// We do not need the HandshakeHostInfo since we cache the packet in the originally chosen gateway
if hostinfo, ready = f.handshakeManager.GetOrHandshake(gateways[i].Addr(), nil); ready {
return hostinfo, true
}
}
// No gateways reachable, cache the packet in the originally chosen gateway
cacheCallback(handshakeInfoForChosenGateway)
return hostinfo, false
}
return f.handshakeManager.GetOrHandshake(vpnAddr, cacheCallback)
}
func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubType, hostinfo *HostInfo, p, nb, out []byte) {
@ -163,7 +236,7 @@ func (f *Interface) sendMessageNow(t header.MessageType, st header.MessageSubTyp
// SendMessageToVpnAddr handles real addr:port lookup and sends to the current best known address for vpnAddr
func (f *Interface) SendMessageToVpnAddr(t header.MessageType, st header.MessageSubType, vpnAddr netip.Addr, p, nb, out []byte) {
hostInfo, ready := f.getOrHandshake(vpnAddr, func(hh *HandshakeHostInfo) {
hostInfo, ready := f.getOrHandshakeNoRouting(vpnAddr, func(hh *HandshakeHostInfo) {
hh.cachePacket(f.l, t, st, p, f.SendMessageToHostInfo, f.cachedPacketMetrics)
})

View file

@ -3,6 +3,8 @@ package overlay
import (
"io"
"net/netip"
"github.com/slackhq/nebula/routing"
)
type Device interface {
@ -10,6 +12,6 @@ type Device interface {
Activate() error
Networks() []netip.Prefix
Name() string
RouteFor(netip.Addr) netip.Addr
RoutesFor(netip.Addr) routing.Gateways
NewMultiQueueReader() (io.ReadWriteCloser, error)
}

View file

@ -11,13 +11,14 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
type Route struct {
MTU int
Metric int
Cidr netip.Prefix
Via netip.Addr
Via routing.Gateways
Install bool
}
@ -47,15 +48,17 @@ func (r Route) String() string {
return s
}
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[netip.Addr], error) {
routeTree := new(bart.Table[netip.Addr])
func makeRouteTree(l *logrus.Logger, routes []Route, allowMTU bool) (*bart.Table[routing.Gateways], error) {
routeTree := new(bart.Table[routing.Gateways])
for _, r := range routes {
if !allowMTU && r.MTU > 0 {
l.WithField("route", r).Warnf("route MTU is not supported in %s", runtime.GOOS)
}
if r.Via.IsValid() {
routeTree.Insert(r.Cidr, r.Via)
gateways := r.Via
if len(gateways) > 0 {
routing.CalculateBucketsForGateways(gateways)
routeTree.Insert(r.Cidr, gateways)
}
}
return routeTree, nil
@ -201,14 +204,63 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not present", i+1)
}
via, ok := rVia.(string)
if !ok {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string: found %T", i+1, rVia)
}
var gateways routing.Gateways
viaVpnIp, err := netip.ParseAddr(via)
if err != nil {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
switch via := rVia.(type) {
case string:
viaIp, err := netip.ParseAddr(via)
if err != nil {
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes failed to parse address: %v", i+1, err)
}
gateways = routing.Gateways{routing.NewGateway(viaIp, 1)}
case []interface{}:
gateways = make(routing.Gateways, len(via))
for ig, v := range via {
gatewayMap, ok := v.(map[interface{}]interface{})
if !ok {
return nil, fmt.Errorf("entry %v in tun.unsafe_routes[%v].via is invalid", i+1, ig+1)
}
rGateway, ok := gatewayMap["gateway"]
if !ok {
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not present", i+1, ig+1)
}
parsedGateway, ok := rGateway.(string)
if !ok {
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] is not a string", i+1, ig+1)
}
gatewayIp, err := netip.ParseAddr(parsedGateway)
if err != nil {
return nil, fmt.Errorf("entry .gateway in tun.unsafe_routes[%v].via[%v] failed to parse address: %v", i+1, ig+1, err)
}
rGatewayWeight, ok := gatewayMap["weight"]
if !ok {
rGatewayWeight = 1
}
gatewayWeight, ok := rGatewayWeight.(int)
if !ok {
_, err = strconv.ParseInt(rGatewayWeight.(string), 10, 32)
if err != nil {
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not an integer", i+1, ig+1)
}
}
if gatewayWeight < 1 || gatewayWeight > math.MaxInt32 {
return nil, fmt.Errorf("entry .weight in tun.unsafe_routes[%v].via[%v] is not in range (1-%d) : %v", i+1, ig+1, math.MaxInt32, gatewayWeight)
}
gateways[ig] = routing.NewGateway(gatewayIp, gatewayWeight)
}
default:
return nil, fmt.Errorf("entry %v.via in tun.unsafe_routes is not a string or list of gateways: found %T", i+1, rVia)
}
rRoute, ok := m["route"]
@ -226,7 +278,7 @@ func parseUnsafeRoutes(c *config.C, networks []netip.Prefix) ([]Route, error) {
}
r := Route{
Via: viaVpnIp,
Via: gateways,
MTU: mtu,
Metric: metric,
Install: install,

View file

@ -6,6 +6,7 @@ import (
"testing"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -158,15 +159,39 @@ func Test_parseUnsafeRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string or list of gateways: found %T", invalidValue))
}
// Unparsable list of via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": []string{"1", "2"}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not a string or list of gateways: found []string")
// unparsable via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
// unparsable gateway
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "1"}}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] failed to parse address: ParseAddr(\"1\"): unable to parse IP")
// missing gateway element
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"weight": "1"}}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry .gateway in tun.unsafe_routes[1].via[1] is not present")
// unparsable weight element
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": []interface{}{map[interface{}]interface{}{"gateway": "10.0.0.1", "weight": "a"}}}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes)
require.EqualError(t, err, "entry .weight in tun.unsafe_routes[1].via[1] is not an integer")
// missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
@ -280,7 +305,7 @@ func Test_makeRouteTree(t *testing.T) {
nip, err := netip.ParseAddr("192.168.0.1")
require.NoError(t, err)
assert.Equal(t, nip, r)
assert.Equal(t, nip, r[0].Addr())
ip, err = netip.ParseAddr("1.0.0.1")
require.NoError(t, err)
@ -289,10 +314,91 @@ func Test_makeRouteTree(t *testing.T) {
nip, err = netip.ParseAddr("192.168.0.2")
require.NoError(t, err)
assert.Equal(t, nip, r)
assert.Equal(t, nip, r[0].Addr())
ip, err = netip.ParseAddr("1.1.0.1")
require.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.False(t, ok)
}
func Test_makeMultipathUnsafeRouteTree(t *testing.T) {
l := test.NewLogger()
c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24")
require.NoError(t, err)
c.Settings["tun"] = map[interface{}]interface{}{
"unsafe_routes": []interface{}{
map[interface{}]interface{}{
"route": "192.168.86.0/24",
"via": "192.168.100.10",
},
map[interface{}]interface{}{
"route": "192.168.87.0/24",
"via": []interface{}{
map[interface{}]interface{}{
"gateway": "10.0.0.1",
},
map[interface{}]interface{}{
"gateway": "10.0.0.2",
},
map[interface{}]interface{}{
"gateway": "10.0.0.3",
},
},
},
map[interface{}]interface{}{
"route": "192.168.89.0/24",
"via": []interface{}{
map[interface{}]interface{}{
"gateway": "10.0.0.1",
"weight": 10,
},
map[interface{}]interface{}{
"gateway": "10.0.0.2",
"weight": 5,
},
},
},
},
}
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
require.NoError(t, err)
assert.Len(t, routes, 3)
routeTree, err := makeRouteTree(l, routes, true)
require.NoError(t, err)
ip, err := netip.ParseAddr("192.168.86.1")
require.NoError(t, err)
r, ok := routeTree.Lookup(ip)
assert.True(t, ok)
nip, err := netip.ParseAddr("192.168.100.10")
require.NoError(t, err)
assert.Equal(t, nip, r[0].Addr())
ip, err = netip.ParseAddr("192.168.87.1")
require.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.True(t, ok)
expectedGateways := routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 1),
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 1),
routing.NewGateway(netip.MustParseAddr("10.0.0.3"), 1)}
routing.CalculateBucketsForGateways(expectedGateways)
assert.ElementsMatch(t, expectedGateways, r)
ip, err = netip.ParseAddr("192.168.89.1")
require.NoError(t, err)
r, ok = routeTree.Lookup(ip)
assert.True(t, ok)
expectedGateways = routing.Gateways{routing.NewGateway(netip.MustParseAddr("10.0.0.1"), 10),
routing.NewGateway(netip.MustParseAddr("10.0.0.2"), 5)}
routing.CalculateBucketsForGateways(expectedGateways)
assert.ElementsMatch(t, expectedGateways, r)
}

View file

@ -13,6 +13,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
)
@ -21,7 +22,7 @@ type tun struct {
fd int
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
}
@ -56,7 +57,7 @@ func newTun(_ *config.C, _ *logrus.Logger, _ []netip.Prefix, _ bool) (*tun, erro
return nil, fmt.Errorf("newTun not supported in Android")
}
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}

View file

@ -17,6 +17,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
netroute "golang.org/x/net/route"
"golang.org/x/sys/unix"
@ -28,7 +29,7 @@ type tun struct {
vpnNetworks []netip.Prefix
DefaultMTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
linkAddr *netroute.LinkAddr
l *logrus.Logger
@ -342,12 +343,12 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, ok := t.routeTree.Load().Lookup(ip)
if ok {
return r
}
return netip.Addr{}
return routing.Gateways{}
}
// Get the LinkAddr for the interface of the given name
@ -382,7 +383,7 @@ func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}

View file

@ -9,6 +9,7 @@ import (
"github.com/rcrowley/go-metrics"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/iputil"
"github.com/slackhq/nebula/routing"
)
type disabledTun struct {
@ -43,8 +44,8 @@ func (*disabledTun) Activate() error {
return nil
}
func (*disabledTun) RouteFor(addr netip.Addr) netip.Addr {
return netip.Addr{}
func (*disabledTun) RoutesFor(addr netip.Addr) routing.Gateways {
return routing.Gateways{}
}
func (t *disabledTun) Networks() []netip.Prefix {

View file

@ -20,6 +20,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
)
@ -50,7 +51,7 @@ type tun struct {
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
io.ReadWriteCloser
@ -242,7 +243,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@ -262,7 +263,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}

View file

@ -16,6 +16,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
)
@ -23,7 +24,7 @@ type tun struct {
io.ReadWriteCloser
vpnNetworks []netip.Prefix
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
}
@ -79,7 +80,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}

View file

@ -17,6 +17,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
"github.com/vishvananda/netlink"
"golang.org/x/sys/unix"
@ -34,7 +35,7 @@ type tun struct {
ioctlFd uintptr
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
routeChan chan struct{}
useSystemRoutes bool
@ -231,7 +232,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return file, nil
}
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@ -550,20 +551,7 @@ func (t *tun) watchRoutes() {
}()
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
if r.Gw == nil {
// Not a gateway route, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not a gateway route")
return
}
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
return
}
gwAddr = gwAddr.Unmap()
func (t *tun) isGatewayInVpnNetworks(gwAddr netip.Addr) bool {
withinNetworks := false
for i := range t.vpnNetworks {
if t.vpnNetworks[i].Contains(gwAddr) {
@ -571,9 +559,68 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
break
}
}
if !withinNetworks {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our networks")
return withinNetworks
}
func (t *tun) getGatewaysFromRoute(r *netlink.Route) routing.Gateways {
var gateways routing.Gateways
link, err := netlink.LinkByName(t.Device)
if err != nil {
t.l.WithField("Devicename", t.Device).Error("Ignoring route update: failed to get link by name")
return gateways
}
// If this route is relevant to our interface and there is a gateway then add it
if r.LinkIndex == link.Attrs().Index && len(r.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(r.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
gateways = append(gateways, routing.NewGateway(gwAddr, 1))
}
}
}
for _, p := range r.MultiPath {
// If this route is relevant to our interface and there is a gateway then add it
if p.LinkIndex == link.Attrs().Index && len(p.Gw) > 0 {
gwAddr, ok := netip.AddrFromSlice(p.Gw)
if !ok {
t.l.WithField("route", r).Debug("Ignoring multipath route update, invalid gateway address")
} else {
gwAddr = gwAddr.Unmap()
if !t.isGatewayInVpnNetworks(gwAddr) {
// Gateway isn't in our overlay network, ignore
t.l.WithField("route", r).Debug("Ignoring route update, not in our network")
} else {
// p.Hops+1 = weight of the route
gateways = append(gateways, routing.NewGateway(gwAddr, p.Hops+1))
}
}
}
}
routing.CalculateBucketsForGateways(gateways)
return gateways
}
func (t *tun) updateRoutes(r netlink.RouteUpdate) {
gateways := t.getGatewaysFromRoute(&r.Route)
if len(gateways) == 0 {
// No gateways relevant to our network, no routing changes required.
t.l.WithField("route", r).Debug("Ignoring route update, no gateways")
return
}
@ -589,12 +636,12 @@ func (t *tun) updateRoutes(r netlink.RouteUpdate) {
newTree := t.routeTree.Load().Clone()
if r.Type == unix.RTM_NEWROUTE {
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Adding route")
newTree.Insert(dst, gwAddr)
t.l.WithField("destination", dst).WithField("via", gateways).Info("Adding route")
newTree.Insert(dst, gateways)
} else {
t.l.WithField("destination", dst).WithField("via", gateways).Info("Removing route")
newTree.Delete(dst)
t.l.WithField("destination", r.Dst).WithField("via", r.Gw).Info("Removing route")
}
t.routeTree.Store(newTree)
}

View file

@ -18,6 +18,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
)
@ -31,7 +32,7 @@ type tun struct {
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
io.ReadWriteCloser
@ -177,7 +178,7 @@ func (t *tun) reload(c *config.C, initial bool) error {
return nil
}
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@ -197,7 +198,7 @@ func (t *tun) NewMultiQueueReader() (io.ReadWriteCloser, error) {
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}

View file

@ -17,6 +17,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
)
@ -25,7 +26,7 @@ type tun struct {
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
io.ReadWriteCloser
@ -158,7 +159,7 @@ func (t *tun) Activate() error {
return nil
}
func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}
@ -166,7 +167,7 @@ func (t *tun) RouteFor(ip netip.Addr) netip.Addr {
func (t *tun) addRoutes(logErrors bool) error {
routes := *t.Routes.Load()
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}

View file

@ -13,13 +13,14 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
type TestTun struct {
Device string
vpnNetworks []netip.Prefix
Routes []Route
routeTree *bart.Table[netip.Addr]
routeTree *bart.Table[routing.Gateways]
l *logrus.Logger
closed atomic.Bool
@ -86,7 +87,7 @@ func (t *TestTun) Get(block bool) []byte {
// Below this is boilerplate implementation to make nebula actually work
//********************************************************************************************************************//
func (t *TestTun) RouteFor(ip netip.Addr) netip.Addr {
func (t *TestTun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Lookup(ip)
return r
}

View file

@ -18,6 +18,7 @@ import (
"github.com/gaissmai/bart"
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
"github.com/slackhq/nebula/util"
"github.com/slackhq/nebula/wintun"
"golang.org/x/sys/windows"
@ -31,7 +32,7 @@ type winTun struct {
vpnNetworks []netip.Prefix
MTU int
Routes atomic.Pointer[[]Route]
routeTree atomic.Pointer[bart.Table[netip.Addr]]
routeTree atomic.Pointer[bart.Table[routing.Gateways]]
l *logrus.Logger
tun *wintun.NativeTun
@ -147,13 +148,16 @@ func (t *winTun) addRoutes(logErrors bool) error {
foundDefault4 := false
for _, r := range routes {
if !r.Via.IsValid() || !r.Install {
if len(r.Via) == 0 || !r.Install {
// We don't allow route MTUs so only install routes with a via
continue
}
// Add our unsafe route
err := luid.AddRoute(r.Cidr, r.Via, uint32(r.Metric))
// Windows does not support multipath routes natively, so we install only a single route.
// This is not a problem as traffic will always be sent to Nebula which handles the multipath routing internally.
// In effect this provides multipath routing support to windows supporting loadbalancing and redundancy.
err := luid.AddRoute(r.Cidr, r.Via[0].Addr(), uint32(r.Metric))
if err != nil {
retErr := util.NewContextualError("Failed to add route", map[string]interface{}{"route": r}, err)
if logErrors {
@ -198,7 +202,8 @@ func (t *winTun) removeRoutes(routes []Route) error {
continue
}
err := luid.DeleteRoute(r.Cidr, r.Via)
// See comment on luid.AddRoute
err := luid.DeleteRoute(r.Cidr, r.Via[0].Addr())
if err != nil {
t.l.WithError(err).WithField("route", r).Error("Failed to remove route")
} else {
@ -208,7 +213,7 @@ func (t *winTun) removeRoutes(routes []Route) error {
return nil
}
func (t *winTun) RouteFor(ip netip.Addr) netip.Addr {
func (t *winTun) RoutesFor(ip netip.Addr) routing.Gateways {
r, _ := t.routeTree.Load().Lookup(ip)
return r
}

View file

@ -6,6 +6,7 @@ import (
"github.com/sirupsen/logrus"
"github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/routing"
)
func NewUserDeviceFromConfig(c *config.C, l *logrus.Logger, vpnNetworks []netip.Prefix, routines int) (Device, error) {
@ -38,9 +39,13 @@ type UserDevice struct {
func (d *UserDevice) Activate() error {
return nil
}
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
func (d *UserDevice) Name() string { return "faketun0" }
func (d *UserDevice) RouteFor(ip netip.Addr) netip.Addr { return ip }
func (d *UserDevice) Networks() []netip.Prefix { return d.vpnNetworks }
func (d *UserDevice) Name() string { return "faketun0" }
func (d *UserDevice) RoutesFor(ip netip.Addr) routing.Gateways {
return routing.Gateways{routing.NewGateway(ip, 1)}
}
func (d *UserDevice) NewMultiQueueReader() (io.ReadWriteCloser, error) {
return d, nil
}

39
routing/balance.go Normal file
View file

@ -0,0 +1,39 @@
package routing
import (
"net/netip"
"github.com/slackhq/nebula/firewall"
)
// Hashes the packet source and destination port and always returns a positive integer
// Based on 'Prospecting for Hash Functions'
// - https://nullprogram.com/blog/2018/07/31/
// - https://github.com/skeeto/hash-prospector
// [16 21f0aaad 15 d35a2d97 15] = 0.10760229515479501
func hashPacket(p *firewall.Packet) int {
x := (uint32(p.LocalPort) << 16) | uint32(p.RemotePort)
x ^= x >> 16
x *= 0x21f0aaad
x ^= x >> 15
x *= 0xd35a2d97
x ^= x >> 15
return int(x) & 0x7FFFFFFF
}
// For this function to work correctly it requires that the buckets for the gateways have been calculated
// If the contract is violated balancing will not work properly and the second return value will return false
func BalancePacket(fwPacket *firewall.Packet, gateways []Gateway) (netip.Addr, bool) {
hash := hashPacket(fwPacket)
for i := range gateways {
if hash <= gateways[i].BucketUpperBound() {
return gateways[i].Addr(), true
}
}
// If you land here then the buckets for the gateways are not properly calculated
// Fallback to random routing and let the caller know
return gateways[hash%len(gateways)].Addr(), false
}

144
routing/balance_test.go Normal file
View file

@ -0,0 +1,144 @@
package routing
import (
"net/netip"
"testing"
"github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert"
)
func TestPacketsAreBalancedEqually(t *testing.T) {
gateways := []Gateway{}
gw1Addr := netip.MustParseAddr("1.0.0.1")
gw2Addr := netip.MustParseAddr("1.0.0.2")
gw3Addr := netip.MustParseAddr("1.0.0.3")
gateways = append(gateways, NewGateway(gw1Addr, 1))
gateways = append(gateways, NewGateway(gw2Addr, 1))
gateways = append(gateways, NewGateway(gw3Addr, 1))
CalculateBucketsForGateways(gateways)
gw1count := 0
gw2count := 0
gw3count := 0
iterationCount := uint16(65535)
for i := uint16(0); i < iterationCount; i++ {
packet := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.168.1.1"),
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
LocalPort: i,
RemotePort: 65535 - i,
Protocol: 6, // TCP
Fragment: false,
}
selectedGw, ok := BalancePacket(&packet, gateways)
assert.True(t, ok)
switch selectedGw {
case gw1Addr:
gw1count += 1
case gw2Addr:
gw2count += 1
case gw3Addr:
gw3count += 1
}
}
// Assert packets are balanced, allow variation of up to 100 packets per gateway
assert.InDeltaf(t, iterationCount/3, gw1count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
assert.InDeltaf(t, iterationCount/3, gw2count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
assert.InDeltaf(t, iterationCount/3, gw3count, 100, "Expected %d +/- 100, but got %d", iterationCount/3, gw1count)
}
func TestPacketsAreBalancedByPriority(t *testing.T) {
gateways := []Gateway{}
gw1Addr := netip.MustParseAddr("1.0.0.1")
gw2Addr := netip.MustParseAddr("1.0.0.2")
gateways = append(gateways, NewGateway(gw1Addr, 10))
gateways = append(gateways, NewGateway(gw2Addr, 5))
CalculateBucketsForGateways(gateways)
gw1count := 0
gw2count := 0
iterationCount := uint16(65535)
for i := uint16(0); i < iterationCount; i++ {
packet := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.168.1.1"),
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
LocalPort: i,
RemotePort: 65535 - i,
Protocol: 6, // TCP
Fragment: false,
}
selectedGw, ok := BalancePacket(&packet, gateways)
assert.True(t, ok)
switch selectedGw {
case gw1Addr:
gw1count += 1
case gw2Addr:
gw2count += 1
}
}
iterationCountAsFloat := float32(iterationCount)
assert.InDeltaf(t, iterationCountAsFloat*(2.0/3.0), gw1count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(2.0/3.0), gw1count)
assert.InDeltaf(t, iterationCountAsFloat*(1.0/3.0), gw2count, 100, "Expected %d +/- 100, but got %d", iterationCountAsFloat*(1.0/3.0), gw2count)
}
func TestBalancePacketDistributsRandomlyAndReturnsFalseIfBucketsNotCalculated(t *testing.T) {
gateways := []Gateway{}
gw1Addr := netip.MustParseAddr("1.0.0.1")
gw2Addr := netip.MustParseAddr("1.0.0.2")
gateways = append(gateways, NewGateway(gw1Addr, 10))
gateways = append(gateways, NewGateway(gw2Addr, 5))
iterationCount := uint16(65535)
gw1count := 0
gw2count := 0
for i := uint16(0); i < iterationCount; i++ {
packet := firewall.Packet{
LocalAddr: netip.MustParseAddr("192.168.1.1"),
RemoteAddr: netip.MustParseAddr("10.0.0.1"),
LocalPort: i,
RemotePort: 65535 - i,
Protocol: 6, // TCP
Fragment: false,
}
selectedGw, ok := BalancePacket(&packet, gateways)
assert.False(t, ok)
switch selectedGw {
case gw1Addr:
gw1count += 1
case gw2Addr:
gw2count += 1
}
}
assert.Equal(t, int(iterationCount), (gw1count + gw2count))
assert.NotEqual(t, 0, gw1count)
assert.NotEqual(t, 0, gw2count)
}

70
routing/gateway.go Normal file
View file

@ -0,0 +1,70 @@
package routing
import (
"fmt"
"net/netip"
)
const (
// Sentinal value
BucketNotCalculated = -1
)
type Gateways []Gateway
func (g Gateways) String() string {
str := ""
for i, gw := range g {
str += gw.String()
if i < len(g)-1 {
str += ", "
}
}
return str
}
type Gateway struct {
addr netip.Addr
weight int
bucketUpperBound int
}
func NewGateway(addr netip.Addr, weight int) Gateway {
return Gateway{addr: addr, weight: weight, bucketUpperBound: BucketNotCalculated}
}
func (g *Gateway) BucketUpperBound() int {
return g.bucketUpperBound
}
func (g *Gateway) Addr() netip.Addr {
return g.addr
}
func (g *Gateway) String() string {
return fmt.Sprintf("{addr: %s, weight: %d}", g.addr, g.weight)
}
// Divide and round to nearest integer
func divideAndRound(v uint64, d uint64) uint64 {
var tmp uint64 = v + d/2
return tmp / d
}
// Implements Hash-Threshold mapping, equivalent to the implementation in the linux kernel.
// After this function returns each gateway will have a
// positive bucketUpperBound with a maximum value of 2147483647 (INT_MAX)
func CalculateBucketsForGateways(gateways []Gateway) {
var totalWeight int = 0
for i := range gateways {
totalWeight += gateways[i].weight
}
var loopWeight int = 0
for i := range gateways {
loopWeight += gateways[i].weight
gateways[i].bucketUpperBound = int(divideAndRound(uint64(loopWeight)<<31, uint64(totalWeight))) - 1
}
}

34
routing/gateway_test.go Normal file
View file

@ -0,0 +1,34 @@
package routing
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
)
func TestRebalance3_2Split(t *testing.T) {
gateways := []Gateway{}
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 10})
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 5})
CalculateBucketsForGateways(gateways)
assert.Equal(t, 1431655764, gateways[0].bucketUpperBound) // INT_MAX/3*2
assert.Equal(t, 2147483647, gateways[1].bucketUpperBound) // INT_MAX
}
func TestRebalanceEqualSplit(t *testing.T) {
gateways := []Gateway{}
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
gateways = append(gateways, Gateway{addr: netip.Addr{}, weight: 1})
CalculateBucketsForGateways(gateways)
assert.Equal(t, 715827882, gateways[0].bucketUpperBound) // INT_MAX/3
assert.Equal(t, 1431655764, gateways[1].bucketUpperBound) // INT_MAX/3*2
assert.Equal(t, 2147483647, gateways[2].bucketUpperBound) // INT_MAX
}

View file

@ -4,12 +4,14 @@ import (
"errors"
"io"
"net/netip"
"github.com/slackhq/nebula/routing"
)
type NoopTun struct{}
func (NoopTun) RouteFor(addr netip.Addr) netip.Addr {
return netip.Addr{}
func (NoopTun) RoutesFor(addr netip.Addr) routing.Gateways {
return routing.Gateways{}
}
func (NoopTun) Activate() error {