[Date Prev][Date Next][Thread Prev][Thread Next][Date Index][Thread Index]

[PATCH 07/20] chacha: packet decryption


From: Aris Adamantiadis <aris@xxxxxxxxxxxx>

---
 include/libssh/crypto.h |  4 ++
 include/libssh/packet.h |  5 ++-
 src/chachapoly.c        | 52 ++++++++++++++++++++++++++
 src/kex.c               |  2 +-
 src/packet.c            | 99 +++++++++++++++++++++----------------------------
 src/packet_crypt.c      | 91 ++++++++++++++++++++++++++++++---------------
 src/wrapper.c           | 20 ++++++++--
 7 files changed, 181 insertions(+), 92 deletions(-)

diff --git a/include/libssh/crypto.h b/include/libssh/crypto.h
index 83391b69..243c4801 100644
--- a/include/libssh/crypto.h
+++ b/include/libssh/crypto.h
@@ -156,6 +156,10 @@ struct ssh_cipher_struct {
         unsigned long len);
     void (*aead_encrypt)(struct ssh_cipher_struct *cipher, void *in, void *out,
         size_t len, uint8_t *mac, uint64_t seq);
+    int (*aead_decrypt_length)(struct ssh_cipher_struct *cipher, void *in,
+        uint8_t *out, size_t len, uint64_t seq);
+    int (*aead_decrypt)(struct ssh_cipher_struct *cipher, void *complete_packet, uint8_t *out,
+        size_t encrypted_size, uint64_t seq);
     void (*cleanup)(struct ssh_cipher_struct *cipher);
 };
 
diff --git a/include/libssh/packet.h b/include/libssh/packet.h
index 3a84eb70..b10308f7 100644
--- a/include/libssh/packet.h
+++ b/include/libssh/packet.h
@@ -78,8 +78,9 @@ void ssh_packet_set_default_callbacks(ssh_session session);
 void ssh_packet_process(ssh_session session, uint8_t type);
 
 /* PACKET CRYPT */
-uint32_t ssh_packet_decrypt_len(ssh_session session, char *crypted);
-int ssh_packet_decrypt(ssh_session session, void *packet, unsigned int len);
+uint32_t ssh_packet_decrypt_len(ssh_session session, uint8_t *destination, uint8_t *source);
+int ssh_packet_decrypt(ssh_session session, uint8_t *destination, uint8_t *source,
+        size_t start, size_t encrypted_size);
 unsigned char *ssh_packet_encrypt(ssh_session session,
                                   void *packet,
                                   unsigned int len);
