encode.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589
  1. package pq
  2. import (
  3. "bytes"
  4. "database/sql/driver"
  5. "encoding/binary"
  6. "encoding/hex"
  7. "errors"
  8. "fmt"
  9. "math"
  10. "strconv"
  11. "strings"
  12. "sync"
  13. "time"
  14. "github.com/lib/pq/oid"
  15. )
  16. func binaryEncode(parameterStatus *parameterStatus, x interface{}) []byte {
  17. switch v := x.(type) {
  18. case []byte:
  19. return v
  20. default:
  21. return encode(parameterStatus, x, oid.T_unknown)
  22. }
  23. }
  24. func encode(parameterStatus *parameterStatus, x interface{}, pgtypOid oid.Oid) []byte {
  25. switch v := x.(type) {
  26. case int64:
  27. return strconv.AppendInt(nil, v, 10)
  28. case float64:
  29. return strconv.AppendFloat(nil, v, 'f', -1, 64)
  30. case []byte:
  31. if pgtypOid == oid.T_bytea {
  32. return encodeBytea(parameterStatus.serverVersion, v)
  33. }
  34. return v
  35. case string:
  36. if pgtypOid == oid.T_bytea {
  37. return encodeBytea(parameterStatus.serverVersion, []byte(v))
  38. }
  39. return []byte(v)
  40. case bool:
  41. return strconv.AppendBool(nil, v)
  42. case time.Time:
  43. return formatTs(v)
  44. default:
  45. errorf("encode: unknown type for %T", v)
  46. }
  47. panic("not reached")
  48. }
  49. func decode(parameterStatus *parameterStatus, s []byte, typ oid.Oid, f format) interface{} {
  50. switch f {
  51. case formatBinary:
  52. return binaryDecode(parameterStatus, s, typ)
  53. case formatText:
  54. return textDecode(parameterStatus, s, typ)
  55. default:
  56. panic("not reached")
  57. }
  58. }
  59. func binaryDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
  60. switch typ {
  61. case oid.T_bytea:
  62. return s
  63. case oid.T_int8:
  64. return int64(binary.BigEndian.Uint64(s))
  65. case oid.T_int4:
  66. return int64(int32(binary.BigEndian.Uint32(s)))
  67. case oid.T_int2:
  68. return int64(int16(binary.BigEndian.Uint16(s)))
  69. default:
  70. errorf("don't know how to decode binary parameter of type %d", uint32(typ))
  71. }
  72. panic("not reached")
  73. }
  74. func textDecode(parameterStatus *parameterStatus, s []byte, typ oid.Oid) interface{} {
  75. switch typ {
  76. case oid.T_char, oid.T_varchar, oid.T_text:
  77. return string(s)
  78. case oid.T_bytea:
  79. b, err := parseBytea(s)
  80. if err != nil {
  81. errorf("%s", err)
  82. }
  83. return b
  84. case oid.T_timestamptz:
  85. return parseTs(parameterStatus.currentLocation, string(s))
  86. case oid.T_timestamp, oid.T_date:
  87. return parseTs(nil, string(s))
  88. case oid.T_time:
  89. return mustParse("15:04:05", typ, s)
  90. case oid.T_timetz:
  91. return mustParse("15:04:05-07", typ, s)
  92. case oid.T_bool:
  93. return s[0] == 't'
  94. case oid.T_int8, oid.T_int4, oid.T_int2:
  95. i, err := strconv.ParseInt(string(s), 10, 64)
  96. if err != nil {
  97. errorf("%s", err)
  98. }
  99. return i
  100. case oid.T_float4, oid.T_float8:
  101. bits := 64
  102. if typ == oid.T_float4 {
  103. bits = 32
  104. }
  105. f, err := strconv.ParseFloat(string(s), bits)
  106. if err != nil {
  107. errorf("%s", err)
  108. }
  109. return f
  110. }
  111. return s
  112. }
  113. // appendEncodedText encodes item in text format as required by COPY
  114. // and appends to buf
  115. func appendEncodedText(parameterStatus *parameterStatus, buf []byte, x interface{}) []byte {
  116. switch v := x.(type) {
  117. case int64:
  118. return strconv.AppendInt(buf, v, 10)
  119. case float64:
  120. return strconv.AppendFloat(buf, v, 'f', -1, 64)
  121. case []byte:
  122. encodedBytea := encodeBytea(parameterStatus.serverVersion, v)
  123. return appendEscapedText(buf, string(encodedBytea))
  124. case string:
  125. return appendEscapedText(buf, v)
  126. case bool:
  127. return strconv.AppendBool(buf, v)
  128. case time.Time:
  129. return append(buf, formatTs(v)...)
  130. case nil:
  131. return append(buf, "\\N"...)
  132. default:
  133. errorf("encode: unknown type for %T", v)
  134. }
  135. panic("not reached")
  136. }
  137. func appendEscapedText(buf []byte, text string) []byte {
  138. escapeNeeded := false
  139. startPos := 0
  140. var c byte
  141. // check if we need to escape
  142. for i := 0; i < len(text); i++ {
  143. c = text[i]
  144. if c == '\\' || c == '\n' || c == '\r' || c == '\t' {
  145. escapeNeeded = true
  146. startPos = i
  147. break
  148. }
  149. }
  150. if !escapeNeeded {
  151. return append(buf, text...)
  152. }
  153. // copy till first char to escape, iterate the rest
  154. result := append(buf, text[:startPos]...)
  155. for i := startPos; i < len(text); i++ {
  156. c = text[i]
  157. switch c {
  158. case '\\':
  159. result = append(result, '\\', '\\')
  160. case '\n':
  161. result = append(result, '\\', 'n')
  162. case '\r':
  163. result = append(result, '\\', 'r')
  164. case '\t':
  165. result = append(result, '\\', 't')
  166. default:
  167. result = append(result, c)
  168. }
  169. }
  170. return result
  171. }
  172. func mustParse(f string, typ oid.Oid, s []byte) time.Time {
  173. str := string(s)
  174. // check for a 30-minute-offset timezone
  175. if (typ == oid.T_timestamptz || typ == oid.T_timetz) &&
  176. str[len(str)-3] == ':' {
  177. f += ":00"
  178. }
  179. t, err := time.Parse(f, str)
  180. if err != nil {
  181. errorf("decode: %s", err)
  182. }
  183. return t
  184. }
  185. var errInvalidTimestamp = errors.New("invalid timestamp")
  186. type timestampParser struct {
  187. err error
  188. }
  189. func (p *timestampParser) expect(str string, char byte, pos int) {
  190. if p.err != nil {
  191. return
  192. }
  193. if pos+1 > len(str) {
  194. p.err = errInvalidTimestamp
  195. return
  196. }
  197. if c := str[pos]; c != char && p.err == nil {
  198. p.err = fmt.Errorf("expected '%v' at position %v; got '%v'", char, pos, c)
  199. }
  200. }
  201. func (p *timestampParser) mustAtoi(str string, begin int, end int) int {
  202. if p.err != nil {
  203. return 0
  204. }
  205. if begin < 0 || end < 0 || begin > end || end > len(str) {
  206. p.err = errInvalidTimestamp
  207. return 0
  208. }
  209. result, err := strconv.Atoi(str[begin:end])
  210. if err != nil {
  211. if p.err == nil {
  212. p.err = fmt.Errorf("expected number; got '%v'", str)
  213. }
  214. return 0
  215. }
  216. return result
  217. }
  218. // The location cache caches the time zones typically used by the client.
  219. type locationCache struct {
  220. cache map[int]*time.Location
  221. lock sync.Mutex
  222. }
  223. // All connections share the same list of timezones. Benchmarking shows that
  224. // about 5% speed could be gained by putting the cache in the connection and
  225. // losing the mutex, at the cost of a small amount of memory and a somewhat
  226. // significant increase in code complexity.
  227. var globalLocationCache = newLocationCache()
  228. func newLocationCache() *locationCache {
  229. return &locationCache{cache: make(map[int]*time.Location)}
  230. }
  231. // Returns the cached timezone for the specified offset, creating and caching
  232. // it if necessary.
  233. func (c *locationCache) getLocation(offset int) *time.Location {
  234. c.lock.Lock()
  235. defer c.lock.Unlock()
  236. location, ok := c.cache[offset]
  237. if !ok {
  238. location = time.FixedZone("", offset)
  239. c.cache[offset] = location
  240. }
  241. return location
  242. }
  243. var infinityTsEnabled = false
  244. var infinityTsNegative time.Time
  245. var infinityTsPositive time.Time
  246. const (
  247. infinityTsEnabledAlready = "pq: infinity timestamp enabled already"
  248. infinityTsNegativeMustBeSmaller = "pq: infinity timestamp: negative value must be smaller (before) than positive"
  249. )
  250. // EnableInfinityTs controls the handling of Postgres' "-infinity" and
  251. // "infinity" "timestamp"s.
  252. //
  253. // If EnableInfinityTs is not called, "-infinity" and "infinity" will return
  254. // []byte("-infinity") and []byte("infinity") respectively, and potentially
  255. // cause error "sql: Scan error on column index 0: unsupported driver -> Scan
  256. // pair: []uint8 -> *time.Time", when scanning into a time.Time value.
  257. //
  258. // Once EnableInfinityTs has been called, all connections created using this
  259. // driver will decode Postgres' "-infinity" and "infinity" for "timestamp",
  260. // "timestamp with time zone" and "date" types to the predefined minimum and
  261. // maximum times, respectively. When encoding time.Time values, any time which
  262. // equals or precedes the predefined minimum time will be encoded to
  263. // "-infinity". Any values at or past the maximum time will similarly be
  264. // encoded to "infinity".
  265. //
  266. // If EnableInfinityTs is called with negative >= positive, it will panic.
  267. // Calling EnableInfinityTs after a connection has been established results in
  268. // undefined behavior. If EnableInfinityTs is called more than once, it will
  269. // panic.
  270. func EnableInfinityTs(negative time.Time, positive time.Time) {
  271. if infinityTsEnabled {
  272. panic(infinityTsEnabledAlready)
  273. }
  274. if !negative.Before(positive) {
  275. panic(infinityTsNegativeMustBeSmaller)
  276. }
  277. infinityTsEnabled = true
  278. infinityTsNegative = negative
  279. infinityTsPositive = positive
  280. }
  281. /*
  282. * Testing might want to toggle infinityTsEnabled
  283. */
  284. func disableInfinityTs() {
  285. infinityTsEnabled = false
  286. }
  287. // This is a time function specific to the Postgres default DateStyle
  288. // setting ("ISO, MDY"), the only one we currently support. This
  289. // accounts for the discrepancies between the parsing available with
  290. // time.Parse and the Postgres date formatting quirks.
  291. func parseTs(currentLocation *time.Location, str string) interface{} {
  292. switch str {
  293. case "-infinity":
  294. if infinityTsEnabled {
  295. return infinityTsNegative
  296. }
  297. return []byte(str)
  298. case "infinity":
  299. if infinityTsEnabled {
  300. return infinityTsPositive
  301. }
  302. return []byte(str)
  303. }
  304. t, err := ParseTimestamp(currentLocation, str)
  305. if err != nil {
  306. panic(err)
  307. }
  308. return t
  309. }
  310. // ParseTimestamp parses Postgres' text format. It returns a time.Time in
  311. // currentLocation iff that time's offset agrees with the offset sent from the
  312. // Postgres server. Otherwise, ParseTimestamp returns a time.Time with the
  313. // fixed offset offset provided by the Postgres server.
  314. func ParseTimestamp(currentLocation *time.Location, str string) (time.Time, error) {
  315. p := timestampParser{}
  316. monSep := strings.IndexRune(str, '-')
  317. // this is Gregorian year, not ISO Year
  318. // In Gregorian system, the year 1 BC is followed by AD 1
  319. year := p.mustAtoi(str, 0, monSep)
  320. daySep := monSep + 3
  321. month := p.mustAtoi(str, monSep+1, daySep)
  322. p.expect(str, '-', daySep)
  323. timeSep := daySep + 3
  324. day := p.mustAtoi(str, daySep+1, timeSep)
  325. var hour, minute, second int
  326. if len(str) > monSep+len("01-01")+1 {
  327. p.expect(str, ' ', timeSep)
  328. minSep := timeSep + 3
  329. p.expect(str, ':', minSep)
  330. hour = p.mustAtoi(str, timeSep+1, minSep)
  331. secSep := minSep + 3
  332. p.expect(str, ':', secSep)
  333. minute = p.mustAtoi(str, minSep+1, secSep)
  334. secEnd := secSep + 3
  335. second = p.mustAtoi(str, secSep+1, secEnd)
  336. }
  337. remainderIdx := monSep + len("01-01 00:00:00") + 1
  338. // Three optional (but ordered) sections follow: the
  339. // fractional seconds, the time zone offset, and the BC
  340. // designation. We set them up here and adjust the other
  341. // offsets if the preceding sections exist.
  342. nanoSec := 0
  343. tzOff := 0
  344. if remainderIdx < len(str) && str[remainderIdx] == '.' {
  345. fracStart := remainderIdx + 1
  346. fracOff := strings.IndexAny(str[fracStart:], "-+ ")
  347. if fracOff < 0 {
  348. fracOff = len(str) - fracStart
  349. }
  350. fracSec := p.mustAtoi(str, fracStart, fracStart+fracOff)
  351. nanoSec = fracSec * (1000000000 / int(math.Pow(10, float64(fracOff))))
  352. remainderIdx += fracOff + 1
  353. }
  354. if tzStart := remainderIdx; tzStart < len(str) && (str[tzStart] == '-' || str[tzStart] == '+') {
  355. // time zone separator is always '-' or '+' (UTC is +00)
  356. var tzSign int
  357. switch c := str[tzStart]; c {
  358. case '-':
  359. tzSign = -1
  360. case '+':
  361. tzSign = +1
  362. default:
  363. return time.Time{}, fmt.Errorf("expected '-' or '+' at position %v; got %v", tzStart, c)
  364. }
  365. tzHours := p.mustAtoi(str, tzStart+1, tzStart+3)
  366. remainderIdx += 3
  367. var tzMin, tzSec int
  368. if remainderIdx < len(str) && str[remainderIdx] == ':' {
  369. tzMin = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
  370. remainderIdx += 3
  371. }
  372. if remainderIdx < len(str) && str[remainderIdx] == ':' {
  373. tzSec = p.mustAtoi(str, remainderIdx+1, remainderIdx+3)
  374. remainderIdx += 3
  375. }
  376. tzOff = tzSign * ((tzHours * 60 * 60) + (tzMin * 60) + tzSec)
  377. }
  378. var isoYear int
  379. if remainderIdx+3 <= len(str) && str[remainderIdx:remainderIdx+3] == " BC" {
  380. isoYear = 1 - year
  381. remainderIdx += 3
  382. } else {
  383. isoYear = year
  384. }
  385. if remainderIdx < len(str) {
  386. return time.Time{}, fmt.Errorf("expected end of input, got %v", str[remainderIdx:])
  387. }
  388. t := time.Date(isoYear, time.Month(month), day,
  389. hour, minute, second, nanoSec,
  390. globalLocationCache.getLocation(tzOff))
  391. if currentLocation != nil {
  392. // Set the location of the returned Time based on the session's
  393. // TimeZone value, but only if the local time zone database agrees with
  394. // the remote database on the offset.
  395. lt := t.In(currentLocation)
  396. _, newOff := lt.Zone()
  397. if newOff == tzOff {
  398. t = lt
  399. }
  400. }
  401. return t, p.err
  402. }
  403. // formatTs formats t into a format postgres understands.
  404. func formatTs(t time.Time) []byte {
  405. if infinityTsEnabled {
  406. // t <= -infinity : ! (t > -infinity)
  407. if !t.After(infinityTsNegative) {
  408. return []byte("-infinity")
  409. }
  410. // t >= infinity : ! (!t < infinity)
  411. if !t.Before(infinityTsPositive) {
  412. return []byte("infinity")
  413. }
  414. }
  415. return FormatTimestamp(t)
  416. }
  417. // FormatTimestamp formats t into Postgres' text format for timestamps.
  418. func FormatTimestamp(t time.Time) []byte {
  419. // Need to send dates before 0001 A.D. with " BC" suffix, instead of the
  420. // minus sign preferred by Go.
  421. // Beware, "0000" in ISO is "1 BC", "-0001" is "2 BC" and so on
  422. bc := false
  423. if t.Year() <= 0 {
  424. // flip year sign, and add 1, e.g: "0" will be "1", and "-10" will be "11"
  425. t = t.AddDate((-t.Year())*2+1, 0, 0)
  426. bc = true
  427. }
  428. b := []byte(t.Format(time.RFC3339Nano))
  429. _, offset := t.Zone()
  430. offset = offset % 60
  431. if offset != 0 {
  432. // RFC3339Nano already printed the minus sign
  433. if offset < 0 {
  434. offset = -offset
  435. }
  436. b = append(b, ':')
  437. if offset < 10 {
  438. b = append(b, '0')
  439. }
  440. b = strconv.AppendInt(b, int64(offset), 10)
  441. }
  442. if bc {
  443. b = append(b, " BC"...)
  444. }
  445. return b
  446. }
  447. // Parse a bytea value received from the server. Both "hex" and the legacy
  448. // "escape" format are supported.
  449. func parseBytea(s []byte) (result []byte, err error) {
  450. if len(s) >= 2 && bytes.Equal(s[:2], []byte("\\x")) {
  451. // bytea_output = hex
  452. s = s[2:] // trim off leading "\\x"
  453. result = make([]byte, hex.DecodedLen(len(s)))
  454. _, err := hex.Decode(result, s)
  455. if err != nil {
  456. return nil, err
  457. }
  458. } else {
  459. // bytea_output = escape
  460. for len(s) > 0 {
  461. if s[0] == '\\' {
  462. // escaped '\\'
  463. if len(s) >= 2 && s[1] == '\\' {
  464. result = append(result, '\\')
  465. s = s[2:]
  466. continue
  467. }
  468. // '\\' followed by an octal number
  469. if len(s) < 4 {
  470. return nil, fmt.Errorf("invalid bytea sequence %v", s)
  471. }
  472. r, err := strconv.ParseInt(string(s[1:4]), 8, 9)
  473. if err != nil {
  474. return nil, fmt.Errorf("could not parse bytea value: %s", err.Error())
  475. }
  476. result = append(result, byte(r))
  477. s = s[4:]
  478. } else {
  479. // We hit an unescaped, raw byte. Try to read in as many as
  480. // possible in one go.
  481. i := bytes.IndexByte(s, '\\')
  482. if i == -1 {
  483. result = append(result, s...)
  484. break
  485. }
  486. result = append(result, s[:i]...)
  487. s = s[i:]
  488. }
  489. }
  490. }
  491. return result, nil
  492. }
  493. func encodeBytea(serverVersion int, v []byte) (result []byte) {
  494. if serverVersion >= 90000 {
  495. // Use the hex format if we know that the server supports it
  496. result = make([]byte, 2+hex.EncodedLen(len(v)))
  497. result[0] = '\\'
  498. result[1] = 'x'
  499. hex.Encode(result[2:], v)
  500. } else {
  501. // .. or resort to "escape"
  502. for _, b := range v {
  503. if b == '\\' {
  504. result = append(result, '\\', '\\')
  505. } else if b < 0x20 || b > 0x7e {
  506. result = append(result, []byte(fmt.Sprintf("\\%03o", b))...)
  507. } else {
  508. result = append(result, b)
  509. }
  510. }
  511. }
  512. return result
  513. }
  514. // NullTime represents a time.Time that may be null. NullTime implements the
  515. // sql.Scanner interface so it can be used as a scan destination, similar to
  516. // sql.NullString.
  517. type NullTime struct {
  518. Time time.Time
  519. Valid bool // Valid is true if Time is not NULL
  520. }
  521. // Scan implements the Scanner interface.
  522. func (nt *NullTime) Scan(value interface{}) error {
  523. nt.Time, nt.Valid = value.(time.Time)
  524. return nil
  525. }
  526. // Value implements the driver Valuer interface.
  527. func (nt NullTime) Value() (driver.Value, error) {
  528. if !nt.Valid {
  529. return nil, nil
  530. }
  531. return nt.Time, nil
  532. }