Skip to content

Commit

Permalink
http2: implement client initiated graceful shutdown
Browse files Browse the repository at this point in the history
Sends a GOAWAY frame and wait for the in-flight streams to complete.

Fixes golang/go#17292

Change-Id: I2b7dd61446f4ffd9c820fbb21d1233c3b3ad1ba8
Reviewed-on: https://go-review.googlesource.com/30076
Run-TryBot: Brad Fitzpatrick <[email protected]>
TryBot-Result: Gobot Gobot <[email protected]>
Reviewed-by: Brad Fitzpatrick <[email protected]>
  • Loading branch information
rs authored and bradfitz committed Jul 9, 2018
1 parent c4e4b2a commit b87faa7
Show file tree
Hide file tree
Showing 4 changed files with 294 additions and 7 deletions.
7 changes: 7 additions & 0 deletions http2/go17.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ type contextContext interface {
context.Context
}

var errCanceled = context.Canceled

func serverConnBaseContext(c net.Conn, opts *ServeConnOpts) (ctx contextContext, cancel func()) {
ctx, cancel = context.WithCancel(context.Background())
ctx = context.WithValue(ctx, http.LocalAddrContextKey, c.LocalAddr())
Expand Down Expand Up @@ -104,3 +106,8 @@ func requestTrace(req *http.Request) *clientTrace {
func (cc *ClientConn) Ping(ctx context.Context) error {
return cc.ping(ctx)
}

// Shutdown gracefully closes the client connection, waiting for running streams to complete.
func (cc *ClientConn) Shutdown(ctx context.Context) error {
return cc.shutdown(ctx)
}
7 changes: 7 additions & 0 deletions http2/not_go17.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ package http2

import (
"crypto/tls"
"errors"
"net"
"net/http"
"time"
Expand All @@ -18,6 +19,8 @@ type contextContext interface {
Err() error
}

var errCanceled = errors.New("canceled")

type fakeContext struct{}

func (fakeContext) Done() <-chan struct{} { return nil }
Expand Down Expand Up @@ -84,4 +87,8 @@ func (cc *ClientConn) Ping(ctx contextContext) error {
return cc.ping(ctx)
}

func (cc *ClientConn) Shutdown(ctx contextContext) error {
return cc.shutdown(ctx)
}

func (t *Transport) idleConnTimeout() time.Duration { return 0 }
85 changes: 84 additions & 1 deletion http2/transport.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ type ClientConn struct {
cond *sync.Cond // hold mu; broadcast on flow/closed changes
flow flow // our conn-level flow control quota (cs.flow is per stream)
inflow flow // peer's conn-level flow control
closing bool
closed bool
wantSettingsAck bool // we sent a SETTINGS frame and haven't heard back
goAway *GoAwayFrame // if non-nil, the GoAwayFrame we received
Expand Down Expand Up @@ -634,7 +635,7 @@ func (cc *ClientConn) canTakeNewRequestLocked() bool {
if cc.singleUse && cc.nextStreamID > 1 {
return false
}
return cc.goAway == nil && !cc.closed &&
return cc.goAway == nil && !cc.closed && !cc.closing &&
int64(cc.nextStreamID)+int64(cc.pendingRequests) < math.MaxInt32
}

Expand Down Expand Up @@ -665,6 +666,88 @@ func (cc *ClientConn) closeIfIdle() {
cc.tconn.Close()
}

var shutdownEnterWaitStateHook = func() {}

// Shutdown gracefully close the client connection, waiting for running streams to complete.
// Public implementation is in go17.go and not_go17.go
func (cc *ClientConn) shutdown(ctx contextContext) error {
if err := cc.sendGoAway(); err != nil {
return err
}
// Wait for all in-flight streams to complete or connection to close
done := make(chan error, 1)
cancelled := false // guarded by cc.mu
go func() {
cc.mu.Lock()
defer cc.mu.Unlock()
for {
if len(cc.streams) == 0 || cc.closed {
cc.closed = true
done <- cc.tconn.Close()
break
}
if cancelled {
break
}
cc.cond.Wait()
}
}()
shutdownEnterWaitStateHook()
select {
case err := <-done:
return err
case <-ctx.Done():
cc.mu.Lock()
// Free the goroutine above
cancelled = true
cc.cond.Broadcast()
cc.mu.Unlock()
return ctx.Err()
}
}

func (cc *ClientConn) sendGoAway() error {
cc.mu.Lock()
defer cc.mu.Unlock()
cc.wmu.Lock()
defer cc.wmu.Unlock()
if cc.closing {
// GOAWAY sent already
return nil
}
// Send a graceful shutdown frame to server
maxStreamID := cc.nextStreamID
if err := cc.fr.WriteGoAway(maxStreamID, ErrCodeNo, nil); err != nil {
return err
}
if err := cc.bw.Flush(); err != nil {
return err
}
// Prevent new requests
cc.closing = true
return nil
}

// Close closes the client connection immediately.
//
// In-flight requests are interrupted. For a graceful shutdown, use Shutdown instead.
func (cc *ClientConn) Close() error {
cc.mu.Lock()
defer cc.cond.Broadcast()
defer cc.mu.Unlock()
err := errors.New("http2: client connection force closed via ClientConn.Close")
for id, cs := range cc.streams {
select {
case cs.resc <- resAndError{err: err}:
default:
}
cs.bufPipe.CloseWithError(err)
delete(cc.streams, id)
}
cc.closed = true
return cc.tconn.Close()
}

const maxAllocFrameSize = 512 << 10

// frameBuffer returns a scratch buffer suitable for writing DATA frames.
Expand Down
202 changes: 196 additions & 6 deletions http2/transport_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (
"testing"
"time"

"golang.org/x/net/context"
"golang.org/x/net/http2/hpack"
)

Expand All @@ -41,12 +42,13 @@ var (

var tlsConfigInsecure = &tls.Config{InsecureSkipVerify: true}

type testContext struct{}
var canceledCtx context.Context

func (testContext) Done() <-chan struct{} { return make(chan struct{}) }
func (testContext) Err() error { panic("should not be called") }
func (testContext) Deadline() (deadline time.Time, ok bool) { return time.Time{}, false }
func (testContext) Value(key interface{}) interface{} { return nil }
func init() {
ctx, cancel := context.WithCancel(context.Background())
cancel()
canceledCtx = ctx
}

func TestTransportExternal(t *testing.T) {
if !*extNet {
Expand Down Expand Up @@ -3054,7 +3056,7 @@ func TestClientConnPing(t *testing.T) {
if err != nil {
t.Fatal(err)
}
if err = cc.Ping(testContext{}); err != nil {
if err = cc.Ping(context.Background()); err != nil {
t.Fatal(err)
}
}
Expand Down Expand Up @@ -3856,3 +3858,191 @@ func BenchmarkClientRequestHeaders(b *testing.B) {
b.Run(" 100 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 100) })
b.Run("1000 Headers", func(b *testing.B) { benchSimpleRoundTrip(b, 1000) })
}

func activeStreams(cc *ClientConn) int {
cc.mu.Lock()
defer cc.mu.Unlock()
return len(cc.streams)
}

type closeMode int

const (
closeAtHeaders closeMode = iota
closeAtBody
shutdown
shutdownCancel
)

// See golang.org/issue/17292
func testClientConnClose(t *testing.T, closeMode closeMode) {
clientDone := make(chan struct{})
defer close(clientDone)
handlerDone := make(chan struct{})
closeDone := make(chan struct{})
beforeHeader := func() {}
bodyWrite := func(w http.ResponseWriter) {}
st := newServerTester(t, func(w http.ResponseWriter, r *http.Request) {
defer close(handlerDone)
beforeHeader()
w.WriteHeader(http.StatusOK)
w.(http.Flusher).Flush()
bodyWrite(w)
select {
case <-w.(http.CloseNotifier).CloseNotify():
// client closed connection before completion
if closeMode == shutdown || closeMode == shutdownCancel {
t.Error("expected request to complete")
}
case <-clientDone:
if closeMode == closeAtHeaders || closeMode == closeAtBody {
t.Error("expected connection closed by client")
}
}
}, optOnlyServer)
defer st.Close()
tr := &Transport{TLSClientConfig: tlsConfigInsecure}
defer tr.CloseIdleConnections()
cc, err := tr.dialClientConn(st.ts.Listener.Addr().String(), false)
req, err := http.NewRequest("GET", st.ts.URL, nil)
if err != nil {
t.Fatal(err)
}
if closeMode == closeAtHeaders {
beforeHeader = func() {
if err := cc.Close(); err != nil {
t.Error(err)
}
close(closeDone)
}
}
var sendBody chan struct{}
if closeMode == closeAtBody {
sendBody = make(chan struct{})
bodyWrite = func(w http.ResponseWriter) {
<-sendBody
b := make([]byte, 32)
w.Write(b)
w.(http.Flusher).Flush()
if err := cc.Close(); err != nil {
t.Errorf("unexpected ClientConn close error: %v", err)
}
close(closeDone)
w.Write(b)
w.(http.Flusher).Flush()
}
}
res, err := cc.RoundTrip(req)
if res != nil {
defer res.Body.Close()
}
if closeMode == closeAtHeaders {
got := fmt.Sprint(err)
want := "http2: client connection force closed via ClientConn.Close"
if got != want {
t.Fatalf("RoundTrip error = %v, want %v", got, want)
}
} else {
if err != nil {
t.Fatalf("RoundTrip: %v", err)
}
if got, want := activeStreams(cc), 1; got != want {
t.Errorf("got %d active streams, want %d", got, want)
}
}
switch closeMode {
case shutdownCancel:
if err = cc.Shutdown(canceledCtx); err != errCanceled {
t.Errorf("got %v, want %v", err, errCanceled)
}
if cc.closing == false {
t.Error("expected closing to be true")
}
if cc.CanTakeNewRequest() == true {
t.Error("CanTakeNewRequest to return false")
}
if v, want := len(cc.streams), 1; v != want {
t.Errorf("expected %d active streams, got %d", want, v)
}
clientDone <- struct{}{}
<-handlerDone
case shutdown:
wait := make(chan struct{})
shutdownEnterWaitStateHook = func() {
close(wait)
shutdownEnterWaitStateHook = func() {}
}
defer func() { shutdownEnterWaitStateHook = func() {} }()
shutdown := make(chan struct{}, 1)
go func() {
if err = cc.Shutdown(context.Background()); err != nil {
t.Error(err)
}
close(shutdown)
}()
// Let the shutdown to enter wait state
<-wait
cc.mu.Lock()
if cc.closing == false {
t.Error("expected closing to be true")
}
cc.mu.Unlock()
if cc.CanTakeNewRequest() == true {
t.Error("CanTakeNewRequest to return false")
}
if got, want := activeStreams(cc), 1; got != want {
t.Errorf("got %d active streams, want %d", got, want)
}
// Let the active request finish
clientDone <- struct{}{}
// Wait for the shutdown to end
select {
case <-shutdown:
case <-time.After(2 * time.Second):
t.Fatal("expected server connection to close")
}
case closeAtHeaders, closeAtBody:
if closeMode == closeAtBody {
go close(sendBody)
if _, err := io.Copy(ioutil.Discard, res.Body); err == nil {
t.Error("expected a Copy error, got nil")
}
}
<-closeDone
if got, want := activeStreams(cc), 0; got != want {
t.Errorf("got %d active streams, want %d", got, want)
}
// wait for server to get the connection close notice
select {
case <-handlerDone:
case <-time.After(2 * time.Second):
t.Fatal("expected server connection to close")
}
}
}

// The client closes the connection just after the server got the client's HEADERS
// frame, but before the server sends its HEADERS response back. The expected
// result is an error on RoundTrip explaining the client closed the connection.
func TestClientConnCloseAtHeaders(t *testing.T) {
testClientConnClose(t, closeAtHeaders)
}

// The client closes the connection between two server's response DATA frames.
// The expected behavior is a response body io read error on the client.
func TestClientConnCloseAtBody(t *testing.T) {
testClientConnClose(t, closeAtBody)
}

// The client sends a GOAWAY frame before the server finished processing a request.
// We expect the connection not to close until the request is completed.
func TestClientConnShutdown(t *testing.T) {
testClientConnClose(t, shutdown)
}

// The client sends a GOAWAY frame before the server finishes processing a request,
// but cancels the passed context before the request is completed. The expected
// behavior is the client closing the connection after the context is canceled.
func TestClientConnShutdownCancel(t *testing.T) {
testClientConnClose(t, shutdownCancel)
}

0 comments on commit b87faa7

Please sign in to comment.