#include #include #include #include namespace saw { class tls::impl { public: gnutls_certificate_credentials_t xcred; public: impl() { gnutls_global_init(); gnutls_certificate_allocate_credentials(&xcred); gnutls_certificate_set_x509_system_trust(xcred); } ~impl() { gnutls_certificate_free_credentials(xcred); gnutls_global_deinit(); } }; template ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, size_t size); template ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size); tls::tls() : impl_{heap()} {} tls::~tls() {} tls::impl &tls::get_impl() { return *impl_; } template class tls_io_stream final : public io_stream> { private: own> internal_; gnutls_certificate_credentials_t xcred_; gnutls_session_t session_handle_; public: /* tls_io_stream(own> internal__, gnutls_certificate_credentials_t xcred__, gnutls_session_t session_handle__): internal_{std::move(internal__)}, xcred_{xcred__}, session_handle_{session_handle__} {} */ tls_io_stream(own> internal__): internal_{std::move(internal__)}, xcred_{}, session_handle_{} {} ~tls_io_stream() { gnutls_bye(session_handle_, GNUTLS_SHUT_RDWR); } error_or read(void *buffer, size_t length) override { ssize_t size = gnutls_record_recv(session_handle_, buffer, length); if (size < 0) { if(gnutls_error_is_fatal(size) == 0){ return make_error("Recoverable error on read in gnutls. TODO better error msg handling"); // Leaving proper message handling done in previous error framework //return recoverable_error([size](){return std::string{"Read recoverable Error "}+std::string{gnutls_strerror(size)};}, "Error read r"); }else{ return make_error("Fatal error on read in gnutls. TODO better error msg handling"); } }else if(size == 0){ return make_error(); } return static_cast(length); } conveyor read_ready() override { return internal_->read_ready(); } conveyor on_read_disconnected() override { return internal_->on_read_disconnected(); } error_or write(const void *buffer, size_t length) override { ssize_t size = gnutls_record_send(session_handle_, buffer, length); if(size < 0){ if(gnutls_error_is_fatal(size) == 0){ return make_error("Recoverable error on write in gnutls. TODO better error msg handling"); }else{ return make_error("Fatal error on write in gnutls. TODO better error msg handling"); } } return static_cast(size); } conveyor write_ready() override { return internal_->write_ready(); } gnutls_session_t &session() { return session_handle_; } }; template class tls_server final : public server> { private: own> internal_; gnutls_certificate_credentials_t xcred_; gnutls_session_t session_handle_; public: tls_server(own> internal__, gnutls_certificate_credentials_t xcred__): internal_{std::move(internal__)} {} ~tls_server() { gnutls_bye(session_handle_, GNUTLS_SHUT_RDWR); gnutls_certificate_free_credentials(xcred_); } conveyor>>> accept() override; }; template class tls_network final : public network> { private: own tls_; ref> internal_; public: tls_network(own tls_, network &network_); conveyor>>> resolve_address(const std::string &addr, uint16_t port = 0) override; error_or>>> parse_address(const std::string &addr, uint16_t port = 0) override { return make_error(); } error_or>>> listen(network_address>& address) override; conveyor>>> connect(network_address>& address) override; }; namespace { /* * Small helper for setting up the nonblocking connection handshake */ template struct tls_client_stream_helper { public: own>>>> feeder; conveyor_sink connection_sink; conveyor_sink stream_reader; conveyor_sink stream_writer; own> stream = nullptr; public: tls_client_stream_helper(own>>>> f): feeder{std::move(f)} {} void setupTurn(){ SAW_ASSERT(stream){ return; } stream_reader = stream->read_ready().then([this](){ turn(); }).sink(); stream_writer = stream->write_ready().then([this](){ turn(); }).sink(); } void turn(){ if(stream){ // Guarantee that the receiving end is already setup SAW_ASSERT(feeder){ return; } auto &session = stream->session(); int ret; do { ret = gnutls_handshake(session); } while ( (ret == GNUTLS_E_AGAIN || ret == GNUTLS_E_INTERRUPTED) && gnutls_error_is_fatal(ret) == 0); if(gnutls_error_is_fatal(ret)){ feeder->fail(make_error("Couldn't create Tls connection")); stream = nullptr; }else if(ret == GNUTLS_E_SUCCESS){ feeder->feed(std::move(stream)); } } } }; } template error_or>>> tls_network::listen(network_address>& address) { gnutls_certificate_credentials_t x509_cred; gnutls_certificate_allocate_credentials(&x509_cred); std::string_view KEYFILE = "key.pem"; std::string_view CERTFILE = "cert.pem"; std::string_view CAFILE = "/etc/ssl/certs/ca-certificates.crt"; std::string_view CRLFILE = "crl.pem"; gnutls_certificate_set_x509_trust_file(x509_cred, CAFILE, GNUTLS_X509_FMT_PEM); gnutls_certificate_set_x509_crl_file(x509_cred, CRLFILE, GNUTLS_X509_FMT_PEM); gnutls_certificate_set_x509_key_file(x509_cred, CERTFILE, KEYFILE, GNUTLS_X509_FMT_PEM); gnutls_certificate_set_x509_ocsp_status_request_file(x509_cred, OCSP_STATUS_FILE, 0); auto int_srv = internal_().listen(address.get_handle()); own>> tls_srv = heap>(std::move(int_srv), x509_cred); return tls_srv; } template conveyor>>> tls_server::accept() { SAW_ASSERT(internal_) { return conveyor>>>{fix_void>>>{nullptr}}; } auto caf = new_conveyor_and_feeder>>>(); own> helper = heap>(std::move(caf.feeder)); tls_client_stream_helper* hlp_ptr = helper.get(); auto prim_conv = internal_->accept().then([&](own> stream) -> error_or { io_stream* inner_stream = stream.get(); auto tls_stream = heap>(std::move(stream)); auto &session = tls_stream->session(); gnutls_init(&session, GNUTLS_SERVER); gnutls_certificate_server_set_request(session, GNUTLS_CERT_IGNORE); gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, xcred_); gnutls_set_default_priority(session); gnutls_transport_set_ptr(session, reinterpret_cast(inner_stream)); gnutls_transport_set_push_function(session, forst_tls_push_func); gnutls_transport_set_pull_function(session, forst_tls_pull_func); // gnutls_handshake_set_timeout(session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); hlp_ptr->stream = std::move(tls_stream); hlp_ptr->setupTurn(); hlp_ptr->turn(); return void_t{}; }); helper->connection_sink = prim_conv.sink(); return caf.conveyor.attach(std::move(helper)); } template conveyor>>> tls_network::connect(network_address>& address) { // Helper setups auto caf = new_conveyor_and_feeder>>>(); own> helper = heap>(std::move(caf.feeder)); tls_client_stream_helper* hlp_ptr = helper.get(); // Conveyor entangled structure auto prim_conv = internal_().connect(address.get_handle()).then([this, hlp_ptr, addr = address.get_handle().address()]( own> stream) -> error_or { io_stream* inner_stream = stream.get(); auto tls_stream = heap>(std::move(stream)); auto &session = tls_stream->session(); gnutls_init(&session, GNUTLS_CLIENT); gnutls_server_name_set(session, GNUTLS_NAME_DNS, addr.c_str(), addr.size()); gnutls_set_default_priority(session); gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE, tls_->get_impl().xcred); gnutls_session_set_verify_cert(session, addr.c_str(), 0); gnutls_transport_set_ptr(session, reinterpret_cast(inner_stream)); gnutls_transport_set_push_function(session, forst_tls_push_func); gnutls_transport_set_pull_function(session, forst_tls_pull_func); // gnutls_handshake_set_timeout(session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT); hlp_ptr->stream = std::move(tls_stream); hlp_ptr->setupTurn(); hlp_ptr->turn(); return void_t{}; }); helper->connection_sink = prim_conv.sink(); return caf.conveyor.attach(std::move(helper)); } template ssize_t forst_tls_push_func(gnutls_transport_ptr_t p, const void *data, size_t size) { io_stream *stream = reinterpret_cast*>(p); if (!stream) { return -1; } error_or length = stream->write(data, size); if (length.is_error() || !length.is_value()) { return -1; } return static_cast(length.get_value()); } template ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) { io_stream *stream = reinterpret_cast*>(p); if (!stream) { return -1; } error_or length = stream->read(data, size); if (length.is_error() || !length.is_value()) { return -1; } return static_cast(length.get_value()); } template tls_network::tls_network(own tls__, network &network) : tls_{std::move(tls__)},internal_{network} {} template conveyor>>> tls_network::resolve_address(const std::string &addr, uint16_t port) { /// @todo tls server name needed. Check validity. Won't matter later on, because gnutls should fail anyway. But /// it's better to find the error source sooner rather than later return internal_().resolve_address(addr, port).then([](auto net_addr){ return heap>>(std::move(net_addr)); }); } template error_or>>> setup_tls_network(network &net) { auto tls_ctx = heap(); own>> tls_net = heap>(std::move(tls_ctx), net); return tls_net; } } // namespace saw