[Spice-commits] 10 commits - configure.ac server/red-stream.c server/red-stream.h server/reds.c server/tests

Frediano Ziglio fziglio at kemper.freedesktop.org
Tue Jan 9 17:07:53 UTC 2018


 configure.ac             |    1 
 server/red-stream.c      |  330 +++++++++--------------
 server/red-stream.h      |   11 
 server/reds.c            |  119 --------
 server/tests/.gitignore  |    1 
 server/tests/Makefile.am |    4 
 server/tests/test-sasl.c |  665 +++++++++++++++++++++++++++++++++++++++++++++++
 7 files changed, 825 insertions(+), 306 deletions(-)

New commits:
commit 6c416f50980e21657dae2899d284d43f6f2be5a0
Author: Frediano Ziglio <fziglio at redhat.com>
Date:   Tue Dec 12 17:20:39 2017 +0000

    red-stream: Encapsulate all authentication state in RedSASLAuth
    
    Instead of having half state in RedSASL and half in RedSASLAuth
    move everything in RedSASLAuth. This also reduces memory usage
    when we are using SASL but we finish the authentication step.
    
    Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
    Acked-by: Christophe Fergeau <cfergeau at redhat.com>

diff --git a/server/red-stream.c b/server/red-stream.c
index 242e552c..5b2c084e 100644
--- a/server/red-stream.c
+++ b/server/red-stream.c
@@ -67,13 +67,6 @@ typedef struct RedSASL {
     unsigned int encodedOffset;
 
     SpiceBuffer inbuffer;
-
-    char *mechlist;
-    char *mechname;
-
-    /* temporary data during authentication */
-    unsigned int len;
-    char *data;
 } RedSASL;
 #endif
 
@@ -351,13 +344,8 @@ void red_stream_free(RedStream *s)
 #if HAVE_SASL
     if (s->priv->sasl.conn) {
         s->priv->sasl.runSSF = s->priv->sasl.wantSSF = 0;
-        s->priv->sasl.len = 0;
         s->priv->sasl.encodedLength = s->priv->sasl.encodedOffset = 0;
         s->priv->sasl.encoded = NULL;
-        g_free(s->priv->sasl.mechlist);
-        g_free(s->priv->sasl.mechname);
-        s->priv->sasl.mechlist = NULL;
-        g_free(s->priv->sasl.data);
         sasl_dispose(&s->priv->sasl.conn);
         s->priv->sasl.conn = NULL;
     }
@@ -735,6 +723,12 @@ static int auth_sasl_check_ssf(RedSASL *sasl, int *runSSF)
 
 typedef struct RedSASLAuth {
     RedStream *stream;
+    // list of mechanisms allowed, allocated and freed by SASL
+    char *mechlist;
+    // mech received
+    char *mechname;
+    uint32_t len;
+    char *data;
     // callback to call if success
     RedSaslResult result_cb;
     void *result_opaque;
@@ -743,6 +737,14 @@ typedef struct RedSASLAuth {
     AsyncReadError saved_error_cb;
 } RedSASLAuth;
 
+static void red_sasl_auth_free(RedSASLAuth *auth)
+{
+    g_free(auth->data);
+    g_free(auth->mechname);
+    g_free(auth->mechlist);
+    g_free(auth);
+}
+
 // handle SASL termination, either success or error
 // NOTE: After this function is called usually there should be a
 // return or the function should exit
@@ -750,7 +752,7 @@ static void red_sasl_async_result(RedSASLAuth *auth, RedSaslError err)
 {
     red_stream_set_async_error_handler(auth->stream, auth->saved_error_cb);
     auth->result_cb(auth->result_opaque, err);
-    g_free(auth);
+    red_sasl_auth_free(auth);
 }
 
 static void red_sasl_error(void *opaque, int err)
@@ -760,7 +762,7 @@ static void red_sasl_error(void *opaque, int err)
     if (auth->saved_error_cb) {
         auth->saved_error_cb(auth->result_opaque, err);
     }
-    g_free(auth);
+    red_sasl_auth_free(auth);
 }
 
 /*
@@ -799,32 +801,33 @@ static void red_sasl_handle_auth_steplen(void *opaque);
 
 static void red_sasl_handle_auth_step(void *opaque)
 {
-    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
+    RedSASLAuth *auth = opaque;
+    RedStream *stream = auth->stream;
     const char *serverout;
     unsigned int serveroutlen;
     int err;
     char *clientdata = NULL;
     RedSASL *sasl = &stream->priv->sasl;
-    uint32_t datalen = sasl->len;
+    uint32_t datalen = auth->len;
 
     /* NB, distinction of NULL vs "" is *critical* in SASL */
     if (datalen) {
-        clientdata = sasl->data;
+        clientdata = auth->data;
         clientdata[datalen - 1] = '\0'; /* Wire includes '\0', but make sure */
         datalen--; /* Don't count NULL byte when passing to _start() */
     }
 
-    if (sasl->mechname != NULL) {
+    if (auth->mechname != NULL) {
         spice_debug("Start SASL auth with mechanism %s. Data %p (%d bytes)",
-                   sasl->mechname, clientdata, datalen);
+                    auth->mechname, clientdata, datalen);
         err = sasl_server_start(sasl->conn,
-                                sasl->mechname,
+                                auth->mechname,
                                 clientdata,
                                 datalen,
                                 &serverout,
                                 &serveroutlen);
-        g_free(sasl->mechname);
-        sasl->mechname = NULL;
+        g_free(auth->mechname);
+        auth->mechname = NULL;
     } else {
         spice_debug("Step using SASL Data %p (%d bytes)", clientdata, datalen);
         err = sasl_server_step(sasl->conn,
@@ -837,13 +840,13 @@ static void red_sasl_handle_auth_step(void *opaque)
         err != SASL_CONTINUE) {
         spice_warning("sasl step failed %d (%s)",
                     err, sasl_errdetail(sasl->conn));
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
+        return red_sasl_async_result(auth, RED_SASL_ERROR_GENERIC);
     }
 
     if (serveroutlen > SASL_DATA_MAX_LEN) {
         spice_warning("sasl step reply data too long %d",
                       serveroutlen);
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
+        return red_sasl_async_result(auth, RED_SASL_ERROR_GENERIC);
     }
 
     spice_debug("SASL return data %d bytes, %p", serveroutlen, serverout);
@@ -862,8 +865,8 @@ static void red_sasl_handle_auth_step(void *opaque)
     if (err == SASL_CONTINUE) {
         spice_debug("%s", "Authentication must continue");
         /* Wait for step length */
-        red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                              red_sasl_handle_auth_steplen, opaque);
+        red_stream_async_read(stream, (uint8_t *)&auth->len, sizeof(uint32_t),
+                              red_sasl_handle_auth_steplen, auth);
         return;
     } else {
         int ssf;
@@ -882,7 +885,7 @@ static void red_sasl_handle_auth_step(void *opaque)
         sasl->runSSF = ssf;
         red_stream_disable_writev(stream); /* make sure writev isn't called directly anymore */
 
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_OK);
+        return red_sasl_async_result(auth, RED_SASL_ERROR_OK);
     }
 
 authreject:
@@ -890,70 +893,69 @@ authreject:
     red_stream_write_u32_le(stream, sizeof("Authentication failed"));
     red_stream_write_all(stream, "Authentication failed", sizeof("Authentication failed"));
 
-    red_sasl_async_result(opaque, RED_SASL_ERROR_AUTH_FAILED);
+    red_sasl_async_result(auth, RED_SASL_ERROR_AUTH_FAILED);
 }
 
 static void red_sasl_handle_auth_steplen(void *opaque)
 {
-    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
-    RedSASL *sasl = &stream->priv->sasl;
+    RedSASLAuth *auth = opaque;
 
-    sasl->len = GUINT32_FROM_LE(sasl->len);
-    spice_debug("Got steplen %d", sasl->len);
-    if (sasl->len > SASL_DATA_MAX_LEN) {
-        spice_warning("Too much SASL data %d", sasl->len);
-        return red_sasl_async_result(opaque, sasl->mechname ? RED_SASL_ERROR_INVALID_DATA : RED_SASL_ERROR_GENERIC);
+    auth->len = GUINT32_FROM_LE(auth->len);
+    uint32_t len = auth->len;
+    spice_debug("Got steplen %d", len);
+    if (len > SASL_DATA_MAX_LEN) {
+        spice_warning("Too much SASL data %d", len);
+        return red_sasl_async_result(opaque, auth->mechname ? RED_SASL_ERROR_INVALID_DATA : RED_SASL_ERROR_GENERIC);
     }
 
-    if (sasl->len == 0) {
-        return red_sasl_handle_auth_step(opaque);
+    if (len == 0) {
+        return red_sasl_handle_auth_step(auth);
     }
 
-    sasl->data = g_realloc(sasl->data, sasl->len);
-    red_stream_async_read(stream, (uint8_t *)sasl->data, sasl->len,
-                          red_sasl_handle_auth_step, opaque);
+    auth->data = g_realloc(auth->data, len);
+    red_stream_async_read(auth->stream, (uint8_t *)auth->data, len,
+                          red_sasl_handle_auth_step, auth);
 }
 
 
 
 static void red_sasl_handle_auth_mechname(void *opaque)
 {
-    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
-    RedSASL *sasl = &stream->priv->sasl;
+    RedSASLAuth *auth = opaque;
 
-    sasl->mechname[sasl->len] = '\0';
+    auth->mechname[auth->len] = '\0';
     spice_debug("Got client mechname '%s' check against '%s'",
-               sasl->mechname, sasl->mechlist);
+                auth->mechname, auth->mechlist);
 
     char quoted_mechname[SASL_MAX_MECHNAME_LEN + 4];
-    sprintf(quoted_mechname, ",%s,", sasl->mechname);
+    sprintf(quoted_mechname, ",%s,", auth->mechname);
 
-    if (strchr(sasl->mechname, ',') || strstr(sasl->mechlist, quoted_mechname) == NULL) {
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA);
+    if (strchr(auth->mechname, ',') || strstr(auth->mechlist, quoted_mechname) == NULL) {
+        return red_sasl_async_result(auth, RED_SASL_ERROR_INVALID_DATA);
     }
 
-    spice_debug("Validated mechname '%s'", sasl->mechname);
+    spice_debug("Validated mechname '%s'", auth->mechname);
 
-    red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                          red_sasl_handle_auth_steplen, opaque);
+    red_stream_async_read(auth->stream, (uint8_t *)&auth->len, sizeof(uint32_t),
+                          red_sasl_handle_auth_steplen, auth);
 }
 
 static void red_sasl_handle_auth_mechlen(void *opaque)
 {
-    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
-    RedSASL *sasl = &stream->priv->sasl;
+    RedSASLAuth *auth = opaque;
 
-    sasl->len = GUINT32_FROM_LE(sasl->len);
-    if (sasl->len < 1 || sasl->len > SASL_MAX_MECHNAME_LEN) {
-        spice_warning("Got bad client mechname len %d", sasl->len);
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
+    auth->len = GUINT32_FROM_LE(auth->len);
+    uint32_t len = auth->len;
+    if (len < 1 || len > SASL_MAX_MECHNAME_LEN) {
+        spice_warning("Got bad client mechname len %d", len);
+        return red_sasl_async_result(auth, RED_SASL_ERROR_GENERIC);
     }
 
-    sasl->mechname = g_malloc(sasl->len + 1);
+    auth->mechname = g_malloc(len + 1);
 
     spice_debug("Wait for client mechname");
-    red_stream_async_read(stream, (uint8_t *)sasl->mechname, sasl->len,
-                          red_sasl_handle_auth_mechname, opaque);
+    red_stream_async_read(auth->stream, (uint8_t *)auth->mechname, len,
+                          red_sasl_handle_auth_mechname, auth);
 }
 
 bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *result_opaque)
