Commit 3f1bc9ed authored by Dmitry Smirnov's avatar Dmitry Smirnov

New upstream version 0.0~git20180605.35205983

parent 91098558
# Compiled Object files, Static and Dynamic libs (Shared Objects)
*.o
*.a
*.so
# Folders
_obj
_test
# Architecture specific extensions/prefixes
*.[568vq]
[568vq].out
*.cgo1.go
*.cgo2.c
_cgo_defun.c
_cgo_gotypes.go
_cgo_export.*
_testmain.go
*.exe
*.test
......@@ -79,3 +79,45 @@ func BenchmarkSendRecv(b *testing.B) {
}
<-doneCh
}
func BenchmarkSendRecvLarge(b *testing.B) {
client, server := testClientServer()
defer client.Close()
defer server.Close()
const sendSize = 512 * 1024 * 1024
const recvSize = 4 * 1024
sendBuf := make([]byte, sendSize)
recvBuf := make([]byte, recvSize)
b.ResetTimer()
recvDone := make(chan struct{})
go func() {
stream, err := server.AcceptStream()
if err != nil {
return
}
defer stream.Close()
for i := 0; i < b.N; i++ {
for j := 0; j < sendSize/recvSize; j++ {
if _, err := stream.Read(recvBuf); err != nil {
b.Fatalf("err: %v", err)
}
}
}
close(recvDone)
}()
stream, err := client.Open()
if err != nil {
b.Fatalf("err: %v", err)
}
defer stream.Close()
for i := 0; i < b.N; i++ {
if _, err := stream.Write(sendBuf); err != nil {
b.Fatalf("err: %v", err)
}
}
<-recvDone
}
......@@ -46,8 +46,11 @@ type Session struct {
pingID uint32
pingLock sync.Mutex
// streams maps a stream id to a stream
// streams maps a stream id to a stream, and inflight has an entry
// for any outgoing stream that has not yet been established. Both are
// protected by streamLock.
streams map[uint32]*Stream
inflight map[uint32]struct{}
streamLock sync.Mutex
// synCh acts like a semaphore. It is sized to the AcceptBacklog which
......@@ -90,6 +93,7 @@ func newSession(config *Config, conn io.ReadWriteCloser, client bool) *Session {
bufRead: bufio.NewReader(conn),
pings: make(map[uint32]chan struct{}),
streams: make(map[uint32]*Stream),
inflight: make(map[uint32]struct{}),
synCh: make(chan struct{}, config.AcceptBacklog),
acceptCh: make(chan *Stream, config.AcceptBacklog),
sendCh: make(chan sendReady, 64),
......@@ -119,6 +123,12 @@ func (s *Session) IsClosed() bool {
}
}
// CloseChan returns a read-only channel which is closed as
// soon as the session is closed.
func (s *Session) CloseChan() <-chan struct{} {
return s.shutdownCh
}
// NumStreams returns the number of currently open streams
func (s *Session) NumStreams() int {
s.streamLock.Lock()
......@@ -153,7 +163,7 @@ func (s *Session) OpenStream() (*Stream, error) {
}
GET_ID:
// Get and ID, and check for stream exhaustion
// Get an ID, and check for stream exhaustion
id := atomic.LoadUint32(&s.nextStreamID)
if id >= math.MaxUint32-1 {
return nil, ErrStreamsExhausted
......@@ -166,10 +176,16 @@ GET_ID:
stream := newStream(s, id, streamInit)
s.streamLock.Lock()
s.streams[id] = stream
s.inflight[id] = struct{}{}
s.streamLock.Unlock()
// Send the window update to create
if err := stream.sendWindowUpdate(); err != nil {
select {
case <-s.synCh:
default:
s.logger.Printf("[ERR] yamux: aborted stream open without inflight syn semaphore")
}
return nil, err
}
return stream, nil
......@@ -293,8 +309,10 @@ func (s *Session) keepalive() {
case <-time.After(s.config.KeepAliveInterval):
_, err := s.Ping()
if err != nil {
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
s.exitErr(ErrKeepAliveTimeout)
if err != ErrSessionShutdown {
s.logger.Printf("[ERR] yamux: keepalive failed: %v", err)
s.exitErr(ErrKeepAliveTimeout)
}
return
}
case <-s.shutdownCh:
......@@ -313,8 +331,17 @@ func (s *Session) waitForSend(hdr header, body io.Reader) error {
// potential shutdown. Since there's the expectation that sends can happen
// in a timely manner, we enforce the connection write timeout here.
func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) error {
timer := time.NewTimer(s.config.ConnectionWriteTimeout)
defer timer.Stop()
t := timerPool.Get()
timer := t.(*time.Timer)
timer.Reset(s.config.ConnectionWriteTimeout)
defer func() {
timer.Stop()
select {
case <-timer.C:
default:
}
timerPool.Put(t)
}()
ready := sendReady{Hdr: hdr, Body: body, Err: errCh}
select {
......@@ -339,8 +366,17 @@ func (s *Session) waitForSendErr(hdr header, body io.Reader, errCh chan error) e
// the send happens right here, we enforce the connection write timeout if we
// can't queue the header to be sent.
func (s *Session) sendNoWait(hdr header) error {
timer := time.NewTimer(s.config.ConnectionWriteTimeout)
defer timer.Stop()
t := timerPool.Get()
timer := t.(*time.Timer)
timer.Reset(s.config.ConnectionWriteTimeout)
defer func() {
timer.Stop()
select {
case <-timer.C:
default:
}
timerPool.Put(t)
}()
select {
case s.sendCh <- sendReady{Hdr: hdr}:
......@@ -398,11 +434,20 @@ func (s *Session) recv() {
}
}
// Ensure that the index of the handler (typeData/typeWindowUpdate/etc) matches the message type
var (
handlers = []func(*Session, header) error{
typeData: (*Session).handleStreamMessage,
typeWindowUpdate: (*Session).handleStreamMessage,
typePing: (*Session).handlePing,
typeGoAway: (*Session).handleGoAway,
}
)
// recvLoop continues to receive data until a fatal error is encountered
func (s *Session) recvLoop() error {
defer close(s.recvDoneCh)
hdr := header(make([]byte, headerSize))
var handler func(header) error
for {
// Read the header
if _, err := io.ReadFull(s.bufRead, hdr); err != nil {
......@@ -418,22 +463,12 @@ func (s *Session) recvLoop() error {
return ErrInvalidVersion
}
// Switch on the type
switch hdr.MsgType() {
case typeData:
handler = s.handleStreamMessage
case typeWindowUpdate:
handler = s.handleStreamMessage
case typeGoAway:
handler = s.handleGoAway
case typePing:
handler = s.handlePing
default:
mt := hdr.MsgType()
if mt < typeData || mt > typeGoAway {
return ErrInvalidMsgType
}
// Invoke the handler
if err := handler(hdr); err != nil {
if err := handlers[mt](s, hdr); err != nil {
return err
}
}
......@@ -580,19 +615,34 @@ func (s *Session) incomingStream(id uint32) error {
}
// closeStream is used to close a stream once both sides have
// issued a close.
// issued a close. If there was an in-flight SYN and the stream
// was not yet established, then this will give the credit back.
func (s *Session) closeStream(id uint32) {
s.streamLock.Lock()
if _, ok := s.inflight[id]; ok {
select {
case <-s.synCh:
default:
s.logger.Printf("[ERR] yamux: SYN tracking out of sync")
}
}
delete(s.streams, id)
s.streamLock.Unlock()
}
// establishStream is used to mark a stream that was in the
// SYN Sent state as established.
func (s *Session) establishStream() {
func (s *Session) establishStream(id uint32) {
s.streamLock.Lock()
if _, ok := s.inflight[id]; ok {
delete(s.inflight, id)
} else {
s.logger.Printf("[ERR] yamux: established stream without inflight SYN (no tracking entry)")
}
select {
case <-s.synCh:
default:
panic("established stream without inflight syn")
s.logger.Printf("[ERR] yamux: established stream without inflight SYN (didn't have semaphore)")
}
s.streamLock.Unlock()
}
......@@ -148,6 +148,47 @@ func TestPing_Timeout(t *testing.T) {
}
}
func TestCloseBeforeAck(t *testing.T) {
cfg := testConf()
cfg.AcceptBacklog = 8
client, server := testClientServerConfig(cfg)
defer client.Close()
defer server.Close()
for i := 0; i < 8; i++ {
s, err := client.OpenStream()
if err != nil {
t.Fatal(err)
}
s.Close()
}
for i := 0; i < 8; i++ {
s, err := server.AcceptStream()
if err != nil {
t.Fatal(err)
}
s.Close()
}
done := make(chan struct{})
go func() {
defer close(done)
s, err := client.OpenStream()
if err != nil {
t.Fatal(err)
}
s.Close()
}()
select {
case <-done:
case <-time.After(time.Second * 5):
t.Fatal("timed out trying to open stream")
}
}
func TestAccept(t *testing.T) {
client, server := testClientServer()
defer client.Close()
......@@ -335,7 +376,12 @@ func TestSendData_Large(t *testing.T) {
defer client.Close()
defer server.Close()
data := make([]byte, 512*1024)
const (
sendSize = 250 * 1024 * 1024
recvSize = 4 * 1024
)
data := make([]byte, sendSize)
for idx := range data {
data[idx] = byte(idx % 256)
}
......@@ -349,16 +395,17 @@ func TestSendData_Large(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}
buf := make([]byte, 4*1024)
for i := 0; i < 128; i++ {
var sz int
buf := make([]byte, recvSize)
for i := 0; i < sendSize/recvSize; i++ {
n, err := stream.Read(buf)
if err != nil {
t.Fatalf("err: %v", err)
}
if n != 4*1024 {
if n != recvSize {
t.Fatalf("short read: %d", n)
}
sz += n
for idx := range buf {
if buf[idx] != byte(idx%256) {
t.Fatalf("bad: %v %v %v", i, idx, buf[idx])
......@@ -369,6 +416,8 @@ func TestSendData_Large(t *testing.T) {
if err := stream.Close(); err != nil {
t.Fatalf("err: %v", err)
}
t.Logf("cap=%d, n=%d\n", stream.recvBuf.Cap(), sz)
}()
go func() {
......@@ -398,7 +447,7 @@ func TestSendData_Large(t *testing.T) {
}()
select {
case <-doneCh:
case <-time.After(time.Second):
case <-time.After(5 * time.Second):
panic("timeout")
}
}
......@@ -578,7 +627,7 @@ func TestHalfClose(t *testing.T) {
if err != nil {
t.Fatalf("err: %v", err)
}
if _, err := stream.Write([]byte("a")); err != nil {
if _, err = stream.Write([]byte("a")); err != nil {
t.Fatalf("err: %v", err)
}
......@@ -598,7 +647,7 @@ func TestHalfClose(t *testing.T) {
}
// Send more
if _, err := stream.Write([]byte("bcd")); err != nil {
if _, err = stream.Write([]byte("bcd")); err != nil {
t.Fatalf("err: %v", err)
}
stream.Close()
......@@ -985,6 +1034,60 @@ func TestSession_WindowUpdateWriteDuringRead(t *testing.T) {
wg.Wait()
}
func TestSession_PartialReadWindowUpdate(t *testing.T) {
client, server := testClientServerConfig(testConfNoKeepAlive())
defer client.Close()
defer server.Close()
var wg sync.WaitGroup
wg.Add(1)
// Choose a huge flood size that we know will result in a window update.
flood := int64(client.config.MaxStreamWindowSize)
var wr *Stream
// The server will accept a new stream and then flood data to it.
go func() {
defer wg.Done()
var err error
wr, err = server.AcceptStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer wr.Close()
if wr.sendWindow != client.config.MaxStreamWindowSize {
t.Fatalf("sendWindow: exp=%d, got=%d", client.config.MaxStreamWindowSize, wr.sendWindow)
}
n, err := wr.Write(make([]byte, flood))
if err != nil {
t.Fatalf("err: %v", err)
}
if int64(n) != flood {
t.Fatalf("short write: %d", n)
}
if wr.sendWindow != 0 {
t.Fatalf("sendWindow: exp=%d, got=%d", 0, wr.sendWindow)
}
}()
stream, err := client.OpenStream()
if err != nil {
t.Fatalf("err: %v", err)
}
defer stream.Close()
wg.Wait()
_, err = stream.Read(make([]byte, flood/2+1))
if exp := uint32(flood/2 + 1); wr.sendWindow != exp {
t.Errorf("sendWindow: exp=%d, got=%d", exp, wr.sendWindow)
}
}
func TestSession_sendNoWait_Timeout(t *testing.T) {
client, server := testClientServerConfig(testConfNoKeepAlive())
defer client.Close()
......
......@@ -22,7 +22,7 @@ Each field is described below:
## Version Field
The version field is used for future backwards compatibily. At the
The version field is used for future backward compatibility. At the
current time, the field is always set to 0, to indicate the initial
version.
......@@ -96,7 +96,7 @@ Because we are relying on the reliable stream underneath, a connection
can begin sending data once the SYN flag is sent. The corresponding
ACK does not need to be received. This is particularly well suited
for an RPC system where a client wants to open a stream and immediately
fire a request without wiating for the RTT of the ACK.
fire a request without waiting for the RTT of the ACK.
This does introduce the possibility of a connection being rejected
after data has been sent already. This is a slight semantic difference
......@@ -138,4 +138,3 @@ provide an error code:
* 0x0 Normal termination
* 0x1 Protocol error
* 0x2 Internal error
......@@ -47,8 +47,8 @@ type Stream struct {
recvNotifyCh chan struct{}
sendNotifyCh chan struct{}
readDeadline time.Time
writeDeadline time.Time
readDeadline atomic.Value // time.Time
writeDeadline atomic.Value // time.Time
}
// newStream is used to construct a new stream within
......@@ -67,6 +67,8 @@ func newStream(session *Session, id uint32, state streamState) *Stream {
recvNotifyCh: make(chan struct{}, 1),
sendNotifyCh: make(chan struct{}, 1),
}
s.readDeadline.Store(time.Time{})
s.writeDeadline.Store(time.Time{})
return s
}
......@@ -91,10 +93,13 @@ START:
case streamRemoteClose:
fallthrough
case streamClosed:
s.recvLock.Lock()
if s.recvBuf == nil || s.recvBuf.Len() == 0 {
s.recvLock.Unlock()
s.stateLock.Unlock()
return 0, io.EOF
}
s.recvLock.Unlock()
case streamReset:
s.stateLock.Unlock()
return 0, ErrConnectionReset
......@@ -118,12 +123,18 @@ START:
WAIT:
var timeout <-chan time.Time
if !s.readDeadline.IsZero() {
delay := s.readDeadline.Sub(time.Now())
timeout = time.After(delay)
var timer *time.Timer
readDeadline := s.readDeadline.Load().(time.Time)
if !readDeadline.IsZero() {
delay := readDeadline.Sub(time.Now())
timer = time.NewTimer(delay)
timeout = timer.C
}
select {
case <-s.recvNotifyCh:
if timer != nil {
timer.Stop()
}
goto START
case <-timeout:
return 0, ErrTimeout
......@@ -180,7 +191,7 @@ START:
// Send the header
s.sendHdr.encode(typeData, flags, s.id, max)
if err := s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
if err = s.session.waitForSendErr(s.sendHdr, body, s.sendErr); err != nil {
return 0, err
}
......@@ -192,8 +203,9 @@ START:
WAIT:
var timeout <-chan time.Time
if !s.writeDeadline.IsZero() {
delay := s.writeDeadline.Sub(time.Now())
writeDeadline := s.writeDeadline.Load().(time.Time)
if !writeDeadline.IsZero() {
delay := writeDeadline.Sub(time.Now())
timeout = time.After(delay)
}
select {
......@@ -230,18 +242,25 @@ func (s *Stream) sendWindowUpdate() error {
// Determine the delta update
max := s.session.config.MaxStreamWindowSize
delta := max - atomic.LoadUint32(&s.recvWindow)
var bufLen uint32
s.recvLock.Lock()
if s.recvBuf != nil {
bufLen = uint32(s.recvBuf.Len())
}
delta := (max - bufLen) - s.recvWindow
// Determine the flags if any
flags := s.sendFlags()
// Check if we can omit the update
if delta < (max/2) && flags == 0 {
s.recvLock.Unlock()
return nil
}
// Update our window
atomic.AddUint32(&s.recvWindow, delta)
s.recvWindow += delta
s.recvLock.Unlock()
// Send the header
s.controlHdr.encode(typeWindowUpdate, flags, s.id, delta)
......@@ -327,7 +346,7 @@ func (s *Stream) processFlags(flags uint16) error {
if s.state == streamSYNSent {
s.state = streamEstablished
}
s.session.establishStream()
s.session.establishStream(s.id)
}
if flags&flagFIN == flagFIN {
switch s.state {
......@@ -348,9 +367,6 @@ func (s *Stream) processFlags(flags uint16) error {
}
}
if flags&flagRST == flagRST {
if s.state == streamSYNSent {
s.session.establishStream()
}
s.state = streamReset
closeStream = true
s.notifyWaiting()
......@@ -387,16 +403,18 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
if length == 0 {
return nil
}
if remain := atomic.LoadUint32(&s.recvWindow); length > remain {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, remain, length)
return ErrRecvWindowExceeded
}
// Wrap in a limited reader
conn = &io.LimitedReader{R: conn, N: int64(length)}
// Copy into buffer
s.recvLock.Lock()
if length > s.recvWindow {
s.session.logger.Printf("[ERR] yamux: receive window exceeded (stream: %d, remain: %d, recv: %d)", s.id, s.recvWindow, length)
return ErrRecvWindowExceeded
}
if s.recvBuf == nil {
// Allocate the receive buffer just-in-time to fit the full data frame.
// This way we can read in the whole packet without further allocations.
......@@ -409,7 +427,7 @@ func (s *Stream) readData(hdr header, flags uint16, conn io.Reader) error {
}
// Decrement the receive window
atomic.AddUint32(&s.recvWindow, ^uint32(length-1))
s.recvWindow -= length
s.recvLock.Unlock()
// Unblock any readers
......@@ -430,13 +448,13 @@ func (s *Stream) SetDeadline(t time.Time) error {
// SetReadDeadline sets the deadline for future Read calls.
func (s *Stream) SetReadDeadline(t time.Time) error {
s.readDeadline = t
s.readDeadline.Store(t)
return nil
}
// SetWriteDeadline sets the deadline for future Write calls
func (s *Stream) SetWriteDeadline(t time.Time) error {
s.writeDeadline = t
s.writeDeadline.Store(t)
return nil
}
......
package yamux
import (
"sync"
"time"
)
var (
timerPool = &sync.Pool{
New: func() interface{} {
timer := time.NewTimer(time.Hour * 1e6)
timer.Stop()
return timer
},
}
)
// asyncSendErr is used to try an async send of an error
func asyncSendErr(ch chan error, err error) {
if ch == nil {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment