Commit 9140d54a authored by Han-Wen Nienhuys's avatar Han-Wen Nienhuys

go.crypto/ssh: move interpretation of msgNewKeys into

transport.

Sending the msgNewKeys packet and setting up the key material
now happen under a lock, preventing races with concurrent
writers.

R=kardianos, agl, jpsugar, hanwenn
CC=golang-dev
https://codereview.appspot.com/14476043

Committer: Adam Langley <agl@golang.org>
parent 1d963b1d
......@@ -43,7 +43,7 @@ func Client(c net.Conn, config *ClientConfig) (*ClientConn, error) {
func clientWithAddress(c net.Conn, addr string, config *ClientConfig) (*ClientConn, error) {
conn := &ClientConn{
transport: newTransport(c, config.rand()),
transport: newTransport(c, config.rand(), true /* is client */),
config: config,
globalRequest: globalRequest{response: make(chan interface{}, 1)},
dialAddress: addr,
......@@ -104,12 +104,12 @@ func (c *ClientConn) handshake() error {
return err
}
kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(c.transport, &clientKexInit, &serverKexInit)
if !ok {
algs := findAgreedAlgorithms(&clientKexInit, &serverKexInit)
if algs == nil {
return errors.New("ssh: no common algorithms")
}
if serverKexInit.FirstKexFollows && kexAlgo != serverKexInit.KexAlgos[0] {
if serverKexInit.FirstKexFollows && algs.kex != serverKexInit.KexAlgos[0] {
// The server sent a Kex message for the wrong algorithm,
// which we have to ignore.
if _, err := c.readPacket(); err != nil {
......@@ -117,9 +117,9 @@ func (c *ClientConn) handshake() error {
}
}
kex, ok := kexAlgoMap[kexAlgo]
kex, ok := kexAlgoMap[algs.kex]
if !ok {
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
}
magics := handshakeMagics{
......@@ -133,23 +133,21 @@ func (c *ClientConn) handshake() error {
return err
}
err = verifyHostKeySignature(hostKeyAlgo, result.HostKey, result.H, result.Signature)
err = verifyHostKeySignature(algs.hostKey, result.HostKey, result.H, result.Signature)
if err != nil {
return err
}
if checker := c.config.HostKeyChecker; checker != nil {
err = checker.Check(c.dialAddress, c.RemoteAddr(), hostKeyAlgo, result.HostKey)
err = checker.Check(c.dialAddress, c.RemoteAddr(), algs.hostKey, result.HostKey)
if err != nil {
return err
}
}
if err = c.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
c.transport.prepareKeyChange(algs, result)
if err = c.transport.writer.setupKeys(clientKeys, result.K, result.H, result.H, kex.Hash()); err != nil {
if err = c.writePacket([]byte{msgNewKeys}); err != nil {
return err
}
if packet, err = c.readPacket(); err != nil {
......@@ -158,9 +156,6 @@ func (c *ClientConn) handshake() error {
if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
if err := c.transport.reader.setupKeys(serverKeys, result.K, result.H, result.H, kex.Hash()); err != nil {
return err
}
return c.authenticate(result.H)
}
......
......@@ -90,49 +90,61 @@ func findCommonCipher(clientCiphers []string, serverCiphers []string) (commonCip
return
}
func findAgreedAlgorithms(transport *transport, clientKexInit, serverKexInit *kexInitMsg) (kexAlgo, hostKeyAlgo string, ok bool) {
kexAlgo, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
type algorithms struct {
kex string
hostKey string
wCipher string
rCipher string
rMAC string
wMAC string
rCompression string
wCompression string
}
func findAgreedAlgorithms(clientKexInit, serverKexInit *kexInitMsg) (algs *algorithms) {
var ok bool
result := &algorithms{}
result.kex, ok = findCommonAlgorithm(clientKexInit.KexAlgos, serverKexInit.KexAlgos)
if !ok {
return
}
hostKeyAlgo, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
result.hostKey, ok = findCommonAlgorithm(clientKexInit.ServerHostKeyAlgos, serverKexInit.ServerHostKeyAlgos)
if !ok {
return
}
transport.writer.cipherAlgo, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
result.wCipher, ok = findCommonCipher(clientKexInit.CiphersClientServer, serverKexInit.CiphersClientServer)
if !ok {
return
}
transport.reader.cipherAlgo, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
result.rCipher, ok = findCommonCipher(clientKexInit.CiphersServerClient, serverKexInit.CiphersServerClient)
if !ok {
return
}
transport.writer.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
result.wMAC, ok = findCommonAlgorithm(clientKexInit.MACsClientServer, serverKexInit.MACsClientServer)
if !ok {
return
}
transport.reader.macAlgo, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
result.rMAC, ok = findCommonAlgorithm(clientKexInit.MACsServerClient, serverKexInit.MACsServerClient)
if !ok {
return
}
transport.writer.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
result.wCompression, ok = findCommonAlgorithm(clientKexInit.CompressionClientServer, serverKexInit.CompressionClientServer)
if !ok {
return
}
transport.reader.compressionAlgo, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
result.rCompression, ok = findCommonAlgorithm(clientKexInit.CompressionServerClient, serverKexInit.CompressionServerClient)
if !ok {
return
}
ok = true
return
return result
}
// Cryptographic configuration common to both ServerConfig and ClientConfig.
......
......@@ -35,6 +35,15 @@ type kexResult struct {
// Signature of H
Signature []byte
// A cryptographic hash function that matches the security
// level of the key exchange algorithm. It is used for
// calculating H, and for deriving keys from H and K.
Hash crypto.Hash
// The session ID, which is the first H computed. This is used
// to signal data inside transport.
SessionID []byte
}
// handshakeMagics contains data that is always included in the
......@@ -60,12 +69,6 @@ type kexAlgorithm interface {
// Client runs the client-side key agreement. Caller is
// responsible for verifying the host key signature.
Client(p packetConn, rand io.Reader, magics *handshakeMagics) (*kexResult, error)
// Hash returns a cryptographic hash function that matches the
// security level of the key exchange algorithm. It is used
// for calculating kexResult.H, and for deriving keys from
// data in kexResult.
Hash() crypto.Hash
}
// dhGroup is a multiplicative group suitable for implementing Diffie-Hellman key agreement.
......@@ -73,10 +76,6 @@ type dhGroup struct {
g, p *big.Int
}
func (group *dhGroup) Hash() crypto.Hash {
return crypto.SHA1
}
func (group *dhGroup) diffieHellman(theirPublic, myPrivate *big.Int) (*big.Int, error) {
if theirPublic.Sign() <= 0 || theirPublic.Cmp(group.p) >= 0 {
return nil, errors.New("ssh: DH parameter out of bounds")
......@@ -128,6 +127,7 @@ func (group *dhGroup) Client(c packetConn, randSource io.Reader, magics *handsha
K: K,
HostKey: kexDHReply.HostKey,
Signature: kexDHReply.Signature,
Hash: crypto.SHA1,
}, nil
}
......@@ -187,6 +187,7 @@ func (group *dhGroup) Server(c packetConn, randSource io.Reader, magics *handsha
K: K,
HostKey: hostKeyBytes,
Signature: sig,
Hash: crypto.SHA1,
}, nil
}
......@@ -243,6 +244,7 @@ func (kex *ecdh) Client(c packetConn, rand io.Reader, magics *handshakeMagics) (
K: K,
HostKey: reply.HostKey,
Signature: reply.Signature,
Hash: ecHash(kex.curve),
}, nil
}
......@@ -354,13 +356,10 @@ func (kex *ecdh) Server(c packetConn, rand io.Reader, magics *handshakeMagics, p
K: K,
HostKey: reply.HostKey,
Signature: sig,
Hash: ecHash(kex.curve),
}, nil
}
func (kex *ecdh) Hash() crypto.Hash {
return ecHash(kex.curve)
}
var kexAlgoMap = map[string]kexAlgorithm{}
func init() {
......
......@@ -121,17 +121,13 @@ type ServerConn struct {
// ClientVersion is the client's version, populated after
// Handshake is called. It should not be modified.
ClientVersion []byte
// Initial H used for the session ID. Once assigned this must not change
// even during subsequent key exchanges.
sessionId []byte
}
// Server returns a new SSH server connection
// using c as the underlying transport.
func Server(c net.Conn, config *ServerConfig) *ServerConn {
return &ServerConn{
transport: newTransport(c, config.rand()),
transport: newTransport(c, config.rand(), false /* not client */),
channels: make(map[uint32]*serverChan),
config: config,
}
......@@ -186,7 +182,7 @@ func (s *ServerConn) Handshake() (err error) {
return
}
if err = s.authenticate(s.sessionId); err != nil {
if err = s.authenticate(s.transport.sessionID); err != nil {
return
}
return
......@@ -222,11 +218,12 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
}
}
kexAlgo, hostKeyAlgo, ok := findAgreedAlgorithms(s.transport, clientKexInit, &serverKexInit)
if !ok {
algs := findAgreedAlgorithms(clientKexInit, &serverKexInit)
if algs == nil {
return errors.New("ssh: no common algorithms")
}
if clientKexInit.FirstKexFollows && kexAlgo != clientKexInit.KexAlgos[0] {
if clientKexInit.FirstKexFollows && algs.kex != clientKexInit.KexAlgos[0] {
// The client sent a Kex message for the wrong algorithm,
// which we have to ignore.
if _, err = s.readPacket(); err != nil {
......@@ -236,14 +233,14 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
var hostKey Signer
for _, k := range s.config.hostKeys {
if hostKeyAlgo == k.PublicKey().PublicKeyAlgo() {
if algs.hostKey == k.PublicKey().PublicKeyAlgo() {
hostKey = k
}
}
kex, ok := kexAlgoMap[kexAlgo]
kex, ok := kexAlgoMap[algs.kex]
if !ok {
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", kexAlgo)
return fmt.Errorf("ssh: unexpected key exchange algorithm %v", algs.kex)
}
magics := handshakeMagics{
......@@ -257,29 +254,18 @@ func (s *ServerConn) clientInitHandshake(clientKexInit *kexInitMsg, clientKexIni
return err
}
// sessionId must only be assigned during initial handshake.
if s.sessionId == nil {
s.sessionId = result.H
if err = s.transport.prepareKeyChange(algs, result); err != nil {
return err
}
var packet []byte
if err = s.writePacket([]byte{msgNewKeys}); err != nil {
return
}
if err = s.transport.writer.setupKeys(serverKeys, result.K, result.H, s.sessionId, kex.Hash()); err != nil {
return
}
if packet, err = s.readPacket(); err != nil {
return
}
if packet[0] != msgNewKeys {
if packet, err := s.readPacket(); err != nil {
return err
} else if packet[0] != msgNewKeys {
return UnexpectedMessageError{msgNewKeys, packet[0]}
}
if err = s.transport.reader.setupKeys(clientKeys, result.K, result.H, s.sessionId, kex.Hash()); err != nil {
return
}
return
}
......
......@@ -6,7 +6,6 @@ package ssh
import (
"bufio"
"crypto"
"crypto/cipher"
"crypto/subtle"
"encoding/binary"
......@@ -48,6 +47,10 @@ type transport struct {
writer
net.Conn
// Initial H used for the session ID. Once assigned this does
// not change, even during subsequent key exchanges.
sessionID []byte
}
// reader represents the incoming connection state.
......@@ -64,6 +67,28 @@ type writer struct {
common
}
// prepareKeyChange sets up key material for a keychange. The key changes in
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
func (t *transport) prepareKeyChange(algs *algorithms, kexResult *kexResult) error {
t.writer.cipherAlgo = algs.wCipher
t.writer.macAlgo = algs.wMAC
t.writer.compressionAlgo = algs.wCompression
t.reader.cipherAlgo = algs.rCipher
t.reader.macAlgo = algs.rMAC
t.reader.compressionAlgo = algs.rCompression
if t.sessionID == nil {
t.sessionID = kexResult.H
}
kexResult.SessionID = t.sessionID
t.reader.pendingKeyChange <- kexResult
t.writer.pendingKeyChange <- kexResult
return nil
}
// common represents the cipher state needed to process messages in a single
// direction.
type common struct {
......@@ -74,6 +99,9 @@ type common struct {
cipherAlgo string
macAlgo string
compressionAlgo string
dir direction
pendingKeyChange chan *kexResult
}
// Read and decrypt a single packet from the remote peer.
......@@ -125,7 +153,19 @@ func (r *reader) readPacket() ([]byte, error) {
}
r.seqNum++
return packet[:length-paddingLength-1], nil
packet = packet[:length-paddingLength-1]
if len(packet) > 0 && packet[0] == msgNewKeys {
select {
case k := <-r.pendingKeyChange:
if err := r.setupKeys(r.dir, k); err != nil {
return nil, err
}
default:
return nil, errors.New("ssh: got bogus newkeys message.")
}
}
return packet, nil
}
// Read and decrypt next packet discarding debug and noop messages.
......@@ -138,6 +178,7 @@ func (t *transport) readPacket() ([]byte, error) {
if len(packet) == 0 {
return nil, errors.New("ssh: zero length packet")
}
if packet[0] != msgIgnore && packet[0] != msgDebug {
return packet, nil
}
......@@ -147,6 +188,8 @@ func (t *transport) readPacket() ([]byte, error) {
// Encrypt and send a packet of data to the remote peer.
func (w *writer) writePacket(packet []byte) error {
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
if len(packet) > maxPacket {
return errors.New("ssh: packet too large")
}
......@@ -209,26 +252,49 @@ func (w *writer) writePacket(packet []byte) error {
}
w.seqNum++
return w.Flush()
if err = w.Flush(); err != nil {
return err
}
if changeKeys {
select {
case k := <-w.pendingKeyChange:
err = w.setupKeys(w.dir, k)
default:
panic("ssh: no key material for msgNewKeys")
}
}
return err
}
func newTransport(conn net.Conn, rand io.Reader) *transport {
return &transport{
func newTransport(conn net.Conn, rand io.Reader, isClient bool) *transport {
t := &transport{
reader: reader{
Reader: bufio.NewReader(conn),
common: common{
cipher: noneCipher{},
cipher: noneCipher{},
pendingKeyChange: make(chan *kexResult, 1),
},
},
writer: writer{
Writer: bufio.NewWriter(conn),
rand: rand,
common: common{
cipher: noneCipher{},
cipher: noneCipher{},
pendingKeyChange: make(chan *kexResult, 1),
},
},
Conn: conn,
}
if isClient {
t.reader.dir = serverKeys
t.writer.dir = clientKeys
} else {
t.reader.dir = clientKeys
t.writer.dir = serverKeys
}
return t
}
type direction struct {
......@@ -246,7 +312,7 @@ var (
// setupKeys sets the cipher and MAC keys from kex.K, kex.H and sessionId, as
// described in RFC 4253, section 6.4. direction should either be serverKeys
// (to setup server->client keys) or clientKeys (for client->server keys).
func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.Hash) error {
func (c *common) setupKeys(d direction, r *kexResult) error {
cipherMode := cipherModes[c.cipherAlgo]
macMode := macModes[c.macAlgo]
......@@ -254,10 +320,10 @@ func (c *common) setupKeys(d direction, K, H, sessionId []byte, hashFunc crypto.
key := make([]byte, cipherMode.keySize)
macKey := make([]byte, macMode.keySize)
h := hashFunc.New()
generateKeyMaterial(iv, d.ivTag, K, H, sessionId, h)
generateKeyMaterial(key, d.keyTag, K, H, sessionId, h)
generateKeyMaterial(macKey, d.macKeyTag, K, H, sessionId, h)
h := r.Hash.New()
generateKeyMaterial(iv, d.ivTag, r.K, r.H, r.SessionID, h)
generateKeyMaterial(key, d.keyTag, r.K, r.H, r.SessionID, h)
generateKeyMaterial(macKey, d.macKeyTag, r.K, r.H, r.SessionID, h)
c.mac = macMode.new(macKey)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment