struct smb2_negotiate_req *req = work->request_buf;
struct smb2_negotiate_rsp *rsp = work->response_buf;
int rc = 0;
+ unsigned int smb2_buf_len, smb2_neg_size;
__le32 status;
ksmbd_debug(SMB, "Received negotiate request\n");
goto err_out;
}
+ smb2_buf_len = get_rfc1002_len(work->request_buf);
+ smb2_neg_size = offsetof(struct smb2_negotiate_req, Dialects) - 4;
+ if (smb2_neg_size > smb2_buf_len) {
+ rsp->hdr.Status = STATUS_INVALID_PARAMETER;
+ rc = -EINVAL;
+ goto err_out;
+ }
+
+ if (conn->dialect == SMB311_PROT_ID) {
+ unsigned int nego_ctxt_off = le32_to_cpu(req->NegotiateContextOffset);
+
+ if (smb2_buf_len < nego_ctxt_off) {
+ rsp->hdr.Status = STATUS_INVALID_PARAMETER;
+ rc = -EINVAL;
+ goto err_out;
+ }
+
+ if (smb2_neg_size > nego_ctxt_off) {
+ rsp->hdr.Status = STATUS_INVALID_PARAMETER;
+ rc = -EINVAL;
+ goto err_out;
+ }
+
+ if (smb2_neg_size + le16_to_cpu(req->DialectCount) * sizeof(__le16) >
+ nego_ctxt_off) {
+ rsp->hdr.Status = STATUS_INVALID_PARAMETER;
+ rc = -EINVAL;
+ goto err_out;
+ }
+ } else {
+ if (smb2_neg_size + le16_to_cpu(req->DialectCount) * sizeof(__le16) >
+ smb2_buf_len) {
+ rsp->hdr.Status = STATUS_INVALID_PARAMETER;
+ rc = -EINVAL;
+ goto err_out;
+ }
+ }
+
conn->cli_cap = le32_to_cpu(req->Capabilities);
switch (conn->dialect) {
case SMB311_PROT_ID:
WORK_BUFFERS(work, req, rsp);
- if (le16_to_cpu(req->Command) == SMB2_NEGOTIATE_HE)
+ if (le16_to_cpu(req->Command) == SMB2_NEGOTIATE_HE &&
+ conn->preauth_info)
ksmbd_gen_preauth_integrity_hash(conn, (char *)rsp,
conn->preauth_info->Preauth_HashValue);
idx <= server_conf.max_protocol);
}
-static char *next_dialect(char *dialect, int *next_off)
+static char *next_dialect(char *dialect, int *next_off, int bcount)
{
dialect = dialect + *next_off;
- *next_off = strlen(dialect);
+ *next_off = strnlen(dialect, bcount);
+ if (dialect[*next_off] != '\0')
+ return NULL;
return dialect;
}
dialect = cli_dialects;
bcount = le16_to_cpu(byte_count);
do {
- dialect = next_dialect(dialect, &next);
+ dialect = next_dialect(dialect, &next, bcount);
+ if (!dialect)
+ break;
ksmbd_debug(SMB, "client requested dialect %s\n",
dialect);
if (!strcmp(dialect, smb1_protos[i].name)) {
static int ksmbd_negotiate_smb_dialect(void *buf)
{
- __le32 proto;
+ int smb_buf_length = get_rfc1002_len(buf);
+ __le32 proto = ((struct smb2_hdr *)buf)->ProtocolId;
- proto = ((struct smb2_hdr *)buf)->ProtocolId;
if (proto == SMB2_PROTO_NUMBER) {
struct smb2_negotiate_req *req;
+ int smb2_neg_size =
+ offsetof(struct smb2_negotiate_req, Dialects) - 4;
req = (struct smb2_negotiate_req *)buf;
+ if (smb2_neg_size > smb_buf_length)
+ goto err_out;
+
+ if (smb2_neg_size + le16_to_cpu(req->DialectCount) * sizeof(__le16) >
+ smb_buf_length)
+ goto err_out;
+
return ksmbd_lookup_dialect_by_id(req->Dialects,
req->DialectCount);
}
struct smb_negotiate_req *req;
req = (struct smb_negotiate_req *)buf;
+ if (le16_to_cpu(req->ByteCount) < 2)
+ goto err_out;
+
+ if (offsetof(struct smb_negotiate_req, DialectsArray) - 4 +
+ le16_to_cpu(req->ByteCount) > smb_buf_length) {
+ goto err_out;
+ }
+
return ksmbd_lookup_dialect_by_name(req->DialectsArray,
req->ByteCount);
}
+err_out:
return BAD_PROT_ID;
}