[Libreoffice-commits] online.git: common/Unit.hpp net/DelaySocket.cpp net/loolnb.cpp net/ServerSocket.hpp net/Socket.cpp net/Socket.hpp net/WebSocketHandler.hpp test/UnitFuzz.cpp wsd/ClientSession.cpp wsd/ClientSession.hpp wsd/LOOLWSD.cpp

Michael Meeks michael.meeks at collabora.com
Fri May 5 17:37:48 UTC 2017


 common/Unit.hpp          |    1 
 net/DelaySocket.cpp      |    7 +-
 net/ServerSocket.hpp     |    7 +-
 net/Socket.cpp           |   10 ++++
 net/Socket.hpp           |  116 +++++++++++++++++++++++++----------------------
 net/WebSocketHandler.hpp |    4 -
 net/loolnb.cpp           |    8 +--
 test/UnitFuzz.cpp        |    1 
 wsd/ClientSession.cpp    |    8 +--
 wsd/ClientSession.hpp    |    2 
 wsd/LOOLWSD.cpp          |   61 ++++++++++++------------
 11 files changed, 121 insertions(+), 104 deletions(-)

New commits:
commit 9e45fb30d7f33b57fda9f615447ae8ac9b920fc1
Author: Michael Meeks <michael.meeks at collabora.com>
Date:   Fri May 5 11:51:43 2017 +0100

    SocketDisposition: push it down the stack, and cleanup around that.
    
    Dung out overlapping return enumerations. Move more work into 'move'
    callbacks at a safer time, etc.
    
    Change-Id: I62ba5a35f12073b7b9c8de4674be9dae519a8aca

diff --git a/common/Unit.hpp b/common/Unit.hpp
index e8197fd1..5f8d20ea 100644
--- a/common/Unit.hpp
+++ b/common/Unit.hpp
@@ -177,6 +177,7 @@ public:
     /// Intercept incoming requests, so unit tests can silently communicate
     virtual bool filterHandleRequest(
         TestRequest /* type */,
+        SocketDisposition & /* disposition */,
         WebSocketHandler & /* handler */)
     {
         return false;
diff --git a/net/DelaySocket.cpp b/net/DelaySocket.cpp
index 723357c1..20990e5c 100644
--- a/net/DelaySocket.cpp
+++ b/net/DelaySocket.cpp
@@ -122,7 +122,8 @@ public:
         _state = newState;
     }
 
-    HandleResult handlePoll(std::chrono::steady_clock::time_point now, int events) override
+    void handlePoll(SocketDisposition &disposition,
+                    std::chrono::steady_clock::time_point now, int events) override
     {
         if (_state == ReadWrite && (events & POLLIN))
         {
@@ -215,9 +216,7 @@ public:
         }
 
         if (_state == Closed)
-            return HandleResult::SOCKET_CLOSED;
-        else
-            return HandleResult::CONTINUE;
+            disposition.setClosed();
     }
 };
 
diff --git a/net/ServerSocket.hpp b/net/ServerSocket.hpp
index 805430ea..4d4bb353 100644
--- a/net/ServerSocket.hpp
+++ b/net/ServerSocket.hpp
@@ -88,8 +88,9 @@ public:
 
     void dumpState(std::ostream& os) override;
 
-    HandleResult handlePoll(std::chrono::steady_clock::time_point /* now */,
-                            int events) override
+    void handlePoll(SocketDisposition &,
+                    std::chrono::steady_clock::time_point /* now */,
+                    int events) override
     {
         if (events & POLLIN)
         {
@@ -103,8 +104,6 @@ public:
             LOG_DBG("Accepted client #" << clientSocket->getFD());
             _clientPoller.insertNewSocket(clientSocket);
         }
-
-        return Socket::HandleResult::CONTINUE;
     }
 
 private:
diff --git a/net/Socket.cpp b/net/Socket.cpp
index ad912880..5faeced7 100644
--- a/net/Socket.cpp
+++ b/net/Socket.cpp
@@ -122,6 +122,16 @@ void ServerSocket::dumpState(std::ostream& os)
     os << "\t" << getFD() << "\t<accept>\n";
 }
 
