mirror of
https://github.com/slackhq/nebula.git
synced 2025-01-25 17:48:25 +00:00
147 lines
2.3 KiB
Go
147 lines
2.3 KiB
Go
package nebula
|
|
|
|
import (
|
|
"encoding/binary"
|
|
"fmt"
|
|
"net"
|
|
)
|
|
|
|
type CIDRNode struct {
|
|
left *CIDRNode
|
|
right *CIDRNode
|
|
parent *CIDRNode
|
|
value interface{}
|
|
}
|
|
|
|
type CIDRTree struct {
|
|
root *CIDRNode
|
|
}
|
|
|
|
const (
|
|
startbit = uint32(0x80000000)
|
|
)
|
|
|
|
func NewCIDRTree() *CIDRTree {
|
|
tree := new(CIDRTree)
|
|
tree.root = &CIDRNode{}
|
|
return tree
|
|
}
|
|
|
|
func (tree *CIDRTree) AddCIDR(cidr *net.IPNet, val interface{}) {
|
|
bit := startbit
|
|
node := tree.root
|
|
next := tree.root
|
|
|
|
ip := ip2int(cidr.IP)
|
|
mask := ip2int(cidr.Mask)
|
|
|
|
// Find our last ancestor in the tree
|
|
for bit&mask != 0 {
|
|
if ip&bit != 0 {
|
|
next = node.right
|
|
} else {
|
|
next = node.left
|
|
}
|
|
|
|
if next == nil {
|
|
break
|
|
}
|
|
|
|
bit = bit >> 1
|
|
node = next
|
|
}
|
|
|
|
// We already have this range so update the value
|
|
if next != nil {
|
|
node.value = val
|
|
return
|
|
}
|
|
|
|
// Build up the rest of the tree we don't already have
|
|
for bit&mask != 0 {
|
|
next = &CIDRNode{}
|
|
next.parent = node
|
|
|
|
if ip&bit != 0 {
|
|
node.right = next
|
|
} else {
|
|
node.left = next
|
|
}
|
|
|
|
bit >>= 1
|
|
node = next
|
|
}
|
|
|
|
// Final node marks our cidr, set the value
|
|
node.value = val
|
|
}
|
|
|
|
// Finds the first match, which way be the least specific
|
|
func (tree *CIDRTree) Contains(ip uint32) (value interface{}) {
|
|
bit := startbit
|
|
node := tree.root
|
|
|
|
for node != nil {
|
|
if node.value != nil {
|
|
return node.value
|
|
}
|
|
|
|
if ip&bit != 0 {
|
|
node = node.right
|
|
} else {
|
|
node = node.left
|
|
}
|
|
|
|
bit >>= 1
|
|
|
|
}
|
|
|
|
return value
|
|
}
|
|
|
|
// Finds the most specific match
|
|
func (tree *CIDRTree) Match(ip uint32) (value interface{}) {
|
|
bit := startbit
|
|
node := tree.root
|
|
lastNode := node
|
|
|
|
for node != nil {
|
|
lastNode = node
|
|
if ip&bit != 0 {
|
|
node = node.right
|
|
} else {
|
|
node = node.left
|
|
}
|
|
|
|
bit >>= 1
|
|
}
|
|
|
|
if bit == 0 && lastNode != nil {
|
|
value = lastNode.value
|
|
}
|
|
return value
|
|
}
|
|
|
|
// A helper type to avoid converting to IP when logging
|
|
type IntIp uint32
|
|
|
|
func (ip IntIp) String() string {
|
|
return fmt.Sprintf("%v", int2ip(uint32(ip)))
|
|
}
|
|
|
|
func (ip IntIp) MarshalJSON() ([]byte, error) {
|
|
return []byte(fmt.Sprintf("\"%s\"", int2ip(uint32(ip)).String())), nil
|
|
}
|
|
|
|
func ip2int(ip []byte) uint32 {
|
|
if len(ip) == 16 {
|
|
return binary.BigEndian.Uint32(ip[12:16])
|
|
}
|
|
return binary.BigEndian.Uint32(ip)
|
|
}
|
|
|
|
func int2ip(nn uint32) net.IP {
|
|
ip := make(net.IP, 4)
|
|
binary.BigEndian.PutUint32(ip, nn)
|
|
return ip
|
|
}
|