[Spice-devel] [PATCH spice-server v6 07/10] Handle SASL initialisation mainly in red-stream.c

Frediano Ziglio fziglio at redhat.com
Tue Jan 9 07:45:04 UTC 2018


Asynchronous code jumping from a file to another is tedious to read
also having code handling the same stuff in two files does not look
a good design.

Signed-off-by: Frediano Ziglio <fziglio at redhat.com>
---
 server/red-stream.c | 121 +++++++++++++++++++++++++++++++++-------------------
 server/red-stream.h |  11 +----
 server/reds.c       | 119 ++++++---------------------------------------------
 3 files changed, 93 insertions(+), 158 deletions(-)

diff --git a/server/red-stream.c b/server/red-stream.c
index fccad8b2..0176f0a9 100644
--- a/server/red-stream.c
+++ b/server/red-stream.c
@@ -732,6 +732,36 @@ static int auth_sasl_check_ssf(RedSASL *sasl, int *runSSF)
     return 1;
 }
 
+typedef struct RedSASLAuth {
+    RedStream *stream;
+    // callback to call if success
+    RedSaslResult result_cb;
+    void *result_opaque;
+    // saved Async callback, we need to call if failed as
+    // we need to chain it in order to use a different opaque data
+    AsyncReadError saved_error_cb;
+} RedSASLAuth;
+
+// handle SASL termination, either success or error
+// NOTE: After this function is called usually there should be a
+// return or the function should exit
+static void red_sasl_async_result(RedSASLAuth *auth, RedSaslError err)
+{
+    red_stream_set_async_error_handler(auth->stream, auth->saved_error_cb);
+    auth->result_cb(auth->result_opaque, err);
+    g_free(auth);
+}
+
+static void red_sasl_error(void *opaque, int err)
+{
+    RedSASLAuth *auth = opaque;
+    red_stream_set_async_error_handler(auth->stream, auth->saved_error_cb);
+    if (auth->saved_error_cb) {
+        auth->saved_error_cb(auth->result_opaque, err);
+    }
+    g_free(auth);
+}
+
 /*
  * Step Msg
  *
@@ -749,8 +779,11 @@ static int auth_sasl_check_ssf(RedSASL *sasl, int *runSSF)
 #define SASL_MAX_MECHNAME_LEN 100
 #define SASL_DATA_MAX_LEN (1024 * 1024)
 
-RedSaslError red_sasl_handle_auth_step(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_steplen(void *opaque);
+
+static void red_sasl_handle_auth_step(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     const char *serverout;
     unsigned int serveroutlen;
     int err;
@@ -776,13 +809,13 @@ RedSaslError red_sasl_handle_auth_step(RedStream *stream, AsyncReadDone read_cb,
         err != SASL_CONTINUE) {
         spice_warning("sasl step failed %d (%s)",
                       err, sasl_errdetail(sasl->conn));
-        return RED_SASL_ERROR_GENERIC;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     if (serveroutlen > SASL_DATA_MAX_LEN) {
         spice_warning("sasl step reply data too long %d",
                       serveroutlen);
-        return RED_SASL_ERROR_INVALID_DATA;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     spice_debug("SASL return data %d bytes, %p", serveroutlen, serverout);
@@ -802,8 +835,8 @@ RedSaslError red_sasl_handle_auth_step(RedStream *stream, AsyncReadDone read_cb,
         spice_debug("%s", "Authentication must continue (step)");
         /* Wait for step length */
         red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                              read_cb, opaque);
-        return RED_SASL_ERROR_CONTINUE;
+                              red_sasl_handle_auth_steplen, opaque);
+        return;
     } else {
         int ssf;
 
@@ -821,7 +854,7 @@ RedSaslError red_sasl_handle_auth_step(RedStream *stream, AsyncReadDone read_cb,
         sasl->runSSF = ssf;
         red_stream_disable_writev(stream); /* make sure writev isn't called directly anymore */
 
-        return RED_SASL_ERROR_OK;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_OK);
     }
 
 authreject:
@@ -829,31 +862,26 @@ authreject:
     red_stream_write_u32(stream, sizeof("Authentication failed"));
     red_stream_write_all(stream, "Authentication failed", sizeof("Authentication failed"));
 
-    return RED_SASL_ERROR_AUTH_FAILED;
+    red_sasl_async_result(opaque, RED_SASL_ERROR_AUTH_FAILED);
 }
 
