diff options
author | Claudius "keldu" Holeksa <mail@keldu.de> | 2023-12-04 12:18:14 +0100 |
---|---|---|
committer | Claudius "keldu" Holeksa <mail@keldu.de> | 2023-12-04 12:18:14 +0100 |
commit | a14896f9ed209dd3f9597722e5a5697bd7dbf531 (patch) | |
tree | 089ca5cbbd206d1921f8f6b53292f5bc1902ca5c /modules/io | |
parent | 84ecdcbca9e55b1f57fbb832e12ff4fdbb86e7c9 (diff) |
meta: Renamed folder containing source
Diffstat (limited to 'modules/io')
-rw-r--r-- | modules/io/.nix/derivation.nix | 29 | ||||
-rw-r--r-- | modules/io/SConscript | 38 | ||||
-rw-r--r-- | modules/io/SConstruct | 66 | ||||
-rw-r--r-- | modules/io/io.cpp | 70 | ||||
-rw-r--r-- | modules/io/io.h | 219 | ||||
-rw-r--r-- | modules/io/io_helpers.cpp | 85 | ||||
-rw-r--r-- | modules/io/io_helpers.h | 53 | ||||
-rw-r--r-- | modules/io/io_unix.cpp | 894 |
8 files changed, 1454 insertions, 0 deletions
diff --git a/modules/io/.nix/derivation.nix b/modules/io/.nix/derivation.nix new file mode 100644 index 0000000..a14bd34 --- /dev/null +++ b/modules/io/.nix/derivation.nix @@ -0,0 +1,29 @@ +{ lib +, stdenv +, scons +, clang-tools +, version +, forstio +}: + +let + +in stdenv.mkDerivation { + pname = "forstio-io"; + inherit version; + src = ./..; + + enableParallelBuilding = true; + + nativeBuildInputs = [ + scons + clang-tools + ]; + + buildInputs = [ + forstio.core + forstio.async + ]; + + outputs = ["out" "dev"]; +} diff --git a/modules/io/SConscript b/modules/io/SConscript new file mode 100644 index 0000000..62ad58a --- /dev/null +++ b/modules/io/SConscript @@ -0,0 +1,38 @@ +#!/bin/false + +import os +import os.path +import glob + + +Import('env') + +dir_path = Dir('.').abspath + +# Environment for base library +io_env = env.Clone(); + +io_env.sources = sorted(glob.glob(dir_path + "/*.cpp")) +io_env.headers = sorted(glob.glob(dir_path + "/*.h")) + +env.sources += io_env.sources; +env.headers += io_env.headers; + +## Shared lib +objects_shared = [] +io_env.add_source_files(objects_shared, io_env.sources, shared=True); +io_env.library_shared = io_env.SharedLibrary('#build/forstio-io', [objects_shared]); + +## Static lib +objects_static = [] +io_env.add_source_files(objects_static, io_env.sources, shared=False); +io_env.library_static = io_env.StaticLibrary('#build/forstio-io', [objects_static]); + +# Set Alias +env.Alias('library_io', [io_env.library_shared, io_env.library_static]); + +env.targets += ['library_io']; + +# Install +env.Install('$prefix/lib/', [io_env.library_shared, io_env.library_static]); +env.Install('$prefix/include/forstio/io/', [io_env.headers]); diff --git a/modules/io/SConstruct b/modules/io/SConstruct new file mode 100644 index 0000000..4cccf82 --- /dev/null +++ b/modules/io/SConstruct @@ -0,0 +1,66 @@ +#!/usr/bin/env python3 + +import sys +import os +import os.path +import glob +import re + + +if sys.version_info < (3,): + def isbasestring(s): + return isinstance(s,basestring) +else: + def isbasestring(s): + return isinstance(s, (str,bytes)) + +def add_kel_source_files(self, sources, filetype, lib_env=None, shared=False, target_post=""): + + if isbasestring(filetype): + dir_path = self.Dir('.').abspath + filetype = sorted(glob.glob(dir_path+"/"+filetype)) + + for path in filetype: + target_name = re.sub( r'(.*?)(\.cpp|\.c\+\+)', r'\1' + target_post, path ) + if shared: + target_name+='.os' + sources.append( self.SharedObject( target=target_name, source=path ) ) + else: + target_name+='.o' + sources.append( self.StaticObject( target=target_name, source=path ) ) + pass + +def isAbsolutePath(key, dirname, env): + assert os.path.isabs(dirname), "%r must have absolute path syntax" % (key,) + +env_vars = Variables( + args=ARGUMENTS +) + +env_vars.Add('prefix', + help='Installation target location of build results and headers', + default='/usr/local/', + validator=isAbsolutePath +) + +env=Environment(ENV=os.environ, variables=env_vars, CPPPATH=[], + CPPDEFINES=['SAW_UNIX'], + CXXFLAGS=['-std=c++20','-g','-Wall','-Wextra'], + LIBS=['forstio-async']) +env.__class__.add_source_files = add_kel_source_files +env.Tool('compilation_db'); +env.cdb = env.CompilationDatabase('compile_commands.json'); + +env.objects = []; +env.sources = []; +env.headers = []; +env.targets = []; + +Export('env') +SConscript('SConscript') + +env.Alias('cdb', env.cdb); +env.Alias('all', [env.targets]); +env.Default('all'); + +env.Alias('install', '$prefix') diff --git a/modules/io/io.cpp b/modules/io/io.cpp new file mode 100644 index 0000000..f0705d2 --- /dev/null +++ b/modules/io/io.cpp @@ -0,0 +1,70 @@ +#include "io.h" + +#include <cassert> + +namespace saw { + +async_io_stream::async_io_stream(own<io_stream> str) + : stream_{std::move(str)}, + read_ready_{stream_->read_ready() + .then([this]() { read_stepper_.read_step(*stream_); }) + .sink()}, + write_ready_{stream_->write_ready() + .then([this]() { write_stepper_.write_step(*stream_); }) + .sink()}, + read_disconnected_{stream_->on_read_disconnected() + .then([this]() { + if (read_stepper_.on_read_disconnect) { + read_stepper_.on_read_disconnect->feed(); + } + }) + .sink()} {} + +void async_io_stream::read(void *buffer, size_t min_length, size_t max_length) { + SAW_ASSERT(buffer && max_length >= min_length && min_length > 0) { return; } + + SAW_ASSERT(!read_stepper_.read_task.has_value()) { return; } + + read_stepper_.read_task = read_task_and_step_helper::read_io_task{ + buffer, min_length, max_length, 0}; + read_stepper_.read_step(*stream_); +} + +conveyor<size_t> async_io_stream::read_done() { + auto caf = new_conveyor_and_feeder<size_t>(); + read_stepper_.read_done = std::move(caf.feeder); + return std::move(caf.conveyor); +} + +conveyor<void> async_io_stream::on_read_disconnected() { + auto caf = new_conveyor_and_feeder<void>(); + read_stepper_.on_read_disconnect = std::move(caf.feeder); + return std::move(caf.conveyor); +} + +void async_io_stream::write(const void *buffer, size_t length) { + SAW_ASSERT(buffer && length > 0) { return; } + + SAW_ASSERT(!write_stepper_.write_task.has_value()) { return; } + + write_stepper_.write_task = + write_task_and_step_helper::write_io_task{buffer, length, 0}; + write_stepper_.write_step(*stream_); +} + +conveyor<size_t> async_io_stream::write_done() { + auto caf = new_conveyor_and_feeder<size_t>(); + write_stepper_.write_done = std::move(caf.feeder); + return std::move(caf.conveyor); +} + +string_network_address::string_network_address(const std::string &address, + uint16_t port) + : address_value_{address}, port_value_{port} {} + +const std::string &string_network_address::address() const { + return address_value_; +} + +uint16_t string_network_address::port() const { return port_value_; } +} // namespace saw diff --git a/modules/io/io.h b/modules/io/io.h new file mode 100644 index 0000000..7653ace --- /dev/null +++ b/modules/io/io.h @@ -0,0 +1,219 @@ +#pragma once + +#include <forstio/async/async.h> +#include <forstio/core/common.h> +#include "io_helpers.h" + +#include <string> +#include <variant> + +namespace saw { +/** + * Set of error common in io + */ +namespace err { +struct disconnected { + static constexpr std::string_view description = "Disconnected"; + static constexpr bool is_critical = true; +}; + +struct resource_busy { + static constexpr std::string_view description = "Resource busy"; + static constexpr bool is_critical = false; +}; +} +/* + * Input stream + */ +class input_stream { +public: + virtual ~input_stream() = default; + + virtual error_or<size_t> read(void *buffer, size_t length) = 0; + + virtual conveyor<void> read_ready() = 0; + + virtual conveyor<void> on_read_disconnected() = 0; +}; + +/* + * Output stream + */ +class output_stream { +public: + virtual ~output_stream() = default; + + virtual error_or<size_t> write(const void *buffer, size_t length) = 0; + + virtual conveyor<void> write_ready() = 0; +}; + +/* + * Io stream + */ +class io_stream : public input_stream, public output_stream { +public: + virtual ~io_stream() = default; +}; + +class async_input_stream { +public: + virtual ~async_input_stream() = default; + + virtual void read(void *buffer, size_t min_length, size_t max_length) = 0; + + virtual conveyor<size_t> read_done() = 0; + virtual conveyor<void> on_read_disconnected() = 0; +}; + +class async_output_stream { +public: + virtual ~async_output_stream() = default; + + virtual void write(const void *buffer, size_t length) = 0; + + virtual conveyor<size_t> write_done() = 0; +}; + +class async_io_stream final : public async_input_stream, + public async_output_stream { +private: + own<io_stream> stream_; + + conveyor_sink read_ready_; + conveyor_sink write_ready_; + conveyor_sink read_disconnected_; + + read_task_and_step_helper read_stepper_; + write_task_and_step_helper write_stepper_; + +public: + async_io_stream(own<io_stream> str); + + SAW_FORBID_COPY(async_io_stream); + SAW_FORBID_MOVE(async_io_stream); + + void read(void *buffer, size_t length, size_t max_length) override; + + conveyor<size_t> read_done() override; + + conveyor<void> on_read_disconnected() override; + + void write(const void *buffer, size_t length) override; + + conveyor<size_t> write_done() override; +}; + +class server { +public: + virtual ~server() = default; + + virtual conveyor<own<io_stream>> accept() = 0; +}; + +class network_address; +/** + * Datagram class. Bound to a local address it is able to receive inbound + * datagram messages and send them as well as long as an address is provided as + * well + */ +class datagram { +public: + virtual ~datagram() = default; + + virtual error_or<size_t> read(void *buffer, size_t length) = 0; + virtual conveyor<void> read_ready() = 0; + + virtual error_or<size_t> write(const void *buffer, size_t length, + network_address &dest) = 0; + virtual conveyor<void> write_ready() = 0; +}; + +class os_network_address; +class string_network_address; + +class network_address { +public: + using child_variant = + std::variant<os_network_address *, string_network_address *>; + + virtual ~network_address() = default; + + virtual network_address::child_variant representation() = 0; + + virtual const std::string &address() const = 0; + virtual uint16_t port() const = 0; +}; + +class os_network_address : public network_address { +public: + virtual ~os_network_address() = default; + + network_address::child_variant representation() override { return this; } +}; + +class string_network_address final : public network_address { +private: + std::string address_value_; + uint16_t port_value_; + +public: + string_network_address(const std::string &address, uint16_t port); + + const std::string &address() const override; + uint16_t port() const override; + + network_address::child_variant representation() override { return this; } +}; + +class network { +public: + virtual ~network() = default; + + /** + * Resolve the provided string and uint16 to the preferred storage method + */ + virtual conveyor<own<network_address>> + resolve_address(const std::string &addr, uint16_t port_hint = 0) = 0; + + /** + * Parse the provided string and uint16 to the preferred storage method + * Since no dns request is made here, no async conveyors have to be used. + */ + /// @todo implement + // virtual Own<NetworkAddress> parseAddress(const std::string& addr, + // uint16_t port_hint = 0) = 0; + + /** + * Set up a listener on this address + */ + virtual own<server> listen(network_address &bind_addr) = 0; + + /** + * Connect to a remote address + */ + virtual conveyor<own<io_stream>> connect(network_address &address) = 0; + + /** + * Bind a datagram socket at this address. + */ + virtual own<datagram> datagram(network_address &address) = 0; +}; + +class io_provider { +public: + virtual ~io_provider() = default; + + virtual own<input_stream> wrap_input_fd(int fd) = 0; + + virtual network &get_network() = 0; +}; + +struct async_io_context { + own<io_provider> io; + event_loop &event_loop; + event_port &event_port; +}; + +error_or<async_io_context> setup_async_io(); +} // namespace saw diff --git a/modules/io/io_helpers.cpp b/modules/io/io_helpers.cpp new file mode 100644 index 0000000..c2cf2be --- /dev/null +++ b/modules/io/io_helpers.cpp @@ -0,0 +1,85 @@ +#include "io_helpers.h" + +#include "io.h" + +#include <cassert> + +namespace saw { +void read_task_and_step_helper::read_step(input_stream &reader) { + while (read_task.has_value()) { + read_io_task &task = *read_task; + + error_or<size_t> n_err = reader.read(task.buffer, task.max_length); + if (n_err.is_error()) { + const error &error = n_err.get_error(); + if (error.is_critical()) { + if (read_done) { + read_done->fail(error.copy_error()); + } + read_task = std::nullopt; + } + + break; + } else if (n_err.is_value()) { + size_t n = n_err.get_value(); + if (static_cast<size_t>(n) >= task.min_length && + static_cast<size_t>(n) <= task.max_length) { + if (read_done) { + read_done->feed(n + task.already_read); + } + read_task = std::nullopt; + } else { + task.buffer = static_cast<uint8_t *>(task.buffer) + n; + task.min_length -= static_cast<size_t>(n); + task.max_length -= static_cast<size_t>(n); + task.already_read += n; + } + + } else { + if (read_done) { + read_done->fail(make_error<err::invalid_state>("Read failed")); + } + read_task = std::nullopt; + } + } +} + +void write_task_and_step_helper::write_step(output_stream &writer) { + while (write_task.has_value()) { + write_io_task &task = *write_task; + + error_or<size_t> n_err = writer.write(task.buffer, task.length); + + if (n_err.is_value()) { + + size_t n = n_err.get_value(); + assert(n <= task.length); + if (n == task.length) { + if (write_done) { + write_done->feed(n + task.already_written); + } + write_task = std::nullopt; + } else { + task.buffer = static_cast<const uint8_t *>(task.buffer) + n; + task.length -= n; + task.already_written += n; + } + } else if (n_err.is_error()) { + const error &error = n_err.get_error(); + if (error.is_critical()) { + if (write_done) { + write_done->fail(error.copy_error()); + } + write_task = std::nullopt; + } + break; + } else { + if (write_done) { + write_done->fail(make_error<err::invalid_state>("Write failed")); + } + write_task = std::nullopt; + } + } +} + +} // namespace saw diff --git a/modules/io/io_helpers.h b/modules/io/io_helpers.h new file mode 100644 index 0000000..94e37f4 --- /dev/null +++ b/modules/io/io_helpers.h @@ -0,0 +1,53 @@ +#pragma once + +#include <forstio/async/async.h> +#include <forstio/core/common.h> + +#include <cstdint> +#include <optional> + +namespace saw { +/* + * Helper classes for the specific driver implementations + */ + +/* + * Since I don't want to repeat these implementations for tls on unix systems + * and gnutls doesn't let me write or read into buffers I have to have this kind + * of strange abstraction. This may also be reusable for windows/macOS though. + */ +class input_stream; + +class read_task_and_step_helper { +public: + struct read_io_task { + void *buffer; + size_t min_length; + size_t max_length; + size_t already_read = 0; + }; + std::optional<read_io_task> read_task; + own<conveyor_feeder<size_t>> read_done = nullptr; + + own<conveyor_feeder<void>> on_read_disconnect = nullptr; + +public: + void read_step(input_stream &reader); +}; + +class output_stream; + +class write_task_and_step_helper { +public: + struct write_io_task { + const void *buffer; + size_t length; + size_t already_written = 0; + }; + std::optional<write_io_task> write_task; + own<conveyor_feeder<size_t>> write_done = nullptr; + +public: + void write_step(output_stream &writer); +}; +} // namespace saw diff --git a/modules/io/io_unix.cpp b/modules/io/io_unix.cpp new file mode 100644 index 0000000..c3b4f17 --- /dev/null +++ b/modules/io/io_unix.cpp @@ -0,0 +1,894 @@ +#ifdef SAW_UNIX + +#include <csignal> +#include <sys/signalfd.h> + +#include <fcntl.h> +#include <netdb.h> +#include <netinet/in.h> +#include <sys/epoll.h> +#include <sys/socket.h> +#include <sys/stat.h> +#include <sys/types.h> +#include <sys/un.h> + +#include <cassert> +#include <cstring> + +#include <errno.h> +#include <unistd.h> + +#include <queue> +#include <sstream> +#include <unordered_map> +#include <vector> + +#include "io.h" + +namespace saw { +namespace unix { +constexpr int MAX_EPOLL_EVENTS = 256; + +class unix_event_port; +class i_fd_owner { +protected: + unix_event_port &event_port_; + +private: + int file_descriptor_; + int fd_flags_; + uint32_t event_mask_; + +public: + i_fd_owner(unix_event_port &event_port, int file_descriptor, int fd_flags, + uint32_t event_mask); + + virtual ~i_fd_owner(); + + virtual void notify(uint32_t mask) = 0; + + int fd() const { return file_descriptor_; } +}; + +class unix_event_port final : public event_port { +private: + int epoll_fd_; + int signal_fd_; + + sigset_t signal_fd_set_; + + std::unordered_multimap<Signal, own<conveyor_feeder<void>>> + signal_conveyors_; + + int pipefds_[2]; + + std::vector<int> to_unix_signal(Signal signal) const { + switch (signal) { + case Signal::User1: + return {SIGUSR1}; + case Signal::Terminate: + default: + return {SIGTERM, SIGQUIT, SIGINT}; + } + } + + Signal from_unix_signal(int signal) const { + switch (signal) { + case SIGUSR1: + return Signal::User1; + case SIGTERM: + case SIGINT: + case SIGQUIT: + default: + return Signal::Terminate; + } + } + + void notify_signal_listener(int sig) { + Signal signal = from_unix_signal(sig); + + auto equal_range = signal_conveyors_.equal_range(signal); + for (auto iter = equal_range.first; iter != equal_range.second; + ++iter) { + + if (iter->second) { + if (iter->second->space() > 0) { + iter->second->feed(); + } + } + } + } + + bool poll_impl(int time) { + epoll_event events[MAX_EPOLL_EVENTS]; + int nfds = 0; + do { + nfds = epoll_wait(epoll_fd_, events, MAX_EPOLL_EVENTS, time); + + if (nfds < 0) { + /// @todo error_handling + return false; + } + + for (int i = 0; i < nfds; ++i) { + if (events[i].data.u64 == 0) { + while (1) { + struct ::signalfd_siginfo siginfo; + ssize_t n = + ::read(signal_fd_, &siginfo, sizeof(siginfo)); + if (n < 0) { + break; + } + assert(n == sizeof(siginfo)); + + notify_signal_listener(siginfo.ssi_signo); + } + } else if (events[i].data.u64 == 1) { + uint8_t i; + if (pipefds_[0] < 0) { + continue; + } + while (1) { + ssize_t n = ::recv(pipefds_[0], &i, sizeof(i), 0); + if (n < 0) { + break; + } + } + } else { + i_fd_owner *owner = + reinterpret_cast<i_fd_owner *>(events[i].data.ptr); + if (owner) { + owner->notify(events[i].events); + } + } + } + } while (nfds == MAX_EPOLL_EVENTS); + + return true; + } + +public: + unix_event_port() : epoll_fd_{-1}, signal_fd_{-1} { + ::signal(SIGPIPE, SIG_IGN); + + epoll_fd_ = ::epoll_create1(EPOLL_CLOEXEC); + if (epoll_fd_ < 0) { + return; + } + + ::sigemptyset(&signal_fd_set_); + signal_fd_ = + ::signalfd(-1, &signal_fd_set_, SFD_NONBLOCK | SFD_CLOEXEC); + if (signal_fd_ < 0) { + return; + } + + struct epoll_event event; + memset(&event, 0, sizeof(event)); + event.events = EPOLLIN; + event.data.u64 = 0; + ::epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, signal_fd_, &event); + + int rc = ::pipe2(pipefds_, O_NONBLOCK | O_CLOEXEC); + if (rc < 0) { + return; + } + memset(&event, 0, sizeof(event)); + event.events = EPOLLIN; + event.data.u64 = 1; + ::epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, pipefds_[0], &event); + } + + ~unix_event_port() { + ::close(epoll_fd_); + ::close(signal_fd_); + ::close(pipefds_[0]); + ::close(pipefds_[1]); + } + + conveyor<void> on_signal(Signal signal) override { + auto caf = new_conveyor_and_feeder<void>(); + + signal_conveyors_.insert(std::make_pair(signal, std::move(caf.feeder))); + + std::vector<int> sig = to_unix_signal(signal); + + for (auto iter = sig.begin(); iter != sig.end(); ++iter) { + ::sigaddset(&signal_fd_set_, *iter); + } + ::sigprocmask(SIG_BLOCK, &signal_fd_set_, nullptr); + ::signalfd(signal_fd_, &signal_fd_set_, SFD_NONBLOCK | SFD_CLOEXEC); + + auto node = conveyor<void>::from_conveyor(std::move(caf.conveyor)); + return conveyor<void>::to_conveyor(std::move(node)); + } + + void poll() override { poll_impl(0); } + + void wait() override { poll_impl(-1); } + + void wait(const std::chrono::steady_clock::duration &duration) override { + poll_impl( + std::chrono::duration_cast<std::chrono::milliseconds>(duration) + .count()); + } + void + wait(const std::chrono::steady_clock::time_point &time_point) override { + auto now = std::chrono::steady_clock::now(); + if (time_point <= now) { + poll(); + } else { + poll_impl(std::chrono::duration_cast<std::chrono::milliseconds>( + time_point - now) + .count()); + } + } + + void wake() override { + /// @todo pipe() in the beginning and write something minor into it like + /// uint8_t or sth the value itself doesn't matter + if (pipefds_[1] < 0) { + return; + } + uint8_t i = 0; + ::send(pipefds_[1], &i, sizeof(i), MSG_DONTWAIT); + } + + void subscribe(i_fd_owner &owner, int fd, uint32_t event_mask) { + if (epoll_fd_ < 0 || fd < 0) { + return; + } + ::epoll_event event; + memset(&event, 0, sizeof(event)); + event.events = event_mask | EPOLLET; + event.data.ptr = &owner; + + if (::epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) < 0) { + /// @todo error_handling + return; + } + } + + void unsubscribe(int fd) { + if (epoll_fd_ < 0 || fd < 0) { + return; + } + if (::epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) < 0) { + /// @todo error_handling + return; + } + } +}; + +ssize_t unix_read(int fd, void *buffer, size_t length); +ssize_t unix_write(int fd, const void *buffer, size_t length); + +class unix_io_stream final : public io_stream, public i_fd_owner { +private: + own<conveyor_feeder<void>> read_ready_ = nullptr; + own<conveyor_feeder<void>> on_read_disconnect_ = nullptr; + own<conveyor_feeder<void>> write_ready_ = nullptr; + +public: + unix_io_stream(unix_event_port &event_port, int file_descriptor, + int fd_flags, uint32_t event_mask); + + error_or<size_t> read(void *buffer, size_t length) override; + + conveyor<void> read_ready() override; + + conveyor<void> on_read_disconnected() override; + + error_or<size_t> write(const void *buffer, size_t length) override; + + conveyor<void> write_ready() override; + + /* + void read(void *buffer, size_t min_length, size_t max_length) override; + Conveyor<size_t> readDone() override; + Conveyor<void> readReady() override; + + Conveyor<void> onReadDisconnected() override; + + void write(const void *buffer, size_t length) override; + Conveyor<size_t> writeDone() override; + Conveyor<void> writeReady() override; + */ + + void notify(uint32_t mask) override; +}; + +class unix_server final : public server, public i_fd_owner { +private: + own<conveyor_feeder<own<io_stream>>> accept_feeder_ = nullptr; + +public: + unix_server(unix_event_port &event_port, int file_descriptor, int fd_flags); + + conveyor<own<io_stream>> accept() override; + + void notify(uint32_t mask) override; +}; + +class unix_datagram final : public datagram, public i_fd_owner { +private: + own<conveyor_feeder<void>> read_ready_ = nullptr; + own<conveyor_feeder<void>> write_ready_ = nullptr; + +public: + unix_datagram(unix_event_port &event_port, int file_descriptor, + int fd_flags); + + error_or<size_t> read(void *buffer, size_t length) override; + conveyor<void> read_ready() override; + + error_or<size_t> write(const void *buffer, size_t length, + network_address &dest) override; + conveyor<void> write_ready() override; + + void notify(uint32_t mask) override; +}; + +/** + * Helper class which provides potential addresses to NetworkAddress + */ +class socket_address { +private: + union { + struct sockaddr generic; + struct sockaddr_un unix; + struct sockaddr_in inet; + struct sockaddr_in6 inet6; + struct sockaddr_storage storage; + } address_; + + socklen_t address_length_; + bool wildcard_; + + socket_address() : wildcard_{false} {} + +public: + socket_address(const void *sockaddr, socklen_t len, bool wildcard) + : address_length_{len}, wildcard_{wildcard} { + assert(len <= sizeof(address_)); + memcpy(&address_.generic, sockaddr, len); + } + + int socket(int type) const { + type |= SOCK_NONBLOCK | SOCK_CLOEXEC; + + int result = ::socket(address_.generic.sa_family, type, 0); + return result; + } + + bool bind(int fd) const { + if (wildcard_) { + int value = 0; + ::setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &value, sizeof(value)); + } + int error = ::bind(fd, &address_.generic, address_length_); + return error < 0; + } + + struct ::sockaddr *get_raw() { + return &address_.generic; + } + + const struct ::sockaddr *get_raw() const { return &address_.generic; } + + socklen_t get_raw_length() const { return address_length_; } + + static std::vector<socket_address> resolve(std::string_view str, + uint16_t port_hint) { + std::vector<socket_address> results; + + struct ::addrinfo *head; + struct ::addrinfo hints; + memset(&hints, 0, sizeof(hints)); + hints.ai_family = AF_UNSPEC; + + std::string port_string = std::to_string(port_hint); + bool wildcard = str == "*" || str == "::"; + std::string address_string{str}; + + int error = ::getaddrinfo(address_string.c_str(), port_string.c_str(), + &hints, &head); + + if (error) { + return {}; + } + + for (struct ::addrinfo *it = head; it != nullptr; it = it->ai_next) { + if (it->ai_addrlen > sizeof(socket_address::address_)) { + continue; + } + results.push_back({it->ai_addr, it->ai_addrlen, wildcard}); + } + ::freeaddrinfo(head); + return results; + } +}; + +class unix_network_address final : public os_network_address { +private: + const std::string path_; + uint16_t port_hint_; + std::vector<socket_address> addresses_; + +public: + unix_network_address(const std::string &path, uint16_t port_hint, + std::vector<socket_address> &&addr) + : path_{path}, port_hint_{port_hint}, addresses_{std::move(addr)} {} + + const std::string &address() const override; + + uint16_t port() const override; + + // Custom address info + socket_address &unix_address(size_t i = 0); + size_t unix_address_size() const; +}; + +class unix_network final : public network { +private: + unix_event_port &event_port_; + +public: + unix_network(unix_event_port &event_port); + + conveyor<own<network_address>> + resolve_address(const std::string &address, + uint16_t port_hint = 0) override; + + own<server> listen(network_address &addr) override; + + conveyor<own<io_stream>> connect(network_address &addr) override; + + own<class datagram> datagram(network_address &addr) override; +}; + +class unix_io_provider final : public io_provider { +private: + unix_event_port &event_port_; + class event_loop event_loop_; + + unix_network unix_network_; + +public: + unix_io_provider(unix_event_port &port_ref, own<event_port> port); + + class network &get_network() override; + + own<input_stream> wrap_input_fd(int fd) override; + + class event_loop &event_loop(); +}; + +i_fd_owner::i_fd_owner(unix_event_port &event_port, int file_descriptor, + int fd_flags, uint32_t event_mask) + : event_port_{event_port}, file_descriptor_{file_descriptor}, + fd_flags_{fd_flags}, event_mask_{event_mask} { + event_port_.subscribe(*this, file_descriptor, event_mask); +} + +i_fd_owner::~i_fd_owner() { + if (file_descriptor_ >= 0) { + event_port_.unsubscribe(file_descriptor_); + ::close(file_descriptor_); + } +} + +ssize_t unix_read(int fd, void *buffer, size_t length) { + return ::recv(fd, buffer, length, 0); +} + +ssize_t unix_write(int fd, const void *buffer, size_t length) { + return ::send(fd, buffer, length, 0); +} + +unix_io_stream::unix_io_stream(unix_event_port &event_port, int file_descriptor, + int fd_flags, uint32_t event_mask) + : i_fd_owner{event_port, file_descriptor, fd_flags, + event_mask | EPOLLRDHUP} {} + +error_or<size_t> unix_io_stream::read(void *buffer, size_t length) { + ssize_t read_bytes = unix_read(fd(), buffer, length); + if (read_bytes > 0) { + return static_cast<size_t>(read_bytes); + } else if (read_bytes == 0) { + return make_error<err::disconnected>(); + } + + return make_error<err::resource_busy>(); +} + +conveyor<void> unix_io_stream::read_ready() { + auto caf = new_conveyor_and_feeder<void>(); + read_ready_ = std::move(caf.feeder); + return std::move(caf.conveyor); +} + +conveyor<void> unix_io_stream::on_read_disconnected() { + auto caf = new_conveyor_and_feeder<void>(); + on_read_disconnect_ = std::move(caf.feeder); + return std::move(caf.conveyor); +} + +error_or<size_t> unix_io_stream::write(const void *buffer, size_t length) { + ssize_t write_bytes = unix_write(fd(), buffer, length); + if (write_bytes > 0) { + return static_cast<size_t>(write_bytes); + } + + int error = errno; + + if (error == EAGAIN || error == EWOULDBLOCK) { + return make_error<err::resource_busy>(); + } + + return make_error<err::disconnected>(); +} + +conveyor<void> unix_io_stream::write_ready() { + auto caf = new_conveyor_and_feeder<void>(); + write_ready_ = std::move(caf.feeder); + return std::move(caf.conveyor); +} + +void unix_io_stream::notify(uint32_t mask) { + if (mask & EPOLLOUT) { + if (write_ready_) { + write_ready_->feed(); + } + } + + if (mask & EPOLLIN) { + if (read_ready_) { + read_ready_->feed(); + } + } + + if (mask & EPOLLRDHUP) { + if (on_read_disconnect_) { + on_read_disconnect_->feed(); + } + } +} + +unix_server::unix_server(unix_event_port &event_port, int file_descriptor, + int fd_flags) + : i_fd_owner{event_port, file_descriptor, fd_flags, EPOLLIN} {} + +conveyor<own<io_stream>> unix_server::accept() { + auto caf = new_conveyor_and_feeder<own<io_stream>>(); + accept_feeder_ = std::move(caf.feeder); + return std::move(caf.conveyor); +} + +void unix_server::notify(uint32_t mask) { + if (mask & EPOLLIN) { + if (accept_feeder_) { + struct ::sockaddr_storage address; + socklen_t address_length = sizeof(address); + + int accept_fd = + ::accept4(fd(), reinterpret_cast<struct ::sockaddr *>(&address), + &address_length, SOCK_NONBLOCK | SOCK_CLOEXEC); + if (accept_fd < 0) { + return; + } + auto fd_stream = heap<unix_io_stream>(event_port_, accept_fd, 0, + EPOLLIN | EPOLLOUT); + accept_feeder_->feed(std::move(fd_stream)); + } + } +} + +unix_datagram::unix_datagram(unix_event_port &event_port, int file_descriptor, + int fd_flags) + : i_fd_owner{event_port, file_descriptor, fd_flags, EPOLLIN | EPOLLOUT} {} + +namespace { +ssize_t unix_read_msg(int fd, void *buffer, size_t length) { + struct ::sockaddr_storage their_addr; + socklen_t addr_len = sizeof(sockaddr_storage); + return ::recvfrom(fd, buffer, length, 0, + reinterpret_cast<struct ::sockaddr *>(&their_addr), + &addr_len); +} + +ssize_t unix_write_msg(int fd, const void *buffer, size_t length, + ::sockaddr *dest_addr, socklen_t dest_addr_len) { + + return ::sendto(fd, buffer, length, 0, dest_addr, dest_addr_len); +} +} // namespace + +error_or<size_t> unix_datagram::read(void *buffer, size_t length) { + ssize_t read_bytes = unix_read_msg(fd(), buffer, length); + if (read_bytes > 0) { + return static_cast<size_t>(read_bytes); + } + return make_error<err::resource_busy>(); +} + +conveyor<void> unix_datagram::read_ready() { + auto caf = new_conveyor_and_feeder<void>(); + read_ready_ = std::move(caf.feeder); + return std::move(caf.conveyor); +} + +error_or<size_t> unix_datagram::write(const void *buffer, size_t length, + network_address &dest) { + unix_network_address &unix_dest = static_cast<unix_network_address &>(dest); + socket_address &sock_addr = unix_dest.unix_address(); + socklen_t sock_addr_length = sock_addr.get_raw_length(); + ssize_t write_bytes = unix_write_msg(fd(), buffer, length, + sock_addr.get_raw(), sock_addr_length); + if (write_bytes > 0) { + return static_cast<size_t>(write_bytes); + } + return make_error<err::resource_busy>(); +} + +conveyor<void> unix_datagram::write_ready() { + auto caf = new_conveyor_and_feeder<void>(); + write_ready_ = std::move(caf.feeder); + return std::move(caf.conveyor); +} + +void unix_datagram::notify(uint32_t mask) { + if (mask & EPOLLOUT) { + if (write_ready_) { + write_ready_->feed(); + } + } + + if (mask & EPOLLIN) { + if (read_ready_) { + read_ready_->feed(); + } + } +} + +namespace { +bool begins_with(const std::string_view &viewed, + const std::string_view &begins) { + return viewed.size() >= begins.size() && + viewed.compare(0, begins.size(), begins) == 0; +} + +std::variant<unix_network_address, unix_network_address *> +translate_network_address_to_unix_network_address(network_address &addr) { + auto addr_variant = addr.representation(); + std::variant<unix_network_address, unix_network_address *> os_addr = + std::visit( + [](auto &arg) + -> std::variant<unix_network_address, unix_network_address *> { + using T = std::decay_t<decltype(arg)>; + + if constexpr (std::is_same_v<T, os_network_address *>) { + return static_cast<unix_network_address *>(arg); + } + + auto sock_addrs = socket_address::resolve( + std::string_view{arg->address()}, arg->port()); + + return unix_network_address{arg->address(), arg->port(), + std::move(sock_addrs)}; + }, + addr_variant); + return os_addr; +} + +unix_network_address &translate_to_unix_address_ref( + std::variant<unix_network_address, unix_network_address *> &addr_variant) { + return std::visit( + [](auto &arg) -> unix_network_address & { + using T = std::decay_t<decltype(arg)>; + + if constexpr (std::is_same_v<T, unix_network_address>) { + return arg; + } else if constexpr (std::is_same_v<T, unix_network_address *>) { + return *arg; + } else { + static_assert(true, "Cases exhausted"); + } + }, + addr_variant); +} + +} // namespace + +own<server> unix_network::listen(network_address &addr) { + auto unix_addr_storage = + translate_network_address_to_unix_network_address(addr); + unix_network_address &address = + translate_to_unix_address_ref(unix_addr_storage); + + assert(address.unix_address_size() > 0); + if (address.unix_address_size() == 0) { + return nullptr; + } + + int fd = address.unix_address(0).socket(SOCK_STREAM); + if (fd < 0) { + return nullptr; + } + + int val = 1; + int rc = ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &val, sizeof(val)); + if (rc < 0) { + ::close(fd); + return nullptr; + } + + bool failed = address.unix_address(0).bind(fd); + if (failed) { + ::close(fd); + return nullptr; + } + + ::listen(fd, SOMAXCONN); + + return heap<unix_server>(event_port_, fd, 0); +} + +conveyor<own<io_stream>> unix_network::connect(network_address &addr) { + auto unix_addr_storage = + translate_network_address_to_unix_network_address(addr); + unix_network_address &address = + translate_to_unix_address_ref(unix_addr_storage); + + assert(address.unix_address_size() > 0); + if (address.unix_address_size() == 0) { + return conveyor<own<io_stream>>{make_error<err::critical>()}; + } + + int fd = address.unix_address(0).socket(SOCK_STREAM); + if (fd < 0) { + return conveyor<own<io_stream>>{make_error<err::disconnected>()}; + } + + own<unix_io_stream> io_str = + heap<unix_io_stream>(event_port_, fd, 0, EPOLLIN | EPOLLOUT); + + bool success = false; + for (size_t i = 0; i < address.unix_address_size(); ++i) { + socket_address &addr_iter = address.unix_address(i); + int status = + ::connect(fd, addr_iter.get_raw(), addr_iter.get_raw_length()); + if (status < 0) { + int error = errno; + /* + * It's not connected yet... + * But edge triggered epolling means that it'll + * be ready when the signal is triggered + */ + + /// @todo Add limit node when implemented + if (error == EINPROGRESS) { + /* + Conveyor<void> write_ready = io_stream->writeReady(); + return write_ready.then( + [ios{std::move(io_stream)}]() mutable { + ios->write_ready = nullptr; + return std::move(ios); + }); + */ + success = true; + break; + } else if (error != EINTR) { + /// @todo Push error message from + return conveyor<own<io_stream>>{make_error<err::disconnected>()}; + } + } else { + success = true; + break; + } + } + + if (!success) { + return conveyor<own<io_stream>>{make_error<err::disconnected>()}; + } + + return conveyor<own<io_stream>>{std::move(io_str)}; +} + +own<datagram> unix_network::datagram(network_address &addr) { + auto unix_addr_storage = + translate_network_address_to_unix_network_address(addr); + unix_network_address &address = + translate_to_unix_address_ref(unix_addr_storage); + + SAW_ASSERT(address.unix_address_size() > 0) { return nullptr; } + + int fd = address.unix_address(0).socket(SOCK_DGRAM); + + int optval = 1; + int rc = + ::setsockopt(fd, SOL_SOCKET, SO_REUSEADDR, &optval, sizeof(optval)); + if (rc < 0) { + ::close(fd); + return nullptr; + } + + bool failed = address.unix_address(0).bind(fd); + if (failed) { + ::close(fd); + return nullptr; + } + /// @todo + return heap<unix_datagram>(event_port_, fd, 0); +} + +const std::string &unix_network_address::address() const { return path_; } + +uint16_t unix_network_address::port() const { return port_hint_; } + +socket_address &unix_network_address::unix_address(size_t i) { + assert(i < addresses_.size()); + /// @todo change from list to vector? + return addresses_.at(i); +} + +size_t unix_network_address::unix_address_size() const { + return addresses_.size(); +} + +unix_network::unix_network(unix_event_port &event) : event_port_{event} {} + +conveyor<own<network_address>> +unix_network::resolve_address(const std::string &path, uint16_t port_hint) { + std::string_view addr_view{path}; + { + std::string_view str_begins_with = "unix:"; + if (begins_with(addr_view, str_begins_with)) { + addr_view.remove_prefix(str_begins_with.size()); + } + } + + std::vector<socket_address> addresses = + socket_address::resolve(addr_view, port_hint); + + return conveyor<own<network_address>>{ + heap<unix_network_address>(path, port_hint, std::move(addresses))}; +} + +unix_io_provider::unix_io_provider(unix_event_port &port_ref, + own<event_port> port) + : event_port_{port_ref}, event_loop_{std::move(port)}, unix_network_{ + port_ref} {} + +own<input_stream> unix_io_provider::wrap_input_fd(int fd) { + return heap<unix_io_stream>(event_port_, fd, 0, EPOLLIN); +} + +class network &unix_io_provider::get_network() { + return static_cast<class network &>(unix_network_); +} + +class event_loop &unix_io_provider::event_loop() { + return event_loop_; +} + +} // namespace unix + +error_or<async_io_context> setup_async_io() { + using namespace unix; + try { + own<unix_event_port> prt = heap<unix_event_port>(); + unix_event_port &prt_ref = *prt; + + own<unix_io_provider> io_provider = + heap<unix_io_provider>(prt_ref, std::move(prt)); + + event_loop &loop_ref = io_provider->event_loop(); + + return {{std::move(io_provider), loop_ref, prt_ref}}; + } catch (std::bad_alloc &) { + return make_error<err::out_of_memory>(); + } +} +} // namespace saw +#endif |