From 17e22f10026068990595941eeb503fc2adb476a8 Mon Sep 17 00:00:00 2001 From: Claudius 'keldu' Holeksa Date: Thu, 17 Oct 2024 18:56:11 +0200 Subject: Changing impl things --- modules/io-tls/tls.cpp | 120 ++++++++++++++++++++++++++----------------------- modules/io-tls/tls.hpp | 29 +++++++++++- 2 files changed, 91 insertions(+), 58 deletions(-) (limited to 'modules/io-tls') diff --git a/modules/io-tls/tls.cpp b/modules/io-tls/tls.cpp index c9c71f4..1c42215 100644 --- a/modules/io-tls/tls.cpp +++ b/modules/io-tls/tls.cpp @@ -38,23 +38,24 @@ tls::~tls() {} tls::impl &tls::get_impl() { return *impl_; } -class tls_io_stream final : public io_stream { +template +class tls_io_stream final : public io_stream> { private: - own internal; + own> internal_; gnutls_certificate_credentials_t xcred_; - gnutls_session_t session_handle; + gnutls_session_t session_handle_; public: - tls_io_stream(own internal_, gnutls_certificate_credentials_t xcred__, gnutls_session_t session_handle__): - internal{std::move(internal_)}, + tls_io_stream(own> internal__, gnutls_certificate_credentials_t xcred__, gnutls_session_t session_handle__): + internal_{std::move(internal__)}, xcred_{xcred__}, session_handle_{session_handle__} {} - ~tls_io_stream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); } + ~tls_io_stream() { gnutls_bye(session_handle_, GNUTLS_SHUT_RDWR); } error_or read(void *buffer, size_t length) override { - ssize_t size = gnutls_record_recv(session_handle, buffer, length); + ssize_t size = gnutls_record_recv(session_handle_, buffer, length); if (size < 0) { if(gnutls_error_is_fatal(size) == 0){ return make_error("Recoverable error on read in gnutls. TODO better error msg handling"); @@ -70,14 +71,14 @@ public: return static_cast(length); } - conveyor read_ready() override { return internal->read_ready(); } + conveyor read_ready() override { return internal_->read_ready(); } conveyor on_read_disconnected() override { - return internal->on_read_disconnected(); + return internal_->on_read_disconnected(); } error_or write(const void *buffer, size_t length) override { - ssize_t size = gnutls_record_send(session_handle, buffer, length); + ssize_t size = gnutls_record_send(session_handle_, buffer, length); if(size < 0){ if(gnutls_error_is_fatal(size) == 0){ return make_error("Recoverable error on write in gnutls. TODO better error msg handling"); @@ -89,19 +90,20 @@ public: return static_cast(size); } - conveyor write_ready() override { return internal->write_ready(); } + conveyor write_ready() override { return internal_->write_ready(); } - gnutls_session_t &session() { return session_handle; } + gnutls_session_t &session() { return session_handle_; } }; -class tls_server final : public server { +template +class tls_server final : public server> { private: - own internal_; + own> internal_; gnutls_certificate_credentials_t xcred_; gnutls_session_t session_handle_; public: - tls_server(own internal__, gnutls_certificate_credentials_t xcred__): + tls_server(own> internal__, gnutls_certificate_credentials_t xcred__): internal_{std::move(internal__)} {} @@ -110,36 +112,33 @@ public: gnutls_certificate_free_credentials(xcred_); } - conveyor> accept() override { - return make_error(); - } + conveyor>>> accept() override; }; -class tls_network final : public network { +template +class tls_network final : public network> { private: - tls& tls_; - network &internal; + ref tls_; + ref> internal_; public: - tls_network(tls& tls_, network &network_); + tls_network(tls& tls_, network &network_); - conveyor> resolve_address(const std::string &addr, uint16_t port = 0) override; + conveyor>>> resolve_address(const std::string &addr, uint16_t port = 0) override; - own listen(network_address& address) override; + own>> listen(const network_address>& address) override; - conveyor> connect(network_address& address) override; + conveyor>>> connect(const network_address>& address) override; - own datagram(network_address& address) override; + own>> bind_datagram(const network_address>& address) override; }; -tls_server::tls_server(own srv) : internal{std::move(srv)} {} - -conveyor> tls_server::accept() { - SAW_ASSERT(internal) { return conveyor>{fix_void>{nullptr}}; } - return internal->accept().then([](own stream) -> own { +template +conveyor>>> tls_server::accept() { + SAW_ASSERT(internal_) { return conveyor>>>{fix_void>>>{nullptr}}; } + return internal_->accept().then([](own> stream) -> own>> { /// @todo handshake - - return heap(std::move(stream)); + return heap>(std::move(stream)); }); } @@ -147,16 +146,17 @@ namespace { /* * Small helper for setting up the nonblocking connection handshake */ +template struct tls_client_stream_helper { public: - own>> feeder; + own>>>> feeder; conveyor_sink connection_sink; conveyor_sink stream_reader; conveyor_sink stream_writer; - own stream = nullptr; + own> stream = nullptr; public: - tls_client_stream_helper(own>> f): + tls_client_stream_helper(own>>>> f): feeder{std::move(f)} {} @@ -199,25 +199,27 @@ public: }; } -own tls_network::listen(const network_address& address) { +template +own>> tls_network::listen(const network_address>& address) { gnutls_certificate_credentials_t x509_cred; gnutls_certificate_allocate_credentials(&x509_cred); - auto int_srv = internal.listen(address); + auto int_srv = internal_.listen(address); - return heap(int_srv, x509_cred); + return heap(std::move(int_srv), x509_cred); } -conveyor> tls_network::connect(network_address& address) { +template +conveyor>>> tls_network::connect(const network_address>& address) { // Helper setups - auto caf = new_conveyor_and_feeder>(); - own helper = heap(std::move(caf.feeder)); - tls_client_stream_helper* hlp_ptr = helper.get(); + auto caf = new_conveyor_and_feeder>>>(); + own> helper = heap>(std::move(caf.feeder)); + tls_client_stream_helper* hlp_ptr = helper.get(); // Conveyor entangled structure - auto prim_conv = internal.connect(address).then([this, hlp_ptr, addr = address.address()]( - own stream) -> error_or { - io_stream* inner_stream = stream.get(); - auto tls_stream = heap(std::move(stream)); + auto prim_conv = internal_.connect(address).then([this, hlp_ptr, addr = address.address()]( + own> stream) -> error_or { + io_stream* inner_stream = stream.get(); + auto tls_stream = heap>(std::move(stream)); auto &session = tls_stream->session(); @@ -228,7 +230,7 @@ conveyor> tls_network::connect(network_address& address) { gnutls_set_default_priority(session); gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, - tls_.get_impl().xcred); + tls_().get_impl().xcred); gnutls_session_set_verify_cert(session, addr.c_str(), 0); gnutls_transport_set_ptr(session, reinterpret_cast(inner_stream)); @@ -249,14 +251,16 @@ conveyor> tls_network::connect(network_address& address) { return caf.conveyor.attach(std::move(helper)); } -own tls_network::datagram(network_address& address){ +template +own>> tls_network::bind_datagram(const network_address>& address){ ///@unimplemented return nullptr; } +template static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, size_t size) { - io_stream *stream = reinterpret_cast(p); + io_stream *stream = reinterpret_cast*>(p); if (!stream) { return -1; } @@ -269,8 +273,9 @@ static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, return static_cast(length.get_value()); } +template static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) { - io_stream *stream = reinterpret_cast(p); + io_stream *stream = reinterpret_cast*>(p); if (!stream) { return -1; } @@ -283,16 +288,19 @@ static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t return static_cast(length.get_value()); } -tls_network::tls_network(tls& tls_, network &network) : tls_{tls_},internal{network} {} +template +tls_network::tls_network(tls& tls_, network &network) : tls_{tls_},internal_{network} {} -conveyor> tls_network::resolve_address(const std::string &addr, +template +conveyor>>> tls_network::resolve_address(const std::string &addr, uint16_t port) { /// @todo tls server name needed. Check validity. Won't matter later on, because gnutls should fail anyway. But /// it's better to find the error source sooner rather than later - return internal.resolve_address(addr, port); + return internal_.resolve_address(addr, port); } -error_or>> setup_tls_network(network &network) { - return std::nullopt; +template +error_or>>> setup_tls_network(network> &network) { + return make_error(); } } // namespace saw diff --git a/modules/io-tls/tls.hpp b/modules/io-tls/tls.hpp index e2202f4..c5c3da1 100644 --- a/modules/io-tls/tls.hpp +++ b/modules/io-tls/tls.hpp @@ -42,8 +42,33 @@ private: options options_; }; -template<> -class network { +template +class network> { +public: + virtual ~network() = default; + + /** + * Resolve the provided string and uint16 to the preferred storage method + */ + virtual conveyor>>> + resolve_address(const std::string &addr, uint16_t port_hint = 0) = 0; + + /** + * Parse the provided string and uint16 to the preferred storage method + * Since no dns request is made here, no async conveyors have to be used. + */ + virtual error_or>>> + parse_address(const std::string &addr, uint16_t port_hint = 0) = 0; + + /** + * Set up a listener on this address + */ + virtual error_or>> listen(network_address &bind_addr) = 0; + + /** + * Connect to a remote address + */ + virtual conveyor>> connect(network_address &address) = 0; }; template -- cgit v1.2.3