@@ -1050,11 +1052,9 @@ bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *resul
 
     spice_debug("Available mechanisms for client: '%s'", mechlist);
 
-    sasl->mechlist = g_strdup(mechlist);
-
     mechlistlen = strlen(mechlist);
     if (!red_stream_write_u32_le(stream, mechlistlen)
-        || !red_stream_write_all(stream, sasl->mechlist, mechlistlen)) {
+        || !red_stream_write_all(stream, mechlist, mechlistlen)) {
         spice_warning("SASL mechanisms write error");
         goto error;
     }
@@ -1064,10 +1064,11 @@ bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *resul
     auth->result_cb = result_cb;
     auth->result_opaque = result_opaque;
     auth->saved_error_cb = stream->priv->async_read.error;
-    red_stream_set_async_error_handler(stream, red_sasl_error);
+    auth->mechlist = g_strdup(mechlist);
 
     spice_debug("Wait for client mechname length");
-    red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
+    red_stream_set_async_error_handler(stream, red_sasl_error);
+    red_stream_async_read(stream, (uint8_t *)&auth->len, sizeof(uint32_t),
                           red_sasl_handle_auth_mechlen, auth);
 
     return true;
commit cb70583e5c6922c50544a4ba8808839ee68c2093
Author: Frediano Ziglio <fziglio at redhat.com>
Date:   Tue Dec 12 17:20:39 2017 +0000

    red-stream: Unify start and step passes
    
    Most of these function are identical.
    Only difference were basically debugging message but now
    with a proper tests are less important.
    The mechname field is used to differentiate between first step and
    following ones.
    
    Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
    Acked-by: Christophe Fergeau <cfergeau at redhat.com>

diff --git a/server/red-stream.c b/server/red-stream.c
index 5c2c2739..242e552c 100644
--- a/server/red-stream.c
+++ b/server/red-stream.c
@@ -782,111 +782,6 @@ static void red_sasl_error(void *opaque, int err)
 
 static void red_sasl_handle_auth_steplen(void *opaque);
 
-static void red_sasl_handle_auth_step(void *opaque)
-{
-    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
-    const char *serverout;
-    unsigned int serveroutlen;
-    int err;
-    char *clientdata = NULL;
-    RedSASL *sasl = &stream->priv->sasl;
-    uint32_t datalen = sasl->len;
-
-    /* NB, distinction of NULL vs "" is *critical* in SASL */
-    if (datalen) {
-        clientdata = sasl->data;
-        clientdata[datalen - 1] = '\0'; /* Wire includes '\0', but make sure */
-        datalen--; /* Don't count NULL byte when passing to _start() */
-    }
-
-    spice_debug("Step using SASL Data %p (%d bytes)",
-               clientdata, datalen);
-    err = sasl_server_step(sasl->conn,
-                           clientdata,
-                           datalen,
-                           &serverout,
-                           &serveroutlen);
-    if (err != SASL_OK &&
-        err != SASL_CONTINUE) {
-        spice_warning("sasl step failed %d (%s)",
-                      err, sasl_errdetail(sasl->conn));
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
-    }
-
-    if (serveroutlen > SASL_DATA_MAX_LEN) {
-        spice_warning("sasl step reply data too long %d",
-                      serveroutlen);
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
-    }
-
-    spice_debug("SASL return data %d bytes, %p", serveroutlen, serverout);
-
-    if (serveroutlen) {
-        serveroutlen += 1;
-        red_stream_write_u32_le(stream, serveroutlen);
-        red_stream_write_all(stream, serverout, serveroutlen);
-    } else {
-        red_stream_write_u32_le(stream, serveroutlen);
-    }
-
-    /* Whether auth is complete */
-    red_stream_write_u8(stream, err == SASL_CONTINUE ? 0 : 1);
-
-    if (err == SASL_CONTINUE) {
-        spice_debug("%s", "Authentication must continue (step)");
-        /* Wait for step length */
-        red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                              red_sasl_handle_auth_steplen, opaque);
-        return;
-    } else {
-        int ssf;
-
-        if (auth_sasl_check_ssf(sasl, &ssf) == 0) {
-            spice_warning("Authentication rejected for weak SSF");
-            goto authreject;
-        }
-
-        spice_debug("Authentication successful");
-        red_stream_write_u32_le(stream, SPICE_LINK_ERR_OK); /* Accept auth */
-
-        /*
-         * Delay writing in SSF encoded until now
-         */
-        sasl->runSSF = ssf;
-        red_stream_disable_writev(stream); /* make sure writev isn't called directly anymore */
-
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_OK);
-    }
-
-authreject:
-    red_stream_write_u32_le(stream, 1); /* Reject auth */
-    red_stream_write_u32_le(stream, sizeof("Authentication failed"));
-    red_stream_write_all(stream, "Authentication failed", sizeof("Authentication failed"));
-
-    red_sasl_async_result(opaque, RED_SASL_ERROR_AUTH_FAILED);
-}
-
-static void red_sasl_handle_auth_steplen(void *opaque)
-{
-    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
-    RedSASL *sasl = &stream->priv->sasl;
-
-    sasl->len = GUINT32_FROM_LE(sasl->len);
-    spice_debug("Got steplen %d", sasl->len);
-    if (sasl->len > SASL_DATA_MAX_LEN) {
-        spice_warning("Too much SASL data %d", sasl->len);
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
-    }
-
-    if (sasl->len == 0) {
-        red_sasl_handle_auth_step(opaque);
-    } else {
-        sasl->data = g_realloc(sasl->data, sasl->len);
-        red_stream_async_read(stream, (uint8_t *)sasl->data, sasl->len,
-                              red_sasl_handle_auth_step, opaque);
-    }
-}
-
 /*
  * Start Msg
  *
@@ -902,7 +797,7 @@ static void red_sasl_handle_auth_steplen(void *opaque)
  * u8 continue
  */
 
-static void red_sasl_handle_auth_start(void *opaque)
+static void red_sasl_handle_auth_step(void *opaque)
 {
     RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     const char *serverout;
@@ -915,27 +810,38 @@ static void red_sasl_handle_auth_start(void *opaque)
     /* NB, distinction of NULL vs "" is *critical* in SASL */
     if (datalen) {
         clientdata = sasl->data;
-        clientdata[datalen - 1] = '\0'; /* Should be on wire, but make sure */
+        clientdata[datalen - 1] = '\0'; /* Wire includes '\0', but make sure */
         datalen--; /* Don't count NULL byte when passing to _start() */
     }
 
-    spice_debug("Start SASL auth with mechanism %s. Data %p (%d bytes)",
-               sasl->mechname, clientdata, datalen);
-    err = sasl_server_start(sasl->conn,
-                            sasl->mechname,
-                            clientdata,
-                            datalen,
-                            &serverout,
-                            &serveroutlen);
+    if (sasl->mechname != NULL) {
+        spice_debug("Start SASL auth with mechanism %s. Data %p (%d bytes)",
+                   sasl->mechname, clientdata, datalen);
+        err = sasl_server_start(sasl->conn,
+                                sasl->mechname,
+                                clientdata,
+                                datalen,
+                                &serverout,
+                                &serveroutlen);
+        g_free(sasl->mechname);
+        sasl->mechname = NULL;
+    } else {
+        spice_debug("Step using SASL Data %p (%d bytes)", clientdata, datalen);
+        err = sasl_server_step(sasl->conn,
+                               clientdata,
+                               datalen,
+                               &serverout,
+                               &serveroutlen);
+    }
     if (err != SASL_OK &&
         err != SASL_CONTINUE) {
-        spice_warning("sasl start failed %d (%s)",
+        spice_warning("sasl step failed %d (%s)",
                     err, sasl_errdetail(sasl->conn));
         return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     if (serveroutlen > SASL_DATA_MAX_LEN) {
-        spice_warning("sasl start reply data too long %d",
+        spice_warning("sasl step reply data too long %d",
                       serveroutlen);
         return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
@@ -954,7 +860,7 @@ static void red_sasl_handle_auth_start(void *opaque)
     red_stream_write_u8(stream, err == SASL_CONTINUE ? 0 : 1);
 
     if (err == SASL_CONTINUE) {
-        spice_debug("%s", "Authentication must continue (start)");
+        spice_debug("%s", "Authentication must continue");
         /* Wait for step length */
         red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
                               red_sasl_handle_auth_steplen, opaque);
@@ -987,27 +893,29 @@ authreject:
     red_sasl_async_result(opaque, RED_SASL_ERROR_AUTH_FAILED);
 }
 
-static void red_sasl_handle_auth_startlen(void *opaque)
+static void red_sasl_handle_auth_steplen(void *opaque)
 {
     RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
     sasl->len = GUINT32_FROM_LE(sasl->len);
-    spice_debug("Got client start len %d", sasl->len);
+    spice_debug("Got steplen %d", sasl->len);
     if (sasl->len > SASL_DATA_MAX_LEN) {
         spice_warning("Too much SASL data %d", sasl->len);
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA);
+        return red_sasl_async_result(opaque, sasl->mechname ? RED_SASL_ERROR_INVALID_DATA : RED_SASL_ERROR_GENERIC);
     }
 
     if (sasl->len == 0) {
-        return red_sasl_handle_auth_start(opaque);
+        return red_sasl_handle_auth_step(opaque);
     }
 
     sasl->data = g_realloc(sasl->data, sasl->len);
     red_stream_async_read(stream, (uint8_t *)sasl->data, sasl->len,
-                          red_sasl_handle_auth_start, opaque);
+                          red_sasl_handle_auth_step, opaque);
 }
 
+
+
 static void red_sasl_handle_auth_mechname(void *opaque)
 {
     RedStream *stream = ((RedSASLAuth *)opaque)->stream;
@@ -1027,7 +935,7 @@ static void red_sasl_handle_auth_mechname(void *opaque)
     spice_debug("Validated mechname '%s'", sasl->mechname);
 
     red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                          red_sasl_handle_auth_startlen, opaque);
+                          red_sasl_handle_auth_steplen, opaque);
 }
 
 static void red_sasl_handle_auth_mechlen(void *opaque)
commit 5c516a6e42f149e7ed7cac15d138e2b487804d9c
Author: Frediano Ziglio <fziglio at redhat.com>
Date:   Tue Dec 12 17:20:39 2017 +0000

    red-stream: Handle properly endianness in SASL code
    
    All SPICE protocol is little endian, there's no agreement on other
    endian and currently we support only little endian so make sure
    this will work even possibly running on a big endian machine.
    
    Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
    Acked-by: Christophe Fergeau <cfergeau at redhat.com>

diff --git a/server/red-stream.c b/server/red-stream.c
index 0176f0a9..5c2c2739 100644
--- a/server/red-stream.c
+++ b/server/red-stream.c
@@ -580,8 +580,9 @@ static bool red_stream_write_u8(RedStream *s, uint8_t n)
     return red_stream_write_all(s, &n, sizeof(uint8_t));
 }
 
-static bool red_stream_write_u32(RedStream *s, uint32_t n)
+static bool red_stream_write_u32_le(RedStream *s, uint32_t n)
 {
+    n = GUINT32_TO_LE(n);
     return red_stream_write_all(s, &n, sizeof(uint32_t));
 }
 
@@ -822,10 +823,10 @@ static void red_sasl_handle_auth_step(void *opaque)
 
     if (serveroutlen) {
         serveroutlen += 1;
-        red_stream_write_all(stream, &serveroutlen, sizeof(uint32_t));
+        red_stream_write_u32_le(stream, serveroutlen);
         red_stream_write_all(stream, serverout, serveroutlen);
     } else {
-        red_stream_write_all(stream, &serveroutlen, sizeof(uint32_t));
+        red_stream_write_u32_le(stream, serveroutlen);
     }
 
     /* Whether auth is complete */
@@ -846,7 +847,7 @@ static void red_sasl_handle_auth_step(void *opaque)
         }
 
         spice_debug("Authentication successful");
-        red_stream_write_u32(stream, SPICE_LINK_ERR_OK); /* Accept auth */
+        red_stream_write_u32_le(stream, SPICE_LINK_ERR_OK); /* Accept auth */
 
         /*
          * Delay writing in SSF encoded until now
@@ -858,8 +859,8 @@ static void red_sasl_handle_auth_step(void *opaque)
     }
 
 authreject:
-    red_stream_write_u32(stream, 1); /* Reject auth */
-    red_stream_write_u32(stream, sizeof("Authentication failed"));
+    red_stream_write_u32_le(stream, 1); /* Reject auth */
+    red_stream_write_u32_le(stream, sizeof("Authentication failed"));
     red_stream_write_all(stream, "Authentication failed", sizeof("Authentication failed"));
 
     red_sasl_async_result(opaque, RED_SASL_ERROR_AUTH_FAILED);
@@ -870,6 +871,7 @@ static void red_sasl_handle_auth_steplen(void *opaque)
     RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
+    sasl->len = GUINT32_FROM_LE(sasl->len);
     spice_debug("Got steplen %d", sasl->len);
     if (sasl->len > SASL_DATA_MAX_LEN) {
         spice_warning("Too much SASL data %d", sasl->len);
@@ -942,10 +944,10 @@ static void red_sasl_handle_auth_start(void *opaque)
 
     if (serveroutlen) {
         serveroutlen += 1;
-        red_stream_write_all(stream, &serveroutlen, sizeof(uint32_t));
+        red_stream_write_u32_le(stream, serveroutlen);
         red_stream_write_all(stream, serverout, serveroutlen);
     } else {
-        red_stream_write_all(stream, &serveroutlen, sizeof(uint32_t));
+        red_stream_write_u32_le(stream, serveroutlen);
     }
 
     /* Whether auth is complete */
@@ -966,7 +968,7 @@ static void red_sasl_handle_auth_start(void *opaque)
         }
 
         spice_debug("Authentication successful");
-        red_stream_write_u32(stream, SPICE_LINK_ERR_OK); /* Accept auth */
+        red_stream_write_u32_le(stream, SPICE_LINK_ERR_OK); /* Accept auth */
 
         /*
          * Delay writing in SSF encoded until now
@@ -978,8 +980,8 @@ static void red_sasl_handle_auth_start(void *opaque)
     }
 
 authreject:
-    red_stream_write_u32(stream, 1); /* Reject auth */
-    red_stream_write_u32(stream, sizeof("Authentication failed"));
+    red_stream_write_u32_le(stream, 1); /* Reject auth */
+    red_stream_write_u32_le(stream, sizeof("Authentication failed"));
     red_stream_write_all(stream, "Authentication failed", sizeof("Authentication failed"));
 
     red_sasl_async_result(opaque, RED_SASL_ERROR_AUTH_FAILED);
@@ -990,6 +992,7 @@ static void red_sasl_handle_auth_startlen(void *opaque)
     RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
+    sasl->len = GUINT32_FROM_LE(sasl->len);
     spice_debug("Got client start len %d", sasl->len);
     if (sasl->len > SASL_DATA_MAX_LEN) {
         spice_warning("Too much SASL data %d", sasl->len);
@@ -1032,6 +1035,7 @@ static void red_sasl_handle_auth_mechlen(void *opaque)
     RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
+    sasl->len = GUINT32_FROM_LE(sasl->len);
     if (sasl->len < 1 || sasl->len > SASL_MAX_MECHNAME_LEN) {
         spice_warning("Got bad client mechname len %d", sasl->len);
         return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
@@ -1141,7 +1145,7 @@ bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *resul
     sasl->mechlist = g_strdup(mechlist);
 
     mechlistlen = strlen(mechlist);
-    if (!red_stream_write_all(stream, &mechlistlen, sizeof(uint32_t))
+    if (!red_stream_write_u32_le(stream, mechlistlen)
         || !red_stream_write_all(stream, sasl->mechlist, mechlistlen)) {
         spice_warning("SASL mechanisms write error");
         goto error;
commit 5c438510cd373db9c6730231d829c757d9b74e53
Author: Frediano Ziglio <fziglio at redhat.com>
Date:   Tue Dec 12 17:20:39 2017 +0000

    Handle SASL initialisation mainly in red-stream.c
    
    Asynchronous code jumping from a file to another is tedious to read
    also having code handling the same stuff in two files does not look
    a good design.
    
    Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
    Acked-by: Christophe Fergeau <cfergeau at redhat.com>

diff --git a/server/red-stream.c b/server/red-stream.c
index fccad8b2..0176f0a9 100644
--- a/server/red-stream.c
+++ b/server/red-stream.c
@@ -732,6 +732,36 @@ static int auth_sasl_check_ssf(RedSASL *sasl, int *runSSF)
     return 1;
 }
 
+typedef struct RedSASLAuth {
+    RedStream *stream;
+    // callback to call if success
+    RedSaslResult result_cb;
+    void *result_opaque;
+    // saved Async callback, we need to call if failed as
+    // we need to chain it in order to use a different opaque data
+    AsyncReadError saved_error_cb;
+} RedSASLAuth;
+
+// handle SASL termination, either success or error
+// NOTE: After this function is called usually there should be a
+// return or the function should exit
+static void red_sasl_async_result(RedSASLAuth *auth, RedSaslError err)
+{
+    red_stream_set_async_error_handler(auth->stream, auth->saved_error_cb);
+    auth->result_cb(auth->result_opaque, err);
+    g_free(auth);
+}
+
+static void red_sasl_error(void *opaque, int err)
+{
+    RedSASLAuth *auth = opaque;
+    red_stream_set_async_error_handler(auth->stream, auth->saved_error_cb);
+    if (auth->saved_error_cb) {
+        auth->saved_error_cb(auth->result_opaque, err);
+    }
+    g_free(auth);
+}
+
 /*
  * Step Msg
  *
@@ -749,8 +779,11 @@ static int auth_sasl_check_ssf(RedSASL *sasl, int *runSSF)
 #define SASL_MAX_MECHNAME_LEN 100
 #define SASL_DATA_MAX_LEN (1024 * 1024)
 
-RedSaslError red_sasl_handle_auth_step(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_steplen(void *opaque);
+
+static void red_sasl_handle_auth_step(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     const char *serverout;
     unsigned int serveroutlen;
     int err;
@@ -776,13 +809,13 @@ RedSaslError red_sasl_handle_auth_step(RedStream *stream, AsyncReadDone read_cb,
         err != SASL_CONTINUE) {
         spice_warning("sasl step failed %d (%s)",
                       err, sasl_errdetail(sasl->conn));
-        return RED_SASL_ERROR_GENERIC;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     if (serveroutlen > SASL_DATA_MAX_LEN) {
         spice_warning("sasl step reply data too long %d",
                       serveroutlen);
-        return RED_SASL_ERROR_INVALID_DATA;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     spice_debug("SASL return data %d bytes, %p", serveroutlen, serverout);
@@ -802,8 +835,8 @@ RedSaslError red_sasl_handle_auth_step(RedStream *stream, AsyncReadDone read_cb,
         spice_debug("%s", "Authentication must continue (step)");
         /* Wait for step length */
         red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                              read_cb, opaque);
-        return RED_SASL_ERROR_CONTINUE;
+                              red_sasl_handle_auth_steplen, opaque);
+        return;
     } else {
         int ssf;
 
@@ -821,7 +854,7 @@ RedSaslError red_sasl_handle_auth_step(RedStream *stream, AsyncReadDone read_cb,
         sasl->runSSF = ssf;
         red_stream_disable_writev(stream); /* make sure writev isn't called directly anymore */
 
-        return RED_SASL_ERROR_OK;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_OK);
     }
 
 authreject:
@@ -829,31 +862,26 @@ authreject:
     red_stream_write_u32(stream, sizeof("Authentication failed"));
     red_stream_write_all(stream, "Authentication failed", sizeof("Authentication failed"));
 
-    return RED_SASL_ERROR_AUTH_FAILED;
+    red_sasl_async_result(opaque, RED_SASL_ERROR_AUTH_FAILED);
 }
 
