forward_unix_test.go 3.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. // Copyright 2012 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. // +build darwin dragonfly freebsd linux netbsd openbsd
  5. package test
  6. import (
  7. "bytes"
  8. "io"
  9. "io/ioutil"
  10. "math/rand"
  11. "net"
  12. "testing"
  13. "time"
  14. )
  15. func TestPortForward(t *testing.T) {
  16. server := newServer(t)
  17. defer server.Shutdown()
  18. conn := server.Dial(clientConfig())
  19. defer conn.Close()
  20. sshListener, err := conn.Listen("tcp", "localhost:0")
  21. if err != nil {
  22. t.Fatal(err)
  23. }
  24. go func() {
  25. sshConn, err := sshListener.Accept()
  26. if err != nil {
  27. t.Fatalf("listen.Accept failed: %v", err)
  28. }
  29. _, err = io.Copy(sshConn, sshConn)
  30. if err != nil && err != io.EOF {
  31. t.Fatalf("ssh client copy: %v", err)
  32. }
  33. sshConn.Close()
  34. }()
  35. forwardedAddr := sshListener.Addr().String()
  36. tcpConn, err := net.Dial("tcp", forwardedAddr)
  37. if err != nil {
  38. t.Fatalf("TCP dial failed: %v", err)
  39. }
  40. readChan := make(chan []byte)
  41. go func() {
  42. data, _ := ioutil.ReadAll(tcpConn)
  43. readChan <- data
  44. }()
  45. // Invent some data.
  46. data := make([]byte, 100*1000)
  47. for i := range data {
  48. data[i] = byte(i % 255)
  49. }
  50. var sent []byte
  51. for len(sent) < 1000*1000 {
  52. // Send random sized chunks
  53. m := rand.Intn(len(data))
  54. n, err := tcpConn.Write(data[:m])
  55. if err != nil {
  56. break
  57. }
  58. sent = append(sent, data[:n]...)
  59. }
  60. if err := tcpConn.(*net.TCPConn).CloseWrite(); err != nil {
  61. t.Errorf("tcpConn.CloseWrite: %v", err)
  62. }
  63. read := <-readChan
  64. if len(sent) != len(read) {
  65. t.Fatalf("got %d bytes, want %d", len(read), len(sent))
  66. }
  67. if bytes.Compare(sent, read) != 0 {
  68. t.Fatalf("read back data does not match")
  69. }
  70. if err := sshListener.Close(); err != nil {
  71. t.Fatalf("sshListener.Close: %v", err)
  72. }
  73. // Check that the forward disappeared.
  74. tcpConn, err = net.Dial("tcp", forwardedAddr)
  75. if err == nil {
  76. tcpConn.Close()
  77. t.Errorf("still listening to %s after closing", forwardedAddr)
  78. }
  79. }
  80. func TestAcceptClose(t *testing.T) {
  81. server := newServer(t)
  82. defer server.Shutdown()
  83. conn := server.Dial(clientConfig())
  84. sshListener, err := conn.Listen("tcp", "localhost:0")
  85. if err != nil {
  86. t.Fatal(err)
  87. }
  88. quit := make(chan error, 1)
  89. go func() {
  90. for {
  91. c, err := sshListener.Accept()
  92. if err != nil {
  93. quit <- err
  94. break
  95. }
  96. c.Close()
  97. }
  98. }()
  99. sshListener.Close()
  100. select {
  101. case <-time.After(1 * time.Second):
  102. t.Errorf("timeout: listener did not close.")
  103. case err := <-quit:
  104. t.Logf("quit as expected (error %v)", err)
  105. }
  106. }
  107. // Check that listeners exit if the underlying client transport dies.
  108. func TestPortForwardConnectionClose(t *testing.T) {
  109. server := newServer(t)
  110. defer server.Shutdown()
  111. conn := server.Dial(clientConfig())
  112. sshListener, err := conn.Listen("tcp", "localhost:0")
  113. if err != nil {
  114. t.Fatal(err)
  115. }
  116. quit := make(chan error, 1)
  117. go func() {
  118. for {
  119. c, err := sshListener.Accept()
  120. if err != nil {
  121. quit <- err
  122. break
  123. }
  124. c.Close()
  125. }
  126. }()
  127. // It would be even nicer if we closed the server side, but it
  128. // is more involved as the fd for that side is dup()ed.
  129. server.clientConn.Close()
  130. select {
  131. case <-time.After(1 * time.Second):
  132. t.Errorf("timeout: listener did not close.")
  133. case err := <-quit:
  134. t.Logf("quit as expected (error %v)", err)
  135. }
  136. }