summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorClaudius 'keldu' Holeksa <mail@keldu.de>2024-10-17 18:56:11 +0200
committerClaudius 'keldu' Holeksa <mail@keldu.de>2024-10-17 18:56:11 +0200
commit17e22f10026068990595941eeb503fc2adb476a8 (patch)
treeb8e92692ae7a6dc770bd7d81aeb55869ce162a98
parentb048b02732cbfcfbb95bb8e16dec71aca0e977f4 (diff)
Changing impl things
-rw-r--r--modules/io-tls/tls.cpp120
-rw-r--r--modules/io-tls/tls.hpp29
2 files changed, 91 insertions, 58 deletions
diff --git a/modules/io-tls/tls.cpp b/modules/io-tls/tls.cpp
index c9c71f4..1c42215 100644
--- a/modules/io-tls/tls.cpp
+++ b/modules/io-tls/tls.cpp
@@ -38,23 +38,24 @@ tls::~tls() {}
tls::impl &tls::get_impl() { return *impl_; }
-class tls_io_stream final : public io_stream {
+template<typename T>
+class tls_io_stream final : public io_stream<net::Tls<T>> {
private:
- own<io_stream> internal;
+ own<io_stream<T>> internal_;
gnutls_certificate_credentials_t xcred_;
- gnutls_session_t session_handle;
+ gnutls_session_t session_handle_;
public:
- tls_io_stream(own<io_stream> internal_, gnutls_certificate_credentials_t xcred__, gnutls_session_t session_handle__):
- internal{std::move(internal_)},
+ tls_io_stream(own<io_stream<T>> internal__, gnutls_certificate_credentials_t xcred__, gnutls_session_t session_handle__):
+ internal_{std::move(internal__)},
xcred_{xcred__},
session_handle_{session_handle__}
{}
- ~tls_io_stream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); }
+ ~tls_io_stream() { gnutls_bye(session_handle_, GNUTLS_SHUT_RDWR); }
error_or<size_t> read(void *buffer, size_t length) override {
- ssize_t size = gnutls_record_recv(session_handle, buffer, length);
+ ssize_t size = gnutls_record_recv(session_handle_, buffer, length);
if (size < 0) {
if(gnutls_error_is_fatal(size) == 0){
return make_error<err::recoverable>("Recoverable error on read in gnutls. TODO better error msg handling");
@@ -70,14 +71,14 @@ public:
return static_cast<size_t>(length);
}
- conveyor<void> read_ready() override { return internal->read_ready(); }
+ conveyor<void> read_ready() override { return internal_->read_ready(); }
conveyor<void> on_read_disconnected() override {
- return internal->on_read_disconnected();
+ return internal_->on_read_disconnected();
}
error_or<size_t> write(const void *buffer, size_t length) override {
- ssize_t size = gnutls_record_send(session_handle, buffer, length);
+ ssize_t size = gnutls_record_send(session_handle_, buffer, length);
if(size < 0){
if(gnutls_error_is_fatal(size) == 0){
return make_error<err::recoverable>("Recoverable error on write in gnutls. TODO better error msg handling");
@@ -89,19 +90,20 @@ public:
return static_cast<size_t>(size);
}
- conveyor<void> write_ready() override { return internal->write_ready(); }
+ conveyor<void> write_ready() override { return internal_->write_ready(); }
- gnutls_session_t &session() { return session_handle; }
+ gnutls_session_t &session() { return session_handle_; }
};
-class tls_server final : public server {
+template<typename T>
+class tls_server final : public server<net::Tls<T>> {
private:
- own<server> internal_;
+ own<server<T>> internal_;
gnutls_certificate_credentials_t xcred_;
gnutls_session_t session_handle_;
public:
- tls_server(own<server> internal__, gnutls_certificate_credentials_t xcred__):
+ tls_server(own<server<T>> internal__, gnutls_certificate_credentials_t xcred__):
internal_{std::move(internal__)}
{}
@@ -110,36 +112,33 @@ public:
gnutls_certificate_free_credentials(xcred_);
}
- conveyor<own<io_stream>> accept() override {
- return make_error<err::not_implemented>();
- }
+ conveyor<own<io_stream<net::Tls<T>>>> accept() override;
};
-class tls_network final : public network {
+template<typename T>
+class tls_network final : public network<net::Tls<T>> {
private:
- tls& tls_;
- network &internal;
+ ref<tls> tls_;
+ ref<network<T>> internal_;
public:
- tls_network(tls& tls_, network &network_);
+ tls_network(tls& tls_, network<T> &network_);
- conveyor<own<network_address>> resolve_address(const std::string &addr, uint16_t port = 0) override;
+ conveyor<own<network_address<net::Tls<T>>>> resolve_address(const std::string &addr, uint16_t port = 0) override;
- own<server> listen(network_address& address) override;
+ own<server<net::Tls<T>>> listen(const network_address<net::Tls<T>>& address) override;
- conveyor<own<io_stream>> connect(network_address& address) override;
+ conveyor<own<io_stream<net::Tls<T>>>> connect(const network_address<net::Tls<T>>& address) override;
- own<class datagram> datagram(network_address& address) override;
+ own<datagram<net::Tls<T>>> bind_datagram(const network_address<net::Tls<T>>& address) override;
};
-tls_server::tls_server(own<server> srv) : internal{std::move(srv)} {}
-
-conveyor<own<io_stream>> tls_server::accept() {
- SAW_ASSERT(internal) { return conveyor<own<io_stream>>{fix_void<own<io_stream>>{nullptr}}; }
- return internal->accept().then([](own<io_stream> stream) -> own<io_stream> {
+template<typename T>
+conveyor<own<io_stream<net::Tls<T>>>> tls_server<T>::accept() {
+ SAW_ASSERT(internal_) { return conveyor<own<io_stream<net::Tls<T>>>>{fix_void<own<io_stream<net::Tls<T>>>>{nullptr}}; }
+ return internal_->accept().then([](own<io_stream<T>> stream) -> own<io_stream<net::Tls<T>>> {
/// @todo handshake
-
- return heap<tls_io_stream>(std::move(stream));
+ return heap<tls_io_stream<T>>(std::move(stream));
});
}
@@ -147,16 +146,17 @@ namespace {
/*
* Small helper for setting up the nonblocking connection handshake
*/
+template<typename T>
struct tls_client_stream_helper {
public:
- own<conveyor_feeder<own<io_stream>>> feeder;
+ own<conveyor_feeder<own<io_stream<net::Tls<T>>>>> feeder;
conveyor_sink connection_sink;
conveyor_sink stream_reader;
conveyor_sink stream_writer;
- own<tls_io_stream> stream = nullptr;
+ own<tls_io_stream<T>> stream = nullptr;
public:
- tls_client_stream_helper(own<conveyor_feeder<own<io_stream>>> f):
+ tls_client_stream_helper(own<conveyor_feeder<own<io_stream<net::Tls<T>>>>> f):
feeder{std::move(f)}
{}
@@ -199,25 +199,27 @@ public:
};
}
-own<server> tls_network::listen(const network_address& address) {
+template<typename T>
+own<server<net::Tls<T>>> tls_network<T>::listen(const network_address<net::Tls<T>>& address) {
gnutls_certificate_credentials_t x509_cred;
gnutls_certificate_allocate_credentials(&x509_cred);
- auto int_srv = internal.listen(address);
+ auto int_srv = internal_.listen(address);
- return heap<tls_server>(int_srv, x509_cred);
+ return heap<tls_server>(std::move(int_srv), x509_cred);
}
-conveyor<own<io_stream>> tls_network::connect(network_address& address) {
+template<typename T>
+conveyor<own<io_stream<net::Tls<T>>>> tls_network<T>::connect(const network_address<net::Tls<T>>& address) {
// Helper setups
- auto caf = new_conveyor_and_feeder<own<io_stream>>();
- own<tls_client_stream_helper> helper = heap<tls_client_stream_helper>(std::move(caf.feeder));
- tls_client_stream_helper* hlp_ptr = helper.get();
+ auto caf = new_conveyor_and_feeder<own<io_stream<net::Tls<T>>>>();
+ own<tls_client_stream_helper<T>> helper = heap<tls_client_stream_helper<T>>(std::move(caf.feeder));
+ tls_client_stream_helper<T>* hlp_ptr = helper.get();
// Conveyor entangled structure
- auto prim_conv = internal.connect(address).then([this, hlp_ptr, addr = address.address()](
- own<io_stream> stream) -> error_or<void> {
- io_stream* inner_stream = stream.get();
- auto tls_stream = heap<tls_io_stream>(std::move(stream));
+ auto prim_conv = internal_.connect(address).then([this, hlp_ptr, addr = address.address()](
+ own<io_stream<T>> stream) -> error_or<void> {
+ io_stream<T>* inner_stream = stream.get();
+ auto tls_stream = heap<tls_io_stream<T>>(std::move(stream));
auto &session = tls_stream->session();
@@ -228,7 +230,7 @@ conveyor<own<io_stream>> tls_network::connect(network_address& address) {
gnutls_set_default_priority(session);
gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE,
- tls_.get_impl().xcred);
+ tls_().get_impl().xcred);
gnutls_session_set_verify_cert(session, addr.c_str(), 0);
gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(inner_stream));
@@ -249,14 +251,16 @@ conveyor<own<io_stream>> tls_network::connect(network_address& address) {
return caf.conveyor.attach(std::move(helper));
}
-own<datagram> tls_network::datagram(network_address& address){
+template<typename T>
+own<datagram<net::Tls<T>>> tls_network<T>::bind_datagram(const network_address<net::Tls<T>>& address){
///@unimplemented
return nullptr;
}
+template<typename T>
static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data,
size_t size) {
- io_stream *stream = reinterpret_cast<io_stream *>(p);
+ io_stream<T> *stream = reinterpret_cast<io_stream<T>*>(p);
if (!stream) {
return -1;
}
@@ -269,8 +273,9 @@ static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data,
return static_cast<ssize_t>(length.get_value());
}
+template<typename T>
static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) {
- io_stream *stream = reinterpret_cast<io_stream *>(p);
+ io_stream<T> *stream = reinterpret_cast<io_stream<T>*>(p);
if (!stream) {
return -1;
}
@@ -283,16 +288,19 @@ static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t
return static_cast<ssize_t>(length.get_value());
}
-tls_network::tls_network(tls& tls_, network &network) : tls_{tls_},internal{network} {}
+template<typename T>
+tls_network<T>::tls_network(tls& tls_, network<T> &network) : tls_{tls_},internal_{network} {}
-conveyor<own<network_address>> tls_network::resolve_address(const std::string &addr,
+template<typename T>
+conveyor<own<network_address<net::Tls<T>>>> tls_network<T>::resolve_address(const std::string &addr,
uint16_t port) {
/// @todo tls server name needed. Check validity. Won't matter later on, because gnutls should fail anyway. But
/// it's better to find the error source sooner rather than later
- return internal.resolve_address(addr, port);
+ return internal_.resolve_address(addr, port);
}
-error_or<own<network<net::Tls>>> setup_tls_network(network &network) {
- return std::nullopt;
+template<typename T>
+error_or<own<network<net::Tls<T>>>> setup_tls_network(network<net::Tls<T>> &network) {
+ return make_error<err::not_implemented>();
}
} // namespace saw
diff --git a/modules/io-tls/tls.hpp b/modules/io-tls/tls.hpp
index e2202f4..c5c3da1 100644
--- a/modules/io-tls/tls.hpp
+++ b/modules/io-tls/tls.hpp
@@ -42,8 +42,33 @@ private:
options options_;
};
-template<>
-class network<net::Tls> {
+template<typename T>
+class network<net::Tls<T>> {
+public:
+ virtual ~network() = default;
+
+ /**
+ * Resolve the provided string and uint16 to the preferred storage method
+ */
+ virtual conveyor<own<network_address<net::Tls<T>>>>
+ resolve_address(const std::string &addr, uint16_t port_hint = 0) = 0;
+
+ /**
+ * Parse the provided string and uint16 to the preferred storage method
+ * Since no dns request is made here, no async conveyors have to be used.
+ */
+ virtual error_or<own<network_address<net::Tls<T>>>>
+ parse_address(const std::string &addr, uint16_t port_hint = 0) = 0;
+
+ /**
+ * Set up a listener on this address
+ */
+ virtual error_or<own<server<T>>> listen(network_address<T> &bind_addr) = 0;
+
+ /**
+ * Connect to a remote address
+ */
+ virtual conveyor<own<io_stream<T>>> connect(network_address<T> &address) = 0;
};
template<typename T = net::Os>