]> git.baikalelectronics.ru Git - kernel.git/commitdiff
cifs: protect session channel fields with chan_lock
authorShyam Prasad N <sprasad@microsoft.com>
Mon, 19 Jul 2021 10:54:46 +0000 (10:54 +0000)
committerSteve French <stfrench@microsoft.com>
Fri, 12 Nov 2021 22:22:20 +0000 (16:22 -0600)
Introducing a new spin lock to protect all the channel related
fields in a cifs_ses struct. This lock should be taken
whenever dealing with the channel fields, and should be held
only for very short intervals which will not sleep.

Currently, all channel related fields in cifs_ses structure
are protected by session_mutex. However, this mutex is held for
long periods (sometimes while waiting for a reply from server).
This makes the codepath quite tricky to change.

Signed-off-by: Shyam Prasad N <sprasad@microsoft.com>
Reviewed-by: Paulo Alcantara (SUSE) <pc@cjr.nz>
Signed-off-by: Steve French <stfrench@microsoft.com>
fs/cifs/cifs_debug.c
fs/cifs/cifsglob.h
fs/cifs/connect.c
fs/cifs/misc.c
fs/cifs/sess.c
fs/cifs/transport.c

index 905a901f7f80b9af3b68dcbedb00afc07066ba01..248a8f973cf9c11cd23f2cd68ec2a84d8fefdce0 100644 (file)
@@ -414,12 +414,14 @@ skip_rdma:
                                   from_kuid(&init_user_ns, ses->linux_uid),
                                   from_kuid(&init_user_ns, ses->cred_uid));
 
+                       spin_lock(&ses->chan_lock);
                        if (ses->chan_count > 1) {
                                seq_printf(m, "\n\n\tExtra Channels: %zu ",
                                           ses->chan_count-1);
                                for (j = 1; j < ses->chan_count; j++)
                                        cifs_dump_channel(m, j, &ses->chans[j]);
                        }
+                       spin_unlock(&ses->chan_lock);
 
                        seq_puts(m, "\n\n\tShares: ");
                        j = 0;
index 673ecbfe6d9a4eed043fce0bc45c781d61f99002..3c18c56a119ba649fcb0699b2fa170c73860090f 100644 (file)
@@ -952,16 +952,21 @@ struct cifs_ses {
         * iface_lock should be taken when accessing any of these fields
         */
        spinlock_t iface_lock;
+       /* ========= begin: protected by iface_lock ======== */
        struct cifs_server_iface *iface_list;
        size_t iface_count;
        unsigned long iface_last_update; /* jiffies */
+       /* ========= end: protected by iface_lock ======== */
 
+       spinlock_t chan_lock;
+       /* ========= begin: protected by chan_lock ======== */
 #define CIFS_MAX_CHANNELS 16
        struct cifs_chan chans[CIFS_MAX_CHANNELS];
        struct cifs_chan *binding_chan;
        size_t chan_count;
        size_t chan_max;
        atomic_t chan_seq; /* round robin state */
+       /* ========= end: protected by chan_lock ======== */
 };
 
 /*
index 498ec05ca10d588d44f199d9c51de94b365a23fe..09a532ed8fe42eb657fef0e758fb9d78cfc9329c 100644 (file)
@@ -1577,8 +1577,12 @@ static int match_session(struct cifs_ses *ses, struct smb3_fs_context *ctx)
         * If an existing session is limited to less channels than
         * requested, it should not be reused
         */
