tls works for clients
This commit is contained in:
parent
5890ff86ad
commit
9435dd1b25
11
SConstruct
11
SConstruct
|
@ -36,9 +36,12 @@ env=Environment(CPPPATH=['#source/kelgin','#source','#','#driver'],
|
|||
LIBS=['gnutls'])
|
||||
env.__class__.add_source_files = add_kel_source_files
|
||||
|
||||
env.objects = []
|
||||
env.sources = []
|
||||
env.headers = []
|
||||
env.objects = []
|
||||
|
||||
env.tls_sources = []
|
||||
env.tls_headers = []
|
||||
|
||||
env.driver_sources = []
|
||||
env.driver_headers = []
|
||||
|
@ -52,11 +55,11 @@ SConscript('driver/SConscript')
|
|||
env_library = env.Clone()
|
||||
|
||||
env.objects_shared = []
|
||||
env_library.add_source_files(env.objects_shared, env.sources + env.driver_sources, shared=True)
|
||||
env_library.add_source_files(env.objects_shared, env.sources + env.driver_sources + env.tls_sources, shared=True)
|
||||
env.library_shared = env_library.SharedLibrary('#bin/kelgin', [env.objects_shared])
|
||||
|
||||
env.objects_static = []
|
||||
env_library.add_source_files(env.objects_static, env.sources + env.driver_sources)
|
||||
env_library.add_source_files(env.objects_static, env.sources + env.driver_sources + env.tls_sources)
|
||||
env.library_static = env_library.StaticLibrary('#bin/kelgin', [env.objects_static])
|
||||
|
||||
env.Alias('library', [env.library_shared, env.library_static])
|
||||
|
@ -86,5 +89,7 @@ env.Alias('all', ['format', 'library_shared', 'library_static', 'test'])
|
|||
|
||||
env.Install('/usr/local/lib/', [env.library_shared, env.library_static])
|
||||
env.Install('/usr/local/include/kelgin/', [env.headers])
|
||||
env.Install('/usr/local/include/kelgin/tls/', [env.tls_headers])
|
||||
|
||||
env.Install('/usr/local/include/kelgin/test/', [env.test_headers])
|
||||
env.Alias('install', '/usr/local/')
|
||||
|
|
|
@ -213,6 +213,9 @@ std::string UnixNetworkAddress::toString() const {
|
|||
return {};
|
||||
}
|
||||
}
|
||||
const std::string &UnixNetworkAddress::address() const { return path; }
|
||||
|
||||
uint16_t UnixNetworkAddress::port() const { return port_hint; }
|
||||
|
||||
UnixNetwork::UnixNetwork(UnixEventPort &event) : event_port{event} {}
|
||||
|
||||
|
|
|
@ -404,6 +404,10 @@ public:
|
|||
Conveyor<Own<IoStream>> connect() override;
|
||||
|
||||
std::string toString() const override;
|
||||
|
||||
const std::string &address() const override;
|
||||
|
||||
uint16_t port() const override;
|
||||
};
|
||||
|
||||
class UnixNetwork final : public Network {
|
||||
|
|
|
@ -11,3 +11,6 @@ dir_path = Dir('.').abspath
|
|||
|
||||
env.sources += sorted(glob.glob(dir_path + "/*.cpp"))
|
||||
env.headers += sorted(glob.glob(dir_path + "/*.h"))
|
||||
|
||||
env.tls_sources += sorted(glob.glob(dir_path + "/tls/*.cpp"))
|
||||
env.tls_headers += sorted(glob.glob(dir_path + "/tls/*.h"))
|
||||
|
|
|
@ -701,32 +701,61 @@ public:
|
|||
|
||||
// ConveyorNode
|
||||
void getResult(ErrorOrValue &err_or_val) noexcept override {
|
||||
if (retrieved) {
|
||||
if (retrieved > 0) {
|
||||
err_or_val.as<FixVoid<T>>() = criticalError("Already taken value");
|
||||
} else {
|
||||
err_or_val.as<FixVoid<T>>() = std::move(value);
|
||||
}
|
||||
++retrieved;
|
||||
if(queued() > 0){
|
||||
++retrieved;
|
||||
}
|
||||
}
|
||||
|
||||
// Event
|
||||
void fire() override;
|
||||
};
|
||||
|
||||
class JoinConveyorNodeBase : public ConveyorNode, public ConveyorStorage {
|
||||
class JoinConveyorNodeBase : public ConveyorStorage {
|
||||
public:
|
||||
virtual ~JoinConveyorNodeBase() = default;
|
||||
|
||||
|
||||
};
|
||||
|
||||
template <typename T> class JoinConveyorNode : public JoinConveyorNodeBase {
|
||||
template <typename T> class JoinConveyorNode final : public JoinConveyorNodeBase {
|
||||
private:
|
||||
T data;
|
||||
public:
|
||||
};
|
||||
|
||||
template <typename... Args> class JoinConveyorMerger : public ConveyorStorage {
|
||||
class JoinConveyorMergerNodeBase : public ConveyorNode, public ConveyorStorage {
|
||||
public:
|
||||
|
||||
};
|
||||
|
||||
template <typename... Args> class JoinConveyorMergerNode final : public JoinConveyorMergerBase {
|
||||
private:
|
||||
std::tuple<JoinConveyorNode<Args>...> joined;
|
||||
public:
|
||||
void getResult(ErrorOrValue &err_or_val) noexcept override {
|
||||
|
||||
}
|
||||
|
||||
void fire() override;
|
||||
};
|
||||
|
||||
class UniteConveyorNodeBase : public ConveyorNode, public ConveyorStorage {
|
||||
public:
|
||||
virtual ~UniteConveyorNodeBase() = default;
|
||||
};
|
||||
|
||||
template <typename T> class UniteConveyorNode : public UniteConveyorNodeBase {
|
||||
public:
|
||||
virtual ~UniteConveyorNode() = default;
|
||||
};
|
||||
|
||||
template <typename T> class
|
||||
|
||||
} // namespace gin
|
||||
|
||||
#include "async.tmpl.h"
|
||||
|
|
|
@ -100,6 +100,9 @@ public:
|
|||
virtual Conveyor<Own<IoStream>> connect() = 0;
|
||||
|
||||
virtual std::string toString() const = 0;
|
||||
|
||||
virtual const std::string &address() const = 0;
|
||||
virtual uint16_t port() const = 0;
|
||||
};
|
||||
|
||||
class Network {
|
||||
|
|
|
@ -23,7 +23,6 @@ void ReadTaskAndStepHelper::readStep(InputStream &reader) {
|
|||
if (static_cast<size_t>(n) >= task.min_length &&
|
||||
static_cast<size_t>(n) <= task.max_length) {
|
||||
if (read_done) {
|
||||
// Accumulated bytes are not pushed
|
||||
read_done->feed(n + task.already_read);
|
||||
}
|
||||
read_task = std::nullopt;
|
||||
|
|
|
@ -0,0 +1,18 @@
|
|||
#pragma once
|
||||
|
||||
#include "async.h"
|
||||
#include "io.h"
|
||||
|
||||
namespace gin {
|
||||
/*
|
||||
template<typename Codec, typename Incoming, typename Outgoing>
|
||||
class StreamingIoPeer {
|
||||
private:
|
||||
Codec codec;
|
||||
public:
|
||||
void send(Outgoing&& outgoing);
|
||||
|
||||
Conveyor<Incoming> startReadPump();
|
||||
};
|
||||
*/
|
||||
}
|
|
@ -5,29 +5,203 @@
|
|||
|
||||
#include "io_helpers.h"
|
||||
|
||||
#include <cassert>
|
||||
|
||||
#include <iostream>
|
||||
|
||||
namespace gin {
|
||||
|
||||
class Tls::Impl {
|
||||
public:
|
||||
Impl(){
|
||||
gnutls_certificate_credentials_t xcred;
|
||||
|
||||
public:
|
||||
Impl() {
|
||||
gnutls_global_init();
|
||||
gnutls_certificate_allocate_credentials(&xcred);
|
||||
gnutls_certificate_set_x509_system_trust(xcred);
|
||||
}
|
||||
|
||||
~Impl(){
|
||||
~Impl() {
|
||||
gnutls_certificate_free_credentials(xcred);
|
||||
gnutls_global_deinit();
|
||||
}
|
||||
};
|
||||
|
||||
Tls::Tls():
|
||||
impl{heap<Tls::Impl>()}
|
||||
{}
|
||||
static ssize_t kelgin_tls_push_func(gnutls_transport_ptr_t p, const void *data,
|
||||
size_t size);
|
||||
static ssize_t kelgin_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size);
|
||||
|
||||
Tls::~Tls(){}
|
||||
Tls::Tls() : impl{heap<Tls::Impl>()} {}
|
||||
|
||||
Tls::~Tls() {}
|
||||
|
||||
Tls::Impl &Tls::getImpl() { return *impl; }
|
||||
|
||||
class TlsIoStream final : public IoStream {
|
||||
private:
|
||||
Own<IoStream> internal;
|
||||
gnutls_session_t session_handle;
|
||||
|
||||
class TlsNetworkImpl final : public TlsNetwork {
|
||||
public:
|
||||
TlsIoStream(Own<IoStream> internal_) : internal{std::move(internal_)} {}
|
||||
|
||||
~TlsIoStream() { gnutls_bye(session_handle, GNUTLS_SHUT_RDWR); }
|
||||
|
||||
ErrorOr<size_t> read(void *buffer, size_t length) override {
|
||||
ssize_t size = gnutls_record_recv(session_handle, buffer, length);
|
||||
if (size < 0) {
|
||||
if(gnutls_error_is_fatal(size) == 0){
|
||||
return recoverableError([size](){return std::string{"Read recoverable Error "}+std::string{gnutls_strerror(size)};}, "Error read r");
|
||||
}else{
|
||||
return criticalError([size](){return std::string{"Read critical Error "}+std::string{gnutls_strerror(size)};}, "Error read c");
|
||||
}
|
||||
}else if(size == 0){
|
||||
return criticalError("Disconnected");
|
||||
}
|
||||
|
||||
return static_cast<size_t>(length);
|
||||
}
|
||||
|
||||
Conveyor<void> readReady() override { return internal->readReady(); }
|
||||
|
||||
Conveyor<void> onReadDisconnected() override {
|
||||
return internal->onReadDisconnected();
|
||||
}
|
||||
|
||||
ErrorOr<size_t> write(const void *buffer, size_t length) override {
|
||||
ssize_t size = gnutls_record_send(session_handle, buffer, length);
|
||||
if(size < 0){
|
||||
if(gnutls_error_is_fatal(size) == 0){
|
||||
return recoverableError([size](){return std::string{"Write recoverable Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write r");
|
||||
}else{
|
||||
return criticalError([size](){return std::string{"Write critical Error "}+std::string{gnutls_strerror(size)} + " " + std::to_string(size);}, "Error write c");
|
||||
}
|
||||
}
|
||||
|
||||
return static_cast<size_t>(size);
|
||||
}
|
||||
|
||||
Conveyor<void> writeReady() override { return internal->writeReady(); }
|
||||
|
||||
gnutls_session_t &session() { return session_handle; }
|
||||
};
|
||||
|
||||
TlsServer::TlsServer(Own<Server> srv) : internal{std::move(srv)} {}
|
||||
|
||||
Conveyor<Own<IoStream>> TlsServer::accept() {
|
||||
GIN_ASSERT(internal) { return Conveyor<Own<IoStream>>{nullptr, nullptr}; }
|
||||
return internal->accept().then([](Own<IoStream> stream) -> Own<IoStream> {
|
||||
return heap<TlsIoStream>(std::move(stream));
|
||||
});
|
||||
}
|
||||
|
||||
TlsNetworkAddress::TlsNetworkAddress(Own<NetworkAddress> net_addr, const std::string& host_name_, Tls &tls_)
|
||||
: internal{std::move(net_addr)}, host_name{host_name_}, tls{tls_} {}
|
||||
|
||||
Own<Server> TlsNetworkAddress::listen() {
|
||||
GIN_ASSERT(internal) { return nullptr; }
|
||||
return heap<TlsServer>(internal->listen());
|
||||
}
|
||||
|
||||
Conveyor<Own<IoStream>> TlsNetworkAddress::connect() {
|
||||
GIN_ASSERT(internal) { return Conveyor<Own<IoStream>>{nullptr, nullptr}; }
|
||||
return internal->connect().then([this](
|
||||
Own<IoStream> stream) -> ErrorOr<Own<IoStream>> {
|
||||
IoStream* inner_stream = stream.get();
|
||||
auto tls_stream = heap<TlsIoStream>(std::move(stream));
|
||||
|
||||
auto &session = tls_stream->session();
|
||||
|
||||
gnutls_init(&session, GNUTLS_CLIENT);
|
||||
|
||||
const std::string &addr = this->address();
|
||||
|
||||
gnutls_server_name_set(session, GNUTLS_NAME_DNS, addr.c_str(),
|
||||
addr.size());
|
||||
|
||||
gnutls_set_default_priority(session);
|
||||
gnutls_credentials_set(session, GNUTLS_CRD_CERTIFICATE,
|
||||
tls.getImpl().xcred);
|
||||
gnutls_session_set_verify_cert(session, addr.c_str(), 0);
|
||||
|
||||
gnutls_transport_set_ptr(session, reinterpret_cast<gnutls_transport_ptr_t>(inner_stream));
|
||||
gnutls_transport_set_push_function(session, kelgin_tls_push_func);
|
||||
gnutls_transport_set_pull_function(session, kelgin_tls_pull_func);
|
||||
|
||||
// gnutls_handshake_set_timeout(session, GNUTLS_DEFAULT_HANDSHAKE_TIMEOUT);
|
||||
|
||||
int ret;
|
||||
do {
|
||||
ret = gnutls_handshake(session);
|
||||
} while (ret < 0 && gnutls_error_is_fatal(ret) == 0);
|
||||
|
||||
if(ret < 0){
|
||||
return criticalError("Couldn't create Tls connection");
|
||||
}
|
||||
|
||||
return tls_stream;
|
||||
});
|
||||
}
|
||||
|
||||
static ssize_t kelgin_tls_push_func(gnutls_transport_ptr_t p, const void *data,
|
||||
size_t size) {
|
||||
IoStream *stream = reinterpret_cast<IoStream *>(p);
|
||||
if (!stream) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
ErrorOr<size_t> length = stream->write(data, size);
|
||||
if (length.isError() || !length.isValue()) {
|
||||
if(length.isError()){
|
||||
std::cerr<<"*** Error: "<<length.error().message()<<std::endl;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
return static_cast<ssize_t>(length.value());
|
||||
}
|
||||
|
||||
static ssize_t kelgin_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t size) {
|
||||
IoStream *stream = reinterpret_cast<IoStream *>(p);
|
||||
if (!stream) {
|
||||
return -1;
|
||||
}
|
||||
|
||||
ErrorOr<size_t> length = stream->read(data, size);
|
||||
if (length.isError() || !length.isValue()) {
|
||||
if(length.isError()){
|
||||
std::cerr<<"*** Error: "<<length.error().message()<<std::endl;
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
return static_cast<ssize_t>(length.value());
|
||||
}
|
||||
|
||||
const std::string &TlsNetworkAddress::address() const {
|
||||
assert(internal);
|
||||
return internal->address();
|
||||
}
|
||||
uint16_t TlsNetworkAddress::port() const {
|
||||
assert(internal);
|
||||
return internal->port(); }
|
||||
|
||||
std::string TlsNetworkAddress::toString() const { return internal->toString(); }
|
||||
|
||||
TlsNetwork::TlsNetwork(Network &network) : internal{network} {}
|
||||
|
||||
Conveyor<Own<NetworkAddress>> TlsNetwork::parseAddress(const std::string &addr,
|
||||
uint16_t port) {
|
||||
return internal.parseAddress(addr, port)
|
||||
.then(
|
||||
[this, addr, port](Own<NetworkAddress> net) -> Own<NetworkAddress> {
|
||||
assert(net);
|
||||
return heap<TlsNetworkAddress>(std::move(net), tls);
|
||||
});
|
||||
}
|
||||
|
||||
std::optional<Own<TlsNetwork>> setupTlsNetwork(Network &network) {
|
||||
return std::nullopt;
|
||||
}
|
||||
} // namespace gin
|
||||
|
|
|
@ -1,61 +1,69 @@
|
|||
#pragma once
|
||||
|
||||
#include <kelgin/common.h>
|
||||
#include <kelgin/io.h>
|
||||
#include "../common.h"
|
||||
#include "../io.h"
|
||||
|
||||
#include <optional>
|
||||
#include <variant>
|
||||
|
||||
namespace gin {
|
||||
class Tls {
|
||||
public:
|
||||
private:
|
||||
class Impl;
|
||||
Own<Impl> impl;
|
||||
|
||||
|
||||
public:
|
||||
Tls();
|
||||
~Tls();
|
||||
|
||||
class Options {
|
||||
public:
|
||||
};
|
||||
};
|
||||
|
||||
class TlsIoStream final : public IoStream {
|
||||
private:
|
||||
Own<IoStream> stream;
|
||||
public:
|
||||
TlsIoStream(Own<IoStream> str);
|
||||
|
||||
size_t read(void* buffer, size_t length) override;
|
||||
|
||||
Conveyor<void> readReady() override;
|
||||
|
||||
Conveyor<void> onReadDisconnected() override;
|
||||
|
||||
size_t write(const void* buffer, size_t length) override;
|
||||
|
||||
Conveyor<void> writeReady() override;
|
||||
Impl &getImpl();
|
||||
};
|
||||
|
||||
class TlsServer final : public Server {
|
||||
private:
|
||||
Own<Server> internal;
|
||||
|
||||
public:
|
||||
TlsServer(Own<Server> srv);
|
||||
|
||||
Conveyor<Own<IoStream>> accept() override;
|
||||
};
|
||||
|
||||
class TlsNetworkAddress final : public NetworkAddress {
|
||||
private:
|
||||
Own<NetworkAddress> internal;
|
||||
std::string host_name;
|
||||
Tls &tls;
|
||||
|
||||
public:
|
||||
TlsNetworkAddress(Own<NetworkAddress> net_addr, const std::string& host_name_, Tls &tls_);
|
||||
|
||||
Own<Server> listen() override;
|
||||
|
||||
Own<IoStream> connect() override;
|
||||
Conveyor<Own<IoStream>> connect() override;
|
||||
|
||||
std::string toString() override;
|
||||
std::string toString() const override;
|
||||
|
||||
const std::string &address() const override;
|
||||
uint16_t port() const override;
|
||||
};
|
||||
|
||||
class TlsNetwork final : public Network {
|
||||
public:
|
||||
TlsNetwork(Network& network);
|
||||
private:
|
||||
Tls tls;
|
||||
Network &internal;
|
||||
|
||||
Own<NetworkAddress> parseAddress(const std::string& addr, uint16_t port = 0) override;
|
||||
public:
|
||||
TlsNetwork(Network &network);
|
||||
|
||||
Conveyor<Own<NetworkAddress>> parseAddress(const std::string &addr,
|
||||
uint16_t port = 0) override;
|
||||
};
|
||||
|
||||
std::optional<Own<TlsNetwork>> setupTlsNetwork(Network& network);
|
||||
std::optional<Own<TlsNetwork>> setupTlsNetwork(Network &network);
|
||||
|
||||
} // namespace gin
|
||||
|
|
Loading…
Reference in New Issue