123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257 |
- package mssql
- import (
- "bytes"
- "io"
- "strconv"
- )
- type parser struct {
- r *bytes.Reader
- w bytes.Buffer
- paramCount int
- paramMax int
- // using map as a set
- namedParams map [string]bool
- }
- func (p *parser) next() (rune, bool) {
- ch, _, err := p.r.ReadRune()
- if err != nil {
- if err != io.EOF {
- panic(err)
- }
- return 0, false
- }
- return ch, true
- }
- func (p *parser) unread() {
- err := p.r.UnreadRune()
- if err != nil {
- panic(err)
- }
- }
- func (p *parser) write(ch rune) {
- p.w.WriteRune(ch)
- }
- type stateFunc func(*parser) stateFunc
- func parseParams(query string) (string, int) {
- p := &parser{
- r: bytes.NewReader([]byte(query)),
- namedParams: map [string]bool{},
- }
- state := parseNormal
- for state != nil {
- state = state(p)
- }
- return p.w.String(), p.paramMax + len(p.namedParams)
- }
- func parseNormal(p *parser) stateFunc {
- for {
- ch, ok := p.next()
- if !ok {
- return nil
- }
- if ch == '?' {
- return parseOrdinalParameter
- } else if ch == '$' || ch == ':' {
- ch2, ok := p.next()
- if !ok {
- p.write(ch)
- return nil
- }
- p.unread()
- if ch2 >= '0' && ch2 <= '9' {
- return parseOrdinalParameter
- } else if 'a' <= ch2 && ch2 <= 'z' || 'A' <= ch2 && ch2 <= 'Z' {
- return parseNamedParameter
- }
- }
- p.write(ch)
- switch ch {
- case '\'':
- return parseQuote
- case '"':
- return parseDoubleQuote
- case '[':
- return parseBracket
- case '-':
- return parseLineComment
- case '/':
- return parseComment
- }
- }
- }
- func parseOrdinalParameter(p *parser) stateFunc {
- var paramN int
- var ok bool
- for {
- var ch rune
- ch, ok = p.next()
- if ok && ch >= '0' && ch <= '9' {
- paramN = paramN*10 + int(ch-'0')
- } else {
- break
- }
- }
- if ok {
- p.unread()
- }
- if paramN == 0 {
- p.paramCount++
- paramN = p.paramCount
- }
- if paramN > p.paramMax {
- p.paramMax = paramN
- }
- p.w.WriteString("@p")
- p.w.WriteString(strconv.Itoa(paramN))
- if !ok {
- return nil
- }
- return parseNormal
- }
- func parseNamedParameter(p *parser) stateFunc {
- var paramName string
- var ok bool
- for {
- var ch rune
- ch, ok = p.next()
- if ok && (ch >= '0' && ch <= '9' || 'a' <= ch && ch <= 'z' || 'A' <= ch && ch <= 'Z') {
- paramName = paramName + string(ch)
- } else {
- break
- }
- }
- if ok {
- p.unread()
- }
- p.namedParams[paramName] = true
- p.w.WriteString("@")
- p.w.WriteString(paramName)
- if !ok {
- return nil
- }
- return parseNormal
- }
- func parseQuote(p *parser) stateFunc {
- for {
- ch, ok := p.next()
- if !ok {
- return nil
- }
- p.write(ch)
- if ch == '\'' {
- return parseNormal
- }
- }
- }
- func parseDoubleQuote(p *parser) stateFunc {
- for {
- ch, ok := p.next()
- if !ok {
- return nil
- }
- p.write(ch)
- if ch == '"' {
- return parseNormal
- }
- }
- }
- func parseBracket(p *parser) stateFunc {
- for {
- ch, ok := p.next()
- if !ok {
- return nil
- }
- p.write(ch)
- if ch == ']' {
- ch, ok = p.next()
- if !ok {
- return nil
- }
- if ch != ']' {
- p.unread()
- return parseNormal
- }
- p.write(ch)
- }
- }
- }
- func parseLineComment(p *parser) stateFunc {
- ch, ok := p.next()
- if !ok {
- return nil
- }
- if ch != '-' {
- p.unread()
- return parseNormal
- }
- p.write(ch)
- for {
- ch, ok = p.next()
- if !ok {
- return nil
- }
- p.write(ch)
- if ch == '\n' {
- return parseNormal
- }
- }
- }
- func parseComment(p *parser) stateFunc {
- var nested int
- ch, ok := p.next()
- if !ok {
- return nil
- }
- if ch != '*' {
- p.unread()
- return parseNormal
- }
- p.write(ch)
- for {
- ch, ok = p.next()
- if !ok {
- return nil
- }
- p.write(ch)
- for ch == '*' {
- ch, ok = p.next()
- if !ok {
- return nil
- }
- p.write(ch)
- if ch == '/' {
- if nested == 0 {
- return parseNormal
- } else {
- nested--
- }
- }
- }
- for ch == '/' {
- ch, ok = p.next()
- if !ok {
- return nil
- }
- p.write(ch)
- if ch == '*' {
- nested++
- }
- }
- }
- }
|