diff --git a/src/chachapoly.c b/src/chachapoly.c
index eeb12da8..1d921f1e 100644
--- a/src/chachapoly.c
+++ b/src/chachapoly.c
@@ -89,8 +89,58 @@ static void chacha20_poly1305_aead_encrypt(struct ssh_cipher_struct *cipher,
     chacha_encrypt_bytes(&keys->k2, in_packet->payload, out_packet->payload,
                 len - sizeof(uint32_t));
 
+    /* ssh_print_hexa("poly1305_ctx", poly1305_ctx, sizeof(poly1305_ctx)); */
     /* step 4, compute the MAC */
     poly1305_auth(tag, (uint8_t *)out_packet, len, poly1305_ctx);
+    /* ssh_print_hexa("poly1305 src", (uint8_t *)out_packet, len);
+    ssh_print_hexa("poly1305 tag", tag, POLY1305_TAGLEN); */
+}
+
+static int chacha20_poly1305_aead_decrypt_length(struct ssh_cipher_struct *cipher,
+        void *in, uint8_t *out, size_t len, uint64_t seq){
+    struct chacha20_poly1305_keysched *keys = cipher->chacha20_schedule;
+
+    if (len < sizeof(uint32_t)){
+        return SSH_ERROR;
+    }
+    seq = htonll(seq);
+
+    chacha_ivsetup(&keys->k1, (uint8_t *)&seq, zero_block_counter);
+    chacha_encrypt_bytes(&keys->k1, in,
+            (uint8_t *)out, sizeof(uint32_t));
+    return SSH_OK;
+}
+
+static int chacha20_poly1305_aead_decrypt(struct ssh_cipher_struct *cipher,
+        void *complete_packet, uint8_t *out, size_t encrypted_size,
+        uint64_t seq){
+    uint8_t poly1305_ctx[POLY1305_KEYLEN];
+    uint8_t tag[POLY1305_TAGLEN];
+    struct chacha20_poly1305_keysched *keys = cipher->chacha20_schedule;
+    uint8_t *mac = (uint8_t *)complete_packet + sizeof(uint32_t) + encrypted_size;
+
+    seq = htonll(seq);
+
+    ZERO_STRUCT(poly1305_ctx);
+    chacha_ivsetup(&keys->k2, (uint8_t *)&seq, zero_block_counter);
+    chacha_encrypt_bytes(&keys->k2, poly1305_ctx, poly1305_ctx, POLY1305_KEYLEN);
+    /* ssh_print_hexa("poly1305_ctx", poly1305_ctx, sizeof(poly1305_ctx)); */
+
+    poly1305_auth(tag, (uint8_t *)complete_packet, encrypted_size +
+            sizeof(uint32_t), poly1305_ctx);
+    /* ssh_print_hexa("poly1305 src", (uint8_t*)complete_packet, encrypted_size + 4);
+    ssh_print_hexa("poly1305 tag", tag, POLY1305_TAGLEN);
+    ssh_print_hexa("received tag", mac, POLY1305_TAGLEN); */
+
+    if(memcmp(tag, mac, POLY1305_TAGLEN) != 0){
+        /* mac error */
+        SSH_LOG(SSH_LOG_PACKET,"poly1305 verify error");
+        return SSH_ERROR;
+    }
+    chacha_ivsetup(&keys->k2, (uint8_t *)&seq, payload_block_counter);
+    chacha_encrypt_bytes(&keys->k2, (uint8_t *)complete_packet +
+            sizeof(uint32_t), out, encrypted_size);
+    return SSH_OK;
 }
 
 const struct ssh_cipher_struct chacha20poly1305_cipher = {
@@ -103,4 +153,6 @@ const struct ssh_cipher_struct chacha20poly1305_cipher = {
     .set_encrypt_key = chacha20_set_encrypt_key,
     .set_decrypt_key = chacha20_set_encrypt_key,
     .aead_encrypt = chacha20_poly1305_aead_encrypt,
+    .aead_decrypt_length = chacha20_poly1305_aead_decrypt_length,
+    .aead_decrypt = chacha20_poly1305_aead_decrypt
 };
diff --git a/src/kex.c b/src/kex.c
index 20044c32..00f4e00f 100644
--- a/src/kex.c
+++ b/src/kex.c
@@ -120,7 +120,7 @@ static const char *supported_methods[] = {
   KEY_EXCHANGE,
   HOSTKEYS,
   CHACHA20 AES BLOWFISH DES_SUPPORTED,
-  AES BLOWFISH DES_SUPPORTED,
+  CHACHA20 AES BLOWFISH DES_SUPPORTED,
   "hmac-sha2-256,hmac-sha2-512,hmac-sha1",
   "hmac-sha2-256,hmac-sha2-512,hmac-sha1",
   ZLIB,
diff --git a/src/packet.c b/src/packet.c
index e989fa28..216b86b2 100644
--- a/src/packet.c
+++ b/src/packet.c
@@ -144,20 +144,26 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
     ssh_session session= (ssh_session) user;
     unsigned int blocksize = (session->current_crypto ?
                               session->current_crypto->in_cipher->blocksize : 8);
-    unsigned char mac[DIGEST_MAX_LEN] = {0};
-    char buffer[16] = {0};
+    unsigned int lenfield_blocksize = (session->current_crypto ?
+                                  session->current_crypto->in_cipher->lenfield_blocksize : 8);
     size_t current_macsize = 0;
-    const uint8_t *packet;
+    uint8_t *ptr;
     int to_be_read;
     int rc;
-    uint32_t len, compsize, payloadsize;
+    uint8_t *cleartext_packet;
+    uint8_t *packet_second_block;
+    uint8_t *mac;
+    size_t packet_remaining;
+    uint32_t packet_len, compsize, payloadsize;
     uint8_t padding;
     size_t processed = 0; /* number of byte processed from the callback */
 
     if(session->current_crypto != NULL) {
       current_macsize = hmac_digest_len(session->current_crypto->in_hmac);
     }
-
+    if (lenfield_blocksize == 0){
+        lenfield_blocksize = blocksize;
+    }
     if (data == NULL) {
         goto error;
     }
@@ -173,13 +179,13 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
 #endif
     switch(session->packet_state) {
         case PACKET_STATE_INIT:
-            if (receivedlen < blocksize) {
+            if (receivedlen < lenfield_blocksize) {
                 /*
                  * We didn't receive enough data to read at least one
                  * block size, give up
                  */
 #ifdef DEBUG_PACKET
-                SSH_LOG(SSH_LOG_PACKET, "Waiting for more data (%u < %u)", (unsigned)receivedlen, (unsigned)blocksize);
+                SSH_LOG(SSH_LOG_PACKET, "Waiting for more data (%u < %u)", (unsigned)receivedlen, (unsigned)lenfield_blocksize);
 #endif
                 return 0;
             }
@@ -198,24 +204,21 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
                 }
             }
 
-            memcpy(buffer, data, blocksize);
-            processed += blocksize;
-            len = ssh_packet_decrypt_len(session, buffer);
-
-            rc = ssh_buffer_add_data(session->in_buffer, buffer, blocksize);
-            if (rc < 0) {
+            ptr = ssh_buffer_allocate(session->in_buffer, lenfield_blocksize);
+            if (ptr == NULL) {
                 goto error;
             }
+            processed += lenfield_blocksize;
+            packet_len = ssh_packet_decrypt_len(session, ptr, (uint8_t *)data);
 
-            if (len > MAX_PACKET_LEN) {
+            if (packet_len > MAX_PACKET_LEN) {
                 ssh_set_error(session,
                               SSH_FATAL,
                               "read_packet(): Packet len too high(%u %.4x)",
-                              len, len);
+                              packet_len, packet_len);
                 goto error;
             }
-
-            to_be_read = len - blocksize + sizeof(uint32_t);
+            to_be_read = packet_len - lenfield_blocksize + sizeof(uint32_t);
             if (to_be_read < 0) {
                 /* remote sshd sends invalid sizes? */
                 ssh_set_error(session,
@@ -225,59 +228,41 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
                 goto error;
             }
 
-            /* Saves the status of the current operations */
-            session->in_packet.len = len;
+            session->in_packet.len = packet_len;
             session->packet_state = PACKET_STATE_SIZEREAD;
             FALL_THROUGH;
         case PACKET_STATE_SIZEREAD:
-            len = session->in_packet.len;
-            to_be_read = len - blocksize + sizeof(uint32_t) + current_macsize;
+            packet_len = session->in_packet.len;
+            processed = lenfield_blocksize;
+            to_be_read = packet_len + sizeof(uint32_t) + current_macsize;
             /* if to_be_read is zero, the whole packet was blocksize bytes. */
             if (to_be_read != 0) {
-                if (receivedlen - processed < (unsigned int)to_be_read) {
+                if (receivedlen  < (unsigned int)to_be_read) {
                     /* give up, not enough data in buffer */
-                    SSH_LOG(SSH_LOG_PACKET,"packet: partial packet (read len) [len=%d]",len);
-                    return processed;
+                    SSH_LOG(SSH_LOG_PACKET,"packet: partial packet (read len) [len=%d, receivedlen=%d, to_be_read=%d]",packet_len, (int)receivedlen, to_be_read);
+                    return 0;
                 }
 
-                packet = ((uint8_t*)data) + processed;
-#if 0
-                ssh_socket_read(session->socket,
-                                packet,
-                                to_be_read - current_macsize);
-#endif
-
-                rc = ssh_buffer_add_data(session->in_buffer,
-                                     packet,
-                                     to_be_read - current_macsize);
-                if (rc < 0) {
-                    goto error;
-                }
-                processed += to_be_read - current_macsize;
+                packet_second_block = (uint8_t*)data + lenfield_blocksize;
+                processed = to_be_read - current_macsize;
             }
 
+            /* remaining encrypted bytes from the packet, MAC not included */
+            packet_remaining = packet_len - (lenfield_blocksize - sizeof(uint32_t));
+            cleartext_packet = ssh_buffer_allocate(session->in_buffer, packet_remaining);
             if (session->current_crypto) {
                 /*
-                 * Decrypt the rest of the packet (blocksize bytes already
+                 * Decrypt the rest of the packet (lenfield_blocksize bytes already
                  * have been decrypted)
                  */
-                uint32_t buffer_len = ssh_buffer_get_len(session->in_buffer);
-
-                /* The following check avoids decrypting zero bytes */
-                if (buffer_len > blocksize) {
-                    uint8_t *payload = ((uint8_t*)ssh_buffer_get(session->in_buffer) + blocksize);
-                    uint32_t plen = buffer_len - blocksize;
-
-                    rc = ssh_packet_decrypt(session, payload, plen);
+                if (packet_remaining > 0) {
+                    rc = ssh_packet_decrypt(session, cleartext_packet, (uint8_t *)data, lenfield_blocksize, processed - lenfield_blocksize);
                     if (rc < 0) {
-                        ssh_set_error(session, SSH_FATAL, "Decrypt error");
+                        ssh_set_error(session, SSH_FATAL, "Decryption error");
                         goto error;
                     }
                 }
-
-                /* copy the last part from the incoming buffer */
-                packet = ((uint8_t *)data) + processed;
-                memcpy(mac, packet, current_macsize);
+                mac = packet_second_block + packet_remaining;
 
                 rc = ssh_packet_hmac_verify(session, session->in_buffer, mac, session->current_crypto->in_hmac);
                 if (rc < 0) {
@@ -285,6 +270,8 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
                     goto error;
                 }
                 processed += current_macsize;
+            } else {
+                memcpy(cleartext_packet, packet_second_block, packet_remaining);
             }
 
             /* skip the size field which has been processed before */
@@ -334,7 +321,7 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
             ssh_packet_parse_type(session);
             SSH_LOG(SSH_LOG_PACKET,
                     "packet: read type %hhd [len=%d,padding=%hhd,comp=%d,payload=%d]",
-                    session->in_packet.type, len, padding, compsize, payloadsize);
+                    session->in_packet.type, packet_len, padding, compsize, payloadsize);
 
             /* Execute callbacks */
             ssh_packet_process(session, session->in_packet.type);
@@ -345,9 +332,9 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
                         "Processing %" PRIdS " bytes left in socket buffer",
                         receivedlen-processed);
 
-                packet = ((uint8_t*)data) + processed;
+                ptr = ((uint8_t*)data) + processed;
 
-                rc = ssh_packet_socket_callback(packet, receivedlen - processed,user);
+                rc = ssh_packet_socket_callback(ptr, receivedlen - processed,user);
                 processed += rc;
             }
 
@@ -364,7 +351,7 @@ int ssh_packet_socket_callback(const void *data, size_t receivedlen, void *user)
 
 error:
     session->session_state= SSH_SESSION_STATE_ERROR;
-
+    SSH_LOG(SSH_LOG_PACKET,"Packet: processed %" PRIdS " bytes", processed);
     return processed;
 }
 
diff --git a/src/packet_crypt.c b/src/packet_crypt.c
index 5bcb6d65..a0f7635a 100644
--- a/src/packet_crypt.c
+++ b/src/packet_crypt.c
@@ -44,40 +44,68 @@
 #include "libssh/crypto.h"
 #include "libssh/buffer.h"
 
-uint32_t ssh_packet_decrypt_len(ssh_session session, char *crypted){
-  uint32_t decrypted;
-
-  if (session->current_crypto) {
-    if (ssh_packet_decrypt(session, crypted,
-          session->current_crypto->in_cipher->blocksize) < 0) {
-      return 0;
+/** @internal
+ * @brief decrypt the packet length from a raw encrypted packet, and store the first decrypted
+ * blocksize.
+ * @returns native byte-ordered decrypted length of the upcoming packet
+ */
+uint32_t ssh_packet_decrypt_len(ssh_session session, uint8_t *destination, uint8_t *source){
+    uint32_t decrypted;
+
+    if (session->current_crypto) {
+        if(session->current_crypto->in_cipher->aead_decrypt_length != NULL){
+            if (session->current_crypto->in_cipher->set_decrypt_key(session->current_crypto->in_cipher,
+                    session->current_crypto->decryptkey,
+                    session->current_crypto->decryptIV) < 0) {
+                return (uint32_t)-1;
+            }
+            session->current_crypto->in_cipher->aead_decrypt_length(
+                    session->current_crypto->in_cipher, source, destination,
+                    session->current_crypto->in_cipher->lenfield_blocksize,
+                    session->recv_seq);
+        } else {
+            if (ssh_packet_decrypt(session, destination, source, 0,
+                    session->current_crypto->in_cipher->blocksize) < 0) {
+                return 0;
+            }
+        }
+    } else {
+        memcpy(destination, source, 8);
     }
-  }
-  memcpy(&decrypted,crypted,sizeof(decrypted));
-  return ntohl(decrypted);
+    memcpy(&decrypted,destination,sizeof(decrypted));
+    return ntohl(decrypted);
 }
 
-int ssh_packet_decrypt(ssh_session session, void *data,uint32_t len) {
-  struct ssh_cipher_struct *crypto = session->current_crypto->in_cipher;
-  char *out = NULL;
-
-  assert(len);
+/** @internal
+ * @brief decrypts the content of an SSH packet.
+ * @param[source] source packet, including the encrypted length field
+ * @param[start] index in the packet that was not decrypted yet.
+ * @param[encrypted_size] size of the encrypted data to be decrypted after start.
+ */
+int ssh_packet_decrypt(ssh_session session, uint8_t *destination, uint8_t *source,
+        size_t start, size_t encrypted_size) {
+    struct ssh_cipher_struct *crypto = session->current_crypto->in_cipher;
 
-  if(len % session->current_crypto->in_cipher->blocksize != 0){
-    ssh_set_error(session, SSH_FATAL, "Cryptographic functions must be set on at least one blocksize (received %d)",len);
-    return SSH_ERROR;
-  }
-  out = malloc(len);
-  if (out == NULL) {
-    return -1;
-  }
+    if (encrypted_size <= 0){
+        return SSH_ERROR;
+    }
+    if(encrypted_size % session->current_crypto->in_cipher->blocksize != 0){
+        ssh_set_error(session, SSH_FATAL, "Cryptographic functions must be used on multiple of blocksize (received %" PRIdS ")",encrypted_size);
+        return SSH_ERROR;
+    }
 
-  crypto->decrypt(crypto,data,out,len);
+    if (crypto->set_decrypt_key(crypto, session->current_crypto->decryptkey,
+            session->current_crypto->decryptIV) < 0) {
+        return -1;
+    }
 
-  memcpy(data,out,len);
-  explicit_bzero(out, len);
-  SAFE_FREE(out);
-  return 0;
+    if (crypto->aead_decrypt != NULL){
+        return crypto->aead_decrypt(crypto, source, destination, encrypted_size,
+                session->recv_seq);
+    } else {
+        crypto->decrypt(crypto, source + start, destination, encrypted_size);
+    }
+    return 0;
 }
 
 unsigned char *ssh_packet_encrypt(ssh_session session, void *data, uint32_t len) {
@@ -160,12 +188,17 @@ unsigned char *ssh_packet_encrypt(ssh_session session, void *data, uint32_t len)
  *                      occurred.
  */
 int ssh_packet_hmac_verify(ssh_session session, ssh_buffer buffer,
-    unsigned char *mac, enum ssh_hmac_e type) {
+    uint8_t *mac, enum ssh_hmac_e type) {
   unsigned char hmacbuf[DIGEST_MAX_LEN] = {0};
   HMACCTX ctx;
   unsigned int len;
   uint32_t seq;
 
+  /* AEAD type have no mac checking */
+  if (type == SSH_HMAC_AEAD_POLY1305){
+      return SSH_OK;
+  }
+
   ctx = hmac_init(session->current_crypto->decryptMAC, hmac_digest_len(type), type);
   if (ctx == NULL) {
     return -1;
diff --git a/src/wrapper.c b/src/wrapper.c
index 71d64bf3..381fc16d 100644
--- a/src/wrapper.c
+++ b/src/wrapper.c
@@ -298,8 +298,13 @@ static int crypt_set_algorithms2(ssh_session session){
   }
   i = 0;
 
-  /* we must scan the kex entries to find hmac algorithms and set their appropriate structure */
-  wanted = session->next_crypto->kex_methods[SSH_MAC_S_C];
+  if (session->next_crypto->in_cipher->aead_encrypt != NULL){
+      /* this cipher has integrated MAC */
+      wanted = "aead-poly1305";
+  } else {
+      /* we must scan the kex entries to find hmac algorithms and set their appropriate structure */
+      wanted = session->next_crypto->kex_methods[SSH_MAC_S_C];
+  }
   while (ssh_hmactab[i].name && strcmp(wanted, ssh_hmactab[i].name)) {
     i++;
   }
@@ -310,7 +315,7 @@ static int crypt_set_algorithms2(ssh_session session){
         wanted);
       return SSH_ERROR;
   }
-  SSH_LOG(SSH_LOG_PACKET, "Set HMAC output algorithm to %s", wanted);
+  SSH_LOG(SSH_LOG_PACKET, "Set HMAC input algorithm to %s", wanted);
 
   session->next_crypto->in_hmac = ssh_hmactab[i].hmac_type;
   i = 0;
@@ -442,7 +447,14 @@ int crypt_set_algorithms_server(ssh_session session){
     }
     i=0;
 
-    method = session->next_crypto->kex_methods[SSH_MAC_C_S];
+    if (session->next_crypto->in_cipher->aead_encrypt != NULL){
+        /* this cipher has integrated MAC */
+        method = "aead-poly1305";
+    } else {
+        /* we must scan the kex entries to find hmac algorithms and set their appropriate structure */
+        method = session->next_crypto->kex_methods[SSH_MAC_C_S];
+    }
+
     while (ssh_hmactab[i].name && strcmp(method, ssh_hmactab[i].name)) {
       i++;
     }
-- 
2.14.1


References:
[PATCH 00/20] Add chacha20-poly1305 supportAlberto Aguirre <albaguirre@xxxxxxxxx>
Archive administrator: postmaster@lists.cynapses.org