From 2a1fd2e01875babf3cad299c2ae860aed0f72a2d Mon Sep 17 00:00:00 2001
From: Joachim Bauch <bauch@struktur.de>
Date: Tue, 28 May 2024 12:18:08 +0200
Subject: [PATCH] Support reloading allowed stats IPs.

---
 backend_server.go     | 31 +++++++++++++++++++++++++------
 proxy/proxy_server.go | 22 ++++++++++++++++++----
 server/main.go        |  1 +
 3 files changed, 44 insertions(+), 10 deletions(-)

diff --git a/backend_server.go b/backend_server.go
index 309c21a..016f0eb 100644
--- a/backend_server.go
+++ b/backend_server.go
@@ -68,7 +68,7 @@ type BackendServer struct {
 	turnvalid   time.Duration
 	turnservers []string
 
-	statsAllowedIps *AllowedIps
+	statsAllowedIps atomic.Pointer[AllowedIps]
 	invalidSecret   []byte
 }
 
@@ -120,7 +120,7 @@ func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*Bac
 		return nil, err
 	}
 
-	return &BackendServer{
+	result := &BackendServer{
 		hub:          hub,
 		events:       hub.events,
 		roomSessions: hub.roomSessions,
@@ -131,9 +131,27 @@ func NewBackendServer(config *goconf.ConfigFile, hub *Hub, version string) (*Bac
 		turnvalid:   turnvalid,
 		turnservers: turnserverslist,
 
-		statsAllowedIps: statsAllowedIps,
-		invalidSecret:   invalidSecret,
-	}, nil
+		invalidSecret: invalidSecret,
+	}
+
+	result.statsAllowedIps.Store(statsAllowedIps)
+
+	return result, nil
+}
+
+func (b *BackendServer) Reload(config *goconf.ConfigFile) {
+	statsAllowed, _ := config.GetString("stats", "allowed_ips")
+	if statsAllowedIps, err := ParseAllowedIps(statsAllowed); err == nil {
+		if !statsAllowedIps.Empty() {
+			log.Printf("Only allowing access to the stats endpoint from %s", statsAllowed)
+		} else {
+			log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1")
+			statsAllowedIps = DefaultAllowedIps()
+		}
+		b.statsAllowedIps.Store(statsAllowedIps)
+	} else {
+		log.Printf("Error parsing allowed stats ips from \"%s\": %s", statsAllowedIps, err)
+	}
 }
 
 func (b *BackendServer) Start(r *mux.Router) error {
@@ -899,7 +917,8 @@ func (b *BackendServer) allowStatsAccess(r *http.Request) bool {
 		return false
 	}
 
-	return b.statsAllowedIps.Allowed(ip)
+	allowed := b.statsAllowedIps.Load()
+	return allowed != nil && allowed.Allowed(ip)
 }
 
 func (b *BackendServer) validateStatsRequest(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
diff --git a/proxy/proxy_server.go b/proxy/proxy_server.go
index 2dfa774..b679536 100644
--- a/proxy/proxy_server.go
+++ b/proxy/proxy_server.go
@@ -111,7 +111,7 @@ type ProxyServer struct {
 	upgrader websocket.Upgrader
 
 	tokens          ProxyTokens
-	statsAllowedIps *signaling.AllowedIps
+	statsAllowedIps atomic.Pointer[signaling.AllowedIps]
 	trustedProxies  atomic.Pointer[signaling.AllowedIps]
 
 	sid          atomic.Uint64
@@ -319,8 +319,7 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (*
 			WriteBufferSize: websocketWriteBufferSize,
 		},
 
-		tokens:          tokens,
-		statsAllowedIps: statsAllowedIps,
+		tokens: tokens,
 
 		cookie:   securecookie.New(hashKey, blockKey).MaxAge(0),
 		sessions: make(map[uint64]*ProxySession),
@@ -335,6 +334,7 @@ func NewProxyServer(r *mux.Router, version string, config *goconf.ConfigFile) (*
 		remoteConnections: make(map[string]*RemoteConnection),
 	}
 
+	result.statsAllowedIps.Store(statsAllowedIps)
 	result.trustedProxies.Store(trustedProxiesIps)
 	result.upgrader.CheckOrigin = result.checkOrigin
 
@@ -548,6 +548,19 @@ func (s *ProxyServer) ScheduleShutdown() {
 }
 
 func (s *ProxyServer) Reload(config *goconf.ConfigFile) {
+	statsAllowed, _ := config.GetString("stats", "allowed_ips")
+	if statsAllowedIps, err := signaling.ParseAllowedIps(statsAllowed); err == nil {
+		if !statsAllowedIps.Empty() {
+			log.Printf("Only allowing access to the stats endpoint from %s", statsAllowed)
+		} else {
+			log.Printf("No IPs configured for the stats endpoint, only allowing access from 127.0.0.1")
+			statsAllowedIps = signaling.DefaultAllowedIps()
+		}
+		s.statsAllowedIps.Store(statsAllowedIps)
+	} else {
+		log.Printf("Error parsing allowed stats ips from \"%s\": %s", statsAllowedIps, err)
+	}
+
 	trustedProxies, _ := config.GetString("app", "trustedproxies")
 	if trustedProxiesIps, err := signaling.ParseAllowedIps(trustedProxies); err == nil {
 		if !trustedProxiesIps.Empty() {
@@ -1396,7 +1409,8 @@ func (s *ProxyServer) allowStatsAccess(r *http.Request) bool {
 		return false
 	}
 
-	return s.statsAllowedIps.Allowed(ip)
+	allowed := s.statsAllowedIps.Load()
+	return allowed != nil && allowed.Allowed(ip)
 }
 
 func (s *ProxyServer) validateStatsRequest(f func(http.ResponseWriter, *http.Request)) func(http.ResponseWriter, *http.Request) {
diff --git a/server/main.go b/server/main.go
index 2a5dc60..b8ace97 100644
--- a/server/main.go
+++ b/server/main.go
@@ -417,6 +417,7 @@ loop:
 					log.Printf("Could not read configuration from %s: %s", *configFlag, err)
 				} else {
 					hub.Reload(config)
+					server.Reload(config)
 				}
 			case syscall.SIGUSR1:
 				log.Printf("Received SIGUSR1, scheduling server to shutdown")