libwebsockets/minimal-examples-lowlevel/secure-streams/minimal-secure-streams-smd/multi.c

420 lines
11 KiB
C

/*
* lws-minimal-secure-streams-smd
*
* Written in 2010-2021 by Andy Green <andy@warmcat.com>
*
* This file is made available under the Creative Commons CC0 1.0
* Universal Public Domain Dedication.
*
*
* This demonstrates a minimal http client using secure streams to access the
* SMD api. This file is only built when LWS_SS_USE_SSPC defined.
*
* This is an alternative test implementation selected by --multi at runtime,
* it's in its own file to stop muddying up the main test sources. It's only
* available when built with SSPC / produces -client executable.
*
* We will fork several times, the original thread and the forks hook up to
* the proxy with smd SS, each fork waits a second for everyone to have joined,
* and then each fork (NOT the original process) sends a bunch of user messages
* that all the forks should receive, having been distributed by SMD and the
* ss proxy.
*
* The participants check they received all the messages expected from everyone
* and then send a final message indicating success and exits. The original
* fork is watching for these to arrive before the timeout, if so it's a PASS.
*/
#include <libwebsockets.h>
#include <string.h>
#include <signal.h>
static int bad = 1, interrupted;
/* number of forks */
#define FORKS 4
/* number of messages each will send, eg, 4 forks 64 message == 256 messages */
#define MSGCOUNT 64
typedef struct myss {
struct lws_ss_handle *ss;
void *opaque_data;
/* ... application specific state ... */
uint64_t seen_mask[FORKS];
int seen_msgs[FORKS];
lws_sorted_usec_list_t sul;
int count;
char seen_all;
char send_seen_all;
char starting;
} myss_t;
/* secure streams payload interface */
static lws_ss_state_return_t
multi_myss_rx(void *userobj, const uint8_t *buf, size_t len, int flags)
{
myss_t *m = (myss_t *)userobj;
const char *p;
int fk, t, n;
size_t al;
/* ignore our and other forks announcing their result */
if (lws_json_simple_find((const char *)buf, len, "\"seen_all\":", &al))
return LWSSSSRET_OK;
/*
* otherwise once we saw the expected messages, any other messages
* coming in this class are wrong
*/
if (m->seen_all) {
lwsl_err("%s: unexpected extra messages\n", __func__);
return LWSSSSRET_DESTROY_ME;
}
p = lws_json_simple_find((const char *)buf, len, "\"fork\":", &al);
if (!p)
return LWSSSSRET_DESTROY_ME;
fk = atoi(p);
if (fk < 1 || fk > FORKS)
return LWSSSSRET_DESTROY_ME;
p = lws_json_simple_find((const char *)buf, len, "\"test\":", &al);
if (!p)
return LWSSSSRET_DESTROY_ME;
t = atoi(p);
if (t < 0 || t >= MSGCOUNT)
return LWSSSSRET_DESTROY_ME;
m->seen_mask[fk - 1] |= 1ull << t;
m->seen_msgs[fk - 1]++; /* keep an eye on dupes */
/* Have we seen a full set of messages from everyone? */
for (n = 0; n < FORKS; n++) {
if (m->seen_msgs[n] != (int)MSGCOUNT)
return LWSSSSRET_OK;
if (m->seen_mask[n] != 0xffffffffffffffffull)
return LWSSSSRET_OK;
}
/*
* Oh... so we have finished collecting messages
*/
lwsl_user("%s: test thread %d: %s received all messages\n", __func__,
(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
lws_ss_tag(m->ss));
m->seen_all = m->send_seen_all = 1;
/*
* Prepare to inform the original process we saw everything
* from everyone OK
*/
lws_ss_request_tx(m->ss);
return LWSSSSRET_OK;
}
static void
sul_multi_tx_periodic_cb(lws_sorted_usec_list_t *sul)
{
myss_t *m = lws_container_of(sul, myss_t, sul);
if (!m->send_seen_all && m->seen_all) {
lws_ss_destroy(&m->ss);
return;
}
m->starting = 1;
if (m->count < MSGCOUNT || m->send_seen_all)
lws_ss_request_tx(m->ss);
}
static lws_ss_state_return_t
multi_myss_tx(void *userobj, lws_ss_tx_ordinal_t ord, uint8_t *buf, size_t *len,
int *flags)
{
myss_t *m = (myss_t *)userobj;
/*
* We want to send exactly MSGCOUNT user class smd messages
*/
if (!m->starting || (m->count == MSGCOUNT && !m->send_seen_all))
return LWSSSSRET_TX_DONT_SEND;
// lwsl_notice("%s: sending SS smd\n", __func__);
lws_ser_wu64be(buf, 1 << LWSSMDCL_USER_BASE_BITNUM);
lws_ser_wu64be(buf + 8, 0); /* valgrind notices uninitialized if left */
if (m->send_seen_all) {
*len = LWS_SMD_SS_RX_HEADER_LEN + (unsigned int)
lws_snprintf((char *)buf + LWS_SMD_SS_RX_HEADER_LEN, *len,
"{\"class\":\"user\",\"fork\": %d,\"seen_all\":true}",
(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
m->send_seen_all = 0;
lwsl_info("%s: test thread %d: sent summary message\n", __func__,
(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
} else
*len = LWS_SMD_SS_RX_HEADER_LEN + (unsigned int)
lws_snprintf((char *)buf + LWS_SMD_SS_RX_HEADER_LEN, *len,
"{\"class\":\"user\",\"fork\": %d,\"test\":%u}",
(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
m->count++);
*flags = LWSSS_FLAG_SOM | LWSSS_FLAG_EOM;
lws_sul_schedule(lws_ss_get_context(m->ss), 0, &m->sul,
sul_multi_tx_periodic_cb, 25 * LWS_US_PER_MS);
return LWSSSSRET_OK;
}
static lws_ss_state_return_t
multi_myss_state(void *userobj, void *h_src, lws_ss_constate_t state,
lws_ss_tx_ordinal_t ack)
{
myss_t *m = (myss_t *)userobj;
int n;
lwsl_notice("%s: %s: %s (%d), ord 0x%x\n", __func__, lws_ss_tag(m->ss),
lws_ss_state_name((int)state), state, (unsigned int)ack);
switch (state) {
case LWSSSCS_DESTROYING:
lws_sul_cancel(&m->sul);
interrupted = 1;
return 0;
case LWSSSCS_CONNECTED:
lwsl_notice("%s: CONNECTED: test fork %d\n", __func__,
(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)));
/*
* Because in this test everybody is watching and counting
* everybody else's messages from different forks, we have to
* hold off starting sending for 2s so all forks can join the
* proxy first and not miss anything
*/
lws_sul_schedule(lws_ss_get_context(m->ss), 0, &m->sul,
sul_multi_tx_periodic_cb, 2 * LWS_US_PER_SEC);
m->starting = 0;
return 0;
case LWSSSCS_DISCONNECTED:
for (n = 0; n < FORKS; n++)
lwsl_notice("%s: testfork %d: peer %d: seen_msg = %d, "
"seen make = 0x%llx\n", __func__,
(int)(intptr_t)lws_context_user(lws_ss_get_context(m->ss)),
n, m->seen_msgs[n],
(unsigned long long)m->seen_mask[n]);
break;
default:
break;
}
return 0;
}
static const lws_ss_info_t ssi_multi_lws_smd = {
.handle_offset = offsetof(myss_t, ss),
.opaque_user_data_offset = offsetof(myss_t, opaque_data),
.rx = multi_myss_rx,
.tx = multi_myss_tx,
.state = multi_myss_state,
.user_alloc = sizeof(myss_t),
.streamtype = LWS_SMD_STREAMTYPENAME,
.manual_initial_tx_credit = 1 << LWSSMDCL_USER_BASE_BITNUM,
};
static lws_ss_state_return_t
multi_myss_rx_monitor(void *userobj, const uint8_t *buf, size_t len, int flags)
{
myss_t *m = (myss_t *)userobj;
const char *p;
size_t al;
int fk, n;
/* ignore our and other forks announcing their result */
if (!lws_json_simple_find((const char *)buf, len, "\"seen_all\":", &al))
return LWSSSSRET_OK;
p = lws_json_simple_find((const char *)buf, len, "\"fork\":", &al);
if (!p)
return LWSSSSRET_DESTROY_ME;
fk = atoi(p);
if (fk < 1 || fk > FORKS)
return LWSSSSRET_DESTROY_ME;
if (m->seen_msgs[fk - 1])
/* expected only once ... dupe */
return LWSSSSRET_DESTROY_ME;
m->seen_msgs[fk - 1] = 1;
for (n = 0; n < FORKS; n++)
if (!m->seen_msgs[n])
return LWSSSSRET_OK;
/* the test has succeeded */
bad = 0;
interrupted = 1;
return LWSSSSRET_OK;
}
static const lws_ss_info_t ssi_multi_lws_smd_monitor = {
.handle_offset = offsetof(myss_t, ss),
.opaque_user_data_offset = offsetof(myss_t, opaque_data),
.rx = multi_myss_rx_monitor,
// .state = multi_myss_state_monitor,
.user_alloc = sizeof(myss_t),
.streamtype = LWS_SMD_STREAMTYPENAME,
.manual_initial_tx_credit = 1 << LWSSMDCL_USER_BASE_BITNUM,
};
/* for comparison, this is a non-SS lws_smd participant */
static int
direct_smd_cb(void *opaque, lws_smd_class_t _class, lws_usec_t timestamp,
void *buf, size_t len)
{
struct lws_context **pctx = (struct lws_context **)opaque;
if (_class != LWSSMDCL_SYSTEM_STATE)
return 0;
if (!lws_json_simple_strcmp(buf, len, "\"state\":", "OPERATIONAL")) {
/*
* Create the SSPC link to lws_smd... notice in ssi_lws_smd
* above, we tell this link to use the user class filter.
*
* If context->user is zero, we are the original process
* monitoring the progress of the others, otherwise we are
* 1 .. FORKS and producing / checking the smd messages
*/
lwsl_info("%s: starting ss for test fork %d\n", __func__,
(int)(intptr_t)lws_context_user(*pctx));
if (lws_ss_create(*pctx, 0, lws_context_user(*pctx) ?
&ssi_multi_lws_smd /* forked process send / check */:
&ssi_multi_lws_smd_monitor /* original monitors */,
NULL, NULL, NULL, NULL)) {
lwsl_err("%s: failed to create secure stream\n",
__func__);
return -1;
}
}
return 0;
}
static void
sul_timeout_cb(lws_sorted_usec_list_t *sul)
{
interrupted = 1;
}
int
smd_ss_multi_test(int argc, const char **argv)
{
struct lws_context_creation_info info;
lws_sorted_usec_list_t sul_timeout;
struct lws_context *context;
pid_t pid;
int n;
lwsl_user("LWS Secure Streams SMD MULTI test client [-d<verb>]\n");
for (n = 0; n < FORKS; n++) {
pid = fork();
if (!pid) /* forked child */ {
break;
}
lwsl_notice("%s: forked test process %u\n", __func__, pid);
}
if (n == FORKS)
/* the original process */
n = -1; /* so original ends up with context.user as 0 below */
memset(&info, 0, sizeof info);
memset(&sul_timeout, 0, sizeof sul_timeout);
lws_cmdline_option_handle_builtin(argc, argv, &info);
{
const char *p;
/* connect to ssproxy via UDS by default, else via
* tcp connection to this port */
if ((p = lws_cmdline_option(argc, argv, "-p")))
info.ss_proxy_port = (uint16_t)atoi(p);
/* UDS "proxy.ss.lws" in abstract namespace, else this socket
* path; when -p given this can specify the network interface
* to bind to */
if ((p = lws_cmdline_option(argc, argv, "-i")))
info.ss_proxy_bind = p;
/* if -p given, -a specifies the proxy address to connect to */
if ((p = lws_cmdline_option(argc, argv, "-a")))
info.ss_proxy_address = p;
}
info.fd_limit_per_thread = 1 + 6 + 1;
info.port = CONTEXT_PORT_NO_LISTEN;
info.protocols = lws_sspc_protocols;
info.options = LWS_SERVER_OPTION_EXPLICIT_VHOSTS |
LWS_SERVER_OPTION_DO_SSL_GLOBAL_INIT;
info.early_smd_cb = direct_smd_cb;
info.early_smd_class_filter = 0xffffffff;
info.early_smd_opaque = &context;
info.user = (void *)(intptr_t)(n + 1);
/* create the context */
context = lws_create_context(&info);
if (!context) {
lwsl_err("lws init failed\n");
return 1;
}
if (!lws_create_vhost(context, &info)) {
lwsl_err("%s: failed to create default vhost\n", __func__);
goto bail;
}
/* set up the test timeout */
lws_sul_schedule(context, 0, &sul_timeout, sul_timeout_cb,
10 * LWS_US_PER_SEC);
/* the event loop */
while (lws_service(context, 0) >= 0 && !interrupted)
;
bail:
lws_context_destroy(context);
if (n == -1)
lwsl_user("%s: finished %s\n", __func__, bad ? "FAIL" : "PASS");
return bad;
}