]> git.baikalelectronics.ru Git - kernel.git/commitdiff
hv_netvsc: Add (more) validation for untrusted Hyper-V values
authorAndrea Parri (Microsoft) <parri.andrea@gmail.com>
Thu, 14 Jan 2021 20:26:28 +0000 (21:26 +0100)
committerJakub Kicinski <kuba@kernel.org>
Tue, 19 Jan 2021 03:47:47 +0000 (19:47 -0800)
For additional robustness in the face of Hyper-V errors or malicious
behavior, validate all values that originate from packets that Hyper-V
has sent to the guest.  Ensure that invalid values cannot cause indexing
off the end of an array, or subvert an existing validation via integer
overflow.  Ensure that outgoing packets do not have any leftover guest
memory that has not been zeroed out.

Reported-by: Juan Vazquez <juvazq@microsoft.com>
Signed-off-by: Andrea Parri (Microsoft) <parri.andrea@gmail.com>
Link: https://lore.kernel.org/r/20210114202628.119541-1-parri.andrea@gmail.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
drivers/net/hyperv/netvsc.c
drivers/net/hyperv/netvsc_bpf.c
drivers/net/hyperv/netvsc_drv.c
drivers/net/hyperv/rndis_filter.c

index 3a3db2f0134d8fbbe332f0e0c292b6d242f701a4..6184e99c7f31fd1c7d0ffcb6d81cd5001b881579 100644 (file)
@@ -918,6 +918,7 @@ static inline int netvsc_send_pkt(
        int ret;
        u32 ring_avail = hv_get_avail_to_write_percent(&out_channel->outbound);
 
+       memset(&nvmsg, 0, sizeof(struct nvsp_message));
        nvmsg.hdr.msg_type = NVSP_MSG1_TYPE_SEND_RNDIS_PKT;
        if (skb)
                rpkt->channel_type = 0;         /* 0 is RMC_DATA */
@@ -1337,7 +1338,7 @@ static void netvsc_send_table(struct net_device *ndev,
                         sizeof(union nvsp_6_message_uber);
 
        /* Boundary check for all versions */
-       if (offset > msglen - count * sizeof(u32)) {
+       if (msglen < count * sizeof(u32) || offset > msglen - count * sizeof(u32)) {
                netdev_err(ndev, "Received send-table offset too big:%u\n",
                           offset);
                return;
index d60dcf6c9829ec3357fe67c9ba9965db86568768..aa877da113f8e2804ddd89dc0f740caf275e4989 100644 (file)
@@ -37,6 +37,12 @@ u32 netvsc_run_xdp(struct net_device *ndev, struct netvsc_channel *nvchan,
        if (!prog)
                goto out;
 
+       /* Ensure that the below memcpy() won't overflow the page buffer. */
+       if (len > ndev->mtu + ETH_HLEN) {
+               act = XDP_DROP;
+               goto out;
+       }
+
        /* allocate page buffer for data */
        page = alloc_page(GFP_ATOMIC);
        if (!page) {
index 75b4d6703cf1eb55d0f84356114bf28ec87b11d8..ac20c432d4d8f537901fdf388c3067438028a325 100644 (file)
@@ -761,6 +761,16 @@ void netvsc_linkstatus_callback(struct net_device *net,
        if (indicate->status == RNDIS_STATUS_LINK_SPEED_CHANGE) {
                u32 speed;
 
+               /* Validate status_buf_offset */
+               if (indicate->status_buflen < sizeof(speed) ||
+                   indicate->status_buf_offset < sizeof(*indicate) ||
+                   resp->msg_len - RNDIS_HEADER_SIZE < indicate->status_buf_offset ||
+                   resp->msg_len - RNDIS_HEADER_SIZE - indicate->status_buf_offset
+                               < indicate->status_buflen) {
+                       netdev_err(net, "invalid rndis_indicate_status packet\n");
+                       return;
+               }
+
                speed = *(u32 *)((void *)indicate
                                 + indicate->status_buf_offset) / 10000;
                ndev_ctx->speed = speed;
@@ -866,8 +876,14 @@ static struct sk_buff *netvsc_alloc_recv_skb(struct net_device *net,
         */
        if (csum_info && csum_info->receive.ip_checksum_value_invalid &&
            csum_info->receive.ip_checksum_succeeded &&
-           skb->protocol == htons(ETH_P_IP))
+           skb->protocol == htons(ETH_P_IP)) {
+               /* Check that there is enough space to hold the IP header. */
+               if (skb_headlen(skb) < sizeof(struct iphdr)) {
+                       kfree_skb(skb);
+                       return NULL;
+               }
                netvsc_comp_ipcsum(skb);
+       }
 
        /* Do L4 checksum offload if enabled and present. */
        if (csum_info && (net->features & NETIF_F_RXCSUM)) {
index 598713c0d5a8753f1692a82d7f68eb0b83e19375..c8534b6619b8d16376d01e5aabbe270e7b159177 100644 (file)
@@ -131,66 +131,84 @@ static void dump_rndis_message(struct net_device *netdev,
 {
        switch (rndis_msg->ndis_msg_type) {
        case RNDIS_MSG_PACKET:
-               netdev_dbg(netdev, "RNDIS_MSG_PACKET (len %u, "
-                          "data offset %u data len %u, # oob %u, "
-                          "oob offset %u, oob len %u, pkt offset %u, "
-                          "pkt len %u\n",
-                          rndis_msg->msg_len,
-                          rndis_msg->msg.pkt.data_offset,
-                          rndis_msg->msg.pkt.data_len,
-                          rndis_msg->msg.pkt.num_oob_data_elements,
-                          rndis_msg->msg.pkt.oob_data_offset,
-                          rndis_msg->msg.pkt.oob_data_len,
-                          rndis_msg->msg.pkt.per_pkt_info_offset,
-                          rndis_msg->msg.pkt.per_pkt_info_len);
+               if (rndis_msg->msg_len - RNDIS_HEADER_SIZE >= sizeof(struct rndis_packet)) {
+                       const struct rndis_packet *pkt = &rndis_msg->msg.pkt;
+                       netdev_dbg(netdev, "RNDIS_MSG_PACKET (len %u, "
+                                  "data offset %u data len %u, # oob %u, "
+                                  "oob offset %u, oob len %u, pkt offset %u, "
+                                  "pkt len %u\n",
+                                  rndis_msg->msg_len,
+                                  pkt->data_offset,
+                                  pkt->data_len,
+                                  pkt->num_oob_data_elements,
+                                  pkt->oob_data_offset,
+                                  pkt->oob_data_len,
+                                  pkt->per_pkt_info_offset,
+                                  pkt->per_pkt_info_len);
+               }
                break;
 
        case RNDIS_MSG_INIT_C:
-               netdev_dbg(netdev, "RNDIS_MSG_INIT_C "
-                       "(len %u, id 0x%x, status 0x%x, major %d, minor %d, "
-                       "device flags %d, max xfer size 0x%x, max pkts %u, "
-                       "pkt aligned %u)\n",
-                       rndis_msg->msg_len,
-                       rndis_msg->msg.init_complete.req_id,
-                       rndis_msg->msg.init_complete.status,
-                       rndis_msg->msg.init_complete.major_ver,
-                       rndis_msg->msg.init_complete.minor_ver,
-                       rndis_msg->msg.init_complete.dev_flags,
-                       rndis_msg->msg.init_complete.max_xfer_size,
-                       rndis_msg->msg.init_complete.
-                          max_pkt_per_msg,
-                       rndis_msg->msg.init_complete.
-                          pkt_alignment_factor);
+               if (rndis_msg->msg_len - RNDIS_HEADER_SIZE >=
+                               sizeof(struct rndis_initialize_complete)) {
+                       const struct rndis_initialize_complete *init_complete =
+                               &rndis_msg->msg.init_complete;
+                       netdev_dbg(netdev, "RNDIS_MSG_INIT_C "
+                               "(len %u, id 0x%x, status 0x%x, major %d, minor %d, "
+                               "device flags %d, max xfer size 0x%x, max pkts %u, "
+                               "pkt aligned %u)\n",
+                               rndis_msg->msg_len,
+                               init_complete->req_id,
+                               init_complete->status,
+                               init_complete->major_ver,
+                               init_complete->minor_ver,
+                               init_complete->dev_flags,
+                               init_complete->max_xfer_size,
+                               init_complete->max_pkt_per_msg,
+                               init_complete->pkt_alignment_factor);
+               }
                break;
 
        case RNDIS_MSG_QUERY_C:
-               netdev_dbg(netdev, "RNDIS_MSG_QUERY_C "
-                       "(len %u, id 0x%x, status 0x%x, buf len %u, "
-                       "buf offset %u)\n",
-                       rndis_msg->msg_len,
-                       rndis_msg->msg.query_complete.req_id,
-                       rndis_msg->msg.query_complete.status,
-                       rndis_msg->msg.query_complete.
-                          info_buflen,
-                       rndis_msg->msg.query_complete.
-                          info_buf_offset);
+               if (rndis_msg->msg_len - RNDIS_HEADER_SIZE >=
+                               sizeof(struct rndis_query_complete)) {
+                       const struct rndis_query_complete *query_complete =
+                               &rndis_msg->msg.query_complete;
+                       netdev_dbg(netdev, "RNDIS_MSG_QUERY_C "
+                               "(len %u, id 0x%x, status 0x%x, buf len %u, "
+                               "buf offset %u)\n",
+                               rndis_msg->msg_len,
+                               query_complete->req_id,
+                               query_complete->status,
+                               query_complete->info_buflen,
+                               query_complete->info_buf_offset);
+               }
                break;
 
        case RNDIS_MSG_SET_C:
-               netdev_dbg(netdev,
-                       "RNDIS_MSG_SET_C (len %u, id 0x%x, status 0x%x)\n",
-                       rndis_msg->msg_len,
-                       rndis_msg->msg.set_complete.req_id,
-                       rndis_msg->msg.set_complete.status);
+               if (rndis_msg->msg_len - RNDIS_HEADER_SIZE + sizeof(struct rndis_set_complete)) {
+                       const struct rndis_set_complete *set_complete =
+                               &rndis_msg->msg.set_complete;
+                       netdev_dbg(netdev,
+                               "RNDIS_MSG_SET_C (len %u, id 0x%x, status 0x%x)\n",
+                               rndis_msg->msg_len,
+                               set_complete->req_id,
+                               set_complete->status);
+               }
                break;
 
        case RNDIS_MSG_INDICATE:
-               netdev_dbg(netdev, "RNDIS_MSG_INDICATE "
-                       "(len %u, status 0x%x, buf len %u, buf offset %u)\n",
-                       rndis_msg->msg_len,
-                       rndis_msg->msg.indicate_status.status,
-                       rndis_msg->msg.indicate_status.status_buflen,
-                       rndis_msg->msg.indicate_status.status_buf_offset);
+               if (rndis_msg->msg_len - RNDIS_HEADER_SIZE >=
+                               sizeof(struct rndis_indicate_status)) {
+                       const struct rndis_indicate_status *indicate_status =
+                               &rndis_msg->msg.indicate_status;
+                       netdev_dbg(netdev, "RNDIS_MSG_INDICATE "
+                               "(len %u, status 0x%x, buf len %u, buf offset %u)\n",
+                               rndis_msg->msg_len,
+                               indicate_status->status,
+                               indicate_status->status_buflen,
+                               indicate_status->status_buf_offset);
+               }
                break;
 
        default:
@@ -246,11 +264,20 @@ static void rndis_set_link_state(struct rndis_device *rdev,
 {
        u32 link_status;
        struct rndis_query_complete *query_complete;
+       u32 msg_len = request->response_msg.msg_len;
+
+       /* Ensure the packet is big enough to access its fields */
+       if (msg_len - RNDIS_HEADER_SIZE < sizeof(struct rndis_query_complete))
+               return;
 
        query_complete = &request->response_msg.msg.query_complete;
 
        if (query_complete->status == RNDIS_STATUS_SUCCESS &&
-           query_complete->info_buflen == sizeof(u32)) {
+           query_complete->info_buflen >= sizeof(u32) &&
+           query_complete->info_buf_offset >= sizeof(*query_complete) &&
+           msg_len - RNDIS_HEADER_SIZE >= query_complete->info_buf_offset &&
+           msg_len - RNDIS_HEADER_SIZE - query_complete->info_buf_offset
+                       >= query_complete->info_buflen) {
                memcpy(&link_status, (void *)((unsigned long)query_complete +
                       query_complete->info_buf_offset), sizeof(u32));
                rdev->link_state = link_status != 0;
@@ -343,7 +370,8 @@ static void rndis_filter_receive_response(struct net_device *ndev,
  */
 static inline void *rndis_get_ppi(struct net_device *ndev,
                                  struct rndis_packet *rpkt,
-                                 u32 rpkt_len, u32 type, u8 internal)
+                                 u32 rpkt_len, u32 type, u8 internal,
+                                 u32 ppi_size)
 {
        struct rndis_per_packet_info *ppi;
        int len;
@@ -359,7 +387,8 @@ static inline void *rndis_get_ppi(struct net_device *ndev,
                return NULL;
        }
 
-       if (rpkt->per_pkt_info_len > rpkt_len - rpkt->per_pkt_info_offset) {
+       if (rpkt->per_pkt_info_len < sizeof(*ppi) ||
+           rpkt->per_pkt_info_len > rpkt_len - rpkt->per_pkt_info_offset) {
                netdev_err(ndev, "Invalid per_pkt_info_len: %u\n",
                           rpkt->per_pkt_info_len);
                return NULL;
@@ -381,8 +410,15 @@ static inline void *rndis_get_ppi(struct net_device *ndev,
                        continue;
                }
 
-               if (ppi->type == type && ppi->internal == internal)
+               if (ppi->type == type && ppi->internal == internal) {
+                       /* ppi->size should be big enough to hold the returned object. */
+                       if (ppi->size - ppi->ppi_offset < ppi_size) {
+                               netdev_err(ndev, "Invalid ppi: size %u ppi_offset %u\n",
+                                          ppi->size, ppi->ppi_offset);
+                               continue;
+                       }
                        return (void *)((ulong)ppi + ppi->ppi_offset);
+               }
                len -= ppi->size;
                ppi = (struct rndis_per_packet_info *)((ulong)ppi + ppi->size);
        }
@@ -461,13 +497,16 @@ static int rndis_filter_receive_data(struct net_device *ndev,
                return NVSP_STAT_FAIL;
        }
 
-       vlan = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, IEEE_8021Q_INFO, 0);
+       vlan = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, IEEE_8021Q_INFO, 0, sizeof(*vlan));
 
-       csum_info = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, TCPIP_CHKSUM_PKTINFO, 0);
+       csum_info = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, TCPIP_CHKSUM_PKTINFO, 0,
+                                 sizeof(*csum_info));
 
-       hash_info = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, NBL_HASH_VALUE, 0);
+       hash_info = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, NBL_HASH_VALUE, 0,
+                                 sizeof(*hash_info));
 
-       pktinfo_id = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, RNDIS_PKTINFO_ID, 1);
+       pktinfo_id = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, RNDIS_PKTINFO_ID, 1,
+                                  sizeof(*pktinfo_id));
 
        data = (void *)msg + data_offset;
 
@@ -522,9 +561,6 @@ int rndis_filter_receive(struct net_device *ndev,
        struct net_device_context *net_device_ctx = netdev_priv(ndev);
        struct rndis_message *rndis_msg = data;
 
-       if (netif_msg_rx_status(net_device_ctx))
-               dump_rndis_message(ndev, rndis_msg);
-
        /* Validate incoming rndis_message packet */
        if (buflen < RNDIS_HEADER_SIZE || rndis_msg->msg_len < RNDIS_HEADER_SIZE ||
            buflen < rndis_msg->msg_len) {
@@ -533,6 +569,9 @@ int rndis_filter_receive(struct net_device *ndev,
                return NVSP_STAT_FAIL;
        }
 
+       if (netif_msg_rx_status(net_device_ctx))
+               dump_rndis_message(ndev, rndis_msg);
+
        switch (rndis_msg->ndis_msg_type) {
        case RNDIS_MSG_PACKET:
                return rndis_filter_receive_data(ndev, net_dev, nvchan,
@@ -567,6 +606,7 @@ static int rndis_filter_query_device(struct rndis_device *dev,
        u32 inresult_size = *result_size;
        struct rndis_query_request *query;
        struct rndis_query_complete *query_complete;
+       u32 msg_len;
        int ret = 0;
 
        if (!result)
@@ -634,8 +674,19 @@ static int rndis_filter_query_device(struct rndis_device *dev,
 
        /* Copy the response back */
        query_complete = &request->response_msg.msg.query_complete;
+       msg_len = request->response_msg.msg_len;
+
+       /* Ensure the packet is big enough to access its fields */
+       if (msg_len - RNDIS_HEADER_SIZE < sizeof(struct rndis_query_complete)) {
+               ret = -1;
+               goto cleanup;
+       }
 
-       if (query_complete->info_buflen > inresult_size) {
+       if (query_complete->info_buflen > inresult_size ||
+           query_complete->info_buf_offset < sizeof(*query_complete) ||
+           msg_len - RNDIS_HEADER_SIZE < query_complete->info_buf_offset ||
+           msg_len - RNDIS_HEADER_SIZE - query_complete->info_buf_offset
+                       < query_complete->info_buflen) {
                ret = -1;
                goto cleanup;
        }