mssql.go 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609
  1. package mssql
  2. import (
  3. "database/sql"
  4. "database/sql/driver"
  5. "encoding/binary"
  6. "errors"
  7. "fmt"
  8. "io"
  9. "math"
  10. "net"
  11. "reflect"
  12. "strings"
  13. "time"
  14. "golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.
  15. )
  16. var driverInstance = &MssqlDriver{processQueryText: true}
  17. var driverInstanceNoProcess = &MssqlDriver{processQueryText: false}
  18. func init() {
  19. sql.Register("mssql", driverInstance)
  20. sql.Register("sqlserver", driverInstanceNoProcess)
  21. }
  22. // Abstract the dialer for testing and for non-TCP based connections.
  23. type dialer interface {
  24. Dial(addr string) (net.Conn, error)
  25. }
  26. var createDialer func(p *connectParams) dialer
  27. type tcpDialer struct {
  28. nd *net.Dialer
  29. }
  30. func (d tcpDialer) Dial(addr string) (net.Conn, error) {
  31. return d.nd.Dial("tcp", addr)
  32. }
  33. type MssqlDriver struct {
  34. log optionalLogger
  35. processQueryText bool
  36. }
  37. func SetLogger(logger Logger) {
  38. driverInstance.SetLogger(logger)
  39. driverInstanceNoProcess.SetLogger(logger)
  40. }
  41. func (d *MssqlDriver) SetLogger(logger Logger) {
  42. d.log = optionalLogger{logger}
  43. }
  44. type MssqlConn struct {
  45. sess *tdsSession
  46. transactionCtx context.Context
  47. processQueryText bool
  48. }
  49. func (c *MssqlConn) simpleProcessResp(ctx context.Context) error {
  50. tokchan := make(chan tokenStruct, 5)
  51. go processResponse(ctx, c.sess, tokchan)
  52. for tok := range tokchan {
  53. switch token := tok.(type) {
  54. case doneStruct:
  55. if token.isError() {
  56. return token.getError()
  57. }
  58. case error:
  59. return token
  60. }
  61. }
  62. return nil
  63. }
  64. func (c *MssqlConn) Commit() error {
  65. if err := c.sendCommitRequest(); err != nil {
  66. return err
  67. }
  68. return c.simpleProcessResp(c.transactionCtx)
  69. }
  70. func (c *MssqlConn) sendCommitRequest() error {
  71. headers := []headerStruct{
  72. {hdrtype: dataStmHdrTransDescr,
  73. data: transDescrHdr{c.sess.tranid, 1}.pack()},
  74. }
  75. if err := sendCommitXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
  76. if c.sess.logFlags&logErrors != 0 {
  77. c.sess.log.Printf("Failed to send CommitXact with %v", err)
  78. }
  79. return driver.ErrBadConn
  80. }
  81. return nil
  82. }
  83. func (c *MssqlConn) Rollback() error {
  84. if err := c.sendRollbackRequest(); err != nil {
  85. return err
  86. }
  87. return c.simpleProcessResp(c.transactionCtx)
  88. }
  89. func (c *MssqlConn) sendRollbackRequest() error {
  90. headers := []headerStruct{
  91. {hdrtype: dataStmHdrTransDescr,
  92. data: transDescrHdr{c.sess.tranid, 1}.pack()},
  93. }
  94. if err := sendRollbackXact(c.sess.buf, headers, "", 0, 0, ""); err != nil {
  95. if c.sess.logFlags&logErrors != 0 {
  96. c.sess.log.Printf("Failed to send RollbackXact with %v", err)
  97. }
  98. return driver.ErrBadConn
  99. }
  100. return nil
  101. }
  102. func (c *MssqlConn) Begin() (driver.Tx, error) {
  103. return c.begin(context.Background(), isolationUseCurrent)
  104. }
  105. func (c *MssqlConn) begin(ctx context.Context, tdsIsolation isoLevel) (driver.Tx, error) {
  106. err := c.sendBeginRequest(ctx, tdsIsolation)
  107. if err != nil {
  108. return nil, err
  109. }
  110. return c.processBeginResponse(ctx)
  111. }
  112. func (c *MssqlConn) sendBeginRequest(ctx context.Context, tdsIsolation isoLevel) error {
  113. c.transactionCtx = ctx
  114. headers := []headerStruct{
  115. {hdrtype: dataStmHdrTransDescr,
  116. data: transDescrHdr{0, 1}.pack()},
  117. }
  118. if err := sendBeginXact(c.sess.buf, headers, tdsIsolation, ""); err != nil {
  119. if c.sess.logFlags&logErrors != 0 {
  120. c.sess.log.Printf("Failed to send BeginXact with %v", err)
  121. }
  122. return driver.ErrBadConn
  123. }
  124. return nil
  125. }
  126. func (c *MssqlConn) processBeginResponse(ctx context.Context) (driver.Tx, error) {
  127. if err := c.simpleProcessResp(ctx); err != nil {
  128. return nil, err
  129. }
  130. // successful BEGINXACT request will return sess.tranid
  131. // for started transaction
  132. return c, nil
  133. }
  134. func (d *MssqlDriver) Open(dsn string) (driver.Conn, error) {
  135. return d.open(dsn)
  136. }
  137. func (d *MssqlDriver) open(dsn string) (*MssqlConn, error) {
  138. params, err := parseConnectParams(dsn)
  139. if err != nil {
  140. return nil, err
  141. }
  142. sess, err := connect(d.log, params)
  143. if err != nil {
  144. // main server failed, try fail-over partner
  145. if params.failOverPartner == "" {
  146. return nil, err
  147. }
  148. params.host = params.failOverPartner
  149. if params.failOverPort != 0 {
  150. params.port = params.failOverPort
  151. }
  152. sess, err = connect(d.log, params)
  153. if err != nil {
  154. // fail-over partner also failed, now fail
  155. return nil, err
  156. }
  157. }
  158. conn := &MssqlConn{sess, context.Background(), d.processQueryText}
  159. conn.sess.log = d.log
  160. return conn, nil
  161. }
  162. func (c *MssqlConn) Close() error {
  163. return c.sess.buf.transport.Close()
  164. }
  165. type MssqlStmt struct {
  166. c *MssqlConn
  167. query string
  168. paramCount int
  169. notifSub *queryNotifSub
  170. }
  171. type queryNotifSub struct {
  172. msgText string
  173. options string
  174. timeout uint32
  175. }
  176. func (c *MssqlConn) Prepare(query string) (driver.Stmt, error) {
  177. return c.prepareContext(context.Background(), query)
  178. }
  179. func (c *MssqlConn) prepareContext(ctx context.Context, query string) (*MssqlStmt, error) {
  180. paramCount := -1
  181. if c.processQueryText {
  182. query, paramCount = parseParams(query)
  183. }
  184. return &MssqlStmt{c, query, paramCount, nil}, nil
  185. }
  186. func (s *MssqlStmt) Close() error {
  187. return nil
  188. }
  189. func (s *MssqlStmt) SetQueryNotification(id, options string, timeout time.Duration) {
  190. to := uint32(timeout / time.Second)
  191. if to < 1 {
  192. to = 1
  193. }
  194. s.notifSub = &queryNotifSub{id, options, to}
  195. }
  196. func (s *MssqlStmt) NumInput() int {
  197. return s.paramCount
  198. }
  199. func (s *MssqlStmt) sendQuery(args []namedValue) (err error) {
  200. headers := []headerStruct{
  201. {hdrtype: dataStmHdrTransDescr,
  202. data: transDescrHdr{s.c.sess.tranid, 1}.pack()},
  203. }
  204. if s.notifSub != nil {
  205. headers = append(headers, headerStruct{hdrtype: dataStmHdrQueryNotif,
  206. data: queryNotifHdr{s.notifSub.msgText, s.notifSub.options, s.notifSub.timeout}.pack()})
  207. }
  208. // no need to check number of parameters here, it is checked by database/sql
  209. if s.c.sess.logFlags&logSQL != 0 {
  210. s.c.sess.log.Println(s.query)
  211. }
  212. if s.c.sess.logFlags&logParams != 0 && len(args) > 0 {
  213. for i := 0; i < len(args); i++ {
  214. s.c.sess.log.Printf("\t@p%d\t%v\n", i+1, args[i])
  215. }
  216. }
  217. if len(args) == 0 {
  218. if err = sendSqlBatch72(s.c.sess.buf, s.query, headers); err != nil {
  219. if s.c.sess.logFlags&logErrors != 0 {
  220. s.c.sess.log.Printf("Failed to send SqlBatch with %v", err)
  221. }
  222. return driver.ErrBadConn
  223. }
  224. } else {
  225. params := make([]Param, len(args)+2)
  226. decls := make([]string, len(args))
  227. params[0] = makeStrParam(s.query)
  228. for i, val := range args {
  229. params[i+2], err = s.makeParam(val.Value)
  230. if err != nil {
  231. return
  232. }
  233. var name string
  234. if len(val.Name) > 0 {
  235. name = "@" + val.Name
  236. } else {
  237. name = fmt.Sprintf("@p%d", val.Ordinal)
  238. }
  239. params[i+2].Name = name
  240. decls[i] = fmt.Sprintf("%s %s", name, makeDecl(params[i+2].ti))
  241. }
  242. params[1] = makeStrParam(strings.Join(decls, ","))
  243. if err = sendRpc(s.c.sess.buf, headers, Sp_ExecuteSql, 0, params); err != nil {
  244. if s.c.sess.logFlags&logErrors != 0 {
  245. s.c.sess.log.Printf("Failed to send Rpc with %v", err)
  246. }
  247. return driver.ErrBadConn
  248. }
  249. }
  250. return
  251. }
  252. type namedValue struct {
  253. Name string
  254. Ordinal int
  255. Value driver.Value
  256. }
  257. func convertOldArgs(args []driver.Value) []namedValue {
  258. list := make([]namedValue, len(args))
  259. for i, v := range args {
  260. list[i] = namedValue{
  261. Ordinal: i + 1,
  262. Value: v,
  263. }
  264. }
  265. return list
  266. }
  267. func (s *MssqlStmt) Query(args []driver.Value) (driver.Rows, error) {
  268. return s.queryContext(context.Background(), convertOldArgs(args))
  269. }
  270. func (s *MssqlStmt) queryContext(ctx context.Context, args []namedValue) (driver.Rows, error) {
  271. if err := s.sendQuery(args); err != nil {
  272. return nil, err
  273. }
  274. return s.processQueryResponse(ctx)
  275. }
  276. func (s *MssqlStmt) processQueryResponse(ctx context.Context) (res driver.Rows, err error) {
  277. tokchan := make(chan tokenStruct, 5)
  278. ctx, cancel := context.WithCancel(ctx)
  279. go processResponse(ctx, s.c.sess, tokchan)
  280. // process metadata
  281. var cols []columnStruct
  282. loop:
  283. for tok := range tokchan {
  284. switch token := tok.(type) {
  285. // by ignoring DONE token we effectively
  286. // skip empty result-sets
  287. // this improves results in queryes like that:
  288. // set nocount on; select 1
  289. // see TestIgnoreEmptyResults test
  290. //case doneStruct:
  291. //break loop
  292. case []columnStruct:
  293. cols = token
  294. break loop
  295. case doneStruct:
  296. if token.isError() {
  297. return nil, token.getError()
  298. }
  299. case error:
  300. return nil, token
  301. }
  302. }
  303. res = &MssqlRows{sess: s.c.sess, tokchan: tokchan, cols: cols, cancel: cancel}
  304. return
  305. }
  306. func (s *MssqlStmt) Exec(args []driver.Value) (driver.Result, error) {
  307. return s.exec(context.Background(), convertOldArgs(args))
  308. }
  309. func (s *MssqlStmt) exec(ctx context.Context, args []namedValue) (driver.Result, error) {
  310. if err := s.sendQuery(args); err != nil {
  311. return nil, err
  312. }
  313. return s.processExec(ctx)
  314. }
  315. func (s *MssqlStmt) processExec(ctx context.Context) (res driver.Result, err error) {
  316. tokchan := make(chan tokenStruct, 5)
  317. go processResponse(ctx, s.c.sess, tokchan)
  318. var rowCount int64
  319. for token := range tokchan {
  320. switch token := token.(type) {
  321. case doneInProcStruct:
  322. if token.Status&doneCount != 0 {
  323. rowCount += int64(token.RowCount)
  324. }
  325. case doneStruct:
  326. if token.Status&doneCount != 0 {
  327. rowCount += int64(token.RowCount)
  328. }
  329. if token.isError() {
  330. return nil, token.getError()
  331. }
  332. case error:
  333. return nil, token
  334. }
  335. }
  336. return &MssqlResult{s.c, rowCount}, nil
  337. }
  338. type MssqlRows struct {
  339. sess *tdsSession
  340. cols []columnStruct
  341. tokchan chan tokenStruct
  342. nextCols []columnStruct
  343. cancel func()
  344. }
  345. func (rc *MssqlRows) Close() error {
  346. rc.cancel()
  347. for _ = range rc.tokchan {
  348. }
  349. rc.tokchan = nil
  350. return nil
  351. }
  352. func (rc *MssqlRows) Columns() (res []string) {
  353. res = make([]string, len(rc.cols))
  354. for i, col := range rc.cols {
  355. res[i] = col.ColName
  356. }
  357. return
  358. }
  359. func (rc *MssqlRows) Next(dest []driver.Value) error {
  360. if rc.nextCols != nil {
  361. return io.EOF
  362. }
  363. for tok := range rc.tokchan {
  364. switch tokdata := tok.(type) {
  365. case []columnStruct:
  366. rc.nextCols = tokdata
  367. return io.EOF
  368. case []interface{}:
  369. for i := range dest {
  370. dest[i] = tokdata[i]
  371. }
  372. return nil
  373. case doneStruct:
  374. if tokdata.isError() {
  375. return tokdata.getError()
  376. }
  377. case error:
  378. return tokdata
  379. }
  380. }
  381. return io.EOF
  382. }
  383. func (rc *MssqlRows) HasNextResultSet() bool {
  384. return rc.nextCols != nil
  385. }
  386. func (rc *MssqlRows) NextResultSet() error {
  387. rc.cols = rc.nextCols
  388. rc.nextCols = nil
  389. if rc.cols == nil {
  390. return io.EOF
  391. }
  392. return nil
  393. }
  394. // It should return
  395. // the value type that can be used to scan types into. For example, the database
  396. // column type "bigint" this should return "reflect.TypeOf(int64(0))".
  397. func (r *MssqlRows) ColumnTypeScanType(index int) reflect.Type {
  398. return makeGoLangScanType(r.cols[index].ti)
  399. }
  400. // RowsColumnTypeDatabaseTypeName may be implemented by Rows. It should return the
  401. // database system type name without the length. Type names should be uppercase.
  402. // Examples of returned types: "VARCHAR", "NVARCHAR", "VARCHAR2", "CHAR", "TEXT",
  403. // "DECIMAL", "SMALLINT", "INT", "BIGINT", "BOOL", "[]BIGINT", "JSONB", "XML",
  404. // "TIMESTAMP".
  405. func (r *MssqlRows) ColumnTypeDatabaseTypeName(index int) string {
  406. return makeGoLangTypeName(r.cols[index].ti)
  407. }
  408. // RowsColumnTypeLength may be implemented by Rows. It should return the length
  409. // of the column type if the column is a variable length type. If the column is
  410. // not a variable length type ok should return false.
  411. // If length is not limited other than system limits, it should return math.MaxInt64.
  412. // The following are examples of returned values for various types:
  413. // TEXT (math.MaxInt64, true)
  414. // varchar(10) (10, true)
  415. // nvarchar(10) (10, true)
  416. // decimal (0, false)
  417. // int (0, false)
  418. // bytea(30) (30, true)
  419. func (r *MssqlRows) ColumnTypeLength(index int) (int64, bool) {
  420. return makeGoLangTypeLength(r.cols[index].ti)
  421. }
  422. // It should return
  423. // the precision and scale for decimal types. If not applicable, ok should be false.
  424. // The following are examples of returned values for various types:
  425. // decimal(38, 4) (38, 4, true)
  426. // int (0, 0, false)
  427. // decimal (math.MaxInt64, math.MaxInt64, true)
  428. func (r *MssqlRows) ColumnTypePrecisionScale(index int) (int64, int64, bool) {
  429. return makeGoLangTypePrecisionScale(r.cols[index].ti)
  430. }
  431. // The nullable value should
  432. // be true if it is known the column may be null, or false if the column is known
  433. // to be not nullable.
  434. // If the column nullability is unknown, ok should be false.
  435. func (r *MssqlRows) ColumnTypeNullable(index int) (nullable, ok bool) {
  436. nullable = r.cols[index].Flags&colFlagNullable != 0
  437. ok = true
  438. return
  439. }
  440. func makeStrParam(val string) (res Param) {
  441. res.ti.TypeId = typeNVarChar
  442. res.buffer = str2ucs2(val)
  443. res.ti.Size = len(res.buffer)
  444. return
  445. }
  446. func (s *MssqlStmt) makeParam(val driver.Value) (res Param, err error) {
  447. if val == nil {
  448. res.ti.TypeId = typeNVarChar
  449. res.buffer = nil
  450. res.ti.Size = 2
  451. return
  452. }
  453. switch val := val.(type) {
  454. case int64:
  455. res.ti.TypeId = typeIntN
  456. res.buffer = make([]byte, 8)
  457. res.ti.Size = 8
  458. binary.LittleEndian.PutUint64(res.buffer, uint64(val))
  459. case float64:
  460. res.ti.TypeId = typeFltN
  461. res.ti.Size = 8
  462. res.buffer = make([]byte, 8)
  463. binary.LittleEndian.PutUint64(res.buffer, math.Float64bits(val))
  464. case []byte:
  465. res.ti.TypeId = typeBigVarBin
  466. res.ti.Size = len(val)
  467. res.buffer = val
  468. case string:
  469. res = makeStrParam(val)
  470. case bool:
  471. res.ti.TypeId = typeBitN
  472. res.ti.Size = 1
  473. res.buffer = make([]byte, 1)
  474. if val {
  475. res.buffer[0] = 1
  476. }
  477. case time.Time:
  478. if s.c.sess.loginAck.TDSVersion >= verTDS73 {
  479. res.ti.TypeId = typeDateTimeOffsetN
  480. res.ti.Scale = 7
  481. res.ti.Size = 10
  482. buf := make([]byte, 10)
  483. res.buffer = buf
  484. days, ns := dateTime2(val)
  485. ns /= 100
  486. buf[0] = byte(ns)
  487. buf[1] = byte(ns >> 8)
  488. buf[2] = byte(ns >> 16)
  489. buf[3] = byte(ns >> 24)
  490. buf[4] = byte(ns >> 32)
  491. buf[5] = byte(days)
  492. buf[6] = byte(days >> 8)
  493. buf[7] = byte(days >> 16)
  494. _, offset := val.Zone()
  495. offset /= 60
  496. buf[8] = byte(offset)
  497. buf[9] = byte(offset >> 8)
  498. } else {
  499. res.ti.TypeId = typeDateTimeN
  500. res.ti.Size = 8
  501. res.buffer = make([]byte, 8)
  502. ref := time.Date(1900, 1, 1, 0, 0, 0, 0, time.UTC)
  503. dur := val.Sub(ref)
  504. days := dur / (24 * time.Hour)
  505. tm := (300 * (dur % (24 * time.Hour))) / time.Second
  506. binary.LittleEndian.PutUint32(res.buffer[0:4], uint32(days))
  507. binary.LittleEndian.PutUint32(res.buffer[4:8], uint32(tm))
  508. }
  509. default:
  510. err = fmt.Errorf("mssql: unknown type for %T", val)
  511. return
  512. }
  513. return
  514. }
  515. type MssqlResult struct {
  516. c *MssqlConn
  517. rowsAffected int64
  518. }
  519. func (r *MssqlResult) RowsAffected() (int64, error) {
  520. return r.rowsAffected, nil
  521. }
  522. func (r *MssqlResult) LastInsertId() (int64, error) {
  523. s, err := r.c.Prepare("select cast(@@identity as bigint)")
  524. if err != nil {
  525. return 0, err
  526. }
  527. defer s.Close()
  528. rows, err := s.Query(nil)
  529. if err != nil {
  530. return 0, err
  531. }
  532. defer rows.Close()
  533. dest := make([]driver.Value, 1)
  534. err = rows.Next(dest)
  535. if err != nil {
  536. return 0, err
  537. }
  538. if dest[0] == nil {
  539. return -1, errors.New("There is no generated identity value")
  540. }
  541. lastInsertId := dest[0].(int64)
  542. return lastInsertId, nil
  543. }