tls works for clients

This commit is contained in:
keldu.magnus 2021-06-18 00:17:18 +02:00
parent 5890ff86ad
commit 9435dd1b25
10 changed files with 288 additions and 42 deletions

View File

@ -36,9 +36,12 @@ env=Environment(CPPPATH=['#source/kelgin','#source','#','#driver'],
LIBS=['gnutls'])
env.__class__.add_source_files = add_kel_source_files
env.objects = []
env.sources = []
env.headers = []
env.objects = []
env.tls_sources = []
env.tls_headers = []
env.driver_sources = []
env.driver_headers = []
@ -52,11 +55,11 @@ SConscript('driver/SConscript')
env_library = env.Clone()
env.objects_shared = []
env_library.add_source_files(env.objects_shared, env.sources + env.driver_sources, shared=True)
env_library.add_source_files(env.objects_shared, env.sources + env.driver_sources + env.tls_sources, shared=True)
env.library_shared = env_library.SharedLibrary('#bin/kelgin', [env.objects_shared])
env.objects_static = []
env_library.add_source_files(env.objects_static, env.sources + env.driver_sources)
env_library.add_source_files(env.objects_static, env.sources + env.driver_sources + env.tls_sources)
env.library_static = env_library.StaticLibrary('#bin/kelgin', [env.objects_static])
env.Alias('library', [env.library_shared, env.library_static])
@ -86,5 +89,7 @@ env.Alias('all', ['format', 'library_shared', 'library_static', 'test'])
env.Install('/usr/local/lib/', [env.library_shared, env.library_static])
env.Install('/usr/local/include/kelgin/', [env.headers])
env.Install('/usr/local/include/kelgin/tls/', [env.tls_headers])
env.Install('/usr/local/include/kelgin/test/', [env.test_headers])
env.Alias('install', '/usr/local/')

View File

@ -213,6 +213,9 @@ std::string UnixNetworkAddress::toString() const {
return {};
}
}
const std::string &UnixNetworkAddress::address() const { return path; }
uint16_t UnixNetworkAddress::port() const { return port_hint; }
UnixNetwork::UnixNetwork(UnixEventPort &event) : event_port{event} {}

View File