+
+void SocketDisposition::execute()
+{
+    // We should have hard ownership of this socket.
+    assert(_socket->getThreadOwner() == std::this_thread::get_id());
+    if (_socketMove)
+        _socketMove(_socket);
+    _socketMove = nullptr;
+}
+
 namespace {
 
 void dump_hex (const char *legend, const char *prefix, std::vector<char> buffer)
diff --git a/net/Socket.hpp b/net/Socket.hpp
index 694e82a5..0468eb9c 100644
--- a/net/Socket.hpp
+++ b/net/Socket.hpp
@@ -44,6 +44,48 @@ namespace Poco
     }
 }
 
+class Socket;
+
+/// Helper to allow us to easily defer the movement of a socket
+/// between polls to clarify thread ownership.
+class SocketDisposition
+{
+    typedef std::function<void(const std::shared_ptr<Socket> &)> MoveFunction;
+    enum class Type { CONTINUE, CLOSED, MOVE };
+
+    Type _disposition;
+    MoveFunction _socketMove;
+    std::shared_ptr<Socket> _socket;
+
+public:
+    SocketDisposition(const std::shared_ptr<Socket> &socket) :
+        _disposition(Type::CONTINUE),
+        _socket(socket)
+    {}
+    ~SocketDisposition()
+    {
+        assert (!_socketMove);
+    }
+    void setMove()
+    {
+        _disposition = Type::MOVE;
+    }
+    void setMove(MoveFunction moveFn)
+    {
+        _socketMove = moveFn;
+        _disposition = Type::MOVE;
+    }
+    void setClosed()
+    {
+        _disposition = Type::CLOSED;
+    }
+    bool isMove() { return _disposition == Type::MOVE; }
+    bool isClosed() { return _disposition == Type::CLOSED; }
+
+    /// Perform the queued up work.
+    void execute();
+};
+
 /// A non-blocking, streaming socket.
 class Socket
 {
@@ -86,8 +128,9 @@ public:
                               int &timeoutMaxMs) = 0;
 
     /// Handle results of events returned from poll
-    enum class HandleResult { CONTINUE, SOCKET_CLOSED, MOVED };
-    virtual HandleResult handlePoll(std::chrono::steady_clock::time_point now, int events) = 0;
+    virtual void handlePoll(SocketDisposition &disposition,
+                            std::chrono::steady_clock::time_point now,
+                            int events) = 0;
 
     /// manage latency issues around packet aggregation
     virtual void setNoDelay()
@@ -411,26 +454,30 @@ public:
         // Fire the poll callbacks and remove dead fds.
         std::chrono::steady_clock::time_point newNow =
             std::chrono::steady_clock::now();
+
         for (int i = static_cast<int>(size) - 1; i >= 0; --i)
         {
-            Socket::HandleResult res = Socket::HandleResult::SOCKET_CLOSED;
+            SocketDisposition disposition(_pollSockets[i]);
             try
             {
-                res = _pollSockets[i]->handlePoll(newNow, _pollFds[i].revents);
+                _pollSockets[i]->handlePoll(disposition, newNow,
+                                            _pollFds[i].revents);
             }
             catch (const std::exception& exc)
             {
                 LOG_ERR("Error while handling poll for socket #" <<
                         _pollFds[i].fd << " in " << _name << ": " << exc.what());
+                disposition.setClosed();
             }
 
-            if (res == Socket::HandleResult::SOCKET_CLOSED ||
-                res == Socket::HandleResult::MOVED)
+            if (disposition.isMove() || disposition.isClosed())
             {
                 LOG_DBG("Removing socket #" << _pollFds[i].fd << " (of " <<
                         _pollSockets.size() << ") from " << _name);
                 _pollSockets.erase(_pollSockets.begin() + i);
             }
+
+            disposition.execute();
         }
     }
 
@@ -608,14 +655,8 @@ public:
     /// Will be called exactly once.
     virtual void onConnect(const std::shared_ptr<StreamSocket>& socket) = 0;
 
-    enum class SocketOwnership
-    {
-        UNCHANGED,  //< Same socket poll, business as usual.
-        MOVED       //< The socket poll is now different.
-    };
-
     /// Called after successful socket reads.
-    virtual SocketHandlerInterface::SocketOwnership handleIncomingMessage() = 0;
+    virtual void handleIncomingMessage(SocketDisposition &disposition) = 0;
 
     /// Prepare our poll record; adjust @timeoutMaxMs downwards
     /// for timeouts, based on current time @now.
@@ -773,15 +814,16 @@ protected:
 
     /// Called when a polling event is received.
     /// @events is the mask of events that triggered the wake.
-    HandleResult handlePoll(std::chrono::steady_clock::time_point now,
-                            const int events) override
+    void handlePoll(SocketDisposition &disposition,
+                    std::chrono::steady_clock::time_point now,
+                    const int events) override
     {
         assertCorrectThread();
 
         _socketHandler->checkTimeout(now);
 
         if (!events)
-            return Socket::HandleResult::CONTINUE;
+            return;
 
         // FIXME: need to close input, but not output (?)
         bool closed = (events & (POLLHUP | POLLERR | POLLNVAL));
@@ -801,8 +843,9 @@ protected:
         while (!_inBuffer.empty() && oldSize != _inBuffer.size())
         {
             oldSize = _inBuffer.size();
-            if (_socketHandler->handleIncomingMessage() == SocketHandlerInterface::SocketOwnership::MOVED)
-                return Socket::HandleResult::MOVED;
+            _socketHandler->handleIncomingMessage(disposition);
+            if (disposition.isMove())
+                return;
         }
 
         do
@@ -837,8 +880,8 @@ protected:
             _socketHandler->onDisconnect();
         }
 
-        return _closed ? HandleResult::SOCKET_CLOSED :
-                         HandleResult::CONTINUE;
+        if (_closed)
+            disposition.setClosed();
     }
 
     /// Override to write data out to socket.
@@ -917,39 +960,6 @@ protected:
     friend class SimpleResponseClient;
 };
 
-/// Helper to allow us to easily defer the movement of a socket
-/// between polls to clarify thread ownership.
-class SocketDisposition
-{
-    std::shared_ptr<StreamSocket> _socket;
-    typedef std::function<void(const std::shared_ptr<StreamSocket> &)> MoveFunction;
-    MoveFunction _socketMove;
-    SocketHandlerInterface::SocketOwnership _socketOwnership;
-public:
-    SocketDisposition(const std::shared_ptr<StreamSocket> &socket) :
-        _socket(socket),
-        _socketOwnership(SocketHandlerInterface::SocketOwnership::UNCHANGED)
-    {}
-    ~SocketDisposition()
-    {
-        assert (!_socketMove);
-    }
-    void setMove(MoveFunction moveFn)
-    {
-        _socketMove = moveFn;
-        _socketOwnership = SocketHandlerInterface::SocketOwnership::MOVED;
-    }
-    SocketHandlerInterface::SocketOwnership execute()
-    {
-        // We should have hard ownership of this socket.
-        assert(_socket->getThreadOwner() == std::this_thread::get_id());
-        if (_socketMove)
-            _socketMove(_socket);
-        _socketMove = nullptr;
-        return _socketOwnership;
-    }
-};
-
 namespace HttpHelper
 {
     /// Sends file as HTTP response.
diff --git a/net/WebSocketHandler.hpp b/net/WebSocketHandler.hpp
index 14d97f81..4ff01c36 100644
--- a/net/WebSocketHandler.hpp
+++ b/net/WebSocketHandler.hpp
@@ -250,7 +250,7 @@ public:
     }
 
     /// Implementation of the SocketHandlerInterface.
-    virtual SocketHandlerInterface::SocketOwnership handleIncomingMessage() override
+    virtual void handleIncomingMessage(SocketDisposition&) override
     {
         auto socket = _socket.lock();
         if (socket == nullptr)
@@ -262,8 +262,6 @@ public:
             while (handleOneIncomingMessage(socket))
                 ; // can have multiple msgs in one recv'd packet.
         }
