forstio/driver/io_unix.h

471 lines
11 KiB
C++

#pragma once
#ifndef SAW_UNIX
#error "Don't include this"
#endif
#include <csignal>
#include <sys/signalfd.h>
#include <fcntl.h>
#include <netdb.h>
#include <netinet/in.h>
#include <sys/epoll.h>
#include <sys/socket.h>
#include <sys/stat.h>
#include <sys/types.h>
#include <sys/un.h>
#include <cassert>
#include <cstring>
#include <errno.h>
#include <unistd.h>
#include <queue>
#include <unordered_map>
#include <vector>
#include "io.h"
namespace saw {
namespace unix {
constexpr int MAX_EPOLL_EVENTS = 256;
class unix_event_port;
class i_fd_owner {
protected:
unix_event_port &event_port_;
private:
int file_descriptor_;
int fd_flags_;
uint32_t event_mask_;
public:
i_fd_owner(unix_event_port &event_port, int file_descriptor, int fd_flags,
uint32_t event_mask);
virtual ~i_fd_owner();
virtual void notify(uint32_t mask) = 0;
int fd() const { return file_descriptor_; }
};
class unix_event_port final : public event_port {
private:
int epoll_fd_;
int signal_fd_;
sigset_t signal_fd_set_;
std::unordered_multimap<Signal, own<conveyor_feeder<void>>>
signal_conveyors_;
int pipefds_[2];
std::vector<int> to_unix_signal(Signal signal) const {
switch (signal) {
case Signal::User1:
return {SIGUSR1};
case Signal::Terminate:
default:
return {SIGTERM, SIGQUIT, SIGINT};
}
}
Signal from_unix_signal(int signal) const {
switch (signal) {
case SIGUSR1:
return Signal::User1;
case SIGTERM:
case SIGINT:
case SIGQUIT:
default:
return Signal::Terminate;
}
}
void notify_signal_listener(int sig) {
Signal signal = from_unix_signal(sig);
auto equal_range = signal_conveyors_.equal_range(signal);
for (auto iter = equal_range.first; iter != equal_range.second;
++iter) {
if (iter->second) {
if (iter->second->space() > 0) {
iter->second->feed();
}
}
}
}
bool poll_impl(int time) {
epoll_event events[MAX_EPOLL_EVENTS];
int nfds = 0;
do {
nfds = epoll_wait(epoll_fd_, events, MAX_EPOLL_EVENTS, time);
if (nfds < 0) {
/// @todo error_handling
return false;
}
for (int i = 0; i < nfds; ++i) {
if (events[i].data.u64 == 0) {
while (1) {
struct ::signalfd_siginfo siginfo;
ssize_t n =
::read(signal_fd_, &siginfo, sizeof(siginfo));
if (n < 0) {
break;
}
assert(n == sizeof(siginfo));
notify_signal_listener(siginfo.ssi_signo);
}
} else if (events[i].data.u64 == 1) {
uint8_t i;
if (pipefds_[0] < 0) {
continue;
}
while (1) {
ssize_t n = ::recv(pipefds_[0], &i, sizeof(i), 0);
if (n < 0) {
break;
}
}
} else {
i_fd_owner *owner =
reinterpret_cast<i_fd_owner *>(events[i].data.ptr);
if (owner) {
owner->notify(events[i].events);
}
}
}
} while (nfds == MAX_EPOLL_EVENTS);
return true;
}
public:
unix_event_port() : epoll_fd_{-1}, signal_fd_{-1} {
::signal(SIGPIPE, SIG_IGN);
epoll_fd_ = ::epoll_create1(EPOLL_CLOEXEC);
if (epoll_fd_ < 0) {
return;
}
::sigemptyset(&signal_fd_set_);
signal_fd_ =
::signalfd(-1, &signal_fd_set_, SFD_NONBLOCK | SFD_CLOEXEC);
if (signal_fd_ < 0) {
return;
}
struct epoll_event event;
memset(&event, 0, sizeof(event));
event.events = EPOLLIN;
event.data.u64 = 0;
::epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, signal_fd_, &event);
int rc = ::pipe2(pipefds_, O_NONBLOCK | O_CLOEXEC);
if (rc < 0) {
return;
}
memset(&event, 0, sizeof(event));
event.events = EPOLLIN;
event.data.u64 = 1;
::epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, pipefds_[0], &event);
}
~unix_event_port() {
::close(epoll_fd_);
::close(signal_fd_);
::close(pipefds_[0]);
::close(pipefds_[1]);
}
conveyor<void> on_signal(Signal signal) override {
auto caf = new_conveyor_and_feeder<void>();
signal_conveyors_.insert(std::make_pair(signal, std::move(caf.feeder)));
std::vector<int> sig = to_unix_signal(signal);
for (auto iter = sig.begin(); iter != sig.end(); ++iter) {
::sigaddset(&signal_fd_set_, *iter);
}
::sigprocmask(SIG_BLOCK, &signal_fd_set_, nullptr);
::signalfd(signal_fd_, &signal_fd_set_, SFD_NONBLOCK | SFD_CLOEXEC);
auto node = conveyor<void>::from_conveyor(std::move(caf.conveyor));
return conveyor<void>::to_conveyor(std::move(node));
}
void poll() override { poll_impl(0); }
void wait() override { poll_impl(-1); }
void wait(const std::chrono::steady_clock::duration &duration) override {
poll_impl(
std::chrono::duration_cast<std::chrono::milliseconds>(duration)
.count());
}
void
wait(const std::chrono::steady_clock::time_point &time_point) override {
auto now = std::chrono::steady_clock::now();
if (time_point <= now) {
poll();
} else {
poll_impl(std::chrono::duration_cast<std::chrono::milliseconds>(
time_point - now)
.count());
}
}
void wake() override {
/// @todo pipe() in the beginning and write something minor into it like
/// uint8_t or sth the value itself doesn't matter
if (pipefds_[1] < 0) {
return;
}
uint8_t i = 0;
::send(pipefds_[1], &i, sizeof(i), MSG_DONTWAIT);
}
void subscribe(i_fd_owner &owner, int fd, uint32_t event_mask) {
if (epoll_fd_ < 0 || fd < 0) {
return;
}
::epoll_event event;
memset(&event, 0, sizeof(event));
event.events = event_mask | EPOLLET;
event.data.ptr = &owner;
if (::epoll_ctl(epoll_fd_, EPOLL_CTL_ADD, fd, &event) < 0) {
/// @todo error_handling
return;
}
}
void unsubscribe(int fd) {
if (epoll_fd_ < 0 || fd < 0) {
return;
}
if (::epoll_ctl(epoll_fd_, EPOLL_CTL_DEL, fd, nullptr) < 0) {
/// @todo error_handling
return;
}
}
};
ssize_t unix_read(int fd, void *buffer, size_t length);
ssize_t unix_write(int fd, const void *buffer, size_t length);
class unix_io_stream final : public io_stream, public i_fd_owner {
private:
own<conveyor_feeder<void>> read_ready_ = nullptr;
own<conveyor_feeder<void>> on_read_disconnect_ = nullptr;
own<conveyor_feeder<void>> write_ready_ = nullptr;
public:
unix_io_stream(unix_event_port &event_port, int file_descriptor,
int fd_flags, uint32_t event_mask);
error_or<size_t> read(void *buffer, size_t length) override;
conveyor<void> read_ready() override;
conveyor<void> on_read_disconnected() override;
error_or<size_t> write(const void *buffer, size_t length) override;
conveyor<void> write_ready() override;
/*
void read(void *buffer, size_t min_length, size_t max_length) override;
Conveyor<size_t> readDone() override;
Conveyor<void> readReady() override;
Conveyor<void> onReadDisconnected() override;
void write(const void *buffer, size_t length) override;
Conveyor<size_t> writeDone() override;
Conveyor<void> writeReady() override;
*/
void notify(uint32_t mask) override;
};
class unix_server final : public server, public i_fd_owner {
private:
own<conveyor_feeder<own<io_stream>>> accept_feeder_ = nullptr;
public:
unix_server(unix_event_port &event_port, int file_descriptor, int fd_flags);
conveyor<own<io_stream>> accept() override;
void notify(uint32_t mask) override;
};
class unix_datagram final : public datagram, public i_fd_owner {
private:
own<conveyor_feeder<void>> read_ready_ = nullptr;
own<conveyor_feeder<void>> write_ready_ = nullptr;
public:
unix_datagram(unix_event_port &event_port, int file_descriptor,
int fd_flags);
error_or<size_t> read(void *buffer, size_t length) override;
conveyor<void> read_ready() override;
error_or<size_t> write(const void *buffer, size_t length,
network_address &dest) override;
conveyor<void> write_ready() override;
void notify(uint32_t mask) override;
};
/**
* Helper class which provides potential addresses to NetworkAddress
*/
class socket_address {
private:
union {
struct sockaddr generic;
struct sockaddr_un unix;
struct sockaddr_in inet;
struct sockaddr_in6 inet6;
struct sockaddr_storage storage;
} address_;
socklen_t address_length_;
bool wildcard_;
socket_address() : wildcard_{false} {}
public:
socket_address(const void *sockaddr, socklen_t len, bool wildcard)
: address_length_{len}, wildcard_{wildcard} {
assert(len <= sizeof(address_));
memcpy(&address_.generic, sockaddr, len);
}
int socket(int type) const {
type |= SOCK_NONBLOCK | SOCK_CLOEXEC;
int result = ::socket(address_.generic.sa_family, type, 0);
return result;
}
bool bind(int fd) const {
if (wildcard_) {
int value = 0;
::setsockopt(fd, IPPROTO_IPV6, IPV6_V6ONLY, &value, sizeof(value));
}
int error = ::bind(fd, &address_.generic, address_length_);
return error < 0;
}
struct ::sockaddr *get_raw() {
return &address_.generic;
}
const struct ::sockaddr *get_raw() const { return &address_.generic; }
socklen_t get_raw_length() const { return address_length_; }
static std::vector<socket_address> resolve(std::string_view str,
uint16_t port_hint) {
std::vector<socket_address> results;
struct ::addrinfo *head;
struct ::addrinfo hints;
memset(&hints, 0, sizeof(hints));
hints.ai_family = AF_UNSPEC;
std::string port_string = std::to_string(port_hint);
bool wildcard = str == "*" || str == "::";
std::string address_string{str};
int error = ::getaddrinfo(address_string.c_str(), port_string.c_str(),
&hints, &head);
if (error) {
return {};
}
for (struct ::addrinfo *it = head; it != nullptr; it = it->ai_next) {
if (it->ai_addrlen > sizeof(socket_address::address_)) {
continue;
}
results.push_back({it->ai_addr, it->ai_addrlen, wildcard});
}
::freeaddrinfo(head);
return results;
}
};
class unix_network_address final : public os_network_address {
private:
const std::string path_;
uint16_t port_hint_;
std::vector<socket_address> addresses_;
public:
unix_network_address(const std::string &path, uint16_t port_hint,
std::vector<socket_address> &&addr)
: path_{path}, port_hint_{port_hint}, addresses_{std::move(addr)} {}
const std::string &address() const override;
uint16_t port() const override;
// Custom address info
socket_address &unix_address(size_t i = 0);
size_t unix_address_size() const;
};
class unix_network final : public network {
private:
unix_event_port &event_port_;
public:
unix_network(unix_event_port &event_port);
conveyor<own<network_address>>
resolve_address(const std::string &address,
uint16_t port_hint = 0) override;
own<server> listen(network_address &addr) override;
conveyor<own<io_stream>> connect(network_address &addr) override;
own<class datagram> datagram(network_address &addr) override;
};
class unix_io_provider final : public io_provider {
private:
unix_event_port &event_port_;
class event_loop event_loop_;
unix_network unix_network_;
public:
unix_io_provider(unix_event_port &port_ref, own<event_port> port);
class network &network() override;
own<input_stream> wrap_input_fd(int fd) override;
class event_loop &event_loop();
};
} // namespace unix
} // namespace saw