moved logic from networkaddress to network.

fb-udp
Claudius Holeksa 2022-02-03 00:12:59 +01:00
parent 6ba4e70778
commit d2b0178f79
5 changed files with 140 additions and 110 deletions

View File

@ -197,15 +197,58 @@ bool beginsWith(const std::string_view &viewed,
return viewed.size() >= begins.size() &&
viewed.compare(0, begins.size(), begins) == 0;
}
std::variant<UnixNetworkAddress, UnixNetworkAddress *>
translateNetworkAddressToUnixNetworkAddress(NetworkAddress &addr) {
auto addr_variant = addr.representation();
std::variant<UnixNetworkAddress, UnixNetworkAddress *> os_addr = std::visit(
[](auto &arg)
-> std::variant<UnixNetworkAddress, UnixNetworkAddress *> {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, OsNetworkAddress *>) {
return static_cast<UnixNetworkAddress *>(arg);
}
auto sock_addrs = SocketAddress::parse(
std::string_view{arg->address()}, arg->port());
return UnixNetworkAddress{arg->address(), arg->port(),
std::move(sock_addrs)};
},
addr_variant);
return os_addr;
}
UnixNetworkAddress &translateToUnixAddressRef(
std::variant<UnixNetworkAddress, UnixNetworkAddress *> &addr_variant) {
return std::visit(
[](auto &arg) -> UnixNetworkAddress & {
using T = std::decay_t<decltype(arg)>;
if constexpr (std::is_same_v<T, UnixNetworkAddress>) {
return arg;
} else if constexpr (std::is_same_v<T, UnixNetworkAddress *>) {
return *arg;
} else {
static_assert(true, "Cases exhausted");
}
},
addr_variant);
}
} // namespace
Own<Server> UnixNetworkAddress::listen() {
assert(addresses.size() > 0);
if (addresses.size() == 0) {
Own<Server> UnixNetwork::listen(NetworkAddress &addr) {
auto unix_addr_storage = translateNetworkAddressToUnixNetworkAddress(addr);
UnixNetworkAddress &address = translateToUnixAddressRef(unix_addr_storage);
assert(address.unixAddressSize() > 0);
if (address.unixAddressSize() == 0) {
return nullptr;
}
int fd = addresses.front().socket(SOCK_STREAM);
int fd = address.unixAddress(0).socket(SOCK_STREAM);
if (fd < 0) {
return nullptr;
}
@ -217,7 +260,7 @@ Own<Server> UnixNetworkAddress::listen() {
return nullptr;
}
bool failed = addresses.front().bind(fd);
bool failed = address.unixAddress(0).bind(fd);
if (failed) {
::close(fd);
return nullptr;
@ -228,13 +271,16 @@ Own<Server> UnixNetworkAddress::listen() {
return heap<UnixServer>(event_port, fd, 0);
}
Conveyor<Own<IoStream>> UnixNetworkAddress::connect() {
assert(addresses.size() > 0);
if (addresses.size() == 0) {
Conveyor<Own<IoStream>> UnixNetwork::connect(NetworkAddress &addr) {
auto unix_addr_storage = translateNetworkAddressToUnixNetworkAddress(addr);
UnixNetworkAddress &address = translateToUnixAddressRef(unix_addr_storage);
assert(address.unixAddressSize() > 0);
if (address.unixAddressSize() == 0) {
return Conveyor<Own<IoStream>>{criticalError("No address found")};
}
int fd = addresses.front().socket(SOCK_STREAM);
int fd = address.unixAddress(0).socket(SOCK_STREAM);
if (fd < 0) {
return Conveyor<Own<IoStream>>{criticalError("Couldn't open socket")};
}
@ -243,8 +289,10 @@ Conveyor<Own<IoStream>> UnixNetworkAddress::connect() {
heap<UnixIoStream>(event_port, fd, 0, EPOLLIN | EPOLLOUT);
bool success = false;
for (auto iter = addresses.begin(); iter != addresses.end(); ++iter) {
int status = ::connect(fd, iter->getRaw(), iter->getRawLength());
for (size_t i = 0; i < address.unixAddressSize(); ++i) {
SocketAddress &addr_iter = address.unixAddress(i);
int status =
::connect(fd, addr_iter.getRaw(), addr_iter.getRawLength());
if (status < 0) {
int error = errno;
/*
@ -283,10 +331,13 @@ Conveyor<Own<IoStream>> UnixNetworkAddress::connect() {
return Conveyor<Own<IoStream>>{std::move(io_stream)};
}
Own<Datagram> UnixNetworkAddress::datagram() {
SAW_ASSERT(addresses.size() > 0) { return nullptr; }
Own<Datagram> UnixNetwork::datagram(NetworkAddress &addr) {
auto unix_addr_storage = translateNetworkAddressToUnixNetworkAddress(addr);
UnixNetworkAddress &address = translateToUnixAddressRef(unix_addr_storage);
int fd = addresses.front().socket(SOCK_DGRAM);
SAW_ASSERT(address.unixAddressSize() > 0) { return nullptr; }
int fd = address.unixAddress(0).socket(SOCK_DGRAM);
int optval = 1;
int rc =
@ -296,7 +347,7 @@ Own<Datagram> UnixNetworkAddress::datagram() {
return nullptr;
}
bool failed = addresses.front().bind(fd);
bool failed = address.unixAddress(0).bind(fd);
if (failed) {
::close(fd);
return nullptr;
@ -305,18 +356,6 @@ Own<Datagram> UnixNetworkAddress::datagram() {
return heap<UnixDatagram>(event_port, fd, 0);
}
std::string UnixNetworkAddress::toString() const {
try {
std::ostringstream oss;
oss << "Address: " << path;
if (port_hint > 0) {
oss << "\nPort: " << port_hint;
}
return oss.str();
} catch (std::bad_alloc &) {
return {};
}
}
const std::string &UnixNetworkAddress::address() const { return path; }
uint16_t UnixNetworkAddress::port() const { return port_hint; }
@ -344,8 +383,8 @@ Conveyor<Own<NetworkAddress>> UnixNetwork::parseAddress(const std::string &path,
std::vector<SocketAddress> addresses =
SocketAddress::parse(addr_view, port_hint);
return Conveyor<Own<NetworkAddress>>{heap<UnixNetworkAddress>(
event_port, path, port_hint, std::move(addresses))};
return Conveyor<Own<NetworkAddress>>{
heap<UnixNetworkAddress>(path, port_hint, std::move(addresses))};
}
UnixIoProvider::UnixIoProvider(UnixEventPort &port_ref, Own<EventPort> port)

View File

@ -410,25 +410,16 @@ public:
}
};
class UnixNetworkAddress final : public NetworkAddress {
class UnixNetworkAddress final : public OsNetworkAddress {
private:
UnixEventPort &event_port;
const std::string path;
uint16_t port_hint;
std::vector<SocketAddress> addresses;
public:
UnixNetworkAddress(UnixEventPort &event_port, const std::string &path,
uint16_t port_hint, std::vector<SocketAddress> &&addr)
: event_port{event_port}, path{path}, port_hint{port_hint},
addresses{std::move(addr)} {}
Own<Server> listen() override;
Conveyor<Own<IoStream>> connect() override;
Own<Datagram> datagram() override;
std::string toString() const override;
UnixNetworkAddress(const std::string &path, uint16_t port_hint,
std::vector<SocketAddress> &&addr)
: path{path}, port_hint{port_hint}, addresses{std::move(addr)} {}
const std::string &address() const override;
@ -448,6 +439,12 @@ public:
Conveyor<Own<NetworkAddress>> parseAddress(const std::string &address,
uint16_t port_hint = 0) override;
Own<Server> listen(NetworkAddress &addr) override;
Conveyor<Own<IoStream>> connect(NetworkAddress &addr) override;
Own<Datagram> datagram(NetworkAddress &addr) override;
};
class UnixIoProvider final : public IoProvider {

View File

@ -5,6 +5,7 @@
#include "io_helpers.h"
#include <string>
#include <variant>
namespace saw {
/*
@ -110,37 +111,67 @@ public:
virtual Conveyor<void> writeReady() = 0;
};
class OsNetworkAddress;
class StringNetworkAddress;
class NetworkAddress {
public:
using ChildVariant =
std::variant<OsNetworkAddress *, StringNetworkAddress *>;
virtual ~NetworkAddress() = default;
/**
* Set up a listener on this address
*/
virtual Own<Server> listen() = 0;
/**
* Connect to a remote address
*/
virtual Conveyor<Own<IoStream>> connect() = 0;
/**
* Bind a datagram socket at this address.
*/
virtual Own<Datagram> datagram() = 0;
virtual std::string toString() const = 0;
virtual NetworkAddress::ChildVariant representation() = 0;
virtual const std::string &address() const = 0;
virtual uint16_t port() const = 0;
};
class OsNetworkAddress : public NetworkAddress {
public:
virtual ~OsNetworkAddress() = default;
NetworkAddress::ChildVariant representation() override { return this; }
};
class StringNetworkAddress final : public NetworkAddress {
private:
std::string address_value;
uint16_t port_value;
public:
StringNetworkAddress(const std::string &address, uint16_t port);
const std::string &address() const override;
uint16_t port() const override;
NetworkAddress::ChildVariant representation() override { return this; }
};
class Network {
public:
virtual ~Network() = default;
/**
* Parse the provided string and uint16 to the preferred storage method
*/
virtual Conveyor<Own<NetworkAddress>>
parseAddress(const std::string &addr, uint16_t port_hint = 0) = 0;
/**
* Set up a listener on this address
*/
virtual Own<Server> listen(NetworkAddress &bind_addr) = 0;
/**
* Connect to a remote address
*/
virtual Conveyor<Own<IoStream>> connect(NetworkAddress &address) = 0;
/**
* Bind a datagram socket at this address.
*/
virtual Own<Datagram> datagram(NetworkAddress &address) = 0;
};
class IoProvider {

View File

@ -155,24 +155,18 @@ public:
};
}
TlsNetworkAddress::TlsNetworkAddress(Own<NetworkAddress> net_addr, const std::string& host_name_, Tls &tls_)
: internal{std::move(net_addr)}, host_name{host_name_}, tls{tls_} {}
Own<Server> TlsNetworkAddress::listen() {
SAW_ASSERT(internal) { return nullptr; }
return heap<TlsServer>(internal->listen());
Own<Server> TlsNetwork::listen(NetworkAddress& address) {
return heap<TlsServer>(internal.listen(address));
}
Conveyor<Own<IoStream>> TlsNetworkAddress::connect() {
SAW_ASSERT(internal) { return Conveyor<Own<IoStream>>{nullptr, nullptr}; }
Conveyor<Own<IoStream>> TlsNetwork::connect(NetworkAddress& address) {
// Helper setups
auto caf = newConveyorAndFeeder<Own<IoStream>>();
Own<TlsClientStreamHelper> helper = heap<TlsClientStreamHelper>(std::move(caf.feeder));
TlsClientStreamHelper* hlp_ptr = helper.get();
// Conveyor entangled structure
auto prim_conv = internal->connect().then([this, hlp_ptr](
auto prim_conv = internal.connect(address).then([this, hlp_ptr, addr = address.address()](
Own<IoStream> stream) -> ErrorOr<void> {
IoStream* inner_stream = stream.get();
auto tls_stream = heap<TlsIoStream>(std::move(stream));
@ -181,8 +175,6 @@ Conveyor<Own<IoStream>> TlsNetworkAddress::connect() {
gnutls_init(&session, GNUTLS_CLIENT);
const std::string &addr = this->address();
gnutls_server_name_set(session, GNUTLS_NAME_DNS, addr.c_str(),
addr.size());
@ -209,7 +201,7 @@ Conveyor<Own<IoStream>> TlsNetworkAddress::connect() {
return caf.conveyor.attach(std::move(helper));
}
Own<Datagram> TlsNetworkAddress::datagram(){
Own<Datagram> TlsNetwork::datagram(NetworkAddress& address){
///@unimplemented
return nullptr;
}
@ -243,26 +235,13 @@ static ssize_t forst_tls_pull_func(gnutls_transport_ptr_t p, void *data, size_t
return static_cast<ssize_t>(length.value());
}
const std::string &TlsNetworkAddress::address() const {
assert(internal);
return internal->address();
}
uint16_t TlsNetworkAddress::port() const {
assert(internal);
return internal->port(); }
std::string TlsNetworkAddress::toString() const { return internal->toString(); }
TlsNetwork::TlsNetwork(Network &network) : internal{network} {}
Conveyor<Own<NetworkAddress>> TlsNetwork::parseAddress(const std::string &addr,
uint16_t port) {
return internal.parseAddress(addr, port)
.then(
[this, addr, port](Own<NetworkAddress> net) -> Own<NetworkAddress> {
assert(net);
return heap<TlsNetworkAddress>(std::move(net), addr, tls);
});
/// @todo tls server name needed. Check validity. Won't matter later on, because gnutls should fail anyway. But
/// it's better to find the error source sooner rather than later
return internal.parseAddress(addr, port);
}
std::optional<Own<TlsNetwork>> setupTlsNetwork(Network &network) {

View File

@ -33,27 +33,6 @@ public:
Conveyor<Own<IoStream>> accept() override;
};
class TlsNetworkAddress final : public NetworkAddress {
private:
Own<NetworkAddress> internal;
std::string host_name;
Tls &tls;
public:
TlsNetworkAddress(Own<NetworkAddress> net_addr, const std::string& host_name_, Tls &tls_);
Own<Server> listen() override;
Conveyor<Own<IoStream>> connect() override;
Own<Datagram> datagram() override;
std::string toString() const override;
const std::string &address() const override;
uint16_t port() const override;
};
class TlsNetwork final : public Network {
private:
Tls tls;
@ -62,8 +41,13 @@ private:
public:
TlsNetwork(Network &network);
Conveyor<Own<NetworkAddress>> parseAddress(const std::string &addr,
uint16_t port = 0) override;
Conveyor<Own<NetworkAddress>> parseAddress(const std::string &addr, uint16_t port = 0) override;
Own<Server> listen(NetworkAddress& address) override;
Conveyor<Own<IoStream>> connect(NetworkAddress& address) override;
Own<Datagram> datagram(NetworkAddress& address) override;
};
std::optional<Own<TlsNetwork>> setupTlsNetwork(Network &network);