copy.go 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267
  1. package pq
  2. import (
  3. "database/sql/driver"
  4. "encoding/binary"
  5. "errors"
  6. "fmt"
  7. "sync"
  8. )
  9. var (
  10. errCopyInClosed = errors.New("pq: copyin statement has already been closed")
  11. errBinaryCopyNotSupported = errors.New("pq: only text format supported for COPY")
  12. errCopyToNotSupported = errors.New("pq: COPY TO is not supported")
  13. errCopyNotSupportedOutsideTxn = errors.New("pq: COPY is only allowed inside a transaction")
  14. )
  15. // CopyIn creates a COPY FROM statement which can be prepared with
  16. // Tx.Prepare(). The target table should be visible in search_path.
  17. func CopyIn(table string, columns ...string) string {
  18. stmt := "COPY " + QuoteIdentifier(table) + " ("
  19. for i, col := range columns {
  20. if i != 0 {
  21. stmt += ", "
  22. }
  23. stmt += QuoteIdentifier(col)
  24. }
  25. stmt += ") FROM STDIN"
  26. return stmt
  27. }
  28. // CopyInSchema creates a COPY FROM statement which can be prepared with
  29. // Tx.Prepare().
  30. func CopyInSchema(schema, table string, columns ...string) string {
  31. stmt := "COPY " + QuoteIdentifier(schema) + "." + QuoteIdentifier(table) + " ("
  32. for i, col := range columns {
  33. if i != 0 {
  34. stmt += ", "
  35. }
  36. stmt += QuoteIdentifier(col)
  37. }
  38. stmt += ") FROM STDIN"
  39. return stmt
  40. }
  41. type copyin struct {
  42. cn *conn
  43. buffer []byte
  44. rowData chan []byte
  45. done chan bool
  46. closed bool
  47. sync.Mutex // guards err
  48. err error
  49. }
  50. const ciBufferSize = 64 * 1024
  51. // flush buffer before the buffer is filled up and needs reallocation
  52. const ciBufferFlushSize = 63 * 1024
  53. func (cn *conn) prepareCopyIn(q string) (_ driver.Stmt, err error) {
  54. if !cn.isInTransaction() {
  55. return nil, errCopyNotSupportedOutsideTxn
  56. }
  57. ci := &copyin{
  58. cn: cn,
  59. buffer: make([]byte, 0, ciBufferSize),
  60. rowData: make(chan []byte),
  61. done: make(chan bool, 1),
  62. }
  63. // add CopyData identifier + 4 bytes for message length
  64. ci.buffer = append(ci.buffer, 'd', 0, 0, 0, 0)
  65. b := cn.writeBuf('Q')
  66. b.string(q)
  67. cn.send(b)
  68. awaitCopyInResponse:
  69. for {
  70. t, r := cn.recv1()
  71. switch t {
  72. case 'G':
  73. if r.byte() != 0 {
  74. err = errBinaryCopyNotSupported
  75. break awaitCopyInResponse
  76. }
  77. go ci.resploop()
  78. return ci, nil
  79. case 'H':
  80. err = errCopyToNotSupported
  81. break awaitCopyInResponse
  82. case 'E':
  83. err = parseError(r)
  84. case 'Z':
  85. if err == nil {
  86. cn.bad = true
  87. errorf("unexpected ReadyForQuery in response to COPY")
  88. }
  89. cn.processReadyForQuery(r)
  90. return nil, err
  91. default:
  92. cn.bad = true
  93. errorf("unknown response for copy query: %q", t)
  94. }
  95. }
  96. // something went wrong, abort COPY before we return
  97. b = cn.writeBuf('f')
  98. b.string(err.Error())
  99. cn.send(b)
  100. for {
  101. t, r := cn.recv1()
  102. switch t {
  103. case 'c', 'C', 'E':
  104. case 'Z':
  105. // correctly aborted, we're done
  106. cn.processReadyForQuery(r)
  107. return nil, err
  108. default:
  109. cn.bad = true
  110. errorf("unknown response for CopyFail: %q", t)
  111. }
  112. }
  113. }
  114. func (ci *copyin) flush(buf []byte) {
  115. // set message length (without message identifier)
  116. binary.BigEndian.PutUint32(buf[1:], uint32(len(buf)-1))
  117. _, err := ci.cn.c.Write(buf)
  118. if err != nil {
  119. panic(err)
  120. }
  121. }
  122. func (ci *copyin) resploop() {
  123. for {
  124. var r readBuf
  125. t, err := ci.cn.recvMessage(&r)
  126. if err != nil {
  127. ci.cn.bad = true
  128. ci.setError(err)
  129. ci.done <- true
  130. return
  131. }
  132. switch t {
  133. case 'C':
  134. // complete
  135. case 'N':
  136. // NoticeResponse
  137. case 'Z':
  138. ci.cn.processReadyForQuery(&r)
  139. ci.done <- true
  140. return
  141. case 'E':
  142. err := parseError(&r)
  143. ci.setError(err)
  144. default:
  145. ci.cn.bad = true
  146. ci.setError(fmt.Errorf("unknown response during CopyIn: %q", t))
  147. ci.done <- true
  148. return
  149. }
  150. }
  151. }
  152. func (ci *copyin) isErrorSet() bool {
  153. ci.Lock()
  154. isSet := (ci.err != nil)
  155. ci.Unlock()
  156. return isSet
  157. }
  158. // setError() sets ci.err if one has not been set already. Caller must not be
  159. // holding ci.Mutex.
  160. func (ci *copyin) setError(err error) {
  161. ci.Lock()
  162. if ci.err == nil {
  163. ci.err = err
  164. }
  165. ci.Unlock()
  166. }
  167. func (ci *copyin) NumInput() int {
  168. return -1
  169. }
  170. func (ci *copyin) Query(v []driver.Value) (r driver.Rows, err error) {
  171. return nil, ErrNotSupported
  172. }
  173. // Exec inserts values into the COPY stream. The insert is asynchronous
  174. // and Exec can return errors from previous Exec calls to the same
  175. // COPY stmt.
  176. //
  177. // You need to call Exec(nil) to sync the COPY stream and to get any
  178. // errors from pending data, since Stmt.Close() doesn't return errors
  179. // to the user.
  180. func (ci *copyin) Exec(v []driver.Value) (r driver.Result, err error) {
  181. if ci.closed {
  182. return nil, errCopyInClosed
  183. }
  184. if ci.cn.bad {
  185. return nil, driver.ErrBadConn
  186. }
  187. defer ci.cn.errRecover(&err)
  188. if ci.isErrorSet() {
  189. return nil, ci.err
  190. }
  191. if len(v) == 0 {
  192. return nil, ci.Close()
  193. }
  194. numValues := len(v)
  195. for i, value := range v {
  196. ci.buffer = appendEncodedText(&ci.cn.parameterStatus, ci.buffer, value)
  197. if i < numValues-1 {
  198. ci.buffer = append(ci.buffer, '\t')
  199. }
  200. }
  201. ci.buffer = append(ci.buffer, '\n')
  202. if len(ci.buffer) > ciBufferFlushSize {
  203. ci.flush(ci.buffer)
  204. // reset buffer, keep bytes for message identifier and length
  205. ci.buffer = ci.buffer[:5]
  206. }
  207. return driver.RowsAffected(0), nil
  208. }
  209. func (ci *copyin) Close() (err error) {
  210. if ci.closed { // Don't do anything, we're already closed
  211. return nil
  212. }
  213. ci.closed = true
  214. if ci.cn.bad {
  215. return driver.ErrBadConn
  216. }
  217. defer ci.cn.errRecover(&err)
  218. if len(ci.buffer) > 0 {
  219. ci.flush(ci.buffer)
  220. }
  221. // Avoid touching the scratch buffer as resploop could be using it.
  222. err = ci.cn.sendSimpleMessage('c')
  223. if err != nil {
  224. return err
  225. }
  226. <-ci.done
  227. if ci.isErrorSet() {
  228. err = ci.err
  229. return err
  230. }
  231. return nil
  232. }