[Spice-devel] [PATCH spice-server 16/16] red-stream: Encapsulate all authentication state in RedSASLAuth

Frediano Ziglio fziglio at redhat.com
Thu Dec 14 10:07:48 UTC 2017


Instead of having half state in RedSASL and half in RedSASLAuth
move everything in RedSASLAuth. This also reduces memory usage
when we are using SASL but we finish the authentication step.

Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
---
 server/red-stream.c | 135 +++++++++++++++++++++++++++-------------------------
 1 file changed, 70 insertions(+), 65 deletions(-)

diff --git a/server/red-stream.c b/server/red-stream.c
index e498c4ff8..cc449036b 100644
--- a/server/red-stream.c
+++ b/server/red-stream.c
@@ -67,13 +67,6 @@ typedef struct RedSASL {
     unsigned int encodedOffset;
 
     SpiceBuffer inbuffer;
-
-    const char *mechlist;
-    char *mechname;
-
-    /* temporary data during authentication */
-    unsigned int len;
-    char *data;
 } RedSASL;
 #endif
 
@@ -351,11 +344,8 @@ void red_stream_free(RedStream *s)
 #if HAVE_SASL
     if (s->priv->sasl.conn) {
         s->priv->sasl.runSSF = s->priv->sasl.wantSSF = 0;
-        s->priv->sasl.len = 0;
         s->priv->sasl.encodedLength = s->priv->sasl.encodedOffset = 0;
         s->priv->sasl.encoded = NULL;
-        g_free(s->priv->sasl.mechname);
-        g_free(s->priv->sasl.data);
         sasl_dispose(&s->priv->sasl.conn);
         s->priv->sasl.conn = NULL;
     }
@@ -733,17 +723,35 @@ static int auth_sasl_check_ssf(RedSASL *sasl, int *runSSF)
 
 typedef struct RedSASLAuth {
     RedStream *stream;
+    // list of mech allowed, allocated and freed by SASL
+    const char *mechlist;
+    // mech received
+    char *mechname;
+    uint32_t len;
+    char *data;
+    // callback to call if success
     RedSaslResult result_cb;
     void *result_opaque;
+    // saved Async callback, we need to call if failed as
+    // we need to chain it in order to use a different opaque data
     AsyncReadError saved_error;
 } RedSASLAuth;
 
+static void red_sasl_async_deinit(RedSASLAuth *opaque)
+{
+    g_free(opaque->data);
+    opaque->data = NULL;
+    g_free(opaque->mechname);
+    opaque->mechname = NULL;
+    opaque->stream->priv->async_read.error = opaque->saved_error;
+}
+
 // handle SASL termination, either success or error
 // NOTE: After this function is called usually there should be a
 // return or the function should exit
 static void red_sasl_async_result(RedSASLAuth *auth, RedSaslError err)
 {
-    auth->stream->priv->async_read.error = auth->saved_error;
+    red_sasl_async_deinit(auth);
     auth->result_cb(auth->result_opaque, err);
     g_free(auth);
 }
