12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325 |
- package mssql
- import (
- "crypto/tls"
- "crypto/x509"
- "encoding/binary"
- "errors"
- "fmt"
- "io"
- "io/ioutil"
- "net"
- "net/url"
- "os"
- "sort"
- "strconv"
- "strings"
- "time"
- "unicode"
- "unicode/utf16"
- "unicode/utf8"
- "golang.org/x/net/context" // use the "x/net/context" for backwards compatibility.
- )
- func parseInstances(msg []byte) map[string]map[string]string {
- results := map[string]map[string]string{}
- if len(msg) > 3 && msg[0] == 5 {
- out_s := string(msg[3:])
- tokens := strings.Split(out_s, ";")
- instdict := map[string]string{}
- got_name := false
- var name string
- for _, token := range tokens {
- if got_name {
- instdict[name] = token
- got_name = false
- } else {
- name = token
- if len(name) == 0 {
- if len(instdict) == 0 {
- break
- }
- results[strings.ToUpper(instdict["InstanceName"])] = instdict
- instdict = map[string]string{}
- continue
- }
- got_name = true
- }
- }
- }
- return results
- }
- func getInstances(address string) (map[string]map[string]string, error) {
- conn, err := net.DialTimeout("udp", address+":1434", 5*time.Second)
- if err != nil {
- return nil, err
- }
- defer conn.Close()
- conn.SetDeadline(time.Now().Add(5 * time.Second))
- _, err = conn.Write([]byte{3})
- if err != nil {
- return nil, err
- }
- var resp = make([]byte, 16*1024-1)
- read, err := conn.Read(resp)
- if err != nil {
- return nil, err
- }
- return parseInstances(resp[:read]), nil
- }
- // tds versions
- const (
- verTDS70 = 0x70000000
- verTDS71 = 0x71000000
- verTDS71rev1 = 0x71000001
- verTDS72 = 0x72090002
- verTDS73A = 0x730A0003
- verTDS73 = verTDS73A
- verTDS73B = 0x730B0003
- verTDS74 = 0x74000004
- )
- // packet types
- // https://msdn.microsoft.com/en-us/library/dd304214.aspx
- const (
- packSQLBatch packetType = 1
- packRPCRequest = 3
- packReply = 4
- // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
- // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
- packAttention = 6
- packBulkLoadBCP = 7
- packTransMgrReq = 14
- packNormal = 15
- packLogin7 = 16
- packSSPIMessage = 17
- packPrelogin = 18
- )
- // prelogin fields
- // http://msdn.microsoft.com/en-us/library/dd357559.aspx
- const (
- preloginVERSION = 0
- preloginENCRYPTION = 1
- preloginINSTOPT = 2
- preloginTHREADID = 3
- preloginMARS = 4
- preloginTRACEID = 5
- preloginTERMINATOR = 0xff
- )
- const (
- encryptOff = 0 // Encryption is available but off.
- encryptOn = 1 // Encryption is available and on.
- encryptNotSup = 2 // Encryption is not available.
- encryptReq = 3 // Encryption is required.
- )
- type tdsSession struct {
- buf *tdsBuffer
- loginAck loginAckStruct
- database string
- partner string
- columns []columnStruct
- tranid uint64
- logFlags uint64
- log optionalLogger
- routedServer string
- routedPort uint16
- }
- const (
- logErrors = 1
- logMessages = 2
- logRows = 4
- logSQL = 8
- logParams = 16
- logTransaction = 32
- logDebug = 64
- )
- type columnStruct struct {
- UserType uint32
- Flags uint16
- ColName string
- ti typeInfo
- }
- type KeySlice []uint8
- func (p KeySlice) Len() int { return len(p) }
- func (p KeySlice) Less(i, j int) bool { return p[i] < p[j] }
- func (p KeySlice) Swap(i, j int) { p[i], p[j] = p[j], p[i] }
- // http://msdn.microsoft.com/en-us/library/dd357559.aspx
- func writePrelogin(w *tdsBuffer, fields map[uint8][]byte) error {
- var err error
- w.BeginPacket(packPrelogin)
- offset := uint16(5*len(fields) + 1)
- keys := make(KeySlice, 0, len(fields))
- for k, _ := range fields {
- keys = append(keys, k)
- }
- sort.Sort(keys)
- // writing header
- for _, k := range keys {
- err = w.WriteByte(k)
- if err != nil {
- return err
- }
- err = binary.Write(w, binary.BigEndian, offset)
- if err != nil {
- return err
- }
- v := fields[k]
- size := uint16(len(v))
- err = binary.Write(w, binary.BigEndian, size)
- if err != nil {
- return err
- }
- offset += size
- }
- err = w.WriteByte(preloginTERMINATOR)
- if err != nil {
- return err
- }
- // writing values
- for _, k := range keys {
- v := fields[k]
- written, err := w.Write(v)
- if err != nil {
- return err
- }
- if written != len(v) {
- return errors.New("Write method didn't write the whole value")
- }
- }
- return w.FinishPacket()
- }
- func readPrelogin(r *tdsBuffer) (map[uint8][]byte, error) {
- packet_type, err := r.BeginRead()
- if err != nil {
- return nil, err
- }
- struct_buf, err := ioutil.ReadAll(r)
- if err != nil {
- return nil, err
- }
- if packet_type != 4 {
- return nil, errors.New("Invalid respones, expected packet type 4, PRELOGIN RESPONSE")
- }
- offset := 0
- results := map[uint8][]byte{}
- for true {
- rec_type := struct_buf[offset]
- if rec_type == preloginTERMINATOR {
- break
- }
- rec_offset := binary.BigEndian.Uint16(struct_buf[offset+1:])
- rec_len := binary.BigEndian.Uint16(struct_buf[offset+3:])
- value := struct_buf[rec_offset : rec_offset+rec_len]
- results[rec_type] = value
- offset += 5
- }
- return results, nil
- }
- // OptionFlags2
- // http://msdn.microsoft.com/en-us/library/dd304019.aspx
- const (
- fLanguageFatal = 1
- fODBC = 2
- fTransBoundary = 4
- fCacheConnect = 8
- fIntSecurity = 0x80
- )
- // TypeFlags
- const (
- // 4 bits for fSQLType
- // 1 bit for fOLEDB
- fReadOnlyIntent = 32
- )
- type login struct {
- TDSVersion uint32
- PacketSize uint32
- ClientProgVer uint32
- ClientPID uint32
- ConnectionID uint32
- OptionFlags1 uint8
- OptionFlags2 uint8
- TypeFlags uint8
- OptionFlags3 uint8
- ClientTimeZone int32
- ClientLCID uint32
- HostName string
- UserName string
- Password string
- AppName string
- ServerName string
- CtlIntName string
- Language string
- Database string
- ClientID [6]byte
- SSPI []byte
- AtchDBFile string
- ChangePassword string
- }
- type loginHeader struct {
- Length uint32
- TDSVersion uint32
- PacketSize uint32
- ClientProgVer uint32
- ClientPID uint32
- ConnectionID uint32
- OptionFlags1 uint8
- OptionFlags2 uint8
- TypeFlags uint8
- OptionFlags3 uint8
- ClientTimeZone int32
- ClientLCID uint32
- HostNameOffset uint16
- HostNameLength uint16
- UserNameOffset uint16
- UserNameLength uint16
- PasswordOffset uint16
- PasswordLength uint16
- AppNameOffset uint16
- AppNameLength uint16
- ServerNameOffset uint16
- ServerNameLength uint16
- ExtensionOffset uint16
- ExtensionLenght uint16
- CtlIntNameOffset uint16
- CtlIntNameLength uint16
- LanguageOffset uint16
- LanguageLength uint16
- DatabaseOffset uint16
- DatabaseLength uint16
- ClientID [6]byte
- SSPIOffset uint16
- SSPILength uint16
- AtchDBFileOffset uint16
- AtchDBFileLength uint16
- ChangePasswordOffset uint16
- ChangePasswordLength uint16
- SSPILongLength uint32
- }
- // convert Go string to UTF-16 encoded []byte (littleEndian)
- // done manually rather than using bytes and binary packages
- // for performance reasons
- func str2ucs2(s string) []byte {
- res := utf16.Encode([]rune(s))
- ucs2 := make([]byte, 2*len(res))
- for i := 0; i < len(res); i++ {
- ucs2[2*i] = byte(res[i])
- ucs2[2*i+1] = byte(res[i] >> 8)
- }
- return ucs2
- }
- func ucs22str(s []byte) (string, error) {
- if len(s)%2 != 0 {
- return "", fmt.Errorf("Illegal UCS2 string length: %d", len(s))
- }
- buf := make([]uint16, len(s)/2)
- for i := 0; i < len(s); i += 2 {
- buf[i/2] = binary.LittleEndian.Uint16(s[i:])
- }
- return string(utf16.Decode(buf)), nil
- }
- func manglePassword(password string) []byte {
- var ucs2password []byte = str2ucs2(password)
- for i, ch := range ucs2password {
- ucs2password[i] = ((ch<<4)&0xff | (ch >> 4)) ^ 0xA5
- }
- return ucs2password
- }
- // http://msdn.microsoft.com/en-us/library/dd304019.aspx
- func sendLogin(w *tdsBuffer, login login) error {
- w.BeginPacket(packLogin7)
- hostname := str2ucs2(login.HostName)
- username := str2ucs2(login.UserName)
- password := manglePassword(login.Password)
- appname := str2ucs2(login.AppName)
- servername := str2ucs2(login.ServerName)
- ctlintname := str2ucs2(login.CtlIntName)
- language := str2ucs2(login.Language)
- database := str2ucs2(login.Database)
- atchdbfile := str2ucs2(login.AtchDBFile)
- changepassword := str2ucs2(login.ChangePassword)
- hdr := loginHeader{
- TDSVersion: login.TDSVersion,
- PacketSize: login.PacketSize,
- ClientProgVer: login.ClientProgVer,
- ClientPID: login.ClientPID,
- ConnectionID: login.ConnectionID,
- OptionFlags1: login.OptionFlags1,
- OptionFlags2: login.OptionFlags2,
- TypeFlags: login.TypeFlags,
- OptionFlags3: login.OptionFlags3,
- ClientTimeZone: login.ClientTimeZone,
- ClientLCID: login.ClientLCID,
- HostNameLength: uint16(utf8.RuneCountInString(login.HostName)),
- UserNameLength: uint16(utf8.RuneCountInString(login.UserName)),
- PasswordLength: uint16(utf8.RuneCountInString(login.Password)),
- AppNameLength: uint16(utf8.RuneCountInString(login.AppName)),
- ServerNameLength: uint16(utf8.RuneCountInString(login.ServerName)),
- CtlIntNameLength: uint16(utf8.RuneCountInString(login.CtlIntName)),
- LanguageLength: uint16(utf8.RuneCountInString(login.Language)),
- DatabaseLength: uint16(utf8.RuneCountInString(login.Database)),
- ClientID: login.ClientID,
- SSPILength: uint16(len(login.SSPI)),
- AtchDBFileLength: uint16(utf8.RuneCountInString(login.AtchDBFile)),
- ChangePasswordLength: uint16(utf8.RuneCountInString(login.ChangePassword)),
- }
- offset := uint16(binary.Size(hdr))
- hdr.HostNameOffset = offset
- offset += uint16(len(hostname))
- hdr.UserNameOffset = offset
- offset += uint16(len(username))
- hdr.PasswordOffset = offset
- offset += uint16(len(password))
- hdr.AppNameOffset = offset
- offset += uint16(len(appname))
- hdr.ServerNameOffset = offset
- offset += uint16(len(servername))
- hdr.CtlIntNameOffset = offset
- offset += uint16(len(ctlintname))
- hdr.LanguageOffset = offset
- offset += uint16(len(language))
- hdr.DatabaseOffset = offset
- offset += uint16(len(database))
- hdr.SSPIOffset = offset
- offset += uint16(len(login.SSPI))
- hdr.AtchDBFileOffset = offset
- offset += uint16(len(atchdbfile))
- hdr.ChangePasswordOffset = offset
- offset += uint16(len(changepassword))
- hdr.Length = uint32(offset)
- var err error
- err = binary.Write(w, binary.LittleEndian, &hdr)
- if err != nil {
- return err
- }
- _, err = w.Write(hostname)
- if err != nil {
- return err
- }
- _, err = w.Write(username)
- if err != nil {
- return err
- }
- _, err = w.Write(password)
- if err != nil {
- return err
- }
- _, err = w.Write(appname)
- if err != nil {
- return err
- }
- _, err = w.Write(servername)
- if err != nil {
- return err
- }
- _, err = w.Write(ctlintname)
- if err != nil {
- return err
- }
- _, err = w.Write(language)
- if err != nil {
- return err
- }
- _, err = w.Write(database)
- if err != nil {
- return err
- }
- _, err = w.Write(login.SSPI)
- if err != nil {
- return err
- }
- _, err = w.Write(atchdbfile)
- if err != nil {
- return err
- }
- _, err = w.Write(changepassword)
- if err != nil {
- return err
- }
- return w.FinishPacket()
- }
- func readUcs2(r io.Reader, numchars int) (res string, err error) {
- buf := make([]byte, numchars*2)
- _, err = io.ReadFull(r, buf)
- if err != nil {
- return "", err
- }
- return ucs22str(buf)
- }
- func readUsVarChar(r io.Reader) (res string, err error) {
- var numchars uint16
- err = binary.Read(r, binary.LittleEndian, &numchars)
- if err != nil {
- return "", err
- }
- return readUcs2(r, int(numchars))
- }
- func writeUsVarChar(w io.Writer, s string) (err error) {
- buf := str2ucs2(s)
- var numchars int = len(buf) / 2
- if numchars > 0xffff {
- panic("invalid size for US_VARCHAR")
- }
- err = binary.Write(w, binary.LittleEndian, uint16(numchars))
- if err != nil {
- return
- }
- _, err = w.Write(buf)
- return
- }
- func readBVarChar(r io.Reader) (res string, err error) {
- var numchars uint8
- err = binary.Read(r, binary.LittleEndian, &numchars)
- if err != nil {
- return "", err
- }
- return readUcs2(r, int(numchars))
- }
- func writeBVarChar(w io.Writer, s string) (err error) {
- buf := str2ucs2(s)
- var numchars int = len(buf) / 2
- if numchars > 0xff {
- panic("invalid size for B_VARCHAR")
- }
- err = binary.Write(w, binary.LittleEndian, uint8(numchars))
- if err != nil {
- return
- }
- _, err = w.Write(buf)
- return
- }
- func readBVarByte(r io.Reader) (res []byte, err error) {
- var length uint8
- err = binary.Read(r, binary.LittleEndian, &length)
- if err != nil {
- return
- }
- res = make([]byte, length)
- _, err = io.ReadFull(r, res)
- return
- }
- func readUshort(r io.Reader) (res uint16, err error) {
- err = binary.Read(r, binary.LittleEndian, &res)
- return
- }
- func readByte(r io.Reader) (res byte, err error) {
- var b [1]byte
- _, err = r.Read(b[:])
- res = b[0]
- return
- }
- // Packet Data Stream Headers
- // http://msdn.microsoft.com/en-us/library/dd304953.aspx
- type headerStruct struct {
- hdrtype uint16
- data []byte
- }
- const (
- dataStmHdrQueryNotif = 1 // query notifications
- dataStmHdrTransDescr = 2 // MARS transaction descriptor (required)
- dataStmHdrTraceActivity = 3
- )
- // Query Notifications Header
- // http://msdn.microsoft.com/en-us/library/dd304949.aspx
- type queryNotifHdr struct {
- notifyId string
- ssbDeployment string
- notifyTimeout uint32
- }
- func (hdr queryNotifHdr) pack() (res []byte) {
- notifyId := str2ucs2(hdr.notifyId)
- ssbDeployment := str2ucs2(hdr.ssbDeployment)
- res = make([]byte, 2+len(notifyId)+2+len(ssbDeployment)+4)
- b := res
- binary.LittleEndian.PutUint16(b, uint16(len(notifyId)))
- b = b[2:]
- copy(b, notifyId)
- b = b[len(notifyId):]
- binary.LittleEndian.PutUint16(b, uint16(len(ssbDeployment)))
- b = b[2:]
- copy(b, ssbDeployment)
- b = b[len(ssbDeployment):]
- binary.LittleEndian.PutUint32(b, hdr.notifyTimeout)
- return res
- }
- // MARS Transaction Descriptor Header
- // http://msdn.microsoft.com/en-us/library/dd340515.aspx
- type transDescrHdr struct {
- transDescr uint64 // transaction descriptor returned from ENVCHANGE
- outstandingReqCnt uint32 // outstanding request count
- }
- func (hdr transDescrHdr) pack() (res []byte) {
- res = make([]byte, 8+4)
- binary.LittleEndian.PutUint64(res, hdr.transDescr)
- binary.LittleEndian.PutUint32(res[8:], hdr.outstandingReqCnt)
- return res
- }
- func writeAllHeaders(w io.Writer, headers []headerStruct) (err error) {
- // calculatint total length
- var totallen uint32 = 4
- for _, hdr := range headers {
- totallen += 4 + 2 + uint32(len(hdr.data))
- }
- // writing
- err = binary.Write(w, binary.LittleEndian, totallen)
- if err != nil {
- return err
- }
- for _, hdr := range headers {
- var headerlen uint32 = 4 + 2 + uint32(len(hdr.data))
- err = binary.Write(w, binary.LittleEndian, headerlen)
- if err != nil {
- return err
- }
- err = binary.Write(w, binary.LittleEndian, hdr.hdrtype)
- if err != nil {
- return err
- }
- _, err = w.Write(hdr.data)
- if err != nil {
- return err
- }
- }
- return nil
- }
- func sendSqlBatch72(buf *tdsBuffer,
- sqltext string,
- headers []headerStruct) (err error) {
- buf.BeginPacket(packSQLBatch)
- if err = writeAllHeaders(buf, headers); err != nil {
- return
- }
- _, err = buf.Write(str2ucs2(sqltext))
- if err != nil {
- return
- }
- return buf.FinishPacket()
- }
- // 2.2.1.7 Attention: https://msdn.microsoft.com/en-us/library/dd341449.aspx
- // 4.19.2 Out-of-Band Attention Signal: https://msdn.microsoft.com/en-us/library/dd305167.aspx
- func sendAttention(buf *tdsBuffer) error {
- buf.BeginPacket(packAttention)
- return buf.FinishPacket()
- }
- type connectParams struct {
- logFlags uint64
- port uint64
- host string
- instance string
- database string
- user string
- password string
- dial_timeout time.Duration
- conn_timeout time.Duration
- keepAlive time.Duration
- encrypt bool
- disableEncryption bool
- trustServerCertificate bool
- certificate string
- hostInCertificate string
- serverSPN string
- workstation string
- appname string
- typeFlags uint8
- failOverPartner string
- failOverPort uint64
- }
- func splitConnectionString(dsn string) (res map[string]string) {
- res = map[string]string{}
- parts := strings.Split(dsn, ";")
- for _, part := range parts {
- if len(part) == 0 {
- continue
- }
- lst := strings.SplitN(part, "=", 2)
- name := strings.TrimSpace(strings.ToLower(lst[0]))
- if len(name) == 0 {
- continue
- }
- var value string = ""
- if len(lst) > 1 {
- value = strings.TrimSpace(lst[1])
- }
- res[name] = value
- }
- return res
- }
- // Splits a URL in the ODBC format
- func splitConnectionStringOdbc(dsn string) (map[string]string, error) {
- res := map[string]string{}
- type parserState int
- const (
- // Before the start of a key
- parserStateBeforeKey parserState = iota
- // Inside a key
- parserStateKey
- // Beginning of a value. May be bare or braced
- parserStateBeginValue
- // Inside a bare value
- parserStateBareValue
- // Inside a braced value
- parserStateBracedValue
- // A closing brace inside a braced value.
- // May be the end of the value or an escaped closing brace, depending on the next character
- parserStateBracedValueClosingBrace
- // After a value. Next character should be a semi-colon or whitespace.
- parserStateEndValue
- )
- var state = parserStateBeforeKey
- var key string
- var value string
- for i, c := range dsn {
- switch state {
- case parserStateBeforeKey:
- switch {
- case c == '=':
- return res, fmt.Errorf("Unexpected character = at index %d. Expected start of key or semi-colon or whitespace.", i)
- case !unicode.IsSpace(c) && c != ';':
- state = parserStateKey
- key += string(c)
- }
- case parserStateKey:
- switch c {
- case '=':
- key = normalizeOdbcKey(key)
- if len(key) == 0 {
- return res, fmt.Errorf("Unexpected end of key at index %d.", i)
- }
- state = parserStateBeginValue
- case ';':
- // Key without value
- key = normalizeOdbcKey(key)
- if len(key) == 0 {
- return res, fmt.Errorf("Unexpected end of key at index %d.", i)
- }
- res[key] = value
- key = ""
- value = ""
- state = parserStateBeforeKey
- default:
- key += string(c)
- }
- case parserStateBeginValue:
- switch {
- case c == '{':
- state = parserStateBracedValue
- case c == ';':
- // Empty value
- res[key] = value
- key = ""
- state = parserStateBeforeKey
- case unicode.IsSpace(c):
- // Ignore whitespace
- default:
- state = parserStateBareValue
- value += string(c)
- }
- case parserStateBareValue:
- if c == ';' {
- res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
- key = ""
- value = ""
- state = parserStateBeforeKey
- } else {
- value += string(c)
- }
- case parserStateBracedValue:
- if c == '}' {
- state = parserStateBracedValueClosingBrace
- } else {
- value += string(c)
- }
- case parserStateBracedValueClosingBrace:
- if c == '}' {
- // Escaped closing brace
- value += string(c)
- state = parserStateBracedValue
- continue
- }
- // End of braced value
- res[key] = value
- key = ""
- value = ""
- // This character is the first character past the end,
- // so it needs to be parsed like the parserStateEndValue state.
- state = parserStateEndValue
- switch {
- case c == ';':
- state = parserStateBeforeKey
- case unicode.IsSpace(c):
- // Ignore whitespace
- default:
- return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
- }
- case parserStateEndValue:
- switch {
- case c == ';':
- state = parserStateBeforeKey
- case unicode.IsSpace(c):
- // Ignore whitespace
- default:
- return res, fmt.Errorf("Unexpected character %c at index %d. Expected semi-colon or whitespace.", c, i)
- }
- }
- }
- switch state {
- case parserStateBeforeKey: // Okay
- case parserStateKey: // Unfinished key. Treat as key without value.
- key = normalizeOdbcKey(key)
- if len(key) == 0 {
- return res, fmt.Errorf("Unexpected end of key at index %d.", len(dsn))
- }
- res[key] = value
- case parserStateBeginValue: // Empty value
- res[key] = value
- case parserStateBareValue:
- res[key] = strings.TrimRightFunc(value, unicode.IsSpace)
- case parserStateBracedValue:
- return res, fmt.Errorf("Unexpected end of braced value at index %d.", len(dsn))
- case parserStateBracedValueClosingBrace: // End of braced value
- res[key] = value
- case parserStateEndValue: // Okay
- }
- return res, nil
- }
- // Normalizes the given string as an ODBC-format key
- func normalizeOdbcKey(s string) string {
- return strings.ToLower(strings.TrimRightFunc(s, unicode.IsSpace))
- }
- // Splits a URL of the form sqlserver://username:password@host/instance?param1=value¶m2=value
- func splitConnectionStringURL(dsn string) (map[string]string, error) {
- res := map[string]string{}
- u, err := url.Parse(dsn)
- if err != nil {
- return res, err
- }
- if u.Scheme != "sqlserver" {
- return res, fmt.Errorf("scheme %s is not recognized", u.Scheme)
- }
- if u.User != nil {
- res["user id"] = u.User.Username()
- p, exists := u.User.Password()
- if exists {
- res["password"] = p
- }
- }
- host, port, err := net.SplitHostPort(u.Host)
- if err != nil {
- host = u.Host
- }
- if len(u.Path) > 0 {
- res["server"] = host + "\\" + u.Path[1:]
- } else {
- res["server"] = host
- }
- if len(port) > 0 {
- res["port"] = port
- }
- query := u.Query()
- for k, v := range query {
- if len(v) > 1 {
- return res, fmt.Errorf("key %s provided more than once", k)
- }
- res[k] = v[0]
- }
- return res, nil
- }
- func parseConnectParams(dsn string) (connectParams, error) {
- var p connectParams
- var params map[string]string
- if strings.HasPrefix(dsn, "odbc:") {
- parameters, err := splitConnectionStringOdbc(dsn[len("odbc:"):])
- if err != nil {
- return p, err
- }
- params = parameters
- } else if strings.HasPrefix(dsn, "sqlserver://") {
- parameters, err := splitConnectionStringURL(dsn)
- if err != nil {
- return p, err
- }
- params = parameters
- } else {
- params = splitConnectionString(dsn)
- }
- strlog, ok := params["log"]
- if ok {
- var err error
- p.logFlags, err = strconv.ParseUint(strlog, 10, 0)
- if err != nil {
- return p, fmt.Errorf("Invalid log parameter '%s': %s", strlog, err.Error())
- }
- }
- server := params["server"]
- parts := strings.SplitN(server, "\\", 2)
- p.host = parts[0]
- if p.host == "." || strings.ToUpper(p.host) == "(LOCAL)" || p.host == "" {
- p.host = "localhost"
- }
- if len(parts) > 1 {
- p.instance = parts[1]
- }
- p.database = params["database"]
- p.user = params["user id"]
- p.password = params["password"]
- p.port = 1433
- strport, ok := params["port"]
- if ok {
- var err error
- p.port, err = strconv.ParseUint(strport, 0, 16)
- if err != nil {
- f := "Invalid tcp port '%v': %v"
- return p, fmt.Errorf(f, strport, err.Error())
- }
- }
- // https://msdn.microsoft.com/en-us/library/dd341108.aspx
- p.dial_timeout = 15 * time.Second
- p.conn_timeout = 30 * time.Second
- strconntimeout, ok := params["connection timeout"]
- if ok {
- timeout, err := strconv.ParseUint(strconntimeout, 0, 16)
- if err != nil {
- f := "Invalid connection timeout '%v': %v"
- return p, fmt.Errorf(f, strconntimeout, err.Error())
- }
- p.conn_timeout = time.Duration(timeout) * time.Second
- }
- strdialtimeout, ok := params["dial timeout"]
- if ok {
- timeout, err := strconv.ParseUint(strdialtimeout, 0, 16)
- if err != nil {
- f := "Invalid dial timeout '%v': %v"
- return p, fmt.Errorf(f, strdialtimeout, err.Error())
- }
- p.dial_timeout = time.Duration(timeout) * time.Second
- }
- // default keep alive should be 30 seconds according to spec:
- // https://msdn.microsoft.com/en-us/library/dd341108.aspx
- p.keepAlive = 30 * time.Second
- keepAlive, ok := params["keepalive"]
- if ok {
- timeout, err := strconv.ParseUint(keepAlive, 0, 16)
- if err != nil {
- f := "Invalid keepAlive value '%s': %s"
- return p, fmt.Errorf(f, keepAlive, err.Error())
- }
- p.keepAlive = time.Duration(timeout) * time.Second
- }
- encrypt, ok := params["encrypt"]
- if ok {
- if strings.ToUpper(encrypt) == "DISABLE" {
- p.disableEncryption = true
- } else {
- var err error
- p.encrypt, err = strconv.ParseBool(encrypt)
- if err != nil {
- f := "Invalid encrypt '%s': %s"
- return p, fmt.Errorf(f, encrypt, err.Error())
- }
- }
- } else {
- p.trustServerCertificate = true
- }
- trust, ok := params["trustservercertificate"]
- if ok {
- var err error
- p.trustServerCertificate, err = strconv.ParseBool(trust)
- if err != nil {
- f := "Invalid trust server certificate '%s': %s"
- return p, fmt.Errorf(f, trust, err.Error())
- }
- }
- p.certificate = params["certificate"]
- p.hostInCertificate, ok = params["hostnameincertificate"]
- if !ok {
- p.hostInCertificate = p.host
- }
- serverSPN, ok := params["serverspn"]
- if ok {
- p.serverSPN = serverSPN
- } else {
- p.serverSPN = fmt.Sprintf("MSSQLSvc/%s:%d", p.host, p.port)
- }
- workstation, ok := params["workstation id"]
- if ok {
- p.workstation = workstation
- } else {
- workstation, err := os.Hostname()
- if err == nil {
- p.workstation = workstation
- }
- }
- appname, ok := params["app name"]
- if !ok {
- appname = "go-mssqldb"
- }
- p.appname = appname
- appintent, ok := params["applicationintent"]
- if ok {
- if appintent == "ReadOnly" {
- p.typeFlags |= fReadOnlyIntent
- }
- }
- failOverPartner, ok := params["failoverpartner"]
- if ok {
- p.failOverPartner = failOverPartner
- }
- failOverPort, ok := params["failoverport"]
- if ok {
- var err error
- p.failOverPort, err = strconv.ParseUint(failOverPort, 0, 16)
- if err != nil {
- f := "Invalid tcp port '%v': %v"
- return p, fmt.Errorf(f, failOverPort, err.Error())
- }
- }
- return p, nil
- }
- type Auth interface {
- InitialBytes() ([]byte, error)
- NextBytes([]byte) ([]byte, error)
- Free()
- }
- // SQL Server AlwaysOn Availability Group Listeners are bound by DNS to a
- // list of IP addresses. So if there is more than one, try them all and
- // use the first one that allows a connection.
- func dialConnection(p connectParams) (conn net.Conn, err error) {
- var ips []net.IP
- ips, err = net.LookupIP(p.host)
- if err != nil {
- ip := net.ParseIP(p.host)
- if ip == nil {
- return nil, err
- }
- ips = []net.IP{ip}
- }
- if len(ips) == 1 {
- d := createDialer(&p)
- addr := net.JoinHostPort(ips[0].String(), strconv.Itoa(int(p.port)))
- conn, err = d.Dial(addr)
- } else {
- //Try Dials in parallel to avoid waiting for timeouts.
- connChan := make(chan net.Conn, len(ips))
- errChan := make(chan error, len(ips))
- portStr := strconv.Itoa(int(p.port))
- for _, ip := range ips {
- go func(ip net.IP) {
- d := createDialer(&p)
- addr := net.JoinHostPort(ip.String(), portStr)
- conn, err := d.Dial(addr)
- if err == nil {
- connChan <- conn
- } else {
- errChan <- err
- }
- }(ip)
- }
- // Wait for either the *first* successful connection, or all the errors
- wait_loop:
- for i, _ := range ips {
- select {
- case conn = <-connChan:
- // Got a connection to use, close any others
- go func(n int) {
- for i := 0; i < n; i++ {
- select {
- case conn := <-connChan:
- conn.Close()
- case <-errChan:
- }
- }
- }(len(ips) - i - 1)
- // Remove any earlier errors we may have collected
- err = nil
- break wait_loop
- case err = <-errChan:
- }
- }
- }
- // Can't do the usual err != nil check, as it is possible to have gotten an error before a successful connection
- if conn == nil {
- f := "Unable to open tcp connection with host '%v:%v': %v"
- return nil, fmt.Errorf(f, p.host, p.port, err.Error())
- }
- return conn, err
- }
- func connect(log optionalLogger, p connectParams) (res *tdsSession, err error) {
- res = nil
- // if instance is specified use instance resolution service
- if p.instance != "" {
- p.instance = strings.ToUpper(p.instance)
- instances, err := getInstances(p.host)
- if err != nil {
- f := "Unable to get instances from Sql Server Browser on host %v: %v"
- return nil, fmt.Errorf(f, p.host, err.Error())
- }
- strport, ok := instances[p.instance]["tcp"]
- if !ok {
- f := "No instance matching '%v' returned from host '%v'"
- return nil, fmt.Errorf(f, p.instance, p.host)
- }
- p.port, err = strconv.ParseUint(strport, 0, 16)
- if err != nil {
- f := "Invalid tcp port returned from Sql Server Browser '%v': %v"
- return nil, fmt.Errorf(f, strport, err.Error())
- }
- }
- initiate_connection:
- conn, err := dialConnection(p)
- if err != nil {
- return nil, err
- }
- toconn := NewTimeoutConn(conn, p.conn_timeout)
- outbuf := newTdsBuffer(4096, toconn)
- sess := tdsSession{
- buf: outbuf,
- log: log,
- logFlags: p.logFlags,
- }
- instance_buf := []byte(p.instance)
- instance_buf = append(instance_buf, 0) // zero terminate instance name
- var encrypt byte
- if p.disableEncryption {
- encrypt = encryptNotSup
- } else if p.encrypt {
- encrypt = encryptOn
- } else {
- encrypt = encryptOff
- }
- fields := map[uint8][]byte{
- preloginVERSION: {0, 0, 0, 0, 0, 0},
- preloginENCRYPTION: {encrypt},
- preloginINSTOPT: instance_buf,
- preloginTHREADID: {0, 0, 0, 0},
- preloginMARS: {0}, // MARS disabled
- }
- err = writePrelogin(outbuf, fields)
- if err != nil {
- return nil, err
- }
- fields, err = readPrelogin(outbuf)
- if err != nil {
- return nil, err
- }
- encryptBytes, ok := fields[preloginENCRYPTION]
- if !ok {
- return nil, fmt.Errorf("Encrypt negotiation failed")
- }
- encrypt = encryptBytes[0]
- if p.encrypt && (encrypt == encryptNotSup || encrypt == encryptOff) {
- return nil, fmt.Errorf("Server does not support encryption")
- }
- if encrypt != encryptNotSup {
- var config tls.Config
- if p.certificate != "" {
- pem, err := ioutil.ReadFile(p.certificate)
- if err != nil {
- f := "Cannot read certificate '%s': %s"
- return nil, fmt.Errorf(f, p.certificate, err.Error())
- }
- certs := x509.NewCertPool()
- certs.AppendCertsFromPEM(pem)
- config.RootCAs = certs
- }
- if p.trustServerCertificate {
- config.InsecureSkipVerify = true
- }
- config.ServerName = p.hostInCertificate
- outbuf.transport = conn
- toconn.buf = outbuf
- tlsConn := tls.Client(toconn, &config)
- err = tlsConn.Handshake()
- toconn.buf = nil
- outbuf.transport = tlsConn
- if err != nil {
- f := "TLS Handshake failed: %s"
- return nil, fmt.Errorf(f, err.Error())
- }
- if encrypt == encryptOff {
- outbuf.afterFirst = func() {
- outbuf.transport = toconn
- }
- }
- }
- login := login{
- TDSVersion: verTDS74,
- PacketSize: outbuf.PackageSize(),
- Database: p.database,
- OptionFlags2: fODBC, // to get unlimited TEXTSIZE
- HostName: p.workstation,
- ServerName: p.host,
- AppName: p.appname,
- TypeFlags: p.typeFlags,
- }
- auth, auth_ok := getAuth(p.user, p.password, p.serverSPN, p.workstation)
- if auth_ok {
- login.SSPI, err = auth.InitialBytes()
- if err != nil {
- return nil, err
- }
- login.OptionFlags2 |= fIntSecurity
- defer auth.Free()
- } else {
- login.UserName = p.user
- login.Password = p.password
- }
- err = sendLogin(outbuf, login)
- if err != nil {
- return nil, err
- }
- // processing login response
- var sspi_msg []byte
- continue_login:
- tokchan := make(chan tokenStruct, 5)
- go processResponse(context.Background(), &sess, tokchan)
- success := false
- for tok := range tokchan {
- switch token := tok.(type) {
- case sspiMsg:
- sspi_msg, err = auth.NextBytes(token)
- if err != nil {
- return nil, err
- }
- case loginAckStruct:
- success = true
- sess.loginAck = token
- case error:
- return nil, fmt.Errorf("Login error: %s", token.Error())
- }
- }
- if sspi_msg != nil {
- outbuf.BeginPacket(packSSPIMessage)
- _, err = outbuf.Write(sspi_msg)
- if err != nil {
- return nil, err
- }
- err = outbuf.FinishPacket()
- if err != nil {
- return nil, err
- }
- sspi_msg = nil
- goto continue_login
- }
- if !success {
- return nil, fmt.Errorf("Login failed")
- }
- if sess.routedServer != "" {
- toconn.Close()
- p.host = sess.routedServer
- p.port = uint64(sess.routedPort)
- goto initiate_connection
- }
- return &sess, nil
- }
|