Last active
November 30, 2020 06:01
-
-
Save zombiezen/85095566a5c025e35817a0b399fec235 to your computer and use it in GitHub Desktop.
Reading while respecting Context.Done
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// ctxReader adds timeouts and cancellation to a reader. | |
type ctxReader struct { | |
r io.Reader | |
ctx context.Context // set to change Context | |
// internal state | |
result chan readResult | |
pos, n int | |
err error | |
buf [1024]byte | |
} | |
type readResult struct { | |
n int | |
err error | |
} | |
// Read reads into p. It makes a best effort to respect the Done signal | |
// in cr.ctx. | |
func (cr *ctxReader) Read(p []byte) (int, error) { | |
if cr.pos < cr.n { | |
// Buffered from previous read. | |
n := copy(p, cr.buf[cr.pos:cr.n]) | |
cr.pos += n | |
if cr.pos == cr.n && cr.err != nil { | |
err := cr.err | |
cr.err = nil | |
return n, err | |
} | |
return n, nil | |
} | |
if cr.result != nil { | |
// Read in progress. | |
select { | |
case r := <-cr.result: | |
cr.result = nil | |
cr.n = r.n | |
cr.pos = copy(p, cr.buf[:cr.n]) | |
if cr.pos == cr.n && r.err != nil { | |
return cr.pos, r.err | |
} | |
cr.err = r.err | |
return cr.pos, nil | |
case <-cr.ctx.Done(): | |
return 0, cr.ctx.Err() | |
} | |
} | |
// Check for early cancel. | |
select { | |
case <-cr.ctx.Done(): | |
return 0, cr.ctx.Err() | |
default: | |
} | |
// Check for timeout support. | |
rd, ok := cr.r.(interface { | |
SetReadDeadline(time.Time) error | |
}) | |
if !ok { | |
return cr.leakyRead(p) | |
} | |
if err := rd.SetReadDeadline(time.Now()); err != nil { | |
return cr.leakyRead(p) | |
} | |
// Start separate goroutine to wait on Context.Done. | |
if d, ok := cr.ctx.Deadline(); ok { | |
rd.SetReadDeadline(d) | |
} else { | |
rd.SetReadDeadline(time.Time{}) | |
} | |
readDone := make(chan struct{}) | |
listenDone := make(chan struct{}) | |
go func() { | |
defer close(listenDone) | |
select { | |
case <-cr.ctx.Done(): | |
rd.SetReadDeadline(time.Now()) // interrupt read | |
case <-readDone: | |
} | |
}() | |
// Read from reader. | |
n, err := cr.r.Read(p) | |
close(readDone) | |
<-listenDone | |
return n, err | |
} | |
// leakyRead reads from the underlying reader in a separate goroutine. | |
// If the Context is Done before the read completes, then the goroutine | |
// will stay alive until cr.wait() is called. The result is written to | |
// cr.buf to avoid retaining p past the end of leakyRead. | |
func (cr *ctxReader) leakyRead(p []byte) (int, error) { | |
cr.result = make(chan readResult) | |
max := len(p) | |
if max > len(cr.buf) { | |
max = len(cr.buf) | |
} | |
go func() { | |
n, err := cr.r.Read(cr.buf[:max]) | |
cr.result <- readResult{n, err} | |
}() | |
select { | |
case r := <-cr.result: | |
cr.result = nil | |
copy(p, cr.buf[:r.n]) | |
return r.n, r.err | |
case <-cr.ctx.Done(): | |
return 0, cr.ctx.Err() | |
} | |
} | |
// wait waits until any goroutine started by leakyRead finishes. | |
func (cr *ctxReader) wait() { | |
if cr.result == nil { | |
return | |
} | |
r := <-cr.result | |
cr.result = nil | |
cr.pos, cr.n = 0, r.n | |
cr.err = r.err | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment