tls works for clients

This commit is contained in:
keldu.magnus 2021-06-18 00:17:18 +02:00
parent 5890ff86ad
commit 9435dd1b25
10 changed files with 288 additions and 42 deletions

View File

@ -36,9 +36,12 @@ env=Environment(CPPPATH=['#source/kelgin','#source','#','#driver'],
LIBS=['gnutls']) LIBS=['gnutls'])
env.__class__.add_source_files = add_kel_source_files env.__class__.add_source_files = add_kel_source_files
env.objects = []
env.sources = [] env.sources = []
env.headers = [] env.headers = []
env.objects = []
env.tls_sources = []
env.tls_headers = []
env.driver_sources = [] env.driver_sources = []
env.driver_headers = [] env.driver_headers = []
@ -52,11 +55,11 @@ SConscript('driver/SConscript')
env_library = env.Clone() env_library = env.Clone()
env.objects_shared = [] env.objects_shared = []
env_library.add_source_files(env.objects_shared, env.sources + env.driver_sources, shared=True) env_library.add_source_files(env.objects_shared, env.sources + env.driver_sources + env.tls_sources, shared=True)
env.library_shared = env_library.SharedLibrary('#bin/kelgin', [env.objects_shared]) env.library_shared = env_library.SharedLibrary('#bin/kelgin', [env.objects_shared])
env.objects_static = [] env.objects_static = []
env_library.add_source_files(env.objects_static, env.sources + env.driver_sources) env_library.add_source_files(env.objects_static, env.sources + env.driver_sources + env.tls_sources)
env.library_static = env_library.StaticLibrary('#bin/kelgin', [env.objects_static]) env.library_static = env_library.StaticLibrary('#bin/kelgin', [env.objects_static])
env.Alias('library', [env.library_shared, env.library_static]) env.Alias('library', [env.library_shared, env.library_static])
@ -86,5 +89,7 @@ env.Alias('all', ['format', 'library_shared', 'library_static', 'test'])
env.Install('/usr/local/lib/', [env.library_shared, env.library_static]) env.Install('/usr/local/lib/', [env.library_shared, env.library_static])
env.Install('/usr/local/include/kelgin/', [env.headers]) env.Install('/usr/local/include/kelgin/', [env.headers])
env.Install('/usr/local/include/kelgin/tls/', [env.tls_headers])
env.Install('/usr/local/include/kelgin/test/', [env.test_headers]) env.Install('/usr/local/include/kelgin/test/', [env.test_headers])
env.Alias('install', '/usr/local/') env.Alias('install', '/usr/local/')

View File

@ -213,6 +213,9 @@ std::string UnixNetworkAddress::toString() const {
return {}; return {};
} }
} }
const std::string &UnixNetworkAddress::address() const { return path; }
uint16_t UnixNetworkAddress::port() const { return port_hint; }
UnixNetwork::UnixNetwork(UnixEventPort &event) : event_port{event} {} UnixNetwork::UnixNetwork(UnixEventPort &event) : event_port{event} {}

View File

