mux.go 7.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330
  1. // Copyright 2013 The Go Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package ssh
  5. import (
  6. "encoding/binary"
  7. "fmt"
  8. "io"
  9. "log"
  10. "sync"
  11. "sync/atomic"
  12. )
  13. // debugMux, if set, causes messages in the connection protocol to be
  14. // logged.
  15. const debugMux = false
  16. // chanList is a thread safe channel list.
  17. type chanList struct {
  18. // protects concurrent access to chans
  19. sync.Mutex
  20. // chans are indexed by the local id of the channel, which the
  21. // other side should send in the PeersId field.
  22. chans []*channel
  23. // This is a debugging aid: it offsets all IDs by this
  24. // amount. This helps distinguish otherwise identical
  25. // server/client muxes
  26. offset uint32
  27. }
  28. // Assigns a channel ID to the given channel.
  29. func (c *chanList) add(ch *channel) uint32 {
  30. c.Lock()
  31. defer c.Unlock()
  32. for i := range c.chans {
  33. if c.chans[i] == nil {
  34. c.chans[i] = ch
  35. return uint32(i) + c.offset
  36. }
  37. }
  38. c.chans = append(c.chans, ch)
  39. return uint32(len(c.chans)-1) + c.offset
  40. }
  41. // getChan returns the channel for the given ID.
  42. func (c *chanList) getChan(id uint32) *channel {
  43. id -= c.offset
  44. c.Lock()
  45. defer c.Unlock()
  46. if id < uint32(len(c.chans)) {
  47. return c.chans[id]
  48. }
  49. return nil
  50. }
  51. func (c *chanList) remove(id uint32) {
  52. id -= c.offset
  53. c.Lock()
  54. if id < uint32(len(c.chans)) {
  55. c.chans[id] = nil
  56. }
  57. c.Unlock()
  58. }
  59. // dropAll forgets all channels it knows, returning them in a slice.
  60. func (c *chanList) dropAll() []*channel {
  61. c.Lock()
  62. defer c.Unlock()
  63. var r []*channel
  64. for _, ch := range c.chans {
  65. if ch == nil {
  66. continue
  67. }
  68. r = append(r, ch)
  69. }
  70. c.chans = nil
  71. return r
  72. }
  73. // mux represents the state for the SSH connection protocol, which
  74. // multiplexes many channels onto a single packet transport.
  75. type mux struct {
  76. conn packetConn
  77. chanList chanList
  78. incomingChannels chan NewChannel
  79. globalSentMu sync.Mutex
  80. globalResponses chan interface{}
  81. incomingRequests chan *Request
  82. errCond *sync.Cond
  83. err error
  84. }
  85. // When debugging, each new chanList instantiation has a different
  86. // offset.
  87. var globalOff uint32
  88. func (m *mux) Wait() error {
  89. m.errCond.L.Lock()
  90. defer m.errCond.L.Unlock()
  91. for m.err == nil {
  92. m.errCond.Wait()
  93. }
  94. return m.err
  95. }
  96. // newMux returns a mux that runs over the given connection.
  97. func newMux(p packetConn) *mux {
  98. m := &mux{
  99. conn: p,
  100. incomingChannels: make(chan NewChannel, chanSize),
  101. globalResponses: make(chan interface{}, 1),
  102. incomingRequests: make(chan *Request, chanSize),
  103. errCond: newCond(),
  104. }
  105. if debugMux {
  106. m.chanList.offset = atomic.AddUint32(&globalOff, 1)
  107. }
  108. go m.loop()
  109. return m
  110. }
  111. func (m *mux) sendMessage(msg interface{}) error {
  112. p := Marshal(msg)
  113. if debugMux {
  114. log.Printf("send global(%d): %#v", m.chanList.offset, msg)
  115. }
  116. return m.conn.writePacket(p)
  117. }
  118. func (m *mux) SendRequest(name string, wantReply bool, payload []byte) (bool, []byte, error) {
  119. if wantReply {
  120. m.globalSentMu.Lock()
  121. defer m.globalSentMu.Unlock()
  122. }
  123. if err := m.sendMessage(globalRequestMsg{
  124. Type: name,
  125. WantReply: wantReply,
  126. Data: payload,
  127. }); err != nil {
  128. return false, nil, err
  129. }
  130. if !wantReply {
  131. return false, nil, nil
  132. }
  133. msg, ok := <-m.globalResponses
  134. if !ok {
  135. return false, nil, io.EOF
  136. }
  137. switch msg := msg.(type) {
  138. case *globalRequestFailureMsg:
  139. return false, msg.Data, nil
  140. case *globalRequestSuccessMsg:
  141. return true, msg.Data, nil
  142. default:
  143. return false, nil, fmt.Errorf("ssh: unexpected response to request: %#v", msg)
  144. }
  145. }
  146. // ackRequest must be called after processing a global request that
  147. // has WantReply set.
  148. func (m *mux) ackRequest(ok bool, data []byte) error {
  149. if ok {
  150. return m.sendMessage(globalRequestSuccessMsg{Data: data})
  151. }
  152. return m.sendMessage(globalRequestFailureMsg{Data: data})
  153. }
  154. func (m *mux) Close() error {
  155. return m.conn.Close()
  156. }
  157. // loop runs the connection machine. It will process packets until an
  158. // error is encountered. To synchronize on loop exit, use mux.Wait.
  159. func (m *mux) loop() {
  160. var err error
  161. for err == nil {
  162. err = m.onePacket()
  163. }
  164. for _, ch := range m.chanList.dropAll() {
  165. ch.close()
  166. }
  167. close(m.incomingChannels)
  168. close(m.incomingRequests)
  169. close(m.globalResponses)
  170. m.conn.Close()
  171. m.errCond.L.Lock()
  172. m.err = err
  173. m.errCond.Broadcast()
  174. m.errCond.L.Unlock()
  175. if debugMux {
  176. log.Println("loop exit", err)
  177. }
  178. }
  179. // onePacket reads and processes one packet.
  180. func (m *mux) onePacket() error {
  181. packet, err := m.conn.readPacket()
  182. if err != nil {
  183. return err
  184. }
  185. if debugMux {
  186. if packet[0] == msgChannelData || packet[0] == msgChannelExtendedData {
  187. log.Printf("decoding(%d): data packet - %d bytes", m.chanList.offset, len(packet))
  188. } else {
  189. p, _ := decode(packet)
  190. log.Printf("decoding(%d): %d %#v - %d bytes", m.chanList.offset, packet[0], p, len(packet))
  191. }
  192. }
  193. switch packet[0] {
  194. case msgChannelOpen:
  195. return m.handleChannelOpen(packet)
  196. case msgGlobalRequest, msgRequestSuccess, msgRequestFailure:
  197. return m.handleGlobalPacket(packet)
  198. }
  199. // assume a channel packet.
  200. if len(packet) < 5 {
  201. return parseError(packet[0])
  202. }
  203. id := binary.BigEndian.Uint32(packet[1:])
  204. ch := m.chanList.getChan(id)
  205. if ch == nil {
  206. return fmt.Errorf("ssh: invalid channel %d", id)
  207. }
  208. return ch.handlePacket(packet)
  209. }
  210. func (m *mux) handleGlobalPacket(packet []byte) error {
  211. msg, err := decode(packet)
  212. if err != nil {
  213. return err
  214. }
  215. switch msg := msg.(type) {
  216. case *globalRequestMsg:
  217. m.incomingRequests <- &Request{
  218. Type: msg.Type,
  219. WantReply: msg.WantReply,
  220. Payload: msg.Data,
  221. mux: m,
  222. }
  223. case *globalRequestSuccessMsg, *globalRequestFailureMsg:
  224. m.globalResponses <- msg
  225. default:
  226. panic(fmt.Sprintf("not a global message %#v", msg))
  227. }
  228. return nil
  229. }
  230. // handleChannelOpen schedules a channel to be Accept()ed.
  231. func (m *mux) handleChannelOpen(packet []byte) error {
  232. var msg channelOpenMsg
  233. if err := Unmarshal(packet, &msg); err != nil {
  234. return err
  235. }
  236. if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
  237. failMsg := channelOpenFailureMsg{
  238. PeersId: msg.PeersId,
  239. Reason: ConnectionFailed,
  240. Message: "invalid request",
  241. Language: "en_US.UTF-8",
  242. }
  243. return m.sendMessage(failMsg)
  244. }
  245. c := m.newChannel(msg.ChanType, channelInbound, msg.TypeSpecificData)
  246. c.remoteId = msg.PeersId
  247. c.maxRemotePayload = msg.MaxPacketSize
  248. c.remoteWin.add(msg.PeersWindow)
  249. m.incomingChannels <- c
  250. return nil
  251. }
  252. func (m *mux) OpenChannel(chanType string, extra []byte) (Channel, <-chan *Request, error) {
  253. ch, err := m.openChannel(chanType, extra)
  254. if err != nil {
  255. return nil, nil, err
  256. }
  257. return ch, ch.incomingRequests, nil
  258. }
  259. func (m *mux) openChannel(chanType string, extra []byte) (*channel, error) {
  260. ch := m.newChannel(chanType, channelOutbound, extra)
  261. ch.maxIncomingPayload = channelMaxPacket
  262. open := channelOpenMsg{
  263. ChanType: chanType,
  264. PeersWindow: ch.myWindow,
  265. MaxPacketSize: ch.maxIncomingPayload,
  266. TypeSpecificData: extra,
  267. PeersId: ch.localId,
  268. }
  269. if err := m.sendMessage(open); err != nil {
  270. return nil, err
  271. }
  272. switch msg := (<-ch.msg).(type) {
  273. case *channelOpenConfirmMsg:
  274. return ch, nil
  275. case *channelOpenFailureMsg:
  276. return nil, &OpenChannelError{msg.Reason, msg.Message}
  277. default:
  278. return nil, fmt.Errorf("ssh: unexpected packet in response to channel open: %T", msg)
  279. }
  280. }