conn.go 45 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893
  1. package pq
  2. import (
  3. "bufio"
  4. "crypto/md5"
  5. "crypto/tls"
  6. "crypto/x509"
  7. "database/sql"
  8. "database/sql/driver"
  9. "encoding/binary"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "io/ioutil"
  14. "net"
  15. "os"
  16. "os/user"
  17. "path"
  18. "path/filepath"
  19. "strconv"
  20. "strings"
  21. "time"
  22. "unicode"
  23. "github.com/lib/pq/oid"
  24. )
  25. // Common error types
  26. var (
  27. ErrNotSupported = errors.New("pq: Unsupported command")
  28. ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
  29. ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
  30. ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less.")
  31. ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly.")
  32. errUnexpectedReady = errors.New("unexpected ReadyForQuery")
  33. errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
  34. errNoLastInsertId = errors.New("no LastInsertId available after the empty statement")
  35. )
  36. type drv struct{}
  37. func (d *drv) Open(name string) (driver.Conn, error) {
  38. return Open(name)
  39. }
  40. func init() {
  41. sql.Register("postgres", &drv{})
  42. }
  43. type parameterStatus struct {
  44. // server version in the same format as server_version_num, or 0 if
  45. // unavailable
  46. serverVersion int
  47. // the current location based on the TimeZone value of the session, if
  48. // available
  49. currentLocation *time.Location
  50. }
  51. type transactionStatus byte
  52. const (
  53. txnStatusIdle transactionStatus = 'I'
  54. txnStatusIdleInTransaction transactionStatus = 'T'
  55. txnStatusInFailedTransaction transactionStatus = 'E'
  56. )
  57. func (s transactionStatus) String() string {
  58. switch s {
  59. case txnStatusIdle:
  60. return "idle"
  61. case txnStatusIdleInTransaction:
  62. return "idle in transaction"
  63. case txnStatusInFailedTransaction:
  64. return "in a failed transaction"
  65. default:
  66. errorf("unknown transactionStatus %d", s)
  67. }
  68. panic("not reached")
  69. }
  70. type Dialer interface {
  71. Dial(network, address string) (net.Conn, error)
  72. DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
  73. }
  74. type defaultDialer struct{}
  75. func (d defaultDialer) Dial(ntw, addr string) (net.Conn, error) {
  76. return net.Dial(ntw, addr)
  77. }
  78. func (d defaultDialer) DialTimeout(ntw, addr string, timeout time.Duration) (net.Conn, error) {
  79. return net.DialTimeout(ntw, addr, timeout)
  80. }
  81. type conn struct {
  82. c net.Conn
  83. buf *bufio.Reader
  84. namei int
  85. scratch [512]byte
  86. txnStatus transactionStatus
  87. parameterStatus parameterStatus
  88. saveMessageType byte
  89. saveMessageBuffer []byte
  90. // If true, this connection is bad and all public-facing functions should
  91. // return ErrBadConn.
  92. bad bool
  93. // If set, this connection should never use the binary format when
  94. // receiving query results from prepared statements. Only provided for
  95. // debugging.
  96. disablePreparedBinaryResult bool
  97. // Whether to always send []byte parameters over as binary. Enables single
  98. // round-trip mode for non-prepared Query calls.
  99. binaryParameters bool
  100. }
  101. // Handle driver-side settings in parsed connection string.
  102. func (c *conn) handleDriverSettings(o values) (err error) {
  103. boolSetting := func(key string, val *bool) error {
  104. if value := o.Get(key); value != "" {
  105. if value == "yes" {
  106. *val = true
  107. } else if value == "no" {
  108. *val = false
  109. } else {
  110. return fmt.Errorf("unrecognized value %q for %s", value, key)
  111. }
  112. }
  113. return nil
  114. }
  115. err = boolSetting("disable_prepared_binary_result", &c.disablePreparedBinaryResult)
  116. if err != nil {
  117. return err
  118. }
  119. err = boolSetting("binary_parameters", &c.binaryParameters)
  120. if err != nil {
  121. return err
  122. }
  123. return nil
  124. }
  125. func (c *conn) handlePgpass(o values) {
  126. // if a password was supplied, do not process .pgpass
  127. _, ok := o["password"]
  128. if ok {
  129. return
  130. }
  131. filename := os.Getenv("PGPASSFILE")
  132. if filename == "" {
  133. // XXX this code doesn't work on Windows where the default filename is
  134. // XXX %APPDATA%\postgresql\pgpass.conf
  135. user, err := user.Current()
  136. if err != nil {
  137. return
  138. }
  139. filename = filepath.Join(user.HomeDir, ".pgpass")
  140. }
  141. fileinfo, err := os.Stat(filename)
  142. if err != nil {
  143. return
  144. }
  145. mode := fileinfo.Mode()
  146. if mode&(0x77) != 0 {
  147. // XXX should warn about incorrect .pgpass permissions as psql does
  148. return
  149. }
  150. file, err := os.Open(filename)
  151. if err != nil {
  152. return
  153. }
  154. defer file.Close()
  155. scanner := bufio.NewScanner(io.Reader(file))
  156. hostname := o.Get("host")
  157. ntw, _ := network(o)
  158. port := o.Get("port")
  159. db := o.Get("dbname")
  160. username := o.Get("user")
  161. // From: https://github.com/tg/pgpass/blob/master/reader.go
  162. getFields := func(s string) []string {
  163. fs := make([]string, 0, 5)
  164. f := make([]rune, 0, len(s))
  165. var esc bool
  166. for _, c := range s {
  167. switch {
  168. case esc:
  169. f = append(f, c)
  170. esc = false
  171. case c == '\\':
  172. esc = true
  173. case c == ':':
  174. fs = append(fs, string(f))
  175. f = f[:0]
  176. default:
  177. f = append(f, c)
  178. }
  179. }
  180. return append(fs, string(f))
  181. }
  182. for scanner.Scan() {
  183. line := scanner.Text()
  184. if len(line) == 0 || line[0] == '#' {
  185. continue
  186. }
  187. split := getFields(line)
  188. if len(split) != 5 {
  189. continue
  190. }
  191. if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
  192. o["password"] = split[4]
  193. return
  194. }
  195. }
  196. }
  197. func (c *conn) writeBuf(b byte) *writeBuf {
  198. c.scratch[0] = b
  199. return &writeBuf{
  200. buf: c.scratch[:5],
  201. pos: 1,
  202. }
  203. }
  204. func Open(name string) (_ driver.Conn, err error) {
  205. return DialOpen(defaultDialer{}, name)
  206. }
  207. func DialOpen(d Dialer, name string) (_ driver.Conn, err error) {
  208. // Handle any panics during connection initialization. Note that we
  209. // specifically do *not* want to use errRecover(), as that would turn any
  210. // connection errors into ErrBadConns, hiding the real error message from
  211. // the user.
  212. defer errRecoverNoErrBadConn(&err)
  213. o := make(values)
  214. // A number of defaults are applied here, in this order:
  215. //
  216. // * Very low precedence defaults applied in every situation
  217. // * Environment variables
  218. // * Explicitly passed connection information
  219. o.Set("host", "localhost")
  220. o.Set("port", "5432")
  221. // N.B.: Extra float digits should be set to 3, but that breaks
  222. // Postgres 8.4 and older, where the max is 2.
  223. o.Set("extra_float_digits", "2")
  224. for k, v := range parseEnviron(os.Environ()) {
  225. o.Set(k, v)
  226. }
  227. if strings.HasPrefix(name, "postgres://") || strings.HasPrefix(name, "postgresql://") {
  228. name, err = ParseURL(name)
  229. if err != nil {
  230. return nil, err
  231. }
  232. }
  233. if err := parseOpts(name, o); err != nil {
  234. return nil, err
  235. }
  236. // Use the "fallback" application name if necessary
  237. if fallback := o.Get("fallback_application_name"); fallback != "" {
  238. if !o.Isset("application_name") {
  239. o.Set("application_name", fallback)
  240. }
  241. }
  242. // We can't work with any client_encoding other than UTF-8 currently.
  243. // However, we have historically allowed the user to set it to UTF-8
  244. // explicitly, and there's no reason to break such programs, so allow that.
  245. // Note that the "options" setting could also set client_encoding, but
  246. // parsing its value is not worth it. Instead, we always explicitly send
  247. // client_encoding as a separate run-time parameter, which should override
  248. // anything set in options.
  249. if enc := o.Get("client_encoding"); enc != "" && !isUTF8(enc) {
  250. return nil, errors.New("client_encoding must be absent or 'UTF8'")
  251. }
  252. o.Set("client_encoding", "UTF8")
  253. // DateStyle needs a similar treatment.
  254. if datestyle := o.Get("datestyle"); datestyle != "" {
  255. if datestyle != "ISO, MDY" {
  256. panic(fmt.Sprintf("setting datestyle must be absent or %v; got %v",
  257. "ISO, MDY", datestyle))
  258. }
  259. } else {
  260. o.Set("datestyle", "ISO, MDY")
  261. }
  262. // If a user is not provided by any other means, the last
  263. // resort is to use the current operating system provided user
  264. // name.
  265. if o.Get("user") == "" {
  266. u, err := userCurrent()
  267. if err != nil {
  268. return nil, err
  269. } else {
  270. o.Set("user", u)
  271. }
  272. }
  273. cn := &conn{}
  274. err = cn.handleDriverSettings(o)
  275. if err != nil {
  276. return nil, err
  277. }
  278. cn.handlePgpass(o)
  279. cn.c, err = dial(d, o)
  280. if err != nil {
  281. return nil, err
  282. }
  283. cn.ssl(o)
  284. cn.buf = bufio.NewReader(cn.c)
  285. cn.startup(o)
  286. // reset the deadline, in case one was set (see dial)
  287. if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" {
  288. err = cn.c.SetDeadline(time.Time{})
  289. }
  290. return cn, err
  291. }
  292. func dial(d Dialer, o values) (net.Conn, error) {
  293. ntw, addr := network(o)
  294. // SSL is not necessary or supported over UNIX domain sockets
  295. if ntw == "unix" {
  296. o["sslmode"] = "disable"
  297. }
  298. // Zero or not specified means wait indefinitely.
  299. if timeout := o.Get("connect_timeout"); timeout != "" && timeout != "0" {
  300. seconds, err := strconv.ParseInt(timeout, 10, 0)
  301. if err != nil {
  302. return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
  303. }
  304. duration := time.Duration(seconds) * time.Second
  305. // connect_timeout should apply to the entire connection establishment
  306. // procedure, so we both use a timeout for the TCP connection
  307. // establishment and set a deadline for doing the initial handshake.
  308. // The deadline is then reset after startup() is done.
  309. deadline := time.Now().Add(duration)
  310. conn, err := d.DialTimeout(ntw, addr, duration)
  311. if err != nil {
  312. return nil, err
  313. }
  314. err = conn.SetDeadline(deadline)
  315. return conn, err
  316. }
  317. return d.Dial(ntw, addr)
  318. }
  319. func network(o values) (string, string) {
  320. host := o.Get("host")
  321. if strings.HasPrefix(host, "/") {
  322. sockPath := path.Join(host, ".s.PGSQL."+o.Get("port"))
  323. return "unix", sockPath
  324. }
  325. return "tcp", net.JoinHostPort(host, o.Get("port"))
  326. }
  327. type values map[string]string
  328. func (vs values) Set(k, v string) {
  329. vs[k] = v
  330. }
  331. func (vs values) Get(k string) (v string) {
  332. return vs[k]
  333. }
  334. func (vs values) Isset(k string) bool {
  335. _, ok := vs[k]
  336. return ok
  337. }
  338. // scanner implements a tokenizer for libpq-style option strings.
  339. type scanner struct {
  340. s []rune
  341. i int
  342. }
  343. // newScanner returns a new scanner initialized with the option string s.
  344. func newScanner(s string) *scanner {
  345. return &scanner{[]rune(s), 0}
  346. }
  347. // Next returns the next rune.
  348. // It returns 0, false if the end of the text has been reached.
  349. func (s *scanner) Next() (rune, bool) {
  350. if s.i >= len(s.s) {
  351. return 0, false
  352. }
  353. r := s.s[s.i]
  354. s.i++
  355. return r, true
  356. }
  357. // SkipSpaces returns the next non-whitespace rune.
  358. // It returns 0, false if the end of the text has been reached.
  359. func (s *scanner) SkipSpaces() (rune, bool) {
  360. r, ok := s.Next()
  361. for unicode.IsSpace(r) && ok {
  362. r, ok = s.Next()
  363. }
  364. return r, ok
  365. }
  366. // parseOpts parses the options from name and adds them to the values.
  367. //
  368. // The parsing code is based on conninfo_parse from libpq's fe-connect.c
  369. func parseOpts(name string, o values) error {
  370. s := newScanner(name)
  371. for {
  372. var (
  373. keyRunes, valRunes []rune
  374. r rune
  375. ok bool
  376. )
  377. if r, ok = s.SkipSpaces(); !ok {
  378. break
  379. }
  380. // Scan the key
  381. for !unicode.IsSpace(r) && r != '=' {
  382. keyRunes = append(keyRunes, r)
  383. if r, ok = s.Next(); !ok {
  384. break
  385. }
  386. }
  387. // Skip any whitespace if we're not at the = yet
  388. if r != '=' {
  389. r, ok = s.SkipSpaces()
  390. }
  391. // The current character should be =
  392. if r != '=' || !ok {
  393. return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
  394. }
  395. // Skip any whitespace after the =
  396. if r, ok = s.SkipSpaces(); !ok {
  397. // If we reach the end here, the last value is just an empty string as per libpq.
  398. o.Set(string(keyRunes), "")
  399. break
  400. }
  401. if r != '\'' {
  402. for !unicode.IsSpace(r) {
  403. if r == '\\' {
  404. if r, ok = s.Next(); !ok {
  405. return fmt.Errorf(`missing character after backslash`)
  406. }
  407. }
  408. valRunes = append(valRunes, r)
  409. if r, ok = s.Next(); !ok {
  410. break
  411. }
  412. }
  413. } else {
  414. quote:
  415. for {
  416. if r, ok = s.Next(); !ok {
  417. return fmt.Errorf(`unterminated quoted string literal in connection string`)
  418. }
  419. switch r {
  420. case '\'':
  421. break quote
  422. case '\\':
  423. r, _ = s.Next()
  424. fallthrough
  425. default:
  426. valRunes = append(valRunes, r)
  427. }
  428. }
  429. }
  430. o.Set(string(keyRunes), string(valRunes))
  431. }
  432. return nil
  433. }
  434. func (cn *conn) isInTransaction() bool {
  435. return cn.txnStatus == txnStatusIdleInTransaction ||
  436. cn.txnStatus == txnStatusInFailedTransaction
  437. }
  438. func (cn *conn) checkIsInTransaction(intxn bool) {
  439. if cn.isInTransaction() != intxn {
  440. cn.bad = true
  441. errorf("unexpected transaction status %v", cn.txnStatus)
  442. }
  443. }
  444. func (cn *conn) Begin() (_ driver.Tx, err error) {
  445. if cn.bad {
  446. return nil, driver.ErrBadConn
  447. }
  448. defer cn.errRecover(&err)
  449. cn.checkIsInTransaction(false)
  450. _, commandTag, err := cn.simpleExec("BEGIN")
  451. if err != nil {
  452. return nil, err
  453. }
  454. if commandTag != "BEGIN" {
  455. cn.bad = true
  456. return nil, fmt.Errorf("unexpected command tag %s", commandTag)
  457. }
  458. if cn.txnStatus != txnStatusIdleInTransaction {
  459. cn.bad = true
  460. return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
  461. }
  462. return cn, nil
  463. }
  464. func (cn *conn) Commit() (err error) {
  465. if cn.bad {
  466. return driver.ErrBadConn
  467. }
  468. defer cn.errRecover(&err)
  469. cn.checkIsInTransaction(true)
  470. // We don't want the client to think that everything is okay if it tries
  471. // to commit a failed transaction. However, no matter what we return,
  472. // database/sql will release this connection back into the free connection
  473. // pool so we have to abort the current transaction here. Note that you
  474. // would get the same behaviour if you issued a COMMIT in a failed
  475. // transaction, so it's also the least surprising thing to do here.
  476. if cn.txnStatus == txnStatusInFailedTransaction {
  477. if err := cn.Rollback(); err != nil {
  478. return err
  479. }
  480. return ErrInFailedTransaction
  481. }
  482. _, commandTag, err := cn.simpleExec("COMMIT")
  483. if err != nil {
  484. if cn.isInTransaction() {
  485. cn.bad = true
  486. }
  487. return err
  488. }
  489. if commandTag != "COMMIT" {
  490. cn.bad = true
  491. return fmt.Errorf("unexpected command tag %s", commandTag)
  492. }
  493. cn.checkIsInTransaction(false)
  494. return nil
  495. }
  496. func (cn *conn) Rollback() (err error) {
  497. if cn.bad {
  498. return driver.ErrBadConn
  499. }
  500. defer cn.errRecover(&err)
  501. cn.checkIsInTransaction(true)
  502. _, commandTag, err := cn.simpleExec("ROLLBACK")
  503. if err != nil {
  504. if cn.isInTransaction() {
  505. cn.bad = true
  506. }
  507. return err
  508. }
  509. if commandTag != "ROLLBACK" {
  510. return fmt.Errorf("unexpected command tag %s", commandTag)
  511. }
  512. cn.checkIsInTransaction(false)
  513. return nil
  514. }
  515. func (cn *conn) gname() string {
  516. cn.namei++
  517. return strconv.FormatInt(int64(cn.namei), 10)
  518. }
  519. func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
  520. b := cn.writeBuf('Q')
  521. b.string(q)
  522. cn.send(b)
  523. for {
  524. t, r := cn.recv1()
  525. switch t {
  526. case 'C':
  527. res, commandTag = cn.parseComplete(r.string())
  528. case 'Z':
  529. cn.processReadyForQuery(r)
  530. if res == nil && err == nil {
  531. err = errUnexpectedReady
  532. }
  533. // done
  534. return
  535. case 'E':
  536. err = parseError(r)
  537. case 'I':
  538. res = emptyRows
  539. case 'T', 'D':
  540. // ignore any results
  541. default:
  542. cn.bad = true
  543. errorf("unknown response for simple query: %q", t)
  544. }
  545. }
  546. }
  547. func (cn *conn) simpleQuery(q string) (res *rows, err error) {
  548. defer cn.errRecover(&err)
  549. b := cn.writeBuf('Q')
  550. b.string(q)
  551. cn.send(b)
  552. for {
  553. t, r := cn.recv1()
  554. switch t {
  555. case 'C', 'I':
  556. // We allow queries which don't return any results through Query as
  557. // well as Exec. We still have to give database/sql a rows object
  558. // the user can close, though, to avoid connections from being
  559. // leaked. A "rows" with done=true works fine for that purpose.
  560. if err != nil {
  561. cn.bad = true
  562. errorf("unexpected message %q in simple query execution", t)
  563. }
  564. if res == nil {
  565. res = &rows{
  566. cn: cn,
  567. }
  568. }
  569. res.done = true
  570. case 'Z':
  571. cn.processReadyForQuery(r)
  572. // done
  573. return
  574. case 'E':
  575. res = nil
  576. err = parseError(r)
  577. case 'D':
  578. if res == nil {
  579. cn.bad = true
  580. errorf("unexpected DataRow in simple query execution")
  581. }
  582. // the query didn't fail; kick off to Next
  583. cn.saveMessage(t, r)
  584. return
  585. case 'T':
  586. // res might be non-nil here if we received a previous
  587. // CommandComplete, but that's fine; just overwrite it
  588. res = &rows{cn: cn}
  589. res.colNames, res.colFmts, res.colTyps = parsePortalRowDescribe(r)
  590. // To work around a bug in QueryRow in Go 1.2 and earlier, wait
  591. // until the first DataRow has been received.
  592. default:
  593. cn.bad = true
  594. errorf("unknown response for simple query: %q", t)
  595. }
  596. }
  597. }
  598. type noRows struct{}
  599. var emptyRows noRows
  600. var _ driver.Result = noRows{}
  601. func (noRows) LastInsertId() (int64, error) {
  602. return 0, errNoLastInsertId
  603. }
  604. func (noRows) RowsAffected() (int64, error) {
  605. return 0, errNoRowsAffected
  606. }
  607. // Decides which column formats to use for a prepared statement. The input is
  608. // an array of type oids, one element per result column.
  609. func decideColumnFormats(colTyps []oid.Oid, forceText bool) (colFmts []format, colFmtData []byte) {
  610. if len(colTyps) == 0 {
  611. return nil, colFmtDataAllText
  612. }
  613. colFmts = make([]format, len(colTyps))
  614. if forceText {
  615. return colFmts, colFmtDataAllText
  616. }
  617. allBinary := true
  618. allText := true
  619. for i, o := range colTyps {
  620. switch o {
  621. // This is the list of types to use binary mode for when receiving them
  622. // through a prepared statement. If a type appears in this list, it
  623. // must also be implemented in binaryDecode in encode.go.
  624. case oid.T_bytea:
  625. fallthrough
  626. case oid.T_int8:
  627. fallthrough
  628. case oid.T_int4:
  629. fallthrough
  630. case oid.T_int2:
  631. colFmts[i] = formatBinary
  632. allText = false
  633. default:
  634. allBinary = false
  635. }
  636. }
  637. if allBinary {
  638. return colFmts, colFmtDataAllBinary
  639. } else if allText {
  640. return colFmts, colFmtDataAllText
  641. } else {
  642. colFmtData = make([]byte, 2+len(colFmts)*2)
  643. binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
  644. for i, v := range colFmts {
  645. binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
  646. }
  647. return colFmts, colFmtData
  648. }
  649. }
  650. func (cn *conn) prepareTo(q, stmtName string) *stmt {
  651. st := &stmt{cn: cn, name: stmtName}
  652. b := cn.writeBuf('P')
  653. b.string(st.name)
  654. b.string(q)
  655. b.int16(0)
  656. b.next('D')
  657. b.byte('S')
  658. b.string(st.name)
  659. b.next('S')
  660. cn.send(b)
  661. cn.readParseResponse()
  662. st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
  663. st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
  664. cn.readReadyForQuery()
  665. return st
  666. }
  667. func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
  668. if cn.bad {
  669. return nil, driver.ErrBadConn
  670. }
  671. defer cn.errRecover(&err)
  672. if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
  673. return cn.prepareCopyIn(q)
  674. }
  675. return cn.prepareTo(q, cn.gname()), nil
  676. }
  677. func (cn *conn) Close() (err error) {
  678. // Skip cn.bad return here because we always want to close a connection.
  679. defer cn.errRecover(&err)
  680. // Ensure that cn.c.Close is always run. Since error handling is done with
  681. // panics and cn.errRecover, the Close must be in a defer.
  682. defer func() {
  683. cerr := cn.c.Close()
  684. if err == nil {
  685. err = cerr
  686. }
  687. }()
  688. // Don't go through send(); ListenerConn relies on us not scribbling on the
  689. // scratch buffer of this connection.
  690. return cn.sendSimpleMessage('X')
  691. }
  692. // Implement the "Queryer" interface
  693. func (cn *conn) Query(query string, args []driver.Value) (_ driver.Rows, err error) {
  694. if cn.bad {
  695. return nil, driver.ErrBadConn
  696. }
  697. defer cn.errRecover(&err)
  698. // Check to see if we can use the "simpleQuery" interface, which is
  699. // *much* faster than going through prepare/exec
  700. if len(args) == 0 {
  701. return cn.simpleQuery(query)
  702. }
  703. if cn.binaryParameters {
  704. cn.sendBinaryModeQuery(query, args)
  705. cn.readParseResponse()
  706. cn.readBindResponse()
  707. rows := &rows{cn: cn}
  708. rows.colNames, rows.colFmts, rows.colTyps = cn.readPortalDescribeResponse()
  709. cn.postExecuteWorkaround()
  710. return rows, nil
  711. } else {
  712. st := cn.prepareTo(query, "")
  713. st.exec(args)
  714. return &rows{
  715. cn: cn,
  716. colNames: st.colNames,
  717. colTyps: st.colTyps,
  718. colFmts: st.colFmts,
  719. }, nil
  720. }
  721. }
  722. // Implement the optional "Execer" interface for one-shot queries
  723. func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
  724. if cn.bad {
  725. return nil, driver.ErrBadConn
  726. }
  727. defer cn.errRecover(&err)
  728. // Check to see if we can use the "simpleExec" interface, which is
  729. // *much* faster than going through prepare/exec
  730. if len(args) == 0 {
  731. // ignore commandTag, our caller doesn't care
  732. r, _, err := cn.simpleExec(query)
  733. return r, err
  734. }
  735. if cn.binaryParameters {
  736. cn.sendBinaryModeQuery(query, args)
  737. cn.readParseResponse()
  738. cn.readBindResponse()
  739. cn.readPortalDescribeResponse()
  740. cn.postExecuteWorkaround()
  741. res, _, err = cn.readExecuteResponse("Execute")
  742. return res, err
  743. } else {
  744. // Use the unnamed statement to defer planning until bind
  745. // time, or else value-based selectivity estimates cannot be
  746. // used.
  747. st := cn.prepareTo(query, "")
  748. r, err := st.Exec(args)
  749. if err != nil {
  750. panic(err)
  751. }
  752. return r, err
  753. }
  754. }
  755. func (cn *conn) send(m *writeBuf) {
  756. _, err := cn.c.Write(m.wrap())
  757. if err != nil {
  758. panic(err)
  759. }
  760. }
  761. func (cn *conn) sendStartupPacket(m *writeBuf) {
  762. // sanity check
  763. if m.buf[0] != 0 {
  764. panic("oops")
  765. }
  766. _, err := cn.c.Write((m.wrap())[1:])
  767. if err != nil {
  768. panic(err)
  769. }
  770. }
  771. // Send a message of type typ to the server on the other end of cn. The
  772. // message should have no payload. This method does not use the scratch
  773. // buffer.
  774. func (cn *conn) sendSimpleMessage(typ byte) (err error) {
  775. _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
  776. return err
  777. }
  778. // saveMessage memorizes a message and its buffer in the conn struct.
  779. // recvMessage will then return these values on the next call to it. This
  780. // method is useful in cases where you have to see what the next message is
  781. // going to be (e.g. to see whether it's an error or not) but you can't handle
  782. // the message yourself.
  783. func (cn *conn) saveMessage(typ byte, buf *readBuf) {
  784. if cn.saveMessageType != 0 {
  785. cn.bad = true
  786. errorf("unexpected saveMessageType %d", cn.saveMessageType)
  787. }
  788. cn.saveMessageType = typ
  789. cn.saveMessageBuffer = *buf
  790. }
  791. // recvMessage receives any message from the backend, or returns an error if
  792. // a problem occurred while reading the message.
  793. func (cn *conn) recvMessage(r *readBuf) (byte, error) {
  794. // workaround for a QueryRow bug, see exec
  795. if cn.saveMessageType != 0 {
  796. t := cn.saveMessageType
  797. *r = cn.saveMessageBuffer
  798. cn.saveMessageType = 0
  799. cn.saveMessageBuffer = nil
  800. return t, nil
  801. }
  802. x := cn.scratch[:5]
  803. _, err := io.ReadFull(cn.buf, x)
  804. if err != nil {
  805. return 0, err
  806. }
  807. // read the type and length of the message that follows
  808. t := x[0]
  809. n := int(binary.BigEndian.Uint32(x[1:])) - 4
  810. var y []byte
  811. if n <= len(cn.scratch) {
  812. y = cn.scratch[:n]
  813. } else {
  814. y = make([]byte, n)
  815. }
  816. _, err = io.ReadFull(cn.buf, y)
  817. if err != nil {
  818. return 0, err
  819. }
  820. *r = y
  821. return t, nil
  822. }
  823. // recv receives a message from the backend, but if an error happened while
  824. // reading the message or the received message was an ErrorResponse, it panics.
  825. // NoticeResponses are ignored. This function should generally be used only
  826. // during the startup sequence.
  827. func (cn *conn) recv() (t byte, r *readBuf) {
  828. for {
  829. var err error
  830. r = &readBuf{}
  831. t, err = cn.recvMessage(r)
  832. if err != nil {
  833. panic(err)
  834. }
  835. switch t {
  836. case 'E':
  837. panic(parseError(r))
  838. case 'N':
  839. // ignore
  840. default:
  841. return
  842. }
  843. }
  844. }
  845. // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
  846. // the caller to avoid an allocation.
  847. func (cn *conn) recv1Buf(r *readBuf) byte {
  848. for {
  849. t, err := cn.recvMessage(r)
  850. if err != nil {
  851. panic(err)
  852. }
  853. switch t {
  854. case 'A', 'N':
  855. // ignore
  856. case 'S':
  857. cn.processParameterStatus(r)
  858. default:
  859. return t
  860. }
  861. }
  862. }
  863. // recv1 receives a message from the backend, panicking if an error occurs
  864. // while attempting to read it. All asynchronous messages are ignored, with
  865. // the exception of ErrorResponse.
  866. func (cn *conn) recv1() (t byte, r *readBuf) {
  867. r = &readBuf{}
  868. t = cn.recv1Buf(r)
  869. return t, r
  870. }
  871. func (cn *conn) ssl(o values) {
  872. verifyCaOnly := false
  873. tlsConf := tls.Config{}
  874. switch mode := o.Get("sslmode"); mode {
  875. // "require" is the default.
  876. case "", "require":
  877. // We must skip TLS's own verification since it requires full
  878. // verification since Go 1.3.
  879. tlsConf.InsecureSkipVerify = true
  880. // From http://www.postgresql.org/docs/current/static/libpq-ssl.html:
  881. // Note: For backwards compatibility with earlier versions of PostgreSQL, if a
  882. // root CA file exists, the behavior of sslmode=require will be the same as
  883. // that of verify-ca, meaning the server certificate is validated against the
  884. // CA. Relying on this behavior is discouraged, and applications that need
  885. // certificate validation should always use verify-ca or verify-full.
  886. if _, err := os.Stat(o.Get("sslrootcert")); err == nil {
  887. verifyCaOnly = true
  888. } else {
  889. o.Set("sslrootcert", "")
  890. }
  891. case "verify-ca":
  892. // We must skip TLS's own verification since it requires full
  893. // verification since Go 1.3.
  894. tlsConf.InsecureSkipVerify = true
  895. verifyCaOnly = true
  896. case "verify-full":
  897. tlsConf.ServerName = o.Get("host")
  898. case "disable":
  899. return
  900. default:
  901. errorf(`unsupported sslmode %q; only "require" (default), "verify-full", "verify-ca", and "disable" supported`, mode)
  902. }
  903. cn.setupSSLClientCertificates(&tlsConf, o)
  904. cn.setupSSLCA(&tlsConf, o)
  905. w := cn.writeBuf(0)
  906. w.int32(80877103)
  907. cn.sendStartupPacket(w)
  908. b := cn.scratch[:1]
  909. _, err := io.ReadFull(cn.c, b)
  910. if err != nil {
  911. panic(err)
  912. }
  913. if b[0] != 'S' {
  914. panic(ErrSSLNotSupported)
  915. }
  916. client := tls.Client(cn.c, &tlsConf)
  917. if verifyCaOnly {
  918. cn.verifyCA(client, &tlsConf)
  919. }
  920. cn.c = client
  921. }
  922. // verifyCA carries out a TLS handshake to the server and verifies the
  923. // presented certificate against the effective CA, i.e. the one specified in
  924. // sslrootcert or the system CA if sslrootcert was not specified.
  925. func (cn *conn) verifyCA(client *tls.Conn, tlsConf *tls.Config) {
  926. err := client.Handshake()
  927. if err != nil {
  928. panic(err)
  929. }
  930. certs := client.ConnectionState().PeerCertificates
  931. opts := x509.VerifyOptions{
  932. DNSName: client.ConnectionState().ServerName,
  933. Intermediates: x509.NewCertPool(),
  934. Roots: tlsConf.RootCAs,
  935. }
  936. for i, cert := range certs {
  937. if i == 0 {
  938. continue
  939. }
  940. opts.Intermediates.AddCert(cert)
  941. }
  942. _, err = certs[0].Verify(opts)
  943. if err != nil {
  944. panic(err)
  945. }
  946. }
  947. // This function sets up SSL client certificates based on either the "sslkey"
  948. // and "sslcert" settings (possibly set via the environment variables PGSSLKEY
  949. // and PGSSLCERT, respectively), or if they aren't set, from the .postgresql
  950. // directory in the user's home directory. If the file paths are set
  951. // explicitly, the files must exist. The key file must also not be
  952. // world-readable, or this function will panic with
  953. // ErrSSLKeyHasWorldPermissions.
  954. func (cn *conn) setupSSLClientCertificates(tlsConf *tls.Config, o values) {
  955. var missingOk bool
  956. sslkey := o.Get("sslkey")
  957. sslcert := o.Get("sslcert")
  958. if sslkey != "" && sslcert != "" {
  959. // If the user has set an sslkey and sslcert, they *must* exist.
  960. missingOk = false
  961. } else {
  962. // Automatically load certificates from ~/.postgresql.
  963. user, err := user.Current()
  964. if err != nil {
  965. // user.Current() might fail when cross-compiling. We have to
  966. // ignore the error and continue without client certificates, since
  967. // we wouldn't know where to load them from.
  968. return
  969. }
  970. sslkey = filepath.Join(user.HomeDir, ".postgresql", "postgresql.key")
  971. sslcert = filepath.Join(user.HomeDir, ".postgresql", "postgresql.crt")
  972. missingOk = true
  973. }
  974. // Check that both files exist, and report the error or stop, depending on
  975. // which behaviour we want. Note that we don't do any more extensive
  976. // checks than this (such as checking that the paths aren't directories);
  977. // LoadX509KeyPair() will take care of the rest.
  978. keyfinfo, err := os.Stat(sslkey)
  979. if err != nil && missingOk {
  980. return
  981. } else if err != nil {
  982. panic(err)
  983. }
  984. _, err = os.Stat(sslcert)
  985. if err != nil && missingOk {
  986. return
  987. } else if err != nil {
  988. panic(err)
  989. }
  990. // If we got this far, the key file must also have the correct permissions
  991. kmode := keyfinfo.Mode()
  992. if kmode != kmode&0600 {
  993. panic(ErrSSLKeyHasWorldPermissions)
  994. }
  995. cert, err := tls.LoadX509KeyPair(sslcert, sslkey)
  996. if err != nil {
  997. panic(err)
  998. }
  999. tlsConf.Certificates = []tls.Certificate{cert}
  1000. }
  1001. // Sets up RootCAs in the TLS configuration if sslrootcert is set.
  1002. func (cn *conn) setupSSLCA(tlsConf *tls.Config, o values) {
  1003. if sslrootcert := o.Get("sslrootcert"); sslrootcert != "" {
  1004. tlsConf.RootCAs = x509.NewCertPool()
  1005. cert, err := ioutil.ReadFile(sslrootcert)
  1006. if err != nil {
  1007. panic(err)
  1008. }
  1009. ok := tlsConf.RootCAs.AppendCertsFromPEM(cert)
  1010. if !ok {
  1011. errorf("couldn't parse pem in sslrootcert")
  1012. }
  1013. }
  1014. }
  1015. // isDriverSetting returns true iff a setting is purely for configuring the
  1016. // driver's options and should not be sent to the server in the connection
  1017. // startup packet.
  1018. func isDriverSetting(key string) bool {
  1019. switch key {
  1020. case "host", "port":
  1021. return true
  1022. case "password":
  1023. return true
  1024. case "sslmode", "sslcert", "sslkey", "sslrootcert":
  1025. return true
  1026. case "fallback_application_name":
  1027. return true
  1028. case "connect_timeout":
  1029. return true
  1030. case "disable_prepared_binary_result":
  1031. return true
  1032. case "binary_parameters":
  1033. return true
  1034. default:
  1035. return false
  1036. }
  1037. }
  1038. func (cn *conn) startup(o values) {
  1039. w := cn.writeBuf(0)
  1040. w.int32(196608)
  1041. // Send the backend the name of the database we want to connect to, and the
  1042. // user we want to connect as. Additionally, we send over any run-time
  1043. // parameters potentially included in the connection string. If the server
  1044. // doesn't recognize any of them, it will reply with an error.
  1045. for k, v := range o {
  1046. if isDriverSetting(k) {
  1047. // skip options which can't be run-time parameters
  1048. continue
  1049. }
  1050. // The protocol requires us to supply the database name as "database"
  1051. // instead of "dbname".
  1052. if k == "dbname" {
  1053. k = "database"
  1054. }
  1055. w.string(k)
  1056. w.string(v)
  1057. }
  1058. w.string("")
  1059. cn.sendStartupPacket(w)
  1060. for {
  1061. t, r := cn.recv()
  1062. switch t {
  1063. case 'K':
  1064. case 'S':
  1065. cn.processParameterStatus(r)
  1066. case 'R':
  1067. cn.auth(r, o)
  1068. case 'Z':
  1069. cn.processReadyForQuery(r)
  1070. return
  1071. default:
  1072. errorf("unknown response for startup: %q", t)
  1073. }
  1074. }
  1075. }
  1076. func (cn *conn) auth(r *readBuf, o values) {
  1077. switch code := r.int32(); code {
  1078. case 0:
  1079. // OK
  1080. case 3:
  1081. w := cn.writeBuf('p')
  1082. w.string(o.Get("password"))
  1083. cn.send(w)
  1084. t, r := cn.recv()
  1085. if t != 'R' {
  1086. errorf("unexpected password response: %q", t)
  1087. }
  1088. if r.int32() != 0 {
  1089. errorf("unexpected authentication response: %q", t)
  1090. }
  1091. case 5:
  1092. s := string(r.next(4))
  1093. w := cn.writeBuf('p')
  1094. w.string("md5" + md5s(md5s(o.Get("password")+o.Get("user"))+s))
  1095. cn.send(w)
  1096. t, r := cn.recv()
  1097. if t != 'R' {
  1098. errorf("unexpected password response: %q", t)
  1099. }
  1100. if r.int32() != 0 {
  1101. errorf("unexpected authentication response: %q", t)
  1102. }
  1103. default:
  1104. errorf("unknown authentication response: %d", code)
  1105. }
  1106. }
  1107. type format int
  1108. const formatText format = 0
  1109. const formatBinary format = 1
  1110. // One result-column format code with the value 1 (i.e. all binary).
  1111. var colFmtDataAllBinary []byte = []byte{0, 1, 0, 1}
  1112. // No result-column format codes (i.e. all text).
  1113. var colFmtDataAllText []byte = []byte{0, 0}
  1114. type stmt struct {
  1115. cn *conn
  1116. name string
  1117. colNames []string
  1118. colFmts []format
  1119. colFmtData []byte
  1120. colTyps []oid.Oid
  1121. paramTyps []oid.Oid
  1122. closed bool
  1123. }
  1124. func (st *stmt) Close() (err error) {
  1125. if st.closed {
  1126. return nil
  1127. }
  1128. if st.cn.bad {
  1129. return driver.ErrBadConn
  1130. }
  1131. defer st.cn.errRecover(&err)
  1132. w := st.cn.writeBuf('C')
  1133. w.byte('S')
  1134. w.string(st.name)
  1135. st.cn.send(w)
  1136. st.cn.send(st.cn.writeBuf('S'))
  1137. t, _ := st.cn.recv1()
  1138. if t != '3' {
  1139. st.cn.bad = true
  1140. errorf("unexpected close response: %q", t)
  1141. }
  1142. st.closed = true
  1143. t, r := st.cn.recv1()
  1144. if t != 'Z' {
  1145. st.cn.bad = true
  1146. errorf("expected ready for query, but got: %q", t)
  1147. }
  1148. st.cn.processReadyForQuery(r)
  1149. return nil
  1150. }
  1151. func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1152. if st.cn.bad {
  1153. return nil, driver.ErrBadConn
  1154. }
  1155. defer st.cn.errRecover(&err)
  1156. st.exec(v)
  1157. return &rows{
  1158. cn: st.cn,
  1159. colNames: st.colNames,
  1160. colTyps: st.colTyps,
  1161. colFmts: st.colFmts,
  1162. }, nil
  1163. }
  1164. func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1165. if st.cn.bad {
  1166. return nil, driver.ErrBadConn
  1167. }
  1168. defer st.cn.errRecover(&err)
  1169. st.exec(v)
  1170. res, _, err = st.cn.readExecuteResponse("simple query")
  1171. return res, err
  1172. }
  1173. func (st *stmt) exec(v []driver.Value) {
  1174. if len(v) >= 65536 {
  1175. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
  1176. }
  1177. if len(v) != len(st.paramTyps) {
  1178. errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1179. }
  1180. cn := st.cn
  1181. w := cn.writeBuf('B')
  1182. w.byte(0) // unnamed portal
  1183. w.string(st.name)
  1184. if cn.binaryParameters {
  1185. cn.sendBinaryParameters(w, v)
  1186. } else {
  1187. w.int16(0)
  1188. w.int16(len(v))
  1189. for i, x := range v {
  1190. if x == nil {
  1191. w.int32(-1)
  1192. } else {
  1193. b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1194. w.int32(len(b))
  1195. w.bytes(b)
  1196. }
  1197. }
  1198. }
  1199. w.bytes(st.colFmtData)
  1200. w.next('E')
  1201. w.byte(0)
  1202. w.int32(0)
  1203. w.next('S')
  1204. cn.send(w)
  1205. cn.readBindResponse()
  1206. cn.postExecuteWorkaround()
  1207. }
  1208. func (st *stmt) NumInput() int {
  1209. return len(st.paramTyps)
  1210. }
  1211. // parseComplete parses the "command tag" from a CommandComplete message, and
  1212. // returns the number of rows affected (if applicable) and a string
  1213. // identifying only the command that was executed, e.g. "ALTER TABLE". If the
  1214. // command tag could not be parsed, parseComplete panics.
  1215. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1216. commandsWithAffectedRows := []string{
  1217. "SELECT ",
  1218. // INSERT is handled below
  1219. "UPDATE ",
  1220. "DELETE ",
  1221. "FETCH ",
  1222. "MOVE ",
  1223. "COPY ",
  1224. }
  1225. var affectedRows *string
  1226. for _, tag := range commandsWithAffectedRows {
  1227. if strings.HasPrefix(commandTag, tag) {
  1228. t := commandTag[len(tag):]
  1229. affectedRows = &t
  1230. commandTag = tag[:len(tag)-1]
  1231. break
  1232. }
  1233. }
  1234. // INSERT also includes the oid of the inserted row in its command tag.
  1235. // Oids in user tables are deprecated, and the oid is only returned when
  1236. // exactly one row is inserted, so it's unlikely to be of value to any
  1237. // real-world application and we can ignore it.
  1238. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1239. parts := strings.Split(commandTag, " ")
  1240. if len(parts) != 3 {
  1241. cn.bad = true
  1242. errorf("unexpected INSERT command tag %s", commandTag)
  1243. }
  1244. affectedRows = &parts[len(parts)-1]
  1245. commandTag = "INSERT"
  1246. }
  1247. // There should be no affected rows attached to the tag, just return it
  1248. if affectedRows == nil {
  1249. return driver.RowsAffected(0), commandTag
  1250. }
  1251. n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1252. if err != nil {
  1253. cn.bad = true
  1254. errorf("could not parse commandTag: %s", err)
  1255. }
  1256. return driver.RowsAffected(n), commandTag
  1257. }
  1258. type rows struct {
  1259. cn *conn
  1260. colNames []string
  1261. colTyps []oid.Oid
  1262. colFmts []format
  1263. done bool
  1264. rb readBuf
  1265. }
  1266. func (rs *rows) Close() error {
  1267. // no need to look at cn.bad as Next() will
  1268. for {
  1269. err := rs.Next(nil)
  1270. switch err {
  1271. case nil:
  1272. case io.EOF:
  1273. return nil
  1274. default:
  1275. return err
  1276. }
  1277. }
  1278. }
  1279. func (rs *rows) Columns() []string {
  1280. return rs.colNames
  1281. }
  1282. func (rs *rows) Next(dest []driver.Value) (err error) {
  1283. if rs.done {
  1284. return io.EOF
  1285. }
  1286. conn := rs.cn
  1287. if conn.bad {
  1288. return driver.ErrBadConn
  1289. }
  1290. defer conn.errRecover(&err)
  1291. for {
  1292. t := conn.recv1Buf(&rs.rb)
  1293. switch t {
  1294. case 'E':
  1295. err = parseError(&rs.rb)
  1296. case 'C', 'I':
  1297. continue
  1298. case 'Z':
  1299. conn.processReadyForQuery(&rs.rb)
  1300. rs.done = true
  1301. if err != nil {
  1302. return err
  1303. }
  1304. return io.EOF
  1305. case 'D':
  1306. n := rs.rb.int16()
  1307. if err != nil {
  1308. conn.bad = true
  1309. errorf("unexpected DataRow after error %s", err)
  1310. }
  1311. if n < len(dest) {
  1312. dest = dest[:n]
  1313. }
  1314. for i := range dest {
  1315. l := rs.rb.int32()
  1316. if l == -1 {
  1317. dest[i] = nil
  1318. continue
  1319. }
  1320. dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i], rs.colFmts[i])
  1321. }
  1322. return
  1323. default:
  1324. errorf("unexpected message after execute: %q", t)
  1325. }
  1326. }
  1327. }
  1328. // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1329. // used as part of an SQL statement. For example:
  1330. //
  1331. // tblname := "my_table"
  1332. // data := "my_data"
  1333. // err = db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", pq.QuoteIdentifier(tblname)), data)
  1334. //
  1335. // Any double quotes in name will be escaped. The quoted identifier will be
  1336. // case sensitive when used in a query. If the input string contains a zero
  1337. // byte, the result will be truncated immediately before it.
  1338. func QuoteIdentifier(name string) string {
  1339. end := strings.IndexRune(name, 0)
  1340. if end > -1 {
  1341. name = name[:end]
  1342. }
  1343. return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1344. }
  1345. func md5s(s string) string {
  1346. h := md5.New()
  1347. h.Write([]byte(s))
  1348. return fmt.Sprintf("%x", h.Sum(nil))
  1349. }
  1350. func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1351. // Do one pass over the parameters to see if we're going to send any of
  1352. // them over in binary. If we are, create a paramFormats array at the
  1353. // same time.
  1354. var paramFormats []int
  1355. for i, x := range args {
  1356. _, ok := x.([]byte)
  1357. if ok {
  1358. if paramFormats == nil {
  1359. paramFormats = make([]int, len(args))
  1360. }
  1361. paramFormats[i] = 1
  1362. }
  1363. }
  1364. if paramFormats == nil {
  1365. b.int16(0)
  1366. } else {
  1367. b.int16(len(paramFormats))
  1368. for _, x := range paramFormats {
  1369. b.int16(x)
  1370. }
  1371. }
  1372. b.int16(len(args))
  1373. for _, x := range args {
  1374. if x == nil {
  1375. b.int32(-1)
  1376. } else {
  1377. datum := binaryEncode(&cn.parameterStatus, x)
  1378. b.int32(len(datum))
  1379. b.bytes(datum)
  1380. }
  1381. }
  1382. }
  1383. func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1384. if len(args) >= 65536 {
  1385. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
  1386. }
  1387. b := cn.writeBuf('P')
  1388. b.byte(0) // unnamed statement
  1389. b.string(query)
  1390. b.int16(0)
  1391. b.next('B')
  1392. b.int16(0) // unnamed portal and statement
  1393. cn.sendBinaryParameters(b, args)
  1394. b.bytes(colFmtDataAllText)
  1395. b.next('D')
  1396. b.byte('P')
  1397. b.byte(0) // unnamed portal
  1398. b.next('E')
  1399. b.byte(0)
  1400. b.int32(0)
  1401. b.next('S')
  1402. cn.send(b)
  1403. }
  1404. func (c *conn) processParameterStatus(r *readBuf) {
  1405. var err error
  1406. param := r.string()
  1407. switch param {
  1408. case "server_version":
  1409. var major1 int
  1410. var major2 int
  1411. var minor int
  1412. _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
  1413. if err == nil {
  1414. c.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
  1415. }
  1416. case "TimeZone":
  1417. c.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1418. if err != nil {
  1419. c.parameterStatus.currentLocation = nil
  1420. }
  1421. default:
  1422. // ignore
  1423. }
  1424. }
  1425. func (c *conn) processReadyForQuery(r *readBuf) {
  1426. c.txnStatus = transactionStatus(r.byte())
  1427. }
  1428. func (cn *conn) readReadyForQuery() {
  1429. t, r := cn.recv1()
  1430. switch t {
  1431. case 'Z':
  1432. cn.processReadyForQuery(r)
  1433. return
  1434. default:
  1435. cn.bad = true
  1436. errorf("unexpected message %q; expected ReadyForQuery", t)
  1437. }
  1438. }
  1439. func (cn *conn) readParseResponse() {
  1440. t, r := cn.recv1()
  1441. switch t {
  1442. case '1':
  1443. return
  1444. case 'E':
  1445. err := parseError(r)
  1446. cn.readReadyForQuery()
  1447. panic(err)
  1448. default:
  1449. cn.bad = true
  1450. errorf("unexpected Parse response %q", t)
  1451. }
  1452. }
  1453. func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []oid.Oid) {
  1454. for {
  1455. t, r := cn.recv1()
  1456. switch t {
  1457. case 't':
  1458. nparams := r.int16()
  1459. paramTyps = make([]oid.Oid, nparams)
  1460. for i := range paramTyps {
  1461. paramTyps[i] = r.oid()
  1462. }
  1463. case 'n':
  1464. return paramTyps, nil, nil
  1465. case 'T':
  1466. colNames, colTyps = parseStatementRowDescribe(r)
  1467. return paramTyps, colNames, colTyps
  1468. case 'E':
  1469. err := parseError(r)
  1470. cn.readReadyForQuery()
  1471. panic(err)
  1472. default:
  1473. cn.bad = true
  1474. errorf("unexpected Describe statement response %q", t)
  1475. }
  1476. }
  1477. }
  1478. func (cn *conn) readPortalDescribeResponse() (colNames []string, colFmts []format, colTyps []oid.Oid) {
  1479. t, r := cn.recv1()
  1480. switch t {
  1481. case 'T':
  1482. return parsePortalRowDescribe(r)
  1483. case 'n':
  1484. return nil, nil, nil
  1485. case 'E':
  1486. err := parseError(r)
  1487. cn.readReadyForQuery()
  1488. panic(err)
  1489. default:
  1490. cn.bad = true
  1491. errorf("unexpected Describe response %q", t)
  1492. }
  1493. panic("not reached")
  1494. }
  1495. func (cn *conn) readBindResponse() {
  1496. t, r := cn.recv1()
  1497. switch t {
  1498. case '2':
  1499. return
  1500. case 'E':
  1501. err := parseError(r)
  1502. cn.readReadyForQuery()
  1503. panic(err)
  1504. default:
  1505. cn.bad = true
  1506. errorf("unexpected Bind response %q", t)
  1507. }
  1508. }
  1509. func (cn *conn) postExecuteWorkaround() {
  1510. // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1511. // any errors from rows.Next, which masks errors that happened during the
  1512. // execution of the query. To avoid the problem in common cases, we wait
  1513. // here for one more message from the database. If it's not an error the
  1514. // query will likely succeed (or perhaps has already, if it's a
  1515. // CommandComplete), so we push the message into the conn struct; recv1
  1516. // will return it as the next message for rows.Next or rows.Close.
  1517. // However, if it's an error, we wait until ReadyForQuery and then return
  1518. // the error to our caller.
  1519. for {
  1520. t, r := cn.recv1()
  1521. switch t {
  1522. case 'E':
  1523. err := parseError(r)
  1524. cn.readReadyForQuery()
  1525. panic(err)
  1526. case 'C', 'D', 'I':
  1527. // the query didn't fail, but we can't process this message
  1528. cn.saveMessage(t, r)
  1529. return
  1530. default:
  1531. cn.bad = true
  1532. errorf("unexpected message during extended query execution: %q", t)
  1533. }
  1534. }
  1535. }
  1536. // Only for Exec(), since we ignore the returned data
  1537. func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
  1538. for {
  1539. t, r := cn.recv1()
  1540. switch t {
  1541. case 'C':
  1542. if err != nil {
  1543. cn.bad = true
  1544. errorf("unexpected CommandComplete after error %s", err)
  1545. }
  1546. res, commandTag = cn.parseComplete(r.string())
  1547. case 'Z':
  1548. cn.processReadyForQuery(r)
  1549. if res == nil && err == nil {
  1550. err = errUnexpectedReady
  1551. }
  1552. return res, commandTag, err
  1553. case 'E':
  1554. err = parseError(r)
  1555. case 'T', 'D', 'I':
  1556. if err != nil {
  1557. cn.bad = true
  1558. errorf("unexpected %q after error %s", t, err)
  1559. }
  1560. if t == 'I' {
  1561. res = emptyRows
  1562. }
  1563. // ignore any results
  1564. default:
  1565. cn.bad = true
  1566. errorf("unknown %s response: %q", protocolState, t)
  1567. }
  1568. }
  1569. }
  1570. func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []oid.Oid) {
  1571. n := r.int16()
  1572. colNames = make([]string, n)
  1573. colTyps = make([]oid.Oid, n)
  1574. for i := range colNames {
  1575. colNames[i] = r.string()
  1576. r.next(6)
  1577. colTyps[i] = r.oid()
  1578. r.next(6)
  1579. // format code not known when describing a statement; always 0
  1580. r.next(2)
  1581. }
  1582. return
  1583. }
  1584. func parsePortalRowDescribe(r *readBuf) (colNames []string, colFmts []format, colTyps []oid.Oid) {
  1585. n := r.int16()
  1586. colNames = make([]string, n)
  1587. colFmts = make([]format, n)
  1588. colTyps = make([]oid.Oid, n)
  1589. for i := range colNames {
  1590. colNames[i] = r.string()
  1591. r.next(6)
  1592. colTyps[i] = r.oid()
  1593. r.next(6)
  1594. colFmts[i] = format(r.int16())
  1595. }
  1596. return
  1597. }
  1598. // parseEnviron tries to mimic some of libpq's environment handling
  1599. //
  1600. // To ease testing, it does not directly reference os.Environ, but is
  1601. // designed to accept its output.
  1602. //
  1603. // Environment-set connection information is intended to have a higher
  1604. // precedence than a library default but lower than any explicitly
  1605. // passed information (such as in the URL or connection string).
  1606. func parseEnviron(env []string) (out map[string]string) {
  1607. out = make(map[string]string)
  1608. for _, v := range env {
  1609. parts := strings.SplitN(v, "=", 2)
  1610. accrue := func(keyname string) {
  1611. out[keyname] = parts[1]
  1612. }
  1613. unsupported := func() {
  1614. panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1615. }
  1616. // The order of these is the same as is seen in the
  1617. // PostgreSQL 9.1 manual. Unsupported but well-defined
  1618. // keys cause a panic; these should be unset prior to
  1619. // execution. Options which pq expects to be set to a
  1620. // certain value are allowed, but must be set to that
  1621. // value if present (they can, of course, be absent).
  1622. switch parts[0] {
  1623. case "PGHOST":
  1624. accrue("host")
  1625. case "PGHOSTADDR":
  1626. unsupported()
  1627. case "PGPORT":
  1628. accrue("port")
  1629. case "PGDATABASE":
  1630. accrue("dbname")
  1631. case "PGUSER":
  1632. accrue("user")
  1633. case "PGPASSWORD":
  1634. accrue("password")
  1635. case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
  1636. unsupported()
  1637. case "PGOPTIONS":
  1638. accrue("options")
  1639. case "PGAPPNAME":
  1640. accrue("application_name")
  1641. case "PGSSLMODE":
  1642. accrue("sslmode")
  1643. case "PGSSLCERT":
  1644. accrue("sslcert")
  1645. case "PGSSLKEY":
  1646. accrue("sslkey")
  1647. case "PGSSLROOTCERT":
  1648. accrue("sslrootcert")
  1649. case "PGREQUIRESSL", "PGSSLCRL":
  1650. unsupported()
  1651. case "PGREQUIREPEER":
  1652. unsupported()
  1653. case "PGKRBSRVNAME", "PGGSSLIB":
  1654. unsupported()
  1655. case "PGCONNECT_TIMEOUT":
  1656. accrue("connect_timeout")
  1657. case "PGCLIENTENCODING":
  1658. accrue("client_encoding")
  1659. case "PGDATESTYLE":
  1660. accrue("datestyle")
  1661. case "PGTZ":
  1662. accrue("timezone")
  1663. case "PGGEQO":
  1664. accrue("geqo")
  1665. case "PGSYSCONFDIR", "PGLOCALEDIR":
  1666. unsupported()
  1667. }
  1668. }
  1669. return out
  1670. }
  1671. // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1672. func isUTF8(name string) bool {
  1673. // Recognize all sorts of silly things as "UTF-8", like Postgres does
  1674. s := strings.Map(alnumLowerASCII, name)
  1675. return s == "utf8" || s == "unicode"
  1676. }
  1677. func alnumLowerASCII(ch rune) rune {
  1678. if 'A' <= ch && ch <= 'Z' {
  1679. return ch + ('a' - 'A')
  1680. }
  1681. if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1682. return ch
  1683. }
  1684. return -1 // discard
  1685. }