@ -404,6 +404,10 @@ public:
Conveyor<Own<IoStream>> connect() override; Conveyor<Own<IoStream>> connect() override;
std::string toString() const override; std::string toString() const override;
const std::string &address() const override;
uint16_t port() const override;
}; };
class UnixNetwork final : public Network { class UnixNetwork final : public Network {

View File

@ -11,3 +11,6 @@ dir_path = Dir('.').abspath
env.sources += sorted(glob.glob(dir_path + "/*.cpp")) env.sources += sorted(glob.glob(dir_path + "/*.cpp"))
env.headers += sorted(glob.glob(dir_path + "/*.h")) env.headers += sorted(glob.glob(dir_path + "/*.h"))
env.tls_sources += sorted(glob.glob(dir_path + "/tls/*.cpp"))
env.tls_headers += sorted(glob.glob(dir_path + "/tls/*.h"))

View File

@ -701,32 +701,61 @@ public:
// ConveyorNode // ConveyorNode
void getResult(ErrorOrValue &err_or_val) noexcept override { void getResult(ErrorOrValue &err_or_val) noexcept override {
if (retrieved) { if (retrieved > 0) {
err_or_val.as<FixVoid<T>>() = criticalError("Already taken value"); err_or_val.as<FixVoid<T>>() = criticalError("Already taken value");
} else { } else {
err_or_val.as<FixVoid<T>>() = std::move(value); err_or_val.as<FixVoid<T>>() = std::move(value);
} }
++retrieved; if(queued() > 0){
++retrieved;
}
} }
// Event // Event
void fire() override; void fire() override;
}; };
class JoinConveyorNodeBase : public ConveyorNode, public ConveyorStorage { class JoinConveyorNodeBase : public ConveyorStorage {
public: public:
virtual ~JoinConveyorNodeBase() = default; virtual ~JoinConveyorNodeBase() = default;
}; };
template <typename T> class JoinConveyorNode : public JoinConveyorNodeBase { template <typename T> class JoinConveyorNode final : public JoinConveyorNodeBase {
private:
T data;
public: public:
}; };
template <typename... Args> class JoinConveyorMerger : public ConveyorStorage { class JoinConveyorMergerNodeBase : public ConveyorNode, public ConveyorStorage {
public:
};
template <typename... Args> class JoinConveyorMergerNode final : public JoinConveyorMergerBase {
private: private:
std::tuple<JoinConveyorNode<Args>...> joined; std::tuple<JoinConveyorNode<Args>...> joined;
public:
void getResult(ErrorOrValue &err_or_val) noexcept override {
}
void fire() override;
}; };
class UniteConveyorNodeBase : public ConveyorNode, public ConveyorStorage {
public:
virtual ~UniteConveyorNodeBase() = default;
};
template <typename T> class UniteConveyorNode : public UniteConveyorNodeBase {
public:
virtual ~UniteConveyorNode() = default;
};
template <typename T> class
} // namespace gin } // namespace gin
#include "async.tmpl.h" #include "async.tmpl.h"

View File

