[Libreoffice-commits] online.git: common/Session.cpp common/Session.hpp kit/ChildSession.cpp kit/ChildSession.hpp kit/Kit.cpp net/Socket.cpp net/Socket.hpp net/SslSocket.hpp net/WebSocketHandler.hpp test/UnitWOPIVersionRestore.cpp tools/WebSocketDump.cpp wsd/ClientSession.cpp wsd/ClientSession.hpp wsd/DocumentBroker.cpp wsd/DocumentBroker.hpp wsd/LOOLWSD.cpp wsd/TestStubs.cpp
Michael Meeks (via logerrit)
logerrit at kemper.freedesktop.org
Wed Mar 11 15:48:24 UTC 2020
common/Session.cpp | 50 +++++++++++-------
common/Session.hpp | 38 +++++++++++--
kit/ChildSession.cpp | 14 ++---
kit/ChildSession.hpp | 22 +++++---
kit/Kit.cpp | 12 ++--
net/Socket.cpp | 8 +-
net/Socket.hpp | 109 ++++++++++++++++++++++++++++++++++------
net/SslSocket.hpp | 2
net/WebSocketHandler.hpp | 63 +++++++++++++++++++----
test/UnitWOPIVersionRestore.cpp | 1
tools/WebSocketDump.cpp | 2
wsd/ClientSession.cpp | 44 +++++++---------
wsd/ClientSession.hpp | 21 ++++---
wsd/DocumentBroker.cpp | 20 ++++---
wsd/DocumentBroker.hpp | 11 ++--
wsd/LOOLWSD.cpp | 53 +++++++++++--------
wsd/TestStubs.cpp | 8 +-
17 files changed, 332 insertions(+), 146 deletions(-)
New commits:
commit e924625cc1af8736505f363fc525d20a6373bb95
Author: Michael Meeks <michael.meeks at collabora.com>
AuthorDate: Fri Mar 6 17:43:46 2020 +0000
Commit: Michael Meeks <michael.meeks at collabora.com>
CommitDate: Wed Mar 11 16:48:03 2020 +0100
re-factor: Socket / WebSocketHandler.
Essentially we want to be able to separate low-level socket code
for eg. TCP vs. UDS, from Protocol handling: eg. WebSocketHandler
and client sessions themselves which handle and send messages
which now implement the simple MessageHandlerInterface.
Some helpful renaming too:
s/SocketHandlerInterface/ProtocolHandlerInterface/
Change-Id: I58092b5e0b5792fda47498fb2c875851eada461d
Reviewed-on: https://gerrit.libreoffice.org/c/online/+/90138
Tested-by: Jenkins CollaboraOffice <jenkinscollaboraoffice at gmail.com>
Reviewed-by: Michael Meeks <michael.meeks at collabora.com>
diff --git a/common/Session.cpp b/common/Session.cpp
index 4b4c563d6..15dbe86d7 100644
--- a/common/Session.cpp
+++ b/common/Session.cpp
@@ -44,7 +44,9 @@ using namespace LOOLProtocol;
using Poco::Exception;
using std::size_t;
-Session::Session(const std::string& name, const std::string& id, bool readOnly) :
+Session::Session(const std::shared_ptr<ProtocolHandlerInterface> &protocol,
+ const std::string& name, const std::string& id, bool readOnly) :
+ MessageHandlerInterface(protocol),
_id(id),
_name(name),
_disconnected(false),
@@ -65,14 +67,26 @@ Session::~Session()
bool Session::sendTextFrame(const char* buffer, const int length)
{
+ if (!_protocol)
+ {
+ LOG_TRC("ERR - missing protocol " << getName() << ": Send: [" << getAbbreviatedMessage(buffer, length) << "].");
+ return false;
+ }
+
LOG_TRC(getName() << ": Send: [" << getAbbreviatedMessage(buffer, length) << "].");
- return sendMessage(buffer, length, WSOpCode::Text) >= length;
+ return _protocol->sendTextMessage(buffer, length) >= length;
}
bool Session::sendBinaryFrame(const char *buffer, int length)
{
+ if (!_protocol)
+ {
+ LOG_TRC("ERR - missing protocol " << getName() << ": Send: " << std::to_string(length) << " binary bytes.");
+ return false;
+ }
+
LOG_TRC(getName() << ": Send: " << std::to_string(length) << " binary bytes.");
- return sendMessage(buffer, length, WSOpCode::Binary) >= length;
+ return _protocol->sendBinaryMessage(buffer, length) >= length;
}
void Session::parseDocOptions(const StringVector& tokens, int& part, std::string& timestamp, std::string& doctemplate)
@@ -196,15 +210,20 @@ void Session::disconnect()
}
}
-void Session::shutdown(const WebSocketHandler::StatusCodes statusCode, const std::string& statusMessage)
+void Session::shutdown(bool goingAway, const std::string& statusMessage)
{
- LOG_TRC("Shutting down WS [" << getName() << "] with statusCode [" <<
- static_cast<unsigned>(statusCode) << "] and reason [" << statusMessage << "].");
+ LOG_TRC("Shutting down WS [" << getName() << "] " <<
+ (goingAway ? "going" : "normal") <<
+ " and reason [" << statusMessage << "].");
// See protocol.txt for this application-level close frame.
- sendMessage("close: " + statusMessage);
-
- WebSocketHandler::shutdown(statusCode, statusMessage);
+ if (_protocol)
+ {
+ // skip the queue; FIXME: should we flush SessionClient's queue ?
+ std::string closeMsg = "close: " + statusMessage;
+ _protocol->sendTextMessage(closeMsg, closeMsg.size());
+ _protocol->shutdown(goingAway, statusMessage);
+ }
}
void Session::handleMessage(const std::vector<char> &data)
@@ -238,21 +257,12 @@ void Session::handleMessage(const std::vector<char> &data)
void Session::getIOStats(uint64_t &sent, uint64_t &recv)
{
- std::shared_ptr<StreamSocket> socket = getSocket().lock();
- if (socket)
- socket->getIOStats(sent, recv);
- else
- {
- sent = 0;
- recv = 0;
- }
+ _protocol->getIOStats(sent, recv);
}
void Session::dumpState(std::ostream& os)
{
- WebSocketHandler::dumpState(os);
-
- os << "\t\tid: " << _id
+ os << "\t\tid: " << _id
<< "\n\t\tname: " << _name
<< "\n\t\tdisconnected: " << _disconnected
<< "\n\t\tisActive: " << _isActive
diff --git a/common/Session.hpp b/common/Session.hpp
index 6b5e93322..dbf75ad2f 100644
--- a/common/Session.hpp
+++ b/common/Session.hpp
@@ -64,7 +64,7 @@ public:
};
/// Base class of a WebSocket session.
-class Session : public WebSocketHandler
+class Session : public MessageHandlerInterface
{
public:
const std::string& getId() const { return _id; }
@@ -74,8 +74,32 @@ public:
virtual void setReadOnly() { _isReadOnly = true; }
bool isReadOnly() const { return _isReadOnly; }
+ /// overridden to prepend client ids on messages by the Kit
virtual bool sendBinaryFrame(const char* buffer, int length);
virtual bool sendTextFrame(const char* buffer, const int length);
+
+ /// Get notified that the underlying transports disconnected
+ void onDisconnect() override { /* ignore */ }
+
+ bool hasQueuedMessages() const override
+ {
+ // queued in Socket output buffer
+ return false;
+ }
+
+ // By default rely on the socket buffer.
+ void writeQueuedMessages() override
+ {
+ assert(false);
+ }
+
+ /// Sends a WebSocket Text message.
+ int sendMessage(const std::string& msg)
+ {
+ return sendTextFrame(msg.data(), msg.size());
+ }
+
+ // FIXME: remove synonym - and clean from WebSocketHandler too ... (?)
bool sendTextFrame(const std::string& text)
{
return sendTextFrame(text.data(), text.size());
@@ -98,12 +122,10 @@ public:
virtual void disconnect();
/// clean & normal shutdown
- void shutdownNormal(const std::string& statusMessage = "")
- { shutdown(WebSocketHandler::StatusCodes::NORMAL_CLOSE, statusMessage); }
+ void shutdownNormal(const std::string& statusMessage = "") { shutdown(false, statusMessage); }
/// abnormal / hash shutdown end-point going away
- void shutdownGoingAway(const std::string& statusMessage = "")
- { shutdown(WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY, statusMessage); }
+ void shutdownGoingAway(const std::string& statusMessage = "") { shutdown(true, statusMessage); }
bool isActive() const { return _isActive; }
void setIsActive(bool active) { _isActive = active; }
@@ -165,7 +187,8 @@ public:
}
protected:
- Session(const std::string& name, const std::string& id, bool readonly);
+ Session(const std::shared_ptr<ProtocolHandlerInterface> &handler,
+ const std::string& name, const std::string& id, bool readonly);
virtual ~Session();
/// Parses the options of the "load" command,
@@ -181,8 +204,7 @@ protected:
private:
- void shutdown(const WebSocketHandler::StatusCodes statusCode = WebSocketHandler::StatusCodes::NORMAL_CLOSE,
- const std::string& statusMessage = "");
+ void shutdown(bool goingAway = false, const std::string& statusMessage = "");
virtual bool _handleInput(const char* buffer, int length) = 0;
diff --git a/kit/ChildSession.cpp b/kit/ChildSession.cpp
index 682012a50..4842b8ffe 100644
--- a/kit/ChildSession.cpp
+++ b/kit/ChildSession.cpp
@@ -19,7 +19,6 @@
#include <Poco/JSON/Object.h>
#include <Poco/JSON/Parser.h>
-#include <Poco/Net/WebSocket.h>
#include <Poco/StreamCopier.h>
#include <Poco/URI.h>
#include <Poco/BinaryReader.h>
@@ -62,10 +61,12 @@ std::vector<unsigned char> decodeBase64(const std::string & inputBase64)
}
-ChildSession::ChildSession(const std::string& id,
- const std::string& jailId,
- DocumentManagerInterface& docManager) :
- Session("ToMaster-" + id, id, false),
+ChildSession::ChildSession(
+ const std::shared_ptr<ProtocolHandlerInterface> &protocol,
+ const std::string& id,
+ const std::string& jailId,
+ DocumentManagerInterface& docManager) :
+ Session(protocol, "ToMaster-" + id, id, false),
_jailId(jailId),
_docManager(&docManager),
_viewId(-1),
@@ -98,7 +99,8 @@ void ChildSession::disconnect()
LOG_WRN("Skipping unload on incomplete view.");
}
- Session::disconnect();
+// This shuts down the shared socket, which is not what we want.
+// Session::disconnect();
}
}
diff --git a/kit/ChildSession.hpp b/kit/ChildSession.hpp
index 9bb2b7d0f..c7a248546 100644
--- a/kit/ChildSession.hpp
+++ b/kit/ChildSession.hpp
@@ -199,9 +199,11 @@ public:
/// a new view) or nullptr (when first view).
/// jailId The JailID of the jail root directory,
// used by downloadas to construct jailed path.
- ChildSession(const std::string& id,
- const std::string& jailId,
- DocumentManagerInterface& docManager);
+ ChildSession(
+ const std::shared_ptr<ProtocolHandlerInterface> &protocol,
+ const std::string& id,
+ const std::string& jailId,
+ DocumentManagerInterface& docManager);
virtual ~ChildSession();
bool getStatus(const char* buffer, int length);
@@ -219,12 +221,22 @@ public:
bool sendTextFrame(const char* buffer, int length) override
{
+ if (!_docManager)
+ {
+ LOG_TRC("ERR dropping - client-" + getId() + ' ' + std::string(buffer, length));
+ return false;
+ }
const auto msg = "client-" + getId() + ' ' + std::string(buffer, length);
return _docManager->sendFrame(msg.data(), msg.size(), WSOpCode::Text);
}
bool sendBinaryFrame(const char* buffer, int length) override
{
+ if (!_docManager)
+ {
+ LOG_TRC("ERR dropping binary - client-" + getId());
+ return false;
+ }
const auto msg = "client-" + getId() + ' ' + std::string(buffer, length);
return _docManager->sendFrame(msg.data(), msg.size(), WSOpCode::Binary);
}
@@ -235,11 +247,7 @@ public:
void resetDocManager()
{
-#if MOBILEAPP
- // I suspect this might be useful even for the non-mobile case, but
- // not 100% sure, so rather do it mobile-only for now
disconnect();
-#endif
_docManager = nullptr;
}
diff --git a/kit/Kit.cpp b/kit/Kit.cpp
index a302f6e35..54ad81647 100644
--- a/kit/Kit.cpp
+++ b/kit/Kit.cpp
@@ -781,7 +781,9 @@ public:
" session for url: " << anonymizeUrl(_url) << " for sessionId: " <<
sessionId << " on jailId: " << _jailId);
- auto session = std::make_shared<ChildSession>(sessionId, _jailId, *this);
+ auto session = std::make_shared<ChildSession>(
+ _websocketHandler,
+ sessionId, _jailId, *this);
_sessions.emplace(sessionId, session);
int viewId = session->getViewId();
@@ -2072,7 +2074,7 @@ std::shared_ptr<lok::Document> getLOKDocument()
return Document::_loKitDocument;
}
-class KitWebSocketHandler final : public WebSocketHandler, public std::enable_shared_from_this<KitWebSocketHandler>
+class KitWebSocketHandler final : public WebSocketHandler
{
std::shared_ptr<TileQueue> _queue;
std::string _socketName;
@@ -2137,7 +2139,9 @@ protected:
Util::setThreadName("kitbroker_" + docId);
if (!document)
- document = std::make_shared<Document>(_loKit, _jailId, docKey, docId, url, _queue, shared_from_this());
+ document = std::make_shared<Document>(
+ _loKit, _jailId, docKey, docId, url, _queue,
+ std::static_pointer_cast<WebSocketHandler>(shared_from_this()));
// Validate and create session.
if (!(url == document->getUrl() && document->createSession(sessionId)))
@@ -2633,7 +2637,7 @@ void lokit_main(
KitSocketPoll mainKit;
mainKit.runOnClientThread(); // We will do the polling on this thread.
- std::shared_ptr<SocketHandlerInterface> websocketHandler =
+ std::shared_ptr<ProtocolHandlerInterface> websocketHandler =
std::make_shared<KitWebSocketHandler>("child_ws", loKit, jailId);
#if !MOBILEAPP
mainKit.insertNewUnixSocket(MasterLocation, pathAndQuery, websocketHandler);
diff --git a/net/Socket.cpp b/net/Socket.cpp
index cbe5a0a52..5bb1fa250 100644
--- a/net/Socket.cpp
+++ b/net/Socket.cpp
@@ -204,7 +204,7 @@ void SocketPoll::wakeupWorld()
void SocketPoll::insertNewWebSocketSync(
const Poco::URI &uri,
- const std::shared_ptr<SocketHandlerInterface>& websocketHandler)
+ const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler)
{
LOG_INF("Connecting to " << uri.getHost() << " : " << uri.getPort() << " : " << uri.getPath());
@@ -277,7 +277,7 @@ void SocketPoll::insertNewWebSocketSync(
// should this be a static method in the WebsocketHandler(?)
void SocketPoll::clientRequestWebsocketUpgrade(const std::shared_ptr<StreamSocket>& socket,
- const std::shared_ptr<SocketHandlerInterface>& websocketHandler,
+ const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler,
const std::string &pathAndQuery)
{
// cf. WebSocketHandler::upgradeToWebSocket (?)
@@ -304,7 +304,7 @@ void SocketPoll::clientRequestWebsocketUpgrade(const std::shared_ptr<StreamSocke
void SocketPoll::insertNewUnixSocket(
const std::string &location,
const std::string &pathAndQuery,
- const std::shared_ptr<SocketHandlerInterface>& websocketHandler)
+ const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler)
{
int fd = socket(AF_UNIX, SOCK_STREAM | SOCK_NONBLOCK, 0);
@@ -337,7 +337,7 @@ void SocketPoll::insertNewUnixSocket(
void SocketPoll::insertNewFakeSocket(
int peerSocket,
- const std::shared_ptr<SocketHandlerInterface>& websocketHandler)
+ const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler)
{
LOG_INF("Connecting to " << peerSocket);
int fd = fakeSocketSocket();
diff --git a/net/Socket.hpp b/net/Socket.hpp
index c95b93dd7..99fdf259a 100644
--- a/net/Socket.hpp
+++ b/net/Socket.hpp
@@ -344,12 +344,21 @@ private:
};
class StreamSocket;
+class MessageHandlerInterface;
-/// Interface that handles the actual incoming message.
-class SocketHandlerInterface
+/// Interface that decodes the actual incoming message.
+class ProtocolHandlerInterface :
+ public std::enable_shared_from_this<ProtocolHandlerInterface>
{
+protected:
+ /// We own a message handler, after decoding the socket data we pass it on as messages.
+ std::shared_ptr<MessageHandlerInterface> _msgHandler;
public:
- virtual ~SocketHandlerInterface() {}
+ // ------------------------------------------------------------------
+ // Interface for implementing low level socket goodness from streams.
+ // ------------------------------------------------------------------
+ virtual ~ProtocolHandlerInterface() { }
+
/// Called when the socket is newly created to
/// set the socket associated with this ResponseClient.
/// Will be called exactly once.
@@ -374,10 +383,81 @@ public:
/// Will be called exactly once.
virtual void onDisconnect() {}
+ // -----------------------------------------------------------------
+ // Interface for external MessageHandlers
+ // -----------------------------------------------------------------
+public:
+ void setMessageHandler(const std::shared_ptr<MessageHandlerInterface> &msgHandler)
+ {
+ _msgHandler = msgHandler;
+ }
+
+ /// Clear all external references
+ virtual void dispose() { _msgHandler.reset(); }
+
+ virtual int sendTextMessage(const std::string &msg, const size_t len, bool flush = false) const = 0;
+ virtual int sendBinaryMessage(const char *data, const size_t len, bool flush = false) const = 0;
+ virtual void shutdown(bool goingAway = false, const std::string &statusMessage = "") = 0;
+
+ virtual void getIOStats(uint64_t &sent, uint64_t &recv) = 0;
+
/// Append pretty printed internal state to a line
virtual void dumpState(std::ostream& os) { os << "\n"; }
};
+/// A ProtocolHandlerInterface with dummy sending API.
+class SimpleSocketHandler : public ProtocolHandlerInterface
+{
+public:
+ SimpleSocketHandler() {}
+ int sendTextMessage(const std::string &, const size_t, bool) const override { return 0; }
+ int sendBinaryMessage(const char *, const size_t , bool ) const override { return 0; }
+ void shutdown(bool, const std::string &) override {}
+ void getIOStats(uint64_t &, uint64_t &) override {}
+};
+
+/// Interface that receives and sends incoming messages.
+class MessageHandlerInterface :
+ public std::enable_shared_from_this<MessageHandlerInterface>
+{
+protected:
+ std::shared_ptr<ProtocolHandlerInterface> _protocol;
+ MessageHandlerInterface(const std::shared_ptr<ProtocolHandlerInterface> &protocol) :
+ _protocol(protocol)
+ {
+ }
+ virtual ~MessageHandlerInterface() {}
+
+public:
+ /// Setup, after construction for shared_from_this
+ void initialize()
+ {
+ if (_protocol)
+ _protocol->setMessageHandler(shared_from_this());
+ }
+
+ /// Clear all external references
+ virtual void dispose()
+ {
+ if (_protocol)
+ {
+ _protocol->dispose();
+ _protocol.reset();
+ }
+ }
+
+ /// Do we have something to send ?
+ virtual bool hasQueuedMessages() const = 0;
+ /// Please send them to me then.
+ virtual void writeQueuedMessages() = 0;
+ /// We just got a message - here it is
+ virtual void handleMessage(const std::vector<char> &data) = 0;
+ /// Get notified that the underlying transports disconnected
+ virtual void onDisconnect() = 0;
+ /// Append pretty printed internal state to a line
+ virtual void dumpState(std::ostream& os) = 0;
+};
+
/// Handles non-blocking socket event polling.
/// Only polls on N-Sockets and invokes callback and
/// doesn't manage buffers or client data.
@@ -672,16 +752,16 @@ public:
/// Inserts a new remote websocket to be polled.
/// NOTE: The DNS lookup is synchronous.
void insertNewWebSocketSync(const Poco::URI &uri,
- const std::shared_ptr<SocketHandlerInterface>& websocketHandler);
+ const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler);
void insertNewUnixSocket(
const std::string &location,
const std::string &pathAndQuery,
- const std::shared_ptr<SocketHandlerInterface>& websocketHandler);
+ const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler);
#else
void insertNewFakeSocket(
int peerSocket,
- const std::shared_ptr<SocketHandlerInterface>& websocketHandler);
+ const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler);
#endif
typedef std::function<void()> CallbackFn;
@@ -736,7 +816,7 @@ protected:
private:
/// Generate the request to connect & upgrade this socket to a given path
void clientRequestWebsocketUpgrade(const std::shared_ptr<StreamSocket>& socket,
- const std::shared_ptr<SocketHandlerInterface>& websocketHandler,
+ const std::shared_ptr<ProtocolHandlerInterface>& websocketHandler,
const std::string &pathAndQuery);
/// Initialize the poll fds array with the right events
@@ -791,12 +871,13 @@ private:
};
/// A plain, non-blocking, data streaming socket.
-class StreamSocket : public Socket, public std::enable_shared_from_this<StreamSocket>
+class StreamSocket : public Socket,
+ public std::enable_shared_from_this<StreamSocket>
{
public:
/// Create a StreamSocket from native FD.
StreamSocket(const int fd, bool /* isClient */,
- std::shared_ptr<SocketHandlerInterface> socketHandler) :
+ std::shared_ptr<ProtocolHandlerInterface> socketHandler) :
Socket(fd),
_socketHandler(std::move(socketHandler)),
_bytesSent(0),
@@ -933,7 +1014,7 @@ public:
}
/// Replace the existing SocketHandler with a new one.
- void setHandler(std::shared_ptr<SocketHandlerInterface> handler)
+ void setHandler(std::shared_ptr<ProtocolHandlerInterface> handler)
{
_socketHandler = std::move(handler);
_socketHandler->onConnect(shared_from_this());
@@ -944,9 +1025,9 @@ public:
/// but we can't have a shared_ptr in the ctor.
template <typename TSocket>
static
- std::shared_ptr<TSocket> create(const int fd, bool isClient, std::shared_ptr<SocketHandlerInterface> handler)
+ std::shared_ptr<TSocket> create(const int fd, bool isClient, std::shared_ptr<ProtocolHandlerInterface> handler)
{
- SocketHandlerInterface* pHandler = handler.get();
+ ProtocolHandlerInterface* pHandler = handler.get();
auto socket = std::make_shared<TSocket>(fd, isClient, std::move(handler));
pHandler->onConnect(socket);
return socket;
@@ -1157,14 +1238,14 @@ protected:
return _shutdownSignalled;
}
- const std::shared_ptr<SocketHandlerInterface>& getSocketHandler() const
+ const std::shared_ptr<ProtocolHandlerInterface>& getSocketHandler() const
{
return _socketHandler;
}
private:
/// Client handling the actual data.
- std::shared_ptr<SocketHandlerInterface> _socketHandler;
+ std::shared_ptr<ProtocolHandlerInterface> _socketHandler;
std::vector<char> _inBuffer;
std::vector<char> _outBuffer;
diff --git a/net/SslSocket.hpp b/net/SslSocket.hpp
index ba9954f56..27e075328 100644
--- a/net/SslSocket.hpp
+++ b/net/SslSocket.hpp
@@ -20,7 +20,7 @@ class SslStreamSocket final : public StreamSocket
{
public:
SslStreamSocket(const int fd, bool isClient,
- std::shared_ptr<SocketHandlerInterface> responseClient) :
+ std::shared_ptr<ProtocolHandlerInterface> responseClient) :
StreamSocket(fd, isClient, std::move(responseClient)),
_bio(nullptr),
_ssl(nullptr),
diff --git a/net/WebSocketHandler.hpp b/net/WebSocketHandler.hpp
index 130f81b69..1c2977602 100644
--- a/net/WebSocketHandler.hpp
+++ b/net/WebSocketHandler.hpp
@@ -24,7 +24,7 @@
#include <Poco/Net/HTTPResponse.h>
#include <Poco/Net/WebSocket.h>
-class WebSocketHandler : public SocketHandlerInterface
+class WebSocketHandler : public ProtocolHandlerInterface
{
private:
/// The socket that owns us (we can't own it).
@@ -94,7 +94,7 @@ public:
upgradeToWebSocket(request);
}
- /// Implementation of the SocketHandlerInterface.
+ /// Implementation of the ProtocolHandlerInterface.
void onConnect(const std::shared_ptr<StreamSocket>& socket) override
{
_socket = socket;
@@ -146,6 +146,24 @@ public:
#endif
}
+ void shutdown(bool goingAway, const std::string &statusMessage) override
+ {
+ shutdown(goingAway ? WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY :
+ WebSocketHandler::StatusCodes::NORMAL_CLOSE, statusMessage);
+ }
+
+ void getIOStats(uint64_t &sent, uint64_t &recv) override
+ {
+ std::shared_ptr<StreamSocket> socket = getSocket().lock();
+ if (socket)
+ socket->getIOStats(sent, recv);
+ else
+ {
+ sent = 0;
+ recv = 0;
+ }
+ }
+
void shutdown(const StatusCodes statusCode = StatusCodes::NORMAL_CLOSE, const std::string& statusMessage = "")
{
if (!_shuttingDown)
@@ -384,7 +402,7 @@ public:
return true;
}
- /// Implementation of the SocketHandlerInterface.
+ /// Implementation of the ProtocolHandlerInterface.
virtual void handleIncomingMessage(SocketDisposition&) override
{
// LOG_TRC("***** WebSocketHandler::handleIncomingMessage()");
@@ -421,7 +439,10 @@ public:
std::chrono::duration_cast<std::chrono::milliseconds>(now - _lastPingSentTime).count();
timeoutMaxMs = std::min(timeoutMaxMs, PingFrequencyMs - timeSincePingMs);
}
- return POLLIN;
+ int events = POLLIN;
+ if (_msgHandler && _msgHandler->hasQueuedMessages())
+ events |= POLLOUT;
+ return events;
}
#if !MOBILEAPP
@@ -483,13 +504,34 @@ private:
#endif
}
public:
- /// By default rely on the socket buffer.
- void performWrites() override {}
+ void performWrites() override
+ {
+ if (_msgHandler)
+ _msgHandler->writeQueuedMessages();
+ }
+
+ void onDisconnect() override
+ {
+ if (_msgHandler)
+ _msgHandler->onDisconnect();
+ }
/// Sends a WebSocket Text message.
int sendMessage(const std::string& msg) const
{
- return sendMessage(msg.data(), msg.size(), WSOpCode::Text);
+ return sendTextMessage(msg, msg.size());
+ }
+
+ /// Implementation of the ProtocolHandlerInterface.
+ int sendTextMessage(const std::string &msg, const size_t len, bool flush = false) const override
+ {
+ return sendMessage(msg.data(), len, WSOpCode::Text, flush);
+ }
+
+ /// Implementation of the ProtocolHandlerInterface.
+ int sendBinaryMessage(const char *data, const size_t len, bool flush = false) const override
+ {
+ return sendMessage(data, len, WSOpCode::Binary, flush);
}
/// Sends a WebSocket message of WPOpCode type.
@@ -506,9 +548,7 @@ public:
std::shared_ptr<StreamSocket> socket = _socket.lock();
return sendFrame(socket, data, len, WSFrameMask::Fin | static_cast<unsigned char>(code), flush);
}
-
private:
-
/// Sends a WebSocket frame given the data, length, and flags.
/// Returns the number of bytes written (including frame overhead) on success,
/// 0 for closed/invalid socket, and -1 for other errors.
@@ -615,8 +655,10 @@ protected:
}
/// To be overriden to handle the websocket messages the way you need.
- virtual void handleMessage(const std::vector<char> &/*data*/)
+ virtual void handleMessage(const std::vector<char> &data)
{
+ if (_msgHandler)
+ _msgHandler->handleMessage(data);
}
std::weak_ptr<StreamSocket>& getSocket()
@@ -629,6 +671,7 @@ protected:
_socket = socket;
}
+ /// Implementation of the ProtocolHandlerInterface.
void dumpState(std::ostream& os) override;
private:
diff --git a/test/UnitWOPIVersionRestore.cpp b/test/UnitWOPIVersionRestore.cpp
index 3ad8dab09..16192c621 100644
--- a/test/UnitWOPIVersionRestore.cpp
+++ b/test/UnitWOPIVersionRestore.cpp
@@ -68,6 +68,7 @@ public:
{
constexpr char testName[] = "UnitWOPIVersionRestore";
+ LOG_TRC("invokeTest " << (int)_phase);
switch (_phase)
{
case Phase::Load:
diff --git a/tools/WebSocketDump.cpp b/tools/WebSocketDump.cpp
index e2fe32e54..c699a8fed 100644
--- a/tools/WebSocketDump.cpp
+++ b/tools/WebSocketDump.cpp
@@ -50,7 +50,7 @@ private:
};
/// Handles incoming connections and dispatches to the appropriate handler.
-class ClientRequestDispatcher : public SocketHandlerInterface
+class ClientRequestDispatcher : public SimpleSocketHandler
{
public:
ClientRequestDispatcher()
diff --git a/wsd/ClientSession.cpp b/wsd/ClientSession.cpp
index 696411fbf..29e420dad 100644
--- a/wsd/ClientSession.cpp
+++ b/wsd/ClientSession.cpp
@@ -38,12 +38,14 @@ using Poco::Path;
static std::mutex GlobalSessionMapMutex;
static std::unordered_map<std::string, std::weak_ptr<ClientSession>> GlobalSessionMap;
-ClientSession::ClientSession(const std::string& id,
- const std::shared_ptr<DocumentBroker>& docBroker,
- const Poco::URI& uriPublic,
- const bool readOnly,
- const std::string& hostNoTrust) :
- Session("ToClient-" + id, id, readOnly),
+ClientSession::ClientSession(
+ const std::shared_ptr<ProtocolHandlerInterface>& ws,
+ const std::string& id,
+ const std::shared_ptr<DocumentBroker>& docBroker,
+ const Poco::URI& uriPublic,
+ const bool readOnly,
+ const std::string& hostNoTrust) :
+ Session(ws, "ToClient-" + id, id, readOnly),
_docBroker(docBroker),
_uriPublic(uriPublic),
_isDocumentOwner(false),
@@ -86,7 +88,8 @@ ClientSession::ClientSession(const std::string& id,
void ClientSession::construct()
{
std::unique_lock<std::mutex> lock(GlobalSessionMapMutex);
- GlobalSessionMap[getId()] = shared_from_this();
+ MessageHandlerInterface::initialize();
+ GlobalSessionMap[getId()] = client_from_this();
}
ClientSession::~ClientSession()
@@ -444,7 +447,7 @@ bool ClientSession::_handleInput(const char *buffer, int length)
}
else if (tokens.equals(0, "canceltiles"))
{
- docBroker->cancelTileRequests(shared_from_this());
+ docBroker->cancelTileRequests(client_from_this());
return true;
}
else if (tokens.equals(0, "commandvalues"))
@@ -678,7 +681,7 @@ bool ClientSession::_handleInput(const char *buffer, int length)
else
LOG_INF("Tileprocessed message with an unknown tile ID");
- docBroker->sendRequestedTiles(shared_from_this());
+ docBroker->sendRequestedTiles(client_from_this());
return true;
}
else if (tokens.equals(0, "removesession")) {
@@ -882,7 +885,7 @@ bool ClientSession::sendTile(const char * /*buffer*/, int /*length*/, const Stri
{
TileDesc tileDesc = TileDesc::parse(tokens);
tileDesc.setNormalizedViewId(getCanonicalViewId());
- docBroker->handleTileRequest(tileDesc, shared_from_this());
+ docBroker->handleTileRequest(tileDesc, client_from_this());
}
catch (const std::exception& exc)
{
@@ -900,7 +903,7 @@ bool ClientSession::sendCombinedTiles(const char* /*buffer*/, int /*length*/, co
{
TileCombined tileCombined = TileCombined::parse(tokens);
tileCombined.setNormalizedViewId(getCanonicalViewId());
- docBroker->handleTileCombinedRequest(tileCombined, shared_from_this());
+ docBroker->handleTileCombinedRequest(tileCombined, client_from_this());
}
catch (const std::exception& exc)
{
@@ -981,17 +984,13 @@ void ClientSession::setReadOnly()
sendTextFrame("perm: readonly");
}
-int ClientSession::getPollEvents(std::chrono::steady_clock::time_point /* now */,
- int & /* timeoutMaxMs */)
+bool ClientSession::hasQueuedMessages() const
{
- LOG_TRC(getName() << " ClientSession has " << _senderQueue.size() << " write message(s) queued.");
- int events = POLLIN;
- if (_senderQueue.size())
- events |= POLLOUT;
- return events;
+ return _senderQueue.size() > 0;
}
-void ClientSession::performWrites()
+ /// Please send them to me then.
+void ClientSession::writeQueuedMessages()
{
LOG_TRC(getName() << " ClientSession: performing writes.");
@@ -1706,11 +1705,10 @@ void ClientSession::dumpState(std::ostream& os)
<< "\n\t\tclipboardKeys[1]: " << _clipboardKeys[1]
<< "\n\t\tclip sockets: " << _clipSockets.size();
- std::shared_ptr<StreamSocket> socket = getSocket().lock();
- if (socket)
+ if (_protocol)
{
uint64_t sent, recv;
- socket->getIOStats(sent, recv);
+ _protocol->getIOStats(sent, recv);
os << "\n\t\tsent/keystroke: " << (double)sent/_keyEvents << "bytes";
}
@@ -1781,7 +1779,7 @@ void ClientSession::handleTileInvalidation(const std::string& message,
{
TileCombined tileCombined = TileCombined::create(invalidTiles);
tileCombined.setNormalizedViewId(normalizedViewId);
- docBroker->handleTileCombinedRequest(tileCombined, shared_from_this());
+ docBroker->handleTileCombinedRequest(tileCombined, client_from_this());
}
}
diff --git a/wsd/ClientSession.hpp b/wsd/ClientSession.hpp
index fe39b9e7d..b47285dd0 100644
--- a/wsd/ClientSession.hpp
+++ b/wsd/ClientSession.hpp
@@ -24,12 +24,12 @@
class DocumentBroker;
-
/// Represents a session to a LOOL client, in the WSD process.
-class ClientSession final : public Session, public std::enable_shared_from_this<ClientSession>
+class ClientSession final : public Session
{
public:
- ClientSession(const std::string& id,
+ ClientSession(const std::shared_ptr<ProtocolHandlerInterface>& ws,
+ const std::string& id,
const std::shared_ptr<DocumentBroker>& docBroker,
const Poco::URI& uriPublic,
const bool isReadOnly,
@@ -174,14 +174,19 @@ public:
void rotateClipboardKey(bool notifyClient);
private:
+ std::shared_ptr<ClientSession> client_from_this()
+ {
+ return std::static_pointer_cast<ClientSession>(shared_from_this());
+ }
+
/// SocketHandler: disconnection event.
void onDisconnect() override;
- /// Does SocketHandler: have data or timeouts to setup.
- int getPollEvents(std::chrono::steady_clock::time_point /* now */,
- int & /* timeoutMaxMs */) override;
- /// SocketHandler: write to socket.
- void performWrites() override;
+ /// Does SocketHandler: have messages to send ?
+ bool hasQueuedMessages() const override;
+
+ /// SocketHandler: send those messages
+ void writeQueuedMessages() override;
virtual bool _handleInput(const char* buffer, int length) override;
diff --git a/wsd/DocumentBroker.cpp b/wsd/DocumentBroker.cpp
index b9e7e983c..8b0c883c0 100644
--- a/wsd/DocumentBroker.cpp
+++ b/wsd/DocumentBroker.cpp
@@ -1468,6 +1468,7 @@ void DocumentBroker::finalRemoveSession(const std::string& id)
// Remove. The caller must have a reference to the session
// in question, lest we destroy from underneath them.
+ it->second->dispose();
_sessions.erase(it);
const size_t count = _sessions.size();
@@ -1497,11 +1498,12 @@ void DocumentBroker::finalRemoveSession(const std::string& id)
}
}
-std::shared_ptr<ClientSession> DocumentBroker::createNewClientSession(const WebSocketHandler* ws,
- const std::string& id,
- const Poco::URI& uriPublic,
- const bool isReadOnly,
- const std::string& hostNoTrust)
+std::shared_ptr<ClientSession> DocumentBroker::createNewClientSession(
+ const std::shared_ptr<ProtocolHandlerInterface> &ws,
+ const std::string& id,
+ const Poco::URI& uriPublic,
+ const bool isReadOnly,
+ const std::string& hostNoTrust)
{
try
{
@@ -1510,13 +1512,13 @@ std::shared_ptr<ClientSession> DocumentBroker::createNewClientSession(const WebS
{
const std::string statusReady = "statusindicator: ready";
LOG_TRC("Sending to Client [" << statusReady << "].");
- ws->sendMessage(statusReady);
+ ws->sendTextMessage(statusReady, statusReady.size());
}
// In case of WOPI, if this session is not set as readonly, it might be set so
// later after making a call to WOPI host which tells us the permission on files
// (UserCanWrite param).
- auto session = std::make_shared<ClientSession>(id, shared_from_this(), uriPublic, isReadOnly, hostNoTrust);
+ auto session = std::make_shared<ClientSession>(ws, id, shared_from_this(), uriPublic, isReadOnly, hostNoTrust);
session->construct();
return session;
@@ -2252,7 +2254,9 @@ bool ConvertToBroker::startConversion(SocketDisposition &disposition, const std:
// Create a session to load the document.
const bool isReadOnly = true;
- _clientSession = std::make_shared<ClientSession>(id, docBroker, getPublicUri(), isReadOnly, "nocliphost");
+ // FIXME: associate this with moveSocket (?)
+ std::shared_ptr<ProtocolHandlerInterface> nullPtr;
+ _clientSession = std::make_shared<ClientSession>(nullPtr, id, docBroker, getPublicUri(), isReadOnly, "nocliphost");
_clientSession->construct();
if (!_clientSession)
diff --git a/wsd/DocumentBroker.hpp b/wsd/DocumentBroker.hpp
index f56bd1e3f..68369d274 100644
--- a/wsd/DocumentBroker.hpp
+++ b/wsd/DocumentBroker.hpp
@@ -244,11 +244,12 @@ public:
void finalRemoveSession(const std::string& id);
/// Create new client session
- std::shared_ptr<ClientSession> createNewClientSession(const WebSocketHandler* ws,
- const std::string& id,
- const Poco::URI& uriPublic,
- const bool isReadOnly,
- const std::string& hostNoTrust);
+ std::shared_ptr<ClientSession> createNewClientSession(
+ const std::shared_ptr<ProtocolHandlerInterface> &ws,
+ const std::string& id,
+ const Poco::URI& uriPublic,
+ const bool isReadOnly,
+ const std::string& hostNoTrust);
/// Thread safe termination of this broker if it has a lingering thread
void joinThread();
diff --git a/wsd/LOOLWSD.cpp b/wsd/LOOLWSD.cpp
index b009bcaa5..241867af7 100644
--- a/wsd/LOOLWSD.cpp
+++ b/wsd/LOOLWSD.cpp
@@ -236,7 +236,7 @@ namespace
{
#if ENABLE_SUPPORT_KEY
-inline void shutdownLimitReached(WebSocketHandler& ws)
+inline void shutdownLimitReached(const std::shared_ptr<WebSocketHandler>& ws)
{
const std::string error = Poco::format(PAYLOAD_UNAVAILABLE_LIMIT_REACHED, LOOLWSD::MaxDocuments, LOOLWSD::MaxConnections);
LOG_INF("Sending client 'hardlimitreached' message: " << error);
@@ -244,10 +244,10 @@ inline void shutdownLimitReached(WebSocketHandler& ws)
try
{
// Let the client know we are shutting down.
- ws.sendMessage(error);
+ ws->sendMessage(error);
// Shutdown.
- ws.shutdown(WebSocketHandler::StatusCodes::POLICY_VIOLATION);
+ ws->shutdown(WebSocketHandler::StatusCodes::POLICY_VIOLATION);
}
catch (const std::exception& ex)
{
@@ -1728,11 +1728,12 @@ std::mutex Connection::Mutex;
/// Otherwise, creates and adds a new one to DocBrokers.
/// May return null if terminating or MaxDocuments limit is reached.
/// After returning a valid instance DocBrokers must be cleaned up after exceptions.
-static std::shared_ptr<DocumentBroker> findOrCreateDocBroker(WebSocketHandler& ws,
- const std::string& uri,
- const std::string& docKey,
- const std::string& id,
- const Poco::URI& uriPublic)
+static std::shared_ptr<DocumentBroker>
+ findOrCreateDocBroker(const std::shared_ptr<WebSocketHandler>& ws,
+ const std::string& uri,
+ const std::string& docKey,
+ const std::string& id,
+ const Poco::URI& uriPublic)
{
LOG_INF("Find or create DocBroker for docKey [" << docKey <<
"] for session [" << id << "] on url [" << LOOLWSD::anonymizeUrl(uriPublic.toString()) << "].");
@@ -1761,8 +1762,8 @@ static std::shared_ptr<DocumentBroker> findOrCreateDocBroker(WebSocketHandler& w
if (docBroker->isMarkedToDestroy())
{
LOG_WRN("DocBroker with docKey [" << docKey << "] that is marked to be destroyed. Rejecting client request.");
- ws.sendMessage("error: cmd=load kind=docunloading");
- ws.shutdown(WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY, "error: cmd=load kind=docunloading");
+ ws->sendMessage("error: cmd=load kind=docunloading");
+ ws->shutdown(WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY, "error: cmd=load kind=docunloading");
return nullptr;
}
}
@@ -1780,7 +1781,7 @@ static std::shared_ptr<DocumentBroker> findOrCreateDocBroker(WebSocketHandler& w
// Indicate to the client that we're connecting to the docbroker.
const std::string statusConnect = "statusindicator: connect";
LOG_TRC("Sending to Client [" << statusConnect << "].");
- ws.sendMessage(statusConnect);
+ ws->sendMessage(statusConnect);
if (!docBroker)
{
@@ -1932,6 +1933,11 @@ private:
addNewChild(child);
});
}
+ catch (const std::bad_weak_ptr&)
+ {
+ // Using shared_from_this() from a constructor is not good.
+ assert(false);
+ }
catch (const std::exception& exc)
{
// Probably don't have enough data just yet.
@@ -1995,7 +2001,7 @@ public:
#endif
/// Handles incoming connections and dispatches to the appropriate handler.
-class ClientRequestDispatcher : public SocketHandlerInterface
+class ClientRequestDispatcher : public SimpleSocketHandler
{
public:
ClientRequestDispatcher()
@@ -2780,7 +2786,7 @@ private:
LOG_TRC("Client WS request: " << request.getURI() << ", url: " << url << ", socket #" << socket->getFD());
// First Upgrade.
- WebSocketHandler ws(_socket, request);
+ auto ws = std::make_shared<WebSocketHandler>(_socket, request);
// Response to clients beyond this point is done via WebSocket.
try
@@ -2807,7 +2813,7 @@ private:
// Indicate to the client that document broker is searching.
const std::string status("statusindicator: find");
LOG_TRC("Sending to Client [" << status << "].");
- ws.sendMessage(status);
+ ws->sendMessage(status);
LOG_INF("Sanitized URI [" << LOOLWSD::anonymizeUrl(url) << "] to [" << LOOLWSD::anonymizeUrl(uriPublic.toString()) <<
"] and mapped to docKey [" << docKey << "] for session [" << _id << "].");
@@ -2837,11 +2843,11 @@ private:
#endif
std::shared_ptr<ClientSession> clientSession =
- docBroker->createNewClientSession(&ws, _id, uriPublic, isReadOnly, hostNoTrust);
+ docBroker->createNewClientSession(ws, _id, uriPublic, isReadOnly, hostNoTrust);
if (clientSession)
{
// Transfer the client socket to the DocumentBroker when we get back to the poll:
- disposition.setMove([docBroker, clientSession]
+ disposition.setMove([docBroker, clientSession, ws]
(const std::shared_ptr<Socket> &moveSocket)
{
// Make sure the thread is running before adding callback.
@@ -2850,16 +2856,16 @@ private:
// We no longer own this socket.
moveSocket->setThreadOwner(std::thread::id());
- docBroker->addCallback([docBroker, moveSocket, clientSession]()
+ docBroker->addCallback([docBroker, moveSocket, clientSession, ws]()
{
try
{
auto streamSocket = std::static_pointer_cast<StreamSocket>(moveSocket);
- // Set the ClientSession to handle Socket events.
- streamSocket->setHandler(clientSession);
- LOG_DBG("Socket #" << moveSocket->getFD() << " handler is " << clientSession->getName());
+ // Set WebSocketHandler's socket after its construction for shared_ptr goodness.
+ streamSocket->setHandler(ws);
+ LOG_DBG("Socket #" << moveSocket->getFD() << " handler is " << clientSession->getName());
// Move the socket into DocBroker.
docBroker->addSocketToPoll(moveSocket);
@@ -2868,7 +2874,8 @@ private:
checkDiskSpaceAndWarnClients(true);
#if !ENABLE_SUPPORT_KEY
- // Users of development versions get just an info when reaching max documents or connections
+ // Users of development versions get just an info
+ // when reaching max documents or connections
checkSessionLimitsAndWarnClients();
#endif
}
@@ -2909,8 +2916,8 @@ private:
{
LOG_ERR("Error while handling Client WS Request: " << exc.what());
const std::string msg = "error: cmd=internal kind=load";
- ws.sendMessage(msg);
- ws.shutdown(WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY, msg);
+ ws->sendMessage(msg);
+ ws->shutdown(WebSocketHandler::StatusCodes::ENDPOINT_GOING_AWAY, msg);
}
}
diff --git a/wsd/TestStubs.cpp b/wsd/TestStubs.cpp
index b75499ec2..ca04416da 100644
--- a/wsd/TestStubs.cpp
+++ b/wsd/TestStubs.cpp
@@ -25,16 +25,16 @@ void ClientSession::enqueueSendMessage(const std::shared_ptr<Message>& /*data*/)
ClientSession::~ClientSession() {}
-void ClientSession::performWrites() {}
-
void ClientSession::onDisconnect() {}
+bool ClientSession::hasQueuedMessages() const { return false; }
+
+void ClientSession::writeQueuedMessages() {}
+
void ClientSession::dumpState(std::ostream& /*os*/) {}
void ClientSession::setReadOnly() {}
bool ClientSession::_handleInput(const char* /*buffer*/, int /*length*/) { return false; }
-int ClientSession::getPollEvents(std::chrono::steady_clock::time_point /* now */, int & /* timeoutMaxMs */) { return 0; }
-
/* vim:set shiftwidth=4 softtabstop=4 expandtab: */
More information about the Libreoffice-commits
mailing list