summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--default.nix7
-rw-r--r--forstio/io-tls/.nix/derivation.nix35
-rw-r--r--forstio/io-tls/SConscript38
-rw-r--r--forstio/io-tls/SConstruct66
-rw-r--r--forstio/io-tls/tls.cpp250
-rw-r--r--forstio/io-tls/tls.h70
6 files changed, 466 insertions, 0 deletions
diff --git a/default.nix b/default.nix
index 334b692..aff0363 100644
--- a/default.nix
+++ b/default.nix
@@ -31,5 +31,12 @@ in rec {
clang = pkgs.clang_15;
clang-tools = pkgs.clang-tools_15;
};
+
+ io-tls = pkgs.callPackage forstio/io-tls/.nix/derivation.nix {
+ inherit version;
+ inherit forstio;
+ clang = pkgs.clang_15;
+ clang-tools = pkgs.clang-tools_15;
+ };
};
}
diff --git a/forstio/io-tls/.nix/derivation.nix b/forstio/io-tls/.nix/derivation.nix
new file mode 100644
index 0000000..7f142fb
--- /dev/null
+++ b/forstio/io-tls/.nix/derivation.nix
@@ -0,0 +1,35 @@
+{ lib
+, stdenvNoCC
+, scons
+, clang
+, clang-tools
+, version
+, forstio
+, gnutls
+}:
+
+let
+
+in stdenvNoCC.mkDerivation {
+ pname = "forstio-io";
+ inherit version;
+
+ src = ./..;
+
+ enableParallelBuilding = true;
+
+ nativeBuildInputs = [
+ scons
+ clang
+ clang-tools
+ ];
+
+ buildInputs = [
+ forstio.core
+ forstio.async
+ forstio.io
+ gnutls
+ ];
+
+ outputs = ["out" "dev"];
+}
diff --git a/forstio/io-tls/SConscript b/forstio/io-tls/SConscript
new file mode 100644
index 0000000..4f88f37
--- /dev/null
+++ b/forstio/io-tls/SConscript
@@ -0,0 +1,38 @@
+#!/bin/false
+
+import os
+import os.path
+import glob
+
+
+Import('env')
+
+dir_path = Dir('.').abspath
+
+# Environment for base library
+io_tls_env = env.Clone();
+
+io_tls_env.sources = sorted(glob.glob(dir_path + "/*.cpp"))
+io_tls_env.headers = sorted(glob.glob(dir_path + "/*.h"))
+
+env.sources += io_tls_env.sources;
+env.headers += io_tls_env.headers;
+
+## Shared lib
+objects_shared = []
+io_tls_env.add_source_files(objects_shared, io_tls_env.sources, shared=True);
+io_tls_env.library_shared = io_tls_env.SharedLibrary('#build/forstio-io-tls', [objects_shared]);
+
+## Static lib
+objects_static = []
+io_tls_env.add_source_files(objects_static, io_tls_env.sources, shared=False);
+io_tls_env.library_static = io_tls_env.StaticLibrary('#build/forstio-io-tls', [objects_static]);
+
+# Set Alias
+env.Alias('library_io_tls', [io_tls_env.library_shared, io_tls_env.library_static]);
+
+env.targets += ['library_io_tls'];
+
+# Install
+env.Install('$prefix/lib/', [io_tls_env.library_shared, io_tls_env.library_static]);
+env.Install('$prefix/include/forstio/io/tls/', [io_tls_env.headers]);
diff --git a/forstio/io-tls/SConstruct b/forstio/io-tls/SConstruct
new file mode 100644
index 0000000..fbd8657
--- /dev/null
+++ b/forstio/io-tls/SConstruct
@@ -0,0 +1,66 @@
+#!/usr/bin/env python3
+
+import sys
+import os
+import os.path
+import glob
+import re
+
+
+if sys.version_info < (3,):
+ def isbasestring(s):
+ return isinstance(s,basestring)
+else:
+ def isbasestring(s):
+ return isinstance(s, (str,bytes))
+
+def add_kel_source_files(self, sources, filetype, lib_env=None, shared=False, target_post=""):
+
+ if isbasestring(filetype):
+ dir_path = self.Dir('.').abspath
+ filetype = sorted(glob.glob(dir_path+"/"+filetype))
+
+ for path in filetype:
+ target_name = re.sub( r'(.*?)(\.cpp|\.c\+\+)', r'\1' + target_post, path )
+ if shared:
+ target_name+='.os'
+ sources.append( self.SharedObject( target=target_name, source=path ) )
+ else:
+ target_name+='.o'
+ sources.append( self.StaticObject( target=target_name, source=path ) )
+ pass
+
+def isAbsolutePath(key, dirname, env):
+ assert os.path.isabs(dirname), "%r must have absolute path syntax" % (key,)
+
+env_vars = Variables(
+ args=ARGUMENTS
+)
+
+env_vars.Add('prefix',
+ help='Installation target location of build results and headers',
+ default='/usr/local/',
+ validator=isAbsolutePath
+)
+
+env=Environment(ENV=os.environ, variables=env_vars, CPPPATH=[],
+ CPPDEFINES=['SAW_UNIX'],
+ CXXFLAGS=['-std=c++20','-g','-Wall','-Wextra'],
+ LIBS=['gnutls','forstio-io'])
+env.__class__.add_source_files = add_kel_source_files
+env.Tool('compilation_db');
+env.cdb = env.CompilationDatabase('compile_commands.json');
+
+env.objects = [];
+env.sources = [];
+env.headers = [];
+env.targets = [];
+
+Export('env')
+SConscript('SConscript')
+
+env.Alias('cdb', env.cdb);
+env.Alias('all', [env.targets]);
+env.Default('all');
+
+env.Alias('install', '$prefix')
diff --git a/forstio/io-tls/tls.cpp b/forstio/io-tls/tls.cpp
new file mode 100644
index 0000000..c1497bc
--- /dev/null
+++ b/forstio/io-tls/tls.cpp
@@ -0,0 +1,250 @@
+#include "tls.h"
+
+#include <gnutls/gnutls.h>
+#include <gnutls/x509.h>
+
+#include <forstio/io/io_helpers.h>
+
+#include <cassert>
+
+#include <iostream>
+
+namespace saw {
+
+class Tls::Impl {
+public:
+ gnutls_certificate_credentials_t xcred;
+
+public:
+ Impl() {
+ gnutls_global_init();
+ gnutls_certificate_allocate_credentials(&xcred);
+ gnutls_certificate_set_x509_system_trust(xcred);
+ }
+
+ ~Impl() {
+ gnutls_certificate_free_credentials(xcred);
+ gnutls_global_deinit();
+ }
+};
+
+static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data,
+ size_t size);
+static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size);
+
+Tls::Tls() : impl{heap<Tls::Impl>()} {}
+
+Tls::~Tls() {}
+
+Tls::Impl &Tls::getImpl() { return *impl; }
+
+class TlsIoStream final : public io_stream {
+private:
+ own<io_stream> internal;
+ gnutls_session_t session_handle;
+
+public:
+ TlsIoStream(own<io_stream> internal_) : internal{std::move(internal_)} {}
+
+ ~TlsIoStream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); }
+
+ error_or<size_t> read(void *buffer, size_t length) override {
+ ssize_t size = gnutls_record_recv(session_handle, buffer, length);
+ if (size < 0) {
+ if(gnutls_error_is_fatal(size) == 0){
+ return recoverable_error([size](){return std::string{"Read recoverable Error "}+std::string{gnutls_strerror(size)};}, "Error read r");
+ }else{
+ return critical_error([size](){return std::string{"Read critical Error "}+std::string{gnutls_strerror(size)};}, "Error read c");
+ }
+ }else if(size == 0){
+ return critical_error("Disconnected");
+ }
+
+ return static_cast<size_t>(length);
+ }
+
+ conveyor<void> read_ready() override { return internal->read_ready(); }
+
+ conveyor<void> on_read_disconnected() override {
+ return internal->on_read_disconnected();
+ }
+
+ error_or<size_t> write(const void *buffer, size_t length) override {
+ ssize_t size = gnutls_record_send(session_handle, buffer, length);
+ if(size < 0){
+ if(gnutls_error_is_fatal(size) == 0){
+ return recoverable_error([size](){return std::string{"Write recoverable Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write r");
+ }else{
+ return critical_error([size](){return std::string{"Write critical Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write c");
+ }
+ }
+
+ return static_cast<size_t>(size);
+ }
+
+ conveyor<void> write_ready() override { return internal->write_ready(); }
+
+ gnutls_session_t &session() { return session_handle; }
+};
+
+TlsServer::TlsServer(own<server> srv) : internal{std::move(srv)} {}
+
+conveyor<own<io_stream>> TlsServer::accept() {
+ SAW_ASSERT(internal) { return conveyor<own<io_stream>>{fix_void<own<io_stream>>{nullptr}}; }
+ return internal->accept().then([](own<io_stream> stream) -> own<io_stream> {
+ /// @todo handshake
+
+
+ return heap<TlsIoStream>(std::move(stream));
+ });
+}
+
+namespace {
+/*
+* Small helper for setting up the nonblocking connection handshake
+*/
+struct TlsClientStreamHelper {
+public:
+ own<conveyor_feeder<own<io_stream>>> feeder;
+ conveyor_sink connection_sink;
+ conveyor_sink stream_reader;
+ conveyor_sink stream_writer;
+
+ own<TlsIoStream> stream = nullptr;
+public:
+ TlsClientStreamHelper(own<conveyor_feeder<own<io_stream>>> f):
+ feeder{std::move(f)}
+ {}
+
+ void setupTurn(){
+ SAW_ASSERT(stream){
+ return;
+ }
+
+ stream_reader = stream->read_ready().then([this](){
+ turn();
+ }).sink();
+
+ stream_writer = stream->write_ready().then([this](){
+ turn();
+ }).sink();
+ }
+
+ void turn(){
+ if(stream){
+ // Guarantee that the receiving end is already setup
+ SAW_ASSERT(feeder){
+ return;
+ }
+
+ auto &session = stream->session();
+
+ int ret;
+ do {
+ ret = gnutls_handshake(session);
+ } while ( (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) && gnutls_error_is_fatal(ret) == 0);
+
+ if(gnutls_error_is_fatal(ret)){
+ feeder->fail(critical_error("Couldn't create Tls connection"));
+ stream = nullptr;
+ }else if(ret == GNUTLS_E_SUCCESS){
+ feeder->feed(std::move(stream));
+ }
+ }
+ }
+};
+}
+
+own<server> TlsNetwork::listen(network_address& address) {
+ return heap<TlsServer>(internal.listen(address));
+}
+
+conveyor<own<io_stream>> TlsNetwork::connect(network_address& address) {
+ // Helper setups
+ auto caf = new_conveyor_and_feeder<own<io_stream>>();
+ own<TlsClientStreamHelper> helper = heap<TlsClientStreamHelper>(std::move(caf.feeder));
+ TlsClientStreamHelper* hlp_ptr = helper.get();
+
+ // Conveyor entangled structure
+ auto prim_conv = internal.connect(address).then([this, hlp_ptr, addr = address.address()](
+ own<io_stream> stream) -> error_or<void> {
+ io_stream* inner_stream = stream.get();
+ auto tls_stream = heap<TlsIoStream>(std::move(stream));
+
+ auto &session = tls_stream->session();
+
+ gnutls_init(&session, GNUTLS_CLIENT);
+
+ gnutls_server_name_set(session, GNUTLS_NAME_DNS, addr.c_str(),
+ addr.size());
+
+ gnutls_set_default_priority(session);
+ gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE,
+ tls.getImpl().xcred);
+ gnutls_session_set_verify_cert(session, addr.c_str(), 0);
+
+ gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(inner_stream));
+ gnutls_transport_set_push_function(session, forst_tls_push_func);
+ gnutls_transport_set_pull_function(session, forst_tls_pull_func);
+
+ // gnutls_handshake_set_timeout(session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
+
+ hlp_ptr->stream = std::move(tls_stream);
+ hlp_ptr->setupTurn();
+ hlp_ptr->turn();
+
+ return void_t{};
+ });
+
+ helper->connection_sink = prim_conv.sink();
+
+ return caf.conveyor.attach(std::move(helper));
+}
+
+own<datagram> TlsNetwork::datagram(network_address& address){
+ ///@unimplemented
+ return nullptr;
+}
+
+static ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data,
+ size_t size) {
+ io_stream *stream = reinterpret_cast<io_stream *>(p);
+ if (!stream) {
+ return -1;
+ }
+
+ error_or<size_t> length = stream->write(data, size);
+ if (length.is_error() || !length.is_value()) {
+ return -1;
+ }
+
+ return static_cast<ssize_t>(length.value());
+}
+
+static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) {
+ io_stream *stream = reinterpret_cast<io_stream *>(p);
+ if (!stream) {
+ return -1;
+ }
+
+ error_or<size_t> length = stream->read(data, size);
+ if (length.is_error() || !length.is_value()) {
+ return -1;
+ }
+
+ return static_cast<ssize_t>(length.value());
+}
+
+TlsNetwork::TlsNetwork(Tls& tls_, network &network) : tls{tls_},internal{network} {}
+
+conveyor<own<network_address>> TlsNetwork::resolve_address(const std::string &addr,
+ uint16_t port) {
+ /// @todo tls server name needed. Check validity. Won't matter later on, because gnutls should fail anyway. But
+ /// it's better to find the error source sooner rather than later
+ return internal.resolve_address(addr, port);
+}
+
+std::optional<own<TlsNetwork>> setupTlsNetwork(network &network) {
+ return std::nullopt;
+}
+} // namespace saw
diff --git a/forstio/io-tls/tls.h b/forstio/io-tls/tls.h
new file mode 100644
index 0000000..8a31c1d
--- /dev/null
+++ b/forstio/io-tls/tls.h
@@ -0,0 +1,70 @@
+#pragma once
+
+#include <forstio/core/common.h>
+#include <forstio/io/io.h>
+
+#include <optional>
+#include <variant>
+
+namespace saw {
+class Tls;
+
+class TlsServer final : public server {
+private:
+ own<server> internal;
+
+public:
+ TlsServer(own<server> srv);
+
+ conveyor<own<io_stream>> accept() override;
+};
+
+class TlsNetwork final : public network {
+private:
+ Tls& tls;
+ network &internal;
+public:
+ TlsNetwork(Tls& tls_, network &network_);
+
+ conveyor<own<network_address>> resolve_address(const std::string &addr, uint16_t port = 0) override;
+
+ own<server> listen(network_address& address) override;
+
+ conveyor<own<io_stream>> connect(network_address& address) override;
+
+ own<class datagram> datagram(network_address& address) override;
+};
+
+/**
+* Tls context class.
+* Provides tls network class which ensures the usage of tls encrypted connections
+*/
+class Tls {
+private:
+ class Impl;
+ own<Impl> impl;
+public:
+ Tls();
+ ~Tls();
+
+ struct Version {
+ struct Tls_1_0{};
+ struct Tls_1_1{};
+ struct Tls_1_2{};
+ };
+
+ struct Options {
+ public:
+ Version version;
+ };
+
+ network& tlsNetwork();
+
+ Impl &getImpl();
+private:
+ Options options;
+};
+
+std::optional<own<TlsNetwork>> setupTlsNetwork(network &network);
+
+} // namespace saw