From 9435dd1b25f6b6daac64e8a2ac900c65c440b07a Mon Sep 17 00:00:00 2001 From: "keldu.magnus" Date: Fri, 18 Jun 2021 00:17:18 +0200 Subject: [PATCH] tls works for clients --- SConstruct | 11 +- driver/io-unix.cpp | 3 + driver/io-unix.h | 4 + source/kelgin/SConscript | 3 + source/kelgin/async.h | 39 +++++++- source/kelgin/io.h | 3 + source/kelgin/io_helpers.cpp | 1 - source/kelgin/io_wrapper.h | 18 ++++ source/kelgin/tls/tls.cpp | 188 +++++++++++++++++++++++++++++++++-- source/kelgin/tls/tls.h | 60 ++++++----- 10 files changed, 288 insertions(+), 42 deletions(-) create mode 100644 source/kelgin/io_wrapper.h diff --git a/SConstruct b/SConstruct index 9fbd7a2..706e9dc 100644 --- a/SConstruct +++ b/SConstruct @@ -36,9 +36,12 @@ env=Environment(CPPPATH=['#source/kelgin','#source','#','#driver'], LIBS=['gnutls']) env.__class__.add_source_files = add_kel_source_files +env.objects = [] env.sources = [] env.headers = [] -env.objects = [] + +env.tls_sources = [] +env.tls_headers = [] env.driver_sources = [] env.driver_headers = [] @@ -52,11 +55,11 @@ SConscript('driver/SConscript') env_library = env.Clone() env.objects_shared = [] -env_library.add_source_files(env.objects_shared, env.sources + env.driver_sources, shared=True) +env_library.add_source_files(env.objects_shared, env.sources + env.driver_sources + env.tls_sources, shared=True) env.library_shared = env_library.SharedLibrary('#bin/kelgin', [env.objects_shared]) env.objects_static = [] -env_library.add_source_files(env.objects_static, env.sources + env.driver_sources) +env_library.add_source_files(env.objects_static, env.sources + env.driver_sources + env.tls_sources) env.library_static = env_library.StaticLibrary('#bin/kelgin', [env.objects_static]) env.Alias('library', [env.library_shared, env.library_static]) @@ -86,5 +89,7 @@ env.Alias('all', ['format', 'library_shared', 'library_static', 'test']) env.Install('/usr/local/lib/', [env.library_shared, env.library_static]) env.Install('/usr/local/include/kelgin/', [env.headers]) +env.Install('/usr/local/include/kelgin/tls/', [env.tls_headers]) + env.Install('/usr/local/include/kelgin/test/', [env.test_headers]) env.Alias('install', '/usr/local/') diff --git a/driver/io-unix.cpp b/driver/io-unix.cpp index 822025d..8ee1122 100644 --- a/driver/io-unix.cpp +++ b/driver/io-unix.cpp @@ -213,6 +213,9 @@ std::string UnixNetworkAddress::toString() const { return {}; } } +const std::string &UnixNetworkAddress::address() const { return path; } + +uint16_t UnixNetworkAddress::port() const { return port_hint; } UnixNetwork::UnixNetwork(UnixEventPort &event) : event_port{event} {} diff --git a/driver/io-unix.h b/driver/io-unix.h index 52ff530..2bf3bc7 100644 --- a/driver/io-unix.h +++ b/driver/io-unix.h @@ -404,6 +404,10 @@ public: Conveyor> connect() override; std::string toString() const override; + + const std::string &address() const override; + + uint16_t port() const override; }; class UnixNetwork final : public Network { diff --git a/source/kelgin/SConscript b/source/kelgin/SConscript index 490e08e..dd30e4d 100644 --- a/source/kelgin/SConscript +++ b/source/kelgin/SConscript @@ -11,3 +11,6 @@ dir_path = Dir('.').abspath env.sources += sorted(glob.glob(dir_path + "/*.cpp")) env.headers += sorted(glob.glob(dir_path + "/*.h")) + +env.tls_sources += sorted(glob.glob(dir_path + "/tls/*.cpp")) +env.tls_headers += sorted(glob.glob(dir_path + "/tls/*.h")) diff --git a/source/kelgin/async.h b/source/kelgin/async.h index 227b186..e2dbe2d 100644 --- a/source/kelgin/async.h +++ b/source/kelgin/async.h @@ -701,32 +701,61 @@ public: // ConveyorNode void getResult(ErrorOrValue &err_or_val) noexcept override { - if (retrieved) { + if (retrieved > 0) { err_or_val.as>() = criticalError("Already taken value"); } else { err_or_val.as>() = std::move(value); } - ++retrieved; + if(queued() > 0){ + ++retrieved; + } } // Event void fire() override; }; -class JoinConveyorNodeBase : public ConveyorNode, public ConveyorStorage { +class JoinConveyorNodeBase : public ConveyorStorage { public: virtual ~JoinConveyorNodeBase() = default; + + }; -template class JoinConveyorNode : public JoinConveyorNodeBase { +template class JoinConveyorNode final : public JoinConveyorNodeBase { +private: + T data; public: }; -template class JoinConveyorMerger : public ConveyorStorage { +class JoinConveyorMergerNodeBase : public ConveyorNode, public ConveyorStorage { +public: + +}; + +template class JoinConveyorMergerNode final : public JoinConveyorMergerBase { private: std::tuple...> joined; +public: + void getResult(ErrorOrValue &err_or_val) noexcept override { + + } + + void fire() override; }; +class UniteConveyorNodeBase : public ConveyorNode, public ConveyorStorage { +public: + virtual ~UniteConveyorNodeBase() = default; +}; + +template class UniteConveyorNode : public UniteConveyorNodeBase { +public: + virtual ~UniteConveyorNode() = default; +}; + +template class + } // namespace gin #include "async.tmpl.h" diff --git a/source/kelgin/io.h b/source/kelgin/io.h index 7c63c2c..3f42ccd 100644 --- a/source/kelgin/io.h +++ b/source/kelgin/io.h @@ -100,6 +100,9 @@ public: virtual Conveyor> connect() = 0; virtual std::string toString() const = 0; + + virtual const std::string &address() const = 0; + virtual uint16_t port() const = 0; }; class Network { diff --git a/source/kelgin/io_helpers.cpp b/source/kelgin/io_helpers.cpp index d13a04d..5718685 100644 --- a/source/kelgin/io_helpers.cpp +++ b/source/kelgin/io_helpers.cpp @@ -23,7 +23,6 @@ void ReadTaskAndStepHelper::readStep(InputStream &reader) { if (static_cast(n) >= task.min_length && static_cast(n) <= task.max_length) { if (read_done) { - // Accumulated bytes are not pushed read_done->feed(n + task.already_read); } read_task = std::nullopt; diff --git a/source/kelgin/io_wrapper.h b/source/kelgin/io_wrapper.h new file mode 100644 index 0000000..1782cce --- /dev/null +++ b/source/kelgin/io_wrapper.h @@ -0,0 +1,18 @@ +#pragma once + +#include "async.h" +#include "io.h" + +namespace gin { +/* +template +class StreamingIoPeer { +private: + Codec codec; +public: + void send(Outgoing&& outgoing); + + Conveyor startReadPump(); +}; +*/ +} \ No newline at end of file diff --git a/source/kelgin/tls/tls.cpp b/source/kelgin/tls/tls.cpp index e324b93..a2feea4 100644 --- a/source/kelgin/tls/tls.cpp +++ b/source/kelgin/tls/tls.cpp @@ -5,29 +5,203 @@ #include "io_helpers.h" +#include + +#include + namespace gin { class Tls::Impl { public: - Impl(){ + gnutls_certificate_credentials_t xcred; + +public: + Impl() { gnutls_global_init(); gnutls_certificate_allocate_credentials(&xcred); gnutls_certificate_set_x509_system_trust(xcred); } - ~Impl(){ + ~Impl() { gnutls_certificate_free_credentials(xcred); gnutls_global_deinit(); } }; -Tls::Tls(): - impl{heap()} -{} +static ssize_t kelgin_tls_push_func(gnutls_transport_ptr_t p, const void *data, + size_t size); +static ssize_t kelgin_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size); -Tls::~Tls(){} +Tls::Tls() : impl{heap()} {} + +Tls::~Tls() {} + +Tls::Impl &Tls::getImpl() { return *impl; } + +class TlsIoStream final : public IoStream { +private: + Own internal; + gnutls_session_t session_handle; -class TlsNetworkImpl final : public TlsNetwork { public: + TlsIoStream(Own internal_) : internal{std::move(internal_)} {} + + ~TlsIoStream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); } + + ErrorOr read(void *buffer, size_t length) override { + ssize_t size = gnutls_record_recv(session_handle, buffer, length); + if (size < 0) { + if(gnutls_error_is_fatal(size) == 0){ + return recoverableError([size](){return std::string{"Read recoverable Error "}+std::string{gnutls_strerror(size)};}, "Error read r"); + }else{ + return criticalError([size](){return std::string{"Read critical Error "}+std::string{gnutls_strerror(size)};}, "Error read c"); + } + }else if(size == 0){ + return criticalError("Disconnected"); + } + + return static_cast(length); + } + + Conveyor readReady() override { return internal->readReady(); } + + Conveyor onReadDisconnected() override { + return internal->onReadDisconnected(); + } + + ErrorOr write(const void *buffer, size_t length) override { + ssize_t size = gnutls_record_send(session_handle, buffer, length); + if(size < 0){ + if(gnutls_error_is_fatal(size) == 0){ + return recoverableError([size](){return std::string{"Write recoverable Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write r"); + }else{ + return criticalError([size](){return std::string{"Write critical Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write c"); + } + } + + return static_cast(size); + } + + Conveyor writeReady() override { return internal->writeReady(); } + + gnutls_session_t &session() { return session_handle; } }; + +TlsServer::TlsServer(Own srv) : internal{std::move(srv)} {} + +Conveyor> TlsServer::accept() { + GIN_ASSERT(internal) { return Conveyor>{nullptr, nullptr}; } + return internal->accept().then([](Own stream) -> Own { + return heap(std::move(stream)); + }); +} + +TlsNetworkAddress::TlsNetworkAddress(Own net_addr, const std::string& host_name_, Tls &tls_) + : internal{std::move(net_addr)}, host_name{host_name_}, tls{tls_} {} + +Own TlsNetworkAddress::listen() { + GIN_ASSERT(internal) { return nullptr; } + return heap(internal->listen()); +} + +Conveyor> TlsNetworkAddress::connect() { + GIN_ASSERT(internal) { return Conveyor>{nullptr, nullptr}; } + return internal->connect().then([this]( + Own stream) -> ErrorOr> { + IoStream* inner_stream = stream.get(); + auto tls_stream = heap(std::move(stream)); + + auto &session = tls_stream->session(); + + gnutls_init(&session, GNUTLS_CLIENT); + + const std::string &addr = this->address(); + + gnutls_server_name_set(session, GNUTLS_NAME_DNS, addr.c_str(), + addr.size()); + + gnutls_set_default_priority(session); + gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, + tls.getImpl().xcred); + gnutls_session_set_verify_cert(session, addr.c_str(), 0); + + gnutls_transport_set_ptr(session, reinterpret_cast(inner_stream)); + gnutls_transport_set_push_function(session, kelgin_tls_push_func); + gnutls_transport_set_pull_function(session, kelgin_tls_pull_func); + + // gnutls_handshake_set_timeout(session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); + + int ret; + do { + ret = gnutls_handshake(session); + } while (ret < 0 && gnutls_error_is_fatal(ret) == 0); + + if(ret < 0){ + return criticalError("Couldn't create Tls connection"); + } + + return tls_stream; + }); +} + +static ssize_t kelgin_tls_push_func(gnutls_transport_ptr_t p, const void *data, + size_t size) { + IoStream *stream = reinterpret_cast(p); + if (!stream) { + return -1; + } + + ErrorOr length = stream->write(data, size); + if (length.isError() || !length.isValue()) { + if(length.isError()){ + std::cerr<<"*** Error: "<(length.value()); +} + +static ssize_t kelgin_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) { + IoStream *stream = reinterpret_cast(p); + if (!stream) { + return -1; + } + + ErrorOr length = stream->read(data, size); + if (length.isError() || !length.isValue()) { + if(length.isError()){ + std::cerr<<"*** Error: "<(length.value()); +} + +const std::string &TlsNetworkAddress::address() const { + assert(internal); + return internal->address(); +} +uint16_t TlsNetworkAddress::port() const { + assert(internal); + return internal->port(); } + +std::string TlsNetworkAddress::toString() const { return internal->toString(); } + +TlsNetwork::TlsNetwork(Network &network) : internal{network} {} + +Conveyor> TlsNetwork::parseAddress(const std::string &addr, + uint16_t port) { + return internal.parseAddress(addr, port) + .then( + [this, addr, port](Own net) -> Own { + assert(net); + return heap(std::move(net), tls); + }); +} + +std::optional> setupTlsNetwork(Network &network) { + return std::nullopt; +} } // namespace gin diff --git a/source/kelgin/tls/tls.h b/source/kelgin/tls/tls.h index 0324e91..2a5c727 100644 --- a/source/kelgin/tls/tls.h +++ b/source/kelgin/tls/tls.h @@ -1,61 +1,69 @@ #pragma once -#include -#include +#include "../common.h" +#include "../io.h" #include +#include namespace gin { class Tls { -public: +private: class Impl; Own impl; - + +public: Tls(); ~Tls(); class Options { public: }; -}; -class TlsIoStream final : public IoStream { -private: - Own stream; -public: - TlsIoStream(Own str); - - size_t read(void* buffer, size_t length) override; - - Conveyor readReady() override; - - Conveyor onReadDisconnected() override; - - size_t write(const void* buffer, size_t length) override; - - Conveyor writeReady() override; + Impl &getImpl(); }; class TlsServer final : public Server { +private: + Own internal; +public: + TlsServer(Own srv); + + Conveyor> accept() override; }; class TlsNetworkAddress final : public NetworkAddress { +private: + Own internal; + std::string host_name; + Tls &tls; + public: + TlsNetworkAddress(Own net_addr, const std::string& host_name_, Tls &tls_); + Own listen() override; - Own connect() override; + Conveyor> connect() override; - std::string toString() override; + std::string toString() const override; + + const std::string &address() const override; + uint16_t port() const override; }; class TlsNetwork final : public Network { -public: - TlsNetwork(Network& network); +private: + Tls tls; + Network &internal; - Own parseAddress(const std::string& addr, uint16_t port = 0) override; +public: + TlsNetwork(Network &network); + + Conveyor> parseAddress(const std::string &addr, + uint16_t port = 0) override; }; -std::optional> setupTlsNetwork(Network& network); +std::optional> setupTlsNetwork(Network &network); } // namespace gin