-
-        return SocketHandlerInterface::SocketOwnership::UNCHANGED;
     }
 
     int getPollEvents(std::chrono::steady_clock::time_point now,
diff --git a/net/loolnb.cpp b/net/loolnb.cpp
index a014173a..e268b067 100644
--- a/net/loolnb.cpp
+++ b/net/loolnb.cpp
@@ -45,7 +45,7 @@ public:
     {
     }
 
-    virtual SocketHandlerInterface::SocketOwnership handleIncomingMessage() override
+    virtual void handleIncomingMessage(SocketDisposition &disposition) override
     {
         LOG_TRC("incoming WebSocket message");
         if (_wsState == WSState::HTTP)
@@ -89,16 +89,16 @@ public:
 
                 std::string str = oss.str();
                 socket->_outBuffer.insert(socket->_outBuffer.end(), str.begin(), str.end());
-                return SocketHandlerInterface::SocketOwnership::UNCHANGED;
+                return;
             }
             else if (tokens.count() == 2 && tokens[1] == "ws")
             {
                 upgradeToWebSocket(req);
-                return SocketHandlerInterface::SocketOwnership::UNCHANGED;
+                return;
             }
         }
 
-        return WebSocketHandler::handleIncomingMessage();
+        WebSocketHandler::handleIncomingMessage(disposition);
     }
 
     virtual void handleMessage(const bool fin, const WSOpCode code, std::vector<char> &data) override
diff --git a/test/UnitFuzz.cpp b/test/UnitFuzz.cpp
index 49575b5d..68367884 100644
--- a/test/UnitFuzz.cpp
+++ b/test/UnitFuzz.cpp
@@ -121,6 +121,7 @@ public:
 
     virtual bool filterHandleRequest(
         TestRequest /* type */,
+        SocketDisposition & /* disposition */,
         WebSocketHandler & /* socket */) override
     {
 #if 0 // loolnb
diff --git a/wsd/ClientSession.cpp b/wsd/ClientSession.cpp
index 0bec1538..55b17d64 100644
--- a/wsd/ClientSession.cpp
+++ b/wsd/ClientSession.cpp
@@ -51,13 +51,13 @@ ClientSession::~ClientSession()
     LOG_INF("~ClientSession dtor [" << getName() << "], current number of connections: " << curConnections);
 }
 
-SocketHandlerInterface::SocketOwnership ClientSession::handleIncomingMessage()
+void ClientSession::handleIncomingMessage(SocketDisposition &disposition)
 {
     if (UnitWSD::get().filterHandleRequest(
-            UnitWSD::TestRequest::Client, *this))
-        return SocketHandlerInterface::SocketOwnership::UNCHANGED;
+            UnitWSD::TestRequest::Client, disposition, *this))
+        return;
 
-    return Session::handleIncomingMessage();
+    Session::handleIncomingMessage(disposition);
 }
 
 bool ClientSession::_handleInput(const char *buffer, int length)
diff --git a/wsd/ClientSession.hpp b/wsd/ClientSession.hpp
index b0eefecf..22fad016 100644
--- a/wsd/ClientSession.hpp
+++ b/wsd/ClientSession.hpp
@@ -30,7 +30,7 @@ public:
 
     virtual ~ClientSession();
 
-    SocketHandlerInterface::SocketOwnership handleIncomingMessage() override;
+    void handleIncomingMessage(SocketDisposition &) override;
 
     void setReadOnly() override;
 
diff --git a/wsd/LOOLWSD.cpp b/wsd/LOOLWSD.cpp
index 82126289..f06f67ef 100644
--- a/wsd/LOOLWSD.cpp
+++ b/wsd/LOOLWSD.cpp
@@ -1380,16 +1380,17 @@ private:
     }
 
     /// Called after successful socket reads.