-RedSaslError red_sasl_handle_auth_steplen(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_steplen(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
     spice_debug("Got steplen %d", sasl->len);
     if (sasl->len > SASL_DATA_MAX_LEN) {
         spice_warning("Too much SASL data %d", sasl->len);
-        return RED_SASL_ERROR_INVALID_DATA;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     if (sasl->len == 0) {
-        read_cb(opaque);
-        /* FIXME: can't report potential errors correctly here,
-         * but read_cb() will have done the needed RedLinkInfo cleanups
-         * if an error occurs, so the caller should not need to do more
-         * treatment */
-        return RED_SASL_ERROR_OK;
+        red_sasl_handle_auth_step(opaque);
     } else {
         sasl->data = g_realloc(sasl->data, sasl->len);
         red_stream_async_read(stream, (uint8_t *)sasl->data, sasl->len,
-                              read_cb, opaque);
-        return RED_SASL_ERROR_OK;
+                              red_sasl_handle_auth_step, opaque);
     }
 }
 
@@ -872,8 +900,9 @@ RedSaslError red_sasl_handle_auth_steplen(RedStream *stream, AsyncReadDone read_
  * u8 continue
  */
 
-RedSaslError red_sasl_handle_auth_start(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_start(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     const char *serverout;
     unsigned int serveroutlen;
     int err;
@@ -900,13 +929,13 @@ RedSaslError red_sasl_handle_auth_start(RedStream *stream, AsyncReadDone read_cb
         err != SASL_CONTINUE) {
         spice_warning("sasl start failed %d (%s)",
                     err, sasl_errdetail(sasl->conn));
-        return RED_SASL_ERROR_INVALID_DATA;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     if (serveroutlen > SASL_DATA_MAX_LEN) {
         spice_warning("sasl start reply data too long %d",
-                    serveroutlen);
-        return RED_SASL_ERROR_INVALID_DATA;
+                      serveroutlen);
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     spice_debug("SASL return data %d bytes, %p", serveroutlen, serverout);
@@ -926,8 +955,8 @@ RedSaslError red_sasl_handle_auth_start(RedStream *stream, AsyncReadDone read_cb
         spice_debug("%s", "Authentication must continue (start)");
         /* Wait for step length */
         red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                              read_cb, opaque);
-        return RED_SASL_ERROR_CONTINUE;
+                              red_sasl_handle_auth_steplen, opaque);
+        return;
     } else {
         int ssf;
 
@@ -944,7 +973,8 @@ RedSaslError red_sasl_handle_auth_start(RedStream *stream, AsyncReadDone read_cb
          */
         sasl->runSSF = ssf;
         red_stream_disable_writev(stream); /* make sure writev isn't called directly anymore */
-        return RED_SASL_ERROR_OK;
+
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_OK);
     }
 
 authreject:
@@ -952,32 +982,32 @@ authreject:
     red_stream_write_u32(stream, sizeof("Authentication failed"));
     red_stream_write_all(stream, "Authentication failed", sizeof("Authentication failed"));
 
-    return RED_SASL_ERROR_AUTH_FAILED;
+    red_sasl_async_result(opaque, RED_SASL_ERROR_AUTH_FAILED);
 }
 
-RedSaslError red_sasl_handle_auth_startlen(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_startlen(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
     spice_debug("Got client start len %d", sasl->len);
     if (sasl->len > SASL_DATA_MAX_LEN) {
         spice_warning("Too much SASL data %d", sasl->len);
-        return RED_SASL_ERROR_INVALID_DATA;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA);
     }
 
     if (sasl->len == 0) {
-        return RED_SASL_ERROR_RETRY;
+        return red_sasl_handle_auth_start(opaque);
     }
 
     sasl->data = g_realloc(sasl->data, sasl->len);
     red_stream_async_read(stream, (uint8_t *)sasl->data, sasl->len,
-                          read_cb, opaque);
-
-    return RED_SASL_ERROR_OK;
+                          red_sasl_handle_auth_start, opaque);
 }
 
-bool red_sasl_handle_auth_mechname(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_mechname(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
     sasl->mechname[sasl->len] = '\0';
@@ -988,36 +1018,33 @@ bool red_sasl_handle_auth_mechname(RedStream *stream, AsyncReadDone read_cb, voi
     sprintf(quoted_mechname, ",%s,", sasl->mechname);
 
     if (strchr(sasl->mechname, ',') || strstr(sasl->mechlist, quoted_mechname) == NULL) {
-        return false;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA);
     }
 
     spice_debug("Validated mechname '%s'", sasl->mechname);
 
     red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                          read_cb, opaque);
-
-    return true;
+                          red_sasl_handle_auth_startlen, opaque);
 }
 
-bool red_sasl_handle_auth_mechlen(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_mechlen(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
     if (sasl->len < 1 || sasl->len > SASL_MAX_MECHNAME_LEN) {
         spice_warning("Got bad client mechname len %d", sasl->len);
-        return false;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     sasl->mechname = g_malloc(sasl->len + 1);
 
     spice_debug("Wait for client mechname");
     red_stream_async_read(stream, (uint8_t *)sasl->mechname, sasl->len,
-                          read_cb, opaque);
-
-    return true;
+                          red_sasl_handle_auth_mechname, opaque);
 }
 
-bool red_sasl_start_auth(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *result_opaque)
 {
     const char *mechlist = NULL;
     sasl_security_properties_t secprops;
@@ -1025,6 +1052,7 @@ bool red_sasl_start_auth(RedStream *stream, AsyncReadDone read_cb, void *opaque)
     char *localAddr, *remoteAddr;
     int mechlistlen;
     RedSASL *sasl = &stream->priv->sasl;
+    RedSASLAuth *auth;
 
     if (!(localAddr = red_stream_get_local_address(stream))) {
         goto error;
@@ -1119,9 +1147,16 @@ bool red_sasl_start_auth(RedStream *stream, AsyncReadDone read_cb, void *opaque)
         goto error;
     }
 
+    auth = g_new0(RedSASLAuth, 1);
+    auth->stream = stream;
+    auth->result_cb = result_cb;
+    auth->result_opaque = result_opaque;
+    auth->saved_error_cb = stream->priv->async_read.error;
+    red_stream_set_async_error_handler(stream, red_sasl_error);
+
     spice_debug("Wait for client mechname length");
     red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                          read_cb, opaque);
+                          red_sasl_handle_auth_mechlen, auth);
 
     return true;
 
diff --git a/server/red-stream.h b/server/red-stream.h
index a8d855c2..4d5075ed 100644
--- a/server/red-stream.h
+++ b/server/red-stream.h
@@ -73,17 +73,10 @@ typedef enum {
     RED_SASL_ERROR_OK,
     RED_SASL_ERROR_GENERIC,
     RED_SASL_ERROR_INVALID_DATA,
-    RED_SASL_ERROR_RETRY,
-    RED_SASL_ERROR_CONTINUE,
     RED_SASL_ERROR_AUTH_FAILED
 } RedSaslError;
 
