moving helpers, restructuring for tls abstraction

fb-windows
keldu.magnus 2021-05-30 00:46:07 +02:00
parent 00fac70f2e
commit 0b58504233
18 changed files with 218 additions and 227 deletions

View File

@ -40,6 +40,9 @@ env.sources = []
env.headers = []
env.objects = []
env.driver_sources = []
env.driver_headers = []
Export('env')
SConscript('source/kelgin/SConscript')
SConscript('driver/SConscript')
@ -49,11 +52,11 @@ SConscript('driver/SConscript')
env_library = env.Clone()
env.objects_shared = []
env_library.add_source_files(env.objects_shared, env.sources, shared=True)
env_library.add_source_files(env.objects_shared, env.sources + env.driver_sources, shared=True)
env.library_shared = env_library.SharedLibrary('#bin/kelgin', [env.objects_shared])
env.objects_static = []
env_library.add_source_files(env.objects_static, env.sources)
env_library.add_source_files(env.objects_static, env.sources + env.driver_sources)
env.library_static = env_library.StaticLibrary('#bin/kelgin', [env.objects_static])
env.Alias('library', [env.library_shared, env.library_static])
@ -75,7 +78,7 @@ def format_iter(env,files):
env.format_actions.append(env.AlwaysBuild(env.ClangFormat(target=f+"-clang-format",source=f)))
pass
format_iter(env,env.sources + env.headers)
format_iter(env,env.sources + env.driver_sources + env.headers + env.driver_headers)
env.Alias('format', env.format_actions)

View File

@ -9,5 +9,7 @@ Import('env')
dir_path = Dir('.').abspath
env.sources += sorted(glob.glob(dir_path + "/*.cpp"))
env.headers += sorted(glob.glob(dir_path + "/*.h"))
env.driver_sources += sorted(glob.glob(dir_path + "/tls/*.cpp"))
env.driver_sources += sorted(glob.glob(dir_path + "/*.cpp"))
env.driver_headers += sorted(glob.glob(dir_path + "/*.h"))

View File