-    SocketHandlerInterface::SocketOwnership handleIncomingMessage() override
+    void handleIncomingMessage(SocketDisposition &disposition) override
     {
         if (UnitWSD::get().filterHandleRequest(
-                UnitWSD::TestRequest::Prisoner, *this))
-            return SocketHandlerInterface::SocketOwnership::UNCHANGED;
+                UnitWSD::TestRequest::Prisoner, disposition, *this))
+            return;
 
         if (_childProcess.lock())
         {
             // FIXME: inelegant etc. - derogate to websocket code
-            return WebSocketHandler::handleIncomingMessage();
+            WebSocketHandler::handleIncomingMessage(disposition);
+            return;
         }
 
         auto socket = _socket.lock();
@@ -1402,7 +1403,7 @@ private:
         if (itBody == in.end())
         {
             LOG_TRC("#" << socket->getFD() << " doesn't have enough data yet.");
-            return SocketHandlerInterface::SocketOwnership::UNCHANGED;
+            return;
         }
 
         // Skip the marker.
@@ -1434,7 +1435,7 @@ private:
             if (request.getURI().find(NEW_CHILD_URI) != 0)
             {
                 LOG_ERR("Invalid incoming URI.");
-                return SocketHandlerInterface::SocketOwnership::UNCHANGED;
+                return;
             }
 
             // New Child is spawned.
@@ -1455,7 +1456,7 @@ private:
             if (pid <= 0)
             {
                 LOG_ERR("Invalid PID in child URI [" << request.getURI() << "].");
-                return SocketHandlerInterface::SocketOwnership::UNCHANGED;
+                return;
             }
 
             in.clear();
@@ -1466,24 +1467,21 @@ private:
 
             auto child = std::make_shared<ChildProcess>(pid, socket, request);
 
-            // Drop pretentions of ownership before adding to the list.
-            socket->setThreadOwner(std::thread::id(0));
-
             _childProcess = child; // weak
-            addNewChild(child);
 
             // Remove from prisoner poll since there is no activity
             // until we attach the childProcess (with this socket)
             // to a docBroker, which will do the polling.
-            return SocketHandlerInterface::SocketOwnership::MOVED;
+            disposition.setMove([child](const std::shared_ptr<Socket> &){
+                    // Drop pretentions of ownership before adding to the list.
+                    addNewChild(child);
+                });
         }
         catch (const std::exception& exc)
         {
             // Probably don't have enough data just yet.
             // TODO: timeout if we never get enough.
         }
-
-        return SocketHandlerInterface::SocketOwnership::UNCHANGED;
     }
 
     /// Prisoner websocket fun ... (for now)
@@ -1538,7 +1536,7 @@ private:
     }
 
     /// Called after successful socket reads.
-    SocketHandlerInterface::SocketOwnership handleIncomingMessage() override
+    void handleIncomingMessage(SocketDisposition &disposition) override
     {
         auto socket = _socket.lock();
         std::vector<char>& in = socket->_inBuffer;
@@ -1551,7 +1549,7 @@ private:
         if (itBody == in.end())
         {
             LOG_DBG("#" << socket->getFD() << " doesn't have enough data yet.");
-            return SocketHandlerInterface::SocketOwnership::UNCHANGED;
+            return;
         }
 
         // Skip the marker.
@@ -1586,17 +1584,16 @@ private:
             if (contentLength != Poco::Net::HTTPMessage::UNKNOWN_CONTENT_LENGTH && available < contentLength)
             {
                 LOG_DBG("Not enough content yet: ContentLength: " << contentLength << ", available: " << available);
-                return SocketHandlerInterface::SocketOwnership::UNCHANGED;
+                return;
             }
         }
         catch (const std::exception& exc)
         {
             // Probably don't have enough data just yet.
             // TODO: timeout if we never get enough.
-            return SocketHandlerInterface::SocketOwnership::UNCHANGED;
+            return;
         }
 
