diff options
author | Claudius 'keldu' Holeksa <mail@keldu.de> | 2024-10-17 18:56:11 +0200 |
---|---|---|
committer | Claudius 'keldu' Holeksa <mail@keldu.de> | 2024-10-17 18:56:11 +0200 |
commit | 17e22f10026068990595941eeb503fc2adb476a8 (patch) | |
tree | b8e92692ae7a6dc770bd7d81aeb55869ce162a98 | |
parent | b048b02732cbfcfbb95bb8e16dec71aca0e977f4 (diff) |
Changing impl things
-rw-r--r-- | modules/io-tls/tls.cpp | 120 | ||||
-rw-r--r-- | modules/io-tls/tls.hpp | 29 |
2 files changed, 91 insertions, 58 deletions
diff --git a/modules/io-tls/tls.cpp b/modules/io-tls/tls.cpp index c9c71f4..1c42215 100644 --- a/modules/io-tls/tls.cpp +++ b/modules/io-tls/tls.cpp @@ -38,23 +38,24 @@ tls::~tls() {} tls::impl &tls::get_impl() { return *impl_; } -class tls_io_stream final : public io_stream { +template<typename T> +class tls_io_stream final : public io_stream<net::Tls<T>> { private: - own<io_stream> internal; + own<io_stream<T>> internal_; gnutls_certificate_credentials_t xcred_; - gnutls_session_t session_handle; + gnutls_session_t session_handle_; public: - tls_io_stream(own<io_stream> internal_, gnutls_certificate_credentials_t xcred__, gnutls_session_t session_handle__): - internal{std::move(internal_)}, + tls_io_stream(own<io_stream<T>> internal__, gnutls_certificate_credentials_t xcred__, gnutls_session_t session_handle__): + internal_{std::move(internal__)}, xcred_{xcred__}, session_handle_{session_handle__} {} - ~tls_io_stream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); } + ~tls_io_stream() { gnutls_bye(session_handle_, GNUTLS_SHUT_RDWR); } error_or<size_t> read(void *buffer, size_t length) override { - ssize_t size = gnutls_record_recv(session_handle, buffer, length); + ssize_t size = gnutls_record_recv(session_handle_, buffer, length); if (size < 0) { if(gnutls_error_is_fatal(size) == 0){ return make_error<err::recoverable>("Recoverable error on read in gnutls. TODO better error msg handling"); @@ -70,14 +71,14 @@ public: return static_cast<size_t>(length); } - conveyor<void> read_ready() override { return internal->read_ready(); } + conveyor<void> read_ready() override { return internal_->read_ready(); } conveyor<void> on_read_disconnected() override { - return internal->on_read_disconnected(); + return internal_->on_read_disconnected(); } error_or<size_t> write(const void *buffer, size_t length) override { - ssize_t size = gnutls_record_send(session_handle, buffer, length); + ssize_t size = gnutls_record_send(session_handle_, buffer, length); if(size < 0){ if(gnutls_error_is_fatal(size) == 0){ return make_error<err::recoverable>("Recoverable error on write in gnutls. TODO better error msg handling"); @@ -89,19 +90,20 @@ public: return static_cast<size_t>(size); } - conveyor<void> write_ready() override { return internal->write_ready(); } + conveyor<void> write_ready() override { return internal_->write_ready(); } - gnutls_session_t &session() { return session_handle; } + gnutls_session_t &session() { return session_handle_; } }; -class tls_server final : public server { +template<typename T> +class tls_server final : public server<net::Tls<T>> { private: - own<server> internal_; + own<server<T>> internal_; gnutls_certificate_credentials_t xcred_; gnutls_session_t session_handle_; public: - tls_server(own<server> internal__, gnutls_certificate_credentials_t xcred__): + tls_server(own<server<T>> internal__, gnutls_certificate_credentials_t xcred__): internal_{std::move(internal__)} {} @@ -110,36 +112,33 @@ public: gnutls_certificate_free_credentials(xcred_); } - conveyor<own<io_stream>> accept() override { - return make_error<err::not_implemented>(); - } + conveyor<own<io_stream<net::Tls<T>>>> accept() override; }; -class tls_network final : public network { +template<typename T> +class tls_network final : public network<net::Tls<T>> { private: - tls& tls_; - network &internal; + ref<tls> tls_; + ref<network<T>> internal_; public: - tls_network(tls& tls_, network &network_); + tls_network(tls& tls_, network<T> &network_); - conveyor<own<network_address>> resolve_address(const std::string &addr, uint16_t port = 0) override; + conveyor<own<network_address<net::Tls<T>>>> resolve_address(const std::string &addr, uint16_t port = 0) override; - own<server> listen(network_address& address) override; + own<server<net::Tls<T>>> listen(const network_address<net::Tls<T>>& address) override; - conveyor<own<io_stream>> connect(network_address& address) override; + conveyor<own<io_stream<net::Tls<T>>>> connect(const network_address<net::Tls<T>>& address) override; - own<class datagram> datagram(network_address& address) override; + own<datagram<net::Tls<T>>> bind_datagram(const network_address<net::Tls<T>>& address) override; }; -tls_server::tls_server(own<server> srv) : internal{std::move(srv)} {} - -conveyor<own<io_stream>> tls_server::accept() { - SAW_ASSERT(internal) { return conveyor<own<io_stream>>{fix_void<own<io_stream>>{nullptr}}; } - return internal->accept().then([](own<io_stream> stream) -> own<io_stream> { +template<typename T> +conveyor<own<io_stream<net::Tls<T>>>> tls_server<T>::accept() { + SAW_ASSERT(internal_) { return conveyor<own<io_stream<net::Tls<T>>>>{fix_void<own<io_stream<net::Tls<T>>>>{nullptr}}; } + return internal_->accept().then([](own<io_stream<T>> stream) -> own<io_stream<net::Tls<T>>> { /// @todo handshake - - return heap<tls_io_stream>(std::move(stream)); + return heap<tls_io_stream<T>>(std::move(stream)); }); } @@ -147,16 +146,17 @@ namespace { /* * Small helper for setting up the nonblocking connection handshake */ +template<typename T> struct tls_client_stream_helper { public: - own<conveyor_feeder<own<io_stream>>> feeder; + own<conveyor_feeder<own<io_stream<net::Tls<T>>>>> feeder; conveyor_sink connection_sink; conveyor_sink stream_reader; conveyor_sink stream_writer; - own<tls_io_stream> stream = nullptr; + own<tls_io_stream<T>> stream = nullptr; public: - tls_client_stream_helper(own<conveyor_feeder<own<io_stream>>> f): + tls_client_stream_helper(own<conveyor_feeder<own<io_stream<net::Tls<T>>>>> f): feeder{std::move(f)} {} @@ -199,25 +199,27 @@ public: }; } -own<server> tls_network::listen(const network_address& address) { +template<typename T> +own<server<net::Tls<T>>> tls_network<T>::listen(const network_address<net::Tls<T>>& address) { gnutls_certificate_credentials_t x509_cred; gnutls_certificate_allocate_credentials(&x509_cred); - auto int_srv = internal.listen(address); + auto int_srv = internal_.listen(address); - return heap<tls_server>(int_srv, x509_cred); + return heap<tls_server>(std::move(int_srv), x509_cred); } -conveyor<own<io_stream>> tls_network::connect(network_address& address) { +template<typename T> +conveyor<own<io_stream<net::Tls<T>>>> tls_network<T>::connect(const network_address<net::Tls<T>>& address) { // Helper setups - auto caf = new_conveyor_and_feeder<own<io_stream>>(); - own<tls_client_stream_helper> helper = heap<tls_client_stream_helper>(std::move(caf.feeder)); - tls_client_stream_helper* hlp_ptr = helper.get(); + auto caf = new_conveyor_and_feeder<own<io_stream<net::Tls<T>>>>(); + own<tls_client_stream_helper<T>> helper = heap<tls_client_stream_helper<T>>(std::move(caf.feeder)); + tls_client_stream_helper<T>* hlp_ptr = helper.get(); // Conveyor entangled structure - auto prim_conv = internal.connect(address).then([this, hlp_ptr, addr = address.address()]( - own<io_stream> stream) -> error_or<void> { - io_stream* inner_stream = stream.get(); - auto tls_stream = heap<tls_io_stream>(std::move(stream)); + auto prim_conv = internal_.connect(address).then([this, hlp_ptr, addr = address.address()]( + own<io_stream<T>> stream) -> error_or<void> { + io_stream<T>* inner_stream = stream.get(); + auto tls_stream = heap<tls_io_stream<T>>(std::move(stream)); auto &session = tls_stream->session(); @@ -228,7 +230,7 @@ conveyor<own<io_stream>> tls_network::connect(network_address& address) { gnutls_set_default_priority(session); gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, - tls_.get_impl().xcred); + tls_().get_impl().xcred); gnutls_session_set_verify_cert(session, addr.c_str(), 0); gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(inner_stream)); @@ -249,14 +251,16 @@ conveyor<own<io_stream>> tls_network::connect(network_address& address) { return caf.conveyor.attach(std::move(helper)); } -own<datagram> tls_network::datagram(network_address& address){ +template<typename T> +own<datagram<net::Tls<T>>> tls_network<T>::bind_datagram(const network_address<net::Tls<T>>& address){ ///@unimplemented return nullptr; } +template<typename T> static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, size_t size) { - io_stream *stream = reinterpret_cast<io_stream *>(p); + io_stream<T> *stream = reinterpret_cast<io_stream<T>*>(p); if (!stream) { return -1; } @@ -269,8 +273,9 @@ static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, return static_cast<ssize_t>(length.get_value()); } +template<typename T> static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) { - io_stream *stream = reinterpret_cast<io_stream *>(p); + io_stream<T> *stream = reinterpret_cast<io_stream<T>*>(p); if (!stream) { return -1; } @@ -283,16 +288,19 @@ static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t return static_cast<ssize_t>(length.get_value()); } -tls_network::tls_network(tls& tls_, network &network) : tls_{tls_},internal{network} {} +template<typename T> +tls_network<T>::tls_network(tls& tls_, network<T> &network) : tls_{tls_},internal_{network} {} -conveyor<own<network_address>> tls_network::resolve_address(const std::string &addr, +template<typename T> +conveyor<own<network_address<net::Tls<T>>>> tls_network<T>::resolve_address(const std::string &addr, uint16_t port) { /// @todo tls server name needed. Check validity. Won't matter later on, because gnutls should fail anyway. But /// it's better to find the error source sooner rather than later - return internal.resolve_address(addr, port); + return internal_.resolve_address(addr, port); } -error_or<own<network<net::Tls>>> setup_tls_network(network &network) { - return std::nullopt; +template<typename T> +error_or<own<network<net::Tls<T>>>> setup_tls_network(network<net::Tls<T>> &network) { + return make_error<err::not_implemented>(); } } // namespace saw diff --git a/modules/io-tls/tls.hpp b/modules/io-tls/tls.hpp index e2202f4..c5c3da1 100644 --- a/modules/io-tls/tls.hpp +++ b/modules/io-tls/tls.hpp @@ -42,8 +42,33 @@ private: options options_; }; -template<> -class network<net::Tls> { +template<typename T> +class network<net::Tls<T>> { +public: + virtual ~network() = default; + + /** + * Resolve the provided string and uint16 to the preferred storage method + */ + virtual conveyor<own<network_address<net::Tls<T>>>> + resolve_address(const std::string &addr, uint16_t port_hint = 0) = 0; + + /** + * Parse the provided string and uint16 to the preferred storage method + * Since no dns request is made here, no async conveyors have to be used. + */ + virtual error_or<own<network_address<net::Tls<T>>>> + parse_address(const std::string &addr, uint16_t port_hint = 0) = 0; + + /** + * Set up a listener on this address + */ + virtual error_or<own<server<T>>> listen(network_address<T> &bind_addr) = 0; + + /** + * Connect to a remote address + */ + virtual conveyor<own<io_stream<T>>> connect(network_address<T> &address) = 0; }; template<typename T = net::Os> |