@ -100,6 +100,9 @@ public:
virtual Conveyor<Own<IoStream>> connect() = 0; virtual Conveyor<Own<IoStream>> connect() = 0;
virtual std::string toString() const = 0; virtual std::string toString() const = 0;
virtual const std::string &address() const = 0;
virtual uint16_t port() const = 0;
}; };
class Network { class Network {

View File

@ -23,7 +23,6 @@ void ReadTaskAndStepHelper::readStep(InputStream &reader) {
if (static_cast<size_t>(n) >= task.min_length && if (static_cast<size_t>(n) >= task.min_length &&
static_cast<size_t>(n) <= task.max_length) { static_cast<size_t>(n) <= task.max_length) {
if (read_done) { if (read_done) {
// Accumulated bytes are not pushed
read_done->feed(n + task.already_read); read_done->feed(n + task.already_read);
} }
read_task = std::nullopt; read_task = std::nullopt;

View File

@ -0,0 +1,18 @@
#pragma once
#include "async.h"
#include "io.h"
namespace gin {
/*
template<typename Codec, typename Incoming, typename Outgoing>
class StreamingIoPeer {
private:
Codec codec;
public:
void send(Outgoing&& outgoing);
Conveyor<Incoming> startReadPump();
};
*/
}

View File

@ -5,29 +5,203 @@
#include "io_helpers.h" #include "io_helpers.h"
#include <cassert>
#include <iostream>
namespace gin { namespace gin {
class Tls::Impl { class Tls::Impl {
public: public:
Impl(){ gnutls_certificate_credentials_t xcred;
public:
Impl() {
gnutls_global_init(); gnutls_global_init();
gnutls_certificate_allocate_credentials(&xcred); gnutls_certificate_allocate_credentials(&xcred);
gnutls_certificate_set_x509_system_trust(xcred); gnutls_certificate_set_x509_system_trust(xcred);
} }
~Impl(){ ~Impl() {
gnutls_certificate_free_credentials(xcred); gnutls_certificate_free_credentials(xcred);
gnutls_global_deinit(); gnutls_global_deinit();
} }
}; };
Tls::Tls(): static ssize_t kelgin_tls_push_func(gnutls_transport_ptr_t p, const void *data,
impl{heap<Tls::Impl>()} size_t size);
{} static ssize_t kelgin_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size);
Tls::~Tls(){} Tls::Tls() : impl{heap<Tls::Impl>()} {}
Tls::~Tls() {}
Tls::Impl &Tls::getImpl() { return *impl; }
class TlsIoStream final : public IoStream {
private:
Own<IoStream> internal;
gnutls_session_t session_handle;
class TlsNetworkImpl final : public TlsNetwork {
public: public:
TlsIoStream(Own<IoStream> internal_) : internal{std::move(internal_)} {}
~TlsIoStream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); }
ErrorOr<size_t> read(void *buffer, size_t length) override {
ssize_t size = gnutls_record_recv(session_handle, buffer, length);
if (size < 0) {
if(gnutls_error_is_fatal(size) == 0){
return recoverableError([size](){return std::string{"Read recoverable Error "}+std::string{gnutls_strerror(size)};}, "Error read r");
}else{
return criticalError([size](){return std::string{"Read critical Error "}+std::string{gnutls_strerror(size)};}, "Error read c");
}
}else if(size == 0){
return criticalError("Disconnected");
}
return static_cast<size_t>(length);
}
Conveyor<void> readReady() override { return internal->readReady(); }
Conveyor<void> onReadDisconnected() override {
return internal->onReadDisconnected();
}
ErrorOr<size_t> write(const void *buffer, size_t length) override {
ssize_t size = gnutls_record_send(session_handle, buffer, length);
if(size < 0){
if(gnutls_error_is_fatal(size) == 0){
return recoverableError([size](){return std::string{"Write recoverable Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write r");
}else{
return criticalError([size](){return std::string{"Write critical Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write c");
}
}
return static_cast<size_t>(size);
}
Conveyor<void> writeReady() override { return internal->writeReady(); }
gnutls_session_t &session() { return session_handle; }
}; };
TlsServer::TlsServer(Own<Server> srv) : internal{std::move(srv)} {}
Conveyor<Own<IoStream>> TlsServer::accept() {
GIN_ASSERT(internal) { return Conveyor<Own<IoStream>>{nullptr, nullptr}; }
return internal->accept().then([](Own<IoStream> stream) -> Own<IoStream> {
return heap<TlsIoStream>(std::move(stream));
});
}
TlsNetworkAddress::TlsNetworkAddress(Own<NetworkAddress> net_addr, const std::string& host_name_, Tls &tls_)
: internal{std::move(net_addr)}, host_name{host_name_}, tls{tls_} {}
Own<Server> TlsNetworkAddress::listen() {
GIN_ASSERT(internal) { return nullptr; }
return heap<TlsServer>(internal->listen());
}
Conveyor<Own<IoStream>> TlsNetworkAddress::connect() {
GIN_ASSERT(internal) { return Conveyor<Own<IoStream>>{nullptr, nullptr}; }
return internal->connect().then([this](
Own<IoStream> stream) -> ErrorOr<Own<IoStream>> {
IoStream* inner_stream = stream.get();
auto tls_stream = heap<TlsIoStream>(std::move(stream));
auto &session = tls_stream->session();
gnutls_init(&session, GNUTLS_CLIENT);
const std::string &addr = this->address();
gnutls_server_name_set(session, GNUTLS_NAME_DNS, addr.c_str(),
addr.size());
gnutls_set_default_priority(session);
gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE,
tls.getImpl().xcred);
gnutls_session_set_verify_cert(session, addr.c_str(), 0);
gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(inner_stream));
gnutls_transport_set_push_function(session, kelgin_tls_push_func);
gnutls_transport_set_pull_function(session, kelgin_tls_pull_func);
// gnutls_handshake_set_timeout(session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
int ret;
do {
ret = gnutls_handshake(session);
} while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
if(ret < 0){
return criticalError("Couldn't create Tls connection");
}
return tls_stream;
});
}
static ssize_t kelgin_tls_push_func(gnutls_transport_ptr_t p, const void *data,
size_t size) {
IoStream *stream = reinterpret_cast<IoStream *>(p);
if (!stream) {
return -1;
}
ErrorOr<size_t> length = stream->write(data, size);
if (length.isError() || !length.isValue()) {
if(length.isError()){
std::cerr<<"*** Error: "<<length.error().message()<<std::endl;
}
return -1;
}
return static_cast<ssize_t>(length.value());
}
static ssize_t kelgin_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) {
IoStream *stream = reinterpret_cast<IoStream *>(p);
if (!stream) {
return -1;
}
ErrorOr<size_t> length = stream->read(data, size);
if (length.isError() || !length.isValue()) {
if(length.isError()){
std::cerr<<"*** Error: "<<length.error().message()<<std::endl;
}
return -1;
}
return static_cast<ssize_t>(length.value());
}
const std::string &TlsNetworkAddress::address() const {
assert(internal);
return internal->address();
}
uint16_t TlsNetworkAddress::port() const {
assert(internal);
return internal->port(); }
std::string TlsNetworkAddress::toString() const { return internal->toString(); }
TlsNetwork::TlsNetwork(Network &network) : internal{network} {}
Conveyor<Own<NetworkAddress>> TlsNetwork::parseAddress(const std::string &addr,
uint16_t port) {
return internal.parseAddress(addr, port)
.then(
[this, addr, port](Own<NetworkAddress> net) -> Own<NetworkAddress> {
assert(net);
return heap<TlsNetworkAddress>(std::move(net), tls);
});
}
std::optional<Own<TlsNetwork>> setupTlsNetwork(Network &network) {
return std::nullopt;
}
} // namespace gin } // namespace gin

View File

@ -1,61 +1,69 @@
#pragma once #pragma once
#include <kelgin/common.h> #include "../common.h"
#include <kelgin/io.h> #include "../io.h"
#include <optional> #include <optional>
#include <variant>
namespace gin { namespace gin {
class Tls { class Tls {
public: private:
class Impl; class Impl;
Own<Impl> impl; Own<Impl> impl;
public:
Tls(); Tls();
~Tls(); ~Tls();
class Options { class Options {
public: public:
}; };
};
class TlsIoStream final : public IoStream { Impl &getImpl();
private:
Own<IoStream> stream;
public:
TlsIoStream(Own<IoStream> str);
size_t read(void* buffer, size_t length) override;
Conveyor<void> readReady() override;
Conveyor<void> onReadDisconnected() override;
size_t write(const void* buffer, size_t length) override;
Conveyor<void> writeReady() override;
}; };
class TlsServer final : public Server { class TlsServer final : public Server {
private:
Own<Server> internal;
public:
TlsServer(Own<Server> srv);
Conveyor<Own<IoStream>> accept() override;
}; };
class TlsNetworkAddress final : public NetworkAddress { class TlsNetworkAddress final : public NetworkAddress {
private:
Own<NetworkAddress> internal;
std::string host_name;
Tls &tls;
public: public:
TlsNetworkAddress(Own<NetworkAddress> net_addr, const std::string& host_name_, Tls &tls_);
Own<Server> listen() override; Own<Server> listen() override;
Own<IoStream> connect() override; Conveyor<Own<IoStream>> connect() override;
std::string toString() override; std::string toString() const override;
const std::string &address() const override;
uint16_t port() const override;
}; };
class TlsNetwork final : public Network { class TlsNetwork final : public Network {
public: private:
TlsNetwork(Network& network); Tls tls;
Network &internal;
Own<NetworkAddress> parseAddress(const std::string& addr, uint16_t port = 0) override; public:
TlsNetwork(Network &network);
Conveyor<Own<NetworkAddress>> parseAddress(const std::string &addr,
uint16_t port = 0) override;
}; };
std::optional<Own<TlsNetwork>> setupTlsNetwork(Network& network); std::optional<Own<TlsNetwork>> setupTlsNetwork(Network &network);
} // namespace gin } // namespace gin