mux.go 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "encoding/binary"
  7. "fmt"
  8. "io"
  9. "log"
  10. "sync"
  11. "sync/atomic"
  12. )
  13. // debugMux, if set, causes messages in the connection protocol to be
  14. // logged.
  15. const debugMux = false
  16. // chanList is a thread safe channel list.
  17. type chanList struct {
  18. // protects concurrent access to chans
  19. sync.Mutex
  20. // chans are indexed by the local id of the channel, which the
  21. // other side should send in the PeersId field.
  22. chans []*channel
  23. // This is a debugging aid: it offsets all IDs by this
  24. // amount. This helps distinguish otherwise identical
  25. // server/client muxes
  26. offset uint32
  27. }
  28. // Assigns a channel ID to the given channel.
  29. func (c *chanList) add(ch *channel) uint32 {
  30. c.Lock()
  31. defer c.Unlock()
  32. for i := range c.chans {
  33. if c.chans[i] == nil {
  34. c.chans[i] = ch
  35. return uint32(i) + c.offset
  36. }
  37. }
  38. c.chans = append(c.chans, ch)
  39. return uint32(len(c.chans)-1) + c.offset
  40. }
  41. // getChan returns the channel for the given ID.
  42. func (c *chanList) getChan(id uint32) *channel {
  43. id -= c.offset
  44. c.Lock()
  45. defer c.Unlock()
  46. if id < uint32(len(c.chans)) {
  47. return c.chans[id]
  48. }
  49. return nil
  50. }
  51. func (c *chanList) remove(id uint32) {
  52. id -= c.offset
  53. c.Lock()
  54. if id < uint32(len(c.chans)) {
  55. c.chans[id] = nil
  56. }
  57. c.Unlock()
  58. }
  59. // dropAll forgets all channels it knows, returning them in a slice.
  60. func (c *chanList) dropAll() []*channel {
  61. c.Lock()
  62. defer c.Unlock()
  63. var r []*channel
  64. for _, ch := range c.chans {
  65. if ch == nil {
  66. continue
  67. }
  68. r = append(r, ch)
  69. }
  70. c.chans = nil
  71. return r
  72. }
  73. // mux represents the state for the SSH connection protocol, which
  74. // multiplexes many channels onto a single packet transport.
  75. type mux struct {
  76. conn packetConn
  77. chanList chanList
  78. incomingChannels chan NewChannel
  79. globalSentMu sync.Mutex
  80. globalResponses chan interface{}
  81. incomingRequests chan *Request
  82. errCond *sync.Cond
  83. err error
  84. }
  85. // When debugging, each new chanList instantiation has a different
  86. // offset.
  87. var globalOff uint32
  88. func (m *mux) Wait() error {
  89. m.errCond.L.Lock()
  90. defer m.errCond.L.Unlock()
  91. for m.err == nil {
  92. m.errCond.Wait()
  93. }
  94. return m.err
  95. }
  96. // newMux returns a mux that runs over the given connection.
  97. func newMux(p packetConn) *mux {
  98. m := &mux{
  99. conn: p,
  100. incomingChannels: make(chan NewChannel, 16),
  101. globalResponses: make(chan interface{}, 1),
  102. incomingRequests: make(chan *Request, 16),
  103. errCond: newCond(),
  104. }
  105. if debugMux {
  106. m.chanList.offset = atomic.AddUint32(&globalOff, 1)
  107. }
  108. go m.loop()
  109. return m
  110. }
  111. func (m *mux) sendMessage(msg interface{}) error {
  112. p := Marshal(msg)
  113. return m.conn.writePacket(p)
  114. }
  115. func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
  116. if wantReply {
  117. m.globalSentMu.Lock()
  118. defer m.globalSentMu.Unlock()
  119. }
  120. if err := m.sendMessage(globalRequestMsg{
  121. Type: name,
  122. WantReply: wantReply,
  123. Data: payload,
  124. }); err != nil {
  125. return false, nil, err
  126. }
  127. if !wantReply {
  128. return false, nil, nil
  129. }
  130. msg, ok := <-m.globalResponses
  131. if !ok {
  132. return false, nil, io.EOF
  133. }
  134. switch msg := msg.(type) {
  135. case *globalRequestFailureMsg:
  136. return false, msg.Data, nil
  137. case *globalRequestSuccessMsg:
  138. return true, msg.Data, nil
  139. default:
  140. return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
  141. }
  142. }
  143. // ackRequest must be called after processing a global request that
  144. // has WantReply set.
  145. func (m *mux) ackRequest(ok bool, data []byte) error {
  146. if ok {
  147. return m.sendMessage(globalRequestSuccessMsg{Data: data})
  148. }
  149. return m.sendMessage(globalRequestFailureMsg{Data: data})
  150. }
  151. // TODO(hanwen): Disconnect is a transport layer message. We should
  152. // probably send and receive Disconnect somewhere in the transport
  153. // code.
  154. // Disconnect sends a disconnect message.
  155. func (m *mux) Disconnect(reason uint32, message string) error {
  156. return m.sendMessage(disconnectMsg{
  157. Reason: reason,
  158. Message: message,
  159. })
  160. }
  161. func (m *mux) Close() error {
  162. return m.conn.Close()
  163. }
  164. // loop runs the connection machine. It will process packets until an
  165. // error is encountered. To synchronize on loop exit, use mux.Wait.
  166. func (m *mux) loop() {
  167. var err error
  168. for err == nil {
  169. err = m.onePacket()
  170. }
  171. for _, ch := range m.chanList.dropAll() {
  172. ch.close()
  173. }
  174. close(m.incomingChannels)
  175. close(m.incomingRequests)
  176. close(m.globalResponses)
  177. m.conn.Close()
  178. m.errCond.L.Lock()
  179. m.err = err
  180. m.errCond.Broadcast()
  181. m.errCond.L.Unlock()
  182. if debugMux {
  183. log.Println("loop exit", err)
  184. }
  185. }
  186. // onePacket reads and processes one packet.
  187. func (m *mux) onePacket() error {
  188. packet, err := m.conn.readPacket()
  189. if err != nil {
  190. return err
  191. }
  192. if debugMux {
  193. if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
  194. log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
  195. } else {
  196. p, _ := decode(packet)
  197. log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
  198. }
  199. }
  200. switch packet[0] {
  201. case msgNewKeys:
  202. // Ignore notification of key change.
  203. return nil
  204. case msgDisconnect:
  205. return m.handleDisconnect(packet)
  206. case msgChannelOpen:
  207. return m.handleChannelOpen(packet)
  208. case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
  209. return m.handleGlobalPacket(packet)
  210. }
  211. // assume a channel packet.
  212. if len(packet) < 5 {
  213. return parseError(packet[0])
  214. }
  215. id := binary.BigEndian.Uint32(packet[1:])
  216. ch := m.chanList.getChan(id)
  217. if ch == nil {
  218. return fmt.Errorf("ssh: invalid channel %d", id)
  219. }
  220. return ch.handlePacket(packet)
  221. }
  222. func (m *mux) handleDisconnect(packet []byte) error {
  223. var d disconnectMsg
  224. if err := Unmarshal(packet, &d); err != nil {
  225. return err
  226. }
  227. if debugMux {
  228. log.Printf("caught disconnect: %v", d)
  229. }
  230. return &d
  231. }
  232. func (m *mux) handleGlobalPacket(packet []byte) error {
  233. msg, err := decode(packet)
  234. if err != nil {
  235. return err
  236. }
  237. switch msg := msg.(type) {
  238. case *globalRequestMsg:
  239. m.incomingRequests <- &Request{
  240. Type: msg.Type,
  241. WantReply: msg.WantReply,
  242. Payload: msg.Data,
  243. mux: m,
  244. }
  245. case *globalRequestSuccessMsg, *globalRequestFailureMsg:
  246. m.globalResponses <- msg
  247. default:
  248. panic(fmt.Sprintf("not a global message %#v", msg))
  249. }
  250. return nil
  251. }
  252. // handleChannelOpen schedules a channel to be Accept()ed.
  253. func (m *mux) handleChannelOpen(packet []byte) error {
  254. var msg channelOpenMsg
  255. if err := Unmarshal(packet, &msg); err != nil {
  256. return err
  257. }
  258. if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
  259. failMsg := channelOpenFailureMsg{
  260. PeersId: msg.PeersId,
  261. Reason: ConnectionFailed,
  262. Message: "invalid request",
  263. Language: "en_US.UTF-8",
  264. }
  265. return m.sendMessage(failMsg)
  266. }
  267. c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
  268. c.remoteId = msg.PeersId
  269. c.maxRemotePayload = msg.MaxPacketSize
  270. c.remoteWin.add(msg.PeersWindow)
  271. m.incomingChannels <- c
  272. return nil
  273. }
  274. func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
  275. ch, err := m.openChannel(chanType, extra)
  276. if err != nil {
  277. return nil, nil, err
  278. }
  279. return ch, ch.incomingRequests, nil
  280. }
  281. func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
  282. ch := m.newChannel(chanType, channelOutbound, extra)
  283. ch.maxIncomingPayload = channelMaxPacket
  284. open := channelOpenMsg{
  285. ChanType: chanType,
  286. PeersWindow: ch.myWindow,
  287. MaxPacketSize: ch.maxIncomingPayload,
  288. TypeSpecificData: extra,
  289. PeersId: ch.localId,
  290. }
  291. if err := m.sendMessage(open); err != nil {
  292. return nil, err
  293. }
  294. switch msg := (<-ch.msg).(type) {
  295. case *channelOpenConfirmMsg:
  296. return ch, nil
  297. case *channelOpenFailureMsg:
  298. return nil, &OpenChannelError{msg.Reason, msg.Message}
  299. default:
  300. return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
  301. }
  302. }