mirror of
https://github.com/slackhq/nebula.git
synced 2025-01-11 03:48:12 +00:00
9af242dc47
These new helpers make the code a lot cleaner. I confirmed that the simple helpers like `atomic.Int64` don't add any extra overhead as they get inlined by the compiler. `atomic.Pointer` adds an extra method call as it no longer gets inlined, but we aren't using these on the hot path so it is probably okay.
209 lines
5.4 KiB
Go
209 lines
5.4 KiB
Go
//go:build windows
|
|
// +build windows
|
|
|
|
/* SPDX-License-Identifier: MIT
|
|
*
|
|
* Copyright (C) 2018-2021 WireGuard LLC. All Rights Reserved.
|
|
*/
|
|
|
|
//NOTE: This file was forked from https://git.zx2c4.com/wireguard-go/tree/tun/tun_windows.go?id=851efb1bb65555e0f765a3361c8eb5ac47435b19
|
|
// Mainly to shed functionality we won't be using and to fix names that display in the system
|
|
|
|
package wintun
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"os"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
_ "unsafe"
|
|
|
|
"golang.org/x/sys/windows"
|
|
|
|
"golang.zx2c4.com/wintun"
|
|
)
|
|
|
|
const (
|
|
rateMeasurementGranularity = uint64((time.Second / 2) / time.Nanosecond)
|
|
spinloopRateThreshold = 800000000 / 8 // 800mbps
|
|
spinloopDuration = uint64(time.Millisecond / 80 / time.Nanosecond) // ~1gbit/s
|
|
)
|
|
|
|
type rateJuggler struct {
|
|
current uint64
|
|
nextByteCount uint64
|
|
nextStartTime int64
|
|
changing int32
|
|
}
|
|
|
|
type NativeTun struct {
|
|
wt *wintun.Adapter
|
|
name string
|
|
handle windows.Handle
|
|
rate rateJuggler
|
|
session wintun.Session
|
|
readWait windows.Handle
|
|
running sync.WaitGroup
|
|
closeOnce sync.Once
|
|
close int32
|
|
}
|
|
|
|
var WintunTunnelType = "Nebula"
|
|
var WintunStaticRequestedGUID *windows.GUID
|
|
|
|
//go:linkname procyield runtime.procyield
|
|
func procyield(cycles uint32)
|
|
|
|
//go:linkname nanotime runtime.nanotime
|
|
func nanotime() int64
|
|
|
|
// CreateTUN creates a Wintun interface with the given name. Should a Wintun
|
|
// interface with the same name exist, it is reused.
|
|
func CreateTUN(ifname string, mtu int) (Device, error) {
|
|
return CreateTUNWithRequestedGUID(ifname, WintunStaticRequestedGUID, mtu)
|
|
}
|
|
|
|
// CreateTUNWithRequestedGUID creates a Wintun interface with the given name and
|
|
// a requested GUID. Should a Wintun interface with the same name exist, it is reused.
|
|
func CreateTUNWithRequestedGUID(ifname string, requestedGUID *windows.GUID, mtu int) (Device, error) {
|
|
wt, err := wintun.CreateAdapter(ifname, WintunTunnelType, requestedGUID)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("Error creating interface: %w", err)
|
|
}
|
|
|
|
tun := &NativeTun{
|
|
wt: wt,
|
|
name: ifname,
|
|
handle: windows.InvalidHandle,
|
|
}
|
|
|
|
tun.session, err = wt.StartSession(0x800000) // Ring capacity, 8 MiB
|
|
if err != nil {
|
|
tun.wt.Close()
|
|
return nil, fmt.Errorf("Error starting session: %w", err)
|
|
}
|
|
tun.readWait = tun.session.ReadWaitEvent()
|
|
return tun, nil
|
|
}
|
|
|
|
func (tun *NativeTun) Name() (string, error) {
|
|
return tun.name, nil
|
|
}
|
|
|
|
func (tun *NativeTun) File() *os.File {
|
|
return nil
|
|
}
|
|
|
|
func (tun *NativeTun) Close() error {
|
|
var err error
|
|
tun.closeOnce.Do(func() {
|
|
atomic.StoreInt32(&tun.close, 1)
|
|
windows.SetEvent(tun.readWait)
|
|
tun.running.Wait()
|
|
tun.session.End()
|
|
if tun.wt != nil {
|
|
tun.wt.Close()
|
|
}
|
|
})
|
|
return err
|
|
}
|
|
|
|
// Note: Read() and Write() assume the caller comes only from a single thread; there's no locking.
|
|
|
|
func (tun *NativeTun) Read(buff []byte, offset int) (int, error) {
|
|
tun.running.Add(1)
|
|
defer tun.running.Done()
|
|
retry:
|
|
if atomic.LoadInt32(&tun.close) == 1 {
|
|
return 0, os.ErrClosed
|
|
}
|
|
start := nanotime()
|
|
shouldSpin := atomic.LoadUint64(&tun.rate.current) >= spinloopRateThreshold && uint64(start-atomic.LoadInt64(&tun.rate.nextStartTime)) <= rateMeasurementGranularity*2
|
|
for {
|
|
if atomic.LoadInt32(&tun.close) == 1 {
|
|
return 0, os.ErrClosed
|
|
}
|
|
packet, err := tun.session.ReceivePacket()
|
|
switch err {
|
|
case nil:
|
|
packetSize := len(packet)
|
|
copy(buff[offset:], packet)
|
|
tun.session.ReleaseReceivePacket(packet)
|
|
tun.rate.update(uint64(packetSize))
|
|
return packetSize, nil
|
|
case windows.ERROR_NO_MORE_ITEMS:
|
|
if !shouldSpin || uint64(nanotime()-start) >= spinloopDuration {
|
|
windows.WaitForSingleObject(tun.readWait, windows.INFINITE)
|
|
goto retry
|
|
}
|
|
procyield(1)
|
|
continue
|
|
case windows.ERROR_HANDLE_EOF:
|
|
return 0, os.ErrClosed
|
|
case windows.ERROR_INVALID_DATA:
|
|
return 0, errors.New("Send ring corrupt")
|
|
}
|
|
return 0, fmt.Errorf("Read failed: %w", err)
|
|
}
|
|
}
|
|
|
|
func (tun *NativeTun) Flush() error {
|
|
return nil
|
|
}
|
|
|
|
func (tun *NativeTun) Write(buff []byte, offset int) (int, error) {
|
|
tun.running.Add(1)
|
|
defer tun.running.Done()
|
|
if atomic.LoadInt32(&tun.close) == 1 {
|
|
return 0, os.ErrClosed
|
|
}
|
|
|
|
packetSize := len(buff) - offset
|
|
tun.rate.update(uint64(packetSize))
|
|
|
|
packet, err := tun.session.AllocateSendPacket(packetSize)
|
|
if err == nil {
|
|
copy(packet, buff[offset:])
|
|
tun.session.SendPacket(packet)
|
|
return packetSize, nil
|
|
}
|
|
switch err {
|
|
case windows.ERROR_HANDLE_EOF:
|
|
return 0, os.ErrClosed
|
|
case windows.ERROR_BUFFER_OVERFLOW:
|
|
return 0, nil // Dropping when ring is full.
|
|
}
|
|
return 0, fmt.Errorf("Write failed: %w", err)
|
|
}
|
|
|
|
// LUID returns Windows interface instance ID.
|
|
func (tun *NativeTun) LUID() uint64 {
|
|
tun.running.Add(1)
|
|
defer tun.running.Done()
|
|
if atomic.LoadInt32(&tun.close) == 1 {
|
|
return 0
|
|
}
|
|
return tun.wt.LUID()
|
|
}
|
|
|
|
// RunningVersion returns the running version of the Wintun driver.
|
|
func (tun *NativeTun) RunningVersion() (version uint32, err error) {
|
|
return wintun.RunningVersion()
|
|
}
|
|
|
|
func (rate *rateJuggler) update(packetLen uint64) {
|
|
now := nanotime()
|
|
total := atomic.AddUint64(&rate.nextByteCount, packetLen)
|
|
period := uint64(now - atomic.LoadInt64(&rate.nextStartTime))
|
|
if period >= rateMeasurementGranularity {
|
|
if !atomic.CompareAndSwapInt32(&rate.changing, 0, 1) {
|
|
return
|
|
}
|
|
atomic.StoreInt64(&rate.nextStartTime, now)
|
|
atomic.StoreUint64(&rate.current, total*uint64(time.Second/time.Nanosecond)/period)
|
|
atomic.StoreUint64(&rate.nextByteCount, 0)
|
|
atomic.StoreInt32(&rate.changing, 0)
|
|
}
|
|
}
|