Created
June 9, 2017 21:21
-
-
Save meatballhat/eda84ef33b09fc846591a48ec2330b59 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/ssh/agent/client_test.go b/ssh/agent/client_test.go | |
index a13a650..5fc47e5 100644 | |
--- a/ssh/agent/client_test.go | |
+++ b/ssh/agent/client_test.go | |
@@ -180,9 +180,12 @@ func TestCert(t *testing.T) { | |
// therefore is buffered (net.Pipe deadlocks if both sides start with | |
// a write.) | |
func netPipe() (net.Conn, net.Conn, error) { | |
- listener, err := net.Listen("tcp", ":0") | |
+ listener, err := net.Listen("tcp", "127.0.0.1:0") | |
if err != nil { | |
- return nil, nil, err | |
+ listener, err = net.Listen("tcp", "[::1]:0") | |
+ if err != nil { | |
+ return nil, nil, err | |
+ } | |
} | |
defer listener.Close() | |
c1, err := net.Dial("tcp", listener.Addr().String()) | |
@@ -200,6 +203,9 @@ func netPipe() (net.Conn, net.Conn, error) { | |
} | |
func TestAuth(t *testing.T) { | |
+ agent, _, cleanup := startAgent(t) | |
+ defer cleanup() | |
+ | |
a, b, err := netPipe() | |
if err != nil { | |
t.Fatalf("netPipe: %v", err) | |
@@ -208,9 +214,6 @@ func TestAuth(t *testing.T) { | |
defer a.Close() | |
defer b.Close() | |
- agent, _, cleanup := startAgent(t) | |
- defer cleanup() | |
- | |
if err := agent.Add(AddedKey{PrivateKey: testPrivateKeys["rsa"], Comment: "comment"}); err != nil { | |
t.Errorf("Add: %v", err) | |
} | |
@@ -233,7 +236,9 @@ func TestAuth(t *testing.T) { | |
conn.Close() | |
}() | |
- conf := ssh.ClientConfig{} | |
+ conf := ssh.ClientConfig{ | |
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(), | |
+ } | |
conf.Auth = append(conf.Auth, ssh.PublicKeysCallback(agent.Signers)) | |
conn, _, _, err := ssh.NewClientConn(b, "", &conf) | |
if err != nil { | |
diff --git a/ssh/agent/example_test.go b/ssh/agent/example_test.go | |
index c1130f7..8556225 100644 | |
--- a/ssh/agent/example_test.go | |
+++ b/ssh/agent/example_test.go | |
@@ -6,20 +6,20 @@ package agent_test | |
import ( | |
"log" | |
- "os" | |
"net" | |
+ "os" | |
- "golang.org/x/crypto/ssh" | |
- "golang.org/x/crypto/ssh/agent" | |
+ "golang.org/x/crypto/ssh" | |
+ "golang.org/x/crypto/ssh/agent" | |
) | |
func ExampleClientAgent() { | |
// ssh-agent has a UNIX socket under $SSH_AUTH_SOCK | |
socket := os.Getenv("SSH_AUTH_SOCK") | |
- conn, err := net.Dial("unix", socket) | |
- if err != nil { | |
- log.Fatalf("net.Dial: %v", err) | |
- } | |
+ conn, err := net.Dial("unix", socket) | |
+ if err != nil { | |
+ log.Fatalf("net.Dial: %v", err) | |
+ } | |
agentClient := agent.NewClient(conn) | |
config := &ssh.ClientConfig{ | |
User: "username", | |
@@ -29,6 +29,7 @@ func ExampleClientAgent() { | |
// wants it. | |
ssh.PublicKeysCallback(agentClient.Signers), | |
}, | |
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(), | |
} | |
sshc, err := ssh.Dial("tcp", "localhost:22", config) | |
diff --git a/ssh/agent/server_test.go b/ssh/agent/server_test.go | |
index ec9cdee..6b0837d 100644 | |
--- a/ssh/agent/server_test.go | |
+++ b/ssh/agent/server_test.go | |
@@ -56,7 +56,9 @@ func TestSetupForwardAgent(t *testing.T) { | |
incoming <- conn | |
}() | |
- conf := ssh.ClientConfig{} | |
+ conf := ssh.ClientConfig{ | |
+ HostKeyCallback: ssh.InsecureIgnoreHostKey(), | |
+ } | |
conn, chans, reqs, err := ssh.NewClientConn(b, "", &conf) | |
if err != nil { | |
t.Fatalf("NewClientConn: %v", err) | |
diff --git a/ssh/certs.go b/ssh/certs.go | |
index 6331c94..b1f0220 100644 | |
--- a/ssh/certs.go | |
+++ b/ssh/certs.go | |
@@ -251,10 +251,18 @@ type CertChecker struct { | |
// for user certificates. | |
SupportedCriticalOptions []string | |
- // IsAuthority should return true if the key is recognized as | |
- // an authority. This allows for certificates to be signed by other | |
- // certificates. | |
- IsAuthority func(auth PublicKey) bool | |
+ // IsUserAuthority should return true if the key is recognized as an | |
+ // authority for the given user certificate. This allows for | |
+ // certificates to be signed by other certificates. This must be set | |
+ // if this CertChecker will be checking user certificates. | |
+ IsUserAuthority func(auth PublicKey) bool | |
+ | |
+ // IsHostAuthority should report whether the key is recognized as | |
+ // an authority for this host. This allows for certificates to be | |
+ // signed by other keys, and for those other keys to only be valid | |
+ // signers for particular hostnames. This must be set if this | |
+ // CertChecker will be checking host certificates. | |
+ IsHostAuthority func(auth PublicKey, address string) bool | |
// Clock is used for verifying time stamps. If nil, time.Now | |
// is used. | |
@@ -268,7 +276,7 @@ type CertChecker struct { | |
// HostKeyFallback is called when CertChecker.CheckHostKey encounters a | |
// public key that is not a certificate. It must implement host key | |
// validation or else, if nil, all such keys are rejected. | |
- HostKeyFallback func(addr string, remote net.Addr, key PublicKey) error | |
+ HostKeyFallback HostKeyCallback | |
// IsRevoked is called for each certificate so that revocation checking | |
// can be implemented. It should return true if the given certificate | |
@@ -290,8 +298,17 @@ func (c *CertChecker) CheckHostKey(addr string, remote net.Addr, key PublicKey) | |
if cert.CertType != HostCert { | |
return fmt.Errorf("ssh: certificate presented as a host key has type %d", cert.CertType) | |
} | |
+ if !c.IsHostAuthority(cert.SignatureKey, addr) { | |
+ return fmt.Errorf("ssh: no authorities for hostname: %v", addr) | |
+ } | |
+ | |
+ hostname, _, err := net.SplitHostPort(addr) | |
+ if err != nil { | |
+ return err | |
+ } | |
- return c.CheckCert(addr, cert) | |
+ // Pass hostname only as principal for host certificates (consistent with OpenSSH) | |
+ return c.CheckCert(hostname, cert) | |
} | |
// Authenticate checks a user certificate. Authenticate can be used as | |
@@ -308,6 +325,9 @@ func (c *CertChecker) Authenticate(conn ConnMetadata, pubKey PublicKey) (*Permis | |
if cert.CertType != UserCert { | |
return nil, fmt.Errorf("ssh: cert has type %d", cert.CertType) | |
} | |
+ if !c.IsUserAuthority(cert.SignatureKey) { | |
+ return nil, fmt.Errorf("ssh: certificate signed by unrecognized authority") | |
+ } | |
if err := c.CheckCert(conn.User(), cert); err != nil { | |
return nil, err | |
@@ -356,10 +376,6 @@ func (c *CertChecker) CheckCert(principal string, cert *Certificate) error { | |
} | |
} | |
- if !c.IsAuthority(cert.SignatureKey) { | |
- return fmt.Errorf("ssh: certificate signed by unrecognized authority") | |
- } | |
- | |
clock := c.Clock | |
if clock == nil { | |
clock = time.Now | |
diff --git a/ssh/certs_test.go b/ssh/certs_test.go | |
index c5f2e53..0200531 100644 | |
--- a/ssh/certs_test.go | |
+++ b/ssh/certs_test.go | |
@@ -104,7 +104,7 @@ func TestValidateCert(t *testing.T) { | |
t.Fatalf("got %v (%T), want *Certificate", key, key) | |
} | |
checker := CertChecker{} | |
- checker.IsAuthority = func(k PublicKey) bool { | |
+ checker.IsUserAuthority = func(k PublicKey) bool { | |
return bytes.Equal(k.Marshal(), validCert.SignatureKey.Marshal()) | |
} | |
@@ -142,7 +142,7 @@ func TestValidateCertTime(t *testing.T) { | |
checker := CertChecker{ | |
Clock: func() time.Time { return time.Unix(ts, 0) }, | |
} | |
- checker.IsAuthority = func(k PublicKey) bool { | |
+ checker.IsUserAuthority = func(k PublicKey) bool { | |
return bytes.Equal(k.Marshal(), | |
testPublicKeys["ecdsa"].Marshal()) | |
} | |
@@ -160,7 +160,7 @@ func TestValidateCertTime(t *testing.T) { | |
func TestHostKeyCert(t *testing.T) { | |
cert := &Certificate{ | |
- ValidPrincipals: []string{"hostname", "hostname.domain"}, | |
+ ValidPrincipals: []string{"hostname", "hostname.domain", "otherhost"}, | |
Key: testPublicKeys["rsa"], | |
ValidBefore: CertTimeInfinity, | |
CertType: HostCert, | |
@@ -168,8 +168,8 @@ func TestHostKeyCert(t *testing.T) { | |
cert.SignCert(rand.Reader, testSigners["ecdsa"]) | |
checker := &CertChecker{ | |
- IsAuthority: func(p PublicKey) bool { | |
- return bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) | |
+ IsHostAuthority: func(p PublicKey, addr string) bool { | |
+ return addr == "hostname:22" && bytes.Equal(testPublicKeys["ecdsa"].Marshal(), p.Marshal()) | |
}, | |
} | |
@@ -178,7 +178,14 @@ func TestHostKeyCert(t *testing.T) { | |
t.Errorf("NewCertSigner: %v", err) | |
} | |
- for _, name := range []string{"hostname", "otherhost"} { | |
+ for _, test := range []struct { | |
+ addr string | |
+ succeed bool | |
+ }{ | |
+ {addr: "hostname:22", succeed: true}, | |
+ {addr: "otherhost:22", succeed: false}, // The certificate is valid for 'otherhost' as hostname, but we only recognize the authority of the signer for the address 'hostname:22' | |
+ {addr: "lasthost:22", succeed: false}, | |
+ } { | |
c1, c2, err := netPipe() | |
if err != nil { | |
t.Fatalf("netPipe: %v", err) | |
@@ -201,16 +208,15 @@ func TestHostKeyCert(t *testing.T) { | |
User: "user", | |
HostKeyCallback: checker.CheckHostKey, | |
} | |
- _, _, _, err = NewClientConn(c2, name, config) | |
+ _, _, _, err = NewClientConn(c2, test.addr, config) | |
- succeed := name == "hostname" | |
- if (err == nil) != succeed { | |
- t.Fatalf("NewClientConn(%q): %v", name, err) | |
+ if (err == nil) != test.succeed { | |
+ t.Fatalf("NewClientConn(%q): %v", test.addr, err) | |
} | |
err = <-errc | |
- if (err == nil) != succeed { | |
- t.Fatalf("NewServerConn(%q): %v", name, err) | |
+ if (err == nil) != test.succeed { | |
+ t.Fatalf("NewServerConn(%q): %v", test.addr, err) | |
} | |
} | |
} | |
diff --git a/ssh/client.go b/ssh/client.go | |
index c97f297..a7e3263 100644 | |
--- a/ssh/client.go | |
+++ b/ssh/client.go | |
@@ -5,6 +5,7 @@ | |
package ssh | |
import ( | |
+ "bytes" | |
"errors" | |
"fmt" | |
"net" | |
@@ -13,7 +14,7 @@ import ( | |
) | |
// Client implements a traditional SSH client that supports shells, | |
-// subprocesses, port forwarding and tunneled dialing. | |
+// subprocesses, TCP port/streamlocal forwarding and tunneled dialing. | |
type Client struct { | |
Conn | |
@@ -59,6 +60,7 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { | |
conn.forwards.closeAll() | |
}() | |
go conn.forwards.handleChannels(conn.HandleChannelOpen("forwarded-tcpip")) | |
+ go conn.forwards.handleChannels(conn.HandleChannelOpen("[email protected]")) | |
return conn | |
} | |
@@ -68,6 +70,11 @@ func NewClient(c Conn, chans <-chan NewChannel, reqs <-chan *Request) *Client { | |
func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan NewChannel, <-chan *Request, error) { | |
fullConf := *config | |
fullConf.SetDefaults() | |
+ if fullConf.HostKeyCallback == nil { | |
+ c.Close() | |
+ return nil, nil, nil, errors.New("ssh: must specify HostKeyCallback") | |
+ } | |
+ | |
conn := &connection{ | |
sshConn: sshConn{conn: c}, | |
} | |
@@ -173,6 +180,13 @@ func Dial(network, addr string, config *ClientConfig) (*Client, error) { | |
return NewClient(c, chans, reqs), nil | |
} | |
+// HostKeyCallback is the function type used for verifying server | |
+// keys. A HostKeyCallback must return nil if the host key is OK, or | |
+// an error to reject it. It receives the hostname as passed to Dial | |
+// or NewClientConn. The remote address is the RemoteAddr of the | |
+// net.Conn underlying the the SSH connection. | |
+type HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error | |
+ | |
// A ClientConfig structure is used to configure a Client. It must not be | |
// modified after having been passed to an SSH function. | |
type ClientConfig struct { | |
@@ -188,10 +202,12 @@ type ClientConfig struct { | |
// be used during authentication. | |
Auth []AuthMethod | |
- // HostKeyCallback, if not nil, is called during the cryptographic | |
- // handshake to validate the server's host key. A nil HostKeyCallback | |
- // implies that all host keys are accepted. | |
- HostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error | |
+ // HostKeyCallback is called during the cryptographic | |
+ // handshake to validate the server's host key. The client | |
+ // configuration must supply this callback for the connection | |
+ // to succeed. The functions InsecureIgnoreHostKey or | |
+ // FixedHostKey can be used for simplistic host key checks. | |
+ HostKeyCallback HostKeyCallback | |
// ClientVersion contains the version identification string that will | |
// be used for the connection. If empty, a reasonable default is used. | |
@@ -209,3 +225,33 @@ type ClientConfig struct { | |
// A Timeout of zero means no timeout. | |
Timeout time.Duration | |
} | |
+ | |
+// InsecureIgnoreHostKey returns a function that can be used for | |
+// ClientConfig.HostKeyCallback to accept any host key. It should | |
+// not be used for production code. | |
+func InsecureIgnoreHostKey() HostKeyCallback { | |
+ return func(hostname string, remote net.Addr, key PublicKey) error { | |
+ return nil | |
+ } | |
+} | |
+ | |
+type fixedHostKey struct { | |
+ key PublicKey | |
+} | |
+ | |
+func (f *fixedHostKey) check(hostname string, remote net.Addr, key PublicKey) error { | |
+ if f.key == nil { | |
+ return fmt.Errorf("ssh: required host key was nil") | |
+ } | |
+ if !bytes.Equal(key.Marshal(), f.key.Marshal()) { | |
+ return fmt.Errorf("ssh: host key mismatch") | |
+ } | |
+ return nil | |
+} | |
+ | |
+// FixedHostKey returns a function for use in | |
+// ClientConfig.HostKeyCallback to accept only a specific host key. | |
+func FixedHostKey(key PublicKey) HostKeyCallback { | |
+ hk := &fixedHostKey{key} | |
+ return hk.check | |
+} | |
diff --git a/ssh/client_auth.go b/ssh/client_auth.go | |
index fd1ec5d..b882da0 100644 | |
--- a/ssh/client_auth.go | |
+++ b/ssh/client_auth.go | |
@@ -179,31 +179,26 @@ func (cb publicKeyCallback) method() string { | |
} | |
func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand io.Reader) (bool, []string, error) { | |
- // Authentication is performed in two stages. The first stage sends an | |
- // enquiry to test if each key is acceptable to the remote. The second | |
- // stage attempts to authenticate with the valid keys obtained in the | |
- // first stage. | |
+ // Authentication is performed by sending an enquiry to test if a key is | |
+ // acceptable to the remote. If the key is acceptable, the client will | |
+ // attempt to authenticate with the valid key. If not the client will repeat | |
+ // the process with the remaining keys. | |
signers, err := cb() | |
if err != nil { | |
return false, nil, err | |
} | |
- var validKeys []Signer | |
+ var methods []string | |
for _, signer := range signers { | |
- if ok, err := validateKey(signer.PublicKey(), user, c); ok { | |
- validKeys = append(validKeys, signer) | |
- } else { | |
- if err != nil { | |
- return false, nil, err | |
- } | |
+ ok, err := validateKey(signer.PublicKey(), user, c) | |
+ if err != nil { | |
+ return false, nil, err | |
+ } | |
+ if !ok { | |
+ continue | |
} | |
- } | |
- // methods that may continue if this auth is not successful. | |
- var methods []string | |
- for _, signer := range validKeys { | |
pub := signer.PublicKey() | |
- | |
pubKey := pub.Marshal() | |
sign, err := signer.Sign(rand, buildDataSignedForAuth(session, userAuthRequestMsg{ | |
User: user, | |
@@ -236,13 +231,29 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand | |
if err != nil { | |
return false, nil, err | |
} | |
- if success { | |
+ | |
+ // If authentication succeeds or the list of available methods does not | |
+ // contain the "publickey" method, do not attempt to authenticate with any | |
+ // other keys. According to RFC 4252 Section 7, the latter can occur when | |
+ // additional authentication methods are required. | |
+ if success || !containsMethod(methods, cb.method()) { | |
return success, methods, err | |
} | |
} | |
+ | |
return false, methods, nil | |
} | |
+func containsMethod(methods []string, method string) bool { | |
+ for _, m := range methods { | |
+ if m == method { | |
+ return true | |
+ } | |
+ } | |
+ | |
+ return false | |
+} | |
+ | |
// validateKey validates the key provided is acceptable to the server. | |
func validateKey(key PublicKey, user string, c packetConn) (bool, error) { | |
pubKey := key.Marshal() | |
diff --git a/ssh/client_auth_test.go b/ssh/client_auth_test.go | |
index e384c79..bd9f8a1 100644 | |
--- a/ssh/client_auth_test.go | |
+++ b/ssh/client_auth_test.go | |
@@ -38,7 +38,7 @@ func tryAuth(t *testing.T, config *ClientConfig) error { | |
defer c2.Close() | |
certChecker := CertChecker{ | |
- IsAuthority: func(k PublicKey) bool { | |
+ IsUserAuthority: func(k PublicKey) bool { | |
return bytes.Equal(k.Marshal(), testPublicKeys["ecdsa"].Marshal()) | |
}, | |
UserKeyFallback: func(conn ConnMetadata, key PublicKey) (*Permissions, error) { | |
@@ -76,8 +76,6 @@ func tryAuth(t *testing.T, config *ClientConfig) error { | |
} | |
return nil, errors.New("keyboard-interactive failed") | |
}, | |
- AuthLogCallback: func(conn ConnMetadata, method string, err error) { | |
- }, | |
} | |
serverConfig.AddHostKey(testSigners["rsa"]) | |
@@ -92,6 +90,7 @@ func TestClientAuthPublicKey(t *testing.T) { | |
Auth: []AuthMethod{ | |
PublicKeys(testSigners["rsa"]), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
t.Fatalf("unable to dial remote side: %s", err) | |
@@ -104,6 +103,7 @@ func TestAuthMethodPassword(t *testing.T) { | |
Auth: []AuthMethod{ | |
Password(clientPassword), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
@@ -123,6 +123,7 @@ func TestAuthMethodFallback(t *testing.T) { | |
return "WRONG", nil | |
}), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
@@ -141,6 +142,7 @@ func TestAuthMethodWrongPassword(t *testing.T) { | |
Password("wrong"), | |
PublicKeys(testSigners["rsa"]), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
@@ -158,6 +160,7 @@ func TestAuthMethodKeyboardInteractive(t *testing.T) { | |
Auth: []AuthMethod{ | |
KeyboardInteractive(answers.Challenge), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
@@ -203,6 +206,7 @@ func TestAuthMethodRSAandDSA(t *testing.T) { | |
Auth: []AuthMethod{ | |
PublicKeys(testSigners["dsa"], testSigners["rsa"]), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
t.Fatalf("client could not authenticate with rsa key: %v", err) | |
@@ -219,6 +223,7 @@ func TestClientHMAC(t *testing.T) { | |
Config: Config{ | |
MACs: []string{mac}, | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
t.Fatalf("client could not authenticate with mac algo %s: %v", mac, err) | |
@@ -254,6 +259,7 @@ func TestClientUnsupportedKex(t *testing.T) { | |
Config: Config{ | |
KeyExchanges: []string{"diffie-hellman-group-exchange-sha256"}, // not currently supported | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err == nil || !strings.Contains(err.Error(), "common algorithm") { | |
t.Errorf("got %v, expected 'common algorithm'", err) | |
@@ -273,7 +279,8 @@ func TestClientLoginCert(t *testing.T) { | |
} | |
clientConfig := &ClientConfig{ | |
- User: "user", | |
+ User: "user", | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
clientConfig.Auth = append(clientConfig.Auth, PublicKeys(certSigner)) | |
@@ -363,6 +370,7 @@ func testPermissionsPassing(withPermissions bool, t *testing.T) { | |
Auth: []AuthMethod{ | |
PublicKeys(testSigners["rsa"]), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if withPermissions { | |
clientConfig.User = "permissions" | |
@@ -409,6 +417,7 @@ func TestRetryableAuth(t *testing.T) { | |
}), 2), | |
PublicKeys(testSigners["rsa"]), | |
}, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
if err := tryAuth(t, config); err != nil { | |
@@ -430,7 +439,8 @@ func ExampleRetryableAuthMethod(t *testing.T) { | |
} | |
config := &ClientConfig{ | |
- User: user, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ User: user, | |
Auth: []AuthMethod{ | |
RetryableAuthMethod(KeyboardInteractiveChallenge(Cb), NumberOfPrompts), | |
}, | |
@@ -450,7 +460,8 @@ func TestClientAuthNone(t *testing.T) { | |
serverConfig.AddHostKey(testSigners["rsa"]) | |
clientConfig := &ClientConfig{ | |
- User: user, | |
+ User: user, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
c1, c2, err := netPipe() | |
@@ -469,3 +480,100 @@ func TestClientAuthNone(t *testing.T) { | |
t.Fatalf("server: got %q, want %q", serverConn.User(), user) | |
} | |
} | |
+ | |
+// Test if authentication attempts are limited on server when MaxAuthTries is set | |
+func TestClientAuthMaxAuthTries(t *testing.T) { | |
+ user := "testuser" | |
+ | |
+ serverConfig := &ServerConfig{ | |
+ MaxAuthTries: 2, | |
+ PasswordCallback: func(conn ConnMetadata, pass []byte) (*Permissions, error) { | |
+ if conn.User() == "testuser" && string(pass) == "right" { | |
+ return nil, nil | |
+ } | |
+ return nil, errors.New("password auth failed") | |
+ }, | |
+ } | |
+ serverConfig.AddHostKey(testSigners["rsa"]) | |
+ | |
+ expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ | |
+ Reason: 2, | |
+ Message: "too many authentication failures", | |
+ }) | |
+ | |
+ for tries := 2; tries < 4; tries++ { | |
+ n := tries | |
+ clientConfig := &ClientConfig{ | |
+ User: user, | |
+ Auth: []AuthMethod{ | |
+ RetryableAuthMethod(PasswordCallback(func() (string, error) { | |
+ n-- | |
+ if n == 0 { | |
+ return "right", nil | |
+ } else { | |
+ return "wrong", nil | |
+ } | |
+ }), tries), | |
+ }, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ } | |
+ | |
+ c1, c2, err := netPipe() | |
+ if err != nil { | |
+ t.Fatalf("netPipe: %v", err) | |
+ } | |
+ defer c1.Close() | |
+ defer c2.Close() | |
+ | |
+ go newServer(c1, serverConfig) | |
+ _, _, _, err = NewClientConn(c2, "", clientConfig) | |
+ if tries > 2 { | |
+ if err == nil { | |
+ t.Fatalf("client: got no error, want %s", expectedErr) | |
+ } else if err.Error() != expectedErr.Error() { | |
+ t.Fatalf("client: got %s, want %s", err, expectedErr) | |
+ } | |
+ } else { | |
+ if err != nil { | |
+ t.Fatalf("client: got %s, want no error", err) | |
+ } | |
+ } | |
+ } | |
+} | |
+ | |
+// Test if authentication attempts are correctly limited on server | |
+// when more public keys are provided then MaxAuthTries | |
+func TestClientAuthMaxAuthTriesPublicKey(t *testing.T) { | |
+ signers := []Signer{} | |
+ for i := 0; i < 6; i++ { | |
+ signers = append(signers, testSigners["dsa"]) | |
+ } | |
+ | |
+ validConfig := &ClientConfig{ | |
+ User: "testuser", | |
+ Auth: []AuthMethod{ | |
+ PublicKeys(append([]Signer{testSigners["rsa"]}, signers...)...), | |
+ }, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ } | |
+ if err := tryAuth(t, validConfig); err != nil { | |
+ t.Fatalf("unable to dial remote side: %s", err) | |
+ } | |
+ | |
+ expectedErr := fmt.Errorf("ssh: handshake failed: %v", &disconnectMsg{ | |
+ Reason: 2, | |
+ Message: "too many authentication failures", | |
+ }) | |
+ invalidConfig := &ClientConfig{ | |
+ User: "testuser", | |
+ Auth: []AuthMethod{ | |
+ PublicKeys(append(signers, testSigners["rsa"])...), | |
+ }, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ } | |
+ if err := tryAuth(t, invalidConfig); err == nil { | |
+ t.Fatalf("client: got no error, want %s", expectedErr) | |
+ } else if err.Error() != expectedErr.Error() { | |
+ t.Fatalf("client: got %s, want %s", err, expectedErr) | |
+ } | |
+} | |
diff --git a/ssh/client_test.go b/ssh/client_test.go | |
index 1fe790c..ccf5607 100644 | |
--- a/ssh/client_test.go | |
+++ b/ssh/client_test.go | |
@@ -6,6 +6,7 @@ package ssh | |
import ( | |
"net" | |
+ "strings" | |
"testing" | |
) | |
@@ -13,6 +14,7 @@ func testClientVersion(t *testing.T, config *ClientConfig, expected string) { | |
clientConn, serverConn := net.Pipe() | |
defer clientConn.Close() | |
receivedVersion := make(chan string, 1) | |
+ config.HostKeyCallback = InsecureIgnoreHostKey() | |
go func() { | |
version, err := readVersion(serverConn) | |
if err != nil { | |
@@ -37,3 +39,43 @@ func TestCustomClientVersion(t *testing.T) { | |
func TestDefaultClientVersion(t *testing.T) { | |
testClientVersion(t, &ClientConfig{}, packageVersion) | |
} | |
+ | |
+func TestHostKeyCheck(t *testing.T) { | |
+ for _, tt := range []struct { | |
+ name string | |
+ wantError string | |
+ key PublicKey | |
+ }{ | |
+ {"no callback", "must specify HostKeyCallback", nil}, | |
+ {"correct key", "", testSigners["rsa"].PublicKey()}, | |
+ {"mismatch", "mismatch", testSigners["ecdsa"].PublicKey()}, | |
+ } { | |
+ c1, c2, err := netPipe() | |
+ if err != nil { | |
+ t.Fatalf("netPipe: %v", err) | |
+ } | |
+ defer c1.Close() | |
+ defer c2.Close() | |
+ serverConf := &ServerConfig{ | |
+ NoClientAuth: true, | |
+ } | |
+ serverConf.AddHostKey(testSigners["rsa"]) | |
+ | |
+ go NewServerConn(c1, serverConf) | |
+ clientConf := ClientConfig{ | |
+ User: "user", | |
+ } | |
+ if tt.key != nil { | |
+ clientConf.HostKeyCallback = FixedHostKey(tt.key) | |
+ } | |
+ | |
+ _, _, _, err = NewClientConn(c2, "", &clientConf) | |
+ if err != nil { | |
+ if tt.wantError == "" || !strings.Contains(err.Error(), tt.wantError) { | |
+ t.Errorf("%s: got error %q, missing %q", tt.name, err.Error(), tt.wantError) | |
+ } | |
+ } else if tt.wantError != "" { | |
+ t.Errorf("%s: succeeded, but want error string %q", tt.name, tt.wantError) | |
+ } | |
+ } | |
+} | |
diff --git a/ssh/common.go b/ssh/common.go | |
index 8656d0f..dc39e4d 100644 | |
--- a/ssh/common.go | |
+++ b/ssh/common.go | |
@@ -9,6 +9,7 @@ import ( | |
"crypto/rand" | |
"fmt" | |
"io" | |
+ "math" | |
"sync" | |
_ "crypto/sha1" | |
@@ -40,7 +41,7 @@ var supportedKexAlgos = []string{ | |
kexAlgoDH14SHA1, kexAlgoDH1SHA1, | |
} | |
-// supportedKexAlgos specifies the supported host-key algorithms (i.e. methods | |
+// supportedHostKeyAlgos specifies the supported host-key algorithms (i.e. methods | |
// of authenticating servers) in preference order. | |
var supportedHostKeyAlgos = []string{ | |
CertAlgoRSAv01, CertAlgoDSAv01, CertAlgoECDSA256v01, | |
@@ -186,7 +187,7 @@ type Config struct { | |
// The maximum number of bytes sent or received after which a | |
// new key is negotiated. It must be at least 256. If | |
- // unspecified, 1 gigabyte is used. | |
+ // unspecified, a size suitable for the chosen cipher is used. | |
RekeyThreshold uint64 | |
// The allowed key exchanges algorithms. If unspecified then a | |
@@ -230,11 +231,12 @@ func (c *Config) SetDefaults() { | |
} | |
if c.RekeyThreshold == 0 { | |
- // RFC 4253, section 9 suggests rekeying after 1G. | |
- c.RekeyThreshold = 1 << 30 | |
- } | |
- if c.RekeyThreshold < minRekeyThreshold { | |
+ // cipher specific default | |
+ } else if c.RekeyThreshold < minRekeyThreshold { | |
c.RekeyThreshold = minRekeyThreshold | |
+ } else if c.RekeyThreshold >= math.MaxInt64 { | |
+ // Avoid weirdness if somebody uses -1 as a threshold. | |
+ c.RekeyThreshold = math.MaxInt64 | |
} | |
} | |
diff --git a/ssh/connection.go b/ssh/connection.go | |
index e786f2f..fd6b068 100644 | |
--- a/ssh/connection.go | |
+++ b/ssh/connection.go | |
@@ -25,7 +25,7 @@ type ConnMetadata interface { | |
// User returns the user ID for this connection. | |
User() string | |
- // SessionID returns the sesson hash, also denoted by H. | |
+ // SessionID returns the session hash, also denoted by H. | |
SessionID() []byte | |
// ClientVersion returns the client's version string as hashed | |
diff --git a/ssh/doc.go b/ssh/doc.go | |
index d6be894..67b7322 100644 | |
--- a/ssh/doc.go | |
+++ b/ssh/doc.go | |
@@ -14,5 +14,8 @@ others. | |
References: | |
[PROTOCOL.certkeys]: http://cvsweb.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.certkeys?rev=HEAD | |
[SSH-PARAMETERS]: http://www.iana.org/assignments/ssh-parameters/ssh-parameters.xml#ssh-parameters-1 | |
+ | |
+This package does not fall under the stability promise of the Go language itself, | |
+so its API may be changed when pressing needs arise. | |
*/ | |
package ssh // import "golang.org/x/crypto/ssh" | |
diff --git a/ssh/example_test.go b/ssh/example_test.go | |
index 4d2eabd..618398c 100644 | |
--- a/ssh/example_test.go | |
+++ b/ssh/example_test.go | |
@@ -5,12 +5,16 @@ | |
package ssh_test | |
import ( | |
+ "bufio" | |
"bytes" | |
"fmt" | |
"io/ioutil" | |
"log" | |
"net" | |
"net/http" | |
+ "os" | |
+ "path/filepath" | |
+ "strings" | |
"golang.org/x/crypto/ssh" | |
"golang.org/x/crypto/ssh/terminal" | |
@@ -91,8 +95,6 @@ func ExampleNewServerConn() { | |
go ssh.DiscardRequests(reqs) | |
// Service the incoming Channel channel. | |
- | |
- // Service the incoming Channel channel. | |
for newChannel := range chans { | |
// Channels have a type, depending on the application level | |
// protocol intended. In the case of a shell, the type is | |
@@ -131,16 +133,59 @@ func ExampleNewServerConn() { | |
} | |
} | |
+func ExampleHostKeyCheck() { | |
+ // Every client must provide a host key check. Here is a | |
+ // simple-minded parse of OpenSSH's known_hosts file | |
+ host := "hostname" | |
+ file, err := os.Open(filepath.Join(os.Getenv("HOME"), ".ssh", "known_hosts")) | |
+ if err != nil { | |
+ log.Fatal(err) | |
+ } | |
+ defer file.Close() | |
+ | |
+ scanner := bufio.NewScanner(file) | |
+ var hostKey ssh.PublicKey | |
+ for scanner.Scan() { | |
+ fields := strings.Split(scanner.Text(), " ") | |
+ if len(fields) != 3 { | |
+ continue | |
+ } | |
+ if strings.Contains(fields[0], host) { | |
+ var err error | |
+ hostKey, _, _, _, err = ssh.ParseAuthorizedKey(scanner.Bytes()) | |
+ if err != nil { | |
+ log.Fatalf("error parsing %q: %v", fields[2], err) | |
+ } | |
+ break | |
+ } | |
+ } | |
+ | |
+ if hostKey == nil { | |
+ log.Fatalf("no hostkey for %s", host) | |
+ } | |
+ | |
+ config := ssh.ClientConfig{ | |
+ User: os.Getenv("USER"), | |
+ HostKeyCallback: ssh.FixedHostKey(hostKey), | |
+ } | |
+ | |
+ _, err = ssh.Dial("tcp", host+":22", &config) | |
+ log.Println(err) | |
+} | |
+ | |
func ExampleDial() { | |
+ var hostKey ssh.PublicKey | |
// An SSH client is represented with a ClientConn. | |
// | |
// To authenticate with the remote server you must pass at least one | |
- // implementation of AuthMethod via the Auth field in ClientConfig. | |
+ // implementation of AuthMethod via the Auth field in ClientConfig, | |
+ // and provide a HostKeyCallback. | |
config := &ssh.ClientConfig{ | |
User: "username", | |
Auth: []ssh.AuthMethod{ | |
ssh.Password("yourpassword"), | |
}, | |
+ HostKeyCallback: ssh.FixedHostKey(hostKey), | |
} | |
client, err := ssh.Dial("tcp", "yourserver.com:22", config) | |
if err != nil { | |
@@ -166,6 +211,7 @@ func ExampleDial() { | |
} | |
func ExamplePublicKeys() { | |
+ var hostKey ssh.PublicKey | |
// A public key may be used to authenticate against the remote | |
// server by using an unencrypted PEM-encoded private key file. | |
// | |
@@ -188,6 +234,7 @@ func ExamplePublicKeys() { | |
// Use the PublicKeys method for remote authentication. | |
ssh.PublicKeys(signer), | |
}, | |
+ HostKeyCallback: ssh.FixedHostKey(hostKey), | |
} | |
// Connect to the remote server and perform the SSH handshake. | |
@@ -199,11 +246,13 @@ func ExamplePublicKeys() { | |
} | |
func ExampleClient_Listen() { | |
+ var hostKey ssh.PublicKey | |
config := &ssh.ClientConfig{ | |
User: "username", | |
Auth: []ssh.AuthMethod{ | |
ssh.Password("password"), | |
}, | |
+ HostKeyCallback: ssh.FixedHostKey(hostKey), | |
} | |
// Dial your ssh server. | |
conn, err := ssh.Dial("tcp", "localhost:22", config) | |
@@ -226,12 +275,14 @@ func ExampleClient_Listen() { | |
} | |
func ExampleSession_RequestPty() { | |
+ var hostKey ssh.PublicKey | |
// Create client config | |
config := &ssh.ClientConfig{ | |
User: "username", | |
Auth: []ssh.AuthMethod{ | |
ssh.Password("password"), | |
}, | |
+ HostKeyCallback: ssh.FixedHostKey(hostKey), | |
} | |
// Connect to ssh server | |
conn, err := ssh.Dial("tcp", "localhost:22", config) | |
diff --git a/ssh/handshake.go b/ssh/handshake.go | |
index 8de6506..932ce83 100644 | |
--- a/ssh/handshake.go | |
+++ b/ssh/handshake.go | |
@@ -74,7 +74,7 @@ type handshakeTransport struct { | |
startKex chan *pendingKex | |
// data for host key checking | |
- hostKeyCallback func(hostname string, remote net.Addr, key PublicKey) error | |
+ hostKeyCallback HostKeyCallback | |
dialAddress string | |
remoteAddr net.Addr | |
@@ -107,6 +107,8 @@ func newHandshakeTransport(conn keyingTransport, config *Config, clientVersion, | |
config: config, | |
} | |
+ t.resetReadThresholds() | |
+ t.resetWriteThresholds() | |
// We always start with a mandatory key exchange. | |
t.requestKex <- struct{}{} | |
@@ -237,6 +239,17 @@ func (t *handshakeTransport) requestKeyExchange() { | |
} | |
} | |
+func (t *handshakeTransport) resetWriteThresholds() { | |
+ t.writePacketsLeft = packetRekeyThreshold | |
+ if t.config.RekeyThreshold > 0 { | |
+ t.writeBytesLeft = int64(t.config.RekeyThreshold) | |
+ } else if t.algorithms != nil { | |
+ t.writeBytesLeft = t.algorithms.w.rekeyBytes() | |
+ } else { | |
+ t.writeBytesLeft = 1 << 30 | |
+ } | |
+} | |
+ | |
func (t *handshakeTransport) kexLoop() { | |
write: | |
@@ -285,12 +298,8 @@ write: | |
t.writeError = err | |
t.sentInitPacket = nil | |
t.sentInitMsg = nil | |
- t.writePacketsLeft = packetRekeyThreshold | |
- if t.config.RekeyThreshold > 0 { | |
- t.writeBytesLeft = int64(t.config.RekeyThreshold) | |
- } else if t.algorithms != nil { | |
- t.writeBytesLeft = t.algorithms.w.rekeyBytes() | |
- } | |
+ | |
+ t.resetWriteThresholds() | |
// we have completed the key exchange. Since the | |
// reader is still blocked, it is safe to clear out | |
@@ -344,6 +353,17 @@ write: | |
// key exchange itself. | |
const packetRekeyThreshold = (1 << 31) | |
+func (t *handshakeTransport) resetReadThresholds() { | |
+ t.readPacketsLeft = packetRekeyThreshold | |
+ if t.config.RekeyThreshold > 0 { | |
+ t.readBytesLeft = int64(t.config.RekeyThreshold) | |
+ } else if t.algorithms != nil { | |
+ t.readBytesLeft = t.algorithms.r.rekeyBytes() | |
+ } else { | |
+ t.readBytesLeft = 1 << 30 | |
+ } | |
+} | |
+ | |
func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { | |
p, err := t.conn.readPacket() | |
if err != nil { | |
@@ -391,12 +411,7 @@ func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) { | |
return nil, err | |
} | |
- t.readPacketsLeft = packetRekeyThreshold | |
- if t.config.RekeyThreshold > 0 { | |
- t.readBytesLeft = int64(t.config.RekeyThreshold) | |
- } else { | |
- t.readBytesLeft = t.algorithms.r.rekeyBytes() | |
- } | |
+ t.resetReadThresholds() | |
// By default, a key exchange is hidden from higher layers by | |
// translating it into msgIgnore. | |
@@ -574,7 +589,9 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error { | |
} | |
result.SessionID = t.sessionID | |
- t.conn.prepareKeyChange(t.algorithms, result) | |
+ if err := t.conn.prepareKeyChange(t.algorithms, result); err != nil { | |
+ return err | |
+ } | |
if err = t.conn.writePacket([]byte{msgNewKeys}); err != nil { | |
return err | |
} | |
@@ -614,11 +631,9 @@ func (t *handshakeTransport) client(kex kexAlgorithm, algs *algorithms, magics * | |
return nil, err | |
} | |
- if t.hostKeyCallback != nil { | |
- err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) | |
- if err != nil { | |
- return nil, err | |
- } | |
+ err = t.hostKeyCallback(t.dialAddress, t.remoteAddr, hostKey) | |
+ if err != nil { | |
+ return nil, err | |
} | |
return result, nil | |
diff --git a/ssh/handshake_test.go b/ssh/handshake_test.go | |
index 1b83112..91d4935 100644 | |
--- a/ssh/handshake_test.go | |
+++ b/ssh/handshake_test.go | |
@@ -40,9 +40,12 @@ func (t *testChecker) Check(dialAddr string, addr net.Addr, key PublicKey) error | |
// therefore is buffered (net.Pipe deadlocks if both sides start with | |
// a write.) | |
func netPipe() (net.Conn, net.Conn, error) { | |
- listener, err := net.Listen("tcp", ":0") | |
+ listener, err := net.Listen("tcp", "127.0.0.1:0") | |
if err != nil { | |
- return nil, nil, err | |
+ listener, err = net.Listen("tcp", "[::1]:0") | |
+ if err != nil { | |
+ return nil, nil, err | |
+ } | |
} | |
defer listener.Close() | |
c1, err := net.Dial("tcp", listener.Addr().String()) | |
@@ -436,6 +439,7 @@ func testHandshakeErrorHandlingN(t *testing.T, readLimit, writeLimit int, couple | |
clientConf.SetDefaults() | |
clientConn := newHandshakeTransport(&errorKeyingTransport{b, -1, -1}, &clientConf, []byte{'a'}, []byte{'b'}) | |
clientConn.hostKeyAlgorithms = []string{key.PublicKey().Type()} | |
+ clientConn.hostKeyCallback = InsecureIgnoreHostKey() | |
go clientConn.readLoop() | |
go clientConn.kexLoop() | |
@@ -525,3 +529,31 @@ func TestDisconnect(t *testing.T) { | |
t.Errorf("readPacket 3 succeeded") | |
} | |
} | |
+ | |
+func TestHandshakeRekeyDefault(t *testing.T) { | |
+ clientConf := &ClientConfig{ | |
+ Config: Config{ | |
+ Ciphers: []string{"aes128-ctr"}, | |
+ }, | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ } | |
+ trC, trS, err := handshakePair(clientConf, "addr", false) | |
+ if err != nil { | |
+ t.Fatalf("handshakePair: %v", err) | |
+ } | |
+ defer trC.Close() | |
+ defer trS.Close() | |
+ | |
+ trC.writePacket([]byte{msgRequestSuccess, 0, 0}) | |
+ trC.Close() | |
+ | |
+ rgb := (1024 + trC.readBytesLeft) >> 30 | |
+ wgb := (1024 + trC.writeBytesLeft) >> 30 | |
+ | |
+ if rgb != 64 { | |
+ t.Errorf("got rekey after %dG read, want 64G", rgb) | |
+ } | |
+ if wgb != 64 { | |
+ t.Errorf("got rekey after %dG write, want 64G", wgb) | |
+ } | |
+} | |
diff --git a/ssh/keys.go b/ssh/keys.go | |
index f38de98..cf68532 100644 | |
--- a/ssh/keys.go | |
+++ b/ssh/keys.go | |
@@ -824,7 +824,7 @@ func ParseDSAPrivateKey(der []byte) (*dsa.PrivateKey, error) { | |
// Implemented based on the documentation at | |
// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL.key | |
-func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) { | |
+func parseOpenSSHPrivateKey(key []byte) (crypto.PrivateKey, error) { | |
magic := append([]byte("openssh-key-v1"), 0) | |
if !bytes.Equal(magic, key[0:len(magic)]) { | |
return nil, errors.New("ssh: invalid openssh private key format") | |
@@ -844,14 +844,15 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) { | |
return nil, err | |
} | |
+ if w.KdfName != "none" || w.CipherName != "none" { | |
+ return nil, errors.New("ssh: cannot decode encrypted private keys") | |
+ } | |
+ | |
pk1 := struct { | |
Check1 uint32 | |
Check2 uint32 | |
Keytype string | |
- Pub []byte | |
- Priv []byte | |
- Comment string | |
- Pad []byte `ssh:"rest"` | |
+ Rest []byte `ssh:"rest"` | |
}{} | |
if err := Unmarshal(w.PrivKeyBlock, &pk1); err != nil { | |
@@ -862,24 +863,75 @@ func parseOpenSSHPrivateKey(key []byte) (*ed25519.PrivateKey, error) { | |
return nil, errors.New("ssh: checkint mismatch") | |
} | |
- // we only handle ed25519 keys currently | |
- if pk1.Keytype != KeyAlgoED25519 { | |
- return nil, errors.New("ssh: unhandled key type") | |
- } | |
+ // we only handle ed25519 and rsa keys currently | |
+ switch pk1.Keytype { | |
+ case KeyAlgoRSA: | |
+ // https://github.com/openssh/openssh-portable/blob/master/sshkey.c#L2760-L2773 | |
+ key := struct { | |
+ N *big.Int | |
+ E *big.Int | |
+ D *big.Int | |
+ Iqmp *big.Int | |
+ P *big.Int | |
+ Q *big.Int | |
+ Comment string | |
+ Pad []byte `ssh:"rest"` | |
+ }{} | |
+ | |
+ if err := Unmarshal(pk1.Rest, &key); err != nil { | |
+ return nil, err | |
+ } | |
- for i, b := range pk1.Pad { | |
- if int(b) != i+1 { | |
- return nil, errors.New("ssh: padding not as expected") | |
+ for i, b := range key.Pad { | |
+ if int(b) != i+1 { | |
+ return nil, errors.New("ssh: padding not as expected") | |
+ } | |
} | |
- } | |
- if len(pk1.Priv) != ed25519.PrivateKeySize { | |
- return nil, errors.New("ssh: private key unexpected length") | |
- } | |
+ pk := &rsa.PrivateKey{ | |
+ PublicKey: rsa.PublicKey{ | |
+ N: key.N, | |
+ E: int(key.E.Int64()), | |
+ }, | |
+ D: key.D, | |
+ Primes: []*big.Int{key.P, key.Q}, | |
+ } | |
- pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize)) | |
- copy(pk, pk1.Priv) | |
- return &pk, nil | |
+ if err := pk.Validate(); err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ pk.Precompute() | |
+ | |
+ return pk, nil | |
+ case KeyAlgoED25519: | |
+ key := struct { | |
+ Pub []byte | |
+ Priv []byte | |
+ Comment string | |
+ Pad []byte `ssh:"rest"` | |
+ }{} | |
+ | |
+ if err := Unmarshal(pk1.Rest, &key); err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ if len(key.Priv) != ed25519.PrivateKeySize { | |
+ return nil, errors.New("ssh: private key unexpected length") | |
+ } | |
+ | |
+ for i, b := range key.Pad { | |
+ if int(b) != i+1 { | |
+ return nil, errors.New("ssh: padding not as expected") | |
+ } | |
+ } | |
+ | |
+ pk := ed25519.PrivateKey(make([]byte, ed25519.PrivateKeySize)) | |
+ copy(pk, key.Priv) | |
+ return &pk, nil | |
+ default: | |
+ return nil, errors.New("ssh: unhandled key type") | |
+ } | |
} | |
// FingerprintLegacyMD5 returns the user presentation of the key's | |
diff --git a/ssh/knownhosts/knownhosts.go b/ssh/knownhosts/knownhosts.go | |
new file mode 100644 | |
index 0000000..ea92b29 | |
--- /dev/null | |
+++ b/ssh/knownhosts/knownhosts.go | |
@@ -0,0 +1,546 @@ | |
+// Copyright 2017 The Go Authors. All rights reserved. | |
+// Use of this source code is governed by a BSD-style | |
+// license that can be found in the LICENSE file. | |
+ | |
+// Package knownhosts implements a parser for the OpenSSH | |
+// known_hosts host key database. | |
+package knownhosts | |
+ | |
+import ( | |
+ "bufio" | |
+ "bytes" | |
+ "crypto/hmac" | |
+ "crypto/rand" | |
+ "crypto/sha1" | |
+ "encoding/base64" | |
+ "errors" | |
+ "fmt" | |
+ "io" | |
+ "net" | |
+ "os" | |
+ "strings" | |
+ | |
+ "golang.org/x/crypto/ssh" | |
+) | |
+ | |
+// See the sshd manpage | |
+// (http://man.openbsd.org/sshd#SSH_KNOWN_HOSTS_FILE_FORMAT) for | |
+// background. | |
+ | |
+type addr struct{ host, port string } | |
+ | |
+func (a *addr) String() string { | |
+ h := a.host | |
+ if strings.Contains(h, ":") { | |
+ h = "[" + h + "]" | |
+ } | |
+ return h + ":" + a.port | |
+} | |
+ | |
+type matcher interface { | |
+ match([]addr) bool | |
+} | |
+ | |
+type hostPattern struct { | |
+ negate bool | |
+ addr addr | |
+} | |
+ | |
+func (p *hostPattern) String() string { | |
+ n := "" | |
+ if p.negate { | |
+ n = "!" | |
+ } | |
+ | |
+ return n + p.addr.String() | |
+} | |
+ | |
+type hostPatterns []hostPattern | |
+ | |
+func (ps hostPatterns) match(addrs []addr) bool { | |
+ matched := false | |
+ for _, p := range ps { | |
+ for _, a := range addrs { | |
+ m := p.match(a) | |
+ if !m { | |
+ continue | |
+ } | |
+ if p.negate { | |
+ return false | |
+ } | |
+ matched = true | |
+ } | |
+ } | |
+ return matched | |
+} | |
+ | |
+// See | |
+// https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/addrmatch.c | |
+// The matching of * has no regard for separators, unlike filesystem globs | |
+func wildcardMatch(pat []byte, str []byte) bool { | |
+ for { | |
+ if len(pat) == 0 { | |
+ return len(str) == 0 | |
+ } | |
+ if len(str) == 0 { | |
+ return false | |
+ } | |
+ | |
+ if pat[0] == '*' { | |
+ if len(pat) == 1 { | |
+ return true | |
+ } | |
+ | |
+ for j := range str { | |
+ if wildcardMatch(pat[1:], str[j:]) { | |
+ return true | |
+ } | |
+ } | |
+ return false | |
+ } | |
+ | |
+ if pat[0] == '?' || pat[0] == str[0] { | |
+ pat = pat[1:] | |
+ str = str[1:] | |
+ } else { | |
+ return false | |
+ } | |
+ } | |
+} | |
+ | |
+func (l *hostPattern) match(a addr) bool { | |
+ return wildcardMatch([]byte(l.addr.host), []byte(a.host)) && l.addr.port == a.port | |
+} | |
+ | |
+type keyDBLine struct { | |
+ cert bool | |
+ matcher matcher | |
+ knownKey KnownKey | |
+} | |
+ | |
+func serialize(k ssh.PublicKey) string { | |
+ return k.Type() + " " + base64.StdEncoding.EncodeToString(k.Marshal()) | |
+} | |
+ | |
+func (l *keyDBLine) match(addrs []addr) bool { | |
+ return l.matcher.match(addrs) | |
+} | |
+ | |
+type hostKeyDB struct { | |
+ // Serialized version of revoked keys | |
+ revoked map[string]*KnownKey | |
+ lines []keyDBLine | |
+} | |
+ | |
+func newHostKeyDB() *hostKeyDB { | |
+ db := &hostKeyDB{ | |
+ revoked: make(map[string]*KnownKey), | |
+ } | |
+ | |
+ return db | |
+} | |
+ | |
+func keyEq(a, b ssh.PublicKey) bool { | |
+ return bytes.Equal(a.Marshal(), b.Marshal()) | |
+} | |
+ | |
+// IsAuthorityForHost can be used as a callback in ssh.CertChecker | |
+func (db *hostKeyDB) IsHostAuthority(remote ssh.PublicKey, address string) bool { | |
+ h, p, err := net.SplitHostPort(address) | |
+ if err != nil { | |
+ return false | |
+ } | |
+ a := addr{host: h, port: p} | |
+ | |
+ for _, l := range db.lines { | |
+ if l.cert && keyEq(l.knownKey.Key, remote) && l.match([]addr{a}) { | |
+ return true | |
+ } | |
+ } | |
+ return false | |
+} | |
+ | |
+// IsRevoked can be used as a callback in ssh.CertChecker | |
+func (db *hostKeyDB) IsRevoked(key *ssh.Certificate) bool { | |
+ _, ok := db.revoked[string(key.Marshal())] | |
+ return ok | |
+} | |
+ | |
+const markerCert = "@cert-authority" | |
+const markerRevoked = "@revoked" | |
+ | |
+func nextWord(line []byte) (string, []byte) { | |
+ i := bytes.IndexAny(line, "\t ") | |
+ if i == -1 { | |
+ return string(line), nil | |
+ } | |
+ | |
+ return string(line[:i]), bytes.TrimSpace(line[i:]) | |
+} | |
+ | |
+func parseLine(line []byte) (marker, host string, key ssh.PublicKey, err error) { | |
+ if w, next := nextWord(line); w == markerCert || w == markerRevoked { | |
+ marker = w | |
+ line = next | |
+ } | |
+ | |
+ host, line = nextWord(line) | |
+ if len(line) == 0 { | |
+ return "", "", nil, errors.New("knownhosts: missing host pattern") | |
+ } | |
+ | |
+ // ignore the keytype as it's in the key blob anyway. | |
+ _, line = nextWord(line) | |
+ if len(line) == 0 { | |
+ return "", "", nil, errors.New("knownhosts: missing key type pattern") | |
+ } | |
+ | |
+ keyBlob, _ := nextWord(line) | |
+ | |
+ keyBytes, err := base64.StdEncoding.DecodeString(keyBlob) | |
+ if err != nil { | |
+ return "", "", nil, err | |
+ } | |
+ key, err = ssh.ParsePublicKey(keyBytes) | |
+ if err != nil { | |
+ return "", "", nil, err | |
+ } | |
+ | |
+ return marker, host, key, nil | |
+} | |
+ | |
+func (db *hostKeyDB) parseLine(line []byte, filename string, linenum int) error { | |
+ marker, pattern, key, err := parseLine(line) | |
+ if err != nil { | |
+ return err | |
+ } | |
+ | |
+ if marker == markerRevoked { | |
+ db.revoked[string(key.Marshal())] = &KnownKey{ | |
+ Key: key, | |
+ Filename: filename, | |
+ Line: linenum, | |
+ } | |
+ | |
+ return nil | |
+ } | |
+ | |
+ entry := keyDBLine{ | |
+ cert: marker == markerCert, | |
+ knownKey: KnownKey{ | |
+ Filename: filename, | |
+ Line: linenum, | |
+ Key: key, | |
+ }, | |
+ } | |
+ | |
+ if pattern[0] == '|' { | |
+ entry.matcher, err = newHashedHost(pattern) | |
+ } else { | |
+ entry.matcher, err = newHostnameMatcher(pattern) | |
+ } | |
+ | |
+ if err != nil { | |
+ return err | |
+ } | |
+ | |
+ db.lines = append(db.lines, entry) | |
+ return nil | |
+} | |
+ | |
+func newHostnameMatcher(pattern string) (matcher, error) { | |
+ var hps hostPatterns | |
+ for _, p := range strings.Split(pattern, ",") { | |
+ if len(p) == 0 { | |
+ continue | |
+ } | |
+ | |
+ var a addr | |
+ var negate bool | |
+ if p[0] == '!' { | |
+ negate = true | |
+ p = p[1:] | |
+ } | |
+ | |
+ if len(p) == 0 { | |
+ return nil, errors.New("knownhosts: negation without following hostname") | |
+ } | |
+ | |
+ var err error | |
+ if p[0] == '[' { | |
+ a.host, a.port, err = net.SplitHostPort(p) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ } else { | |
+ a.host, a.port, err = net.SplitHostPort(p) | |
+ if err != nil { | |
+ a.host = p | |
+ a.port = "22" | |
+ } | |
+ } | |
+ hps = append(hps, hostPattern{ | |
+ negate: negate, | |
+ addr: a, | |
+ }) | |
+ } | |
+ return hps, nil | |
+} | |
+ | |
+// KnownKey represents a key declared in a known_hosts file. | |
+type KnownKey struct { | |
+ Key ssh.PublicKey | |
+ Filename string | |
+ Line int | |
+} | |
+ | |
+func (k *KnownKey) String() string { | |
+ return fmt.Sprintf("%s:%d: %s", k.Filename, k.Line, serialize(k.Key)) | |
+} | |
+ | |
+// KeyError is returned if we did not find the key in the host key | |
+// database, or there was a mismatch. Typically, in batch | |
+// applications, this should be interpreted as failure. Interactive | |
+// applications can offer an interactive prompt to the user. | |
+type KeyError struct { | |
+ // Want holds the accepted host keys. For each key algorithm, | |
+ // there can be one hostkey. If Want is empty, the host is | |
+ // unknown. If Want is non-empty, there was a mismatch, which | |
+ // can signify a MITM attack. | |
+ Want []KnownKey | |
+} | |
+ | |
+func (u *KeyError) Error() string { | |
+ if len(u.Want) == 0 { | |
+ return "knownhosts: key is unknown" | |
+ } | |
+ return "knownhosts: key mismatch" | |
+} | |
+ | |
+// RevokedError is returned if we found a key that was revoked. | |
+type RevokedError struct { | |
+ Revoked KnownKey | |
+} | |
+ | |
+func (r *RevokedError) Error() string { | |
+ return "knownhosts: key is revoked" | |
+} | |
+ | |
+// check checks a key against the host database. This should not be | |
+// used for verifying certificates. | |
+func (db *hostKeyDB) check(address string, remote net.Addr, remoteKey ssh.PublicKey) error { | |
+ if revoked := db.revoked[string(remoteKey.Marshal())]; revoked != nil { | |
+ return &RevokedError{Revoked: *revoked} | |
+ } | |
+ | |
+ host, port, err := net.SplitHostPort(remote.String()) | |
+ if err != nil { | |
+ return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", remote, err) | |
+ } | |
+ | |
+ addrs := []addr{ | |
+ {host, port}, | |
+ } | |
+ | |
+ if address != "" { | |
+ host, port, err := net.SplitHostPort(address) | |
+ if err != nil { | |
+ return fmt.Errorf("knownhosts: SplitHostPort(%s): %v", address, err) | |
+ } | |
+ | |
+ addrs = append(addrs, addr{host, port}) | |
+ } | |
+ | |
+ return db.checkAddrs(addrs, remoteKey) | |
+} | |
+ | |
+// checkAddrs checks if we can find the given public key for any of | |
+// the given addresses. If we only find an entry for the IP address, | |
+// or only the hostname, then this still succeeds. | |
+func (db *hostKeyDB) checkAddrs(addrs []addr, remoteKey ssh.PublicKey) error { | |
+ // TODO(hanwen): are these the right semantics? What if there | |
+ // is just a key for the IP address, but not for the | |
+ // hostname? | |
+ | |
+ // Algorithm => key. | |
+ knownKeys := map[string]KnownKey{} | |
+ for _, l := range db.lines { | |
+ if l.match(addrs) { | |
+ typ := l.knownKey.Key.Type() | |
+ if _, ok := knownKeys[typ]; !ok { | |
+ knownKeys[typ] = l.knownKey | |
+ } | |
+ } | |
+ } | |
+ | |
+ keyErr := &KeyError{} | |
+ for _, v := range knownKeys { | |
+ keyErr.Want = append(keyErr.Want, v) | |
+ } | |
+ | |
+ // Unknown remote host. | |
+ if len(knownKeys) == 0 { | |
+ return keyErr | |
+ } | |
+ | |
+ // If the remote host starts using a different, unknown key type, we | |
+ // also interpret that as a mismatch. | |
+ if known, ok := knownKeys[remoteKey.Type()]; !ok || !keyEq(known.Key, remoteKey) { | |
+ return keyErr | |
+ } | |
+ | |
+ return nil | |
+} | |
+ | |
+// The Read function parses file contents. | |
+func (db *hostKeyDB) Read(r io.Reader, filename string) error { | |
+ scanner := bufio.NewScanner(r) | |
+ | |
+ lineNum := 0 | |
+ for scanner.Scan() { | |
+ lineNum++ | |
+ line := scanner.Bytes() | |
+ line = bytes.TrimSpace(line) | |
+ if len(line) == 0 || line[0] == '#' { | |
+ continue | |
+ } | |
+ | |
+ if err := db.parseLine(line, filename, lineNum); err != nil { | |
+ return fmt.Errorf("knownhosts: %s:%d: %v", filename, lineNum, err) | |
+ } | |
+ } | |
+ return scanner.Err() | |
+} | |
+ | |
+// New creates a host key callback from the given OpenSSH host key | |
+// files. The returned callback is for use in | |
+// ssh.ClientConfig.HostKeyCallback. Hashed hostnames are not supported. | |
+func New(files ...string) (ssh.HostKeyCallback, error) { | |
+ db := newHostKeyDB() | |
+ for _, fn := range files { | |
+ f, err := os.Open(fn) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ defer f.Close() | |
+ if err := db.Read(f, fn); err != nil { | |
+ return nil, err | |
+ } | |
+ } | |
+ | |
+ var certChecker ssh.CertChecker | |
+ certChecker.IsHostAuthority = db.IsHostAuthority | |
+ certChecker.IsRevoked = db.IsRevoked | |
+ certChecker.HostKeyFallback = db.check | |
+ | |
+ return certChecker.CheckHostKey, nil | |
+} | |
+ | |
+// Normalize normalizes an address into the form used in known_hosts | |
+func Normalize(address string) string { | |
+ host, port, err := net.SplitHostPort(address) | |
+ if err != nil { | |
+ host = address | |
+ port = "22" | |
+ } | |
+ entry := host | |
+ if port != "22" { | |
+ entry = "[" + entry + "]:" + port | |
+ } else if strings.Contains(host, ":") && !strings.HasPrefix(host, "[") { | |
+ entry = "[" + entry + "]" | |
+ } | |
+ return entry | |
+} | |
+ | |
+// Line returns a line to add append to the known_hosts files. | |
+func Line(addresses []string, key ssh.PublicKey) string { | |
+ var trimmed []string | |
+ for _, a := range addresses { | |
+ trimmed = append(trimmed, Normalize(a)) | |
+ } | |
+ | |
+ return strings.Join(trimmed, ",") + " " + serialize(key) | |
+} | |
+ | |
+// HashHostname hashes the given hostname. The hostname is not | |
+// normalized before hashing. | |
+func HashHostname(hostname string) string { | |
+ // TODO(hanwen): check if we can safely normalize this always. | |
+ salt := make([]byte, sha1.Size) | |
+ | |
+ _, err := rand.Read(salt) | |
+ if err != nil { | |
+ panic(fmt.Sprintf("crypto/rand failure %v", err)) | |
+ } | |
+ | |
+ hash := hashHost(hostname, salt) | |
+ return encodeHash(sha1HashType, salt, hash) | |
+} | |
+ | |
+func decodeHash(encoded string) (hashType string, salt, hash []byte, err error) { | |
+ if len(encoded) == 0 || encoded[0] != '|' { | |
+ err = errors.New("knownhosts: hashed host must start with '|'") | |
+ return | |
+ } | |
+ components := strings.Split(encoded, "|") | |
+ if len(components) != 4 { | |
+ err = fmt.Errorf("knownhosts: got %d components, want 3", len(components)) | |
+ return | |
+ } | |
+ | |
+ hashType = components[1] | |
+ if salt, err = base64.StdEncoding.DecodeString(components[2]); err != nil { | |
+ return | |
+ } | |
+ if hash, err = base64.StdEncoding.DecodeString(components[3]); err != nil { | |
+ return | |
+ } | |
+ return | |
+} | |
+ | |
+func encodeHash(typ string, salt []byte, hash []byte) string { | |
+ return strings.Join([]string{"", | |
+ typ, | |
+ base64.StdEncoding.EncodeToString(salt), | |
+ base64.StdEncoding.EncodeToString(hash), | |
+ }, "|") | |
+} | |
+ | |
+// See https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120 | |
+func hashHost(hostname string, salt []byte) []byte { | |
+ mac := hmac.New(sha1.New, salt) | |
+ mac.Write([]byte(hostname)) | |
+ return mac.Sum(nil) | |
+} | |
+ | |
+type hashedHost struct { | |
+ salt []byte | |
+ hash []byte | |
+} | |
+ | |
+const sha1HashType = "1" | |
+ | |
+func newHashedHost(encoded string) (*hashedHost, error) { | |
+ typ, salt, hash, err := decodeHash(encoded) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ // The type field seems for future algorithm agility, but it's | |
+ // actually hardcoded in openssh currently, see | |
+ // https://android.googlesource.com/platform/external/openssh/+/ab28f5495c85297e7a597c1ba62e996416da7c7e/hostfile.c#120 | |
+ if typ != sha1HashType { | |
+ return nil, fmt.Errorf("knownhosts: got hash type %s, must be '1'", typ) | |
+ } | |
+ | |
+ return &hashedHost{salt: salt, hash: hash}, nil | |
+} | |
+ | |
+func (h *hashedHost) match(addrs []addr) bool { | |
+ for _, a := range addrs { | |
+ if bytes.Equal(hashHost(Normalize(a.String()), h.salt), h.hash) { | |
+ return true | |
+ } | |
+ } | |
+ return false | |
+} | |
diff --git a/ssh/knownhosts/knownhosts_test.go b/ssh/knownhosts/knownhosts_test.go | |
new file mode 100644 | |
index 0000000..be7cc0e | |
--- /dev/null | |
+++ b/ssh/knownhosts/knownhosts_test.go | |
@@ -0,0 +1,329 @@ | |
+// Copyright 2017 The Go Authors. All rights reserved. | |
+// Use of this source code is governed by a BSD-style | |
+// license that can be found in the LICENSE file. | |
+ | |
+package knownhosts | |
+ | |
+import ( | |
+ "bytes" | |
+ "fmt" | |
+ "net" | |
+ "reflect" | |
+ "testing" | |
+ | |
+ "golang.org/x/crypto/ssh" | |
+) | |
+ | |
+const edKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIGBAarftlLeoyf+v+nVchEZII/vna2PCV8FaX4vsF5BX" | |
+const alternateEdKeyStr = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIIXffBYeYL+WVzVru8npl5JHt2cjlr4ornFTWzoij9sx" | |
+const ecKeyStr = "ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBNLCu01+wpXe3xB5olXCN4SqU2rQu0qjSRKJO4Bg+JRCPU+ENcgdA5srTU8xYDz/GEa4dzK5ldPw4J/gZgSXCMs=" | |
+ | |
+var ecKey, alternateEdKey, edKey ssh.PublicKey | |
+var testAddr = &net.TCPAddr{ | |
+ IP: net.IP{198, 41, 30, 196}, | |
+ Port: 22, | |
+} | |
+ | |
+var testAddr6 = &net.TCPAddr{ | |
+ IP: net.IP{198, 41, 30, 196, | |
+ 1, 2, 3, 4, | |
+ 1, 2, 3, 4, | |
+ 1, 2, 3, 4, | |
+ }, | |
+ Port: 22, | |
+} | |
+ | |
+func init() { | |
+ var err error | |
+ ecKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(ecKeyStr)) | |
+ if err != nil { | |
+ panic(err) | |
+ } | |
+ edKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(edKeyStr)) | |
+ if err != nil { | |
+ panic(err) | |
+ } | |
+ alternateEdKey, _, _, _, err = ssh.ParseAuthorizedKey([]byte(alternateEdKeyStr)) | |
+ if err != nil { | |
+ panic(err) | |
+ } | |
+} | |
+ | |
+func testDB(t *testing.T, s string) *hostKeyDB { | |
+ db := newHostKeyDB() | |
+ if err := db.Read(bytes.NewBufferString(s), "testdb"); err != nil { | |
+ t.Fatalf("Read: %v", err) | |
+ } | |
+ | |
+ return db | |
+} | |
+ | |
+func TestRevoked(t *testing.T) { | |
+ db := testDB(t, "\n\n@revoked * "+edKeyStr+"\n") | |
+ want := &RevokedError{ | |
+ Revoked: KnownKey{ | |
+ Key: edKey, | |
+ Filename: "testdb", | |
+ Line: 3, | |
+ }, | |
+ } | |
+ if err := db.check("", &net.TCPAddr{ | |
+ Port: 42, | |
+ }, edKey); err == nil { | |
+ t.Fatal("no error for revoked key") | |
+ } else if !reflect.DeepEqual(want, err) { | |
+ t.Fatalf("got %#v, want %#v", want, err) | |
+ } | |
+} | |
+ | |
+func TestHostAuthority(t *testing.T) { | |
+ for _, m := range []struct { | |
+ authorityFor string | |
+ address string | |
+ | |
+ good bool | |
+ }{ | |
+ {authorityFor: "localhost", address: "localhost:22", good: true}, | |
+ {authorityFor: "localhost", address: "localhost", good: false}, | |
+ {authorityFor: "localhost", address: "localhost:1234", good: false}, | |
+ {authorityFor: "[localhost]:1234", address: "localhost:1234", good: true}, | |
+ {authorityFor: "[localhost]:1234", address: "localhost:22", good: false}, | |
+ {authorityFor: "[localhost]:1234", address: "localhost", good: false}, | |
+ } { | |
+ db := testDB(t, `@cert-authority `+m.authorityFor+` `+edKeyStr) | |
+ if ok := db.IsHostAuthority(db.lines[0].knownKey.Key, m.address); ok != m.good { | |
+ t.Errorf("IsHostAuthority: authority %s, address %s, wanted good = %v, got good = %v", | |
+ m.authorityFor, m.address, m.good, ok) | |
+ } | |
+ } | |
+} | |
+ | |
+func TestBracket(t *testing.T) { | |
+ db := testDB(t, `[git.eclipse.org]:29418,[198.41.30.196]:29418 `+edKeyStr) | |
+ | |
+ if err := db.check("git.eclipse.org:29418", &net.TCPAddr{ | |
+ IP: net.IP{198, 41, 30, 196}, | |
+ Port: 29418, | |
+ }, edKey); err != nil { | |
+ t.Errorf("got error %v, want none", err) | |
+ } | |
+ | |
+ if err := db.check("git.eclipse.org:29419", &net.TCPAddr{ | |
+ Port: 42, | |
+ }, edKey); err == nil { | |
+ t.Fatalf("no error for unknown address") | |
+ } else if ke, ok := err.(*KeyError); !ok { | |
+ t.Fatalf("got type %T, want *KeyError", err) | |
+ } else if len(ke.Want) > 0 { | |
+ t.Fatalf("got Want %v, want []", ke.Want) | |
+ } | |
+} | |
+ | |
+func TestNewKeyType(t *testing.T) { | |
+ str := fmt.Sprintf("%s %s", testAddr, edKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check("", testAddr, ecKey); err == nil { | |
+ t.Fatalf("no error for unknown address") | |
+ } else if ke, ok := err.(*KeyError); !ok { | |
+ t.Fatalf("got type %T, want *KeyError", err) | |
+ } else if len(ke.Want) == 0 { | |
+ t.Fatalf("got empty KeyError.Want") | |
+ } | |
+} | |
+ | |
+func TestSameKeyType(t *testing.T) { | |
+ str := fmt.Sprintf("%s %s", testAddr, edKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check("", testAddr, alternateEdKey); err == nil { | |
+ t.Fatalf("no error for unknown address") | |
+ } else if ke, ok := err.(*KeyError); !ok { | |
+ t.Fatalf("got type %T, want *KeyError", err) | |
+ } else if len(ke.Want) == 0 { | |
+ t.Fatalf("got empty KeyError.Want") | |
+ } else if got, want := ke.Want[0].Key.Marshal(), edKey.Marshal(); !bytes.Equal(got, want) { | |
+ t.Fatalf("got key %q, want %q", got, want) | |
+ } | |
+} | |
+ | |
+func TestIPAddress(t *testing.T) { | |
+ str := fmt.Sprintf("%s %s", testAddr, edKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check("", testAddr, edKey); err != nil { | |
+ t.Errorf("got error %q, want none", err) | |
+ } | |
+} | |
+ | |
+func TestIPv6Address(t *testing.T) { | |
+ str := fmt.Sprintf("%s %s", testAddr6, edKeyStr) | |
+ db := testDB(t, str) | |
+ | |
+ if err := db.check("", testAddr6, edKey); err != nil { | |
+ t.Errorf("got error %q, want none", err) | |
+ } | |
+} | |
+ | |
+func TestBasic(t *testing.T) { | |
+ str := fmt.Sprintf("#comment\n\nserver.org,%s %s\notherhost %s", testAddr, edKeyStr, ecKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check("server.org:22", testAddr, edKey); err != nil { | |
+ t.Errorf("got error %q, want none", err) | |
+ } | |
+ | |
+ want := KnownKey{ | |
+ Key: edKey, | |
+ Filename: "testdb", | |
+ Line: 3, | |
+ } | |
+ if err := db.check("server.org:22", testAddr, ecKey); err == nil { | |
+ t.Errorf("succeeded, want KeyError") | |
+ } else if ke, ok := err.(*KeyError); !ok { | |
+ t.Errorf("got %T, want *KeyError", err) | |
+ } else if len(ke.Want) != 1 { | |
+ t.Errorf("got %v, want 1 entry", ke) | |
+ } else if !reflect.DeepEqual(ke.Want[0], want) { | |
+ t.Errorf("got %v, want %v", ke.Want[0], want) | |
+ } | |
+} | |
+ | |
+func TestNegate(t *testing.T) { | |
+ str := fmt.Sprintf("%s,!server.org %s", testAddr, edKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check("server.org:22", testAddr, ecKey); err == nil { | |
+ t.Errorf("succeeded") | |
+ } else if ke, ok := err.(*KeyError); !ok { | |
+ t.Errorf("got error type %T, want *KeyError", err) | |
+ } else if len(ke.Want) != 0 { | |
+ t.Errorf("got expected keys %d (first of type %s), want []", len(ke.Want), ke.Want[0].Key.Type()) | |
+ } | |
+} | |
+ | |
+func TestWildcard(t *testing.T) { | |
+ str := fmt.Sprintf("server*.domain %s", edKeyStr) | |
+ db := testDB(t, str) | |
+ | |
+ want := &KeyError{ | |
+ Want: []KnownKey{{ | |
+ Filename: "testdb", | |
+ Line: 1, | |
+ Key: edKey, | |
+ }}, | |
+ } | |
+ | |
+ got := db.check("server.domain:22", &net.TCPAddr{}, ecKey) | |
+ if !reflect.DeepEqual(got, want) { | |
+ t.Errorf("got %s, want %s", got, want) | |
+ } | |
+} | |
+ | |
+func TestLine(t *testing.T) { | |
+ for in, want := range map[string]string{ | |
+ "server.org": "server.org " + edKeyStr, | |
+ "server.org:22": "server.org " + edKeyStr, | |
+ "server.org:23": "[server.org]:23 " + edKeyStr, | |
+ "[c629:1ec4:102:304:102:304:102:304]:22": "[c629:1ec4:102:304:102:304:102:304] " + edKeyStr, | |
+ "[c629:1ec4:102:304:102:304:102:304]:23": "[c629:1ec4:102:304:102:304:102:304]:23 " + edKeyStr, | |
+ } { | |
+ if got := Line([]string{in}, edKey); got != want { | |
+ t.Errorf("Line(%q) = %q, want %q", in, got, want) | |
+ } | |
+ } | |
+} | |
+ | |
+func TestWildcardMatch(t *testing.T) { | |
+ for _, c := range []struct { | |
+ pat, str string | |
+ want bool | |
+ }{ | |
+ {"a?b", "abb", true}, | |
+ {"ab", "abc", false}, | |
+ {"abc", "ab", false}, | |
+ {"a*b", "axxxb", true}, | |
+ {"a*b", "axbxb", true}, | |
+ {"a*b", "axbxbc", false}, | |
+ {"a*?", "axbxc", true}, | |
+ {"a*b*", "axxbxxxxxx", true}, | |
+ {"a*b*c", "axxbxxxxxxc", true}, | |
+ {"a*b*?", "axxbxxxxxxc", true}, | |
+ {"a*b*z", "axxbxxbxxxz", true}, | |
+ {"a*b*z", "axxbxxzxxxz", true}, | |
+ {"a*b*z", "axxbxxzxxx", false}, | |
+ } { | |
+ got := wildcardMatch([]byte(c.pat), []byte(c.str)) | |
+ if got != c.want { | |
+ t.Errorf("wildcardMatch(%q, %q) = %v, want %v", c.pat, c.str, got, c.want) | |
+ } | |
+ | |
+ } | |
+} | |
+ | |
+// TODO(hanwen): test coverage for certificates. | |
+ | |
+const testHostname = "hostname" | |
+ | |
+// generated with keygen -H -f | |
+const encodedTestHostnameHash = "|1|IHXZvQMvTcZTUU29+2vXFgx8Frs=|UGccIWfRVDwilMBnA3WJoRAC75Y=" | |
+ | |
+func TestHostHash(t *testing.T) { | |
+ testHostHash(t, testHostname, encodedTestHostnameHash) | |
+} | |
+ | |
+func TestHashList(t *testing.T) { | |
+ encoded := HashHostname(testHostname) | |
+ testHostHash(t, testHostname, encoded) | |
+} | |
+ | |
+func testHostHash(t *testing.T, hostname, encoded string) { | |
+ typ, salt, hash, err := decodeHash(encoded) | |
+ if err != nil { | |
+ t.Fatalf("decodeHash: %v", err) | |
+ } | |
+ | |
+ if got := encodeHash(typ, salt, hash); got != encoded { | |
+ t.Errorf("got encoding %s want %s", got, encoded) | |
+ } | |
+ | |
+ if typ != sha1HashType { | |
+ t.Fatalf("got hash type %q, want %q", typ, sha1HashType) | |
+ } | |
+ | |
+ got := hashHost(hostname, salt) | |
+ if !bytes.Equal(got, hash) { | |
+ t.Errorf("got hash %x want %x", got, hash) | |
+ } | |
+} | |
+ | |
+func TestNormalize(t *testing.T) { | |
+ for in, want := range map[string]string{ | |
+ "127.0.0.1:22": "127.0.0.1", | |
+ "[127.0.0.1]:22": "127.0.0.1", | |
+ "[127.0.0.1]:23": "[127.0.0.1]:23", | |
+ "127.0.0.1:23": "[127.0.0.1]:23", | |
+ "[a.b.c]:22": "a.b.c", | |
+ "[abcd:abcd:abcd:abcd]": "[abcd:abcd:abcd:abcd]", | |
+ "[abcd:abcd:abcd:abcd]:22": "[abcd:abcd:abcd:abcd]", | |
+ "[abcd:abcd:abcd:abcd]:23": "[abcd:abcd:abcd:abcd]:23", | |
+ } { | |
+ got := Normalize(in) | |
+ if got != want { | |
+ t.Errorf("Normalize(%q) = %q, want %q", in, got, want) | |
+ } | |
+ } | |
+} | |
+ | |
+func TestHashedHostkeyCheck(t *testing.T) { | |
+ str := fmt.Sprintf("%s %s", HashHostname(testHostname), edKeyStr) | |
+ db := testDB(t, str) | |
+ if err := db.check(testHostname+":22", testAddr, edKey); err != nil { | |
+ t.Errorf("check(%s): %v", testHostname, err) | |
+ } | |
+ want := &KeyError{ | |
+ Want: []KnownKey{{ | |
+ Filename: "testdb", | |
+ Line: 1, | |
+ Key: edKey, | |
+ }}, | |
+ } | |
+ if got := db.check(testHostname+":22", testAddr, alternateEdKey); !reflect.DeepEqual(got, want) { | |
+ t.Errorf("got error %v, want %v", got, want) | |
+ } | |
+} | |
diff --git a/ssh/server.go b/ssh/server.go | |
index 77c84d1..23b41d9 100644 | |
--- a/ssh/server.go | |
+++ b/ssh/server.go | |
@@ -45,6 +45,12 @@ type ServerConfig struct { | |
// authenticating. | |
NoClientAuth bool | |
+ // MaxAuthTries specifies the maximum number of authentication attempts | |
+ // permitted per connection. If set to a negative number, the number of | |
+ // attempts are unlimited. If set to zero, the number of attempts are limited | |
+ // to 6. | |
+ MaxAuthTries int | |
+ | |
// PasswordCallback, if non-nil, is called when a user | |
// attempts to authenticate using a password. | |
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error) | |
@@ -143,6 +149,10 @@ type ServerConn struct { | |
func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewChannel, <-chan *Request, error) { | |
fullConf := *config | |
fullConf.SetDefaults() | |
+ if fullConf.MaxAuthTries == 0 { | |
+ fullConf.MaxAuthTries = 6 | |
+ } | |
+ | |
s := &connection{ | |
sshConn: sshConn{conn: c}, | |
} | |
@@ -267,8 +277,23 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err | |
var cache pubKeyCache | |
var perms *Permissions | |
+ authFailures := 0 | |
+ | |
userAuthLoop: | |
for { | |
+ if authFailures >= config.MaxAuthTries && config.MaxAuthTries > 0 { | |
+ discMsg := &disconnectMsg{ | |
+ Reason: 2, | |
+ Message: "too many authentication failures", | |
+ } | |
+ | |
+ if err := s.transport.writePacket(Marshal(discMsg)); err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ return nil, discMsg | |
+ } | |
+ | |
var userAuthReq userAuthRequestMsg | |
if packet, err := s.transport.readPacket(); err != nil { | |
return nil, err | |
@@ -289,6 +314,11 @@ userAuthLoop: | |
if config.NoClientAuth { | |
authErr = nil | |
} | |
+ | |
+ // allow initial attempt of 'none' without penalty | |
+ if authFailures == 0 { | |
+ authFailures-- | |
+ } | |
case "password": | |
if config.PasswordCallback == nil { | |
authErr = errors.New("ssh: password auth not configured") | |
@@ -360,6 +390,7 @@ userAuthLoop: | |
if isQuery { | |
// The client can query if the given public key | |
// would be okay. | |
+ | |
if len(payload) > 0 { | |
return nil, parseError(msgUserAuthRequest) | |
} | |
@@ -409,6 +440,8 @@ userAuthLoop: | |
break userAuthLoop | |
} | |
+ authFailures++ | |
+ | |
var failureMsg userAuthFailureMsg | |
if config.PasswordCallback != nil { | |
failureMsg.Methods = append(failureMsg.Methods, "password") | |
diff --git a/ssh/session_test.go b/ssh/session_test.go | |
index f35a378..7dce6dd 100644 | |
--- a/ssh/session_test.go | |
+++ b/ssh/session_test.go | |
@@ -59,7 +59,8 @@ func dial(handler serverType, t *testing.T) *Client { | |
}() | |
config := &ClientConfig{ | |
- User: "testuser", | |
+ User: "testuser", | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
} | |
conn, chans, reqs, err := NewClientConn(c2, "", config) | |
@@ -641,7 +642,8 @@ func TestSessionID(t *testing.T) { | |
} | |
serverConf.AddHostKey(testSigners["ecdsa"]) | |
clientConf := &ClientConfig{ | |
- User: "user", | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ User: "user", | |
} | |
go func() { | |
@@ -747,7 +749,9 @@ func TestHostKeyAlgorithms(t *testing.T) { | |
// By default, we get the preferred algorithm, which is ECDSA 256. | |
- clientConf := &ClientConfig{} | |
+ clientConf := &ClientConfig{ | |
+ HostKeyCallback: InsecureIgnoreHostKey(), | |
+ } | |
connect(clientConf, KeyAlgoECDSA256) | |
// Client asks for RSA explicitly. | |
diff --git a/ssh/streamlocal.go b/ssh/streamlocal.go | |
new file mode 100644 | |
index 0000000..a2dccc6 | |
--- /dev/null | |
+++ b/ssh/streamlocal.go | |
@@ -0,0 +1,115 @@ | |
+package ssh | |
+ | |
+import ( | |
+ "errors" | |
+ "io" | |
+ "net" | |
+) | |
+ | |
+// streamLocalChannelOpenDirectMsg is a struct used for SSH_MSG_CHANNEL_OPEN message | |
+// with "[email protected]" string. | |
+// | |
+// See openssh-portable/PROTOCOL, section 2.4. connection: Unix domain socket forwarding | |
+// https://github.com/openssh/openssh-portable/blob/master/PROTOCOL#L235 | |
+type streamLocalChannelOpenDirectMsg struct { | |
+ socketPath string | |
+ reserved0 string | |
+ reserved1 uint32 | |
+} | |
+ | |
+// forwardedStreamLocalPayload is a struct used for SSH_MSG_CHANNEL_OPEN message | |
+// with "[email protected]" string. | |
+type forwardedStreamLocalPayload struct { | |
+ SocketPath string | |
+ Reserved0 string | |
+} | |
+ | |
+// streamLocalChannelForwardMsg is a struct used for SSH2_MSG_GLOBAL_REQUEST message | |
+// with "[email protected]"/"[email protected]" string. | |
+type streamLocalChannelForwardMsg struct { | |
+ socketPath string | |
+} | |
+ | |
+// ListenUnix is similar to ListenTCP but uses a Unix domain socket. | |
+func (c *Client) ListenUnix(socketPath string) (net.Listener, error) { | |
+ m := streamLocalChannelForwardMsg{ | |
+ socketPath, | |
+ } | |
+ // send message | |
+ ok, _, err := c.SendRequest("[email protected]", true, Marshal(&m)) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ if !ok { | |
+ return nil, errors.New("ssh: [email protected] request denied by peer") | |
+ } | |
+ ch := c.forwards.add(&net.UnixAddr{Name: socketPath, Net: "unix"}) | |
+ | |
+ return &unixListener{socketPath, c, ch}, nil | |
+} | |
+ | |
+func (c *Client) dialStreamLocal(socketPath string) (Channel, error) { | |
+ msg := streamLocalChannelOpenDirectMsg{ | |
+ socketPath: socketPath, | |
+ } | |
+ ch, in, err := c.OpenChannel("[email protected]", Marshal(&msg)) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ go DiscardRequests(in) | |
+ return ch, err | |
+} | |
+ | |
+type unixListener struct { | |
+ socketPath string | |
+ | |
+ conn *Client | |
+ in <-chan forward | |
+} | |
+ | |
+// Accept waits for and returns the next connection to the listener. | |
+func (l *unixListener) Accept() (net.Conn, error) { | |
+ s, ok := <-l.in | |
+ if !ok { | |
+ return nil, io.EOF | |
+ } | |
+ ch, incoming, err := s.newCh.Accept() | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ go DiscardRequests(incoming) | |
+ | |
+ return &chanConn{ | |
+ Channel: ch, | |
+ laddr: &net.UnixAddr{ | |
+ Name: l.socketPath, | |
+ Net: "unix", | |
+ }, | |
+ raddr: &net.UnixAddr{ | |
+ Name: "@", | |
+ Net: "unix", | |
+ }, | |
+ }, nil | |
+} | |
+ | |
+// Close closes the listener. | |
+func (l *unixListener) Close() error { | |
+ // this also closes the listener. | |
+ l.conn.forwards.remove(&net.UnixAddr{Name: l.socketPath, Net: "unix"}) | |
+ m := streamLocalChannelForwardMsg{ | |
+ l.socketPath, | |
+ } | |
+ ok, _, err := l.conn.SendRequest("[email protected]", true, Marshal(&m)) | |
+ if err == nil && !ok { | |
+ err = errors.New("ssh: [email protected] failed") | |
+ } | |
+ return err | |
+} | |
+ | |
+// Addr returns the listener's network address. | |
+func (l *unixListener) Addr() net.Addr { | |
+ return &net.UnixAddr{ | |
+ Name: l.socketPath, | |
+ Net: "unix", | |
+ } | |
+} | |
diff --git a/ssh/tcpip.go b/ssh/tcpip.go | |
index 6151241..acf1717 100644 | |
--- a/ssh/tcpip.go | |
+++ b/ssh/tcpip.go | |
@@ -20,12 +20,20 @@ import ( | |
// addr. Incoming connections will be available by calling Accept on | |
// the returned net.Listener. The listener must be serviced, or the | |
// SSH connection may hang. | |
+// N must be "tcp", "tcp4", "tcp6", or "unix". | |
func (c *Client) Listen(n, addr string) (net.Listener, error) { | |
- laddr, err := net.ResolveTCPAddr(n, addr) | |
- if err != nil { | |
- return nil, err | |
+ switch n { | |
+ case "tcp", "tcp4", "tcp6": | |
+ laddr, err := net.ResolveTCPAddr(n, addr) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ return c.ListenTCP(laddr) | |
+ case "unix": | |
+ return c.ListenUnix(addr) | |
+ default: | |
+ return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) | |
} | |
- return c.ListenTCP(laddr) | |
} | |
// Automatic port allocation is broken with OpenSSH before 6.0. See | |
@@ -116,7 +124,7 @@ func (c *Client) ListenTCP(laddr *net.TCPAddr) (net.Listener, error) { | |
} | |
// Register this forward, using the port number we obtained. | |
- ch := c.forwards.add(*laddr) | |
+ ch := c.forwards.add(laddr) | |
return &tcpListener{laddr, c, ch}, nil | |
} | |
@@ -131,7 +139,7 @@ type forwardList struct { | |
// forwardEntry represents an established mapping of a laddr on a | |
// remote ssh server to a channel connected to a tcpListener. | |
type forwardEntry struct { | |
- laddr net.TCPAddr | |
+ laddr net.Addr | |
c chan forward | |
} | |
@@ -139,16 +147,16 @@ type forwardEntry struct { | |
// arguments to add/remove/lookup should be address as specified in | |
// the original forward-request. | |
type forward struct { | |
- newCh NewChannel // the ssh client channel underlying this forward | |
- raddr *net.TCPAddr // the raddr of the incoming connection | |
+ newCh NewChannel // the ssh client channel underlying this forward | |
+ raddr net.Addr // the raddr of the incoming connection | |
} | |
-func (l *forwardList) add(addr net.TCPAddr) chan forward { | |
+func (l *forwardList) add(addr net.Addr) chan forward { | |
l.Lock() | |
defer l.Unlock() | |
f := forwardEntry{ | |
- addr, | |
- make(chan forward, 1), | |
+ laddr: addr, | |
+ c: make(chan forward, 1), | |
} | |
l.entries = append(l.entries, f) | |
return f.c | |
@@ -176,44 +184,69 @@ func parseTCPAddr(addr string, port uint32) (*net.TCPAddr, error) { | |
func (l *forwardList) handleChannels(in <-chan NewChannel) { | |
for ch := range in { | |
- var payload forwardedTCPPayload | |
- if err := Unmarshal(ch.ExtraData(), &payload); err != nil { | |
- ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) | |
- continue | |
+ var ( | |
+ laddr net.Addr | |
+ raddr net.Addr | |
+ err error | |
+ ) | |
+ switch channelType := ch.ChannelType(); channelType { | |
+ case "forwarded-tcpip": | |
+ var payload forwardedTCPPayload | |
+ if err = Unmarshal(ch.ExtraData(), &payload); err != nil { | |
+ ch.Reject(ConnectionFailed, "could not parse forwarded-tcpip payload: "+err.Error()) | |
+ continue | |
+ } | |
+ | |
+ // RFC 4254 section 7.2 specifies that incoming | |
+ // addresses should list the address, in string | |
+ // format. It is implied that this should be an IP | |
+ // address, as it would be impossible to connect to it | |
+ // otherwise. | |
+ laddr, err = parseTCPAddr(payload.Addr, payload.Port) | |
+ if err != nil { | |
+ ch.Reject(ConnectionFailed, err.Error()) | |
+ continue | |
+ } | |
+ raddr, err = parseTCPAddr(payload.OriginAddr, payload.OriginPort) | |
+ if err != nil { | |
+ ch.Reject(ConnectionFailed, err.Error()) | |
+ continue | |
+ } | |
+ | |
+ case "[email protected]": | |
+ var payload forwardedStreamLocalPayload | |
+ if err = Unmarshal(ch.ExtraData(), &payload); err != nil { | |
+ ch.Reject(ConnectionFailed, "could not parse [email protected] payload: "+err.Error()) | |
+ continue | |
+ } | |
+ laddr = &net.UnixAddr{ | |
+ Name: payload.SocketPath, | |
+ Net: "unix", | |
+ } | |
+ raddr = &net.UnixAddr{ | |
+ Name: "@", | |
+ Net: "unix", | |
+ } | |
+ default: | |
+ panic(fmt.Errorf("ssh: unknown channel type %s", channelType)) | |
} | |
- | |
- // RFC 4254 section 7.2 specifies that incoming | |
- // addresses should list the address, in string | |
- // format. It is implied that this should be an IP | |
- // address, as it would be impossible to connect to it | |
- // otherwise. | |
- laddr, err := parseTCPAddr(payload.Addr, payload.Port) | |
- if err != nil { | |
- ch.Reject(ConnectionFailed, err.Error()) | |
- continue | |
- } | |
- raddr, err := parseTCPAddr(payload.OriginAddr, payload.OriginPort) | |
- if err != nil { | |
- ch.Reject(ConnectionFailed, err.Error()) | |
- continue | |
- } | |
- | |
- if ok := l.forward(*laddr, *raddr, ch); !ok { | |
+ if ok := l.forward(laddr, raddr, ch); !ok { | |
// Section 7.2, implementations MUST reject spurious incoming | |
// connections. | |
ch.Reject(Prohibited, "no forward for address") | |
continue | |
} | |
+ | |
} | |
} | |
// remove removes the forward entry, and the channel feeding its | |
// listener. | |
-func (l *forwardList) remove(addr net.TCPAddr) { | |
+func (l *forwardList) remove(addr net.Addr) { | |
l.Lock() | |
defer l.Unlock() | |
for i, f := range l.entries { | |
- if addr.IP.Equal(f.laddr.IP) && addr.Port == f.laddr.Port { | |
+ if addr.Network() == f.laddr.Network() && addr.String() == f.laddr.String() { | |
l.entries = append(l.entries[:i], l.entries[i+1:]...) | |
close(f.c) | |
return | |
@@ -231,12 +264,12 @@ func (l *forwardList) closeAll() { | |
l.entries = nil | |
} | |
-func (l *forwardList) forward(laddr, raddr net.TCPAddr, ch NewChannel) bool { | |
+func (l *forwardList) forward(laddr, raddr net.Addr, ch NewChannel) bool { | |
l.Lock() | |
defer l.Unlock() | |
for _, f := range l.entries { | |
- if laddr.IP.Equal(f.laddr.IP) && laddr.Port == f.laddr.Port { | |
- f.c <- forward{ch, &raddr} | |
+ if laddr.Network() == f.laddr.Network() && laddr.String() == f.laddr.String() { | |
+ f.c <- forward{newCh: ch, raddr: raddr} | |
return true | |
} | |
} | |
@@ -262,7 +295,7 @@ func (l *tcpListener) Accept() (net.Conn, error) { | |
} | |
go DiscardRequests(incoming) | |
- return &tcpChanConn{ | |
+ return &chanConn{ | |
Channel: ch, | |
laddr: l.laddr, | |
raddr: s.raddr, | |
@@ -277,7 +310,7 @@ func (l *tcpListener) Close() error { | |
} | |
// this also closes the listener. | |
- l.conn.forwards.remove(*l.laddr) | |
+ l.conn.forwards.remove(l.laddr) | |
ok, _, err := l.conn.SendRequest("cancel-tcpip-forward", true, Marshal(&m)) | |
if err == nil && !ok { | |
err = errors.New("ssh: cancel-tcpip-forward failed") | |
@@ -293,29 +326,52 @@ func (l *tcpListener) Addr() net.Addr { | |
// Dial initiates a connection to the addr from the remote host. | |
// The resulting connection has a zero LocalAddr() and RemoteAddr(). | |
func (c *Client) Dial(n, addr string) (net.Conn, error) { | |
- // Parse the address into host and numeric port. | |
- host, portString, err := net.SplitHostPort(addr) | |
- if err != nil { | |
- return nil, err | |
- } | |
- port, err := strconv.ParseUint(portString, 10, 16) | |
- if err != nil { | |
- return nil, err | |
- } | |
- // Use a zero address for local and remote address. | |
- zeroAddr := &net.TCPAddr{ | |
- IP: net.IPv4zero, | |
- Port: 0, | |
- } | |
- ch, err := c.dial(net.IPv4zero.String(), 0, host, int(port)) | |
- if err != nil { | |
- return nil, err | |
+ var ch Channel | |
+ switch n { | |
+ case "tcp", "tcp4", "tcp6": | |
+ // Parse the address into host and numeric port. | |
+ host, portString, err := net.SplitHostPort(addr) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ port, err := strconv.ParseUint(portString, 10, 16) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ ch, err = c.dial(net.IPv4zero.String(), 0, host, int(port)) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ // Use a zero address for local and remote address. | |
+ zeroAddr := &net.TCPAddr{ | |
+ IP: net.IPv4zero, | |
+ Port: 0, | |
+ } | |
+ return &chanConn{ | |
+ Channel: ch, | |
+ laddr: zeroAddr, | |
+ raddr: zeroAddr, | |
+ }, nil | |
+ case "unix": | |
+ var err error | |
+ ch, err = c.dialStreamLocal(addr) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ return &chanConn{ | |
+ Channel: ch, | |
+ laddr: &net.UnixAddr{ | |
+ Name: "@", | |
+ Net: "unix", | |
+ }, | |
+ raddr: &net.UnixAddr{ | |
+ Name: addr, | |
+ Net: "unix", | |
+ }, | |
+ }, nil | |
+ default: | |
+ return nil, fmt.Errorf("ssh: unsupported protocol: %s", n) | |
} | |
- return &tcpChanConn{ | |
- Channel: ch, | |
- laddr: zeroAddr, | |
- raddr: zeroAddr, | |
- }, nil | |
} | |
// DialTCP connects to the remote address raddr on the network net, | |
@@ -332,7 +388,7 @@ func (c *Client) DialTCP(n string, laddr, raddr *net.TCPAddr) (net.Conn, error) | |
if err != nil { | |
return nil, err | |
} | |
- return &tcpChanConn{ | |
+ return &chanConn{ | |
Channel: ch, | |
laddr: laddr, | |
raddr: raddr, | |
@@ -366,26 +422,26 @@ type tcpChan struct { | |
Channel // the backing channel | |
} | |
-// tcpChanConn fulfills the net.Conn interface without | |
+// chanConn fulfills the net.Conn interface without | |
// the tcpChan having to hold laddr or raddr directly. | |
-type tcpChanConn struct { | |
+type chanConn struct { | |
Channel | |
laddr, raddr net.Addr | |
} | |
// LocalAddr returns the local network address. | |
-func (t *tcpChanConn) LocalAddr() net.Addr { | |
+func (t *chanConn) LocalAddr() net.Addr { | |
return t.laddr | |
} | |
// RemoteAddr returns the remote network address. | |
-func (t *tcpChanConn) RemoteAddr() net.Addr { | |
+func (t *chanConn) RemoteAddr() net.Addr { | |
return t.raddr | |
} | |
// SetDeadline sets the read and write deadlines associated | |
// with the connection. | |
-func (t *tcpChanConn) SetDeadline(deadline time.Time) error { | |
+func (t *chanConn) SetDeadline(deadline time.Time) error { | |
if err := t.SetReadDeadline(deadline); err != nil { | |
return err | |
} | |
@@ -396,12 +452,14 @@ func (t *tcpChanConn) SetDeadline(deadline time.Time) error { | |
// A zero value for t means Read will not time out. | |
// After the deadline, the error from Read will implement net.Error | |
// with Timeout() == true. | |
-func (t *tcpChanConn) SetReadDeadline(deadline time.Time) error { | |
+func (t *chanConn) SetReadDeadline(deadline time.Time) error { | |
+ // for compatibility with previous version, | |
+ // the error message contains "tcpChan" | |
return errors.New("ssh: tcpChan: deadline not supported") | |
} | |
// SetWriteDeadline exists to satisfy the net.Conn interface | |
// but is not implemented by this type. It always returns an error. | |
-func (t *tcpChanConn) SetWriteDeadline(deadline time.Time) error { | |
+func (t *chanConn) SetWriteDeadline(deadline time.Time) error { | |
return errors.New("ssh: tcpChan: deadline not supported") | |
} | |
diff --git a/ssh/terminal/util_solaris.go b/ssh/terminal/util_solaris.go | |
index 07eb5ed..a2e1b57 100644 | |
--- a/ssh/terminal/util_solaris.go | |
+++ b/ssh/terminal/util_solaris.go | |
@@ -14,14 +14,12 @@ import ( | |
// State contains the state of a terminal. | |
type State struct { | |
- termios syscall.Termios | |
+ state *unix.Termios | |
} | |
// IsTerminal returns true if the given file descriptor is a terminal. | |
func IsTerminal(fd int) bool { | |
- // see: http://src.illumos.org/source/xref/illumos-gate/usr/src/lib/libbc/libc/gen/common/isatty.c | |
- var termio unix.Termio | |
- err := unix.IoctlSetTermio(fd, unix.TCGETA, &termio) | |
+ _, err := unix.IoctlGetTermio(fd, unix.TCGETA) | |
return err == nil | |
} | |
@@ -71,3 +69,60 @@ func ReadPassword(fd int) ([]byte, error) { | |
return ret, nil | |
} | |
+ | |
+// MakeRaw puts the terminal connected to the given file descriptor into raw | |
+// mode and returns the previous state of the terminal so that it can be | |
+// restored. | |
+// see http://cr.illumos.org/~webrev/andy_js/1060/ | |
+func MakeRaw(fd int) (*State, error) { | |
+ oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ oldTermios := *oldTermiosPtr | |
+ | |
+ newTermios := oldTermios | |
+ newTermios.Iflag &^= syscall.IGNBRK | syscall.BRKINT | syscall.PARMRK | syscall.ISTRIP | syscall.INLCR | syscall.IGNCR | syscall.ICRNL | syscall.IXON | |
+ newTermios.Oflag &^= syscall.OPOST | |
+ newTermios.Lflag &^= syscall.ECHO | syscall.ECHONL | syscall.ICANON | syscall.ISIG | syscall.IEXTEN | |
+ newTermios.Cflag &^= syscall.CSIZE | syscall.PARENB | |
+ newTermios.Cflag |= syscall.CS8 | |
+ newTermios.Cc[unix.VMIN] = 1 | |
+ newTermios.Cc[unix.VTIME] = 0 | |
+ | |
+ if err := unix.IoctlSetTermios(fd, unix.TCSETS, &newTermios); err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ return &State{ | |
+ state: oldTermiosPtr, | |
+ }, nil | |
+} | |
+ | |
+// Restore restores the terminal connected to the given file descriptor to a | |
+// previous state. | |
+func Restore(fd int, oldState *State) error { | |
+ return unix.IoctlSetTermios(fd, unix.TCSETS, oldState.state) | |
+} | |
+ | |
+// GetState returns the current state of a terminal which may be useful to | |
+// restore the terminal after a signal. | |
+func GetState(fd int) (*State, error) { | |
+ oldTermiosPtr, err := unix.IoctlGetTermios(fd, unix.TCGETS) | |
+ if err != nil { | |
+ return nil, err | |
+ } | |
+ | |
+ return &State{ | |
+ state: oldTermiosPtr, | |
+ }, nil | |
+} | |
+ | |
+// GetSize returns the dimensions of the given terminal. | |
+func GetSize(fd int) (width, height int, err error) { | |
+ ws, err := unix.IoctlGetWinsize(fd, unix.TIOCGWINSZ) | |
+ if err != nil { | |
+ return 0, 0, err | |
+ } | |
+ return int(ws.Col), int(ws.Row), nil | |
+} | |
diff --git a/ssh/test/cert_test.go b/ssh/test/cert_test.go | |
index 364790f..b231dd8 100644 | |
--- a/ssh/test/cert_test.go | |
+++ b/ssh/test/cert_test.go | |
@@ -7,12 +7,14 @@ | |
package test | |
import ( | |
+ "bytes" | |
"crypto/rand" | |
"testing" | |
"golang.org/x/crypto/ssh" | |
) | |
+// Test both logging in with a cert, and also that the certificate presented by an OpenSSH host can be validated correctly | |
func TestCertLogin(t *testing.T) { | |
s := newServer(t) | |
defer s.Shutdown() | |
@@ -37,11 +39,39 @@ func TestCertLogin(t *testing.T) { | |
conf := &ssh.ClientConfig{ | |
User: username(), | |
+ HostKeyCallback: (&ssh.CertChecker{ | |
+ IsHostAuthority: func(pk ssh.PublicKey, addr string) bool { | |
+ return bytes.Equal(pk.Marshal(), testPublicKeys["ca"].Marshal()) | |
+ }, | |
+ }).CheckHostKey, | |
} | |
conf.Auth = append(conf.Auth, ssh.PublicKeys(certSigner)) | |
- client, err := s.TryDial(conf) | |
- if err != nil { | |
- t.Fatalf("TryDial: %v", err) | |
+ | |
+ for _, test := range []struct { | |
+ addr string | |
+ succeed bool | |
+ }{ | |
+ {addr: "host.example.com:22", succeed: true}, | |
+ {addr: "host.example.com:10000", succeed: true}, // non-standard port must be OK | |
+ {addr: "host.example.com", succeed: false}, // port must be specified | |
+ {addr: "host.ex4mple.com:22", succeed: false}, // wrong host | |
+ } { | |
+ client, err := s.TryDialWithAddr(conf, test.addr) | |
+ | |
+ // Always close client if opened successfully | |
+ if err == nil { | |
+ client.Close() | |
+ } | |
+ | |
+ // Now evaluate whether the test failed or passed | |
+ if test.succeed { | |
+ if err != nil { | |
+ t.Fatalf("TryDialWithAddr: %v", err) | |
+ } | |
+ } else { | |
+ if err == nil { | |
+ t.Fatalf("TryDialWithAddr, unexpected success") | |
+ } | |
+ } | |
} | |
- client.Close() | |
} | |
diff --git a/ssh/test/dial_unix_test.go b/ssh/test/dial_unix_test.go | |
new file mode 100644 | |
index 0000000..091e48c | |
--- /dev/null | |
+++ b/ssh/test/dial_unix_test.go | |
@@ -0,0 +1,128 @@ | |
+// Copyright 2012 The Go Authors. All rights reserved. | |
+// Use of this source code is governed by a BSD-style | |
+// license that can be found in the LICENSE file. | |
+ | |
+// +build !windows | |
+ | |
+package test | |
+ | |
+// direct-tcpip and direct-streamlocal functional tests | |
+ | |
+import ( | |
+ "fmt" | |
+ "io" | |
+ "io/ioutil" | |
+ "net" | |
+ "strings" | |
+ "testing" | |
+) | |
+ | |
+type dialTester interface { | |
+ TestServerConn(t *testing.T, c net.Conn) | |
+ TestClientConn(t *testing.T, c net.Conn) | |
+} | |
+ | |
+func testDial(t *testing.T, n, listenAddr string, x dialTester) { | |
+ server := newServer(t) | |
+ defer server.Shutdown() | |
+ sshConn := server.Dial(clientConfig()) | |
+ defer sshConn.Close() | |
+ | |
+ l, err := net.Listen(n, listenAddr) | |
+ if err != nil { | |
+ t.Fatalf("Listen: %v", err) | |
+ } | |
+ defer l.Close() | |
+ | |
+ testData := fmt.Sprintf("hello from %s, %s", n, listenAddr) | |
+ go func() { | |
+ for { | |
+ c, err := l.Accept() | |
+ if err != nil { | |
+ break | |
+ } | |
+ x.TestServerConn(t, c) | |
+ | |
+ io.WriteString(c, testData) | |
+ c.Close() | |
+ } | |
+ }() | |
+ | |
+ conn, err := sshConn.Dial(n, l.Addr().String()) | |
+ if err != nil { | |
+ t.Fatalf("Dial: %v", err) | |
+ } | |
+ x.TestClientConn(t, conn) | |
+ defer conn.Close() | |
+ b, err := ioutil.ReadAll(conn) | |
+ if err != nil { | |
+ t.Fatalf("ReadAll: %v", err) | |
+ } | |
+ t.Logf("got %q", string(b)) | |
+ if string(b) != testData { | |
+ t.Fatalf("expected %q, got %q", testData, string(b)) | |
+ } | |
+} | |
+ | |
+type tcpDialTester struct { | |
+ listenAddr string | |
+} | |
+ | |
+func (x *tcpDialTester) TestServerConn(t *testing.T, c net.Conn) { | |
+ host := strings.Split(x.listenAddr, ":")[0] | |
+ prefix := host + ":" | |
+ if !strings.HasPrefix(c.LocalAddr().String(), prefix) { | |
+ t.Fatalf("expected to start with %q, got %q", prefix, c.LocalAddr().String()) | |
+ } | |
+ if !strings.HasPrefix(c.RemoteAddr().String(), prefix) { | |
+ t.Fatalf("expected to start with %q, got %q", prefix, c.RemoteAddr().String()) | |
+ } | |
+} | |
+ | |
+func (x *tcpDialTester) TestClientConn(t *testing.T, c net.Conn) { | |
+ // we use zero addresses. see *Client.Dial. | |
+ if c.LocalAddr().String() != "0.0.0.0:0" { | |
+ t.Fatalf("expected \"0.0.0.0:0\", got %q", c.LocalAddr().String()) | |
+ } | |
+ if c.RemoteAddr().String() != "0.0.0.0:0" { | |
+ t.Fatalf("expected \"0.0.0.0:0\", got %q", c.RemoteAddr().String()) | |
+ } | |
+} | |
+ | |
+func TestDialTCP(t *testing.T) { | |
+ x := &tcpDialTester{ | |
+ listenAddr: "127.0.0.1:0", | |
+ } | |
+ testDial(t, "tcp", x.listenAddr, x) | |
+} | |
+ | |
+type unixDialTester struct { | |
+ listenAddr string | |
+} | |
+ | |
+func (x *unixDialTester) TestServerConn(t *testing.T, c net.Conn) { | |
+ if c.LocalAddr().String() != x.listenAddr { | |
+ t.Fatalf("expected %q, got %q", x.listenAddr, c.LocalAddr().String()) | |
+ } | |
+ if c.RemoteAddr().String() != "@" { | |
+ t.Fatalf("expected \"@\", got %q", c.RemoteAddr().String()) | |
+ } | |
+} | |
+ | |
+func (x *unixDialTester) TestClientConn(t *testing.T, c net.Conn) { | |
+ if c.RemoteAddr().String() != x.listenAddr { | |
+ t.Fatalf("expected %q, got %q", x.listenAddr, c.RemoteAddr().String()) | |
+ } | |
+ if c.LocalAddr().String() != "@" { | |
+ t.Fatalf("expected \"@\", got %q", c.LocalAddr().String()) | |
+ } | |
+} | |
+ | |
+func TestDialUnix(t *testing.T) { | |
+ addr, cleanup := newTempSocket(t) | |
+ defer cleanup() | |
+ x := &unixDialTester{ | |
+ listenAddr: addr, | |
+ } | |
+ testDial(t, "unix", x.listenAddr, x) | |
+} | |
diff --git a/ssh/test/forward_unix_test.go b/ssh/test/forward_unix_test.go | |
index 877a88c..ea81937 100644 | |
--- a/ssh/test/forward_unix_test.go | |
+++ b/ssh/test/forward_unix_test.go | |
@@ -16,13 +16,17 @@ import ( | |
"time" | |
) | |
-func TestPortForward(t *testing.T) { | |
+type closeWriter interface { | |
+ CloseWrite() error | |
+} | |
+ | |
+func testPortForward(t *testing.T, n, listenAddr string) { | |
server := newServer(t) | |
defer server.Shutdown() | |
conn := server.Dial(clientConfig()) | |
defer conn.Close() | |
- sshListener, err := conn.Listen("tcp", "localhost:0") | |
+ sshListener, err := conn.Listen(n, listenAddr) | |
if err != nil { | |
t.Fatal(err) | |
} | |
@@ -41,14 +45,14 @@ func TestPortForward(t *testing.T) { | |
}() | |
forwardedAddr := sshListener.Addr().String() | |
- tcpConn, err := net.Dial("tcp", forwardedAddr) | |
+ netConn, err := net.Dial(n, forwardedAddr) | |
if err != nil { | |
- t.Fatalf("TCP dial failed: %v", err) | |
+ t.Fatalf("net dial failed: %v", err) | |
} | |
readChan := make(chan []byte) | |
go func() { | |
- data, _ := ioutil.ReadAll(tcpConn) | |
+ data, _ := ioutil.ReadAll(netConn) | |
readChan <- data | |
}() | |
@@ -62,14 +66,14 @@ func TestPortForward(t *testing.T) { | |
for len(sent) < 1000*1000 { | |
// Send random sized chunks | |
m := rand.Intn(len(data)) | |
- n, err := tcpConn.Write(data[:m]) | |
+ n, err := netConn.Write(data[:m]) | |
if err != nil { | |
break | |
} | |
sent = append(sent, data[:n]...) | |
} | |
- if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil { | |
- t.Errorf("tcpConn.CloseWrite: %v", err) | |
+ if err := netConn.(closeWriter).CloseWrite(); err != nil { | |
+ t.Errorf("netConn.CloseWrite: %v", err) | |
} | |
read := <-readChan | |
@@ -86,19 +90,29 @@ func TestPortForward(t *testing.T) { | |
} | |
// Check that the forward disappeared. | |
- tcpConn, err = net.Dial("tcp", forwardedAddr) | |
+ netConn, err = net.Dial(n, forwardedAddr) | |
if err == nil { | |
- tcpConn.Close() | |
+ netConn.Close() | |
t.Errorf("still listening to %s after closing", forwardedAddr) | |
} | |
} | |
-func TestAcceptClose(t *testing.T) { | |
+func TestPortForwardTCP(t *testing.T) { | |
+ testPortForward(t, "tcp", "localhost:0") | |
+} | |
+ | |
+func TestPortForwardUnix(t *testing.T) { | |
+ addr, cleanup := newTempSocket(t) | |
+ defer cleanup() | |
+ testPortForward(t, "unix", addr) | |
+} | |
+ | |
+func testAcceptClose(t *testing.T, n, listenAddr string) { | |
server := newServer(t) | |
defer server.Shutdown() | |
conn := server.Dial(clientConfig()) | |
- sshListener, err := conn.Listen("tcp", "localhost:0") | |
+ sshListener, err := conn.Listen(n, listenAddr) | |
if err != nil { | |
t.Fatal(err) | |
} | |
@@ -124,13 +138,23 @@ func TestAcceptClose(t *testing.T) { | |
} | |
} | |
+func TestAcceptCloseTCP(t *testing.T) { | |
+ testAcceptClose(t, "tcp", "localhost:0") | |
+} | |
+ | |
+func TestAcceptCloseUnix(t *testing.T) { | |
+ addr, cleanup := newTempSocket(t) | |
+ defer cleanup() | |
+ testAcceptClose(t, "unix", addr) | |
+} | |
+ | |
// Check that listeners exit if the underlying client transport dies. | |
-func TestPortForwardConnectionClose(t *testing.T) { | |
+func testPortForwardConnectionClose(t *testing.T, n, listenAddr string) { | |
server := newServer(t) | |
defer server.Shutdown() | |
conn := server.Dial(clientConfig()) | |
- sshListener, err := conn.Listen("tcp", "localhost:0") | |
+ sshListener, err := conn.Listen(n, listenAddr) | |
if err != nil { | |
t.Fatal(err) | |
} | |
@@ -158,3 +182,13 @@ func TestPortForwardConnectionClose(t *testing.T) { | |
t.Logf("quit as expected (error %v)", err) | |
} | |
} | |
+ | |
+func TestPortForwardConnectionCloseTCP(t *testing.T) { | |
+ testPortForwardConnectionClose(t, "tcp", "localhost:0") | |
+} | |
+ | |
+func TestPortForwardConnectionCloseUnix(t *testing.T) { | |
+ addr, cleanup := newTempSocket(t) | |
+ defer cleanup() | |
+ testPortForwardConnectionClose(t, "unix", addr) | |
+} | |
diff --git a/ssh/test/tcpip_test.go b/ssh/test/tcpip_test.go | |
deleted file mode 100644 | |
index a2eb935..0000000 | |
--- a/ssh/test/tcpip_test.go | |
+++ /dev/null | |
@@ -1,46 +0,0 @@ | |
-// Copyright 2012 The Go Authors. All rights reserved. | |
-// Use of this source code is governed by a BSD-style | |
-// license that can be found in the LICENSE file. | |
- | |
-// +build !windows | |
- | |
-package test | |
- | |
-// direct-tcpip functional tests | |
- | |
-import ( | |
- "io" | |
- "net" | |
- "testing" | |
-) | |
- | |
-func TestDial(t *testing.T) { | |
- server := newServer(t) | |
- defer server.Shutdown() | |
- sshConn := server.Dial(clientConfig()) | |
- defer sshConn.Close() | |
- | |
- l, err := net.Listen("tcp", "127.0.0.1:0") | |
- if err != nil { | |
- t.Fatalf("Listen: %v", err) | |
- } | |
- defer l.Close() | |
- | |
- go func() { | |
- for { | |
- c, err := l.Accept() | |
- if err != nil { | |
- break | |
- } | |
- | |
- io.WriteString(c, c.RemoteAddr().String()) | |
- c.Close() | |
- } | |
- }() | |
- | |
- conn, err := sshConn.Dial("tcp", l.Addr().String()) | |
- if err != nil { | |
- t.Fatalf("Dial: %v", err) | |
- } | |
- defer conn.Close() | |
-} | |
diff --git a/ssh/test/test_unix_test.go b/ssh/test/test_unix_test.go | |
index 3bfd881..e673536 100644 | |
--- a/ssh/test/test_unix_test.go | |
+++ b/ssh/test/test_unix_test.go | |
@@ -30,6 +30,7 @@ Protocol 2 | |
HostKey {{.Dir}}/id_rsa | |
HostKey {{.Dir}}/id_dsa | |
HostKey {{.Dir}}/id_ecdsa | |
+HostCertificate {{.Dir}}/id_rsa-cert.pub | |
Pidfile {{.Dir}}/sshd.pid | |
#UsePrivilegeSeparation no | |
KeyRegenerationInterval 3600 | |
@@ -119,6 +120,11 @@ func clientConfig() *ssh.ClientConfig { | |
ssh.PublicKeys(testSigners["user"]), | |
}, | |
HostKeyCallback: hostKeyDB().Check, | |
+ HostKeyAlgorithms: []string{ // by default, don't allow certs as this affects the hostKeyDB checker | |
+ ssh.KeyAlgoECDSA256, ssh.KeyAlgoECDSA384, ssh.KeyAlgoECDSA521, | |
+ ssh.KeyAlgoRSA, ssh.KeyAlgoDSA, | |
+ ssh.KeyAlgoED25519, | |
+ }, | |
} | |
return config | |
} | |
@@ -154,6 +160,12 @@ func unixConnection() (*net.UnixConn, *net.UnixConn, error) { | |
} | |
func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) { | |
+ return s.TryDialWithAddr(config, "") | |
+} | |
+ | |
+// addr is the user specified host:port. While we don't actually dial it, | |
+// we need to know this for host key matching | |
+func (s *server) TryDialWithAddr(config *ssh.ClientConfig, addr string) (*ssh.Client, error) { | |
sshd, err := exec.LookPath("sshd") | |
if err != nil { | |
s.t.Skipf("skipping test: %v", err) | |
@@ -179,7 +191,7 @@ func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) { | |
s.t.Fatalf("s.cmd.Start: %v", err) | |
} | |
s.clientConn = c1 | |
- conn, chans, reqs, err := ssh.NewClientConn(c1, "", config) | |
+ conn, chans, reqs, err := ssh.NewClientConn(c1, addr, config) | |
if err != nil { | |
return nil, err | |
} | |
@@ -250,6 +262,11 @@ func newServer(t *testing.T) *server { | |
writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k])) | |
} | |
+ for k, v := range testdata.SSHCertificates { | |
+ filename := "id_" + k + "-cert.pub" | |
+ writeFile(filepath.Join(dir, filename), v) | |
+ } | |
+ | |
var authkeys bytes.Buffer | |
for k, _ := range testdata.PEMBytes { | |
authkeys.Write(ssh.MarshalAuthorizedKey(testPublicKeys[k])) | |
@@ -266,3 +283,13 @@ func newServer(t *testing.T) *server { | |
}, | |
} | |
} | |
+ | |
+func newTempSocket(t *testing.T) (string, func()) { | |
+ dir, err := ioutil.TempDir("", "socket") | |
+ if err != nil { | |
+ t.Fatal(err) | |
+ } | |
+ deferFunc := func() { os.RemoveAll(dir) } | |
+ addr := filepath.Join(dir, "sock") | |
+ return addr, deferFunc | |
+} | |
diff --git a/ssh/testdata/keys.go b/ssh/testdata/keys.go | |
index 736dad9..3b3d26c 100644 | |
--- a/ssh/testdata/keys.go | |
+++ b/ssh/testdata/keys.go | |
@@ -48,12 +48,69 @@ AAAEAaYmXltfW6nhRo3iWGglRB48lYq0z0Q3I3KyrdutEr6j7d/uFLuDlRbBc4ZVOsx+Gb | |
HKuOrPtLHFvHsjWPwO+/AAAAE2dhcnRvbm1AZ2FydG9ubS14cHMBAg== | |
-----END OPENSSH PRIVATE KEY----- | |
`), | |
+ "rsa-openssh-format": []byte(`-----BEGIN OPENSSH PRIVATE KEY----- | |
+b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAlwAAAAdzc2gtcn | |
+NhAAAAAwEAAQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5l | |
+oEuW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lz | |
+a+yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAIQWL0H31i9B98AAAAH | |
+c3NoLXJzYQAAAIEAwa48yfWFi3uIdqzuf9X7C2Zxfea/Iaaw0zIwHudpF8U92WVIiC5loE | |
+uW1+OaVi3UWfIEjWMV1tHGysrHOwtwc34BPCJqJknUQO/KtDTBTJ4Pryhw1bWPC999Lza+ | |
+yrCTdNQYBzoROXKExZgPFh9pTMi5wqpHDuOQ2qZFIEI3lT0AAAADAQABAAAAgCThyTGsT4 | |
+IARDxVMhWl6eiB2ZrgFgWSeJm/NOqtppWgOebsIqPMMg4UVuVFsl422/lE3RkPhVkjGXgE | |
+pWvZAdCnmLmApK8wK12vF334lZhZT7t3Z9EzJps88PWEHo7kguf285HcnUM7FlFeissJdk | |
+kXly34y7/3X/a6Tclm+iABAAAAQE0xR/KxZ39slwfMv64Rz7WKk1PPskaryI29aHE3mKHk | |
+pY2QA+P3QlrKxT/VWUMjHUbNNdYfJm48xu0SGNMRdKMAAABBAORh2NP/06JUV3J9W/2Hju | |
+X1ViJuqqcQnJPVzpgSL826EC2xwOECTqoY8uvFpUdD7CtpksIxNVqRIhuNOlz0lqEAAABB | |
+ANkaHTTaPojClO0dKJ/Zjs7pWOCGliebBYprQ/Y4r9QLBkC/XaWMS26gFIrjgC7D2Rv+rZ | |
+wSD0v0RcmkITP1ZR0AAAAYcHF1ZXJuYUBMdWNreUh5ZHJvLmxvY2FsAQID | |
+-----END OPENSSH PRIVATE KEY-----`), | |
"user": []byte(`-----BEGIN EC PRIVATE KEY----- | |
MHcCAQEEILYCAeq8f7V4vSSypRw7pxy8yz3V5W4qg8kSC3zJhqpQoAoGCCqGSM49 | |
AwEHoUQDQgAEYcO2xNKiRUYOLEHM7VYAp57HNyKbOdYtHD83Z4hzNPVC4tM5mdGD | |
PLL8IEwvYu2wq+lpXfGQnNMbzYf9gspG0w== | |
-----END EC PRIVATE KEY----- | |
`), | |
+ "ca": []byte(`-----BEGIN RSA PRIVATE KEY----- | |
+MIIEpAIBAAKCAQEAvg9dQ9IRG59lYJb+GESfKWTch4yBpr7Ydw1jkK6vvtrx9jLo | |
+5hkA8X6+ElRPRqTAZSlN5cBm6YCAcQIOsmXDUn6Oj1lVPQAoOjTBTvsjM3NjGhvv | |
+52kHTY0nsMsBeY9q5DTtlzmlYkVUq2a6Htgf2mNi01dIw5fJ7uTTo8EbNf7O0i3u | |
+c9a8P19HaZl5NKiWN4EIZkfB2WdXYRJCVBsGgQj3dE/GrEmH9QINq1A+GkNvK96u | |
+vZm8H1jjmuqzHplWa7lFeXcx8FTVTbVb/iJrZ2Lc/JvIPitKZWhqbR59yrGjpwEp | |
+Id7bo4WhO5L3OB0fSIJYvfu+o4WYnt4f3UzecwIDAQABAoIBABRD9yHgKErVuC2Q | |
+bA+SYZY8VvdtF/X7q4EmQFORDNRA7EPgMc03JU6awRGbQ8i4kHs46EFzPoXvWcKz | |
+AXYsO6N0Myc900Tp22A5d9NAHATEbPC/wdje7hRq1KyZONMJY9BphFv3nZbY5apR | |
+Dc90JBFZP5RhXjTc3n9GjvqLAKfFEKVmPRCvqxCOZunw6XR+SgIQLJo36nsIsbhW | |
+QUXIVaCI6cXMN8bRPm8EITdBNZu06Fpu4ZHm6VaxlXN9smERCDkgBSNXNWHKxmmA | |
+c3Glo2DByUr2/JFBOrLEe9fkYgr24KNCQkHVcSaFxEcZvTggr7StjKISVHlCNEaB | |
+7Q+kPoECgYEA3zE9FmvFGoQCU4g4Nl3dpQHs6kaAW8vJlrmq3xsireIuaJoa2HMe | |
+wYdIvgCnK9DIjyxd5OWnE4jXtAEYPsyGD32B5rSLQrRO96lgb3f4bESCLUb3Bsn/ | |
+sdgeE3p1xZMA0B59htqCrvVgN9k8WxyevBxYl3/gSBm/p8OVH1RTW/ECgYEA2f9Z | |
+95OLj0KQHQtxQXf+I3VjhCw3LkLW39QZOXVI0QrCJfqqP7uxsJXH9NYX0l0GFTcR | |
+kRrlyoaSU1EGQosZh+n1MvplGBTkTSV47/bPsTzFpgK2NfEZuFm9RoWgltS+nYeH | |
+Y2k4mnAN3PhReCMwuprmJz8GRLsO3Cs2s2YylKMCgYEA2UX+uO/q7jgqZ5UJW+ue | |
+1H5+W0aMuFA3i7JtZEnvRaUVFqFGlwXin/WJ2+WY1++k/rPrJ+Rk9IBXtBUIvEGw | |
+FC5TIfsKQsJyyWgqx/jbbtJ2g4s8+W/1qfTAuqeRNOg5d2DnRDs90wJuS4//0JaY | |
+9HkHyVwkQyxFxhSA/AHEMJECgYA2MvyFR1O9bIk0D3I7GsA+xKLXa77Ua53MzIjw | |
+9i4CezBGDQpjCiFli/fI8am+jY5DnAtsDknvjoG24UAzLy5L0mk6IXMdB6SzYYut | |
+7ak5oahqW+Y9hxIj+XvLmtGQbphtxhJtLu35x75KoBpxSh6FZpmuTEccs31AVCYn | |
+eFM/DQKBgQDOPUwbLKqVi6ddFGgrV9MrWw+SWsDa43bPuyvYppMM3oqesvyaX1Dt | |
+qDvN7owaNxNM4OnfKcZr91z8YPVCFo4RbBif3DXRzjNNBlxEjHBtuMOikwvsmucN | |
+vIrbeEpjTiUMTEAr6PoTiVHjsfS8WAM6MDlF5M+2PNswDsBpa2yLgA== | |
+-----END RSA PRIVATE KEY----- | |
+`), | |
+} | |
+ | |
+var SSHCertificates = map[string][]byte{ | |
+ // The following are corresponding certificates for the private keys above, signed by the CA key | |
+ // Generated by the following commands: | |
+ // | |
+ // 1. Assumes "rsa" key above in file named "rsa", write out the public key to "rsa.pub": | |
+ // ssh-keygen -y -f rsa > rsa.pu | |
+ // | |
+ // 2. Assumes "ca" key above in file named "ca", sign a cert for "rsa.pub": | |
+ // ssh-keygen -s ca -h -n host.example.com -V +500w -I host.example.com-key rsa.pub | |
+ "rsa": []byte(`[email protected] AAAAHHNzaC1yc2EtY2VydC12MDFAb3BlbnNzaC5jb20AAAAgLjYqmmuTSEmjVhSfLQphBSTJMLwIZhRgmpn8FHKLiEIAAAADAQABAAAAgQC8A6FGHDiWCSREAXCq6yBfNVr0xCVG2CzvktFNRpue+RXrGs/2a6ySEJQb3IYquw7HlJgu6fg3WIWhOmHCjfpG0PrL4CRwbqQ2LaPPXhJErWYejcD8Di00cF3677+G10KMZk9RXbmHtuBFZT98wxg8j+ZsBMqGM1+7yrWUvynswQAAAAAAAAAAAAAAAgAAABRob3N0LmV4YW1wbGUuY29tLWtleQAAABQAAAAQaG9zdC5leGFtcGxlLmNvbQAAAABZHN8UAAAAAGsjIYUAAAAAAAAAAAAAAAAAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQC+D11D0hEbn2Vglv4YRJ8pZNyHjIGmvth3DWOQrq++2vH2MujmGQDxfr4SVE9GpMBlKU3lwGbpgIBxAg6yZcNSfo6PWVU9ACg6NMFO+yMzc2MaG+/naQdNjSewywF5j2rkNO2XOaViRVSrZroe2B/aY2LTV0jDl8nu5NOjwRs1/s7SLe5z1rw/X0dpmXk0qJY3gQhmR8HZZ1dhEkJUGwaBCPd0T8asSYf1Ag2rUD4aQ28r3q69mbwfWOOa6rMemVZruUV5dzHwVNVNtVv+ImtnYtz8m8g+K0plaGptHn3KsaOnASkh3tujhaE7kvc4HR9Igli9+76jhZie3h/dTN5zAAABDwAAAAdzc2gtcnNhAAABALeDea+60H6xJGhktAyosHaSY7AYzLocaqd8hJQjEIDifBwzoTlnBmcK9CxGhKuaoJFThdCLdaevCeOSuquh8HTkf+2ebZZc/G5T+2thPvPqmcuEcmMosWo+SIjYhbP3S6KD49aLC1X0kz8IBQeauFvURhkZ5ZjhA1L4aQYt9NjL73nqOl8PplRui+Ov5w8b4ldul4zOvYAFrzfcP6wnnXk3c1Zzwwf5wynD5jakO8GpYKBuhM7Z4crzkKSQjU3hla7xqgfomC5Gz4XbR2TNjcQiRrJQ0UlKtX3X3ObRCEhuvG0Kzjklhv+Ddw6txrhKjMjiSi/Yyius/AE8TmC1p4U= host.example.com | |
+`), | |
} | |
var PEMEncryptedKeys = []struct { |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment