diff options
author | Claudius Holeksa <mail@keldu.de> | 2023-04-29 19:06:53 +0200 |
---|---|---|
committer | Claudius Holeksa <mail@keldu.de> | 2023-04-29 19:06:53 +0200 |
commit | c742bc3f57cb00d84e2df034f757d4a39e3ade7e (patch) | |
tree | ab42b2b86f49c09c0124e5b65c226a0588c2ca05 /forstio/io-tls/tls.cpp | |
parent | f07487ce8f0f3ebd5c4d1082a9521f09588fa34a (diff) |
Added io tls with gnutls
Diffstat (limited to 'forstio/io-tls/tls.cpp')
-rw-r--r-- | forstio/io-tls/tls.cpp | 250 |
1 files changed, 250 insertions, 0 deletions
diff --git a/forstio/io-tls/tls.cpp b/forstio/io-tls/tls.cpp new file mode 100644 index 0000000..c1497bc --- /dev/null +++ b/forstio/io-tls/tls.cpp @@ -0,0 +1,250 @@ +#include "tls.h" + +#include <gnutls/gnutls.h> +#include <gnutls/x509.h> + +#include <forstio/io/io_helpers.h> + +#include <cassert> + +#include <iostream> + +namespace saw { + +class Tls::Impl { +public: + gnutls_certificate_credentials_t xcred; + +public: + Impl() { + gnutls_global_init(); + gnutls_certificate_allocate_credentials(&xcred); + gnutls_certificate_set_x509_system_trust(xcred); + } + + ~Impl() { + gnutls_certificate_free_credentials(xcred); + gnutls_global_deinit(); + } +}; + +static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, + size_t size); +static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size); + +Tls::Tls() : impl{heap<Tls::Impl>()} {} + +Tls::~Tls() {} + +Tls::Impl &Tls::getImpl() { return *impl; } + +class TlsIoStream final : public io_stream { +private: + own<io_stream> internal; + gnutls_session_t session_handle; + +public: + TlsIoStream(own<io_stream> internal_) : internal{std::move(internal_)} {} + + ~TlsIoStream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); } + + error_or<size_t> read(void *buffer, size_t length) override { + ssize_t size = gnutls_record_recv(session_handle, buffer, length); + if (size < 0) { + if(gnutls_error_is_fatal(size) == 0){ + return recoverable_error([size](){return std::string{"Read recoverable Error "}+std::string{gnutls_strerror(size)};}, "Error read r"); + }else{ + return critical_error([size](){return std::string{"Read critical Error "}+std::string{gnutls_strerror(size)};}, "Error read c"); + } + }else if(size == 0){ + return critical_error("Disconnected"); + } + + return static_cast<size_t>(length); + } + + conveyor<void> read_ready() override { return internal->read_ready(); } + + conveyor<void> on_read_disconnected() override { + return internal->on_read_disconnected(); + } + + error_or<size_t> write(const void *buffer, size_t length) override { + ssize_t size = gnutls_record_send(session_handle, buffer, length); + if(size < 0){ + if(gnutls_error_is_fatal(size) == 0){ + return recoverable_error([size](){return std::string{"Write recoverable Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write r"); + }else{ + return critical_error([size](){return std::string{"Write critical Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write c"); + } + } + + return static_cast<size_t>(size); + } + + conveyor<void> write_ready() override { return internal->write_ready(); } + + gnutls_session_t &session() { return session_handle; } +}; + +TlsServer::TlsServer(own<server> srv) : internal{std::move(srv)} {} + +conveyor<own<io_stream>> TlsServer::accept() { + SAW_ASSERT(internal) { return conveyor<own<io_stream>>{fix_void<own<io_stream>>{nullptr}}; } + return internal->accept().then([](own<io_stream> stream) -> own<io_stream> { + /// @todo handshake + + + return heap<TlsIoStream>(std::move(stream)); + }); +} + +namespace { +/* +* Small helper for setting up the nonblocking connection handshake +*/ +struct TlsClientStreamHelper { +public: + own<conveyor_feeder<own<io_stream>>> feeder; + conveyor_sink connection_sink; + conveyor_sink stream_reader; + conveyor_sink stream_writer; + + own<TlsIoStream> stream = nullptr; +public: + TlsClientStreamHelper(own<conveyor_feeder<own<io_stream>>> f): + feeder{std::move(f)} + {} + + void setupTurn(){ + SAW_ASSERT(stream){ + return; + } + + stream_reader = stream->read_ready().then([this](){ + turn(); + }).sink(); + + stream_writer = stream->write_ready().then([this](){ + turn(); + }).sink(); + } + + void turn(){ + if(stream){ + // Guarantee that the receiving end is already setup + SAW_ASSERT(feeder){ + return; + } + + auto &session = stream->session(); + + int ret; + do { + ret = gnutls_handshake(session); + } while ( (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) && gnutls_error_is_fatal(ret) == 0); + + if(gnutls_error_is_fatal(ret)){ + feeder->fail(critical_error("Couldn't create Tls connection")); + stream = nullptr; + }else if(ret == GNUTLS_E_SUCCESS){ + feeder->feed(std::move(stream)); + } + } + } +}; +} + +own<server> TlsNetwork::listen(network_address& address) { + return heap<TlsServer>(internal.listen(address)); +} + +conveyor<own<io_stream>> TlsNetwork::connect(network_address& address) { + // Helper setups + auto caf = new_conveyor_and_feeder<own<io_stream>>(); + own<TlsClientStreamHelper> helper = heap<TlsClientStreamHelper>(std::move(caf.feeder)); + TlsClientStreamHelper* hlp_ptr = helper.get(); + + // Conveyor entangled structure + auto prim_conv = internal.connect(address).then([this, hlp_ptr, addr = address.address()]( + own<io_stream> stream) -> error_or<void> { + io_stream* inner_stream = stream.get(); + auto tls_stream = heap<TlsIoStream>(std::move(stream)); + + auto &session = tls_stream->session(); + + gnutls_init(&session, GNUTLS_CLIENT); + + gnutls_server_name_set(session, GNUTLS_NAME_DNS, addr.c_str(), + addr.size()); + + gnutls_set_default_priority(session); + gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, + tls.getImpl().xcred); + gnutls_session_set_verify_cert(session, addr.c_str(), 0); + + gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(inner_stream)); + gnutls_transport_set_push_function(session, forst_tls_push_func); + gnutls_transport_set_pull_function(session, forst_tls_pull_func); + + // gnutls_handshake_set_timeout(session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); + + hlp_ptr->stream = std::move(tls_stream); + hlp_ptr->setupTurn(); + hlp_ptr->turn(); + + return void_t{}; + }); + + helper->connection_sink = prim_conv.sink(); + + return caf.conveyor.attach(std::move(helper)); +} + +own<datagram> TlsNetwork::datagram(network_address& address){ + ///@unimplemented + return nullptr; +} + +static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, + size_t size) { + io_stream *stream = reinterpret_cast<io_stream *>(p); + if (!stream) { + return -1; + } + + error_or<size_t> length = stream->write(data, size); + if (length.is_error() || !length.is_value()) { + return -1; + } + + return static_cast<ssize_t>(length.value()); +} + +static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) { + io_stream *stream = reinterpret_cast<io_stream *>(p); + if (!stream) { + return -1; + } + + error_or<size_t> length = stream->read(data, size); + if (length.is_error() || !length.is_value()) { + return -1; + } + + return static_cast<ssize_t>(length.value()); +} + +TlsNetwork::TlsNetwork(Tls& tls_, network &network) : tls{tls_},internal{network} {} + +conveyor<own<network_address>> TlsNetwork::resolve_address(const std::string &addr, + uint16_t port) { + /// @todo tls server name needed. Check validity. Won't matter later on, because gnutls should fail anyway. But + /// it's better to find the error source sooner rather than later + return internal.resolve_address(addr, port); +} + +std::optional<own<TlsNetwork>> setupTlsNetwork(network &network) { + return std::nullopt; +} +} // namespace saw |