Commit d295108f authored by Han-Wen Nienhuys's avatar Han-Wen Nienhuys

go.crypto/ssh: put version exchange in function

R=golang-dev, dave, jpsugar, agl
CC=golang-dev
https://codereview.appspot.com/14641044

Committer: Dave Cheney <dave@cheney.net>
parent 74d09edc
......@@ -14,9 +14,6 @@ import (
"sync"
)
// clientVersion is the default identification string that the client will use.
var clientVersion = []byte("SSH-2.0-Go")
// ClientConn represents the client side of an SSH connection.
type ClientConn struct {
*transport
......@@ -59,22 +56,12 @@ func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientCo
// handshake performs the client side key exchange. See RFC 4253 Section 7.
func (c *ClientConn) handshake() error {
var myVersion []byte
if len(c.config.ClientVersion) > 0 {
myVersion = []byte(c.config.ClientVersion)
} else {
myVersion = clientVersion
}
if _, err := c.Write(append(myVersion, '\r', '\n')); err != nil {
return err
}
if err := c.Flush(); err != nil {
return err
clientVersion := []byte(packageVersion)
if c.config.ClientVersion != "" {
clientVersion = []byte(c.config.ClientVersion)
}
// read remote server version
serverVersion, err := readVersion(c)
serverVersion, err := exchangeVersions(c.transport.Conn, clientVersion)
if err != nil {
return err
}
......@@ -123,7 +110,7 @@ func (c *ClientConn) handshake() error {
}
magics := handshakeMagics{
clientVersion: myVersion,
clientVersion: clientVersion,
serverVersion: serverVersion,
clientKexInit: kexInitPacket,
serverKexInit: packet,
......
......@@ -30,5 +30,5 @@ func TestCustomClientVersion(t *testing.T) {
}
func TestDefaultClientVersion(t *testing.T) {
testClientVersion(t, &ClientConfig{}, string(clientVersion))
testClientVersion(t, &ClientConfig{}, packageVersion)
}
......@@ -121,6 +121,9 @@ type ServerConn struct {
// ClientVersion is the client's version, populated after
// Handshake is called. It should not be modified.
ClientVersion []byte
// Our version.
serverVersion []byte
}
// Server returns a new SSH server connection
......@@ -144,33 +147,25 @@ func signAndMarshal(k Signer, rand io.Reader, data []byte) ([]byte, error) {
return serializeSignature(k.PublicKey().PrivateKeyAlgo(), sig), nil
}
// serverVersion is the fixed identification string that Server will use.
var serverVersion = []byte("SSH-2.0-Go\r\n")
// Handshake performs an SSH transport and client authentication on the given ServerConn.
func (s *ServerConn) Handshake() (err error) {
if _, err = s.Write(serverVersion); err != nil {
return
}
if err := s.Flush(); err != nil {
return err
}
s.ClientVersion, err = readVersion(s)
func (s *ServerConn) Handshake() error {
var err error
s.serverVersion = []byte(packageVersion)
s.ClientVersion, err = exchangeVersions(s.transport.Conn, s.serverVersion)
if err != nil {
return
return err
}
if err = s.clientInitHandshake(nil, nil); err != nil {
return
if err := s.clientInitHandshake(nil, nil); err != nil {
return err
}
var packet []byte
if packet, err = s.readPacket(); err != nil {
return
return err
}
var serviceRequest serviceRequestMsg
if err = unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
return
if err := unmarshal(&serviceRequest, packet, msgServiceRequest); err != nil {
return err
}
if serviceRequest.Service != serviceUserAuth {
return errors.New("ssh: requested service '" + serviceRequest.Service + "' before authenticating")
......@@ -178,14 +173,14 @@ func (s *ServerConn) Handshake() (err error) {
serviceAccept := serviceAcceptMsg{
Service: serviceUserAuth,
}
if err = s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
return
if err := s.writePacket(marshal(msgServiceAccept, serviceAccept)); err != nil {
return err
}
if err = s.authenticate(s.transport.sessionID); err != nil {
return
if err := s.authenticate(s.transport.sessionID); err != nil {
return err
}
return
return err
}
func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexInitPacket []byte) (err error) {
......@@ -244,7 +239,7 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
}
magics := handshakeMagics{
serverVersion: serverVersion[:len(serverVersion)-2],
serverVersion: s.serverVersion,
clientVersion: s.ClientVersion,
serverKexInit: marshal(msgKexInit, serverKexInit),
clientKexInit: clientKexInitPacket,
......
......@@ -358,18 +358,41 @@ func generateKeyMaterial(out, tag []byte, K, H, sessionId []byte, h hash.Hash) {
}
}
// maxVersionStringBytes is the maximum number of bytes that we'll accept as a
// version string. In the event that the client is talking a different protocol
// we need to set a limit otherwise we will keep using more and more memory
// while searching for the end of the version handshake.
const maxVersionStringBytes = 1024
const packageVersion = "SSH-2.0-Go"
// Sends and receives a version line. The versionLine string should
// be US ASCII, start with "SSH-2.0-", and should not include a
// newline. exchangeVersions returns the other side's version line.
func exchangeVersions(rw io.ReadWriter, versionLine []byte) (them []byte, err error) {
// Contrary to the RFC, we do not ignore lines that don't
// start with "SSH-2.0-" to make the library usable with
// nonconforming servers.
for _, c := range versionLine {
// The spec disallows non US-ASCII chars, and
// specifically forbids null chars.
if c < 32 {
return nil, errors.New("ssh: junk character in version line")
}
}
if _, err = rw.Write(append(versionLine, '\r', '\n')); err != nil {
return
}
them, err = readVersion(rw)
return them, err
}
// maxVersionStringBytes is the maximum number of bytes that we'll
// accept as a version string. RFC 4253 section 4.2 limits this at 255
// chars
const maxVersionStringBytes = 255
// Read version string as specified by RFC 4253, section 4.2.
func readVersion(r io.Reader) ([]byte, error) {
versionString := make([]byte, 0, 64)
var ok bool
var buf [1]byte
forEachByte:
for len(versionString) < maxVersionStringBytes {
_, err := io.ReadFull(r, buf[:])
if err != nil {
......@@ -379,13 +402,20 @@ forEachByte:
// but several SSH servers actually only send a \n.
if buf[0] == '\n' {
ok = true
break forEachByte
break
}
// non ASCII chars are disallowed, but we are lenient,
// since Go doesn't use null-terminated strings.
// The RFC allows a comment after a space, however,
// all of it (version and comments) goes into the
// session hash.
versionString = append(versionString, buf[0])
}
if !ok {
return nil, errors.New("ssh: failed to read version string")
return nil, errors.New("ssh: overflow reading version string")
}
// There might be a '\r' on the end which we should remove.
......
......@@ -5,47 +5,65 @@
package ssh
import (
"bufio"
"bytes"
"strings"
"testing"
)
func TestReadVersion(t *testing.T) {
buf := serverVersion
result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
cases := map[string]string{
"SSH-2.0-bla\r\n": "SSH-2.0-bla",
"SSH-2.0-bla\n": "SSH-2.0-bla",
longversion + "\r\n": longversion,
}
for in, want := range cases {
result, err := readVersion(bytes.NewBufferString(in))
if err != nil {
t.Errorf("readVersion didn't read version correctly: %s", err)
t.Errorf("readVersion(%q): %s", in, err)
}
got := string(result)
if got != want {
t.Errorf("got %q, want %q", got, want)
}
if !bytes.Equal(buf[:len(buf)-2], result) {
t.Error("version read did not match expected")
}
}
func TestReadVersionWithJustLF(t *testing.T) {
var buf []byte
buf = append(buf, serverVersion...)
buf = buf[:len(buf)-1]
buf[len(buf)-1] = '\n'
result, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf)))
if err != nil {
t.Error("readVersion failed to handle just a \n")
func TestReadVersionError(t *testing.T) {
longversion := strings.Repeat("SSH-2.0-bla", 50)[:253]
cases := []string{
longversion + "too-long\r\n",
}
for _, in := range cases {
if _, err := readVersion(bytes.NewBufferString(in)); err == nil {
t.Errorf("readVersion(%q) should have failed", in)
}
if !bytes.Equal(buf[:len(buf)-1], result) {
t.Errorf("version read did not match expected: got %x, want %x", result, buf[:len(buf)-1])
}
}
func TestReadVersionTooLong(t *testing.T) {
buf := make([]byte, maxVersionStringBytes+1)
if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
t.Errorf("readVersion consumed %d bytes without error", len(buf))
func TestExchangeVersionsBasic(t *testing.T) {
v := "SSH-2.0-bla"
buf := bytes.NewBufferString(v + "\r\n")
them, err := exchangeVersions(buf, []byte("xyz"))
if err != nil {
t.Errorf("exchangeVersions: %v", err)
}
if want := "SSH-2.0-bla"; string(them) != want {
t.Errorf("got %q want %q for our version", them, want)
}
}
func TestReadVersionWithoutCRLF(t *testing.T) {
buf := serverVersion
buf = buf[:len(buf)-1]
if _, err := readVersion(bufio.NewReader(bytes.NewBuffer(buf))); err == nil {
t.Error("readVersion did not notice \\n was missing")
func TestExchangeVersions(t *testing.T) {
cases := []string{
"not\x000allowed",
"not allowed\n",
}
for _, c := range cases {
buf := bytes.NewBufferString("SSH-2.0-bla\r\n")
if _, err := exchangeVersions(buf, []byte(c)); err == nil {
t.Errorf("exchangeVersions(%q): should have failed", c)
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment