Commit a2a307ee authored by Dave Cheney's avatar Dave Cheney

go.crypto/ssh: ensure {Server,Client}Conn do not expose io.ReadWriter

Transport should not be a ReadWriter. It can only write packets, i.e. no partial reads or writes. Furthermore, you can currently do ClientConn.Write() while the connection is live, which sends raw bytes over the connection. Doing so will confuse the transports because the data is not encrypted.

As a consequence, ClientConn and ServerConn stop being a net.Conn

Finally, ensure that {Server,Client}Conn implement LocalAddr and RemoteAddr methods that previously were exposed by an embedded net.Conn field.

R=hanwen
CC=golang-dev
https://codereview.appspot.com/16610043
parent 004923d8
......@@ -16,7 +16,7 @@ import (
// ClientConn represents the client side of an SSH connection.
type ClientConn struct {
*transport
transport *transport
config *ClientConfig
chanList // channels associated with this connection
forwardList // forwarded tcpip connections from the remote side
......@@ -47,13 +47,22 @@ func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientCo
}
if err := conn.handshake(); err != nil {
conn.Close()
conn.transport.Close()
return nil, fmt.Errorf("handshake failed: %v", err)
}
go conn.mainLoop()
return conn, nil
}
// Close closes the connection.
func (c *ClientConn) Close() error { return c.transport.Close() }
// LocalAddr returns the local network address.
func (c *ClientConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
// RemoteAddr returns the remote network address.
func (c *ClientConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
// handshake performs the client side key exchange. See RFC 4253 Section 7.
func (c *ClientConn) handshake() error {
clientVersion := []byte(packageVersion)
......@@ -78,10 +87,10 @@ func (c *ClientConn) handshake() error {
CompressionServerClient: supportedCompressions,
}
kexInitPacket := marshal(msgKexInit, clientKexInit)
if err := c.writePacket(kexInitPacket); err != nil {
if err := c.transport.writePacket(kexInitPacket); err != nil {
return err
}
packet, err := c.readPacket()
packet, err := c.transport.readPacket()
if err != nil {
return err
}
......@@ -99,7 +108,7 @@ func (c *ClientConn) handshake() error {
if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0] {
// The server sent a Kex message for the wrong algorithm,
// which we have to ignore.
if _, err := c.readPacket(); err != nil {
if _, err := c.transport.readPacket(); err != nil {
return err
}
}
......@@ -115,7 +124,7 @@ func (c *ClientConn) handshake() error {
clientKexInit: kexInitPacket,
serverKexInit: packet,
}
result, err := kex.Client(c, c.config.rand(), &magics)
result, err := kex.Client(c.transport, c.config.rand(), &magics)
if err != nil {
return err
}
......@@ -126,7 +135,7 @@ func (c *ClientConn) handshake() error {
}
if checker := c.config.HostKeyChecker; checker != nil {
err = checker.Check(c.dialAddress, c.RemoteAddr(), algs.hostKey, result.HostKey)
err = checker.Check(c.dialAddress, c.transport.RemoteAddr(), algs.hostKey, result.HostKey)
if err != nil {
return err
}
......@@ -134,10 +143,10 @@ func (c *ClientConn) handshake() error {
c.transport.prepareKeyChange(algs, result)
if err = c.writePacket([]byte{msgNewKeys}); err != nil {
if err = c.transport.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
if packet, err = c.readPacket(); err != nil {
if packet, err = c.transport.readPacket(); err != nil {
return err
}
if packet[0] != msgNewKeys {
......@@ -171,13 +180,13 @@ func verifyHostKeySignature(hostKeyAlgo string, hostKeyBytes []byte, data []byte
// to their respective ClientChans.
func (c *ClientConn) mainLoop() {
defer func() {
c.Close()
c.transport.Close()
c.chanList.closeAll()
c.forwardList.closeAll()
}()
for {
packet, err := c.readPacket()
packet, err := c.transport.readPacket()
if err != nil {
break
}
......@@ -298,7 +307,7 @@ func (c *ClientConn) mainLoop() {
// This handles keepalive messages and matches
// the behaviour of OpenSSH.
if msg.WantReply {
c.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{}))
c.transport.writePacket(marshal(msgRequestFailure, globalRequestFailureMsg{}))
}
case *globalRequestSuccessMsg, *globalRequestFailureMsg:
c.globalRequest.response <- msg
......@@ -355,7 +364,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
MaxPacketSize: 1 << 15,
}
c.writePacket(marshal(msgChannelOpenConfirm, m))
c.transport.writePacket(marshal(msgChannelOpenConfirm, m))
l <- forward{ch, raddr}
default:
// unknown channel type
......@@ -365,7 +374,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
Message: fmt.Sprintf("unknown channel type: %v", msg.ChanType),
Language: "en_US.UTF-8",
}
c.writePacket(marshal(msgChannelOpenFailure, m))
c.transport.writePacket(marshal(msgChannelOpenFailure, m))
}
}
......@@ -375,7 +384,7 @@ func (c *ClientConn) handleChanOpen(msg *channelOpenMsg) {
func (c *ClientConn) sendGlobalRequest(m interface{}) (*globalRequestSuccessMsg, error) {
c.globalRequest.Lock()
defer c.globalRequest.Unlock()
if err := c.writePacket(marshal(msgGlobalRequest, m)); err != nil {
if err := c.transport.writePacket(marshal(msgGlobalRequest, m)); err != nil {
return nil, err
}
r := <-c.globalRequest.response
......@@ -394,7 +403,7 @@ func (c *ClientConn) sendConnectionFailed(remoteId uint32) error {
Message: "invalid request",
Language: "en_US.UTF-8",
}
return c.writePacket(marshal(msgChannelOpenFailure, m))
return c.transport.writePacket(marshal(msgChannelOpenFailure, m))
}
// parseTCPAddr parses the originating address from the remote into a *net.TCPAddr.
......
......@@ -14,10 +14,10 @@ import (
// authenticate authenticates with the remote server. See RFC 4252.
func (c *ClientConn) authenticate(session []byte) error {
// initiate user auth session
if err := c.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
if err := c.transport.writePacket(marshal(msgServiceRequest, serviceRequestMsg{serviceUserAuth})); err != nil {
return err
}
packet, err := c.readPacket()
packet, err := c.transport.readPacket()
if err != nil {
return err
}
......
......@@ -5,6 +5,8 @@
package ssh
import (
"io"
"net"
"testing"
)
......@@ -24,3 +26,32 @@ func TestSafeString(t *testing.T) {
}
}
}
// Make sure Read/Write are not exposed.
func TestConnHideRWMethods(t *testing.T) {
for _, c := range []interface{}{new(ServerConn), new(ClientConn)} {
if _, ok := c.(io.Reader); ok {
t.Errorf("%T implements io.Reader", c)
}
if _, ok := c.(io.Writer); ok {
t.Errorf("%T implements io.Writer", c)
}
}
}
func TestConnSupportsLocalRemoteMethods(t *testing.T) {
type LocalAddr interface {
LocalAddr() net.Addr
}
type RemoteAddr interface {
RemoteAddr() net.Addr
}
for _, c := range []interface{}{new(ServerConn), new(ClientConn)} {
if _, ok := c.(LocalAddr); !ok {
t.Errorf("%T does not implement LocalAddr", c)
}
if _, ok := c.(RemoteAddr); !ok {
t.Errorf("%T does not implement RemoteAddr", c)
}
}
}
......@@ -97,8 +97,8 @@ const maxCachedPubKeys = 16
// A ServerConn represents an incoming connection.
type ServerConn struct {
*transport
config *ServerConfig
transport *transport
config *ServerConfig
channels map[uint32]*serverChan
nextChanId uint32
......@@ -147,6 +147,15 @@ func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
}
// Close closes the connection.
func (s *ServerConn) Close() error { return s.transport.Close() }
// LocalAddr returns the local network address.
func (c *ServerConn) LocalAddr() net.Addr { return c.transport.LocalAddr() }
// RemoteAddr returns the remote network address.
func (c *ServerConn) RemoteAddr() net.Addr { return c.transport.RemoteAddr() }
// Handshake performs an SSH transport and client authentication on the given ServerConn.
func (s *ServerConn) Handshake() error {
var err error
......@@ -160,7 +169,7 @@ func (s *ServerConn) Handshake() error {
}
var packet []byte
if packet, err = s.readPacket(); err != nil {
if packet, err = s.transport.readPacket(); err != nil {
return err
}
var serviceRequest serviceRequestMsg
......@@ -173,7 +182,7 @@ func (s *ServerConn) Handshake() error {
serviceAccept := serviceAcceptMsg{
Service: serviceUserAuth,
}
if err := s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
if err := s.transport.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
return err
}
......@@ -199,13 +208,13 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
}
serverKexInitPacket := marshal(msgKexInit, serverKexInit)
if err = s.writePacket(serverKexInitPacket); err != nil {
if err = s.transport.writePacket(serverKexInitPacket); err != nil {
return
}
if clientKexInitPacket == nil {
clientKexInit = new(kexInitMsg)
if clientKexInitPacket, err = s.readPacket(); err != nil {
if clientKexInitPacket, err = s.transport.readPacket(); err != nil {
return
}
if err = unmarshal(clientKexInit, clientKexInitPacket, msgKexInit); err != nil {
......@@ -221,7 +230,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] {
// The client sent a Kex message for the wrong algorithm,
// which we have to ignore.
if _, err = s.readPacket(); err != nil {
if _, err = s.transport.readPacket(); err != nil {
return
}
}
......@@ -244,7 +253,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
serverKexInit: marshal(msgKexInit, serverKexInit),
clientKexInit: clientKexInitPacket,
}
result, err := kex.Server(s, s.config.rand(), &magics, hostKey)
result, err := kex.Server(s.transport, s.config.rand(), &magics, hostKey)
if err != nil {
return err
}
......@@ -253,10 +262,10 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
return err
}
if err = s.writePacket([]byte{msgNewKeys}); err != nil {
if err = s.transport.writePacket([]byte{msgNewKeys}); err != nil {
return
}
if packet, err := s.readPacket(); err != nil {
if packet, err := s.transport.readPacket(); err != nil {
return err
} else if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
......@@ -308,7 +317,7 @@ func (s *ServerConn) authenticate(H []byte) error {
userAuthLoop:
for {
if packet, err = s.readPacket(); err != nil {
if packet, err = s.transport.readPacket(); err != nil {
return err
}
if err = unmarshal(&userAuthReq, packet, msgUserAuthRequest); err != nil {
......@@ -382,7 +391,7 @@ userAuthLoop:
Algo: algo,
PubKey: string(pubKey),
}
if err = s.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
if err = s.transport.writePacket(marshal(msgUserAuthPubKeyOk, okMsg)); err != nil {
return err
}
continue userAuthLoop
......@@ -432,13 +441,13 @@ userAuthLoop:
return errors.New("ssh: no authentication methods configured but NoClientAuth is also false")
}
if err = s.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
if err = s.transport.writePacket(marshal(msgUserAuthFailure, failureMsg)); err != nil {
return err
}
}
packet = []byte{msgUserAuthSuccess}
if err = s.writePacket(packet); err != nil {
if err = s.transport.writePacket(packet); err != nil {
return err
}
......@@ -462,7 +471,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest
prompts = appendBool(prompts, echos[i])
}
if err := c.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
if err := c.transport.writePacket(marshal(msgUserAuthInfoRequest, userAuthInfoRequestMsg{
Instruction: instruction,
NumPrompts: uint32(len(questions)),
Prompts: prompts,
......@@ -470,7 +479,7 @@ func (c *sshClientKeyboardInteractive) Challenge(user, instruction string, quest
return nil, err
}
packet, err := c.readPacket()
packet, err := c.transport.readPacket()
if err != nil {
return nil, err
}
......@@ -511,7 +520,7 @@ func (s *ServerConn) Accept() (Channel, error) {
}
for {
packet, err := s.readPacket()
packet, err := s.transport.readPacket()
if err != nil {
s.lock.Lock()
......@@ -557,7 +566,7 @@ func (s *ServerConn) Accept() (Channel, error) {
}
c := &serverChan{
channel: channel{
packetConn: s,
packetConn: s.transport,
remoteId: msg.PeersId,
remoteWin: window{Cond: newCond()},
maxPacket: msg.MaxPacketSize,
......@@ -619,7 +628,7 @@ func (s *ServerConn) Accept() (Channel, error) {
case *globalRequestMsg:
if msg.WantReply {
if err := s.writePacket([]byte{msgRequestFailure}); err != nil {
if err := s.transport.writePacket([]byte{msgRequestFailure}); err != nil {
return nil, err
}
}
......
......@@ -564,7 +564,7 @@ func (s *Session) StderrPipe() (io.Reader, error) {
// NewSession returns a new interactive session on the remote host.
func (c *ClientConn) NewSession() (*Session, error) {
ch := c.newChan(c.transport)
if err := c.writePacket(marshal(msgChannelOpen, channelOpenMsg{
if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenMsg{
ChanType: "session",
PeersId: ch.localId,
PeersWindow: 1 << 14,
......
......@@ -296,7 +296,7 @@ type channelOpenDirectMsg struct {
// strings and are expected to be resolvable at the remote end.
func (c *ClientConn) dial(laddr string, lport int, raddr string, rport int) (*tcpChan, error) {
ch := c.newChan(c.transport)
if err := c.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
if err := c.transport.writePacket(marshal(msgChannelOpen, channelOpenDirectMsg{
ChanType: "direct-tcpip",
PeersId: ch.localId,
PeersWindow: 1 << 14,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment