123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631 |
- // Copyright 2011 The Go Authors. All rights reserved.
- // Use of this source code is governed by a BSD-style
- // license that can be found in the LICENSE file.
- package ssh
- import (
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "log"
- "sync"
- )
- const (
- minPacketLength = 9
- // channelMaxPacket contains the maximum number of bytes that will be
- // sent in a single packet. As per RFC 4253, section 6.1, 32k is also
- // the minimum.
- channelMaxPacket = 1 << 15
- // We follow OpenSSH here.
- channelWindowSize = 64 * channelMaxPacket
- )
- // NewChannel represents an incoming request to a channel. It must either be
- // accepted for use by calling Accept, or rejected by calling Reject.
- type NewChannel interface {
- // Accept accepts the channel creation request. It returns the Channel
- // and a Go channel containing SSH requests. The Go channel must be
- // serviced otherwise the Channel will hang.
- Accept() (Channel, <-chan *Request, error)
- // Reject rejects the channel creation request. After calling
- // this, no other methods on the Channel may be called.
- Reject(reason RejectionReason, message string) error
- // ChannelType returns the type of the channel, as supplied by the
- // client.
- ChannelType() string
- // ExtraData returns the arbitrary payload for this channel, as supplied
- // by the client. This data is specific to the channel type.
- ExtraData() []byte
- }
- // A Channel is an ordered, reliable, flow-controlled, duplex stream
- // that is multiplexed over an SSH connection.
- type Channel interface {
- // Read reads up to len(data) bytes from the channel.
- Read(data []byte) (int, error)
- // Write writes len(data) bytes to the channel.
- Write(data []byte) (int, error)
- // Close signals end of channel use. No data may be sent after this
- // call.
- Close() error
- // CloseWrite signals the end of sending in-band
- // data. Requests may still be sent, and the other side may
- // still send data
- CloseWrite() error
- // SendRequest sends a channel request. If wantReply is true,
- // it will wait for a reply and return the result as a
- // boolean, otherwise the return value will be false. Channel
- // requests are out-of-band messages so they may be sent even
- // if the data stream is closed or blocked by flow control.
- SendRequest(name string, wantReply bool, payload []byte) (bool, error)
- // Stderr returns an io.ReadWriter that writes to this channel
- // with the extended data type set to stderr. Stderr may
- // safely be read and written from a different goroutine than
- // Read and Write respectively.
- Stderr() io.ReadWriter
- }
- // Request is a request sent outside of the normal stream of
- // data. Requests can either be specific to an SSH channel, or they
- // can be global.
- type Request struct {
- Type string
- WantReply bool
- Payload []byte
- ch *channel
- mux *mux
- }
- // Reply sends a response to a request. It must be called for all requests
- // where WantReply is true and is a no-op otherwise. The payload argument is
- // ignored for replies to channel-specific requests.
- func (r *Request) Reply(ok bool, payload []byte) error {
- if !r.WantReply {
- return nil
- }
- if r.ch == nil {
- return r.mux.ackRequest(ok, payload)
- }
- return r.ch.ackRequest(ok)
- }
- // RejectionReason is an enumeration used when rejecting channel creation
- // requests. See RFC 4254, section 5.1.
- type RejectionReason uint32
- const (
- Prohibited RejectionReason = iota + 1
- ConnectionFailed
- UnknownChannelType
- ResourceShortage
- )
- // String converts the rejection reason to human readable form.
- func (r RejectionReason) String() string {
- switch r {
- case Prohibited:
- return "administratively prohibited"
- case ConnectionFailed:
- return "connect failed"
- case UnknownChannelType:
- return "unknown channel type"
- case ResourceShortage:
- return "resource shortage"
- }
- return fmt.Sprintf("unknown reason %d", int(r))
- }
- func min(a uint32, b int) uint32 {
- if a < uint32(b) {
- return a
- }
- return uint32(b)
- }
- type channelDirection uint8
- const (
- channelInbound channelDirection = iota
- channelOutbound
- )
- // channel is an implementation of the Channel interface that works
- // with the mux class.
- type channel struct {
- // R/O after creation
- chanType string
- extraData []byte
- localId, remoteId uint32
- // maxIncomingPayload and maxRemotePayload are the maximum
- // payload sizes of normal and extended data packets for
- // receiving and sending, respectively. The wire packet will
- // be 9 or 13 bytes larger (excluding encryption overhead).
- maxIncomingPayload uint32
- maxRemotePayload uint32
- mux *mux
- // decided is set to true if an accept or reject message has been sent
- // (for outbound channels) or received (for inbound channels).
- decided bool
- // direction contains either channelOutbound, for channels created
- // locally, or channelInbound, for channels created by the peer.
- direction channelDirection
- // Pending internal channel messages.
- msg chan interface{}
- // Since requests have no ID, there can be only one request
- // with WantReply=true outstanding. This lock is held by a
- // goroutine that has such an outgoing request pending.
- sentRequestMu sync.Mutex
- incomingRequests chan *Request
- sentEOF bool
- // thread-safe data
- remoteWin window
- pending *buffer
- extPending *buffer
- // windowMu protects myWindow, the flow-control window.
- windowMu sync.Mutex
- myWindow uint32
- // writeMu serializes calls to mux.conn.writePacket() and
- // protects sentClose and packetPool. This mutex must be
- // different from windowMu, as writePacket can block if there
- // is a key exchange pending.
- writeMu sync.Mutex
- sentClose bool
- // packetPool has a buffer for each extended channel ID to
- // save allocations during writes.
- packetPool map[uint32][]byte
- }
- // writePacket sends a packet. If the packet is a channel close, it updates
- // sentClose. This method takes the lock c.writeMu.
- func (c *channel) writePacket(packet []byte) error {
- c.writeMu.Lock()
- if c.sentClose {
- c.writeMu.Unlock()
- return io.EOF
- }
- c.sentClose = (packet[0] == msgChannelClose)
- err := c.mux.conn.writePacket(packet)
- c.writeMu.Unlock()
- return err
- }
- func (c *channel) sendMessage(msg interface{}) error {
- if debugMux {
- log.Printf("send %d: %#v", c.mux.chanList.offset, msg)
- }
- p := Marshal(msg)
- binary.BigEndian.PutUint32(p[1:], c.remoteId)
- return c.writePacket(p)
- }
- // WriteExtended writes data to a specific extended stream. These streams are
- // used, for example, for stderr.
- func (c *channel) WriteExtended(data []byte, extendedCode uint32) (n int, err error) {
- if c.sentEOF {
- return 0, io.EOF
- }
- // 1 byte message type, 4 bytes remoteId, 4 bytes data length
- opCode := byte(msgChannelData)
- headerLength := uint32(9)
- if extendedCode > 0 {
- headerLength += 4
- opCode = msgChannelExtendedData
- }
- c.writeMu.Lock()
- packet := c.packetPool[extendedCode]
- // We don't remove the buffer from packetPool, so
- // WriteExtended calls from different goroutines will be
- // flagged as errors by the race detector.
- c.writeMu.Unlock()
- for len(data) > 0 {
- space := min(c.maxRemotePayload, len(data))
- if space, err = c.remoteWin.reserve(space); err != nil {
- return n, err
- }
- if want := headerLength + space; uint32(cap(packet)) < want {
- packet = make([]byte, want)
- } else {
- packet = packet[:want]
- }
- todo := data[:space]
- packet[0] = opCode
- binary.BigEndian.PutUint32(packet[1:], c.remoteId)
- if extendedCode > 0 {
- binary.BigEndian.PutUint32(packet[5:], uint32(extendedCode))
- }
- binary.BigEndian.PutUint32(packet[headerLength-4:], uint32(len(todo)))
- copy(packet[headerLength:], todo)
- if err = c.writePacket(packet); err != nil {
- return n, err
- }
- n += len(todo)
- data = data[len(todo):]
- }
- c.writeMu.Lock()
- c.packetPool[extendedCode] = packet
- c.writeMu.Unlock()
- return n, err
- }
- func (c *channel) handleData(packet []byte) error {
- headerLen := 9
- isExtendedData := packet[0] == msgChannelExtendedData
- if isExtendedData {
- headerLen = 13
- }
- if len(packet) < headerLen {
- // malformed data packet
- return parseError(packet[0])
- }
- var extended uint32
- if isExtendedData {
- extended = binary.BigEndian.Uint32(packet[5:])
- }
- length := binary.BigEndian.Uint32(packet[headerLen-4 : headerLen])
- if length == 0 {
- return nil
- }
- if length > c.maxIncomingPayload {
- // TODO(hanwen): should send Disconnect?
- return errors.New("ssh: incoming packet exceeds maximum payload size")
- }
- data := packet[headerLen:]
- if length != uint32(len(data)) {
- return errors.New("ssh: wrong packet length")
- }
- c.windowMu.Lock()
- if c.myWindow < length {
- c.windowMu.Unlock()
- // TODO(hanwen): should send Disconnect with reason?
- return errors.New("ssh: remote side wrote too much")
- }
- c.myWindow -= length
- c.windowMu.Unlock()
- if extended == 1 {
- c.extPending.write(data)
- } else if extended > 0 {
- // discard other extended data.
- } else {
- c.pending.write(data)
- }
- return nil
- }
- func (c *channel) adjustWindow(n uint32) error {
- c.windowMu.Lock()
- // Since myWindow is managed on our side, and can never exceed
- // the initial window setting, we don't worry about overflow.
- c.myWindow += uint32(n)
- c.windowMu.Unlock()
- return c.sendMessage(windowAdjustMsg{
- AdditionalBytes: uint32(n),
- })
- }
- func (c *channel) ReadExtended(data []byte, extended uint32) (n int, err error) {
- switch extended {
- case 1:
- n, err = c.extPending.Read(data)
- case 0:
- n, err = c.pending.Read(data)
- default:
- return 0, fmt.Errorf("ssh: extended code %d unimplemented", extended)
- }
- if n > 0 {
- err = c.adjustWindow(uint32(n))
- // sendWindowAdjust can return io.EOF if the remote
- // peer has closed the connection, however we want to
- // defer forwarding io.EOF to the caller of Read until
- // the buffer has been drained.
- if n > 0 && err == io.EOF {
- err = nil
- }
- }
- return n, err
- }
- func (c *channel) close() {
- c.pending.eof()
- c.extPending.eof()
- close(c.msg)
- close(c.incomingRequests)
- c.writeMu.Lock()
- // This is not necesary for a normal channel teardown, but if
- // there was another error, it is.
- c.sentClose = true
- c.writeMu.Unlock()
- // Unblock writers.
- c.remoteWin.close()
- }
- // responseMessageReceived is called when a success or failure message is
- // received on a channel to check that such a message is reasonable for the
- // given channel.
- func (c *channel) responseMessageReceived() error {
- if c.direction == channelInbound {
- return errors.New("ssh: channel response message received on inbound channel")
- }
- if c.decided {
- return errors.New("ssh: duplicate response received for channel")
- }
- c.decided = true
- return nil
- }
- func (c *channel) handlePacket(packet []byte) error {
- switch packet[0] {
- case msgChannelData, msgChannelExtendedData:
- return c.handleData(packet)
- case msgChannelClose:
- c.sendMessage(channelCloseMsg{PeersId: c.remoteId})
- c.mux.chanList.remove(c.localId)
- c.close()
- return nil
- case msgChannelEOF:
- // RFC 4254 is mute on how EOF affects dataExt messages but
- // it is logical to signal EOF at the same time.
- c.extPending.eof()
- c.pending.eof()
- return nil
- }
- decoded, err := decode(packet)
- if err != nil {
- return err
- }
- switch msg := decoded.(type) {
- case *channelOpenFailureMsg:
- if err := c.responseMessageReceived(); err != nil {
- return err
- }
- c.mux.chanList.remove(msg.PeersId)
- c.msg <- msg
- case *channelOpenConfirmMsg:
- if err := c.responseMessageReceived(); err != nil {
- return err
- }
- if msg.MaxPacketSize < minPacketLength || msg.MaxPacketSize > 1<<31 {
- return fmt.Errorf("ssh: invalid MaxPacketSize %d from peer", msg.MaxPacketSize)
- }
- c.remoteId = msg.MyId
- c.maxRemotePayload = msg.MaxPacketSize
- c.remoteWin.add(msg.MyWindow)
- c.msg <- msg
- case *windowAdjustMsg:
- if !c.remoteWin.add(msg.AdditionalBytes) {
- return fmt.Errorf("ssh: invalid window update for %d bytes", msg.AdditionalBytes)
- }
- case *channelRequestMsg:
- req := Request{
- Type: msg.Request,
- WantReply: msg.WantReply,
- Payload: msg.RequestSpecificData,
- ch: c,
- }
- c.incomingRequests <- &req
- default:
- c.msg <- msg
- }
- return nil
- }
- func (m *mux) newChannel(chanType string, direction channelDirection, extraData []byte) *channel {
- ch := &channel{
- remoteWin: window{Cond: newCond()},
- myWindow: channelWindowSize,
- pending: newBuffer(),
- extPending: newBuffer(),
- direction: direction,
- incomingRequests: make(chan *Request, 16),
- msg: make(chan interface{}, 16),
- chanType: chanType,
- extraData: extraData,
- mux: m,
- packetPool: make(map[uint32][]byte),
- }
- ch.localId = m.chanList.add(ch)
- return ch
- }
- var errUndecided = errors.New("ssh: must Accept or Reject channel")
- var errDecidedAlready = errors.New("ssh: can call Accept or Reject only once")
- type extChannel struct {
- code uint32
- ch *channel
- }
- func (e *extChannel) Write(data []byte) (n int, err error) {
- return e.ch.WriteExtended(data, e.code)
- }
- func (e *extChannel) Read(data []byte) (n int, err error) {
- return e.ch.ReadExtended(data, e.code)
- }
- func (c *channel) Accept() (Channel, <-chan *Request, error) {
- if c.decided {
- return nil, nil, errDecidedAlready
- }
- c.maxIncomingPayload = channelMaxPacket
- confirm := channelOpenConfirmMsg{
- PeersId: c.remoteId,
- MyId: c.localId,
- MyWindow: c.myWindow,
- MaxPacketSize: c.maxIncomingPayload,
- }
- c.decided = true
- if err := c.sendMessage(confirm); err != nil {
- return nil, nil, err
- }
- return c, c.incomingRequests, nil
- }
- func (ch *channel) Reject(reason RejectionReason, message string) error {
- if ch.decided {
- return errDecidedAlready
- }
- reject := channelOpenFailureMsg{
- PeersId: ch.remoteId,
- Reason: reason,
- Message: message,
- Language: "en",
- }
- ch.decided = true
- return ch.sendMessage(reject)
- }
- func (ch *channel) Read(data []byte) (int, error) {
- if !ch.decided {
- return 0, errUndecided
- }
- return ch.ReadExtended(data, 0)
- }
- func (ch *channel) Write(data []byte) (int, error) {
- if !ch.decided {
- return 0, errUndecided
- }
- return ch.WriteExtended(data, 0)
- }
- func (ch *channel) CloseWrite() error {
- if !ch.decided {
- return errUndecided
- }
- ch.sentEOF = true
- return ch.sendMessage(channelEOFMsg{
- PeersId: ch.remoteId})
- }
- func (ch *channel) Close() error {
- if !ch.decided {
- return errUndecided
- }
- return ch.sendMessage(channelCloseMsg{
- PeersId: ch.remoteId})
- }
- // Extended returns an io.ReadWriter that sends and receives data on the given,
- // SSH extended stream. Such streams are used, for example, for stderr.
- func (ch *channel) Extended(code uint32) io.ReadWriter {
- if !ch.decided {
- return nil
- }
- return &extChannel{code, ch}
- }
- func (ch *channel) Stderr() io.ReadWriter {
- return ch.Extended(1)
- }
- func (ch *channel) SendRequest(name string, wantReply bool, payload []byte) (bool, error) {
- if !ch.decided {
- return false, errUndecided
- }
- if wantReply {
- ch.sentRequestMu.Lock()
- defer ch.sentRequestMu.Unlock()
- }
- msg := channelRequestMsg{
- PeersId: ch.remoteId,
- Request: name,
- WantReply: wantReply,
- RequestSpecificData: payload,
- }
- if err := ch.sendMessage(msg); err != nil {
- return false, err
- }
- if wantReply {
- m, ok := (<-ch.msg)
- if !ok {
- return false, io.EOF
- }
- switch m.(type) {
- case *channelRequestFailureMsg:
- return false, nil
- case *channelRequestSuccessMsg:
- return true, nil
- default:
- return false, fmt.Errorf("ssh: unexpected response to channel request: %#v", m)
- }
- }
- return false, nil
- }
- // ackRequest either sends an ack or nack to the channel request.
- func (ch *channel) ackRequest(ok bool) error {
- if !ch.decided {
- return errUndecided
- }
- var msg interface{}
- if !ok {
- msg = channelRequestFailureMsg{
- PeersId: ch.remoteId,
- }
- } else {
- msg = channelRequestSuccessMsg{
- PeersId: ch.remoteId,
- }
- }
- return ch.sendMessage(msg)
- }
- func (ch *channel) ChannelType() string {
- return ch.chanType
- }
- func (ch *channel) ExtraData() []byte {
- return ch.extraData
- }
|