251 lines
6.8 KiB
C++
251 lines
6.8 KiB
C++
#include "tls.h"
|
|
|
|
#include <gnutls/gnutls.h>
|
|
#include <gnutls/x509.h>
|
|
|
|
#include "io_helpers.h"
|
|
|
|
#include <cassert>
|
|
|
|
#include <iostream>
|
|
|
|
namespace saw {
|
|
|
|
class Tls::Impl {
|
|
public:
|
|
gnutls_certificate_credentials_t xcred;
|
|
|
|
public:
|
|
Impl() {
|
|
gnutls_global_init();
|
|
gnutls_certificate_allocate_credentials(&xcred);
|
|
gnutls_certificate_set_x509_system_trust(xcred);
|
|
}
|
|
|
|
~Impl() {
|
|
gnutls_certificate_free_credentials(xcred);
|
|
gnutls_global_deinit();
|
|
}
|
|
};
|
|
|
|
static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data,
|
|
size_t size);
|
|
static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size);
|
|
|
|
Tls::Tls() : impl{heap<Tls::Impl>()} {}
|
|
|
|
Tls::~Tls() {}
|
|
|
|
Tls::Impl &Tls::getImpl() { return *impl; }
|
|
|
|
class TlsIoStream final : public io_stream {
|
|
private:
|
|
own<io_stream> internal;
|
|
gnutls_session_t session_handle;
|
|
|
|
public:
|
|
TlsIoStream(own<io_stream> internal_) : internal{std::move(internal_)} {}
|
|
|
|
~TlsIoStream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); }
|
|
|
|
error_or<size_t> read(void *buffer, size_t length) override {
|
|
ssize_t size = gnutls_record_recv(session_handle, buffer, length);
|
|
if (size < 0) {
|
|
if(gnutls_error_is_fatal(size) == 0){
|
|
return recoverable_error([size](){return std::string{"Read recoverable Error "}+std::string{gnutls_strerror(size)};}, "Error read r");
|
|
}else{
|
|
return critical_error([size](){return std::string{"Read critical Error "}+std::string{gnutls_strerror(size)};}, "Error read c");
|
|
}
|
|
}else if(size == 0){
|
|
return critical_error("Disconnected");
|
|
}
|
|
|
|
return static_cast<size_t>(length);
|
|
}
|
|
|
|
conveyor<void> read_ready() override { return internal->read_ready(); }
|
|
|
|
conveyor<void> on_read_disconnected() override {
|
|
return internal->on_read_disconnected();
|
|
}
|
|
|
|
error_or<size_t> write(const void *buffer, size_t length) override {
|
|
ssize_t size = gnutls_record_send(session_handle, buffer, length);
|
|
if(size < 0){
|
|
if(gnutls_error_is_fatal(size) == 0){
|
|
return recoverable_error([size](){return std::string{"Write recoverable Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write r");
|
|
}else{
|
|
return critical_error([size](){return std::string{"Write critical Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write c");
|
|
}
|
|
}
|
|
|
|
return static_cast<size_t>(size);
|
|
}
|
|
|
|
conveyor<void> write_ready() override { return internal->write_ready(); }
|
|
|
|
gnutls_session_t &session() { return session_handle; }
|
|
};
|
|
|
|
TlsServer::TlsServer(own<server> srv) : internal{std::move(srv)} {}
|
|
|
|
conveyor<own<io_stream>> TlsServer::accept() {
|
|
SAW_ASSERT(internal) { return conveyor<own<io_stream>>{fix_void<own<io_stream>>{nullptr}}; }
|
|
return internal->accept().then([](own<io_stream> stream) -> own<io_stream> {
|
|
/// @todo handshake
|
|
|
|
|
|
return heap<TlsIoStream>(std::move(stream));
|
|
});
|
|
}
|
|
|
|
namespace {
|
|
/*
|
|
* Small helper for setting up the nonblocking connection handshake
|
|
*/
|
|
struct TlsClientStreamHelper {
|
|
public:
|
|
own<conveyor_feeder<own<io_stream>>> feeder;
|
|
conveyor_sink connection_sink;
|
|
conveyor_sink stream_reader;
|
|
conveyor_sink stream_writer;
|
|
|
|
own<TlsIoStream> stream = nullptr;
|
|
public:
|
|
TlsClientStreamHelper(own<conveyor_feeder<own<io_stream>>> f):
|
|
feeder{std::move(f)}
|
|
{}
|
|
|
|
void setupTurn(){
|
|
SAW_ASSERT(stream){
|
|
return;
|
|
}
|
|
|
|
stream_reader = stream->read_ready().then([this](){
|
|
turn();
|
|
}).sink();
|
|
|
|
stream_writer = stream->write_ready().then([this](){
|
|
turn();
|
|
}).sink();
|
|
}
|
|
|
|
void turn(){
|
|
if(stream){
|
|
// Guarantee that the receiving end is already setup
|
|
SAW_ASSERT(feeder){
|
|
return;
|
|
}
|
|
|
|
auto &session = stream->session();
|
|
|
|
int ret;
|
|
do {
|
|
ret = gnutls_handshake(session);
|
|
} while ( (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) && gnutls_error_is_fatal(ret) == 0);
|
|
|
|
if(gnutls_error_is_fatal(ret)){
|
|
feeder->fail(critical_error("Couldn't create Tls connection"));
|
|
stream = nullptr;
|
|
}else if(ret == GNUTLS_E_SUCCESS){
|
|
feeder->feed(std::move(stream));
|
|
}
|
|
}
|
|
}
|
|
};
|
|
}
|
|
|
|
own<server> TlsNetwork::listen(network_address& address) {
|
|
return heap<TlsServer>(internal.listen(address));
|
|
}
|
|
|
|
conveyor<own<io_stream>> TlsNetwork::connect(network_address& address) {
|
|
// Helper setups
|
|
auto caf = new_conveyor_and_feeder<own<io_stream>>();
|
|
own<TlsClientStreamHelper> helper = heap<TlsClientStreamHelper>(std::move(caf.feeder));
|
|
TlsClientStreamHelper* hlp_ptr = helper.get();
|
|
|
|
// Conveyor entangled structure
|
|
auto prim_conv = internal.connect(address).then([this, hlp_ptr, addr = address.address()](
|
|
own<io_stream> stream) -> error_or<void> {
|
|
io_stream* inner_stream = stream.get();
|
|
auto tls_stream = heap<TlsIoStream>(std::move(stream));
|
|
|
|
auto &session = tls_stream->session();
|
|
|
|
gnutls_init(&session, GNUTLS_CLIENT);
|
|
|
|
gnutls_server_name_set(session, GNUTLS_NAME_DNS, addr.c_str(),
|
|
addr.size());
|
|
|
|
gnutls_set_default_priority(session);
|
|
gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE,
|
|
tls.getImpl().xcred);
|
|
gnutls_session_set_verify_cert(session, addr.c_str(), 0);
|
|
|
|
gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(inner_stream));
|
|
gnutls_transport_set_push_function(session, forst_tls_push_func);
|
|
gnutls_transport_set_pull_function(session, forst_tls_pull_func);
|
|
|
|
// gnutls_handshake_set_timeout(session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
|
|
|
|
hlp_ptr->stream = std::move(tls_stream);
|
|
hlp_ptr->setupTurn();
|
|
hlp_ptr->turn();
|
|
|
|
return Void{};
|
|
});
|
|
|
|
helper->connection_sink = prim_conv.sink();
|
|
|
|
return caf.conveyor.attach(std::move(helper));
|
|
}
|
|
|
|
own<datagram> TlsNetwork::datagram(network_address& address){
|
|
///@unimplemented
|
|
return nullptr;
|
|
}
|
|
|
|
static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data,
|
|
size_t size) {
|
|
io_stream *stream = reinterpret_cast<io_stream *>(p);
|
|
if (!stream) {
|
|
return -1;
|
|
}
|
|
|
|
error_or<size_t> length = stream->write(data, size);
|
|
if (length.is_error() || !length.is_value()) {
|
|
return -1;
|
|
}
|
|
|
|
return static_cast<ssize_t>(length.value());
|
|
}
|
|
|
|
static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) {
|
|
io_stream *stream = reinterpret_cast<io_stream *>(p);
|
|
if (!stream) {
|
|
return -1;
|
|
}
|
|
|
|
error_or<size_t> length = stream->read(data, size);
|
|
if (length.is_error() || !length.is_value()) {
|
|
return -1;
|
|
}
|
|
|
|
return static_cast<ssize_t>(length.value());
|
|
}
|
|
|
|
TlsNetwork::TlsNetwork(Tls& tls_, network &network) : tls{tls_},internal{network} {}
|
|
|
|
conveyor<own<network_address>> TlsNetwork::resolve_address(const std::string &addr,
|
|
uint16_t port) {
|
|
/// @todo tls server name needed. Check validity. Won't matter later on, because gnutls should fail anyway. But
|
|
/// it's better to find the error source sooner rather than later
|
|
return internal.resolve_address(addr, port);
|
|
}
|
|
|
|
std::optional<own<TlsNetwork>> setupTlsNetwork(network &network) {
|
|
return std::nullopt;
|
|
}
|
|
} // namespace saw
|