From c742bc3f57cb00d84e2df034f757d4a39e3ade7e Mon Sep 17 00:00:00 2001 From: Claudius Holeksa Date: Sat, 29 Apr 2023 19:06:53 +0200 Subject: Added io tls with gnutls --- default.nix | 7 ++ forstio/io-tls/.nix/derivation.nix | 35 ++++++ forstio/io-tls/SConscript | 38 ++++++ forstio/io-tls/SConstruct | 66 ++++++++++ forstio/io-tls/tls.cpp | 250 +++++++++++++++++++++++++++++++++++++ forstio/io-tls/tls.h | 70 +++++++++++ 6 files changed, 466 insertions(+) create mode 100644 forstio/io-tls/.nix/derivation.nix create mode 100644 forstio/io-tls/SConscript create mode 100644 forstio/io-tls/SConstruct create mode 100644 forstio/io-tls/tls.cpp create mode 100644 forstio/io-tls/tls.h diff --git a/default.nix b/default.nix index 334b692..aff0363 100644 --- a/default.nix +++ b/default.nix @@ -31,5 +31,12 @@ in rec { clang = pkgs.clang_15; clang-tools = pkgs.clang-tools_15; }; + + io-tls = pkgs.callPackage forstio/io-tls/.nix/derivation.nix { + inherit version; + inherit forstio; + clang = pkgs.clang_15; + clang-tools = pkgs.clang-tools_15; + }; }; } diff --git a/forstio/io-tls/.nix/derivation.nix b/forstio/io-tls/.nix/derivation.nix new file mode 100644 index 0000000..7f142fb --- /dev/null +++ b/forstio/io-tls/.nix/derivation.nix @@ -0,0 +1,35 @@ +{ lib +, stdenvNoCC +, scons +, clang +, clang-tools +, version +, forstio +, gnutls +}: + +let + +in stdenvNoCC.mkDerivation { + pname = "forstio-io"; + inherit version; + + src = ./..; + + enableParallelBuilding = true; + + nativeBuildInputs = [ + scons + clang + clang-tools + ]; + + buildInputs = [ + forstio.core + forstio.async + forstio.io + gnutls + ]; + + outputs = ["out" "dev"]; +} diff --git a/forstio/io-tls/SConscript b/forstio/io-tls/SConscript new file mode 100644 index 0000000..4f88f37 --- /dev/null +++ b/forstio/io-tls/SConscript @@ -0,0 +1,38 @@ +#!/bin/false + +import os +import os.path +import glob + + +Import('env') + +dir_path = Dir('.').abspath + +# Environment for base library +io_tls_env = env.Clone(); + +io_tls_env.sources = sorted(glob.glob(dir_path + "/*.cpp")) +io_tls_env.headers = sorted(glob.glob(dir_path + "/*.h")) + +env.sources += io_tls_env.sources; +env.headers += io_tls_env.headers; + +## Shared lib +objects_shared = [] +io_tls_env.add_source_files(objects_shared, io_tls_env.sources, shared=True); +io_tls_env.library_shared = io_tls_env.SharedLibrary('#build/forstio-io-tls', [objects_shared]); + +## Static lib +objects_static = [] +io_tls_env.add_source_files(objects_static, io_tls_env.sources, shared=False); +io_tls_env.library_static = io_tls_env.StaticLibrary('#build/forstio-io-tls', [objects_static]); + +# Set Alias +env.Alias('library_io_tls', [io_tls_env.library_shared, io_tls_env.library_static]); + +env.targets += ['library_io_tls']; + +# Install +env.Install('$prefix/lib/', [io_tls_env.library_shared, io_tls_env.library_static]); +env.Install('$prefix/include/forstio/io/tls/', [io_tls_env.headers]); diff --git a/forstio/io-tls/SConstruct b/forstio/io-tls/SConstruct new file mode 100644 index 0000000..fbd8657 --- /dev/null +++ b/forstio/io-tls/SConstruct @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +import sys +import os +import os.path +import glob +import re + + +if sys.version_info < (3,): + def isbasestring(s): + return isinstance(s,basestring) +else: + def isbasestring(s): + return isinstance(s, (str,bytes)) + +def add_kel_source_files(self, sources, filetype, lib_env=None, shared=False, target_post=""): + + if isbasestring(filetype): + dir_path = self.Dir('.').abspath + filetype = sorted(glob.glob(dir_path+"/"+filetype)) + + for path in filetype: + target_name = re.sub( r'(.*?)(\.cpp|\.c\+\+)', r'\1' + target_post, path ) + if shared: + target_name+='.os' + sources.append( self.SharedObject( target=target_name, source=path ) ) + else: + target_name+='.o' + sources.append( self.StaticObject( target=target_name, source=path ) ) + pass + +def isAbsolutePath(key, dirname, env): + assert os.path.isabs(dirname), "%r must have absolute path syntax" % (key,) + +env_vars = Variables( + args=ARGUMENTS +) + +env_vars.Add('prefix', + help='Installation target location of build results and headers', + default='/usr/local/', + validator=isAbsolutePath +) + +env=Environment(ENV=os.environ, variables=env_vars, CPPPATH=[], + CPPDEFINES=['SAW_UNIX'], + CXXFLAGS=['-std=c++20','-g','-Wall','-Wextra'], + LIBS=['gnutls','forstio-io']) +env.__class__.add_source_files = add_kel_source_files +env.Tool('compilation_db'); +env.cdb = env.CompilationDatabase('compile_commands.json'); + +env.objects = []; +env.sources = []; +env.headers = []; +env.targets = []; + +Export('env') +SConscript('SConscript') + +env.Alias('cdb', env.cdb); +env.Alias('all', [env.targets]); +env.Default('all'); + +env.Alias('install', '$prefix') diff --git a/forstio/io-tls/tls.cpp b/forstio/io-tls/tls.cpp new file mode 100644 index 0000000..c1497bc --- /dev/null +++ b/forstio/io-tls/tls.cpp @@ -0,0 +1,250 @@ +#include "tls.h" + +#include +#include + +#include + +#include + +#include + +namespace saw { + +class Tls::Impl { +public: + gnutls_certificate_credentials_t xcred; + +public: + Impl() { + gnutls_global_init(); + gnutls_certificate_allocate_credentials(&xcred); + gnutls_certificate_set_x509_system_trust(xcred); + } + + ~Impl() { + gnutls_certificate_free_credentials(xcred); + gnutls_global_deinit(); + } +}; + +static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, + size_t size); +static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size); + +Tls::Tls() : impl{heap()} {} + +Tls::~Tls() {} + +Tls::Impl &Tls::getImpl() { return *impl; } + +class TlsIoStream final : public io_stream { +private: + own internal; + gnutls_session_t session_handle; + +public: + TlsIoStream(own internal_) : internal{std::move(internal_)} {} + + ~TlsIoStream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); } + + error_or read(void *buffer, size_t length) override { + ssize_t size = gnutls_record_recv(session_handle, buffer, length); + if (size < 0) { + if(gnutls_error_is_fatal(size) == 0){ + return recoverable_error([size](){return std::string{"Read recoverable Error "}+std::string{gnutls_strerror(size)};}, "Error read r"); + }else{ + return critical_error([size](){return std::string{"Read critical Error "}+std::string{gnutls_strerror(size)};}, "Error read c"); + } + }else if(size == 0){ + return critical_error("Disconnected"); + } + + return static_cast(length); + } + + conveyor read_ready() override { return internal->read_ready(); } + + conveyor on_read_disconnected() override { + return internal->on_read_disconnected(); + } + + error_or write(const void *buffer, size_t length) override { + ssize_t size = gnutls_record_send(session_handle, buffer, length); + if(size < 0){ + if(gnutls_error_is_fatal(size) == 0){ + return recoverable_error([size](){return std::string{"Write recoverable Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write r"); + }else{ + return critical_error([size](){return std::string{"Write critical Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write c"); + } + } + + return static_cast(size); + } + + conveyor write_ready() override { return internal->write_ready(); } + + gnutls_session_t &session() { return session_handle; } +}; + +TlsServer::TlsServer(own srv) : internal{std::move(srv)} {} + +conveyor> TlsServer::accept() { + SAW_ASSERT(internal) { return conveyor>{fix_void>{nullptr}}; } + return internal->accept().then([](own stream) -> own { + /// @todo handshake + + + return heap(std::move(stream)); + }); +} + +namespace { +/* +* Small helper for setting up the nonblocking connection handshake +*/ +struct TlsClientStreamHelper { +public: + own>> feeder; + conveyor_sink connection_sink; + conveyor_sink stream_reader; + conveyor_sink stream_writer; + + own stream = nullptr; +public: + TlsClientStreamHelper(own>> f): + feeder{std::move(f)} + {} + + void setupTurn(){ + SAW_ASSERT(stream){ + return; + } + + stream_reader = stream->read_ready().then([this](){ + turn(); + }).sink(); + + stream_writer = stream->write_ready().then([this](){ + turn(); + }).sink(); + } + + void turn(){ + if(stream){ + // Guarantee that the receiving end is already setup + SAW_ASSERT(feeder){ + return; + } + + auto &session = stream->session(); + + int ret; + do { + ret = gnutls_handshake(session); + } while ( (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) && gnutls_error_is_fatal(ret) == 0); + + if(gnutls_error_is_fatal(ret)){ + feeder->fail(critical_error("Couldn't create Tls connection")); + stream = nullptr; + }else if(ret == GNUTLS_E_SUCCESS){ + feeder->feed(std::move(stream)); + } + } + } +}; +} + +own TlsNetwork::listen(network_address& address) { + return heap(internal.listen(address)); +} + +conveyor> TlsNetwork::connect(network_address& address) { + // Helper setups + auto caf = new_conveyor_and_feeder>(); + own helper = heap(std::move(caf.feeder)); + TlsClientStreamHelper* hlp_ptr = helper.get(); + + // Conveyor entangled structure + auto prim_conv = internal.connect(address).then([this, hlp_ptr, addr = address.address()]( + own stream) -> error_or { + io_stream* inner_stream = stream.get(); + auto tls_stream = heap(std::move(stream)); + + auto &session = tls_stream->session(); + + gnutls_init(&session, GNUTLS_CLIENT); + + gnutls_server_name_set(session, GNUTLS_NAME_DNS, addr.c_str(), + addr.size()); + + gnutls_set_default_priority(session); + gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, + tls.getImpl().xcred); + gnutls_session_set_verify_cert(session, addr.c_str(), 0); + + gnutls_transport_set_ptr(session, reinterpret_cast(inner_stream)); + gnutls_transport_set_push_function(session, forst_tls_push_func); + gnutls_transport_set_pull_function(session, forst_tls_pull_func); + + // gnutls_handshake_set_timeout(session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); + + hlp_ptr->stream = std::move(tls_stream); + hlp_ptr->setupTurn(); + hlp_ptr->turn(); + + return void_t{}; + }); + + helper->connection_sink = prim_conv.sink(); + + return caf.conveyor.attach(std::move(helper)); +} + +own TlsNetwork::datagram(network_address& address){ + ///@unimplemented + return nullptr; +} + +static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, + size_t size) { + io_stream *stream = reinterpret_cast(p); + if (!stream) { + return -1; + } + + error_or length = stream->write(data, size); + if (length.is_error() || !length.is_value()) { + return -1; + } + + return static_cast(length.value()); +} + +static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) { + io_stream *stream = reinterpret_cast(p); + if (!stream) { + return -1; + } + + error_or length = stream->read(data, size); + if (length.is_error() || !length.is_value()) { + return -1; + } + + return static_cast(length.value()); +} + +TlsNetwork::TlsNetwork(Tls& tls_, network &network) : tls{tls_},internal{network} {} + +conveyor> TlsNetwork::resolve_address(const std::string &addr, + uint16_t port) { + /// @todo tls server name needed. Check validity. Won't matter later on, because gnutls should fail anyway. But + /// it's better to find the error source sooner rather than later + return internal.resolve_address(addr, port); +} + +std::optional> setupTlsNetwork(network &network) { + return std::nullopt; +} +} // namespace saw diff --git a/forstio/io-tls/tls.h b/forstio/io-tls/tls.h new file mode 100644 index 0000000..8a31c1d --- /dev/null +++ b/forstio/io-tls/tls.h @@ -0,0 +1,70 @@ +#pragma once + +#include +#include + +#include +#include + +namespace saw { +class Tls; + +class TlsServer final : public server { +private: + own internal; + +public: + TlsServer(own srv); + + conveyor> accept() override; +}; + +class TlsNetwork final : public network { +private: + Tls& tls; + network &internal; +public: + TlsNetwork(Tls& tls_, network &network_); + + conveyor> resolve_address(const std::string &addr, uint16_t port = 0) override; + + own listen(network_address& address) override; + + conveyor> connect(network_address& address) override; + + own datagram(network_address& address) override; +}; + +/** +* Tls context class. +* Provides tls network class which ensures the usage of tls encrypted connections +*/ +class Tls { +private: + class Impl; + own impl; +public: + Tls(); + ~Tls(); + + struct Version { + struct Tls_1_0{}; + struct Tls_1_1{}; + struct Tls_1_2{}; + }; + + struct Options { + public: + Version version; + }; + + network& tlsNetwork(); + + Impl &getImpl(); +private: + Options options; +}; + +std::optional> setupTlsNetwork(network &network); + +} // namespace saw -- cgit v1.2.3