-RedSaslError red_sasl_handle_auth_step(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-RedSaslError red_sasl_handle_auth_steplen(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-RedSaslError red_sasl_handle_auth_start(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-RedSaslError red_sasl_handle_auth_startlen(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-bool red_sasl_handle_auth_mechname(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-bool red_sasl_handle_auth_mechlen(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-bool red_sasl_start_auth(RedStream *stream, AsyncReadDone read_cb, void *opaque);
+typedef void (*RedSaslResult)(void *opaque, RedSaslError err);
+bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *opaque);
 
 #endif /* RED_STREAM_H_ */
diff --git a/server/reds.c b/server/reds.c
index 9338b78b..777f5ba4 100644
--- a/server/reds.c
+++ b/server/reds.c
@@ -2100,123 +2100,30 @@ static void reds_get_spice_ticket(RedLinkInfo *link)
 }
 
 #if HAVE_SASL
-/*
- * Step Msg
- *
- * Input from client:
- *
- * u32 clientin-length
- * u8-array clientin-string
- *
- * Output to client:
- *
- * u32 serverout-length
- * u8-array serverout-strin
- * u8 continue
- */
-
-static void reds_handle_auth_sasl_steplen(void *opaque);
-
-static void reds_handle_auth_sasl_step(void *opaque)
-{
-    RedLinkInfo *link = (RedLinkInfo *)opaque;
-    RedSaslError status;
-
-    status = red_sasl_handle_auth_step(link->stream, reds_handle_auth_sasl_steplen, link);
-    if (status == RED_SASL_ERROR_OK) {
-        reds_handle_link(link);
-    } else if (status != RED_SASL_ERROR_CONTINUE) {
-        reds_link_free(link);
-    }
-}
-
-static void reds_handle_auth_sasl_steplen(void *opaque)
-{
-    RedLinkInfo *link = (RedLinkInfo *)opaque;
-    RedSaslError status;
-
-    status = red_sasl_handle_auth_steplen(link->stream, reds_handle_auth_sasl_step, link);
-    if (status != RED_SASL_ERROR_OK) {
-        reds_link_free(link);
-    }
-}
-
-/*
- * Start Msg
- *
- * Input from client:
- *
- * u32 clientin-length
- * u8-array clientin-string
- *
- * Output to client:
- *
- * u32 serverout-length
- * u8-array serverout-strin
- * u8 continue
- */
-
-
-static void reds_handle_auth_sasl_start(void *opaque)
+static void reds_handle_sasl_result(void *opaque, RedSaslError status)
 {
     RedLinkInfo *link = (RedLinkInfo *)opaque;
-    RedSaslError status;
-
-    status = red_sasl_handle_auth_start(link->stream, reds_handle_auth_sasl_steplen, link);
-    if (status == RED_SASL_ERROR_OK) {
-        reds_handle_link(link);
-    } else if (status != RED_SASL_ERROR_CONTINUE) {
-        reds_link_free(link);
-    }
-}
 
-static void reds_handle_auth_startlen(void *opaque)
-{
-    RedLinkInfo *link = (RedLinkInfo *)opaque;
-    RedSaslError status;
-
-    status = red_sasl_handle_auth_startlen(link->stream, reds_handle_auth_sasl_start, link);
     switch (status) {
-        case RED_SASL_ERROR_OK:
-            break;
-        case RED_SASL_ERROR_RETRY:
-            reds_handle_auth_sasl_start(opaque);
-            break;
-        case RED_SASL_ERROR_GENERIC:
-        case RED_SASL_ERROR_INVALID_DATA:
-            reds_send_link_error(link, SPICE_LINK_ERR_INVALID_DATA);
-            reds_link_free(link);
-            break;
-        default:
-            g_warn_if_reached();
-            reds_send_link_error(link, SPICE_LINK_ERR_INVALID_DATA);
-            reds_link_free(link);
-            break;
-    }
-}
-
-static void reds_handle_auth_mechname(void *opaque)
-{
-    RedLinkInfo *link = (RedLinkInfo *)opaque;
-
-    if (!red_sasl_handle_auth_mechname(link->stream, reds_handle_auth_startlen, link)) {
-            reds_send_link_error(link, SPICE_LINK_ERR_INVALID_DATA);
+    case RED_SASL_ERROR_OK:
+        reds_handle_link(link);
+        break;
+    case RED_SASL_ERROR_INVALID_DATA:
+        reds_send_link_error(link, SPICE_LINK_ERR_INVALID_DATA);
         reds_link_free(link);
-    }
-}
-
-static void reds_handle_auth_mechlen(void *opaque)
-{
-    RedLinkInfo *link = (RedLinkInfo *)opaque;
-
-    if (!red_sasl_handle_auth_mechlen(link->stream, reds_handle_auth_mechname, link)) {
+        break;
+    default:
+        // in these cases error was reported using SASL protocol
+        // (RED_SASL_ERROR_AUTH_FAILED) or we just need to close the
+        // connection
         reds_link_free(link);
+        break;
     }
 }
 
 static void reds_start_auth_sasl(RedLinkInfo *link)
 {
-    if (!red_sasl_start_auth(link->stream, reds_handle_auth_mechlen, link)) {
+    if (!red_sasl_start_auth(link->stream, reds_handle_sasl_result, link)) {
         reds_link_free(link);
     }
 }
commit 6543bef0cbf7abd37c46b20ff4024435acc260a7
Author: Frediano Ziglio <fziglio at redhat.com>
Date:   Sun Jan 7 11:00:07 2018 +0000

    test-sasl: Test how to server reports the failure
    
    The server on failure can just disconnect the client or report the
    error. The error report can be done using new protocol 2 or just
    a number (like protocol 1).
    Detect the failure report to make possible to check it.
    
    Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
    Acked-by: Christophe Fergeau <cfergeau at redhat.com>

diff --git a/server/tests/test-sasl.c b/server/tests/test-sasl.c
index 726b4029..25ad79af 100644
--- a/server/tests/test-sasl.c
+++ b/server/tests/test-sasl.c
@@ -406,10 +406,19 @@ typedef enum {
     STEP_NEVER,
 } ClientEmulationSteps;
 
+typedef enum {
+    FAILURE_OK,
+    FAILURE_UNKNOWN,
+    FAILURE_DISCONNECT,
+    FAILURE_MAGIC,
+    FAILURE_AUTH,
+} FailureType;
+
 typedef struct {
     const char *mechname;
     int mechlen;
     bool success;
+    FailureType failure_type;
     ClientEmulationSteps last_step;
     unsigned flags;
     int line;
@@ -419,28 +428,30 @@ static char long_mechname[128];
 static TestData tests_data[] = {
     // these should just succeed
 #define TEST_SUCCESS(mech) \
-    { mech, -1, true, STEP_NEVER, FLAG_NONE, __LINE__ },
+    { mech, -1, true, FAILURE_OK, STEP_NEVER, FLAG_NONE, __LINE__ },
     TEST_SUCCESS("ONE")
     TEST_SUCCESS("TWO")
     TEST_SUCCESS("THREE")
 
     // these test bad mech names
 #define TEST_BAD_NAME(mech, len) \
-    { mech, len, false, STEP_NEVER, FLAG_NONE, __LINE__ },
+    { mech, len, false, FAILURE_MAGIC, STEP_NEVER, FLAG_NONE, __LINE__ },
+#define TEST_BAD_NAME_LEN(mech, len) \
+    { mech, len, false, FAILURE_DISCONNECT, STEP_NEVER, FLAG_NONE, __LINE__ },
     TEST_BAD_NAME("ON", -1)
     TEST_BAD_NAME("NE", -1)
     TEST_BAD_NAME("THRE", -1)
     TEST_BAD_NAME("HREE", -1)
     TEST_BAD_NAME("ON\x00", 3)
     TEST_BAD_NAME("O\x00\x00", 3)
-    TEST_BAD_NAME("", -1)
+    TEST_BAD_NAME_LEN("", -1)
     TEST_BAD_NAME(long_mechname, 100)
-    TEST_BAD_NAME(long_mechname, 101)
+    TEST_BAD_NAME_LEN(long_mechname, 101)
     TEST_BAD_NAME("ONE,TWO", -1)
 
     // stop before filling everything
 #define TEST_EARLY_STOP(step) \
-    { "ONE", -1, false, step, FLAG_NONE, __LINE__},
+    { "ONE", -1, false, FAILURE_UNKNOWN, step, FLAG_NONE, __LINE__},
     TEST_EARLY_STOP(STEP_READ_MECHLIST_LEN)
     TEST_EARLY_STOP(STEP_READ_MECHLIST)
     TEST_EARLY_STOP(STEP_WRITE_MECHNAME_LEN)
@@ -449,29 +460,31 @@ static TestData tests_data[] = {
     TEST_EARLY_STOP(STEP_WRITE_START)
     TEST_EARLY_STOP(STEP_WRITE_STEP_LEN)
 
-#define TEST_FLAGS(result, flags) \
-    { "ONE", -1, result, STEP_NEVER, flags, __LINE__},
-    TEST_FLAGS(false, FLAG_LOW_SSF)
-    TEST_FLAGS(false, FLAG_START_OK|FLAG_LOW_SSF)
-    TEST_FLAGS(true, FLAG_START_OK)
-    TEST_FLAGS(true, FLAG_SERVER_NULL_START)
-    TEST_FLAGS(true, FLAG_SERVER_NULL_STEP)
-    TEST_FLAGS(true, FLAG_CLIENT_NULL_START)
-    TEST_FLAGS(true, FLAG_CLIENT_NULL_STEP)
-    TEST_FLAGS(false, FLAG_SERVER_BIG_START)
-    TEST_FLAGS(false, FLAG_SERVER_BIG_STEP)
-    TEST_FLAGS(false, FLAG_CLIENT_BIG_START)
-    TEST_FLAGS(false, FLAG_CLIENT_BIG_STEP)
-    TEST_FLAGS(false, FLAG_START_ERROR)
-    TEST_FLAGS(false, FLAG_STEP_ERROR)
+#define TEST_FLAGS_OK(flags) \
+    { "ONE", -1, true, FAILURE_OK, STEP_NEVER, flags, __LINE__},
+#define TEST_FLAGS_KO(failure, flags) \
+    { "ONE", -1, false, failure, STEP_NEVER, flags, __LINE__},
+    TEST_FLAGS_KO(FAILURE_OK, FLAG_LOW_SSF)
+    TEST_FLAGS_KO(FAILURE_OK, FLAG_START_OK|FLAG_LOW_SSF)
+    TEST_FLAGS_OK(FLAG_START_OK)
+    TEST_FLAGS_OK(FLAG_SERVER_NULL_START)
+    TEST_FLAGS_OK(FLAG_SERVER_NULL_STEP)
+    TEST_FLAGS_OK(FLAG_CLIENT_NULL_START)
+    TEST_FLAGS_OK(FLAG_CLIENT_NULL_STEP)
+    TEST_FLAGS_KO(FAILURE_DISCONNECT, FLAG_SERVER_BIG_START)
+    TEST_FLAGS_KO(FAILURE_DISCONNECT, FLAG_SERVER_BIG_STEP)
+    TEST_FLAGS_KO(FAILURE_MAGIC, FLAG_CLIENT_BIG_START)
+    TEST_FLAGS_KO(FAILURE_DISCONNECT, FLAG_CLIENT_BIG_STEP)
+    TEST_FLAGS_KO(FAILURE_DISCONNECT, FLAG_START_ERROR)
+    TEST_FLAGS_KO(FAILURE_DISCONNECT, FLAG_STEP_ERROR)
 };
 
-static void
+static FailureType
 client_emulator(int sock)
 {
     const TestData *data = &tests_data[test_num];
 
-#define STOP_AT(step) if (data->last_step == STEP_ ## step) { return; }
+#define STOP_AT(step) if (data->last_step == STEP_ ## step) { return FAILURE_UNKNOWN; }
 
     // send initial message
     write_all(sock, &initial_message, sizeof(initial_message));
@@ -521,32 +534,43 @@ client_emulator(int sock)
     if (write_u32_err(sock, out ? outlen : 0) == sizeof(uint32_t)) {
         STOP_AT(WRITE_START_LEN);
         if (out) {
-            if (do_readwrite_all(sock, out, outlen, true) != outlen) {
-                return;
-            }
+            do_readwrite_all(sock, out, outlen, true);
         }
         STOP_AT(WRITE_START);
     }
 
-    uint32_t datalen;
-    if (read_u32_err(sock, &datalen) != sizeof(datalen)) {
-        return;
-    }
-    if (datalen == GUINT32_FROM_LE(SPICE_MAGIC)) {
-        return;
-    }
-    g_assert_cmpint(datalen, <=, sizeof(buf));
-    read_all(sock, buf, datalen);
+    for (;;) {
+        uint32_t datalen;
+        if (read_u32_err(sock, &datalen) != sizeof(datalen)) {
+            return FAILURE_DISCONNECT;
+        }
+        if (datalen == GUINT32_FROM_LE(SPICE_MAGIC)) {
+            return FAILURE_MAGIC;
+        }
+        g_assert_cmpint(datalen, <=, sizeof(buf));
+        read_all(sock, buf, datalen);
+
+        uint8_t is_ok;
+        read_all(sock, &is_ok, sizeof(is_ok));
+        if (is_ok) {
+            // is_ok should be 0 or 1
+            g_assert_cmpint(is_ok, ==, 1);
+            uint32_t step_result;
+            read_u32(sock, &step_result);
+            if (!step_result) {
+                return FAILURE_AUTH;
+            }
+            return FAILURE_OK;
+        }
 
-    get_step_out(&out, &outlen, "STEP", FLAG_CLIENT_NULL_STEP, FLAG_CLIENT_BIG_STEP);
-    if (write_u32_err(sock, out ? outlen : 0) == sizeof(uint32_t)) {
-        STOP_AT(WRITE_STEP_LEN);
-        if (out) {
-            if (do_readwrite_all(sock, out, outlen, true) != outlen) {
-                return;
+        get_step_out(&out, &outlen, "STEP", FLAG_CLIENT_NULL_STEP, FLAG_CLIENT_BIG_STEP);
+        if (write_u32_err(sock, out ? outlen : 0) == sizeof(uint32_t)) {
+            STOP_AT(WRITE_STEP_LEN);
+            if (out) {
+                do_readwrite_all(sock, out, outlen, true);
             }
+            STOP_AT(WRITE_STEP);
         }
-        STOP_AT(WRITE_STEP);
     }
 }
 
@@ -555,14 +579,14 @@ client_emulator_thread(void *arg)
 {
     int sock = GPOINTER_TO_INT(arg);
 
-    client_emulator(sock);
+    FailureType result = client_emulator(sock);
 
     shutdown(sock, SHUT_RDWR);
     close(sock);
 
     idle_add(idle_end_test, NULL);
 
-    return NULL;
+    return GINT_TO_POINTER(result);
 }
 
 static pthread_t
@@ -598,7 +622,7 @@ idle_end_test(void *arg)
 }
 
 static void
-check_test_results(void)
+check_test_results(FailureType res)
 {
     const TestData *data = &tests_data[test_num];
     if (data->success) {
@@ -608,6 +632,7 @@ check_test_results(void)
 
     g_assert(mechlist_called);
     g_assert(!encode_called);
+    g_assert_cmpint(res, ==, data->failure_type);
 }
 
 static void
@@ -621,9 +646,10 @@ sasl_mechs(void)
         pthread_t thread = setup_thread();
         alarm(4);
         basic_event_loop_mainloop();
-        g_assert_cmpint(pthread_join(thread, NULL), ==, 0);
+        void *thread_ret = NULL;
+        g_assert_cmpint(pthread_join(thread, &thread_ret), ==, 0);
         alarm(0);
-        check_test_results();
+        check_test_results((FailureType)GPOINTER_TO_INT(thread_ret));
         reset_test();
     }
 
commit 0e533dfec51c547f77aaf7547f3fc8e15063cd8d
Author: Frediano Ziglio <fziglio at redhat.com>
Date:   Wed Dec 13 19:52:21 2017 +0000

    test-sasl: Add tests for different failures and cases
    
    Use some flags to specify which behaviour to change and different test
    cases to test them.
    Some cases specify when client stop sending data at different steps of
    the process.
    
    Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
    Acked-by: Christophe Fergeau <cfergeau at redhat.com>

diff --git a/server/tests/test-sasl.c b/server/tests/test-sasl.c
index 7c1181dc..726b4029 100644
--- a/server/tests/test-sasl.c
+++ b/server/tests/test-sasl.c
@@ -41,11 +41,29 @@ typedef struct SPICE_ATTR_PACKED SpiceInitialMessage {
 } SpiceInitialMessage;
 #include <spice/end-packed.h>
 
+typedef enum {
+    FLAG_NONE = 0,
+    FLAG_START_OK = 1,
+    FLAG_LOW_SSF = 2,
+    FLAG_SERVER_NULL_START = 4,
+    FLAG_SERVER_NULL_STEP = 8,
+    FLAG_CLIENT_NULL_START = 16,
+    FLAG_CLIENT_NULL_STEP = 32,
+    FLAG_SERVER_BIG_START = 64,
+    FLAG_SERVER_BIG_STEP = 128,
+    FLAG_CLIENT_BIG_START = 256,
+    FLAG_CLIENT_BIG_STEP = 512,
+    FLAG_START_ERROR = 1024,
+    FLAG_STEP_ERROR = 2048,
+} TestFlags;
+
 static char *mechlist;
+static char *big_data;
 static bool mechlist_called;
 static bool start_called;
 static bool step_called;
 static bool encode_called;
+static unsigned test_flags;
 
 static SpiceCoreInterface *core;
 static SpiceServer *server;
@@ -60,6 +78,22 @@ check_sasl_conn(sasl_conn_t *conn)
     g_assert_nonnull(conn);
 }
 
+static void
+get_step_out(const char **serverout, unsigned *serveroutlen,
+             const char *normal_data, unsigned null_flag, unsigned big_flag)
+{
+    if ((test_flags & big_flag) != 0) {
+        *serverout = big_data;
+        *serveroutlen = strlen(big_data);
+    } else if ((test_flags & null_flag) != 0) {
+        *serverout = NULL;
+        *serveroutlen = 0;
+    } else {
+        *serverout = normal_data;
+        *serveroutlen = strlen(normal_data);
+    }
+}
+
 int
 sasl_server_init(const sasl_callback_t *callbacks, const char *appname)
 {
@@ -113,7 +147,8 @@ sasl_getprop(sasl_conn_t *conn, int propnum,
     g_assert_nonnull(pvalue);
 
     if (propnum == SASL_SSF) {
-        static const int val = 64;
+        static int val;
+        val = (test_flags & FLAG_LOW_SSF) ? 44 : 64;
         *pvalue = &val;
     }
     return SASL_OK;
@@ -194,9 +229,11 @@ sasl_server_start(sasl_conn_t *conn,
     g_assert(!step_called);
     start_called = true;
 
-    *serverout = "foo";
-    *serveroutlen = 3;
-    return SASL_OK;
+    get_step_out(serverout, serveroutlen, "foo", FLAG_SERVER_NULL_START, FLAG_SERVER_BIG_START);
+    if (test_flags & FLAG_START_ERROR) {
+        return SASL_FAIL;
+    }
+    return (test_flags & FLAG_START_OK) ? SASL_OK : SASL_CONTINUE;
 }
 
 int
@@ -211,8 +248,10 @@ sasl_server_step(sasl_conn_t *conn,
     g_assert(start_called);
     step_called = true;
 
-    *serverout = "foo";
-    *serveroutlen = 3;
+    get_step_out(serverout, serveroutlen, "foo", FLAG_SERVER_NULL_STEP, FLAG_SERVER_BIG_STEP);
+    if (test_flags & FLAG_STEP_ERROR) {
+        return SASL_FAIL;
+    }
     return SASL_OK;
 }
 
@@ -244,11 +283,19 @@ reset_test(void)
     start_called = false;
     step_called = false;
     encode_called = false;
+    test_flags = FLAG_NONE;
 }
 
 static void
 start_test(void)
 {
+    g_assert_null(big_data);
+    big_data = g_malloc(1024 * 1024 + 10);
+    for (unsigned n = 0; n < 1024 * 1024 + 10; ++n) {
+        big_data[n] = ' ' + (n % 94);
+    }
+    big_data[1024 * 1024 + 5] = 0;
+
     g_assert_null(server);
 
     initial_message.hdr.magic = SPICE_MAGIC;
@@ -275,6 +322,9 @@ end_tests(void)
 
     g_free(mechlist);
     mechlist = NULL;
+
+    g_free(big_data);
+    big_data = NULL;
 }
 
 static size_t
@@ -343,24 +393,40 @@ idle_add(GSourceFunc func, void *arg)
     g_source_unref(source);
 }
 
+typedef enum {
+    STEP_NONE,
+    STEP_READ_MECHLIST_LEN,
+    STEP_READ_MECHLIST,
+    STEP_WRITE_MECHNAME_LEN,
+    STEP_WRITE_MECHNAME,
+    STEP_WRITE_START_LEN,
+    STEP_WRITE_START,
+    STEP_WRITE_STEP_LEN,
+    STEP_WRITE_STEP,
+    STEP_NEVER,
+} ClientEmulationSteps;
+
 typedef struct {
     const char *mechname;
     int mechlen;
     bool success;
+    ClientEmulationSteps last_step;
+    unsigned flags;
+    int line;
 } TestData;
 
 static char long_mechname[128];
 static TestData tests_data[] = {
     // these should just succeed
 #define TEST_SUCCESS(mech) \
-    { mech, -1, true },
+    { mech, -1, true, STEP_NEVER, FLAG_NONE, __LINE__ },
     TEST_SUCCESS("ONE")
     TEST_SUCCESS("TWO")
     TEST_SUCCESS("THREE")
 
     // these test bad mech names
 #define TEST_BAD_NAME(mech, len) \
-    { mech, len, false },
+    { mech, len, false, STEP_NEVER, FLAG_NONE, __LINE__ },
     TEST_BAD_NAME("ON", -1)
     TEST_BAD_NAME("NE", -1)
     TEST_BAD_NAME("THRE", -1)
@@ -371,14 +437,41 @@ static TestData tests_data[] = {
     TEST_BAD_NAME(long_mechname, 100)
     TEST_BAD_NAME(long_mechname, 101)
     TEST_BAD_NAME("ONE,TWO", -1)
+
+    // stop before filling everything
+#define TEST_EARLY_STOP(step) \
+    { "ONE", -1, false, step, FLAG_NONE, __LINE__},
+    TEST_EARLY_STOP(STEP_READ_MECHLIST_LEN)
+    TEST_EARLY_STOP(STEP_READ_MECHLIST)
+    TEST_EARLY_STOP(STEP_WRITE_MECHNAME_LEN)
+    TEST_EARLY_STOP(STEP_WRITE_MECHNAME)
+    TEST_EARLY_STOP(STEP_WRITE_START_LEN)
+    TEST_EARLY_STOP(STEP_WRITE_START)
+    TEST_EARLY_STOP(STEP_WRITE_STEP_LEN)
+
+#define TEST_FLAGS(result, flags) \
+    { "ONE", -1, result, STEP_NEVER, flags, __LINE__},
+    TEST_FLAGS(false, FLAG_LOW_SSF)
+    TEST_FLAGS(false, FLAG_START_OK|FLAG_LOW_SSF)
+    TEST_FLAGS(true, FLAG_START_OK)
+    TEST_FLAGS(true, FLAG_SERVER_NULL_START)
+    TEST_FLAGS(true, FLAG_SERVER_NULL_STEP)
+    TEST_FLAGS(true, FLAG_CLIENT_NULL_START)
+    TEST_FLAGS(true, FLAG_CLIENT_NULL_STEP)
+    TEST_FLAGS(false, FLAG_SERVER_BIG_START)
+    TEST_FLAGS(false, FLAG_SERVER_BIG_STEP)
+    TEST_FLAGS(false, FLAG_CLIENT_BIG_START)
+    TEST_FLAGS(false, FLAG_CLIENT_BIG_STEP)
+    TEST_FLAGS(false, FLAG_START_ERROR)
+    TEST_FLAGS(false, FLAG_STEP_ERROR)
 };
 
-static void *
-client_emulator(void *arg)
+static void
+client_emulator(int sock)
 {
     const TestData *data = &tests_data[test_num];
 
-    int sock = GPOINTER_TO_INT(arg);
+#define STOP_AT(step) if (data->last_step == STEP_ ## step) { return; }
 
     // send initial message
     write_all(sock, &initial_message, sizeof(initial_message));
@@ -402,23 +495,68 @@ client_emulator(void *arg)
     // mech SPICE_COMMON_CAP_AUTH_SASL)
     write_u32(sock, SPICE_COMMON_CAP_AUTH_SASL);
 
+    STOP_AT(NONE);
+
     // sasl finally start, data starts from server (mech list)
     //
     uint32_t mechlen;
     read_u32(sock, &mechlen);
+    STOP_AT(READ_MECHLIST_LEN);
+
     char buf[300];
     g_assert_cmpint(mechlen, <=, sizeof(buf));
     read_all(sock, buf, mechlen);
+    STOP_AT(READ_MECHLIST);
 
     // mech name
     write_u32(sock, data->mechlen);
+    STOP_AT(WRITE_MECHNAME_LEN);
     write_all(sock, data->mechname, data->mechlen);
+    STOP_AT(WRITE_MECHNAME);
 
     // first challenge
-    if (write_u32_err(sock, 5) == sizeof(uint32_t)) {
-        do_readwrite_all(sock, "START", 5, true);
+    const char *out;
+    unsigned outlen;
+    get_step_out(&out, &outlen, "START", FLAG_CLIENT_NULL_START, FLAG_CLIENT_BIG_START);
+    if (write_u32_err(sock, out ? outlen : 0) == sizeof(uint32_t)) {
+        STOP_AT(WRITE_START_LEN);
+        if (out) {
+            if (do_readwrite_all(sock, out, outlen, true) != outlen) {
+                return;
+            }
+        }
+        STOP_AT(WRITE_START);
     }
 
+    uint32_t datalen;
+    if (read_u32_err(sock, &datalen) != sizeof(datalen)) {
+        return;
+    }
+    if (datalen == GUINT32_FROM_LE(SPICE_MAGIC)) {
+        return;
+    }
+    g_assert_cmpint(datalen, <=, sizeof(buf));
+    read_all(sock, buf, datalen);
+
+    get_step_out(&out, &outlen, "STEP", FLAG_CLIENT_NULL_STEP, FLAG_CLIENT_BIG_STEP);
+    if (write_u32_err(sock, out ? outlen : 0) == sizeof(uint32_t)) {
+        STOP_AT(WRITE_STEP_LEN);
+        if (out) {
+            if (do_readwrite_all(sock, out, outlen, true) != outlen) {
+                return;
+            }
+        }
+        STOP_AT(WRITE_STEP);
+    }
+}
+
+static void *
+client_emulator_thread(void *arg)
+{
+    int sock = GPOINTER_TO_INT(arg);
+
+    client_emulator(sock);
+
     shutdown(sock, SHUT_RDWR);
     close(sock);
 
@@ -434,8 +572,10 @@ setup_thread(void)
     if (data->mechlen < 0) {
         data->mechlen = strlen(data->mechname);
     }
+    test_flags = data->flags;
     int len = data->mechlen;
-    printf("\nRunning test %d ('%*.*s' %d)\n", test_num, len, len, data->mechname, len);
+    printf("\nRunning test %d ('%.*s' %d) line %d\n",
+           test_num, len, data->mechname, len, data->line);
 
     int sv[2];
     g_assert_cmpint(socketpair(AF_LOCAL, SOCK_STREAM, 0, sv), ==, 0);
@@ -443,7 +583,8 @@ setup_thread(void)
     g_assert(spice_server_add_client(server, sv[0], 0) == 0);
 
     pthread_t thread;
-    g_assert_cmpint(pthread_create(&thread, NULL, client_emulator, GINT_TO_POINTER(sv[1])), ==, 0);
+    g_assert_cmpint(pthread_create(&thread, NULL, client_emulator_thread,
+                                   GINT_TO_POINTER(sv[1])), ==, 0);
     return thread;
 }
 
commit 9881702df246846746f13814f62da2d93dca9884
Author: Frediano Ziglio <fziglio at redhat.com>
Date:   Wed Dec 13 18:19:43 2017 +0000

    test-sasl: Add tests for different mechanism names
    
    Try different connections with different tricky names.
    
    Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
    Acked-by: Christophe Fergeau <cfergeau at redhat.com>

diff --git a/server/tests/test-sasl.c b/server/tests/test-sasl.c
index 41bcf7e8..7c1181dc 100644
--- a/server/tests/test-sasl.c
+++ b/server/tests/test-sasl.c
@@ -22,6 +22,7 @@
 
 #include <unistd.h>
 #include <errno.h>
+#include <string.h>
 #include <stdbool.h>
 #include <spice.h>
 #include <sasl/sasl.h>
@@ -49,6 +50,8 @@ static bool encode_called;
 static SpiceCoreInterface *core;
 static SpiceServer *server;
 
+static unsigned int test_num;
+
 static gboolean idle_end_test(void *arg);
 
 static void
@@ -340,9 +343,41 @@ idle_add(GSourceFunc func, void *arg)
     g_source_unref(source);
 }
 
+typedef struct {
+    const char *mechname;
+    int mechlen;
+    bool success;
+} TestData;
+
+static char long_mechname[128];
+static TestData tests_data[] = {
+    // these should just succeed
+#define TEST_SUCCESS(mech) \
+    { mech, -1, true },
+    TEST_SUCCESS("ONE")
+    TEST_SUCCESS("TWO")
+    TEST_SUCCESS("THREE")
+
+    // these test bad mech names
+#define TEST_BAD_NAME(mech, len) \
+    { mech, len, false },
+    TEST_BAD_NAME("ON", -1)
+    TEST_BAD_NAME("NE", -1)
+    TEST_BAD_NAME("THRE", -1)
+    TEST_BAD_NAME("HREE", -1)
+    TEST_BAD_NAME("ON\x00", 3)
+    TEST_BAD_NAME("O\x00\x00", 3)
+    TEST_BAD_NAME("", -1)
+    TEST_BAD_NAME(long_mechname, 100)
+    TEST_BAD_NAME(long_mechname, 101)
+    TEST_BAD_NAME("ONE,TWO", -1)
+};
+
 static void *
 client_emulator(void *arg)
 {
+    const TestData *data = &tests_data[test_num];
+
     int sock = GPOINTER_TO_INT(arg);
 
     // send initial message
@@ -376,12 +411,13 @@ client_emulator(void *arg)
     read_all(sock, buf, mechlen);
 
     // mech name
-    write_u32(sock, 3);
-    write_all(sock, "ONE", 3);
+    write_u32(sock, data->mechlen);
+    write_all(sock, data->mechname, data->mechlen);
 
     // first challenge
-    write_u32(sock, 5);
-    write_all(sock, "START", 5);
+    if (write_u32_err(sock, 5) == sizeof(uint32_t)) {
+        do_readwrite_all(sock, "START", 5, true);
+    }
 
     shutdown(sock, SHUT_RDWR);
     close(sock);
@@ -394,6 +430,13 @@ client_emulator(void *arg)
 static pthread_t
 setup_thread(void)
 {
+    TestData *data = &tests_data[test_num];
+    if (data->mechlen < 0) {
+        data->mechlen = strlen(data->mechname);
+    }
+    int len = data->mechlen;
+    printf("\nRunning test %d ('%*.*s' %d)\n", test_num, len, len, data->mechname, len);
+
     int sv[2];
     g_assert_cmpint(socketpair(AF_LOCAL, SOCK_STREAM, 0, sv), ==, 0);
 
@@ -414,17 +457,34 @@ idle_end_test(void *arg)
 }
 
 static void
+check_test_results(void)
+{
+    const TestData *data = &tests_data[test_num];
+    if (data->success) {
+        g_assert(encode_called);
+        return;
+    }
+
+    g_assert(mechlist_called);
+    g_assert(!encode_called);
+}
+
+static void
 sasl_mechs(void)
 {
     start_test();
 
-    pthread_t thread = setup_thread();
-    alarm(4);
-    basic_event_loop_mainloop();
-    g_assert_cmpint(pthread_join(thread, NULL), ==, 0);
-    alarm(0);
-    g_assert(encode_called);
-    reset_test();
+    memset(long_mechname, 'X', sizeof(long_mechname));
+
+    for (test_num = 0; test_num < G_N_ELEMENTS(tests_data); test_num++) {
+        pthread_t thread = setup_thread();
+        alarm(4);
+        basic_event_loop_mainloop();
+        g_assert_cmpint(pthread_join(thread, NULL), ==, 0);
+        alarm(0);
+        check_test_results();
+        reset_test();
+    }
 
     end_tests();
 }
commit cbc082ba06fb8fb6666cb608d108988f7c8f8f2f
Author: Frediano Ziglio <fziglio at redhat.com>
Date:   Tue Dec 12 17:20:39 2017 +0000

    test-sasl: Base test, connect using SASL
    
    Create a thread that emulates a client and starts SASL authentication
    
    Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
    Acked-by: Christophe Fergeau <cfergeau at redhat.com>

diff --git a/server/tests/test-sasl.c b/server/tests/test-sasl.c
index 85332974..41bcf7e8 100644
--- a/server/tests/test-sasl.c
+++ b/server/tests/test-sasl.c
@@ -21,19 +21,36 @@
 #include <config.h>
 
 #include <unistd.h>
-#include <spice.h>
+#include <errno.h>
 #include <stdbool.h>
+#include <spice.h>
 #include <sasl/sasl.h>
 
+#include <spice/protocol.h>
+#include <common/macros.h>
+
 #include "test-glib-compat.h"
 #include "basic-event-loop.h"
 
+#include <spice/start-packed.h>
+typedef struct SPICE_ATTR_PACKED SpiceInitialMessage {
+        SpiceLinkHeader hdr;
+        SpiceLinkMess mess;
+        uint32_t caps[2];
+} SpiceInitialMessage;
+#include <spice/end-packed.h>
+
 static char *mechlist;
 static bool mechlist_called;
 static bool start_called;
 static bool step_called;
 static bool encode_called;
 
+static SpiceCoreInterface *core;
+static SpiceServer *server;
+
+static gboolean idle_end_test(void *arg);
+
 static void
 check_sasl_conn(sasl_conn_t *conn)
 {
@@ -196,8 +213,226 @@ sasl_server_step(sasl_conn_t *conn,
     return SASL_OK;
 }
 
+static SpiceInitialMessage initial_message = {
+    {
+        0, // SPICE_MAGIC,
+        GUINT32_TO_LE(SPICE_VERSION_MAJOR), GUINT32_TO_LE(SPICE_VERSION_MINOR),
+        GUINT32_TO_LE(sizeof(SpiceInitialMessage) - sizeof(SpiceLinkHeader))
+    },
+    {
+        0,
+        SPICE_CHANNEL_MAIN,
+        0,
+        GUINT32_TO_LE(1),
+        GUINT32_TO_LE(1),
+        GUINT32_TO_LE(sizeof(SpiceLinkMess))
+    },
+    {
+        GUINT32_TO_LE(SPICE_COMMON_CAP_PROTOCOL_AUTH_SELECTION|SPICE_COMMON_CAP_AUTH_SASL|
+                      SPICE_COMMON_CAP_MINI_HEADER),
+        0
+    }
+};
+
+static void
+reset_test(void)
+{
+    mechlist_called = false;
+    start_called = false;
+    step_called = false;
+    encode_called = false;
+}
+
+static void
+start_test(void)
+{
+    g_assert_null(server);
+
+    initial_message.hdr.magic = SPICE_MAGIC;
+
+    reset_test();
+
+    core = basic_event_loop_init();
+    g_assert_nonnull(core);
+
+    server = spice_server_new();
+    g_assert_nonnull(server);
+    spice_server_set_sasl(server, true);
+    g_assert_cmpint(spice_server_init(server, core), ==, 0);
+}
+
+static void
+end_tests(void)
+{
+    spice_server_destroy(server);
+    server = NULL;
+
+    basic_event_loop_destroy();
+    core = NULL;
+
+    g_free(mechlist);
+    mechlist = NULL;
+}
+
+static size_t
+do_readwrite_all(int fd, const void *buf, const size_t len, bool do_write)
+{
+    size_t byte_count = 0;
+    while (byte_count < len) {
+        int l;
+        if (do_write) {
+            l = write(fd, (const char *) buf + byte_count, len - byte_count);
+        } else {
+            l = read(fd, (char *) buf + byte_count, len - byte_count);
+            if (l == 0) {
+                return byte_count;
+            }
+        }
+        if (l < 0 && errno == EINTR) {
+            continue;
+        }
+        if (l < 0) {
+            return l;
+        }
+        byte_count += l;
+    }
+    return byte_count;
+}
+
+// use macro to maintain line number on error
+#define read_all(fd, buf, len) \
+    g_assert_cmpint(do_readwrite_all(fd, buf, len, false), ==, len)
+
+#define write_all(fd, buf, len) \
+    g_assert_cmpint(do_readwrite_all(fd, buf, len, true), ==, len)
+
+static ssize_t
+read_u32_err(int fd, uint32_t *out)
+{
+    uint32_t val = 0;
+    ssize_t ret = do_readwrite_all(fd, &val, sizeof(val), false);
+    *out = GUINT32_FROM_LE(val);
+    return ret;
+}
+#define read_u32(fd, out) \
+    g_assert_cmpint(read_u32_err(fd, out), ==, sizeof(uint32_t))
+
+static ssize_t
+write_u32_err(int fd, uint32_t val)
+{
+    val = GUINT32_TO_LE(val);
+    return do_readwrite_all(fd, &val, sizeof(val), true);
+}
+
+#define write_u32(fd, val) \
+    g_assert_cmpint(write_u32_err(fd, val), ==, sizeof(uint32_t))
+
+/* This function is similar to g_idle_add but uses our internal Glib
+ * main context. g_idle_add uses the default main context but to make
+ * sure we can use a different main context we don't use the default
+ * one (as Qemu does) */
+static void
+idle_add(GSourceFunc func, void *arg)
+{
+    GSource *source = g_idle_source_new();
+    g_source_set_callback(source, func, NULL, NULL);
+    g_source_attach(source, basic_event_loop_get_context());
+    g_source_unref(source);
+}
+
+static void *
+client_emulator(void *arg)
+{
+    int sock = GPOINTER_TO_INT(arg);
+
+    // send initial message
+    write_all(sock, &initial_message, sizeof(initial_message));
+
+    // server replies link ack with rsa, etc, similar to above beside
+    // fixed fields
+    struct {
+        SpiceLinkHeader header;
+        SpiceLinkReply ack;
+    } msg;
+    SPICE_VERIFY(sizeof(msg) == sizeof(SpiceLinkHeader) + sizeof(SpiceLinkReply));
+    read_all(sock, &msg, sizeof(msg));
+    uint32_t num_caps = GUINT32_FROM_LE(msg.ack.num_common_caps) +
+                        GUINT32_FROM_LE(msg.ack.num_channel_caps);
+    while (num_caps-- > 0) {
+        uint32_t cap;
+        read_all(sock, &cap, sizeof(cap));
+    }
+
+    // client have to send a SpiceLinkAuthMechanism (just uint32 with
+    // mech SPICE_COMMON_CAP_AUTH_SASL)
+    write_u32(sock, SPICE_COMMON_CAP_AUTH_SASL);
+
+    // sasl finally start, data starts from server (mech list)
+    //
+    uint32_t mechlen;
+    read_u32(sock, &mechlen);
+    char buf[300];
+    g_assert_cmpint(mechlen, <=, sizeof(buf));
+    read_all(sock, buf, mechlen);
+
+    // mech name
+    write_u32(sock, 3);
+    write_all(sock, "ONE", 3);
+
+    // first challenge
+    write_u32(sock, 5);
+    write_all(sock, "START", 5);
+
+    shutdown(sock, SHUT_RDWR);
+    close(sock);
+
+    idle_add(idle_end_test, NULL);
+
+    return NULL;
+}
+
+static pthread_t
+setup_thread(void)
+{
+    int sv[2];
+    g_assert_cmpint(socketpair(AF_LOCAL, SOCK_STREAM, 0, sv), ==, 0);
+
+    g_assert(spice_server_add_client(server, sv[0], 0) == 0);
+
+    pthread_t thread;
+    g_assert_cmpint(pthread_create(&thread, NULL, client_emulator, GINT_TO_POINTER(sv[1])), ==, 0);
+    return thread;
+}
+
+// called when the next test has to be run
+static gboolean
+idle_end_test(void *arg)
+{
+    basic_event_loop_quit();
+
+    return G_SOURCE_REMOVE;
+}
+
+static void
+sasl_mechs(void)
+{
+    start_test();
+
+    pthread_t thread = setup_thread();
+    alarm(4);
+    basic_event_loop_mainloop();
+    g_assert_cmpint(pthread_join(thread, NULL), ==, 0);
+    alarm(0);
+    g_assert(encode_called);
+    reset_test();
+
+    end_tests();
+}
+
 int
 main(int argc, char *argv[])
 {
+    sasl_mechs();
+
     return 0;
 }
commit 9aa26056117601fe05212f45d7c663d30a930aa1
Author: Frediano Ziglio <fziglio at redhat.com>
Date:   Tue Dec 12 17:20:39 2017 +0000

    test-sasl: Add code to mocking functions to test state
    
    Check some functions are called in a given sequence.
    
    Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
    Acked-by: Christophe Fergeau <cfergeau at redhat.com>

diff --git a/server/tests/test-sasl.c b/server/tests/test-sasl.c
index 0e4c633c..85332974 100644
--- a/server/tests/test-sasl.c
+++ b/server/tests/test-sasl.c
@@ -29,6 +29,10 @@
 #include "basic-event-loop.h"
 
 static char *mechlist;
+static bool mechlist_called;
+static bool start_called;
+static bool step_called;
+static bool encode_called;
 
 static void
 check_sasl_conn(sasl_conn_t *conn)
@@ -49,6 +53,8 @@ sasl_decode(sasl_conn_t *conn,
             const char **output, unsigned *outputlen)
 {
     check_sasl_conn(conn);
+    g_assert(start_called);
+
     return SASL_NOTDONE;
 }
 
@@ -58,6 +64,9 @@ sasl_encode(sasl_conn_t *conn,
             const char **output, unsigned *outputlen)
 {
     check_sasl_conn(conn);
+    g_assert(start_called);
+
+    encode_called = true;
     return SASL_NOTDONE;
 }
 
@@ -139,6 +148,10 @@ sasl_listmech(sasl_conn_t *conn,
     g_assert_nonnull(prefix);
     g_assert_nonnull(sep);
     g_assert_nonnull(suffix);
+    g_assert(!mechlist_called);
+    g_assert(!start_called);
+    g_assert(!step_called);
+    mechlist_called = true;
 
     g_free(mechlist);
     mechlist = g_strjoin("", prefix, "ONE", sep, "TWO", sep, "THREE", suffix, NULL);
@@ -156,6 +169,10 @@ sasl_server_start(sasl_conn_t *conn,
 {
     check_sasl_conn(conn);
     g_assert_nonnull(serverout);
+    g_assert(mechlist_called);
+    g_assert(!start_called);
+    g_assert(!step_called);
+    start_called = true;
 
     *serverout = "foo";
     *serveroutlen = 3;
@@ -171,6 +188,8 @@ sasl_server_step(sasl_conn_t *conn,
 {
     check_sasl_conn(conn);
     g_assert_nonnull(serverout);
+    g_assert(start_called);
+    step_called = true;
 
     *serverout = "foo";
     *serveroutlen = 3;
commit aeb8bbe5ac88b1f56f2718222904e322252b09d1
Author: Frediano Ziglio <fziglio at redhat.com>
Date:   Tue Dec 12 17:20:39 2017 +0000

    test-sasl: Initial SASL test
    
    Not currently working, is defining SASL functions used by the code.
    As the symbols defined in the objects have more priority than the ones
    defined by the libraries these function take precedence compared to
    system library.
    
    Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
    Acked-by: Christophe Fergeau <cfergeau at redhat.com>

diff --git a/configure.ac b/configure.ac
index 62d1a020..9e4868bc 100644
--- a/configure.ac
+++ b/configure.ac
@@ -133,6 +133,7 @@ AX_VALGRIND_CHECK
 
 SPICE_CHECK_LZ4
 SPICE_CHECK_SASL
+AM_CONDITIONAL(HAVE_SASL, test "x$have_sasl" = "xyes")
 
 dnl =========================================================================
 dnl Check deps
diff --git a/server/tests/.gitignore b/server/tests/.gitignore
index ad3c532f..2890cfa5 100644
--- a/server/tests/.gitignore
+++ b/server/tests/.gitignore
@@ -26,5 +26,6 @@ test-two-servers
 test-vdagent
 test-gst
 test-leaks
+test-sasl
 /test-*.log
 /test-*.trs
diff --git a/server/tests/Makefile.am b/server/tests/Makefile.am
index 8e1e479e..43b58adb 100644
--- a/server/tests/Makefile.am
+++ b/server/tests/Makefile.am
@@ -141,3 +141,7 @@ test_gst_CPPFLAGS = \
 endif
 
 EXTRA_DIST += video-encoders
+
+if HAVE_SASL
+check_PROGRAMS += test-sasl
+endif
diff --git a/server/tests/test-sasl.c b/server/tests/test-sasl.c
new file mode 100644
index 00000000..0e4c633c
--- /dev/null
+++ b/server/tests/test-sasl.c
@@ -0,0 +1,184 @@
+/* -*- Mode: C; c-basic-offset: 4; indent-tabs-mode: nil -*- */
+/*
+   Copyright (C) 2017 Red Hat, Inc.
+
+   This library is free software; you can redistribute it and/or
+   modify it under the terms of the GNU Lesser General Public
+   License as published by the Free Software Foundation; either
+   version 2.1 of the License, or (at your option) any later version.
+
+   This library is distributed in the hope that it will be useful,
+   but WITHOUT ANY WARRANTY; without even the implied warranty of
+   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
+   Lesser General Public License for more details.
+
+   You should have received a copy of the GNU Lesser General Public
+   License along with this library; if not, see <http://www.gnu.org/licenses/>.
+*/
+/*
+ * Test SASL connections.
+ */
+#include <config.h>
+
+#include <unistd.h>
+#include <spice.h>
+#include <stdbool.h>
+#include <sasl/sasl.h>
+
+#include "test-glib-compat.h"
+#include "basic-event-loop.h"
+
+static char *mechlist;
+
+static void
+check_sasl_conn(sasl_conn_t *conn)
+{
+    g_assert_nonnull(conn);
+}
+
+int
+sasl_server_init(const sasl_callback_t *callbacks, const char *appname)
+{
+    g_assert_null(callbacks);
+    return SASL_OK;
+}
+
+int
+sasl_decode(sasl_conn_t *conn,
+            const char *input, unsigned inputlen,
+            const char **output, unsigned *outputlen)
+{
+    check_sasl_conn(conn);
+    return SASL_NOTDONE;
+}
+
+int
+sasl_encode(sasl_conn_t *conn,
+            const char *input, unsigned inputlen,
+            const char **output, unsigned *outputlen)
+{
+    check_sasl_conn(conn);
+    return SASL_NOTDONE;
+}
+
+const char *
+sasl_errdetail(sasl_conn_t *conn)
+{
+    check_sasl_conn(conn);
+    return "XXX";
+}
+
+const char *
+sasl_errstring(int saslerr,
+               const char *langlist,
+               const char **outlang)
+{
+    return "YYY";
+}
+
+int
+sasl_getprop(sasl_conn_t *conn, int propnum,
+             const void **pvalue)
+{
+    check_sasl_conn(conn);
+    g_assert_nonnull(pvalue);
+
+    if (propnum == SASL_SSF) {
+        static const int val = 64;
+        *pvalue = &val;
+    }
+    return SASL_OK;
+}
+
+int
+sasl_setprop(sasl_conn_t *conn,
+             int propnum,
+             const void *value)
+{
+    check_sasl_conn(conn);
+    g_assert(value);
+    return SASL_OK;
+}
+
+int
+sasl_server_new(const char *service,
+                const char *serverFQDN,
+                const char *user_realm,
+                const char *iplocalport,
+                const char *ipremoteport,
+                const sasl_callback_t *callbacks,
+                unsigned flags,
+                sasl_conn_t **pconn)
+{
+    g_assert_nonnull(pconn);
+    g_assert_null(callbacks);
+
+    *pconn = GUINT_TO_POINTER(0xdeadbeef);
+    return SASL_OK;
+}
+
+void
+sasl_dispose(sasl_conn_t **pconn)
+{
+    g_assert_nonnull(pconn);
+    check_sasl_conn(*pconn);
+}
+
+int
+sasl_listmech(sasl_conn_t *conn,
+              const char *user,
+              const char *prefix,
+              const char *sep,
+              const char *suffix,
+              const char **result,
+              unsigned *plen,
+              int *pcount)
+{
+    check_sasl_conn(conn);
+    g_assert_nonnull(result);
+    g_assert_nonnull(prefix);
+    g_assert_nonnull(sep);
+    g_assert_nonnull(suffix);
+
+    g_free(mechlist);
+    mechlist = g_strjoin("", prefix, "ONE", sep, "TWO", sep, "THREE", suffix, NULL);
+    *result = mechlist;
+    return SASL_OK;
+}
+
+int
+sasl_server_start(sasl_conn_t *conn,
+                  const char *mech,
+                  const char *clientin,
+                  unsigned clientinlen,
+                  const char **serverout,
+                  unsigned *serveroutlen)
+{
+    check_sasl_conn(conn);
+    g_assert_nonnull(serverout);
+
+    *serverout = "foo";
+    *serveroutlen = 3;
+    return SASL_OK;
+}
+
+int
+sasl_server_step(sasl_conn_t *conn,
+                 const char *clientin,
+                 unsigned clientinlen,
+                 const char **serverout,
+                 unsigned *serveroutlen)
+{
+    check_sasl_conn(conn);
+    g_assert_nonnull(serverout);
+
+    *serverout = "foo";
+    *serveroutlen = 3;
+    return SASL_OK;
+}
+
+int
+main(int argc, char *argv[])
+{
+    return 0;
+}


More information about the Spice-commits mailing list