test_unix_test.go 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. // Copyright 2012 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // +build darwin dragonfly freebsd linux netbsd openbsd plan9
  5. package test
  6. // functional test harness for unix.
  7. import (
  8. "bytes"
  9. "fmt"
  10. "io/ioutil"
  11. "log"
  12. "net"
  13. "os"
  14. "os/exec"
  15. "os/user"
  16. "path/filepath"
  17. "testing"
  18. "text/template"
  19. "golang.org/x/crypto/ssh"
  20. "golang.org/x/crypto/ssh/testdata"
  21. )
  22. const sshd_config = `
  23. Protocol 2
  24. HostKey {{.Dir}}/id_rsa
  25. HostKey {{.Dir}}/id_dsa
  26. HostKey {{.Dir}}/id_ecdsa
  27. Pidfile {{.Dir}}/sshd.pid
  28. #UsePrivilegeSeparation no
  29. KeyRegenerationInterval 3600
  30. ServerKeyBits 768
  31. SyslogFacility AUTH
  32. LogLevel DEBUG2
  33. LoginGraceTime 120
  34. PermitRootLogin no
  35. StrictModes no
  36. RSAAuthentication yes
  37. PubkeyAuthentication yes
  38. AuthorizedKeysFile {{.Dir}}/id_user.pub
  39. TrustedUserCAKeys {{.Dir}}/id_ecdsa.pub
  40. IgnoreRhosts yes
  41. RhostsRSAAuthentication no
  42. HostbasedAuthentication no
  43. `
  44. var configTmpl = template.Must(template.New("").Parse(sshd_config))
  45. type server struct {
  46. t *testing.T
  47. cleanup func() // executed during Shutdown
  48. configfile string
  49. cmd *exec.Cmd
  50. output bytes.Buffer // holds stderr from sshd process
  51. // Client half of the network connection.
  52. clientConn net.Conn
  53. }
  54. func username() string {
  55. var username string
  56. if user, err := user.Current(); err == nil {
  57. username = user.Username
  58. } else {
  59. // user.Current() currently requires cgo. If an error is
  60. // returned attempt to get the username from the environment.
  61. log.Printf("user.Current: %v; falling back on $USER", err)
  62. username = os.Getenv("USER")
  63. }
  64. if username == "" {
  65. panic("Unable to get username")
  66. }
  67. return username
  68. }
  69. type storedHostKey struct {
  70. // keys map from an algorithm string to binary key data.
  71. keys map[string][]byte
  72. // checkCount counts the Check calls. Used for testing
  73. // rekeying.
  74. checkCount int
  75. }
  76. func (k *storedHostKey) Add(key ssh.PublicKey) {
  77. if k.keys == nil {
  78. k.keys = map[string][]byte{}
  79. }
  80. k.keys[key.Type()] = key.Marshal()
  81. }
  82. func (k *storedHostKey) Check(addr string, remote net.Addr, key ssh.PublicKey) error {
  83. k.checkCount++
  84. algo := key.Type()
  85. if k.keys == nil || bytes.Compare(key.Marshal(), k.keys[algo]) != 0 {
  86. return fmt.Errorf("host key mismatch. Got %q, want %q", key, k.keys[algo])
  87. }
  88. return nil
  89. }
  90. func hostKeyDB() *storedHostKey {
  91. keyChecker := &storedHostKey{}
  92. keyChecker.Add(testPublicKeys["ecdsa"])
  93. keyChecker.Add(testPublicKeys["rsa"])
  94. keyChecker.Add(testPublicKeys["dsa"])
  95. return keyChecker
  96. }
  97. func clientConfig() *ssh.ClientConfig {
  98. config := &ssh.ClientConfig{
  99. User: username(),
  100. Auth: []ssh.AuthMethod{
  101. ssh.PublicKeys(testSigners["user"]),
  102. },
  103. HostKeyCallback: hostKeyDB().Check,
  104. }
  105. return config
  106. }
  107. // unixConnection creates two halves of a connected net.UnixConn. It
  108. // is used for connecting the Go SSH client with sshd without opening
  109. // ports.
  110. func unixConnection() (*net.UnixConn, *net.UnixConn, error) {
  111. dir, err := ioutil.TempDir("", "unixConnection")
  112. if err != nil {
  113. return nil, nil, err
  114. }
  115. defer os.Remove(dir)
  116. addr := filepath.Join(dir, "ssh")
  117. listener, err := net.Listen("unix", addr)
  118. if err != nil {
  119. return nil, nil, err
  120. }
  121. defer listener.Close()
  122. c1, err := net.Dial("unix", addr)
  123. if err != nil {
  124. return nil, nil, err
  125. }
  126. c2, err := listener.Accept()
  127. if err != nil {
  128. c1.Close()
  129. return nil, nil, err
  130. }
  131. return c1.(*net.UnixConn), c2.(*net.UnixConn), nil
  132. }
  133. func (s *server) TryDial(config *ssh.ClientConfig) (*ssh.Client, error) {
  134. sshd, err := exec.LookPath("sshd")
  135. if err != nil {
  136. s.t.Skipf("skipping test: %v", err)
  137. }
  138. c1, c2, err := unixConnection()
  139. if err != nil {
  140. s.t.Fatalf("unixConnection: %v", err)
  141. }
  142. s.cmd = exec.Command(sshd, "-f", s.configfile, "-i", "-e")
  143. f, err := c2.File()
  144. if err != nil {
  145. s.t.Fatalf("UnixConn.File: %v", err)
  146. }
  147. defer f.Close()
  148. s.cmd.Stdin = f
  149. s.cmd.Stdout = f
  150. s.cmd.Stderr = &s.output
  151. if err := s.cmd.Start(); err != nil {
  152. s.t.Fail()
  153. s.Shutdown()
  154. s.t.Fatalf("s.cmd.Start: %v", err)
  155. }
  156. s.clientConn = c1
  157. conn, chans, reqs, err := ssh.NewClientConn(c1, "", config)
  158. if err != nil {
  159. return nil, err
  160. }
  161. return ssh.NewClient(conn, chans, reqs), nil
  162. }
  163. func (s *server) Dial(config *ssh.ClientConfig) *ssh.Client {
  164. conn, err := s.TryDial(config)
  165. if err != nil {
  166. s.t.Fail()
  167. s.Shutdown()
  168. s.t.Fatalf("ssh.Client: %v", err)
  169. }
  170. return conn
  171. }
  172. func (s *server) Shutdown() {
  173. if s.cmd != nil && s.cmd.Process != nil {
  174. // Don't check for errors; if it fails it's most
  175. // likely "os: process already finished", and we don't
  176. // care about that. Use os.Interrupt, so child
  177. // processes are killed too.
  178. s.cmd.Process.Signal(os.Interrupt)
  179. s.cmd.Wait()
  180. }
  181. if s.t.Failed() {
  182. // log any output from sshd process
  183. s.t.Logf("sshd: %s", s.output.String())
  184. }
  185. s.cleanup()
  186. }
  187. func writeFile(path string, contents []byte) {
  188. f, err := os.OpenFile(path, os.O_WRONLY|os.O_TRUNC|os.O_CREATE, 0600)
  189. if err != nil {
  190. panic(err)
  191. }
  192. defer f.Close()
  193. if _, err := f.Write(contents); err != nil {
  194. panic(err)
  195. }
  196. }
  197. // newServer returns a new mock ssh server.
  198. func newServer(t *testing.T) *server {
  199. if testing.Short() {
  200. t.Skip("skipping test due to -short")
  201. }
  202. dir, err := ioutil.TempDir("", "sshtest")
  203. if err != nil {
  204. t.Fatal(err)
  205. }
  206. f, err := os.Create(filepath.Join(dir, "sshd_config"))
  207. if err != nil {
  208. t.Fatal(err)
  209. }
  210. err = configTmpl.Execute(f, map[string]string{
  211. "Dir": dir,
  212. })
  213. if err != nil {
  214. t.Fatal(err)
  215. }
  216. f.Close()
  217. for k, v := range testdata.PEMBytes {
  218. filename := "id_" + k
  219. writeFile(filepath.Join(dir, filename), v)
  220. writeFile(filepath.Join(dir, filename+".pub"), ssh.MarshalAuthorizedKey(testPublicKeys[k]))
  221. }
  222. return &server{
  223. t: t,
  224. configfile: f.Name(),
  225. cleanup: func() {
  226. if err := os.RemoveAll(dir); err != nil {
  227. t.Error(err)
  228. }
  229. },
  230. }
  231. }