0
0
Fork 0
mirror of https://github.com/netdata/netdata.git synced 2025-04-25 05:31:37 +00:00

chore(go.d/pkg/socket): add err to callback return values ()

This commit is contained in:
Ilya Mashchenko 2024-11-29 12:43:26 +02:00 committed by GitHub
parent 28ee226945
commit 61fb6d77d3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
20 changed files with 476 additions and 475 deletions

View file

@ -88,9 +88,10 @@ func newBeanstalkConn(conf Config, log *logger.Logger) beanstalkConn {
return &beanstalkClient{
Logger: log,
client: socket.New(socket.Config{
Address: conf.Address,
Timeout: conf.Timeout.Duration(),
TLSConf: nil,
Address: conf.Address,
Timeout: conf.Timeout.Duration(),
MaxReadLines: 2000,
TLSConf: nil,
}),
}
}
@ -180,43 +181,34 @@ func (c *beanstalkClient) queryStatsTube(tubeName string) (*tubeStats, error) {
}
func (c *beanstalkClient) query(command string) (string, []byte, error) {
var resp string
var length int
var body []byte
var err error
c.Debugf("executing command: %s", command)
const limitReadLines = 1000
var num int
var (
resp string
body []byte
length int
err error
)
clientErr := c.client.Command(command+"\r\n", func(line []byte) bool {
if err := c.client.Command(command+"\r\n", func(line []byte) (bool, error) {
if resp == "" {
s := string(line)
c.Debugf("command '%s' response: '%s'", command, s)
resp, length, err = parseResponseLine(s)
if err != nil {
err = fmt.Errorf("command '%s' line '%s': %v", command, s, err)
return false, fmt.Errorf("command '%s' line '%s': %v", command, s, err)
}
return err == nil && resp == "OK"
}
if num++; num >= limitReadLines {
err = fmt.Errorf("command '%s': read line limit exceeded (%d)", command, limitReadLines)
return false
return resp == "OK", nil
}
body = append(body, line...)
body = append(body, '\n')
return len(body) < length
})
if clientErr != nil {
return "", nil, fmt.Errorf("command '%s' client error: %v", command, clientErr)
}
if err != nil {
return "", nil, err
return len(body) < length, nil
}); err != nil {
return "", nil, fmt.Errorf("command '%s': %v", command, err)
}
return resp, body, nil

View file

@ -111,25 +111,20 @@ func (c *boincClient) send(req *boincRequest) (*boincReply, error) {
var b bytes.Buffer
clientErr := c.conn.Command(string(reqData), func(bs []byte) bool {
if err := c.conn.Command(string(reqData), func(bs []byte) (bool, error) {
s := strings.TrimSpace(string(bs))
if s == "" {
return true
return true, nil
}
if b.Len() == 0 && s != respStart {
err = fmt.Errorf("unexpected response first line: %s", s)
return false
return false, fmt.Errorf("unexpected response first line: %s", s)
}
b.WriteString(s)
return s != respEnd
})
if clientErr != nil {
return nil, fmt.Errorf("failed to send command: %v", clientErr)
}
if err != nil {
return s != respEnd, nil
}); err != nil {
return nil, fmt.Errorf("failed to send command: %v", err)
}

View file

@ -37,14 +37,13 @@ func (c *dovecotClient) queryExportGlobal() ([]byte, error) {
var b bytes.Buffer
var n int
err := c.conn.Command("EXPORT\tglobal\n", func(bs []byte) bool {
if err := c.conn.Command("EXPORT\tglobal\n", func(bs []byte) (bool, error) {
b.Write(bs)
b.WriteByte('\n')
n++
return n < 2
})
if err != nil {
return n < 2, nil
}); err != nil {
return nil, err
}

View file

@ -19,8 +19,9 @@ type gearmanConn interface {
func newGearmanConn(conf Config) gearmanConn {
return &gearmanClient{conn: socket.New(socket.Config{
Address: conf.Address,
Timeout: conf.Timeout.Duration(),
Address: conf.Address,
Timeout: conf.Timeout.Duration(),
MaxReadLines: 10000,
})}
}
@ -45,32 +46,20 @@ func (c *gearmanClient) queryPriorityStatus() ([]byte, error) {
}
func (c *gearmanClient) query(cmd string) ([]byte, error) {
const limitReadLines = 10000
var num int
var err error
var b bytes.Buffer
clientErr := c.conn.Command(cmd+"\n", func(bs []byte) bool {
if err := c.conn.Command(cmd+"\n", func(bs []byte) (bool, error) {
s := string(bs)
if strings.HasPrefix(s, "ERR") {
err = fmt.Errorf("command '%s': %s", cmd, s)
return false
return false, fmt.Errorf("command '%s': %s", cmd, s)
}
b.WriteString(s)
b.WriteByte('\n')
if num++; num >= limitReadLines {
err = fmt.Errorf("command '%s': read line limit exceeded (%d)", cmd, limitReadLines)
return false
}
return !strings.HasPrefix(s, ".")
})
if clientErr != nil {
return nil, fmt.Errorf("command '%s' client error: %v", cmd, clientErr)
}
if err != nil {
return !strings.HasPrefix(s, "."), nil
}); err != nil {
return nil, err
}

View file

@ -25,21 +25,15 @@ type hddtempClient struct {
}
func (c *hddtempClient) queryHddTemp() (string, error) {
var i int
var s string
cfg := socket.Config{
Address: c.address,
Timeout: c.timeout,
}
err := socket.ConnectAndRead(cfg, func(bs []byte) bool {
if i++; i > 1 {
return false
}
var s string
err := socket.ConnectAndRead(cfg, func(bs []byte) (bool, error) {
s = string(bs)
return true
return false, nil
})
if err != nil {
return "", err

View file

@ -36,13 +36,13 @@ func (c *memcachedClient) disconnect() {
func (c *memcachedClient) queryStats() ([]byte, error) {
var b bytes.Buffer
err := c.conn.Command("stats\r\n", func(bytes []byte) bool {
if err := c.conn.Command("stats\r\n", func(bytes []byte) (bool, error) {
s := strings.TrimSpace(string(bytes))
b.WriteString(s)
b.WriteByte('\n')
return !(strings.HasPrefix(s, "END") || strings.HasPrefix(s, "ERROR"))
})
if err != nil {
return !(strings.HasPrefix(s, "END") || strings.HasPrefix(s, "ERROR")), nil
}); err != nil {
return nil, err
}
return b.Bytes(), nil

View file

@ -57,29 +57,26 @@ func (c *Client) Version() (*Version, error) {
func (c *Client) get(command string, stopRead stopReadFunc) (output []string, err error) {
var num int
var maxLinesErr error
err = c.Command(command, func(bytes []byte) bool {
if err := c.Command(command, func(bytes []byte) (bool, error) {
line := string(bytes)
num++
if num > maxLinesToRead {
maxLinesErr = fmt.Errorf("read line limit exceeded (%d)", maxLinesToRead)
return false
return false, fmt.Errorf("read line limit exceeded (%d)", maxLinesToRead)
}
// skip real-time messages
if strings.HasPrefix(line, ">") {
return true
return true, nil
}
line = strings.Trim(line, "\r\n ")
output = append(output, line)
if stopRead != nil && stopRead(line) {
return false
return false, nil
}
return true
})
if maxLinesErr != nil {
return nil, maxLinesErr
return true, nil
}); err != nil {
return nil, err
}
return output, err
}

View file

@ -98,7 +98,9 @@ func (m *mockSocketClient) Command(command string, process socket.Processor) err
}
for s.Scan() {
process(s.Bytes())
if _, err := process(s.Bytes()); err != nil {
return err
}
}
return nil
}

View file

@ -58,9 +58,9 @@ func (c *torControlClient) authenticate() error {
}
var s string
err := c.conn.Command(cmd+"\n", func(bs []byte) bool {
err := c.conn.Command(cmd+"\n", func(bs []byte) (bool, error) {
s = string(bs)
return false
return false, nil
})
if err != nil {
return fmt.Errorf("authentication failed: %v", err)
@ -74,7 +74,7 @@ func (c *torControlClient) authenticate() error {
func (c *torControlClient) disconnect() {
// https://spec.torproject.org/control-spec/commands.html#quit
_ = c.conn.Command(cmdQuit+"\n", func(bs []byte) bool { return false })
_ = c.conn.Command(cmdQuit+"\n", func(bs []byte) (bool, error) { return false, nil })
_ = c.conn.Disconnect()
}
@ -87,27 +87,21 @@ func (c *torControlClient) getInfo(keywords ...string) ([]byte, error) {
cmd := fmt.Sprintf("%s %s", cmdGetInfo, strings.Join(keywords, " "))
var buf bytes.Buffer
var err error
clientErr := c.conn.Command(cmd+"\n", func(bs []byte) bool {
if err := c.conn.Command(cmd+"\n", func(bs []byte) (bool, error) {
s := string(bs)
switch {
case strings.HasPrefix(s, "250-"):
buf.WriteString(strings.TrimPrefix(s, "250-"))
buf.WriteByte('\n')
return true
return true, nil
case strings.HasPrefix(s, "250 "):
return false
return false, nil
default:
err = errors.New(s)
return false
return false, errors.New(s)
}
})
if clientErr != nil {
return nil, fmt.Errorf("command '%s' failed: %v", cmd, clientErr)
}
if err != nil {
}); err != nil {
return nil, fmt.Errorf("command '%s' failed: %v", cmd, err)
}

View file

@ -36,9 +36,9 @@ func (c *Collector) scrapeUnboundStats() ([]entry, error) {
}
defer func() { _ = c.client.Disconnect() }()
err := c.client.Command(command+"\n", func(bytes []byte) bool {
err := c.client.Command(command+"\n", func(bytes []byte) (bool, error) {
output = append(output, string(bytes))
return true
return true, nil
})
if err != nil {
return nil, fmt.Errorf("send command '%s': %w", command, err)

View file

@ -133,7 +133,7 @@ func (c *upsdClient) sendCommand(cmd string) ([]string, error) {
var errMsg string
endLine := getEndLine(cmd)
err := c.conn.Command(cmd+"\n", func(bytes []byte) bool {
err := c.conn.Command(cmd+"\n", func(bytes []byte) (bool, error) {
line := string(bytes)
resp = append(resp, line)
@ -141,7 +141,7 @@ func (c *upsdClient) sendCommand(cmd string) ([]string, error) {
errMsg = strings.TrimPrefix(line, "ERR ")
}
return line != endLine && errMsg == ""
return line != endLine && errMsg == "", nil
})
if err != nil {
return nil, err

View file

@ -4,7 +4,6 @@ package uwsgi
import (
"bytes"
"fmt"
"time"
"github.com/netdata/netdata/go/plugins/plugin/go.d/pkg/socket"
@ -28,30 +27,19 @@ type uwsgiClient struct {
func (c *uwsgiClient) queryStats() ([]byte, error) {
var b bytes.Buffer
var n int64
var err error
const readLineLimit = 1000 * 10
cfg := socket.Config{
Address: c.address,
Timeout: c.timeout,
Address: c.address,
Timeout: c.timeout,
MaxReadLines: 1000 * 10,
}
clientErr := socket.ConnectAndRead(cfg, func(bs []byte) bool {
if err := socket.ConnectAndRead(cfg, func(bs []byte) (bool, error) {
b.Write(bs)
b.WriteByte('\n')
if n++; n >= readLineLimit {
err = fmt.Errorf("read line limit exceeded %d", readLineLimit)
return false
}
// The server will close the connection when it has finished sending data.
return true
})
if clientErr != nil {
return nil, clientErr
}
if err != nil {
return true, nil
}); err != nil {
return nil, err
}

View file

@ -4,14 +4,11 @@ package zookeeper
import (
"bytes"
"fmt"
"unsafe"
"github.com/netdata/netdata/go/plugins/plugin/go.d/pkg/socket"
)
const limitReadLines = 2000
type fetcher interface {
fetch(command string) ([]string, error)
}
@ -26,21 +23,12 @@ func (c *zookeeperFetcher) fetch(command string) (rows []string, err error) {
}
defer func() { _ = c.Disconnect() }()
var num int
clientErr := c.Command(command, func(b []byte) bool {
if err := c.Command(command, func(b []byte) (bool, error) {
if !isZKLine(b) || isMntrLineOK(b) {
rows = append(rows, string(b))
}
if num += 1; num >= limitReadLines {
err = fmt.Errorf("read line limit exceeded (%d)", limitReadLines)
return false
}
return true
})
if clientErr != nil {
return nil, clientErr
}
if err != nil {
return true, nil
}); err != nil {
return nil, err
}

View file

@ -22,14 +22,6 @@ func Test_clientFetch(t *testing.T) {
assert.Len(t, rows, 10)
}
func Test_clientFetchReadLineLimitExceeded(t *testing.T) {
c := &zookeeperFetcher{Client: &mockSocket{rowsNumResp: limitReadLines + 1}}
rows, err := c.fetch("whatever\n")
assert.Error(t, err)
assert.Len(t, rows, 0)
}
type mockSocket struct {
rowsNumResp int
}
@ -44,7 +36,9 @@ func (m *mockSocket) Disconnect() error {
func (m *mockSocket) Command(command string, process socket.Processor) error {
for i := 0; i < m.rowsNumResp; i++ {
process([]byte(command))
if _, err := process([]byte(command)); err != nil {
return err
}
}
return nil
}

View file

@ -30,9 +30,10 @@ func (c *Collector) initZookeeperFetcher() (fetcher, error) {
}
sock := socket.New(socket.Config{
Address: c.Address,
Timeout: c.Timeout.Duration(),
TLSConf: tlsConf,
Address: c.Address,
Timeout: c.Timeout.Duration(),
TLSConf: tlsConf,
MaxReadLines: 2000,
})
return &zookeeperFetcher{Client: sock}, nil

View file

@ -4,25 +4,29 @@ package socket
import (
"bufio"
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"time"
)
// Processor function passed to the Socket.Command function.
// It is passed by the caller to process a command's response line by line.
type Processor func([]byte) bool
// Processor is a callback function passed to the Socket.Command method.
// It processes each response line received from the server.
type Processor func([]byte) (bool, error)
// Client is the interface that wraps the basic socket client operations
// and hides the implementation details from the users.
// Implementations should return TCP, UDP or Unix ready sockets.
// Client defines an interface for socket clients, abstracting the underlying implementation.
// Implementations should provide connections for various socket types such as TCP, UDP, or Unix domain sockets.
type Client interface {
Connect() error
Disconnect() error
Command(command string, process Processor) error
}
// ConnectAndRead establishes a connection using the given configuration,
// executes the provided processor function on the incoming response lines,
// and ensures the connection is properly closed after use.
func ConnectAndRead(cfg Config, process Processor) error {
sock := New(cfg)
@ -35,46 +39,33 @@ func ConnectAndRead(cfg Config, process Processor) error {
return sock.read(process)
}
// New returns a new pointer to a socket client given the socket
// type (IP, TCP, UDP, UNIX), a network address (IP/domain:port),
// a timeout and a TLS config. It supports both IPv4 and IPv6 address
// and reuses connection where possible.
// New creates and returns a new Socket instance configured with the provided settings.
// The socket supports multiple types (TCP, UDP, UNIX), addresses (IPv4, IPv6, domain names),
// and optional TLS encryption. Connections are reused where possible.
func New(cfg Config) *Socket {
return &Socket{Config: cfg}
}
// Socket is the implementation of a socket client.
// Socket is a concrete implementation of the Client interface, managing a network connection
// based on the specified configuration (address, type, timeout, and optional TLS settings).
type Socket struct {
Config
conn net.Conn
}
// Config holds the network ip v4 or v6 address, port,
// Socket type(ip, tcp, udp, unix), timeout and TLS configuration for a Socket
// Config encapsulates the settings required to establish a network connection.
type Config struct {
Address string
Timeout time.Duration
TLSConf *tls.Config
Address string
Timeout time.Duration
TLSConf *tls.Config
MaxReadLines int64
}
// Connect connects to the Socket address on the named network.
// If the address is a domain name it will also perform the DNS resolution.
// Address like :80 will attempt to connect to the localhost.
// The config timeout and TLS config will be used.
// Connect establishes a connection to the specified address using the configuration details.
func (s *Socket) Connect() error {
network, address := networkType(s.Address)
var conn net.Conn
var err error
if s.TLSConf == nil {
conn, err = net.DialTimeout(network, address, s.timeout())
} else {
var d net.Dialer
d.Timeout = s.timeout()
conn, err = tls.DialWithDialer(&d, network, address, s.TLSConf)
}
conn, err := s.dial()
if err != nil {
return err
return fmt.Errorf("socket.Connect: %w", err)
}
s.conn = conn
@ -82,22 +73,19 @@ func (s *Socket) Connect() error {
return nil
}
// Disconnect closes the connection.
// Any in-flight commands will be cancelled and return errors.
func (s *Socket) Disconnect() (err error) {
if s.conn != nil {
err = s.conn.Close()
s.conn = nil
// Disconnect terminates the active connection if one exists.
func (s *Socket) Disconnect() error {
if s.conn == nil {
return nil
}
err := s.conn.Close()
s.conn = nil
return err
}
// Command writes the command string to the connection and passed the
// response bytes line by line to the process function. It uses the
// timeout value from the Socket config and returns read, write and
// timeout errors if any. If a timeout occurs during the processing
// of the responses this function will stop processing and return a
// timeout error.
// Command sends a command string to the connected server and processes its response line by line
// using the provided Processor function. This method respects the timeout configuration
// for write and read operations. If a timeout or processing error occurs, it stops and returns the error.
func (s *Socket) Command(command string, process Processor) error {
if s.conn == nil {
return errors.New("cannot send command on nil connection")
@ -112,10 +100,10 @@ func (s *Socket) Command(command string, process Processor) error {
func (s *Socket) write(command string) error {
if s.conn == nil {
return errors.New("attempt to write on nil connection")
return errors.New("write: nil connection")
}
if err := s.conn.SetWriteDeadline(time.Now().Add(s.timeout())); err != nil {
if err := s.conn.SetWriteDeadline(s.deadline()); err != nil {
return err
}
@ -126,25 +114,53 @@ func (s *Socket) write(command string) error {
func (s *Socket) read(process Processor) error {
if process == nil {
return errors.New("process func is nil")
return errors.New("read: process func is nil")
}
if s.conn == nil {
return errors.New("attempt to read on nil connection")
return errors.New("read: nil connection")
}
if err := s.conn.SetReadDeadline(time.Now().Add(s.timeout())); err != nil {
if err := s.conn.SetReadDeadline(s.deadline()); err != nil {
return err
}
sc := bufio.NewScanner(s.conn)
for sc.Scan() && process(sc.Bytes()) {
var n int64
limit := s.MaxReadLines
for sc.Scan() {
more, err := process(sc.Bytes())
if err != nil {
return err
}
if n++; limit > 0 && n > limit {
return fmt.Errorf("read line limit exceeded (%d", limit)
}
if !more {
break
}
}
return sc.Err()
}
func (s *Socket) dial() (net.Conn, error) {
network, address := parseAddress(s.Address)
var d net.Dialer
d.Timeout = s.timeout()
if s.TLSConf != nil {
return tls.DialWithDialer(&d, network, address, s.TLSConf)
}
return d.DialContext(context.Background(), network, address)
}
func (s *Socket) deadline() time.Time {
return time.Now().Add(s.timeout())
}
func (s *Socket) timeout() time.Duration {
if s.Timeout == 0 {
return time.Second

View file

@ -3,152 +3,86 @@
package socket
import (
"crypto/tls"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
const (
testServerAddress = "127.0.0.1:9999"
testUdpServerAddress = "udp://127.0.0.1:9999"
testUnixServerAddress = "/tmp/testSocketFD"
defaultTimeout = 100 * time.Millisecond
)
func TestSocket_Command(t *testing.T) {
const (
testServerAddress = "tcp://127.0.0.1:9999"
testUdpServerAddress = "udp://127.0.0.1:9999"
testUnixServerAddress = "unix:///tmp/testSocketFD"
defaultTimeout = 1000 * time.Millisecond
)
var tcpConfig = Config{
Address: testServerAddress,
Timeout: defaultTimeout,
TLSConf: nil,
}
var udpConfig = Config{
Address: testUdpServerAddress,
Timeout: defaultTimeout,
TLSConf: nil,
}
var unixConfig = Config{
Address: testUnixServerAddress,
Timeout: defaultTimeout,
TLSConf: nil,
}
var tcpTlsConfig = Config{
Address: testServerAddress,
Timeout: defaultTimeout,
TLSConf: &tls.Config{},
}
func Test_clientCommand(t *testing.T) {
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1}
go func() { _ = srv.Run(); defer func() { _ = srv.Close() }() }()
time.Sleep(time.Millisecond * 100)
sock := New(tcpConfig)
require.NoError(t, sock.Connect())
err := sock.Command("ping\n", func(bytes []byte) bool {
assert.Equal(t, "pong", string(bytes))
return true
})
require.NoError(t, sock.Disconnect())
require.NoError(t, err)
}
func Test_clientTimeout(t *testing.T) {
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1}
go func() { _ = srv.Run() }()
time.Sleep(time.Millisecond * 100)
sock := New(tcpConfig)
require.NoError(t, sock.Connect())
sock.Timeout = 0
err := sock.Command("ping\n", func(bytes []byte) bool {
assert.Equal(t, "pong", string(bytes))
return true
})
require.NoError(t, err)
}
func Test_clientIncompleteSSL(t *testing.T) {
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1}
go func() { _ = srv.Run() }()
time.Sleep(time.Millisecond * 100)
sock := New(tcpTlsConfig)
err := sock.Connect()
require.Error(t, err)
}
func Test_clientCommandStopProcessing(t *testing.T) {
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 2}
go func() { _ = srv.Run() }()
time.Sleep(time.Millisecond * 100)
sock := New(tcpConfig)
require.NoError(t, sock.Connect())
err := sock.Command("ping\n", func(bytes []byte) bool {
assert.Equal(t, "pong", string(bytes))
return false
})
require.NoError(t, sock.Disconnect())
require.NoError(t, err)
}
func Test_clientUDPCommand(t *testing.T) {
srv := &udpServer{addr: testServerAddress, rowsNumResp: 1}
go func() { _ = srv.Run(); defer func() { _ = srv.Close() }() }()
time.Sleep(time.Millisecond * 100)
sock := New(udpConfig)
require.NoError(t, sock.Connect())
err := sock.Command("ping\n", func(bytes []byte) bool {
assert.Equal(t, "pong", string(bytes))
return false
})
require.NoError(t, sock.Disconnect())
require.NoError(t, err)
}
func Test_clientTCPAddress(t *testing.T) {
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1}
go func() { _ = srv.Run() }()
time.Sleep(time.Millisecond * 100)
sock := New(tcpConfig)
require.NoError(t, sock.Connect())
tcpConfig.Address = "tcp://" + tcpConfig.Address
sock = New(tcpConfig)
require.NoError(t, sock.Connect())
}
func Test_clientUnixCommand(t *testing.T) {
srv := &unixServer{addr: testUnixServerAddress, rowsNumResp: 1}
// cleanup previous file descriptors
_ = srv.Close()
go func() { _ = srv.Run() }()
time.Sleep(time.Millisecond * 200)
sock := New(unixConfig)
require.NoError(t, sock.Connect())
err := sock.Command("ping\n", func(bytes []byte) bool {
assert.Equal(t, "pong", string(bytes))
return false
})
require.NoError(t, err)
require.NoError(t, sock.Disconnect())
}
func Test_clientEmptyProcessFunc(t *testing.T) {
srv := &tcpServer{addr: testServerAddress, rowsNumResp: 1}
go func() { _ = srv.Run() }()
time.Sleep(time.Millisecond * 100)
sock := New(tcpConfig)
require.NoError(t, sock.Connect())
err := sock.Command("ping\n", nil)
require.Error(t, err, "nil process func should return an error")
type server interface {
Run() error
Close() error
}
tests := map[string]struct {
srv server
cfg Config
wantConnectErr bool
wantCommandErr bool
}{
"tcp": {
srv: newTCPServer(testServerAddress),
cfg: Config{
Address: testServerAddress,
Timeout: defaultTimeout,
},
},
"udp": {
srv: newUDPServer(testUdpServerAddress),
cfg: Config{
Address: testUdpServerAddress,
Timeout: defaultTimeout,
},
},
"unix": {
srv: newUnixServer(testUnixServerAddress),
cfg: Config{
Address: testUnixServerAddress,
Timeout: defaultTimeout,
},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
go func() {
defer func() { _ = test.srv.Close() }()
require.NoError(t, test.srv.Run())
}()
time.Sleep(time.Millisecond * 500)
sock := New(test.cfg)
err := sock.Connect()
if test.wantConnectErr {
require.Error(t, err)
return
}
require.NoError(t, err)
defer sock.Disconnect()
var resp string
err = sock.Command("ping\n", func(bytes []byte) (bool, error) {
resp = string(bytes)
return false, nil
})
if test.wantCommandErr {
require.Error(t, err)
} else {
require.NoError(t, err)
require.Equal(t, "pong", resp)
}
})
}
}

View file

@ -0,0 +1,257 @@
// SPDX-License-Identifier: GPL-3.0-or-later
package socket
import (
"bufio"
"context"
"errors"
"fmt"
"net"
"os"
"sync"
"time"
)
func newTCPServer(addr string) *tcpServer {
ctx, cancel := context.WithCancel(context.Background())
_, addr = parseAddress(addr)
return &tcpServer{
addr: addr,
ctx: ctx,
cancel: cancel,
}
}
type tcpServer struct {
addr string
listener net.Listener
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
func (t *tcpServer) Run() error {
var err error
t.listener, err = net.Listen("tcp", t.addr)
if err != nil {
return fmt.Errorf("failed to start TCP server: %w", err)
}
return t.handleConnections()
}
func (t *tcpServer) Close() (err error) {
t.cancel()
if t.listener != nil {
if err := t.listener.Close(); err != nil {
return fmt.Errorf("failed to close TCP server: %w", err)
}
}
t.wg.Wait()
return nil
}
func (t *tcpServer) handleConnections() (err error) {
for {
select {
case <-t.ctx.Done():
return nil
default:
conn, err := t.listener.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return nil
}
return fmt.Errorf("could not accept connection: %v", err)
}
t.wg.Add(1)
go func() {
defer t.wg.Done()
t.handleConnection(conn)
}()
}
}
}
func (t *tcpServer) handleConnection(conn net.Conn) {
defer func() { _ = conn.Close() }()
if err := conn.SetDeadline(time.Now().Add(time.Second)); err != nil {
return
}
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
if _, err := rw.ReadString('\n'); err != nil {
writeResponse(rw, fmt.Sprintf("failed to read input: %v\n", err))
} else {
writeResponse(rw, "pong\n")
}
}
func newUDPServer(addr string) *udpServer {
ctx, cancel := context.WithCancel(context.Background())
_, addr = parseAddress(addr)
return &udpServer{
addr: addr,
ctx: ctx,
cancel: cancel,
}
}
type udpServer struct {
addr string
conn *net.UDPConn
ctx context.Context
cancel context.CancelFunc
}
func (u *udpServer) Run() error {
addr, err := net.ResolveUDPAddr("udp", u.addr)
if err != nil {
return fmt.Errorf("failed to resolve UDP address: %w", err)
}
u.conn, err = net.ListenUDP("udp", addr)
if err != nil {
return fmt.Errorf("failed to start UDP server: %w", err)
}
return u.handleConnections()
}
func (u *udpServer) Close() (err error) {
u.cancel()
if u.conn != nil {
if err := u.conn.Close(); err != nil {
return fmt.Errorf("failed to close UDP server: %w", err)
}
}
return nil
}
func (u *udpServer) handleConnections() error {
buffer := make([]byte, 8192)
for {
select {
case <-u.ctx.Done():
return nil
default:
if err := u.conn.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
continue
}
_, addr, err := u.conn.ReadFromUDP(buffer[0:])
if err != nil {
if !errors.Is(err, os.ErrDeadlineExceeded) {
return fmt.Errorf("failed to read UDP packet: %w", err)
}
continue
}
if _, err := u.conn.WriteToUDP([]byte("pong\n"), addr); err != nil {
return fmt.Errorf("failed to write UDP response: %w", err)
}
}
}
}
func newUnixServer(addr string) *unixServer {
ctx, cancel := context.WithCancel(context.Background())
_, addr = parseAddress(addr)
return &unixServer{
addr: addr,
ctx: ctx,
cancel: cancel,
}
}
type unixServer struct {
addr string
listener *net.UnixListener
wg sync.WaitGroup
ctx context.Context
cancel context.CancelFunc
}
func (u *unixServer) Run() error {
if err := os.Remove(u.addr); err != nil && !os.IsNotExist(err) {
return fmt.Errorf("failed to clean up existing socket: %w", err)
}
addr, err := net.ResolveUnixAddr("unix", u.addr)
if err != nil {
return fmt.Errorf("failed to resolve Unix address: %w", err)
}
u.listener, err = net.ListenUnix("unix", addr)
if err != nil {
return fmt.Errorf("failed to start Unix server: %w", err)
}
return u.handleConnections()
}
func (u *unixServer) Close() error {
u.cancel()
if u.listener != nil {
if err := u.listener.Close(); err != nil {
return fmt.Errorf("failed to close Unix server: %w", err)
}
}
u.wg.Wait()
_ = os.Remove(u.addr)
return nil
}
func (u *unixServer) handleConnections() error {
for {
select {
case <-u.ctx.Done():
return nil
default:
if err := u.listener.SetDeadline(time.Now().Add(time.Second)); err != nil {
continue
}
conn, err := u.listener.AcceptUnix()
if err != nil {
if !errors.Is(err, os.ErrDeadlineExceeded) {
return err
}
continue
}
u.wg.Add(1)
go func() {
defer u.wg.Done()
u.handleConnection(conn)
}()
}
}
}
func (u *unixServer) handleConnection(conn net.Conn) {
defer func() { _ = conn.Close() }()
if err := conn.SetDeadline(time.Now().Add(time.Second)); err != nil {
return
}
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
if _, err := rw.ReadString('\n'); err != nil {
writeResponse(rw, fmt.Sprintf("failed to read input: %v\n", err))
} else {
writeResponse(rw, "pong\n")
}
}
func writeResponse(rw *bufio.ReadWriter, response string) {
_, _ = rw.WriteString(response)
_ = rw.Flush()
}

View file

@ -1,139 +0,0 @@
// SPDX-License-Identifier: GPL-3.0-or-later
package socket
import (
"bufio"
"errors"
"fmt"
"net"
"os"
"strings"
"time"
)
type tcpServer struct {
addr string
server net.Listener
rowsNumResp int
}
func (t *tcpServer) Run() (err error) {
t.server, err = net.Listen("tcp", t.addr)
if err != nil {
return
}
return t.handleConnections()
}
func (t *tcpServer) Close() (err error) {
return t.server.Close()
}
func (t *tcpServer) handleConnections() (err error) {
for {
conn, err := t.server.Accept()
if err != nil || conn == nil {
return errors.New("could not accept connection")
}
t.handleConnection(conn)
}
}
func (t *tcpServer) handleConnection(conn net.Conn) {
defer func() { _ = conn.Close() }()
_ = conn.SetDeadline(time.Now().Add(time.Millisecond * 100))
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
_, err := rw.ReadString('\n')
if err != nil {
_, _ = rw.WriteString("failed to read input")
_ = rw.Flush()
} else {
resp := strings.Repeat("pong\n", t.rowsNumResp)
_, _ = rw.WriteString(resp)
_ = rw.Flush()
}
}
type udpServer struct {
addr string
conn *net.UDPConn
rowsNumResp int
}
func (u *udpServer) Run() (err error) {
addr, err := net.ResolveUDPAddr("udp", u.addr)
if err != nil {
return err
}
u.conn, err = net.ListenUDP("udp", addr)
if err != nil {
return
}
u.handleConnections()
return nil
}
func (u *udpServer) Close() (err error) {
return u.conn.Close()
}
func (u *udpServer) handleConnections() {
for {
var buf [2048]byte
_, addr, _ := u.conn.ReadFromUDP(buf[0:])
resp := strings.Repeat("pong\n", u.rowsNumResp)
_, _ = u.conn.WriteToUDP([]byte(resp), addr)
}
}
type unixServer struct {
addr string
conn *net.UnixListener
rowsNumResp int
}
func (u *unixServer) Run() (err error) {
_, _ = os.CreateTemp("/tmp", "testSocketFD")
addr, err := net.ResolveUnixAddr("unix", u.addr)
if err != nil {
return err
}
u.conn, err = net.ListenUnix("unix", addr)
if err != nil {
return
}
go u.handleConnections()
return nil
}
func (u *unixServer) Close() (err error) {
_ = os.Remove(testUnixServerAddress)
return u.conn.Close()
}
func (u *unixServer) handleConnections() {
var conn net.Conn
var err error
conn, err = u.conn.AcceptUnix()
if err != nil {
panic(fmt.Errorf("could not accept connection: %v", err))
}
u.handleConnection(conn)
}
func (u *unixServer) handleConnection(conn net.Conn) {
_ = conn.SetDeadline(time.Now().Add(time.Second))
rw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
_, err := rw.ReadString('\n')
if err != nil {
_, _ = rw.WriteString("failed to read input")
_ = rw.Flush()
} else {
resp := strings.Repeat("pong\n", u.rowsNumResp)
_, _ = rw.WriteString(resp)
_ = rw.Flush()
}
}

View file

@ -12,7 +12,7 @@ func IsUdpSocket(address string) bool {
return strings.HasPrefix(address, "udp://")
}
func networkType(address string) (string, string) {
func parseAddress(address string) (string, string) {
switch {
case IsUnixSocket(address):
address = strings.TrimPrefix(address, "unix://")