From f1698a337eb2b511835cd6cc38bf76d8e776ffa1 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Sun, 28 Apr 2019 13:26:22 +0000 Subject: widgets/msglist: fix MessageList.store race This field could be written to in the middle of a Draw call, which reads it multiple times. Use an atomic variable instead. --- widgets/msglist.go | 55 +++++++++++++++++++++++++++++++----------------------- 1 file changed, 32 insertions(+), 23 deletions(-) (limited to 'widgets') diff --git a/widgets/msglist.go b/widgets/msglist.go index c4b5d82..4f853f3 100644 --- a/widgets/msglist.go +++ b/widgets/msglist.go @@ -2,6 +2,7 @@ package widgets import ( "log" + "sync/atomic" "github.com/gdamore/tcell" @@ -19,7 +20,7 @@ type MessageList struct { scroll int selected int spinner *Spinner - store *lib.MessageStore + store atomic.Value // *lib.MessageStore } // TODO: fish in config @@ -29,6 +30,7 @@ func NewMessageList(logger *log.Logger) *MessageList { selected: 0, spinner: NewSpinner(), } + ml.store.Store((*lib.MessageStore)(nil)) ml.spinner.OnInvalidate(func(_ ui.Drawable) { ml.Invalidate() }) @@ -45,7 +47,8 @@ func (ml *MessageList) Draw(ctx *ui.Context) { ml.height = ctx.Height() ctx.Fill(0, 0, ctx.Width(), ctx.Height(), ' ', tcell.StyleDefault) - if ml.store == nil { + store := ml.Store() + if store == nil { ml.spinner.Draw(ctx) return } @@ -55,9 +58,9 @@ func (ml *MessageList) Draw(ctx *ui.Context) { row int = 0 ) - for i := len(ml.store.Uids) - 1 - ml.scroll; i >= 0; i-- { - uid := ml.store.Uids[i] - msg := ml.store.Messages[uid] + for i := len(store.Uids) - 1 - ml.scroll; i >= 0; i-- { + uid := store.Uids[i] + msg := store.Messages[uid] if row >= ctx.Height() { break @@ -74,7 +77,7 @@ func (ml *MessageList) Draw(ctx *ui.Context) { if row == ml.selected-ml.scroll { style = style.Reverse(true) } - if _, ok := ml.store.Deleted[msg.Uid]; ok { + if _, ok := store.Deleted[msg.Uid]; ok { style = style.Foreground(tcell.ColorGray) } ctx.Fill(0, row, ctx.Width(), 1, ' ', style) @@ -83,14 +86,14 @@ func (ml *MessageList) Draw(ctx *ui.Context) { row += 1 } - if len(ml.store.Uids) == 0 { + if len(store.Uids) == 0 { msg := "(no messages)" ctx.Printf((ctx.Width()/2)-(len(msg)/2), 0, tcell.StyleDefault, "%s", msg) } if len(needsHeaders) != 0 { - ml.store.FetchHeaders(needsHeaders, nil) + store.FetchHeaders(needsHeaders, nil) ml.spinner.Start() } else { ml.spinner.Stop() @@ -102,26 +105,28 @@ func (ml *MessageList) Height() int { } func (ml *MessageList) storeUpdate(store *lib.MessageStore) { - if ml.store != store { + if ml.Store() != store { return } - if len(ml.store.Uids) > 0 { - for ml.selected >= len(ml.store.Uids) { + + if len(store.Uids) > 0 { + for ml.selected >= len(store.Uids) { ml.Prev() } } + ml.Invalidate() } func (ml *MessageList) SetStore(store *lib.MessageStore) { - if ml.store == store { + if ml.Store() == store { ml.scroll = 0 ml.selected = 0 } - ml.store = store + ml.store.Store(store) if store != nil { ml.spinner.Stop() - ml.store.OnUpdate(ml.storeUpdate) + store.OnUpdate(ml.storeUpdate) } else { ml.spinner.Start() } @@ -129,23 +134,26 @@ func (ml *MessageList) SetStore(store *lib.MessageStore) { } func (ml *MessageList) Store() *lib.MessageStore { - return ml.store + return ml.store.Load().(*lib.MessageStore) } func (ml *MessageList) Empty() bool { - return ml.store == nil || len(ml.store.Uids) == 0 + store := ml.Store() + return store == nil || len(store.Uids) == 0 } func (ml *MessageList) Selected() *types.MessageInfo { - return ml.store.Messages[ml.store.Uids[len(ml.store.Uids)-ml.selected-1]] + store := ml.Store() + return store.Messages[store.Uids[len(store.Uids)-ml.selected-1]] } func (ml *MessageList) Select(index int) { + store := ml.Store() ml.selected = index - for ; ml.selected < 0; ml.selected = len(ml.store.Uids) + ml.selected { + for ; ml.selected < 0; ml.selected = len(store.Uids) + ml.selected { } - if ml.selected > len(ml.store.Uids) { - ml.selected = len(ml.store.Uids) + if ml.selected > len(store.Uids) { + ml.selected = len(store.Uids) } // I'm too lazy to do the math right now for ml.selected-ml.scroll >= ml.Height() { @@ -157,15 +165,16 @@ func (ml *MessageList) Select(index int) { } func (ml *MessageList) nextPrev(delta int) { - if ml.store == nil || len(ml.store.Uids) == 0 { + store := ml.Store() + if store == nil || len(store.Uids) == 0 { return } ml.selected += delta if ml.selected < 0 { ml.selected = 0 } - if ml.selected >= len(ml.store.Uids) { - ml.selected = len(ml.store.Uids) - 1 + if ml.selected >= len(store.Uids) { + ml.selected = len(store.Uids) - 1 } if ml.Height() != 0 { if ml.selected-ml.scroll >= ml.Height() { -- cgit v1.2.3