0
0
Fork 0
mirror of https://github.com/netdata/netdata.git synced 2025-05-04 01:10:03 +00:00

chore(go.d/pkg/socket): add err to callback return values ()

This commit is contained in:
Ilya Mashchenko 2024-11-29 12:43:26 +02:00 committed by GitHub
parent 28ee226945
commit 61fb6d77d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 476 additions and 475 deletions

View file

@ -88,9 +88,10 @@ func newBeanstalkConn(conf Config, log *logger.Logger) beanstalkConn {
return &beanstalkClient{ return &beanstalkClient{
Logger: log, Logger: log,
client: socket.New(socket.Config{ client: socket.New(socket.Config{
Address: conf.Address, Address: conf.Address,
Timeout: conf.Timeout.Duration(), Timeout: conf.Timeout.Duration(),
TLSConf: nil, MaxReadLines: 2000,
TLSConf: nil,
}), }),
} }
} }
@ -180,43 +181,34 @@ func (c *beanstalkClient) queryStatsTube(tubeName string) (*tubeStats, error) {
} }
func (c *beanstalkClient) query(command string) (string, []byte, error) { func (c *beanstalkClient) query(command string) (string, []byte, error) {
var resp string
var length int
var body []byte
var err error
c.Debugf("executing command: %s", command) c.Debugf("executing command: %s", command)
const limitReadLines = 1000 var (
var num int resp string
body []byte
length int
err error
)
clientErr := c.client.Command(command+"\r\n", func(line []byte) bool { if err := c.client.Command(command+"\r\n", func(line []byte) (bool, error) {
if resp == "" { if resp == "" {
s := string(line) s := string(line)
c.Debugf("command '%s' response: '%s'", command, s) c.Debugf("command '%s' response: '%s'", command, s)
resp, length, err = parseResponseLine(s) resp, length, err = parseResponseLine(s)
if err != nil { if err != nil {
err = fmt.Errorf("command '%s' line '%s': %v", command, s, err) return false, fmt.Errorf("command '%s' line '%s': %v", command, s, err)
} }
return err == nil && resp == "OK"
}
if num++; num >= limitReadLines { return resp == "OK", nil
err = fmt.Errorf("command '%s': read line limit exceeded (%d)", command, limitReadLines)
return false
} }
body = append(body, line...) body = append(body, line...)
body = append(body, '\n') body = append(body, '\n')
return len(body) < length return len(body) < length, nil
}) }); err != nil {
if clientErr != nil { return "", nil, fmt.Errorf("command '%s': %v", command, err)
return "", nil, fmt.Errorf("command '%s' client error: %v", command, clientErr)
}
if err != nil {
return "", nil, err
} }
return resp, body, nil return resp, body, nil

View file

@ -111,25 +111,20 @@ func (c *boincClient) send(req *boincRequest) (*boincReply, error) {
var b bytes.Buffer var b bytes.Buffer
clientErr := c.conn.Command(string(reqData), func(bs []byte) bool { if err := c.conn.Command(string(reqData), func(bs []byte) (bool, error) {
s := strings.TrimSpace(string(bs)) s := strings.TrimSpace(string(bs))
if s == "" { if s == "" {
return true return true, nil
} }
if b.Len() == 0 && s != respStart { if b.Len() == 0 && s != respStart {
err = fmt.Errorf("unexpected response first line: %s", s) return false, fmt.Errorf("unexpected response first line: %s", s)
return false
} }
b.WriteString(s) b.WriteString(s)
return s != respEnd return s != respEnd, nil
}) }); err != nil {
if clientErr != nil {
return nil, fmt.Errorf("failed to send command: %v", clientErr)
}
if err != nil {
return nil, fmt.Errorf("failed to send command: %v", err) return nil, fmt.Errorf("failed to send command: %v", err)
} }

View file

@ -37,14 +37,13 @@ func (c *dovecotClient) queryExportGlobal() ([]byte, error) {
var b bytes.Buffer var b bytes.Buffer
var n int var n int
err := c.conn.Command("EXPORT\tglobal\n", func(bs []byte) bool { if err := c.conn.Command("EXPORT\tglobal\n", func(bs []byte) (bool, error) {
b.Write(bs) b.Write(bs)
b.WriteByte('\n') b.WriteByte('\n')
n++ n++
return n < 2 return n < 2, nil
}) }); err != nil {
if err != nil {
return nil, err return nil, err
} }

View file

@ -19,8 +19,9 @@ type gearmanConn interface {
func newGearmanConn(conf Config) gearmanConn { func newGearmanConn(conf Config) gearmanConn {
return &gearmanClient{conn: socket.New(socket.Config{ return &gearmanClient{conn: socket.New(socket.Config{
Address: conf.Address, Address: conf.Address,
Timeout: conf.Timeout.Duration(), Timeout: conf.Timeout.Duration(),
MaxReadLines: 10000,
})} })}
} }
@ -45,32 +46,20 @@ func (c *gearmanClient) queryPriorityStatus() ([]byte, error) {
} }
func (c *gearmanClient) query(cmd string) ([]byte, error) { func (c *gearmanClient) query(cmd string) ([]byte, error) {
const limitReadLines = 10000
var num int
var err error
var b bytes.Buffer var b bytes.Buffer
clientErr := c.conn.Command(cmd+"\n", func(bs []byte) bool { if err := c.conn.Command(cmd+"\n", func(bs []byte) (bool, error) {
s := string(bs) s := string(bs)
if strings.HasPrefix(s, "ERR") { if strings.HasPrefix(s, "ERR") {
err = fmt.Errorf("command '%s': %s", cmd, s) return false, fmt.Errorf("command '%s': %s", cmd, s)
return false
} }
b.WriteString(s) b.WriteString(s)
b.WriteByte('\n') b.WriteByte('\n')
if num++; num >= limitReadLines { return !strings.HasPrefix(s, "."), nil
err = fmt.Errorf("command '%s': read line limit exceeded (%d)", cmd, limitReadLines) }); err != nil {
return false
}
return !strings.HasPrefix(s, ".")
})
if clientErr != nil {
return nil, fmt.Errorf("command '%s' client error: %v", cmd, clientErr)
}
if err != nil {
return nil, err return nil, err
} }

View file

@ -25,21 +25,15 @@ type hddtempClient struct {
} }
func (c *hddtempClient) queryHddTemp() (string, error) { func (c *hddtempClient) queryHddTemp() (string, error) {
var i int
var s string
cfg := socket.Config{ cfg := socket.Config{
Address: c.address, Address: c.address,
Timeout: c.timeout, Timeout: c.timeout,
} }
err := socket.ConnectAndRead(cfg, func(bs []byte) bool { var s string
if i++; i > 1 { err := socket.ConnectAndRead(cfg, func(bs []byte) (bool, error) {
return false
}
s = string(bs) s = string(bs)
return true return false, nil
}) })
if err != nil { if err != nil {
return "", err return "", err

View file

@ -36,13 +36,13 @@ func (c *memcachedClient) disconnect() {
func (c *memcachedClient) queryStats() ([]byte, error) { func (c *memcachedClient) queryStats() ([]byte, error) {
var b bytes.Buffer var b bytes.Buffer
err := c.conn.Command("stats\r\n", func(bytes []byte) bool { if err := c.conn.Command("stats\r\n", func(bytes []byte) (bool, error) {
s := strings.TrimSpace(string(bytes)) s := strings.TrimSpace(string(bytes))
b.WriteString(s) b.WriteString(s)
b.WriteByte('\n') b.WriteByte('\n')
return !(strings.HasPrefix(s, "END") || strings.HasPrefix(s, "ERROR"))
}) return !(strings.HasPrefix(s, "END") || strings.HasPrefix(s, "ERROR")), nil
if err != nil { }); err != nil {
return nil, err return nil, err
} }
return b.Bytes(), nil return b.Bytes(), nil

View file

@ -57,29 +57,26 @@ func (c *Client) Version() (*Version, error) {
func (c *Client) get(command string, stopRead stopReadFunc) (output []string, err error) { func (c *Client) get(command string, stopRead stopReadFunc) (output []string, err error) {
var num int var num int
var maxLinesErr error if err := c.Command(command, func(bytes []byte) (bool, error) {
err = c.Command(command, func(bytes []byte) bool {
line := string(bytes) line := string(bytes)
num++ num++
if num > maxLinesToRead { if num > maxLinesToRead {
maxLinesErr = fmt.Errorf("read line limit exceeded (%d)", maxLinesToRead) return false, fmt.Errorf("read line limit exceeded (%d)", maxLinesToRead)
return false
} }
// skip real-time messages // skip real-time messages
if strings.HasPrefix(line, ">") { if strings.HasPrefix(line, ">") {
return true return true, nil
} }
line = strings.Trim(line, "\r\n ") line = strings.Trim(line, "\r\n ")
output = append(output, line) output = append(output, line)
if stopRead != nil && stopRead(line) { if stopRead != nil && stopRead(line) {
return false return false, nil
} }
return true return true, nil
}) }); err != nil {
if maxLinesErr != nil { return nil, err
return nil, maxLinesErr
} }
return output, err return output, err
} }

View file

@ -98,7 +98,9 @@ func (m *mockSocketClient) Command(command string, process socket.Processor) err
} }
for s.Scan() { for s.Scan() {
process(s.Bytes()) if _, err := process(s.Bytes()); err != nil {
return err
}
} }
return nil return nil
} }

View file

@ -58,9 +58,9 @@ func (c *torControlClient) authenticate() error {
} }
var s string var s string
err := c.conn.Command(cmd+"\n", func(bs []byte) bool { err := c.conn.Command(cmd+"\n", func(bs []byte) (bool, error) {
s = string(bs) s = string(bs)
return false return false, nil
}) })
if err != nil { if err != nil {
return fmt.Errorf("authentication failed: %v", err) return fmt.Errorf("authentication failed: %v", err)
@ -74,7 +74,7 @@ func (c *torControlClient) authenticate() error {
func (c *torControlClient) disconnect() { func (c *torControlClient) disconnect() {
// https://spec.torproject.org/control-spec/commands.html#quit // https://spec.torproject.org/control-spec/commands.html#quit
_ = c.conn.Command(cmdQuit+"\n", func(bs []byte) bool { return false }) _ = c.conn.Command(cmdQuit+"\n", func(bs []byte) (bool, error) { return false, nil })
_ = c.conn.Disconnect() _ = c.conn.Disconnect()
} }
@ -87,27 +87,21 @@ func (c *torControlClient) getInfo(keywords ...string) ([]byte, error) {
cmd := fmt.Sprintf("%s %s", cmdGetInfo, strings.Join(keywords, " ")) cmd := fmt.Sprintf("%s %s", cmdGetInfo, strings.Join(keywords, " "))
var buf bytes.Buffer var buf bytes.Buffer
var err error
clientErr := c.conn.Command(cmd+"\n", func(bs []byte) bool { if err := c.conn.Command(cmd+"\n", func(bs []byte) (bool, error) {
s := string(bs) s := string(bs)
switch { switch {
case strings.HasPrefix(s, "250-"): case strings.HasPrefix(s, "250-"):
buf.WriteString(strings.TrimPrefix(s, "250-")) buf.WriteString(strings.TrimPrefix(s, "250-"))
buf.WriteByte('\n') buf.WriteByte('\n')
return true return true, nil
case strings.HasPrefix(s, "250 "): case strings.HasPrefix(s, "250 "):
return false return false, nil
default: default:
err = errors.New(s) return false, errors.New(s)
return false
} }
}) }); err != nil {
if clientErr != nil {
return nil, fmt.Errorf("command '%s' failed: %v", cmd, clientErr)
}
if err != nil {
return nil, fmt.Errorf("command '%s' failed: %v", cmd, err) return nil, fmt.Errorf("command '%s' failed: %v", cmd, err)
} }

View file

@ -36,9 +36,9 @@ func (c *Collector) scrapeUnboundStats() ([]entry, error) {
} }
defer func() { _ = c.client.Disconnect() }() defer func() { _ = c.client.Disconnect() }()
err := c.client.Command(command+"\n", func(bytes []byte) bool { err := c.client.Command(command+"\n", func(bytes []byte) (bool, error) {
output = append(output, string(bytes)) output = append(output, string(bytes))
return true return true, nil
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("send command '%s': %w", command, err) return nil, fmt.Errorf("send command '%s': %w", command, err)

View file

@ -133,7 +133,7 @@ func (c *upsdClient) sendCommand(cmd string) ([]string, error) {
var errMsg string var errMsg string
endLine := getEndLine(cmd) endLine := getEndLine(cmd)
err := c.conn.Command(cmd+"\n", func(bytes []byte) bool { err := c.conn.Command(cmd+"\n", func(bytes []byte) (bool, error) {
line := string(bytes) line := string(bytes)
resp = append(resp, line) resp = append(resp, line)
@ -141,7 +141,7 @@ func (c *upsdClient) sendCommand(cmd string) ([]string, error) {
errMsg = strings.TrimPrefix(line, "ERR ") errMsg = strings.TrimPrefix(line, "ERR ")
} }
return line != endLine && errMsg == "" return line != endLine && errMsg == "", nil
}) })
if err != nil { if err != nil {
return nil, err return nil, err

View file

@ -4,7 +4,6 @@ package uwsgi
import ( import (
"bytes" "bytes"
"fmt"
"time" "time"
"github.com/netdata/netdata/go/plugins/plugin/go.d/pkg/socket" "github.com/netdata/netdata/go/plugins/plugin/go.d/pkg/socket"
@ -28,30 +27,19 @@ type uwsgiClient struct {
func (c *uwsgiClient) queryStats() ([]byte, error) { func (c *uwsgiClient) queryStats() ([]byte, error) {
var b bytes.Buffer var b bytes.Buffer
var n int64
var err error
const readLineLimit = 1000 * 10
cfg := socket.Config{ cfg := socket.Config{
Address: c.address, Address: c.address,
Timeout: c.timeout, Timeout: c.timeout,
MaxReadLines: 1000 * 10,
} }
clientErr := socket.ConnectAndRead(cfg, func(bs []byte) bool { if err := socket.ConnectAndRead(cfg, func(bs []byte) (bool, error) {
b.Write(bs) b.Write(bs)
b.WriteByte('\n') b.WriteByte('\n')
if n++; n >= readLineLimit {
err = fmt.Errorf("read line limit exceeded %d", readLineLimit)
return false
}
// The server will close the connection when it has finished sending data. // The server will close the connection when it has finished sending data.
return true return true, nil
}) }); err != nil {
if clientErr != nil {
return nil, clientErr
}
if err != nil {
return nil, err return nil, err
} }

View file

@ -4,14 +4,11 @@ package zookeeper
import ( import (
"bytes" "bytes"
"fmt"
"unsafe" "unsafe"
"github.com/netdata/netdata/go/plugins/plugin/go.d/pkg/socket" "github.com/netdata/netdata/go/plugins/plugin/go.d/pkg/socket"
) )
const limitReadLines = 2000
type fetcher interface { type fetcher interface {
fetch(command string) ([]string, error) fetch(command string) ([]string, error)
} }
@ -26,21 +23,12 @@ func (c *zookeeperFetcher) fetch(command string) (rows []string, err error) {
} }
defer func() { _ = c.Disconnect() }() defer func() { _ = c.Disconnect() }()
var num int if err := c.Command(command, func(b []byte) (bool, error) {
clientErr := c.Command(command, func(b []byte) bool {
if !isZKLine(b) || isMntrLineOK(b) { if !isZKLine(b) || isMntrLineOK(b) {
rows = append(rows, string(b)) rows = append(rows, string(b))
} }
if num += 1; num >= limitReadLines { return true, nil
err = fmt.Errorf("read line limit exceeded (%d)", limitReadLines) }); err != nil {
return false
}
return true
})
if clientErr != nil {
return nil, clientErr
}
if err != nil {
return nil, err return nil, err
} }

View file

@ -22,14 +22,6 @@ func Test_clientFetch(t *testing.T) {
assert.Len(t, rows, 10) assert.Len(t, rows, 10)
} }
func Test_clientFetchReadLineLimitExceeded(t *testing.T) {
c := &zookeeperFetcher{Client: &mockSocket{rowsNumResp: limitReadLines + 1}}
rows, err := c.fetch("whatever\n")
assert.Error(t, err)
assert.Len(t, rows, 0)
}
type mockSocket struct { type mockSocket struct {
rowsNumResp int rowsNumResp int
} }
@ -44,7 +36,9 @@ func (m *mockSocket) Disconnect() error {
func (m *mockSocket) Command(command string, process socket.Processor) error { func (m *mockSocket) Command(command string, process socket.Processor) error {
for i := 0; i < m.rowsNumResp; i++ { for i := 0; i < m.rowsNumResp; i++ {
process([]byte(command)) if _, err := process([]byte(command)); err != nil {
return err
}
} }
return nil return nil
} }

View file

@ -30,9 +30,10 @@ func (c *Collector) initZookeeperFetcher() (fetcher, error) {
} }
sock := socket.New(socket.Config{ sock := socket.New(socket.Config{
Address: c.Address, Address: c.Address,
Timeout: c.Timeout.Duration(), Timeout: c.Timeout.Duration(),
TLSConf: tlsConf, TLSConf: tlsConf,
MaxReadLines: 2000,
}) })
return &zookeeperFetcher{Client: sock}, nil return &zookeeperFetcher{Client: sock}, nil

View file

@ -4,25 +4,29 @@ package socket
import ( import (
"bufio" "bufio"
"context"
"crypto/tls" "crypto/tls"
"errors" "errors"
"fmt"
"net" "net"
"time" "time"
) )
// Processor function passed to the Socket.Command function. // Processor is a callback function passed to the Socket.Command method.
// It is passed by the caller to process a command's response line by line. // It processes each response line received from the server.
type Processor func([]byte) bool type Processor func([]byte) (bool, error)
// Client is the interface that wraps the basic socket client operations // Client defines an interface for socket clients, abstracting the underlying implementation.
// and hides the implementation details from the users. // Implementations should provide connections for various socket types such as TCP, UDP, or Unix domain sockets.
// Implementations should return TCP, UDP or Unix ready sockets.
type Client interface { type Client interface {
Connect() error Connect() error
Disconnect() error Disconnect() error
Command(command string, process Processor) error Command(command string, process Processor) error
} }
// ConnectAndRead establishes a connection using the given configuration,
// executes the provided processor function on the incoming response lines,
// and ensures the connection is properly closed after use.
func ConnectAndRead(cfg Config, process Processor) error { func ConnectAndRead(cfg Config, process Processor) error {
sock := New(cfg) sock := New(cfg)
@ -35,46 +39,33 @@ func ConnectAndRead(cfg Config, process Processor) error {
return sock.read(process) return sock.read(process)
} }
// New returns a new pointer to a socket client given the socket // New creates and returns a new Socket instance configured with the provided settings.
// type (IP, TCP, UDP, UNIX), a network address (IP/domain:port), // The socket supports multiple types (TCP, UDP, UNIX), addresses (IPv4, IPv6, domain names),
// a timeout and a TLS config. It supports both IPv4 and IPv6 address // and optional TLS encryption. Connections are reused where possible.
// and reuses connection where possible.
func New(cfg Config) *Socket { func New(cfg Config) *Socket {
return &Socket{Config: cfg} return &Socket{Config: cfg}
} }
// Socket is the implementation of a socket client. // Socket is a concrete implementation of the Client interface, managing a network connection
// based on the specified configuration (address, type, timeout, and optional TLS settings).
type Socket struct { type Socket struct {
Config Config
conn net.Conn conn net.Conn
} }
// Config holds the network ip v4 or v6 address, port, // Config encapsulates the settings required to establish a network connection.
// Socket type(ip, tcp, udp, unix), timeout and TLS configuration for a Socket
type Config struct { type Config struct {
Address string Address string
Timeout time.Duration Timeout time.Duration
TLSConf *tls.Config TLSConf *tls.Config
MaxReadLines int64
} }
// Connect connects to the Socket address on the named network. // Connect establishes a connection to the specified address using the configuration details.
// If the address is a domain name it will also perform the DNS resolution.
// Address like :80 will attempt to connect to the localhost.
// The config timeout and TLS config will be used.
func (s *Socket) Connect() error { func (s *Socket) Connect() error {
network, address := networkType(s.Address) conn, err := s.dial()
var conn net.Conn
var err error
if s.TLSConf == nil {
conn, err = net.DialTimeout(network, address, s.timeout())
} else {
var d net.Dialer
d.Timeout = s.timeout()
conn, err = tls.DialWithDialer(&d, network, address, s.TLSConf)
}
if err != nil { if err != nil {
return err return fmt.Errorf("socket.Connect: %w", err)
} }
s.conn = conn s.conn = conn
@ -82,22 +73,19 @@ func (s *Socket) Connect() error {
return nil return nil
} }
// Disconnect closes the connection. // Disconnect terminates the active connection if one exists.
// Any in-flight commands will be cancelled and return errors. func (s *Socket) Disconnect() error {
func (s *Socket) Disconnect() (err error) { if s.conn == nil {
if s.conn != nil { return nil
err = s.conn.Close()
s.conn = nil
} }
err := s.conn.Close()
s.conn = nil
return err return err
} }
// Command writes the command string to the connection and passed the // Command sends a command string to the connected server and processes its response line by line
// response bytes line by line to the process function. It uses the // using the provided Processor function. This method respects the timeout configuration
// timeout value from the Socket config and returns read, write and // for write and read operations. If a timeout or processing error occurs, it stops and returns the error.
// timeout errors if any. If a timeout occurs during the processing
// of the responses this function will stop processing and return a
// timeout error.
func (s *Socket) Command(command string, process Processor) error { func (s *Socket) Command(command string, process Processor) error {
if s.conn == nil { if s.conn == nil {
return errors.New("cannot send command on nil connection") return errors.New("cannot send command on nil connection")
@ -112,10 +100,10 @@ func (s *Socket) Command(command string, process Processor) error {
func (s *Socket) write(command string) error { func (s *Socket) write(command string) error {
if s.conn == nil { if s.conn == nil {
return errors.New("attempt to write on nil connection") return errors.New("write: nil connection")
} }
if err := s.conn.SetWriteDeadline(time.Now().Add(s.timeout())); err != nil { if err := s.conn.SetWriteDeadline(s.deadline()); err != nil {
return err return err
} }
@ -126,25 +114,53 @@ func (s *Socket) write(command string) error {
func (s *Socket) read(process Processor) error { func (s *Socket) read(process Processor) error {
if process == nil { if process == nil {
return errors.New("process func is nil") return errors.New("read: process func is nil")
} }
if s.conn == nil { if s.conn == nil {
return errors.New("attempt to read on nil connection") return errors.New("read: nil connection")
} }
if err := s.conn.SetReadDeadline(time.Now().Add(s.timeout())); err != nil { if err := s.conn.SetReadDeadline(s.deadline()); err != nil {
return err return err
} }
sc := bufio.NewScanner(s.conn) sc := bufio.NewScanner(s.conn)
for sc.Scan() && process(sc.Bytes()) { var n int64
limit := s.MaxReadLines
for sc.Scan() {
more, err := process(sc.Bytes())
if err != nil {
return err
}
if n++; limit > 0 && n > limit {
return fmt.Errorf("read line limit exceeded (%d", limit)
}
if !more {
break
}
} }
return sc.Err() return sc.Err()
} }
func (s *Socket) dial() (net.Conn, error) {
network, address := parseAddress(s.Address)
var d net.Dialer
d.Timeout = s.timeout()
if s.TLSConf != nil {
return tls.DialWithDialer(&d, network, address, s.TLSConf)
}
return d.DialContext(context.Background(), network, address)
}
func (s *Socket) deadline() time.Time {
return time.Now().Add(s.timeout())
}
func (s *Socket) timeout() time.Duration { func (s *Socket) timeout() time.Duration {
if s.Timeout == 0 { if s.Timeout == 0 {
return time.Second return time.Second

View file

@ -3,152 +3,86 @@
package socket package socket
import ( import (
"crypto/tls"
"testing" "testing"
"time" "time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
const ( func TestSocket_Command(t *testing.T) {
testServerAddress = "127.0.0.1:9999" const (
testUdpServerAddress = "udp://127.0.0.1:9999" testServerAddress = "tcp://127.0.0.1:9999"
testUnixServerAddress = "/tmp/testSocketFD" testUdpServerAddress = "udp://127.0.0.1:9999"
defaultTimeout = 100 * time.Millisecond testUnixServerAddress = "unix:///tmp/testSocketFD"
) defaultTimeout = 1000 * time.Millisecond
)
var tcpConfig = Config{ type server interface {
Address: testServerAddress, Run() error
Timeout: defaultTimeout, Close() error
TLSConf: nil, }
}
tests := map[string]struct {
var udpConfig = Config{ srv server
Address: testUdpServerAddress, cfg Config
Timeout: defaultTimeout, wantConnectErr bool
TLSConf: nil, wantCommandErr bool
} }{
"tcp": {
var unixConfig = Config{ srv: newTCPServer(testServerAddress),
Address: testUnixServerAddress, cfg: Config{
Timeout: defaultTimeout, Address: testServerAddress,
TLSConf: nil, Timeout: defaultTimeout,
} },
},
var tcpTlsConfig = Config{ "udp": {
Address: testServerAddress, srv: newUDPServer(testUdpServerAddress),
Timeout: defaultTimeout, cfg: Config{
TLSConf: &tls.Config{}, Address: testUdpServerAddress,
} Timeout: defaultTimeout,
},
func Test_clientCommand(t *testing.T) { },
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1} "unix": {
go func() { _ = srv.Run(); defer func() { _ = srv.Close() }() }() srv: newUnixServer(testUnixServerAddress),
cfg: Config{
time.Sleep(time.Millisecond * 100) Address: testUnixServerAddress,
sock := New(tcpConfig) Timeout: defaultTimeout,
require.NoError(t, sock.Connect()) },
err := sock.Command("ping\n", func(bytes []byte) bool { },
assert.Equal(t, "pong", string(bytes)) }
return true
}) for name, test := range tests {
require.NoError(t, sock.Disconnect()) t.Run(name, func(t *testing.T) {
require.NoError(t, err) go func() {
} defer func() { _ = test.srv.Close() }()
require.NoError(t, test.srv.Run())
func Test_clientTimeout(t *testing.T) { }()
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1} time.Sleep(time.Millisecond * 500)
go func() { _ = srv.Run() }()
sock := New(test.cfg)
time.Sleep(time.Millisecond * 100)
sock := New(tcpConfig) err := sock.Connect()
require.NoError(t, sock.Connect())
sock.Timeout = 0 if test.wantConnectErr {
err := sock.Command("ping\n", func(bytes []byte) bool { require.Error(t, err)
assert.Equal(t, "pong", string(bytes)) return
return true }
}) require.NoError(t, err)
require.NoError(t, err)
} defer sock.Disconnect()
func Test_clientIncompleteSSL(t *testing.T) { var resp string
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1} err = sock.Command("ping\n", func(bytes []byte) (bool, error) {
go func() { _ = srv.Run() }() resp = string(bytes)
return false, nil
time.Sleep(time.Millisecond * 100) })
sock := New(tcpTlsConfig)
err := sock.Connect() if test.wantCommandErr {
require.Error(t, err) require.Error(t, err)
} } else {
require.NoError(t, err)
func Test_clientCommandStopProcessing(t *testing.T) { require.Equal(t, "pong", resp)
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 2} }
go func() { _ = srv.Run() }() })
}
time.Sleep(time.Millisecond * 100)
sock := New(tcpConfig)
require.NoError(t, sock.Connect())
err := sock.Command("ping\n", func(bytes []byte) bool {
assert.Equal(t, "pong", string(bytes))
return false
})
require.NoError(t, sock.Disconnect())
require.NoError(t, err)
}
func Test_clientUDPCommand(t *testing.T) {
srv := &udpServer{addr: testServerAddress, rowsNumResp: 1}
go func() { _ = srv.Run(); defer func() { _ = srv.Close() }() }()
time.Sleep(time.Millisecond * 100)
sock := New(udpConfig)
require.NoError(t, sock.Connect())
err := sock.Command("ping\n", func(bytes []byte) bool {
assert.Equal(t, "pong", string(bytes))
return false
})
require.NoError(t, sock.Disconnect())
require.NoError(t, err)
}
func Test_clientTCPAddress(t *testing.T) {
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1}
go func() { _ = srv.Run() }()
time.Sleep(time.Millisecond * 100)
sock := New(tcpConfig)
require.NoError(t, sock.Connect())
tcpConfig.Address = "tcp://" + tcpConfig.Address
sock = New(tcpConfig)
require.NoError(t, sock.Connect())
}
func Test_clientUnixCommand(t *testing.T) {
srv := &unixServer{addr: testUnixServerAddress, rowsNumResp: 1}
// cleanup previous file descriptors
_ = srv.Close()
go func() { _ = srv.Run() }()
time.Sleep(time.Millisecond * 200)
sock := New(unixConfig)
require.NoError(t, sock.Connect())
err := sock.Command("ping\n", func(bytes []byte) bool {
assert.Equal(t, "pong", string(bytes))
return false
})
require.NoError(t, err)
require.NoError(t, sock.Disconnect())
}
func Test_clientEmptyProcessFunc(t *testing.T) {
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1}
go func() { _ = srv.Run() }()
time.Sleep(time.Millisecond * 100)
sock := New(tcpConfig)
require.NoError(t, sock.Connect())
err := sock.Command("ping\n", nil)
require.Error(t, err, "nil process func should return an error")
} }

View file

@ -0,0 +1,257 @@
// SPDX-License-Identifier: GPL-3.0-or-later
package socket
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"os"
"sync"
"time"
)
func newTCPServer(addr string) *tcpServer {
ctx, cancel := context.WithCancel(context.Background())
_, addr = parseAddress(addr)
return &tcpServer{
addr: addr,
ctx: ctx,
cancel: cancel,
}
}
type tcpServer struct {
addr string
listener net.Listener
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
func (t *tcpServer) Run() error {
var err error
t.listener, err = net.Listen("tcp", t.addr)
if err != nil {
return fmt.Errorf("failed to start TCP server: %w", err)
}
return t.handleConnections()
}
func (t *tcpServer) Close() (err error) {
t.cancel()
if t.listener != nil {
if err := t.listener.Close(); err != nil {
return fmt.Errorf("failed to close TCP server: %w", err)
}
}
t.wg.Wait()
return nil
}
func (t *tcpServer) handleConnections() (err error) {
for {
select {
case <-t.ctx.Done():
return nil
default:
conn, err := t.listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return fmt.Errorf("could not accept connection: %v", err)
}
t.wg.Add(1)
go func() {
defer t.wg.Done()
t.handleConnection(conn)
}()
}
}
}
func (t *tcpServer) handleConnection(conn net.Conn) {
defer func() { _ = conn.Close() }()
if err := conn.SetDeadline(time.Now().Add(time.Second)); err != nil {
return
}
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
if _, err := rw.ReadString('\n'); err != nil {
writeResponse(rw, fmt.Sprintf("failed to read input: %v\n", err))
} else {
writeResponse(rw, "pong\n")
}
}
func newUDPServer(addr string) *udpServer {
ctx, cancel := context.WithCancel(context.Background())
_, addr = parseAddress(addr)
return &udpServer{
addr: addr,
ctx: ctx,
cancel: cancel,
}
}
type udpServer struct {
addr string
conn *net.UDPConn
ctx context.Context
cancel context.CancelFunc
}
func (u *udpServer) Run() error {
addr, err := net.ResolveUDPAddr("udp", u.addr)
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %w", err)
}
u.conn, err = net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("failed to start UDP server: %w", err)
}
return u.handleConnections()
}
func (u *udpServer) Close() (err error) {
u.cancel()
if u.conn != nil {
if err := u.conn.Close(); err != nil {
return fmt.Errorf("failed to close UDP server: %w", err)
}
}
return nil
}
func (u *udpServer) handleConnections() error {
buffer := make([]byte, 8192)
for {
select {
case <-u.ctx.Done():
return nil
default:
if err := u.conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
continue
}
_, addr, err := u.conn.ReadFromUDP(buffer[0:])
if err != nil {
if !errors.Is(err, os.ErrDeadlineExceeded) {
return fmt.Errorf("failed to read UDP packet: %w", err)
}
continue
}
if _, err := u.conn.WriteToUDP([]byte("pong\n"), addr); err != nil {
return fmt.Errorf("failed to write UDP response: %w", err)
}
}
}
}
func newUnixServer(addr string) *unixServer {
ctx, cancel := context.WithCancel(context.Background())
_, addr = parseAddress(addr)
return &unixServer{
addr: addr,
ctx: ctx,
cancel: cancel,
}
}
type unixServer struct {
addr string
listener *net.UnixListener
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
func (u *unixServer) Run() error {
if err := os.Remove(u.addr); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to clean up existing socket: %w", err)
}
addr, err := net.ResolveUnixAddr("unix", u.addr)
if err != nil {
return fmt.Errorf("failed to resolve Unix address: %w", err)
}
u.listener, err = net.ListenUnix("unix", addr)
if err != nil {
return fmt.Errorf("failed to start Unix server: %w", err)
}
return u.handleConnections()
}
func (u *unixServer) Close() error {
u.cancel()
if u.listener != nil {
if err := u.listener.Close(); err != nil {
return fmt.Errorf("failed to close Unix server: %w", err)
}
}
u.wg.Wait()
_ = os.Remove(u.addr)
return nil
}
func (u *unixServer) handleConnections() error {
for {
select {
case <-u.ctx.Done():
return nil
default:
if err := u.listener.SetDeadline(time.Now().Add(time.Second)); err != nil {
continue
}
conn, err := u.listener.AcceptUnix()
if err != nil {
if !errors.Is(err, os.ErrDeadlineExceeded) {
return err
}
continue
}
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.handleConnection(conn)
}()
}
}
}
func (u *unixServer) handleConnection(conn net.Conn) {
defer func() { _ = conn.Close() }()
if err := conn.SetDeadline(time.Now().Add(time.Second)); err != nil {
return
}
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
if _, err := rw.ReadString('\n'); err != nil {
writeResponse(rw, fmt.Sprintf("failed to read input: %v\n", err))
} else {
writeResponse(rw, "pong\n")
}
}
func writeResponse(rw *bufio.ReadWriter, response string) {
_, _ = rw.WriteString(response)
_ = rw.Flush()
}

View file

@ -1,139 +0,0 @@
// SPDX-License-Identifier: GPL-3.0-or-later
package socket
import (
"bufio"
"errors"
"fmt"
"net"
"os"
"strings"
"time"
)
type tcpServer struct {
addr string
server net.Listener
rowsNumResp int
}
func (t *tcpServer) Run() (err error) {
t.server, err = net.Listen("tcp", t.addr)
if err != nil {
return
}
return t.handleConnections()
}
func (t *tcpServer) Close() (err error) {
return t.server.Close()
}
func (t *tcpServer) handleConnections() (err error) {
for {
conn, err := t.server.Accept()
if err != nil || conn == nil {
return errors.New("could not accept connection")
}
t.handleConnection(conn)
}
}
func (t *tcpServer) handleConnection(conn net.Conn) {
defer func() { _ = conn.Close() }()
_ = conn.SetDeadline(time.Now().Add(time.Millisecond * 100))
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
_, err := rw.ReadString('\n')
if err != nil {
_, _ = rw.WriteString("failed to read input")
_ = rw.Flush()
} else {
resp := strings.Repeat("pong\n", t.rowsNumResp)
_, _ = rw.WriteString(resp)
_ = rw.Flush()
}
}
type udpServer struct {
addr string
conn *net.UDPConn
rowsNumResp int
}
func (u *udpServer) Run() (err error) {
addr, err := net.ResolveUDPAddr("udp", u.addr)
if err != nil {
return err
}
u.conn, err = net.ListenUDP("udp", addr)
if err != nil {
return
}
u.handleConnections()
return nil
}
func (u *udpServer) Close() (err error) {
return u.conn.Close()
}
func (u *udpServer) handleConnections() {
for {
var buf [2048]byte
_, addr, _ := u.conn.ReadFromUDP(buf[0:])
resp := strings.Repeat("pong\n", u.rowsNumResp)
_, _ = u.conn.WriteToUDP([]byte(resp), addr)
}
}
type unixServer struct {
addr string
conn *net.UnixListener
rowsNumResp int
}
func (u *unixServer) Run() (err error) {
_, _ = os.CreateTemp("/tmp", "testSocketFD")
addr, err := net.ResolveUnixAddr("unix", u.addr)
if err != nil {
return err
}
u.conn, err = net.ListenUnix("unix", addr)
if err != nil {
return
}
go u.handleConnections()
return nil
}
func (u *unixServer) Close() (err error) {
_ = os.Remove(testUnixServerAddress)
return u.conn.Close()
}
func (u *unixServer) handleConnections() {
var conn net.Conn
var err error
conn, err = u.conn.AcceptUnix()
if err != nil {
panic(fmt.Errorf("could not accept connection: %v", err))
}
u.handleConnection(conn)
}
func (u *unixServer) handleConnection(conn net.Conn) {
_ = conn.SetDeadline(time.Now().Add(time.Second))
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
_, err := rw.ReadString('\n')
if err != nil {
_, _ = rw.WriteString("failed to read input")
_ = rw.Flush()
} else {
resp := strings.Repeat("pong\n", u.rowsNumResp)
_, _ = rw.WriteString(resp)
_ = rw.Flush()
}
}

View file

@ -12,7 +12,7 @@ func IsUdpSocket(address string) bool {
return strings.HasPrefix(address, "udp://") return strings.HasPrefix(address, "udp://")
} }
func networkType(address string) (string, string) { func parseAddress(address string) (string, string) {
switch { switch {
case IsUnixSocket(address): case IsUnixSocket(address):
address = strings.TrimPrefix(address, "unix://") address = strings.TrimPrefix(address, "unix://")