@@ -751,7 +759,7 @@ static void red_sasl_async_result(RedSASLAuth *auth, RedSaslError err)
 static void red_sasl_error(void *opaque, int err)
 {
     RedSASLAuth *auth = opaque;
-    auth->stream->priv->async_read.error = auth->saved_error;
+    red_sasl_async_deinit(auth);
     if (auth->saved_error) {
         auth->saved_error(auth->result_opaque, err);
     }
@@ -794,32 +802,33 @@ static void red_sasl_handle_auth_steplen(void *opaque);
 
 static void red_sasl_handle_auth_step(void *opaque)
 {
-    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
+    RedSASLAuth *auth = opaque;
+    RedStream *stream = auth->stream;
     const char *serverout;
     unsigned int serveroutlen;
     int err;
     char *clientdata = NULL;
     RedSASL *sasl = &stream->priv->sasl;
-    uint32_t datalen = sasl->len;
+    uint32_t datalen = auth->len;
 
     /* NB, distinction of NULL vs "" is *critical* in SASL */
     if (datalen) {
-        clientdata = sasl->data;
+        clientdata = auth->data;
         clientdata[datalen - 1] = '\0'; /* Wire includes '\0', but make sure */
         datalen--; /* Don't count NULL byte when passing to _start() */
     }
 
-    if (sasl->mechname != NULL) {
+    if (auth->mechname != NULL) {
         spice_debug("Start SASL auth with mechanism %s. Data %p (%d bytes)",
-                   sasl->mechname, clientdata, datalen);
+                    auth->mechname, clientdata, datalen);
         err = sasl_server_start(sasl->conn,
-                                sasl->mechname,
+                                auth->mechname,
                                 clientdata,
                                 datalen,
                                 &serverout,
                                 &serveroutlen);
-        g_free(sasl->mechname);
-        sasl->mechname = NULL;
+        g_free(auth->mechname);
+        auth->mechname = NULL;
     } else {
         spice_debug("Step using SASL Data %p (%d bytes)", clientdata, datalen);
         err = sasl_server_step(sasl->conn,
@@ -832,13 +841,13 @@ static void red_sasl_handle_auth_step(void *opaque)
         err != SASL_CONTINUE) {
         spice_warning("sasl step failed %d (%s)",
                     err, sasl_errdetail(sasl->conn));
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
+        return red_sasl_async_result(auth, RED_SASL_ERROR_GENERIC);
     }
 
     if (serveroutlen > SASL_DATA_MAX_LEN) {
         spice_warning("sasl step reply data too long %d",
                       serveroutlen);
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA);
+        return red_sasl_async_result(auth, RED_SASL_ERROR_INVALID_DATA);
     }
 
     spice_debug("SASL return data %d bytes, %p", serveroutlen, serverout);
@@ -857,8 +866,8 @@ static void red_sasl_handle_auth_step(void *opaque)
     if (err == SASL_CONTINUE) {
         spice_debug("%s", "Authentication must continue");
         /* Wait for step length */
-        red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                              red_sasl_handle_auth_steplen, opaque);
+        red_stream_async_read(stream, (uint8_t *)&auth->len, sizeof(uint32_t),
+                              red_sasl_handle_auth_steplen, auth);
         return;
     } else {
         int ssf;
@@ -877,7 +886,7 @@ static void red_sasl_handle_auth_step(void *opaque)
         sasl->runSSF = ssf;
         red_stream_disable_writev(stream); /* make sure writev isn't called directly anymore */
 
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_OK);
+        return red_sasl_async_result(auth, RED_SASL_ERROR_OK);
     }
 
 authreject:
@@ -885,73 +894,71 @@ authreject:
     red_stream_write_u32(stream, sizeof("Authentication failed"));
     red_stream_write_all(stream, "Authentication failed", sizeof("Authentication failed"));
 