@ -3,6 +3,7 @@
#include <sstream>
namespace gin {
namespace unix {
IFdOwner::IFdOwner(UnixEventPort &event_port, int file_descriptor, int fd_flags,
uint32_t event_mask)
: event_port{event_port}, file_descriptor{file_descriptor},
@ -17,86 +18,21 @@ IFdOwner::~IFdOwner() {
}
}
ssize_t UnixIoStream::dataRead(void *buffer, size_t length) {
return ::read(fd(), buffer, length);
ssize_t unixRead(int fd, void* buffer, size_t length){
return ::read(fd, buffer, length);
}
ssize_t UnixIoStream::dataWrite(const void *buffer, size_t length) {
return ::write(fd(), buffer, length);
}
/*
void UnixIoStream::readStep() {
if (read_ready) {
read_ready->feed();
}
while (!read_tasks.empty()) {
ReadIoTask &task = read_tasks.front();
ssize_t n = ::read(fd(), task.buffer, task.max_length);
if (n <= 0) {
if (n == 0) {
if (on_read_disconnect) {
on_read_disconnect->feed();
}
break;
}
int error = errno;
if (error == EAGAIN || error == EWOULDBLOCK) {
break;
} else {
if (read_done) {
read_done->fail(criticalError("Read failed"));
}
read_tasks.pop();
}
} else if (static_cast<size_t>(n) >= task.min_length &&
static_cast<size_t>(n) <= task.max_length) {
if (read_done) {
read_done->feed(static_cast<size_t>(n));
}
size_t max_len = task.max_length;
read_tasks.pop();
} else {
task.buffer = reinterpret_cast<uint8_t *>(task.buffer) + n;
task.min_length -= static_cast<size_t>(n);
task.max_length -= static_cast<size_t>(n);
}
}
ssize_t unixWrite(int fd, const void* buffer, size_t length){
return ::write(fd, buffer, length);
}
void UnixIoStream::writeStep() {
if (write_ready) {
write_ready->feed();
}
while (!write_tasks.empty()) {
WriteIoTask &task = write_tasks.front();
ssize_t n = ::write(fd(), task.buffer, task.length);
if (n < 0) {
int error = errno;
if (error == EAGAIN || error == EWOULDBLOCK) {
break;
} else {
if (write_done) {
write_done->fail(criticalError("Write failed"));
}
write_tasks.pop();
}
} else if (static_cast<size_t>(n) == task.length) {
if (write_done) {
write_done->feed(static_cast<size_t>(task.length));
}
write_tasks.pop();
} else {
task.buffer = reinterpret_cast<const uint8_t *>(task.buffer) +
static_cast<size_t>(n);
task.length -= static_cast<size_t>(n);
}
}
ssize_t UnixIoStream::readStream(void* buffer, size_t length){
return unixRead(fd(), buffer, length);
}
ssize_t UnixIoStream::writeStream(const void* buffer, size_t length) {
return unixWrite(fd(), buffer, length);
}
*/
UnixIoStream::UnixIoStream(UnixEventPort &event_port, int file_descriptor,
int fd_flags, uint32_t event_mask)
@ -104,9 +40,9 @@ UnixIoStream::UnixIoStream(UnixEventPort &event_port, int file_descriptor,
}
void UnixIoStream::read(void *buffer, size_t min_length, size_t max_length) {
bool is_ready = read_helper.read_tasks.empty();
read_helper.read_tasks.push(
ReadTaskAndStepHelper::ReadIoTask{buffer, min_length, max_length});
bool is_ready = !read_helper.read_task.has_value();
read_helper.read_task =
ReadTaskAndStepHelper::ReadIoTask{buffer, min_length, max_length};
if (is_ready) {
read_helper.readStep(*this);
}
@ -131,9 +67,9 @@ Conveyor<void> UnixIoStream::onReadDisconnected() {
}
void UnixIoStream::write(const void *buffer, size_t length) {
bool is_ready = write_helper.write_tasks.empty();
write_helper.write_tasks.push(
WriteTaskAndStepHelper::WriteIoTask{buffer, length});
bool is_ready = !write_helper.write_task.has_value();
write_helper.write_task =
WriteTaskAndStepHelper::WriteIoTask{buffer, length};
if (is_ready) {
write_helper.writeStep(*this);
}
@ -153,11 +89,11 @@ Conveyor<void> UnixIoStream::writeReady() {
void UnixIoStream::notify(uint32_t mask) {
if (mask & EPOLLOUT) {
writeStep();
write_helper.writeStep(*this);
}
if (mask & EPOLLIN) {
readStep();
read_helper.readStep(*this);
}
if (mask & EPOLLRDHUP) {
@ -286,6 +222,8 @@ std::string UnixNetworkAddress::toString() const {
}
}
UnixNetwork::UnixNetwork(UnixEventPort &event) : event_port{event} {}
Conveyor<Own<NetworkAddress>> UnixNetwork::parseAddress(const std::string &path,
uint16_t port_hint) {
std::string_view addr_view{path};
@ -333,4 +271,5 @@ ErrorOr<AsyncIoContext> setupAsyncIo() {
return criticalError("Out of memory");
}
}
}
} // namespace gin

View File

@ -26,10 +26,11 @@
#include <unordered_map>
#include <vector>
#include "io.h"
#include "io_helpers.h"
#include "kelgin/io.h"
#include "./io.h"
namespace gin {
namespace unix {
constexpr int MAX_EPOLL_EVENTS = 256;
class UnixEventPort;
@ -262,20 +263,18 @@ public:
}
};
ssize_t unixRead(int fd, void* buffer, size_t length);
ssize_t unixWrite(int fd, const void* buffer, size_t length);
class UnixIoStream final : public IoStream,
public IFdOwner,
public DataReaderAndWriter {
public IFdOwner, public StreamReaderAndWriter {
private:
WriteTaskAndStepHelper write_helper;
ReadTaskAndStepHelper read_helper;
private:
// Interface impl for the helpers above
ssize_t dataRead(void *buffer, size_t length) override;
ssize_t dataWrite(const void *buffer, size_t length) override;
void readStep();
void writeStep();
ssize_t readStream(void* buffer, size_t len) override;
ssize_t writeStream(const void* buffer, size_t len) override;
public:
UnixIoStream(UnixEventPort &event_port, int file_descriptor, int fd_flags,
@ -426,4 +425,5 @@ public:
EventLoop &eventLoop();
};
}
} // namespace gin

View File

@ -1,30 +1,30 @@
#include "io_helpers.h"
#include "io.h"
namespace gin {
void ReadTaskAndStepHelper::readStep(DataReader &reader) {
void ReadTaskAndStepHelper::readStep(StreamReader &reader) {
if (read_ready) {
read_ready->feed();
}
while (!read_tasks.empty()) {
ReadIoTask &task = read_tasks.front();
if (read_task.has_value()) {
ReadIoTask &task = *read_task;
ssize_t n = reader.dataRead(task.buffer, task.max_length);
ssize_t n = reader.readStream(task.buffer, task.max_length);
if (n <= 0) {
if (n == 0) {
if (on_read_disconnect) {
on_read_disconnect->feed();
}
break;
return;
}
int error = errno;
if (error == EAGAIN || error == EWOULDBLOCK) {
break;
return;
} else {
if (read_done) {
read_done->fail(criticalError("Read failed"));
}
read_tasks.pop();
read_task = std::nullopt;
}
} else if (static_cast<size_t>(n) >= task.min_length &&
static_cast<size_t>(n) <= task.max_length) {
@ -32,7 +32,7 @@ void ReadTaskAndStepHelper::readStep(DataReader &reader) {
read_done->feed(static_cast<size_t>(n));
}
size_t max_len = task.max_length;
read_tasks.pop();
read_task = std::nullopt;
} else {
task.buffer = reinterpret_cast<uint8_t *>(task.buffer) + n;
task.min_length -= static_cast<size_t>(n);
@ -41,30 +41,30 @@ void ReadTaskAndStepHelper::readStep(DataReader &reader) {
}
}
void WriteTaskAndStepHelper::writeStep(DataWriter &writer) {
void WriteTaskAndStepHelper::writeStep(StreamWriter &writer) {
if (write_ready) {
write_ready->feed();
}
while (!write_tasks.empty()) {
WriteIoTask &task = write_tasks.front();
if (write_task.has_value()) {
WriteIoTask &task = *write_task;
ssize_t n = writer.dataWrite(task.buffer, task.length);
ssize_t n = writer.writeStream(task.buffer, task.length);
if (n < 0) {
int error = errno;
if (error == EAGAIN || error == EWOULDBLOCK) {
break;
return;
} else {
if (write_done) {
write_done->fail(criticalError("Write failed"));
}
write_tasks.pop();
write_task = std::nullopt;
}
} else if (static_cast<size_t>(n) == task.length) {
if (write_done) {
write_done->feed(static_cast<size_t>(task.length));
}
write_tasks.pop();
write_task = std::nullopt;
} else {
task.buffer = reinterpret_cast<const uint8_t *>(task.buffer) +
static_cast<size_t>(n);

View File

@ -1,10 +1,11 @@
#pragma once
#include <kelgin/async.h>
#include <kelgin/io.h>
#include <kelgin/common.h>
#include <cstdint>
#include <queue>
#include <optional>
namespace gin {
/*
@ -17,11 +18,11 @@ namespace gin {
* of strange abstraction. This may also be reusable for windows/macOS though.
*/
class DataReader {
class StreamReader {
protected:
~StreamReader() = default;
public:
virtual ~DataReader() = default;
virtual ssize_t dataRead(void *buffer, size_t length) = 0;
virtual ssize_t readStream(void* buffer, size_t length) = 0;
};
class ReadTaskAndStepHelper {
@ -31,21 +32,21 @@ public:
size_t min_length;
size_t max_length;
};
std::queue<ReadIoTask> read_tasks;
std::optional<ReadIoTask> read_task;
Own<ConveyorFeeder<size_t>> read_done = nullptr;
Own<ConveyorFeeder<void>> read_ready = nullptr;
Own<ConveyorFeeder<void>> on_read_disconnect = nullptr;
public:
void readStep(DataReader &reader);
void readStep(StreamReader &reader);
};
class DataWriter {
class StreamWriter {
protected:
~StreamWriter() = default;
public:
virtual ~DataWriter() = default;
virtual ssize_t dataWrite(const void *buffer, size_t length) = 0;
virtual ssize_t writeStream(const void* buffer, size_t length) = 0;
};
class WriteTaskAndStepHelper {
@ -54,16 +55,16 @@ public:
const void *buffer;
size_t length;
};
std::queue<WriteIoTask> write_tasks;
std::optional<WriteIoTask> write_task;
Own<ConveyorFeeder<size_t>> write_done = nullptr;
Own<ConveyorFeeder<void>> write_ready = nullptr;
public:
void writeStep(DataWriter &writer);
void writeStep(StreamWriter &writer);
};
class DataReaderAndWriter : public DataReader, public DataWriter {
public:
virtual ~DataReaderAndWriter() = default;
class StreamReaderAndWriter : public StreamReader, public StreamWriter {
protected:
~StreamReaderAndWriter() = default;
};
} // namespace gin

View File

@ -1,3 +0,0 @@
#ifdef GIN_UNIX
namespace gin {}
#endif

18
driver/tls/tls-unix.h Normal file
View File

@ -0,0 +1,18 @@
#pragma once
#ifndef GIN_UNIX
#error "Don't include this"
#endif
#include "../io-unix.h"
#include "./tls.h"
namespace gin {
class TlsUnixIoStream final : public IoStream, public IFdOwner, public StreamReaderAndWriter {
public:
TlsUnixIoStream(UnixEventPort &event_port, int file_descriptor, int fd_flags,
uint32_t event_mask);
};
}

24
driver/tls/tls.cpp Normal file
View File

@ -0,0 +1,24 @@
#include "tls.h"
namespace gin {
TlsContext::TlsContext(){
gnutls_global_init();
gnutls_certificate_allocate_credentials(&xcred);
gnutls_certificate_set_x509_system_trust(xcred);
}
TlsContext::~TlsContext(){
gnutls_certificate_free_credentials(xcred);
gnutls_global_deinit();
}
TlsNetwork::TlsNetwork(Network& net):
network{net}
{}
std::optional<Own<Network>> setupTlsNetwork(Network& network){
return std::nullopt;
}
}

33
driver/tls/tls.h Normal file
View File

@ -0,0 +1,33 @@
#pragma once
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#include "common.h"
#include <kelgin/tls/tls.h>
namespace gin {
class TlsContext {
public:
gnutls_certificate_credentials_t xcred;
public:
TlsContext();
~TlsContext();
GIN_FORBID_COPY(TlsContext);
};
class TlsNetwork : public Network {
private:
Network& network;
TlsContext context;
public:
TlsNetwork(Network& net);
Conveyor<Own<NetworkAddress>>
parseAddress(const std::string &addr, uint16_t port_hint = 0) override;
};
}

View File

@ -14,9 +14,7 @@ EventLoop &currentEventLoop() {
}
} // namespace
ConveyorNode::ConveyorNode() : child{nullptr} {}
ConveyorNode::ConveyorNode(Own<ConveyorNode> &&node) : child{std::move(node)} {}
ConveyorNode::ConveyorNode() {}
void ConveyorStorage::setParent(ConveyorStorage *p) {
/*
@ -307,7 +305,7 @@ void ConveyorSinks::fire() {
}
ConvertConveyorNodeBase::ConvertConveyorNodeBase(Own<ConveyorNode> &&dep)
: ConveyorNode{std::move(dep)} {}
: child{std::move(dep)} {}
void ConvertConveyorNodeBase::getResult(ErrorOrValue &err_or_val) {
getImpl(err_or_val);

View File

@ -12,12 +12,8 @@
namespace gin {
class ConveyorNode {
protected:
Own<ConveyorNode> child;
public:
ConveyorNode();
ConveyorNode(Own<ConveyorNode> &&child);
virtual ~ConveyorNode() = default;
virtual void getResult(ErrorOrValue &err_or_val) = 0;
@ -208,9 +204,9 @@ public:
/*
* Join Conveyors into a single one
*/
// template<typename... Args>
// Conveyor<std::tuple<Args...>> joinConveyors(std::tuple<Conveyor<Args...>>&
// conveyors);
template <typename... Args>
Conveyor<std::tuple<Args...>>
joinConveyors(std::tuple<Conveyor<Args>...> &conveyors);
template <typename T> class ConveyorFeeder {
public:
@ -412,6 +408,9 @@ public:
template <typename T>
class AdaptConveyorNode final : public ConveyorNode, public ConveyorStorage {
protected:
Own<ConveyorNode> child;
private:
AdaptConveyorFeeder<T> *feeder = nullptr;
@ -459,6 +458,9 @@ public:
template <typename T>
class OneTimeConveyorNode final : public ConveyorNode, public ConveyorStorage {
protected:
Own<ConveyorNode> child;
private:
OneTimeConveyorFeeder<T> *feeder = nullptr;
@ -488,9 +490,12 @@ public:
class QueueBufferConveyorNodeBase : public ConveyorNode,
public ConveyorStorage {
protected:
Own<ConveyorNode> child;
public:
QueueBufferConveyorNodeBase(Own<ConveyorNode> &&dep)
: ConveyorNode(std::move(dep)) {}
: child(std::move(dep)) {}
virtual ~QueueBufferConveyorNodeBase() = default;
};
@ -545,9 +550,11 @@ public:
};
class AttachConveyorNodeBase : public ConveyorNode {
protected:
Own<ConveyorNode> child;
public:
AttachConveyorNodeBase(Own<ConveyorNode> &&dep)
: ConveyorNode(std::move(dep)) {}
AttachConveyorNodeBase(Own<ConveyorNode> &&dep) : child(std::move(dep)) {}
virtual ~AttachConveyorNodeBase() = default;
@ -566,6 +573,9 @@ public:
};
class ConvertConveyorNodeBase : public ConveyorNode {
protected:
Own<ConveyorNode> child;
public:
ConvertConveyorNodeBase(Own<ConveyorNode> &&dep);
virtual ~ConvertConveyorNodeBase() = default;
@ -616,14 +626,15 @@ public:
class SinkConveyorNode final : public ConveyorNode, public ConveyorStorage {
private:
Own<ConveyorNode> child;
ConveyorSinks *conveyor_sink;
public:
SinkConveyorNode(Own<ConveyorNode> &&node, ConveyorSinks &conv_sink)
: ConveyorNode(std::move(node)), conveyor_sink{&conv_sink} {}
: child(std::move(node)), conveyor_sink{&conv_sink} {}
SinkConveyorNode(Own<ConveyorNode> &&node)
: ConveyorNode(std::move(node)), conveyor_sink{nullptr} {}
: child(std::move(node)), conveyor_sink{nullptr} {}
// Event only queued if a critical error occured
void fire() override {
@ -702,6 +713,20 @@ public:
void fire() override;
};
class JoinConveyorNodeBase : public ConveyorNode, public ConveyorStorage {
public:
virtual ~JoinConveyorNodeBase() = default;
};
template <typename T> class JoinConveyorNode : public JoinConveyorNodeBase {
public:
};
template <typename... Args> class JoinConveyorMerger : public ConveyorStorage {
private:
std::tuple<JoinConveyorNode<Args>...> joined;
};
} // namespace gin
#include "async.tmpl.h"

View File

@ -1,3 +1,4 @@
#include "io.h"
namespace gin {}
namespace gin {
}

View File

@ -6,7 +6,6 @@
#include <string>
namespace gin {
/*
* Input stream
*/

View File

@ -1,44 +0,0 @@
#include "tls.h"
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#include "io_helpers.h"
namespace gin {
class TlsContext::Impl {
private:
gnutls_certificate_credentials_t xcred;
public:
Impl() {
gnutls_global_init();
gnutls_certificate_allocate_credentials(&xcred);
gnutls_certificate_set_x509_system_trust(xcred);
}
~Impl() {
gnutls_certificate_free_credentials(xcred);
gnutls_global_deinit();
}
};
TlsContext::TlsContext() : impl{heap<TlsContext::Impl>()} {}
TlsContext::~TlsContext() {}
class TlsIoStream final : public IoStream, public DataReaderAndWriter {
public:
};
class TlsNetworkAddress : public NetworkAddress {
public:
Own<Server> listen() override { return nullptr; }
Conveyor<Own<IoStream>> connect() override { return {nullptr, nullptr}; }
std::string toString() const override { return {}; }
};
Conveyor<Own<NetworkAddress>> TlsNetwork::parseAddress(const std::string &addr,
uint16_t port_hint) {
return {nullptr, nullptr};
}
} // namespace gin

View File

@ -1,32 +0,0 @@
#pragma once
#include <kelgin/common.h>
#include <kelgin/io.h>
namespace gin {
class TlsContext {
public:
struct Options {};
private:
/*
* Pimpl pattern to hide GnuTls includes
*/
class Impl;
Own<Impl> impl;
public:
TlsContext();
~TlsContext();
};
class TlsNetwork final : public Network {
private:
TlsContext context;
public:
Conveyor<Own<NetworkAddress>> parseAddress(const std::string &,
uint16_t port_hint = 0) override;
};
} // namespace gin

View File

@ -0,0 +1,9 @@
#include "tls.h"
#include <gnutls/gnutls.h>
#include <gnutls/x509.h>
#include "io_helpers.h"
namespace gin {
} // namespace gin

18
source/kelgin/tls/tls.h Normal file
View File

@ -0,0 +1,18 @@
#pragma once
#include <kelgin/common.h>
#include <kelgin/io.h>
#include <optional>
namespace gin {
class Tls {
public:
class Options {
public:
};
};
std::optional<Own<Network>> setupTlsNetwork(Network& network);
} // namespace gin