libwebsockets/lib/tls/tls-jit-trust.c
2024-11-03 07:59:12 +00:00

690 lines
18 KiB
C

/*
* libwebsockets - small server side websockets and web server implementation
*
* Copyright (C) 2010 - 2021 Andy Green <andy@warmcat.com>
*
* Permission is hereby granted, free of charge, to any person obtaining a copy
* of this software and associated documentation files (the "Software"), to
* deal in the Software without restriction, including without limitation the
* rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
* sell copies of the Software, and to permit persons to whom the Software is
* furnished to do so, subject to the following conditions:
*
* The above copyright notice and this permission notice shall be included in
* all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
* AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
* LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
* IN THE SOFTWARE.
*/
#include "private-lib-core.h"
void
lws_tls_kid_copy(union lws_tls_cert_info_results *ci, lws_tls_kid_t *kid)
{
/*
* KIDs all seem to be 20 bytes / SHA1 or less. If we get one that
* is bigger, treat only the first 20 bytes as significant.
*/
if ((size_t)ci->ns.len > sizeof(kid->kid))
kid->kid_len = sizeof(kid->kid);
else
kid->kid_len = (uint8_t)ci->ns.len;
memcpy(kid->kid, ci->ns.name, kid->kid_len);
}
void
lws_tls_kid_copy_kid(lws_tls_kid_t *kid, const lws_tls_kid_t *src)
{
int klen = sizeof(kid->kid);
if (src->kid_len < klen)
klen = src->kid_len;
kid->kid_len = (uint8_t)klen;
memcpy(kid->kid, src->kid, (size_t)klen);
}
int
lws_tls_kid_cmp(const lws_tls_kid_t *a, const lws_tls_kid_t *b)
{
if (a->kid_len != b->kid_len)
return 1;
return memcmp(a->kid, b->kid, a->kid_len);
}
/*
* We have the SKID and AKID for every peer cert captured, but they may be
* in any order, and eg, falsely have sent the root CA, or an attacker may
* send unresolveable self-referencing loops of KIDs.
*
* Let's sort them into the SKID -> AKID hierarchy, so the last entry is the
* server cert and the first entry is the highest parent that the server sent.
* Normally the top one will be an intermediate, and its AKID is the ID of the
* root CA cert we would need to trust to validate the chain.
*
* It's not unknown the server is misconfigured to also send the root CA, if so
* the top slot's AKID is empty and we should look for its SKID in the trust
* blob.
*
* If we return 0, we succeeded and the AKID of ch[0] is the SKID we want to see
* try to import from the trust blob.
*
* If we return nonzero, we can't identify what we want and should abandon the
* connection.
*/
int
lws_tls_jit_trust_sort_kids(struct lws *wsi, lws_tls_kid_chain_t *ch)
{
size_t hl;
lws_tls_jit_inflight_t *inf;
int n, m, sanity = 10;
const char *host = wsi->cli_hostname_copy;
char more = 1;
lwsl_info("%s\n", __func__);
if (!host) {
if (wsi->stash && wsi->stash->cis[CIS_HOST])
host = wsi->stash->cis[CIS_HOST];
#if defined(LWS_ROLE_H1) || defined(LWS_ROLE_H2)
else
host = lws_hdr_simple_ptr(wsi,
_WSI_TOKEN_CLIENT_PEER_ADDRESS);
}
#endif
if (!host)
return 1;
hl = strlen(host);
/* something to work with? */
if (!ch->count)
return 1;
/* do we need to sort? */
if (ch->count > 1) {
/* okie... */
while (more) {
if (!sanity--)
/* let's not get fooled into spinning */
return 1;
more = 0;
for (n = 0; n < ch->count - 1; n++) {
if (!lws_tls_kid_cmp(&ch->skid[n],
&ch->akid[n + 1]))
/* next belongs with this one */
continue;
/*
* next doesn't belong with this one, let's
* try to figure out where this one does belong
* then
*/
for (m = 0; m < ch->count; m++) {
if (n == m)
continue;
if (!lws_tls_kid_cmp(&ch->skid[n],
&ch->akid[m])) {
lws_tls_kid_t t;
/*
* m references us, so we
* need to go one step above m,
* swap m and n
*/
more = 1;
t = ch->akid[m];
ch->akid[m] = ch->akid[n];
ch->akid[n] = t;
t = ch->skid[m];
ch->skid[m] = ch->skid[n];
ch->skid[n] = t;
break;
}
}
if (more)
break;
}
}
/* then we should be sorted */
}
for (n = 0; n < ch->count; n++) {
lwsl_info("%s: AKID[%d]\n", __func__, n);
lwsl_hexdump_info(ch->akid[n].kid, ch->akid[n].kid_len);
lwsl_info("%s: SKID[%d]\n", __func__, n);
lwsl_hexdump_info(ch->skid[n].kid, ch->skid[n].kid_len);
}
/* to go further, user must provide a lookup helper */
if (!wsi->a.context->system_ops ||
!wsi->a.context->system_ops->jit_trust_query)
return 1;
/*
* If there's already a pending lookup for this host, let's bail and
* just wait for that to complete (since it will be done async if we
* can see it)
*/
lws_start_foreach_dll(struct lws_dll2 *, d,
wsi->a.context->jit_inflight.head) {
inf = lws_container_of(d, lws_tls_jit_inflight_t, list);
if (!strcmp((const char *)&inf[1], host))
/* already being handled */
return 1;
} lws_end_foreach_dll(d);
/*
* No... let's make an inflight entry for this host, then
*/
inf = lws_zalloc(sizeof(*inf) + hl + 1, __func__);
if (!inf)
return 1;
memcpy(&inf[1], host, hl + 1);
inf->refcount = (char)ch->count;
lws_dll2_add_tail(&inf->list, &wsi->a.context->jit_inflight);
/*
* ...kid_chain[0] AKID should indicate the right CA SKID that we want.
*
* Because of cross-signing, we check all of them and accept we may get
* multiple (the inflight accepts up to 2) CAs needed.
*/
for (n = 0; n < ch->count; n++)
wsi->a.context->system_ops->jit_trust_query(wsi->a.context,
ch->akid[n].kid, (size_t)ch->akid[n].kid_len,
(void *)inf);
return 0;
}
static void
tag_to_vh_name(char *result, size_t max, uint32_t tag)
{
lws_snprintf(result, max, "jitt-%08X", (unsigned int)tag);
}
int
lws_tls_jit_trust_vhost_bind(struct lws_context *cx, const char *address,
struct lws_vhost **pvh)
{
lws_tls_jit_cache_item_t *ci, jci;
lws_tls_jit_inflight_t *inf;
char vhtag[32];
size_t size;
int n;
if (lws_cache_item_get(cx->trust_cache, address, (const void **)&ci,
&size))
/*
* There's no cached info, we have to start from scratch on
* this one
*/
return 1;
/* gotten cache item may be evicted by jit_trust_query */
jci = *ci;
/*
* We have some trust cache information for this host already, it tells
* us the trusted CA SKIDs we found before, and the xor tag used to name
* the vhost configured for these trust CAs in its SSL_CTX.
*
* Let's check first if the correct prepared vhost already exists, if
* so, we can just bind to that and go.
*/
tag_to_vh_name(vhtag, sizeof(vhtag), jci.xor_tag);
*pvh = lws_get_vhost_by_name(cx, vhtag);
if (*pvh) {
lwsl_info("%s: %s -> existing %s\n", __func__, address, vhtag);
/* hit, let's just use that then */
return 0;
}
/*
* ... so, we know the SKIDs of the missing CAs, but we don't have the
* DERs for them, and so no configured vhost trusting them yet. We have
* had the DERs at some point, but we can't afford to cache them, so
* we will have to get them again.
*
* Let's make an inflight for this, it will create the vhost when it
* completes. If syncrhronous, then it will complete before we leave
* here, otherwise it will have a life of its own until all the
* queries use the cb to succeed or fail.
*/
size = strlen(address);
inf = lws_zalloc(sizeof(*inf) + size + 1, __func__);
if (!inf)
return 1;
memcpy(&inf[1], address, size + 1);
inf->refcount = (char)jci.count_skids;
lws_dll2_add_tail(&inf->list, &cx->jit_inflight);
/*
* ...kid_chain[0] AKID should indicate the right CA SKID that we want.
*
* Because of cross-signing, we check all of them and accept we may get
* multiple (we can handle 3) CAs needed.
*/
for (n = 0; n < jci.count_skids; n++)
cx->system_ops->jit_trust_query(cx, jci.skids[n].kid,
(size_t)jci.skids[n].kid_len,
(void *)inf);
/* ... in case synchronous and it already finished the queries */
*pvh = lws_get_vhost_by_name(cx, vhtag);
if (*pvh) {
/* hit, let's just use that then */
lwsl_info("%s: bind to created vhost %s\n", __func__, vhtag);
return 0;
} else
lwsl_err("%s: unable to bind to %s\n", __func__, vhtag);
/* right now, nothing to offer */
return 1;
}
void
lws_tls_jit_trust_inflight_destroy(lws_tls_jit_inflight_t *inf)
{
int n;
for (n = 0; n < inf->ders; n++)
lws_free_set_NULL(inf->der[n]);
lws_dll2_remove(&inf->list);
lws_free(inf);
}
static int
inflight_destroy(struct lws_dll2 *d, void *user)
{
lws_tls_jit_inflight_t *inf;
inf = lws_container_of(d, lws_tls_jit_inflight_t, list);
lws_tls_jit_trust_inflight_destroy(inf);
return 0;
}
void
lws_tls_jit_trust_inflight_destroy_all(struct lws_context *cx)
{
lws_dll2_foreach_safe(&cx->jit_inflight, cx, inflight_destroy);
}
static void
unref_vh_grace_cb(lws_sorted_usec_list_t *sul)
{
struct lws_vhost *vh = lws_container_of(sul, struct lws_vhost,
sul_unref);
lwsl_info("%s: %s\n", __func__, vh->lc.gutag);
lws_vhost_destroy(vh);
}
void
lws_tls_jit_trust_vh_start_grace(struct lws_vhost *vh)
{
lwsl_info("%s: %s: unused, grace %dms\n", __func__, vh->lc.gutag,
vh->context->vh_idle_grace_ms);
lws_sul_schedule(vh->context, 0, &vh->sul_unref, unref_vh_grace_cb,
(lws_usec_t)vh->context->vh_idle_grace_ms *
LWS_US_PER_MS);
}
#if defined(_DEBUG)
static void
lws_tls_jit_trust_cert_info(const uint8_t *der, size_t der_len)
{
struct lws_x509_cert *x;
union lws_tls_cert_info_results *u;
char p = 0, buf[192 + sizeof(*u)];
if (lws_x509_create(&x))
return;
if (!lws_x509_parse_from_pem(x, der, der_len)) {
u = (union lws_tls_cert_info_results *)buf;
if (!lws_x509_info(x, LWS_TLS_CERT_INFO_ISSUER_NAME, u, 192)) {
lwsl_info("ISS: %s\n", u->ns.name);
p = 1;
}
if (!lws_x509_info(x, LWS_TLS_CERT_INFO_COMMON_NAME, u, 192)) {
lwsl_info("CN: %s\n", u->ns.name);
p = 1;
}
if (!p) {
lwsl_err("%s: unable to get any info\n", __func__);
lwsl_hexdump_err(der, der_len);
}
} else
lwsl_err("%s: unable to load DER\n", __func__);
lws_x509_destroy(&x);
}
#endif
/*
* This processes the JIT Trust lookup results independent of the tls backend.
*/
int
lws_tls_jit_trust_got_cert_cb(struct lws_context *cx, void *got_opaque,
const uint8_t *skid, size_t skid_len,
const uint8_t *der, size_t der_len)
{
lws_tls_jit_inflight_t *inf = (lws_tls_jit_inflight_t *)got_opaque;
struct lws_context_creation_info info;
lws_tls_jit_cache_item_t jci;
struct lws_vhost *v;
char vhtag[20];
char hit = 0;
int n;
/*
* Before anything else, check the inf is still valid. In the low
* probability but possible case it was reallocated to be a different
* inflight, that may cause different CA certs to apply to a connection,
* but since mbedtls will then validate the server cert using the wrong
* trusted CA, it will just cause temporary conn fail.
*/
lws_start_foreach_dll(struct lws_dll2 *, e, cx->jit_inflight.head) {
lws_tls_jit_inflight_t *i = lws_container_of(e,
lws_tls_jit_inflight_t, list);
if (i == inf) {
hit = 1;
break;
}
} lws_end_foreach_dll(e);
if (!hit)
/* inf has already gone */
return 1;
inf->refcount--;
if (skid_len >= 4)
inf->tag ^= *((uint32_t *)skid);
if (der && inf->ders < (int)LWS_ARRAY_SIZE(inf->der) && inf->refcount) {
/*
* We have a trusted CA, but more results coming... stash it
* in heap.
*/
inf->kid[inf->ders].kid_len = (uint8_t)((skid_len >
(uint8_t)sizeof(inf->kid[inf->ders].kid)) ?
sizeof(inf->kid[inf->ders].kid) : skid_len);
memcpy(inf->kid[inf->ders].kid, skid,
inf->kid[inf->ders].kid_len);
inf->der[inf->ders] = lws_malloc(der_len, __func__);
if (!inf->der[inf->ders])
return 1;
memcpy(inf->der[inf->ders], der, der_len);
inf->der_len[inf->ders] = (short)der_len;
inf->ders++;
return 0;
}
/*
* We accept up to three valid CA, and then end the inflight early.
* Any further pending results are dropped, since we got all we could
* use. Up to two valid CA would be held in the inflight and the other
* provided in the params.
*
* If we did not already fill up the inflight, keep waiting for any
* others expected
*/
if (inf->refcount && inf->ders < (int)LWS_ARRAY_SIZE(inf->der))
return 0;
if (!der && !inf->ders) {
lwsl_warn("%s: no trusted CA certs matching\n", __func__);
goto destroy_inf;
}
tag_to_vh_name(vhtag, sizeof(vhtag), inf->tag);
/*
* We have got at least one CA, it's all the CAs we're going to get,
* or that we can handle. So we have to process and drop the inf.
*
* First let's make a cache entry with a shortish ttl, mapping the
* hostname we were trying to connect to, to the SKIDs that actually
* had trust results. This may come in handy later when we want to
* connect to the same host again, but any vhost from before has been
* removed... we can just ask for the specific CAs to regenerate the
* vhost, without having to first fail the connection attempt to get the
* server cert.
*
* The cache entry can be evicted at any time, so it is selfcontained.
* If it's also lost, we start over with the initial failing connection
* to figure out what we need to make it work.
*/
memset(&jci, 0, sizeof(jci));
jci.xor_tag = inf->tag;
/* copy the SKIDs from the inflight and params into the cache item */
for (n = 0; n < (int)LWS_ARRAY_SIZE(inf->der); n++)
if (inf->kid[n].kid_len)
lws_tls_kid_copy_kid(&jci.skids[jci.count_skids++],
&inf->kid[n]);
if (skid_len) {
if (skid_len > sizeof(inf->kid[0].kid))
skid_len = sizeof(inf->kid[0].kid);
jci.skids[jci.count_skids].kid_len = (uint8_t)skid_len;
memcpy(jci.skids[jci.count_skids++].kid, skid, skid_len);
}
lwsl_info("%s: adding cache mapping %s -> %s\n", __func__,
(const char *)&inf[1], vhtag);
if (lws_cache_write_through(cx->trust_cache, (const char *)&inf[1],
(const uint8_t *)&jci, sizeof(jci),
lws_now_usecs() + (3600ll *LWS_US_PER_SEC),
NULL))
lwsl_warn("%s: add to cache failed\n", __func__);
/* is there already a vhost for this commutative-xor SKID trust? */
if (lws_get_vhost_by_name(cx, vhtag)) {
lwsl_info("%s: tag vhost %s already exists, skipping\n",
__func__, vhtag);
goto destroy_inf;
}
/*
* We only end up here when we attempted a connection to this hostname.
*
* We have the identified CA trust DER(s) to hand, let's create the
* necessary vhost + prepared SSL_CTX for it to use on the retry, it
* will be used straight away if the retry comes before the idle vhost
* timeout.
*
* We also use this path in the case we have the cache entry but no
* matching vhost already existing, to create one.
*/
memset(&info, 0, sizeof(info));
info.vhost_name = vhtag;
info.port = CONTEXT_PORT_NO_LISTEN;
info.options = cx->options;
/*
* We have to create the vhost with the first valid trusted DER...
* if we have a params one, use that so the rest are all from inflight
*/
if (der) {
info.client_ssl_ca_mem = der;
info.client_ssl_ca_mem_len = (unsigned int)der_len;
n = 0;
} else {
info.client_ssl_ca_mem = inf->der[0];
info.client_ssl_ca_mem_len = (unsigned int)inf->der_len[0];
n = 1;
}
#if defined(_DEBUG)
lws_tls_jit_trust_cert_info(info.client_ssl_ca_mem,
info.client_ssl_ca_mem_len);
#endif
info.protocols = cx->protocols_copy;
v = lws_create_vhost(cx, &info);
if (!v)
lwsl_err("%s: failed to create vh %s\n", __func__, vhtag);
v->grace_after_unref = 1;
lws_tls_jit_trust_vh_start_grace(v);
/*
* Do we need to add more trusted certs from inflight?
*/
while (n < inf->ders) {
#if defined(_DEBUG)
lws_tls_jit_trust_cert_info(inf->der[n],
(size_t)inf->der_len[n]);
#endif
if (lws_tls_client_vhost_extra_cert_mem(v, inf->der[n],
(size_t)inf->der_len[n]))
lwsl_err("%s: add extra cert failed\n", __func__);
n++;
}
lwsl_info("%s: created jitt %s -> vh %s\n", __func__,
(const char *)&inf[1], vhtag);
destroy_inf:
lws_tls_jit_trust_inflight_destroy(inf);
return 0;
}
/*
* Refer to ./READMEs/README.jit-trust.md for blob layout specification
*/
int
lws_tls_jit_trust_blob_queury_skid(const void *_blob, size_t blen,
const uint8_t *skid, size_t skid_len,
const uint8_t **prpder, size_t *prder_len)
{
const uint8_t *pskidlen, *pskids, *pder, *blob = (uint8_t *)_blob;
const uint16_t *pderlen;
int certs;
/* sanity check blob length and magic */
if (blen < 32768 ||
lws_ser_ru32be(blob) != LWS_JIT_TRUST_MAGIC_BE ||
lws_ser_ru32be(blob + LJT_OFS_END) != blen) {
lwsl_err("%s: blob not sane\n", __func__);
return -1;
}
if (!skid_len)
return 1;
/* point into the various sub-tables */
certs = (int)lws_ser_ru16be(blob + LJT_OFS_32_COUNT_CERTS);
pderlen = (uint16_t *)(blob + lws_ser_ru32be(blob +
LJT_OFS_32_DERLEN));
pskidlen = blob + lws_ser_ru32be(blob + LJT_OFS_32_SKIDLEN);
pskids = blob + lws_ser_ru32be(blob + LJT_OFS_32_SKID);
pder = blob + LJT_OFS_DER;
/* check each cert SKID in turn, return the DER if found */
while (certs--) {
/* paranoia / sanity */
assert(pskids < blob + blen);
assert(pder < blob + blen);
assert(pskidlen < blob + blen);
assert((uint8_t *)pderlen < blob + blen);
/* we will accept to match on truncated SKIDs */
if (*pskidlen >= skid_len &&
!memcmp(skid, pskids, skid_len)) {
/*
* We found a trusted CA cert of the right SKID
*/
*prpder = pder;
*prder_len = lws_ser_ru16be((uint8_t *)pderlen);
return 0;
}
pder += lws_ser_ru16be((uint8_t *)pderlen);
pskids += *pskidlen;
pderlen++;
pskidlen++;
}
return 1;
}