-       if (ses->chan_max < ctx->max_channels)
+       spin_lock(&ses->chan_lock);
+       if (ses->chan_max < ctx->max_channels) {
+               spin_unlock(&ses->chan_lock);
                return 0;
+       }
+       spin_unlock(&ses->chan_lock);
 
        switch (ses->sectype) {
        case Kerberos:
@@ -1713,6 +1717,7 @@ cifs_find_smb_ses(struct TCP_Server_Info *server, struct smb3_fs_context *ctx)
 void cifs_put_smb_ses(struct cifs_ses *ses)
 {
        unsigned int rc, xid;
+       unsigned int chan_count;
        struct TCP_Server_Info *server = ses->server;
        cifs_dbg(FYI, "%s: ses_count=%d\n", __func__, ses->ses_count);
 
@@ -1754,12 +1759,24 @@ void cifs_put_smb_ses(struct cifs_ses *ses)
        list_del_init(&ses->smb_ses_list);
        spin_unlock(&cifs_tcp_ses_lock);
 
+       spin_lock(&ses->chan_lock);
+       chan_count = ses->chan_count;
+       spin_unlock(&ses->chan_lock);
+
        /* close any extra channels */
-       if (ses->chan_count > 1) {
+       if (chan_count > 1) {
                int i;
 
-               for (i = 1; i < ses->chan_count; i++)
+               for (i = 1; i < chan_count; i++) {
+                       /*
+                        * note: for now, we're okay accessing ses->chans
+                        * without chan_lock. But when chans can go away, we'll
+                        * need to introduce ref counting to make sure that chan
+                        * is not freed from under us.
+                        */
                        cifs_put_tcp_session(ses->chans[i].server, 0);
+                       ses->chans[i].server = NULL;
+               }
        }
 
        sesInfoFree(ses);
@@ -2018,9 +2035,11 @@ cifs_get_smb_ses(struct TCP_Server_Info *server, struct smb3_fs_context *ctx)
        mutex_lock(&ses->session_mutex);
 
        /* add server as first channel */
+       spin_lock(&ses->chan_lock);
        ses->chans[0].server = server;
        ses->chan_count = 1;
        ses->chan_max = ctx->multichannel ? ctx->max_channels:1;
+       spin_unlock(&ses->chan_lock);
 
        rc = cifs_negotiate_protocol(xid, ses);
        if (!rc)
index a6089ea53ad7e8e0cf8e4ed6a823b2c12cf55147..5148d48d6a357e0f231ce1e278d0c67ec42bacc9 100644 (file)
@@ -75,6 +75,7 @@ sesInfoAlloc(void)
                INIT_LIST_HEAD(&ret_buf->tcon_list);
                mutex_init(&ret_buf->session_mutex);
                spin_lock_init(&ret_buf->iface_lock);
+               spin_lock_init(&ret_buf->chan_lock);
        }
        return ret_buf;
 }
index 93a1619d60e6e7716e41674d338c1d3d73d1533d..27660bffb91884005343b4451d4ea3421a62d29d 100644 (file)
@@ -54,41 +54,53 @@ bool is_ses_using_iface(struct cifs_ses *ses, struct cifs_server_iface *iface)
 {
        int i;
 
+       spin_lock(&ses->chan_lock);
        for (i = 0; i < ses->chan_count; i++) {
-               if (is_server_using_iface(ses->chans[i].server, iface))
+               if (is_server_using_iface(ses->chans[i].server, iface)) {
+                       spin_unlock(&ses->chan_lock);
                        return true;
+               }
        }
+       spin_unlock(&ses->chan_lock);
        return false;
 }
 
 /* returns number of channels added */
 int cifs_try_adding_channels(struct cifs_sb_info *cifs_sb, struct cifs_ses *ses)
 {
-       int old_chan_count = ses->chan_count;
-       int left = ses->chan_max - ses->chan_count;
+       int old_chan_count, new_chan_count;
+       int left;
        int i = 0;
        int rc = 0;
        int tries = 0;
        struct cifs_server_iface *ifaces = NULL;
        size_t iface_count;
 
+       if (ses->server->dialect < SMB30_PROT_ID) {
+               cifs_dbg(VFS, "multichannel is not supported on this protocol version, use 3.0 or above\n");
+               return 0;
+       }
+
+       spin_lock(&ses->chan_lock);
+
+       new_chan_count = old_chan_count = ses->chan_count;
+       left = ses->chan_max - ses->chan_count;
+
        if (left <= 0) {
                cifs_dbg(FYI,
                         "ses already at max_channels (%zu), nothing to open\n",
                         ses->chan_max);
-               return 0;
-       }
-
-       if (ses->server->dialect < SMB30_PROT_ID) {
-               cifs_dbg(VFS, "multichannel is not supported on this protocol version, use 3.0 or above\n");
+               spin_unlock(&ses->chan_lock);
                return 0;
        }
 
        if (!(ses->server->capabilities & SMB2_GLOBAL_CAP_MULTI_CHANNEL)) {
                cifs_dbg(VFS, "server %s does not support multichannel\n", ses->server->hostname);
                ses->chan_max = 1;
+               spin_unlock(&ses->chan_lock);
                return 0;
        }
+       spin_unlock(&ses->chan_lock);
 
        /*
         * Make a copy of the iface list at the time and use that
@@ -142,10 +154,11 @@ int cifs_try_adding_channels(struct cifs_sb_info *cifs_sb, struct cifs_ses *ses)
                cifs_dbg(FYI, "successfully opened new channel on iface#%d\n",
                         i);
                left--;
+               new_chan_count++;
        }
 
        kfree(ifaces);
-       return ses->chan_count - old_chan_count;
+       return new_chan_count - old_chan_count;
 }
 
 /*
@@ -157,10 +170,14 @@ cifs_ses_find_chan(struct cifs_ses *ses, struct TCP_Server_Info *server)
 {
        int i;
 
+       spin_lock(&ses->chan_lock);
        for (i = 0; i < ses->chan_count; i++) {
-               if (ses->chans[i].server == server)
+               if (ses->chans[i].server == server) {
+                       spin_unlock(&ses->chan_lock);
                        return &ses->chans[i];
+               }
        }
+       spin_unlock(&ses->chan_lock);
        return NULL;
 }
 
@@ -168,6 +185,7 @@ static int
 cifs_ses_add_channel(struct cifs_sb_info *cifs_sb, struct cifs_ses *ses,
                     struct cifs_server_iface *iface)
 {
+       struct TCP_Server_Info *chan_server;
        struct cifs_chan *chan;
        struct smb3_fs_context ctx = {NULL};
        static const char unc_fmt[] = "\\%s\\foo";
@@ -240,15 +258,20 @@ cifs_ses_add_channel(struct cifs_sb_info *cifs_sb, struct cifs_ses *ses,
               SMB2_CLIENT_GUID_SIZE);
        ctx.use_client_guid = true;
 
-       mutex_lock(&ses->session_mutex);
+       chan_server = cifs_get_tcp_session(&ctx);
 
+       mutex_lock(&ses->session_mutex);
+       spin_lock(&ses->chan_lock);
        chan = ses->binding_chan = &ses->chans[ses->chan_count];
-       chan->server = cifs_get_tcp_session(&ctx);
+       chan->server = chan_server;
        if (IS_ERR(chan->server)) {
                rc = PTR_ERR(chan->server);
                chan->server = NULL;
+               spin_unlock(&ses->chan_lock);
                goto out;
        }
+       spin_unlock(&ses->chan_lock);
+
        spin_lock(&cifs_tcp_ses_lock);
        chan->server->is_channel = true;
        spin_unlock(&cifs_tcp_ses_lock);
@@ -283,8 +306,11 @@ cifs_ses_add_channel(struct cifs_sb_info *cifs_sb, struct cifs_ses *ses,
         * ses to the new server.
         */
 
+       spin_lock(&ses->chan_lock);
        ses->chan_count++;
        atomic_set(&ses->chan_seq, 0);
+       spin_unlock(&ses->chan_lock);
+
 out:
        ses->binding = false;
        ses->binding_chan = NULL;
index b7379329b741c92dee666011d4d42d0b79992a9e..61ea3d3f95b4a61a5eca77c78dadf7bcebb107b9 100644 (file)
@@ -1044,14 +1044,17 @@ struct TCP_Server_Info *cifs_pick_channel(struct cifs_ses *ses)
        if (!ses)
                return NULL;
 
+       spin_lock(&ses->chan_lock);
        if (!ses->binding) {
                /* round robin */
                if (ses->chan_count > 1) {
                        index = (uint)atomic_inc_return(&ses->chan_seq);
                        index %= ses->chan_count;
                }
+               spin_unlock(&ses->chan_lock);
                return ses->chans[index].server;
        } else {
+               spin_unlock(&ses->chan_lock);
                return cifs_ses_server(ses);
        }
 }