+// Copyright 2019 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package memoize supports memoizing the return values of functions with
+// idempotent results that are expensive to compute.
+//
+// To use this package, build a store and use it to acquire handles with the
+// Bind method.
+//
+package memoize
+
+import (
+ "context"
+ "flag"
+ "fmt"
+ "reflect"
+ "sync"
+ "sync/atomic"
+
+ "golang.org/x/tools/internal/xcontext"
+)
+
+var (
+ panicOnDestroyed = flag.Bool("memoize_panic_on_destroyed", false,
+ "Panic when a destroyed generation is read rather than returning an error. "+
+ "Panicking may make it easier to debug lifetime errors, especially when "+
+ "used with GOTRACEBACK=crash to see all running goroutines.")
+)
+
+// Store binds keys to functions, returning handles that can be used to access
+// the functions results.
+type Store struct {
+ mu sync.Mutex
+ // handles is the set of values stored.
+ handles map[interface{}]*Handle
+
+ // generations is the set of generations live in this store.
+ generations map[*Generation]struct{}
+}
+
+// Generation creates a new Generation associated with s. Destroy must be
+// called on the returned Generation once it is no longer in use. name is
+// for debugging purposes only.
+func (s *Store) Generation(name string) *Generation {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ if s.handles == nil {
+ s.handles = map[interface{}]*Handle{}
+ s.generations = map[*Generation]struct{}{}
+ }
+ g := &Generation{store: s, name: name}
+ s.generations[g] = struct{}{}
+ return g
+}
+
+// A Generation is a logical point in time of the cache life-cycle. Cache
+// entries associated with a Generation will not be removed until the
+// Generation is destroyed.
+type Generation struct {
+ // destroyed is 1 after the generation is destroyed. Atomic.
+ destroyed uint32
+ store *Store
+ name string
+ // wg tracks the reference count of this generation.
+ wg sync.WaitGroup
+}
+
+// Destroy waits for all operations referencing g to complete, then removes
+// all references to g from cache entries. Cache entries that no longer
+// reference any non-destroyed generation are removed. Destroy must be called
+// exactly once for each generation.
+func (g *Generation) Destroy() {
+ g.wg.Wait()
+ atomic.StoreUint32(&g.destroyed, 1)
+ g.store.mu.Lock()
+ defer g.store.mu.Unlock()
+ for k, e := range g.store.handles {
+ e.mu.Lock()
+ if _, ok := e.generations[g]; ok {
+ delete(e.generations, g) // delete even if it's dead, in case of dangling references to the entry.
+ if len(e.generations) == 0 {
+ delete(g.store.handles, k)
+ e.state = stateDestroyed
+ }
+ }
+ e.mu.Unlock()
+ }
+ delete(g.store.generations, g)
+}
+
+// Acquire creates a new reference to g, and returns a func to release that
+// reference.
+func (g *Generation) Acquire(ctx context.Context) func() {
+ destroyed := atomic.LoadUint32(&g.destroyed)
+ if ctx.Err() != nil {
+ return func() {}
+ }
+ if destroyed != 0 {
+ panic("acquire on destroyed generation " + g.name)
+ }
+ g.wg.Add(1)
+ return g.wg.Done
+}
+
+// Arg is a marker interface that can be embedded to indicate a type is
+// intended for use as a Function argument.
+type Arg interface{ memoizeArg() }
+
+// Function is the type for functions that can be memoized.
+// The result must be a pointer.
+type Function func(ctx context.Context, arg Arg) interface{}
+
+type state int
+
+const (
+ stateIdle = iota
+ stateRunning
+ stateCompleted
+ stateDestroyed
+)
+
+// Handle is returned from a store when a key is bound to a function.
+// It is then used to access the results of that function.
+//
+// A Handle starts out in idle state, waiting for something to demand its
+// evaluation. It then transitions into running state. While it's running,
+// waiters tracks the number of Get calls waiting for a result, and the done
+// channel is used to notify waiters of the next state transition. Once the
+// evaluation finishes, value is set, state changes to completed, and done
+// is closed, unblocking waiters. Alternatively, as Get calls are cancelled,
+// they decrement waiters. If it drops to zero, the inner context is cancelled,
+// computation is abandoned, and state resets to idle to start the process over
+// again.
+type Handle struct {
+ key interface{}
+ mu sync.Mutex
+
+ // generations is the set of generations in which this handle is valid.
+ generations map[*Generation]struct{}
+
+ state state
+ // done is set in running state, and closed when exiting it.
+ done chan struct{}
+ // cancel is set in running state. It cancels computation.
+ cancel context.CancelFunc
+ // waiters is the number of Gets outstanding.
+ waiters uint
+ // the function that will be used to populate the value
+ function Function
+ // value is set in completed state.
+ value interface{}
+}
+
+// Bind returns a handle for the given key and function.
+//
+// Each call to bind will return the same handle if it is already bound.
+// Bind will always return a valid handle, creating one if needed.
+// Each key can only have one handle at any given time.
+// The value will be held at least until the associated generation is destroyed.
+// Bind does not cause the value to be generated.
+func (g *Generation) Bind(key interface{}, function Function) *Handle {
+ // panic early if the function is nil
+ // it would panic later anyway, but in a way that was much harder to debug
+ if function == nil {
+ panic("the function passed to bind must not be nil")
+ }
+ if atomic.LoadUint32(&g.destroyed) != 0 {
+ panic("operation on destroyed generation " + g.name)
+ }
+ g.store.mu.Lock()
+ defer g.store.mu.Unlock()
+ h, ok := g.store.handles[key]
+ if !ok {
+ h := &Handle{
+ key: key,
+ function: function,
+ generations: map[*Generation]struct{}{g: {}},
+ }
+ g.store.handles[key] = h
+ return h
+ }
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ if _, ok := h.generations[g]; !ok {
+ h.generations[g] = struct{}{}
+ }
+ return h
+}
+
+// Stats returns the number of each type of value in the store.
+func (s *Store) Stats() map[reflect.Type]int {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ result := map[reflect.Type]int{}
+ for k := range s.handles {
+ result[reflect.TypeOf(k)]++
+ }
+ return result
+}
+
+// DebugOnlyIterate iterates through all live cache entries and calls f on them.
+// It should only be used for debugging purposes.
+func (s *Store) DebugOnlyIterate(f func(k, v interface{})) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ for k, e := range s.handles {
+ var v interface{}
+ e.mu.Lock()
+ if e.state == stateCompleted {
+ v = e.value
+ }
+ e.mu.Unlock()
+ if v == nil {
+ continue
+ }
+ f(k, v)
+ }
+}
+
+func (g *Generation) Inherit(h *Handle) {
+ if atomic.LoadUint32(&g.destroyed) != 0 {
+ panic("inherit on destroyed generation " + g.name)
+ }
+
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ if h.state == stateDestroyed {
+ panic(fmt.Sprintf("inheriting destroyed handle %#v (type %T) into generation %v", h.key, h.key, g.name))
+ }
+ h.generations[g] = struct{}{}
+}
+
+// Cached returns the value associated with a handle.
+//
+// It will never cause the value to be generated.
+// It will return the cached value, if present.
+func (h *Handle) Cached(g *Generation) interface{} {
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ if _, ok := h.generations[g]; !ok {
+ return nil
+ }
+ if h.state == stateCompleted {
+ return h.value
+ }
+ return nil
+}
+
+// Get returns the value associated with a handle.
+//
+// If the value is not yet ready, the underlying function will be invoked.
+// If ctx is cancelled, Get returns nil.
+func (h *Handle) Get(ctx context.Context, g *Generation, arg Arg) (interface{}, error) {
+ release := g.Acquire(ctx)
+ defer release()
+
+ if ctx.Err() != nil {
+ return nil, ctx.Err()
+ }
+ h.mu.Lock()
+ if _, ok := h.generations[g]; !ok {
+ h.mu.Unlock()
+
+ err := fmt.Errorf("reading key %#v: generation %v is not known", h.key, g.name)
+ if *panicOnDestroyed && ctx.Err() != nil {
+ panic(err)
+ }
+ return nil, err
+ }
+ switch h.state {
+ case stateIdle:
+ return h.run(ctx, g, arg)
+ case stateRunning:
+ return h.wait(ctx)
+ case stateCompleted:
+ defer h.mu.Unlock()
+ return h.value, nil
+ case stateDestroyed:
+ h.mu.Unlock()
+ err := fmt.Errorf("Get on destroyed entry %#v (type %T) in generation %v", h.key, h.key, g.name)
+ if *panicOnDestroyed {
+ panic(err)
+ }
+ return nil, err
+ default:
+ panic("unknown state")
+ }
+}
+
+// run starts h.function and returns the result. h.mu must be locked.
+func (h *Handle) run(ctx context.Context, g *Generation, arg Arg) (interface{}, error) {
+ childCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
+ h.cancel = cancel
+ h.state = stateRunning
+ h.done = make(chan struct{})
+ function := h.function // Read under the lock
+
+ // Make sure that the generation isn't destroyed while we're running in it.
+ release := g.Acquire(ctx)
+ go func() {
+ defer release()
+ // Just in case the function does something expensive without checking
+ // the context, double-check we're still alive.
+ if childCtx.Err() != nil {
+ return
+ }
+ v := function(childCtx, arg)
+ if childCtx.Err() != nil {
+ return
+ }
+
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ // It's theoretically possible that the handle has been cancelled out
+ // of the run that started us, and then started running again since we
+ // checked childCtx above. Even so, that should be harmless, since each
+ // run should produce the same results.
+ if h.state != stateRunning {
+ return
+ }
+ h.value = v
+ h.function = nil
+ h.state = stateCompleted
+ close(h.done)
+ }()
+
+ return h.wait(ctx)
+}
+
+// wait waits for the value to be computed, or ctx to be cancelled. h.mu must be locked.
+func (h *Handle) wait(ctx context.Context) (interface{}, error) {
+ h.waiters++
+ done := h.done
+ h.mu.Unlock()
+
+ select {
+ case <-done:
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ if h.state == stateCompleted {
+ return h.value, nil
+ }
+ return nil, nil
+ case <-ctx.Done():
+ h.mu.Lock()
+ defer h.mu.Unlock()
+ h.waiters--
+ if h.waiters == 0 && h.state == stateRunning {
+ h.cancel()
+ close(h.done)
+ h.state = stateIdle
+ h.done = nil
+ h.cancel = nil
+ }
+ return nil, ctx.Err()
+ }
+}