Commit ba451b83 authored by Dave Cheney's avatar Dave Cheney

go.crypto/ssh: introduce a circular buffer for chanReader

R=agl, gustav.paul, kardianos
CC=golang-dev
http://codereview.appspot.com/6207051
parent ceddc4d5
// Copyright 2012 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"sync"
)
// buffer provides a linked list buffer for data exchange
// between producer and consumer. Theoretically the buffer is
// of unlimited capacity as it does no allocation of its own.
type buffer struct {
// protects concurrent access to head, tail and closed
*sync.Cond
head *element // the buffer that will be read first
tail *element // the buffer that will be read last
closed bool
}
// An element represents a single link in a linked list.
type element struct {
buf []byte
next *element
}
// newBuffer returns an empty buffer that is not closed.
func newBuffer() *buffer {
e := new(element)
b := &buffer{
Cond: newCond(),
head: e,
tail: e,
}
return b
}
// write makes buf available for Read to receive.
// buf must not be modified after the call to write.
func (b *buffer) write(buf []byte) {
b.Cond.L.Lock()
defer b.Cond.L.Unlock()
e := &element{buf: buf}
b.tail.next = e
b.tail = e
b.Cond.Signal()
}
// eof closes the buffer. Reads from the buffer once all
// the data has been consumed will receive os.EOF.
func (b *buffer) eof() error {
b.Cond.L.Lock()
defer b.Cond.L.Unlock()
b.closed = true
b.Cond.Signal()
return nil
}
// Read reads data from the internal buffer in buf.
// Reads will block if no data is available, or until
// the buffer is closed.
func (b *buffer) Read(buf []byte) (n int, err error) {
b.Cond.L.Lock()
defer b.Cond.L.Unlock()
for len(buf) > 0 {
// if there is data in b.head, copy it
if len(b.head.buf) > 0 {
r := copy(buf, b.head.buf)
buf, b.head.buf = buf[r:], b.head.buf[r:]
n += r
continue
}
// if there is a next buffer, make it the head
if len(b.head.buf) == 0 && b.head != b.tail {
b.head = b.head.next
continue
}
// if at least one byte has been copied, return
if n > 0 {
break
}
// if nothing was read, and there is nothing outstanding
// check to see if the buffer is closed.
if b.closed {
err = io.EOF
break
}
// out of buffers, wait for producer
b.Cond.Wait()
}
return
}
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package ssh
import (
"io"
"testing"
)
var BYTES = []byte("abcdefghijklmnopqrstuvwxyz")
func TestBufferReadwrite(t *testing.T) {
b := newBuffer()
b.write(BYTES[:10])
r, _ := b.Read(make([]byte, 10))
if r != 10 {
t.Fatalf("Expected written == read == 10, written: 10, read %d", r)
}
b = newBuffer()
b.write(BYTES[:5])
r, _ = b.Read(make([]byte, 10))
if r != 5 {
t.Fatalf("Expected written == read == 5, written: 5, read %d", r)
}
b = newBuffer()
b.write(BYTES[:10])
r, _ = b.Read(make([]byte, 5))
if r != 5 {
t.Fatalf("Expected written == 10, read == 5, written: 10, read %d", r)
}
b = newBuffer()
b.write(BYTES[:5])
b.write(BYTES[5:15])
r, _ = b.Read(make([]byte, 10))
r2, _ := b.Read(make([]byte, 10))
if r != 10 || r2 != 5 || 15 != r+r2 {
t.Fatal("Expected written == read == 15")
}
}
func TestBufferClose(t *testing.T) {
b := newBuffer()
b.write(BYTES[:10])
b.eof()
_, err := b.Read(make([]byte, 5))
if err != nil {
t.Fatal("expected read of 5 to not return EOF")
}
b = newBuffer()
b.write(BYTES[:10])
b.eof()
r, err := b.Read(make([]byte, 5))
r2, err2 := b.Read(make([]byte, 10))
if r != 5 || r2 != 5 || err != nil || err2 != nil {
t.Fatal("expected reads of 5 and 5")
}
b = newBuffer()
b.write(BYTES[:10])
b.eof()
r, err = b.Read(make([]byte, 5))
r2, err2 = b.Read(make([]byte, 10))
r3, err3 := b.Read(make([]byte, 10))
if r != 5 || r2 != 5 || r3 != 0 || err != nil || err2 != nil || err3 != io.EOF {
t.Fatal("expected reads of 5 and 5 and 0, with EOF")
}
b = newBuffer()
b.write(make([]byte, 5))
b.write(make([]byte, 10))
b.eof()
r, err = b.Read(make([]byte, 9))
r2, err2 = b.Read(make([]byte, 3))
r3, err3 = b.Read(make([]byte, 3))
r4, err4 := b.Read(make([]byte, 10))
if err != nil || err2 != nil || err3 != nil || err4 != io.EOF {
t.Fatalf("Expected EOF on forth read only, err=%v, err2=%v, err3=%v, err4=%v", err, err2, err3, err4)
}
if r != 9 || r2 != 3 || r3 != 3 || r4 != 0 {
t.Fatal("Expected written == read == 15", r, r2, r3, r4)
}
}
......@@ -224,7 +224,7 @@ func (c *ClientConn) mainLoop() {
if length != uint32(len(packet)) {
return
}
c.getChan(remoteId).stdout.handleData(packet)
c.getChan(remoteId).stdout.write(packet)
case msgChannelExtendedData:
if len(packet) < 13 {
// malformed data packet
......@@ -242,7 +242,7 @@ func (c *ClientConn) mainLoop() {
// for stderr on interactive sessions. Other data types are
// silently discarded.
if datatype == 1 {
c.getChan(remoteId).stderr.handleData(packet)
c.getChan(remoteId).stderr.write(packet)
}
default:
switch msg := decode(packet).(type) {
......@@ -448,12 +448,12 @@ func newClientChan(cc conn, id uint32) *clientChan {
channel: &c.channel,
}
c.stdout = &chanReader{
data: make(chan []byte, 16),
channel: &c.channel,
buffer: newBuffer(),
}
c.stderr = &chanReader{
data: make(chan []byte, 16),
channel: &c.channel,
buffer: newBuffer(),
}
return c
}
......@@ -579,44 +579,18 @@ func (w *chanWriter) Close() error {
// A chanReader represents stdout or stderr of a remote process.
type chanReader struct {
// TODO(dfc) a fixed size channel may not be the right data structure.
// If writes to this channel block, they will block mainLoop, making
// it unable to receive new messages from the remote side.
data chan []byte // receives data from remote
dataClosed bool // protects data from being closed twice
*channel // the channel backing this reader
buf []byte
}
// eof signals to the consumer that there is no more data to be received.
func (r *chanReader) eof() {
if !r.dataClosed {
r.dataClosed = true
close(r.data)
}
}
// handleData sends buf to the reader's consumer. If r.data is closed
// the data will be silently discarded
func (r *chanReader) handleData(buf []byte) {
if !r.dataClosed {
r.data <- buf
}
*channel // the channel backing this reader
*buffer
}
// Read reads data from the remote process's stdout or stderr.
func (r *chanReader) Read(data []byte) (int, error) {
var ok bool
for {
if len(r.buf) > 0 {
n := copy(data, r.buf)
r.buf = r.buf[n:]
return n, r.sendWindowAdj(n)
}
r.buf, ok = <-r.data
if !ok {
return 0, io.EOF
func (r *chanReader) Read(buf []byte) (int, error) {
n, err := r.buffer.Read(buf)
if err != nil {
if err == io.EOF {
return n, err
}
return 0, err
}
panic("unreachable")
return n, r.sendWindowAdj(n)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment