]> git.baikalelectronics.ru Git - kernel.git/commitdiff
ksmbd: add validation in smb2 negotiate
authorNamjae Jeon <linkinjeon@kernel.org>
Wed, 29 Sep 2021 06:44:32 +0000 (15:44 +0900)
committerSteve French <stfrench@microsoft.com>
Thu, 30 Sep 2021 14:58:07 +0000 (09:58 -0500)
This patch add validation to check request buffer check in smb2
negotiate and fix null pointer deferencing oops in smb3_preauth_hash_rsp()
that found from manual test.

Cc: Tom Talpey <tom@talpey.com>
Cc: Ronnie Sahlberg <ronniesahlberg@gmail.com>
Cc: Ralph Böhme <slow@samba.org>
Cc: Hyunchul Lee <hyc.lee@gmail.com>
Cc: Sergey Senozhatsky <senozhatsky@chromium.org>
Reviewed-by: Ralph Boehme <slow@samba.org>
Signed-off-by: Namjae Jeon <linkinjeon@kernel.org>
Signed-off-by: Steve French <stfrench@microsoft.com>
fs/ksmbd/smb2pdu.c
fs/ksmbd/smb_common.c

index 0d915554f5326bb9a2c19dba3c9b8cd805b27d75..40882fd47febd4e719db3fe09a33ea99fbb206b5 100644 (file)
@@ -1067,6 +1067,7 @@ int smb2_handle_negotiate(struct ksmbd_work *work)
        struct smb2_negotiate_req *req = work->request_buf;
        struct smb2_negotiate_rsp *rsp = work->response_buf;
        int rc = 0;
+       unsigned int smb2_buf_len, smb2_neg_size;
        __le32 status;
 
        ksmbd_debug(SMB, "Received negotiate request\n");
@@ -1084,6 +1085,44 @@ int smb2_handle_negotiate(struct ksmbd_work *work)
                goto err_out;
        }
 
+       smb2_buf_len = get_rfc1002_len(work->request_buf);
+       smb2_neg_size = offsetof(struct smb2_negotiate_req, Dialects) - 4;
+       if (smb2_neg_size > smb2_buf_len) {
+               rsp->hdr.Status = STATUS_INVALID_PARAMETER;
+               rc = -EINVAL;
+               goto err_out;
+       }
+
+       if (conn->dialect == SMB311_PROT_ID) {
+               unsigned int nego_ctxt_off = le32_to_cpu(req->NegotiateContextOffset);
+
+               if (smb2_buf_len < nego_ctxt_off) {
+                       rsp->hdr.Status = STATUS_INVALID_PARAMETER;
+                       rc = -EINVAL;
+                       goto err_out;
+               }
+
+               if (smb2_neg_size > nego_ctxt_off) {
+                       rsp->hdr.Status = STATUS_INVALID_PARAMETER;
+                       rc = -EINVAL;
+                       goto err_out;
+               }
+
+               if (smb2_neg_size + le16_to_cpu(req->DialectCount) * sizeof(__le16) >
+                   nego_ctxt_off) {
+                       rsp->hdr.Status = STATUS_INVALID_PARAMETER;
+                       rc = -EINVAL;
+                       goto err_out;
+               }
+       } else {
+               if (smb2_neg_size + le16_to_cpu(req->DialectCount) * sizeof(__le16) >
+                   smb2_buf_len) {
+                       rsp->hdr.Status = STATUS_INVALID_PARAMETER;
+                       rc = -EINVAL;
+                       goto err_out;
+               }
+       }
+
        conn->cli_cap = le32_to_cpu(req->Capabilities);
        switch (conn->dialect) {
        case SMB311_PROT_ID:
@@ -8244,7 +8283,8 @@ void smb3_preauth_hash_rsp(struct ksmbd_work *work)
 
        WORK_BUFFERS(work, req, rsp);
 
-       if (le16_to_cpu(req->Command) == SMB2_NEGOTIATE_HE)
+       if (le16_to_cpu(req->Command) == SMB2_NEGOTIATE_HE &&
+           conn->preauth_info)
                ksmbd_gen_preauth_integrity_hash(conn, (char *)rsp,
                                                 conn->preauth_info->Preauth_HashValue);
 
index 5901b2884c602fd7ffb67533a8652e3a43d45062..db8042a173d09e49ee11c88e4b2c2a7237efa8fc 100644 (file)
@@ -169,10 +169,12 @@ static bool supported_protocol(int idx)
                idx <= server_conf.max_protocol);
 }
 
-static char *next_dialect(char *dialect, int *next_off)
+static char *next_dialect(char *dialect, int *next_off, int bcount)
 {
        dialect = dialect + *next_off;
-       *next_off = strlen(dialect);
+       *next_off = strnlen(dialect, bcount);
+       if (dialect[*next_off] != '\0')
+               return NULL;
        return dialect;
 }
 
@@ -187,7 +189,9 @@ static int ksmbd_lookup_dialect_by_name(char *cli_dialects, __le16 byte_count)
                dialect = cli_dialects;
                bcount = le16_to_cpu(byte_count);
                do {
-                       dialect = next_dialect(dialect, &next);
+                       dialect = next_dialect(dialect, &next, bcount);
+                       if (!dialect)
+                               break;
                        ksmbd_debug(SMB, "client requested dialect %s\n",
                                    dialect);
                        if (!strcmp(dialect, smb1_protos[i].name)) {
@@ -235,13 +239,22 @@ int ksmbd_lookup_dialect_by_id(__le16 *cli_dialects, __le16 dialects_count)
 
 static int ksmbd_negotiate_smb_dialect(void *buf)
 {
-       __le32 proto;
+       int smb_buf_length = get_rfc1002_len(buf);
+       __le32 proto = ((struct smb2_hdr *)buf)->ProtocolId;
 
-       proto = ((struct smb2_hdr *)buf)->ProtocolId;
        if (proto == SMB2_PROTO_NUMBER) {
                struct smb2_negotiate_req *req;
+               int smb2_neg_size =
+                       offsetof(struct smb2_negotiate_req, Dialects) - 4;
 
                req = (struct smb2_negotiate_req *)buf;
+               if (smb2_neg_size > smb_buf_length)
+                       goto err_out;
+
+               if (smb2_neg_size + le16_to_cpu(req->DialectCount) * sizeof(__le16) >
+                   smb_buf_length)
+                       goto err_out;
+
                return ksmbd_lookup_dialect_by_id(req->Dialects,
                                                  req->DialectCount);
        }
@@ -251,10 +264,19 @@ static int ksmbd_negotiate_smb_dialect(void *buf)
                struct smb_negotiate_req *req;
 
                req = (struct smb_negotiate_req *)buf;
+               if (le16_to_cpu(req->ByteCount) < 2)
+                       goto err_out;
+
+               if (offsetof(struct smb_negotiate_req, DialectsArray) - 4 +
+                       le16_to_cpu(req->ByteCount) > smb_buf_length) {
+                       goto err_out;
+               }
+
                return ksmbd_lookup_dialect_by_name(req->DialectsArray,
                                                    req->ByteCount);
        }
 
+err_out:
        return BAD_PROT_ID;
 }