-        SocketDisposition tailDisposition(socket);
         try
         {
             // Routing
@@ -1615,7 +1612,7 @@ private:
                 LOG_INF("Admin request: " << request.getURI());
                 if (AdminSocketHandler::handleInitialRequest(_socket, request))
                 {
-                    tailDisposition.setMove([](const std::shared_ptr<StreamSocket> &moveSocket){
+                    disposition.setMove([](const std::shared_ptr<Socket> &moveSocket){
                             // Hand the socket over to the Admin poll.
                             Admin::instance().insertNewSocket(moveSocket);
                         });
@@ -1644,12 +1641,12 @@ private:
                     reqPathTokens.count() > 0 && reqPathTokens[0] == "lool")
                 {
                     // All post requests have url prefix 'lool'.
-                    handlePostRequest(request, message, tailDisposition);
+                    handlePostRequest(request, message, disposition);
                 }
                 else if (reqPathTokens.count() > 2 && reqPathTokens[0] == "lool" && reqPathTokens[2] == "ws" &&
                          request.find("Upgrade") != request.end() && Poco::icompare(request["Upgrade"], "websocket") == 0)
                 {
-                    handleClientWsUpgrade(request, reqPathTokens[1], tailDisposition);
+                    handleClientWsUpgrade(request, reqPathTokens[1], disposition);
                 }
                 else
                 {
@@ -1678,7 +1675,6 @@ private:
         // if we succeeded - remove the request from our input buffer
         // we expect one request per socket
         in.erase(in.begin(), itBody);
-        return tailDisposition.execute();
     }
 
     int getPollEvents(std::chrono::steady_clock::time_point /* now */,
@@ -1819,7 +1815,7 @@ private:
     }
 
     void handlePostRequest(const Poco::Net::HTTPRequest& request, Poco::MemoryInputStream& message,
-                           SocketDisposition &tailDisposition)
+                           SocketDisposition &disposition)
     {
         LOG_INF("Post request: [" << request.getURI() << "]");
 
@@ -1863,8 +1859,8 @@ private:
                     auto clientSession = createNewClientSession(nullptr, _id, uriPublic, docBroker, isReadOnly);
                     if (clientSession)
                     {
-                        tailDisposition.setMove([docBroker, clientSession, format]
-                                                (const std::shared_ptr<StreamSocket> &moveSocket)
+                        disposition.setMove([docBroker, clientSession, format]
+                                            (const std::shared_ptr<Socket> &moveSocket)
                         { // Perform all of this after removing the socket
 
                         // Make sure the thread is running before adding callback.
@@ -1875,7 +1871,8 @@ private:
 
                         docBroker->addCallback([docBroker, moveSocket, clientSession, format]()
                         {
-                            clientSession->setSaveAsSocket(moveSocket);
+			    auto streamSocket = std::static_pointer_cast<StreamSocket>(moveSocket);
+                            clientSession->setSaveAsSocket(streamSocket);
 
                             // Move the socket into DocBroker.
                             docBroker->addSocketToPoll(moveSocket);
@@ -2028,7 +2025,7 @@ private:
     }
 
     void handleClientWsUpgrade(const Poco::Net::HTTPRequest& request, const std::string& url,
-                               SocketDisposition &tailDisposition)
+                               SocketDisposition &disposition)
     {
         auto socket = _socket.lock();
         if (!socket)
@@ -2082,8 +2079,8 @@ private:
             if (clientSession)
             {
                 // Transfer the client socket to the DocumentBroker when we get back to the poll:
-                tailDisposition.setMove([docBroker, clientSession]
-                                        (const std::shared_ptr<StreamSocket> &moveSocket)
+                disposition.setMove([docBroker, clientSession]
+                                    (const std::shared_ptr<Socket> &moveSocket)
                 {
                     // Make sure the thread is running before adding callback.
                     docBroker->startThread();
@@ -2093,8 +2090,10 @@ private:
 
                     docBroker->addCallback([docBroker, moveSocket, clientSession]()
                     {
+			auto streamSocket = std::static_pointer_cast<StreamSocket>(moveSocket);
+
                         // Set the ClientSession to handle Socket events.
-                        moveSocket->setHandler(clientSession);
+                        streamSocket->setHandler(clientSession);
                         LOG_DBG("Socket #" << moveSocket->getFD() << " handler is " << clientSession->getName());
 
                         // Move the socket into DocBroker.


More information about the Libreoffice-commits mailing list