-    red_sasl_async_result(opaque, RED_SASL_ERROR_AUTH_FAILED);
+    red_sasl_async_result(auth, RED_SASL_ERROR_AUTH_FAILED);
 }
 
 static void red_sasl_handle_auth_steplen(void *opaque)
 {
-    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
-    RedSASL *sasl = &stream->priv->sasl;
+    RedSASLAuth *auth = opaque;
 
-    sasl->len = GUINT32_FROM_LE(sasl->len);
-    spice_debug("Got steplen %d", sasl->len);
-    if (sasl->len > SASL_DATA_MAX_LEN) {
-        spice_warning("Too much SASL data %d", sasl->len);
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA);
+    auth->len = GUINT32_FROM_LE(auth->len);
+    uint32_t len = auth->len;
+    spice_debug("Got steplen %d", len);
+    if (len > SASL_DATA_MAX_LEN) {
+        spice_warning("Too much SASL data %d", len);
+        return red_sasl_async_result(auth, RED_SASL_ERROR_INVALID_DATA);
     }
 
-    if (sasl->len == 0) {
-        return red_sasl_handle_auth_step(opaque);
+    if (len == 0) {
+        return red_sasl_handle_auth_step(auth);
     }
 
-    sasl->data = g_realloc(sasl->data, sasl->len);
-    red_stream_async_read(stream, (uint8_t *)sasl->data, sasl->len,
-                          red_sasl_handle_auth_step, opaque);
+    auth->data = g_realloc(auth->data, len);
+    red_stream_async_read(auth->stream, (uint8_t *)auth->data, len,
+                          red_sasl_handle_auth_step, auth);
 }
 
 
 
 static void red_sasl_handle_auth_mechname(void *opaque)
 {
-    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
-    RedSASL *sasl = &stream->priv->sasl;
+    RedSASLAuth *auth = opaque;
 
-    sasl->mechname[sasl->len] = '\0';
+    auth->mechname[auth->len] = '\0';
     spice_debug("Got client mechname '%s' check against '%s'",
-               sasl->mechname, sasl->mechlist);
+                auth->mechname, auth->mechlist);
 
     char quoted_mechname[SASL_MAX_MECHNAME_LEN + 4];
-    sprintf(quoted_mechname, ",%s,", sasl->mechname);
+    sprintf(quoted_mechname, ",%s,", auth->mechname);
 
-    if (strchr(sasl->mechname, ',') || !strstr(sasl->mechlist, quoted_mechname)) {
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA);
+    if (strchr(auth->mechname, ',') || !strstr(auth->mechlist, quoted_mechname)) {
+        return red_sasl_async_result(auth, RED_SASL_ERROR_INVALID_DATA);
     }
 
-    spice_debug("Validated mechname '%s'", sasl->mechname);
+    spice_debug("Validated mechname '%s'", auth->mechname);
 
-    red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                          red_sasl_handle_auth_steplen, opaque);
+    red_stream_async_read(auth->stream, (uint8_t *)&auth->len, sizeof(uint32_t),
+                          red_sasl_handle_auth_steplen, auth);
 }
 
 static void red_sasl_handle_auth_mechlen(void *opaque)
 {
-    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
-    RedSASL *sasl = &stream->priv->sasl;
+    RedSASLAuth *auth = opaque;
 
-    sasl->len = GUINT32_FROM_LE(sasl->len);
-    if (sasl->len < 1 || sasl->len > SASL_MAX_MECHNAME_LEN) {
-        spice_warning("Got bad client mechname len %d", sasl->len);
-        return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA);
+    uint32_t len = GUINT32_FROM_LE(auth->len);
+    if (len < 1 || len > SASL_MAX_MECHNAME_LEN) {
+        spice_warning("Got bad client mechname len %d", auth->len);
+        return red_sasl_async_result(auth, RED_SASL_ERROR_INVALID_DATA);
     }
 
-    sasl->mechname = g_malloc(sasl->len + 1);
+    auth->mechname = g_malloc(len + 1);
 
     spice_debug("Wait for client mechname");
-    red_stream_async_read(stream, (uint8_t *)sasl->mechname, sasl->len,
-                          red_sasl_handle_auth_mechname, opaque);
+    red_stream_async_read(auth->stream, (uint8_t *)auth->mechname, auth->len,
+                          red_sasl_handle_auth_mechname, auth);
 }
 
-bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *opaque)
+bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *result_opaque)
 {
     const char *mechlist = NULL;
     sasl_security_properties_t secprops;
@@ -1045,11 +1052,9 @@ bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *opaqu
 
     spice_debug("Available mechanisms for client: '%s'", mechlist);
 
-    sasl->mechlist = mechlist;
-
     mechlistlen = strlen(mechlist);
     if (!red_stream_write_u32(stream, mechlistlen)
-        || !red_stream_write_all(stream, sasl->mechlist, mechlistlen)) {
+        || !red_stream_write_all(stream, mechlist, mechlistlen)) {
         spice_warning("SASL mechanisms write error");
         goto error;
     }
@@ -1057,18 +1062,18 @@ bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *opaqu
     auth = g_new0(RedSASLAuth, 1);
     auth->stream = stream;
     auth->result_cb = result_cb;
-    auth->result_opaque = opaque;
+    auth->result_opaque = result_opaque;
     auth->saved_error = stream->priv->async_read.error;
-    stream->priv->async_read.error = red_sasl_error;
+    auth->mechlist = mechlist;
 
     spice_debug("Wait for client mechname length");
-    red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
+    red_stream_set_async_error_handler(stream, red_sasl_error);
+    red_stream_async_read(stream, (uint8_t *)&auth->len, sizeof(uint32_t),
                           red_sasl_handle_auth_mechlen, auth);
 
     return true;
 
 error_dispose:
-    sasl->mechlist = NULL;
     sasl_dispose(&sasl->conn);
     sasl->conn = NULL;
 error:
-- 
2.14.3



More information about the Spice-devel mailing list