From fac9e8bec1983fa9dff8f447fef106e427dfec26 Mon Sep 17 00:00:00 2001 From: "Claudius \"keldu\" Holeksa" Date: Thu, 20 Jul 2023 17:02:05 +0200 Subject: c++: Renamed src to c++ --- c++/io-tls/tls.cpp | 252 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 252 insertions(+) create mode 100644 c++/io-tls/tls.cpp (limited to 'c++/io-tls/tls.cpp') diff --git a/c++/io-tls/tls.cpp b/c++/io-tls/tls.cpp new file mode 100644 index 0000000..9fa143c --- /dev/null +++ b/c++/io-tls/tls.cpp @@ -0,0 +1,252 @@ +#include "tls.h" + +#include +#include + +#include + +#include + +#include + +namespace saw { + +class tls::impl { +public: + gnutls_certificate_credentials_t xcred; + +public: + impl() { + gnutls_global_init(); + gnutls_certificate_allocate_credentials(&xcred); + gnutls_certificate_set_x509_system_trust(xcred); + } + + ~impl() { + gnutls_certificate_free_credentials(xcred); + gnutls_global_deinit(); + } +}; + +static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, + size_t size); +static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size); + +tls::tls() : impl_{heap()} {} + +tls::~tls() {} + +tls::impl &tls::get_impl() { return *impl_; } + +class tls_io_stream final : public io_stream { +private: + own internal; + gnutls_session_t session_handle; + +public: + tls_io_stream(own internal_) : internal{std::move(internal_)} {} + + ~tls_io_stream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); } + + error_or read(void *buffer, size_t length) override { + ssize_t size = gnutls_record_recv(session_handle, buffer, length); + if (size < 0) { + if(gnutls_error_is_fatal(size) == 0){ + return make_error("Recoverable error on read in gnutls. TODO better error msg handling"); + // Leaving proper message handling done in previous error framework + //return recoverable_error([size](){return std::string{"Read recoverable Error "}+std::string{gnutls_strerror(size)};}, "Error read r"); + }else{ + return make_error("Fatal error on read in gnutls. TODO better error msg handling"); + } + }else if(size == 0){ + return make_error(); + } + + return static_cast(length); + } + + conveyor read_ready() override { return internal->read_ready(); } + + conveyor on_read_disconnected() override { + return internal->on_read_disconnected(); + } + + error_or write(const void *buffer, size_t length) override { + ssize_t size = gnutls_record_send(session_handle, buffer, length); + if(size < 0){ + if(gnutls_error_is_fatal(size) == 0){ + return make_error("Recoverable error on write in gnutls. TODO better error msg handling"); + }else{ + return make_error("Fatal error on write in gnutls. TODO better error msg handling"); + } + } + + return static_cast(size); + } + + conveyor write_ready() override { return internal->write_ready(); } + + gnutls_session_t &session() { return session_handle; } +}; + +tls_server::tls_server(own srv) : internal{std::move(srv)} {} + +conveyor> tls_server::accept() { + SAW_ASSERT(internal) { return conveyor>{fix_void>{nullptr}}; } + return internal->accept().then([](own stream) -> own { + /// @todo handshake + + + return heap(std::move(stream)); + }); +} + +namespace { +/* +* Small helper for setting up the nonblocking connection handshake +*/ +struct tls_client_stream_helper { +public: + own>> feeder; + conveyor_sink connection_sink; + conveyor_sink stream_reader; + conveyor_sink stream_writer; + + own stream = nullptr; +public: + tls_client_stream_helper(own>> f): + feeder{std::move(f)} + {} + + void setupTurn(){ + SAW_ASSERT(stream){ + return; + } + + stream_reader = stream->read_ready().then([this](){ + turn(); + }).sink(); + + stream_writer = stream->write_ready().then([this](){ + turn(); + }).sink(); + } + + void turn(){ + if(stream){ + // Guarantee that the receiving end is already setup + SAW_ASSERT(feeder){ + return; + } + + auto &session = stream->session(); + + int ret; + do { + ret = gnutls_handshake(session); + } while ( (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) && gnutls_error_is_fatal(ret) == 0); + + if(gnutls_error_is_fatal(ret)){ + feeder->fail(make_error("Couldn't create Tls connection")); + stream = nullptr; + }else if(ret == GNUTLS_E_SUCCESS){ + feeder->feed(std::move(stream)); + } + } + } +}; +} + +own tls_network::listen(network_address& address) { + return heap(internal.listen(address)); +} + +conveyor> tls_network::connect(network_address& address) { + // Helper setups + auto caf = new_conveyor_and_feeder>(); + own helper = heap(std::move(caf.feeder)); + tls_client_stream_helper* hlp_ptr = helper.get(); + + // Conveyor entangled structure + auto prim_conv = internal.connect(address).then([this, hlp_ptr, addr = address.address()]( + own stream) -> error_or { + io_stream* inner_stream = stream.get(); + auto tls_stream = heap(std::move(stream)); + + auto &session = tls_stream->session(); + + gnutls_init(&session, GNUTLS_CLIENT); + + gnutls_server_name_set(session, GNUTLS_NAME_DNS, addr.c_str(), + addr.size()); + + gnutls_set_default_priority(session); + gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, + tls_.get_impl().xcred); + gnutls_session_set_verify_cert(session, addr.c_str(), 0); + + gnutls_transport_set_ptr(session, reinterpret_cast(inner_stream)); + gnutls_transport_set_push_function(session, forst_tls_push_func); + gnutls_transport_set_pull_function(session, forst_tls_pull_func); + + // gnutls_handshake_set_timeout(session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); + + hlp_ptr->stream = std::move(tls_stream); + hlp_ptr->setupTurn(); + hlp_ptr->turn(); + + return void_t{}; + }); + + helper->connection_sink = prim_conv.sink(); + + return caf.conveyor.attach(std::move(helper)); +} + +own tls_network::datagram(network_address& address){ + ///@unimplemented + return nullptr; +} + +static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, + size_t size) { + io_stream *stream = reinterpret_cast(p); + if (!stream) { + return -1; + } + + error_or length = stream->write(data, size); + if (length.is_error() || !length.is_value()) { + return -1; + } + + return static_cast(length.get_value()); +} + +static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) { + io_stream *stream = reinterpret_cast(p); + if (!stream) { + return -1; + } + + error_or length = stream->read(data, size); + if (length.is_error() || !length.is_value()) { + return -1; + } + + return static_cast(length.get_value()); +} + +tls_network::tls_network(tls& tls_, network &network) : tls_{tls_},internal{network} {} + +conveyor> tls_network::resolve_address(const std::string &addr, + uint16_t port) { + /// @todo tls server name needed. Check validity. Won't matter later on, because gnutls should fail anyway. But + /// it's better to find the error source sooner rather than later + return internal.resolve_address(addr, port); +} + +std::optional> setup_tls_network(network &network) { + return std::nullopt; +} +} // namespace saw -- cgit v1.2.3