mux_test.go 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "io"
  7. "io/ioutil"
  8. "sync"
  9. "testing"
  10. )
  11. func muxPair() (*mux, *mux) {
  12. a, b := memPipe()
  13. s := newMux(a)
  14. c := newMux(b)
  15. return s, c
  16. }
  17. // Returns both ends of a channel, and the mux for the the 2nd
  18. // channel.
  19. func channelPair(t *testing.T) (*channel, *channel, *mux) {
  20. c, s := muxPair()
  21. res := make(chan *channel, 1)
  22. go func() {
  23. newCh, ok := <-s.incomingChannels
  24. if !ok {
  25. t.Fatalf("No incoming channel")
  26. }
  27. if newCh.ChannelType() != "chan" {
  28. t.Fatalf("got type %q want chan", newCh.ChannelType())
  29. }
  30. ch, _, err := newCh.Accept()
  31. if err != nil {
  32. t.Fatalf("Accept %v", err)
  33. }
  34. res <- ch.(*channel)
  35. }()
  36. ch, err := c.openChannel("chan", nil)
  37. if err != nil {
  38. t.Fatalf("OpenChannel: %v", err)
  39. }
  40. return <-res, ch, c
  41. }
  42. // Test that stderr and stdout can be addressed from different
  43. // goroutines. This is intended for use with the race detector.
  44. func TestMuxChannelExtendedThreadSafety(t *testing.T) {
  45. writer, reader, mux := channelPair(t)
  46. defer writer.Close()
  47. defer reader.Close()
  48. defer mux.Close()
  49. var wr, rd sync.WaitGroup
  50. magic := "hello world"
  51. wr.Add(2)
  52. go func() {
  53. io.WriteString(writer, magic)
  54. wr.Done()
  55. }()
  56. go func() {
  57. io.WriteString(writer.Stderr(), magic)
  58. wr.Done()
  59. }()
  60. rd.Add(2)
  61. go func() {
  62. c, err := ioutil.ReadAll(reader)
  63. if string(c) != magic {
  64. t.Fatalf("stdout read got %q, want %q (error %s)", c, magic, err)
  65. }
  66. rd.Done()
  67. }()
  68. go func() {
  69. c, err := ioutil.ReadAll(reader.Stderr())
  70. if string(c) != magic {
  71. t.Fatalf("stderr read got %q, want %q (error %s)", c, magic, err)
  72. }
  73. rd.Done()
  74. }()
  75. wr.Wait()
  76. writer.CloseWrite()
  77. rd.Wait()
  78. }
  79. func TestMuxReadWrite(t *testing.T) {
  80. s, c, mux := channelPair(t)
  81. defer s.Close()
  82. defer c.Close()
  83. defer mux.Close()
  84. magic := "hello world"
  85. magicExt := "hello stderr"
  86. go func() {
  87. _, err := s.Write([]byte(magic))
  88. if err != nil {
  89. t.Fatalf("Write: %v", err)
  90. }
  91. _, err = s.Extended(1).Write([]byte(magicExt))
  92. if err != nil {
  93. t.Fatalf("Write: %v", err)
  94. }
  95. err = s.Close()
  96. if err != nil {
  97. t.Fatalf("Close: %v", err)
  98. }
  99. }()
  100. var buf [1024]byte
  101. n, err := c.Read(buf[:])
  102. if err != nil {
  103. t.Fatalf("server Read: %v", err)
  104. }
  105. got := string(buf[:n])
  106. if got != magic {
  107. t.Fatalf("server: got %q want %q", got, magic)
  108. }
  109. n, err = c.Extended(1).Read(buf[:])
  110. if err != nil {
  111. t.Fatalf("server Read: %v", err)
  112. }
  113. got = string(buf[:n])
  114. if got != magicExt {
  115. t.Fatalf("server: got %q want %q", got, magic)
  116. }
  117. }
  118. func TestMuxChannelOverflow(t *testing.T) {
  119. reader, writer, mux := channelPair(t)
  120. defer reader.Close()
  121. defer writer.Close()
  122. defer mux.Close()
  123. wDone := make(chan int, 1)
  124. go func() {
  125. if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
  126. t.Errorf("could not fill window: %v", err)
  127. }
  128. writer.Write(make([]byte, 1))
  129. wDone <- 1
  130. }()
  131. writer.remoteWin.waitWriterBlocked()
  132. // Send 1 byte.
  133. packet := make([]byte, 1+4+4+1)
  134. packet[0] = msgChannelData
  135. marshalUint32(packet[1:], writer.remoteId)
  136. marshalUint32(packet[5:], uint32(1))
  137. packet[9] = 42
  138. if err := writer.mux.conn.writePacket(packet); err != nil {
  139. t.Errorf("could not send packet")
  140. }
  141. if _, err := reader.SendRequest("hello", true, nil); err == nil {
  142. t.Errorf("SendRequest succeeded.")
  143. }
  144. <-wDone
  145. }
  146. func TestMuxChannelCloseWriteUnblock(t *testing.T) {
  147. reader, writer, mux := channelPair(t)
  148. defer reader.Close()
  149. defer writer.Close()
  150. defer mux.Close()
  151. wDone := make(chan int, 1)
  152. go func() {
  153. if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
  154. t.Errorf("could not fill window: %v", err)
  155. }
  156. if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
  157. t.Errorf("got %v, want EOF for unblock write", err)
  158. }
  159. wDone <- 1
  160. }()
  161. writer.remoteWin.waitWriterBlocked()
  162. reader.Close()
  163. <-wDone
  164. }
  165. func TestMuxConnectionCloseWriteUnblock(t *testing.T) {
  166. reader, writer, mux := channelPair(t)
  167. defer reader.Close()
  168. defer writer.Close()
  169. defer mux.Close()
  170. wDone := make(chan int, 1)
  171. go func() {
  172. if _, err := writer.Write(make([]byte, channelWindowSize)); err != nil {
  173. t.Errorf("could not fill window: %v", err)
  174. }
  175. if _, err := writer.Write(make([]byte, 1)); err != io.EOF {
  176. t.Errorf("got %v, want EOF for unblock write", err)
  177. }
  178. wDone <- 1
  179. }()
  180. writer.remoteWin.waitWriterBlocked()
  181. mux.Close()
  182. <-wDone
  183. }
  184. func TestMuxReject(t *testing.T) {
  185. client, server := muxPair()
  186. defer server.Close()
  187. defer client.Close()
  188. go func() {
  189. ch, ok := <-server.incomingChannels
  190. if !ok {
  191. t.Fatalf("Accept")
  192. }
  193. if ch.ChannelType() != "ch" || string(ch.ExtraData()) != "extra" {
  194. t.Fatalf("unexpected channel: %q, %q", ch.ChannelType(), ch.ExtraData())
  195. }
  196. ch.Reject(RejectionReason(42), "message")
  197. }()
  198. ch, err := client.openChannel("ch", []byte("extra"))
  199. if ch != nil {
  200. t.Fatal("openChannel not rejected")
  201. }
  202. ocf, ok := err.(*OpenChannelError)
  203. if !ok {
  204. t.Errorf("got %#v want *OpenChannelError", err)
  205. } else if ocf.Reason != 42 || ocf.Message != "message" {
  206. t.Errorf("got %#v, want {Reason: 42, Message: %q}", ocf, "message")
  207. }
  208. want := "ssh: rejected: unknown reason 42 (message)"
  209. if err.Error() != want {
  210. t.Errorf("got %q, want %q", err.Error(), want)
  211. }
  212. }
  213. func TestMuxChannelRequest(t *testing.T) {
  214. client, server, mux := channelPair(t)
  215. defer server.Close()
  216. defer client.Close()
  217. defer mux.Close()
  218. var received int
  219. var wg sync.WaitGroup
  220. wg.Add(1)
  221. go func() {
  222. for r := range server.incomingRequests {
  223. received++
  224. r.Reply(r.Type == "yes", nil)
  225. }
  226. wg.Done()
  227. }()
  228. _, err := client.SendRequest("yes", false, nil)
  229. if err != nil {
  230. t.Fatalf("SendRequest: %v", err)
  231. }
  232. ok, err := client.SendRequest("yes", true, nil)
  233. if err != nil {
  234. t.Fatalf("SendRequest: %v", err)
  235. }
  236. if !ok {
  237. t.Errorf("SendRequest(yes): %v", ok)
  238. }
  239. ok, err = client.SendRequest("no", true, nil)
  240. if err != nil {
  241. t.Fatalf("SendRequest: %v", err)
  242. }
  243. if ok {
  244. t.Errorf("SendRequest(no): %v", ok)
  245. }
  246. client.Close()
  247. wg.Wait()
  248. if received != 3 {
  249. t.Errorf("got %d requests, want %d", received, 3)
  250. }
  251. }
  252. func TestMuxGlobalRequest(t *testing.T) {
  253. clientMux, serverMux := muxPair()
  254. defer serverMux.Close()
  255. defer clientMux.Close()
  256. var seen bool
  257. go func() {
  258. for r := range serverMux.incomingRequests {
  259. seen = seen || r.Type == "peek"
  260. if r.WantReply {
  261. err := r.Reply(r.Type == "yes",
  262. append([]byte(r.Type), r.Payload...))
  263. if err != nil {
  264. t.Errorf("AckRequest: %v", err)
  265. }
  266. }
  267. }
  268. }()
  269. _, _, err := clientMux.SendRequest("peek", false, nil)
  270. if err != nil {
  271. t.Errorf("SendRequest: %v", err)
  272. }
  273. ok, data, err := clientMux.SendRequest("yes", true, []byte("a"))
  274. if !ok || string(data) != "yesa" || err != nil {
  275. t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
  276. ok, data, err)
  277. }
  278. if ok, data, err := clientMux.SendRequest("yes", true, []byte("a")); !ok || string(data) != "yesa" || err != nil {
  279. t.Errorf("SendRequest(\"yes\", true, \"a\"): %v %v %v",
  280. ok, data, err)
  281. }
  282. if ok, data, err := clientMux.SendRequest("no", true, []byte("a")); ok || string(data) != "noa" || err != nil {
  283. t.Errorf("SendRequest(\"no\", true, \"a\"): %v %v %v",
  284. ok, data, err)
  285. }
  286. clientMux.Disconnect(0, "")
  287. if !seen {
  288. t.Errorf("never saw 'peek' request")
  289. }
  290. }
  291. func TestMuxGlobalRequestUnblock(t *testing.T) {
  292. clientMux, serverMux := muxPair()
  293. defer serverMux.Close()
  294. defer clientMux.Close()
  295. result := make(chan error, 1)
  296. go func() {
  297. _, _, err := clientMux.SendRequest("hello", true, nil)
  298. result <- err
  299. }()
  300. <-serverMux.incomingRequests
  301. serverMux.conn.Close()
  302. err := <-result
  303. if err != io.EOF {
  304. t.Errorf("want EOF, got %v", io.EOF)
  305. }
  306. }
  307. func TestMuxChannelRequestUnblock(t *testing.T) {
  308. a, b, connB := channelPair(t)
  309. defer a.Close()
  310. defer b.Close()
  311. defer connB.Close()
  312. result := make(chan error, 1)
  313. go func() {
  314. _, err := a.SendRequest("hello", true, nil)
  315. result <- err
  316. }()
  317. <-b.incomingRequests
  318. connB.conn.Close()
  319. err := <-result
  320. if err != io.EOF {
  321. t.Errorf("want EOF, got %v", err)
  322. }
  323. }
  324. func TestMuxDisconnect(t *testing.T) {
  325. a, b := muxPair()
  326. defer a.Close()
  327. defer b.Close()
  328. go func() {
  329. for r := range b.incomingRequests {
  330. r.Reply(true, nil)
  331. }
  332. }()
  333. a.Disconnect(42, "whatever")
  334. ok, _, err := a.SendRequest("hello", true, nil)
  335. if ok || err == nil {
  336. t.Errorf("got reply after disconnecting")
  337. }
  338. err = b.Wait()
  339. if d, ok := err.(*disconnectMsg); !ok || d.Reason != 42 {
  340. t.Errorf("got %#v, want disconnectMsg{Reason:42}", err)
  341. }
  342. }
  343. func TestMuxCloseChannel(t *testing.T) {
  344. r, w, mux := channelPair(t)
  345. defer mux.Close()
  346. defer r.Close()
  347. defer w.Close()
  348. result := make(chan error, 1)
  349. go func() {
  350. var b [1024]byte
  351. _, err := r.Read(b[:])
  352. result <- err
  353. }()
  354. if err := w.Close(); err != nil {
  355. t.Errorf("w.Close: %v", err)
  356. }
  357. if _, err := w.Write([]byte("hello")); err != io.EOF {
  358. t.Errorf("got err %v, want io.EOF after Close", err)
  359. }
  360. if err := <-result; err != io.EOF {
  361. t.Errorf("got %v (%T), want io.EOF", err, err)
  362. }
  363. }
  364. func TestMuxCloseWriteChannel(t *testing.T) {
  365. r, w, mux := channelPair(t)
  366. defer mux.Close()
  367. result := make(chan error, 1)
  368. go func() {
  369. var b [1024]byte
  370. _, err := r.Read(b[:])
  371. result <- err
  372. }()
  373. if err := w.CloseWrite(); err != nil {
  374. t.Errorf("w.CloseWrite: %v", err)
  375. }
  376. if _, err := w.Write([]byte("hello")); err != io.EOF {
  377. t.Errorf("got err %v, want io.EOF after CloseWrite", err)
  378. }
  379. if err := <-result; err != io.EOF {
  380. t.Errorf("got %v (%T), want io.EOF", err, err)
  381. }
  382. }
  383. func TestMuxInvalidRecord(t *testing.T) {
  384. a, b := muxPair()
  385. defer a.Close()
  386. defer b.Close()
  387. packet := make([]byte, 1+4+4+1)
  388. packet[0] = msgChannelData
  389. marshalUint32(packet[1:], 29348723 /* invalid channel id */)
  390. marshalUint32(packet[5:], 1)
  391. packet[9] = 42
  392. a.conn.writePacket(packet)
  393. go a.SendRequest("hello", false, nil)
  394. // 'a' wrote an invalid packet, so 'b' has exited.
  395. req, ok := <-b.incomingRequests
  396. if ok {
  397. t.Errorf("got request %#v after receiving invalid packet", req)
  398. }
  399. }
  400. func TestZeroWindowAdjust(t *testing.T) {
  401. a, b, mux := channelPair(t)
  402. defer a.Close()
  403. defer b.Close()
  404. defer mux.Close()
  405. go func() {
  406. io.WriteString(a, "hello")
  407. // bogus adjust.
  408. a.sendMessage(windowAdjustMsg{})
  409. io.WriteString(a, "world")
  410. a.Close()
  411. }()
  412. want := "helloworld"
  413. c, _ := ioutil.ReadAll(b)
  414. if string(c) != want {
  415. t.Errorf("got %q want %q", c, want)
  416. }
  417. }
  418. func TestMuxMaxPacketSize(t *testing.T) {
  419. a, b, mux := channelPair(t)
  420. defer a.Close()
  421. defer b.Close()
  422. defer mux.Close()
  423. large := make([]byte, a.maxRemotePayload+1)
  424. packet := make([]byte, 1+4+4+1+len(large))
  425. packet[0] = msgChannelData
  426. marshalUint32(packet[1:], a.remoteId)
  427. marshalUint32(packet[5:], uint32(len(large)))
  428. packet[9] = 42
  429. if err := a.mux.conn.writePacket(packet); err != nil {
  430. t.Errorf("could not send packet")
  431. }
  432. go a.SendRequest("hello", false, nil)
  433. _, ok := <-b.incomingRequests
  434. if ok {
  435. t.Errorf("connection still alive after receiving large packet.")
  436. }
  437. }
  438. // Don't ship code with debug=true.
  439. func TestDebug(t *testing.T) {
  440. if debugMux {
  441. t.Error("mux debug switched on")
  442. }
  443. if debugHandshake {
  444. t.Error("handshake debug switched on")
  445. }
  446. }