pq_driver.go 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119
  1. // Copyright 2015 The Xorm Authors. All rights reserved.
  2. // Use of this source code is governed by a BSD-style
  3. // license that can be found in the LICENSE file.
  4. package xorm
  5. import (
  6. "errors"
  7. "fmt"
  8. "net/url"
  9. "sort"
  10. "strings"
  11. "github.com/go-xorm/core"
  12. )
  13. type pqDriver struct {
  14. }
  15. type values map[string]string
  16. func (vs values) Set(k, v string) {
  17. vs[k] = v
  18. }
  19. func (vs values) Get(k string) (v string) {
  20. return vs[k]
  21. }
  22. func errorf(s string, args ...interface{}) {
  23. panic(fmt.Errorf("pq: %s", fmt.Sprintf(s, args...)))
  24. }
  25. func parseURL(connstr string) (string, error) {
  26. u, err := url.Parse(connstr)
  27. if err != nil {
  28. return "", err
  29. }
  30. if u.Scheme != "postgresql" && u.Scheme != "postgres" {
  31. return "", fmt.Errorf("invalid connection protocol: %s", u.Scheme)
  32. }
  33. var kvs []string
  34. escaper := strings.NewReplacer(` `, `\ `, `'`, `\'`, `\`, `\\`)
  35. accrue := func(k, v string) {
  36. if v != "" {
  37. kvs = append(kvs, k+"="+escaper.Replace(v))
  38. }
  39. }
  40. if u.User != nil {
  41. v := u.User.Username()
  42. accrue("user", v)
  43. v, _ = u.User.Password()
  44. accrue("password", v)
  45. }
  46. i := strings.Index(u.Host, ":")
  47. if i < 0 {
  48. accrue("host", u.Host)
  49. } else {
  50. accrue("host", u.Host[:i])
  51. accrue("port", u.Host[i+1:])
  52. }
  53. if u.Path != "" {
  54. accrue("dbname", u.Path[1:])
  55. }
  56. q := u.Query()
  57. for k := range q {
  58. accrue(k, q.Get(k))
  59. }
  60. sort.Strings(kvs) // Makes testing easier (not a performance concern)
  61. return strings.Join(kvs, " "), nil
  62. }
  63. func parseOpts(name string, o values) {
  64. if len(name) == 0 {
  65. return
  66. }
  67. name = strings.TrimSpace(name)
  68. ps := strings.Split(name, " ")
  69. for _, p := range ps {
  70. kv := strings.Split(p, "=")
  71. if len(kv) < 2 {
  72. errorf("invalid option: %q", p)
  73. }
  74. o.Set(kv[0], kv[1])
  75. }
  76. }
  77. func (p *pqDriver) Parse(driverName, dataSourceName string) (*core.Uri, error) {
  78. db := &core.Uri{DbType: core.POSTGRES}
  79. o := make(values)
  80. var err error
  81. if strings.HasPrefix(dataSourceName, "postgresql://") || strings.HasPrefix(dataSourceName, "postgres://") {
  82. dataSourceName, err = parseURL(dataSourceName)
  83. if err != nil {
  84. return nil, err
  85. }
  86. }
  87. parseOpts(dataSourceName, o)
  88. db.DbName = o.Get("dbname")
  89. if db.DbName == "" {
  90. return nil, errors.New("dbname is empty")
  91. }
  92. /*db.Schema = o.Get("schema")
  93. if len(db.Schema) == 0 {
  94. db.Schema = "public"
  95. }*/
  96. return db, nil
  97. }