From b503f82ed7cc740d1209f6372094d4d9f0d912ed Mon Sep 17 00:00:00 2001 From: Cristian Marussi Date: Wed, 30 Mar 2022 16:05:33 +0100 Subject: [PATCH] firmware: arm_scmi: Validate BASE_DISCOVER_LIST_PROTOCOLS response Do not blindly trust SCMI platform response about list of implemented protocols, instead validate the reported length of the list of protocols against the real payload size of the message reply. Link: https://lore.kernel.org/r/20220330150551.2573938-5-cristian.marussi@arm.com Fixes: 2f6e27266acb ("firmware: arm_scmi: add common infrastructure and support for base protocol") Signed-off-by: Cristian Marussi [sudeep.holla: Added early break if loop_num_ret = 0 and simplified calc_list_sz calculation] Signed-off-by: Sudeep Holla --- drivers/firmware/arm_scmi/base.c | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/drivers/firmware/arm_scmi/base.c b/drivers/firmware/arm_scmi/base.c index f279146f81106..dc5b0c70a700c 100644 --- a/drivers/firmware/arm_scmi/base.c +++ b/drivers/firmware/arm_scmi/base.c @@ -189,6 +189,9 @@ scmi_base_implementation_list_get(const struct scmi_protocol_handle *ph, list = t->rx.buf + sizeof(*num_ret); do { + size_t real_list_sz; + u32 calc_list_sz; + /* Set the number of protocols to be skipped/already read */ *num_skip = cpu_to_le32(tot_num_ret); @@ -197,11 +200,32 @@ scmi_base_implementation_list_get(const struct scmi_protocol_handle *ph, break; loop_num_ret = le32_to_cpu(*num_ret); + if (!loop_num_ret) + break; + if (loop_num_ret > MAX_PROTOCOLS_IMP - tot_num_ret) { dev_err(dev, "No. of Protocol > MAX_PROTOCOLS_IMP"); break; } + if (t->rx.len < (sizeof(u32) * 2)) { + dev_err(dev, "Truncated reply - rx.len:%zd\n", + t->rx.len); + ret = -EPROTO; + break; + } + + real_list_sz = t->rx.len - sizeof(u32); + calc_list_sz = (1 + (loop_num_ret - 1) / sizeof(u32)) * + sizeof(u32); + if (calc_list_sz != real_list_sz) { + dev_err(dev, + "Malformed reply - real_sz:%zd calc_sz:%u\n", + real_list_sz, calc_list_sz); + ret = -EPROTO; + break; + } + for (loop = 0; loop < loop_num_ret; loop++) protocols_imp[tot_num_ret + loop] = *(list + loop); -- 2.39.5