@ -404,6 +404,10 @@ public:
Conveyor<Own<IoStream>> connect() override;
std::string toString() const override;
const std::string &address() const override;
uint16_t port() const override;
};
class UnixNetwork final : public Network {

View File

@ -11,3 +11,6 @@ dir_path = Dir('.').abspath
env.sources += sorted(glob.glob(dir_path + "/*.cpp"))
env.headers += sorted(glob.glob(dir_path + "/*.h"))
env.tls_sources += sorted(glob.glob(dir_path + "/tls/*.cpp"))
env.tls_headers += sorted(glob.glob(dir_path + "/tls/*.h"))

View File

@ -701,32 +701,61 @@ public:
// ConveyorNode
void getResult(ErrorOrValue &err_or_val) noexcept override {
if (retrieved) {
if (retrieved > 0) {
err_or_val.as<FixVoid<T>>() = criticalError("Already taken value");
} else {
err_or_val.as<FixVoid<T>>() = std::move(value);
}
++retrieved;
if(queued() > 0){
++retrieved;
}
}
// Event
void fire() override;
};
class JoinConveyorNodeBase : public ConveyorNode, public ConveyorStorage {
class JoinConveyorNodeBase : public ConveyorStorage {
public:
virtual ~JoinConveyorNodeBase() = default;
};
template <typename T> class JoinConveyorNode : public JoinConveyorNodeBase {
template <typename T> class JoinConveyorNode final : public JoinConveyorNodeBase {
private:
T data;
public:
};
template <typename... Args> class JoinConveyorMerger : public ConveyorStorage {
class JoinConveyorMergerNodeBase : public ConveyorNode, public ConveyorStorage {
public:
};
template <typename... Args> class JoinConveyorMergerNode final : public JoinConveyorMergerBase {
private:
std::tuple<JoinConveyorNode<Args>...> joined;
public:
void getResult(ErrorOrValue &err_or_val) noexcept override {
}
void fire() override;
};
class UniteConveyorNodeBase : public ConveyorNode, public ConveyorStorage {
public:
virtual ~UniteConveyorNodeBase() = default;
};
template <typename T> class UniteConveyorNode : public UniteConveyorNodeBase {
public:
virtual ~UniteConveyorNode() = default;
};
template <typename T> class
} // namespace gin
#include "async.tmpl.h"

View File

@ -100,6 +100,9 @@ public:
virtual Conveyor<Own<IoStream>> connect() = 0;
virtual std::string toString() const = 0;
virtual const std::string &address() const = 0;
virtual uint16_t port() const = 0;
};
class Network {

View File

@ -23,7 +23,6 @@ void ReadTaskAndStepHelper::readStep(InputStream &reader) {
if (static_cast<size_t>(n) >= task.min_length &&
static_cast<size_t>(n) <= task.max_length) {
if (read_done) {
// Accumulated bytes are not pushed
read_done->feed(n + task.already_read);
}
read_task = std::nullopt;

View File

@ -0,0 +1,18 @@
#pragma once
#include "async.h"
#include "io.h"
namespace gin {
/*
template<typename Codec, typename Incoming, typename Outgoing>
class StreamingIoPeer {
private:
Codec codec;
public:
void send(Outgoing&& outgoing);
Conveyor<Incoming> startReadPump();
};
*/
}

View File

@ -5,29 +5,203 @@
#include "io_helpers.h"
#include <cassert>
#include <iostream>
namespace gin {
class Tls::Impl {
public:
Impl(){
gnutls_certificate_credentials_t xcred;
public:
Impl() {
gnutls_global_init();
gnutls_certificate_allocate_credentials(&xcred);
gnutls_certificate_set_x509_system_trust(xcred);
}
~Impl(){
~Impl() {
gnutls_certificate_free_credentials(xcred);
gnutls_global_deinit();
}
};
Tls::Tls():
impl{heap<Tls::Impl>()}
{}
static ssize_t kelgin_tls_push_func(gnutls_transport_ptr_t p, const void *data,
size_t size);
static ssize_t kelgin_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size);
Tls::~Tls(){}
Tls::Tls() : impl{heap<Tls::Impl>()} {}
Tls::~Tls() {}
Tls::Impl &Tls::getImpl() { return *impl; }
class TlsIoStream final : public IoStream {
private:
Own<IoStream> internal;
gnutls_session_t session_handle;
class TlsNetworkImpl final : public TlsNetwork {
public:
TlsIoStream(Own<IoStream> internal_) : internal{std::move(internal_)} {}
~TlsIoStream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); }
ErrorOr<size_t> read(void *buffer, size_t length) override {
ssize_t size = gnutls_record_recv(session_handle, buffer, length);
if (size < 0) {
if(gnutls_error_is_fatal(size) == 0){
return recoverableError([size](){return std::string{"Read recoverable Error "}+std::string{gnutls_strerror(size)};}, "Error read r");
}else{
return criticalError([size](){return std::string{"Read critical Error "}+std::string{gnutls_strerror(size)};}, "Error read c");
}
}else if(size == 0){
return criticalError("Disconnected");
}
return static_cast<size_t>(length);
}
Conveyor<void> readReady() override { return internal->readReady(); }
Conveyor<void> onReadDisconnected() override {
return internal->onReadDisconnected();
}
ErrorOr<size_t> write(const void *buffer, size_t length) override {
ssize_t size = gnutls_record_send(session_handle, buffer, length);
if(size < 0){
if(gnutls_error_is_fatal(size) == 0){
return recoverableError([size](){return std::string{"Write recoverable Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write r");
}else{
return criticalError([size](){return std::string{"Write critical Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write c");
}
}
return static_cast<size_t>(size);
}
Conveyor<void> writeReady() override { return internal->writeReady(); }
gnutls_session_t &session() { return session_handle; }
};
TlsServer::TlsServer(Own<Server> srv) : internal{std::move(srv)} {}
Conveyor<Own<IoStream>> TlsServer::accept() {
GIN_ASSERT(internal) { return Conveyor<Own<IoStream>>{nullptr, nullptr}; }
return internal->accept().then([](Own<IoStream> stream) -> Own<IoStream> {
return heap<TlsIoStream>(std::move(stream));
});
}
TlsNetworkAddress::TlsNetworkAddress(Own<NetworkAddress> net_addr, const std::string& host_name_, Tls &tls_)
: internal{std::move(net_addr)}, host_name{host_name_}, tls{tls_} {}
Own<Server> TlsNetworkAddress::listen() {
GIN_ASSERT(internal) { return nullptr; }
return heap<TlsServer>(internal->listen());
}
Conveyor<Own<IoStream>> TlsNetworkAddress::connect() {
GIN_ASSERT(internal) { return Conveyor<Own<IoStream>>{nullptr, nullptr}; }
return internal->connect().then([this](
Own<IoStream> stream) -> ErrorOr<Own<IoStream>> {
IoStream* inner_stream = stream.get();
auto tls_stream = heap<TlsIoStream>(std::move(stream));
auto &session = tls_stream->session();
gnutls_init(&session, GNUTLS_CLIENT);
const std::string &addr = this->address();
gnutls_server_name_set(session, GNUTLS_NAME_DNS, addr.c_str(),
addr.size());
gnutls_set_default_priority(session);
gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE,
tls.getImpl().xcred);
gnutls_session_set_verify_cert(session, addr.c_str(), 0);
gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(inner_stream));
gnutls_transport_set_push_function(session, kelgin_tls_push_func);
gnutls_transport_set_pull_function(session, kelgin_tls_pull_func);
// gnutls_handshake_set_timeout(session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
int ret;
do {
ret = gnutls_handshake(session);
} while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
if(ret < 0){
return criticalError("Couldn't create Tls connection");
}
return tls_stream;
});
}
static ssize_t kelgin_tls_push_func(gnutls_transport_ptr_t p, const void *data,
size_t size) {
IoStream *stream = reinterpret_cast<IoStream *>(p);
if (!stream) {
return -1;
}
ErrorOr<size_t> length = stream->write(data, size);
if (length.isError() || !length.isValue()) {
if(length.isError()){
std::cerr<<"*** Error: "<<length.error().message()<<std::endl;
}
return -1;
}
return static_cast<ssize_t>(length.value());
}
static ssize_t kelgin_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) {
IoStream *stream = reinterpret_cast<IoStream *>(p);
if (!stream) {
return -1;
}
ErrorOr<size_t> length = stream->read(data, size);
if (length.isError() || !length.isValue()) {
if(length.isError()){
std::cerr<<"*** Error: "<<length.error().message()<<std::endl;
}
return -1;
}
return static_cast<ssize_t>(length.value());
}
const std::string &TlsNetworkAddress::address() const {
assert(internal);
return internal->address();
}
uint16_t TlsNetworkAddress::port() const {
assert(internal);
return internal->port(); }
std::string TlsNetworkAddress::toString() const { return internal->toString(); }
TlsNetwork::TlsNetwork(Network &network) : internal{network} {}
Conveyor<Own<NetworkAddress>> TlsNetwork::parseAddress(const std::string &addr,
uint16_t port) {
return internal.parseAddress(addr, port)
.then(
[this, addr, port](Own<NetworkAddress> net) -> Own<NetworkAddress> {
assert(net);
return heap<TlsNetworkAddress>(std::move(net), tls);
});
}
std::optional<Own<TlsNetwork>> setupTlsNetwork(Network &network) {
return std::nullopt;
}
} // namespace gin

View File

@ -1,61 +1,69 @@
#pragma once
#include <kelgin/common.h>
#include <kelgin/io.h>
#include "../common.h"
#include "../io.h"
#include <optional>
#include <variant>
namespace gin {
class Tls {
public:
private:
class Impl;
Own<Impl> impl;
public:
Tls();
~Tls();
class Options {
public:
};
};
class TlsIoStream final : public IoStream {
private:
Own<IoStream> stream;
public:
TlsIoStream(Own<IoStream> str);
size_t read(void* buffer, size_t length) override;
Conveyor<void> readReady() override;
Conveyor<void> onReadDisconnected() override;
size_t write(const void* buffer, size_t length) override;
Conveyor<void> writeReady() override;
Impl &getImpl();
};
class TlsServer final : public Server {
private:
Own<Server> internal;
public:
TlsServer(Own<Server> srv);
Conveyor<Own<IoStream>> accept() override;
};
class TlsNetworkAddress final : public NetworkAddress {
private:
Own<NetworkAddress> internal;
std::string host_name;
Tls &tls;
public:
TlsNetworkAddress(Own<NetworkAddress> net_addr, const std::string& host_name_, Tls &tls_);
Own<Server> listen() override;
Own<IoStream> connect() override;
Conveyor<Own<IoStream>> connect() override;
std::string toString() override;
std::string toString() const override;
const std::string &address() const override;
uint16_t port() const override;
};
class TlsNetwork final : public Network {
public:
TlsNetwork(Network& network);
private:
Tls tls;
Network &internal;
Own<NetworkAddress> parseAddress(const std::string& addr, uint16_t port = 0) override;
public:
TlsNetwork(Network &network);
Conveyor<Own<NetworkAddress>> parseAddress(const std::string &addr,
uint16_t port = 0) override;
};
std::optional<Own<TlsNetwork>> setupTlsNetwork(Network& network);
std::optional<Own<TlsNetwork>> setupTlsNetwork(Network &network);
} // namespace gin