* SOFTWARE.
*/
+#include <linux/sched/signal.h>
#include <linux/module.h>
#include <crypto/aead.h>
+#include <net/strparser.h>
#include <net/tls.h>
+static int tls_do_decryption(struct sock *sk,
+ struct scatterlist *sgin,
+ struct scatterlist *sgout,
+ char *iv_recv,
+ size_t data_len,
+ struct sk_buff *skb,
+ gfp_t flags)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+ struct strp_msg *rxm = strp_msg(skb);
+ struct aead_request *aead_req;
+
+ int ret;
+ unsigned int req_size = sizeof(struct aead_request) +
+ crypto_aead_reqsize(ctx->aead_recv);
+
+ aead_req = kzalloc(req_size, flags);
+ if (!aead_req)
+ return -ENOMEM;
+
+ aead_request_set_tfm(aead_req, ctx->aead_recv);
+ aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
+ aead_request_set_crypt(aead_req, sgin, sgout,
+ data_len + tls_ctx->rx.tag_size,
+ (u8 *)iv_recv);
+ aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
+ crypto_req_done, &ctx->async_wait);
+
+ ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
+
+ if (ret < 0)
+ goto out;
+
+ rxm->offset += tls_ctx->rx.prepend_size;
+ rxm->full_len -= tls_ctx->rx.overhead_size;
+ tls_advance_record_sn(sk, &tls_ctx->rx);
+
+ ctx->decrypted = true;
+
+ ctx->saved_data_ready(sk);
+
+out:
+ kfree(aead_req);
+ return ret;
+}
+
static void trim_sg(struct sock *sk, struct scatterlist *sg,
int *sg_num_elem, unsigned int *sg_size, int target_size)
{
return ret;
}
-void tls_sw_free_tx_resources(struct sock *sk)
+static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
+ long timeo, int *err)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+ struct sk_buff *skb;
+ DEFINE_WAIT_FUNC(wait, woken_wake_function);
+
+ while (!(skb = ctx->recv_pkt)) {
+ if (sk->sk_err) {
+ *err = sock_error(sk);
+ return NULL;
+ }
+
+ if (sock_flag(sk, SOCK_DONE))
+ return NULL;
+
+ if ((flags & MSG_DONTWAIT) || !timeo) {
+ *err = -EAGAIN;
+ return NULL;
+ }
+
+ add_wait_queue(sk_sleep(sk), &wait);
+ sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+ sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
+ sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
+ remove_wait_queue(sk_sleep(sk), &wait);
+
+ /* Handle signals */
+ if (signal_pending(current)) {
+ *err = sock_intr_errno(timeo);
+ return NULL;
+ }
+ }
+
+ return skb;
+}
+
+static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
+ struct scatterlist *sgout)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+ char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + tls_ctx->rx.iv_size];
+ struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2];
+ struct scatterlist *sgin = &sgin_arr[0];
+ struct strp_msg *rxm = strp_msg(skb);
+ int ret, nsg = ARRAY_SIZE(sgin_arr);
+ char aad_recv[TLS_AAD_SPACE_SIZE];
+ struct sk_buff *unused;
+
+ ret = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
+ iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+ tls_ctx->rx.iv_size);
+ if (ret < 0)
+ return ret;
+
+ memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+ if (!sgout) {
+ nsg = skb_cow_data(skb, 0, &unused) + 1;
+ sgin = kmalloc_array(nsg, sizeof(*sgin), sk->sk_allocation);
+ if (!sgout)
+ sgout = sgin;
+ }
+
+ sg_init_table(sgin, nsg);
+ sg_set_buf(&sgin[0], aad_recv, sizeof(aad_recv));
+
+ nsg = skb_to_sgvec(skb, &sgin[1],
+ rxm->offset + tls_ctx->rx.prepend_size,
+ rxm->full_len - tls_ctx->rx.prepend_size);
+
+ tls_make_aad(aad_recv,
+ rxm->full_len - tls_ctx->rx.overhead_size,
+ tls_ctx->rx.rec_seq,
+ tls_ctx->rx.rec_seq_size,
+ ctx->control);
+
+ ret = tls_do_decryption(sk, sgin, sgout, iv,
+ rxm->full_len - tls_ctx->rx.overhead_size,
+ skb, sk->sk_allocation);
+
+ if (sgin != &sgin_arr[0])
+ kfree(sgin);
+
+ return ret;
+}
+
+static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
+ unsigned int len)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+ struct strp_msg *rxm = strp_msg(skb);
+
+ if (len < rxm->full_len) {
+ rxm->offset += len;
+ rxm->full_len -= len;
+
+ return false;
+ }
+
+ /* Finished with message */
+ ctx->recv_pkt = NULL;
+ kfree_skb(skb);
+ strp_unpause(&ctx->strp);
+
+ return true;
+}
+
+int tls_sw_recvmsg(struct sock *sk,
+ struct msghdr *msg,
+ size_t len,
+ int nonblock,
+ int flags,
+ int *addr_len)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+ unsigned char control;
+ struct strp_msg *rxm;
+ struct sk_buff *skb;
+ ssize_t copied = 0;
+ bool cmsg = false;
+ int err = 0;
+ long timeo;
+
+ flags |= nonblock;
+
+ if (unlikely(flags & MSG_ERRQUEUE))
+ return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
+
+ lock_sock(sk);
+
+ timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+ do {
+ bool zc = false;
+ int chunk = 0;
+
+ skb = tls_wait_data(sk, flags, timeo, &err);
+ if (!skb)
+ goto recv_end;
+
+ rxm = strp_msg(skb);
+ if (!cmsg) {
+ int cerr;
+
+ cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
+ sizeof(ctx->control), &ctx->control);
+ cmsg = true;
+ control = ctx->control;
+ if (ctx->control != TLS_RECORD_TYPE_DATA) {
+ if (cerr || msg->msg_flags & MSG_CTRUNC) {
+ err = -EIO;
+ goto recv_end;
+ }
+ }
+ } else if (control != ctx->control) {
+ goto recv_end;
+ }
+
+ if (!ctx->decrypted) {
+ int page_count;
+ int to_copy;
+
+ page_count = iov_iter_npages(&msg->msg_iter,
+ MAX_SKB_FRAGS);
+ to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
+ if (to_copy <= len && page_count < MAX_SKB_FRAGS &&
+ likely(!(flags & MSG_PEEK))) {
+ struct scatterlist sgin[MAX_SKB_FRAGS + 1];
+ char unused[21];
+ int pages = 0;
+
+ zc = true;
+ sg_init_table(sgin, MAX_SKB_FRAGS + 1);
+ sg_set_buf(&sgin[0], unused, 13);
+
+ err = zerocopy_from_iter(sk, &msg->msg_iter,
+ to_copy, &pages,
+ &chunk, &sgin[1],
+ MAX_SKB_FRAGS, false);
+ if (err < 0)
+ goto fallback_to_reg_recv;
+
+ err = decrypt_skb(sk, skb, sgin);
+ for (; pages > 0; pages--)
+ put_page(sg_page(&sgin[pages]));
+ if (err < 0) {
+ tls_err_abort(sk, EBADMSG);
+ goto recv_end;
+ }
+ } else {
+fallback_to_reg_recv:
+ err = decrypt_skb(sk, skb, NULL);
+ if (err < 0) {
+ tls_err_abort(sk, EBADMSG);
+ goto recv_end;
+ }
+ }
+ ctx->decrypted = true;
+ }
+
+ if (!zc) {
+ chunk = min_t(unsigned int, rxm->full_len, len);
+ err = skb_copy_datagram_msg(skb, rxm->offset, msg,
+ chunk);
+ if (err < 0)
+ goto recv_end;
+ }
+
+ copied += chunk;
+ len -= chunk;
+ if (likely(!(flags & MSG_PEEK))) {
+ u8 control = ctx->control;
+
+ if (tls_sw_advance_skb(sk, skb, chunk)) {
+ /* Return full control message to
+ * userspace before trying to parse
+ * another message type
+ */
+ msg->msg_flags |= MSG_EOR;
+ if (control != TLS_RECORD_TYPE_DATA)
+ goto recv_end;
+ }
+ }
+ } while (len);
+
+recv_end:
+ release_sock(sk);
+ return copied ? : err;
+}
+
+ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos,
+ struct pipe_inode_info *pipe,
+ size_t len, unsigned int flags)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
+ struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+ struct strp_msg *rxm = NULL;
+ struct sock *sk = sock->sk;
+ struct sk_buff *skb;
+ ssize_t copied = 0;
+ int err = 0;
+ long timeo;
+ int chunk;
+
+ lock_sock(sk);
+
+ timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
+
+ skb = tls_wait_data(sk, flags, timeo, &err);
+ if (!skb)
+ goto splice_read_end;
+
+ /* splice does not support reading control messages */
+ if (ctx->control != TLS_RECORD_TYPE_DATA) {
+ err = -ENOTSUPP;
+ goto splice_read_end;
+ }
+
+ if (!ctx->decrypted) {
+ err = decrypt_skb(sk, skb, NULL);
+
+ if (err < 0) {
+ tls_err_abort(sk, EBADMSG);
+ goto splice_read_end;
+ }
+ ctx->decrypted = true;
+ }
+ rxm = strp_msg(skb);
+
+ chunk = min_t(unsigned int, rxm->full_len, len);
+ copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
+ if (copied < 0)
+ goto splice_read_end;
+
+ if (likely(!(flags & MSG_PEEK)))
+ tls_sw_advance_skb(sk, skb, copied);
+
+splice_read_end:
+ release_sock(sk);
+ return copied ? : err;
+}
+
+unsigned int tls_sw_poll(struct file *file, struct socket *sock,
+ struct poll_table_struct *wait)
+{
+ unsigned int ret;
+ struct sock *sk = sock->sk;
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+
+ /* Grab POLLOUT and POLLHUP from the underlying socket */
+ ret = ctx->sk_poll(file, sock, wait);
+
+ /* Clear POLLIN bits, and set based on recv_pkt */
+ ret &= ~(POLLIN | POLLRDNORM);
+ if (ctx->recv_pkt)
+ ret |= POLLIN | POLLRDNORM;
+
+ return ret;
+}
+
+static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
+ struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+ char header[tls_ctx->rx.prepend_size];
+ struct strp_msg *rxm = strp_msg(skb);
+ size_t cipher_overhead;
+ size_t data_len = 0;
+ int ret;
+
+ /* Verify that we have a full TLS header, or wait for more data */
+ if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
+ return 0;
+
+ /* Linearize header to local buffer */
+ ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
+
+ if (ret < 0)
+ goto read_failure;
+
+ ctx->control = header[0];
+
+ data_len = ((header[4] & 0xFF) | (header[3] << 8));
+
+ cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
+
+ if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
+ ret = -EMSGSIZE;
+ goto read_failure;
+ }
+ if (data_len < cipher_overhead) {
+ ret = -EBADMSG;
+ goto read_failure;
+ }
+
+ if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) ||
+ header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) {
+ ret = -EINVAL;
+ goto read_failure;
+ }
+
+ return data_len + TLS_HEADER_SIZE;
+
+read_failure:
+ tls_err_abort(strp->sk, ret);
+
+ return ret;
+}
+
+static void tls_queue(struct strparser *strp, struct sk_buff *skb)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
+ struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+ struct strp_msg *rxm;
+
+ rxm = strp_msg(skb);
+
+ ctx->decrypted = false;
+
+ ctx->recv_pkt = skb;
+ strp_pause(strp);
+
+ strp->sk->sk_state_change(strp->sk);
+}
+
+static void tls_data_ready(struct sock *sk)
+{
+ struct tls_context *tls_ctx = tls_get_ctx(sk);
+ struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
+
+ strp_data_ready(&ctx->strp);
+}
+
+void tls_sw_free_resources(struct sock *sk)
{
struct tls_context *tls_ctx = tls_get_ctx(sk);
struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx);
if (ctx->aead_send)
crypto_free_aead(ctx->aead_send);
+ if (ctx->aead_recv) {
+ if (ctx->recv_pkt) {
+ kfree_skb(ctx->recv_pkt);
+ ctx->recv_pkt = NULL;
+ }
+ crypto_free_aead(ctx->aead_recv);
+ strp_stop(&ctx->strp);
+ write_lock_bh(&sk->sk_callback_lock);
+ sk->sk_data_ready = ctx->saved_data_ready;
+ write_unlock_bh(&sk->sk_callback_lock);
+ release_sock(sk);
+ strp_done(&ctx->strp);
+ lock_sock(sk);
+ }
tls_free_both_sg(sk);
kfree(tls_ctx);
}
-int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx)
+int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
{
char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
struct tls_crypto_info *crypto_info;
struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
struct tls_sw_context *sw_ctx;
+ struct cipher_context *cctx;
+ struct crypto_aead **aead;
+ struct strp_callbacks cb;
u16 nonce_size, tag_size, iv_size, rec_seq_size;
char *iv, *rec_seq;
int rc = 0;
goto out;
}
- if (ctx->priv_ctx) {
- rc = -EEXIST;
- goto out;
- }
-
- sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
- if (!sw_ctx) {
- rc = -ENOMEM;
- goto out;
+ if (!ctx->priv_ctx) {
+ sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL);
+ if (!sw_ctx) {
+ rc = -ENOMEM;
+ goto out;
+ }
+ crypto_init_wait(&sw_ctx->async_wait);
+ } else {
+ sw_ctx = ctx->priv_ctx;
}
- crypto_init_wait(&sw_ctx->async_wait);
-
ctx->priv_ctx = (struct tls_offload_context *)sw_ctx;
- crypto_info = &ctx->crypto_send;
+ if (tx) {
+ crypto_info = &ctx->crypto_send;
+ cctx = &ctx->tx;
+ aead = &sw_ctx->aead_send;
+ } else {
+ crypto_info = &ctx->crypto_recv;
+ cctx = &ctx->rx;
+ aead = &sw_ctx->aead_recv;
+ }
+
switch (crypto_info->cipher_type) {
case TLS_CIPHER_AES_GCM_128: {
nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
goto free_priv;
}
- ctx->tx.prepend_size = TLS_HEADER_SIZE + nonce_size;
- ctx->tx.tag_size = tag_size;
- ctx->tx.overhead_size = ctx->tx.prepend_size + ctx->tx.tag_size;
- ctx->tx.iv_size = iv_size;
- ctx->tx.iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
- GFP_KERNEL);
- if (!ctx->tx.iv) {
+ cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
+ cctx->tag_size = tag_size;
+ cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
+ cctx->iv_size = iv_size;
+ cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
+ GFP_KERNEL);
+ if (!cctx->iv) {
rc = -ENOMEM;
goto free_priv;
}
- memcpy(ctx->tx.iv, gcm_128_info->salt,
- TLS_CIPHER_AES_GCM_128_SALT_SIZE);
- memcpy(ctx->tx.iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
- ctx->tx.rec_seq_size = rec_seq_size;
- ctx->tx.rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
- if (!ctx->tx.rec_seq) {
+ memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
+ memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
+ cctx->rec_seq_size = rec_seq_size;
+ cctx->rec_seq = kmalloc(rec_seq_size, GFP_KERNEL);
+ if (!cctx->rec_seq) {
rc = -ENOMEM;
goto free_iv;
}
- memcpy(ctx->tx.rec_seq, rec_seq, rec_seq_size);
-
- sg_init_table(sw_ctx->sg_encrypted_data,
- ARRAY_SIZE(sw_ctx->sg_encrypted_data));
- sg_init_table(sw_ctx->sg_plaintext_data,
- ARRAY_SIZE(sw_ctx->sg_plaintext_data));
-
- sg_init_table(sw_ctx->sg_aead_in, 2);
- sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
- sizeof(sw_ctx->aad_space));
- sg_unmark_end(&sw_ctx->sg_aead_in[1]);
- sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
- sg_init_table(sw_ctx->sg_aead_out, 2);
- sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
- sizeof(sw_ctx->aad_space));
- sg_unmark_end(&sw_ctx->sg_aead_out[1]);
- sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
-
- if (!sw_ctx->aead_send) {
- sw_ctx->aead_send = crypto_alloc_aead("gcm(aes)", 0, 0);
- if (IS_ERR(sw_ctx->aead_send)) {
- rc = PTR_ERR(sw_ctx->aead_send);
- sw_ctx->aead_send = NULL;
+ memcpy(cctx->rec_seq, rec_seq, rec_seq_size);
+
+ if (tx) {
+ sg_init_table(sw_ctx->sg_encrypted_data,
+ ARRAY_SIZE(sw_ctx->sg_encrypted_data));
+ sg_init_table(sw_ctx->sg_plaintext_data,
+ ARRAY_SIZE(sw_ctx->sg_plaintext_data));
+
+ sg_init_table(sw_ctx->sg_aead_in, 2);
+ sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space,
+ sizeof(sw_ctx->aad_space));
+ sg_unmark_end(&sw_ctx->sg_aead_in[1]);
+ sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data);
+ sg_init_table(sw_ctx->sg_aead_out, 2);
+ sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space,
+ sizeof(sw_ctx->aad_space));
+ sg_unmark_end(&sw_ctx->sg_aead_out[1]);
+ sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data);
+ }
+
+ if (!*aead) {
+ *aead = crypto_alloc_aead("gcm(aes)", 0, 0);
+ if (IS_ERR(*aead)) {
+ rc = PTR_ERR(*aead);
+ *aead = NULL;
goto free_rec_seq;
}
}
memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
- rc = crypto_aead_setkey(sw_ctx->aead_send, keyval,
+ rc = crypto_aead_setkey(*aead, keyval,
TLS_CIPHER_AES_GCM_128_KEY_SIZE);
if (rc)
goto free_aead;
- rc = crypto_aead_setauthsize(sw_ctx->aead_send, ctx->tx.tag_size);
- if (!rc)
- return 0;
+ rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
+ if (rc)
+ goto free_aead;
+
+ if (!tx) {
+ /* Set up strparser */
+ memset(&cb, 0, sizeof(cb));
+ cb.rcv_msg = tls_queue;
+ cb.parse_msg = tls_read_size;
+
+ strp_init(&sw_ctx->strp, sk, &cb);
+
+ write_lock_bh(&sk->sk_callback_lock);
+ sw_ctx->saved_data_ready = sk->sk_data_ready;
+ sk->sk_data_ready = tls_data_ready;
+ write_unlock_bh(&sk->sk_callback_lock);
+
+ sw_ctx->sk_poll = sk->sk_socket->ops->poll;
+
+ strp_check_rcv(&sw_ctx->strp);
+ }
+
+ goto out;
free_aead:
- crypto_free_aead(sw_ctx->aead_send);
- sw_ctx->aead_send = NULL;
+ crypto_free_aead(*aead);
+ *aead = NULL;
free_rec_seq:
- kfree(ctx->tx.rec_seq);
- ctx->tx.rec_seq = NULL;
+ kfree(cctx->rec_seq);
+ cctx->rec_seq = NULL;
free_iv:
kfree(ctx->tx.iv);
ctx->tx.iv = NULL;