-RedSaslError red_sasl_handle_auth_steplen(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_steplen(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
     spice_debug("Got steplen %d", sasl->len);
     if (sasl->len > SASL_DATA_MAX_LEN) {
         spice_warning("Too much SASL data %d", sasl->len);
-        return RED_SASL_ERROR_INVALID_DATA;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     if (sasl->len == 0) {
-        read_cb(opaque);
-        /* FIXME: can't report potential errors correctly here,
-         * but read_cb() will have done the needed RedLinkInfo cleanups
-         * if an error occurs, so the caller should not need to do more
-         * treatment */
-        return RED_SASL_ERROR_OK;
+        red_sasl_handle_auth_step(opaque);
     } else {
         sasl->data = g_realloc(sasl->data, sasl->len);
         red_stream_async_read(stream, (uint8_t *)sasl->data, sasl->len,
-                              read_cb, opaque);
-        return RED_SASL_ERROR_OK;
+                              red_sasl_handle_auth_step, opaque);
     }
 }
 
@@ -872,8 +900,9 @@ RedSaslError red_sasl_handle_auth_steplen(RedStream *stream, AsyncReadDone read_
  * u8 continue
  */
 
-RedSaslError red_sasl_handle_auth_start(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_start(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     const char *serverout;
     unsigned int serveroutlen;
     int err;
@@ -900,13 +929,13 @@ RedSaslError red_sasl_handle_auth_start(RedStream *stream, AsyncReadDone read_cb
         err != SASL_CONTINUE) {
         spice_warning("sasl start failed %d (%s)",
                     err, sasl_errdetail(sasl->conn));
-        return RED_SASL_ERROR_INVALID_DATA;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     if (serveroutlen > SASL_DATA_MAX_LEN) {
         spice_warning("sasl start reply data too long %d",
-                    serveroutlen);
-        return RED_SASL_ERROR_INVALID_DATA;
+                      serveroutlen);
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     spice_debug("SASL return data %d bytes, %p", serveroutlen, serverout);
@@ -926,8 +955,8 @@ RedSaslError red_sasl_handle_auth_start(RedStream *stream, AsyncReadDone read_cb
         spice_debug("%s", "Authentication must continue (start)");
         /* Wait for step length */
         red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                              read_cb, opaque);
-        return RED_SASL_ERROR_CONTINUE;
+                              red_sasl_handle_auth_steplen, opaque);
+        return;
     } else {
         int ssf;
 
@@ -944,7 +973,8 @@ RedSaslError red_sasl_handle_auth_start(RedStream *stream, AsyncReadDone read_cb
          */
         sasl->runSSF = ssf;
         red_stream_disable_writev(stream); /* make sure writev isn't called directly anymore */
-        return RED_SASL_ERROR_OK;
+
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_OK);
     }
 
 authreject:
@@ -952,32 +982,32 @@ authreject:
     red_stream_write_u32(stream, sizeof("Authentication failed"));
     red_stream_write_all(stream, "Authentication failed", sizeof("Authentication failed"));
 
-    return RED_SASL_ERROR_AUTH_FAILED;
+    red_sasl_async_result(opaque, RED_SASL_ERROR_AUTH_FAILED);
 }
 
-RedSaslError red_sasl_handle_auth_startlen(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_startlen(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
     spice_debug("Got client start len %d", sasl->len);
     if (sasl->len > SASL_DATA_MAX_LEN) {
         spice_warning("Too much SASL data %d", sasl->len);
-        return RED_SASL_ERROR_INVALID_DATA;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA);
     }
 
     if (sasl->len == 0) {
-        return RED_SASL_ERROR_RETRY;
+        return red_sasl_handle_auth_start(opaque);
     }
 
     sasl->data = g_realloc(sasl->data, sasl->len);
     red_stream_async_read(stream, (uint8_t *)sasl->data, sasl->len,
-                          read_cb, opaque);
-
-    return RED_SASL_ERROR_OK;
+                          red_sasl_handle_auth_start, opaque);
 }
 
-bool red_sasl_handle_auth_mechname(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_mechname(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
     sasl->mechname[sasl->len] = '\0';
@@ -988,36 +1018,33 @@ bool red_sasl_handle_auth_mechname(RedStream *stream, AsyncReadDone read_cb, voi
     sprintf(quoted_mechname, ",%s,", sasl->mechname);
 
     if (strchr(sasl->mechname, ',') || strstr(sasl->mechlist, quoted_mechname) == NULL) {
-        return false;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_INVALID_DATA);
     }
 
     spice_debug("Validated mechname '%s'", sasl->mechname);
 
     red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                          read_cb, opaque);
-
-    return true;
+                          red_sasl_handle_auth_startlen, opaque);
 }
 
-bool red_sasl_handle_auth_mechlen(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+static void red_sasl_handle_auth_mechlen(void *opaque)
 {
+    RedStream *stream = ((RedSASLAuth *)opaque)->stream;
     RedSASL *sasl = &stream->priv->sasl;
 
     if (sasl->len < 1 || sasl->len > SASL_MAX_MECHNAME_LEN) {
         spice_warning("Got bad client mechname len %d", sasl->len);
-        return false;
+        return red_sasl_async_result(opaque, RED_SASL_ERROR_GENERIC);
     }
 
     sasl->mechname = g_malloc(sasl->len + 1);
 
     spice_debug("Wait for client mechname");
     red_stream_async_read(stream, (uint8_t *)sasl->mechname, sasl->len,
-                          read_cb, opaque);
-
-    return true;
+                          red_sasl_handle_auth_mechname, opaque);
 }
 
-bool red_sasl_start_auth(RedStream *stream, AsyncReadDone read_cb, void *opaque)
+bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *result_opaque)
 {
     const char *mechlist = NULL;
     sasl_security_properties_t secprops;
@@ -1025,6 +1052,7 @@ bool red_sasl_start_auth(RedStream *stream, AsyncReadDone read_cb, void *opaque)
     char *localAddr, *remoteAddr;
     int mechlistlen;
     RedSASL *sasl = &stream->priv->sasl;
+    RedSASLAuth *auth;
 
     if (!(localAddr = red_stream_get_local_address(stream))) {
         goto error;
@@ -1119,9 +1147,16 @@ bool red_sasl_start_auth(RedStream *stream, AsyncReadDone read_cb, void *opaque)
         goto error;
     }
 
+    auth = g_new0(RedSASLAuth, 1);
+    auth->stream = stream;
+    auth->result_cb = result_cb;
+    auth->result_opaque = result_opaque;
+    auth->saved_error_cb = stream->priv->async_read.error;
+    red_stream_set_async_error_handler(stream, red_sasl_error);
+
     spice_debug("Wait for client mechname length");
     red_stream_async_read(stream, (uint8_t *)&sasl->len, sizeof(uint32_t),
-                          read_cb, opaque);
+                          red_sasl_handle_auth_mechlen, auth);
 
     return true;
 
diff --git a/server/red-stream.h b/server/red-stream.h
index a8d855c2..4d5075ed 100644
--- a/server/red-stream.h
+++ b/server/red-stream.h
@@ -73,17 +73,10 @@ typedef enum {
     RED_SASL_ERROR_OK,
     RED_SASL_ERROR_GENERIC,
     RED_SASL_ERROR_INVALID_DATA,
-    RED_SASL_ERROR_RETRY,
-    RED_SASL_ERROR_CONTINUE,
     RED_SASL_ERROR_AUTH_FAILED
 } RedSaslError;
 
