#pragma once #include "buffer.h" #include "message.h" #include "stream_endian.h" namespace saw { /// @todo replace types with these /* * I'm not really sure if anyone will use a union which is * bigger than uint32_t max. At least I hope noone would do this */ using msg_union_id_t = uint32_t; using msg_array_length_t = uint64_t; using msg_packet_length_t = uint64_t; class ProtoKelCodec { private: struct ReadContext { Buffer &buffer; size_t offset = 0; }; template friend struct ProtoKelEncodeImpl; template friend struct ProtoKelDecodeImpl; public: struct Limits { msg_packet_length_t packet_size; Limits() : packet_size{4096} {} Limits(msg_packet_length_t ps) : packet_size{ps} {} }; struct Version { size_t major; size_t minor; size_t security; }; const Version version() const { return Version{0, 0, 0}; } template > Error encode(typename Message::Reader reader, Buffer &buffer); template > Error decode(typename Message::Builder builder, Buffer &buffer, const Limits &limits = Limits{}); }; template struct ProtoKelEncodeImpl; template struct ProtoKelEncodeImpl, Container>> { static Error encode(typename Message, Container>::Reader data, Buffer &buffer) { Error error = StreamValue>::Type>::encode(data.get(), buffer); return error; } static size_t size(typename Message, Container>::Reader) { return StreamValue>::Type>::size(); } }; template struct ProtoKelEncodeImpl> { static Error encode(typename Message::Reader data, Buffer &buffer) { std::string_view view = data.get(); size_t size = view.size(); Error error = buffer.writeRequireLength(sizeof(size) + size); if (error.failed()) { return error; } error = StreamValue::encode(size, buffer); if (error.failed()) { return error; } for (size_t i = 0; i < view.size(); ++i) { buffer.write(i) = view[i]; } buffer.writeAdvance(view.size()); return noError(); } static size_t size(typename Message::Reader reader) { return sizeof(size_t) + reader.get().size(); } }; template struct ProtoKelEncodeImpl, Container>> { template static typename std::enable_if::type encodeMembers(typename Message, Container>::Reader, Buffer &) { return noError(); } template static typename std::enable_if<(i < sizeof...(T)), Error>::type encodeMembers(typename Message, Container>::Reader data, Buffer &buffer) { Error error = ProtoKelEncodeImpl>:: encode(data.template get(), buffer); if (error.failed()) { return error; } return encodeMembers(data, buffer); } static Error encode(typename Message, Container>::Reader data, Buffer &buffer) { return encodeMembers<0>(data, buffer); } template static typename std::enable_if::type sizeMembers(typename Message, Container>::Reader) { return 0; } template static typename std::enable_if < i::type sizeMembers( typename Message, Container>::Reader reader) { return ProtoKelEncodeImpl>:: size(reader.template get()) + sizeMembers(reader); } static size_t size(typename Message, Container>::Reader reader) { return sizeMembers<0>(reader); } }; template struct ProtoKelEncodeImpl< Message...>, Container>> { template static typename std::enable_if::type encodeMembers(typename Message...>, Container>::Reader, Buffer &) { return noError(); } template static typename std::enable_if < i::type encodeMembers( typename Message...>, Container>::Reader data, Buffer &buffer) { Error error = ProtoKelEncodeImpl>:: encode(data.template get(), buffer); if (error.failed()) { return error; } return encodeMembers(data, buffer); } static Error encode(typename Message...>, Container>::Reader data, Buffer &buffer) { return encodeMembers<0>(data, buffer); } template static typename std::enable_if::type sizeMembers(typename Message...>, Container>::Reader) { return 0; } template static typename std::enable_if < i::type sizeMembers( typename Message...>, Container>::Reader reader) { return ProtoKelEncodeImpl>:: size(reader.template get()) + sizeMembers(reader); } static size_t size(typename Message...>, Container>::Reader reader) { return sizeMembers<0>(reader); } }; template struct ProtoKelEncodeImpl< Message...>, Container>> { template static typename std::enable_if::type encodeMembers(typename Message...>, Container>::Reader, Buffer &) { return noError(); } template static typename std::enable_if < i::type encodeMembers( typename Message...>, Container>::Reader reader, Buffer &buffer) { if (reader.index() == i) { Error error = StreamValue::encode(i, buffer); if (error.failed()) { return error; } return ProtoKelEncodeImpl>::encode(reader.template get(), buffer); } return encodeMembers(reader, buffer); } static Error encode(typename Message...>, Container>::Reader reader, Buffer &buffer) { return encodeMembers<0>(reader, buffer); } template static typename std::enable_if::type sizeMembers(typename Message...>, Container>::Reader) { return 0; } template static typename std::enable_if < i::type sizeMembers( typename Message...>, Container>::Reader reader) { if (reader.index() == i) { return ProtoKelEncodeImpl>::size(reader.template get()); } return sizeMembers(reader); } /* * Size of union id + member size */ static size_t size(typename Message...>, Container>::Reader reader) { return sizeof(msg_union_id_t) + sizeMembers<0>(reader); } }; template struct ProtoKelEncodeImpl, Container>> { static Error encode(typename Message, Container>::Reader data, Buffer &buffer) { msg_array_length_t array_length = data.size(); { Error error = StreamValue::encode(array_length, buffer); if (error.failed()) { return error; } } for (size_t i = 0; i < array_length; ++i) { Error error = ProtoKelEncodeImpl::encode( data.get(i), buffer); if (error.failed()) { return error; } } return noError(); } /* * */ static size_t size(typename Message, Container>::Reader data) { size_t members = sizeof(msg_array_length_t); for (size_t i = 0; i < data.size(); ++i) { members += ProtoKelEncodeImpl::size( data.get(i)); } return members; } }; /* * Decode Implementations */ template struct ProtoKelDecodeImpl; template struct ProtoKelDecodeImpl, Container>> { static Error decode(typename Message, Container>::Builder data, Buffer &buffer) { typename PrimitiveTypeHelper>::Type val = 0; Error error = StreamValue>::Type>::decode(val, buffer); data.set(val); return error; } }; template struct ProtoKelDecodeImpl> { static Error decode(typename Message::Builder data, Buffer &buffer) { size_t size = 0; if (sizeof(size) > buffer.readCompositeLength()) { return recoverableError("Buffer too small"); } Error error = StreamValue::decode(size, buffer); if (error.failed()) { return error; } if (size > buffer.readCompositeLength()) { return recoverableError("Buffer too small"); } std::string value; value.resize(size); if (size > buffer.readCompositeLength()) { return recoverableError("Buffer too small"); } for (size_t i = 0; i < value.size(); ++i) { value[i] = buffer.read(i); } buffer.readAdvance(value.size()); data.set(std::move(value)); return noError(); } }; template struct ProtoKelDecodeImpl, Container>> { template static typename std::enable_if::type decodeMembers(typename Message, Container>::Builder, Buffer &) { return noError(); } template static typename std::enable_if < i::type decodeMembers( typename Message, Container>::Builder builder, Buffer &buffer) { Error error = ProtoKelDecodeImpl>:: decode(builder.template init(), buffer); if (error.failed()) { return error; } return decodeMembers(builder, buffer); } static Error decode(typename Message, Container>::Builder builder, Buffer &buffer) { return decodeMembers<0>(builder, buffer); } }; template struct ProtoKelDecodeImpl< Message...>, Container>> { template static typename std::enable_if::type decodeMembers(typename Message...>, Container>::Builder, Buffer &) { return noError(); } template static typename std::enable_if < i::type decodeMembers( typename Message...>, Container>::Builder builder, Buffer &buffer) { Error error = ProtoKelDecodeImpl>:: decode(builder.template init(), buffer); if (error.failed()) { return error; } return decodeMembers(builder, buffer); } static Error decode(typename Message...>, Container>::Builder builder, Buffer &buffer) { return decodeMembers<0>(builder, buffer); } }; template struct ProtoKelDecodeImpl< Message...>, Container>> { template static typename std::enable_if::type decodeMembers(typename Message...>, Container>::Builder, Buffer &, msg_union_id_t) { return noError(); } template static typename std::enable_if < i::type decodeMembers( typename Message...>, Container>::Builder builder, Buffer &buffer, msg_union_id_t id) { if (id == i) { Error error = ProtoKelDecodeImpl>::decode(builder.template init(), buffer); if (error.failed()) { return error; } } return decodeMembers(builder, buffer, id); } static Error decode(typename Message...>, Container>::Builder builder, Buffer &buffer) { msg_union_id_t id = 0; Error error = StreamValue::decode(id, buffer); if (error.failed()) { return error; } if (id >= sizeof...(V)) { return criticalError("Union doesn't have this many id's"); } return decodeMembers<0>(builder, buffer, id); } }; template struct ProtoKelDecodeImpl, Container>> { static Error decode(typename Message, Container>::Builder data, Buffer &buffer) { msg_array_length_t array_length = 0; { Error error = StreamValue::decode(array_length, buffer); if (error.failed()) { return error; } } data.resize(array_length); for (size_t i = 0; i < array_length; ++i) { Error error = ProtoKelDecodeImpl::decode( data.init(i), buffer); if (error.failed()) { return error; } } return noError(); } }; template Error ProtoKelCodec::encode(typename Message::Reader reader, Buffer &buffer) { BufferView view{buffer}; msg_packet_length_t packet_length = ProtoKelEncodeImpl>::size(reader); // Check the size of the packet for the first // message length description Error error = view.writeRequireLength(packet_length + sizeof(msg_packet_length_t)); if (error.failed()) { return error; } { Error error = StreamValue::encode(packet_length, view); if (error.failed()) { return error; } } { Error error = ProtoKelEncodeImpl>::encode( reader, view); if (error.failed()) { return error; } } buffer.writeAdvance(view.writeOffset()); return noError(); } template Error ProtoKelCodec::decode( typename Message::Builder builder, Buffer &buffer, const Limits &limits) { BufferView view{buffer}; msg_packet_length_t packet_length = 0; { Error error = StreamValue::decode(packet_length, view); if (error.failed()) { return error; } } if (packet_length > limits.packet_size) { return criticalError( [packet_length]() { return std::string{"Packet size too big: "} + std::to_string(packet_length); }, "Packet size too big"); } { Error error = ProtoKelDecodeImpl>::decode( builder, view); if (error.failed()) { return error; } } { if (ProtoKelEncodeImpl>::size( builder.asReader()) != packet_length) { return criticalError("Bad packet format"); } } buffer.readAdvance(view.readOffset()); return noError(); } } // namespace saw