2022-01-01 23:50:23 +01:00

351 lines
10 KiB
C++

/*
* IXSocketMbedTLS.cpp
* Author: Benjamin Sergeant, Max Weisel
* Copyright (c) 2019-2020 Machine Zone, Inc. All rights reserved.
*
* Some code taken from
* https://github.com/rottor12/WsClientLib/blob/master/lib/src/WsClientLib.cpp
* and mini_client.c example from mbedtls
*/
#ifdef IXWEBSOCKET_USE_MBED_TLS
#include "IXSocketMbedTLS.h"
#include "IXNetSystem.h"
#include "IXSocket.h"
#include "IXSocketConnect.h"
#include <string.h>
#ifdef _WIN32
// For manipulating the certificate store
#include <wincrypt.h>
#endif
namespace ix
{
SocketMbedTLS::SocketMbedTLS(const SocketTLSOptions &tlsOptions, int fd)
: Socket(fd), _tlsOptions(tlsOptions)
{
initMBedTLS();
}
SocketMbedTLS::~SocketMbedTLS() { SocketMbedTLS::close(); }
void SocketMbedTLS::initMBedTLS()
{
std::lock_guard<std::mutex> lock(_mutex);
mbedtls_ssl_init(&_ssl);
mbedtls_ssl_config_init(&_conf);
mbedtls_ctr_drbg_init(&_ctr_drbg);
mbedtls_entropy_init(&_entropy);
mbedtls_x509_crt_init(&_cacert);
mbedtls_x509_crt_init(&_cert);
mbedtls_pk_init(&_pkey);
}
bool SocketMbedTLS::loadSystemCertificates(std::string &errorMsg)
{
#ifdef _WIN32
DWORD flags = CERT_STORE_READONLY_FLAG | CERT_STORE_OPEN_EXISTING_FLAG |
CERT_SYSTEM_STORE_CURRENT_USER;
HCERTSTORE systemStore =
CertOpenStore(CERT_STORE_PROV_SYSTEM, 0, 0, flags, L"Root");
if (!systemStore) {
errorMsg = "CertOpenStore failed with ";
errorMsg += std::to_string(GetLastError());
return false;
}
PCCERT_CONTEXT certificateIterator = NULL;
int certificateCount = 0;
while (certificateIterator =
CertEnumCertificatesInStore(systemStore, certificateIterator)) {
if (certificateIterator->dwCertEncodingType & X509_ASN_ENCODING) {
int ret =
mbedtls_x509_crt_parse(&_cacert,
certificateIterator->pbCertEncoded,
certificateIterator->cbCertEncoded);
if (ret == 0) {
++certificateCount;
}
}
}
CertFreeCertificateContext(certificateIterator);
CertCloseStore(systemStore, 0);
if (certificateCount == 0) {
errorMsg = "No certificates found";
return false;
}
return true;
#else
// On macOS we can query the system cert location from the keychain
// On Linux we could try to fetch some local files based on the distribution
// On Android we could use JNI to get to the system certs
return false;
#endif
}
bool SocketMbedTLS::init(const std::string &host,
bool isClient,
std::string &errMsg)
{
initMBedTLS();
std::lock_guard<std::mutex> lock(_mutex);
const char *pers = "IXSocketMbedTLS";
if (mbedtls_ctr_drbg_seed(&_ctr_drbg,
mbedtls_entropy_func,
&_entropy,
(const unsigned char *)pers,
strlen(pers)) != 0) {
errMsg = "Setting entropy seed failed";
return false;
}
if (mbedtls_ssl_config_defaults(&_conf,
(isClient) ? MBEDTLS_SSL_IS_CLIENT
: MBEDTLS_SSL_IS_SERVER,
MBEDTLS_SSL_TRANSPORT_STREAM,
MBEDTLS_SSL_PRESET_DEFAULT) != 0) {
errMsg = "Setting config default failed";
return false;
}
mbedtls_ssl_conf_rng(&_conf, mbedtls_ctr_drbg_random, &_ctr_drbg);
if (_tlsOptions.hasCertAndKey()) {
if (mbedtls_x509_crt_parse_file(&_cert, _tlsOptions.certFile.c_str()) <
0) {
errMsg = "Cannot parse cert file '" + _tlsOptions.certFile + "'";
return false;
}
#ifdef IXWEBSOCKET_USE_MBED_TLS_MIN_VERSION_3
if (mbedtls_pk_parse_keyfile(&_pkey,
_tlsOptions.keyFile.c_str(),
"",
mbedtls_ctr_drbg_random,
&_ctr_drbg) < 0)
#else
if (mbedtls_pk_parse_keyfile(&_pkey, _tlsOptions.keyFile.c_str(), "") <
0)
#endif
{
errMsg = "Cannot parse key file '" + _tlsOptions.keyFile + "'";
return false;
}
if (mbedtls_ssl_conf_own_cert(&_conf, &_cert, &_pkey) < 0) {
errMsg = "Problem configuring cert '" + _tlsOptions.certFile + "'";
return false;
}
}
if (_tlsOptions.isPeerVerifyDisabled()) {
mbedtls_ssl_conf_authmode(&_conf, MBEDTLS_SSL_VERIFY_NONE);
} else {
// FIXME: should we call mbedtls_ssl_conf_verify ?
mbedtls_ssl_conf_authmode(&_conf, MBEDTLS_SSL_VERIFY_REQUIRED);
if (_tlsOptions.isUsingSystemDefaults()) {
if (!loadSystemCertificates(errMsg)) {
return false;
}
} else {
if (_tlsOptions.isUsingInMemoryCAs()) {
const char *buffer = _tlsOptions.caFile.c_str();
size_t bufferSize = _tlsOptions.caFile.size() +
1; // Needs to include null terminating
// character otherwise mbedtls will fail.
if (mbedtls_x509_crt_parse(&_cacert,
(const unsigned char *)buffer,
bufferSize) < 0) {
errMsg = "Cannot parse CA from memory.";
return false;
}
} else if (mbedtls_x509_crt_parse_file(&_cacert,
_tlsOptions.caFile.c_str()) <
0) {
errMsg = "Cannot parse CA file '" + _tlsOptions.caFile + "'";
return false;
}
}
mbedtls_ssl_conf_ca_chain(&_conf, &_cacert, NULL);
}
if (mbedtls_ssl_setup(&_ssl, &_conf) != 0) {
errMsg = "SSL setup failed";
return false;
}
if (!host.empty() && mbedtls_ssl_set_hostname(&_ssl, host.c_str()) != 0) {
errMsg = "SNI setup failed";
return false;
}
return true;
}
bool SocketMbedTLS::accept(std::string &errMsg)
{
bool isClient = false;
bool initialized = init(std::string(), isClient, errMsg);
if (!initialized) {
close();
return false;
}
mbedtls_ssl_set_bio(&_ssl,
&_sockfd,
mbedtls_net_send,
mbedtls_net_recv,
NULL);
int res;
do {
std::lock_guard<std::mutex> lock(_mutex);
res = mbedtls_ssl_handshake(&_ssl);
} while (res == MBEDTLS_ERR_SSL_WANT_READ ||
res == MBEDTLS_ERR_SSL_WANT_WRITE);
if (res != 0) {
char buf[256];
mbedtls_strerror(res, buf, sizeof(buf));
errMsg = "error in handshake : ";
errMsg += buf;
if (res == MBEDTLS_ERR_X509_CERT_VERIFY_FAILED) {
char verifyBuf[512];
uint32_t flags = mbedtls_ssl_get_verify_result(&_ssl);
mbedtls_x509_crt_verify_info(verifyBuf,
sizeof(verifyBuf),
" ! ",
flags);
errMsg += " : ";
errMsg += verifyBuf;
}
close();
return false;
}
return true;
}
bool SocketMbedTLS::connect(const std::string &host,
int port,
std::string &errMsg,
const CancellationRequest &isCancellationRequested)
{
{
std::lock_guard<std::mutex> lock(_mutex);
_sockfd =
SocketConnect::connect(host, port, errMsg, isCancellationRequested);
if (_sockfd == -1)
return false;
}
bool isClient = true;
bool initialized = init(host, isClient, errMsg);
if (!initialized) {
close();
return false;
}
mbedtls_ssl_set_bio(&_ssl,
&_sockfd,
mbedtls_net_send,
mbedtls_net_recv,
NULL);
int res;
do {
{
std::lock_guard<std::mutex> lock(_mutex);
res = mbedtls_ssl_handshake(&_ssl);
}
if (isCancellationRequested()) {
errMsg = "Cancellation requested";
close();
return false;
}
} while (res == MBEDTLS_ERR_SSL_WANT_READ ||
res == MBEDTLS_ERR_SSL_WANT_WRITE);
if (res != 0) {
char buf[256];
mbedtls_strerror(res, buf, sizeof(buf));
errMsg = "error in handshake : ";
errMsg += buf;
close();
return false;
}
return true;
}
void SocketMbedTLS::close()
{
std::lock_guard<std::mutex> lock(_mutex);
mbedtls_ssl_free(&_ssl);
mbedtls_ssl_config_free(&_conf);
mbedtls_ctr_drbg_free(&_ctr_drbg);
mbedtls_entropy_free(&_entropy);
mbedtls_x509_crt_free(&_cacert);
mbedtls_x509_crt_free(&_cert);
Socket::close();
}
ssize_t SocketMbedTLS::send(char *buf, size_t nbyte)
{
std::lock_guard<std::mutex> lock(_mutex);
ssize_t res = mbedtls_ssl_write(&_ssl, (unsigned char *)buf, nbyte);
if (res > 0) {
return res;
} else if (res == MBEDTLS_ERR_SSL_WANT_READ ||
res == MBEDTLS_ERR_SSL_WANT_WRITE) {
errno = EWOULDBLOCK;
return -1;
} else {
return -1;
}
}
ssize_t SocketMbedTLS::recv(void *buf, size_t nbyte)
{
while (true) {
std::lock_guard<std::mutex> lock(_mutex);
ssize_t res = mbedtls_ssl_read(&_ssl, (unsigned char *)buf, (int)nbyte);
if (res > 0) {
return res;
}
if (res == MBEDTLS_ERR_SSL_WANT_READ ||
res == MBEDTLS_ERR_SSL_WANT_WRITE) {
errno = EWOULDBLOCK;
}
return -1;
}
}
} // namespace ix
#endif // IXWEBSOCKET_USE_MBED_TLS