-RedSaslError red_sasl_handle_auth_step(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-RedSaslError red_sasl_handle_auth_steplen(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-RedSaslError red_sasl_handle_auth_start(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-RedSaslError red_sasl_handle_auth_startlen(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-bool red_sasl_handle_auth_mechname(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-bool red_sasl_handle_auth_mechlen(RedStream *stream, AsyncReadDone read_cb, void *opaque);
-bool red_sasl_start_auth(RedStream *stream, AsyncReadDone read_cb, void *opaque);
+typedef void (*RedSaslResult)(void *opaque, RedSaslError err);
+bool red_sasl_start_auth(RedStream *stream, RedSaslResult result_cb, void *opaque);
 
 #endif /* RED_STREAM_H_ */
diff --git a/server/reds.c b/server/reds.c
index 9338b78b..777f5ba4 100644
--- a/server/reds.c
+++ b/server/reds.c
@@ -2100,123 +2100,30 @@ static void reds_get_spice_ticket(RedLinkInfo *link)
 }
 
 #if HAVE_SASL
-/*
- * Step Msg
- *
- * Input from client:
- *
- * u32 clientin-length
- * u8-array clientin-string
- *
- * Output to client:
- *
- * u32 serverout-length
- * u8-array serverout-strin
- * u8 continue
- */
-
-static void reds_handle_auth_sasl_steplen(void *opaque);
-
-static void reds_handle_auth_sasl_step(void *opaque)
-{
-    RedLinkInfo *link = (RedLinkInfo *)opaque;
-    RedSaslError status;
-
-    status = red_sasl_handle_auth_step(link->stream, reds_handle_auth_sasl_steplen, link);
-    if (status == RED_SASL_ERROR_OK) {
-        reds_handle_link(link);
-    } else if (status != RED_SASL_ERROR_CONTINUE) {
-        reds_link_free(link);
-    }
-}
-
-static void reds_handle_auth_sasl_steplen(void *opaque)
-{
-    RedLinkInfo *link = (RedLinkInfo *)opaque;
-    RedSaslError status;
-
-    status = red_sasl_handle_auth_steplen(link->stream, reds_handle_auth_sasl_step, link);
-    if (status != RED_SASL_ERROR_OK) {
-        reds_link_free(link);
-    }
-}
-
-/*
- * Start Msg
- *
- * Input from client:
- *
- * u32 clientin-length
- * u8-array clientin-string
- *
- * Output to client:
- *
- * u32 serverout-length
- * u8-array serverout-strin
- * u8 continue
- */
-
-
-static void reds_handle_auth_sasl_start(void *opaque)
+static void reds_handle_sasl_result(void *opaque, RedSaslError status)
 {
     RedLinkInfo *link = (RedLinkInfo *)opaque;
-    RedSaslError status;
-
-    status = red_sasl_handle_auth_start(link->stream, reds_handle_auth_sasl_steplen, link);
-    if (status == RED_SASL_ERROR_OK) {
-        reds_handle_link(link);
-    } else if (status != RED_SASL_ERROR_CONTINUE) {
-        reds_link_free(link);
-    }
-}
 
-static void reds_handle_auth_startlen(void *opaque)
-{
-    RedLinkInfo *link = (RedLinkInfo *)opaque;
-    RedSaslError status;
-
-    status = red_sasl_handle_auth_startlen(link->stream, reds_handle_auth_sasl_start, link);
     switch (status) {
-        case RED_SASL_ERROR_OK:
-            break;
-        case RED_SASL_ERROR_RETRY:
-            reds_handle_auth_sasl_start(opaque);
-            break;
-        case RED_SASL_ERROR_GENERIC:
-        case RED_SASL_ERROR_INVALID_DATA:
-            reds_send_link_error(link, SPICE_LINK_ERR_INVALID_DATA);
-            reds_link_free(link);
-            break;
-        default:
-            g_warn_if_reached();
-            reds_send_link_error(link, SPICE_LINK_ERR_INVALID_DATA);
-            reds_link_free(link);
-            break;
-    }
-}
-
-static void reds_handle_auth_mechname(void *opaque)
-{
-    RedLinkInfo *link = (RedLinkInfo *)opaque;
-
-    if (!red_sasl_handle_auth_mechname(link->stream, reds_handle_auth_startlen, link)) {
-            reds_send_link_error(link, SPICE_LINK_ERR_INVALID_DATA);
+    case RED_SASL_ERROR_OK:
+        reds_handle_link(link);
+        break;
+    case RED_SASL_ERROR_INVALID_DATA:
+        reds_send_link_error(link, SPICE_LINK_ERR_INVALID_DATA);
         reds_link_free(link);
-    }
-}
-
-static void reds_handle_auth_mechlen(void *opaque)
-{
-    RedLinkInfo *link = (RedLinkInfo *)opaque;
-
-    if (!red_sasl_handle_auth_mechlen(link->stream, reds_handle_auth_mechname, link)) {
+        break;
+    default:
+        // in these cases error was reported using SASL protocol
+        // (RED_SASL_ERROR_AUTH_FAILED) or we just need to close the
+        // connection
         reds_link_free(link);
+        break;
     }
 }
 
 static void reds_start_auth_sasl(RedLinkInfo *link)
 {
-    if (!red_sasl_start_auth(link->stream, reds_handle_auth_mechlen, link)) {
+    if (!red_sasl_start_auth(link->stream, reds_handle_sasl_result, link)) {
         reds_link_free(link);
     }
 }
-- 
2.14.3



More information about the Spice-devel mailing list